diff --git a/apps/application/flow/i_step_node.py b/apps/application/flow/i_step_node.py index e1bae631a..3a558ace0 100644 --- a/apps/application/flow/i_step_node.py +++ b/apps/application/flow/i_step_node.py @@ -105,13 +105,14 @@ class FlowParamsSerializer(serializers.Serializer): chat_record_id = serializers.CharField(required=True, error_messages=ErrMessage.char("对话记录id")) - stream = serializers.BooleanField(required=True, error_messages=ErrMessage.base("流式输出")) + stream = serializers.BooleanField(required=True, error_messages=ErrMessage.boolean("流式输出")) client_id = serializers.CharField(required=False, error_messages=ErrMessage.char("客户端id")) client_type = serializers.CharField(required=False, error_messages=ErrMessage.char("客户端类型")) user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char("用户id")) + re_chat = serializers.BooleanField(required=True, error_messages=ErrMessage.boolean("换个答案")) class INode: diff --git a/apps/application/flow/step_node/search_dataset_node/i_search_dataset_node.py b/apps/application/flow/step_node/search_dataset_node/i_search_dataset_node.py index 436de5a96..0de4e65a0 100644 --- a/apps/application/flow/step_node/search_dataset_node/i_search_dataset_node.py +++ b/apps/application/flow/step_node/search_dataset_node/i_search_dataset_node.py @@ -61,11 +61,14 @@ class ISearchDatasetStepNode(INode): question = self.workflow_manage.get_reference_field( self.node_params_serializer.data.get('question_reference_address')[0], self.node_params_serializer.data.get('question_reference_address')[1:]) - history_chat_record = self.flow_params_serializer.data.get('history_chat_record', []) - paragraph_id_list = [p.get('id') for p in flat_map( - [get_paragraph_list(chat_record, self.node.id) for chat_record in history_chat_record if - chat_record.problem_text == question])] - exclude_paragraph_id_list = list(set(paragraph_id_list)) + exclude_paragraph_id_list = [] + if self.flow_params_serializer.data.get('re_chat', False): + history_chat_record = self.flow_params_serializer.data.get('history_chat_record', []) + paragraph_id_list = [p.get('id') for p in flat_map( + [get_paragraph_list(chat_record, self.node.id) for chat_record in history_chat_record if + chat_record.problem_text == question])] + exclude_paragraph_id_list = list(set(paragraph_id_list)) + return self.execute(**self.node_params_serializer.data, question=str(question), exclude_paragraph_id_list=exclude_paragraph_id_list)