From 3f87335c808b9ec85666ec775e949342ec930a9c Mon Sep 17 00:00:00 2001 From: shaohuzhang1 Date: Tue, 16 Jan 2024 16:46:54 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E4=BC=98=E5=8C=96=E5=AF=B9=E8=AF=9D?= =?UTF-8?q?=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../chat_pipeline/I_base_chat_pipeline.py | 55 ++++ apps/application/chat_pipeline/__init__.py | 8 + .../chat_pipeline/pipeline_manage.py | 45 +++ .../chat_pipeline/step/__init__.py | 8 + .../chat_pipeline/step/chat_step/__init__.py | 8 + .../step/chat_step/i_chat_step.py | 88 +++++ .../step/chat_step/impl/base_chat_step.py | 111 +++++++ .../generate_human_message_step/__init__.py | 8 + .../i_generate_human_message_step.py | 68 ++++ .../impl/base_generate_human_message_step.py | 57 ++++ .../step/reset_problem_step/__init__.py | 8 + .../i_reset_problem_step.py | 49 +++ .../impl/base_reset_problem_step.py | 47 +++ .../step/search_dataset_step/__init__.py | 8 + .../i_search_dataset_step.py | 58 ++++ .../impl/base_search_dataset_step.py | 58 ++++ ...0003_remove_chatrecord_dataset_and_more.py | 55 ++++ ...004_remove_application_example_and_more.py | 38 +++ .../0005_alter_chatrecord_details.py | 18 ++ apps/application/models/application.py | 54 +++- .../serializers/application_serializers.py | 46 ++- .../serializers/chat_message_serializers.py | 305 +++++++++--------- .../serializers/chat_serializers.py | 146 ++++++--- apps/application/sql/list_application.sql | 2 +- ...list_dataset_paragraph_by_paragraph_id.sql | 6 + .../swagger_api/application_api.py | 59 +++- apps/application/swagger_api/chat_api.py | 11 +- apps/application/urls.py | 2 + apps/application/views/chat_views.py | 24 +- apps/common/db/search.py | 2 +- apps/common/event/__init__.py | 2 - apps/common/event/listener_chat_message.py | 67 ---- apps/common/field/common.py | 34 ++ apps/common/sql/list_embedding_text.sql | 8 +- ..._num_remove_paragraph_star_num_and_more.py | 37 +++ apps/dataset/models/data_set.py | 6 - .../serializers/dataset_serializers.py | 9 +- .../serializers/paragraph_serializers.py | 2 +- .../serializers/problem_serializers.py | 6 +- ...0002_remove_embedding_star_num_and_more.py | 37 +++ apps/embedding/models/embedding.py | 8 +- apps/embedding/sql/embedding_search.sql | 30 +- apps/embedding/sql/hit_test.sql | 35 +- apps/embedding/vector/base_vector.py | 15 +- apps/embedding/vector/pg_vector.py | 27 +- .../model/qian_fan_chat_model.py | 14 +- 46 files changed, 1393 insertions(+), 396 deletions(-) create mode 100644 apps/application/chat_pipeline/I_base_chat_pipeline.py create mode 100644 apps/application/chat_pipeline/__init__.py create mode 100644 apps/application/chat_pipeline/pipeline_manage.py create mode 100644 apps/application/chat_pipeline/step/__init__.py create mode 100644 apps/application/chat_pipeline/step/chat_step/__init__.py create mode 100644 apps/application/chat_pipeline/step/chat_step/i_chat_step.py create mode 100644 apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py create mode 100644 apps/application/chat_pipeline/step/generate_human_message_step/__init__.py create mode 100644 apps/application/chat_pipeline/step/generate_human_message_step/i_generate_human_message_step.py create mode 100644 apps/application/chat_pipeline/step/generate_human_message_step/impl/base_generate_human_message_step.py create mode 100644 apps/application/chat_pipeline/step/reset_problem_step/__init__.py create mode 100644 apps/application/chat_pipeline/step/reset_problem_step/i_reset_problem_step.py create mode 100644 apps/application/chat_pipeline/step/reset_problem_step/impl/base_reset_problem_step.py create mode 100644 apps/application/chat_pipeline/step/search_dataset_step/__init__.py create mode 100644 apps/application/chat_pipeline/step/search_dataset_step/i_search_dataset_step.py create mode 100644 apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py create mode 100644 apps/application/migrations/0003_remove_chatrecord_dataset_and_more.py create mode 100644 apps/application/migrations/0004_remove_application_example_and_more.py create mode 100644 apps/application/migrations/0005_alter_chatrecord_details.py create mode 100644 apps/application/sql/list_dataset_paragraph_by_paragraph_id.sql delete mode 100644 apps/common/event/listener_chat_message.py create mode 100644 apps/common/field/common.py create mode 100644 apps/dataset/migrations/0004_remove_paragraph_hit_num_remove_paragraph_star_num_and_more.py create mode 100644 apps/embedding/migrations/0002_remove_embedding_star_num_and_more.py diff --git a/apps/application/chat_pipeline/I_base_chat_pipeline.py b/apps/application/chat_pipeline/I_base_chat_pipeline.py new file mode 100644 index 000000000..eaa255d13 --- /dev/null +++ b/apps/application/chat_pipeline/I_base_chat_pipeline.py @@ -0,0 +1,55 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: I_base_chat_pipeline.py + @date:2024/1/9 17:25 + @desc: +""" +import time +from abc import abstractmethod +from typing import Type + +from rest_framework import serializers + + +class IBaseChatPipelineStep: + def __init__(self): + # 当前步骤上下文,用于存储当前步骤信息 + self.context = {} + + @abstractmethod + def get_step_serializer(self, manage) -> Type[serializers.Serializer]: + pass + + def valid_args(self, manage): + step_serializer_clazz = self.get_step_serializer(manage) + step_serializer = step_serializer_clazz(data=manage.context) + step_serializer.is_valid(raise_exception=True) + self.context['step_args'] = step_serializer.data + + def run(self, manage): + """ + + :param manage: 步骤管理器 + :return: 执行结果 + """ + start_time = time.time() + # 校验参数, + self.valid_args(manage) + self._run(manage) + self.context['start_time'] = start_time + self.context['run_time'] = time.time() - start_time + + def _run(self, manage): + pass + + def execute(self, **kwargs): + pass + + def get_details(self, manage, **kwargs): + """ + 运行详情 + :return: 步骤详情 + """ + return None diff --git a/apps/application/chat_pipeline/__init__.py b/apps/application/chat_pipeline/__init__.py new file mode 100644 index 000000000..719a7e29c --- /dev/null +++ b/apps/application/chat_pipeline/__init__.py @@ -0,0 +1,8 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py.py + @date:2024/1/9 17:23 + @desc: +""" diff --git a/apps/application/chat_pipeline/pipeline_manage.py b/apps/application/chat_pipeline/pipeline_manage.py new file mode 100644 index 000000000..2b94290d2 --- /dev/null +++ b/apps/application/chat_pipeline/pipeline_manage.py @@ -0,0 +1,45 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: pipeline_manage.py + @date:2024/1/9 17:40 + @desc: +""" +import time +from functools import reduce +from typing import List, Type, Dict + +from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep + + +class PiplineManage: + def __init__(self, step_list: List[Type[IBaseChatPipelineStep]]): + # 步骤执行器 + self.step_list = [step() for step in step_list] + # 上下文 + self.context = {'message_tokens': 0, 'answer_tokens': 0} + + def run(self, context: Dict = None): + self.context['start_time'] = time.time() + if context is not None: + for key, value in context.items(): + self.context[key] = value + for step in self.step_list: + step.run(self) + + def get_details(self): + return reduce(lambda x, y: {**x, **y}, [{item.get('step_type'): item} for item in + filter(lambda r: r is not None, + [row.get_details(self) for row in self.step_list])], {}) + + class builder: + def __init__(self): + self.step_list: List[Type[IBaseChatPipelineStep]] = [] + + def append_step(self, step: Type[IBaseChatPipelineStep]): + self.step_list.append(step) + return self + + def build(self): + return PiplineManage(step_list=self.step_list) diff --git a/apps/application/chat_pipeline/step/__init__.py b/apps/application/chat_pipeline/step/__init__.py new file mode 100644 index 000000000..5d9549cdc --- /dev/null +++ b/apps/application/chat_pipeline/step/__init__.py @@ -0,0 +1,8 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py.py + @date:2024/1/9 18:23 + @desc: +""" diff --git a/apps/application/chat_pipeline/step/chat_step/__init__.py b/apps/application/chat_pipeline/step/chat_step/__init__.py new file mode 100644 index 000000000..5d9549cdc --- /dev/null +++ b/apps/application/chat_pipeline/step/chat_step/__init__.py @@ -0,0 +1,8 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py.py + @date:2024/1/9 18:23 + @desc: +""" diff --git a/apps/application/chat_pipeline/step/chat_step/i_chat_step.py b/apps/application/chat_pipeline/step/chat_step/i_chat_step.py new file mode 100644 index 000000000..b02273ae8 --- /dev/null +++ b/apps/application/chat_pipeline/step/chat_step/i_chat_step.py @@ -0,0 +1,88 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: i_chat_step.py + @date:2024/1/9 18:17 + @desc: 对话 +""" +from abc import abstractmethod +from typing import Type, List + +from langchain.chat_models.base import BaseChatModel +from langchain.schema import BaseMessage +from rest_framework import serializers + +from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep +from application.chat_pipeline.pipeline_manage import PiplineManage +from common.field.common import InstanceField +from dataset.models import Paragraph + + +class ModelField(serializers.Field): + def to_internal_value(self, data): + if not isinstance(data, BaseChatModel): + self.fail('模型类型错误', value=data) + return data + + def to_representation(self, value): + return value + + +class MessageField(serializers.Field): + def to_internal_value(self, data): + if not isinstance(data, BaseMessage): + self.fail('message类型错误', value=data) + return data + + def to_representation(self, value): + return value + + +class PostResponseHandler: + @abstractmethod + def handler(self, chat_id, chat_record_id, paragraph_list: List[Paragraph], problem_text: str, answer_text, + manage, step, padding_problem_text: str = None, **kwargs): + pass + + +class IChatStep(IBaseChatPipelineStep): + class InstanceSerializer(serializers.Serializer): + # 对话列表 + message_list = serializers.ListField(required=True, child=MessageField(required=True)) + # 大语言模型 + chat_model = ModelField() + # 段落列表 + paragraph_list = serializers.ListField() + # 对话id + chat_id = serializers.UUIDField(required=True) + # 用户问题 + problem_text = serializers.CharField(required=True) + # 后置处理器 + post_response_handler = InstanceField(model_type=PostResponseHandler) + # 补全问题 + padding_problem_text = serializers.CharField(required=False) + + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + message_list: List = self.initial_data.get('message_list') + for message in message_list: + if not isinstance(message, BaseMessage): + raise Exception("message 类型错误") + + def get_step_serializer(self, manage: PiplineManage) -> Type[serializers.Serializer]: + return self.InstanceSerializer + + def _run(self, manage: PiplineManage): + chat_result = self.execute(**self.context['step_args'], manage=manage) + manage.context['chat_result'] = chat_result + + @abstractmethod + def execute(self, message_list: List[BaseMessage], + chat_id, problem_text, + post_response_handler: PostResponseHandler, + chat_model: BaseChatModel = None, + paragraph_list=None, + manage: PiplineManage = None, + padding_problem_text: str = None, **kwargs): + pass diff --git a/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py b/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py new file mode 100644 index 000000000..123574459 --- /dev/null +++ b/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py @@ -0,0 +1,111 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: base_chat_step.py + @date:2024/1/9 18:25 + @desc: 对话step Base实现 +""" +import json +import logging +import time +import traceback +import uuid +from typing import List + +from django.http import StreamingHttpResponse +from langchain.chat_models.base import BaseChatModel +from langchain.schema import BaseMessage +from langchain.schema.messages import BaseMessageChunk, HumanMessage + +from application.chat_pipeline.pipeline_manage import PiplineManage +from application.chat_pipeline.step.chat_step.i_chat_step import IChatStep, PostResponseHandler +from dataset.models import Paragraph + + +def event_content(response, + chat_id, + chat_record_id, + paragraph_list: List[Paragraph], + post_response_handler: PostResponseHandler, + manage, + step, + chat_model, + message_list: List[BaseMessage], + problem_text: str, + padding_problem_text: str = None): + all_text = '' + try: + for chunk in response: + all_text += chunk.content + yield 'data: ' + json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True, + 'content': chunk.content, 'is_end': False}) + "\n\n" + + yield 'data: ' + json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True, + 'content': '', 'is_end': True}) + "\n\n" + # 获取token + request_token = chat_model.get_num_tokens_from_messages(message_list) + response_token = chat_model.get_num_tokens(all_text) + step.context['message_tokens'] = request_token + step.context['answer_tokens'] = response_token + current_time = time.time() + step.context['answer_text'] = all_text + step.context['run_time'] = current_time - step.context['start_time'] + manage.context['run_time'] = current_time - manage.context['start_time'] + manage.context['message_tokens'] = manage.context['message_tokens'] + request_token + manage.context['answer_tokens'] = manage.context['answer_tokens'] + response_token + post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text, + all_text, manage, step, padding_problem_text) + except Exception as e: + logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}') + yield 'data: ' + json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True, + 'content': '异常' + str(e), 'is_end': True}) + "\n\n" + + +class BaseChatStep(IChatStep): + def execute(self, message_list: List[BaseMessage], + chat_id, + problem_text, + post_response_handler: PostResponseHandler, + chat_model: BaseChatModel = None, + paragraph_list=None, + manage: PiplineManage = None, + padding_problem_text: str = None, + **kwargs): + # 调用模型 + if chat_model is None: + chat_result = iter( + [BaseMessageChunk(content=paragraph.title + "\n" + paragraph.content) for paragraph in paragraph_list]) + else: + chat_result = chat_model.stream(message_list) + + chat_record_id = uuid.uuid1() + r = StreamingHttpResponse( + streaming_content=event_content(chat_result, chat_id, chat_record_id, paragraph_list, + post_response_handler, manage, self, chat_model, message_list, problem_text, + padding_problem_text), + content_type='text/event-stream;charset=utf-8') + + r['Cache-Control'] = 'no-cache' + return r + + def get_details(self, manage, **kwargs): + return { + 'step_type': 'chat_step', + 'run_time': self.context['run_time'], + 'model_id': str(manage.context['model_id']), + 'message_list': self.reset_message_list(self.context['step_args'].get('message_list'), + self.context['answer_text']), + 'message_tokens': self.context['message_tokens'], + 'answer_tokens': self.context['answer_tokens'], + 'cost': 0, + } + + @staticmethod + def reset_message_list(message_list: List[BaseMessage], answer_text): + result = [{'role': 'user' if isinstance(message, HumanMessage) else 'ai', 'content': message.content} for + message + in + message_list] + result.append({'role': 'ai', 'content': answer_text}) + return result diff --git a/apps/application/chat_pipeline/step/generate_human_message_step/__init__.py b/apps/application/chat_pipeline/step/generate_human_message_step/__init__.py new file mode 100644 index 000000000..5d9549cdc --- /dev/null +++ b/apps/application/chat_pipeline/step/generate_human_message_step/__init__.py @@ -0,0 +1,8 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py.py + @date:2024/1/9 18:23 + @desc: +""" diff --git a/apps/application/chat_pipeline/step/generate_human_message_step/i_generate_human_message_step.py b/apps/application/chat_pipeline/step/generate_human_message_step/i_generate_human_message_step.py new file mode 100644 index 000000000..e26b06347 --- /dev/null +++ b/apps/application/chat_pipeline/step/generate_human_message_step/i_generate_human_message_step.py @@ -0,0 +1,68 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: i_generate_human_message_step.py + @date:2024/1/9 18:15 + @desc: 生成对话模板 +""" +from abc import abstractmethod +from typing import Type, List + +from langchain.schema import BaseMessage +from rest_framework import serializers + +from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep +from application.chat_pipeline.pipeline_manage import PiplineManage +from application.models import ChatRecord +from common.field.common import InstanceField +from dataset.models import Paragraph + + +class IGenerateHumanMessageStep(IBaseChatPipelineStep): + class InstanceSerializer(serializers.Serializer): + # 问题 + problem_text = serializers.CharField(required=True) + # 段落列表 + paragraph_list = serializers.ListField(child=InstanceField(model_type=Paragraph, required=True)) + # 历史对答 + history_chat_record = serializers.ListField(child=InstanceField(model_type=ChatRecord, required=True)) + # 多轮对话数量 + dialogue_number = serializers.IntegerField(required=True) + # 最大携带知识库段落长度 + max_paragraph_char_number = serializers.IntegerField(required=True) + # 模板 + prompt = serializers.CharField(required=True) + # 补齐问题 + padding_problem_text = serializers.CharField(required=False) + + def get_step_serializer(self, manage: PiplineManage) -> Type[serializers.Serializer]: + return self.InstanceSerializer + + def _run(self, manage: PiplineManage): + message_list = self.execute(**self.context['step_args']) + manage.context['message_list'] = message_list + + @abstractmethod + def execute(self, + problem_text: str, + paragraph_list: List[Paragraph], + history_chat_record: List[ChatRecord], + dialogue_number: int, + max_paragraph_char_number: int, + prompt: str, + padding_problem_text: str = None, + **kwargs) -> List[BaseMessage]: + """ + + :param problem_text: 原始问题文本 + :param paragraph_list: 段落列表 + :param history_chat_record: 历史对话记录 + :param dialogue_number: 多轮对话数量 + :param max_paragraph_char_number: 最大段落长度 + :param prompt: 模板 + :param padding_problem_text 用户修改文本 + :param kwargs: 其他参数 + :return: + """ + pass diff --git a/apps/application/chat_pipeline/step/generate_human_message_step/impl/base_generate_human_message_step.py b/apps/application/chat_pipeline/step/generate_human_message_step/impl/base_generate_human_message_step.py new file mode 100644 index 000000000..fcd35ac75 --- /dev/null +++ b/apps/application/chat_pipeline/step/generate_human_message_step/impl/base_generate_human_message_step.py @@ -0,0 +1,57 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: base_generate_human_message_step.py.py + @date:2024/1/10 17:50 + @desc: +""" +from typing import List + +from langchain.schema import BaseMessage, HumanMessage + +from application.chat_pipeline.step.generate_human_message_step.i_generate_human_message_step import \ + IGenerateHumanMessageStep +from application.models import ChatRecord +from common.util.split_model import flat_map +from dataset.models import Paragraph + + +class BaseGenerateHumanMessageStep(IGenerateHumanMessageStep): + + def execute(self, problem_text: str, + paragraph_list: List[Paragraph], + history_chat_record: List[ChatRecord], + dialogue_number: int, + max_paragraph_char_number: int, + prompt: str, + padding_problem_text: str = None, + **kwargs) -> List[BaseMessage]: + exec_problem_text = padding_problem_text if padding_problem_text is not None else problem_text + start_index = len(history_chat_record) - dialogue_number + history_message = [[history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()] + for index in + range(start_index if start_index > 0 else 0, len(history_chat_record))] + return [*flat_map(history_message), + self.to_human_message(prompt, exec_problem_text, max_paragraph_char_number, paragraph_list)] + + @staticmethod + def to_human_message(prompt: str, + problem: str, + max_paragraph_char_number: int, + paragraph_list: List[Paragraph]): + if paragraph_list is None or len(paragraph_list) == 0: + return HumanMessage(content=problem) + temp_data = "" + data_list = [] + for p in paragraph_list: + content = f"{p.title}:{p.content}" + temp_data += content + if len(temp_data) > max_paragraph_char_number: + row_data = content[0:max_paragraph_char_number - len(temp_data)] + data_list.append(f"{row_data}") + break + else: + data_list.append(f"{content}") + data = "\n".join(data_list) + return HumanMessage(content=prompt.format(**{'data': data, 'question': problem})) diff --git a/apps/application/chat_pipeline/step/reset_problem_step/__init__.py b/apps/application/chat_pipeline/step/reset_problem_step/__init__.py new file mode 100644 index 000000000..5d9549cdc --- /dev/null +++ b/apps/application/chat_pipeline/step/reset_problem_step/__init__.py @@ -0,0 +1,8 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py.py + @date:2024/1/9 18:23 + @desc: +""" diff --git a/apps/application/chat_pipeline/step/reset_problem_step/i_reset_problem_step.py b/apps/application/chat_pipeline/step/reset_problem_step/i_reset_problem_step.py new file mode 100644 index 000000000..c8f9143dd --- /dev/null +++ b/apps/application/chat_pipeline/step/reset_problem_step/i_reset_problem_step.py @@ -0,0 +1,49 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: i_reset_problem_step.py + @date:2024/1/9 18:12 + @desc: 重写处理问题 +""" +from abc import abstractmethod +from typing import Type, List + +from langchain.chat_models.base import BaseChatModel +from rest_framework import serializers + +from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep +from application.chat_pipeline.pipeline_manage import PiplineManage +from application.chat_pipeline.step.chat_step.i_chat_step import ModelField +from application.models import ChatRecord +from common.field.common import InstanceField + + +class IResetProblemStep(IBaseChatPipelineStep): + class InstanceSerializer(serializers.Serializer): + # 问题文本 + problem_text = serializers.CharField(required=True) + # 历史对答 + history_chat_record = serializers.ListField(child=InstanceField(model_type=ChatRecord, required=True)) + # 大语言模型 + chat_model = ModelField() + + def get_step_serializer(self, manage: PiplineManage) -> Type[serializers.Serializer]: + return self.InstanceSerializer + + def _run(self, manage: PiplineManage): + padding_problem = self.execute(**self.context.get('step_args')) + # 用户输入问题 + source_problem_text = self.context.get('step_args').get('problem_text') + self.context['problem_text'] = source_problem_text + self.context['padding_problem_text'] = padding_problem + manage.context['problem_text'] = source_problem_text + manage.context['padding_problem_text'] = padding_problem + # 累加tokens + manage.context['message_tokens'] = manage.context['message_tokens'] + self.context.get('message_tokens') + manage.context['answer_tokens'] = manage.context['answer_tokens'] + self.context.get('answer_tokens') + + @abstractmethod + def execute(self, problem_text: str, history_chat_record: List[ChatRecord] = None, chat_model: BaseChatModel = None, + **kwargs): + pass diff --git a/apps/application/chat_pipeline/step/reset_problem_step/impl/base_reset_problem_step.py b/apps/application/chat_pipeline/step/reset_problem_step/impl/base_reset_problem_step.py new file mode 100644 index 000000000..14a68d49a --- /dev/null +++ b/apps/application/chat_pipeline/step/reset_problem_step/impl/base_reset_problem_step.py @@ -0,0 +1,47 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: base_reset_problem_step.py + @date:2024/1/10 14:35 + @desc: +""" +from typing import List + +from langchain.chat_models.base import BaseChatModel +from langchain.schema import HumanMessage + +from application.chat_pipeline.step.reset_problem_step.i_reset_problem_step import IResetProblemStep +from application.models import ChatRecord +from common.util.split_model import flat_map + +prompt = ( + '()里面是用户问题,根据上下文回答揣测用户问题({question}) 要求: 输出一个补全问题,并且放在标签中') + + +class BaseResetProblemStep(IResetProblemStep): + def execute(self, problem_text: str, history_chat_record: List[ChatRecord] = None, chat_model: BaseChatModel = None, + **kwargs) -> str: + start_index = len(history_chat_record) - 3 + history_message = [[history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()] + for index in + range(start_index if start_index > 0 else 0, len(history_chat_record))] + message_list = [*flat_map(history_message), + HumanMessage(content=prompt.format(**{'question': problem_text}))] + response = chat_model(message_list) + padding_problem = response.content[response.content.index('') + 6:response.content.index('')] + self.context['message_tokens'] = chat_model.get_num_tokens_from_messages(message_list) + self.context['answer_tokens'] = chat_model.get_num_tokens(padding_problem) + return padding_problem + + def get_details(self, manage, **kwargs): + return { + 'step_type': 'problem_padding', + 'run_time': self.context['run_time'], + 'model_id': str(manage.context['model_id']) if 'model_id' in manage.context else None, + 'message_tokens': self.context['message_tokens'], + 'answer_tokens': self.context['answer_tokens'], + 'cost': 0, + 'padding_problem_text': self.context.get('padding_problem_text'), + 'problem_text': self.context.get("step_args").get('problem_text'), + } diff --git a/apps/application/chat_pipeline/step/search_dataset_step/__init__.py b/apps/application/chat_pipeline/step/search_dataset_step/__init__.py new file mode 100644 index 000000000..023c4bc38 --- /dev/null +++ b/apps/application/chat_pipeline/step/search_dataset_step/__init__.py @@ -0,0 +1,8 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py.py + @date:2024/1/9 18:24 + @desc: +""" diff --git a/apps/application/chat_pipeline/step/search_dataset_step/i_search_dataset_step.py b/apps/application/chat_pipeline/step/search_dataset_step/i_search_dataset_step.py new file mode 100644 index 000000000..a79bad8ba --- /dev/null +++ b/apps/application/chat_pipeline/step/search_dataset_step/i_search_dataset_step.py @@ -0,0 +1,58 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: i_search_dataset_step.py + @date:2024/1/9 18:10 + @desc: 检索知识库 +""" +from abc import abstractmethod +from typing import List, Type + +from rest_framework import serializers + +from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep +from application.chat_pipeline.pipeline_manage import PiplineManage +from dataset.models import Paragraph + + +class ISearchDatasetStep(IBaseChatPipelineStep): + class InstanceSerializer(serializers.Serializer): + # 原始问题文本 + problem_text = serializers.CharField(required=True) + # 系统补全问题文本 + padding_problem_text = serializers.CharField(required=False) + # 需要查询的数据集id列表 + dataset_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True)) + # 需要排除的文档id + exclude_document_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True)) + # 需要排除向量id + exclude_paragraph_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True)) + # 需要查询的条数 + top_n = serializers.IntegerField(required=True) + # 相似度 0-1之间 + similarity = serializers.FloatField(required=True, max_value=1, min_value=0) + + def get_step_serializer(self, manage: PiplineManage) -> Type[InstanceSerializer]: + return self.InstanceSerializer + + def _run(self, manage: PiplineManage): + paragraph_list = self.execute(**self.context['step_args']) + manage.context['paragraph_list'] = paragraph_list + + @abstractmethod + def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str], + exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None, + **kwargs) -> List[Paragraph]: + """ + 关于 用户和补全问题 说明: 补全问题如果有就使用补全问题去查询 反之就用用户原始问题查询 + :param similarity: 相关性 + :param top_n: 查询多少条 + :param problem_text: 用户问题 + :param dataset_id_list: 需要查询的数据集id列表 + :param exclude_document_id_list: 需要排除的文档id + :param exclude_paragraph_id_list: 需要排除段落id + :param padding_problem_text 补全问题 + :return: 段落列表 + """ + pass diff --git a/apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py b/apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py new file mode 100644 index 000000000..4fa13639a --- /dev/null +++ b/apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py @@ -0,0 +1,58 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: base_search_dataset_step.py + @date:2024/1/10 10:33 + @desc: +""" +from typing import List + +from django.db.models import QuerySet + +from application.chat_pipeline.step.search_dataset_step.i_search_dataset_step import ISearchDatasetStep +from common.config.embedding_config import VectorStore, EmbeddingModel +from dataset.models import Paragraph + + +class BaseSearchDatasetStep(ISearchDatasetStep): + + def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str], + exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None, + **kwargs) -> List[Paragraph]: + exec_problem_text = padding_problem_text if padding_problem_text is not None else problem_text + embedding_model = EmbeddingModel.get_embedding_model() + embedding_value = embedding_model.embed_query(exec_problem_text) + vector = VectorStore.get_embedding_vector() + embedding_list = vector.query(embedding_value, dataset_id_list, exclude_document_id_list, + exclude_paragraph_id_list, True, top_n, similarity) + if embedding_list is None: + return [] + return self.list_paragraph([row.get('paragraph_id') for row in embedding_list], vector) + + @staticmethod + def list_paragraph(paragraph_id_list: List, vector): + if paragraph_id_list is None or len(paragraph_id_list) == 0: + return [] + paragraph_list = QuerySet(Paragraph).filter(id__in=paragraph_id_list) + # 如果向量库中存在脏数据 直接删除 + if len(paragraph_list) != len(paragraph_id_list): + exist_paragraph_list = [str(row.id) for row in paragraph_list] + for paragraph_id in paragraph_id_list: + if not exist_paragraph_list.__contains__(paragraph_id): + vector.delete_by_paragraph_id(paragraph_id) + return paragraph_list + + def get_details(self, manage, **kwargs): + step_args = self.context['step_args'] + + return { + 'step_type': 'search_step', + 'run_time': self.context['run_time'], + 'problem_text': step_args.get( + 'padding_problem_text') if 'padding_problem_text' in step_args else step_args.get('problem_text'), + 'model_name': EmbeddingModel.get_embedding_model().model_name, + 'message_tokens': 0, + 'answer_tokens': 0, + 'cost': 0 + } diff --git a/apps/application/migrations/0003_remove_chatrecord_dataset_and_more.py b/apps/application/migrations/0003_remove_chatrecord_dataset_and_more.py new file mode 100644 index 000000000..ec041426f --- /dev/null +++ b/apps/application/migrations/0003_remove_chatrecord_dataset_and_more.py @@ -0,0 +1,55 @@ +# Generated by Django 4.1.10 on 2024-01-12 18:46 + +import django.contrib.postgres.fields +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('application', '0002_alter_chatrecord_dataset'), + ] + + operations = [ + migrations.RemoveField( + model_name='chatrecord', + name='dataset', + ), + migrations.RemoveField( + model_name='chatrecord', + name='paragraph', + ), + migrations.RemoveField( + model_name='chatrecord', + name='source_id', + ), + migrations.RemoveField( + model_name='chatrecord', + name='source_type', + ), + migrations.AddField( + model_name='chatrecord', + name='const', + field=models.IntegerField(default=0, verbose_name='总费用'), + ), + migrations.AddField( + model_name='chatrecord', + name='details', + field=models.JSONField(default=list, verbose_name='对话详情'), + ), + migrations.AddField( + model_name='chatrecord', + name='paragraph_id_list', + field=django.contrib.postgres.fields.ArrayField(base_field=models.UUIDField(blank=True), default=list, size=None, verbose_name='引用段落id列表'), + ), + migrations.AddField( + model_name='chatrecord', + name='run_time', + field=models.FloatField(default=0, verbose_name='运行时长'), + ), + migrations.AlterField( + model_name='chatrecord', + name='answer_text', + field=models.CharField(max_length=4096, verbose_name='答案'), + ), + ] diff --git a/apps/application/migrations/0004_remove_application_example_and_more.py b/apps/application/migrations/0004_remove_application_example_and_more.py new file mode 100644 index 000000000..02647f80e --- /dev/null +++ b/apps/application/migrations/0004_remove_application_example_and_more.py @@ -0,0 +1,38 @@ +# Generated by Django 4.1.10 on 2024-01-15 16:07 + +import application.models.application +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('application', '0003_remove_chatrecord_dataset_and_more'), + ] + + operations = [ + migrations.RemoveField( + model_name='application', + name='example', + ), + migrations.AddField( + model_name='application', + name='dataset_setting', + field=models.JSONField(default=application.models.application.get_dataset_setting_dict, verbose_name='数据集参数设置'), + ), + migrations.AddField( + model_name='application', + name='model_setting', + field=models.JSONField(default=application.models.application.get_model_setting_dict, verbose_name='模型参数相关设置'), + ), + migrations.AddField( + model_name='application', + name='problem_optimization', + field=models.BooleanField(default=False, verbose_name='问题优化'), + ), + migrations.AlterField( + model_name='chatrecord', + name='details', + field=models.JSONField(default={}, verbose_name='对话详情'), + ), + ] diff --git a/apps/application/migrations/0005_alter_chatrecord_details.py b/apps/application/migrations/0005_alter_chatrecord_details.py new file mode 100644 index 000000000..5deb1e1a1 --- /dev/null +++ b/apps/application/migrations/0005_alter_chatrecord_details.py @@ -0,0 +1,18 @@ +# Generated by Django 4.1.10 on 2024-01-16 11:22 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('application', '0004_remove_application_example_and_more'), + ] + + operations = [ + migrations.AlterField( + model_name='chatrecord', + name='details', + field=models.JSONField(default=dict, verbose_name='对话详情'), + ), + ] diff --git a/apps/application/models/application.py b/apps/application/models/application.py index 5c875296d..03b7eb142 100644 --- a/apps/application/models/application.py +++ b/apps/application/models/application.py @@ -10,23 +10,47 @@ import uuid from django.contrib.postgres.fields import ArrayField from django.db import models +from langchain.schema import HumanMessage, AIMessage from common.mixins.app_model_mixin import AppModelMixin -from dataset.models.data_set import DataSet, Paragraph -from embedding.models import SourceType +from dataset.models.data_set import DataSet from setting.models.model_management import Model from users.models import User +def get_dataset_setting_dict(): + return {'top_n': 3, 'similarity': 0.6, 'max_paragraph_char_number': 5000} + + +def get_model_setting_dict(): + return {'prompt': Application.get_default_model_prompt()} + + class Application(AppModelMixin): id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id") name = models.CharField(max_length=128, verbose_name="应用名称") desc = models.CharField(max_length=128, verbose_name="引用描述", default="") prologue = models.CharField(max_length=1024, verbose_name="开场白", default="") - example = ArrayField(verbose_name="示例列表", base_field=models.CharField(max_length=256, blank=True), default=list) dialogue_number = models.IntegerField(default=0, verbose_name="会话数量") user = models.ForeignKey(User, on_delete=models.DO_NOTHING) model = models.ForeignKey(Model, on_delete=models.SET_NULL, db_constraint=False, blank=True, null=True) + dataset_setting = models.JSONField(verbose_name="数据集参数设置", default=get_dataset_setting_dict) + model_setting = models.JSONField(verbose_name="模型参数相关设置", default=get_model_setting_dict) + problem_optimization = models.BooleanField(verbose_name="问题优化", default=False) + + @staticmethod + def get_default_model_prompt(): + return ('已知信息:' + '\n{data}' + '\n回答要求:' + '\n- 如果你不知道答案或者没有从获取答案,请回答“没有在知识库中查找到相关信息,建议咨询相关技术支持或参考官方文档进行操作”。' + '\n- 避免提及你是从中获得的知识。' + '\n- 请保持答案与中描述的一致。' + '\n- 请使用markdown 语法优化答案的格式。' + '\n- 中的图片链接、链接地址和脚本语言请完整返回。' + '\n- 请使用与问题相同的语言来回答。' + '\n问题:' + '\n{question}') class Meta: db_table = "application" @@ -65,20 +89,28 @@ class ChatRecord(AppModelMixin): chat = models.ForeignKey(Chat, on_delete=models.CASCADE) vote_status = models.CharField(verbose_name='投票', max_length=10, choices=VoteChoices.choices, default=VoteChoices.UN_VOTE) - dataset = models.ForeignKey(DataSet, on_delete=models.SET_NULL, verbose_name="数据集", blank=True, null=True) - paragraph = models.ForeignKey(Paragraph, on_delete=models.SET_NULL, verbose_name="段落id", blank=True, null=True) - source_id = models.UUIDField(max_length=128, verbose_name="资源id 段落/问题 id ", null=True) - source_type = models.CharField(verbose_name='资源类型', max_length=2, choices=SourceType.choices, - default=SourceType.PROBLEM, blank=True, null=True) + paragraph_id_list = ArrayField(verbose_name="引用段落id列表", + base_field=models.UUIDField(max_length=128, blank=True) + , default=list) + problem_text = models.CharField(max_length=1024, verbose_name="问题") + answer_text = models.CharField(max_length=4096, verbose_name="答案") message_tokens = models.IntegerField(verbose_name="请求token数量", default=0) answer_tokens = models.IntegerField(verbose_name="响应token数量", default=0) - problem_text = models.CharField(max_length=1024, verbose_name="问题") - answer_text = models.CharField(max_length=1024, verbose_name="答案") + const = models.IntegerField(verbose_name="总费用", default=0) + details = models.JSONField(verbose_name="对话详情", default=dict) improve_paragraph_id_list = ArrayField(verbose_name="改进标注列表", base_field=models.UUIDField(max_length=128, blank=True) , default=list) - + run_time = models.FloatField(verbose_name="运行时长", default=0) index = models.IntegerField(verbose_name="对话下标") + def get_human_message(self): + if 'problem_padding' in self.details: + return HumanMessage(content=self.details.get('problem_padding').get('padding_problem_text')) + return HumanMessage(content=self.problem_text) + + def get_ai_message(self): + return AIMessage(content=self.answer_text) + class Meta: db_table = "application_chat_record" diff --git a/apps/application/serializers/application_serializers.py b/apps/application/serializers/application_serializers.py index 8bd858358..680360dd4 100644 --- a/apps/application/serializers/application_serializers.py +++ b/apps/application/serializers/application_serializers.py @@ -63,15 +63,35 @@ class ApplicationSerializerModel(serializers.ModelSerializer): fields = "__all__" +class DatasetSettingSerializer(serializers.Serializer): + top_n = serializers.FloatField(required=True) + similarity = serializers.FloatField(required=True, max_value=1, min_value=0) + max_paragraph_char_number = serializers.IntegerField(required=True, max_value=10000) + + +class ModelSettingSerializer(serializers.Serializer): + prompt = serializers.CharField(required=True, max_length=4096) + + class ApplicationSerializer(serializers.Serializer): name = serializers.CharField(required=True) desc = serializers.CharField(required=False, allow_null=True, allow_blank=True) model_id = serializers.CharField(required=True) multiple_rounds_dialogue = serializers.BooleanField(required=True) prologue = serializers.CharField(required=False, allow_null=True, allow_blank=True) - example = serializers.ListSerializer(required=False, child=serializers.CharField(required=True), allow_null=True) dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True), allow_null=True) + # 数据集相关设置 + dataset_setting = DatasetSettingSerializer(required=True) + # 模型相关设置 + model_setting = ModelSettingSerializer(required=True) + # 问题补全 + problem_optimization = serializers.BooleanField(required=True) + + def is_valid(self, *, user_id=None, raise_exception=False): + super().is_valid(raise_exception=True) + ModelDatasetAssociation(data={'user_id': user_id, 'model_id': self.data.get('model_id'), + 'dataset_id_list': self.data.get('dataset_id_list')}).is_valid() class AccessTokenSerializer(serializers.Serializer): application_id = serializers.UUIDField(required=True) @@ -135,13 +155,13 @@ class ApplicationSerializer(serializers.Serializer): model_id = serializers.CharField(required=False) multiple_rounds_dialogue = serializers.BooleanField(required=False) prologue = serializers.CharField(required=False, allow_null=True, allow_blank=True) - example = serializers.ListSerializer(required=False, child=serializers.CharField(required=True)) dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True)) - - def is_valid(self, *, user_id=None, raise_exception=False): - super().is_valid(raise_exception=True) - ModelDatasetAssociation(data={'user_id': user_id, 'model_id': self.data.get('model_id'), - 'dataset_id_list': self.data.get('dataset_id_list')}).is_valid() + # 数据集相关设置 + dataset_setting = serializers.JSONField(required=False, allow_null=True) + # 模型相关设置 + model_setting = serializers.JSONField(required=False, allow_null=True) + # 问题补全 + problem_optimization = serializers.BooleanField(required=False, allow_null=True) class Create(serializers.Serializer): user_id = serializers.UUIDField(required=True) @@ -168,9 +188,12 @@ class ApplicationSerializer(serializers.Serializer): @staticmethod def to_application_model(user_id: str, application: Dict): return Application(id=uuid.uuid1(), name=application.get('name'), desc=application.get('desc'), - prologue=application.get('prologue'), example=application.get('example'), + prologue=application.get('prologue'), dialogue_number=3 if application.get('multiple_rounds_dialogue') else 0, user_id=user_id, model_id=application.get('model_id'), + dataset_setting=application.get('dataset_setting'), + model_setting=application.get('model_setting'), + problem_optimization=application.get('problem_optimization') ) @staticmethod @@ -267,7 +290,7 @@ class ApplicationSerializer(serializers.Serializer): class ApplicationModel(serializers.ModelSerializer): class Meta: model = Application - fields = ['id', 'name', 'desc', 'prologue', 'example', 'dialogue_number'] + fields = ['id', 'name', 'desc', 'prologue', 'dialogue_number'] class Operate(serializers.Serializer): application_id = serializers.UUIDField(required=True) @@ -317,8 +340,9 @@ class ApplicationSerializer(serializers.Serializer): model = QuerySet(Model).get(id=instance.get('model_id') if 'model_id' in instance else application.model_id) - update_keys = ['name', 'desc', 'model_id', 'multiple_rounds_dialogue', 'prologue', 'example', 'status', - 'api_key_is_active'] + update_keys = ['name', 'desc', 'model_id', 'multiple_rounds_dialogue', 'prologue', 'status', + 'dataset_setting', 'model_setting', 'problem_optimization' + 'api_key_is_active'] for update_key in update_keys: if update_key in instance and instance.get(update_key) is not None: if update_key == 'multiple_rounds_dialogue': diff --git a/apps/application/serializers/chat_message_serializers.py b/apps/application/serializers/chat_message_serializers.py index 13eae7cb1..c4069a7d4 100644 --- a/apps/application/serializers/chat_message_serializers.py +++ b/apps/application/serializers/chat_message_serializers.py @@ -7,194 +7,181 @@ @desc: """ import json -import uuid from typing import List +from uuid import UUID -from django.db.models import QuerySet -from django.http import StreamingHttpResponse -from langchain.chat_models.base import BaseChatModel -from langchain.schema import HumanMessage -from rest_framework import serializers, status from django.core.cache import cache -from common import event -from common.config.embedding_config import VectorStore, EmbeddingModel -from common.response import result -from dataset.models import Paragraph -from embedding.models import SourceType -from setting.models.model_management import Model +from django.db.models import QuerySet +from langchain.chat_models.base import BaseChatModel +from rest_framework import serializers + +from application.chat_pipeline.pipeline_manage import PiplineManage +from application.chat_pipeline.step.chat_step.i_chat_step import PostResponseHandler +from application.chat_pipeline.step.chat_step.impl.base_chat_step import BaseChatStep +from application.chat_pipeline.step.generate_human_message_step.impl.base_generate_human_message_step import \ + BaseGenerateHumanMessageStep +from application.chat_pipeline.step.reset_problem_step.impl.base_reset_problem_step import BaseResetProblemStep +from application.chat_pipeline.step.search_dataset_step.impl.base_search_dataset_step import BaseSearchDatasetStep +from application.models import ChatRecord, Chat, Application, ApplicationDatasetMapping +from common.exception.app_exception import AppApiException +from common.util.rsa_util import decrypt +from common.util.split_model import flat_map +from dataset.models import Paragraph, Document +from setting.models import Model +from setting.models_provider.constants.model_provider_constants import ModelProvideConstants chat_cache = cache -class MessageManagement: - @staticmethod - def get_message(title: str, content: str, message: str): - if content is None: - return HumanMessage(content=message) - return HumanMessage(content=( - f'已知信息:{title}:{content} ' - '根据上述已知信息,请简洁和专业的来回答用户的问题。已知信息中的图片、链接地址和脚本语言请直接返回。如果无法从已知信息中得到答案,请说 “没有在知识库中查找到相关信息,建议咨询相关技术支持或参考官方文档进行操作” 或 “根据已知信息无法回答该问题,建议联系官方技术支持人员”,不允许在答案中添加编造成分,答案请使用中文。' - f'问题是:{message}')) - - -class ChatMessage: - def __init__(self, id: str, problem: str, title: str, paragraph: str, embedding_id: str, dataset_id: str, - document_id: str, - paragraph_id, - source_type: SourceType, - source_id: str, - answer: str, - message_tokens: int, - answer_token: int, - chat_model=None, - chat_message=None): - self.id = id - self.problem = problem - self.title = title - self.paragraph = paragraph - self.embedding_id = embedding_id - self.dataset_id = dataset_id - self.document_id = document_id - self.paragraph_id = paragraph_id - self.source_type = source_type - self.source_id = source_id - self.answer = answer - self.message_tokens = message_tokens - self.answer_token = answer_token - self.chat_model = chat_model - self.chat_message = chat_message - - def get_chat_message(self): - return MessageManagement.get_message(self.problem, self.paragraph, self.problem) - - class ChatInfo: def __init__(self, chat_id: str, - model: Model, chat_model: BaseChatModel, - application_id: str | None, dataset_id_list: List[str], exclude_document_id_list: list[str], - dialogue_number: int): + application: Application): + """ + :param chat_id: 对话id + :param chat_model: 对话模型 + :param dataset_id_list: 数据集列表 + :param exclude_document_id_list: 排除的文档 + :param application: 应用信息 + """ self.chat_id = chat_id - self.application_id = application_id - self.model = model + self.application = application self.chat_model = chat_model self.dataset_id_list = dataset_id_list self.exclude_document_id_list = exclude_document_id_list - self.dialogue_number = dialogue_number - self.chat_message_list: List[ChatMessage] = [] + self.chat_record_list: List[ChatRecord] = [] - def append_chat_message(self, chat_message: ChatMessage): - self.chat_message_list.append(chat_message) - if self.application_id is not None: + def to_base_pipeline_manage_params(self): + dataset_setting = self.application.dataset_setting + model_setting = self.application.model_setting + return { + 'dataset_id_list': self.dataset_id_list, + 'exclude_document_id_list': self.exclude_document_id_list, + 'exclude_paragraph_id_list': [], + 'top_n': dataset_setting.get('top_n') if 'top_n' in dataset_setting else 3, + 'similarity': dataset_setting.get('similarity') if 'similarity' in dataset_setting else 0.6, + 'max_paragraph_char_number': dataset_setting.get( + 'max_paragraph_char_number') if 'max_paragraph_char_number' in dataset_setting else 5000, + 'history_chat_record': self.chat_record_list, + 'chat_id': self.chat_id, + 'dialogue_number': self.application.dialogue_number, + 'prompt': model_setting.get( + 'prompt') if 'prompt' in model_setting else Application.get_default_model_prompt(), + 'chat_model': self.chat_model, + 'model_id': self.application.model.id if self.application.model is not None else None, + 'problem_optimization': self.application.problem_optimization + + } + + def to_pipeline_manage_params(self, problem_text: str, post_response_handler: PostResponseHandler, + exclude_paragraph_id_list): + params = self.to_base_pipeline_manage_params() + return {**params, 'problem_text': problem_text, 'post_response_handler': post_response_handler, + 'exclude_paragraph_id_list': exclude_paragraph_id_list} + + def append_chat_record(self, chat_record: ChatRecord): + # 存入缓存中 + self.chat_record_list.append(chat_record) + if self.application.id is not None: # 插入数据库 - event.ListenerChatMessage.record_chat_message_signal.send( - event.RecordChatMessageArgs(len(self.chat_message_list) - 1, self.chat_id, self.application_id, - chat_message) - ) - # 异步更新token - event.ListenerChatMessage.update_chat_message_token_signal.send(chat_message) + if not QuerySet(Chat).filter(id=self.chat_id).exists(): + Chat(id=self.chat_id, application_id=self.application.id, abstract=chat_record.problem_text).save() + # 插入会话记录 + chat_record.save() - def get_context_message(self): - start_index = len(self.chat_message_list) - self.dialogue_number - return [self.chat_message_list[index].get_chat_message() for index in - range(start_index if start_index > 0 else 0, len(self.chat_message_list))] + +def get_post_handler(chat_info: ChatInfo): + class PostHandler(PostResponseHandler): + + def handler(self, + chat_id: UUID, + chat_record_id, + paragraph_list: List[Paragraph], + problem_text: str, + answer_text, + manage: PiplineManage, + step: BaseChatStep, + padding_problem_text: str = None, + **kwargs): + chat_record = ChatRecord(id=chat_record_id, + chat_id=chat_id, + paragraph_id_list=[str(p.id) for p in paragraph_list], + problem_text=problem_text, + answer_text=answer_text, + details=manage.get_details(), + message_tokens=manage.context['message_tokens'], + answer_tokens=manage.context['answer_tokens'], + run_time=manage.context['run_time'], + index=len(chat_info.chat_record_list) + 1) + chat_info.append_chat_record(chat_record) + # 重新设置缓存 + chat_cache.set(chat_id, + chat_info, timeout=60 * 30) + + return PostHandler() class ChatMessageSerializer(serializers.Serializer): chat_id = serializers.UUIDField(required=True) - def chat(self, message): + def chat(self, message, re_chat: bool): self.is_valid(raise_exception=True) chat_id = self.data.get('chat_id') chat_info: ChatInfo = chat_cache.get(chat_id) if chat_info is None: - return result.Result(response_status=status.HTTP_404_NOT_FOUND, code=404, message="会话过期") + chat_info = self.re_open_chat(chat_id) + chat_cache.set(chat_id, + chat_info, timeout=60 * 30) - chat_model = chat_info.chat_model - vector = VectorStore.get_embedding_vector() - # 向量库检索 - _value = vector.search(message, chat_info.dataset_id_list, chat_info.exclude_document_id_list, - [chat_message.embedding_id for chat_message in - (list(filter(lambda row: row.problem == message, chat_info.chat_message_list)))], - True, - EmbeddingModel.get_embedding_model()) - # 查询段落id详情 - paragraph = None - if _value is not None: - paragraph = QuerySet(Paragraph).get(id=_value.get('paragraph_id')) - if paragraph is None: - vector.delete_by_paragraph_id(_value.get('paragraph_id')) + pipline_manage_builder = PiplineManage.builder() + # 如果开启了问题优化,则添加上问题优化步骤 + if chat_info.application.problem_optimization: + pipline_manage_builder.append_step(BaseResetProblemStep) + # 构建流水线管理器 + pipline_message = (pipline_manage_builder.append_step(BaseSearchDatasetStep) + .append_step(BaseGenerateHumanMessageStep) + .append_step(BaseChatStep) + .build()) + exclude_paragraph_id_list = [] + # 相同问题是否需要排除已经查询到的段落 + if re_chat: + paragraph_id_list = flat_map([row.paragraph_id_list for row in + filter(lambda chat_record: chat_record == message, + chat_info.chat_record_list)]) + exclude_paragraph_id_list = list(set(paragraph_id_list)) + # 构建运行参数 + params = chat_info.to_pipeline_manage_params(message, get_post_handler(chat_info), exclude_paragraph_id_list) + # 运行流水线作业 + pipline_message.run(params) + return pipline_message.context['chat_result'] - title, content = (None, None) if paragraph is None else (paragraph.title, paragraph.content) - _id = str(uuid.uuid1()) + @staticmethod + def re_open_chat(chat_id: str): + chat = QuerySet(Chat).filter(id=chat_id).first() + if chat is None: + raise AppApiException(500, "会话不存在") + application = QuerySet(Application).filter(id=chat.application_id).first() + if application is None: + raise AppApiException(500, "应用不存在") + model = QuerySet(Model).filter(id=application.model_id).first() + chat_model = None + if model is not None: + # 对话模型 + chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name, + json.loads( + decrypt(model.credential)), + streaming=True) + # 数据集id列表 + dataset_id_list = [str(row.dataset_id) for row in + QuerySet(ApplicationDatasetMapping).filter( + application_id=application.id)] - embedding_id, dataset_id, document_id, paragraph_id, source_type, source_id = (_value.get( - 'id'), _value.get( - 'dataset_id'), _value.get( - 'document_id'), _value.get( - 'paragraph_id'), _value.get( - 'source_type'), _value.get( - 'source_id')) if _value is not None else (None, None, None, None, None, None) - - if chat_model is None: - def event_block_content(c: str): - yield 'data: ' + json.dumps({'chat_id': chat_id, 'id': _id, 'operate': paragraph is not None, - 'is_end': True, - 'content': c if c is not None else '抱歉,根据已知信息无法回答这个问题,请重新描述您的问题或提供更多信息~'}) + "\n\n" - chat_info.append_chat_message( - ChatMessage(_id, message, title, content, embedding_id, dataset_id, document_id, - paragraph_id, - source_type, - source_id, - c if c is not None else '抱歉,根据已知信息无法回答这个问题,请重新描述您的问题或提供更多信息~', - 0, - 0)) - # 重新设置缓存 - chat_cache.set(chat_id, - chat_info, timeout=60 * 30) - - r = StreamingHttpResponse(streaming_content=event_block_content(content), - content_type='text/event-stream;charset=utf-8') - - r['Cache-Control'] = 'no-cache' - return r - # 获取上下文 - history_message = chat_info.get_context_message() - - # 构建会话请求问题 - chat_message = [*history_message, MessageManagement.get_message(title, content, message)] - # 对话 - result_data = chat_model.stream(chat_message) - - def event_content(response): - all_text = '' - try: - for chunk in response: - all_text += chunk.content - yield 'data: ' + json.dumps({'chat_id': chat_id, 'id': _id, 'operate': paragraph is not None, - 'content': chunk.content, 'is_end': False}) + "\n\n" - - yield 'data: ' + json.dumps({'chat_id': chat_id, 'id': _id, 'operate': paragraph is not None, - 'content': '', 'is_end': True}) + "\n\n" - chat_info.append_chat_message( - ChatMessage(_id, message, title, content, embedding_id, dataset_id, document_id, - paragraph_id, - source_type, - source_id, all_text, - 0, - 0, - chat_message=chat_message, chat_model=chat_model)) - # 重新设置缓存 - chat_cache.set(chat_id, - chat_info, timeout=60 * 30) - except Exception as e: - yield e - - r = StreamingHttpResponse(streaming_content=event_content(result_data), - content_type='text/event-stream;charset=utf-8') - - r['Cache-Control'] = 'no-cache' - return r + # 需要排除的文档 + exclude_document_id_list = [str(document.id) for document in + QuerySet(Document).filter( + dataset_id__in=dataset_id_list, + is_active=False)] + return ChatInfo(chat_id, chat_model, dataset_id_list, exclude_document_id_list, application) diff --git a/apps/application/serializers/chat_serializers.py b/apps/application/serializers/chat_serializers.py index 474de0ee8..2204a8b07 100644 --- a/apps/application/serializers/chat_serializers.py +++ b/apps/application/serializers/chat_serializers.py @@ -10,7 +10,8 @@ import datetime import json import os import uuid -from typing import Dict +from functools import reduce +from typing import Dict, List from django.core.cache import cache from django.db import transaction @@ -18,7 +19,8 @@ from django.db.models import QuerySet from rest_framework import serializers from application.models import Chat, Application, ApplicationDatasetMapping, VoteChoices, ChatRecord -from application.serializers.application_serializers import ModelDatasetAssociation +from application.serializers.application_serializers import ModelDatasetAssociation, DatasetSettingSerializer, \ + ModelSettingSerializer from application.serializers.chat_message_serializers import ChatInfo from common.db.search import native_search, native_page_search, page_search from common.event import ListenerManagement @@ -26,8 +28,8 @@ from common.exception.app_exception import AppApiException from common.util.file_util import get_file_content from common.util.lock import try_lock, un_lock from common.util.rsa_util import decrypt +from common.util.split_model import flat_map from dataset.models import Document, Problem, Paragraph -from embedding.models import SourceType, Embedding from setting.models import Model from setting.models_provider.constants.model_provider_constants import ModelProvideConstants from smartdoc.conf import PROJECT_DIR @@ -106,12 +108,12 @@ class ChatSerializers(serializers.Serializer): chat_id = str(uuid.uuid1()) chat_cache.set(chat_id, - ChatInfo(chat_id, model, chat_model, application_id, dataset_id_list, + ChatInfo(chat_id, chat_model, dataset_id_list, [str(document.id) for document in QuerySet(Document).filter( dataset_id__in=dataset_id_list, is_active=False)], - application.dialogue_number), timeout=60 * 30) + application), timeout=60 * 30) return chat_id class OpenTempChat(serializers.Serializer): @@ -122,6 +124,12 @@ class ChatSerializers(serializers.Serializer): multiple_rounds_dialogue = serializers.BooleanField(required=True) dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True)) + # 数据集相关设置 + dataset_setting = DatasetSettingSerializer(required=True) + # 模型相关设置 + model_setting = ModelSettingSerializer(required=True) + # 问题补全 + problem_optimization = serializers.BooleanField(required=True) def is_valid(self, *, raise_exception=False): super().is_valid(raise_exception=True) @@ -140,42 +148,62 @@ class ChatSerializers(serializers.Serializer): json.loads( decrypt(model.credential)), streaming=True) + application = Application(id=None, dialogue_number=3, model=model, + dataset_setting=self.data.get('dataset_setting'), + model_setting=self.data.get('model_setting'), + problem_optimization=self.data.get('problem_optimization')) chat_cache.set(chat_id, - ChatInfo(chat_id, model, chat_model, None, dataset_id_list, + ChatInfo(chat_id, chat_model, dataset_id_list, [str(document.id) for document in QuerySet(Document).filter( dataset_id__in=dataset_id_list, is_active=False)], - 3 if self.data.get('multiple_rounds_dialogue') else 1), timeout=60 * 30) + application), timeout=60 * 30) return chat_id -def vote_exec(source_type: SourceType, source_id: str, field: str, post_handler): - if source_type == SourceType.PROBLEM: - problem = QuerySet(Problem).get(id=source_id) - if problem is not None: - problem.__setattr__(field, post_handler(problem)) - problem.save() - embedding = QuerySet(Embedding).get(source_id=source_id, source_type=source_type) - embedding.__setattr__(field, problem.__getattribute__(field)) - embedding.save() - if source_type == SourceType.PARAGRAPH: - paragraph = QuerySet(Paragraph).get(id=source_id) - if paragraph is not None: - paragraph.__setattr__(field, post_handler(paragraph)) - paragraph.save() - embedding = QuerySet(Embedding).get(source_id=source_id, source_type=source_type) - embedding.__setattr__(field, paragraph.__getattribute__(field)) - embedding.save() - - class ChatRecordSerializerModel(serializers.ModelSerializer): class Meta: model = ChatRecord - fields = "__all__" + fields = ['id', 'chat_id', 'vote_status', 'problem_text', 'answer_text', + 'message_tokens', 'answer_tokens', 'const', 'improve_paragraph_id_list', 'run_time', 'index'] class ChatRecordSerializer(serializers.Serializer): + class Operate(serializers.Serializer): + chat_id = serializers.UUIDField(required=True) + + chat_record_id = serializers.UUIDField(required=True) + + def one(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + chat_record_id = self.data.get('chat_record_id') + chat_id = self.data.get('chat_id') + chat_record = QuerySet(ChatRecord).filter(id=chat_record_id, chat_id=chat_id).first() + dataset_list = [] + paragraph_list = [] + if len(chat_record.paragraph_id_list) > 0: + paragraph_list = native_search(QuerySet(Paragraph).filter(id__in=chat_record.paragraph_id_list), + get_file_content( + os.path.join(PROJECT_DIR, "apps", "application", 'sql', + 'list_dataset_paragraph_by_paragraph_id.sql')), + with_table_name=True) + dataset_list = [{'id': dataset_id, 'name': name} for dataset_id, name in reduce(lambda x, y: {**x, **y}, + [{row.get( + 'dataset_id'): row.get( + "dataset_name")} for + row in + paragraph_list], + {}).items()] + + return { + **ChatRecordSerializerModel(chat_record).data, + 'padding_problem_text': chat_record.details.get( + 'padding_problem_text') if 'problem_padding' in chat_record.details else None, + 'dataset_list': dataset_list, + 'paragraph_list': paragraph_list} + class Query(serializers.Serializer): application_id = serializers.UUIDField(required=True) chat_id = serializers.UUIDField(required=True) @@ -183,15 +211,57 @@ class ChatRecordSerializer(serializers.Serializer): def list(self, with_valid=True): if with_valid: self.is_valid(raise_exception=True) + QuerySet(ChatRecord).filter(chat_id=self.data.get('chat_id')) return [ChatRecordSerializerModel(chat_record).data for chat_record in QuerySet(ChatRecord).filter(chat_id=self.data.get('chat_id'))] + def reset_chat_record_list(self, chat_record_list: List[ChatRecord]): + paragraph_id_list = flat_map([chat_record.paragraph_id_list for chat_record in chat_record_list]) + # 去重 + paragraph_id_list = list(set(paragraph_id_list)) + paragraph_list = self.search_paragraph(paragraph_id_list) + return [self.reset_chat_record(chat_record, paragraph_list) for chat_record in chat_record_list] + + @staticmethod + def search_paragraph(paragraph_id_list: List[str]): + paragraph_list = [] + if len(paragraph_id_list) > 0: + paragraph_list = native_search(QuerySet(Paragraph).filter(id__in=paragraph_id_list), + get_file_content( + os.path.join(PROJECT_DIR, "apps", "application", 'sql', + 'list_dataset_paragraph_by_paragraph_id.sql')), + with_table_name=True) + return paragraph_list + + @staticmethod + def reset_chat_record(chat_record, all_paragraph_list): + paragraph_list = list( + filter(lambda paragraph: chat_record.paragraph_id_list.__contains__(str(paragraph.get('id'))), + all_paragraph_list)) + dataset_list = [{'id': dataset_id, 'name': name} for dataset_id, name in reduce(lambda x, y: {**x, **y}, + [{row.get( + 'dataset_id'): row.get( + "dataset_name")} for + row in + paragraph_list], + {}).items()] + return { + **ChatRecordSerializerModel(chat_record).data, + 'padding_problem_text': chat_record.details.get('problem_padding').get( + 'padding_problem_text') if 'problem_padding' in chat_record.details else None, + 'dataset_list': dataset_list, + 'paragraph_list': paragraph_list + } + def page(self, current_page: int, page_size: int, with_valid=True): if with_valid: self.is_valid(raise_exception=True) - return page_search(current_page, page_size, + page = page_search(current_page, page_size, QuerySet(ChatRecord).filter(chat_id=self.data.get('chat_id')).order_by("index"), - post_records_handler=lambda chat_record: ChatRecordSerializerModel(chat_record).data) + post_records_handler=lambda chat_record: chat_record) + records = page.get('records') + page['records'] = self.reset_chat_record_list(records) + return page class Vote(serializers.Serializer): chat_id = serializers.UUIDField(required=True) @@ -216,38 +286,20 @@ class ChatRecordSerializer(serializers.Serializer): if vote_status == VoteChoices.STAR: # 点赞 chat_record_details_model.vote_status = VoteChoices.STAR - # 点赞数量 +1 - vote_exec(chat_record_details_model.source_type, chat_record_details_model.source_id, - 'star_num', - lambda r: r.star_num + 1) if vote_status == VoteChoices.TRAMPLE: # 点踩 chat_record_details_model.vote_status = VoteChoices.TRAMPLE - # 点踩数量+1 - vote_exec(chat_record_details_model.source_type, chat_record_details_model.source_id, - 'trample_num', - lambda r: r.trample_num + 1) chat_record_details_model.save() else: if vote_status == VoteChoices.UN_VOTE: # 取消点赞 chat_record_details_model.vote_status = VoteChoices.UN_VOTE chat_record_details_model.save() - if chat_record_details_model.vote_status == VoteChoices.STAR: - # 点赞数量 -1 - vote_exec(chat_record_details_model.source_type, chat_record_details_model.source_id, - 'star_num', lambda r: r.star_num - 1) - if chat_record_details_model.vote_status == VoteChoices.TRAMPLE: - # 点踩数量 -1 - vote_exec(chat_record_details_model.source_type, chat_record_details_model.source_id, - 'trample_num', lambda r: r.trample_num - 1) - else: raise AppApiException(500, "已经投票过,请先取消后再进行投票") finally: un_lock(self.data.get('chat_record_id')) - return True class ImproveSerializer(serializers.Serializer): diff --git a/apps/application/sql/list_application.sql b/apps/application/sql/list_application.sql index 283c7bb92..b7aa0fbe9 100644 --- a/apps/application/sql/list_application.sql +++ b/apps/application/sql/list_application.sql @@ -1,4 +1,4 @@ -SELECT * FROM ( SELECT * FROM application ${application_custom_sql} UNION +SELECT *,to_json(dataset_setting) as dataset_setting,to_json(model_setting) as model_setting FROM ( SELECT * FROM application ${application_custom_sql} UNION SELECT * FROM diff --git a/apps/application/sql/list_dataset_paragraph_by_paragraph_id.sql b/apps/application/sql/list_dataset_paragraph_by_paragraph_id.sql new file mode 100644 index 000000000..4c27ca07d --- /dev/null +++ b/apps/application/sql/list_dataset_paragraph_by_paragraph_id.sql @@ -0,0 +1,6 @@ +SELECT + paragraph.*, + dataset."name" AS "dataset_name" +FROM + paragraph paragraph + LEFT JOIN dataset dataset ON dataset."id" = paragraph.dataset_id \ No newline at end of file diff --git a/apps/application/swagger_api/application_api.py b/apps/application/swagger_api/application_api.py index 54eccc9ef..2a7419496 100644 --- a/apps/application/swagger_api/application_api.py +++ b/apps/application/swagger_api/application_api.py @@ -58,6 +58,7 @@ class ApplicationApi(ApiMixin): 'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间", description='创建时间'), 'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间", description='修改时间'), + 'dataset_id_list': openapi.Schema(type=openapi.TYPE_ARRAY, items=openapi.Schema(type=openapi.TYPE_STRING), title="关联知识库Id列表", @@ -133,7 +134,7 @@ class ApplicationApi(ApiMixin): def get_request_body_api(): return openapi.Schema( type=openapi.TYPE_OBJECT, - required=['name', 'desc', 'model_id', 'multiple_rounds_dialogue'], + required=[], properties={ 'name': openapi.Schema(type=openapi.TYPE_STRING, title="应用名称", description="应用名称"), 'desc': openapi.Schema(type=openapi.TYPE_STRING, title="应用描述", description="应用描述"), @@ -141,11 +142,52 @@ class ApplicationApi(ApiMixin): "multiple_rounds_dialogue": openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否开启多轮对话", description="是否开启多轮对话"), 'prologue': openapi.Schema(type=openapi.TYPE_STRING, title="开场白", description="开场白"), - 'example': openapi.Schema(type=openapi.TYPE_ARRAY, items=openapi.Schema(type=openapi.TYPE_STRING), - title="示例列表", description="示例列表"), 'dataset_id_list': openapi.Schema(type=openapi.TYPE_ARRAY, items=openapi.Schema(type=openapi.TYPE_STRING), title="关联知识库Id列表", description="关联知识库Id列表"), + 'dataset_setting': ApplicationApi.DatasetSetting.get_request_body_api(), + 'model_setting': ApplicationApi.ModelSetting.get_request_body_api(), + 'problem_optimization': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="问题优化", + description="是否开启问题优化", default=True) + + } + ) + + class DatasetSetting(ApiMixin): + @staticmethod + def get_request_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=[''], + properties={ + 'top_n': openapi.Schema(type=openapi.TYPE_NUMBER, title="引用分段数", description="引用分段数", + default=5), + 'similarity': openapi.Schema(type=openapi.TYPE_NUMBER, title='相似度', description="相似度", + default=0.6), + 'max_paragraph_char_number': openapi.Schema(type=openapi.TYPE_NUMBER, title='最多引用字符数', + description="最多引用字符数", default=3000), + } + ) + + class ModelSetting(ApiMixin): + @staticmethod + def get_request_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['prompt'], + properties={ + 'prompt': openapi.Schema(type=openapi.TYPE_STRING, title="提示词", description="提示词", + default=('已知信息:' + '\n{data}' + '\n回答要求:' + '\n- 如果你不知道答案或者没有从获取答案,请回答“没有在知识库中查找到相关信息,建议咨询相关技术支持或参考官方文档进行操作”。' + '\n- 避免提及你是从中获得的知识。' + '\n- 请保持答案与中描述的一致。' + '\n- 请使用markdown 语法优化答案的格式。' + '\n- 中的图片链接、链接地址和脚本语言请完整返回。' + '\n- 请使用与问题相同的语言来回答。' + '\n问题:' + '\n{question}')), } ) @@ -155,7 +197,8 @@ class ApplicationApi(ApiMixin): def get_request_body_api(): return openapi.Schema( type=openapi.TYPE_OBJECT, - required=['name', 'desc', 'model_id', 'multiple_rounds_dialogue'], + required=['name', 'desc', 'model_id', 'multiple_rounds_dialogue', 'dataset_setting', 'model_setting', + 'problem_optimization'], properties={ 'name': openapi.Schema(type=openapi.TYPE_STRING, title="应用名称", description="应用名称"), 'desc': openapi.Schema(type=openapi.TYPE_STRING, title="应用描述", description="应用描述"), @@ -163,11 +206,13 @@ class ApplicationApi(ApiMixin): "multiple_rounds_dialogue": openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否开启多轮对话", description="是否开启多轮对话"), 'prologue': openapi.Schema(type=openapi.TYPE_STRING, title="开场白", description="开场白"), - 'example': openapi.Schema(type=openapi.TYPE_ARRAY, items=openapi.Schema(type=openapi.TYPE_STRING), - title="示例列表", description="示例列表"), 'dataset_id_list': openapi.Schema(type=openapi.TYPE_ARRAY, items=openapi.Schema(type=openapi.TYPE_STRING), - title="关联知识库Id列表", description="关联知识库Id列表") + title="关联知识库Id列表", description="关联知识库Id列表"), + 'dataset_setting': ApplicationApi.DatasetSetting.get_request_body_api(), + 'model_setting': ApplicationApi.ModelSetting.get_request_body_api(), + 'problem_optimization': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="问题优化", + description="是否开启问题优化", default=True) } ) diff --git a/apps/application/swagger_api/chat_api.py b/apps/application/swagger_api/chat_api.py index 0cde547e1..bc36599c8 100644 --- a/apps/application/swagger_api/chat_api.py +++ b/apps/application/swagger_api/chat_api.py @@ -8,6 +8,7 @@ """ from drf_yasg import openapi +from application.swagger_api.application_api import ApplicationApi from common.mixins.api_mixin import ApiMixin @@ -19,6 +20,7 @@ class ChatApi(ApiMixin): required=['message'], properties={ 'message': openapi.Schema(type=openapi.TYPE_STRING, title="问题", description="问题"), + 're_chat': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="重新生成", default="重新生成") } ) @@ -73,14 +75,19 @@ class ChatApi(ApiMixin): def get_request_body_api(): return openapi.Schema( type=openapi.TYPE_OBJECT, - required=['model_id', 'multiple_rounds_dialogue'], + required=['model_id', 'multiple_rounds_dialogue', 'dataset_setting', 'model_setting', + 'problem_optimization'], properties={ 'model_id': openapi.Schema(type=openapi.TYPE_STRING, title="模型id", description="模型id"), 'dataset_id_list': openapi.Schema(type=openapi.TYPE_ARRAY, items=openapi.Schema(type=openapi.TYPE_STRING), title="关联知识库Id列表", description="关联知识库Id列表"), 'multiple_rounds_dialogue': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否开启多轮会话", - description="是否开启多轮会话") + description="是否开启多轮会话"), + 'dataset_setting': ApplicationApi.DatasetSetting.get_request_body_api(), + 'model_setting': ApplicationApi.ModelSetting.get_request_body_api(), + 'problem_optimization': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="问题优化", + description="是否开启问题优化", default=True) } ) diff --git a/apps/application/urls.py b/apps/application/urls.py index ef74a8830..6944b4e6a 100644 --- a/apps/application/urls.py +++ b/apps/application/urls.py @@ -25,6 +25,8 @@ urlpatterns = [ path('application//chat//chat_record/', views.ChatView.ChatRecord.as_view()), path('application//chat//chat_record//', views.ChatView.ChatRecord.Page.as_view()), + path('application//chat//chat_record/', + views.ChatView.ChatRecord.Operate.as_view()), path('application//chat//chat_record//vote', views.ChatView.ChatRecord.Vote.as_view(), name=''), diff --git a/apps/application/views/chat_views.py b/apps/application/views/chat_views.py index 3ee0223c3..c39c149c9 100644 --- a/apps/application/views/chat_views.py +++ b/apps/application/views/chat_views.py @@ -71,7 +71,8 @@ class ChatView(APIView): dynamic_tag=keywords.get('application_id'))]) ) def post(self, request: Request, chat_id: str): - return ChatMessageSerializer(data={'chat_id': chat_id}).chat(request.data.get('message')) + return ChatMessageSerializer(data={'chat_id': chat_id}).chat(request.data.get('message'), request.data.get( + 're_chat') if 're_chat' in request.data else False) @action(methods=['GET'], detail=False) @swagger_auto_schema(operation_summary="获取对话列表", @@ -134,6 +135,27 @@ class ChatView(APIView): class ChatRecord(APIView): authentication_classes = [TokenAuth] + class Operate(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="获取对话记录详情", + operation_id="获取对话记录详情", + manual_parameters=ChatRecordApi.get_request_params_api(), + responses=result.get_api_array_response(ChatRecordApi.get_response_body_api()), + tags=["应用/对话日志"] + ) + @has_permissions( + ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_KEY], + [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, + dynamic_tag=keywords.get('application_id'))]) + ) + def get(self, request: Request, application_id: str, chat_id: str, chat_record_id: str): + return result.success(ChatRecordSerializer.Operate( + data={'application_id': application_id, + 'chat_id': chat_id, + 'chat_record_id': chat_record_id}).one()) + @action(methods=['GET'], detail=False) @swagger_auto_schema(operation_summary="获取对话记录列表", operation_id="获取对话记录列表", diff --git a/apps/common/db/search.py b/apps/common/db/search.py index f8d4a6878..763667154 100644 --- a/apps/common/db/search.py +++ b/apps/common/db/search.py @@ -83,7 +83,7 @@ def compiler_queryset(queryset: QuerySet, field_replace_dict: None | Dict[str, s field_replace_dict = get_field_replace_dict(queryset) app_sql_compiler = AppSQLCompiler(q, using=DEFAULT_DB_ALIAS, connection=compiler.connection, field_replace_dict=field_replace_dict) - sql, params = app_sql_compiler.get_query_str(with_table_name) + sql, params = app_sql_compiler.get_query_str(with_table_name=with_table_name) return sql, params diff --git a/apps/common/event/__init__.py b/apps/common/event/__init__.py index e00026645..f98e84f56 100644 --- a/apps/common/event/__init__.py +++ b/apps/common/event/__init__.py @@ -7,10 +7,8 @@ @desc: """ from .listener_manage import * -from .listener_chat_message import * def run(): listener_manage.ListenerManagement().run() - listener_chat_message.ListenerChatMessage().run() QuerySet(Document).filter(status=Status.embedding).update(**{'status': Status.error}) diff --git a/apps/common/event/listener_chat_message.py b/apps/common/event/listener_chat_message.py deleted file mode 100644 index 26c9fe3bc..000000000 --- a/apps/common/event/listener_chat_message.py +++ /dev/null @@ -1,67 +0,0 @@ -# coding=utf-8 -""" - @project: maxkb - @Author:虎 - @file: listener_manage.py - @date:2023/10/20 14:01 - @desc: -""" -import logging - -from blinker import signal -from django.db.models import QuerySet - -from application.models import ChatRecord, Chat -from application.serializers.chat_message_serializers import ChatMessage -from common.event.common import poxy - - -class RecordChatMessageArgs: - def __init__(self, index: int, chat_id: str, application_id: str, chat_message: ChatMessage): - self.index = index - self.chat_id = chat_id - self.application_id = application_id - self.chat_message = chat_message - - -class ListenerChatMessage: - record_chat_message_signal = signal("record_chat_message") - update_chat_message_token_signal = signal("update_chat_message_token") - - @staticmethod - def record_chat_message(args: RecordChatMessageArgs): - if not QuerySet(Chat).filter(id=args.chat_id).exists(): - Chat(id=args.chat_id, application_id=args.application_id, abstract=args.chat_message.problem).save() - # 插入会话记录 - try: - chat_record = ChatRecord( - id=args.chat_message.id, - chat_id=args.chat_id, - dataset_id=args.chat_message.dataset_id, - paragraph_id=args.chat_message.paragraph_id, - source_id=args.chat_message.source_id, - source_type=args.chat_message.source_type, - problem_text=args.chat_message.problem, - answer_text=args.chat_message.answer, - index=args.index, - message_tokens=args.chat_message.message_tokens, - answer_tokens=args.chat_message.answer_token) - chat_record.save() - except Exception as e: - print(e) - - @staticmethod - @poxy - def update_token(chat_message: ChatMessage): - if chat_message.chat_model is not None: - logging.getLogger("max_kb").info("开始更新token") - message_token = chat_message.chat_model.get_num_tokens_from_messages(chat_message.chat_message) - answer_token = chat_message.chat_model.get_num_tokens(chat_message.answer) - # 修改token数量 - QuerySet(ChatRecord).filter(id=chat_message.id).update( - **{'message_tokens': message_token, 'answer_tokens': answer_token}) - - def run(self): - # 记录会话 - ListenerChatMessage.record_chat_message_signal.connect(self.record_chat_message) - ListenerChatMessage.update_chat_message_token_signal.connect(self.update_token) diff --git a/apps/common/field/common.py b/apps/common/field/common.py new file mode 100644 index 000000000..a97469891 --- /dev/null +++ b/apps/common/field/common.py @@ -0,0 +1,34 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: common.py + @date:2024/1/11 18:44 + @desc: +""" +from rest_framework import serializers + + +class InstanceField(serializers.Field): + def __init__(self, model_type, **kwargs): + self.model_type = model_type + super().__init__(**kwargs) + + def to_internal_value(self, data): + if not isinstance(data, self.model_type): + self.fail('message类型错误', value=data) + return data + + def to_representation(self, value): + return value + + +class FunctionField(serializers.Field): + + def to_internal_value(self, data): + if not callable(data): + self.fail('不是一個函數', value=data) + return data + + def to_representation(self, value): + return value diff --git a/apps/common/sql/list_embedding_text.sql b/apps/common/sql/list_embedding_text.sql index e720041f2..b051da1c6 100644 --- a/apps/common/sql/list_embedding_text.sql +++ b/apps/common/sql/list_embedding_text.sql @@ -5,9 +5,7 @@ SELECT problem.dataset_id AS dataset_id, 0 AS source_type, problem."content" AS "text", - paragraph.is_active AS is_active, - problem.star_num as star_num, - problem.trample_num as trample_num + paragraph.is_active AS is_active FROM problem problem LEFT JOIN paragraph paragraph ON paragraph."id" = problem.paragraph_id @@ -23,9 +21,7 @@ SELECT concat_ws(' ',concat_ws(' ',paragraph.title,paragraph."content"),paragraph.title) AS "text", - paragraph.is_active AS is_active, - paragraph.star_num as star_num, - paragraph.trample_num as trample_num + paragraph.is_active AS is_active FROM paragraph paragraph diff --git a/apps/dataset/migrations/0004_remove_paragraph_hit_num_remove_paragraph_star_num_and_more.py b/apps/dataset/migrations/0004_remove_paragraph_hit_num_remove_paragraph_star_num_and_more.py new file mode 100644 index 000000000..3440fa39f --- /dev/null +++ b/apps/dataset/migrations/0004_remove_paragraph_hit_num_remove_paragraph_star_num_and_more.py @@ -0,0 +1,37 @@ +# Generated by Django 4.1.10 on 2024-01-16 11:22 + +from django.db import migrations + + +class Migration(migrations.Migration): + + dependencies = [ + ('dataset', '0003_alter_paragraph_content'), + ] + + operations = [ + migrations.RemoveField( + model_name='paragraph', + name='hit_num', + ), + migrations.RemoveField( + model_name='paragraph', + name='star_num', + ), + migrations.RemoveField( + model_name='paragraph', + name='trample_num', + ), + migrations.RemoveField( + model_name='problem', + name='hit_num', + ), + migrations.RemoveField( + model_name='problem', + name='star_num', + ), + migrations.RemoveField( + model_name='problem', + name='trample_num', + ), + ] diff --git a/apps/dataset/models/data_set.py b/apps/dataset/models/data_set.py index 2b493c0c4..7a69e2753 100644 --- a/apps/dataset/models/data_set.py +++ b/apps/dataset/models/data_set.py @@ -74,9 +74,6 @@ class Paragraph(AppModelMixin): dataset = models.ForeignKey(DataSet, on_delete=models.DO_NOTHING) content = models.CharField(max_length=4096, verbose_name="段落内容") title = models.CharField(max_length=256, verbose_name="标题", default="") - hit_num = models.IntegerField(verbose_name="命中数量", default=0) - star_num = models.IntegerField(verbose_name="点赞数", default=0) - trample_num = models.IntegerField(verbose_name="点踩数", default=0) status = models.CharField(verbose_name='状态', max_length=1, choices=Status.choices, default=Status.embedding) is_active = models.BooleanField(default=True) @@ -94,9 +91,6 @@ class Problem(AppModelMixin): dataset = models.ForeignKey(DataSet, on_delete=models.DO_NOTHING, db_constraint=False) paragraph = models.ForeignKey(Paragraph, on_delete=models.DO_NOTHING, db_constraint=False) content = models.CharField(max_length=256, verbose_name="问题内容") - hit_num = models.IntegerField(verbose_name="命中数量", default=0) - star_num = models.IntegerField(verbose_name="点赞数", default=0) - trample_num = models.IntegerField(verbose_name="点踩数", default=0) class Meta: db_table = "problem" diff --git a/apps/dataset/serializers/dataset_serializers.py b/apps/dataset/serializers/dataset_serializers.py index 1c0524099..72147db2e 100644 --- a/apps/dataset/serializers/dataset_serializers.py +++ b/apps/dataset/serializers/dataset_serializers.py @@ -26,12 +26,11 @@ from application.models import ApplicationDatasetMapping from common.config.embedding_config import VectorStore, EmbeddingModel from common.db.search import get_dynamics_model, native_page_search, native_search from common.db.sql_execute import select_list -from common.event.listener_manage import ListenerManagement, SyncWebDatasetArgs from common.exception.app_exception import AppApiException from common.mixins.api_mixin import ApiMixin from common.util.common import post from common.util.file_util import get_file_content -from common.util.fork import ChildLink, Fork, ForkManage +from common.util.fork import ChildLink, Fork from common.util.split_model import get_split_model from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type from dataset.serializers.common_serializers import list_paragraph @@ -286,7 +285,8 @@ class DataSetSerializers(serializers.ModelSerializer): properties={ 'name': openapi.Schema(type=openapi.TYPE_STRING, title="知识库名称", description="知识库名称"), 'desc': openapi.Schema(type=openapi.TYPE_STRING, title="知识库描述", description="知识库描述"), - 'source_url': openapi.Schema(type=openapi.TYPE_STRING, title="web站点url", description="web站点url"), + 'source_url': openapi.Schema(type=openapi.TYPE_STRING, title="web站点url", + description="web站点url"), 'selector': openapi.Schema(type=openapi.TYPE_STRING, title="选择器", description="选择器") } ) @@ -369,7 +369,8 @@ class DataSetSerializers(serializers.ModelSerializer): dataset_id = uuid.uuid1() dataset = DataSet( **{'id': dataset_id, 'name': instance.get("name"), 'desc': instance.get('desc'), 'user_id': user_id, - 'type': Type.web, 'meta': {'source_url': instance.get('source_url'), 'selector': instance.get('selector')}}) + 'type': Type.web, + 'meta': {'source_url': instance.get('source_url'), 'selector': instance.get('selector')}}) dataset.save() ListenerManagement.sync_web_dataset_signal.send( SyncWebDatasetArgs(str(dataset_id), instance.get('source_url'), instance.get('selector'), diff --git a/apps/dataset/serializers/paragraph_serializers.py b/apps/dataset/serializers/paragraph_serializers.py index 48a1ef434..48694a7d8 100644 --- a/apps/dataset/serializers/paragraph_serializers.py +++ b/apps/dataset/serializers/paragraph_serializers.py @@ -28,7 +28,7 @@ from dataset.serializers.problem_serializers import ProblemInstanceSerializer, P class ParagraphSerializer(serializers.ModelSerializer): class Meta: model = Paragraph - fields = ['id', 'content', 'hit_num', 'star_num', 'trample_num', 'is_active', 'document_id', 'title', + fields = ['id', 'content', 'is_active', 'document_id', 'title', 'create_time', 'update_time'] diff --git a/apps/dataset/serializers/problem_serializers.py b/apps/dataset/serializers/problem_serializers.py index 5948ef5b0..2f5fdc0a4 100644 --- a/apps/dataset/serializers/problem_serializers.py +++ b/apps/dataset/serializers/problem_serializers.py @@ -24,7 +24,7 @@ from embedding.vector.pg_vector import PGVector class ProblemSerializer(serializers.ModelSerializer): class Meta: model = Problem - fields = ['id', 'content', 'hit_num', 'star_num', 'trample_num', 'dataset_id', 'document_id', + fields = ['id', 'content', 'dataset_id', 'document_id', 'create_time', 'update_time'] @@ -77,8 +77,8 @@ class ProblemSerializers(ApiMixin, serializers.Serializer): 'document_id': self.data.get('document_id'), 'paragraph_id': self.data.get('paragraph_id'), 'dataset_id': self.data.get('dataset_id'), - 'star_num': 0, - 'trample_num': 0}) + + }) return ProblemSerializers.Operate( data={'dataset_id': self.data.get('dataset_id'), 'document_id': self.data.get('document_id'), diff --git a/apps/embedding/migrations/0002_remove_embedding_star_num_and_more.py b/apps/embedding/migrations/0002_remove_embedding_star_num_and_more.py new file mode 100644 index 000000000..5597c78b9 --- /dev/null +++ b/apps/embedding/migrations/0002_remove_embedding_star_num_and_more.py @@ -0,0 +1,37 @@ +# Generated by Django 4.1.10 on 2024-01-16 11:22 + +import django.contrib.postgres.fields +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('embedding', '0001_initial'), + ] + + operations = [ + migrations.RemoveField( + model_name='embedding', + name='star_num', + ), + migrations.RemoveField( + model_name='embedding', + name='trample_num', + ), + migrations.AddField( + model_name='embedding', + name='keywords', + field=django.contrib.postgres.fields.ArrayField(base_field=models.CharField(max_length=256), default=list, size=None, verbose_name='关键词列表'), + ), + migrations.AddField( + model_name='embedding', + name='meta', + field=models.JSONField(default=dict, verbose_name='元数据'), + ), + migrations.AlterField( + model_name='embedding', + name='source_type', + field=models.CharField(choices=[('0', '问题'), ('1', '段落'), ('2', '标题')], default='0', max_length=5, verbose_name='资源类型'), + ), + ] diff --git a/apps/embedding/models/embedding.py b/apps/embedding/models/embedding.py index 83f786c68..caa6d7247 100644 --- a/apps/embedding/models/embedding.py +++ b/apps/embedding/models/embedding.py @@ -6,6 +6,7 @@ @date:2023/9/21 15:46 @desc: """ +from django.contrib.postgres.fields import ArrayField from django.db import models from common.field.vector_field import VectorField @@ -16,6 +17,7 @@ class SourceType(models.TextChoices): """订单类型""" PROBLEM = 0, '问题' PARAGRAPH = 1, '段落' + TITLE = 2, '标题' class Embedding(models.Model): @@ -36,10 +38,10 @@ class Embedding(models.Model): embedding = VectorField(verbose_name="向量") - star_num = models.IntegerField(default=0, verbose_name="点赞数量") + keywords = ArrayField(verbose_name="关键词列表", + base_field=models.CharField(max_length=256), default=list) - trample_num = models.IntegerField(default=0, - verbose_name="点踩数量") + meta = models.JSONField(verbose_name="元数据", default=dict) class Meta: db_table = "embedding" diff --git a/apps/embedding/sql/embedding_search.sql b/apps/embedding/sql/embedding_search.sql index 7787eb60b..ce3d4a580 100644 --- a/apps/embedding/sql/embedding_search.sql +++ b/apps/embedding/sql/embedding_search.sql @@ -1,15 +1,17 @@ -SELECT * FROM (SELECT - *, - ( 1 - ( embedding.embedding <=> %s ) ) AS similarity, -CASE - - WHEN embedding.star_num - embedding.trample_num = 0 THEN - 0 ELSE ( ( ( embedding.star_num - embedding.trample_num ) - aggs.min_value ) / ( aggs.max_value - aggs.min_value ) ) - END AS score +SELECT + paragraph_id, + comprehensive_score, + comprehensive_score as similarity FROM - embedding, - ( SELECT MIN ( star_num - trample_num ) AS min_value, MAX ( star_num - trample_num ) AS max_value FROM embedding ${embedding_query}) aggs - ${embedding_query} - ) temp - WHERE similarity>0.5 - ORDER BY (similarity + score) DESC LIMIT 1 \ No newline at end of file + ( + SELECT DISTINCT ON + ("paragraph_id") ( similarity ),* ,similarity AS comprehensive_score + FROM + ( SELECT *, ( 1 - ( embedding.embedding <=> %s ) ) AS similarity FROM embedding ${embedding_query}) TEMP + ORDER BY + paragraph_id, + similarity DESC + ) DISTINCT_TEMP +WHERE comprehensive_score>%s +ORDER BY comprehensive_score DESC +LIMIT %s \ No newline at end of file diff --git a/apps/embedding/sql/hit_test.sql b/apps/embedding/sql/hit_test.sql index 141feee4e..8feffc86d 100644 --- a/apps/embedding/sql/hit_test.sql +++ b/apps/embedding/sql/hit_test.sql @@ -1,34 +1,17 @@ SELECT - similarity, paragraph_id, - comprehensive_score + comprehensive_score, + comprehensive_score as similarity FROM ( SELECT DISTINCT ON - ( "paragraph_id" ) ( similarity + score ),*, - ( similarity + score ) AS comprehensive_score + ("paragraph_id") ( similarity ),* ,similarity AS comprehensive_score FROM - ( - SELECT - *, - ( 1 - ( embedding.embedding <=> %s ) ) AS similarity, - CASE - - WHEN embedding.star_num - embedding.trample_num = 0 THEN - 0 ELSE ( ( ( embedding.star_num - embedding.trample_num ) - aggs.min_value ) / ( aggs.max_value - aggs.min_value ) ) - END AS score - FROM - embedding, - ( SELECT MIN ( star_num - trample_num ) AS min_value, MAX ( star_num - trample_num ) AS max_value FROM embedding ${embedding_query}) aggs - ${embedding_query} - ) TEMP - WHERE - similarity > %s + ( SELECT *, ( 1 - ( embedding.embedding <=> %s ) ) AS similarity FROM embedding ${embedding_query} ) TEMP ORDER BY paragraph_id, - ( similarity + score ) - DESC - ) ss -ORDER BY - comprehensive_score DESC - LIMIT %s \ No newline at end of file + similarity DESC + ) DISTINCT_TEMP +WHERE comprehensive_score>%s +ORDER BY comprehensive_score DESC +LIMIT %s \ No newline at end of file diff --git a/apps/embedding/vector/base_vector.py b/apps/embedding/vector/base_vector.py index a5c14f413..ce96aa3df 100644 --- a/apps/embedding/vector/base_vector.py +++ b/apps/embedding/vector/base_vector.py @@ -99,8 +99,6 @@ class BaseVectorStore(ABC): @abstractmethod def _save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str, is_active: bool, - star_num: int, - trample_num: int, embedding: HuggingFaceEmbeddings): pass @@ -108,11 +106,20 @@ class BaseVectorStore(ABC): def _batch_save(self, text_list: List[Dict], embedding: HuggingFaceEmbeddings): pass - @abstractmethod def search(self, query_text, dataset_id_list: list[str], exclude_document_id_list: list[str], - exclude_id_list: list[str], + exclude_paragraph_list: list[str], is_active: bool, embedding: HuggingFaceEmbeddings): + if dataset_id_list is None or len(dataset_id_list) == 0: + return [] + embedding_query = embedding.embed_query(query_text) + result = self.query(embedding_query, dataset_id_list, exclude_document_id_list, exclude_paragraph_list, + is_active, 1, 0.65) + return result[0] + + @abstractmethod + def query(self, query_embedding: List[float], dataset_id_list: list[str], exclude_document_id_list: list[str], + exclude_paragraph_list: list[str], is_active: bool, top_n: int, similarity: float): pass @abstractmethod diff --git a/apps/embedding/vector/pg_vector.py b/apps/embedding/vector/pg_vector.py index 79f8be9bf..0091e31da 100644 --- a/apps/embedding/vector/pg_vector.py +++ b/apps/embedding/vector/pg_vector.py @@ -33,8 +33,6 @@ class PGVector(BaseVectorStore): def _save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str, is_active: bool, - star_num: int, - trample_num: int, embedding: HuggingFaceEmbeddings): text_embedding = embedding.embed_query(text) embedding = Embedding(id=uuid.uuid1(), @@ -44,10 +42,7 @@ class PGVector(BaseVectorStore): paragraph_id=paragraph_id, source_id=source_id, embedding=text_embedding, - source_type=source_type, - star_num=star_num, - trample_num=trample_num - ) + source_type=source_type) embedding.save() return True @@ -61,8 +56,6 @@ class PGVector(BaseVectorStore): is_active=text_list[index].get('is_active', True), source_id=text_list[index].get('source_id'), source_type=text_list[index].get('source_type'), - star_num=text_list[index].get('star_num'), - trample_num=text_list[index].get('trample_num'), embedding=embeddings[index]) for index in range(0, len(text_list))] QuerySet(Embedding).bulk_create(embedding_list) if len(embedding_list) > 0 else None @@ -78,29 +71,27 @@ class PGVector(BaseVectorStore): 'hit_test.sql')), with_table_name=True) embedding_model = select_list(exec_sql, - [json.dumps(embedding_query), *exec_params, *exec_params, similarity, top_number]) + [json.dumps(embedding_query), *exec_params, similarity, top_number]) return embedding_model - def search(self, query_text, dataset_id_list: list[str], exclude_document_id_list: list[str], - exclude_id_list: list[str], - is_active: bool, - embedding: HuggingFaceEmbeddings): + def query(self, query_embedding: List[float], dataset_id_list: list[str], exclude_document_id_list: list[str], + exclude_paragraph_list: list[str], is_active: bool, top_n: int, similarity: float): exclude_dict = {} if dataset_id_list is None or len(dataset_id_list) == 0: - return None + return [] query_set = QuerySet(Embedding).filter(dataset_id__in=dataset_id_list, is_active=is_active) - embedding_query = embedding.embed_query(query_text) if exclude_document_id_list is not None and len(exclude_document_id_list) > 0: exclude_dict.__setitem__('document_id__in', exclude_document_id_list) - if exclude_id_list is not None and len(exclude_id_list) > 0: - exclude_dict.__setitem__('id__in', exclude_id_list) + if exclude_paragraph_list is not None and len(exclude_paragraph_list) > 0: + exclude_dict.__setitem__('paragraph_id__in', exclude_paragraph_list) query_set = query_set.exclude(**exclude_dict) exec_sql, exec_params = generate_sql_by_query_dict({'embedding_query': query_set}, select_string=get_file_content( os.path.join(PROJECT_DIR, "apps", "embedding", 'sql', 'embedding_search.sql')), with_table_name=True) - embedding_model = select_one(exec_sql, (json.dumps(embedding_query), *exec_params, *exec_params)) + embedding_model = select_list(exec_sql, + [json.dumps(query_embedding), *exec_params, similarity, top_n]) return embedding_model def update_by_source_id(self, source_id: str, instance: Dict): diff --git a/apps/setting/models_provider/impl/wenxin_model_provider/model/qian_fan_chat_model.py b/apps/setting/models_provider/impl/wenxin_model_provider/model/qian_fan_chat_model.py index bd6a5d1cf..3aeb0cdb3 100644 --- a/apps/setting/models_provider/impl/wenxin_model_provider/model/qian_fan_chat_model.py +++ b/apps/setting/models_provider/impl/wenxin_model_provider/model/qian_fan_chat_model.py @@ -14,13 +14,23 @@ from langchain.chat_models.base import BaseChatModel from langchain.load import dumpd from langchain.schema import LLMResult from langchain.schema.language_model import LanguageModelInput -from langchain.schema.messages import BaseMessageChunk, BaseMessage, HumanMessage, AIMessage +from langchain.schema.messages import BaseMessageChunk, BaseMessage, HumanMessage, AIMessage, get_buffer_string from langchain.schema.output import ChatGenerationChunk from langchain.schema.runnable import RunnableConfig +from transformers import GPT2TokenizerFast + +tokenizer = GPT2TokenizerFast.from_pretrained('gpt2', cache_dir="/opt/maxkb/model/tokenizer", resume_download=False, + force_download=False) class QianfanChatModel(QianfanChatEndpoint): + def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: + return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages]) + + def get_num_tokens(self, text: str) -> int: + return len(tokenizer.encode(text)) + def stream( self, input: LanguageModelInput, @@ -30,7 +40,7 @@ class QianfanChatModel(QianfanChatEndpoint): **kwargs: Any, ) -> Iterator[BaseMessageChunk]: if len(input) % 2 == 0: - input = [HumanMessage(content='占位'), *input] + input = [HumanMessage(content='padding'), *input] input = [ HumanMessage(content=input[index].content) if index % 2 == 0 else AIMessage(content=input[index].content) for index in range(0, len(input))]