From e9c8c9581ff369f81fac6a9de939c57c426f4cfa Mon Sep 17 00:00:00 2001 From: CaptainB Date: Tue, 12 Aug 2025 14:11:07 +0800 Subject: [PATCH] feat: add MCP tool ID and source fields to chat node for enhanced configuration --- .../ai_chat_step_node/i_chat_node.py | 13 +++- .../ai_chat_step_node/impl/base_chat_node.py | 73 +++++++++++++++++-- 2 files changed, 76 insertions(+), 10 deletions(-) diff --git a/apps/application/flow/step_node/ai_chat_step_node/i_chat_node.py b/apps/application/flow/step_node/ai_chat_step_node/i_chat_node.py index a19ca3a88..79c2e1337 100644 --- a/apps/application/flow/step_node/ai_chat_step_node/i_chat_node.py +++ b/apps/application/flow/step_node/ai_chat_step_node/i_chat_node.py @@ -31,9 +31,14 @@ class ChatNodeSerializer(serializers.Serializer): label='Model settings') dialogue_type = serializers.CharField(required=False, allow_blank=True, allow_null=True, label=_("Context Type")) - mcp_enable = serializers.BooleanField(required=False, - label=_("Whether to enable MCP")) + mcp_enable = serializers.BooleanField(required=False, label=_("Whether to enable MCP")) mcp_servers = serializers.JSONField(required=False, label=_("MCP Server")) + mcp_tool_id = serializers.CharField(required=False, allow_blank=True, allow_null=True, label=_("MCP Tool ID")) + mcp_source = serializers.CharField(required=False, allow_blank=True, allow_null=True, label=_("MCP Source")) + + tool_enable = serializers.BooleanField(required=False, default=False, label=_("Whether to enable tools")) + tool_ids = serializers.ListField(child=serializers.UUIDField(), required=False, allow_empty=True, + label=_("Tool IDs"), ) class IChatNode(INode): @@ -52,5 +57,9 @@ class IChatNode(INode): model_setting=None, mcp_enable=False, mcp_servers=None, + mcp_tool_id=None, + mcp_source=None, + tool_enable=False, + tool_ids=None, **kwargs) -> NodeResult: pass diff --git a/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py b/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py index b8607cf5d..d42309a8a 100644 --- a/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py +++ b/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py @@ -8,7 +8,7 @@ """ import asyncio import json -import logging +import os import re import time from functools import reduce @@ -23,9 +23,11 @@ from langgraph.prebuilt import create_react_agent from application.flow.i_step_node import NodeResult, INode from application.flow.step_node.ai_chat_step_node.i_chat_node import IChatNode from application.flow.tools import Reasoning +from common.utils.logger import maxkb_logger +from common.utils.tool_code import ToolExecutor from models_provider.models import Model from models_provider.tools import get_model_credential, get_model_instance_by_model_workspace_id -from common.utils.logger import maxkb_logger +from tools.models import Tool tool_message_template = """
@@ -211,6 +213,10 @@ class BaseChatNode(IChatNode): model_setting=None, mcp_enable=False, mcp_servers=None, + mcp_tool_id=None, + mcp_source=None, + tool_enable=False, + tool_ids=None, **kwargs) -> NodeResult: if dialogue_type is None: dialogue_type = 'WORKFLOW' @@ -234,12 +240,13 @@ class BaseChatNode(IChatNode): message_list = self.generate_message_list(system, prompt, history_message) self.context['message_list'] = message_list - if mcp_enable and mcp_servers is not None and '"stdio"' not in mcp_servers: - r = mcp_response_generator(chat_model, message_list, mcp_servers) - return NodeResult( - {'result': r, 'chat_model': chat_model, 'message_list': message_list, - 'history_message': history_message, 'question': question.content}, {}, - _write_context=write_context_stream) + # 处理 MCP 请求 + mcp_result = self._handle_mcp_request( + mcp_enable, tool_enable, mcp_source, mcp_servers, mcp_tool_id, tool_ids, chat_model, message_list, + history_message, question + ) + if mcp_result: + return mcp_result if stream: r = chat_model.stream(message_list) @@ -252,6 +259,48 @@ class BaseChatNode(IChatNode): 'history_message': history_message, 'question': question.content}, {}, _write_context=write_context) + def _handle_mcp_request(self, mcp_enable, tool_enable, mcp_source, mcp_servers, mcp_tool_id, tool_ids, + chat_model, message_list, history_message, question): + if not mcp_enable and not tool_enable: + return None + + mcp_servers_config = {} + + if mcp_enable: + if mcp_source == 'custom' and mcp_servers is not None and '"stdio"' not in mcp_servers: + mcp_servers_config = json.loads(mcp_servers) + elif mcp_tool_id: + mcp_tool = QuerySet(Tool).filter(id=mcp_tool_id).first() + if mcp_tool: + mcp_servers_config = json.loads(mcp_tool.code) + + if tool_enable: + if tool_ids and len(tool_ids) > 0: # 如果有工具ID,则将其转换为MCP + self.context['tool_ids'] = tool_ids + for tool_id in tool_ids: + tool = QuerySet(Tool).filter(id=tool_id).first() + executor = ToolExecutor() + code = executor.generate_mcp_server_code(tool.code) + code_path = f'{executor.sandbox_path}/execute/{tool_id}.py' + with open(code_path, 'w') as f: + f.write(code) + + tool_config = { + 'command': 'python', + 'args': [code_path], + 'transport': 'stdio', + } + mcp_servers_config[str(tool.id)] = tool_config + + if len(mcp_servers_config) > 0: + r = mcp_response_generator(chat_model, message_list, json.dumps(mcp_servers_config)) + return NodeResult( + {'result': r, 'chat_model': chat_model, 'message_list': message_list, + 'history_message': history_message, 'question': question.content}, {}, + _write_context=write_context_stream) + + return None + @staticmethod def get_history_message(history_chat_record, dialogue_number, dialogue_type, runtime_node_id): start_index = len(history_chat_record) - dialogue_number @@ -284,6 +333,14 @@ class BaseChatNode(IChatNode): return result def get_details(self, index: int, **kwargs): + # 删除临时生成的MCP代码文件 + if self.context.get('tool_ids'): + executor = ToolExecutor() + # 清理工具代码文件,延时删除,避免文件被占用 + for tool_id in self.context.get('tool_ids'): + code_path = f'{executor.sandbox_path}/execute/{tool_id}.py' + if os.path.exists(code_path): + os.remove(code_path) return { 'name': self.node.properties.get('stepName'), "index": index,