From 8f529181bd9d3c16e7a344c46ab35611e060a08a Mon Sep 17 00:00:00 2001 From: shaohuzhang1 <80892890+shaohuzhang1@users.noreply.github.com> Date: Mon, 15 Jul 2024 20:24:12 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E6=8D=A2=E4=B8=AA?= =?UTF-8?q?=E7=AD=94=E6=A1=88=E5=9C=A8=E5=B7=A5=E4=BD=9C=E6=B5=81=E4=B8=AD?= =?UTF-8?q?=E4=B8=80=E7=9B=B4=E7=94=9F=E6=95=88=E9=97=AE=E9=A2=98=20(#766)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/application/flow/i_step_node.py | 3 ++- .../search_dataset_node/i_search_dataset_node.py | 13 ++++++++----- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/apps/application/flow/i_step_node.py b/apps/application/flow/i_step_node.py index e1bae631a..3a558ace0 100644 --- a/apps/application/flow/i_step_node.py +++ b/apps/application/flow/i_step_node.py @@ -105,13 +105,14 @@ class FlowParamsSerializer(serializers.Serializer): chat_record_id = serializers.CharField(required=True, error_messages=ErrMessage.char("对话记录id")) - stream = serializers.BooleanField(required=True, error_messages=ErrMessage.base("流式输出")) + stream = serializers.BooleanField(required=True, error_messages=ErrMessage.boolean("流式输出")) client_id = serializers.CharField(required=False, error_messages=ErrMessage.char("客户端id")) client_type = serializers.CharField(required=False, error_messages=ErrMessage.char("客户端类型")) user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char("用户id")) + re_chat = serializers.BooleanField(required=True, error_messages=ErrMessage.boolean("换个答案")) class INode: diff --git a/apps/application/flow/step_node/search_dataset_node/i_search_dataset_node.py b/apps/application/flow/step_node/search_dataset_node/i_search_dataset_node.py index 436de5a96..0de4e65a0 100644 --- a/apps/application/flow/step_node/search_dataset_node/i_search_dataset_node.py +++ b/apps/application/flow/step_node/search_dataset_node/i_search_dataset_node.py @@ -61,11 +61,14 @@ class ISearchDatasetStepNode(INode): question = self.workflow_manage.get_reference_field( self.node_params_serializer.data.get('question_reference_address')[0], self.node_params_serializer.data.get('question_reference_address')[1:]) - history_chat_record = self.flow_params_serializer.data.get('history_chat_record', []) - paragraph_id_list = [p.get('id') for p in flat_map( - [get_paragraph_list(chat_record, self.node.id) for chat_record in history_chat_record if - chat_record.problem_text == question])] - exclude_paragraph_id_list = list(set(paragraph_id_list)) + exclude_paragraph_id_list = [] + if self.flow_params_serializer.data.get('re_chat', False): + history_chat_record = self.flow_params_serializer.data.get('history_chat_record', []) + paragraph_id_list = [p.get('id') for p in flat_map( + [get_paragraph_list(chat_record, self.node.id) for chat_record in history_chat_record if + chat_record.problem_text == question])] + exclude_paragraph_id_list = list(set(paragraph_id_list)) + return self.execute(**self.node_params_serializer.data, question=str(question), exclude_paragraph_id_list=exclude_paragraph_id_list)