feat: add MCP tool ID and source fields to chat node for enhanced configuration

This commit is contained in:
CaptainB 2025-08-12 14:11:07 +08:00
parent 1875368ea8
commit e9c8c9581f
2 changed files with 76 additions and 10 deletions

View File

@ -31,9 +31,14 @@ class ChatNodeSerializer(serializers.Serializer):
label='Model settings')
dialogue_type = serializers.CharField(required=False, allow_blank=True, allow_null=True,
label=_("Context Type"))
mcp_enable = serializers.BooleanField(required=False,
label=_("Whether to enable MCP"))
mcp_enable = serializers.BooleanField(required=False, label=_("Whether to enable MCP"))
mcp_servers = serializers.JSONField(required=False, label=_("MCP Server"))
mcp_tool_id = serializers.CharField(required=False, allow_blank=True, allow_null=True, label=_("MCP Tool ID"))
mcp_source = serializers.CharField(required=False, allow_blank=True, allow_null=True, label=_("MCP Source"))
tool_enable = serializers.BooleanField(required=False, default=False, label=_("Whether to enable tools"))
tool_ids = serializers.ListField(child=serializers.UUIDField(), required=False, allow_empty=True,
label=_("Tool IDs"), )
class IChatNode(INode):
@ -52,5 +57,9 @@ class IChatNode(INode):
model_setting=None,
mcp_enable=False,
mcp_servers=None,
mcp_tool_id=None,
mcp_source=None,
tool_enable=False,
tool_ids=None,
**kwargs) -> NodeResult:
pass

View File

@ -8,7 +8,7 @@
"""
import asyncio
import json
import logging
import os
import re
import time
from functools import reduce
@ -23,9 +23,11 @@ from langgraph.prebuilt import create_react_agent
from application.flow.i_step_node import NodeResult, INode
from application.flow.step_node.ai_chat_step_node.i_chat_node import IChatNode
from application.flow.tools import Reasoning
from common.utils.logger import maxkb_logger
from common.utils.tool_code import ToolExecutor
from models_provider.models import Model
from models_provider.tools import get_model_credential, get_model_instance_by_model_workspace_id
from common.utils.logger import maxkb_logger
from tools.models import Tool
tool_message_template = """
<details>
@ -211,6 +213,10 @@ class BaseChatNode(IChatNode):
model_setting=None,
mcp_enable=False,
mcp_servers=None,
mcp_tool_id=None,
mcp_source=None,
tool_enable=False,
tool_ids=None,
**kwargs) -> NodeResult:
if dialogue_type is None:
dialogue_type = 'WORKFLOW'
@ -234,12 +240,13 @@ class BaseChatNode(IChatNode):
message_list = self.generate_message_list(system, prompt, history_message)
self.context['message_list'] = message_list
if mcp_enable and mcp_servers is not None and '"stdio"' not in mcp_servers:
r = mcp_response_generator(chat_model, message_list, mcp_servers)
return NodeResult(
{'result': r, 'chat_model': chat_model, 'message_list': message_list,
'history_message': history_message, 'question': question.content}, {},
_write_context=write_context_stream)
# 处理 MCP 请求
mcp_result = self._handle_mcp_request(
mcp_enable, tool_enable, mcp_source, mcp_servers, mcp_tool_id, tool_ids, chat_model, message_list,
history_message, question
)
if mcp_result:
return mcp_result
if stream:
r = chat_model.stream(message_list)
@ -252,6 +259,48 @@ class BaseChatNode(IChatNode):
'history_message': history_message, 'question': question.content}, {},
_write_context=write_context)
def _handle_mcp_request(self, mcp_enable, tool_enable, mcp_source, mcp_servers, mcp_tool_id, tool_ids,
chat_model, message_list, history_message, question):
if not mcp_enable and not tool_enable:
return None
mcp_servers_config = {}
if mcp_enable:
if mcp_source == 'custom' and mcp_servers is not None and '"stdio"' not in mcp_servers:
mcp_servers_config = json.loads(mcp_servers)
elif mcp_tool_id:
mcp_tool = QuerySet(Tool).filter(id=mcp_tool_id).first()
if mcp_tool:
mcp_servers_config = json.loads(mcp_tool.code)
if tool_enable:
if tool_ids and len(tool_ids) > 0: # 如果有工具ID则将其转换为MCP
self.context['tool_ids'] = tool_ids
for tool_id in tool_ids:
tool = QuerySet(Tool).filter(id=tool_id).first()
executor = ToolExecutor()
code = executor.generate_mcp_server_code(tool.code)
code_path = f'{executor.sandbox_path}/execute/{tool_id}.py'
with open(code_path, 'w') as f:
f.write(code)
tool_config = {
'command': 'python',
'args': [code_path],
'transport': 'stdio',
}
mcp_servers_config[str(tool.id)] = tool_config
if len(mcp_servers_config) > 0:
r = mcp_response_generator(chat_model, message_list, json.dumps(mcp_servers_config))
return NodeResult(
{'result': r, 'chat_model': chat_model, 'message_list': message_list,
'history_message': history_message, 'question': question.content}, {},
_write_context=write_context_stream)
return None
@staticmethod
def get_history_message(history_chat_record, dialogue_number, dialogue_type, runtime_node_id):
start_index = len(history_chat_record) - dialogue_number
@ -284,6 +333,14 @@ class BaseChatNode(IChatNode):
return result
def get_details(self, index: int, **kwargs):
# 删除临时生成的MCP代码文件
if self.context.get('tool_ids'):
executor = ToolExecutor()
# 清理工具代码文件,延时删除,避免文件被占用
for tool_id in self.context.get('tool_ids'):
code_path = f'{executor.sandbox_path}/execute/{tool_id}.py'
if os.path.exists(code_path):
os.remove(code_path)
return {
'name': self.node.properties.get('stepName'),
"index": index,