From f4a3883be2d69f31b10b0ad6f77a57e50af5eb32 Mon Sep 17 00:00:00 2001 From: shaohuzhang1 Date: Mon, 23 Sep 2024 16:42:34 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=B7=A5=E4=BD=9C=E6=B5=81=E5=A4=84?= =?UTF-8?q?=E7=90=86=E9=BB=98=E8=AE=A4=E5=8F=82=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../start_node/impl/base_start_node.py | 32 +++++++++++++++---- apps/application/flow/workflow_manage.py | 13 ++++++-- ui/src/workflow/nodes/reranker-node/index.vue | 7 +--- 3 files changed, 37 insertions(+), 15 deletions(-) diff --git a/apps/application/flow/step_node/start_node/impl/base_start_node.py b/apps/application/flow/step_node/start_node/impl/base_start_node.py index f6528660d..39fbfe76a 100644 --- a/apps/application/flow/step_node/start_node/impl/base_start_node.py +++ b/apps/application/flow/step_node/start_node/impl/base_start_node.py @@ -8,23 +8,41 @@ """ import time from datetime import datetime +from typing import List, Type + +from rest_framework import serializers from application.flow.i_step_node import NodeResult from application.flow.step_node.start_node.i_start_node import IStarNode +def get_default_global_variable(input_field_list: List): + return {item.get('variable'): item.get('default_value') for item in input_field_list if + item.get('default_value', None) is not None} + + +def get_global_variable(node): + history_chat_record = node.flow_params_serializer.data.get('history_chat_record', []) + history_context = [{'question': chat_record.problem_text, 'answer': chat_record.answer_text} for chat_record in + history_chat_record] + chat_id = node.flow_params_serializer.data.get('chat_id') + return {'time': datetime.now().strftime('%Y-%m-%d %H:%M:%S'), 'start_time': time.time(), + 'history_context': history_context, 'chat_id': str(chat_id), **node.workflow_manage.form_data} + + class BaseStartStepNode(IStarNode): + def get_node_params_serializer_class(self) -> Type[serializers.Serializer]: + pass + def execute(self, question, **kwargs) -> NodeResult: - history_chat_record = self.flow_params_serializer.data.get('history_chat_record', []) - history_context = [{'question': chat_record.problem_text, 'answer': chat_record.answer_text} for chat_record in - history_chat_record] - chat_id = self.flow_params_serializer.data.get('chat_id') + base_node = self.workflow_manage.get_base_node() + default_global_variable = get_default_global_variable(base_node.properties.get('input_field_list', [])) + workflow_variable = {**default_global_variable, **get_global_variable(self)} """ 开始节点 初始化全局变量 """ return NodeResult({'question': question}, - {'time': datetime.now().strftime('%Y-%m-%d %H:%M:%S'), 'start_time': time.time(), - 'history_context': history_context, 'chat_id': str(chat_id)}) + workflow_variable) def get_details(self, index: int, **kwargs): global_fields = [] @@ -33,7 +51,7 @@ class BaseStartStepNode(IStarNode): global_fields.append({ 'label': field['label'], 'key': key, - 'value': self.workflow_manage[key] if key in self.workflow_manage else '' + 'value': self.workflow_manage.context[key] if key in self.workflow_manage.context else '' }) return { 'name': self.node.properties.get('stepName'), diff --git a/apps/application/flow/workflow_manage.py b/apps/application/flow/workflow_manage.py index f8e35a238..920222810 100644 --- a/apps/application/flow/workflow_manage.py +++ b/apps/application/flow/workflow_manage.py @@ -169,9 +169,10 @@ class WorkflowManage: base_to_response: BaseToResponse = SystemToResponse(), form_data=None): if form_data is None: form_data = {} + self.form_data = form_data self.params = params self.flow = flow - self.context = form_data + self.context = {} self.node_context = [] self.work_flow_post_handler = work_flow_post_handler self.current_node = None @@ -302,7 +303,7 @@ class WorkflowManage: """ if self.current_node is None: node = self.get_start_node() - node_instance = get_node(node.type)(node, self.params, self.context) + node_instance = get_node(node.type)(node, self.params, self) return node_instance if self.current_result is not None and self.current_result.is_assertion_result(): for edge in self.flow.edges: @@ -368,6 +369,14 @@ class WorkflowManage: start_node_list = [node for node in self.flow.nodes if node.type == 'start-node'] return start_node_list[0] + def get_base_node(self): + """ + 获取基础节点 + @return: + """ + base_node_list = [node for node in self.flow.nodes if node.type == 'base-node'] + return base_node_list[0] + def get_node_cls_by_id(self, node_id): for node in self.flow.nodes: if node.id == node_id: diff --git a/ui/src/workflow/nodes/reranker-node/index.vue b/ui/src/workflow/nodes/reranker-node/index.vue index 5e418f986..2bfd5ad80 100644 --- a/ui/src/workflow/nodes/reranker-node/index.vue +++ b/ui/src/workflow/nodes/reranker-node/index.vue @@ -224,6 +224,7 @@ const { const form = { reranker_reference_list: [[]], reranker_model_id: '', + question_reference_address: [], reranker_setting: { top_n: 3, similarity: 0.6, @@ -306,12 +307,6 @@ const openCreateModel = (provider?: Provider) => { onMounted(() => { getProvider() getModel() - if (typeof props.nodeModel.properties.node_data?.is_result === 'undefined') { - if (isLastNode(props.nodeModel)) { - set(props.nodeModel.properties.node_data, 'is_result', true) - } - } - set(props.nodeModel, 'validate', validate) })