diff --git a/apps/application/chat_pipeline/I_base_chat_pipeline.py b/apps/application/chat_pipeline/I_base_chat_pipeline.py index eaa255d13..3a7f43060 100644 --- a/apps/application/chat_pipeline/I_base_chat_pipeline.py +++ b/apps/application/chat_pipeline/I_base_chat_pipeline.py @@ -8,10 +8,91 @@ """ import time from abc import abstractmethod -from typing import Type +from typing import Type, Dict from rest_framework import serializers +from dataset.models import Paragraph + + +class ParagraphPipelineModel: + + def __init__(self, _id: str, document_id: str, dataset_id: str, content: str, title: str, status: str, + is_active: bool, comprehensive_score: float, similarity: float, dataset_name: str, document_name: str): + self.id = _id + self.document_id = document_id + self.dataset_id = dataset_id + self.content = content + self.title = title + self.status = status, + self.is_active = is_active + self.comprehensive_score = comprehensive_score + self.similarity = similarity + self.dataset_name = dataset_name + self.document_name = document_name + + def to_dict(self): + return { + 'id': self.id, + 'document_id': self.document_id, + 'dataset_id': self.dataset_id, + 'content': self.content, + 'title': self.title, + 'status': self.status, + 'is_active': self.is_active, + 'comprehensive_score': self.comprehensive_score, + 'similarity': self.similarity, + 'dataset_name': self.dataset_name, + 'document_name': self.document_name + } + + class builder: + def __init__(self): + self.similarity = None + self.paragraph = {} + self.comprehensive_score = None + self.document_name = None + self.dataset_name = None + + def add_paragraph(self, paragraph): + if isinstance(paragraph, Paragraph): + self.paragraph = {'id': paragraph.id, + 'document_id': paragraph.document_id, + 'dataset_id': paragraph.dataset_id, + 'content': paragraph.content, + 'title': paragraph.title, + 'status': paragraph.status, + 'is_active': paragraph.is_active, + } + else: + self.paragraph = paragraph + return self + + def add_dataset_name(self, dataset_name): + self.dataset_name = dataset_name + return self + + def add_document_name(self, document_name): + self.document_name = document_name + return self + + def add_comprehensive_score(self, comprehensive_score: float): + self.comprehensive_score = comprehensive_score + return self + + def add_similarity(self, similarity: float): + self.similarity = similarity + return self + + def build(self): + return ParagraphPipelineModel(str(self.paragraph.get('id')), str(self.paragraph.get('document_id')), + str(self.paragraph.get('dataset_id')), + self.paragraph.get('content'), self.paragraph.get('title'), + self.paragraph.get('status'), + self.paragraph.get('is_active'), + self.comprehensive_score, self.similarity, self.dataset_name, + self.document_name) + class IBaseChatPipelineStep: def __init__(self): diff --git a/apps/application/chat_pipeline/step/chat_step/i_chat_step.py b/apps/application/chat_pipeline/step/chat_step/i_chat_step.py index b02273ae8..1168fb9af 100644 --- a/apps/application/chat_pipeline/step/chat_step/i_chat_step.py +++ b/apps/application/chat_pipeline/step/chat_step/i_chat_step.py @@ -13,7 +13,7 @@ from langchain.chat_models.base import BaseChatModel from langchain.schema import BaseMessage from rest_framework import serializers -from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep +from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep, ParagraphPipelineModel from application.chat_pipeline.pipeline_manage import PiplineManage from common.field.common import InstanceField from dataset.models import Paragraph @@ -41,7 +41,8 @@ class MessageField(serializers.Field): class PostResponseHandler: @abstractmethod - def handler(self, chat_id, chat_record_id, paragraph_list: List[Paragraph], problem_text: str, answer_text, + def handler(self, chat_id, chat_record_id, paragraph_list: List[ParagraphPipelineModel], problem_text: str, + answer_text, manage, step, padding_problem_text: str = None, **kwargs): pass diff --git a/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py b/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py index 9fa4c511e..080142893 100644 --- a/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py +++ b/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py @@ -18,15 +18,15 @@ from langchain.chat_models.base import BaseChatModel from langchain.schema import BaseMessage from langchain.schema.messages import BaseMessageChunk, HumanMessage +from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel from application.chat_pipeline.pipeline_manage import PiplineManage from application.chat_pipeline.step.chat_step.i_chat_step import IChatStep, PostResponseHandler -from dataset.models import Paragraph def event_content(response, chat_id, chat_record_id, - paragraph_list: List[Paragraph], + paragraph_list: List[ParagraphPipelineModel], post_response_handler: PostResponseHandler, manage, step, diff --git a/apps/application/chat_pipeline/step/generate_human_message_step/i_generate_human_message_step.py b/apps/application/chat_pipeline/step/generate_human_message_step/i_generate_human_message_step.py index e26b06347..b5b93d05d 100644 --- a/apps/application/chat_pipeline/step/generate_human_message_step/i_generate_human_message_step.py +++ b/apps/application/chat_pipeline/step/generate_human_message_step/i_generate_human_message_step.py @@ -12,7 +12,7 @@ from typing import Type, List from langchain.schema import BaseMessage from rest_framework import serializers -from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep +from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep, ParagraphPipelineModel from application.chat_pipeline.pipeline_manage import PiplineManage from application.models import ChatRecord from common.field.common import InstanceField @@ -24,7 +24,7 @@ class IGenerateHumanMessageStep(IBaseChatPipelineStep): # 问题 problem_text = serializers.CharField(required=True) # 段落列表 - paragraph_list = serializers.ListField(child=InstanceField(model_type=Paragraph, required=True)) + paragraph_list = serializers.ListField(child=InstanceField(model_type=ParagraphPipelineModel, required=True)) # 历史对答 history_chat_record = serializers.ListField(child=InstanceField(model_type=ChatRecord, required=True)) # 多轮对话数量 @@ -46,7 +46,7 @@ class IGenerateHumanMessageStep(IBaseChatPipelineStep): @abstractmethod def execute(self, problem_text: str, - paragraph_list: List[Paragraph], + paragraph_list: List[ParagraphPipelineModel], history_chat_record: List[ChatRecord], dialogue_number: int, max_paragraph_char_number: int, diff --git a/apps/application/chat_pipeline/step/generate_human_message_step/impl/base_generate_human_message_step.py b/apps/application/chat_pipeline/step/generate_human_message_step/impl/base_generate_human_message_step.py index fcd35ac75..c7a56bc6b 100644 --- a/apps/application/chat_pipeline/step/generate_human_message_step/impl/base_generate_human_message_step.py +++ b/apps/application/chat_pipeline/step/generate_human_message_step/impl/base_generate_human_message_step.py @@ -10,17 +10,17 @@ from typing import List from langchain.schema import BaseMessage, HumanMessage +from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel from application.chat_pipeline.step.generate_human_message_step.i_generate_human_message_step import \ IGenerateHumanMessageStep from application.models import ChatRecord from common.util.split_model import flat_map -from dataset.models import Paragraph class BaseGenerateHumanMessageStep(IGenerateHumanMessageStep): def execute(self, problem_text: str, - paragraph_list: List[Paragraph], + paragraph_list: List[ParagraphPipelineModel], history_chat_record: List[ChatRecord], dialogue_number: int, max_paragraph_char_number: int, @@ -39,7 +39,7 @@ class BaseGenerateHumanMessageStep(IGenerateHumanMessageStep): def to_human_message(prompt: str, problem: str, max_paragraph_char_number: int, - paragraph_list: List[Paragraph]): + paragraph_list: List[ParagraphPipelineModel]): if paragraph_list is None or len(paragraph_list) == 0: return HumanMessage(content=problem) temp_data = "" diff --git a/apps/application/chat_pipeline/step/search_dataset_step/i_search_dataset_step.py b/apps/application/chat_pipeline/step/search_dataset_step/i_search_dataset_step.py index a79bad8ba..08ea08d1f 100644 --- a/apps/application/chat_pipeline/step/search_dataset_step/i_search_dataset_step.py +++ b/apps/application/chat_pipeline/step/search_dataset_step/i_search_dataset_step.py @@ -11,7 +11,7 @@ from typing import List, Type from rest_framework import serializers -from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep +from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep, ParagraphPipelineModel from application.chat_pipeline.pipeline_manage import PiplineManage from dataset.models import Paragraph @@ -39,11 +39,12 @@ class ISearchDatasetStep(IBaseChatPipelineStep): def _run(self, manage: PiplineManage): paragraph_list = self.execute(**self.context['step_args']) manage.context['paragraph_list'] = paragraph_list + self.context['paragraph_list'] = paragraph_list @abstractmethod def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str], exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None, - **kwargs) -> List[Paragraph]: + **kwargs) -> List[ParagraphPipelineModel]: """ 关于 用户和补全问题 说明: 补全问题如果有就使用补全问题去查询 反之就用用户原始问题查询 :param similarity: 相关性 diff --git a/apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py b/apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py index 4fa13639a..1781d4f3d 100644 --- a/apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py +++ b/apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py @@ -6,20 +6,25 @@ @date:2024/1/10 10:33 @desc: """ -from typing import List +import os +from typing import List, Dict from django.db.models import QuerySet +from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel from application.chat_pipeline.step.search_dataset_step.i_search_dataset_step import ISearchDatasetStep from common.config.embedding_config import VectorStore, EmbeddingModel +from common.db.search import native_search +from common.util.file_util import get_file_content from dataset.models import Paragraph +from smartdoc.conf import PROJECT_DIR class BaseSearchDatasetStep(ISearchDatasetStep): def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str], exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None, - **kwargs) -> List[Paragraph]: + **kwargs) -> List[ParagraphPipelineModel]: exec_problem_text = padding_problem_text if padding_problem_text is not None else problem_text embedding_model = EmbeddingModel.get_embedding_model() embedding_value = embedding_model.embed_query(exec_problem_text) @@ -28,16 +33,35 @@ class BaseSearchDatasetStep(ISearchDatasetStep): exclude_paragraph_id_list, True, top_n, similarity) if embedding_list is None: return [] - return self.list_paragraph([row.get('paragraph_id') for row in embedding_list], vector) + paragraph_list = self.list_paragraph([row.get('paragraph_id') for row in embedding_list], vector) + return [self.reset_paragraph(paragraph, embedding_list) for paragraph in paragraph_list] + + @staticmethod + def reset_paragraph(paragraph: Dict, embedding_list: List) -> ParagraphPipelineModel: + filter_embedding_list = [embedding for embedding in embedding_list if + str(embedding.get('paragraph_id')) == str(paragraph.get('id'))] + if filter_embedding_list is not None and len(filter_embedding_list) > 0: + find_embedding = filter_embedding_list[-1] + return (ParagraphPipelineModel.builder() + .add_paragraph(paragraph) + .add_similarity(find_embedding.get('similarity')) + .add_comprehensive_score(find_embedding.get('comprehensive_score')) + .add_dataset_name(paragraph.get('dataset_name')) + .add_document_name(paragraph.get('document_name')) + .build()) @staticmethod def list_paragraph(paragraph_id_list: List, vector): if paragraph_id_list is None or len(paragraph_id_list) == 0: return [] - paragraph_list = QuerySet(Paragraph).filter(id__in=paragraph_id_list) + paragraph_list = native_search(QuerySet(Paragraph).filter(id__in=paragraph_id_list), + get_file_content( + os.path.join(PROJECT_DIR, "apps", "application", 'sql', + 'list_dataset_paragraph_by_paragraph_id.sql')), + with_table_name=True) # 如果向量库中存在脏数据 直接删除 if len(paragraph_list) != len(paragraph_id_list): - exist_paragraph_list = [str(row.id) for row in paragraph_list] + exist_paragraph_list = [row.get('id') for row in paragraph_list] for paragraph_id in paragraph_id_list: if not exist_paragraph_list.__contains__(paragraph_id): vector.delete_by_paragraph_id(paragraph_id) @@ -48,6 +72,7 @@ class BaseSearchDatasetStep(ISearchDatasetStep): return { 'step_type': 'search_step', + 'paragraph_list': [row.to_dict() for row in self.context['paragraph_list']], 'run_time': self.context['run_time'], 'problem_text': step_args.get( 'padding_problem_text') if 'padding_problem_text' in step_args else step_args.get('problem_text'), diff --git a/apps/application/migrations/0006_remove_chatrecord_paragraph_id_list.py b/apps/application/migrations/0006_remove_chatrecord_paragraph_id_list.py new file mode 100644 index 000000000..86a785731 --- /dev/null +++ b/apps/application/migrations/0006_remove_chatrecord_paragraph_id_list.py @@ -0,0 +1,17 @@ +# Generated by Django 4.1.10 on 2024-01-19 14:02 + +from django.db import migrations + + +class Migration(migrations.Migration): + + dependencies = [ + ('application', '0005_alter_chatrecord_details'), + ] + + operations = [ + migrations.RemoveField( + model_name='chatrecord', + name='paragraph_id_list', + ), + ] diff --git a/apps/application/models/application.py b/apps/application/models/application.py index 03b7eb142..b6851274f 100644 --- a/apps/application/models/application.py +++ b/apps/application/models/application.py @@ -89,9 +89,6 @@ class ChatRecord(AppModelMixin): chat = models.ForeignKey(Chat, on_delete=models.CASCADE) vote_status = models.CharField(verbose_name='投票', max_length=10, choices=VoteChoices.choices, default=VoteChoices.UN_VOTE) - paragraph_id_list = ArrayField(verbose_name="引用段落id列表", - base_field=models.UUIDField(max_length=128, blank=True) - , default=list) problem_text = models.CharField(max_length=1024, verbose_name="问题") answer_text = models.CharField(max_length=4096, verbose_name="答案") message_tokens = models.IntegerField(verbose_name="请求token数量", default=0) diff --git a/apps/application/serializers/chat_message_serializers.py b/apps/application/serializers/chat_message_serializers.py index c4069a7d4..3d03e57c3 100644 --- a/apps/application/serializers/chat_message_serializers.py +++ b/apps/application/serializers/chat_message_serializers.py @@ -108,7 +108,6 @@ def get_post_handler(chat_info: ChatInfo): **kwargs): chat_record = ChatRecord(id=chat_record_id, chat_id=chat_id, - paragraph_id_list=[str(p.id) for p in paragraph_list], problem_text=problem_text, answer_text=answer_text, details=manage.get_details(), diff --git a/apps/application/serializers/chat_serializers.py b/apps/application/serializers/chat_serializers.py index 3baf7783e..f7a1825de 100644 --- a/apps/application/serializers/chat_serializers.py +++ b/apps/application/serializers/chat_serializers.py @@ -192,28 +192,7 @@ class ChatRecordSerializer(serializers.Serializer): chat_record = self.get_chat_record() if chat_record is None: raise AppApiException(500, "对话不存在") - dataset_list = [] - paragraph_list = [] - if len(chat_record.paragraph_id_list) > 0: - paragraph_list = native_search(QuerySet(Paragraph).filter(id__in=chat_record.paragraph_id_list), - get_file_content( - os.path.join(PROJECT_DIR, "apps", "application", 'sql', - 'list_dataset_paragraph_by_paragraph_id.sql')), - with_table_name=True) - dataset_list = [{'id': dataset_id, 'name': name} for dataset_id, name in reduce(lambda x, y: {**x, **y}, - [{row.get( - 'dataset_id'): row.get( - "dataset_name")} for - row in - paragraph_list], - {}).items()] - - return { - **ChatRecordSerializerModel(chat_record).data, - 'padding_problem_text': chat_record.details.get('problem_padding').get( - 'padding_problem_text') if 'problem_padding' in chat_record.details else None, - 'dataset_list': dataset_list, - 'paragraph_list': paragraph_list} + return ChatRecordSerializer.Query.reset_chat_record(chat_record) class Query(serializers.Serializer): application_id = serializers.UUIDField(required=True) @@ -226,37 +205,22 @@ class ChatRecordSerializer(serializers.Serializer): return [ChatRecordSerializerModel(chat_record).data for chat_record in QuerySet(ChatRecord).filter(chat_id=self.data.get('chat_id'))] - def reset_chat_record_list(self, chat_record_list: List[ChatRecord]): - paragraph_id_list = flat_map([chat_record.paragraph_id_list for chat_record in chat_record_list]) - # 去重 - paragraph_id_list = list(set(paragraph_id_list)) - paragraph_list = self.search_paragraph(paragraph_id_list) - return [self.reset_chat_record(chat_record, paragraph_list) for chat_record in chat_record_list] - @staticmethod - def search_paragraph(paragraph_id_list: List[str]): + def reset_chat_record(chat_record): + dataset_list = [] paragraph_list = [] - if len(paragraph_id_list) > 0: - paragraph_list = native_search(QuerySet(Paragraph).filter(id__in=paragraph_id_list), - get_file_content( - os.path.join(PROJECT_DIR, "apps", "application", 'sql', - 'list_dataset_paragraph_by_paragraph_id.sql')), - with_table_name=True) - return paragraph_list + if 'search_step' in chat_record.details and chat_record.details.get('search_step').get( + 'paragraph_list') is not None: + paragraph_list = chat_record.details.get('search_step').get( + 'paragraph_list') + dataset_list = [{'id': dataset_id, 'name': name} for dataset_id, name in reduce(lambda x, y: {**x, **y}, + [{row.get( + 'dataset_id'): row.get( + "dataset_name")} for + row in + paragraph_list], + {}).items()] - @staticmethod - def reset_chat_record(chat_record, all_paragraph_list): - paragraph_list = list( - filter( - lambda paragraph: chat_record.paragraph_id_list.__contains__(uuid.UUID(str(paragraph.get('id')))), - all_paragraph_list)) - dataset_list = [{'id': dataset_id, 'name': name} for dataset_id, name in reduce(lambda x, y: {**x, **y}, - [{row.get( - 'dataset_id'): row.get( - "dataset_name")} for - row in - paragraph_list], - {}).items()] return { **ChatRecordSerializerModel(chat_record).data, 'padding_problem_text': chat_record.details.get('problem_padding').get( @@ -270,9 +234,7 @@ class ChatRecordSerializer(serializers.Serializer): self.is_valid(raise_exception=True) page = page_search(current_page, page_size, QuerySet(ChatRecord).filter(chat_id=self.data.get('chat_id')).order_by("index"), - post_records_handler=lambda chat_record: chat_record) - records = page.get('records') - page['records'] = self.reset_chat_record_list(records) + post_records_handler=lambda chat_record: self.reset_chat_record(chat_record)) return page class Vote(serializers.Serializer): diff --git a/apps/application/sql/list_dataset_paragraph_by_paragraph_id.sql b/apps/application/sql/list_dataset_paragraph_by_paragraph_id.sql index 4c27ca07d..b0843b452 100644 --- a/apps/application/sql/list_dataset_paragraph_by_paragraph_id.sql +++ b/apps/application/sql/list_dataset_paragraph_by_paragraph_id.sql @@ -1,6 +1,8 @@ SELECT paragraph.*, - dataset."name" AS "dataset_name" + dataset."name" AS "dataset_name", + "document"."name" AS "document_name" FROM paragraph paragraph - LEFT JOIN dataset dataset ON dataset."id" = paragraph.dataset_id \ No newline at end of file + LEFT JOIN dataset dataset ON dataset."id" = paragraph.dataset_id + LEFT JOIN "document" "document" ON "document"."id" =paragraph.document_id \ No newline at end of file