feat: application chat debug (#3241)

This commit is contained in:
shaohuzhang1 2025-06-12 15:21:59 +08:00 committed by GitHub
parent 31838f8099
commit e5e993986c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 131 additions and 52 deletions

View File

@ -17,12 +17,14 @@ from common.handle.impl.response.system_to_response import SystemToResponse
class PipelineManage:
def __init__(self, step_list: List[Type[IBaseChatPipelineStep]],
base_to_response: BaseToResponse = SystemToResponse()):
base_to_response: BaseToResponse = SystemToResponse(),
debug=False):
# 步骤执行器
self.step_list = [step() for step in step_list]
# 上下文
self.context = {'message_tokens': 0, 'answer_tokens': 0}
self.base_to_response = base_to_response
self.debug = debug
def run(self, context: Dict = None):
self.context['start_time'] = time.time()
@ -44,6 +46,7 @@ class PipelineManage:
def __init__(self):
self.step_list: List[Type[IBaseChatPipelineStep]] = []
self.base_to_response = SystemToResponse()
self.debug = False
def append_step(self, step: Type[IBaseChatPipelineStep]):
self.step_list.append(step)
@ -53,5 +56,9 @@ class PipelineManage:
self.base_to_response = base_to_response
return self
def add_debug(self, debug):
self.debug = debug
return self
def build(self):
return PipelineManage(step_list=self.step_list, base_to_response=self.base_to_response)
return PipelineManage(step_list=self.step_list, base_to_response=self.base_to_response, debug=self.debug)

View File

@ -18,9 +18,8 @@ from rest_framework import serializers
from rest_framework.exceptions import ValidationError, ErrorDetail
from application.flow.common import Answer, NodeChunk
from application.models import ChatRecord, ChatUserType
from application.models import ApplicationChatUserStats
from common.constants.authentication_type import AuthenticationType
from application.models import ChatRecord, ChatUserType
from common.field.common import InstanceField
chat_cache = cache
@ -45,16 +44,14 @@ def is_interrupt(node, step_variable: Dict, global_variable: Dict):
class WorkFlowPostHandler:
def __init__(self, chat_info, chat_user_id, chat_user_type):
def __init__(self, chat_info):
self.chat_info = chat_info
self.chat_user_id = chat_user_id
self.chat_user_type = chat_user_type
def handler(self, chat_id,
chat_record_id,
answer,
workflow):
question = workflow.params['question']
def handler(self, workflow):
workflow_body = workflow.get_body()
question = workflow_body.get('question')
chat_record_id = workflow_body.get('chat_record_id')
chat_id = workflow_body.get('chat_id')
details = workflow.get_runtime_details()
message_tokens = sum([row.get('message_tokens') for row in details.values() if
'message_tokens' in row and row.get('message_tokens') is not None])
@ -83,14 +80,14 @@ class WorkFlowPostHandler:
answer_text_list=answer_text_list,
run_time=time.time() - workflow.context['start_time'],
index=0)
asker = workflow.context.get('asker', None)
self.chat_info.append_chat_record(chat_record)
self.chat_info.set_cahce()
if [ChatUserType.ANONYMOUS_USER.value, ChatUserType.CHAT_USER.value].__contains__(
self.chat_user_type):
workflow_body.get('chat_user_type')):
application_public_access_client = (QuerySet(ApplicationChatUserStats)
.filter(chat_user_id=self.chat_user_id,
chat_user_type=self.chat_user_type,
.filter(chat_user_id=workflow_body.get('chat_user_id'),
chat_user_type=workflow_body.get('chat_user_type'),
application_id=self.chat_info.application.id).first())
if application_public_access_client is not None:
application_public_access_client.access_num = application_public_access_client.access_num + 1
@ -141,13 +138,16 @@ class FlowParamsSerializer(serializers.Serializer):
stream = serializers.BooleanField(required=True, label="流式输出")
client_id = serializers.CharField(required=False, label="客户端id")
chat_user_id = serializers.CharField(required=False, label="对话用户id")
client_type = serializers.CharField(required=False, label="客户端类型")
chat_user_type = serializers.CharField(required=False, label="对话用户类型")
workspace_id = serializers.CharField(required=True, label="工作空间id")
user_id = serializers.UUIDField(required=True, label="用户id")
re_chat = serializers.BooleanField(required=True, label="换个答案")
debug = serializers.BooleanField(required=True, label="是否debug")
class INode:
view_type = 'many_view'

View File

@ -367,9 +367,7 @@ class WorkflowManage:
'\n\n'.join([a.get('content') for a in answer]) for answer in
answer_text_list)
answer_list = reduce(lambda pre, _n: [*pre, *_n], answer_text_list, [])
self.work_flow_post_handler.handler(self.params['chat_id'], self.params['chat_record_id'],
answer_text,
self)
self.work_flow_post_handler.handler(self)
return self.base_to_response.to_block_response(self.params['chat_id'],
self.params['chat_record_id'], answer_text, True
, message_tokens, answer_tokens,
@ -384,6 +382,9 @@ class WorkflowManage:
self.run_chain_async(current_node, node_result_future, language)
return tools.to_stream_response_simple(self.await_result())
def get_body(self):
return self.params
def is_run(self, timeout=0.5):
future_list_len = len(self.future_list)
try:
@ -420,9 +421,7 @@ class WorkflowManage:
'message_tokens' in row and row.get('message_tokens') is not None])
answer_tokens = sum([row.get('answer_tokens') for row in details.values() if
'answer_tokens' in row and row.get('answer_tokens') is not None])
self.work_flow_post_handler.handler(self.params['chat_id'], self.params['chat_record_id'],
self.answer,
self)
self.work_flow_post_handler.handler(self)
yield self.base_to_response.to_stream_chunk_response(self.params['chat_id'],
self.params['chat_record_id'],
'',

View File

@ -27,15 +27,19 @@ class ChatInfo:
chat_user_type: str,
knowledge_id_list: List[str],
exclude_document_id_list: list[str],
application_id: str,
application: Application,
work_flow_version: WorkFlowVersion = None):
work_flow_version: WorkFlowVersion = None,
debug=False):
"""
:param chat_id: 对话id
:param chat_id: 对话id
:param chat_user_id 对话用户id
:param chat_user_type 对话用户类型
:param knowledge_id_list: 知识库列表
:param exclude_document_id_list: 排除的文档
:param application_id 应用id
:param application: 应用信息
:param debug 是否是调试
"""
self.chat_id = chat_id
self.chat_user_id = chat_user_id
@ -43,8 +47,10 @@ class ChatInfo:
self.application = application
self.knowledge_id_list = knowledge_id_list
self.exclude_document_id_list = exclude_document_id_list
self.application_id = application_id
self.chat_record_list: List[ChatRecord] = []
self.work_flow_version = work_flow_version
self.debug = debug
@staticmethod
def get_no_references_setting(knowledge_setting, model_setting):
@ -116,17 +122,17 @@ class ChatInfo:
if record.id == chat_record.id:
self.chat_record_list[index] = chat_record
is_save = False
break
if is_save:
self.chat_record_list.append(chat_record)
cache.set(Cache_Version.CHAT.get_key(key=self.chat_id), self, version=Cache_Version.CHAT.get_version(),
timeout=60 * 30)
if self.application.id is not None:
Chat(id=self.chat_id, application_id=self.application.id, abstract=chat_record.problem_text[0:1024],
chat_user_id=self.chat_user_id, chat_user_type=self.chat_user_type).save()
else:
QuerySet(Chat).filter(id=self.chat_id).update(update_time=datetime.now())
if not self.debug:
if not QuerySet(Chat).filter(id=self.chat_id).exists():
Chat(id=self.chat_id, application_id=self.application.id, abstract=chat_record.problem_text[0:1024],
chat_user_id=self.chat_user_id, chat_user_type=self.chat_user_type).save()
else:
QuerySet(Chat).filter(id=self.chat_id).update(update_time=datetime.now())
# 插入会话记录
chat_record.save()
chat_record.save()
def set_cache(self):
cache.set(Cache_Version.CHAT.get_key(key=self.chat_id), self, version=Cache_Version.CHAT.get_version(),

View File

@ -52,4 +52,7 @@ urlpatterns = [
path(
'workspace/<str:workspace_id>/application/<str:application_id>/work_flow_version/<str:work_flow_version_id>',
views.ApplicationVersionView.Operate.as_view()),
path('workspace/<str:workspace_id>/application/<str:application_id>/open', views.OpenView.as_view()),
path('chat_message/<str:chat_id>', views.ChatView.as_view()),
]

View File

@ -6,6 +6,8 @@
@date2025/6/10 11:00
@desc:
"""
import uuid
from django.utils.translation import gettext_lazy as _
from drf_spectacular.utils import extend_schema
from rest_framework.request import Request
@ -13,7 +15,11 @@ from rest_framework.views import APIView
from application.api.application_chat import ApplicationChatQueryAPI, ApplicationChatQueryPageAPI, \
ApplicationChatExportAPI
from application.models import ChatUserType
from application.serializers.application_chat import ApplicationChatQuerySerializers
from chat.api.chat_api import ChatAPI
from chat.api.chat_authentication_api import ChatOpenAPI
from chat.serializers.chat import OpenChatSerializers, ChatSerializers, DebugChatSerializers
from common.auth import TokenAuth
from common.auth.authentication import has_permissions
from common.constants.permission_constants import PermissionConstants, RoleConstants
@ -81,3 +87,39 @@ class ApplicationChat(APIView):
return ApplicationChatQuerySerializers(
data={**query_params_to_single_dict(request.query_params), 'application_id': application_id,
}).export(request.data)
class OpenView(APIView):
authentication_classes = [TokenAuth]
@extend_schema(
methods=['GET'],
description=_("Get a temporary session id based on the application id"),
summary=_("Get a temporary session id based on the application id"),
operation_id=_("Get a temporary session id based on the application id"), # type: ignore
parameters=ChatOpenAPI.get_parameters(),
responses=None,
tags=[_('Application')] # type: ignore
)
def get(self, request: Request, workspace_id: str, application_id: str):
return result.success(OpenChatSerializers(
data={'workspace_id': workspace_id, 'application_id': application_id,
'chat_user_id': str(uuid.uuid1()), 'chat_user_type': ChatUserType.ANONYMOUS_USER,
'debug': True}).open())
class ChatView(APIView):
authentication_classes = [TokenAuth]
@extend_schema(
methods=['POST'],
description=_("dialogue"),
summary=_("dialogue"),
operation_id=_("dialogue"), # type: ignore
request=ChatAPI.get_request(),
parameters=ChatAPI.get_parameters(),
responses=None,
tags=[_('Application')] # type: ignore
)
def post(self, request: Request, chat_id: str):
return DebugChatSerializers(data={'chat_id': chat_id}).chat(request.data)

View File

@ -94,12 +94,27 @@ def get_post_handler(chat_info: ChatInfo):
return PostHandler()
class DebugChatSerializers(serializers.Serializer):
chat_id = serializers.UUIDField(required=True, label=_("Conversation ID"))
def chat(self, instance: dict, base_to_response: BaseToResponse = SystemToResponse()):
self.is_valid(raise_exception=True)
chat_id = self.data.get('chat_id')
chat_info: ChatInfo = ChatInfo.get_cache(chat_id)
return ChatSerializers(data={
'chat_id': chat_id, "chat_user_id": chat_info.chat_user_id,
"chat_user_type": chat_info.chat_user_type,
"application_id": chat_info.application.id, "debug": True
}).chat(instance, base_to_response)
class ChatSerializers(serializers.Serializer):
chat_id = serializers.UUIDField(required=True, label=_("Conversation ID"))
chat_user_id = serializers.CharField(required=True, label=_("Client id"))
chat_user_type = serializers.CharField(required=True, label=_("Client Type"))
application_id = serializers.UUIDField(required=True, allow_null=True,
label=_("Application ID"))
debug = serializers.BooleanField(required=False, label=_("Debug"))
def is_valid_application_workflow(self, *, raise_exception=False):
self.is_valid_intraday_access_num()
@ -158,6 +173,7 @@ class ChatSerializers(serializers.Serializer):
.append_step(BaseGenerateHumanMessageStep)
.append_step(BaseChatStep)
.add_base_to_response(base_to_response)
.add_debug(self.data.get('debug', False))
.build())
exclude_paragraph_id_list = []
# 相同问题是否需要排除已经查询到的段落
@ -189,18 +205,18 @@ class ChatSerializers(serializers.Serializer):
return chat_record
def chat_work_flow(self, chat_info: ChatInfo, instance: dict, base_to_response):
message = self.data.get('message')
re_chat = self.data.get('re_chat')
stream = self.data.get('stream')
chat_user_id = instance.get('chat_user_id')
chat_user_type = instance.get('chat_user_type')
form_data = self.data.get('form_data')
image_list = self.data.get('image_list')
document_list = self.data.get('document_list')
audio_list = self.data.get('audio_list')
other_list = self.data.get('other_list')
user_id = chat_info.application.user_id
chat_record_id = self.data.get('chat_record_id')
message = instance.get('message')
re_chat = instance.get('re_chat')
stream = instance.get('stream')
chat_user_id = self.data.get("chat_user_id")
chat_user_type = self.data.get('chat_user_type')
form_data = instance.get('form_data')
image_list = instance.get('image_list')
document_list = instance.get('document_list')
audio_list = instance.get('audio_list')
other_list = instance.get('other_list')
workspace_id = chat_info.application.workspace_id
chat_record_id = instance.get('chat_record_id')
chat_record = None
history_chat_record = chat_info.chat_record_list
if chat_record_id is not None:
@ -214,8 +230,9 @@ class ChatSerializers(serializers.Serializer):
're_chat': re_chat,
'chat_user_id': chat_user_id,
'chat_user_type': chat_user_type,
'user_id': user_id},
WorkFlowPostHandler(chat_info, chat_user_id, chat_user_type),
'workspace_id': workspace_id,
'debug': self.data.get('debug', False)},
WorkFlowPostHandler(chat_info),
base_to_response, form_data, image_list, document_list, audio_list,
other_list,
self.data.get('runtime_node_id'),
@ -229,7 +246,7 @@ class ChatSerializers(serializers.Serializer):
chat_info = self.get_chat_info()
self.is_valid_chat_id(chat_info)
if chat_info.application.type == ApplicationTypeChoices.SIMPLE:
self.is_valid_application_simple(raise_exception=True, chat_info=chat_info),
self.is_valid_application_simple(raise_exception=True, chat_info=chat_info)
return self.chat_simple(chat_info, instance, base_to_response)
else:
self.is_valid_application_workflow(raise_exception=True)
@ -295,6 +312,7 @@ class OpenChatSerializers(serializers.Serializer):
application_id = serializers.UUIDField(required=True)
chat_user_id = serializers.CharField(required=True, label=_("Client id"))
chat_user_type = serializers.CharField(required=True, label=_("Client Type"))
debug = serializers.BooleanField(required=True, label=_("Debug"))
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
@ -317,6 +335,7 @@ class OpenChatSerializers(serializers.Serializer):
application_id = self.data.get('application_id')
chat_user_id = self.data.get("chat_user_id")
chat_user_type = self.data.get("chat_user_type")
debug = self.data.get("debug")
chat_id = str(uuid.uuid7())
work_flow_version = QuerySet(WorkFlowVersion).filter(application_id=application_id).order_by(
'-create_time')[0:1].first()
@ -326,13 +345,15 @@ class OpenChatSerializers(serializers.Serializer):
"The application has not been published. Please use it after publishing."))
ChatInfo(chat_id, chat_user_id, chat_user_type, [],
[],
application, work_flow_version).set_cache()
application_id,
application, work_flow_version, debug).set_cache()
return chat_id
def open_simple(self, application):
application_id = self.data.get('application_id')
chat_user_id = self.data.get("chat_user_id")
chat_user_type = self.data.get("chat_user_type")
debug = self.data.get("debug")
knowledge_id_list = [str(row.dataset_id) for row in
QuerySet(ApplicationKnowledgeMapping).filter(
application_id=application_id)]
@ -342,5 +363,6 @@ class OpenChatSerializers(serializers.Serializer):
QuerySet(Document).filter(
knowledge_id__in=knowledge_id_list,
is_active=False)],
application).set_cache()
application_id,
application, debug=debug).set_cache()
return chat_id