diff --git a/apps/application/chat_pipeline/I_base_chat_pipeline.py b/apps/application/chat_pipeline/I_base_chat_pipeline.py
new file mode 100644
index 000000000..eaa255d13
--- /dev/null
+++ b/apps/application/chat_pipeline/I_base_chat_pipeline.py
@@ -0,0 +1,55 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: I_base_chat_pipeline.py
+ @date:2024/1/9 17:25
+ @desc:
+"""
+import time
+from abc import abstractmethod
+from typing import Type
+
+from rest_framework import serializers
+
+
+class IBaseChatPipelineStep:
+ def __init__(self):
+ # 当前步骤上下文,用于存储当前步骤信息
+ self.context = {}
+
+ @abstractmethod
+ def get_step_serializer(self, manage) -> Type[serializers.Serializer]:
+ pass
+
+ def valid_args(self, manage):
+ step_serializer_clazz = self.get_step_serializer(manage)
+ step_serializer = step_serializer_clazz(data=manage.context)
+ step_serializer.is_valid(raise_exception=True)
+ self.context['step_args'] = step_serializer.data
+
+ def run(self, manage):
+ """
+
+ :param manage: 步骤管理器
+ :return: 执行结果
+ """
+ start_time = time.time()
+ # 校验参数,
+ self.valid_args(manage)
+ self._run(manage)
+ self.context['start_time'] = start_time
+ self.context['run_time'] = time.time() - start_time
+
+ def _run(self, manage):
+ pass
+
+ def execute(self, **kwargs):
+ pass
+
+ def get_details(self, manage, **kwargs):
+ """
+ 运行详情
+ :return: 步骤详情
+ """
+ return None
diff --git a/apps/application/chat_pipeline/__init__.py b/apps/application/chat_pipeline/__init__.py
new file mode 100644
index 000000000..719a7e29c
--- /dev/null
+++ b/apps/application/chat_pipeline/__init__.py
@@ -0,0 +1,8 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: __init__.py.py
+ @date:2024/1/9 17:23
+ @desc:
+"""
diff --git a/apps/application/chat_pipeline/pipeline_manage.py b/apps/application/chat_pipeline/pipeline_manage.py
new file mode 100644
index 000000000..2b94290d2
--- /dev/null
+++ b/apps/application/chat_pipeline/pipeline_manage.py
@@ -0,0 +1,45 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: pipeline_manage.py
+ @date:2024/1/9 17:40
+ @desc:
+"""
+import time
+from functools import reduce
+from typing import List, Type, Dict
+
+from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep
+
+
+class PiplineManage:
+ def __init__(self, step_list: List[Type[IBaseChatPipelineStep]]):
+ # 步骤执行器
+ self.step_list = [step() for step in step_list]
+ # 上下文
+ self.context = {'message_tokens': 0, 'answer_tokens': 0}
+
+ def run(self, context: Dict = None):
+ self.context['start_time'] = time.time()
+ if context is not None:
+ for key, value in context.items():
+ self.context[key] = value
+ for step in self.step_list:
+ step.run(self)
+
+ def get_details(self):
+ return reduce(lambda x, y: {**x, **y}, [{item.get('step_type'): item} for item in
+ filter(lambda r: r is not None,
+ [row.get_details(self) for row in self.step_list])], {})
+
+ class builder:
+ def __init__(self):
+ self.step_list: List[Type[IBaseChatPipelineStep]] = []
+
+ def append_step(self, step: Type[IBaseChatPipelineStep]):
+ self.step_list.append(step)
+ return self
+
+ def build(self):
+ return PiplineManage(step_list=self.step_list)
diff --git a/apps/application/chat_pipeline/step/__init__.py b/apps/application/chat_pipeline/step/__init__.py
new file mode 100644
index 000000000..5d9549cdc
--- /dev/null
+++ b/apps/application/chat_pipeline/step/__init__.py
@@ -0,0 +1,8 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: __init__.py.py
+ @date:2024/1/9 18:23
+ @desc:
+"""
diff --git a/apps/application/chat_pipeline/step/chat_step/__init__.py b/apps/application/chat_pipeline/step/chat_step/__init__.py
new file mode 100644
index 000000000..5d9549cdc
--- /dev/null
+++ b/apps/application/chat_pipeline/step/chat_step/__init__.py
@@ -0,0 +1,8 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: __init__.py.py
+ @date:2024/1/9 18:23
+ @desc:
+"""
diff --git a/apps/application/chat_pipeline/step/chat_step/i_chat_step.py b/apps/application/chat_pipeline/step/chat_step/i_chat_step.py
new file mode 100644
index 000000000..b02273ae8
--- /dev/null
+++ b/apps/application/chat_pipeline/step/chat_step/i_chat_step.py
@@ -0,0 +1,88 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: i_chat_step.py
+ @date:2024/1/9 18:17
+ @desc: 对话
+"""
+from abc import abstractmethod
+from typing import Type, List
+
+from langchain.chat_models.base import BaseChatModel
+from langchain.schema import BaseMessage
+from rest_framework import serializers
+
+from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep
+from application.chat_pipeline.pipeline_manage import PiplineManage
+from common.field.common import InstanceField
+from dataset.models import Paragraph
+
+
+class ModelField(serializers.Field):
+ def to_internal_value(self, data):
+ if not isinstance(data, BaseChatModel):
+ self.fail('模型类型错误', value=data)
+ return data
+
+ def to_representation(self, value):
+ return value
+
+
+class MessageField(serializers.Field):
+ def to_internal_value(self, data):
+ if not isinstance(data, BaseMessage):
+ self.fail('message类型错误', value=data)
+ return data
+
+ def to_representation(self, value):
+ return value
+
+
+class PostResponseHandler:
+ @abstractmethod
+ def handler(self, chat_id, chat_record_id, paragraph_list: List[Paragraph], problem_text: str, answer_text,
+ manage, step, padding_problem_text: str = None, **kwargs):
+ pass
+
+
+class IChatStep(IBaseChatPipelineStep):
+ class InstanceSerializer(serializers.Serializer):
+ # 对话列表
+ message_list = serializers.ListField(required=True, child=MessageField(required=True))
+ # 大语言模型
+ chat_model = ModelField()
+ # 段落列表
+ paragraph_list = serializers.ListField()
+ # 对话id
+ chat_id = serializers.UUIDField(required=True)
+ # 用户问题
+ problem_text = serializers.CharField(required=True)
+ # 后置处理器
+ post_response_handler = InstanceField(model_type=PostResponseHandler)
+ # 补全问题
+ padding_problem_text = serializers.CharField(required=False)
+
+ def is_valid(self, *, raise_exception=False):
+ super().is_valid(raise_exception=True)
+ message_list: List = self.initial_data.get('message_list')
+ for message in message_list:
+ if not isinstance(message, BaseMessage):
+ raise Exception("message 类型错误")
+
+ def get_step_serializer(self, manage: PiplineManage) -> Type[serializers.Serializer]:
+ return self.InstanceSerializer
+
+ def _run(self, manage: PiplineManage):
+ chat_result = self.execute(**self.context['step_args'], manage=manage)
+ manage.context['chat_result'] = chat_result
+
+ @abstractmethod
+ def execute(self, message_list: List[BaseMessage],
+ chat_id, problem_text,
+ post_response_handler: PostResponseHandler,
+ chat_model: BaseChatModel = None,
+ paragraph_list=None,
+ manage: PiplineManage = None,
+ padding_problem_text: str = None, **kwargs):
+ pass
diff --git a/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py b/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py
new file mode 100644
index 000000000..123574459
--- /dev/null
+++ b/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py
@@ -0,0 +1,111 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: base_chat_step.py
+ @date:2024/1/9 18:25
+ @desc: 对话step Base实现
+"""
+import json
+import logging
+import time
+import traceback
+import uuid
+from typing import List
+
+from django.http import StreamingHttpResponse
+from langchain.chat_models.base import BaseChatModel
+from langchain.schema import BaseMessage
+from langchain.schema.messages import BaseMessageChunk, HumanMessage
+
+from application.chat_pipeline.pipeline_manage import PiplineManage
+from application.chat_pipeline.step.chat_step.i_chat_step import IChatStep, PostResponseHandler
+from dataset.models import Paragraph
+
+
+def event_content(response,
+ chat_id,
+ chat_record_id,
+ paragraph_list: List[Paragraph],
+ post_response_handler: PostResponseHandler,
+ manage,
+ step,
+ chat_model,
+ message_list: List[BaseMessage],
+ problem_text: str,
+ padding_problem_text: str = None):
+ all_text = ''
+ try:
+ for chunk in response:
+ all_text += chunk.content
+ yield 'data: ' + json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
+ 'content': chunk.content, 'is_end': False}) + "\n\n"
+
+ yield 'data: ' + json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
+ 'content': '', 'is_end': True}) + "\n\n"
+ # 获取token
+ request_token = chat_model.get_num_tokens_from_messages(message_list)
+ response_token = chat_model.get_num_tokens(all_text)
+ step.context['message_tokens'] = request_token
+ step.context['answer_tokens'] = response_token
+ current_time = time.time()
+ step.context['answer_text'] = all_text
+ step.context['run_time'] = current_time - step.context['start_time']
+ manage.context['run_time'] = current_time - manage.context['start_time']
+ manage.context['message_tokens'] = manage.context['message_tokens'] + request_token
+ manage.context['answer_tokens'] = manage.context['answer_tokens'] + response_token
+ post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text,
+ all_text, manage, step, padding_problem_text)
+ except Exception as e:
+ logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}')
+ yield 'data: ' + json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
+ 'content': '异常' + str(e), 'is_end': True}) + "\n\n"
+
+
+class BaseChatStep(IChatStep):
+ def execute(self, message_list: List[BaseMessage],
+ chat_id,
+ problem_text,
+ post_response_handler: PostResponseHandler,
+ chat_model: BaseChatModel = None,
+ paragraph_list=None,
+ manage: PiplineManage = None,
+ padding_problem_text: str = None,
+ **kwargs):
+ # 调用模型
+ if chat_model is None:
+ chat_result = iter(
+ [BaseMessageChunk(content=paragraph.title + "\n" + paragraph.content) for paragraph in paragraph_list])
+ else:
+ chat_result = chat_model.stream(message_list)
+
+ chat_record_id = uuid.uuid1()
+ r = StreamingHttpResponse(
+ streaming_content=event_content(chat_result, chat_id, chat_record_id, paragraph_list,
+ post_response_handler, manage, self, chat_model, message_list, problem_text,
+ padding_problem_text),
+ content_type='text/event-stream;charset=utf-8')
+
+ r['Cache-Control'] = 'no-cache'
+ return r
+
+ def get_details(self, manage, **kwargs):
+ return {
+ 'step_type': 'chat_step',
+ 'run_time': self.context['run_time'],
+ 'model_id': str(manage.context['model_id']),
+ 'message_list': self.reset_message_list(self.context['step_args'].get('message_list'),
+ self.context['answer_text']),
+ 'message_tokens': self.context['message_tokens'],
+ 'answer_tokens': self.context['answer_tokens'],
+ 'cost': 0,
+ }
+
+ @staticmethod
+ def reset_message_list(message_list: List[BaseMessage], answer_text):
+ result = [{'role': 'user' if isinstance(message, HumanMessage) else 'ai', 'content': message.content} for
+ message
+ in
+ message_list]
+ result.append({'role': 'ai', 'content': answer_text})
+ return result
diff --git a/apps/application/chat_pipeline/step/generate_human_message_step/__init__.py b/apps/application/chat_pipeline/step/generate_human_message_step/__init__.py
new file mode 100644
index 000000000..5d9549cdc
--- /dev/null
+++ b/apps/application/chat_pipeline/step/generate_human_message_step/__init__.py
@@ -0,0 +1,8 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: __init__.py.py
+ @date:2024/1/9 18:23
+ @desc:
+"""
diff --git a/apps/application/chat_pipeline/step/generate_human_message_step/i_generate_human_message_step.py b/apps/application/chat_pipeline/step/generate_human_message_step/i_generate_human_message_step.py
new file mode 100644
index 000000000..e26b06347
--- /dev/null
+++ b/apps/application/chat_pipeline/step/generate_human_message_step/i_generate_human_message_step.py
@@ -0,0 +1,68 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: i_generate_human_message_step.py
+ @date:2024/1/9 18:15
+ @desc: 生成对话模板
+"""
+from abc import abstractmethod
+from typing import Type, List
+
+from langchain.schema import BaseMessage
+from rest_framework import serializers
+
+from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep
+from application.chat_pipeline.pipeline_manage import PiplineManage
+from application.models import ChatRecord
+from common.field.common import InstanceField
+from dataset.models import Paragraph
+
+
+class IGenerateHumanMessageStep(IBaseChatPipelineStep):
+ class InstanceSerializer(serializers.Serializer):
+ # 问题
+ problem_text = serializers.CharField(required=True)
+ # 段落列表
+ paragraph_list = serializers.ListField(child=InstanceField(model_type=Paragraph, required=True))
+ # 历史对答
+ history_chat_record = serializers.ListField(child=InstanceField(model_type=ChatRecord, required=True))
+ # 多轮对话数量
+ dialogue_number = serializers.IntegerField(required=True)
+ # 最大携带知识库段落长度
+ max_paragraph_char_number = serializers.IntegerField(required=True)
+ # 模板
+ prompt = serializers.CharField(required=True)
+ # 补齐问题
+ padding_problem_text = serializers.CharField(required=False)
+
+ def get_step_serializer(self, manage: PiplineManage) -> Type[serializers.Serializer]:
+ return self.InstanceSerializer
+
+ def _run(self, manage: PiplineManage):
+ message_list = self.execute(**self.context['step_args'])
+ manage.context['message_list'] = message_list
+
+ @abstractmethod
+ def execute(self,
+ problem_text: str,
+ paragraph_list: List[Paragraph],
+ history_chat_record: List[ChatRecord],
+ dialogue_number: int,
+ max_paragraph_char_number: int,
+ prompt: str,
+ padding_problem_text: str = None,
+ **kwargs) -> List[BaseMessage]:
+ """
+
+ :param problem_text: 原始问题文本
+ :param paragraph_list: 段落列表
+ :param history_chat_record: 历史对话记录
+ :param dialogue_number: 多轮对话数量
+ :param max_paragraph_char_number: 最大段落长度
+ :param prompt: 模板
+ :param padding_problem_text 用户修改文本
+ :param kwargs: 其他参数
+ :return:
+ """
+ pass
diff --git a/apps/application/chat_pipeline/step/generate_human_message_step/impl/base_generate_human_message_step.py b/apps/application/chat_pipeline/step/generate_human_message_step/impl/base_generate_human_message_step.py
new file mode 100644
index 000000000..fcd35ac75
--- /dev/null
+++ b/apps/application/chat_pipeline/step/generate_human_message_step/impl/base_generate_human_message_step.py
@@ -0,0 +1,57 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: base_generate_human_message_step.py.py
+ @date:2024/1/10 17:50
+ @desc:
+"""
+from typing import List
+
+from langchain.schema import BaseMessage, HumanMessage
+
+from application.chat_pipeline.step.generate_human_message_step.i_generate_human_message_step import \
+ IGenerateHumanMessageStep
+from application.models import ChatRecord
+from common.util.split_model import flat_map
+from dataset.models import Paragraph
+
+
+class BaseGenerateHumanMessageStep(IGenerateHumanMessageStep):
+
+ def execute(self, problem_text: str,
+ paragraph_list: List[Paragraph],
+ history_chat_record: List[ChatRecord],
+ dialogue_number: int,
+ max_paragraph_char_number: int,
+ prompt: str,
+ padding_problem_text: str = None,
+ **kwargs) -> List[BaseMessage]:
+ exec_problem_text = padding_problem_text if padding_problem_text is not None else problem_text
+ start_index = len(history_chat_record) - dialogue_number
+ history_message = [[history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()]
+ for index in
+ range(start_index if start_index > 0 else 0, len(history_chat_record))]
+ return [*flat_map(history_message),
+ self.to_human_message(prompt, exec_problem_text, max_paragraph_char_number, paragraph_list)]
+
+ @staticmethod
+ def to_human_message(prompt: str,
+ problem: str,
+ max_paragraph_char_number: int,
+ paragraph_list: List[Paragraph]):
+ if paragraph_list is None or len(paragraph_list) == 0:
+ return HumanMessage(content=problem)
+ temp_data = ""
+ data_list = []
+ for p in paragraph_list:
+ content = f"{p.title}:{p.content}"
+ temp_data += content
+ if len(temp_data) > max_paragraph_char_number:
+ row_data = content[0:max_paragraph_char_number - len(temp_data)]
+ data_list.append(f"{row_data}")
+ break
+ else:
+ data_list.append(f"{content}")
+ data = "\n".join(data_list)
+ return HumanMessage(content=prompt.format(**{'data': data, 'question': problem}))
diff --git a/apps/application/chat_pipeline/step/reset_problem_step/__init__.py b/apps/application/chat_pipeline/step/reset_problem_step/__init__.py
new file mode 100644
index 000000000..5d9549cdc
--- /dev/null
+++ b/apps/application/chat_pipeline/step/reset_problem_step/__init__.py
@@ -0,0 +1,8 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: __init__.py.py
+ @date:2024/1/9 18:23
+ @desc:
+"""
diff --git a/apps/application/chat_pipeline/step/reset_problem_step/i_reset_problem_step.py b/apps/application/chat_pipeline/step/reset_problem_step/i_reset_problem_step.py
new file mode 100644
index 000000000..c8f9143dd
--- /dev/null
+++ b/apps/application/chat_pipeline/step/reset_problem_step/i_reset_problem_step.py
@@ -0,0 +1,49 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: i_reset_problem_step.py
+ @date:2024/1/9 18:12
+ @desc: 重写处理问题
+"""
+from abc import abstractmethod
+from typing import Type, List
+
+from langchain.chat_models.base import BaseChatModel
+from rest_framework import serializers
+
+from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep
+from application.chat_pipeline.pipeline_manage import PiplineManage
+from application.chat_pipeline.step.chat_step.i_chat_step import ModelField
+from application.models import ChatRecord
+from common.field.common import InstanceField
+
+
+class IResetProblemStep(IBaseChatPipelineStep):
+ class InstanceSerializer(serializers.Serializer):
+ # 问题文本
+ problem_text = serializers.CharField(required=True)
+ # 历史对答
+ history_chat_record = serializers.ListField(child=InstanceField(model_type=ChatRecord, required=True))
+ # 大语言模型
+ chat_model = ModelField()
+
+ def get_step_serializer(self, manage: PiplineManage) -> Type[serializers.Serializer]:
+ return self.InstanceSerializer
+
+ def _run(self, manage: PiplineManage):
+ padding_problem = self.execute(**self.context.get('step_args'))
+ # 用户输入问题
+ source_problem_text = self.context.get('step_args').get('problem_text')
+ self.context['problem_text'] = source_problem_text
+ self.context['padding_problem_text'] = padding_problem
+ manage.context['problem_text'] = source_problem_text
+ manage.context['padding_problem_text'] = padding_problem
+ # 累加tokens
+ manage.context['message_tokens'] = manage.context['message_tokens'] + self.context.get('message_tokens')
+ manage.context['answer_tokens'] = manage.context['answer_tokens'] + self.context.get('answer_tokens')
+
+ @abstractmethod
+ def execute(self, problem_text: str, history_chat_record: List[ChatRecord] = None, chat_model: BaseChatModel = None,
+ **kwargs):
+ pass
diff --git a/apps/application/chat_pipeline/step/reset_problem_step/impl/base_reset_problem_step.py b/apps/application/chat_pipeline/step/reset_problem_step/impl/base_reset_problem_step.py
new file mode 100644
index 000000000..14a68d49a
--- /dev/null
+++ b/apps/application/chat_pipeline/step/reset_problem_step/impl/base_reset_problem_step.py
@@ -0,0 +1,47 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: base_reset_problem_step.py
+ @date:2024/1/10 14:35
+ @desc:
+"""
+from typing import List
+
+from langchain.chat_models.base import BaseChatModel
+from langchain.schema import HumanMessage
+
+from application.chat_pipeline.step.reset_problem_step.i_reset_problem_step import IResetProblemStep
+from application.models import ChatRecord
+from common.util.split_model import flat_map
+
+prompt = (
+ '()里面是用户问题,根据上下文回答揣测用户问题({question}) 要求: 输出一个补全问题,并且放在标签中')
+
+
+class BaseResetProblemStep(IResetProblemStep):
+ def execute(self, problem_text: str, history_chat_record: List[ChatRecord] = None, chat_model: BaseChatModel = None,
+ **kwargs) -> str:
+ start_index = len(history_chat_record) - 3
+ history_message = [[history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()]
+ for index in
+ range(start_index if start_index > 0 else 0, len(history_chat_record))]
+ message_list = [*flat_map(history_message),
+ HumanMessage(content=prompt.format(**{'question': problem_text}))]
+ response = chat_model(message_list)
+ padding_problem = response.content[response.content.index('') + 6:response.content.index('')]
+ self.context['message_tokens'] = chat_model.get_num_tokens_from_messages(message_list)
+ self.context['answer_tokens'] = chat_model.get_num_tokens(padding_problem)
+ return padding_problem
+
+ def get_details(self, manage, **kwargs):
+ return {
+ 'step_type': 'problem_padding',
+ 'run_time': self.context['run_time'],
+ 'model_id': str(manage.context['model_id']) if 'model_id' in manage.context else None,
+ 'message_tokens': self.context['message_tokens'],
+ 'answer_tokens': self.context['answer_tokens'],
+ 'cost': 0,
+ 'padding_problem_text': self.context.get('padding_problem_text'),
+ 'problem_text': self.context.get("step_args").get('problem_text'),
+ }
diff --git a/apps/application/chat_pipeline/step/search_dataset_step/__init__.py b/apps/application/chat_pipeline/step/search_dataset_step/__init__.py
new file mode 100644
index 000000000..023c4bc38
--- /dev/null
+++ b/apps/application/chat_pipeline/step/search_dataset_step/__init__.py
@@ -0,0 +1,8 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: __init__.py.py
+ @date:2024/1/9 18:24
+ @desc:
+"""
diff --git a/apps/application/chat_pipeline/step/search_dataset_step/i_search_dataset_step.py b/apps/application/chat_pipeline/step/search_dataset_step/i_search_dataset_step.py
new file mode 100644
index 000000000..a79bad8ba
--- /dev/null
+++ b/apps/application/chat_pipeline/step/search_dataset_step/i_search_dataset_step.py
@@ -0,0 +1,58 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: i_search_dataset_step.py
+ @date:2024/1/9 18:10
+ @desc: 检索知识库
+"""
+from abc import abstractmethod
+from typing import List, Type
+
+from rest_framework import serializers
+
+from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep
+from application.chat_pipeline.pipeline_manage import PiplineManage
+from dataset.models import Paragraph
+
+
+class ISearchDatasetStep(IBaseChatPipelineStep):
+ class InstanceSerializer(serializers.Serializer):
+ # 原始问题文本
+ problem_text = serializers.CharField(required=True)
+ # 系统补全问题文本
+ padding_problem_text = serializers.CharField(required=False)
+ # 需要查询的数据集id列表
+ dataset_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True))
+ # 需要排除的文档id
+ exclude_document_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True))
+ # 需要排除向量id
+ exclude_paragraph_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True))
+ # 需要查询的条数
+ top_n = serializers.IntegerField(required=True)
+ # 相似度 0-1之间
+ similarity = serializers.FloatField(required=True, max_value=1, min_value=0)
+
+ def get_step_serializer(self, manage: PiplineManage) -> Type[InstanceSerializer]:
+ return self.InstanceSerializer
+
+ def _run(self, manage: PiplineManage):
+ paragraph_list = self.execute(**self.context['step_args'])
+ manage.context['paragraph_list'] = paragraph_list
+
+ @abstractmethod
+ def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str],
+ exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None,
+ **kwargs) -> List[Paragraph]:
+ """
+ 关于 用户和补全问题 说明: 补全问题如果有就使用补全问题去查询 反之就用用户原始问题查询
+ :param similarity: 相关性
+ :param top_n: 查询多少条
+ :param problem_text: 用户问题
+ :param dataset_id_list: 需要查询的数据集id列表
+ :param exclude_document_id_list: 需要排除的文档id
+ :param exclude_paragraph_id_list: 需要排除段落id
+ :param padding_problem_text 补全问题
+ :return: 段落列表
+ """
+ pass
diff --git a/apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py b/apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py
new file mode 100644
index 000000000..4fa13639a
--- /dev/null
+++ b/apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py
@@ -0,0 +1,58 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: base_search_dataset_step.py
+ @date:2024/1/10 10:33
+ @desc:
+"""
+from typing import List
+
+from django.db.models import QuerySet
+
+from application.chat_pipeline.step.search_dataset_step.i_search_dataset_step import ISearchDatasetStep
+from common.config.embedding_config import VectorStore, EmbeddingModel
+from dataset.models import Paragraph
+
+
+class BaseSearchDatasetStep(ISearchDatasetStep):
+
+ def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str],
+ exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None,
+ **kwargs) -> List[Paragraph]:
+ exec_problem_text = padding_problem_text if padding_problem_text is not None else problem_text
+ embedding_model = EmbeddingModel.get_embedding_model()
+ embedding_value = embedding_model.embed_query(exec_problem_text)
+ vector = VectorStore.get_embedding_vector()
+ embedding_list = vector.query(embedding_value, dataset_id_list, exclude_document_id_list,
+ exclude_paragraph_id_list, True, top_n, similarity)
+ if embedding_list is None:
+ return []
+ return self.list_paragraph([row.get('paragraph_id') for row in embedding_list], vector)
+
+ @staticmethod
+ def list_paragraph(paragraph_id_list: List, vector):
+ if paragraph_id_list is None or len(paragraph_id_list) == 0:
+ return []
+ paragraph_list = QuerySet(Paragraph).filter(id__in=paragraph_id_list)
+ # 如果向量库中存在脏数据 直接删除
+ if len(paragraph_list) != len(paragraph_id_list):
+ exist_paragraph_list = [str(row.id) for row in paragraph_list]
+ for paragraph_id in paragraph_id_list:
+ if not exist_paragraph_list.__contains__(paragraph_id):
+ vector.delete_by_paragraph_id(paragraph_id)
+ return paragraph_list
+
+ def get_details(self, manage, **kwargs):
+ step_args = self.context['step_args']
+
+ return {
+ 'step_type': 'search_step',
+ 'run_time': self.context['run_time'],
+ 'problem_text': step_args.get(
+ 'padding_problem_text') if 'padding_problem_text' in step_args else step_args.get('problem_text'),
+ 'model_name': EmbeddingModel.get_embedding_model().model_name,
+ 'message_tokens': 0,
+ 'answer_tokens': 0,
+ 'cost': 0
+ }
diff --git a/apps/application/migrations/0003_remove_chatrecord_dataset_and_more.py b/apps/application/migrations/0003_remove_chatrecord_dataset_and_more.py
new file mode 100644
index 000000000..ec041426f
--- /dev/null
+++ b/apps/application/migrations/0003_remove_chatrecord_dataset_and_more.py
@@ -0,0 +1,55 @@
+# Generated by Django 4.1.10 on 2024-01-12 18:46
+
+import django.contrib.postgres.fields
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+
+ dependencies = [
+ ('application', '0002_alter_chatrecord_dataset'),
+ ]
+
+ operations = [
+ migrations.RemoveField(
+ model_name='chatrecord',
+ name='dataset',
+ ),
+ migrations.RemoveField(
+ model_name='chatrecord',
+ name='paragraph',
+ ),
+ migrations.RemoveField(
+ model_name='chatrecord',
+ name='source_id',
+ ),
+ migrations.RemoveField(
+ model_name='chatrecord',
+ name='source_type',
+ ),
+ migrations.AddField(
+ model_name='chatrecord',
+ name='const',
+ field=models.IntegerField(default=0, verbose_name='总费用'),
+ ),
+ migrations.AddField(
+ model_name='chatrecord',
+ name='details',
+ field=models.JSONField(default=list, verbose_name='对话详情'),
+ ),
+ migrations.AddField(
+ model_name='chatrecord',
+ name='paragraph_id_list',
+ field=django.contrib.postgres.fields.ArrayField(base_field=models.UUIDField(blank=True), default=list, size=None, verbose_name='引用段落id列表'),
+ ),
+ migrations.AddField(
+ model_name='chatrecord',
+ name='run_time',
+ field=models.FloatField(default=0, verbose_name='运行时长'),
+ ),
+ migrations.AlterField(
+ model_name='chatrecord',
+ name='answer_text',
+ field=models.CharField(max_length=4096, verbose_name='答案'),
+ ),
+ ]
diff --git a/apps/application/migrations/0004_remove_application_example_and_more.py b/apps/application/migrations/0004_remove_application_example_and_more.py
new file mode 100644
index 000000000..02647f80e
--- /dev/null
+++ b/apps/application/migrations/0004_remove_application_example_and_more.py
@@ -0,0 +1,38 @@
+# Generated by Django 4.1.10 on 2024-01-15 16:07
+
+import application.models.application
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+
+ dependencies = [
+ ('application', '0003_remove_chatrecord_dataset_and_more'),
+ ]
+
+ operations = [
+ migrations.RemoveField(
+ model_name='application',
+ name='example',
+ ),
+ migrations.AddField(
+ model_name='application',
+ name='dataset_setting',
+ field=models.JSONField(default=application.models.application.get_dataset_setting_dict, verbose_name='数据集参数设置'),
+ ),
+ migrations.AddField(
+ model_name='application',
+ name='model_setting',
+ field=models.JSONField(default=application.models.application.get_model_setting_dict, verbose_name='模型参数相关设置'),
+ ),
+ migrations.AddField(
+ model_name='application',
+ name='problem_optimization',
+ field=models.BooleanField(default=False, verbose_name='问题优化'),
+ ),
+ migrations.AlterField(
+ model_name='chatrecord',
+ name='details',
+ field=models.JSONField(default={}, verbose_name='对话详情'),
+ ),
+ ]
diff --git a/apps/application/migrations/0005_alter_chatrecord_details.py b/apps/application/migrations/0005_alter_chatrecord_details.py
new file mode 100644
index 000000000..5deb1e1a1
--- /dev/null
+++ b/apps/application/migrations/0005_alter_chatrecord_details.py
@@ -0,0 +1,18 @@
+# Generated by Django 4.1.10 on 2024-01-16 11:22
+
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+
+ dependencies = [
+ ('application', '0004_remove_application_example_and_more'),
+ ]
+
+ operations = [
+ migrations.AlterField(
+ model_name='chatrecord',
+ name='details',
+ field=models.JSONField(default=dict, verbose_name='对话详情'),
+ ),
+ ]
diff --git a/apps/application/models/application.py b/apps/application/models/application.py
index 5c875296d..03b7eb142 100644
--- a/apps/application/models/application.py
+++ b/apps/application/models/application.py
@@ -10,23 +10,47 @@ import uuid
from django.contrib.postgres.fields import ArrayField
from django.db import models
+from langchain.schema import HumanMessage, AIMessage
from common.mixins.app_model_mixin import AppModelMixin
-from dataset.models.data_set import DataSet, Paragraph
-from embedding.models import SourceType
+from dataset.models.data_set import DataSet
from setting.models.model_management import Model
from users.models import User
+def get_dataset_setting_dict():
+ return {'top_n': 3, 'similarity': 0.6, 'max_paragraph_char_number': 5000}
+
+
+def get_model_setting_dict():
+ return {'prompt': Application.get_default_model_prompt()}
+
+
class Application(AppModelMixin):
id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id")
name = models.CharField(max_length=128, verbose_name="应用名称")
desc = models.CharField(max_length=128, verbose_name="引用描述", default="")
prologue = models.CharField(max_length=1024, verbose_name="开场白", default="")
- example = ArrayField(verbose_name="示例列表", base_field=models.CharField(max_length=256, blank=True), default=list)
dialogue_number = models.IntegerField(default=0, verbose_name="会话数量")
user = models.ForeignKey(User, on_delete=models.DO_NOTHING)
model = models.ForeignKey(Model, on_delete=models.SET_NULL, db_constraint=False, blank=True, null=True)
+ dataset_setting = models.JSONField(verbose_name="数据集参数设置", default=get_dataset_setting_dict)
+ model_setting = models.JSONField(verbose_name="模型参数相关设置", default=get_model_setting_dict)
+ problem_optimization = models.BooleanField(verbose_name="问题优化", default=False)
+
+ @staticmethod
+ def get_default_model_prompt():
+ return ('已知信息:'
+ '\n{data}'
+ '\n回答要求:'
+ '\n- 如果你不知道答案或者没有从获取答案,请回答“没有在知识库中查找到相关信息,建议咨询相关技术支持或参考官方文档进行操作”。'
+ '\n- 避免提及你是从中获得的知识。'
+ '\n- 请保持答案与中描述的一致。'
+ '\n- 请使用markdown 语法优化答案的格式。'
+ '\n- 中的图片链接、链接地址和脚本语言请完整返回。'
+ '\n- 请使用与问题相同的语言来回答。'
+ '\n问题:'
+ '\n{question}')
class Meta:
db_table = "application"
@@ -65,20 +89,28 @@ class ChatRecord(AppModelMixin):
chat = models.ForeignKey(Chat, on_delete=models.CASCADE)
vote_status = models.CharField(verbose_name='投票', max_length=10, choices=VoteChoices.choices,
default=VoteChoices.UN_VOTE)
- dataset = models.ForeignKey(DataSet, on_delete=models.SET_NULL, verbose_name="数据集", blank=True, null=True)
- paragraph = models.ForeignKey(Paragraph, on_delete=models.SET_NULL, verbose_name="段落id", blank=True, null=True)
- source_id = models.UUIDField(max_length=128, verbose_name="资源id 段落/问题 id ", null=True)
- source_type = models.CharField(verbose_name='资源类型', max_length=2, choices=SourceType.choices,
- default=SourceType.PROBLEM, blank=True, null=True)
+ paragraph_id_list = ArrayField(verbose_name="引用段落id列表",
+ base_field=models.UUIDField(max_length=128, blank=True)
+ , default=list)
+ problem_text = models.CharField(max_length=1024, verbose_name="问题")
+ answer_text = models.CharField(max_length=4096, verbose_name="答案")
message_tokens = models.IntegerField(verbose_name="请求token数量", default=0)
answer_tokens = models.IntegerField(verbose_name="响应token数量", default=0)
- problem_text = models.CharField(max_length=1024, verbose_name="问题")
- answer_text = models.CharField(max_length=1024, verbose_name="答案")
+ const = models.IntegerField(verbose_name="总费用", default=0)
+ details = models.JSONField(verbose_name="对话详情", default=dict)
improve_paragraph_id_list = ArrayField(verbose_name="改进标注列表",
base_field=models.UUIDField(max_length=128, blank=True)
, default=list)
-
+ run_time = models.FloatField(verbose_name="运行时长", default=0)
index = models.IntegerField(verbose_name="对话下标")
+ def get_human_message(self):
+ if 'problem_padding' in self.details:
+ return HumanMessage(content=self.details.get('problem_padding').get('padding_problem_text'))
+ return HumanMessage(content=self.problem_text)
+
+ def get_ai_message(self):
+ return AIMessage(content=self.answer_text)
+
class Meta:
db_table = "application_chat_record"
diff --git a/apps/application/serializers/application_serializers.py b/apps/application/serializers/application_serializers.py
index 8bd858358..680360dd4 100644
--- a/apps/application/serializers/application_serializers.py
+++ b/apps/application/serializers/application_serializers.py
@@ -63,15 +63,35 @@ class ApplicationSerializerModel(serializers.ModelSerializer):
fields = "__all__"
+class DatasetSettingSerializer(serializers.Serializer):
+ top_n = serializers.FloatField(required=True)
+ similarity = serializers.FloatField(required=True, max_value=1, min_value=0)
+ max_paragraph_char_number = serializers.IntegerField(required=True, max_value=10000)
+
+
+class ModelSettingSerializer(serializers.Serializer):
+ prompt = serializers.CharField(required=True, max_length=4096)
+
+
class ApplicationSerializer(serializers.Serializer):
name = serializers.CharField(required=True)
desc = serializers.CharField(required=False, allow_null=True, allow_blank=True)
model_id = serializers.CharField(required=True)
multiple_rounds_dialogue = serializers.BooleanField(required=True)
prologue = serializers.CharField(required=False, allow_null=True, allow_blank=True)
- example = serializers.ListSerializer(required=False, child=serializers.CharField(required=True), allow_null=True)
dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True),
allow_null=True)
+ # 数据集相关设置
+ dataset_setting = DatasetSettingSerializer(required=True)
+ # 模型相关设置
+ model_setting = ModelSettingSerializer(required=True)
+ # 问题补全
+ problem_optimization = serializers.BooleanField(required=True)
+
+ def is_valid(self, *, user_id=None, raise_exception=False):
+ super().is_valid(raise_exception=True)
+ ModelDatasetAssociation(data={'user_id': user_id, 'model_id': self.data.get('model_id'),
+ 'dataset_id_list': self.data.get('dataset_id_list')}).is_valid()
class AccessTokenSerializer(serializers.Serializer):
application_id = serializers.UUIDField(required=True)
@@ -135,13 +155,13 @@ class ApplicationSerializer(serializers.Serializer):
model_id = serializers.CharField(required=False)
multiple_rounds_dialogue = serializers.BooleanField(required=False)
prologue = serializers.CharField(required=False, allow_null=True, allow_blank=True)
- example = serializers.ListSerializer(required=False, child=serializers.CharField(required=True))
dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True))
-
- def is_valid(self, *, user_id=None, raise_exception=False):
- super().is_valid(raise_exception=True)
- ModelDatasetAssociation(data={'user_id': user_id, 'model_id': self.data.get('model_id'),
- 'dataset_id_list': self.data.get('dataset_id_list')}).is_valid()
+ # 数据集相关设置
+ dataset_setting = serializers.JSONField(required=False, allow_null=True)
+ # 模型相关设置
+ model_setting = serializers.JSONField(required=False, allow_null=True)
+ # 问题补全
+ problem_optimization = serializers.BooleanField(required=False, allow_null=True)
class Create(serializers.Serializer):
user_id = serializers.UUIDField(required=True)
@@ -168,9 +188,12 @@ class ApplicationSerializer(serializers.Serializer):
@staticmethod
def to_application_model(user_id: str, application: Dict):
return Application(id=uuid.uuid1(), name=application.get('name'), desc=application.get('desc'),
- prologue=application.get('prologue'), example=application.get('example'),
+ prologue=application.get('prologue'),
dialogue_number=3 if application.get('multiple_rounds_dialogue') else 0,
user_id=user_id, model_id=application.get('model_id'),
+ dataset_setting=application.get('dataset_setting'),
+ model_setting=application.get('model_setting'),
+ problem_optimization=application.get('problem_optimization')
)
@staticmethod
@@ -267,7 +290,7 @@ class ApplicationSerializer(serializers.Serializer):
class ApplicationModel(serializers.ModelSerializer):
class Meta:
model = Application
- fields = ['id', 'name', 'desc', 'prologue', 'example', 'dialogue_number']
+ fields = ['id', 'name', 'desc', 'prologue', 'dialogue_number']
class Operate(serializers.Serializer):
application_id = serializers.UUIDField(required=True)
@@ -317,8 +340,9 @@ class ApplicationSerializer(serializers.Serializer):
model = QuerySet(Model).get(id=instance.get('model_id') if 'model_id' in instance else application.model_id)
- update_keys = ['name', 'desc', 'model_id', 'multiple_rounds_dialogue', 'prologue', 'example', 'status',
- 'api_key_is_active']
+ update_keys = ['name', 'desc', 'model_id', 'multiple_rounds_dialogue', 'prologue', 'status',
+ 'dataset_setting', 'model_setting', 'problem_optimization'
+ 'api_key_is_active']
for update_key in update_keys:
if update_key in instance and instance.get(update_key) is not None:
if update_key == 'multiple_rounds_dialogue':
diff --git a/apps/application/serializers/chat_message_serializers.py b/apps/application/serializers/chat_message_serializers.py
index 13eae7cb1..c4069a7d4 100644
--- a/apps/application/serializers/chat_message_serializers.py
+++ b/apps/application/serializers/chat_message_serializers.py
@@ -7,194 +7,181 @@
@desc:
"""
import json
-import uuid
from typing import List
+from uuid import UUID
-from django.db.models import QuerySet
-from django.http import StreamingHttpResponse
-from langchain.chat_models.base import BaseChatModel
-from langchain.schema import HumanMessage
-from rest_framework import serializers, status
from django.core.cache import cache
-from common import event
-from common.config.embedding_config import VectorStore, EmbeddingModel
-from common.response import result
-from dataset.models import Paragraph
-from embedding.models import SourceType
-from setting.models.model_management import Model
+from django.db.models import QuerySet
+from langchain.chat_models.base import BaseChatModel
+from rest_framework import serializers
+
+from application.chat_pipeline.pipeline_manage import PiplineManage
+from application.chat_pipeline.step.chat_step.i_chat_step import PostResponseHandler
+from application.chat_pipeline.step.chat_step.impl.base_chat_step import BaseChatStep
+from application.chat_pipeline.step.generate_human_message_step.impl.base_generate_human_message_step import \
+ BaseGenerateHumanMessageStep
+from application.chat_pipeline.step.reset_problem_step.impl.base_reset_problem_step import BaseResetProblemStep
+from application.chat_pipeline.step.search_dataset_step.impl.base_search_dataset_step import BaseSearchDatasetStep
+from application.models import ChatRecord, Chat, Application, ApplicationDatasetMapping
+from common.exception.app_exception import AppApiException
+from common.util.rsa_util import decrypt
+from common.util.split_model import flat_map
+from dataset.models import Paragraph, Document
+from setting.models import Model
+from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
chat_cache = cache
-class MessageManagement:
- @staticmethod
- def get_message(title: str, content: str, message: str):
- if content is None:
- return HumanMessage(content=message)
- return HumanMessage(content=(
- f'已知信息:{title}:{content} '
- '根据上述已知信息,请简洁和专业的来回答用户的问题。已知信息中的图片、链接地址和脚本语言请直接返回。如果无法从已知信息中得到答案,请说 “没有在知识库中查找到相关信息,建议咨询相关技术支持或参考官方文档进行操作” 或 “根据已知信息无法回答该问题,建议联系官方技术支持人员”,不允许在答案中添加编造成分,答案请使用中文。'
- f'问题是:{message}'))
-
-
-class ChatMessage:
- def __init__(self, id: str, problem: str, title: str, paragraph: str, embedding_id: str, dataset_id: str,
- document_id: str,
- paragraph_id,
- source_type: SourceType,
- source_id: str,
- answer: str,
- message_tokens: int,
- answer_token: int,
- chat_model=None,
- chat_message=None):
- self.id = id
- self.problem = problem
- self.title = title
- self.paragraph = paragraph
- self.embedding_id = embedding_id
- self.dataset_id = dataset_id
- self.document_id = document_id
- self.paragraph_id = paragraph_id
- self.source_type = source_type
- self.source_id = source_id
- self.answer = answer
- self.message_tokens = message_tokens
- self.answer_token = answer_token
- self.chat_model = chat_model
- self.chat_message = chat_message
-
- def get_chat_message(self):
- return MessageManagement.get_message(self.problem, self.paragraph, self.problem)
-
-
class ChatInfo:
def __init__(self,
chat_id: str,
- model: Model,
chat_model: BaseChatModel,
- application_id: str | None,
dataset_id_list: List[str],
exclude_document_id_list: list[str],
- dialogue_number: int):
+ application: Application):
+ """
+ :param chat_id: 对话id
+ :param chat_model: 对话模型
+ :param dataset_id_list: 数据集列表
+ :param exclude_document_id_list: 排除的文档
+ :param application: 应用信息
+ """
self.chat_id = chat_id
- self.application_id = application_id
- self.model = model
+ self.application = application
self.chat_model = chat_model
self.dataset_id_list = dataset_id_list
self.exclude_document_id_list = exclude_document_id_list
- self.dialogue_number = dialogue_number
- self.chat_message_list: List[ChatMessage] = []
+ self.chat_record_list: List[ChatRecord] = []
- def append_chat_message(self, chat_message: ChatMessage):
- self.chat_message_list.append(chat_message)
- if self.application_id is not None:
+ def to_base_pipeline_manage_params(self):
+ dataset_setting = self.application.dataset_setting
+ model_setting = self.application.model_setting
+ return {
+ 'dataset_id_list': self.dataset_id_list,
+ 'exclude_document_id_list': self.exclude_document_id_list,
+ 'exclude_paragraph_id_list': [],
+ 'top_n': dataset_setting.get('top_n') if 'top_n' in dataset_setting else 3,
+ 'similarity': dataset_setting.get('similarity') if 'similarity' in dataset_setting else 0.6,
+ 'max_paragraph_char_number': dataset_setting.get(
+ 'max_paragraph_char_number') if 'max_paragraph_char_number' in dataset_setting else 5000,
+ 'history_chat_record': self.chat_record_list,
+ 'chat_id': self.chat_id,
+ 'dialogue_number': self.application.dialogue_number,
+ 'prompt': model_setting.get(
+ 'prompt') if 'prompt' in model_setting else Application.get_default_model_prompt(),
+ 'chat_model': self.chat_model,
+ 'model_id': self.application.model.id if self.application.model is not None else None,
+ 'problem_optimization': self.application.problem_optimization
+
+ }
+
+ def to_pipeline_manage_params(self, problem_text: str, post_response_handler: PostResponseHandler,
+ exclude_paragraph_id_list):
+ params = self.to_base_pipeline_manage_params()
+ return {**params, 'problem_text': problem_text, 'post_response_handler': post_response_handler,
+ 'exclude_paragraph_id_list': exclude_paragraph_id_list}
+
+ def append_chat_record(self, chat_record: ChatRecord):
+ # 存入缓存中
+ self.chat_record_list.append(chat_record)
+ if self.application.id is not None:
# 插入数据库
- event.ListenerChatMessage.record_chat_message_signal.send(
- event.RecordChatMessageArgs(len(self.chat_message_list) - 1, self.chat_id, self.application_id,
- chat_message)
- )
- # 异步更新token
- event.ListenerChatMessage.update_chat_message_token_signal.send(chat_message)
+ if not QuerySet(Chat).filter(id=self.chat_id).exists():
+ Chat(id=self.chat_id, application_id=self.application.id, abstract=chat_record.problem_text).save()
+ # 插入会话记录
+ chat_record.save()
- def get_context_message(self):
- start_index = len(self.chat_message_list) - self.dialogue_number
- return [self.chat_message_list[index].get_chat_message() for index in
- range(start_index if start_index > 0 else 0, len(self.chat_message_list))]
+
+def get_post_handler(chat_info: ChatInfo):
+ class PostHandler(PostResponseHandler):
+
+ def handler(self,
+ chat_id: UUID,
+ chat_record_id,
+ paragraph_list: List[Paragraph],
+ problem_text: str,
+ answer_text,
+ manage: PiplineManage,
+ step: BaseChatStep,
+ padding_problem_text: str = None,
+ **kwargs):
+ chat_record = ChatRecord(id=chat_record_id,
+ chat_id=chat_id,
+ paragraph_id_list=[str(p.id) for p in paragraph_list],
+ problem_text=problem_text,
+ answer_text=answer_text,
+ details=manage.get_details(),
+ message_tokens=manage.context['message_tokens'],
+ answer_tokens=manage.context['answer_tokens'],
+ run_time=manage.context['run_time'],
+ index=len(chat_info.chat_record_list) + 1)
+ chat_info.append_chat_record(chat_record)
+ # 重新设置缓存
+ chat_cache.set(chat_id,
+ chat_info, timeout=60 * 30)
+
+ return PostHandler()
class ChatMessageSerializer(serializers.Serializer):
chat_id = serializers.UUIDField(required=True)
- def chat(self, message):
+ def chat(self, message, re_chat: bool):
self.is_valid(raise_exception=True)
chat_id = self.data.get('chat_id')
chat_info: ChatInfo = chat_cache.get(chat_id)
if chat_info is None:
- return result.Result(response_status=status.HTTP_404_NOT_FOUND, code=404, message="会话过期")
+ chat_info = self.re_open_chat(chat_id)
+ chat_cache.set(chat_id,
+ chat_info, timeout=60 * 30)
- chat_model = chat_info.chat_model
- vector = VectorStore.get_embedding_vector()
- # 向量库检索
- _value = vector.search(message, chat_info.dataset_id_list, chat_info.exclude_document_id_list,
- [chat_message.embedding_id for chat_message in
- (list(filter(lambda row: row.problem == message, chat_info.chat_message_list)))],
- True,
- EmbeddingModel.get_embedding_model())
- # 查询段落id详情
- paragraph = None
- if _value is not None:
- paragraph = QuerySet(Paragraph).get(id=_value.get('paragraph_id'))
- if paragraph is None:
- vector.delete_by_paragraph_id(_value.get('paragraph_id'))
+ pipline_manage_builder = PiplineManage.builder()
+ # 如果开启了问题优化,则添加上问题优化步骤
+ if chat_info.application.problem_optimization:
+ pipline_manage_builder.append_step(BaseResetProblemStep)
+ # 构建流水线管理器
+ pipline_message = (pipline_manage_builder.append_step(BaseSearchDatasetStep)
+ .append_step(BaseGenerateHumanMessageStep)
+ .append_step(BaseChatStep)
+ .build())
+ exclude_paragraph_id_list = []
+ # 相同问题是否需要排除已经查询到的段落
+ if re_chat:
+ paragraph_id_list = flat_map([row.paragraph_id_list for row in
+ filter(lambda chat_record: chat_record == message,
+ chat_info.chat_record_list)])
+ exclude_paragraph_id_list = list(set(paragraph_id_list))
+ # 构建运行参数
+ params = chat_info.to_pipeline_manage_params(message, get_post_handler(chat_info), exclude_paragraph_id_list)
+ # 运行流水线作业
+ pipline_message.run(params)
+ return pipline_message.context['chat_result']
- title, content = (None, None) if paragraph is None else (paragraph.title, paragraph.content)
- _id = str(uuid.uuid1())
+ @staticmethod
+ def re_open_chat(chat_id: str):
+ chat = QuerySet(Chat).filter(id=chat_id).first()
+ if chat is None:
+ raise AppApiException(500, "会话不存在")
+ application = QuerySet(Application).filter(id=chat.application_id).first()
+ if application is None:
+ raise AppApiException(500, "应用不存在")
+ model = QuerySet(Model).filter(id=application.model_id).first()
+ chat_model = None
+ if model is not None:
+ # 对话模型
+ chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name,
+ json.loads(
+ decrypt(model.credential)),
+ streaming=True)
+ # 数据集id列表
+ dataset_id_list = [str(row.dataset_id) for row in
+ QuerySet(ApplicationDatasetMapping).filter(
+ application_id=application.id)]
- embedding_id, dataset_id, document_id, paragraph_id, source_type, source_id = (_value.get(
- 'id'), _value.get(
- 'dataset_id'), _value.get(
- 'document_id'), _value.get(
- 'paragraph_id'), _value.get(
- 'source_type'), _value.get(
- 'source_id')) if _value is not None else (None, None, None, None, None, None)
-
- if chat_model is None:
- def event_block_content(c: str):
- yield 'data: ' + json.dumps({'chat_id': chat_id, 'id': _id, 'operate': paragraph is not None,
- 'is_end': True,
- 'content': c if c is not None else '抱歉,根据已知信息无法回答这个问题,请重新描述您的问题或提供更多信息~'}) + "\n\n"
- chat_info.append_chat_message(
- ChatMessage(_id, message, title, content, embedding_id, dataset_id, document_id,
- paragraph_id,
- source_type,
- source_id,
- c if c is not None else '抱歉,根据已知信息无法回答这个问题,请重新描述您的问题或提供更多信息~',
- 0,
- 0))
- # 重新设置缓存
- chat_cache.set(chat_id,
- chat_info, timeout=60 * 30)
-
- r = StreamingHttpResponse(streaming_content=event_block_content(content),
- content_type='text/event-stream;charset=utf-8')
-
- r['Cache-Control'] = 'no-cache'
- return r
- # 获取上下文
- history_message = chat_info.get_context_message()
-
- # 构建会话请求问题
- chat_message = [*history_message, MessageManagement.get_message(title, content, message)]
- # 对话
- result_data = chat_model.stream(chat_message)
-
- def event_content(response):
- all_text = ''
- try:
- for chunk in response:
- all_text += chunk.content
- yield 'data: ' + json.dumps({'chat_id': chat_id, 'id': _id, 'operate': paragraph is not None,
- 'content': chunk.content, 'is_end': False}) + "\n\n"
-
- yield 'data: ' + json.dumps({'chat_id': chat_id, 'id': _id, 'operate': paragraph is not None,
- 'content': '', 'is_end': True}) + "\n\n"
- chat_info.append_chat_message(
- ChatMessage(_id, message, title, content, embedding_id, dataset_id, document_id,
- paragraph_id,
- source_type,
- source_id, all_text,
- 0,
- 0,
- chat_message=chat_message, chat_model=chat_model))
- # 重新设置缓存
- chat_cache.set(chat_id,
- chat_info, timeout=60 * 30)
- except Exception as e:
- yield e
-
- r = StreamingHttpResponse(streaming_content=event_content(result_data),
- content_type='text/event-stream;charset=utf-8')
-
- r['Cache-Control'] = 'no-cache'
- return r
+ # 需要排除的文档
+ exclude_document_id_list = [str(document.id) for document in
+ QuerySet(Document).filter(
+ dataset_id__in=dataset_id_list,
+ is_active=False)]
+ return ChatInfo(chat_id, chat_model, dataset_id_list, exclude_document_id_list, application)
diff --git a/apps/application/serializers/chat_serializers.py b/apps/application/serializers/chat_serializers.py
index 474de0ee8..2204a8b07 100644
--- a/apps/application/serializers/chat_serializers.py
+++ b/apps/application/serializers/chat_serializers.py
@@ -10,7 +10,8 @@ import datetime
import json
import os
import uuid
-from typing import Dict
+from functools import reduce
+from typing import Dict, List
from django.core.cache import cache
from django.db import transaction
@@ -18,7 +19,8 @@ from django.db.models import QuerySet
from rest_framework import serializers
from application.models import Chat, Application, ApplicationDatasetMapping, VoteChoices, ChatRecord
-from application.serializers.application_serializers import ModelDatasetAssociation
+from application.serializers.application_serializers import ModelDatasetAssociation, DatasetSettingSerializer, \
+ ModelSettingSerializer
from application.serializers.chat_message_serializers import ChatInfo
from common.db.search import native_search, native_page_search, page_search
from common.event import ListenerManagement
@@ -26,8 +28,8 @@ from common.exception.app_exception import AppApiException
from common.util.file_util import get_file_content
from common.util.lock import try_lock, un_lock
from common.util.rsa_util import decrypt
+from common.util.split_model import flat_map
from dataset.models import Document, Problem, Paragraph
-from embedding.models import SourceType, Embedding
from setting.models import Model
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
from smartdoc.conf import PROJECT_DIR
@@ -106,12 +108,12 @@ class ChatSerializers(serializers.Serializer):
chat_id = str(uuid.uuid1())
chat_cache.set(chat_id,
- ChatInfo(chat_id, model, chat_model, application_id, dataset_id_list,
+ ChatInfo(chat_id, chat_model, dataset_id_list,
[str(document.id) for document in
QuerySet(Document).filter(
dataset_id__in=dataset_id_list,
is_active=False)],
- application.dialogue_number), timeout=60 * 30)
+ application), timeout=60 * 30)
return chat_id
class OpenTempChat(serializers.Serializer):
@@ -122,6 +124,12 @@ class ChatSerializers(serializers.Serializer):
multiple_rounds_dialogue = serializers.BooleanField(required=True)
dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True))
+ # 数据集相关设置
+ dataset_setting = DatasetSettingSerializer(required=True)
+ # 模型相关设置
+ model_setting = ModelSettingSerializer(required=True)
+ # 问题补全
+ problem_optimization = serializers.BooleanField(required=True)
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
@@ -140,42 +148,62 @@ class ChatSerializers(serializers.Serializer):
json.loads(
decrypt(model.credential)),
streaming=True)
+ application = Application(id=None, dialogue_number=3, model=model,
+ dataset_setting=self.data.get('dataset_setting'),
+ model_setting=self.data.get('model_setting'),
+ problem_optimization=self.data.get('problem_optimization'))
chat_cache.set(chat_id,
- ChatInfo(chat_id, model, chat_model, None, dataset_id_list,
+ ChatInfo(chat_id, chat_model, dataset_id_list,
[str(document.id) for document in
QuerySet(Document).filter(
dataset_id__in=dataset_id_list,
is_active=False)],
- 3 if self.data.get('multiple_rounds_dialogue') else 1), timeout=60 * 30)
+ application), timeout=60 * 30)
return chat_id
-def vote_exec(source_type: SourceType, source_id: str, field: str, post_handler):
- if source_type == SourceType.PROBLEM:
- problem = QuerySet(Problem).get(id=source_id)
- if problem is not None:
- problem.__setattr__(field, post_handler(problem))
- problem.save()
- embedding = QuerySet(Embedding).get(source_id=source_id, source_type=source_type)
- embedding.__setattr__(field, problem.__getattribute__(field))
- embedding.save()
- if source_type == SourceType.PARAGRAPH:
- paragraph = QuerySet(Paragraph).get(id=source_id)
- if paragraph is not None:
- paragraph.__setattr__(field, post_handler(paragraph))
- paragraph.save()
- embedding = QuerySet(Embedding).get(source_id=source_id, source_type=source_type)
- embedding.__setattr__(field, paragraph.__getattribute__(field))
- embedding.save()
-
-
class ChatRecordSerializerModel(serializers.ModelSerializer):
class Meta:
model = ChatRecord
- fields = "__all__"
+ fields = ['id', 'chat_id', 'vote_status', 'problem_text', 'answer_text',
+ 'message_tokens', 'answer_tokens', 'const', 'improve_paragraph_id_list', 'run_time', 'index']
class ChatRecordSerializer(serializers.Serializer):
+ class Operate(serializers.Serializer):
+ chat_id = serializers.UUIDField(required=True)
+
+ chat_record_id = serializers.UUIDField(required=True)
+
+ def one(self, with_valid=True):
+ if with_valid:
+ self.is_valid(raise_exception=True)
+ chat_record_id = self.data.get('chat_record_id')
+ chat_id = self.data.get('chat_id')
+ chat_record = QuerySet(ChatRecord).filter(id=chat_record_id, chat_id=chat_id).first()
+ dataset_list = []
+ paragraph_list = []
+ if len(chat_record.paragraph_id_list) > 0:
+ paragraph_list = native_search(QuerySet(Paragraph).filter(id__in=chat_record.paragraph_id_list),
+ get_file_content(
+ os.path.join(PROJECT_DIR, "apps", "application", 'sql',
+ 'list_dataset_paragraph_by_paragraph_id.sql')),
+ with_table_name=True)
+ dataset_list = [{'id': dataset_id, 'name': name} for dataset_id, name in reduce(lambda x, y: {**x, **y},
+ [{row.get(
+ 'dataset_id'): row.get(
+ "dataset_name")} for
+ row in
+ paragraph_list],
+ {}).items()]
+
+ return {
+ **ChatRecordSerializerModel(chat_record).data,
+ 'padding_problem_text': chat_record.details.get(
+ 'padding_problem_text') if 'problem_padding' in chat_record.details else None,
+ 'dataset_list': dataset_list,
+ 'paragraph_list': paragraph_list}
+
class Query(serializers.Serializer):
application_id = serializers.UUIDField(required=True)
chat_id = serializers.UUIDField(required=True)
@@ -183,15 +211,57 @@ class ChatRecordSerializer(serializers.Serializer):
def list(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
+ QuerySet(ChatRecord).filter(chat_id=self.data.get('chat_id'))
return [ChatRecordSerializerModel(chat_record).data for chat_record in
QuerySet(ChatRecord).filter(chat_id=self.data.get('chat_id'))]
+ def reset_chat_record_list(self, chat_record_list: List[ChatRecord]):
+ paragraph_id_list = flat_map([chat_record.paragraph_id_list for chat_record in chat_record_list])
+ # 去重
+ paragraph_id_list = list(set(paragraph_id_list))
+ paragraph_list = self.search_paragraph(paragraph_id_list)
+ return [self.reset_chat_record(chat_record, paragraph_list) for chat_record in chat_record_list]
+
+ @staticmethod
+ def search_paragraph(paragraph_id_list: List[str]):
+ paragraph_list = []
+ if len(paragraph_id_list) > 0:
+ paragraph_list = native_search(QuerySet(Paragraph).filter(id__in=paragraph_id_list),
+ get_file_content(
+ os.path.join(PROJECT_DIR, "apps", "application", 'sql',
+ 'list_dataset_paragraph_by_paragraph_id.sql')),
+ with_table_name=True)
+ return paragraph_list
+
+ @staticmethod
+ def reset_chat_record(chat_record, all_paragraph_list):
+ paragraph_list = list(
+ filter(lambda paragraph: chat_record.paragraph_id_list.__contains__(str(paragraph.get('id'))),
+ all_paragraph_list))
+ dataset_list = [{'id': dataset_id, 'name': name} for dataset_id, name in reduce(lambda x, y: {**x, **y},
+ [{row.get(
+ 'dataset_id'): row.get(
+ "dataset_name")} for
+ row in
+ paragraph_list],
+ {}).items()]
+ return {
+ **ChatRecordSerializerModel(chat_record).data,
+ 'padding_problem_text': chat_record.details.get('problem_padding').get(
+ 'padding_problem_text') if 'problem_padding' in chat_record.details else None,
+ 'dataset_list': dataset_list,
+ 'paragraph_list': paragraph_list
+ }
+
def page(self, current_page: int, page_size: int, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
- return page_search(current_page, page_size,
+ page = page_search(current_page, page_size,
QuerySet(ChatRecord).filter(chat_id=self.data.get('chat_id')).order_by("index"),
- post_records_handler=lambda chat_record: ChatRecordSerializerModel(chat_record).data)
+ post_records_handler=lambda chat_record: chat_record)
+ records = page.get('records')
+ page['records'] = self.reset_chat_record_list(records)
+ return page
class Vote(serializers.Serializer):
chat_id = serializers.UUIDField(required=True)
@@ -216,38 +286,20 @@ class ChatRecordSerializer(serializers.Serializer):
if vote_status == VoteChoices.STAR:
# 点赞
chat_record_details_model.vote_status = VoteChoices.STAR
- # 点赞数量 +1
- vote_exec(chat_record_details_model.source_type, chat_record_details_model.source_id,
- 'star_num',
- lambda r: r.star_num + 1)
if vote_status == VoteChoices.TRAMPLE:
# 点踩
chat_record_details_model.vote_status = VoteChoices.TRAMPLE
- # 点踩数量+1
- vote_exec(chat_record_details_model.source_type, chat_record_details_model.source_id,
- 'trample_num',
- lambda r: r.trample_num + 1)
chat_record_details_model.save()
else:
if vote_status == VoteChoices.UN_VOTE:
# 取消点赞
chat_record_details_model.vote_status = VoteChoices.UN_VOTE
chat_record_details_model.save()
- if chat_record_details_model.vote_status == VoteChoices.STAR:
- # 点赞数量 -1
- vote_exec(chat_record_details_model.source_type, chat_record_details_model.source_id,
- 'star_num', lambda r: r.star_num - 1)
- if chat_record_details_model.vote_status == VoteChoices.TRAMPLE:
- # 点踩数量 -1
- vote_exec(chat_record_details_model.source_type, chat_record_details_model.source_id,
- 'trample_num', lambda r: r.trample_num - 1)
-
else:
raise AppApiException(500, "已经投票过,请先取消后再进行投票")
finally:
un_lock(self.data.get('chat_record_id'))
-
return True
class ImproveSerializer(serializers.Serializer):
diff --git a/apps/application/sql/list_application.sql b/apps/application/sql/list_application.sql
index 283c7bb92..b7aa0fbe9 100644
--- a/apps/application/sql/list_application.sql
+++ b/apps/application/sql/list_application.sql
@@ -1,4 +1,4 @@
-SELECT * FROM ( SELECT * FROM application ${application_custom_sql} UNION
+SELECT *,to_json(dataset_setting) as dataset_setting,to_json(model_setting) as model_setting FROM ( SELECT * FROM application ${application_custom_sql} UNION
SELECT
*
FROM
diff --git a/apps/application/sql/list_dataset_paragraph_by_paragraph_id.sql b/apps/application/sql/list_dataset_paragraph_by_paragraph_id.sql
new file mode 100644
index 000000000..4c27ca07d
--- /dev/null
+++ b/apps/application/sql/list_dataset_paragraph_by_paragraph_id.sql
@@ -0,0 +1,6 @@
+SELECT
+ paragraph.*,
+ dataset."name" AS "dataset_name"
+FROM
+ paragraph paragraph
+ LEFT JOIN dataset dataset ON dataset."id" = paragraph.dataset_id
\ No newline at end of file
diff --git a/apps/application/swagger_api/application_api.py b/apps/application/swagger_api/application_api.py
index 54eccc9ef..2a7419496 100644
--- a/apps/application/swagger_api/application_api.py
+++ b/apps/application/swagger_api/application_api.py
@@ -58,6 +58,7 @@ class ApplicationApi(ApiMixin):
'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间", description='创建时间'),
'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间", description='修改时间'),
+
'dataset_id_list': openapi.Schema(type=openapi.TYPE_ARRAY,
items=openapi.Schema(type=openapi.TYPE_STRING),
title="关联知识库Id列表",
@@ -133,7 +134,7 @@ class ApplicationApi(ApiMixin):
def get_request_body_api():
return openapi.Schema(
type=openapi.TYPE_OBJECT,
- required=['name', 'desc', 'model_id', 'multiple_rounds_dialogue'],
+ required=[],
properties={
'name': openapi.Schema(type=openapi.TYPE_STRING, title="应用名称", description="应用名称"),
'desc': openapi.Schema(type=openapi.TYPE_STRING, title="应用描述", description="应用描述"),
@@ -141,11 +142,52 @@ class ApplicationApi(ApiMixin):
"multiple_rounds_dialogue": openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否开启多轮对话",
description="是否开启多轮对话"),
'prologue': openapi.Schema(type=openapi.TYPE_STRING, title="开场白", description="开场白"),
- 'example': openapi.Schema(type=openapi.TYPE_ARRAY, items=openapi.Schema(type=openapi.TYPE_STRING),
- title="示例列表", description="示例列表"),
'dataset_id_list': openapi.Schema(type=openapi.TYPE_ARRAY,
items=openapi.Schema(type=openapi.TYPE_STRING),
title="关联知识库Id列表", description="关联知识库Id列表"),
+ 'dataset_setting': ApplicationApi.DatasetSetting.get_request_body_api(),
+ 'model_setting': ApplicationApi.ModelSetting.get_request_body_api(),
+ 'problem_optimization': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="问题优化",
+ description="是否开启问题优化", default=True)
+
+ }
+ )
+
+ class DatasetSetting(ApiMixin):
+ @staticmethod
+ def get_request_body_api():
+ return openapi.Schema(
+ type=openapi.TYPE_OBJECT,
+ required=[''],
+ properties={
+ 'top_n': openapi.Schema(type=openapi.TYPE_NUMBER, title="引用分段数", description="引用分段数",
+ default=5),
+ 'similarity': openapi.Schema(type=openapi.TYPE_NUMBER, title='相似度', description="相似度",
+ default=0.6),
+ 'max_paragraph_char_number': openapi.Schema(type=openapi.TYPE_NUMBER, title='最多引用字符数',
+ description="最多引用字符数", default=3000),
+ }
+ )
+
+ class ModelSetting(ApiMixin):
+ @staticmethod
+ def get_request_body_api():
+ return openapi.Schema(
+ type=openapi.TYPE_OBJECT,
+ required=['prompt'],
+ properties={
+ 'prompt': openapi.Schema(type=openapi.TYPE_STRING, title="提示词", description="提示词",
+ default=('已知信息:'
+ '\n{data}'
+ '\n回答要求:'
+ '\n- 如果你不知道答案或者没有从获取答案,请回答“没有在知识库中查找到相关信息,建议咨询相关技术支持或参考官方文档进行操作”。'
+ '\n- 避免提及你是从中获得的知识。'
+ '\n- 请保持答案与中描述的一致。'
+ '\n- 请使用markdown 语法优化答案的格式。'
+ '\n- 中的图片链接、链接地址和脚本语言请完整返回。'
+ '\n- 请使用与问题相同的语言来回答。'
+ '\n问题:'
+ '\n{question}')),
}
)
@@ -155,7 +197,8 @@ class ApplicationApi(ApiMixin):
def get_request_body_api():
return openapi.Schema(
type=openapi.TYPE_OBJECT,
- required=['name', 'desc', 'model_id', 'multiple_rounds_dialogue'],
+ required=['name', 'desc', 'model_id', 'multiple_rounds_dialogue', 'dataset_setting', 'model_setting',
+ 'problem_optimization'],
properties={
'name': openapi.Schema(type=openapi.TYPE_STRING, title="应用名称", description="应用名称"),
'desc': openapi.Schema(type=openapi.TYPE_STRING, title="应用描述", description="应用描述"),
@@ -163,11 +206,13 @@ class ApplicationApi(ApiMixin):
"multiple_rounds_dialogue": openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否开启多轮对话",
description="是否开启多轮对话"),
'prologue': openapi.Schema(type=openapi.TYPE_STRING, title="开场白", description="开场白"),
- 'example': openapi.Schema(type=openapi.TYPE_ARRAY, items=openapi.Schema(type=openapi.TYPE_STRING),
- title="示例列表", description="示例列表"),
'dataset_id_list': openapi.Schema(type=openapi.TYPE_ARRAY,
items=openapi.Schema(type=openapi.TYPE_STRING),
- title="关联知识库Id列表", description="关联知识库Id列表")
+ title="关联知识库Id列表", description="关联知识库Id列表"),
+ 'dataset_setting': ApplicationApi.DatasetSetting.get_request_body_api(),
+ 'model_setting': ApplicationApi.ModelSetting.get_request_body_api(),
+ 'problem_optimization': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="问题优化",
+ description="是否开启问题优化", default=True)
}
)
diff --git a/apps/application/swagger_api/chat_api.py b/apps/application/swagger_api/chat_api.py
index 0cde547e1..bc36599c8 100644
--- a/apps/application/swagger_api/chat_api.py
+++ b/apps/application/swagger_api/chat_api.py
@@ -8,6 +8,7 @@
"""
from drf_yasg import openapi
+from application.swagger_api.application_api import ApplicationApi
from common.mixins.api_mixin import ApiMixin
@@ -19,6 +20,7 @@ class ChatApi(ApiMixin):
required=['message'],
properties={
'message': openapi.Schema(type=openapi.TYPE_STRING, title="问题", description="问题"),
+ 're_chat': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="重新生成", default="重新生成")
}
)
@@ -73,14 +75,19 @@ class ChatApi(ApiMixin):
def get_request_body_api():
return openapi.Schema(
type=openapi.TYPE_OBJECT,
- required=['model_id', 'multiple_rounds_dialogue'],
+ required=['model_id', 'multiple_rounds_dialogue', 'dataset_setting', 'model_setting',
+ 'problem_optimization'],
properties={
'model_id': openapi.Schema(type=openapi.TYPE_STRING, title="模型id", description="模型id"),
'dataset_id_list': openapi.Schema(type=openapi.TYPE_ARRAY,
items=openapi.Schema(type=openapi.TYPE_STRING),
title="关联知识库Id列表", description="关联知识库Id列表"),
'multiple_rounds_dialogue': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否开启多轮会话",
- description="是否开启多轮会话")
+ description="是否开启多轮会话"),
+ 'dataset_setting': ApplicationApi.DatasetSetting.get_request_body_api(),
+ 'model_setting': ApplicationApi.ModelSetting.get_request_body_api(),
+ 'problem_optimization': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="问题优化",
+ description="是否开启问题优化", default=True)
}
)
diff --git a/apps/application/urls.py b/apps/application/urls.py
index ef74a8830..6944b4e6a 100644
--- a/apps/application/urls.py
+++ b/apps/application/urls.py
@@ -25,6 +25,8 @@ urlpatterns = [
path('application//chat//chat_record/', views.ChatView.ChatRecord.as_view()),
path('application//chat//chat_record//',
views.ChatView.ChatRecord.Page.as_view()),
+ path('application//chat//chat_record/',
+ views.ChatView.ChatRecord.Operate.as_view()),
path('application//chat//chat_record//vote',
views.ChatView.ChatRecord.Vote.as_view(),
name=''),
diff --git a/apps/application/views/chat_views.py b/apps/application/views/chat_views.py
index 3ee0223c3..c39c149c9 100644
--- a/apps/application/views/chat_views.py
+++ b/apps/application/views/chat_views.py
@@ -71,7 +71,8 @@ class ChatView(APIView):
dynamic_tag=keywords.get('application_id'))])
)
def post(self, request: Request, chat_id: str):
- return ChatMessageSerializer(data={'chat_id': chat_id}).chat(request.data.get('message'))
+ return ChatMessageSerializer(data={'chat_id': chat_id}).chat(request.data.get('message'), request.data.get(
+ 're_chat') if 're_chat' in request.data else False)
@action(methods=['GET'], detail=False)
@swagger_auto_schema(operation_summary="获取对话列表",
@@ -134,6 +135,27 @@ class ChatView(APIView):
class ChatRecord(APIView):
authentication_classes = [TokenAuth]
+ class Operate(APIView):
+ authentication_classes = [TokenAuth]
+
+ @action(methods=['GET'], detail=False)
+ @swagger_auto_schema(operation_summary="获取对话记录详情",
+ operation_id="获取对话记录详情",
+ manual_parameters=ChatRecordApi.get_request_params_api(),
+ responses=result.get_api_array_response(ChatRecordApi.get_response_body_api()),
+ tags=["应用/对话日志"]
+ )
+ @has_permissions(
+ ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_KEY],
+ [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE,
+ dynamic_tag=keywords.get('application_id'))])
+ )
+ def get(self, request: Request, application_id: str, chat_id: str, chat_record_id: str):
+ return result.success(ChatRecordSerializer.Operate(
+ data={'application_id': application_id,
+ 'chat_id': chat_id,
+ 'chat_record_id': chat_record_id}).one())
+
@action(methods=['GET'], detail=False)
@swagger_auto_schema(operation_summary="获取对话记录列表",
operation_id="获取对话记录列表",
diff --git a/apps/common/db/search.py b/apps/common/db/search.py
index f8d4a6878..763667154 100644
--- a/apps/common/db/search.py
+++ b/apps/common/db/search.py
@@ -83,7 +83,7 @@ def compiler_queryset(queryset: QuerySet, field_replace_dict: None | Dict[str, s
field_replace_dict = get_field_replace_dict(queryset)
app_sql_compiler = AppSQLCompiler(q, using=DEFAULT_DB_ALIAS, connection=compiler.connection,
field_replace_dict=field_replace_dict)
- sql, params = app_sql_compiler.get_query_str(with_table_name)
+ sql, params = app_sql_compiler.get_query_str(with_table_name=with_table_name)
return sql, params
diff --git a/apps/common/event/__init__.py b/apps/common/event/__init__.py
index e00026645..f98e84f56 100644
--- a/apps/common/event/__init__.py
+++ b/apps/common/event/__init__.py
@@ -7,10 +7,8 @@
@desc:
"""
from .listener_manage import *
-from .listener_chat_message import *
def run():
listener_manage.ListenerManagement().run()
- listener_chat_message.ListenerChatMessage().run()
QuerySet(Document).filter(status=Status.embedding).update(**{'status': Status.error})
diff --git a/apps/common/event/listener_chat_message.py b/apps/common/event/listener_chat_message.py
deleted file mode 100644
index 26c9fe3bc..000000000
--- a/apps/common/event/listener_chat_message.py
+++ /dev/null
@@ -1,67 +0,0 @@
-# coding=utf-8
-"""
- @project: maxkb
- @Author:虎
- @file: listener_manage.py
- @date:2023/10/20 14:01
- @desc:
-"""
-import logging
-
-from blinker import signal
-from django.db.models import QuerySet
-
-from application.models import ChatRecord, Chat
-from application.serializers.chat_message_serializers import ChatMessage
-from common.event.common import poxy
-
-
-class RecordChatMessageArgs:
- def __init__(self, index: int, chat_id: str, application_id: str, chat_message: ChatMessage):
- self.index = index
- self.chat_id = chat_id
- self.application_id = application_id
- self.chat_message = chat_message
-
-
-class ListenerChatMessage:
- record_chat_message_signal = signal("record_chat_message")
- update_chat_message_token_signal = signal("update_chat_message_token")
-
- @staticmethod
- def record_chat_message(args: RecordChatMessageArgs):
- if not QuerySet(Chat).filter(id=args.chat_id).exists():
- Chat(id=args.chat_id, application_id=args.application_id, abstract=args.chat_message.problem).save()
- # 插入会话记录
- try:
- chat_record = ChatRecord(
- id=args.chat_message.id,
- chat_id=args.chat_id,
- dataset_id=args.chat_message.dataset_id,
- paragraph_id=args.chat_message.paragraph_id,
- source_id=args.chat_message.source_id,
- source_type=args.chat_message.source_type,
- problem_text=args.chat_message.problem,
- answer_text=args.chat_message.answer,
- index=args.index,
- message_tokens=args.chat_message.message_tokens,
- answer_tokens=args.chat_message.answer_token)
- chat_record.save()
- except Exception as e:
- print(e)
-
- @staticmethod
- @poxy
- def update_token(chat_message: ChatMessage):
- if chat_message.chat_model is not None:
- logging.getLogger("max_kb").info("开始更新token")
- message_token = chat_message.chat_model.get_num_tokens_from_messages(chat_message.chat_message)
- answer_token = chat_message.chat_model.get_num_tokens(chat_message.answer)
- # 修改token数量
- QuerySet(ChatRecord).filter(id=chat_message.id).update(
- **{'message_tokens': message_token, 'answer_tokens': answer_token})
-
- def run(self):
- # 记录会话
- ListenerChatMessage.record_chat_message_signal.connect(self.record_chat_message)
- ListenerChatMessage.update_chat_message_token_signal.connect(self.update_token)
diff --git a/apps/common/field/common.py b/apps/common/field/common.py
new file mode 100644
index 000000000..a97469891
--- /dev/null
+++ b/apps/common/field/common.py
@@ -0,0 +1,34 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: common.py
+ @date:2024/1/11 18:44
+ @desc:
+"""
+from rest_framework import serializers
+
+
+class InstanceField(serializers.Field):
+ def __init__(self, model_type, **kwargs):
+ self.model_type = model_type
+ super().__init__(**kwargs)
+
+ def to_internal_value(self, data):
+ if not isinstance(data, self.model_type):
+ self.fail('message类型错误', value=data)
+ return data
+
+ def to_representation(self, value):
+ return value
+
+
+class FunctionField(serializers.Field):
+
+ def to_internal_value(self, data):
+ if not callable(data):
+ self.fail('不是一個函數', value=data)
+ return data
+
+ def to_representation(self, value):
+ return value
diff --git a/apps/common/sql/list_embedding_text.sql b/apps/common/sql/list_embedding_text.sql
index e720041f2..b051da1c6 100644
--- a/apps/common/sql/list_embedding_text.sql
+++ b/apps/common/sql/list_embedding_text.sql
@@ -5,9 +5,7 @@ SELECT
problem.dataset_id AS dataset_id,
0 AS source_type,
problem."content" AS "text",
- paragraph.is_active AS is_active,
- problem.star_num as star_num,
- problem.trample_num as trample_num
+ paragraph.is_active AS is_active
FROM
problem problem
LEFT JOIN paragraph paragraph ON paragraph."id" = problem.paragraph_id
@@ -23,9 +21,7 @@ SELECT
concat_ws('
',concat_ws('
',paragraph.title,paragraph."content"),paragraph.title) AS "text",
- paragraph.is_active AS is_active,
- paragraph.star_num as star_num,
- paragraph.trample_num as trample_num
+ paragraph.is_active AS is_active
FROM
paragraph paragraph
diff --git a/apps/dataset/migrations/0004_remove_paragraph_hit_num_remove_paragraph_star_num_and_more.py b/apps/dataset/migrations/0004_remove_paragraph_hit_num_remove_paragraph_star_num_and_more.py
new file mode 100644
index 000000000..3440fa39f
--- /dev/null
+++ b/apps/dataset/migrations/0004_remove_paragraph_hit_num_remove_paragraph_star_num_and_more.py
@@ -0,0 +1,37 @@
+# Generated by Django 4.1.10 on 2024-01-16 11:22
+
+from django.db import migrations
+
+
+class Migration(migrations.Migration):
+
+ dependencies = [
+ ('dataset', '0003_alter_paragraph_content'),
+ ]
+
+ operations = [
+ migrations.RemoveField(
+ model_name='paragraph',
+ name='hit_num',
+ ),
+ migrations.RemoveField(
+ model_name='paragraph',
+ name='star_num',
+ ),
+ migrations.RemoveField(
+ model_name='paragraph',
+ name='trample_num',
+ ),
+ migrations.RemoveField(
+ model_name='problem',
+ name='hit_num',
+ ),
+ migrations.RemoveField(
+ model_name='problem',
+ name='star_num',
+ ),
+ migrations.RemoveField(
+ model_name='problem',
+ name='trample_num',
+ ),
+ ]
diff --git a/apps/dataset/models/data_set.py b/apps/dataset/models/data_set.py
index 2b493c0c4..7a69e2753 100644
--- a/apps/dataset/models/data_set.py
+++ b/apps/dataset/models/data_set.py
@@ -74,9 +74,6 @@ class Paragraph(AppModelMixin):
dataset = models.ForeignKey(DataSet, on_delete=models.DO_NOTHING)
content = models.CharField(max_length=4096, verbose_name="段落内容")
title = models.CharField(max_length=256, verbose_name="标题", default="")
- hit_num = models.IntegerField(verbose_name="命中数量", default=0)
- star_num = models.IntegerField(verbose_name="点赞数", default=0)
- trample_num = models.IntegerField(verbose_name="点踩数", default=0)
status = models.CharField(verbose_name='状态', max_length=1, choices=Status.choices,
default=Status.embedding)
is_active = models.BooleanField(default=True)
@@ -94,9 +91,6 @@ class Problem(AppModelMixin):
dataset = models.ForeignKey(DataSet, on_delete=models.DO_NOTHING, db_constraint=False)
paragraph = models.ForeignKey(Paragraph, on_delete=models.DO_NOTHING, db_constraint=False)
content = models.CharField(max_length=256, verbose_name="问题内容")
- hit_num = models.IntegerField(verbose_name="命中数量", default=0)
- star_num = models.IntegerField(verbose_name="点赞数", default=0)
- trample_num = models.IntegerField(verbose_name="点踩数", default=0)
class Meta:
db_table = "problem"
diff --git a/apps/dataset/serializers/dataset_serializers.py b/apps/dataset/serializers/dataset_serializers.py
index 1c0524099..72147db2e 100644
--- a/apps/dataset/serializers/dataset_serializers.py
+++ b/apps/dataset/serializers/dataset_serializers.py
@@ -26,12 +26,11 @@ from application.models import ApplicationDatasetMapping
from common.config.embedding_config import VectorStore, EmbeddingModel
from common.db.search import get_dynamics_model, native_page_search, native_search
from common.db.sql_execute import select_list
-from common.event.listener_manage import ListenerManagement, SyncWebDatasetArgs
from common.exception.app_exception import AppApiException
from common.mixins.api_mixin import ApiMixin
from common.util.common import post
from common.util.file_util import get_file_content
-from common.util.fork import ChildLink, Fork, ForkManage
+from common.util.fork import ChildLink, Fork
from common.util.split_model import get_split_model
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type
from dataset.serializers.common_serializers import list_paragraph
@@ -286,7 +285,8 @@ class DataSetSerializers(serializers.ModelSerializer):
properties={
'name': openapi.Schema(type=openapi.TYPE_STRING, title="知识库名称", description="知识库名称"),
'desc': openapi.Schema(type=openapi.TYPE_STRING, title="知识库描述", description="知识库描述"),
- 'source_url': openapi.Schema(type=openapi.TYPE_STRING, title="web站点url", description="web站点url"),
+ 'source_url': openapi.Schema(type=openapi.TYPE_STRING, title="web站点url",
+ description="web站点url"),
'selector': openapi.Schema(type=openapi.TYPE_STRING, title="选择器", description="选择器")
}
)
@@ -369,7 +369,8 @@ class DataSetSerializers(serializers.ModelSerializer):
dataset_id = uuid.uuid1()
dataset = DataSet(
**{'id': dataset_id, 'name': instance.get("name"), 'desc': instance.get('desc'), 'user_id': user_id,
- 'type': Type.web, 'meta': {'source_url': instance.get('source_url'), 'selector': instance.get('selector')}})
+ 'type': Type.web,
+ 'meta': {'source_url': instance.get('source_url'), 'selector': instance.get('selector')}})
dataset.save()
ListenerManagement.sync_web_dataset_signal.send(
SyncWebDatasetArgs(str(dataset_id), instance.get('source_url'), instance.get('selector'),
diff --git a/apps/dataset/serializers/paragraph_serializers.py b/apps/dataset/serializers/paragraph_serializers.py
index 48a1ef434..48694a7d8 100644
--- a/apps/dataset/serializers/paragraph_serializers.py
+++ b/apps/dataset/serializers/paragraph_serializers.py
@@ -28,7 +28,7 @@ from dataset.serializers.problem_serializers import ProblemInstanceSerializer, P
class ParagraphSerializer(serializers.ModelSerializer):
class Meta:
model = Paragraph
- fields = ['id', 'content', 'hit_num', 'star_num', 'trample_num', 'is_active', 'document_id', 'title',
+ fields = ['id', 'content', 'is_active', 'document_id', 'title',
'create_time', 'update_time']
diff --git a/apps/dataset/serializers/problem_serializers.py b/apps/dataset/serializers/problem_serializers.py
index 5948ef5b0..2f5fdc0a4 100644
--- a/apps/dataset/serializers/problem_serializers.py
+++ b/apps/dataset/serializers/problem_serializers.py
@@ -24,7 +24,7 @@ from embedding.vector.pg_vector import PGVector
class ProblemSerializer(serializers.ModelSerializer):
class Meta:
model = Problem
- fields = ['id', 'content', 'hit_num', 'star_num', 'trample_num', 'dataset_id', 'document_id',
+ fields = ['id', 'content', 'dataset_id', 'document_id',
'create_time', 'update_time']
@@ -77,8 +77,8 @@ class ProblemSerializers(ApiMixin, serializers.Serializer):
'document_id': self.data.get('document_id'),
'paragraph_id': self.data.get('paragraph_id'),
'dataset_id': self.data.get('dataset_id'),
- 'star_num': 0,
- 'trample_num': 0})
+
+ })
return ProblemSerializers.Operate(
data={'dataset_id': self.data.get('dataset_id'), 'document_id': self.data.get('document_id'),
diff --git a/apps/embedding/migrations/0002_remove_embedding_star_num_and_more.py b/apps/embedding/migrations/0002_remove_embedding_star_num_and_more.py
new file mode 100644
index 000000000..5597c78b9
--- /dev/null
+++ b/apps/embedding/migrations/0002_remove_embedding_star_num_and_more.py
@@ -0,0 +1,37 @@
+# Generated by Django 4.1.10 on 2024-01-16 11:22
+
+import django.contrib.postgres.fields
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+
+ dependencies = [
+ ('embedding', '0001_initial'),
+ ]
+
+ operations = [
+ migrations.RemoveField(
+ model_name='embedding',
+ name='star_num',
+ ),
+ migrations.RemoveField(
+ model_name='embedding',
+ name='trample_num',
+ ),
+ migrations.AddField(
+ model_name='embedding',
+ name='keywords',
+ field=django.contrib.postgres.fields.ArrayField(base_field=models.CharField(max_length=256), default=list, size=None, verbose_name='关键词列表'),
+ ),
+ migrations.AddField(
+ model_name='embedding',
+ name='meta',
+ field=models.JSONField(default=dict, verbose_name='元数据'),
+ ),
+ migrations.AlterField(
+ model_name='embedding',
+ name='source_type',
+ field=models.CharField(choices=[('0', '问题'), ('1', '段落'), ('2', '标题')], default='0', max_length=5, verbose_name='资源类型'),
+ ),
+ ]
diff --git a/apps/embedding/models/embedding.py b/apps/embedding/models/embedding.py
index 83f786c68..caa6d7247 100644
--- a/apps/embedding/models/embedding.py
+++ b/apps/embedding/models/embedding.py
@@ -6,6 +6,7 @@
@date:2023/9/21 15:46
@desc:
"""
+from django.contrib.postgres.fields import ArrayField
from django.db import models
from common.field.vector_field import VectorField
@@ -16,6 +17,7 @@ class SourceType(models.TextChoices):
"""订单类型"""
PROBLEM = 0, '问题'
PARAGRAPH = 1, '段落'
+ TITLE = 2, '标题'
class Embedding(models.Model):
@@ -36,10 +38,10 @@ class Embedding(models.Model):
embedding = VectorField(verbose_name="向量")
- star_num = models.IntegerField(default=0, verbose_name="点赞数量")
+ keywords = ArrayField(verbose_name="关键词列表",
+ base_field=models.CharField(max_length=256), default=list)
- trample_num = models.IntegerField(default=0,
- verbose_name="点踩数量")
+ meta = models.JSONField(verbose_name="元数据", default=dict)
class Meta:
db_table = "embedding"
diff --git a/apps/embedding/sql/embedding_search.sql b/apps/embedding/sql/embedding_search.sql
index 7787eb60b..ce3d4a580 100644
--- a/apps/embedding/sql/embedding_search.sql
+++ b/apps/embedding/sql/embedding_search.sql
@@ -1,15 +1,17 @@
-SELECT * FROM (SELECT
- *,
- ( 1 - ( embedding.embedding <=> %s ) ) AS similarity,
-CASE
-
- WHEN embedding.star_num - embedding.trample_num = 0 THEN
- 0 ELSE ( ( ( embedding.star_num - embedding.trample_num ) - aggs.min_value ) / ( aggs.max_value - aggs.min_value ) )
- END AS score
+SELECT
+ paragraph_id,
+ comprehensive_score,
+ comprehensive_score as similarity
FROM
- embedding,
- ( SELECT MIN ( star_num - trample_num ) AS min_value, MAX ( star_num - trample_num ) AS max_value FROM embedding ${embedding_query}) aggs
- ${embedding_query}
- ) temp
- WHERE similarity>0.5
- ORDER BY (similarity + score) DESC LIMIT 1
\ No newline at end of file
+ (
+ SELECT DISTINCT ON
+ ("paragraph_id") ( similarity ),* ,similarity AS comprehensive_score
+ FROM
+ ( SELECT *, ( 1 - ( embedding.embedding <=> %s ) ) AS similarity FROM embedding ${embedding_query}) TEMP
+ ORDER BY
+ paragraph_id,
+ similarity DESC
+ ) DISTINCT_TEMP
+WHERE comprehensive_score>%s
+ORDER BY comprehensive_score DESC
+LIMIT %s
\ No newline at end of file
diff --git a/apps/embedding/sql/hit_test.sql b/apps/embedding/sql/hit_test.sql
index 141feee4e..8feffc86d 100644
--- a/apps/embedding/sql/hit_test.sql
+++ b/apps/embedding/sql/hit_test.sql
@@ -1,34 +1,17 @@
SELECT
- similarity,
paragraph_id,
- comprehensive_score
+ comprehensive_score,
+ comprehensive_score as similarity
FROM
(
SELECT DISTINCT ON
- ( "paragraph_id" ) ( similarity + score ),*,
- ( similarity + score ) AS comprehensive_score
+ ("paragraph_id") ( similarity ),* ,similarity AS comprehensive_score
FROM
- (
- SELECT
- *,
- ( 1 - ( embedding.embedding <=> %s ) ) AS similarity,
- CASE
-
- WHEN embedding.star_num - embedding.trample_num = 0 THEN
- 0 ELSE ( ( ( embedding.star_num - embedding.trample_num ) - aggs.min_value ) / ( aggs.max_value - aggs.min_value ) )
- END AS score
- FROM
- embedding,
- ( SELECT MIN ( star_num - trample_num ) AS min_value, MAX ( star_num - trample_num ) AS max_value FROM embedding ${embedding_query}) aggs
- ${embedding_query}
- ) TEMP
- WHERE
- similarity > %s
+ ( SELECT *, ( 1 - ( embedding.embedding <=> %s ) ) AS similarity FROM embedding ${embedding_query} ) TEMP
ORDER BY
paragraph_id,
- ( similarity + score )
- DESC
- ) ss
-ORDER BY
- comprehensive_score DESC
- LIMIT %s
\ No newline at end of file
+ similarity DESC
+ ) DISTINCT_TEMP
+WHERE comprehensive_score>%s
+ORDER BY comprehensive_score DESC
+LIMIT %s
\ No newline at end of file
diff --git a/apps/embedding/vector/base_vector.py b/apps/embedding/vector/base_vector.py
index a5c14f413..ce96aa3df 100644
--- a/apps/embedding/vector/base_vector.py
+++ b/apps/embedding/vector/base_vector.py
@@ -99,8 +99,6 @@ class BaseVectorStore(ABC):
@abstractmethod
def _save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str,
is_active: bool,
- star_num: int,
- trample_num: int,
embedding: HuggingFaceEmbeddings):
pass
@@ -108,11 +106,20 @@ class BaseVectorStore(ABC):
def _batch_save(self, text_list: List[Dict], embedding: HuggingFaceEmbeddings):
pass
- @abstractmethod
def search(self, query_text, dataset_id_list: list[str], exclude_document_id_list: list[str],
- exclude_id_list: list[str],
+ exclude_paragraph_list: list[str],
is_active: bool,
embedding: HuggingFaceEmbeddings):
+ if dataset_id_list is None or len(dataset_id_list) == 0:
+ return []
+ embedding_query = embedding.embed_query(query_text)
+ result = self.query(embedding_query, dataset_id_list, exclude_document_id_list, exclude_paragraph_list,
+ is_active, 1, 0.65)
+ return result[0]
+
+ @abstractmethod
+ def query(self, query_embedding: List[float], dataset_id_list: list[str], exclude_document_id_list: list[str],
+ exclude_paragraph_list: list[str], is_active: bool, top_n: int, similarity: float):
pass
@abstractmethod
diff --git a/apps/embedding/vector/pg_vector.py b/apps/embedding/vector/pg_vector.py
index 79f8be9bf..0091e31da 100644
--- a/apps/embedding/vector/pg_vector.py
+++ b/apps/embedding/vector/pg_vector.py
@@ -33,8 +33,6 @@ class PGVector(BaseVectorStore):
def _save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str,
is_active: bool,
- star_num: int,
- trample_num: int,
embedding: HuggingFaceEmbeddings):
text_embedding = embedding.embed_query(text)
embedding = Embedding(id=uuid.uuid1(),
@@ -44,10 +42,7 @@ class PGVector(BaseVectorStore):
paragraph_id=paragraph_id,
source_id=source_id,
embedding=text_embedding,
- source_type=source_type,
- star_num=star_num,
- trample_num=trample_num
- )
+ source_type=source_type)
embedding.save()
return True
@@ -61,8 +56,6 @@ class PGVector(BaseVectorStore):
is_active=text_list[index].get('is_active', True),
source_id=text_list[index].get('source_id'),
source_type=text_list[index].get('source_type'),
- star_num=text_list[index].get('star_num'),
- trample_num=text_list[index].get('trample_num'),
embedding=embeddings[index]) for index in
range(0, len(text_list))]
QuerySet(Embedding).bulk_create(embedding_list) if len(embedding_list) > 0 else None
@@ -78,29 +71,27 @@ class PGVector(BaseVectorStore):
'hit_test.sql')),
with_table_name=True)
embedding_model = select_list(exec_sql,
- [json.dumps(embedding_query), *exec_params, *exec_params, similarity, top_number])
+ [json.dumps(embedding_query), *exec_params, similarity, top_number])
return embedding_model
- def search(self, query_text, dataset_id_list: list[str], exclude_document_id_list: list[str],
- exclude_id_list: list[str],
- is_active: bool,
- embedding: HuggingFaceEmbeddings):
+ def query(self, query_embedding: List[float], dataset_id_list: list[str], exclude_document_id_list: list[str],
+ exclude_paragraph_list: list[str], is_active: bool, top_n: int, similarity: float):
exclude_dict = {}
if dataset_id_list is None or len(dataset_id_list) == 0:
- return None
+ return []
query_set = QuerySet(Embedding).filter(dataset_id__in=dataset_id_list, is_active=is_active)
- embedding_query = embedding.embed_query(query_text)
if exclude_document_id_list is not None and len(exclude_document_id_list) > 0:
exclude_dict.__setitem__('document_id__in', exclude_document_id_list)
- if exclude_id_list is not None and len(exclude_id_list) > 0:
- exclude_dict.__setitem__('id__in', exclude_id_list)
+ if exclude_paragraph_list is not None and len(exclude_paragraph_list) > 0:
+ exclude_dict.__setitem__('paragraph_id__in', exclude_paragraph_list)
query_set = query_set.exclude(**exclude_dict)
exec_sql, exec_params = generate_sql_by_query_dict({'embedding_query': query_set},
select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "embedding", 'sql',
'embedding_search.sql')),
with_table_name=True)
- embedding_model = select_one(exec_sql, (json.dumps(embedding_query), *exec_params, *exec_params))
+ embedding_model = select_list(exec_sql,
+ [json.dumps(query_embedding), *exec_params, similarity, top_n])
return embedding_model
def update_by_source_id(self, source_id: str, instance: Dict):
diff --git a/apps/setting/models_provider/impl/wenxin_model_provider/model/qian_fan_chat_model.py b/apps/setting/models_provider/impl/wenxin_model_provider/model/qian_fan_chat_model.py
index bd6a5d1cf..3aeb0cdb3 100644
--- a/apps/setting/models_provider/impl/wenxin_model_provider/model/qian_fan_chat_model.py
+++ b/apps/setting/models_provider/impl/wenxin_model_provider/model/qian_fan_chat_model.py
@@ -14,13 +14,23 @@ from langchain.chat_models.base import BaseChatModel
from langchain.load import dumpd
from langchain.schema import LLMResult
from langchain.schema.language_model import LanguageModelInput
-from langchain.schema.messages import BaseMessageChunk, BaseMessage, HumanMessage, AIMessage
+from langchain.schema.messages import BaseMessageChunk, BaseMessage, HumanMessage, AIMessage, get_buffer_string
from langchain.schema.output import ChatGenerationChunk
from langchain.schema.runnable import RunnableConfig
+from transformers import GPT2TokenizerFast
+
+tokenizer = GPT2TokenizerFast.from_pretrained('gpt2', cache_dir="/opt/maxkb/model/tokenizer", resume_download=False,
+ force_download=False)
class QianfanChatModel(QianfanChatEndpoint):
+ def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
+ return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
+
+ def get_num_tokens(self, text: str) -> int:
+ return len(tokenizer.encode(text))
+
def stream(
self,
input: LanguageModelInput,
@@ -30,7 +40,7 @@ class QianfanChatModel(QianfanChatEndpoint):
**kwargs: Any,
) -> Iterator[BaseMessageChunk]:
if len(input) % 2 == 0:
- input = [HumanMessage(content='占位'), *input]
+ input = [HumanMessage(content='padding'), *input]
input = [
HumanMessage(content=input[index].content) if index % 2 == 0 else AIMessage(content=input[index].content)
for index in range(0, len(input))]