diff --git a/apps/application/flow/i_step_node.py b/apps/application/flow/i_step_node.py index 4e35fac2b..286b2b495 100644 --- a/apps/application/flow/i_step_node.py +++ b/apps/application/flow/i_step_node.py @@ -168,7 +168,7 @@ class INode: self.runtime_node_id, self.context.get('reasoning_content', '') if reasoning_content_enable else '')] def __init__(self, node, workflow_params, workflow_manage, up_node_id_list=None, - get_node_params=lambda node: node.properties.get('node_data')): + get_node_params=lambda node: node.properties.get('node_data'), salt=None): # 当前步骤上下文,用于存储当前步骤信息 self.status = 200 self.err_message = '' @@ -188,7 +188,8 @@ class INode: self.runtime_node_id = sha1(uuid.NAMESPACE_DNS.bytes + bytes(str(uuid.uuid5(uuid.NAMESPACE_DNS, "".join([*sorted(up_node_id_list), node.id]))), - "utf-8")).hexdigest() + "utf-8")).hexdigest() + ( + "__" + str(salt) if salt is not None else '') def valid_args(self, node_params, flow_params): flow_params_serializer_class = self.get_flow_params_serializer_class() diff --git a/apps/application/flow/loop_workflow_manage.py b/apps/application/flow/loop_workflow_manage.py new file mode 100644 index 000000000..16087ff8c --- /dev/null +++ b/apps/application/flow/loop_workflow_manage.py @@ -0,0 +1,167 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: workflow_manage.py + @date:2024/1/9 17:40 + @desc: +""" +from concurrent.futures import ThreadPoolExecutor +from typing import List + +from django.db import close_old_connections +from django.utils.translation import get_language +from langchain_core.prompts import PromptTemplate + +from application.flow.common import Workflow +from application.flow.i_step_node import WorkFlowPostHandler, INode +from application.flow.step_node import get_node +from application.flow.workflow_manage import WorkflowManage +from common.handle.base_to_response import BaseToResponse +from common.handle.impl.response.system_to_response import SystemToResponse + +executor = ThreadPoolExecutor(max_workers=200) + + +class NodeResultFuture: + def __init__(self, r, e, status=200): + self.r = r + self.e = e + self.status = status + + def result(self): + if self.status == 200: + return self.r + else: + raise self.e + + +def await_result(result, timeout=1): + try: + result.result(timeout) + return False + except Exception as e: + return True + + +class NodeChunkManage: + + def __init__(self, work_flow): + self.node_chunk_list = [] + self.current_node_chunk = None + self.work_flow = work_flow + + def add_node_chunk(self, node_chunk): + self.node_chunk_list.append(node_chunk) + + def contains(self, node_chunk): + return self.node_chunk_list.__contains__(node_chunk) + + def pop(self): + if self.current_node_chunk is None: + try: + current_node_chunk = self.node_chunk_list.pop(0) + self.current_node_chunk = current_node_chunk + except IndexError as e: + pass + if self.current_node_chunk is not None: + try: + chunk = self.current_node_chunk.chunk_list.pop(0) + return chunk + except IndexError as e: + if self.current_node_chunk.is_end(): + self.current_node_chunk = None + if self.work_flow.answer_is_not_empty(): + chunk = self.work_flow.base_to_response.to_stream_chunk_response( + self.work_flow.params['chat_id'], + self.work_flow.params['chat_record_id'], + '\n\n', False, 0, 0) + self.work_flow.append_answer('\n\n') + return chunk + return self.pop() + return None + + +class LoopWorkflowManage(WorkflowManage): + + def __init__(self, flow: Workflow, + params, + work_flow_post_handler: WorkFlowPostHandler, + parentWorkflowManage, + loop_params, + base_to_response: BaseToResponse = SystemToResponse(), start_node_id=None, + start_node_data=None, chat_record=None, child_node=None): + self.parentWorkflowManage = parentWorkflowManage + self.loop_params = loop_params + super().__init__(flow, params, work_flow_post_handler, base_to_response, None, None, None, + None, + None, start_node_id, start_node_data, chat_record, child_node) + + def get_node_cls_by_id(self, node_id, up_node_id_list=None, + get_node_params=lambda node: node.properties.get('node_data')): + for node in self.flow.nodes: + if node.id == node_id: + node_instance = get_node(node.type)(node, + self.params, self, up_node_id_list, + get_node_params, + salt=self.get_index()) + return node_instance + return None + + def stream(self): + close_old_connections() + language = get_language() + self.run_chain_async(self.start_node, None, language) + return self.await_result() + + def get_index(self): + return self.loop_params.get('index') + + def get_start_node(self): + start_node_list = [node for node in self.flow.nodes if + ['loop-start-node'].__contains__(node.type)] + return start_node_list[0] + + def get_reference_field(self, node_id: str, fields: List[str]): + """ + @param node_id: 节点id + @param fields: 字段 + @return: + """ + if node_id == 'global': + return self.parentWorkflowManage.get_reference_field(node_id, fields) + elif node_id == 'chat': + return self.parentWorkflowManage.get_reference_field(node_id, fields) + else: + node = self.get_node_by_id(node_id) + if node: + return node.get_reference_field(fields) + return self.parentWorkflowManage.get_reference_field(node_id, fields) + + def get_workflow_content(self): + context = { + 'global': self.context, + 'chat': self.chat_context + } + + for node in self.node_context: + context[node.id] = node.context + return context + + def reset_prompt(self, prompt: str): + prompt = super().reset_prompt(prompt) + prompt = self.parentWorkflowManage.reset_prompt(prompt) + return prompt + + def generate_prompt(self, prompt: str): + """ + 格式化生成提示词 + @param prompt: 提示词信息 + @return: 格式化后的提示词 + """ + + context = {**self.get_workflow_content(), **self.parentWorkflowManage.get_workflow_content()} + prompt = self.reset_prompt(prompt) + prompt_template = PromptTemplate.from_template(prompt, template_format='jinja2') + value = prompt_template.format(context=context) + return value diff --git a/apps/application/flow/step_node/__init__.py b/apps/application/flow/step_node/__init__.py index d816fe6e5..216239850 100644 --- a/apps/application/flow/step_node/__init__.py +++ b/apps/application/flow/step_node/__init__.py @@ -14,6 +14,8 @@ from .document_extract_node import * from .form_node import * from .image_generate_step_node import * from .image_understand_step_node import * +from .loop_node import * +from .loop_start_node import * from .mcp_node import BaseMcpNode from .question_node import * from .reranker_node import * @@ -30,7 +32,7 @@ node_list = [BaseStartStepNode, BaseChatNode, BaseSearchKnowledgeNode, BaseQuest BaseToolNodeNode, BaseToolLibNodeNode, BaseRerankerNode, BaseApplicationNode, BaseDocumentExtractNode, BaseImageUnderstandNode, BaseFormNode, BaseSpeechToTextNode, BaseTextToSpeechNode, - BaseImageGenerateNode, BaseVariableAssignNode, BaseMcpNode] + BaseImageGenerateNode, BaseVariableAssignNode, BaseMcpNode, BaseLoopNode, BaseLoopStartStepNode] def get_node(node_type): diff --git a/apps/application/flow/step_node/loop_node/__init__.py b/apps/application/flow/step_node/loop_node/__init__.py new file mode 100644 index 000000000..a5f59372b --- /dev/null +++ b/apps/application/flow/step_node/loop_node/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: __init__.py + @date:2025/3/11 18:24 + @desc: +""" +from .impl import * diff --git a/apps/application/flow/step_node/loop_node/i_loop_node.py b/apps/application/flow/step_node/loop_node/i_loop_node.py new file mode 100644 index 000000000..6b2176513 --- /dev/null +++ b/apps/application/flow/step_node/loop_node/i_loop_node.py @@ -0,0 +1,56 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: i_loop_node.py + @date:2025/3/11 18:19 + @desc: +""" +from typing import Type + +from django.utils.translation import gettext_lazy as _ +from rest_framework import serializers + +from application.flow.i_step_node import INode, NodeResult +from common.exception.app_exception import AppApiException + + +class ILoopNodeSerializer(serializers.Serializer): + loop_type = serializers.CharField(required=True, label=_("loop_type")) + array = serializers.ListField(required=False, allow_null=True, + label=_("array")) + number = serializers.IntegerField(required=False, allow_null=True, + label=_("number")) + loop_body = serializers.DictField(required=True, label="循环体") + + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + loop_type = self.data.get('loop_type') + if loop_type == 'ARRAY': + array = self.data.get('array') + if array is None or len(array) == 0: + message = _('{field}, this field is required.', field='array') + raise AppApiException(500, message) + elif loop_type == 'NUMBER': + number = self.data.get('number') + if number is None: + message = _('{field}, this field is required.', field='number') + raise AppApiException(500, message) + + +class ILoopNode(INode): + type = 'loop-node' + + def get_node_params_serializer_class(self) -> Type[serializers.Serializer]: + return ILoopNodeSerializer + + def _run(self): + array = self.node_params_serializer.data.get('array') + if self.node_params_serializer.data.get('loop_type') == 'ARRAY': + array = self.workflow_manage.get_reference_field( + array[0], + array[1:]) + return self.execute(**{**self.node_params_serializer.data, "array": array}, **self.flow_params_serializer.data) + + def execute(self, loop_type, array, number, loop_body, stream, **kwargs) -> NodeResult: + pass diff --git a/apps/application/flow/step_node/loop_node/impl/__init__.py b/apps/application/flow/step_node/loop_node/impl/__init__.py new file mode 100644 index 000000000..3cd082322 --- /dev/null +++ b/apps/application/flow/step_node/loop_node/impl/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: __init__.py.py + @date:2025/3/11 18:24 + @desc: +""" +from .base_loop_node import BaseLoopNode diff --git a/apps/application/flow/step_node/loop_node/impl/base_loop_node.py b/apps/application/flow/step_node/loop_node/impl/base_loop_node.py new file mode 100644 index 000000000..f7daa53e9 --- /dev/null +++ b/apps/application/flow/step_node/loop_node/impl/base_loop_node.py @@ -0,0 +1,252 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: base_loop_node.py + @date:2025/3/11 18:24 + @desc: +""" +import time +from typing import Dict, List + +from application.flow.common import Answer +from application.flow.i_step_node import NodeResult, WorkFlowPostHandler, INode +from application.flow.step_node.loop_node.i_loop_node import ILoopNode +from application.flow.tools import Reasoning +from application.models import ChatRecord +from common.handle.impl.response.loop_to_response import LoopToResponse + + +def _is_interrupt_exec(node, node_variable: Dict, workflow_variable: Dict): + return node.context.get('is_interrupt_exec', False) + + +def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str, + reasoning_content: str): + node.context['answer'] = answer + node.context['run_time'] = time.time() - node.context['start_time'] + node.context['reasoning_content'] = reasoning_content + if workflow.is_result(node, NodeResult(node_variable, workflow_variable)): + node.answer_text = answer + + +def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow): + """ + 写入上下文数据 (流式) + @param node_variable: 节点数据 + @param workflow_variable: 全局数据 + @param node: 节点 + @param workflow: 工作流管理器 + """ + + response = node_variable.get('result') + workflow_manage = node_variable.get('workflow_manage') + answer = '' + reasoning_content = '' + for chunk in response: + content_chunk = chunk.get('content', '') + reasoning_content_chunk = chunk.get('reasoning_content', '') + reasoning_content += reasoning_content_chunk + answer += content_chunk + yield {'content': content_chunk, + 'reasoning_content': reasoning_content_chunk} + runtime_details = workflow_manage.get_runtime_details() + _write_context(node_variable, workflow_variable, node, workflow, answer, reasoning_content) + + +def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow): + """ + 写入上下文数据 + @param node_variable: 节点数据 + @param workflow_variable: 全局数据 + @param node: 节点实例对象 + @param workflow: 工作流管理器 + """ + response = node_variable.get('result') + model_setting = node.context.get('model_setting', + {'reasoning_content_enable': False, 'reasoning_content_end': '', + 'reasoning_content_start': ''}) + reasoning = Reasoning(model_setting.get('reasoning_content_start'), model_setting.get('reasoning_content_end')) + reasoning_result = reasoning.get_reasoning_content(response) + reasoning_result_end = reasoning.get_end_reasoning_content() + content = reasoning_result.get('content') + reasoning_result_end.get('content') + if 'reasoning_content' in response.response_metadata: + reasoning_content = response.response_metadata.get('reasoning_content', '') + else: + reasoning_content = reasoning_result.get('reasoning_content') + reasoning_result_end.get('reasoning_content') + _write_context(node_variable, workflow_variable, node, workflow, content, reasoning_content) + + +def get_answer_list(instance, child_node_node_dict, runtime_node_id): + answer_list = instance.get_record_answer_list() + for a in answer_list: + _v = child_node_node_dict.get(a.get('runtime_node_id')) + if _v: + a['runtime_node_id'] = runtime_node_id + a['child_node'] = _v + return answer_list + + +def insert_or_replace(arr, index, value): + if index < len(arr): + arr[index] = value # 替换 + else: + # 在末尾插入足够多的None,然后替换最后一个 + arr.extend([None] * (index - len(arr) + 1)) + arr[index] = value + return arr + + +def loop_number(number: int, workflow_manage_new_instance, node: INode): + loop_global_data = {} + break_outer = False + is_interrupt_exec = False + loop_node_data = node.context.get('loop_node_data') or [] + loop_answer_data = node.context.get("loop_answer_data") or [] + current_index = node.context.get("current_index") or 0 + node_params = node.node_params + start_node_id = node_params.get('child_node', {}).get('runtime_node_id') + start_node_data = None + chat_record = None + child_node = None + if start_node_id: + chat_record_id = node_params.get('child_node', {}).get('chat_record_id') + child_node = node_params.get('child_node', {}).get('child_node') + start_node_data = node_params.get('node_data') + chat_record = ChatRecord(id=chat_record_id, answer_text_list=[], answer_text='', + details=loop_node_data[current_index]) + + for index in range(current_index, number): + """ + 指定次数循环 + @return: + """ + instance = workflow_manage_new_instance({'index': index, 'item': index}, loop_global_data, start_node_id, + start_node_data, chat_record, child_node) + response = instance.stream() + answer = '' + current_index = index + reasoning_content = '' + child_node_node_dict = {} + for chunk in response: + child_node = chunk.get('child_node') + runtime_node_id = chunk.get('runtime_node_id', '') + chat_record_id = chunk.get('chat_record_id', '') + child_node_node_dict[runtime_node_id] = { + 'runtime_node_id': runtime_node_id, + 'chat_record_id': chat_record_id, + 'child_node': child_node} + content_chunk = chunk.get('content', '') + reasoning_content_chunk = chunk.get('reasoning_content', '') + reasoning_content += reasoning_content_chunk + answer += content_chunk + yield chunk + node_type = chunk.get('node_type') + if node_type == 'form-node': + break_outer = True + is_interrupt_exec = True + start_node_id = None + start_node_data = None + chat_record = None + child_node = None + loop_global_data = instance.context + insert_or_replace(loop_node_data, index, instance.get_runtime_details()) + insert_or_replace(loop_answer_data, index, + get_answer_list(instance, child_node_node_dict, node.runtime_node_id)) + if break_outer: + break + node.context['is_interrupt_exec'] = is_interrupt_exec + node.context['loop_node_data'] = loop_node_data + node.context['loop_answer_data'] = loop_answer_data + node.context["index"] = current_index + node.context["item"] = current_index + + +def loop_array(array, workflow_manage_new_instance, node: INode): + loop_global_data = {} + loop_execute_details = [] + for item, index in zip(array, range(len(array))): + """ + 指定次数循环 + @return: + """ + instance = workflow_manage_new_instance({'index': index, 'item': item}, loop_global_data) + response = instance.stream() + for chunk in response: + yield chunk + node_type = chunk.get('node_type') + if node_type == 'form-node': + break + loop_global_data = instance.context + runtime_details = instance.get_runtime_details() + loop_execute_details.append(runtime_details) + node.context['loop_execute_details'] = loop_execute_details + + +def get_write_context(loop_type, array, number, loop_body, stream): + def inner_write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow): + if loop_type == 'ARRAY': + return loop_array(array, node_variable['workflow_manage_new_instance'], node) + return loop_number(number, node_variable['workflow_manage_new_instance'], node) + + return inner_write_context + + +class LoopWorkFlowPostHandler(WorkFlowPostHandler): + def handler(self, workflow): + pass + + +class BaseLoopNode(ILoopNode): + def save_context(self, details, workflow_manage): + self.context['result'] = details.get('result') + self.answer_text = str(details.get('result')) + + def get_answer_list(self) -> List[Answer] | None: + result = [] + for answer_list in (self.context.get("loop_answer_data") or []): + for a in answer_list: + if isinstance(a, dict): + result.append(Answer(**a)) + + return result + + def execute(self, loop_type, array, number, loop_body, stream, **kwargs) -> NodeResult: + from application.flow.loop_workflow_manage import LoopWorkflowManage, Workflow + def workflow_manage_new_instance(loop_data, global_data, start_node_id=None, + start_node_data=None, chat_record=None, child_node=None): + workflow_manage = LoopWorkflowManage(Workflow.new_instance(loop_body), self.workflow_manage.params, + LoopWorkFlowPostHandler( + self.workflow_manage.work_flow_post_handler.chat_info) + , + self.workflow_manage, + loop_data, + base_to_response=LoopToResponse(), + start_node_id=start_node_id, + start_node_data=start_node_data, + chat_record=chat_record, + child_node=child_node + ) + + return workflow_manage + + return NodeResult({'workflow_manage_new_instance': workflow_manage_new_instance}, {}, + _write_context=get_write_context(loop_type, array, number, loop_body, stream), + _is_interrupt=_is_interrupt_exec) + + def get_details(self, index: int, **kwargs): + return { + 'name': self.node.properties.get('stepName'), + "index": index, + "result": self.context.get('result'), + "params": self.context.get('params'), + 'run_time': self.context.get('run_time'), + 'type': self.node.type, + 'current_index': self.context.get("index"), + "current_item": self.context.get("item"), + 'loop_type': self.context.get("loop_type"), + 'status': self.status, + 'loop_node_data': self.context.get("loop_node_data"), + 'loop_answer_data': self.context.get("loop_answer_data"), + 'err_message': self.err_message + } diff --git a/apps/application/flow/step_node/loop_start_node/__init__.py b/apps/application/flow/step_node/loop_start_node/__init__.py new file mode 100644 index 000000000..98a1afcd9 --- /dev/null +++ b/apps/application/flow/step_node/loop_start_node/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py + @date:2024/6/11 15:30 + @desc: +""" +from .impl import * diff --git a/apps/application/flow/step_node/loop_start_node/i_loop_start_node.py b/apps/application/flow/step_node/loop_start_node/i_loop_start_node.py new file mode 100644 index 000000000..21a059b76 --- /dev/null +++ b/apps/application/flow/step_node/loop_start_node/i_loop_start_node.py @@ -0,0 +1,20 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: i_start_node.py + @date:2024/6/3 16:54 + @desc: +""" + +from application.flow.i_step_node import INode, NodeResult + + +class ILoopStarNode(INode): + type = 'loop-start-node' + + def _run(self): + return self.execute(**self.flow_params_serializer.data) + + def execute(self, **kwargs) -> NodeResult: + pass diff --git a/apps/application/flow/step_node/loop_start_node/impl/__init__.py b/apps/application/flow/step_node/loop_start_node/impl/__init__.py new file mode 100644 index 000000000..76f972fce --- /dev/null +++ b/apps/application/flow/step_node/loop_start_node/impl/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py + @date:2024/6/11 15:36 + @desc: +""" +from .base_start_node import BaseLoopStartStepNode diff --git a/apps/application/flow/step_node/loop_start_node/impl/base_start_node.py b/apps/application/flow/step_node/loop_start_node/impl/base_start_node.py new file mode 100644 index 000000000..4f691f9e7 --- /dev/null +++ b/apps/application/flow/step_node/loop_start_node/impl/base_start_node.py @@ -0,0 +1,55 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: base_start_node.py + @date:2024/6/3 17:17 + @desc: +""" +from typing import Type + +from rest_framework import serializers + +from application.flow.i_step_node import NodeResult +from application.flow.step_node.loop_start_node.i_loop_start_node import ILoopStarNode + + +class BaseLoopStartStepNode(ILoopStarNode): + def save_context(self, details, workflow_manage): + self.context['index'] = details.get('current_index') + self.context['item'] = details.get('current_item') + + def get_node_params_serializer_class(self) -> Type[serializers.Serializer]: + pass + + def execute(self, **kwargs) -> NodeResult: + """ + 开始节点 初始化全局变量 + """ + loop_params = self.workflow_manage.loop_params + node_variable = { + 'index': loop_params.get("index"), + 'item': loop_params.get("item") + } + self.workflow_manage.chat_context = self.workflow_manage.get_chat_info().get_chat_variable() + return NodeResult(node_variable, {}) + + def get_details(self, index: int, **kwargs): + global_fields = [] + for field in self.node.properties.get('config')['globalFields']: + key = field['value'] + global_fields.append({ + 'label': field['label'], + 'key': key, + 'value': self.workflow_manage.context[key] if key in self.workflow_manage.context else '' + }) + return { + 'name': self.node.properties.get('stepName'), + "index": index, + "current_index": self.context.get('index'), + "current_item": self.context.get('item'), + 'run_time': self.context.get('run_time'), + 'type': self.node.type, + 'status': self.status, + 'err_message': self.err_message, + } diff --git a/apps/application/flow/workflow_manage.py b/apps/application/flow/workflow_manage.py index d8591d6fc..ff6e09efc 100644 --- a/apps/application/flow/workflow_manage.py +++ b/apps/application/flow/workflow_manage.py @@ -187,6 +187,8 @@ class WorkflowManage: is_result = False if n.type == 'application-node': is_result = True + if n.type == 'loop-node': + is_result = True return {**n.properties.get('node_data'), 'form_data': start_node_data, 'node_data': start_node_data, 'child_node': self.child_node, 'is_result': is_result} @@ -194,6 +196,12 @@ class WorkflowManage: get_node_params=get_node_params) self.start_node.valid_args( {**self.start_node.node_params, 'form_data': start_node_data}, self.start_node.workflow_params) + if self.start_node.type == 'loop-node': + loop_node_data = node_details.get('loop_node_data', {}) + self.start_node.context['loop_node_data'] = loop_node_data + self.start_node.context['current_index'] = node_details.get('current_index') + self.start_node.context['current_item'] = node_details.get('current_item') + self.start_node.context['loop_answer_data']=node_details.get('loop_answer_data', {}) if self.start_node.type == 'application-node': application_node_dict = node_details.get('application_node_dict', {}) self.start_node.context['application_node_dict'] = application_node_dict @@ -500,6 +508,10 @@ class WorkflowManage: details_result[node.runtime_node_id] = details return details_result + def get_record_answer_list(self): + answer_text_list = self.get_answer_text_list() + return reduce(lambda pre, _n: [*pre, *_n], answer_text_list, []) + def get_answer_text_list(self): result = [] answer_list = reduce(lambda x, y: [*x, *y], @@ -546,10 +558,6 @@ class WorkflowManage: return all([any([self.dependent_node(up_node_id, node) for node in self.node_context]) for up_node_id in up_node_id_list]) - def get_up_node_id_list(self, node_id): - up_node_id_list = [edge.sourceNodeId for edge in self.flow.edges if edge.targetNodeId == node_id] - return up_node_id_list - def get_next_node_list(self, current_node, current_node_result): """ 获取下一个可执行节点列表 diff --git a/apps/common/handle/impl/response/loop_to_response.py b/apps/common/handle/impl/response/loop_to_response.py new file mode 100644 index 000000000..7e6553ab7 --- /dev/null +++ b/apps/common/handle/impl/response/loop_to_response.py @@ -0,0 +1,27 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: LoopToResponse.py + @date:2025/3/12 17:21 + @desc: +""" +import json + +from common.handle.impl.response.system_to_response import SystemToResponse + + +class LoopToResponse(SystemToResponse): + + def to_stream_chunk_response(self, chat_id, chat_record_id, node_id, up_node_id_list, content, is_end, + completion_tokens, + prompt_tokens, other_params: dict = None): + if other_params is None: + other_params = {} + return {'chat_id': str(chat_id), 'chat_record_id': str(chat_record_id), 'operate': True, + 'content': content, 'node_id': node_id, 'up_node_id_list': up_node_id_list, + 'is_end': is_end, + 'usage': {'completion_tokens': completion_tokens, + 'prompt_tokens': prompt_tokens, + 'total_tokens': completion_tokens + prompt_tokens}, + **other_params} diff --git a/ui/src/components/ai-chat/component/answer-content/index.vue b/ui/src/components/ai-chat/component/answer-content/index.vue index aec4ad7de..d42ff9fd9 100644 --- a/ui/src/components/ai-chat/component/answer-content/index.vue +++ b/ui/src/components/ai-chat/component/answer-content/index.vue @@ -2,8 +2,8 @@
diff --git a/ui/src/workflow/common/NodeContainer.vue b/ui/src/workflow/common/NodeContainer.vue index fdcc2ea19..07d5d0c27 100644 --- a/ui/src/workflow/common/NodeContainer.vue +++ b/ui/src/workflow/common/NodeContainer.vue @@ -334,7 +334,9 @@ const nodeFields = computed(() => { }) function showOperate(type: string) { - return type !== WorkflowType.Base && type !== WorkflowType.Start + return ![WorkflowType.Start, WorkflowType.Base, WorkflowType.LoopStartNode.toString()].includes( + type, + ) } const openNodeMenu = (anchorValue: any) => { showAnchor.value = true diff --git a/ui/src/workflow/common/app-node.ts b/ui/src/workflow/common/app-node.ts index 46250a236..253a3d420 100644 --- a/ui/src/workflow/common/app-node.ts +++ b/ui/src/workflow/common/app-node.ts @@ -18,7 +18,6 @@ const getNodeName = (nodes: Array, baseName: string) => { while (true) { if (index > 0) { name = baseName + index - console.log(name) } if (!nodes.some((node: any) => node.properties.stepName === name.trim())) { return name @@ -85,7 +84,7 @@ class AppNode extends HtmlResize.view { if (!this.up_node_field_dict || !use_cache) { const up_node_list = this.props.graphModel.getNodeIncomingNode(this.props.model.id) this.up_node_field_dict = up_node_list - .filter((node) => node.id != 'start-node') + .filter((node) => node.id != 'start-node' && node.id != 'loop-start-node') .map((node) => node.get_up_node_field_dict(true, use_cache)) .reduce((pre, next) => ({ ...pre, ...next }), {}) } @@ -103,9 +102,10 @@ class AppNode extends HtmlResize.view { (pre, next) => [...pre, ...next], [], ) - const start_node_field_list = this.props.graphModel - .getNodeModelById('start-node') - .get_node_field_list() + const start_node_field_list = ( + this.props.graphModel.getNodeModelById('start-node') || + this.props.graphModel.getNodeModelById('loop-start-node') + ).get_node_field_list() return [...start_node_field_list, ...result] } @@ -219,7 +219,15 @@ class AppNode extends HtmlResize.view { if (root) { if (isActive()) { - connect(this.targetId(), this.component, root, model, graphModel) + connect( + this.targetId(), + this.component, + root, + model, + graphModel, + undefined, + this.props.graphModel.get_provide, + ) } else { this.r = h(this.component, { properties: this.props.model.properties, @@ -395,7 +403,7 @@ class AppNodeModel extends HtmlResize.model { const anchors: any = [] if (this.type !== WorkflowType.Base) { - if (this.type !== WorkflowType.Start) { + if (![WorkflowType.Start, WorkflowType.LoopStartNode.toString()].includes(this.type)) { anchors.push({ x: x - width / 2 + 10, y: showNode ? y : y - 15, diff --git a/ui/src/workflow/common/data.ts b/ui/src/workflow/common/data.ts index f8e0e5cca..79709a36b 100644 --- a/ui/src/workflow/common/data.ts +++ b/ui/src/workflow/common/data.ts @@ -1,4 +1,4 @@ -import { WorkflowType } from '@/enums/application' +import { WorkflowType, WorkflowMode } from '@/enums/application' import { t } from '@/locales' export const startNode = { @@ -360,7 +360,113 @@ export const toolNode = { }, }, } +export const loopStartNode = { + id: WorkflowType.LoopStartNode, + type: WorkflowType.LoopStartNode, + x: 480, + y: 3340, + properties: { + height: 364, + stepName: t('views.applicationWorkflow.nodes.loopStartNode.label', '循环开始'), + config: { + fields: [ + { + label: t('views.applicationWorkflow.nodes.startNode.index', '下标'), + value: 'index', + }, + { + label: t('views.applicationWorkflow.nodes.startNode.item', '循环元素'), + value: 'item', + }, + ], + globalFields: [], + }, + showNode: true, + }, +} + +export const loopNode = { + type: WorkflowType.LoopNode, + visible: false, + text: t('views.applicationWorkflow.nodes.loopNode.text', '循环节点'), + label: t('views.applicationWorkflow.nodes.loopNode.label', '循环节点'), + height: 252, + properties: { + stepName: t('views.applicationWorkflow.nodes.loopNode.label', '循环节点'), + workflow: { + edges: [], + nodes: [ + { + x: 480, + y: 3340, + id: 'loop-start-node', + type: 'loop-start-node', + properties: { + config: { + fields: [], + globalFields: [], + }, + fields: [], + height: 361.333, + showNode: true, + stepName: '开始', + globalFields: [], + }, + }, + ], + }, + config: { + fields: [ + { + label: t('loop.item', '循环参数'), + value: 'item', + }, + { + label: t('common.result'), + value: 'result', + }, + ], + }, + }, +} + +export const loopBodyNode = { + type: WorkflowType.LoopBodyNode, + text: t('views.applicationWorkflow.nodes.loopBodyNode.text', '循环体'), + label: t('views.applicationWorkflow.nodes.loopBodyNode.label', '循环体'), + height: 600, + properties: { + width: 1800, + stepName: t('views.applicationWorkflow.nodes.loopBodyNode.label', '循环体'), + config: { + fields: [], + }, + }, +} + export const menuNodes = [ + { + label: t('views.applicationWorkflow.nodes.classify.aiCapability'), + list: [ + aiChatNode, + questionNode, + imageGenerateNode, + imageUnderstandNode, + textToSpeechNode, + speechToTextNode, + ], + }, + { label: t('views.knowledge.title'), list: [searchKnowledgeNode, rerankerNode] }, + { + label: t('views.applicationWorkflow.nodes.classify.businessLogic'), + list: [loopNode, conditionNode, formNode, variableAssignNode, replyNode], + }, + { + label: t('views.applicationWorkflow.nodes.classify.other'), + list: [mcpNode, documentExtractNode, toolNode], + }, +] +export const applicationLoopMenuNodes = [ { label: t('views.applicationWorkflow.nodes.classify.aiCapability'), list: [ @@ -383,6 +489,14 @@ export const menuNodes = [ }, ] +export const getMenuNodes = (workflowMode: WorkflowMode) => { + if (workflowMode == WorkflowMode.Application) { + return menuNodes + } + if (workflowMode == WorkflowMode.ApplicationLoop) { + return applicationLoopMenuNodes + } +} /** * 工具配置数据 @@ -462,6 +576,9 @@ export const nodeDict: any = { [WorkflowType.ImageGenerateNode]: imageGenerateNode, [WorkflowType.VariableAssignNode]: variableAssignNode, [WorkflowType.McpNode]: mcpNode, + [WorkflowType.LoopNode]: loopNode, + [WorkflowType.LoopBodyNode]: loopBodyNode, + [WorkflowType.LoopStartNode]: loopStartNode, } export function isWorkFlow(type: string | undefined) { return type === 'WORK_FLOW' diff --git a/ui/src/workflow/common/loopEdge.ts b/ui/src/workflow/common/loopEdge.ts new file mode 100644 index 000000000..8f53c8dc4 --- /dev/null +++ b/ui/src/workflow/common/loopEdge.ts @@ -0,0 +1,73 @@ +import { BezierEdge, BezierEdgeModel, h } from '@logicflow/core' + +class CustomEdgeModel2 extends BezierEdgeModel { + getArrowStyle() { + const arrowStyle = super.getArrowStyle() + arrowStyle.offset = 0 + arrowStyle.verticalLength = 0 + return arrowStyle + } + + getEdgeStyle() { + const style = super.getEdgeStyle() + // svg属性 + style.strokeWidth = 2 + style.stroke = '#BBBFC4' + style.offset = 0 + return style + } + /** + * 重写此方法,使保存数据是能带上锚点数据。 + */ + getData() { + const data: any = super.getData() + if (data) { + data.sourceAnchorId = this.sourceAnchorId + data.targetAnchorId = this.targetAnchorId + } + return data + } + /** + * 给边自定义方案,使其支持基于锚点的位置更新边的路径 + */ + updatePathByAnchor() { + // TODO + const sourceNodeModel = this.graphModel.getNodeModelById(this.sourceNodeId) + const sourceAnchor = sourceNodeModel + .getDefaultAnchor() + .find((anchor: any) => anchor.id === this.sourceAnchorId) + + const targetNodeModel = this.graphModel.getNodeModelById(this.targetNodeId) + const targetAnchor = targetNodeModel + .getDefaultAnchor() + .find((anchor: any) => anchor.id === this.targetAnchorId) + if (sourceAnchor && targetAnchor) { + const startPoint = { + x: sourceAnchor.x, + y: sourceAnchor.y - 10 + } + this.updateStartPoint(startPoint) + const endPoint = { + x: targetAnchor.x, + y: targetAnchor.y + 3 + } + + this.updateEndPoint(endPoint) + } + + // 这里需要将原有的pointsList设置为空,才能触发bezier的自动计算control点。 + this.pointsList = [] + this.initPoints() + } + setAttributes(): void { + super.setAttributes() + this.isHitable = true + this.zIndex = 0 + } +} + +export default { + type: 'loop-edge', + view: BezierEdge, + model: CustomEdgeModel2 +} diff --git a/ui/src/workflow/common/shortcut.ts b/ui/src/workflow/common/shortcut.ts index 49455c29c..66babfc17 100644 --- a/ui/src/workflow/common/shortcut.ts +++ b/ui/src/workflow/common/shortcut.ts @@ -43,7 +43,7 @@ let CHILDREN_TRANSLATION_DISTANCE = 40 export function initDefaultShortcut(lf: LogicFlow, graph: GraphModel) { const { keyboard } = lf const { - options: { keyboard: keyboardOptions } + options: { keyboard: keyboardOptions }, } = keyboard const copy_node = () => { CHILDREN_TRANSLATION_DISTANCE = TRANSLATION_DISTANCE @@ -57,7 +57,7 @@ export function initDefaultShortcut(lf: LogicFlow, graph: GraphModel) { return true } const base_nodes = elements.nodes.filter( - (node: any) => node.type === WorkflowType.Start || node.type === WorkflowType.Base + (node: any) => node.type === WorkflowType.Start || node.type === WorkflowType.Base, ) if (base_nodes.length > 0) { MsgError(base_nodes[0]?.properties?.stepName + t('views.applicationWorkflow.tip.cannotCopy')) @@ -91,23 +91,39 @@ export function initDefaultShortcut(lf: LogicFlow, graph: GraphModel) { return } if (elements.edges.length > 0 && elements.nodes.length == 0) { - elements.edges.forEach((edge: any) => lf.deleteEdge(edge.id)) + elements.edges + .filter((edge) => !['loop-edge'].includes(edge.type || '')) + .forEach((edge: any) => lf.deleteEdge(edge.id)) return } - const nodes = elements.nodes.filter((node) => ['start-node', 'base-node'].includes(node.type)) + const nodes = elements.nodes.filter((node) => + ['start-node', 'base-node', 'loop-body-node'].includes(node.type), + ) if (nodes.length > 0) { - MsgError(`${nodes[0].properties?.stepName}${t('views.applicationWorkflow.delete.deleteMessage')}`) + MsgError( + `${nodes[0].properties?.stepName}${t('views.applicationWorkflow.delete.deleteMessage')}`, + ) return } MsgConfirm(t('common.tip'), t('views.applicationWorkflow.delete.confirmTitle'), { confirmButtonText: t('common.confirm'), - confirmButtonClass: 'danger' + confirmButtonClass: 'danger', }).then(() => { if (!keyboardOptions?.enabled) return true if (graph.textEditElement) return true elements.edges.forEach((edge: any) => lf.deleteEdge(edge.id)) - elements.nodes.forEach((node: any) => lf.deleteNode(node.id)) + elements.nodes.forEach((node: any) => { + if (node.type === 'loop-node') { + const next = lf.getNodeOutgoingNode(node.id) + next.forEach((n: any) => { + if (n.type === 'loop-body-node') { + lf.deleteNode(n.id) + } + }) + } + lf.deleteNode(node.id) + }) }) return false diff --git a/ui/src/workflow/common/teleport.ts b/ui/src/workflow/common/teleport.ts index 6f88b414f..eb863f166 100644 --- a/ui/src/workflow/common/teleport.ts +++ b/ui/src/workflow/common/teleport.ts @@ -10,22 +10,26 @@ export function connect( container: HTMLDivElement, node: BaseNodeModel | BaseEdgeModel, graph: GraphModel, - get_props?: any + get_props?: any, + get_provide?: any, ) { if (!get_props) { get_props = (node: BaseNodeModel | BaseEdgeModel, graph: GraphModel) => { return { nodeModel: node, graph } } } + if (!get_provide) { + get_provide = (node: BaseNodeModel | BaseEdgeModel, graph: GraphModel) => ({ + getNode: () => node, + getGraph: () => graph, + }) + } if (active) { items[id] = markRaw( defineComponent({ render: () => h(Teleport, { to: container } as any, [h(component, get_props(node, graph))]), - provide: () => ({ - getNode: () => node, - getGraph: () => graph - }) - }) + provide: () => get_provide(node, graph), + }), ) } } @@ -50,8 +54,8 @@ export function getTeleport(): any { props: { flowId: { type: String, - required: true - } + required: true, + }, }, setup(props) { return () => { @@ -65,16 +69,15 @@ export function getTeleport(): any { // 比对当前界面显示的flowId,只更新items[当前页面flowId:nodeId]的数据 // 比如items[0]属于Page1的数据,那么Page2无论active=true/false,都无法执行items[0] - if (id.startsWith(props.flowId)) { - children.push(items[id]) - } + + children.push(items[id]) }) return h( Fragment, {}, - children.map((item) => h(item)) + children.map((item) => h(item)), ) } - } + }, }) } diff --git a/ui/src/workflow/common/validate.ts b/ui/src/workflow/common/validate.ts index e49548c42..89e4b9f78 100644 --- a/ui/src/workflow/common/validate.ts +++ b/ui/src/workflow/common/validate.ts @@ -10,7 +10,8 @@ const end_nodes: Array = [ WorkflowType.Application, WorkflowType.SpeechToTextNode, WorkflowType.TextToSpeechNode, - WorkflowType.ImageGenerateNode, + WorkflowType.ImageGenerateNode, + WorkflowType.LoopBodyNode, ] export class WorkFlowInstance { nodes diff --git a/ui/src/workflow/index.vue b/ui/src/workflow/index.vue index c4cb769f3..47fee2346 100644 --- a/ui/src/workflow/index.vue +++ b/ui/src/workflow/index.vue @@ -8,6 +8,7 @@ import LogicFlow from '@logicflow/core' import { ref, onMounted, computed } from 'vue' import AppEdge from './common/edge' +import loopEdge from './common/loopEdge' import Control from './common/NodeControl.vue' import { baseNodes } from '@/workflow/common/data' import '@logicflow/extension/lib/style/index.css' @@ -35,22 +36,6 @@ const props = defineProps({ data: Object || null, }) -const defaultData = { - nodes: [...baseNodes], -} -const graphData = computed({ - get: () => { - if (props.data) { - return props.data - } else { - return defaultData - } - }, - set: (value) => { - return value - }, -}) - const lf = ref() onMounted(() => { renderGraphData() @@ -82,6 +67,7 @@ const renderGraphData = (data?: any) => { }, isSilentMode: false, container: container, + saa: 'sssssss', }) lf.value.setTheme({ bezier: { @@ -89,11 +75,16 @@ const renderGraphData = (data?: any) => { strokeWidth: 1, }, }) + lf.value.graphModel.get = 'sdasdaad' lf.value.on('graph:rendered', () => { flowId.value = lf.value.graphModel.flowId }) initDefaultShortcut(lf.value, lf.value.graphModel) - lf.value.batchRegister([...Object.keys(nodes).map((key) => nodes[key].default), AppEdge]) + lf.value.batchRegister([ + ...Object.keys(nodes).map((key) => nodes[key].default), + AppEdge, + loopEdge, + ]) lf.value.setDefaultEdgeType('app-edge') lf.value.render(data ? data : {}) @@ -117,7 +108,17 @@ const validate = () => { return Promise.all(lf.value.graphModel.nodes.map((element: any) => element?.validate?.())) } const getGraphData = () => { - return lf.value.getGraphData() + const graph_data = lf.value.getGraphData() + graph_data.nodes.forEach((node: any) => { + if (node.type === 'loop-body-node') { + const node_model = lf.value.getNodeModelById(node.id) + node_model.set_loop_body() + } + }) + const _graph_data = lf.value.getGraphData() + _graph_data.nodes = _graph_data.nodes.filter((node: any) => node.type !== 'loop-body-node') + _graph_data.edges = graph_data.edges.filter((node: any) => node.type !== 'loop-edge') + return _graph_data } const onmousedown = (shapeItem: ShapeItem) => { diff --git a/ui/src/workflow/nodes/loop-body-node/LoopBodyContainer.vue b/ui/src/workflow/nodes/loop-body-node/LoopBodyContainer.vue new file mode 100644 index 000000000..290bc664f --- /dev/null +++ b/ui/src/workflow/nodes/loop-body-node/LoopBodyContainer.vue @@ -0,0 +1,222 @@ + + + diff --git a/ui/src/workflow/nodes/loop-body-node/index.ts b/ui/src/workflow/nodes/loop-body-node/index.ts new file mode 100644 index 000000000..839b7f0f6 --- /dev/null +++ b/ui/src/workflow/nodes/loop-body-node/index.ts @@ -0,0 +1,44 @@ +import { WorkflowMode } from './../../../enums/application' +import LoopNode from './index.vue' +import { AppNode, AppNodeModel } from '@/workflow/common/app-node' +class LoopBodyNodeView extends AppNode { + constructor(props: any) { + super(props, LoopNode) + } + get_up_node_field_list(contain_self: boolean, use_cache: boolean) { + const loop_node_id = this.props.model.properties.loop_node_id + const loop_node = this.props.graphModel.getNodeModelById(loop_node_id) + return loop_node.get_up_node_field_list(contain_self, use_cache) + } +} +class LoopBodyModel extends AppNodeModel { + refreshBranch() { + // 更新节点连接边的path + this.incoming.edges.forEach((edge: any) => { + // 调用自定义的更新方案 + edge.updatePathByAnchor() + }) + this.outgoing.edges.forEach((edge: any) => { + edge.updatePathByAnchor() + }) + } + getDefaultAnchor() { + const { id, x, y, width, height } = this + const showNode = this.properties.showNode === undefined ? true : this.properties.showNode + const anchors: any = [] + anchors.push({ + edgeAddable: false, + x: x, + y: y - height / 2 + 10, + id: `${id}_children`, + type: 'children', + }) + + return anchors + } +} +export default { + type: 'loop-body-node', + model: LoopBodyModel, + view: LoopBodyNodeView, +} diff --git a/ui/src/workflow/nodes/loop-body-node/index.vue b/ui/src/workflow/nodes/loop-body-node/index.vue new file mode 100644 index 000000000..8713c7aa4 --- /dev/null +++ b/ui/src/workflow/nodes/loop-body-node/index.vue @@ -0,0 +1,100 @@ + + + diff --git a/ui/src/workflow/nodes/loop-node/index.ts b/ui/src/workflow/nodes/loop-node/index.ts new file mode 100644 index 000000000..2d0158303 --- /dev/null +++ b/ui/src/workflow/nodes/loop-node/index.ts @@ -0,0 +1,56 @@ +import LoopNode from './index.vue' +import { AppNode, AppNodeModel } from '@/workflow/common/app-node' +import { WorkflowType } from '@/enums/application' +class LoopNodeView extends AppNode { + constructor(props: any) { + super(props, LoopNode) + } +} +class LoopModel extends AppNodeModel { + refreshBranch() { + // 更新节点连接边的path + this.incoming.edges.forEach((edge: any) => { + // 调用自定义的更新方案 + edge.updatePathByAnchor() + }) + this.outgoing.edges.forEach((edge: any) => { + edge.updatePathByAnchor() + }) + } + getDefaultAnchor() { + const { id, x, y, width, height } = this + const showNode = this.properties.showNode === undefined ? true : this.properties.showNode + const anchors: any = [] + + if (this.type !== WorkflowType.Base) { + if (this.type !== WorkflowType.Start) { + anchors.push({ + x: x - width / 2 + 10, + y: showNode ? y : y - 15, + id: `${id}_left`, + edgeAddable: false, + type: 'left', + }) + } + anchors.push({ + x: x + width / 2 - 10, + y: showNode ? y : y - 15, + id: `${id}_right`, + type: 'right', + }) + } + anchors.push({ + x: x, + y: y + height / 2 - 25, + id: `${id}_children`, + type: 'children', + }) + + return anchors + } +} +export default { + type: 'loop-node', + model: LoopModel, + view: LoopNodeView, +} diff --git a/ui/src/workflow/nodes/loop-node/index.vue b/ui/src/workflow/nodes/loop-node/index.vue new file mode 100644 index 000000000..d4a48e6f0 --- /dev/null +++ b/ui/src/workflow/nodes/loop-node/index.vue @@ -0,0 +1,158 @@ + + + diff --git a/ui/src/workflow/nodes/loop-start-node/index.ts b/ui/src/workflow/nodes/loop-start-node/index.ts new file mode 100644 index 000000000..2e0127a85 --- /dev/null +++ b/ui/src/workflow/nodes/loop-start-node/index.ts @@ -0,0 +1,12 @@ +import LoopStartNodeVue from './index.vue' +import { AppNode, AppNodeModel } from '@/workflow/common/app-node' +class LoopStartNode extends AppNode { + constructor(props: any) { + super(props, LoopStartNodeVue) + } +} +export default { + type: 'loop-start-node', + model: AppNodeModel, + view: LoopStartNode, +} diff --git a/ui/src/workflow/nodes/loop-start-node/index.vue b/ui/src/workflow/nodes/loop-start-node/index.vue new file mode 100644 index 000000000..135382ed3 --- /dev/null +++ b/ui/src/workflow/nodes/loop-start-node/index.vue @@ -0,0 +1,12 @@ + + +