From c7883cd3fd34c6aa016a8ac3b6957ed16c70a863 Mon Sep 17 00:00:00 2001 From: shaohuzhang1 Date: Mon, 24 Jun 2024 11:48:32 +0800 Subject: [PATCH] =?UTF-8?q?=20feat:=20=E5=AF=B9=E8=AF=9D=E7=BA=AA=E8=A6=81?= =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E8=AF=A6=E6=83=85=E5=AD=97=E6=AE=B5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/application/flow/i_step_node.py | 11 ++++++ .../ai_chat_step_node/impl/base_chat_node.py | 8 +++- .../impl/base_condition_node.py | 8 +++- .../direct_reply_node/impl/base_reply_node.py | 6 ++- .../question_node/impl/base_question_node.py | 8 +++- .../impl/base_search_dataset_node.py | 8 +++- .../start_node/impl/base_start_node.py | 8 +++- apps/application/flow/workflow_manage.py | 37 ++++++++++++------- .../serializers/chat_serializers.py | 2 +- 9 files changed, 71 insertions(+), 25 deletions(-) diff --git a/apps/application/flow/i_step_node.py b/apps/application/flow/i_step_node.py index 6a8e1868f..72fb178e7 100644 --- a/apps/application/flow/i_step_node.py +++ b/apps/application/flow/i_step_node.py @@ -113,6 +113,8 @@ class FlowParamsSerializer(serializers.Serializer): class INode: def __init__(self, node, workflow_params, workflow_manage): # 当前步骤上下文,用于存储当前步骤信息 + self.status = 200 + self.err_message = '' self.node = node self.node_params = node.properties.get('node_data') self.workflow_manage = workflow_manage @@ -152,6 +154,15 @@ class INode: def get_flow_params_serializer_class(self) -> Type[serializers.Serializer]: return FlowParamsSerializer + def get_write_error_context(self, e): + self.status = 500 + self.err_message = str(e) + + def write_error_context(answer): + pass + + return write_error_context + def run(self) -> NodeResult: """ :return: 执行结果 diff --git a/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py b/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py index ca1f3f828..deef0a40e 100644 --- a/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py +++ b/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py @@ -176,5 +176,9 @@ class BaseChatNode(IChatNode): 'answer': self.context.get('answer'), 'type': self.node.type, 'message_tokens': self.context['message_tokens'], - 'answer_tokens': self.context['answer_tokens'] - } + 'answer_tokens': self.context['answer_tokens'], + 'status': self.status, + 'err_message': self.err_message + } if self.status == 200 else {"index": index, 'type': self.node.type, + 'status': self.status, + 'err_message': self.err_message} diff --git a/apps/application/flow/step_node/condition_node/impl/base_condition_node.py b/apps/application/flow/step_node/condition_node/impl/base_condition_node.py index 2920e00c1..62dfa905e 100644 --- a/apps/application/flow/step_node/condition_node/impl/base_condition_node.py +++ b/apps/application/flow/step_node/condition_node/impl/base_condition_node.py @@ -43,5 +43,9 @@ class BaseConditionNode(IConditionNode): 'run_time': self.context.get('run_time'), 'branch_id': self.context.get('branch_id'), 'branch_name': self.context.get('branch_name'), - 'type': self.node.type - } + 'type': self.node.type, + 'status': self.context.get('status'), + 'err_message': self.context.get('err_message') + } if self.status == 200 else {"index": index, 'type': self.node.type, + 'status': self.status, + 'err_message': self.err_message} diff --git a/apps/application/flow/step_node/direct_reply_node/impl/base_reply_node.py b/apps/application/flow/step_node/direct_reply_node/impl/base_reply_node.py index 51dca9dc9..5675fa832 100644 --- a/apps/application/flow/step_node/direct_reply_node/impl/base_reply_node.py +++ b/apps/application/flow/step_node/direct_reply_node/impl/base_reply_node.py @@ -84,4 +84,8 @@ class BaseReplyNode(IReplyNode): 'run_time': self.context.get('run_time'), 'type': self.node.type, 'answer': self.context.get('answer'), - } + 'status': self.status, + 'err_message': self.err_message + } if self.status == 200 else {"index": index, 'type': self.node.type, + 'status': self.status, + 'err_message': self.err_message} diff --git a/apps/application/flow/step_node/question_node/impl/base_question_node.py b/apps/application/flow/step_node/question_node/impl/base_question_node.py index 36071ec49..6192d13e9 100644 --- a/apps/application/flow/step_node/question_node/impl/base_question_node.py +++ b/apps/application/flow/step_node/question_node/impl/base_question_node.py @@ -173,5 +173,9 @@ class BaseQuestionNode(IQuestionNode): 'answer': self.context.get('answer'), 'type': self.node.type, 'message_tokens': self.context['message_tokens'], - 'answer_tokens': self.context['answer_tokens'] - } + 'answer_tokens': self.context['answer_tokens'], + 'status': self.status, + 'err_message': self.err_message + } if self.status == 200 else {"index": index, 'type': self.node.type, + 'status': self.status, + 'err_message': self.err_message} diff --git a/apps/application/flow/step_node/search_dataset_node/impl/base_search_dataset_node.py b/apps/application/flow/step_node/search_dataset_node/impl/base_search_dataset_node.py index cf652b54c..1f22e8603 100644 --- a/apps/application/flow/step_node/search_dataset_node/impl/base_search_dataset_node.py +++ b/apps/application/flow/step_node/search_dataset_node/impl/base_search_dataset_node.py @@ -82,5 +82,9 @@ class BaseSearchDatasetNode(ISearchDatasetStepNode): "index": index, 'run_time': self.context.get('run_time'), 'paragraph_list': self.context.get('paragraph_list'), - 'type': self.node.type - } + 'type': self.node.type, + 'status': self.status, + 'err_message': self.err_message + } if self.status == 200 else {"index": index, 'type': self.node.type, + 'status': self.status, + 'err_message': self.err_message} diff --git a/apps/application/flow/step_node/start_node/impl/base_start_node.py b/apps/application/flow/step_node/start_node/impl/base_start_node.py index f75a03e68..67cd2a48e 100644 --- a/apps/application/flow/step_node/start_node/impl/base_start_node.py +++ b/apps/application/flow/step_node/start_node/impl/base_start_node.py @@ -24,5 +24,9 @@ class BaseStartStepNode(IStarNode): "index": index, "question": self.context.get('question'), 'run_time': self.context.get('run_time'), - 'type': self.node.type - } + 'type': self.node.type, + 'status': self.status, + 'err_message': self.err_message + } if self.status == 200 else {"index": index, 'type': self.node.type, + 'status': self.status, + 'err_message': self.err_message} diff --git a/apps/application/flow/workflow_manage.py b/apps/application/flow/workflow_manage.py index 7b49bcc3e..e1b28f57d 100644 --- a/apps/application/flow/workflow_manage.py +++ b/apps/application/flow/workflow_manage.py @@ -9,8 +9,10 @@ from functools import reduce from typing import List, Dict +from langchain_core.messages import AIMessageChunk, AIMessage from langchain_core.prompts import PromptTemplate +from application.flow import tools from application.flow.i_step_node import INode, WorkFlowPostHandler, NodeResult from application.flow.step_node import get_node from common.exception.app_exception import AppApiException @@ -94,8 +96,6 @@ class Flow: f'不存在的下一个节点') return node_list - - def is_valid_work_flow(self, up_node=None): if up_node is None: up_node = self.get_start_node() @@ -133,17 +133,28 @@ class WorkflowManage: """ 运行工作流 """ - while self.has_next_node(self.current_result): - self.current_node = self.get_next_node() - self.node_context.append(self.current_node) - self.current_result = self.current_node.run() - if self.has_next_node(self.current_result): - self.current_result.write_context(self.current_node, self) + try: + while self.has_next_node(self.current_result): + self.current_node = self.get_next_node() + self.node_context.append(self.current_node) + self.current_result = self.current_node.run() + if self.has_next_node(self.current_result): + self.current_result.write_context(self.current_node, self) + else: + r = self.current_result.to_response(self.params['chat_id'], self.params['chat_record_id'], + self.current_node, self, + self.work_flow_post_handler) + return r + except Exception as e: + if self.params.get('stream'): + return tools.to_stream_response(self.params['chat_id'], self.params['chat_record_id'], + iter([AIMessageChunk(str(e))]), self, + self.current_node.get_write_error_context(e), + self.work_flow_post_handler) else: - r = self.current_result.to_response(self.params['chat_id'], self.params['chat_record_id'], - self.current_node, self, - self.work_flow_post_handler) - return r + return tools.to_response(self.params['chat_id'], self.params['chat_record_id'], + AIMessage(str(e)), self, self.current_node.get_write_error_context(e), + self.work_flow_post_handler) def has_next_node(self, node_result: NodeResult | None): """ @@ -168,7 +179,7 @@ class WorkflowManage: details_result = {} for index in range(len(self.node_context)): node = self.node_context[index] - details = node.get_details({'index': index}) + details = node.get_details(index) details_result[node.id] = details return details_result diff --git a/apps/application/serializers/chat_serializers.py b/apps/application/serializers/chat_serializers.py index ea97fedd9..8c37caaf7 100644 --- a/apps/application/serializers/chat_serializers.py +++ b/apps/application/serializers/chat_serializers.py @@ -425,7 +425,7 @@ class ChatRecordSerializer(serializers.Serializer): 'padding_problem_text') if 'problem_padding' in chat_record.details else None, 'dataset_list': dataset_list, 'paragraph_list': paragraph_list, - 'details': chat_record.details + 'execution_details': [chat_record.details[key] for key in chat_record.details] } def page(self, current_page: int, page_size: int, with_valid=True):