From 466175fb1940e4ba0c5cf01ee97dc62a9da58920 Mon Sep 17 00:00:00 2001 From: shaohuzhang1 <80892890+shaohuzhang1@users.noreply.github.com> Date: Fri, 13 Jun 2025 12:01:08 +0800 Subject: [PATCH] perf: workflow chat (#3247) --- apps/application/flow/common.py | 220 ++++++++++++++++++++ apps/application/flow/workflow_manage.py | 212 ++----------------- apps/application/serializers/application.py | 4 +- apps/chat/serializers/chat.py | 6 +- ui/package.json | 2 +- ui/src/workflow/plugins/dagre.ts | 10 +- 6 files changed, 254 insertions(+), 200 deletions(-) diff --git a/apps/application/flow/common.py b/apps/application/flow/common.py index f5d4cb9b0..10f612300 100644 --- a/apps/application/flow/common.py +++ b/apps/application/flow/common.py @@ -7,6 +7,21 @@ @desc: """ +from typing import List, Dict + +from django.db.models import QuerySet +from django.utils.translation import gettext as _ +from rest_framework.exceptions import ErrorDetail, ValidationError + +from common.exception.app_exception import AppApiException +from common.utils.common import group_by +from models_provider.models import Model +from models_provider.tools import get_model_credential +from tools.models.tool import Tool + +end_nodes = ['ai-chat-node', 'reply-node', 'function-node', 'function-lib-node', 'application-node', + 'image-understand-node', 'speech-to-text-node', 'text-to-speech-node', 'image-generate-node'] + class Answer: def __init__(self, content, view_type, runtime_node_id, chat_record_id, child_node, real_node_id, @@ -42,3 +57,208 @@ class NodeChunk: def is_end(self): return self.status == 200 + + +class Edge: + def __init__(self, _id: str, _type: str, sourceNodeId: str, targetNodeId: str, **keywords): + self.id = _id + self.type = _type + self.sourceNodeId = sourceNodeId + self.targetNodeId = targetNodeId + for keyword in keywords: + self.__setattr__(keyword, keywords.get(keyword)) + + +class Node: + def __init__(self, _id: str, _type: str, x: int, y: int, properties: dict, **kwargs): + self.id = _id + self.type = _type + self.x = x + self.y = y + self.properties = properties + for keyword in kwargs: + self.__setattr__(keyword, kwargs.get(keyword)) + + +class EdgeNode: + edge: Edge + node: Node + + def __init__(self, edge, node): + self.edge = edge + self.node = node + + +class Workflow: + """ + 节点列表 + """ + nodes: List[Node] + """ + 线列表 + """ + edges: List[Edge] + """ + 节点id:node + """ + node_map: Dict[str, Node] + """ + 节点id:当前节点id上面的所有节点 + """ + up_node_map: Dict[str, List[EdgeNode]] + """ + 节点id:当前节点id下面的所有节点 + """ + next_node_map: Dict[str, List[EdgeNode]] + + def __init__(self, nodes: List[Node], edges: List[Edge]): + self.nodes = nodes + self.edges = edges + self.node_map = {node.id: node for node in nodes} + + self.up_node_map = {key: [EdgeNode(edge, self.node_map.get(edge.sourceNodeId)) for + edge in edges] for + key, edges in + group_by(edges, key=lambda edge: edge.targetNodeId).items()} + + self.next_node_map = {key: [EdgeNode(edge, self.node_map.get(edge.targetNodeId)) for edge in edges] for + key, edges in + group_by(edges, key=lambda edge: edge.sourceNodeId).items()} + + def get_node(self, node_id): + """ + 根据node_id 获取节点信息 + @param node_id: node_id + @return: 节点信息 + """ + return self.node_map.get(node_id) + + def get_up_edge_nodes(self, node_id) -> List[EdgeNode]: + """ + 根据节点id 获取当前连接前置节点和连线 + @param node_id: 节点id + @return: 节点连线列表 + """ + return self.up_node_map.get(node_id) + + def get_next_edge_nodes(self, node_id) -> List[EdgeNode]: + """ + 根据节点id 获取当前连接目标节点和连线 + @param node_id: 节点id + @return: 节点连线列表 + """ + return self.next_node_map.get(node_id) + + def get_up_nodes(self, node_id) -> List[Node]: + """ + 根据节点id 获取当前连接前置节点 + @param node_id: 节点id + @return: 节点列表 + """ + return [en.node for en in self.up_node_map.get(node_id)] + + def get_next_nodes(self, node_id) -> List[Node]: + """ + 根据节点id 获取当前连接目标节点 + @param node_id: 节点id + @return: 节点列表 + """ + return [en.node for en in self.next_node_map.get(node_id, [])] + + @staticmethod + def new_instance(flow_obj: Dict): + nodes = flow_obj.get('nodes') + edges = flow_obj.get('edges') + nodes = [Node(node.get('id'), node.get('type'), **node) + for node in nodes] + edges = [Edge(edge.get('id'), edge.get('type'), **edge) for edge in edges] + return Workflow(nodes, edges) + + def get_start_node(self): + start_node_list = [node for node in self.nodes if node.id == 'start-node'] + return start_node_list[0] + + def get_search_node(self): + return [node for node in self.nodes if node.type == 'search-dataset-node'] + + def is_valid(self): + """ + 校验工作流数据 + """ + self.is_valid_model_params() + self.is_valid_start_node() + self.is_valid_base_node() + self.is_valid_work_flow() + + @staticmethod + def is_valid_node_params(node: Node): + from application.flow.step_node import get_node + get_node(node.type)(node, None, None) + + def is_valid_node(self, node: Node): + self.is_valid_node_params(node) + if node.type == 'condition-node': + branch_list = node.properties.get('node_data').get('branch') + for branch in branch_list: + source_anchor_id = f"{node.id}_{branch.get('id')}_right" + edge_list = [edge for edge in self.edges if edge.sourceAnchorId == source_anchor_id] + if len(edge_list) == 0: + raise AppApiException(500, + _('The branch {branch} of the {node} node needs to be connected').format( + node=node.properties.get("stepName"), branch=branch.get("type"))) + + else: + edge_list = [edge for edge in self.edges if edge.sourceNodeId == node.id] + if len(edge_list) == 0 and not end_nodes.__contains__(node.type): + raise AppApiException(500, _("{node} Nodes cannot be considered as end nodes").format( + node=node.properties.get("stepName"))) + + def is_valid_work_flow(self, up_node=None): + if up_node is None: + up_node = self.get_start_node() + self.is_valid_node(up_node) + next_nodes = self.get_next_nodes(up_node) + for next_node in next_nodes: + self.is_valid_work_flow(next_node) + + def is_valid_start_node(self): + start_node_list = [node for node in self.nodes if node.id == 'start-node'] + if len(start_node_list) == 0: + raise AppApiException(500, _('The starting node is required')) + if len(start_node_list) > 1: + raise AppApiException(500, _('There can only be one starting node')) + + def is_valid_model_params(self): + node_list = [node for node in self.nodes if (node.type == 'ai-chat-node' or node.type == 'question-node')] + for node in node_list: + model = QuerySet(Model).filter(id=node.properties.get('node_data', {}).get('model_id')).first() + if model is None: + raise ValidationError(ErrorDetail( + _('The node {node} model does not exist').format(node=node.properties.get("stepName")))) + credential = get_model_credential(model.provider, model.model_type, model.model_name) + model_params_setting = node.properties.get('node_data', {}).get('model_params_setting') + model_params_setting_form = credential.get_model_params_setting_form( + model.model_name) + if model_params_setting is None: + model_params_setting = model_params_setting_form.get_default_form_data() + node.properties.get('node_data', {})['model_params_setting'] = model_params_setting + if node.properties.get('status', 200) != 200: + raise ValidationError( + ErrorDetail(_("Node {node} is unavailable").format(node.properties.get("stepName")))) + node_list = [node for node in self.nodes if (node.type == 'function-lib-node')] + for node in node_list: + function_lib_id = node.properties.get('node_data', {}).get('function_lib_id') + if function_lib_id is None: + raise ValidationError(ErrorDetail( + _('The library ID of node {node} cannot be empty').format(node=node.properties.get("stepName")))) + f_lib = QuerySet(Tool).filter(id=function_lib_id).first() + if f_lib is None: + raise ValidationError(ErrorDetail(_("The function library for node {node} is not available").format( + node=node.properties.get("stepName")))) + + def is_valid_base_node(self): + base_node_list = [node for node in self.nodes if node.id == 'base-node'] + if len(base_node_list) == 0: + raise AppApiException(500, _('Basic information node is required')) + if len(base_node_list) > 1: + raise AppApiException(500, _('There can only be one basic information node')) diff --git a/apps/application/flow/workflow_manage.py b/apps/application/flow/workflow_manage.py index 7535be2a9..e0065e6f4 100644 --- a/apps/application/flow/workflow_manage.py +++ b/apps/application/flow/workflow_manage.py @@ -15,165 +15,21 @@ from functools import reduce from typing import List, Dict from django.db import close_old_connections -from django.db.models import QuerySet from django.utils import translation from django.utils.translation import get_language -from django.utils.translation import gettext as _ from langchain_core.prompts import PromptTemplate from rest_framework import status -from rest_framework.exceptions import ErrorDetail, ValidationError from application.flow import tools +from application.flow.common import Workflow from application.flow.i_step_node import INode, WorkFlowPostHandler, NodeResult from application.flow.step_node import get_node -from common.exception.app_exception import AppApiException from common.handle.base_to_response import BaseToResponse from common.handle.impl.response.system_to_response import SystemToResponse -from tools.models.tool import Tool -from models_provider.models import Model -from models_provider.tools import get_model_credential executor = ThreadPoolExecutor(max_workers=200) -class Edge: - def __init__(self, _id: str, _type: str, sourceNodeId: str, targetNodeId: str, **keywords): - self.id = _id - self.type = _type - self.sourceNodeId = sourceNodeId - self.targetNodeId = targetNodeId - for keyword in keywords: - self.__setattr__(keyword, keywords.get(keyword)) - - -class Node: - def __init__(self, _id: str, _type: str, x: int, y: int, properties: dict, **kwargs): - self.id = _id - self.type = _type - self.x = x - self.y = y - self.properties = properties - for keyword in kwargs: - self.__setattr__(keyword, kwargs.get(keyword)) - - -end_nodes = ['ai-chat-node', 'reply-node', 'function-node', 'function-lib-node', 'application-node', - 'image-understand-node', 'speech-to-text-node', 'text-to-speech-node', 'image-generate-node'] - - -class Flow: - def __init__(self, nodes: List[Node], edges: List[Edge]): - self.nodes = nodes - self.edges = edges - - @staticmethod - def new_instance(flow_obj: Dict): - nodes = flow_obj.get('nodes') - edges = flow_obj.get('edges') - nodes = [Node(node.get('id'), node.get('type'), **node) - for node in nodes] - edges = [Edge(edge.get('id'), edge.get('type'), **edge) for edge in edges] - return Flow(nodes, edges) - - def get_start_node(self): - start_node_list = [node for node in self.nodes if node.id == 'start-node'] - return start_node_list[0] - - def get_search_node(self): - return [node for node in self.nodes if node.type == 'search-dataset-node'] - - def is_valid(self): - """ - 校验工作流数据 - """ - self.is_valid_model_params() - self.is_valid_start_node() - self.is_valid_base_node() - self.is_valid_work_flow() - - @staticmethod - def is_valid_node_params(node: Node): - get_node(node.type)(node, None, None) - - def is_valid_node(self, node: Node): - self.is_valid_node_params(node) - if node.type == 'condition-node': - branch_list = node.properties.get('node_data').get('branch') - for branch in branch_list: - source_anchor_id = f"{node.id}_{branch.get('id')}_right" - edge_list = [edge for edge in self.edges if edge.sourceAnchorId == source_anchor_id] - if len(edge_list) == 0: - raise AppApiException(500, - _('The branch {branch} of the {node} node needs to be connected').format( - node=node.properties.get("stepName"), branch=branch.get("type"))) - - else: - edge_list = [edge for edge in self.edges if edge.sourceNodeId == node.id] - if len(edge_list) == 0 and not end_nodes.__contains__(node.type): - raise AppApiException(500, _("{node} Nodes cannot be considered as end nodes").format( - node=node.properties.get("stepName"))) - - def get_next_nodes(self, node: Node): - edge_list = [edge for edge in self.edges if edge.sourceNodeId == node.id] - node_list = reduce(lambda x, y: [*x, *y], - [[node for node in self.nodes if node.id == edge.targetNodeId] for edge in edge_list], - []) - if len(node_list) == 0 and not end_nodes.__contains__(node.type): - raise AppApiException(500, - _("The next node that does not exist")) - return node_list - - def is_valid_work_flow(self, up_node=None): - if up_node is None: - up_node = self.get_start_node() - self.is_valid_node(up_node) - next_nodes = self.get_next_nodes(up_node) - for next_node in next_nodes: - self.is_valid_work_flow(next_node) - - def is_valid_start_node(self): - start_node_list = [node for node in self.nodes if node.id == 'start-node'] - if len(start_node_list) == 0: - raise AppApiException(500, _('The starting node is required')) - if len(start_node_list) > 1: - raise AppApiException(500, _('There can only be one starting node')) - - def is_valid_model_params(self): - node_list = [node for node in self.nodes if (node.type == 'ai-chat-node' or node.type == 'question-node')] - for node in node_list: - model = QuerySet(Model).filter(id=node.properties.get('node_data', {}).get('model_id')).first() - if model is None: - raise ValidationError(ErrorDetail( - _('The node {node} model does not exist').format(node=node.properties.get("stepName")))) - credential = get_model_credential(model.provider, model.model_type, model.model_name) - model_params_setting = node.properties.get('node_data', {}).get('model_params_setting') - model_params_setting_form = credential.get_model_params_setting_form( - model.model_name) - if model_params_setting is None: - model_params_setting = model_params_setting_form.get_default_form_data() - node.properties.get('node_data', {})['model_params_setting'] = model_params_setting - if node.properties.get('status', 200) != 200: - raise ValidationError( - ErrorDetail(_("Node {node} is unavailable").format(node.properties.get("stepName")))) - node_list = [node for node in self.nodes if (node.type == 'function-lib-node')] - for node in node_list: - function_lib_id = node.properties.get('node_data', {}).get('function_lib_id') - if function_lib_id is None: - raise ValidationError(ErrorDetail( - _('The library ID of node {node} cannot be empty').format(node=node.properties.get("stepName")))) - f_lib = QuerySet(Tool).filter(id=function_lib_id).first() - if f_lib is None: - raise ValidationError(ErrorDetail(_("The function library for node {node} is not available").format( - node=node.properties.get("stepName")))) - - def is_valid_base_node(self): - base_node_list = [node for node in self.nodes if node.id == 'base-node'] - if len(base_node_list) == 0: - raise AppApiException(500, _('Basic information node is required')) - if len(base_node_list) > 1: - raise AppApiException(500, _('There can only be one basic information node')) - - class NodeResultFuture: def __init__(self, r, e, status=200): self.r = r @@ -234,7 +90,7 @@ class NodeChunkManage: class WorkflowManage: - def __init__(self, flow: Flow, params, work_flow_post_handler: WorkFlowPostHandler, + def __init__(self, flow: Workflow, params, work_flow_post_handler: WorkFlowPostHandler, base_to_response: BaseToResponse = SystemToResponse(), form_data=None, image_list=None, document_list=None, audio_list=None, @@ -598,15 +454,14 @@ class WorkflowManage: """ 是否有下一个可运行的节点 """ - if node_result is not None and node_result.is_assertion_result(): - for edge in self.flow.edges: + next_edge_node_list = self.flow.get_next_edge_nodes(current_node.id) or [] + for next_edge_node in next_edge_node_list: + if node_result is not None and node_result.is_assertion_result(): + edge = next_edge_node.edge if (edge.sourceNodeId == current_node.id and f"{edge.sourceNodeId}_{node_result.node_variable.get('branch_id')}_right" == edge.sourceAnchorId): return True - else: - for edge in self.flow.edges: - if edge.sourceNodeId == current_node.id: - return True + return len(next_edge_node_list) > 0 def has_next_node(self, node_result: NodeResult | None): """ @@ -656,26 +511,6 @@ class WorkflowManage: return [[]] return [[item.to_dict() for item in r] for r in result] - def get_next_node(self): - """ - 获取下一个可运行的所有节点 - """ - if self.current_node is None: - node = self.get_start_node() - node_instance = get_node(node.type)(node, self.params, self) - return node_instance - if self.current_result is not None and self.current_result.is_assertion_result(): - for edge in self.flow.edges: - if (edge.sourceNodeId == self.current_node.id and - f"{edge.sourceNodeId}_{self.current_result.node_variable.get('branch_id')}_right" == edge.sourceAnchorId): - return self.get_node_cls_by_id(edge.targetNodeId) - else: - for edge in self.flow.edges: - if edge.sourceNodeId == self.current_node.id: - return self.get_node_cls_by_id(edge.targetNodeId) - - return None - @staticmethod def dependent_node(up_node_id, node): if not node.node_chunk.is_end(): @@ -712,14 +547,14 @@ class WorkflowManage: if current_node_result.is_interrupt_exec(current_node): return [] node_list = [] + next_edge_node_list = self.flow.get_next_edge_nodes(current_node.id) or [] if current_node_result is not None and current_node_result.is_assertion_result(): - for edge in self.flow.edges: - if (edge.sourceNodeId == current_node.id and + for edge_node in next_edge_node_list: + edge = edge_node.edge + next_node = edge_node.node + if ( f"{edge.sourceNodeId}_{current_node_result.node_variable.get('branch_id')}_right" == edge.sourceAnchorId): - next_node = [node for node in self.flow.nodes if node.id == edge.targetNodeId] - if len(next_node) == 0: - continue - if next_node[0].properties.get('condition', "AND") == 'AND': + if next_node.properties.get('condition', "AND") == 'AND': if self.dependent_node_been_executed(edge.targetNodeId): node_list.append( self.get_node_cls_by_id(edge.targetNodeId, @@ -729,20 +564,19 @@ class WorkflowManage: self.get_node_cls_by_id(edge.targetNodeId, [*current_node.up_node_id_list, current_node.node.id])) else: - for edge in self.flow.edges: - if edge.sourceNodeId == current_node.id: - next_node = [node for node in self.flow.nodes if node.id == edge.targetNodeId] - if len(next_node) == 0: - continue - if next_node[0].properties.get('condition', "AND") == 'AND': - if self.dependent_node_been_executed(edge.targetNodeId): - node_list.append( - self.get_node_cls_by_id(edge.targetNodeId, - [*current_node.up_node_id_list, current_node.node.id])) - else: + for edge_node in next_edge_node_list: + edge = edge_node.edge + next_node = edge_node.node + if next_node.properties.get('condition', "AND") == 'AND': + if self.dependent_node_been_executed(edge.targetNodeId): node_list.append( self.get_node_cls_by_id(edge.targetNodeId, [*current_node.up_node_id_list, current_node.node.id])) + else: + node_list.append( + self.get_node_cls_by_id(edge.targetNodeId, + [*current_node.up_node_id_list, current_node.node.id])) + return node_list def get_reference_field(self, node_id: str, fields: List[str]): diff --git a/apps/application/serializers/application.py b/apps/application/serializers/application.py index 3f60a5b12..aafaf3823 100644 --- a/apps/application/serializers/application.py +++ b/apps/application/serializers/application.py @@ -22,10 +22,10 @@ from django.utils.translation import gettext_lazy as _ from rest_framework import serializers, status from rest_framework.utils.formatting import lazy_format +from application.flow.common import Workflow from application.models.application import Application, ApplicationTypeChoices, ApplicationKnowledgeMapping, \ ApplicationFolder, WorkFlowVersion from application.models.application_access_token import ApplicationAccessToken -from application.flow.workflow_manage import Flow from common import result from common.database_model_manage.database_model_manage import DatabaseModelManage from common.db.search import native_search, native_page_search @@ -589,7 +589,7 @@ class ApplicationOperateSerializer(serializers.Serializer): work_flow = instance.get('work_flow') if work_flow is None: raise AppApiException(500, _("work_flow is a required field")) - Flow.new_instance(work_flow).is_valid() + Workflow.new_instance(work_flow).is_valid() base_node = get_base_node_work_flow(work_flow) if base_node is not None: node_data = base_node.get('properties').get('node_data') diff --git a/apps/chat/serializers/chat.py b/apps/chat/serializers/chat.py index 2e8d7289f..047ae564a 100644 --- a/apps/chat/serializers/chat.py +++ b/apps/chat/serializers/chat.py @@ -22,9 +22,9 @@ from application.chat_pipeline.step.generate_human_message_step.impl.base_genera BaseGenerateHumanMessageStep from application.chat_pipeline.step.reset_problem_step.impl.base_reset_problem_step import BaseResetProblemStep from application.chat_pipeline.step.search_dataset_step.impl.base_search_dataset_step import BaseSearchDatasetStep -from application.flow.common import Answer +from application.flow.common import Answer, Workflow from application.flow.i_step_node import WorkFlowPostHandler -from application.flow.workflow_manage import WorkflowManage, Flow +from application.flow.workflow_manage import WorkflowManage from application.models import Application, ApplicationTypeChoices, WorkFlowVersion, ApplicationKnowledgeMapping, \ ChatUserType, ApplicationChatUserStats, ApplicationAccessToken, ChatRecord, Chat from application.serializers.common import ChatInfo @@ -229,7 +229,7 @@ class ChatSerializers(serializers.Serializer): work_flow = chat_info.work_flow_version.work_flow else: work_flow = chat_info.application.work_flow - work_flow_manage = WorkflowManage(Flow.new_instance(work_flow), + work_flow_manage = WorkflowManage(Workflow.new_instance(work_flow), {'history_chat_record': history_chat_record, 'question': message, 'chat_id': chat_info.chat_id, 'chat_record_id': str( uuid.uuid1()) if chat_record is None else chat_record.id, diff --git a/ui/package.json b/ui/package.json index 98d65a406..272bc79d9 100644 --- a/ui/package.json +++ b/ui/package.json @@ -13,7 +13,7 @@ "format": "prettier --write src/" }, "dependencies": { - "@antv/layout": "^1.2.14-beta.8", + "@antv/layout": "^0.3.1", "@codemirror/lang-json": "^6.0.1", "@codemirror/lang-python": "^6.2.1", "@codemirror/theme-one-dark": "^6.1.2", diff --git a/ui/src/workflow/plugins/dagre.ts b/ui/src/workflow/plugins/dagre.ts index 67b72fdad..c89813daa 100644 --- a/ui/src/workflow/plugins/dagre.ts +++ b/ui/src/workflow/plugins/dagre.ts @@ -40,7 +40,7 @@ export default class Dagre { nodesep, ranksep, begin: [120, 120], - ...option + ...option, } const layoutInstance = new DagreLayout(this.option) const layoutData = layoutInstance.layout({ @@ -48,15 +48,15 @@ export default class Dagre { id: node.id, size: { width: node.width, - height: node.height + height: node.height, }, - model: node + model: node, })), edges: edges.map((edge: any) => ({ source: edge.sourceNodeId, target: edge.targetNodeId, - model: edge - })) + model: edge, + })), }) layoutData.nodes?.forEach((node: any) => {