mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 01:33:05 +00:00
fix: 修复工作流节点输出等问题 (#1716)
This commit is contained in:
parent
bce2558951
commit
b8aa4756c5
|
|
@ -9,6 +9,7 @@
|
|||
import time
|
||||
import uuid
|
||||
from abc import abstractmethod
|
||||
from hashlib import sha1
|
||||
from typing import Type, Dict, List
|
||||
|
||||
from django.core import cache
|
||||
|
|
@ -131,6 +132,7 @@ class FlowParamsSerializer(serializers.Serializer):
|
|||
|
||||
|
||||
class INode:
|
||||
view_type = 'many_view'
|
||||
|
||||
@abstractmethod
|
||||
def save_context(self, details, workflow_manage):
|
||||
|
|
@ -139,7 +141,7 @@ class INode:
|
|||
def get_answer_text(self):
|
||||
return self.answer_text
|
||||
|
||||
def __init__(self, node, workflow_params, workflow_manage, runtime_node_id=None):
|
||||
def __init__(self, node, workflow_params, workflow_manage, up_node_id_list=None):
|
||||
# 当前步骤上下文,用于存储当前步骤信息
|
||||
self.status = 200
|
||||
self.err_message = ''
|
||||
|
|
@ -152,10 +154,13 @@ class INode:
|
|||
self.context = {}
|
||||
self.answer_text = None
|
||||
self.id = node.id
|
||||
if runtime_node_id is None:
|
||||
self.runtime_node_id = str(uuid.uuid1())
|
||||
else:
|
||||
self.runtime_node_id = runtime_node_id
|
||||
if up_node_id_list is None:
|
||||
up_node_id_list = []
|
||||
self.up_node_id_list = up_node_id_list
|
||||
self.runtime_node_id = sha1(uuid.NAMESPACE_DNS.bytes + bytes(str(uuid.uuid5(uuid.NAMESPACE_DNS,
|
||||
"".join([*sorted(up_node_id_list),
|
||||
node.id]))),
|
||||
"utf-8")).hexdigest()
|
||||
|
||||
def valid_args(self, node_params, flow_params):
|
||||
flow_params_serializer_class = self.get_flow_params_serializer_class()
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ class FormNodeParamsSerializer(serializers.Serializer):
|
|||
|
||||
class IFormNode(INode):
|
||||
type = 'form-node'
|
||||
view_type = 'single_view'
|
||||
|
||||
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
|
||||
return FormNodeParamsSerializer
|
||||
|
|
|
|||
|
|
@ -34,6 +34,8 @@ class BaseFormNode(IFormNode):
|
|||
self.context['form_field_list'] = details.get('form_field_list')
|
||||
self.context['run_time'] = details.get('run_time')
|
||||
self.context['start_time'] = details.get('start_time')
|
||||
self.context['form_data'] = details.get('form_data')
|
||||
self.context['is_submit'] = details.get('is_submit')
|
||||
self.answer_text = details.get('result')
|
||||
|
||||
def execute(self, form_field_list, form_content_format, **kwargs) -> NodeResult:
|
||||
|
|
@ -77,6 +79,7 @@ class BaseFormNode(IFormNode):
|
|||
"form_field_list": self.context.get('form_field_list'),
|
||||
'form_data': self.context.get('form_data'),
|
||||
'start_time': self.context.get('start_time'),
|
||||
'is_submit': self.context.get('is_submit'),
|
||||
'run_time': self.context.get('run_time'),
|
||||
'type': self.node.type,
|
||||
'status': self.status,
|
||||
|
|
|
|||
|
|
@ -52,7 +52,8 @@ class Node:
|
|||
self.__setattr__(keyword, kwargs.get(keyword))
|
||||
|
||||
|
||||
end_nodes = ['ai-chat-node', 'reply-node', 'function-node', 'function-lib-node', 'application-node', 'image-understand-node']
|
||||
end_nodes = ['ai-chat-node', 'reply-node', 'function-node', 'function-lib-node', 'application-node',
|
||||
'image-understand-node']
|
||||
|
||||
|
||||
class Flow:
|
||||
|
|
@ -229,7 +230,9 @@ class NodeChunk:
|
|||
def add_chunk(self, chunk):
|
||||
self.chunk_list.append(chunk)
|
||||
|
||||
def end(self):
|
||||
def end(self, chunk=None):
|
||||
if chunk is not None:
|
||||
self.add_chunk(chunk)
|
||||
self.status = 200
|
||||
|
||||
def is_end(self):
|
||||
|
|
@ -266,6 +269,7 @@ class WorkflowManage:
|
|||
self.status = 0
|
||||
self.base_to_response = base_to_response
|
||||
self.chat_record = chat_record
|
||||
self.await_future_map = {}
|
||||
if start_node_id is not None:
|
||||
self.load_node(chat_record, start_node_id, start_node_data)
|
||||
else:
|
||||
|
|
@ -286,14 +290,16 @@ class WorkflowManage:
|
|||
for node_details in sorted(chat_record.details.values(), key=lambda d: d.get('index')):
|
||||
node_id = node_details.get('node_id')
|
||||
if node_details.get('runtime_node_id') == start_node_id:
|
||||
self.start_node = self.get_node_cls_by_id(node_id, node_details.get('runtime_node_id'))
|
||||
self.start_node = self.get_node_cls_by_id(node_id, node_details.get('up_node_id_list'))
|
||||
self.start_node.valid_args(self.start_node.node_params, self.start_node.workflow_params)
|
||||
self.start_node.save_context(node_details, self)
|
||||
node_result = NodeResult({**start_node_data, 'form_data': start_node_data, 'is_submit': True}, {})
|
||||
self.start_node_result_future = NodeResultFuture(node_result, None)
|
||||
return
|
||||
self.node_context.append(self.start_node)
|
||||
continue
|
||||
|
||||
node_id = node_details.get('node_id')
|
||||
node = self.get_node_cls_by_id(node_id, node_details.get('runtime_node_id'))
|
||||
node = self.get_node_cls_by_id(node_id, node_details.get('up_node_id_list'))
|
||||
node.valid_args(node.node_params, node.workflow_params)
|
||||
node.save_context(node_details, self)
|
||||
self.node_context.append(node)
|
||||
|
|
@ -345,17 +351,22 @@ class WorkflowManage:
|
|||
if chunk is None:
|
||||
break
|
||||
yield chunk
|
||||
yield self.get_chunk_content('', True)
|
||||
finally:
|
||||
self.work_flow_post_handler.handler(self.params['chat_id'], self.params['chat_record_id'],
|
||||
self.answer,
|
||||
self)
|
||||
yield self.get_chunk_content('', True)
|
||||
|
||||
def run_chain_async(self, current_node, node_result_future):
|
||||
future = executor.submit(self.run_chain, current_node, node_result_future)
|
||||
return future
|
||||
|
||||
def set_await_map(self, node_run_list):
|
||||
sorted_node_run_list = sorted(node_run_list, key=lambda n: n.get('node').node.y)
|
||||
for index in range(len(sorted_node_run_list)):
|
||||
self.await_future_map[sorted_node_run_list[index].get('node').runtime_node_id] = [
|
||||
sorted_node_run_list[i].get('future')
|
||||
for i in range(index)]
|
||||
|
||||
def run_chain(self, current_node, node_result_future=None):
|
||||
if current_node is None:
|
||||
start_node = self.get_start_node()
|
||||
|
|
@ -365,6 +376,9 @@ class WorkflowManage:
|
|||
try:
|
||||
is_stream = self.params.get('stream', True)
|
||||
# 处理节点响应
|
||||
await_future_list = self.await_future_map.get(current_node.runtime_node_id, None)
|
||||
if await_future_list is not None:
|
||||
[f.result() for f in await_future_list]
|
||||
result = self.hand_event_node_result(current_node,
|
||||
node_result_future) if is_stream else self.hand_node_result(
|
||||
current_node, node_result_future)
|
||||
|
|
@ -373,11 +387,9 @@ class WorkflowManage:
|
|||
return
|
||||
node_list = self.get_next_node_list(current_node, result)
|
||||
# 获取到可执行的子节点
|
||||
result_list = []
|
||||
for node in node_list:
|
||||
result = self.run_chain_async(node, None)
|
||||
result_list.append(result)
|
||||
[r.result() for r in result_list]
|
||||
result_list = [{'node': node, 'future': self.run_chain_async(node, None)} for node in node_list]
|
||||
self.set_await_map(result_list)
|
||||
[r.get('future').result() for r in result_list]
|
||||
if self.status == 0:
|
||||
self.status = 200
|
||||
except Exception as e:
|
||||
|
|
@ -401,6 +413,14 @@ class WorkflowManage:
|
|||
current_node.get_write_error_context(e)
|
||||
self.answer += str(e)
|
||||
|
||||
def append_node(self, current_node):
|
||||
for index in range(len(self.node_context)):
|
||||
n = self.node_context[index]
|
||||
if current_node.id == n.node.id and current_node.runtime_node_id == n.runtime_node_id:
|
||||
self.node_context[index] = current_node
|
||||
return
|
||||
self.node_context.append(current_node)
|
||||
|
||||
def hand_event_node_result(self, current_node, node_result_future):
|
||||
node_chunk = NodeChunk()
|
||||
try:
|
||||
|
|
@ -412,22 +432,35 @@ class WorkflowManage:
|
|||
for r in result:
|
||||
chunk = self.base_to_response.to_stream_chunk_response(self.params['chat_id'],
|
||||
self.params['chat_record_id'],
|
||||
r, False, 0, 0)
|
||||
current_node.id,
|
||||
current_node.up_node_id_list,
|
||||
r, False, 0, 0,
|
||||
{'node_type': current_node.type,
|
||||
'view_type': current_node.view_type})
|
||||
node_chunk.add_chunk(chunk)
|
||||
node_chunk.end()
|
||||
chunk = self.base_to_response.to_stream_chunk_response(self.params['chat_id'],
|
||||
self.params['chat_record_id'],
|
||||
current_node.id,
|
||||
current_node.up_node_id_list,
|
||||
'', False, 0, 0, {'node_is_end': True,
|
||||
'node_type': current_node.type,
|
||||
'view_type': current_node.view_type})
|
||||
node_chunk.end(chunk)
|
||||
else:
|
||||
list(result)
|
||||
# 添加节点
|
||||
self.node_context.append(current_node)
|
||||
self.append_node(current_node)
|
||||
return current_result
|
||||
except Exception as e:
|
||||
# 添加节点
|
||||
self.node_context.append(current_node)
|
||||
self.append_node(current_node)
|
||||
traceback.print_exc()
|
||||
self.answer += str(e)
|
||||
chunk = self.base_to_response.to_stream_chunk_response(self.params['chat_id'],
|
||||
self.params['chat_record_id'],
|
||||
str(e), False, 0, 0)
|
||||
current_node.id,
|
||||
current_node.up_node_id_list,
|
||||
str(e), False, 0, 0, {'node_is_end': True})
|
||||
if not self.node_chunk_manage.contains(node_chunk):
|
||||
self.node_chunk_manage.add_node_chunk(node_chunk)
|
||||
node_chunk.add_chunk(chunk)
|
||||
|
|
@ -492,32 +525,36 @@ class WorkflowManage:
|
|||
continue
|
||||
details = node.get_details(index)
|
||||
details['node_id'] = node.id
|
||||
details['up_node_id_list'] = node.up_node_id_list
|
||||
details['runtime_node_id'] = node.runtime_node_id
|
||||
details_result[node.runtime_node_id] = details
|
||||
return details_result
|
||||
|
||||
def get_answer_text_list(self):
|
||||
answer_text_list = []
|
||||
result = []
|
||||
next_node_id_list = []
|
||||
if self.start_node is not None:
|
||||
next_node_id_list = [edge.targetNodeId for edge in self.flow.edges if
|
||||
edge.sourceNodeId == self.start_node.id]
|
||||
for index in range(len(self.node_context)):
|
||||
node = self.node_context[index]
|
||||
up_node = None
|
||||
if index > 0:
|
||||
up_node = self.node_context[index - 1]
|
||||
answer_text = node.get_answer_text()
|
||||
if answer_text is not None:
|
||||
if self.chat_record is not None and self.chat_record.details is not None:
|
||||
details = self.chat_record.details.get(node.runtime_node_id)
|
||||
if details is not None and self.start_node.runtime_node_id != node.runtime_node_id:
|
||||
continue
|
||||
answer_text_list.append(
|
||||
{'content': answer_text, 'type': 'form' if node.type == 'form-node' else 'md'})
|
||||
result = []
|
||||
for index in range(len(answer_text_list)):
|
||||
answer = answer_text_list[index]
|
||||
if index == 0:
|
||||
result.append(answer.get('content'))
|
||||
continue
|
||||
if answer.get('type') != answer_text_list[index - 1].get('type'):
|
||||
result.append(answer.get('content'))
|
||||
else:
|
||||
result[-1] += answer.get('content')
|
||||
if up_node is None or node.view_type == 'single_view' or (
|
||||
node.view_type == 'many_view' and up_node.view_type == 'single_view'):
|
||||
result.append(node.get_answer_text())
|
||||
elif self.chat_record is not None and next_node_id_list.__contains__(
|
||||
node.id) and up_node is not None and not next_node_id_list.__contains__(
|
||||
up_node.id):
|
||||
result.append(node.get_answer_text())
|
||||
else:
|
||||
content = result[len(result) - 1]
|
||||
answer_text = node.get_answer_text()
|
||||
result[len(result) - 1] += answer_text if len(
|
||||
content) == 0 else ('\n\n' + answer_text)
|
||||
return result
|
||||
|
||||
def get_next_node(self):
|
||||
|
|
@ -540,6 +577,15 @@ class WorkflowManage:
|
|||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def dependent_node(up_node_id, node):
|
||||
if node.id == up_node_id:
|
||||
if node.type == 'form-node':
|
||||
if node.context.get('form_data', None) is not None:
|
||||
return True
|
||||
return False
|
||||
return True
|
||||
|
||||
def dependent_node_been_executed(self, node_id):
|
||||
"""
|
||||
判断依赖节点是否都已执行
|
||||
|
|
@ -547,7 +593,12 @@ class WorkflowManage:
|
|||
@return:
|
||||
"""
|
||||
up_node_id_list = [edge.sourceNodeId for edge in self.flow.edges if edge.targetNodeId == node_id]
|
||||
return all([any([node.id == up_node_id for node in self.node_context]) for up_node_id in up_node_id_list])
|
||||
return all([any([self.dependent_node(up_node_id, node) for node in self.node_context]) for up_node_id in
|
||||
up_node_id_list])
|
||||
|
||||
def get_up_node_id_list(self, node_id):
|
||||
up_node_id_list = [edge.sourceNodeId for edge in self.flow.edges if edge.targetNodeId == node_id]
|
||||
return up_node_id_list
|
||||
|
||||
def get_next_node_list(self, current_node, current_node_result):
|
||||
"""
|
||||
|
|
@ -556,6 +607,7 @@ class WorkflowManage:
|
|||
@param current_node_result: 当前可执行节点结果
|
||||
@return: 可执行节点列表
|
||||
"""
|
||||
|
||||
if current_node.type == 'form-node' and 'form_data' not in current_node_result.node_variable:
|
||||
return []
|
||||
node_list = []
|
||||
|
|
@ -564,11 +616,13 @@ class WorkflowManage:
|
|||
if (edge.sourceNodeId == current_node.id and
|
||||
f"{edge.sourceNodeId}_{current_node_result.node_variable.get('branch_id')}_right" == edge.sourceAnchorId):
|
||||
if self.dependent_node_been_executed(edge.targetNodeId):
|
||||
node_list.append(self.get_node_cls_by_id(edge.targetNodeId))
|
||||
node_list.append(
|
||||
self.get_node_cls_by_id(edge.targetNodeId, self.get_up_node_id_list(edge.targetNodeId)))
|
||||
else:
|
||||
for edge in self.flow.edges:
|
||||
if edge.sourceNodeId == current_node.id and self.dependent_node_been_executed(edge.targetNodeId):
|
||||
node_list.append(self.get_node_cls_by_id(edge.targetNodeId))
|
||||
node_list.append(
|
||||
self.get_node_cls_by_id(edge.targetNodeId, self.get_up_node_id_list(edge.targetNodeId)))
|
||||
return node_list
|
||||
|
||||
def get_reference_field(self, node_id: str, fields: List[str]):
|
||||
|
|
@ -629,11 +683,11 @@ class WorkflowManage:
|
|||
base_node_list = [node for node in self.flow.nodes if node.type == 'base-node']
|
||||
return base_node_list[0]
|
||||
|
||||
def get_node_cls_by_id(self, node_id, runtime_node_id=None):
|
||||
def get_node_cls_by_id(self, node_id, up_node_id_list=None):
|
||||
for node in self.flow.nodes:
|
||||
if node.id == node_id:
|
||||
node_instance = get_node(node.type)(node,
|
||||
self.params, self, runtime_node_id)
|
||||
self.params, self, up_node_id_list)
|
||||
return node_instance
|
||||
return None
|
||||
|
||||
|
|
|
|||
|
|
@ -224,8 +224,13 @@ class ChatMessageSerializer(serializers.Serializer):
|
|||
re_chat = serializers.BooleanField(required=True, error_messages=ErrMessage.char("是否重新回答"))
|
||||
chat_record_id = serializers.UUIDField(required=False, allow_null=True,
|
||||
error_messages=ErrMessage.uuid("对话记录id"))
|
||||
|
||||
node_id = serializers.CharField(required=False, allow_null=True, allow_blank=True,
|
||||
error_messages=ErrMessage.char("节点id"))
|
||||
|
||||
runtime_node_id = serializers.CharField(required=False, allow_null=True, allow_blank=True,
|
||||
error_messages=ErrMessage.char("节点id"))
|
||||
error_messages=ErrMessage.char("运行时节点id"))
|
||||
|
||||
node_data = serializers.DictField(required=False, error_messages=ErrMessage.char("节点参数"))
|
||||
application_id = serializers.UUIDField(required=False, allow_null=True, error_messages=ErrMessage.uuid("应用id"))
|
||||
client_id = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端id"))
|
||||
|
|
@ -339,7 +344,8 @@ class ChatMessageSerializer(serializers.Serializer):
|
|||
'client_id': client_id,
|
||||
'client_type': client_type,
|
||||
'user_id': user_id}, WorkFlowPostHandler(chat_info, client_id, client_type),
|
||||
base_to_response, form_data, image_list, document_list, self.data.get('runtime_node_id'),
|
||||
base_to_response, form_data, image_list, document_list,
|
||||
self.data.get('runtime_node_id'),
|
||||
self.data.get('node_data'), chat_record)
|
||||
r = work_flow_manage.run()
|
||||
return r
|
||||
|
|
|
|||
|
|
@ -135,6 +135,7 @@ class ChatView(APIView):
|
|||
'document_list': request.data.get(
|
||||
'document_list') if 'document_list' in request.data else [],
|
||||
'client_type': request.auth.client_type,
|
||||
'node_id': request.data.get('node_id', None),
|
||||
'runtime_node_id': request.data.get('runtime_node_id', None),
|
||||
'node_data': request.data.get('node_data', {}),
|
||||
'chat_record_id': request.data.get('chat_record_id')}
|
||||
|
|
|
|||
|
|
@ -14,12 +14,15 @@ from rest_framework import status
|
|||
class BaseToResponse(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def to_block_response(self, chat_id, chat_record_id, content, is_end, completion_tokens, prompt_tokens,
|
||||
def to_block_response(self, chat_id, chat_record_id, content, is_end, completion_tokens,
|
||||
prompt_tokens, other_params: dict = None,
|
||||
_status=status.HTTP_200_OK):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def to_stream_chunk_response(self, chat_id, chat_record_id, content, is_end, completion_tokens, prompt_tokens):
|
||||
def to_stream_chunk_response(self, chat_id, chat_record_id, node_id, up_node_id_list, content, is_end,
|
||||
completion_tokens,
|
||||
prompt_tokens, other_params: dict = None):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ from common.handle.base_to_response import BaseToResponse
|
|||
|
||||
class OpenaiToResponse(BaseToResponse):
|
||||
def to_block_response(self, chat_id, chat_record_id, content, is_end, completion_tokens, prompt_tokens,
|
||||
other_params: dict = None,
|
||||
_status=status.HTTP_200_OK):
|
||||
data = ChatCompletion(id=chat_record_id, choices=[
|
||||
BlockChoice(finish_reason='stop', index=0, chat_id=chat_id,
|
||||
|
|
@ -31,7 +32,8 @@ class OpenaiToResponse(BaseToResponse):
|
|||
).dict()
|
||||
return JsonResponse(data=data, status=_status)
|
||||
|
||||
def to_stream_chunk_response(self, chat_id, chat_record_id, content, is_end, completion_tokens, prompt_tokens):
|
||||
def to_stream_chunk_response(self, chat_id, chat_record_id, node_id, up_node_id_list, content, is_end, completion_tokens,
|
||||
prompt_tokens, other_params: dict = None):
|
||||
chunk = ChatCompletionChunk(id=chat_record_id, model='', object='chat.completion.chunk',
|
||||
created=datetime.datetime.now().second, choices=[
|
||||
Choice(delta=ChoiceDelta(content=content, chat_id=chat_id), finish_reason='stop' if is_end else None,
|
||||
|
|
|
|||
|
|
@ -15,12 +15,23 @@ from common.response import result
|
|||
|
||||
|
||||
class SystemToResponse(BaseToResponse):
|
||||
def to_block_response(self, chat_id, chat_record_id, content, is_end, completion_tokens, prompt_tokens,
|
||||
def to_block_response(self, chat_id, chat_record_id, content, is_end, completion_tokens,
|
||||
prompt_tokens, other_params: dict = None,
|
||||
_status=status.HTTP_200_OK):
|
||||
if other_params is None:
|
||||
other_params = {}
|
||||
return result.success({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
|
||||
'content': content, 'is_end': is_end}, response_status=_status, code=_status)
|
||||
'content': content, 'is_end': is_end, **other_params}, response_status=_status,
|
||||
code=_status)
|
||||
|
||||
def to_stream_chunk_response(self, chat_id, chat_record_id, content, is_end, completion_tokens, prompt_tokens):
|
||||
def to_stream_chunk_response(self, chat_id, chat_record_id, node_id, up_node_id_list, content, is_end, completion_tokens,
|
||||
prompt_tokens, other_params: dict = None):
|
||||
if other_params is None:
|
||||
other_params = {}
|
||||
chunk = json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
|
||||
'content': content, 'is_end': is_end})
|
||||
'content': content, 'node_id': node_id, 'up_node_id_list': up_node_id_list, 'is_end': is_end,
|
||||
'usage': {'completion_tokens': completion_tokens,
|
||||
'prompt_tokens': prompt_tokens,
|
||||
'total_tokens': completion_tokens + prompt_tokens},
|
||||
**other_params})
|
||||
return super().format_stream_chunk(chunk)
|
||||
|
|
|
|||
|
|
@ -17,6 +17,8 @@ def updateDocumentStatus(apps, schema_editor):
|
|||
ParagraphModel = apps.get_model('dataset', 'Paragraph')
|
||||
DocumentModel = apps.get_model('dataset', 'Document')
|
||||
success_list = QuerySet(DocumentModel).filter(status='2')
|
||||
if len(success_list) == 0:
|
||||
return
|
||||
ListenerManagement.update_status(QuerySet(ParagraphModel).filter(document_id__in=[d.id for d in success_list]),
|
||||
TaskType.EMBEDDING, State.SUCCESS)
|
||||
ListenerManagement.get_aggregation_document_status_by_query_set(QuerySet(DocumentModel))()
|
||||
|
|
|
|||
|
|
@ -22,6 +22,17 @@ interface ApplicationFormType {
|
|||
tts_model_enable?: boolean
|
||||
tts_type?: string
|
||||
}
|
||||
interface Chunk {
|
||||
chat_id: string
|
||||
id: string
|
||||
content: string
|
||||
node_id: string
|
||||
up_node_id: string
|
||||
is_end: boolean
|
||||
node_is_end: boolean
|
||||
node_type: string
|
||||
view_type: string
|
||||
}
|
||||
interface chatType {
|
||||
id: string
|
||||
problem_text: string
|
||||
|
|
@ -47,6 +58,21 @@ interface chatType {
|
|||
}
|
||||
}
|
||||
|
||||
interface Node {
|
||||
buffer: Array<string>
|
||||
node_id: string
|
||||
up_node_id: string
|
||||
node_type: string
|
||||
view_type: string
|
||||
index: number
|
||||
is_end: boolean
|
||||
}
|
||||
interface WriteNodeInfo {
|
||||
current_node: any
|
||||
answer_text_list_index: number
|
||||
current_up_node?: any
|
||||
divider_content?: Array<string>
|
||||
}
|
||||
export class ChatRecordManage {
|
||||
id?: any
|
||||
ms: number
|
||||
|
|
@ -55,6 +81,8 @@ export class ChatRecordManage {
|
|||
write_ed?: boolean
|
||||
is_stop?: boolean
|
||||
loading?: Ref<boolean>
|
||||
node_list: Array<any>
|
||||
write_node_info?: WriteNodeInfo
|
||||
constructor(chat: chatType, ms?: number, loading?: Ref<boolean>) {
|
||||
this.ms = ms ? ms : 10
|
||||
this.chat = chat
|
||||
|
|
@ -62,12 +90,82 @@ export class ChatRecordManage {
|
|||
this.is_stop = false
|
||||
this.is_close = false
|
||||
this.write_ed = false
|
||||
this.node_list = []
|
||||
}
|
||||
append_answer(chunk_answer: String) {
|
||||
this.chat.answer_text_list[this.chat.answer_text_list.length - 1] =
|
||||
this.chat.answer_text_list[this.chat.answer_text_list.length - 1] + chunk_answer
|
||||
append_answer(chunk_answer: string, index?: number) {
|
||||
this.chat.answer_text_list[index != undefined ? index : this.chat.answer_text_list.length - 1] =
|
||||
this.chat.answer_text_list[
|
||||
index !== undefined ? index : this.chat.answer_text_list.length - 1
|
||||
]
|
||||
? this.chat.answer_text_list[
|
||||
index !== undefined ? index : this.chat.answer_text_list.length - 1
|
||||
] + chunk_answer
|
||||
: chunk_answer
|
||||
this.chat.answer_text = this.chat.answer_text + chunk_answer
|
||||
}
|
||||
|
||||
get_run_node() {
|
||||
if (
|
||||
this.write_node_info &&
|
||||
(this.write_node_info.current_node.buffer.length > 0 ||
|
||||
!this.write_node_info.current_node.is_end)
|
||||
) {
|
||||
return this.write_node_info
|
||||
}
|
||||
const run_node = this.node_list.filter((item) => item.buffer.length > 0 || !item.is_end).at(0)
|
||||
|
||||
if (run_node) {
|
||||
const index = this.node_list.indexOf(run_node)
|
||||
let current_up_node = undefined
|
||||
if (index > 0) {
|
||||
current_up_node = this.node_list[index - 1]
|
||||
}
|
||||
let answer_text_list_index = 0
|
||||
|
||||
if (
|
||||
current_up_node == undefined ||
|
||||
run_node.view_type == 'single_view' ||
|
||||
(run_node.view_type == 'many_view' && current_up_node.view_type == 'single_view')
|
||||
) {
|
||||
const none_index = this.chat.answer_text_list.indexOf('')
|
||||
if (none_index > -1) {
|
||||
answer_text_list_index = none_index
|
||||
} else {
|
||||
answer_text_list_index = this.chat.answer_text_list.length
|
||||
}
|
||||
} else {
|
||||
const none_index = this.chat.answer_text_list.indexOf('')
|
||||
if (none_index > -1) {
|
||||
answer_text_list_index = none_index
|
||||
} else {
|
||||
answer_text_list_index = this.chat.answer_text_list.length - 1
|
||||
}
|
||||
}
|
||||
|
||||
this.write_node_info = {
|
||||
current_node: run_node,
|
||||
divider_content: ['\n\n'],
|
||||
current_up_node: current_up_node,
|
||||
answer_text_list_index: answer_text_list_index
|
||||
}
|
||||
return this.write_node_info
|
||||
}
|
||||
return undefined
|
||||
}
|
||||
closeInterval() {
|
||||
this.chat.write_ed = true
|
||||
this.write_ed = true
|
||||
if (this.loading) {
|
||||
this.loading.value = false
|
||||
}
|
||||
if (this.id) {
|
||||
clearInterval(this.id)
|
||||
}
|
||||
const last_index = this.chat.answer_text_list.lastIndexOf('')
|
||||
if (last_index > 0) {
|
||||
this.chat.answer_text_list.splice(last_index, 1)
|
||||
}
|
||||
}
|
||||
write() {
|
||||
this.chat.is_stop = false
|
||||
this.is_stop = false
|
||||
|
|
@ -78,22 +176,45 @@ export class ChatRecordManage {
|
|||
this.loading.value = true
|
||||
}
|
||||
this.id = setInterval(() => {
|
||||
if (this.chat.buffer.length > 20) {
|
||||
this.append_answer(this.chat.buffer.splice(0, this.chat.buffer.length - 20).join(''))
|
||||
const node_info = this.get_run_node()
|
||||
if (node_info == undefined) {
|
||||
if (this.is_close) {
|
||||
this.closeInterval()
|
||||
}
|
||||
return
|
||||
}
|
||||
const { current_node, answer_text_list_index, divider_content } = node_info
|
||||
if (current_node.buffer.length > 20) {
|
||||
const context = current_node.is_end
|
||||
? current_node.buffer.splice(0)
|
||||
: current_node.buffer.splice(
|
||||
0,
|
||||
current_node.is_end ? undefined : current_node.buffer.length - 20
|
||||
)
|
||||
this.append_answer(
|
||||
(divider_content ? divider_content.splice(0).join('') : '') + context.join(''),
|
||||
answer_text_list_index
|
||||
)
|
||||
} else if (this.is_close) {
|
||||
this.append_answer(this.chat.buffer.splice(0).join(''))
|
||||
this.chat.write_ed = true
|
||||
this.write_ed = true
|
||||
if (this.loading) {
|
||||
this.loading.value = false
|
||||
}
|
||||
if (this.id) {
|
||||
clearInterval(this.id)
|
||||
while (true) {
|
||||
const node_info = this.get_run_node()
|
||||
if (node_info == undefined) {
|
||||
break
|
||||
}
|
||||
this.append_answer(
|
||||
(node_info.divider_content ? node_info.divider_content.splice(0).join('') : '') +
|
||||
node_info.current_node.buffer.splice(0).join(''),
|
||||
node_info.answer_text_list_index
|
||||
)
|
||||
}
|
||||
this.closeInterval()
|
||||
} else {
|
||||
const s = this.chat.buffer.shift()
|
||||
const s = current_node.buffer.shift()
|
||||
if (s !== undefined) {
|
||||
this.append_answer(s)
|
||||
this.append_answer(
|
||||
(divider_content ? divider_content.splice(0).join('') : '') + s,
|
||||
answer_text_list_index
|
||||
)
|
||||
}
|
||||
}
|
||||
}, this.ms)
|
||||
|
|
@ -113,6 +234,28 @@ export class ChatRecordManage {
|
|||
this.is_close = false
|
||||
this.is_stop = false
|
||||
}
|
||||
appendChunk(chunk: Chunk) {
|
||||
let n = this.node_list.find(
|
||||
(item) => item.node_id == chunk.node_id && item.up_node_id === chunk.up_node_id
|
||||
)
|
||||
if (n) {
|
||||
n.buffer.push(...chunk.content)
|
||||
} else {
|
||||
n = {
|
||||
buffer: [...chunk.content],
|
||||
node_id: chunk.node_id,
|
||||
up_node_id: chunk.up_node_id,
|
||||
node_type: chunk.node_type,
|
||||
index: this.node_list.length,
|
||||
view_type: chunk.view_type,
|
||||
is_end: false
|
||||
}
|
||||
this.node_list.push(n)
|
||||
}
|
||||
if (chunk.node_is_end) {
|
||||
n['is_end'] = true
|
||||
}
|
||||
}
|
||||
append(answer_text_block: string) {
|
||||
for (let index = 0; index < answer_text_block.length; index++) {
|
||||
this.chat.buffer.push(answer_text_block[index])
|
||||
|
|
@ -126,6 +269,12 @@ export class ChatManagement {
|
|||
static addChatRecord(chat: chatType, ms: number, loading?: Ref<boolean>) {
|
||||
this.chatMessageContainer[chat.id] = new ChatRecordManage(chat, ms, loading)
|
||||
}
|
||||
static appendChunk(chatRecordId: string, chunk: Chunk) {
|
||||
const chatRecord = this.chatMessageContainer[chatRecordId]
|
||||
if (chatRecord) {
|
||||
chatRecord.appendChunk(chunk)
|
||||
}
|
||||
}
|
||||
static append(chatRecordId: string, content: string) {
|
||||
const chatRecord = this.chatMessageContainer[chatRecordId]
|
||||
if (chatRecord) {
|
||||
|
|
@ -144,6 +293,7 @@ export class ChatManagement {
|
|||
*/
|
||||
static write(chatRecordId: string) {
|
||||
const chatRecord = this.chatMessageContainer[chatRecordId]
|
||||
console.log('chatRecord', chatRecordId, this.chatMessageContainer, chatRecord)
|
||||
if (chatRecord) {
|
||||
chatRecord.write()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -223,10 +223,8 @@ const getWrite = (chat: any, reader: any, stream: boolean) => {
|
|||
const chunk = JSON?.parse(split[index].replace('data:', ''))
|
||||
chat.chat_id = chunk.chat_id
|
||||
chat.record_id = chunk.id
|
||||
const content = chunk?.content
|
||||
if (content) {
|
||||
ChatManagement.append(chat.id, content)
|
||||
}
|
||||
ChatManagement.appendChunk(chat.id, chunk)
|
||||
|
||||
if (chunk.is_end) {
|
||||
// 流处理成功 返回成功回调
|
||||
return Promise.resolve()
|
||||
|
|
@ -275,6 +273,7 @@ const errorWrite = (chat: any, message?: string) => {
|
|||
|
||||
function chatMessage(chat?: any, problem?: string, re_chat?: boolean, other_params_data?: any) {
|
||||
loading.value = true
|
||||
console.log(chat)
|
||||
if (!chat) {
|
||||
chat = reactive({
|
||||
id: randomId(),
|
||||
|
|
@ -306,6 +305,10 @@ function chatMessage(chat?: any, problem?: string, re_chat?: boolean, other_para
|
|||
scrollDiv.value.setScrollTop(getMaxHeight())
|
||||
})
|
||||
}
|
||||
if (chat.run_time) {
|
||||
ChatManagement.addChatRecord(chat, 50, loading)
|
||||
ChatManagement.write(chat.id)
|
||||
}
|
||||
if (!chartOpenId.value) {
|
||||
getChartOpenId(chat).catch(() => {
|
||||
errorWrite(chat)
|
||||
|
|
|
|||
|
|
@ -52,7 +52,6 @@ const is_submit = computed(() => {
|
|||
const _form_data = ref<any>({})
|
||||
const form_data = computed({
|
||||
get: () => {
|
||||
console.log(form_setting_data.value)
|
||||
if (form_setting_data.value.is_submit) {
|
||||
return form_setting_data.value.form_data
|
||||
} else {
|
||||
|
|
|
|||
Loading…
Reference in New Issue