From 2e331dcf56fdd1de75bd31f03d194bf850fef258 Mon Sep 17 00:00:00 2001 From: shaohuzhang1 <80892890+shaohuzhang1@users.noreply.github.com> Date: Sun, 13 Oct 2024 12:00:13 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=B7=A5=E4=BD=9C=E6=B5=81=E7=BC=96?= =?UTF-8?q?=E6=8E=92=E6=94=AF=E6=8C=81=E5=B9=B6=E8=A1=8C=20#1154=20(#1362)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/application/flow/i_step_node.py | 3 +- .../ai_chat_step_node/impl/base_chat_node.py | 2 +- .../impl/base_function_lib_node.py | 2 +- apps/application/flow/workflow_manage.py | 367 +++++++++++++----- ui/src/workflow/common/app-node.ts | 27 +- ui/src/workflow/common/validate.ts | 4 - 6 files changed, 289 insertions(+), 116 deletions(-) diff --git a/apps/application/flow/i_step_node.py b/apps/application/flow/i_step_node.py index 2d5e38881..48c373307 100644 --- a/apps/application/flow/i_step_node.py +++ b/apps/application/flow/i_step_node.py @@ -28,7 +28,7 @@ def write_context(step_variable: Dict, global_variable: Dict, node, workflow): if step_variable is not None: for key in step_variable: node.context[key] = step_variable[key] - if workflow.is_result() and 'answer' in step_variable: + if workflow.is_result(node, NodeResult(step_variable, global_variable)) and 'answer' in step_variable: answer = step_variable['answer'] yield answer workflow.answer += answer @@ -166,6 +166,7 @@ class INode: def get_write_error_context(self, e): self.status = 500 self.err_message = str(e) + self.context['run_time'] = time.time() - self.context['start_time'] def write_error_context(answer, status=200): pass diff --git a/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py b/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py index 9b41fb83f..daa7b452f 100644 --- a/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py +++ b/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py @@ -31,7 +31,7 @@ def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wo node.context['history_message'] = node_variable['history_message'] node.context['question'] = node_variable['question'] node.context['run_time'] = time.time() - node.context['start_time'] - if workflow.is_result(): + if workflow.is_result(node, NodeResult(node_variable, workflow_variable)): workflow.answer += answer diff --git a/apps/application/flow/step_node/function_lib_node/impl/base_function_lib_node.py b/apps/application/flow/step_node/function_lib_node/impl/base_function_lib_node.py index 94e2f5c54..3ae9f7b89 100644 --- a/apps/application/flow/step_node/function_lib_node/impl/base_function_lib_node.py +++ b/apps/application/flow/step_node/function_lib_node/impl/base_function_lib_node.py @@ -26,7 +26,7 @@ def write_context(step_variable: Dict, global_variable: Dict, node, workflow): if step_variable is not None: for key in step_variable: node.context[key] = step_variable[key] - if workflow.is_result() and 'result' in step_variable: + if workflow.is_result(node, NodeResult(step_variable, global_variable)) and 'result' in step_variable: result = str(step_variable['result']) + '\n' yield result workflow.answer += result diff --git a/apps/application/flow/workflow_manage.py b/apps/application/flow/workflow_manage.py index 340baa724..070c86525 100644 --- a/apps/application/flow/workflow_manage.py +++ b/apps/application/flow/workflow_manage.py @@ -7,7 +7,9 @@ @desc: """ import json +import threading import traceback +from concurrent.futures import ThreadPoolExecutor from functools import reduce from typing import List, Dict @@ -26,6 +28,8 @@ from function_lib.models.function import FunctionLib from setting.models import Model from setting.models_provider import get_model_credential +executor = ThreadPoolExecutor(max_workers=50) + class Edge: def __init__(self, _id: str, _type: str, sourceNodeId: str, targetNodeId: str, **keywords): @@ -95,17 +99,11 @@ class Flow: if len(edge_list) == 0: raise AppApiException(500, f'{node.properties.get("stepName")} 节点的{branch.get("type")}分支需要连接') - elif len(edge_list) > 1: - raise AppApiException(500, - f'{node.properties.get("stepName")} 节点的{branch.get("type")}分支不能连接俩个节点') else: edge_list = [edge for edge in self.edges if edge.sourceNodeId == node.id] if len(edge_list) == 0 and not end_nodes.__contains__(node.type): raise AppApiException(500, f'{node.properties.get("stepName")} 节点不能当做结束节点') - elif len(edge_list) > 1: - raise AppApiException(500, - f'{node.properties.get("stepName")} 节点不能连接俩个节点') def get_next_nodes(self, node: Node): edge_list = [edge for edge in self.edges if edge.sourceNodeId == node.id] @@ -165,6 +163,77 @@ class Flow: raise AppApiException(500, '基本信息节点只能有一个') +class NodeResultFuture: + def __init__(self, r, e, status=200): + self.r = r + self.e = e + self.status = status + + def result(self): + if self.status == 200: + return self.r + else: + raise self.e + + +def await_result(result, timeout=1): + try: + result.result(timeout) + return False + except Exception as e: + return True + + +class NodeChunkManage: + + def __init__(self, work_flow): + self.node_chunk_list = [] + self.current_node_chunk = None + self.work_flow = work_flow + + def add_node_chunk(self, node_chunk): + self.node_chunk_list.append(node_chunk) + + def pop(self): + if self.current_node_chunk is None: + try: + current_node_chunk = self.node_chunk_list.pop(0) + self.current_node_chunk = current_node_chunk + except IndexError as e: + pass + if self.current_node_chunk is not None: + try: + chunk = self.current_node_chunk.chunk_list.pop(0) + return chunk + except IndexError as e: + if self.current_node_chunk.is_end(): + self.current_node_chunk = None + if len(self.work_flow.answer) > 0: + chunk = self.work_flow.base_to_response.to_stream_chunk_response( + self.work_flow.params['chat_id'], + self.work_flow.params['chat_record_id'], + '\n\n', False, 0, 0) + self.work_flow.answer += '\n\n' + return chunk + return self.pop() + return None + + +class NodeChunk: + def __init__(self): + self.status = 0 + self.chunk_list = [] + + def add_chunk(self, chunk): + self.chunk_list.append(chunk) + + def end(self): + self.status = 200 + + def is_end(self): + return self.status == 200 + + class WorkflowManage: def __init__(self, flow: Flow, params, work_flow_post_handler: WorkFlowPostHandler, base_to_response: BaseToResponse = SystemToResponse(), form_data=None): @@ -173,12 +242,15 @@ class WorkflowManage: self.form_data = form_data self.params = params self.flow = flow + self.lock = threading.Lock() self.context = {} self.node_context = [] + self.node_chunk_manage = NodeChunkManage(self) self.work_flow_post_handler = work_flow_post_handler self.current_node = None self.current_result = None self.answer = "" + self.status = 0 self.base_to_response = base_to_response def run(self): @@ -187,115 +259,180 @@ class WorkflowManage: return self.run_block() def run_block(self): - try: - while self.has_next_node(self.current_result): - self.current_node = self.get_next_node() - self.current_node.valid_args(self.current_node.node_params, self.current_node.workflow_params) - self.node_context.append(self.current_node) - self.current_result = self.current_node.run() - result = self.current_result.write_context(self.current_node, self) - if result is not None: - list(result) - if not self.has_next_node(self.current_result): - details = self.get_runtime_details() - message_tokens = sum([row.get('message_tokens') for row in details.values() if - 'message_tokens' in row and row.get('message_tokens') is not None]) - answer_tokens = sum([row.get('answer_tokens') for row in details.values() if - 'answer_tokens' in row and row.get('answer_tokens') is not None]) - self.work_flow_post_handler.handler(self.params['chat_id'], self.params['chat_record_id'], - self.answer, - self) - return self.base_to_response.to_block_response(self.params['chat_id'], - self.params['chat_record_id'], self.answer, True - , message_tokens, answer_tokens) - except Exception as e: - traceback.print_exc() - self.current_node.get_write_error_context(e) - self.work_flow_post_handler.handler(self.params['chat_id'], self.params['chat_record_id'], - self.answer, - self) - return self.base_to_response.to_block_response(self.params['chat_id'], self.params['chat_record_id'], - str(e), True, - 0, 0, _status=status.HTTP_500_INTERNAL_SERVER_ERROR) + """ + 非流式响应 + @return: 结果 + """ + result = self.run_chain_async(None) + result.result() + details = self.get_runtime_details() + message_tokens = sum([row.get('message_tokens') for row in details.values() if + 'message_tokens' in row and row.get('message_tokens') is not None]) + answer_tokens = sum([row.get('answer_tokens') for row in details.values() if + 'answer_tokens' in row and row.get('answer_tokens') is not None]) + self.work_flow_post_handler.handler(self.params['chat_id'], self.params['chat_record_id'], + self.answer, + self) + return self.base_to_response.to_block_response(self.params['chat_id'], + self.params['chat_record_id'], self.answer, True + , message_tokens, answer_tokens, + _status=status.HTTP_200_OK if self.status == 200 else status.HTTP_500_INTERNAL_SERVER_ERROR) def run_stream(self): - return tools.to_stream_response_simple(self.stream_event()) - - def stream_event(self): - try: - while self.has_next_node(self.current_result): - self.current_node = self.get_next_node() - self.node_context.append(self.current_node) - self.current_node.valid_args(self.current_node.node_params, self.current_node.workflow_params) - self.current_result = self.current_node.run() - result = self.current_result.write_context(self.current_node, self) - has_next_node = self.has_next_node(self.current_result) - if result is not None: - if self.is_result(): - for r in result: - yield self.base_to_response.to_stream_chunk_response(self.params['chat_id'], - self.params['chat_record_id'], - r, False, 0, 0) - if has_next_node: - yield self.base_to_response.to_stream_chunk_response(self.params['chat_id'], - self.params['chat_record_id'], - '\n', False, 0, 0) - self.answer += '\n' - else: - list(result) - if not has_next_node: - details = self.get_runtime_details() - message_tokens = sum([row.get('message_tokens') for row in details.values() if - 'message_tokens' in row and row.get('message_tokens') is not None]) - answer_tokens = sum([row.get('answer_tokens') for row in details.values() if - 'answer_tokens' in row and row.get('answer_tokens') is not None]) - yield self.base_to_response.to_stream_chunk_response(self.params['chat_id'], - self.params['chat_record_id'], - '', True, message_tokens, answer_tokens) - break - self.work_flow_post_handler.handler(self.params['chat_id'], self.params['chat_record_id'], - self.answer, - self) - except Exception as e: - self.current_node.get_write_error_context(e) - self.answer += str(e) - self.work_flow_post_handler.handler(self.params['chat_id'], self.params['chat_record_id'], - self.answer, - self) - yield self.base_to_response.to_stream_chunk_response(self.params['chat_id'], self.params['chat_record_id'], - str(e), True, 0, 0) - - def is_result(self): """ - 判断是否是返回节点 + 流式响应 @return: """ - return self.current_node.node_params.get('is_result', not self.has_next_node( - self.current_result)) if self.current_node.node_params is not None else False + result = self.run_chain_async(None) + return tools.to_stream_response_simple(self.await_result(result)) + + def await_result(self, result): + try: + while await_result(result): + while True: + chunk = self.node_chunk_manage.pop() + if chunk is not None: + yield chunk + else: + break + while True: + chunk = self.node_chunk_manage.pop() + if chunk is None: + break + yield chunk + finally: + self.work_flow_post_handler.handler(self.params['chat_id'], self.params['chat_record_id'], + self.answer, + self) + + def run_chain_async(self, current_node): + future = executor.submit(self.run_chain, current_node) + return future + + def run_chain(self, current_node): + if current_node is None: + start_node = self.get_start_node() + current_node = get_node(start_node.type)(start_node, self.params, self) + node_result_future = self.run_node_future(current_node) + try: + is_stream = self.params.get('stream', True) + # 处理节点响应 + result = self.hand_event_node_result(current_node, + node_result_future) if is_stream else self.hand_node_result( + current_node, node_result_future) + with self.lock: + if current_node.status == 500: + return + node_list = self.get_next_node_list(current_node, result) + # 获取到可执行的子节点 + result_list = [] + for node in node_list: + result = self.run_chain_async(node) + result_list.append(result) + [r.result() for r in result_list] + if self.status == 0: + self.status = 200 + except Exception as e: + traceback.print_exc() + + def hand_node_result(self, current_node, node_result_future): + try: + current_result = node_result_future.result() + result = current_result.write_context(current_node, self) + if result is not None: + # 阻塞获取结果 + list(result) + # 添加节点 + self.node_context.append(current_node) + return current_result + except Exception as e: + # 添加节点 + self.node_context.append(current_node) + traceback.print_exc() + self.status = 500 + current_node.get_write_error_context(e) + self.answer += str(e) + + def hand_event_node_result(self, current_node, node_result_future): + try: + current_result = node_result_future.result() + result = current_result.write_context(current_node, self) + if result is not None: + if self.is_result(current_node, current_result): + node_chunk = NodeChunk() + self.node_chunk_manage.add_node_chunk(node_chunk) + for r in result: + chunk = self.base_to_response.to_stream_chunk_response(self.params['chat_id'], + self.params['chat_record_id'], + r, False, 0, 0) + node_chunk.add_chunk(chunk) + node_chunk.end() + else: + list(result) + # 添加节点 + self.node_context.append(current_node) + return current_result + except Exception as e: + # 添加节点 + self.node_context.append(current_node) + traceback.print_exc() + self.status = 500 + current_node.get_write_error_context(e) + self.answer += str(e) + chunk = self.base_to_response.to_stream_chunk_response(self.params['chat_id'], + self.params['chat_record_id'], + str(e), False, 0, 0) + node_chunk = NodeChunk() + self.node_chunk_manage.add_node_chunk(node_chunk) + node_chunk.add_chunk(chunk) + node_chunk.end() + + def run_node_async(self, node): + future = executor.submit(self.run_node, node) + return future + + def run_node_future(self, node): + try: + node.valid_args(node.node_params, node.workflow_params) + result = self.run_node(node) + return NodeResultFuture(result, None, 200) + except Exception as e: + return NodeResultFuture(None, e, 500) + + def run_node(self, node): + result = node.run() + result.write_context(node, self) + return result + + def is_result(self, current_node, current_node_result): + return current_node.node_params.get('is_result', not self._has_next_node( + current_node, current_node_result)) if current_node.node_params is not None else False def get_chunk_content(self, chunk, is_end=False): return 'data: ' + json.dumps( {'chat_id': self.params['chat_id'], 'id': self.params['chat_record_id'], 'operate': True, 'content': chunk, 'is_end': is_end}, ensure_ascii=False) + "\n\n" + def _has_next_node(self, current_node, node_result: NodeResult | None): + """ + 是否有下一个可运行的节点 + """ + if node_result is not None and node_result.is_assertion_result(): + for edge in self.flow.edges: + if (edge.sourceNodeId == current_node.id and + f"{edge.sourceNodeId}_{node_result.node_variable.get('branch_id')}_right" == edge.sourceAnchorId): + return True + else: + for edge in self.flow.edges: + if edge.sourceNodeId == current_node.id: + return True + def has_next_node(self, node_result: NodeResult | None): """ 是否有下一个可运行的节点 """ - if self.current_node is None: - if self.get_start_node() is not None: - return True - else: - if node_result is not None and node_result.is_assertion_result(): - for edge in self.flow.edges: - if (edge.sourceNodeId == self.current_node.id and - f"{edge.sourceNodeId}_{node_result.node_variable.get('branch_id')}_right" == edge.sourceAnchorId): - return True - else: - for edge in self.flow.edges: - if edge.sourceNodeId == self.current_node.id: - return True - return False + return self._has_next_node(self.get_start_node() if self.current_node is None else self.current_node, + node_result) def get_runtime_details(self): details_result = {} @@ -325,9 +462,37 @@ class WorkflowManage: return None + def dependent_node_been_executed(self, node_id): + """ + 判断依赖节点是否都已执行 + @param node_id: 需要判断的节点id + @return: + """ + up_node_id_list = [edge.sourceNodeId for edge in self.flow.edges if edge.targetNodeId == node_id] + return all([any([node.id == up_node_id for node in self.node_context]) for up_node_id in up_node_id_list]) + + def get_next_node_list(self, current_node, current_node_result): + """ + 获取下一个可执行节点列表 + @param current_node: 当前可执行节点 + @param current_node_result: 当前可执行节点结果 + @return: 可执行节点列表 + """ + node_list = [] + if current_node_result is not None and current_node_result.is_assertion_result(): + for edge in self.flow.edges: + if (edge.sourceNodeId == current_node.id and + f"{edge.sourceNodeId}_{current_node_result.node_variable.get('branch_id')}_right" == edge.sourceAnchorId): + if self.dependent_node_been_executed(edge.targetNodeId): + node_list.append(self.get_node_cls_by_id(edge.targetNodeId)) + else: + for edge in self.flow.edges: + if edge.sourceNodeId == current_node.id and self.dependent_node_been_executed(edge.targetNodeId): + node_list.append(self.get_node_cls_by_id(edge.targetNodeId)) + return node_list + def get_reference_field(self, node_id: str, fields: List[str]): """ - @param node_id: 节点id @param fields: 字段 @return: diff --git a/ui/src/workflow/common/app-node.ts b/ui/src/workflow/common/app-node.ts index b84093816..0528da182 100644 --- a/ui/src/workflow/common/app-node.ts +++ b/ui/src/workflow/common/app-node.ts @@ -73,7 +73,7 @@ class AppNode extends HtmlResize.view { lh('div', { style: { zindex: 0 }, onClick: () => { - if (!isConnect && type == 'right') { + if (type == 'right') { this.props.model.openNodeMenu(anchorData) } }, @@ -193,23 +193,34 @@ class AppNodeModel extends HtmlResize.model { get_width() { return this.properties?.width || 340 } + setAttributes() { this.width = this.get_width() - + const isLoop=(node_id:string,target_node_id:string)=>{ + const up_node_list=this.graphModel.getNodeIncomingNode(node_id) + for (const index in up_node_list) { + const item=up_node_list[index] + if(item.id===target_node_id){ + return true + }else{ + const result= isLoop(item.id,target_node_id) + if(result){ + return true + } + } + } + return false + } const circleOnlyAsTarget = { message: '只允许从右边的锚点连出', validate: (sourceNode: any, targetNode: any, sourceAnchor: any) => { return sourceAnchor.type === 'right' } } - this.sourceRules.push({ - message: '只允许连一个节点', + message: '不可循环连线', validate: (sourceNode: any, targetNode: any, sourceAnchor: any, targetAnchor: any) => { - return !this.graphModel.edges.some( - (item) => - item.sourceAnchorId === sourceAnchor.id || item.targetAnchorId === targetAnchor.id - ) + return !isLoop(sourceNode.id,targetNode.id) } }) diff --git a/ui/src/workflow/common/validate.ts b/ui/src/workflow/common/validate.ts index 3c8e6a494..000d13c0c 100644 --- a/ui/src/workflow/common/validate.ts +++ b/ui/src/workflow/common/validate.ts @@ -129,16 +129,12 @@ export class WorkFlowInstance { const edge_list = this.edges.filter((edge) => edge.sourceAnchorId == source_anchor_id) if (edge_list.length == 0) { throw `${node.properties.stepName} 节点的${branch.type}分支需要连接` - } else if (edge_list.length > 1) { - throw `${node.properties.stepName} 节点的${branch.type}分支不能连接俩个节点` } } } else { const edge_list = this.edges.filter((edge) => edge.sourceNodeId == node.id) if (edge_list.length == 0 && !end_nodes.includes(node.type)) { throw `${node.properties.stepName} 节点不能当做结束节点` - } else if (edge_list.length > 1) { - throw `${node.properties.stepName} 节点不能连接俩个节点` } } if (node.properties.status && node.properties.status !== 200) {