From bfb03a62397c8a5ddadc46d93b7a8631dc1f8c34 Mon Sep 17 00:00:00 2001 From: shaohuzhang1 Date: Thu, 13 Mar 2025 18:54:17 +0800 Subject: [PATCH] feat: loopNode --- .../loop_node/impl/base_loop_node.py | 122 ++++++++++++++---- .../start_node/impl/base_start_node.py | 4 +- apps/application/flow/workflow_manage.py | 5 +- ui/src/workflow/common/data.ts | 28 ++++ .../workflow/nodes/loop-body-node/index.vue | 4 + ui/src/workflow/nodes/loop-node/index.vue | 8 +- 6 files changed, 140 insertions(+), 31 deletions(-) diff --git a/apps/application/flow/step_node/loop_node/impl/base_loop_node.py b/apps/application/flow/step_node/loop_node/impl/base_loop_node.py index 9882b6209..13240d79d 100644 --- a/apps/application/flow/step_node/loop_node/impl/base_loop_node.py +++ b/apps/application/flow/step_node/loop_node/impl/base_loop_node.py @@ -32,7 +32,9 @@ def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INo @param node: 节点 @param workflow: 工作流管理器 """ + response = node_variable.get('result') + workflow_manage = node_variable.get('workflow_manage') answer = '' reasoning_content = '' for chunk in response: @@ -42,7 +44,7 @@ def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INo answer += content_chunk yield {'content': content_chunk, 'reasoning_content': reasoning_content_chunk} - + runtime_details = workflow_manage.get_runtime_details() _write_context(node_variable, workflow_variable, node, workflow, answer, reasoning_content) @@ -69,28 +71,53 @@ def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wor _write_context(node_variable, workflow_variable, node, workflow, content, reasoning_content) -def loop_number(number, loop_body): - """ - 指定次数循环 - @return: - """ - pass +def loop_number(number: int, loop_body, workflow_manage_new_instance, workflow): + loop_global_data = {} + for index in range(number): + """ + 指定次数循环 + @return: + """ + instance = workflow_manage_new_instance({'index': index}, loop_global_data) + response = instance.stream() + answer = '' + reasoning_content = '' + for chunk in response: + content_chunk = chunk.get('content', '') + reasoning_content_chunk = chunk.get('reasoning_content', '') + reasoning_content += reasoning_content_chunk + answer += content_chunk + yield chunk + loop_global_data = instance.context -def loop_array(array, loop_body): - """ - 循环数组 - @return: - """ - pass +def loop_array(array, loop_body, workflow_manage_new_instance, workflow): + loop_global_data = {} + for item, index in zip(array, range(len(array))): + """ + 指定次数循环 + @return: + """ + instance = workflow_manage_new_instance({'index': index, 'item': item}, loop_global_data) + response = instance.stream() + answer = '' + reasoning_content = '' + for chunk in response: + content_chunk = chunk.get('content', '') + reasoning_content_chunk = chunk.get('reasoning_content', '') + reasoning_content += reasoning_content_chunk + answer += content_chunk + yield chunk + loop_global_data = instance.context -def loop_loop(loop_body): - """ - 无线循环 - @return: - """ - pass +def get_write_context(loop_type, array, number, loop_body, stream): + def inner_write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow): + if loop_type == 'ARRAY': + return loop_array(array, loop_body, node_variable['workflow_manage_new_instance'], workflow) + return loop_number(number, loop_body, node_variable['workflow_manage_new_instance'], workflow) + + return inner_write_context class LoopWorkFlowPostHandler(WorkFlowPostHandler): @@ -108,14 +135,55 @@ class BaseLoopNode(ILoopNode): def execute(self, loop_type, array, number, loop_body, stream, **kwargs) -> NodeResult: from application.flow.workflow_manage import WorkflowManage, Flow - workflow_manage = WorkflowManage(Flow.new_instance(loop_body), self.workflow_manage.params, - LoopWorkFlowPostHandler(self.workflow_manage.work_flow_post_handler.chat_info - , - self.workflow_manage.work_flow_post_handler.client_id, - self.workflow_manage.work_flow_post_handler.client_type) - , base_to_response=LoopToResponse()) - result = workflow_manage.stream() - return NodeResult({"result": result}, {}, _write_context=write_context_stream) + def workflow_manage_new_instance(start_data, global_data): + workflow_manage = WorkflowManage(Flow.new_instance(loop_body), self.workflow_manage.params, + LoopWorkFlowPostHandler( + self.workflow_manage.work_flow_post_handler.chat_info + , + self.workflow_manage.work_flow_post_handler.client_id, + self.workflow_manage.work_flow_post_handler.client_type) + , base_to_response=LoopToResponse(), + start_data=start_data, + form_data=global_data) + + return workflow_manage + + return NodeResult({'workflow_manage_new_instance': workflow_manage_new_instance}, {}, + _write_context=get_write_context(loop_type, array, number, loop_body, stream)) + + def loop_number(self, number: int, loop_body, stream): + for index in range(number): + """ + 指定次数循环 + @return: + """ + from application.flow.workflow_manage import WorkflowManage, Flow + workflow_manage = WorkflowManage(Flow.new_instance(loop_body), self.workflow_manage.params, + LoopWorkFlowPostHandler( + self.workflow_manage.work_flow_post_handler.chat_info + , + self.workflow_manage.work_flow_post_handler.client_id, + self.workflow_manage.work_flow_post_handler.client_type) + , base_to_response=LoopToResponse(), + start_data={'index': index}) + result = workflow_manage.stream() + return NodeResult({"result": result, "workflow_manage": workflow_manage}, {}, + _write_context=write_context_stream) + pass + + def loop_array(self, array, loop_body, stream): + """ + 循环数组 + @return: + """ + pass + + def loop_loop(self, loop_body, stream): + """ + 无线循环 + @return: + """ + pass def get_details(self, index: int, **kwargs): return { diff --git a/apps/application/flow/step_node/start_node/impl/base_start_node.py b/apps/application/flow/step_node/start_node/impl/base_start_node.py index c7f5a252b..d4a54a25b 100644 --- a/apps/application/flow/step_node/start_node/impl/base_start_node.py +++ b/apps/application/flow/step_node/start_node/impl/base_start_node.py @@ -64,7 +64,9 @@ class BaseStartStepNode(IStarNode): 'question': question, 'image': self.workflow_manage.image_list, 'document': self.workflow_manage.document_list, - 'audio': self.workflow_manage.audio_list + 'audio': self.workflow_manage.audio_list, + **self.workflow_manage.start_data + } return NodeResult(node_variable, workflow_variable) diff --git a/apps/application/flow/workflow_manage.py b/apps/application/flow/workflow_manage.py index 9a1f1284e..ef2272a5e 100644 --- a/apps/application/flow/workflow_manage.py +++ b/apps/application/flow/workflow_manage.py @@ -239,9 +239,11 @@ class WorkflowManage: document_list=None, audio_list=None, start_node_id=None, - start_node_data=None, chat_record=None, child_node=None): + start_node_data=None, chat_record=None, child_node=None, start_data=None): if form_data is None: form_data = {} + if start_data is None: + start_data = {} if image_list is None: image_list = [] if document_list is None: @@ -272,6 +274,7 @@ class WorkflowManage: self.field_list = [] self.global_field_list = [] self.init_fields() + self.start_data = start_data if start_node_id is not None: self.load_node(chat_record, start_node_id, start_node_data) else: diff --git a/ui/src/workflow/common/data.ts b/ui/src/workflow/common/data.ts index 5fa3e1b09..5eb276f6d 100644 --- a/ui/src/workflow/common/data.ts +++ b/ui/src/workflow/common/data.ts @@ -35,6 +35,34 @@ export const startNode = { showNode: true } } +export const loopStartNode = { + id: WorkflowType.Start, + type: WorkflowType.Start, + x: 480, + y: 3340, + properties: { + height: 364, + stepName: t('views.applicationWorkflow.nodes.startNode.label'), + config: { + fields: [ + { + label: t('views.applicationWorkflow.nodes.startNode.index', '下标'), + value: 'index' + }, + { + label: t('views.applicationWorkflow.nodes.startNode.item', '循环元素'), + value: 'item' + } + ], + globalFields: [] + }, + fields: [{ label: t('views.applicationWorkflow.nodes.startNode.question'), value: 'question' }], + globalFields: [ + { label: t('views.applicationWorkflow.nodes.startNode.currentTime'), value: 'time' } + ], + showNode: true + } +} export const baseNode = { id: WorkflowType.Base, type: WorkflowType.Base, diff --git a/ui/src/workflow/nodes/loop-body-node/index.vue b/ui/src/workflow/nodes/loop-body-node/index.vue index f88714e4b..99dbb0f8b 100644 --- a/ui/src/workflow/nodes/loop-body-node/index.vue +++ b/ui/src/workflow/nodes/loop-body-node/index.vue @@ -79,6 +79,10 @@ const renderGraphData = (data?: any) => { lf.value.graphModel.eventCenter.on('history:change', (data: any) => { set(props.nodeModel.properties, 'workflow', lf.value.getGraphData()) }) + + lf.value.graphModel.eventCenter.on('loop:change', (data: any) => { + console.log('xx') + }) setTimeout(() => { lf.value?.fitView() }, 500) diff --git a/ui/src/workflow/nodes/loop-node/index.vue b/ui/src/workflow/nodes/loop-node/index.vue index 8bb6c14c5..16ffdabf3 100644 --- a/ui/src/workflow/nodes/loop-node/index.vue +++ b/ui/src/workflow/nodes/loop-node/index.vue @@ -87,7 +87,7 @@ import { set } from 'lodash' import NodeContainer from '@/workflow/common/NodeContainer.vue' import { ref, computed, onMounted } from 'vue' import { isLastNode } from '@/workflow/common/data' -import { loopBodyNode } from '@/workflow/common/data' +import { loopBodyNode, loopStartNode } from '@/workflow/common/data' import NodeCascader from '@/workflow/common/NodeCascader.vue' const props = defineProps<{ nodeModel: any }>() @@ -131,11 +131,15 @@ onMounted(() => { set(props.nodeModel, 'validate', validate) const nodeOutgoingNode = props.nodeModel.graphModel.getNodeOutgoingNode(props.nodeModel.id) if (!nodeOutgoingNode.some((item: any) => item.type == loopBodyNode.type)) { + let workflow = { nodes: [loopStartNode], edges: [] } + if (props.nodeModel.properties.node_data.loop_body) { + workflow = props.nodeModel.properties.node_data.loop_body + } const nodeModel = props.nodeModel.graphModel.addNode({ type: loopBodyNode.type, properties: { ...loopBodyNode.properties, - workflow: props.nodeModel.properties.node_data.loop_body, + workflow: workflow, loop_node_id: props.nodeModel.id }, x: props.nodeModel.x,