From ac659676f72118fa5d7e5783f85d543ce28466d3 Mon Sep 17 00:00:00 2001 From: shaohuzhang1 Date: Wed, 26 Jun 2024 13:52:56 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E6=A3=80=E7=B4=A2?= =?UTF-8?q?=E5=BC=95=E7=94=A8=E5=86=85=E5=AE=B9=E4=B8=BA=E5=AD=97=E7=AC=A6?= =?UTF-8?q?=E4=B8=B2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/application/flow/i_step_node.py | 2 +- .../ai_chat_step_node/impl/base_chat_node.py | 6 +++--- .../question_node/impl/base_question_node.py | 13 ++++++++++--- .../search_dataset_node/i_search_dataset_node.py | 2 +- 4 files changed, 15 insertions(+), 8 deletions(-) diff --git a/apps/application/flow/i_step_node.py b/apps/application/flow/i_step_node.py index 440ff0703..2b48869e0 100644 --- a/apps/application/flow/i_step_node.py +++ b/apps/application/flow/i_step_node.py @@ -160,7 +160,7 @@ class INode: self.status = 500 self.err_message = str(e) - def write_error_context(answer): + def write_error_context(answer, status=200): pass return write_error_context diff --git a/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py b/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py index 6265704ea..7f8f61bbd 100644 --- a/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py +++ b/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py @@ -67,18 +67,18 @@ def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wor def get_to_response_write_context(node_variable: Dict, node: INode): def _write_context(answer, status=200): chat_model = node_variable.get('chat_model') - message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list')) + if status == 200: answer_tokens = chat_model.get_num_tokens(answer) + message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list')) else: answer_tokens = 0 + message_tokens = 0 node.err_message = answer node.status = status node.context['message_tokens'] = message_tokens node.context['answer_tokens'] = answer_tokens node.context['answer'] = answer - node.context['history_message'] = node_variable['history_message'] - node.context['question'] = node_variable['question'] return _write_context diff --git a/apps/application/flow/step_node/question_node/impl/base_question_node.py b/apps/application/flow/step_node/question_node/impl/base_question_node.py index f75bf914d..16679ee0e 100644 --- a/apps/application/flow/step_node/question_node/impl/base_question_node.py +++ b/apps/application/flow/step_node/question_node/impl/base_question_node.py @@ -64,10 +64,17 @@ def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wor def get_to_response_write_context(node_variable: Dict, node: INode): - def _write_context(answer): + def _write_context(answer, status=200): chat_model = node_variable.get('chat_model') - message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list')) - answer_tokens = chat_model.get_num_tokens(answer) + + if status == 200: + answer_tokens = chat_model.get_num_tokens(answer) + message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list')) + else: + answer_tokens = 0 + message_tokens = 0 + node.err_message = answer + node.status = status node.context['message_tokens'] = message_tokens node.context['answer_tokens'] = answer_tokens node.context['answer'] = answer diff --git a/apps/application/flow/step_node/search_dataset_node/i_search_dataset_node.py b/apps/application/flow/step_node/search_dataset_node/i_search_dataset_node.py index 020273b71..0a134527c 100644 --- a/apps/application/flow/step_node/search_dataset_node/i_search_dataset_node.py +++ b/apps/application/flow/step_node/search_dataset_node/i_search_dataset_node.py @@ -53,7 +53,7 @@ class ISearchDatasetStepNode(INode): question = self.workflow_manage.get_reference_field( self.node_params_serializer.data.get('question_reference_address')[0], self.node_params_serializer.data.get('question_reference_address')[1:]) - return self.execute(**self.node_params_serializer.data, question=question, exclude_paragraph_id_list=[]) + return self.execute(**self.node_params_serializer.data, question=str(question), exclude_paragraph_id_list=[]) def execute(self, dataset_id_list, dataset_setting, question, exclude_paragraph_id_list=None,