mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 01:33:05 +00:00
feat: application chat debug (#3241)
This commit is contained in:
parent
31838f8099
commit
e5e993986c
|
|
@ -17,12 +17,14 @@ from common.handle.impl.response.system_to_response import SystemToResponse
|
|||
|
||||
class PipelineManage:
|
||||
def __init__(self, step_list: List[Type[IBaseChatPipelineStep]],
|
||||
base_to_response: BaseToResponse = SystemToResponse()):
|
||||
base_to_response: BaseToResponse = SystemToResponse(),
|
||||
debug=False):
|
||||
# 步骤执行器
|
||||
self.step_list = [step() for step in step_list]
|
||||
# 上下文
|
||||
self.context = {'message_tokens': 0, 'answer_tokens': 0}
|
||||
self.base_to_response = base_to_response
|
||||
self.debug = debug
|
||||
|
||||
def run(self, context: Dict = None):
|
||||
self.context['start_time'] = time.time()
|
||||
|
|
@ -44,6 +46,7 @@ class PipelineManage:
|
|||
def __init__(self):
|
||||
self.step_list: List[Type[IBaseChatPipelineStep]] = []
|
||||
self.base_to_response = SystemToResponse()
|
||||
self.debug = False
|
||||
|
||||
def append_step(self, step: Type[IBaseChatPipelineStep]):
|
||||
self.step_list.append(step)
|
||||
|
|
@ -53,5 +56,9 @@ class PipelineManage:
|
|||
self.base_to_response = base_to_response
|
||||
return self
|
||||
|
||||
def add_debug(self, debug):
|
||||
self.debug = debug
|
||||
return self
|
||||
|
||||
def build(self):
|
||||
return PipelineManage(step_list=self.step_list, base_to_response=self.base_to_response)
|
||||
return PipelineManage(step_list=self.step_list, base_to_response=self.base_to_response, debug=self.debug)
|
||||
|
|
|
|||
|
|
@ -18,9 +18,8 @@ from rest_framework import serializers
|
|||
from rest_framework.exceptions import ValidationError, ErrorDetail
|
||||
|
||||
from application.flow.common import Answer, NodeChunk
|
||||
from application.models import ChatRecord, ChatUserType
|
||||
from application.models import ApplicationChatUserStats
|
||||
from common.constants.authentication_type import AuthenticationType
|
||||
from application.models import ChatRecord, ChatUserType
|
||||
from common.field.common import InstanceField
|
||||
|
||||
chat_cache = cache
|
||||
|
|
@ -45,16 +44,14 @@ def is_interrupt(node, step_variable: Dict, global_variable: Dict):
|
|||
|
||||
|
||||
class WorkFlowPostHandler:
|
||||
def __init__(self, chat_info, chat_user_id, chat_user_type):
|
||||
def __init__(self, chat_info):
|
||||
self.chat_info = chat_info
|
||||
self.chat_user_id = chat_user_id
|
||||
self.chat_user_type = chat_user_type
|
||||
|
||||
def handler(self, chat_id,
|
||||
chat_record_id,
|
||||
answer,
|
||||
workflow):
|
||||
question = workflow.params['question']
|
||||
def handler(self, workflow):
|
||||
workflow_body = workflow.get_body()
|
||||
question = workflow_body.get('question')
|
||||
chat_record_id = workflow_body.get('chat_record_id')
|
||||
chat_id = workflow_body.get('chat_id')
|
||||
details = workflow.get_runtime_details()
|
||||
message_tokens = sum([row.get('message_tokens') for row in details.values() if
|
||||
'message_tokens' in row and row.get('message_tokens') is not None])
|
||||
|
|
@ -83,14 +80,14 @@ class WorkFlowPostHandler:
|
|||
answer_text_list=answer_text_list,
|
||||
run_time=time.time() - workflow.context['start_time'],
|
||||
index=0)
|
||||
asker = workflow.context.get('asker', None)
|
||||
|
||||
self.chat_info.append_chat_record(chat_record)
|
||||
self.chat_info.set_cahce()
|
||||
if [ChatUserType.ANONYMOUS_USER.value, ChatUserType.CHAT_USER.value].__contains__(
|
||||
self.chat_user_type):
|
||||
workflow_body.get('chat_user_type')):
|
||||
application_public_access_client = (QuerySet(ApplicationChatUserStats)
|
||||
.filter(chat_user_id=self.chat_user_id,
|
||||
chat_user_type=self.chat_user_type,
|
||||
.filter(chat_user_id=workflow_body.get('chat_user_id'),
|
||||
chat_user_type=workflow_body.get('chat_user_type'),
|
||||
application_id=self.chat_info.application.id).first())
|
||||
if application_public_access_client is not None:
|
||||
application_public_access_client.access_num = application_public_access_client.access_num + 1
|
||||
|
|
@ -141,13 +138,16 @@ class FlowParamsSerializer(serializers.Serializer):
|
|||
|
||||
stream = serializers.BooleanField(required=True, label="流式输出")
|
||||
|
||||
client_id = serializers.CharField(required=False, label="客户端id")
|
||||
chat_user_id = serializers.CharField(required=False, label="对话用户id")
|
||||
|
||||
client_type = serializers.CharField(required=False, label="客户端类型")
|
||||
chat_user_type = serializers.CharField(required=False, label="对话用户类型")
|
||||
|
||||
workspace_id = serializers.CharField(required=True, label="工作空间id")
|
||||
|
||||
user_id = serializers.UUIDField(required=True, label="用户id")
|
||||
re_chat = serializers.BooleanField(required=True, label="换个答案")
|
||||
|
||||
debug = serializers.BooleanField(required=True, label="是否debug")
|
||||
|
||||
|
||||
class INode:
|
||||
view_type = 'many_view'
|
||||
|
|
|
|||
|
|
@ -367,9 +367,7 @@ class WorkflowManage:
|
|||
'\n\n'.join([a.get('content') for a in answer]) for answer in
|
||||
answer_text_list)
|
||||
answer_list = reduce(lambda pre, _n: [*pre, *_n], answer_text_list, [])
|
||||
self.work_flow_post_handler.handler(self.params['chat_id'], self.params['chat_record_id'],
|
||||
answer_text,
|
||||
self)
|
||||
self.work_flow_post_handler.handler(self)
|
||||
return self.base_to_response.to_block_response(self.params['chat_id'],
|
||||
self.params['chat_record_id'], answer_text, True
|
||||
, message_tokens, answer_tokens,
|
||||
|
|
@ -384,6 +382,9 @@ class WorkflowManage:
|
|||
self.run_chain_async(current_node, node_result_future, language)
|
||||
return tools.to_stream_response_simple(self.await_result())
|
||||
|
||||
def get_body(self):
|
||||
return self.params
|
||||
|
||||
def is_run(self, timeout=0.5):
|
||||
future_list_len = len(self.future_list)
|
||||
try:
|
||||
|
|
@ -420,9 +421,7 @@ class WorkflowManage:
|
|||
'message_tokens' in row and row.get('message_tokens') is not None])
|
||||
answer_tokens = sum([row.get('answer_tokens') for row in details.values() if
|
||||
'answer_tokens' in row and row.get('answer_tokens') is not None])
|
||||
self.work_flow_post_handler.handler(self.params['chat_id'], self.params['chat_record_id'],
|
||||
self.answer,
|
||||
self)
|
||||
self.work_flow_post_handler.handler(self)
|
||||
yield self.base_to_response.to_stream_chunk_response(self.params['chat_id'],
|
||||
self.params['chat_record_id'],
|
||||
'',
|
||||
|
|
|
|||
|
|
@ -27,15 +27,19 @@ class ChatInfo:
|
|||
chat_user_type: str,
|
||||
knowledge_id_list: List[str],
|
||||
exclude_document_id_list: list[str],
|
||||
application_id: str,
|
||||
application: Application,
|
||||
work_flow_version: WorkFlowVersion = None):
|
||||
work_flow_version: WorkFlowVersion = None,
|
||||
debug=False):
|
||||
"""
|
||||
:param chat_id: 对话id
|
||||
:param chat_id: 对话id
|
||||
:param chat_user_id 对话用户id
|
||||
:param chat_user_type 对话用户类型
|
||||
:param knowledge_id_list: 知识库列表
|
||||
:param exclude_document_id_list: 排除的文档
|
||||
:param application_id 应用id
|
||||
:param application: 应用信息
|
||||
:param debug 是否是调试
|
||||
"""
|
||||
self.chat_id = chat_id
|
||||
self.chat_user_id = chat_user_id
|
||||
|
|
@ -43,8 +47,10 @@ class ChatInfo:
|
|||
self.application = application
|
||||
self.knowledge_id_list = knowledge_id_list
|
||||
self.exclude_document_id_list = exclude_document_id_list
|
||||
self.application_id = application_id
|
||||
self.chat_record_list: List[ChatRecord] = []
|
||||
self.work_flow_version = work_flow_version
|
||||
self.debug = debug
|
||||
|
||||
@staticmethod
|
||||
def get_no_references_setting(knowledge_setting, model_setting):
|
||||
|
|
@ -116,17 +122,17 @@ class ChatInfo:
|
|||
if record.id == chat_record.id:
|
||||
self.chat_record_list[index] = chat_record
|
||||
is_save = False
|
||||
break
|
||||
if is_save:
|
||||
self.chat_record_list.append(chat_record)
|
||||
cache.set(Cache_Version.CHAT.get_key(key=self.chat_id), self, version=Cache_Version.CHAT.get_version(),
|
||||
timeout=60 * 30)
|
||||
if self.application.id is not None:
|
||||
Chat(id=self.chat_id, application_id=self.application.id, abstract=chat_record.problem_text[0:1024],
|
||||
chat_user_id=self.chat_user_id, chat_user_type=self.chat_user_type).save()
|
||||
else:
|
||||
QuerySet(Chat).filter(id=self.chat_id).update(update_time=datetime.now())
|
||||
if not self.debug:
|
||||
if not QuerySet(Chat).filter(id=self.chat_id).exists():
|
||||
Chat(id=self.chat_id, application_id=self.application.id, abstract=chat_record.problem_text[0:1024],
|
||||
chat_user_id=self.chat_user_id, chat_user_type=self.chat_user_type).save()
|
||||
else:
|
||||
QuerySet(Chat).filter(id=self.chat_id).update(update_time=datetime.now())
|
||||
# 插入会话记录
|
||||
chat_record.save()
|
||||
chat_record.save()
|
||||
|
||||
def set_cache(self):
|
||||
cache.set(Cache_Version.CHAT.get_key(key=self.chat_id), self, version=Cache_Version.CHAT.get_version(),
|
||||
|
|
|
|||
|
|
@ -52,4 +52,7 @@ urlpatterns = [
|
|||
path(
|
||||
'workspace/<str:workspace_id>/application/<str:application_id>/work_flow_version/<str:work_flow_version_id>',
|
||||
views.ApplicationVersionView.Operate.as_view()),
|
||||
path('workspace/<str:workspace_id>/application/<str:application_id>/open', views.OpenView.as_view()),
|
||||
path('chat_message/<str:chat_id>', views.ChatView.as_view()),
|
||||
|
||||
]
|
||||
|
|
|
|||
|
|
@ -6,6 +6,8 @@
|
|||
@date:2025/6/10 11:00
|
||||
@desc:
|
||||
"""
|
||||
import uuid
|
||||
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from drf_spectacular.utils import extend_schema
|
||||
from rest_framework.request import Request
|
||||
|
|
@ -13,7 +15,11 @@ from rest_framework.views import APIView
|
|||
|
||||
from application.api.application_chat import ApplicationChatQueryAPI, ApplicationChatQueryPageAPI, \
|
||||
ApplicationChatExportAPI
|
||||
from application.models import ChatUserType
|
||||
from application.serializers.application_chat import ApplicationChatQuerySerializers
|
||||
from chat.api.chat_api import ChatAPI
|
||||
from chat.api.chat_authentication_api import ChatOpenAPI
|
||||
from chat.serializers.chat import OpenChatSerializers, ChatSerializers, DebugChatSerializers
|
||||
from common.auth import TokenAuth
|
||||
from common.auth.authentication import has_permissions
|
||||
from common.constants.permission_constants import PermissionConstants, RoleConstants
|
||||
|
|
@ -81,3 +87,39 @@ class ApplicationChat(APIView):
|
|||
return ApplicationChatQuerySerializers(
|
||||
data={**query_params_to_single_dict(request.query_params), 'application_id': application_id,
|
||||
}).export(request.data)
|
||||
|
||||
|
||||
class OpenView(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
@extend_schema(
|
||||
methods=['GET'],
|
||||
description=_("Get a temporary session id based on the application id"),
|
||||
summary=_("Get a temporary session id based on the application id"),
|
||||
operation_id=_("Get a temporary session id based on the application id"), # type: ignore
|
||||
parameters=ChatOpenAPI.get_parameters(),
|
||||
responses=None,
|
||||
tags=[_('Application')] # type: ignore
|
||||
)
|
||||
def get(self, request: Request, workspace_id: str, application_id: str):
|
||||
return result.success(OpenChatSerializers(
|
||||
data={'workspace_id': workspace_id, 'application_id': application_id,
|
||||
'chat_user_id': str(uuid.uuid1()), 'chat_user_type': ChatUserType.ANONYMOUS_USER,
|
||||
'debug': True}).open())
|
||||
|
||||
|
||||
class ChatView(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
@extend_schema(
|
||||
methods=['POST'],
|
||||
description=_("dialogue"),
|
||||
summary=_("dialogue"),
|
||||
operation_id=_("dialogue"), # type: ignore
|
||||
request=ChatAPI.get_request(),
|
||||
parameters=ChatAPI.get_parameters(),
|
||||
responses=None,
|
||||
tags=[_('Application')] # type: ignore
|
||||
)
|
||||
def post(self, request: Request, chat_id: str):
|
||||
return DebugChatSerializers(data={'chat_id': chat_id}).chat(request.data)
|
||||
|
|
|
|||
|
|
@ -94,12 +94,27 @@ def get_post_handler(chat_info: ChatInfo):
|
|||
return PostHandler()
|
||||
|
||||
|
||||
class DebugChatSerializers(serializers.Serializer):
|
||||
chat_id = serializers.UUIDField(required=True, label=_("Conversation ID"))
|
||||
|
||||
def chat(self, instance: dict, base_to_response: BaseToResponse = SystemToResponse()):
|
||||
self.is_valid(raise_exception=True)
|
||||
chat_id = self.data.get('chat_id')
|
||||
chat_info: ChatInfo = ChatInfo.get_cache(chat_id)
|
||||
return ChatSerializers(data={
|
||||
'chat_id': chat_id, "chat_user_id": chat_info.chat_user_id,
|
||||
"chat_user_type": chat_info.chat_user_type,
|
||||
"application_id": chat_info.application.id, "debug": True
|
||||
}).chat(instance, base_to_response)
|
||||
|
||||
|
||||
class ChatSerializers(serializers.Serializer):
|
||||
chat_id = serializers.UUIDField(required=True, label=_("Conversation ID"))
|
||||
chat_user_id = serializers.CharField(required=True, label=_("Client id"))
|
||||
chat_user_type = serializers.CharField(required=True, label=_("Client Type"))
|
||||
application_id = serializers.UUIDField(required=True, allow_null=True,
|
||||
label=_("Application ID"))
|
||||
debug = serializers.BooleanField(required=False, label=_("Debug"))
|
||||
|
||||
def is_valid_application_workflow(self, *, raise_exception=False):
|
||||
self.is_valid_intraday_access_num()
|
||||
|
|
@ -158,6 +173,7 @@ class ChatSerializers(serializers.Serializer):
|
|||
.append_step(BaseGenerateHumanMessageStep)
|
||||
.append_step(BaseChatStep)
|
||||
.add_base_to_response(base_to_response)
|
||||
.add_debug(self.data.get('debug', False))
|
||||
.build())
|
||||
exclude_paragraph_id_list = []
|
||||
# 相同问题是否需要排除已经查询到的段落
|
||||
|
|
@ -189,18 +205,18 @@ class ChatSerializers(serializers.Serializer):
|
|||
return chat_record
|
||||
|
||||
def chat_work_flow(self, chat_info: ChatInfo, instance: dict, base_to_response):
|
||||
message = self.data.get('message')
|
||||
re_chat = self.data.get('re_chat')
|
||||
stream = self.data.get('stream')
|
||||
chat_user_id = instance.get('chat_user_id')
|
||||
chat_user_type = instance.get('chat_user_type')
|
||||
form_data = self.data.get('form_data')
|
||||
image_list = self.data.get('image_list')
|
||||
document_list = self.data.get('document_list')
|
||||
audio_list = self.data.get('audio_list')
|
||||
other_list = self.data.get('other_list')
|
||||
user_id = chat_info.application.user_id
|
||||
chat_record_id = self.data.get('chat_record_id')
|
||||
message = instance.get('message')
|
||||
re_chat = instance.get('re_chat')
|
||||
stream = instance.get('stream')
|
||||
chat_user_id = self.data.get("chat_user_id")
|
||||
chat_user_type = self.data.get('chat_user_type')
|
||||
form_data = instance.get('form_data')
|
||||
image_list = instance.get('image_list')
|
||||
document_list = instance.get('document_list')
|
||||
audio_list = instance.get('audio_list')
|
||||
other_list = instance.get('other_list')
|
||||
workspace_id = chat_info.application.workspace_id
|
||||
chat_record_id = instance.get('chat_record_id')
|
||||
chat_record = None
|
||||
history_chat_record = chat_info.chat_record_list
|
||||
if chat_record_id is not None:
|
||||
|
|
@ -214,8 +230,9 @@ class ChatSerializers(serializers.Serializer):
|
|||
're_chat': re_chat,
|
||||
'chat_user_id': chat_user_id,
|
||||
'chat_user_type': chat_user_type,
|
||||
'user_id': user_id},
|
||||
WorkFlowPostHandler(chat_info, chat_user_id, chat_user_type),
|
||||
'workspace_id': workspace_id,
|
||||
'debug': self.data.get('debug', False)},
|
||||
WorkFlowPostHandler(chat_info),
|
||||
base_to_response, form_data, image_list, document_list, audio_list,
|
||||
other_list,
|
||||
self.data.get('runtime_node_id'),
|
||||
|
|
@ -229,7 +246,7 @@ class ChatSerializers(serializers.Serializer):
|
|||
chat_info = self.get_chat_info()
|
||||
self.is_valid_chat_id(chat_info)
|
||||
if chat_info.application.type == ApplicationTypeChoices.SIMPLE:
|
||||
self.is_valid_application_simple(raise_exception=True, chat_info=chat_info),
|
||||
self.is_valid_application_simple(raise_exception=True, chat_info=chat_info)
|
||||
return self.chat_simple(chat_info, instance, base_to_response)
|
||||
else:
|
||||
self.is_valid_application_workflow(raise_exception=True)
|
||||
|
|
@ -295,6 +312,7 @@ class OpenChatSerializers(serializers.Serializer):
|
|||
application_id = serializers.UUIDField(required=True)
|
||||
chat_user_id = serializers.CharField(required=True, label=_("Client id"))
|
||||
chat_user_type = serializers.CharField(required=True, label=_("Client Type"))
|
||||
debug = serializers.BooleanField(required=True, label=_("Debug"))
|
||||
|
||||
def is_valid(self, *, raise_exception=False):
|
||||
super().is_valid(raise_exception=True)
|
||||
|
|
@ -317,6 +335,7 @@ class OpenChatSerializers(serializers.Serializer):
|
|||
application_id = self.data.get('application_id')
|
||||
chat_user_id = self.data.get("chat_user_id")
|
||||
chat_user_type = self.data.get("chat_user_type")
|
||||
debug = self.data.get("debug")
|
||||
chat_id = str(uuid.uuid7())
|
||||
work_flow_version = QuerySet(WorkFlowVersion).filter(application_id=application_id).order_by(
|
||||
'-create_time')[0:1].first()
|
||||
|
|
@ -326,13 +345,15 @@ class OpenChatSerializers(serializers.Serializer):
|
|||
"The application has not been published. Please use it after publishing."))
|
||||
ChatInfo(chat_id, chat_user_id, chat_user_type, [],
|
||||
[],
|
||||
application, work_flow_version).set_cache()
|
||||
application_id,
|
||||
application, work_flow_version, debug).set_cache()
|
||||
return chat_id
|
||||
|
||||
def open_simple(self, application):
|
||||
application_id = self.data.get('application_id')
|
||||
chat_user_id = self.data.get("chat_user_id")
|
||||
chat_user_type = self.data.get("chat_user_type")
|
||||
debug = self.data.get("debug")
|
||||
knowledge_id_list = [str(row.dataset_id) for row in
|
||||
QuerySet(ApplicationKnowledgeMapping).filter(
|
||||
application_id=application_id)]
|
||||
|
|
@ -342,5 +363,6 @@ class OpenChatSerializers(serializers.Serializer):
|
|||
QuerySet(Document).filter(
|
||||
knowledge_id__in=knowledge_id_list,
|
||||
is_active=False)],
|
||||
application).set_cache()
|
||||
application_id,
|
||||
application, debug=debug).set_cache()
|
||||
return chat_id
|
||||
|
|
|
|||
Loading…
Reference in New Issue