mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 01:33:05 +00:00
261 lines
11 KiB
Python
261 lines
11 KiB
Python
# coding=utf-8
|
||
"""
|
||
@project: maxkb
|
||
@Author:虎
|
||
@file: i_step_node.py
|
||
@date:2024/6/3 14:57
|
||
@desc:
|
||
"""
|
||
import time
|
||
import uuid
|
||
from abc import abstractmethod
|
||
from hashlib import sha1
|
||
from typing import Type, Dict, List
|
||
|
||
from django.core import cache
|
||
from django.db.models import QuerySet
|
||
from rest_framework import serializers
|
||
from rest_framework.exceptions import ValidationError, ErrorDetail
|
||
|
||
from application.flow.common import Answer, NodeChunk
|
||
from application.models import ApplicationChatUserStats
|
||
from application.models import ChatRecord, ChatUserType
|
||
from common.field.common import InstanceField
|
||
|
||
chat_cache = cache
|
||
|
||
|
||
def write_context(step_variable: Dict, global_variable: Dict, node, workflow):
|
||
if step_variable is not None:
|
||
for key in step_variable:
|
||
node.context[key] = step_variable[key]
|
||
if workflow.is_result(node, NodeResult(step_variable, global_variable)) and 'answer' in step_variable:
|
||
answer = step_variable['answer']
|
||
yield answer
|
||
node.answer_text = answer
|
||
if global_variable is not None:
|
||
for key in global_variable:
|
||
workflow.context[key] = global_variable[key]
|
||
node.context['run_time'] = time.time() - node.context['start_time']
|
||
|
||
|
||
def is_interrupt(node, step_variable: Dict, global_variable: Dict):
|
||
return node.type == 'form-node' and not node.context.get('is_submit', False)
|
||
|
||
|
||
class WorkFlowPostHandler:
|
||
def __init__(self, chat_info):
|
||
self.chat_info = chat_info
|
||
|
||
def handler(self, workflow):
|
||
workflow_body = workflow.get_body()
|
||
question = workflow_body.get('question')
|
||
chat_record_id = workflow_body.get('chat_record_id')
|
||
chat_id = workflow_body.get('chat_id')
|
||
details = workflow.get_runtime_details()
|
||
message_tokens = sum([row.get('message_tokens') for row in details.values() if
|
||
'message_tokens' in row and row.get('message_tokens') is not None])
|
||
answer_tokens = sum([row.get('answer_tokens') for row in details.values() if
|
||
'answer_tokens' in row and row.get('answer_tokens') is not None])
|
||
answer_text_list = workflow.get_answer_text_list()
|
||
answer_text = '\n\n'.join(
|
||
'\n\n'.join([a.get('content') for a in answer]) for answer in
|
||
answer_text_list)
|
||
if workflow.chat_record is not None:
|
||
chat_record = workflow.chat_record
|
||
chat_record.answer_text = answer_text
|
||
chat_record.details = details
|
||
chat_record.message_tokens = message_tokens
|
||
chat_record.answer_tokens = answer_tokens
|
||
chat_record.answer_text_list = answer_text_list
|
||
chat_record.run_time = time.time() - workflow.context['start_time']
|
||
else:
|
||
chat_record = ChatRecord(id=chat_record_id,
|
||
chat_id=chat_id,
|
||
problem_text=question,
|
||
answer_text=answer_text,
|
||
details=details,
|
||
message_tokens=message_tokens,
|
||
answer_tokens=answer_tokens,
|
||
answer_text_list=answer_text_list,
|
||
run_time=time.time() - workflow.context['start_time'],
|
||
index=0)
|
||
|
||
self.chat_info.append_chat_record(chat_record)
|
||
self.chat_info.set_cache()
|
||
|
||
if not self.chat_info.debug and [ChatUserType.ANONYMOUS_USER.value, ChatUserType.CHAT_USER.value].__contains__(
|
||
workflow_body.get('chat_user_type')):
|
||
application_public_access_client = (QuerySet(ApplicationChatUserStats)
|
||
.filter(chat_user_id=workflow_body.get('chat_user_id'),
|
||
chat_user_type=workflow_body.get('chat_user_type'),
|
||
application_id=self.chat_info.application_id).first())
|
||
if application_public_access_client is not None:
|
||
application_public_access_client.access_num = application_public_access_client.access_num + 1
|
||
application_public_access_client.intraday_access_num = application_public_access_client.intraday_access_num + 1
|
||
application_public_access_client.save()
|
||
self.chat_info = None
|
||
|
||
|
||
class NodeResult:
|
||
def __init__(self, node_variable: Dict, workflow_variable: Dict,
|
||
_write_context=write_context, _is_interrupt=is_interrupt):
|
||
self._write_context = _write_context
|
||
self.node_variable = node_variable
|
||
self.workflow_variable = workflow_variable
|
||
self._is_interrupt = _is_interrupt
|
||
|
||
def write_context(self, node, workflow):
|
||
return self._write_context(self.node_variable, self.workflow_variable, node, workflow)
|
||
|
||
def is_assertion_result(self):
|
||
return 'branch_id' in self.node_variable
|
||
|
||
def is_interrupt_exec(self, current_node):
|
||
"""
|
||
是否中断执行
|
||
@param current_node:
|
||
@return:
|
||
"""
|
||
return self._is_interrupt(current_node, self.node_variable, self.workflow_variable)
|
||
|
||
|
||
class ReferenceAddressSerializer(serializers.Serializer):
|
||
node_id = serializers.CharField(required=True, label="节点id")
|
||
fields = serializers.ListField(
|
||
child=serializers.CharField(required=True, label="节点字段"), required=True,
|
||
label="节点字段数组")
|
||
|
||
|
||
class FlowParamsSerializer(serializers.Serializer):
|
||
# 历史对答
|
||
history_chat_record = serializers.ListField(child=InstanceField(model_type=ChatRecord, required=True),
|
||
label="历史对答")
|
||
|
||
question = serializers.CharField(required=True, label="用户问题")
|
||
|
||
chat_id = serializers.CharField(required=True, label="对话id")
|
||
|
||
chat_record_id = serializers.CharField(required=True, label="对话记录id")
|
||
|
||
stream = serializers.BooleanField(required=True, label="流式输出")
|
||
|
||
chat_user_id = serializers.CharField(required=False, label="对话用户id")
|
||
|
||
chat_user_type = serializers.CharField(required=False, label="对话用户类型")
|
||
|
||
workspace_id = serializers.CharField(required=True, label="工作空间id")
|
||
|
||
application_id = serializers.CharField(required=True, label="应用id")
|
||
|
||
re_chat = serializers.BooleanField(required=True, label="换个答案")
|
||
|
||
debug = serializers.BooleanField(required=True, label="是否debug")
|
||
|
||
|
||
class INode:
|
||
view_type = 'many_view'
|
||
|
||
@abstractmethod
|
||
def save_context(self, details, workflow_manage):
|
||
pass
|
||
|
||
def get_answer_list(self) -> List[Answer] | None:
|
||
if self.answer_text is None:
|
||
return None
|
||
reasoning_content_enable = self.context.get('model_setting', {}).get('reasoning_content_enable', False)
|
||
return [
|
||
Answer(self.answer_text, self.view_type, self.runtime_node_id, self.workflow_params['chat_record_id'], {},
|
||
self.runtime_node_id, self.context.get('reasoning_content', '') if reasoning_content_enable else '')]
|
||
|
||
def __init__(self, node, workflow_params, workflow_manage, up_node_id_list=None,
|
||
get_node_params=lambda node: node.properties.get('node_data'), salt=None):
|
||
# 当前步骤上下文,用于存储当前步骤信息
|
||
self.status = 200
|
||
self.err_message = ''
|
||
self.node = node
|
||
self.node_params = get_node_params(node)
|
||
self.workflow_params = workflow_params
|
||
self.workflow_manage = workflow_manage
|
||
self.node_params_serializer = None
|
||
self.flow_params_serializer = None
|
||
self.context = {}
|
||
self.answer_text = None
|
||
self.id = node.id
|
||
if up_node_id_list is None:
|
||
up_node_id_list = []
|
||
self.up_node_id_list = up_node_id_list
|
||
self.node_chunk = NodeChunk()
|
||
self.runtime_node_id = sha1(uuid.NAMESPACE_DNS.bytes + bytes(str(uuid.uuid5(uuid.NAMESPACE_DNS,
|
||
"".join([*sorted(up_node_id_list),
|
||
node.id]))),
|
||
"utf-8")).hexdigest() + (
|
||
"__" + str(salt) if salt is not None else '')
|
||
|
||
def valid_args(self, node_params, flow_params):
|
||
flow_params_serializer_class = self.get_flow_params_serializer_class()
|
||
node_params_serializer_class = self.get_node_params_serializer_class()
|
||
if flow_params_serializer_class is not None and flow_params is not None:
|
||
self.flow_params_serializer = flow_params_serializer_class(data=flow_params)
|
||
self.flow_params_serializer.is_valid(raise_exception=True)
|
||
if node_params_serializer_class is not None:
|
||
self.node_params_serializer = node_params_serializer_class(data=node_params)
|
||
self.node_params_serializer.is_valid(raise_exception=True)
|
||
if self.node.properties.get('status', 200) != 200:
|
||
raise ValidationError(ErrorDetail(f'节点{self.node.properties.get("stepName")} 不可用'))
|
||
|
||
def get_reference_field(self, fields: List[str]):
|
||
return self.get_field(self.context, fields)
|
||
|
||
@staticmethod
|
||
def get_field(obj, fields: List[str]):
|
||
for field in fields:
|
||
value = obj.get(field)
|
||
if value is None:
|
||
return None
|
||
else:
|
||
obj = value
|
||
return obj
|
||
|
||
@abstractmethod
|
||
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
|
||
pass
|
||
|
||
def get_flow_params_serializer_class(self) -> Type[serializers.Serializer]:
|
||
return FlowParamsSerializer
|
||
|
||
def get_write_error_context(self, e):
|
||
self.status = 500
|
||
self.answer_text = str(e)
|
||
self.err_message = str(e)
|
||
self.context['run_time'] = time.time() - self.context['start_time']
|
||
|
||
def write_error_context(answer, status=200):
|
||
pass
|
||
|
||
return write_error_context
|
||
|
||
def run(self) -> NodeResult:
|
||
"""
|
||
:return: 执行结果
|
||
"""
|
||
start_time = time.time()
|
||
self.context['start_time'] = start_time
|
||
result = self._run()
|
||
self.context['run_time'] = time.time() - start_time
|
||
return result
|
||
|
||
def _run(self):
|
||
result = self.execute()
|
||
return result
|
||
|
||
def execute(self, **kwargs) -> NodeResult:
|
||
pass
|
||
|
||
def get_details(self, index: int, **kwargs):
|
||
"""
|
||
运行详情
|
||
:return: 步骤详情
|
||
"""
|
||
return {}
|