MaxKB/apps/application/serializers/chat_message_serializers.py
wxg0103 5b969ef861
Some checks failed
sync2gitee / repo-sync (push) Has been cancelled
Typos Check / Spell Check with Typos (push) Has been cancelled
feat: add speech_to_text node and text_to_speech node
2024-12-13 15:02:06 +08:00

430 lines
23 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# coding=utf-8
"""
@project: maxkb
@Author
@file chat_message_serializers.py
@date2023/11/14 13:51
@desc:
"""
from datetime import datetime
import uuid
from typing import List, Dict
from uuid import UUID
from django.core.cache import caches
from django.db.models import QuerySet
from rest_framework import serializers
from application.chat_pipeline.pipeline_manage import PipelineManage
from application.chat_pipeline.step.chat_step.i_chat_step import PostResponseHandler
from application.chat_pipeline.step.chat_step.impl.base_chat_step import BaseChatStep
from application.chat_pipeline.step.generate_human_message_step.impl.base_generate_human_message_step import \
BaseGenerateHumanMessageStep
from application.chat_pipeline.step.reset_problem_step.impl.base_reset_problem_step import BaseResetProblemStep
from application.chat_pipeline.step.search_dataset_step.impl.base_search_dataset_step import BaseSearchDatasetStep
from application.flow.i_step_node import WorkFlowPostHandler
from application.flow.workflow_manage import WorkflowManage, Flow
from application.models import ChatRecord, Chat, Application, ApplicationDatasetMapping, ApplicationTypeChoices, \
WorkFlowVersion
from application.models.api_key_model import ApplicationPublicAccessClient, ApplicationAccessToken
from common.constants.authentication_type import AuthenticationType
from common.exception.app_exception import AppChatNumOutOfBoundsFailed, ChatException
from common.handle.base_to_response import BaseToResponse
from common.handle.impl.response.openai_to_response import OpenaiToResponse
from common.handle.impl.response.system_to_response import SystemToResponse
from common.util.field_message import ErrMessage
from common.util.split_model import flat_map
from dataset.models import Paragraph, Document
from setting.models import Model, Status
from setting.models_provider import get_model_credential
chat_cache = caches['chat_cache']
class ChatInfo:
def __init__(self,
chat_id: str,
dataset_id_list: List[str],
exclude_document_id_list: list[str],
application: Application,
work_flow_version: WorkFlowVersion = None):
"""
:param chat_id: 对话id
:param dataset_id_list: 数据集列表
:param exclude_document_id_list: 排除的文档
:param application: 应用信息
"""
self.chat_id = chat_id
self.application = application
self.dataset_id_list = dataset_id_list
self.exclude_document_id_list = exclude_document_id_list
self.chat_record_list: List[ChatRecord] = []
self.work_flow_version = work_flow_version
@staticmethod
def get_no_references_setting(dataset_setting, model_setting):
no_references_setting = dataset_setting.get(
'no_references_setting', {
'status': 'ai_questioning',
'value': '{question}'})
if no_references_setting.get('status') == 'ai_questioning':
no_references_prompt = model_setting.get('no_references_prompt', '{question}')
no_references_setting['value'] = no_references_prompt if len(no_references_prompt) > 0 else "{question}"
return no_references_setting
def to_base_pipeline_manage_params(self):
dataset_setting = self.application.dataset_setting
model_setting = self.application.model_setting
model_id = self.application.model.id if self.application.model is not None else None
model_params_setting = None
if model_id is not None:
model = QuerySet(Model).filter(id=model_id).first()
credential = get_model_credential(model.provider, model.model_type, model.model_name)
model_params_setting = credential.get_model_params_setting_form(model.model_name).get_default_form_data()
return {
'dataset_id_list': self.dataset_id_list,
'exclude_document_id_list': self.exclude_document_id_list,
'exclude_paragraph_id_list': [],
'top_n': dataset_setting.get('top_n') if 'top_n' in dataset_setting else 3,
'similarity': dataset_setting.get('similarity') if 'similarity' in dataset_setting else 0.6,
'max_paragraph_char_number': dataset_setting.get(
'max_paragraph_char_number') if 'max_paragraph_char_number' in dataset_setting else 5000,
'history_chat_record': self.chat_record_list,
'chat_id': self.chat_id,
'dialogue_number': self.application.dialogue_number,
'problem_optimization_prompt': self.application.problem_optimization_prompt if self.application.problem_optimization_prompt is not None and len(
self.application.problem_optimization_prompt) > 0 else '()里面是用户问题,根据上下文回答揣测用户问题({question}) 要求: 输出一个补全问题,并且放在<data></data>标签中',
'prompt': model_setting.get(
'prompt') if 'prompt' in model_setting and len(model_setting.get(
'prompt')) > 0 else Application.get_default_model_prompt(),
'system': model_setting.get(
'system', None),
'model_id': model_id,
'problem_optimization': self.application.problem_optimization,
'stream': True,
'model_params_setting': model_params_setting if self.application.model_params_setting is None or len(
self.application.model_params_setting.keys()) == 0 else self.application.model_params_setting,
'search_mode': self.application.dataset_setting.get(
'search_mode') if 'search_mode' in self.application.dataset_setting else 'embedding',
'no_references_setting': self.get_no_references_setting(self.application.dataset_setting, model_setting),
'user_id': self.application.user_id
}
def to_pipeline_manage_params(self, problem_text: str, post_response_handler: PostResponseHandler,
exclude_paragraph_id_list, client_id: str, client_type, stream=True):
params = self.to_base_pipeline_manage_params()
return {**params, 'problem_text': problem_text, 'post_response_handler': post_response_handler,
'exclude_paragraph_id_list': exclude_paragraph_id_list, 'stream': stream, 'client_id': client_id,
'client_type': client_type}
def append_chat_record(self, chat_record: ChatRecord, client_id=None):
chat_record.problem_text = chat_record.problem_text[0:10240] if chat_record.problem_text is not None else ""
chat_record.answer_text = chat_record.answer_text[0:40960] if chat_record.problem_text is not None else ""
is_save = True
# 存入缓存中
for index in range(len(self.chat_record_list)):
record = self.chat_record_list[index]
if record.id == chat_record.id:
self.chat_record_list[index] = chat_record
is_save = False
if is_save:
self.chat_record_list.append(chat_record)
if self.application.id is not None:
# 插入数据库
if not QuerySet(Chat).filter(id=self.chat_id).exists():
Chat(id=self.chat_id, application_id=self.application.id, abstract=chat_record.problem_text[0:1024],
client_id=client_id, update_time=datetime.now()).save()
else:
Chat.objects.filter(id=self.chat_id).update(update_time=datetime.now())
# 插入会话记录
chat_record.save()
def get_post_handler(chat_info: ChatInfo):
class PostHandler(PostResponseHandler):
def handler(self,
chat_id: UUID,
chat_record_id,
paragraph_list: List[Paragraph],
problem_text: str,
answer_text,
manage: PipelineManage,
step: BaseChatStep,
padding_problem_text: str = None,
client_id=None,
**kwargs):
chat_record = ChatRecord(id=chat_record_id,
chat_id=chat_id,
problem_text=problem_text,
answer_text=answer_text,
details=manage.get_details(),
message_tokens=manage.context['message_tokens'],
answer_tokens=manage.context['answer_tokens'],
answer_text_list=[answer_text],
run_time=manage.context['run_time'],
index=len(chat_info.chat_record_list) + 1)
chat_info.append_chat_record(chat_record, client_id)
# 重新设置缓存
chat_cache.set(chat_id,
chat_info, timeout=60 * 30)
return PostHandler()
class OpenAIMessage(serializers.Serializer):
content = serializers.CharField(required=True, error_messages=ErrMessage.char('内容'))
role = serializers.CharField(required=True, error_messages=ErrMessage.char('角色'))
class OpenAIInstanceSerializer(serializers.Serializer):
messages = serializers.ListField(child=OpenAIMessage())
chat_id = serializers.UUIDField(required=False, error_messages=ErrMessage.char("对话id"))
re_chat = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean("重新生成"))
stream = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean("流式输出"))
class OpenAIChatSerializer(serializers.Serializer):
application_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("应用id"))
client_id = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端id"))
client_type = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端类型"))
@staticmethod
def get_message(instance):
return instance.get('messages')[-1].get('content')
@staticmethod
def generate_chat(chat_id, application_id, message, client_id):
if chat_id is None:
chat_id = str(uuid.uuid1())
chat = QuerySet(Chat).filter(id=chat_id).first()
if chat is None:
Chat(id=chat_id, application_id=application_id, abstract=message[0:1024], client_id=client_id).save()
return chat_id
def chat(self, instance: Dict, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
OpenAIInstanceSerializer(data=instance).is_valid(raise_exception=True)
chat_id = instance.get('chat_id')
message = self.get_message(instance)
re_chat = instance.get('re_chat', False)
stream = instance.get('stream', False)
application_id = self.data.get('application_id')
client_id = self.data.get('client_id')
client_type = self.data.get('client_type')
chat_id = self.generate_chat(chat_id, application_id, message, client_id)
return ChatMessageSerializer(
data={'chat_id': chat_id, 'message': message,
're_chat': re_chat,
'stream': stream,
'application_id': application_id,
'client_id': client_id,
'client_type': client_type, 'form_data': instance.get('form_data', {})}).chat(
base_to_response=OpenaiToResponse())
class ChatMessageSerializer(serializers.Serializer):
chat_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("对话id"))
message = serializers.CharField(required=True, error_messages=ErrMessage.char("用户问题"))
stream = serializers.BooleanField(required=True, error_messages=ErrMessage.char("是否流式回答"))
re_chat = serializers.BooleanField(required=True, error_messages=ErrMessage.char("是否重新回答"))
chat_record_id = serializers.UUIDField(required=False, allow_null=True,
error_messages=ErrMessage.uuid("对话记录id"))
node_id = serializers.CharField(required=False, allow_null=True, allow_blank=True,
error_messages=ErrMessage.char("节点id"))
runtime_node_id = serializers.CharField(required=False, allow_null=True, allow_blank=True,
error_messages=ErrMessage.char("运行时节点id"))
node_data = serializers.DictField(required=False, allow_null=True, error_messages=ErrMessage.char("节点参数"))
application_id = serializers.UUIDField(required=False, allow_null=True, error_messages=ErrMessage.uuid("应用id"))
client_id = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端id"))
client_type = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端类型"))
form_data = serializers.DictField(required=False, error_messages=ErrMessage.char("全局变量"))
image_list = serializers.ListField(required=False, error_messages=ErrMessage.list("图片"))
document_list = serializers.ListField(required=False, error_messages=ErrMessage.list("文档"))
audio_list = serializers.ListField(required=False, error_messages=ErrMessage.list("音频"))
child_node = serializers.DictField(required=False, allow_null=True, error_messages=ErrMessage.dict("子节点"))
def is_valid_application_workflow(self, *, raise_exception=False):
self.is_valid_intraday_access_num()
def is_valid_chat_id(self, chat_info: ChatInfo):
if self.data.get('application_id') is not None and self.data.get('application_id') != str(
chat_info.application.id):
raise ChatException(500, "会话不存在")
def is_valid_intraday_access_num(self):
if self.data.get('client_type') == AuthenticationType.APPLICATION_ACCESS_TOKEN.value:
access_client = QuerySet(ApplicationPublicAccessClient).filter(id=self.data.get('client_id')).first()
if access_client is None:
access_client = ApplicationPublicAccessClient(id=self.data.get('client_id'),
application_id=self.data.get('application_id'),
access_num=0,
intraday_access_num=0)
access_client.save()
application_access_token = QuerySet(ApplicationAccessToken).filter(
application_id=self.data.get('application_id')).first()
if application_access_token.access_num <= access_client.intraday_access_num:
raise AppChatNumOutOfBoundsFailed(1002, "访问次数超过今日访问量")
def is_valid_application_simple(self, *, chat_info: ChatInfo, raise_exception=False):
self.is_valid_intraday_access_num()
model = chat_info.application.model
if model is None:
return chat_info
model = QuerySet(Model).filter(id=model.id).first()
if model is None:
return chat_info
if model.status == Status.ERROR:
raise ChatException(500, "当前模型不可用")
if model.status == Status.DOWNLOAD:
raise ChatException(500, "模型正在下载中,请稍后再发起对话")
return chat_info
def chat_simple(self, chat_info: ChatInfo, base_to_response):
message = self.data.get('message')
re_chat = self.data.get('re_chat')
stream = self.data.get('stream')
client_id = self.data.get('client_id')
client_type = self.data.get('client_type')
pipeline_manage_builder = PipelineManage.builder()
# 如果开启了问题优化,则添加上问题优化步骤
if chat_info.application.problem_optimization:
pipeline_manage_builder.append_step(BaseResetProblemStep)
# 构建流水线管理器
pipeline_message = (pipeline_manage_builder.append_step(BaseSearchDatasetStep)
.append_step(BaseGenerateHumanMessageStep)
.append_step(BaseChatStep)
.add_base_to_response(base_to_response)
.build())
exclude_paragraph_id_list = []
# 相同问题是否需要排除已经查询到的段落
if re_chat:
paragraph_id_list = flat_map(
[[paragraph.get('id') for paragraph in chat_record.details['search_step']['paragraph_list']] for
chat_record in chat_info.chat_record_list if
chat_record.problem_text == message and 'search_step' in chat_record.details and 'paragraph_list' in
chat_record.details['search_step']])
exclude_paragraph_id_list = list(set(paragraph_id_list))
# 构建运行参数
params = chat_info.to_pipeline_manage_params(message, get_post_handler(chat_info), exclude_paragraph_id_list,
client_id, client_type, stream)
# 运行流水线作业
pipeline_message.run(params)
return pipeline_message.context['chat_result']
@staticmethod
def get_chat_record(chat_info, chat_record_id):
if chat_info is not None:
chat_record_list = [chat_record for chat_record in chat_info.chat_record_list if
str(chat_record.id) == str(chat_record_id)]
if chat_record_list is not None and len(chat_record_list):
return chat_record_list[-1]
chat_record = QuerySet(ChatRecord).filter(id=chat_record_id, chat_id=chat_info.chat_id).first()
if chat_record is None:
raise ChatException(500, "对话纪要不存在")
chat_record = QuerySet(ChatRecord).filter(id=chat_record_id).first()
return chat_record
def chat_work_flow(self, chat_info: ChatInfo, base_to_response):
message = self.data.get('message')
re_chat = self.data.get('re_chat')
stream = self.data.get('stream')
client_id = self.data.get('client_id')
client_type = self.data.get('client_type')
form_data = self.data.get('form_data')
image_list = self.data.get('image_list')
document_list = self.data.get('document_list')
audio_list = self.data.get('audio_list')
user_id = chat_info.application.user_id
chat_record_id = self.data.get('chat_record_id')
chat_record = None
history_chat_record = chat_info.chat_record_list
if chat_record_id is not None:
chat_record = self.get_chat_record(chat_info, chat_record_id)
history_chat_record = [r for r in chat_info.chat_record_list if str(r.id) != chat_record_id]
work_flow_manage = WorkflowManage(Flow.new_instance(chat_info.work_flow_version.work_flow),
{'history_chat_record': history_chat_record, 'question': message,
'chat_id': chat_info.chat_id, 'chat_record_id': str(
uuid.uuid1()) if chat_record is None else chat_record.id,
'stream': stream,
're_chat': re_chat,
'client_id': client_id,
'client_type': client_type,
'user_id': user_id}, WorkFlowPostHandler(chat_info, client_id, client_type),
base_to_response, form_data, image_list, document_list, audio_list,
self.data.get('runtime_node_id'),
self.data.get('node_data'), chat_record, self.data.get('child_node'))
r = work_flow_manage.run()
return r
def chat(self, base_to_response: BaseToResponse = SystemToResponse()):
super().is_valid(raise_exception=True)
chat_info = self.get_chat_info()
self.is_valid_chat_id(chat_info)
if chat_info.application.type == ApplicationTypeChoices.SIMPLE:
self.is_valid_application_simple(raise_exception=True, chat_info=chat_info),
return self.chat_simple(chat_info, base_to_response)
else:
self.is_valid_application_workflow(raise_exception=True)
return self.chat_work_flow(chat_info, base_to_response)
def get_chat_info(self):
self.is_valid(raise_exception=True)
chat_id = self.data.get('chat_id')
chat_info: ChatInfo = chat_cache.get(chat_id)
if chat_info is None:
chat_info: ChatInfo = self.re_open_chat(chat_id)
chat_cache.set(chat_id,
chat_info, timeout=60 * 30)
return chat_info
def re_open_chat(self, chat_id: str):
chat = QuerySet(Chat).filter(id=chat_id).first()
if chat is None:
raise ChatException(500, "会话不存在")
application = QuerySet(Application).filter(id=chat.application_id).first()
if application is None:
raise ChatException(500, "应用不存在")
if application.type == ApplicationTypeChoices.SIMPLE:
return self.re_open_chat_simple(chat_id, application)
else:
return self.re_open_chat_work_flow(chat_id, application)
@staticmethod
def re_open_chat_simple(chat_id, application):
# 数据集id列表
dataset_id_list = [str(row.dataset_id) for row in
QuerySet(ApplicationDatasetMapping).filter(
application_id=application.id)]
# 需要排除的文档
exclude_document_id_list = [str(document.id) for document in
QuerySet(Document).filter(
dataset_id__in=dataset_id_list,
is_active=False)]
chat_info = ChatInfo(chat_id, dataset_id_list, exclude_document_id_list, application)
chat_record_list = list(QuerySet(ChatRecord).filter(chat_id=chat_id).order_by('-create_time')[0:5])
chat_record_list.sort(key=lambda r: r.create_time)
for chat_record in chat_record_list:
chat_info.chat_record_list.append(chat_record)
return chat_info
@staticmethod
def re_open_chat_work_flow(chat_id, application):
work_flow_version = QuerySet(WorkFlowVersion).filter(application_id=application.id).order_by(
'-create_time')[0:1].first()
if work_flow_version is None:
raise ChatException(500, "应用未发布,请发布后再使用")
chat_info = ChatInfo(chat_id, [], [], application, work_flow_version)
chat_record_list = list(QuerySet(ChatRecord).filter(chat_id=chat_id).order_by('-create_time')[0:5])
chat_record_list.sort(key=lambda r: r.create_time)
for chat_record in chat_record_list:
chat_info.chat_record_list.append(chat_record)
return chat_info