mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 01:33:05 +00:00
feat: 优化对话逻辑
This commit is contained in:
parent
7349f00c54
commit
3f87335c80
|
|
@ -0,0 +1,55 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: I_base_chat_pipeline.py
|
||||
@date:2024/1/9 17:25
|
||||
@desc:
|
||||
"""
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from typing import Type
|
||||
|
||||
from rest_framework import serializers
|
||||
|
||||
|
||||
class IBaseChatPipelineStep:
|
||||
def __init__(self):
|
||||
# 当前步骤上下文,用于存储当前步骤信息
|
||||
self.context = {}
|
||||
|
||||
@abstractmethod
|
||||
def get_step_serializer(self, manage) -> Type[serializers.Serializer]:
|
||||
pass
|
||||
|
||||
def valid_args(self, manage):
|
||||
step_serializer_clazz = self.get_step_serializer(manage)
|
||||
step_serializer = step_serializer_clazz(data=manage.context)
|
||||
step_serializer.is_valid(raise_exception=True)
|
||||
self.context['step_args'] = step_serializer.data
|
||||
|
||||
def run(self, manage):
|
||||
"""
|
||||
|
||||
:param manage: 步骤管理器
|
||||
:return: 执行结果
|
||||
"""
|
||||
start_time = time.time()
|
||||
# 校验参数,
|
||||
self.valid_args(manage)
|
||||
self._run(manage)
|
||||
self.context['start_time'] = start_time
|
||||
self.context['run_time'] = time.time() - start_time
|
||||
|
||||
def _run(self, manage):
|
||||
pass
|
||||
|
||||
def execute(self, **kwargs):
|
||||
pass
|
||||
|
||||
def get_details(self, manage, **kwargs):
|
||||
"""
|
||||
运行详情
|
||||
:return: 步骤详情
|
||||
"""
|
||||
return None
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: __init__.py.py
|
||||
@date:2024/1/9 17:23
|
||||
@desc:
|
||||
"""
|
||||
|
|
@ -0,0 +1,45 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: pipeline_manage.py
|
||||
@date:2024/1/9 17:40
|
||||
@desc:
|
||||
"""
|
||||
import time
|
||||
from functools import reduce
|
||||
from typing import List, Type, Dict
|
||||
|
||||
from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep
|
||||
|
||||
|
||||
class PiplineManage:
|
||||
def __init__(self, step_list: List[Type[IBaseChatPipelineStep]]):
|
||||
# 步骤执行器
|
||||
self.step_list = [step() for step in step_list]
|
||||
# 上下文
|
||||
self.context = {'message_tokens': 0, 'answer_tokens': 0}
|
||||
|
||||
def run(self, context: Dict = None):
|
||||
self.context['start_time'] = time.time()
|
||||
if context is not None:
|
||||
for key, value in context.items():
|
||||
self.context[key] = value
|
||||
for step in self.step_list:
|
||||
step.run(self)
|
||||
|
||||
def get_details(self):
|
||||
return reduce(lambda x, y: {**x, **y}, [{item.get('step_type'): item} for item in
|
||||
filter(lambda r: r is not None,
|
||||
[row.get_details(self) for row in self.step_list])], {})
|
||||
|
||||
class builder:
|
||||
def __init__(self):
|
||||
self.step_list: List[Type[IBaseChatPipelineStep]] = []
|
||||
|
||||
def append_step(self, step: Type[IBaseChatPipelineStep]):
|
||||
self.step_list.append(step)
|
||||
return self
|
||||
|
||||
def build(self):
|
||||
return PiplineManage(step_list=self.step_list)
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: __init__.py.py
|
||||
@date:2024/1/9 18:23
|
||||
@desc:
|
||||
"""
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: __init__.py.py
|
||||
@date:2024/1/9 18:23
|
||||
@desc:
|
||||
"""
|
||||
|
|
@ -0,0 +1,88 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: i_chat_step.py
|
||||
@date:2024/1/9 18:17
|
||||
@desc: 对话
|
||||
"""
|
||||
from abc import abstractmethod
|
||||
from typing import Type, List
|
||||
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.schema import BaseMessage
|
||||
from rest_framework import serializers
|
||||
|
||||
from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep
|
||||
from application.chat_pipeline.pipeline_manage import PiplineManage
|
||||
from common.field.common import InstanceField
|
||||
from dataset.models import Paragraph
|
||||
|
||||
|
||||
class ModelField(serializers.Field):
|
||||
def to_internal_value(self, data):
|
||||
if not isinstance(data, BaseChatModel):
|
||||
self.fail('模型类型错误', value=data)
|
||||
return data
|
||||
|
||||
def to_representation(self, value):
|
||||
return value
|
||||
|
||||
|
||||
class MessageField(serializers.Field):
|
||||
def to_internal_value(self, data):
|
||||
if not isinstance(data, BaseMessage):
|
||||
self.fail('message类型错误', value=data)
|
||||
return data
|
||||
|
||||
def to_representation(self, value):
|
||||
return value
|
||||
|
||||
|
||||
class PostResponseHandler:
|
||||
@abstractmethod
|
||||
def handler(self, chat_id, chat_record_id, paragraph_list: List[Paragraph], problem_text: str, answer_text,
|
||||
manage, step, padding_problem_text: str = None, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
class IChatStep(IBaseChatPipelineStep):
|
||||
class InstanceSerializer(serializers.Serializer):
|
||||
# 对话列表
|
||||
message_list = serializers.ListField(required=True, child=MessageField(required=True))
|
||||
# 大语言模型
|
||||
chat_model = ModelField()
|
||||
# 段落列表
|
||||
paragraph_list = serializers.ListField()
|
||||
# 对话id
|
||||
chat_id = serializers.UUIDField(required=True)
|
||||
# 用户问题
|
||||
problem_text = serializers.CharField(required=True)
|
||||
# 后置处理器
|
||||
post_response_handler = InstanceField(model_type=PostResponseHandler)
|
||||
# 补全问题
|
||||
padding_problem_text = serializers.CharField(required=False)
|
||||
|
||||
def is_valid(self, *, raise_exception=False):
|
||||
super().is_valid(raise_exception=True)
|
||||
message_list: List = self.initial_data.get('message_list')
|
||||
for message in message_list:
|
||||
if not isinstance(message, BaseMessage):
|
||||
raise Exception("message 类型错误")
|
||||
|
||||
def get_step_serializer(self, manage: PiplineManage) -> Type[serializers.Serializer]:
|
||||
return self.InstanceSerializer
|
||||
|
||||
def _run(self, manage: PiplineManage):
|
||||
chat_result = self.execute(**self.context['step_args'], manage=manage)
|
||||
manage.context['chat_result'] = chat_result
|
||||
|
||||
@abstractmethod
|
||||
def execute(self, message_list: List[BaseMessage],
|
||||
chat_id, problem_text,
|
||||
post_response_handler: PostResponseHandler,
|
||||
chat_model: BaseChatModel = None,
|
||||
paragraph_list=None,
|
||||
manage: PiplineManage = None,
|
||||
padding_problem_text: str = None, **kwargs):
|
||||
pass
|
||||
|
|
@ -0,0 +1,111 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: base_chat_step.py
|
||||
@date:2024/1/9 18:25
|
||||
@desc: 对话step Base实现
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import traceback
|
||||
import uuid
|
||||
from typing import List
|
||||
|
||||
from django.http import StreamingHttpResponse
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.schema import BaseMessage
|
||||
from langchain.schema.messages import BaseMessageChunk, HumanMessage
|
||||
|
||||
from application.chat_pipeline.pipeline_manage import PiplineManage
|
||||
from application.chat_pipeline.step.chat_step.i_chat_step import IChatStep, PostResponseHandler
|
||||
from dataset.models import Paragraph
|
||||
|
||||
|
||||
def event_content(response,
|
||||
chat_id,
|
||||
chat_record_id,
|
||||
paragraph_list: List[Paragraph],
|
||||
post_response_handler: PostResponseHandler,
|
||||
manage,
|
||||
step,
|
||||
chat_model,
|
||||
message_list: List[BaseMessage],
|
||||
problem_text: str,
|
||||
padding_problem_text: str = None):
|
||||
all_text = ''
|
||||
try:
|
||||
for chunk in response:
|
||||
all_text += chunk.content
|
||||
yield 'data: ' + json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
|
||||
'content': chunk.content, 'is_end': False}) + "\n\n"
|
||||
|
||||
yield 'data: ' + json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
|
||||
'content': '', 'is_end': True}) + "\n\n"
|
||||
# 获取token
|
||||
request_token = chat_model.get_num_tokens_from_messages(message_list)
|
||||
response_token = chat_model.get_num_tokens(all_text)
|
||||
step.context['message_tokens'] = request_token
|
||||
step.context['answer_tokens'] = response_token
|
||||
current_time = time.time()
|
||||
step.context['answer_text'] = all_text
|
||||
step.context['run_time'] = current_time - step.context['start_time']
|
||||
manage.context['run_time'] = current_time - manage.context['start_time']
|
||||
manage.context['message_tokens'] = manage.context['message_tokens'] + request_token
|
||||
manage.context['answer_tokens'] = manage.context['answer_tokens'] + response_token
|
||||
post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text,
|
||||
all_text, manage, step, padding_problem_text)
|
||||
except Exception as e:
|
||||
logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}')
|
||||
yield 'data: ' + json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
|
||||
'content': '异常' + str(e), 'is_end': True}) + "\n\n"
|
||||
|
||||
|
||||
class BaseChatStep(IChatStep):
|
||||
def execute(self, message_list: List[BaseMessage],
|
||||
chat_id,
|
||||
problem_text,
|
||||
post_response_handler: PostResponseHandler,
|
||||
chat_model: BaseChatModel = None,
|
||||
paragraph_list=None,
|
||||
manage: PiplineManage = None,
|
||||
padding_problem_text: str = None,
|
||||
**kwargs):
|
||||
# 调用模型
|
||||
if chat_model is None:
|
||||
chat_result = iter(
|
||||
[BaseMessageChunk(content=paragraph.title + "\n" + paragraph.content) for paragraph in paragraph_list])
|
||||
else:
|
||||
chat_result = chat_model.stream(message_list)
|
||||
|
||||
chat_record_id = uuid.uuid1()
|
||||
r = StreamingHttpResponse(
|
||||
streaming_content=event_content(chat_result, chat_id, chat_record_id, paragraph_list,
|
||||
post_response_handler, manage, self, chat_model, message_list, problem_text,
|
||||
padding_problem_text),
|
||||
content_type='text/event-stream;charset=utf-8')
|
||||
|
||||
r['Cache-Control'] = 'no-cache'
|
||||
return r
|
||||
|
||||
def get_details(self, manage, **kwargs):
|
||||
return {
|
||||
'step_type': 'chat_step',
|
||||
'run_time': self.context['run_time'],
|
||||
'model_id': str(manage.context['model_id']),
|
||||
'message_list': self.reset_message_list(self.context['step_args'].get('message_list'),
|
||||
self.context['answer_text']),
|
||||
'message_tokens': self.context['message_tokens'],
|
||||
'answer_tokens': self.context['answer_tokens'],
|
||||
'cost': 0,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def reset_message_list(message_list: List[BaseMessage], answer_text):
|
||||
result = [{'role': 'user' if isinstance(message, HumanMessage) else 'ai', 'content': message.content} for
|
||||
message
|
||||
in
|
||||
message_list]
|
||||
result.append({'role': 'ai', 'content': answer_text})
|
||||
return result
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: __init__.py.py
|
||||
@date:2024/1/9 18:23
|
||||
@desc:
|
||||
"""
|
||||
|
|
@ -0,0 +1,68 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: i_generate_human_message_step.py
|
||||
@date:2024/1/9 18:15
|
||||
@desc: 生成对话模板
|
||||
"""
|
||||
from abc import abstractmethod
|
||||
from typing import Type, List
|
||||
|
||||
from langchain.schema import BaseMessage
|
||||
from rest_framework import serializers
|
||||
|
||||
from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep
|
||||
from application.chat_pipeline.pipeline_manage import PiplineManage
|
||||
from application.models import ChatRecord
|
||||
from common.field.common import InstanceField
|
||||
from dataset.models import Paragraph
|
||||
|
||||
|
||||
class IGenerateHumanMessageStep(IBaseChatPipelineStep):
|
||||
class InstanceSerializer(serializers.Serializer):
|
||||
# 问题
|
||||
problem_text = serializers.CharField(required=True)
|
||||
# 段落列表
|
||||
paragraph_list = serializers.ListField(child=InstanceField(model_type=Paragraph, required=True))
|
||||
# 历史对答
|
||||
history_chat_record = serializers.ListField(child=InstanceField(model_type=ChatRecord, required=True))
|
||||
# 多轮对话数量
|
||||
dialogue_number = serializers.IntegerField(required=True)
|
||||
# 最大携带知识库段落长度
|
||||
max_paragraph_char_number = serializers.IntegerField(required=True)
|
||||
# 模板
|
||||
prompt = serializers.CharField(required=True)
|
||||
# 补齐问题
|
||||
padding_problem_text = serializers.CharField(required=False)
|
||||
|
||||
def get_step_serializer(self, manage: PiplineManage) -> Type[serializers.Serializer]:
|
||||
return self.InstanceSerializer
|
||||
|
||||
def _run(self, manage: PiplineManage):
|
||||
message_list = self.execute(**self.context['step_args'])
|
||||
manage.context['message_list'] = message_list
|
||||
|
||||
@abstractmethod
|
||||
def execute(self,
|
||||
problem_text: str,
|
||||
paragraph_list: List[Paragraph],
|
||||
history_chat_record: List[ChatRecord],
|
||||
dialogue_number: int,
|
||||
max_paragraph_char_number: int,
|
||||
prompt: str,
|
||||
padding_problem_text: str = None,
|
||||
**kwargs) -> List[BaseMessage]:
|
||||
"""
|
||||
|
||||
:param problem_text: 原始问题文本
|
||||
:param paragraph_list: 段落列表
|
||||
:param history_chat_record: 历史对话记录
|
||||
:param dialogue_number: 多轮对话数量
|
||||
:param max_paragraph_char_number: 最大段落长度
|
||||
:param prompt: 模板
|
||||
:param padding_problem_text 用户修改文本
|
||||
:param kwargs: 其他参数
|
||||
:return:
|
||||
"""
|
||||
pass
|
||||
|
|
@ -0,0 +1,57 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: base_generate_human_message_step.py.py
|
||||
@date:2024/1/10 17:50
|
||||
@desc:
|
||||
"""
|
||||
from typing import List
|
||||
|
||||
from langchain.schema import BaseMessage, HumanMessage
|
||||
|
||||
from application.chat_pipeline.step.generate_human_message_step.i_generate_human_message_step import \
|
||||
IGenerateHumanMessageStep
|
||||
from application.models import ChatRecord
|
||||
from common.util.split_model import flat_map
|
||||
from dataset.models import Paragraph
|
||||
|
||||
|
||||
class BaseGenerateHumanMessageStep(IGenerateHumanMessageStep):
|
||||
|
||||
def execute(self, problem_text: str,
|
||||
paragraph_list: List[Paragraph],
|
||||
history_chat_record: List[ChatRecord],
|
||||
dialogue_number: int,
|
||||
max_paragraph_char_number: int,
|
||||
prompt: str,
|
||||
padding_problem_text: str = None,
|
||||
**kwargs) -> List[BaseMessage]:
|
||||
exec_problem_text = padding_problem_text if padding_problem_text is not None else problem_text
|
||||
start_index = len(history_chat_record) - dialogue_number
|
||||
history_message = [[history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()]
|
||||
for index in
|
||||
range(start_index if start_index > 0 else 0, len(history_chat_record))]
|
||||
return [*flat_map(history_message),
|
||||
self.to_human_message(prompt, exec_problem_text, max_paragraph_char_number, paragraph_list)]
|
||||
|
||||
@staticmethod
|
||||
def to_human_message(prompt: str,
|
||||
problem: str,
|
||||
max_paragraph_char_number: int,
|
||||
paragraph_list: List[Paragraph]):
|
||||
if paragraph_list is None or len(paragraph_list) == 0:
|
||||
return HumanMessage(content=problem)
|
||||
temp_data = ""
|
||||
data_list = []
|
||||
for p in paragraph_list:
|
||||
content = f"{p.title}:{p.content}"
|
||||
temp_data += content
|
||||
if len(temp_data) > max_paragraph_char_number:
|
||||
row_data = content[0:max_paragraph_char_number - len(temp_data)]
|
||||
data_list.append(f"<data>{row_data}</data>")
|
||||
break
|
||||
else:
|
||||
data_list.append(f"<data>{content}</data>")
|
||||
data = "\n".join(data_list)
|
||||
return HumanMessage(content=prompt.format(**{'data': data, 'question': problem}))
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: __init__.py.py
|
||||
@date:2024/1/9 18:23
|
||||
@desc:
|
||||
"""
|
||||
|
|
@ -0,0 +1,49 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: i_reset_problem_step.py
|
||||
@date:2024/1/9 18:12
|
||||
@desc: 重写处理问题
|
||||
"""
|
||||
from abc import abstractmethod
|
||||
from typing import Type, List
|
||||
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from rest_framework import serializers
|
||||
|
||||
from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep
|
||||
from application.chat_pipeline.pipeline_manage import PiplineManage
|
||||
from application.chat_pipeline.step.chat_step.i_chat_step import ModelField
|
||||
from application.models import ChatRecord
|
||||
from common.field.common import InstanceField
|
||||
|
||||
|
||||
class IResetProblemStep(IBaseChatPipelineStep):
|
||||
class InstanceSerializer(serializers.Serializer):
|
||||
# 问题文本
|
||||
problem_text = serializers.CharField(required=True)
|
||||
# 历史对答
|
||||
history_chat_record = serializers.ListField(child=InstanceField(model_type=ChatRecord, required=True))
|
||||
# 大语言模型
|
||||
chat_model = ModelField()
|
||||
|
||||
def get_step_serializer(self, manage: PiplineManage) -> Type[serializers.Serializer]:
|
||||
return self.InstanceSerializer
|
||||
|
||||
def _run(self, manage: PiplineManage):
|
||||
padding_problem = self.execute(**self.context.get('step_args'))
|
||||
# 用户输入问题
|
||||
source_problem_text = self.context.get('step_args').get('problem_text')
|
||||
self.context['problem_text'] = source_problem_text
|
||||
self.context['padding_problem_text'] = padding_problem
|
||||
manage.context['problem_text'] = source_problem_text
|
||||
manage.context['padding_problem_text'] = padding_problem
|
||||
# 累加tokens
|
||||
manage.context['message_tokens'] = manage.context['message_tokens'] + self.context.get('message_tokens')
|
||||
manage.context['answer_tokens'] = manage.context['answer_tokens'] + self.context.get('answer_tokens')
|
||||
|
||||
@abstractmethod
|
||||
def execute(self, problem_text: str, history_chat_record: List[ChatRecord] = None, chat_model: BaseChatModel = None,
|
||||
**kwargs):
|
||||
pass
|
||||
|
|
@ -0,0 +1,47 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: base_reset_problem_step.py
|
||||
@date:2024/1/10 14:35
|
||||
@desc:
|
||||
"""
|
||||
from typing import List
|
||||
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.schema import HumanMessage
|
||||
|
||||
from application.chat_pipeline.step.reset_problem_step.i_reset_problem_step import IResetProblemStep
|
||||
from application.models import ChatRecord
|
||||
from common.util.split_model import flat_map
|
||||
|
||||
prompt = (
|
||||
'()里面是用户问题,根据上下文回答揣测用户问题({question}) 要求: 输出一个补全问题,并且放在<data></data>标签中')
|
||||
|
||||
|
||||
class BaseResetProblemStep(IResetProblemStep):
|
||||
def execute(self, problem_text: str, history_chat_record: List[ChatRecord] = None, chat_model: BaseChatModel = None,
|
||||
**kwargs) -> str:
|
||||
start_index = len(history_chat_record) - 3
|
||||
history_message = [[history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()]
|
||||
for index in
|
||||
range(start_index if start_index > 0 else 0, len(history_chat_record))]
|
||||
message_list = [*flat_map(history_message),
|
||||
HumanMessage(content=prompt.format(**{'question': problem_text}))]
|
||||
response = chat_model(message_list)
|
||||
padding_problem = response.content[response.content.index('<data>') + 6:response.content.index('</data>')]
|
||||
self.context['message_tokens'] = chat_model.get_num_tokens_from_messages(message_list)
|
||||
self.context['answer_tokens'] = chat_model.get_num_tokens(padding_problem)
|
||||
return padding_problem
|
||||
|
||||
def get_details(self, manage, **kwargs):
|
||||
return {
|
||||
'step_type': 'problem_padding',
|
||||
'run_time': self.context['run_time'],
|
||||
'model_id': str(manage.context['model_id']) if 'model_id' in manage.context else None,
|
||||
'message_tokens': self.context['message_tokens'],
|
||||
'answer_tokens': self.context['answer_tokens'],
|
||||
'cost': 0,
|
||||
'padding_problem_text': self.context.get('padding_problem_text'),
|
||||
'problem_text': self.context.get("step_args").get('problem_text'),
|
||||
}
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: __init__.py.py
|
||||
@date:2024/1/9 18:24
|
||||
@desc:
|
||||
"""
|
||||
|
|
@ -0,0 +1,58 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: i_search_dataset_step.py
|
||||
@date:2024/1/9 18:10
|
||||
@desc: 检索知识库
|
||||
"""
|
||||
from abc import abstractmethod
|
||||
from typing import List, Type
|
||||
|
||||
from rest_framework import serializers
|
||||
|
||||
from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep
|
||||
from application.chat_pipeline.pipeline_manage import PiplineManage
|
||||
from dataset.models import Paragraph
|
||||
|
||||
|
||||
class ISearchDatasetStep(IBaseChatPipelineStep):
|
||||
class InstanceSerializer(serializers.Serializer):
|
||||
# 原始问题文本
|
||||
problem_text = serializers.CharField(required=True)
|
||||
# 系统补全问题文本
|
||||
padding_problem_text = serializers.CharField(required=False)
|
||||
# 需要查询的数据集id列表
|
||||
dataset_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True))
|
||||
# 需要排除的文档id
|
||||
exclude_document_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True))
|
||||
# 需要排除向量id
|
||||
exclude_paragraph_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True))
|
||||
# 需要查询的条数
|
||||
top_n = serializers.IntegerField(required=True)
|
||||
# 相似度 0-1之间
|
||||
similarity = serializers.FloatField(required=True, max_value=1, min_value=0)
|
||||
|
||||
def get_step_serializer(self, manage: PiplineManage) -> Type[InstanceSerializer]:
|
||||
return self.InstanceSerializer
|
||||
|
||||
def _run(self, manage: PiplineManage):
|
||||
paragraph_list = self.execute(**self.context['step_args'])
|
||||
manage.context['paragraph_list'] = paragraph_list
|
||||
|
||||
@abstractmethod
|
||||
def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str],
|
||||
exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None,
|
||||
**kwargs) -> List[Paragraph]:
|
||||
"""
|
||||
关于 用户和补全问题 说明: 补全问题如果有就使用补全问题去查询 反之就用用户原始问题查询
|
||||
:param similarity: 相关性
|
||||
:param top_n: 查询多少条
|
||||
:param problem_text: 用户问题
|
||||
:param dataset_id_list: 需要查询的数据集id列表
|
||||
:param exclude_document_id_list: 需要排除的文档id
|
||||
:param exclude_paragraph_id_list: 需要排除段落id
|
||||
:param padding_problem_text 补全问题
|
||||
:return: 段落列表
|
||||
"""
|
||||
pass
|
||||
|
|
@ -0,0 +1,58 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: base_search_dataset_step.py
|
||||
@date:2024/1/10 10:33
|
||||
@desc:
|
||||
"""
|
||||
from typing import List
|
||||
|
||||
from django.db.models import QuerySet
|
||||
|
||||
from application.chat_pipeline.step.search_dataset_step.i_search_dataset_step import ISearchDatasetStep
|
||||
from common.config.embedding_config import VectorStore, EmbeddingModel
|
||||
from dataset.models import Paragraph
|
||||
|
||||
|
||||
class BaseSearchDatasetStep(ISearchDatasetStep):
|
||||
|
||||
def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str],
|
||||
exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None,
|
||||
**kwargs) -> List[Paragraph]:
|
||||
exec_problem_text = padding_problem_text if padding_problem_text is not None else problem_text
|
||||
embedding_model = EmbeddingModel.get_embedding_model()
|
||||
embedding_value = embedding_model.embed_query(exec_problem_text)
|
||||
vector = VectorStore.get_embedding_vector()
|
||||
embedding_list = vector.query(embedding_value, dataset_id_list, exclude_document_id_list,
|
||||
exclude_paragraph_id_list, True, top_n, similarity)
|
||||
if embedding_list is None:
|
||||
return []
|
||||
return self.list_paragraph([row.get('paragraph_id') for row in embedding_list], vector)
|
||||
|
||||
@staticmethod
|
||||
def list_paragraph(paragraph_id_list: List, vector):
|
||||
if paragraph_id_list is None or len(paragraph_id_list) == 0:
|
||||
return []
|
||||
paragraph_list = QuerySet(Paragraph).filter(id__in=paragraph_id_list)
|
||||
# 如果向量库中存在脏数据 直接删除
|
||||
if len(paragraph_list) != len(paragraph_id_list):
|
||||
exist_paragraph_list = [str(row.id) for row in paragraph_list]
|
||||
for paragraph_id in paragraph_id_list:
|
||||
if not exist_paragraph_list.__contains__(paragraph_id):
|
||||
vector.delete_by_paragraph_id(paragraph_id)
|
||||
return paragraph_list
|
||||
|
||||
def get_details(self, manage, **kwargs):
|
||||
step_args = self.context['step_args']
|
||||
|
||||
return {
|
||||
'step_type': 'search_step',
|
||||
'run_time': self.context['run_time'],
|
||||
'problem_text': step_args.get(
|
||||
'padding_problem_text') if 'padding_problem_text' in step_args else step_args.get('problem_text'),
|
||||
'model_name': EmbeddingModel.get_embedding_model().model_name,
|
||||
'message_tokens': 0,
|
||||
'answer_tokens': 0,
|
||||
'cost': 0
|
||||
}
|
||||
|
|
@ -0,0 +1,55 @@
|
|||
# Generated by Django 4.1.10 on 2024-01-12 18:46
|
||||
|
||||
import django.contrib.postgres.fields
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('application', '0002_alter_chatrecord_dataset'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.RemoveField(
|
||||
model_name='chatrecord',
|
||||
name='dataset',
|
||||
),
|
||||
migrations.RemoveField(
|
||||
model_name='chatrecord',
|
||||
name='paragraph',
|
||||
),
|
||||
migrations.RemoveField(
|
||||
model_name='chatrecord',
|
||||
name='source_id',
|
||||
),
|
||||
migrations.RemoveField(
|
||||
model_name='chatrecord',
|
||||
name='source_type',
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name='chatrecord',
|
||||
name='const',
|
||||
field=models.IntegerField(default=0, verbose_name='总费用'),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name='chatrecord',
|
||||
name='details',
|
||||
field=models.JSONField(default=list, verbose_name='对话详情'),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name='chatrecord',
|
||||
name='paragraph_id_list',
|
||||
field=django.contrib.postgres.fields.ArrayField(base_field=models.UUIDField(blank=True), default=list, size=None, verbose_name='引用段落id列表'),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name='chatrecord',
|
||||
name='run_time',
|
||||
field=models.FloatField(default=0, verbose_name='运行时长'),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name='chatrecord',
|
||||
name='answer_text',
|
||||
field=models.CharField(max_length=4096, verbose_name='答案'),
|
||||
),
|
||||
]
|
||||
|
|
@ -0,0 +1,38 @@
|
|||
# Generated by Django 4.1.10 on 2024-01-15 16:07
|
||||
|
||||
import application.models.application
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('application', '0003_remove_chatrecord_dataset_and_more'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.RemoveField(
|
||||
model_name='application',
|
||||
name='example',
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name='application',
|
||||
name='dataset_setting',
|
||||
field=models.JSONField(default=application.models.application.get_dataset_setting_dict, verbose_name='数据集参数设置'),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name='application',
|
||||
name='model_setting',
|
||||
field=models.JSONField(default=application.models.application.get_model_setting_dict, verbose_name='模型参数相关设置'),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name='application',
|
||||
name='problem_optimization',
|
||||
field=models.BooleanField(default=False, verbose_name='问题优化'),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name='chatrecord',
|
||||
name='details',
|
||||
field=models.JSONField(default={}, verbose_name='对话详情'),
|
||||
),
|
||||
]
|
||||
|
|
@ -0,0 +1,18 @@
|
|||
# Generated by Django 4.1.10 on 2024-01-16 11:22
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('application', '0004_remove_application_example_and_more'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name='chatrecord',
|
||||
name='details',
|
||||
field=models.JSONField(default=dict, verbose_name='对话详情'),
|
||||
),
|
||||
]
|
||||
|
|
@ -10,23 +10,47 @@ import uuid
|
|||
|
||||
from django.contrib.postgres.fields import ArrayField
|
||||
from django.db import models
|
||||
from langchain.schema import HumanMessage, AIMessage
|
||||
|
||||
from common.mixins.app_model_mixin import AppModelMixin
|
||||
from dataset.models.data_set import DataSet, Paragraph
|
||||
from embedding.models import SourceType
|
||||
from dataset.models.data_set import DataSet
|
||||
from setting.models.model_management import Model
|
||||
from users.models import User
|
||||
|
||||
|
||||
def get_dataset_setting_dict():
|
||||
return {'top_n': 3, 'similarity': 0.6, 'max_paragraph_char_number': 5000}
|
||||
|
||||
|
||||
def get_model_setting_dict():
|
||||
return {'prompt': Application.get_default_model_prompt()}
|
||||
|
||||
|
||||
class Application(AppModelMixin):
|
||||
id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id")
|
||||
name = models.CharField(max_length=128, verbose_name="应用名称")
|
||||
desc = models.CharField(max_length=128, verbose_name="引用描述", default="")
|
||||
prologue = models.CharField(max_length=1024, verbose_name="开场白", default="")
|
||||
example = ArrayField(verbose_name="示例列表", base_field=models.CharField(max_length=256, blank=True), default=list)
|
||||
dialogue_number = models.IntegerField(default=0, verbose_name="会话数量")
|
||||
user = models.ForeignKey(User, on_delete=models.DO_NOTHING)
|
||||
model = models.ForeignKey(Model, on_delete=models.SET_NULL, db_constraint=False, blank=True, null=True)
|
||||
dataset_setting = models.JSONField(verbose_name="数据集参数设置", default=get_dataset_setting_dict)
|
||||
model_setting = models.JSONField(verbose_name="模型参数相关设置", default=get_model_setting_dict)
|
||||
problem_optimization = models.BooleanField(verbose_name="问题优化", default=False)
|
||||
|
||||
@staticmethod
|
||||
def get_default_model_prompt():
|
||||
return ('已知信息:'
|
||||
'\n{data}'
|
||||
'\n回答要求:'
|
||||
'\n- 如果你不知道答案或者没有从获取答案,请回答“没有在知识库中查找到相关信息,建议咨询相关技术支持或参考官方文档进行操作”。'
|
||||
'\n- 避免提及你是从<data></data>中获得的知识。'
|
||||
'\n- 请保持答案与<data></data>中描述的一致。'
|
||||
'\n- 请使用markdown 语法优化答案的格式。'
|
||||
'\n- <data></data>中的图片链接、链接地址和脚本语言请完整返回。'
|
||||
'\n- 请使用与问题相同的语言来回答。'
|
||||
'\n问题:'
|
||||
'\n{question}')
|
||||
|
||||
class Meta:
|
||||
db_table = "application"
|
||||
|
|
@ -65,20 +89,28 @@ class ChatRecord(AppModelMixin):
|
|||
chat = models.ForeignKey(Chat, on_delete=models.CASCADE)
|
||||
vote_status = models.CharField(verbose_name='投票', max_length=10, choices=VoteChoices.choices,
|
||||
default=VoteChoices.UN_VOTE)
|
||||
dataset = models.ForeignKey(DataSet, on_delete=models.SET_NULL, verbose_name="数据集", blank=True, null=True)
|
||||
paragraph = models.ForeignKey(Paragraph, on_delete=models.SET_NULL, verbose_name="段落id", blank=True, null=True)
|
||||
source_id = models.UUIDField(max_length=128, verbose_name="资源id 段落/问题 id ", null=True)
|
||||
source_type = models.CharField(verbose_name='资源类型', max_length=2, choices=SourceType.choices,
|
||||
default=SourceType.PROBLEM, blank=True, null=True)
|
||||
paragraph_id_list = ArrayField(verbose_name="引用段落id列表",
|
||||
base_field=models.UUIDField(max_length=128, blank=True)
|
||||
, default=list)
|
||||
problem_text = models.CharField(max_length=1024, verbose_name="问题")
|
||||
answer_text = models.CharField(max_length=4096, verbose_name="答案")
|
||||
message_tokens = models.IntegerField(verbose_name="请求token数量", default=0)
|
||||
answer_tokens = models.IntegerField(verbose_name="响应token数量", default=0)
|
||||
problem_text = models.CharField(max_length=1024, verbose_name="问题")
|
||||
answer_text = models.CharField(max_length=1024, verbose_name="答案")
|
||||
const = models.IntegerField(verbose_name="总费用", default=0)
|
||||
details = models.JSONField(verbose_name="对话详情", default=dict)
|
||||
improve_paragraph_id_list = ArrayField(verbose_name="改进标注列表",
|
||||
base_field=models.UUIDField(max_length=128, blank=True)
|
||||
, default=list)
|
||||
|
||||
run_time = models.FloatField(verbose_name="运行时长", default=0)
|
||||
index = models.IntegerField(verbose_name="对话下标")
|
||||
|
||||
def get_human_message(self):
|
||||
if 'problem_padding' in self.details:
|
||||
return HumanMessage(content=self.details.get('problem_padding').get('padding_problem_text'))
|
||||
return HumanMessage(content=self.problem_text)
|
||||
|
||||
def get_ai_message(self):
|
||||
return AIMessage(content=self.answer_text)
|
||||
|
||||
class Meta:
|
||||
db_table = "application_chat_record"
|
||||
|
|
|
|||
|
|
@ -63,15 +63,35 @@ class ApplicationSerializerModel(serializers.ModelSerializer):
|
|||
fields = "__all__"
|
||||
|
||||
|
||||
class DatasetSettingSerializer(serializers.Serializer):
|
||||
top_n = serializers.FloatField(required=True)
|
||||
similarity = serializers.FloatField(required=True, max_value=1, min_value=0)
|
||||
max_paragraph_char_number = serializers.IntegerField(required=True, max_value=10000)
|
||||
|
||||
|
||||
class ModelSettingSerializer(serializers.Serializer):
|
||||
prompt = serializers.CharField(required=True, max_length=4096)
|
||||
|
||||
|
||||
class ApplicationSerializer(serializers.Serializer):
|
||||
name = serializers.CharField(required=True)
|
||||
desc = serializers.CharField(required=False, allow_null=True, allow_blank=True)
|
||||
model_id = serializers.CharField(required=True)
|
||||
multiple_rounds_dialogue = serializers.BooleanField(required=True)
|
||||
prologue = serializers.CharField(required=False, allow_null=True, allow_blank=True)
|
||||
example = serializers.ListSerializer(required=False, child=serializers.CharField(required=True), allow_null=True)
|
||||
dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True),
|
||||
allow_null=True)
|
||||
# 数据集相关设置
|
||||
dataset_setting = DatasetSettingSerializer(required=True)
|
||||
# 模型相关设置
|
||||
model_setting = ModelSettingSerializer(required=True)
|
||||
# 问题补全
|
||||
problem_optimization = serializers.BooleanField(required=True)
|
||||
|
||||
def is_valid(self, *, user_id=None, raise_exception=False):
|
||||
super().is_valid(raise_exception=True)
|
||||
ModelDatasetAssociation(data={'user_id': user_id, 'model_id': self.data.get('model_id'),
|
||||
'dataset_id_list': self.data.get('dataset_id_list')}).is_valid()
|
||||
|
||||
class AccessTokenSerializer(serializers.Serializer):
|
||||
application_id = serializers.UUIDField(required=True)
|
||||
|
|
@ -135,13 +155,13 @@ class ApplicationSerializer(serializers.Serializer):
|
|||
model_id = serializers.CharField(required=False)
|
||||
multiple_rounds_dialogue = serializers.BooleanField(required=False)
|
||||
prologue = serializers.CharField(required=False, allow_null=True, allow_blank=True)
|
||||
example = serializers.ListSerializer(required=False, child=serializers.CharField(required=True))
|
||||
dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True))
|
||||
|
||||
def is_valid(self, *, user_id=None, raise_exception=False):
|
||||
super().is_valid(raise_exception=True)
|
||||
ModelDatasetAssociation(data={'user_id': user_id, 'model_id': self.data.get('model_id'),
|
||||
'dataset_id_list': self.data.get('dataset_id_list')}).is_valid()
|
||||
# 数据集相关设置
|
||||
dataset_setting = serializers.JSONField(required=False, allow_null=True)
|
||||
# 模型相关设置
|
||||
model_setting = serializers.JSONField(required=False, allow_null=True)
|
||||
# 问题补全
|
||||
problem_optimization = serializers.BooleanField(required=False, allow_null=True)
|
||||
|
||||
class Create(serializers.Serializer):
|
||||
user_id = serializers.UUIDField(required=True)
|
||||
|
|
@ -168,9 +188,12 @@ class ApplicationSerializer(serializers.Serializer):
|
|||
@staticmethod
|
||||
def to_application_model(user_id: str, application: Dict):
|
||||
return Application(id=uuid.uuid1(), name=application.get('name'), desc=application.get('desc'),
|
||||
prologue=application.get('prologue'), example=application.get('example'),
|
||||
prologue=application.get('prologue'),
|
||||
dialogue_number=3 if application.get('multiple_rounds_dialogue') else 0,
|
||||
user_id=user_id, model_id=application.get('model_id'),
|
||||
dataset_setting=application.get('dataset_setting'),
|
||||
model_setting=application.get('model_setting'),
|
||||
problem_optimization=application.get('problem_optimization')
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -267,7 +290,7 @@ class ApplicationSerializer(serializers.Serializer):
|
|||
class ApplicationModel(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = Application
|
||||
fields = ['id', 'name', 'desc', 'prologue', 'example', 'dialogue_number']
|
||||
fields = ['id', 'name', 'desc', 'prologue', 'dialogue_number']
|
||||
|
||||
class Operate(serializers.Serializer):
|
||||
application_id = serializers.UUIDField(required=True)
|
||||
|
|
@ -317,8 +340,9 @@ class ApplicationSerializer(serializers.Serializer):
|
|||
|
||||
model = QuerySet(Model).get(id=instance.get('model_id') if 'model_id' in instance else application.model_id)
|
||||
|
||||
update_keys = ['name', 'desc', 'model_id', 'multiple_rounds_dialogue', 'prologue', 'example', 'status',
|
||||
'api_key_is_active']
|
||||
update_keys = ['name', 'desc', 'model_id', 'multiple_rounds_dialogue', 'prologue', 'status',
|
||||
'dataset_setting', 'model_setting', 'problem_optimization'
|
||||
'api_key_is_active']
|
||||
for update_key in update_keys:
|
||||
if update_key in instance and instance.get(update_key) is not None:
|
||||
if update_key == 'multiple_rounds_dialogue':
|
||||
|
|
|
|||
|
|
@ -7,194 +7,181 @@
|
|||
@desc:
|
||||
"""
|
||||
import json
|
||||
import uuid
|
||||
from typing import List
|
||||
from uuid import UUID
|
||||
|
||||
from django.db.models import QuerySet
|
||||
from django.http import StreamingHttpResponse
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.schema import HumanMessage
|
||||
from rest_framework import serializers, status
|
||||
from django.core.cache import cache
|
||||
from common import event
|
||||
from common.config.embedding_config import VectorStore, EmbeddingModel
|
||||
from common.response import result
|
||||
from dataset.models import Paragraph
|
||||
from embedding.models import SourceType
|
||||
from setting.models.model_management import Model
|
||||
from django.db.models import QuerySet
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from rest_framework import serializers
|
||||
|
||||
from application.chat_pipeline.pipeline_manage import PiplineManage
|
||||
from application.chat_pipeline.step.chat_step.i_chat_step import PostResponseHandler
|
||||
from application.chat_pipeline.step.chat_step.impl.base_chat_step import BaseChatStep
|
||||
from application.chat_pipeline.step.generate_human_message_step.impl.base_generate_human_message_step import \
|
||||
BaseGenerateHumanMessageStep
|
||||
from application.chat_pipeline.step.reset_problem_step.impl.base_reset_problem_step import BaseResetProblemStep
|
||||
from application.chat_pipeline.step.search_dataset_step.impl.base_search_dataset_step import BaseSearchDatasetStep
|
||||
from application.models import ChatRecord, Chat, Application, ApplicationDatasetMapping
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.util.rsa_util import decrypt
|
||||
from common.util.split_model import flat_map
|
||||
from dataset.models import Paragraph, Document
|
||||
from setting.models import Model
|
||||
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
|
||||
|
||||
chat_cache = cache
|
||||
|
||||
|
||||
class MessageManagement:
|
||||
@staticmethod
|
||||
def get_message(title: str, content: str, message: str):
|
||||
if content is None:
|
||||
return HumanMessage(content=message)
|
||||
return HumanMessage(content=(
|
||||
f'已知信息:{title}:{content} '
|
||||
'根据上述已知信息,请简洁和专业的来回答用户的问题。已知信息中的图片、链接地址和脚本语言请直接返回。如果无法从已知信息中得到答案,请说 “没有在知识库中查找到相关信息,建议咨询相关技术支持或参考官方文档进行操作” 或 “根据已知信息无法回答该问题,建议联系官方技术支持人员”,不允许在答案中添加编造成分,答案请使用中文。'
|
||||
f'问题是:{message}'))
|
||||
|
||||
|
||||
class ChatMessage:
|
||||
def __init__(self, id: str, problem: str, title: str, paragraph: str, embedding_id: str, dataset_id: str,
|
||||
document_id: str,
|
||||
paragraph_id,
|
||||
source_type: SourceType,
|
||||
source_id: str,
|
||||
answer: str,
|
||||
message_tokens: int,
|
||||
answer_token: int,
|
||||
chat_model=None,
|
||||
chat_message=None):
|
||||
self.id = id
|
||||
self.problem = problem
|
||||
self.title = title
|
||||
self.paragraph = paragraph
|
||||
self.embedding_id = embedding_id
|
||||
self.dataset_id = dataset_id
|
||||
self.document_id = document_id
|
||||
self.paragraph_id = paragraph_id
|
||||
self.source_type = source_type
|
||||
self.source_id = source_id
|
||||
self.answer = answer
|
||||
self.message_tokens = message_tokens
|
||||
self.answer_token = answer_token
|
||||
self.chat_model = chat_model
|
||||
self.chat_message = chat_message
|
||||
|
||||
def get_chat_message(self):
|
||||
return MessageManagement.get_message(self.problem, self.paragraph, self.problem)
|
||||
|
||||
|
||||
class ChatInfo:
|
||||
def __init__(self,
|
||||
chat_id: str,
|
||||
model: Model,
|
||||
chat_model: BaseChatModel,
|
||||
application_id: str | None,
|
||||
dataset_id_list: List[str],
|
||||
exclude_document_id_list: list[str],
|
||||
dialogue_number: int):
|
||||
application: Application):
|
||||
"""
|
||||
:param chat_id: 对话id
|
||||
:param chat_model: 对话模型
|
||||
:param dataset_id_list: 数据集列表
|
||||
:param exclude_document_id_list: 排除的文档
|
||||
:param application: 应用信息
|
||||
"""
|
||||
self.chat_id = chat_id
|
||||
self.application_id = application_id
|
||||
self.model = model
|
||||
self.application = application
|
||||
self.chat_model = chat_model
|
||||
self.dataset_id_list = dataset_id_list
|
||||
self.exclude_document_id_list = exclude_document_id_list
|
||||
self.dialogue_number = dialogue_number
|
||||
self.chat_message_list: List[ChatMessage] = []
|
||||
self.chat_record_list: List[ChatRecord] = []
|
||||
|
||||
def append_chat_message(self, chat_message: ChatMessage):
|
||||
self.chat_message_list.append(chat_message)
|
||||
if self.application_id is not None:
|
||||
def to_base_pipeline_manage_params(self):
|
||||
dataset_setting = self.application.dataset_setting
|
||||
model_setting = self.application.model_setting
|
||||
return {
|
||||
'dataset_id_list': self.dataset_id_list,
|
||||
'exclude_document_id_list': self.exclude_document_id_list,
|
||||
'exclude_paragraph_id_list': [],
|
||||
'top_n': dataset_setting.get('top_n') if 'top_n' in dataset_setting else 3,
|
||||
'similarity': dataset_setting.get('similarity') if 'similarity' in dataset_setting else 0.6,
|
||||
'max_paragraph_char_number': dataset_setting.get(
|
||||
'max_paragraph_char_number') if 'max_paragraph_char_number' in dataset_setting else 5000,
|
||||
'history_chat_record': self.chat_record_list,
|
||||
'chat_id': self.chat_id,
|
||||
'dialogue_number': self.application.dialogue_number,
|
||||
'prompt': model_setting.get(
|
||||
'prompt') if 'prompt' in model_setting else Application.get_default_model_prompt(),
|
||||
'chat_model': self.chat_model,
|
||||
'model_id': self.application.model.id if self.application.model is not None else None,
|
||||
'problem_optimization': self.application.problem_optimization
|
||||
|
||||
}
|
||||
|
||||
def to_pipeline_manage_params(self, problem_text: str, post_response_handler: PostResponseHandler,
|
||||
exclude_paragraph_id_list):
|
||||
params = self.to_base_pipeline_manage_params()
|
||||
return {**params, 'problem_text': problem_text, 'post_response_handler': post_response_handler,
|
||||
'exclude_paragraph_id_list': exclude_paragraph_id_list}
|
||||
|
||||
def append_chat_record(self, chat_record: ChatRecord):
|
||||
# 存入缓存中
|
||||
self.chat_record_list.append(chat_record)
|
||||
if self.application.id is not None:
|
||||
# 插入数据库
|
||||
event.ListenerChatMessage.record_chat_message_signal.send(
|
||||
event.RecordChatMessageArgs(len(self.chat_message_list) - 1, self.chat_id, self.application_id,
|
||||
chat_message)
|
||||
)
|
||||
# 异步更新token
|
||||
event.ListenerChatMessage.update_chat_message_token_signal.send(chat_message)
|
||||
if not QuerySet(Chat).filter(id=self.chat_id).exists():
|
||||
Chat(id=self.chat_id, application_id=self.application.id, abstract=chat_record.problem_text).save()
|
||||
# 插入会话记录
|
||||
chat_record.save()
|
||||
|
||||
def get_context_message(self):
|
||||
start_index = len(self.chat_message_list) - self.dialogue_number
|
||||
return [self.chat_message_list[index].get_chat_message() for index in
|
||||
range(start_index if start_index > 0 else 0, len(self.chat_message_list))]
|
||||
|
||||
def get_post_handler(chat_info: ChatInfo):
|
||||
class PostHandler(PostResponseHandler):
|
||||
|
||||
def handler(self,
|
||||
chat_id: UUID,
|
||||
chat_record_id,
|
||||
paragraph_list: List[Paragraph],
|
||||
problem_text: str,
|
||||
answer_text,
|
||||
manage: PiplineManage,
|
||||
step: BaseChatStep,
|
||||
padding_problem_text: str = None,
|
||||
**kwargs):
|
||||
chat_record = ChatRecord(id=chat_record_id,
|
||||
chat_id=chat_id,
|
||||
paragraph_id_list=[str(p.id) for p in paragraph_list],
|
||||
problem_text=problem_text,
|
||||
answer_text=answer_text,
|
||||
details=manage.get_details(),
|
||||
message_tokens=manage.context['message_tokens'],
|
||||
answer_tokens=manage.context['answer_tokens'],
|
||||
run_time=manage.context['run_time'],
|
||||
index=len(chat_info.chat_record_list) + 1)
|
||||
chat_info.append_chat_record(chat_record)
|
||||
# 重新设置缓存
|
||||
chat_cache.set(chat_id,
|
||||
chat_info, timeout=60 * 30)
|
||||
|
||||
return PostHandler()
|
||||
|
||||
|
||||
class ChatMessageSerializer(serializers.Serializer):
|
||||
chat_id = serializers.UUIDField(required=True)
|
||||
|
||||
def chat(self, message):
|
||||
def chat(self, message, re_chat: bool):
|
||||
self.is_valid(raise_exception=True)
|
||||
chat_id = self.data.get('chat_id')
|
||||
chat_info: ChatInfo = chat_cache.get(chat_id)
|
||||
if chat_info is None:
|
||||
return result.Result(response_status=status.HTTP_404_NOT_FOUND, code=404, message="会话过期")
|
||||
chat_info = self.re_open_chat(chat_id)
|
||||
chat_cache.set(chat_id,
|
||||
chat_info, timeout=60 * 30)
|
||||
|
||||
chat_model = chat_info.chat_model
|
||||
vector = VectorStore.get_embedding_vector()
|
||||
# 向量库检索
|
||||
_value = vector.search(message, chat_info.dataset_id_list, chat_info.exclude_document_id_list,
|
||||
[chat_message.embedding_id for chat_message in
|
||||
(list(filter(lambda row: row.problem == message, chat_info.chat_message_list)))],
|
||||
True,
|
||||
EmbeddingModel.get_embedding_model())
|
||||
# 查询段落id详情
|
||||
paragraph = None
|
||||
if _value is not None:
|
||||
paragraph = QuerySet(Paragraph).get(id=_value.get('paragraph_id'))
|
||||
if paragraph is None:
|
||||
vector.delete_by_paragraph_id(_value.get('paragraph_id'))
|
||||
pipline_manage_builder = PiplineManage.builder()
|
||||
# 如果开启了问题优化,则添加上问题优化步骤
|
||||
if chat_info.application.problem_optimization:
|
||||
pipline_manage_builder.append_step(BaseResetProblemStep)
|
||||
# 构建流水线管理器
|
||||
pipline_message = (pipline_manage_builder.append_step(BaseSearchDatasetStep)
|
||||
.append_step(BaseGenerateHumanMessageStep)
|
||||
.append_step(BaseChatStep)
|
||||
.build())
|
||||
exclude_paragraph_id_list = []
|
||||
# 相同问题是否需要排除已经查询到的段落
|
||||
if re_chat:
|
||||
paragraph_id_list = flat_map([row.paragraph_id_list for row in
|
||||
filter(lambda chat_record: chat_record == message,
|
||||
chat_info.chat_record_list)])
|
||||
exclude_paragraph_id_list = list(set(paragraph_id_list))
|
||||
# 构建运行参数
|
||||
params = chat_info.to_pipeline_manage_params(message, get_post_handler(chat_info), exclude_paragraph_id_list)
|
||||
# 运行流水线作业
|
||||
pipline_message.run(params)
|
||||
return pipline_message.context['chat_result']
|
||||
|
||||
title, content = (None, None) if paragraph is None else (paragraph.title, paragraph.content)
|
||||
_id = str(uuid.uuid1())
|
||||
@staticmethod
|
||||
def re_open_chat(chat_id: str):
|
||||
chat = QuerySet(Chat).filter(id=chat_id).first()
|
||||
if chat is None:
|
||||
raise AppApiException(500, "会话不存在")
|
||||
application = QuerySet(Application).filter(id=chat.application_id).first()
|
||||
if application is None:
|
||||
raise AppApiException(500, "应用不存在")
|
||||
model = QuerySet(Model).filter(id=application.model_id).first()
|
||||
chat_model = None
|
||||
if model is not None:
|
||||
# 对话模型
|
||||
chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name,
|
||||
json.loads(
|
||||
decrypt(model.credential)),
|
||||
streaming=True)
|
||||
# 数据集id列表
|
||||
dataset_id_list = [str(row.dataset_id) for row in
|
||||
QuerySet(ApplicationDatasetMapping).filter(
|
||||
application_id=application.id)]
|
||||
|
||||
embedding_id, dataset_id, document_id, paragraph_id, source_type, source_id = (_value.get(
|
||||
'id'), _value.get(
|
||||
'dataset_id'), _value.get(
|
||||
'document_id'), _value.get(
|
||||
'paragraph_id'), _value.get(
|
||||
'source_type'), _value.get(
|
||||
'source_id')) if _value is not None else (None, None, None, None, None, None)
|
||||
|
||||
if chat_model is None:
|
||||
def event_block_content(c: str):
|
||||
yield 'data: ' + json.dumps({'chat_id': chat_id, 'id': _id, 'operate': paragraph is not None,
|
||||
'is_end': True,
|
||||
'content': c if c is not None else '抱歉,根据已知信息无法回答这个问题,请重新描述您的问题或提供更多信息~'}) + "\n\n"
|
||||
chat_info.append_chat_message(
|
||||
ChatMessage(_id, message, title, content, embedding_id, dataset_id, document_id,
|
||||
paragraph_id,
|
||||
source_type,
|
||||
source_id,
|
||||
c if c is not None else '抱歉,根据已知信息无法回答这个问题,请重新描述您的问题或提供更多信息~',
|
||||
0,
|
||||
0))
|
||||
# 重新设置缓存
|
||||
chat_cache.set(chat_id,
|
||||
chat_info, timeout=60 * 30)
|
||||
|
||||
r = StreamingHttpResponse(streaming_content=event_block_content(content),
|
||||
content_type='text/event-stream;charset=utf-8')
|
||||
|
||||
r['Cache-Control'] = 'no-cache'
|
||||
return r
|
||||
# 获取上下文
|
||||
history_message = chat_info.get_context_message()
|
||||
|
||||
# 构建会话请求问题
|
||||
chat_message = [*history_message, MessageManagement.get_message(title, content, message)]
|
||||
# 对话
|
||||
result_data = chat_model.stream(chat_message)
|
||||
|
||||
def event_content(response):
|
||||
all_text = ''
|
||||
try:
|
||||
for chunk in response:
|
||||
all_text += chunk.content
|
||||
yield 'data: ' + json.dumps({'chat_id': chat_id, 'id': _id, 'operate': paragraph is not None,
|
||||
'content': chunk.content, 'is_end': False}) + "\n\n"
|
||||
|
||||
yield 'data: ' + json.dumps({'chat_id': chat_id, 'id': _id, 'operate': paragraph is not None,
|
||||
'content': '', 'is_end': True}) + "\n\n"
|
||||
chat_info.append_chat_message(
|
||||
ChatMessage(_id, message, title, content, embedding_id, dataset_id, document_id,
|
||||
paragraph_id,
|
||||
source_type,
|
||||
source_id, all_text,
|
||||
0,
|
||||
0,
|
||||
chat_message=chat_message, chat_model=chat_model))
|
||||
# 重新设置缓存
|
||||
chat_cache.set(chat_id,
|
||||
chat_info, timeout=60 * 30)
|
||||
except Exception as e:
|
||||
yield e
|
||||
|
||||
r = StreamingHttpResponse(streaming_content=event_content(result_data),
|
||||
content_type='text/event-stream;charset=utf-8')
|
||||
|
||||
r['Cache-Control'] = 'no-cache'
|
||||
return r
|
||||
# 需要排除的文档
|
||||
exclude_document_id_list = [str(document.id) for document in
|
||||
QuerySet(Document).filter(
|
||||
dataset_id__in=dataset_id_list,
|
||||
is_active=False)]
|
||||
return ChatInfo(chat_id, chat_model, dataset_id_list, exclude_document_id_list, application)
|
||||
|
|
|
|||
|
|
@ -10,7 +10,8 @@ import datetime
|
|||
import json
|
||||
import os
|
||||
import uuid
|
||||
from typing import Dict
|
||||
from functools import reduce
|
||||
from typing import Dict, List
|
||||
|
||||
from django.core.cache import cache
|
||||
from django.db import transaction
|
||||
|
|
@ -18,7 +19,8 @@ from django.db.models import QuerySet
|
|||
from rest_framework import serializers
|
||||
|
||||
from application.models import Chat, Application, ApplicationDatasetMapping, VoteChoices, ChatRecord
|
||||
from application.serializers.application_serializers import ModelDatasetAssociation
|
||||
from application.serializers.application_serializers import ModelDatasetAssociation, DatasetSettingSerializer, \
|
||||
ModelSettingSerializer
|
||||
from application.serializers.chat_message_serializers import ChatInfo
|
||||
from common.db.search import native_search, native_page_search, page_search
|
||||
from common.event import ListenerManagement
|
||||
|
|
@ -26,8 +28,8 @@ from common.exception.app_exception import AppApiException
|
|||
from common.util.file_util import get_file_content
|
||||
from common.util.lock import try_lock, un_lock
|
||||
from common.util.rsa_util import decrypt
|
||||
from common.util.split_model import flat_map
|
||||
from dataset.models import Document, Problem, Paragraph
|
||||
from embedding.models import SourceType, Embedding
|
||||
from setting.models import Model
|
||||
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
|
@ -106,12 +108,12 @@ class ChatSerializers(serializers.Serializer):
|
|||
|
||||
chat_id = str(uuid.uuid1())
|
||||
chat_cache.set(chat_id,
|
||||
ChatInfo(chat_id, model, chat_model, application_id, dataset_id_list,
|
||||
ChatInfo(chat_id, chat_model, dataset_id_list,
|
||||
[str(document.id) for document in
|
||||
QuerySet(Document).filter(
|
||||
dataset_id__in=dataset_id_list,
|
||||
is_active=False)],
|
||||
application.dialogue_number), timeout=60 * 30)
|
||||
application), timeout=60 * 30)
|
||||
return chat_id
|
||||
|
||||
class OpenTempChat(serializers.Serializer):
|
||||
|
|
@ -122,6 +124,12 @@ class ChatSerializers(serializers.Serializer):
|
|||
multiple_rounds_dialogue = serializers.BooleanField(required=True)
|
||||
|
||||
dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True))
|
||||
# 数据集相关设置
|
||||
dataset_setting = DatasetSettingSerializer(required=True)
|
||||
# 模型相关设置
|
||||
model_setting = ModelSettingSerializer(required=True)
|
||||
# 问题补全
|
||||
problem_optimization = serializers.BooleanField(required=True)
|
||||
|
||||
def is_valid(self, *, raise_exception=False):
|
||||
super().is_valid(raise_exception=True)
|
||||
|
|
@ -140,42 +148,62 @@ class ChatSerializers(serializers.Serializer):
|
|||
json.loads(
|
||||
decrypt(model.credential)),
|
||||
streaming=True)
|
||||
application = Application(id=None, dialogue_number=3, model=model,
|
||||
dataset_setting=self.data.get('dataset_setting'),
|
||||
model_setting=self.data.get('model_setting'),
|
||||
problem_optimization=self.data.get('problem_optimization'))
|
||||
chat_cache.set(chat_id,
|
||||
ChatInfo(chat_id, model, chat_model, None, dataset_id_list,
|
||||
ChatInfo(chat_id, chat_model, dataset_id_list,
|
||||
[str(document.id) for document in
|
||||
QuerySet(Document).filter(
|
||||
dataset_id__in=dataset_id_list,
|
||||
is_active=False)],
|
||||
3 if self.data.get('multiple_rounds_dialogue') else 1), timeout=60 * 30)
|
||||
application), timeout=60 * 30)
|
||||
return chat_id
|
||||
|
||||
|
||||
def vote_exec(source_type: SourceType, source_id: str, field: str, post_handler):
|
||||
if source_type == SourceType.PROBLEM:
|
||||
problem = QuerySet(Problem).get(id=source_id)
|
||||
if problem is not None:
|
||||
problem.__setattr__(field, post_handler(problem))
|
||||
problem.save()
|
||||
embedding = QuerySet(Embedding).get(source_id=source_id, source_type=source_type)
|
||||
embedding.__setattr__(field, problem.__getattribute__(field))
|
||||
embedding.save()
|
||||
if source_type == SourceType.PARAGRAPH:
|
||||
paragraph = QuerySet(Paragraph).get(id=source_id)
|
||||
if paragraph is not None:
|
||||
paragraph.__setattr__(field, post_handler(paragraph))
|
||||
paragraph.save()
|
||||
embedding = QuerySet(Embedding).get(source_id=source_id, source_type=source_type)
|
||||
embedding.__setattr__(field, paragraph.__getattribute__(field))
|
||||
embedding.save()
|
||||
|
||||
|
||||
class ChatRecordSerializerModel(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = ChatRecord
|
||||
fields = "__all__"
|
||||
fields = ['id', 'chat_id', 'vote_status', 'problem_text', 'answer_text',
|
||||
'message_tokens', 'answer_tokens', 'const', 'improve_paragraph_id_list', 'run_time', 'index']
|
||||
|
||||
|
||||
class ChatRecordSerializer(serializers.Serializer):
|
||||
class Operate(serializers.Serializer):
|
||||
chat_id = serializers.UUIDField(required=True)
|
||||
|
||||
chat_record_id = serializers.UUIDField(required=True)
|
||||
|
||||
def one(self, with_valid=True):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
chat_record_id = self.data.get('chat_record_id')
|
||||
chat_id = self.data.get('chat_id')
|
||||
chat_record = QuerySet(ChatRecord).filter(id=chat_record_id, chat_id=chat_id).first()
|
||||
dataset_list = []
|
||||
paragraph_list = []
|
||||
if len(chat_record.paragraph_id_list) > 0:
|
||||
paragraph_list = native_search(QuerySet(Paragraph).filter(id__in=chat_record.paragraph_id_list),
|
||||
get_file_content(
|
||||
os.path.join(PROJECT_DIR, "apps", "application", 'sql',
|
||||
'list_dataset_paragraph_by_paragraph_id.sql')),
|
||||
with_table_name=True)
|
||||
dataset_list = [{'id': dataset_id, 'name': name} for dataset_id, name in reduce(lambda x, y: {**x, **y},
|
||||
[{row.get(
|
||||
'dataset_id'): row.get(
|
||||
"dataset_name")} for
|
||||
row in
|
||||
paragraph_list],
|
||||
{}).items()]
|
||||
|
||||
return {
|
||||
**ChatRecordSerializerModel(chat_record).data,
|
||||
'padding_problem_text': chat_record.details.get(
|
||||
'padding_problem_text') if 'problem_padding' in chat_record.details else None,
|
||||
'dataset_list': dataset_list,
|
||||
'paragraph_list': paragraph_list}
|
||||
|
||||
class Query(serializers.Serializer):
|
||||
application_id = serializers.UUIDField(required=True)
|
||||
chat_id = serializers.UUIDField(required=True)
|
||||
|
|
@ -183,15 +211,57 @@ class ChatRecordSerializer(serializers.Serializer):
|
|||
def list(self, with_valid=True):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
QuerySet(ChatRecord).filter(chat_id=self.data.get('chat_id'))
|
||||
return [ChatRecordSerializerModel(chat_record).data for chat_record in
|
||||
QuerySet(ChatRecord).filter(chat_id=self.data.get('chat_id'))]
|
||||
|
||||
def reset_chat_record_list(self, chat_record_list: List[ChatRecord]):
|
||||
paragraph_id_list = flat_map([chat_record.paragraph_id_list for chat_record in chat_record_list])
|
||||
# 去重
|
||||
paragraph_id_list = list(set(paragraph_id_list))
|
||||
paragraph_list = self.search_paragraph(paragraph_id_list)
|
||||
return [self.reset_chat_record(chat_record, paragraph_list) for chat_record in chat_record_list]
|
||||
|
||||
@staticmethod
|
||||
def search_paragraph(paragraph_id_list: List[str]):
|
||||
paragraph_list = []
|
||||
if len(paragraph_id_list) > 0:
|
||||
paragraph_list = native_search(QuerySet(Paragraph).filter(id__in=paragraph_id_list),
|
||||
get_file_content(
|
||||
os.path.join(PROJECT_DIR, "apps", "application", 'sql',
|
||||
'list_dataset_paragraph_by_paragraph_id.sql')),
|
||||
with_table_name=True)
|
||||
return paragraph_list
|
||||
|
||||
@staticmethod
|
||||
def reset_chat_record(chat_record, all_paragraph_list):
|
||||
paragraph_list = list(
|
||||
filter(lambda paragraph: chat_record.paragraph_id_list.__contains__(str(paragraph.get('id'))),
|
||||
all_paragraph_list))
|
||||
dataset_list = [{'id': dataset_id, 'name': name} for dataset_id, name in reduce(lambda x, y: {**x, **y},
|
||||
[{row.get(
|
||||
'dataset_id'): row.get(
|
||||
"dataset_name")} for
|
||||
row in
|
||||
paragraph_list],
|
||||
{}).items()]
|
||||
return {
|
||||
**ChatRecordSerializerModel(chat_record).data,
|
||||
'padding_problem_text': chat_record.details.get('problem_padding').get(
|
||||
'padding_problem_text') if 'problem_padding' in chat_record.details else None,
|
||||
'dataset_list': dataset_list,
|
||||
'paragraph_list': paragraph_list
|
||||
}
|
||||
|
||||
def page(self, current_page: int, page_size: int, with_valid=True):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
return page_search(current_page, page_size,
|
||||
page = page_search(current_page, page_size,
|
||||
QuerySet(ChatRecord).filter(chat_id=self.data.get('chat_id')).order_by("index"),
|
||||
post_records_handler=lambda chat_record: ChatRecordSerializerModel(chat_record).data)
|
||||
post_records_handler=lambda chat_record: chat_record)
|
||||
records = page.get('records')
|
||||
page['records'] = self.reset_chat_record_list(records)
|
||||
return page
|
||||
|
||||
class Vote(serializers.Serializer):
|
||||
chat_id = serializers.UUIDField(required=True)
|
||||
|
|
@ -216,38 +286,20 @@ class ChatRecordSerializer(serializers.Serializer):
|
|||
if vote_status == VoteChoices.STAR:
|
||||
# 点赞
|
||||
chat_record_details_model.vote_status = VoteChoices.STAR
|
||||
# 点赞数量 +1
|
||||
vote_exec(chat_record_details_model.source_type, chat_record_details_model.source_id,
|
||||
'star_num',
|
||||
lambda r: r.star_num + 1)
|
||||
|
||||
if vote_status == VoteChoices.TRAMPLE:
|
||||
# 点踩
|
||||
chat_record_details_model.vote_status = VoteChoices.TRAMPLE
|
||||
# 点踩数量+1
|
||||
vote_exec(chat_record_details_model.source_type, chat_record_details_model.source_id,
|
||||
'trample_num',
|
||||
lambda r: r.trample_num + 1)
|
||||
chat_record_details_model.save()
|
||||
else:
|
||||
if vote_status == VoteChoices.UN_VOTE:
|
||||
# 取消点赞
|
||||
chat_record_details_model.vote_status = VoteChoices.UN_VOTE
|
||||
chat_record_details_model.save()
|
||||
if chat_record_details_model.vote_status == VoteChoices.STAR:
|
||||
# 点赞数量 -1
|
||||
vote_exec(chat_record_details_model.source_type, chat_record_details_model.source_id,
|
||||
'star_num', lambda r: r.star_num - 1)
|
||||
if chat_record_details_model.vote_status == VoteChoices.TRAMPLE:
|
||||
# 点踩数量 -1
|
||||
vote_exec(chat_record_details_model.source_type, chat_record_details_model.source_id,
|
||||
'trample_num', lambda r: r.trample_num - 1)
|
||||
|
||||
else:
|
||||
raise AppApiException(500, "已经投票过,请先取消后再进行投票")
|
||||
finally:
|
||||
un_lock(self.data.get('chat_record_id'))
|
||||
|
||||
return True
|
||||
|
||||
class ImproveSerializer(serializers.Serializer):
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
SELECT * FROM ( SELECT * FROM application ${application_custom_sql} UNION
|
||||
SELECT *,to_json(dataset_setting) as dataset_setting,to_json(model_setting) as model_setting FROM ( SELECT * FROM application ${application_custom_sql} UNION
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
|
|
|
|||
|
|
@ -0,0 +1,6 @@
|
|||
SELECT
|
||||
paragraph.*,
|
||||
dataset."name" AS "dataset_name"
|
||||
FROM
|
||||
paragraph paragraph
|
||||
LEFT JOIN dataset dataset ON dataset."id" = paragraph.dataset_id
|
||||
|
|
@ -58,6 +58,7 @@ class ApplicationApi(ApiMixin):
|
|||
'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间", description='创建时间'),
|
||||
|
||||
'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间", description='修改时间'),
|
||||
|
||||
'dataset_id_list': openapi.Schema(type=openapi.TYPE_ARRAY,
|
||||
items=openapi.Schema(type=openapi.TYPE_STRING),
|
||||
title="关联知识库Id列表",
|
||||
|
|
@ -133,7 +134,7 @@ class ApplicationApi(ApiMixin):
|
|||
def get_request_body_api():
|
||||
return openapi.Schema(
|
||||
type=openapi.TYPE_OBJECT,
|
||||
required=['name', 'desc', 'model_id', 'multiple_rounds_dialogue'],
|
||||
required=[],
|
||||
properties={
|
||||
'name': openapi.Schema(type=openapi.TYPE_STRING, title="应用名称", description="应用名称"),
|
||||
'desc': openapi.Schema(type=openapi.TYPE_STRING, title="应用描述", description="应用描述"),
|
||||
|
|
@ -141,11 +142,52 @@ class ApplicationApi(ApiMixin):
|
|||
"multiple_rounds_dialogue": openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否开启多轮对话",
|
||||
description="是否开启多轮对话"),
|
||||
'prologue': openapi.Schema(type=openapi.TYPE_STRING, title="开场白", description="开场白"),
|
||||
'example': openapi.Schema(type=openapi.TYPE_ARRAY, items=openapi.Schema(type=openapi.TYPE_STRING),
|
||||
title="示例列表", description="示例列表"),
|
||||
'dataset_id_list': openapi.Schema(type=openapi.TYPE_ARRAY,
|
||||
items=openapi.Schema(type=openapi.TYPE_STRING),
|
||||
title="关联知识库Id列表", description="关联知识库Id列表"),
|
||||
'dataset_setting': ApplicationApi.DatasetSetting.get_request_body_api(),
|
||||
'model_setting': ApplicationApi.ModelSetting.get_request_body_api(),
|
||||
'problem_optimization': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="问题优化",
|
||||
description="是否开启问题优化", default=True)
|
||||
|
||||
}
|
||||
)
|
||||
|
||||
class DatasetSetting(ApiMixin):
|
||||
@staticmethod
|
||||
def get_request_body_api():
|
||||
return openapi.Schema(
|
||||
type=openapi.TYPE_OBJECT,
|
||||
required=[''],
|
||||
properties={
|
||||
'top_n': openapi.Schema(type=openapi.TYPE_NUMBER, title="引用分段数", description="引用分段数",
|
||||
default=5),
|
||||
'similarity': openapi.Schema(type=openapi.TYPE_NUMBER, title='相似度', description="相似度",
|
||||
default=0.6),
|
||||
'max_paragraph_char_number': openapi.Schema(type=openapi.TYPE_NUMBER, title='最多引用字符数',
|
||||
description="最多引用字符数", default=3000),
|
||||
}
|
||||
)
|
||||
|
||||
class ModelSetting(ApiMixin):
|
||||
@staticmethod
|
||||
def get_request_body_api():
|
||||
return openapi.Schema(
|
||||
type=openapi.TYPE_OBJECT,
|
||||
required=['prompt'],
|
||||
properties={
|
||||
'prompt': openapi.Schema(type=openapi.TYPE_STRING, title="提示词", description="提示词",
|
||||
default=('已知信息:'
|
||||
'\n{data}'
|
||||
'\n回答要求:'
|
||||
'\n- 如果你不知道答案或者没有从获取答案,请回答“没有在知识库中查找到相关信息,建议咨询相关技术支持或参考官方文档进行操作”。'
|
||||
'\n- 避免提及你是从<data></data>中获得的知识。'
|
||||
'\n- 请保持答案与<data></data>中描述的一致。'
|
||||
'\n- 请使用markdown 语法优化答案的格式。'
|
||||
'\n- <data></data>中的图片链接、链接地址和脚本语言请完整返回。'
|
||||
'\n- 请使用与问题相同的语言来回答。'
|
||||
'\n问题:'
|
||||
'\n{question}')),
|
||||
|
||||
}
|
||||
)
|
||||
|
|
@ -155,7 +197,8 @@ class ApplicationApi(ApiMixin):
|
|||
def get_request_body_api():
|
||||
return openapi.Schema(
|
||||
type=openapi.TYPE_OBJECT,
|
||||
required=['name', 'desc', 'model_id', 'multiple_rounds_dialogue'],
|
||||
required=['name', 'desc', 'model_id', 'multiple_rounds_dialogue', 'dataset_setting', 'model_setting',
|
||||
'problem_optimization'],
|
||||
properties={
|
||||
'name': openapi.Schema(type=openapi.TYPE_STRING, title="应用名称", description="应用名称"),
|
||||
'desc': openapi.Schema(type=openapi.TYPE_STRING, title="应用描述", description="应用描述"),
|
||||
|
|
@ -163,11 +206,13 @@ class ApplicationApi(ApiMixin):
|
|||
"multiple_rounds_dialogue": openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否开启多轮对话",
|
||||
description="是否开启多轮对话"),
|
||||
'prologue': openapi.Schema(type=openapi.TYPE_STRING, title="开场白", description="开场白"),
|
||||
'example': openapi.Schema(type=openapi.TYPE_ARRAY, items=openapi.Schema(type=openapi.TYPE_STRING),
|
||||
title="示例列表", description="示例列表"),
|
||||
'dataset_id_list': openapi.Schema(type=openapi.TYPE_ARRAY,
|
||||
items=openapi.Schema(type=openapi.TYPE_STRING),
|
||||
title="关联知识库Id列表", description="关联知识库Id列表")
|
||||
title="关联知识库Id列表", description="关联知识库Id列表"),
|
||||
'dataset_setting': ApplicationApi.DatasetSetting.get_request_body_api(),
|
||||
'model_setting': ApplicationApi.ModelSetting.get_request_body_api(),
|
||||
'problem_optimization': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="问题优化",
|
||||
description="是否开启问题优化", default=True)
|
||||
|
||||
}
|
||||
)
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@
|
|||
"""
|
||||
from drf_yasg import openapi
|
||||
|
||||
from application.swagger_api.application_api import ApplicationApi
|
||||
from common.mixins.api_mixin import ApiMixin
|
||||
|
||||
|
||||
|
|
@ -19,6 +20,7 @@ class ChatApi(ApiMixin):
|
|||
required=['message'],
|
||||
properties={
|
||||
'message': openapi.Schema(type=openapi.TYPE_STRING, title="问题", description="问题"),
|
||||
're_chat': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="重新生成", default="重新生成")
|
||||
|
||||
}
|
||||
)
|
||||
|
|
@ -73,14 +75,19 @@ class ChatApi(ApiMixin):
|
|||
def get_request_body_api():
|
||||
return openapi.Schema(
|
||||
type=openapi.TYPE_OBJECT,
|
||||
required=['model_id', 'multiple_rounds_dialogue'],
|
||||
required=['model_id', 'multiple_rounds_dialogue', 'dataset_setting', 'model_setting',
|
||||
'problem_optimization'],
|
||||
properties={
|
||||
'model_id': openapi.Schema(type=openapi.TYPE_STRING, title="模型id", description="模型id"),
|
||||
'dataset_id_list': openapi.Schema(type=openapi.TYPE_ARRAY,
|
||||
items=openapi.Schema(type=openapi.TYPE_STRING),
|
||||
title="关联知识库Id列表", description="关联知识库Id列表"),
|
||||
'multiple_rounds_dialogue': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否开启多轮会话",
|
||||
description="是否开启多轮会话")
|
||||
description="是否开启多轮会话"),
|
||||
'dataset_setting': ApplicationApi.DatasetSetting.get_request_body_api(),
|
||||
'model_setting': ApplicationApi.ModelSetting.get_request_body_api(),
|
||||
'problem_optimization': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="问题优化",
|
||||
description="是否开启问题优化", default=True)
|
||||
}
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -25,6 +25,8 @@ urlpatterns = [
|
|||
path('application/<str:application_id>/chat/<chat_id>/chat_record/', views.ChatView.ChatRecord.as_view()),
|
||||
path('application/<str:application_id>/chat/<chat_id>/chat_record/<int:current_page>/<int:page_size>',
|
||||
views.ChatView.ChatRecord.Page.as_view()),
|
||||
path('application/<str:application_id>/chat/<chat_id>/chat_record/<chat_record_id>',
|
||||
views.ChatView.ChatRecord.Operate.as_view()),
|
||||
path('application/<str:application_id>/chat/<chat_id>/chat_record/<str:chat_record_id>/vote',
|
||||
views.ChatView.ChatRecord.Vote.as_view(),
|
||||
name=''),
|
||||
|
|
|
|||
|
|
@ -71,7 +71,8 @@ class ChatView(APIView):
|
|||
dynamic_tag=keywords.get('application_id'))])
|
||||
)
|
||||
def post(self, request: Request, chat_id: str):
|
||||
return ChatMessageSerializer(data={'chat_id': chat_id}).chat(request.data.get('message'))
|
||||
return ChatMessageSerializer(data={'chat_id': chat_id}).chat(request.data.get('message'), request.data.get(
|
||||
're_chat') if 're_chat' in request.data else False)
|
||||
|
||||
@action(methods=['GET'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="获取对话列表",
|
||||
|
|
@ -134,6 +135,27 @@ class ChatView(APIView):
|
|||
class ChatRecord(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
class Operate(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
@action(methods=['GET'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="获取对话记录详情",
|
||||
operation_id="获取对话记录详情",
|
||||
manual_parameters=ChatRecordApi.get_request_params_api(),
|
||||
responses=result.get_api_array_response(ChatRecordApi.get_response_body_api()),
|
||||
tags=["应用/对话日志"]
|
||||
)
|
||||
@has_permissions(
|
||||
ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_KEY],
|
||||
[lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE,
|
||||
dynamic_tag=keywords.get('application_id'))])
|
||||
)
|
||||
def get(self, request: Request, application_id: str, chat_id: str, chat_record_id: str):
|
||||
return result.success(ChatRecordSerializer.Operate(
|
||||
data={'application_id': application_id,
|
||||
'chat_id': chat_id,
|
||||
'chat_record_id': chat_record_id}).one())
|
||||
|
||||
@action(methods=['GET'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="获取对话记录列表",
|
||||
operation_id="获取对话记录列表",
|
||||
|
|
|
|||
|
|
@ -83,7 +83,7 @@ def compiler_queryset(queryset: QuerySet, field_replace_dict: None | Dict[str, s
|
|||
field_replace_dict = get_field_replace_dict(queryset)
|
||||
app_sql_compiler = AppSQLCompiler(q, using=DEFAULT_DB_ALIAS, connection=compiler.connection,
|
||||
field_replace_dict=field_replace_dict)
|
||||
sql, params = app_sql_compiler.get_query_str(with_table_name)
|
||||
sql, params = app_sql_compiler.get_query_str(with_table_name=with_table_name)
|
||||
return sql, params
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -7,10 +7,8 @@
|
|||
@desc:
|
||||
"""
|
||||
from .listener_manage import *
|
||||
from .listener_chat_message import *
|
||||
|
||||
|
||||
def run():
|
||||
listener_manage.ListenerManagement().run()
|
||||
listener_chat_message.ListenerChatMessage().run()
|
||||
QuerySet(Document).filter(status=Status.embedding).update(**{'status': Status.error})
|
||||
|
|
|
|||
|
|
@ -1,67 +0,0 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: listener_manage.py
|
||||
@date:2023/10/20 14:01
|
||||
@desc:
|
||||
"""
|
||||
import logging
|
||||
|
||||
from blinker import signal
|
||||
from django.db.models import QuerySet
|
||||
|
||||
from application.models import ChatRecord, Chat
|
||||
from application.serializers.chat_message_serializers import ChatMessage
|
||||
from common.event.common import poxy
|
||||
|
||||
|
||||
class RecordChatMessageArgs:
|
||||
def __init__(self, index: int, chat_id: str, application_id: str, chat_message: ChatMessage):
|
||||
self.index = index
|
||||
self.chat_id = chat_id
|
||||
self.application_id = application_id
|
||||
self.chat_message = chat_message
|
||||
|
||||
|
||||
class ListenerChatMessage:
|
||||
record_chat_message_signal = signal("record_chat_message")
|
||||
update_chat_message_token_signal = signal("update_chat_message_token")
|
||||
|
||||
@staticmethod
|
||||
def record_chat_message(args: RecordChatMessageArgs):
|
||||
if not QuerySet(Chat).filter(id=args.chat_id).exists():
|
||||
Chat(id=args.chat_id, application_id=args.application_id, abstract=args.chat_message.problem).save()
|
||||
# 插入会话记录
|
||||
try:
|
||||
chat_record = ChatRecord(
|
||||
id=args.chat_message.id,
|
||||
chat_id=args.chat_id,
|
||||
dataset_id=args.chat_message.dataset_id,
|
||||
paragraph_id=args.chat_message.paragraph_id,
|
||||
source_id=args.chat_message.source_id,
|
||||
source_type=args.chat_message.source_type,
|
||||
problem_text=args.chat_message.problem,
|
||||
answer_text=args.chat_message.answer,
|
||||
index=args.index,
|
||||
message_tokens=args.chat_message.message_tokens,
|
||||
answer_tokens=args.chat_message.answer_token)
|
||||
chat_record.save()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
@staticmethod
|
||||
@poxy
|
||||
def update_token(chat_message: ChatMessage):
|
||||
if chat_message.chat_model is not None:
|
||||
logging.getLogger("max_kb").info("开始更新token")
|
||||
message_token = chat_message.chat_model.get_num_tokens_from_messages(chat_message.chat_message)
|
||||
answer_token = chat_message.chat_model.get_num_tokens(chat_message.answer)
|
||||
# 修改token数量
|
||||
QuerySet(ChatRecord).filter(id=chat_message.id).update(
|
||||
**{'message_tokens': message_token, 'answer_tokens': answer_token})
|
||||
|
||||
def run(self):
|
||||
# 记录会话
|
||||
ListenerChatMessage.record_chat_message_signal.connect(self.record_chat_message)
|
||||
ListenerChatMessage.update_chat_message_token_signal.connect(self.update_token)
|
||||
|
|
@ -0,0 +1,34 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: common.py
|
||||
@date:2024/1/11 18:44
|
||||
@desc:
|
||||
"""
|
||||
from rest_framework import serializers
|
||||
|
||||
|
||||
class InstanceField(serializers.Field):
|
||||
def __init__(self, model_type, **kwargs):
|
||||
self.model_type = model_type
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def to_internal_value(self, data):
|
||||
if not isinstance(data, self.model_type):
|
||||
self.fail('message类型错误', value=data)
|
||||
return data
|
||||
|
||||
def to_representation(self, value):
|
||||
return value
|
||||
|
||||
|
||||
class FunctionField(serializers.Field):
|
||||
|
||||
def to_internal_value(self, data):
|
||||
if not callable(data):
|
||||
self.fail('不是一個函數', value=data)
|
||||
return data
|
||||
|
||||
def to_representation(self, value):
|
||||
return value
|
||||
|
|
@ -5,9 +5,7 @@ SELECT
|
|||
problem.dataset_id AS dataset_id,
|
||||
0 AS source_type,
|
||||
problem."content" AS "text",
|
||||
paragraph.is_active AS is_active,
|
||||
problem.star_num as star_num,
|
||||
problem.trample_num as trample_num
|
||||
paragraph.is_active AS is_active
|
||||
FROM
|
||||
problem problem
|
||||
LEFT JOIN paragraph paragraph ON paragraph."id" = problem.paragraph_id
|
||||
|
|
@ -23,9 +21,7 @@ SELECT
|
|||
concat_ws('
|
||||
',concat_ws('
|
||||
',paragraph.title,paragraph."content"),paragraph.title) AS "text",
|
||||
paragraph.is_active AS is_active,
|
||||
paragraph.star_num as star_num,
|
||||
paragraph.trample_num as trample_num
|
||||
paragraph.is_active AS is_active
|
||||
FROM
|
||||
paragraph paragraph
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,37 @@
|
|||
# Generated by Django 4.1.10 on 2024-01-16 11:22
|
||||
|
||||
from django.db import migrations
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('dataset', '0003_alter_paragraph_content'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.RemoveField(
|
||||
model_name='paragraph',
|
||||
name='hit_num',
|
||||
),
|
||||
migrations.RemoveField(
|
||||
model_name='paragraph',
|
||||
name='star_num',
|
||||
),
|
||||
migrations.RemoveField(
|
||||
model_name='paragraph',
|
||||
name='trample_num',
|
||||
),
|
||||
migrations.RemoveField(
|
||||
model_name='problem',
|
||||
name='hit_num',
|
||||
),
|
||||
migrations.RemoveField(
|
||||
model_name='problem',
|
||||
name='star_num',
|
||||
),
|
||||
migrations.RemoveField(
|
||||
model_name='problem',
|
||||
name='trample_num',
|
||||
),
|
||||
]
|
||||
|
|
@ -74,9 +74,6 @@ class Paragraph(AppModelMixin):
|
|||
dataset = models.ForeignKey(DataSet, on_delete=models.DO_NOTHING)
|
||||
content = models.CharField(max_length=4096, verbose_name="段落内容")
|
||||
title = models.CharField(max_length=256, verbose_name="标题", default="")
|
||||
hit_num = models.IntegerField(verbose_name="命中数量", default=0)
|
||||
star_num = models.IntegerField(verbose_name="点赞数", default=0)
|
||||
trample_num = models.IntegerField(verbose_name="点踩数", default=0)
|
||||
status = models.CharField(verbose_name='状态', max_length=1, choices=Status.choices,
|
||||
default=Status.embedding)
|
||||
is_active = models.BooleanField(default=True)
|
||||
|
|
@ -94,9 +91,6 @@ class Problem(AppModelMixin):
|
|||
dataset = models.ForeignKey(DataSet, on_delete=models.DO_NOTHING, db_constraint=False)
|
||||
paragraph = models.ForeignKey(Paragraph, on_delete=models.DO_NOTHING, db_constraint=False)
|
||||
content = models.CharField(max_length=256, verbose_name="问题内容")
|
||||
hit_num = models.IntegerField(verbose_name="命中数量", default=0)
|
||||
star_num = models.IntegerField(verbose_name="点赞数", default=0)
|
||||
trample_num = models.IntegerField(verbose_name="点踩数", default=0)
|
||||
|
||||
class Meta:
|
||||
db_table = "problem"
|
||||
|
|
|
|||
|
|
@ -26,12 +26,11 @@ from application.models import ApplicationDatasetMapping
|
|||
from common.config.embedding_config import VectorStore, EmbeddingModel
|
||||
from common.db.search import get_dynamics_model, native_page_search, native_search
|
||||
from common.db.sql_execute import select_list
|
||||
from common.event.listener_manage import ListenerManagement, SyncWebDatasetArgs
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.mixins.api_mixin import ApiMixin
|
||||
from common.util.common import post
|
||||
from common.util.file_util import get_file_content
|
||||
from common.util.fork import ChildLink, Fork, ForkManage
|
||||
from common.util.fork import ChildLink, Fork
|
||||
from common.util.split_model import get_split_model
|
||||
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type
|
||||
from dataset.serializers.common_serializers import list_paragraph
|
||||
|
|
@ -286,7 +285,8 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||
properties={
|
||||
'name': openapi.Schema(type=openapi.TYPE_STRING, title="知识库名称", description="知识库名称"),
|
||||
'desc': openapi.Schema(type=openapi.TYPE_STRING, title="知识库描述", description="知识库描述"),
|
||||
'source_url': openapi.Schema(type=openapi.TYPE_STRING, title="web站点url", description="web站点url"),
|
||||
'source_url': openapi.Schema(type=openapi.TYPE_STRING, title="web站点url",
|
||||
description="web站点url"),
|
||||
'selector': openapi.Schema(type=openapi.TYPE_STRING, title="选择器", description="选择器")
|
||||
}
|
||||
)
|
||||
|
|
@ -369,7 +369,8 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||
dataset_id = uuid.uuid1()
|
||||
dataset = DataSet(
|
||||
**{'id': dataset_id, 'name': instance.get("name"), 'desc': instance.get('desc'), 'user_id': user_id,
|
||||
'type': Type.web, 'meta': {'source_url': instance.get('source_url'), 'selector': instance.get('selector')}})
|
||||
'type': Type.web,
|
||||
'meta': {'source_url': instance.get('source_url'), 'selector': instance.get('selector')}})
|
||||
dataset.save()
|
||||
ListenerManagement.sync_web_dataset_signal.send(
|
||||
SyncWebDatasetArgs(str(dataset_id), instance.get('source_url'), instance.get('selector'),
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ from dataset.serializers.problem_serializers import ProblemInstanceSerializer, P
|
|||
class ParagraphSerializer(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = Paragraph
|
||||
fields = ['id', 'content', 'hit_num', 'star_num', 'trample_num', 'is_active', 'document_id', 'title',
|
||||
fields = ['id', 'content', 'is_active', 'document_id', 'title',
|
||||
'create_time', 'update_time']
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ from embedding.vector.pg_vector import PGVector
|
|||
class ProblemSerializer(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = Problem
|
||||
fields = ['id', 'content', 'hit_num', 'star_num', 'trample_num', 'dataset_id', 'document_id',
|
||||
fields = ['id', 'content', 'dataset_id', 'document_id',
|
||||
'create_time', 'update_time']
|
||||
|
||||
|
||||
|
|
@ -77,8 +77,8 @@ class ProblemSerializers(ApiMixin, serializers.Serializer):
|
|||
'document_id': self.data.get('document_id'),
|
||||
'paragraph_id': self.data.get('paragraph_id'),
|
||||
'dataset_id': self.data.get('dataset_id'),
|
||||
'star_num': 0,
|
||||
'trample_num': 0})
|
||||
|
||||
})
|
||||
|
||||
return ProblemSerializers.Operate(
|
||||
data={'dataset_id': self.data.get('dataset_id'), 'document_id': self.data.get('document_id'),
|
||||
|
|
|
|||
|
|
@ -0,0 +1,37 @@
|
|||
# Generated by Django 4.1.10 on 2024-01-16 11:22
|
||||
|
||||
import django.contrib.postgres.fields
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('embedding', '0001_initial'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.RemoveField(
|
||||
model_name='embedding',
|
||||
name='star_num',
|
||||
),
|
||||
migrations.RemoveField(
|
||||
model_name='embedding',
|
||||
name='trample_num',
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name='embedding',
|
||||
name='keywords',
|
||||
field=django.contrib.postgres.fields.ArrayField(base_field=models.CharField(max_length=256), default=list, size=None, verbose_name='关键词列表'),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name='embedding',
|
||||
name='meta',
|
||||
field=models.JSONField(default=dict, verbose_name='元数据'),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name='embedding',
|
||||
name='source_type',
|
||||
field=models.CharField(choices=[('0', '问题'), ('1', '段落'), ('2', '标题')], default='0', max_length=5, verbose_name='资源类型'),
|
||||
),
|
||||
]
|
||||
|
|
@ -6,6 +6,7 @@
|
|||
@date:2023/9/21 15:46
|
||||
@desc:
|
||||
"""
|
||||
from django.contrib.postgres.fields import ArrayField
|
||||
from django.db import models
|
||||
|
||||
from common.field.vector_field import VectorField
|
||||
|
|
@ -16,6 +17,7 @@ class SourceType(models.TextChoices):
|
|||
"""订单类型"""
|
||||
PROBLEM = 0, '问题'
|
||||
PARAGRAPH = 1, '段落'
|
||||
TITLE = 2, '标题'
|
||||
|
||||
|
||||
class Embedding(models.Model):
|
||||
|
|
@ -36,10 +38,10 @@ class Embedding(models.Model):
|
|||
|
||||
embedding = VectorField(verbose_name="向量")
|
||||
|
||||
star_num = models.IntegerField(default=0, verbose_name="点赞数量")
|
||||
keywords = ArrayField(verbose_name="关键词列表",
|
||||
base_field=models.CharField(max_length=256), default=list)
|
||||
|
||||
trample_num = models.IntegerField(default=0,
|
||||
verbose_name="点踩数量")
|
||||
meta = models.JSONField(verbose_name="元数据", default=dict)
|
||||
|
||||
class Meta:
|
||||
db_table = "embedding"
|
||||
|
|
|
|||
|
|
@ -1,15 +1,17 @@
|
|||
SELECT * FROM (SELECT
|
||||
*,
|
||||
( 1 - ( embedding.embedding <=> %s ) ) AS similarity,
|
||||
CASE
|
||||
|
||||
WHEN embedding.star_num - embedding.trample_num = 0 THEN
|
||||
0 ELSE ( ( ( embedding.star_num - embedding.trample_num ) - aggs.min_value ) / ( aggs.max_value - aggs.min_value ) )
|
||||
END AS score
|
||||
SELECT
|
||||
paragraph_id,
|
||||
comprehensive_score,
|
||||
comprehensive_score as similarity
|
||||
FROM
|
||||
embedding,
|
||||
( SELECT MIN ( star_num - trample_num ) AS min_value, MAX ( star_num - trample_num ) AS max_value FROM embedding ${embedding_query}) aggs
|
||||
${embedding_query}
|
||||
) temp
|
||||
WHERE similarity>0.5
|
||||
ORDER BY (similarity + score) DESC LIMIT 1
|
||||
(
|
||||
SELECT DISTINCT ON
|
||||
("paragraph_id") ( similarity ),* ,similarity AS comprehensive_score
|
||||
FROM
|
||||
( SELECT *, ( 1 - ( embedding.embedding <=> %s ) ) AS similarity FROM embedding ${embedding_query}) TEMP
|
||||
ORDER BY
|
||||
paragraph_id,
|
||||
similarity DESC
|
||||
) DISTINCT_TEMP
|
||||
WHERE comprehensive_score>%s
|
||||
ORDER BY comprehensive_score DESC
|
||||
LIMIT %s
|
||||
|
|
@ -1,34 +1,17 @@
|
|||
SELECT
|
||||
similarity,
|
||||
paragraph_id,
|
||||
comprehensive_score
|
||||
comprehensive_score,
|
||||
comprehensive_score as similarity
|
||||
FROM
|
||||
(
|
||||
SELECT DISTINCT ON
|
||||
( "paragraph_id" ) ( similarity + score ),*,
|
||||
( similarity + score ) AS comprehensive_score
|
||||
("paragraph_id") ( similarity ),* ,similarity AS comprehensive_score
|
||||
FROM
|
||||
(
|
||||
SELECT
|
||||
*,
|
||||
( 1 - ( embedding.embedding <=> %s ) ) AS similarity,
|
||||
CASE
|
||||
|
||||
WHEN embedding.star_num - embedding.trample_num = 0 THEN
|
||||
0 ELSE ( ( ( embedding.star_num - embedding.trample_num ) - aggs.min_value ) / ( aggs.max_value - aggs.min_value ) )
|
||||
END AS score
|
||||
FROM
|
||||
embedding,
|
||||
( SELECT MIN ( star_num - trample_num ) AS min_value, MAX ( star_num - trample_num ) AS max_value FROM embedding ${embedding_query}) aggs
|
||||
${embedding_query}
|
||||
) TEMP
|
||||
WHERE
|
||||
similarity > %s
|
||||
( SELECT *, ( 1 - ( embedding.embedding <=> %s ) ) AS similarity FROM embedding ${embedding_query} ) TEMP
|
||||
ORDER BY
|
||||
paragraph_id,
|
||||
( similarity + score )
|
||||
DESC
|
||||
) ss
|
||||
ORDER BY
|
||||
comprehensive_score DESC
|
||||
LIMIT %s
|
||||
similarity DESC
|
||||
) DISTINCT_TEMP
|
||||
WHERE comprehensive_score>%s
|
||||
ORDER BY comprehensive_score DESC
|
||||
LIMIT %s
|
||||
|
|
@ -99,8 +99,6 @@ class BaseVectorStore(ABC):
|
|||
@abstractmethod
|
||||
def _save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str,
|
||||
is_active: bool,
|
||||
star_num: int,
|
||||
trample_num: int,
|
||||
embedding: HuggingFaceEmbeddings):
|
||||
pass
|
||||
|
||||
|
|
@ -108,11 +106,20 @@ class BaseVectorStore(ABC):
|
|||
def _batch_save(self, text_list: List[Dict], embedding: HuggingFaceEmbeddings):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search(self, query_text, dataset_id_list: list[str], exclude_document_id_list: list[str],
|
||||
exclude_id_list: list[str],
|
||||
exclude_paragraph_list: list[str],
|
||||
is_active: bool,
|
||||
embedding: HuggingFaceEmbeddings):
|
||||
if dataset_id_list is None or len(dataset_id_list) == 0:
|
||||
return []
|
||||
embedding_query = embedding.embed_query(query_text)
|
||||
result = self.query(embedding_query, dataset_id_list, exclude_document_id_list, exclude_paragraph_list,
|
||||
is_active, 1, 0.65)
|
||||
return result[0]
|
||||
|
||||
@abstractmethod
|
||||
def query(self, query_embedding: List[float], dataset_id_list: list[str], exclude_document_id_list: list[str],
|
||||
exclude_paragraph_list: list[str], is_active: bool, top_n: int, similarity: float):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
|
|
|
|||
|
|
@ -33,8 +33,6 @@ class PGVector(BaseVectorStore):
|
|||
|
||||
def _save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str,
|
||||
is_active: bool,
|
||||
star_num: int,
|
||||
trample_num: int,
|
||||
embedding: HuggingFaceEmbeddings):
|
||||
text_embedding = embedding.embed_query(text)
|
||||
embedding = Embedding(id=uuid.uuid1(),
|
||||
|
|
@ -44,10 +42,7 @@ class PGVector(BaseVectorStore):
|
|||
paragraph_id=paragraph_id,
|
||||
source_id=source_id,
|
||||
embedding=text_embedding,
|
||||
source_type=source_type,
|
||||
star_num=star_num,
|
||||
trample_num=trample_num
|
||||
)
|
||||
source_type=source_type)
|
||||
embedding.save()
|
||||
return True
|
||||
|
||||
|
|
@ -61,8 +56,6 @@ class PGVector(BaseVectorStore):
|
|||
is_active=text_list[index].get('is_active', True),
|
||||
source_id=text_list[index].get('source_id'),
|
||||
source_type=text_list[index].get('source_type'),
|
||||
star_num=text_list[index].get('star_num'),
|
||||
trample_num=text_list[index].get('trample_num'),
|
||||
embedding=embeddings[index]) for index in
|
||||
range(0, len(text_list))]
|
||||
QuerySet(Embedding).bulk_create(embedding_list) if len(embedding_list) > 0 else None
|
||||
|
|
@ -78,29 +71,27 @@ class PGVector(BaseVectorStore):
|
|||
'hit_test.sql')),
|
||||
with_table_name=True)
|
||||
embedding_model = select_list(exec_sql,
|
||||
[json.dumps(embedding_query), *exec_params, *exec_params, similarity, top_number])
|
||||
[json.dumps(embedding_query), *exec_params, similarity, top_number])
|
||||
return embedding_model
|
||||
|
||||
def search(self, query_text, dataset_id_list: list[str], exclude_document_id_list: list[str],
|
||||
exclude_id_list: list[str],
|
||||
is_active: bool,
|
||||
embedding: HuggingFaceEmbeddings):
|
||||
def query(self, query_embedding: List[float], dataset_id_list: list[str], exclude_document_id_list: list[str],
|
||||
exclude_paragraph_list: list[str], is_active: bool, top_n: int, similarity: float):
|
||||
exclude_dict = {}
|
||||
if dataset_id_list is None or len(dataset_id_list) == 0:
|
||||
return None
|
||||
return []
|
||||
query_set = QuerySet(Embedding).filter(dataset_id__in=dataset_id_list, is_active=is_active)
|
||||
embedding_query = embedding.embed_query(query_text)
|
||||
if exclude_document_id_list is not None and len(exclude_document_id_list) > 0:
|
||||
exclude_dict.__setitem__('document_id__in', exclude_document_id_list)
|
||||
if exclude_id_list is not None and len(exclude_id_list) > 0:
|
||||
exclude_dict.__setitem__('id__in', exclude_id_list)
|
||||
if exclude_paragraph_list is not None and len(exclude_paragraph_list) > 0:
|
||||
exclude_dict.__setitem__('paragraph_id__in', exclude_paragraph_list)
|
||||
query_set = query_set.exclude(**exclude_dict)
|
||||
exec_sql, exec_params = generate_sql_by_query_dict({'embedding_query': query_set},
|
||||
select_string=get_file_content(
|
||||
os.path.join(PROJECT_DIR, "apps", "embedding", 'sql',
|
||||
'embedding_search.sql')),
|
||||
with_table_name=True)
|
||||
embedding_model = select_one(exec_sql, (json.dumps(embedding_query), *exec_params, *exec_params))
|
||||
embedding_model = select_list(exec_sql,
|
||||
[json.dumps(query_embedding), *exec_params, similarity, top_n])
|
||||
return embedding_model
|
||||
|
||||
def update_by_source_id(self, source_id: str, instance: Dict):
|
||||
|
|
|
|||
|
|
@ -14,13 +14,23 @@ from langchain.chat_models.base import BaseChatModel
|
|||
from langchain.load import dumpd
|
||||
from langchain.schema import LLMResult
|
||||
from langchain.schema.language_model import LanguageModelInput
|
||||
from langchain.schema.messages import BaseMessageChunk, BaseMessage, HumanMessage, AIMessage
|
||||
from langchain.schema.messages import BaseMessageChunk, BaseMessage, HumanMessage, AIMessage, get_buffer_string
|
||||
from langchain.schema.output import ChatGenerationChunk
|
||||
from langchain.schema.runnable import RunnableConfig
|
||||
from transformers import GPT2TokenizerFast
|
||||
|
||||
tokenizer = GPT2TokenizerFast.from_pretrained('gpt2', cache_dir="/opt/maxkb/model/tokenizer", resume_download=False,
|
||||
force_download=False)
|
||||
|
||||
|
||||
class QianfanChatModel(QianfanChatEndpoint):
|
||||
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
return len(tokenizer.encode(text))
|
||||
|
||||
def stream(
|
||||
self,
|
||||
input: LanguageModelInput,
|
||||
|
|
@ -30,7 +40,7 @@ class QianfanChatModel(QianfanChatEndpoint):
|
|||
**kwargs: Any,
|
||||
) -> Iterator[BaseMessageChunk]:
|
||||
if len(input) % 2 == 0:
|
||||
input = [HumanMessage(content='占位'), *input]
|
||||
input = [HumanMessage(content='padding'), *input]
|
||||
input = [
|
||||
HumanMessage(content=input[index].content) if index % 2 == 0 else AIMessage(content=input[index].content)
|
||||
for index in range(0, len(input))]
|
||||
|
|
|
|||
Loading…
Reference in New Issue