feat: 对话纪要添加详情字段

This commit is contained in:
shaohuzhang1 2024-06-24 11:48:32 +08:00
parent 7fd4fb08c4
commit c7883cd3fd
9 changed files with 71 additions and 25 deletions

View File

@ -113,6 +113,8 @@ class FlowParamsSerializer(serializers.Serializer):
class INode:
def __init__(self, node, workflow_params, workflow_manage):
# 当前步骤上下文,用于存储当前步骤信息
self.status = 200
self.err_message = ''
self.node = node
self.node_params = node.properties.get('node_data')
self.workflow_manage = workflow_manage
@ -152,6 +154,15 @@ class INode:
def get_flow_params_serializer_class(self) -> Type[serializers.Serializer]:
return FlowParamsSerializer
def get_write_error_context(self, e):
self.status = 500
self.err_message = str(e)
def write_error_context(answer):
pass
return write_error_context
def run(self) -> NodeResult:
"""
:return: 执行结果

View File

@ -176,5 +176,9 @@ class BaseChatNode(IChatNode):
'answer': self.context.get('answer'),
'type': self.node.type,
'message_tokens': self.context['message_tokens'],
'answer_tokens': self.context['answer_tokens']
}
'answer_tokens': self.context['answer_tokens'],
'status': self.status,
'err_message': self.err_message
} if self.status == 200 else {"index": index, 'type': self.node.type,
'status': self.status,
'err_message': self.err_message}

View File

@ -43,5 +43,9 @@ class BaseConditionNode(IConditionNode):
'run_time': self.context.get('run_time'),
'branch_id': self.context.get('branch_id'),
'branch_name': self.context.get('branch_name'),
'type': self.node.type
}
'type': self.node.type,
'status': self.context.get('status'),
'err_message': self.context.get('err_message')
} if self.status == 200 else {"index": index, 'type': self.node.type,
'status': self.status,
'err_message': self.err_message}

View File

@ -84,4 +84,8 @@ class BaseReplyNode(IReplyNode):
'run_time': self.context.get('run_time'),
'type': self.node.type,
'answer': self.context.get('answer'),
}
'status': self.status,
'err_message': self.err_message
} if self.status == 200 else {"index": index, 'type': self.node.type,
'status': self.status,
'err_message': self.err_message}

View File

@ -173,5 +173,9 @@ class BaseQuestionNode(IQuestionNode):
'answer': self.context.get('answer'),
'type': self.node.type,
'message_tokens': self.context['message_tokens'],
'answer_tokens': self.context['answer_tokens']
}
'answer_tokens': self.context['answer_tokens'],
'status': self.status,
'err_message': self.err_message
} if self.status == 200 else {"index": index, 'type': self.node.type,
'status': self.status,
'err_message': self.err_message}

View File

@ -82,5 +82,9 @@ class BaseSearchDatasetNode(ISearchDatasetStepNode):
"index": index,
'run_time': self.context.get('run_time'),
'paragraph_list': self.context.get('paragraph_list'),
'type': self.node.type
}
'type': self.node.type,
'status': self.status,
'err_message': self.err_message
} if self.status == 200 else {"index": index, 'type': self.node.type,
'status': self.status,
'err_message': self.err_message}

View File

@ -24,5 +24,9 @@ class BaseStartStepNode(IStarNode):
"index": index,
"question": self.context.get('question'),
'run_time': self.context.get('run_time'),
'type': self.node.type
}
'type': self.node.type,
'status': self.status,
'err_message': self.err_message
} if self.status == 200 else {"index": index, 'type': self.node.type,
'status': self.status,
'err_message': self.err_message}

View File

@ -9,8 +9,10 @@
from functools import reduce
from typing import List, Dict
from langchain_core.messages import AIMessageChunk, AIMessage
from langchain_core.prompts import PromptTemplate
from application.flow import tools
from application.flow.i_step_node import INode, WorkFlowPostHandler, NodeResult
from application.flow.step_node import get_node
from common.exception.app_exception import AppApiException
@ -94,8 +96,6 @@ class Flow:
f'不存在的下一个节点')
return node_list
def is_valid_work_flow(self, up_node=None):
if up_node is None:
up_node = self.get_start_node()
@ -133,17 +133,28 @@ class WorkflowManage:
"""
运行工作流
"""
while self.has_next_node(self.current_result):
self.current_node = self.get_next_node()
self.node_context.append(self.current_node)
self.current_result = self.current_node.run()
if self.has_next_node(self.current_result):
self.current_result.write_context(self.current_node, self)
try:
while self.has_next_node(self.current_result):
self.current_node = self.get_next_node()
self.node_context.append(self.current_node)
self.current_result = self.current_node.run()
if self.has_next_node(self.current_result):
self.current_result.write_context(self.current_node, self)
else:
r = self.current_result.to_response(self.params['chat_id'], self.params['chat_record_id'],
self.current_node, self,
self.work_flow_post_handler)
return r
except Exception as e:
if self.params.get('stream'):
return tools.to_stream_response(self.params['chat_id'], self.params['chat_record_id'],
iter([AIMessageChunk(str(e))]), self,
self.current_node.get_write_error_context(e),
self.work_flow_post_handler)
else:
r = self.current_result.to_response(self.params['chat_id'], self.params['chat_record_id'],
self.current_node, self,
self.work_flow_post_handler)
return r
return tools.to_response(self.params['chat_id'], self.params['chat_record_id'],
AIMessage(str(e)), self, self.current_node.get_write_error_context(e),
self.work_flow_post_handler)
def has_next_node(self, node_result: NodeResult | None):
"""
@ -168,7 +179,7 @@ class WorkflowManage:
details_result = {}
for index in range(len(self.node_context)):
node = self.node_context[index]
details = node.get_details({'index': index})
details = node.get_details(index)
details_result[node.id] = details
return details_result

View File

@ -425,7 +425,7 @@ class ChatRecordSerializer(serializers.Serializer):
'padding_problem_text') if 'problem_padding' in chat_record.details else None,
'dataset_list': dataset_list,
'paragraph_list': paragraph_list,
'details': chat_record.details
'execution_details': [chat_record.details[key] for key in chat_record.details]
}
def page(self, current_page: int, page_size: int, with_valid=True):