diff --git a/apps/application/chat_pipeline/pipeline_manage.py b/apps/application/chat_pipeline/pipeline_manage.py index 7c4acb3a3..4c0b10505 100644 --- a/apps/application/chat_pipeline/pipeline_manage.py +++ b/apps/application/chat_pipeline/pipeline_manage.py @@ -17,12 +17,14 @@ from common.handle.impl.response.system_to_response import SystemToResponse class PipelineManage: def __init__(self, step_list: List[Type[IBaseChatPipelineStep]], - base_to_response: BaseToResponse = SystemToResponse()): + base_to_response: BaseToResponse = SystemToResponse(), + debug=False): # 步骤执行器 self.step_list = [step() for step in step_list] # 上下文 self.context = {'message_tokens': 0, 'answer_tokens': 0} self.base_to_response = base_to_response + self.debug = debug def run(self, context: Dict = None): self.context['start_time'] = time.time() @@ -44,6 +46,7 @@ class PipelineManage: def __init__(self): self.step_list: List[Type[IBaseChatPipelineStep]] = [] self.base_to_response = SystemToResponse() + self.debug = False def append_step(self, step: Type[IBaseChatPipelineStep]): self.step_list.append(step) @@ -53,5 +56,9 @@ class PipelineManage: self.base_to_response = base_to_response return self + def add_debug(self, debug): + self.debug = debug + return self + def build(self): - return PipelineManage(step_list=self.step_list, base_to_response=self.base_to_response) + return PipelineManage(step_list=self.step_list, base_to_response=self.base_to_response, debug=self.debug) diff --git a/apps/application/flow/i_step_node.py b/apps/application/flow/i_step_node.py index 2fd61408a..bea03986a 100644 --- a/apps/application/flow/i_step_node.py +++ b/apps/application/flow/i_step_node.py @@ -18,9 +18,8 @@ from rest_framework import serializers from rest_framework.exceptions import ValidationError, ErrorDetail from application.flow.common import Answer, NodeChunk -from application.models import ChatRecord, ChatUserType from application.models import ApplicationChatUserStats -from common.constants.authentication_type import AuthenticationType +from application.models import ChatRecord, ChatUserType from common.field.common import InstanceField chat_cache = cache @@ -45,16 +44,14 @@ def is_interrupt(node, step_variable: Dict, global_variable: Dict): class WorkFlowPostHandler: - def __init__(self, chat_info, chat_user_id, chat_user_type): + def __init__(self, chat_info): self.chat_info = chat_info - self.chat_user_id = chat_user_id - self.chat_user_type = chat_user_type - def handler(self, chat_id, - chat_record_id, - answer, - workflow): - question = workflow.params['question'] + def handler(self, workflow): + workflow_body = workflow.get_body() + question = workflow_body.get('question') + chat_record_id = workflow_body.get('chat_record_id') + chat_id = workflow_body.get('chat_id') details = workflow.get_runtime_details() message_tokens = sum([row.get('message_tokens') for row in details.values() if 'message_tokens' in row and row.get('message_tokens') is not None]) @@ -83,14 +80,14 @@ class WorkFlowPostHandler: answer_text_list=answer_text_list, run_time=time.time() - workflow.context['start_time'], index=0) - asker = workflow.context.get('asker', None) + self.chat_info.append_chat_record(chat_record) self.chat_info.set_cahce() if [ChatUserType.ANONYMOUS_USER.value, ChatUserType.CHAT_USER.value].__contains__( - self.chat_user_type): + workflow_body.get('chat_user_type')): application_public_access_client = (QuerySet(ApplicationChatUserStats) - .filter(chat_user_id=self.chat_user_id, - chat_user_type=self.chat_user_type, + .filter(chat_user_id=workflow_body.get('chat_user_id'), + chat_user_type=workflow_body.get('chat_user_type'), application_id=self.chat_info.application.id).first()) if application_public_access_client is not None: application_public_access_client.access_num = application_public_access_client.access_num + 1 @@ -141,13 +138,16 @@ class FlowParamsSerializer(serializers.Serializer): stream = serializers.BooleanField(required=True, label="流式输出") - client_id = serializers.CharField(required=False, label="客户端id") + chat_user_id = serializers.CharField(required=False, label="对话用户id") - client_type = serializers.CharField(required=False, label="客户端类型") + chat_user_type = serializers.CharField(required=False, label="对话用户类型") + + workspace_id = serializers.CharField(required=True, label="工作空间id") - user_id = serializers.UUIDField(required=True, label="用户id") re_chat = serializers.BooleanField(required=True, label="换个答案") + debug = serializers.BooleanField(required=True, label="是否debug") + class INode: view_type = 'many_view' diff --git a/apps/application/flow/workflow_manage.py b/apps/application/flow/workflow_manage.py index d0e96680f..7535be2a9 100644 --- a/apps/application/flow/workflow_manage.py +++ b/apps/application/flow/workflow_manage.py @@ -367,9 +367,7 @@ class WorkflowManage: '\n\n'.join([a.get('content') for a in answer]) for answer in answer_text_list) answer_list = reduce(lambda pre, _n: [*pre, *_n], answer_text_list, []) - self.work_flow_post_handler.handler(self.params['chat_id'], self.params['chat_record_id'], - answer_text, - self) + self.work_flow_post_handler.handler(self) return self.base_to_response.to_block_response(self.params['chat_id'], self.params['chat_record_id'], answer_text, True , message_tokens, answer_tokens, @@ -384,6 +382,9 @@ class WorkflowManage: self.run_chain_async(current_node, node_result_future, language) return tools.to_stream_response_simple(self.await_result()) + def get_body(self): + return self.params + def is_run(self, timeout=0.5): future_list_len = len(self.future_list) try: @@ -420,9 +421,7 @@ class WorkflowManage: 'message_tokens' in row and row.get('message_tokens') is not None]) answer_tokens = sum([row.get('answer_tokens') for row in details.values() if 'answer_tokens' in row and row.get('answer_tokens') is not None]) - self.work_flow_post_handler.handler(self.params['chat_id'], self.params['chat_record_id'], - self.answer, - self) + self.work_flow_post_handler.handler(self) yield self.base_to_response.to_stream_chunk_response(self.params['chat_id'], self.params['chat_record_id'], '', diff --git a/apps/application/serializers/common.py b/apps/application/serializers/common.py index 8af009077..29d720d09 100644 --- a/apps/application/serializers/common.py +++ b/apps/application/serializers/common.py @@ -27,15 +27,19 @@ class ChatInfo: chat_user_type: str, knowledge_id_list: List[str], exclude_document_id_list: list[str], + application_id: str, application: Application, - work_flow_version: WorkFlowVersion = None): + work_flow_version: WorkFlowVersion = None, + debug=False): """ - :param chat_id: 对话id + :param chat_id: 对话id :param chat_user_id 对话用户id :param chat_user_type 对话用户类型 :param knowledge_id_list: 知识库列表 :param exclude_document_id_list: 排除的文档 + :param application_id 应用id :param application: 应用信息 + :param debug 是否是调试 """ self.chat_id = chat_id self.chat_user_id = chat_user_id @@ -43,8 +47,10 @@ class ChatInfo: self.application = application self.knowledge_id_list = knowledge_id_list self.exclude_document_id_list = exclude_document_id_list + self.application_id = application_id self.chat_record_list: List[ChatRecord] = [] self.work_flow_version = work_flow_version + self.debug = debug @staticmethod def get_no_references_setting(knowledge_setting, model_setting): @@ -116,17 +122,17 @@ class ChatInfo: if record.id == chat_record.id: self.chat_record_list[index] = chat_record is_save = False + break if is_save: self.chat_record_list.append(chat_record) - cache.set(Cache_Version.CHAT.get_key(key=self.chat_id), self, version=Cache_Version.CHAT.get_version(), - timeout=60 * 30) - if self.application.id is not None: - Chat(id=self.chat_id, application_id=self.application.id, abstract=chat_record.problem_text[0:1024], - chat_user_id=self.chat_user_id, chat_user_type=self.chat_user_type).save() - else: - QuerySet(Chat).filter(id=self.chat_id).update(update_time=datetime.now()) + if not self.debug: + if not QuerySet(Chat).filter(id=self.chat_id).exists(): + Chat(id=self.chat_id, application_id=self.application.id, abstract=chat_record.problem_text[0:1024], + chat_user_id=self.chat_user_id, chat_user_type=self.chat_user_type).save() + else: + QuerySet(Chat).filter(id=self.chat_id).update(update_time=datetime.now()) # 插入会话记录 - chat_record.save() + chat_record.save() def set_cache(self): cache.set(Cache_Version.CHAT.get_key(key=self.chat_id), self, version=Cache_Version.CHAT.get_version(), diff --git a/apps/application/urls.py b/apps/application/urls.py index 9e0a39485..4cf1fc7f0 100644 --- a/apps/application/urls.py +++ b/apps/application/urls.py @@ -52,4 +52,7 @@ urlpatterns = [ path( 'workspace//application//work_flow_version/', views.ApplicationVersionView.Operate.as_view()), + path('workspace//application//open', views.OpenView.as_view()), + path('chat_message/', views.ChatView.as_view()), + ] diff --git a/apps/application/views/application_chat.py b/apps/application/views/application_chat.py index a9b41c4d3..584539581 100644 --- a/apps/application/views/application_chat.py +++ b/apps/application/views/application_chat.py @@ -6,6 +6,8 @@ @date:2025/6/10 11:00 @desc: """ +import uuid + from django.utils.translation import gettext_lazy as _ from drf_spectacular.utils import extend_schema from rest_framework.request import Request @@ -13,7 +15,11 @@ from rest_framework.views import APIView from application.api.application_chat import ApplicationChatQueryAPI, ApplicationChatQueryPageAPI, \ ApplicationChatExportAPI +from application.models import ChatUserType from application.serializers.application_chat import ApplicationChatQuerySerializers +from chat.api.chat_api import ChatAPI +from chat.api.chat_authentication_api import ChatOpenAPI +from chat.serializers.chat import OpenChatSerializers, ChatSerializers, DebugChatSerializers from common.auth import TokenAuth from common.auth.authentication import has_permissions from common.constants.permission_constants import PermissionConstants, RoleConstants @@ -81,3 +87,39 @@ class ApplicationChat(APIView): return ApplicationChatQuerySerializers( data={**query_params_to_single_dict(request.query_params), 'application_id': application_id, }).export(request.data) + + +class OpenView(APIView): + authentication_classes = [TokenAuth] + + @extend_schema( + methods=['GET'], + description=_("Get a temporary session id based on the application id"), + summary=_("Get a temporary session id based on the application id"), + operation_id=_("Get a temporary session id based on the application id"), # type: ignore + parameters=ChatOpenAPI.get_parameters(), + responses=None, + tags=[_('Application')] # type: ignore + ) + def get(self, request: Request, workspace_id: str, application_id: str): + return result.success(OpenChatSerializers( + data={'workspace_id': workspace_id, 'application_id': application_id, + 'chat_user_id': str(uuid.uuid1()), 'chat_user_type': ChatUserType.ANONYMOUS_USER, + 'debug': True}).open()) + + +class ChatView(APIView): + authentication_classes = [TokenAuth] + + @extend_schema( + methods=['POST'], + description=_("dialogue"), + summary=_("dialogue"), + operation_id=_("dialogue"), # type: ignore + request=ChatAPI.get_request(), + parameters=ChatAPI.get_parameters(), + responses=None, + tags=[_('Application')] # type: ignore + ) + def post(self, request: Request, chat_id: str): + return DebugChatSerializers(data={'chat_id': chat_id}).chat(request.data) diff --git a/apps/chat/serializers/chat.py b/apps/chat/serializers/chat.py index 54bc79aa1..b47ba160e 100644 --- a/apps/chat/serializers/chat.py +++ b/apps/chat/serializers/chat.py @@ -94,12 +94,27 @@ def get_post_handler(chat_info: ChatInfo): return PostHandler() +class DebugChatSerializers(serializers.Serializer): + chat_id = serializers.UUIDField(required=True, label=_("Conversation ID")) + + def chat(self, instance: dict, base_to_response: BaseToResponse = SystemToResponse()): + self.is_valid(raise_exception=True) + chat_id = self.data.get('chat_id') + chat_info: ChatInfo = ChatInfo.get_cache(chat_id) + return ChatSerializers(data={ + 'chat_id': chat_id, "chat_user_id": chat_info.chat_user_id, + "chat_user_type": chat_info.chat_user_type, + "application_id": chat_info.application.id, "debug": True + }).chat(instance, base_to_response) + + class ChatSerializers(serializers.Serializer): chat_id = serializers.UUIDField(required=True, label=_("Conversation ID")) chat_user_id = serializers.CharField(required=True, label=_("Client id")) chat_user_type = serializers.CharField(required=True, label=_("Client Type")) application_id = serializers.UUIDField(required=True, allow_null=True, label=_("Application ID")) + debug = serializers.BooleanField(required=False, label=_("Debug")) def is_valid_application_workflow(self, *, raise_exception=False): self.is_valid_intraday_access_num() @@ -158,6 +173,7 @@ class ChatSerializers(serializers.Serializer): .append_step(BaseGenerateHumanMessageStep) .append_step(BaseChatStep) .add_base_to_response(base_to_response) + .add_debug(self.data.get('debug', False)) .build()) exclude_paragraph_id_list = [] # 相同问题是否需要排除已经查询到的段落 @@ -189,18 +205,18 @@ class ChatSerializers(serializers.Serializer): return chat_record def chat_work_flow(self, chat_info: ChatInfo, instance: dict, base_to_response): - message = self.data.get('message') - re_chat = self.data.get('re_chat') - stream = self.data.get('stream') - chat_user_id = instance.get('chat_user_id') - chat_user_type = instance.get('chat_user_type') - form_data = self.data.get('form_data') - image_list = self.data.get('image_list') - document_list = self.data.get('document_list') - audio_list = self.data.get('audio_list') - other_list = self.data.get('other_list') - user_id = chat_info.application.user_id - chat_record_id = self.data.get('chat_record_id') + message = instance.get('message') + re_chat = instance.get('re_chat') + stream = instance.get('stream') + chat_user_id = self.data.get("chat_user_id") + chat_user_type = self.data.get('chat_user_type') + form_data = instance.get('form_data') + image_list = instance.get('image_list') + document_list = instance.get('document_list') + audio_list = instance.get('audio_list') + other_list = instance.get('other_list') + workspace_id = chat_info.application.workspace_id + chat_record_id = instance.get('chat_record_id') chat_record = None history_chat_record = chat_info.chat_record_list if chat_record_id is not None: @@ -214,8 +230,9 @@ class ChatSerializers(serializers.Serializer): 're_chat': re_chat, 'chat_user_id': chat_user_id, 'chat_user_type': chat_user_type, - 'user_id': user_id}, - WorkFlowPostHandler(chat_info, chat_user_id, chat_user_type), + 'workspace_id': workspace_id, + 'debug': self.data.get('debug', False)}, + WorkFlowPostHandler(chat_info), base_to_response, form_data, image_list, document_list, audio_list, other_list, self.data.get('runtime_node_id'), @@ -229,7 +246,7 @@ class ChatSerializers(serializers.Serializer): chat_info = self.get_chat_info() self.is_valid_chat_id(chat_info) if chat_info.application.type == ApplicationTypeChoices.SIMPLE: - self.is_valid_application_simple(raise_exception=True, chat_info=chat_info), + self.is_valid_application_simple(raise_exception=True, chat_info=chat_info) return self.chat_simple(chat_info, instance, base_to_response) else: self.is_valid_application_workflow(raise_exception=True) @@ -295,6 +312,7 @@ class OpenChatSerializers(serializers.Serializer): application_id = serializers.UUIDField(required=True) chat_user_id = serializers.CharField(required=True, label=_("Client id")) chat_user_type = serializers.CharField(required=True, label=_("Client Type")) + debug = serializers.BooleanField(required=True, label=_("Debug")) def is_valid(self, *, raise_exception=False): super().is_valid(raise_exception=True) @@ -317,6 +335,7 @@ class OpenChatSerializers(serializers.Serializer): application_id = self.data.get('application_id') chat_user_id = self.data.get("chat_user_id") chat_user_type = self.data.get("chat_user_type") + debug = self.data.get("debug") chat_id = str(uuid.uuid7()) work_flow_version = QuerySet(WorkFlowVersion).filter(application_id=application_id).order_by( '-create_time')[0:1].first() @@ -326,13 +345,15 @@ class OpenChatSerializers(serializers.Serializer): "The application has not been published. Please use it after publishing.")) ChatInfo(chat_id, chat_user_id, chat_user_type, [], [], - application, work_flow_version).set_cache() + application_id, + application, work_flow_version, debug).set_cache() return chat_id def open_simple(self, application): application_id = self.data.get('application_id') chat_user_id = self.data.get("chat_user_id") chat_user_type = self.data.get("chat_user_type") + debug = self.data.get("debug") knowledge_id_list = [str(row.dataset_id) for row in QuerySet(ApplicationKnowledgeMapping).filter( application_id=application_id)] @@ -342,5 +363,6 @@ class OpenChatSerializers(serializers.Serializer): QuerySet(Document).filter( knowledge_id__in=knowledge_id_list, is_active=False)], - application).set_cache() + application_id, + application, debug=debug).set_cache() return chat_id