mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 01:33:05 +00:00
feat: add MCP tool ID and source fields to chat node for enhanced configuration
This commit is contained in:
parent
1875368ea8
commit
e9c8c9581f
|
|
@ -31,9 +31,14 @@ class ChatNodeSerializer(serializers.Serializer):
|
|||
label='Model settings')
|
||||
dialogue_type = serializers.CharField(required=False, allow_blank=True, allow_null=True,
|
||||
label=_("Context Type"))
|
||||
mcp_enable = serializers.BooleanField(required=False,
|
||||
label=_("Whether to enable MCP"))
|
||||
mcp_enable = serializers.BooleanField(required=False, label=_("Whether to enable MCP"))
|
||||
mcp_servers = serializers.JSONField(required=False, label=_("MCP Server"))
|
||||
mcp_tool_id = serializers.CharField(required=False, allow_blank=True, allow_null=True, label=_("MCP Tool ID"))
|
||||
mcp_source = serializers.CharField(required=False, allow_blank=True, allow_null=True, label=_("MCP Source"))
|
||||
|
||||
tool_enable = serializers.BooleanField(required=False, default=False, label=_("Whether to enable tools"))
|
||||
tool_ids = serializers.ListField(child=serializers.UUIDField(), required=False, allow_empty=True,
|
||||
label=_("Tool IDs"), )
|
||||
|
||||
|
||||
class IChatNode(INode):
|
||||
|
|
@ -52,5 +57,9 @@ class IChatNode(INode):
|
|||
model_setting=None,
|
||||
mcp_enable=False,
|
||||
mcp_servers=None,
|
||||
mcp_tool_id=None,
|
||||
mcp_source=None,
|
||||
tool_enable=False,
|
||||
tool_ids=None,
|
||||
**kwargs) -> NodeResult:
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@
|
|||
"""
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from functools import reduce
|
||||
|
|
@ -23,9 +23,11 @@ from langgraph.prebuilt import create_react_agent
|
|||
from application.flow.i_step_node import NodeResult, INode
|
||||
from application.flow.step_node.ai_chat_step_node.i_chat_node import IChatNode
|
||||
from application.flow.tools import Reasoning
|
||||
from common.utils.logger import maxkb_logger
|
||||
from common.utils.tool_code import ToolExecutor
|
||||
from models_provider.models import Model
|
||||
from models_provider.tools import get_model_credential, get_model_instance_by_model_workspace_id
|
||||
from common.utils.logger import maxkb_logger
|
||||
from tools.models import Tool
|
||||
|
||||
tool_message_template = """
|
||||
<details>
|
||||
|
|
@ -211,6 +213,10 @@ class BaseChatNode(IChatNode):
|
|||
model_setting=None,
|
||||
mcp_enable=False,
|
||||
mcp_servers=None,
|
||||
mcp_tool_id=None,
|
||||
mcp_source=None,
|
||||
tool_enable=False,
|
||||
tool_ids=None,
|
||||
**kwargs) -> NodeResult:
|
||||
if dialogue_type is None:
|
||||
dialogue_type = 'WORKFLOW'
|
||||
|
|
@ -234,12 +240,13 @@ class BaseChatNode(IChatNode):
|
|||
message_list = self.generate_message_list(system, prompt, history_message)
|
||||
self.context['message_list'] = message_list
|
||||
|
||||
if mcp_enable and mcp_servers is not None and '"stdio"' not in mcp_servers:
|
||||
r = mcp_response_generator(chat_model, message_list, mcp_servers)
|
||||
return NodeResult(
|
||||
{'result': r, 'chat_model': chat_model, 'message_list': message_list,
|
||||
'history_message': history_message, 'question': question.content}, {},
|
||||
_write_context=write_context_stream)
|
||||
# 处理 MCP 请求
|
||||
mcp_result = self._handle_mcp_request(
|
||||
mcp_enable, tool_enable, mcp_source, mcp_servers, mcp_tool_id, tool_ids, chat_model, message_list,
|
||||
history_message, question
|
||||
)
|
||||
if mcp_result:
|
||||
return mcp_result
|
||||
|
||||
if stream:
|
||||
r = chat_model.stream(message_list)
|
||||
|
|
@ -252,6 +259,48 @@ class BaseChatNode(IChatNode):
|
|||
'history_message': history_message, 'question': question.content}, {},
|
||||
_write_context=write_context)
|
||||
|
||||
def _handle_mcp_request(self, mcp_enable, tool_enable, mcp_source, mcp_servers, mcp_tool_id, tool_ids,
|
||||
chat_model, message_list, history_message, question):
|
||||
if not mcp_enable and not tool_enable:
|
||||
return None
|
||||
|
||||
mcp_servers_config = {}
|
||||
|
||||
if mcp_enable:
|
||||
if mcp_source == 'custom' and mcp_servers is not None and '"stdio"' not in mcp_servers:
|
||||
mcp_servers_config = json.loads(mcp_servers)
|
||||
elif mcp_tool_id:
|
||||
mcp_tool = QuerySet(Tool).filter(id=mcp_tool_id).first()
|
||||
if mcp_tool:
|
||||
mcp_servers_config = json.loads(mcp_tool.code)
|
||||
|
||||
if tool_enable:
|
||||
if tool_ids and len(tool_ids) > 0: # 如果有工具ID,则将其转换为MCP
|
||||
self.context['tool_ids'] = tool_ids
|
||||
for tool_id in tool_ids:
|
||||
tool = QuerySet(Tool).filter(id=tool_id).first()
|
||||
executor = ToolExecutor()
|
||||
code = executor.generate_mcp_server_code(tool.code)
|
||||
code_path = f'{executor.sandbox_path}/execute/{tool_id}.py'
|
||||
with open(code_path, 'w') as f:
|
||||
f.write(code)
|
||||
|
||||
tool_config = {
|
||||
'command': 'python',
|
||||
'args': [code_path],
|
||||
'transport': 'stdio',
|
||||
}
|
||||
mcp_servers_config[str(tool.id)] = tool_config
|
||||
|
||||
if len(mcp_servers_config) > 0:
|
||||
r = mcp_response_generator(chat_model, message_list, json.dumps(mcp_servers_config))
|
||||
return NodeResult(
|
||||
{'result': r, 'chat_model': chat_model, 'message_list': message_list,
|
||||
'history_message': history_message, 'question': question.content}, {},
|
||||
_write_context=write_context_stream)
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_history_message(history_chat_record, dialogue_number, dialogue_type, runtime_node_id):
|
||||
start_index = len(history_chat_record) - dialogue_number
|
||||
|
|
@ -284,6 +333,14 @@ class BaseChatNode(IChatNode):
|
|||
return result
|
||||
|
||||
def get_details(self, index: int, **kwargs):
|
||||
# 删除临时生成的MCP代码文件
|
||||
if self.context.get('tool_ids'):
|
||||
executor = ToolExecutor()
|
||||
# 清理工具代码文件,延时删除,避免文件被占用
|
||||
for tool_id in self.context.get('tool_ids'):
|
||||
code_path = f'{executor.sandbox_path}/execute/{tool_id}.py'
|
||||
if os.path.exists(code_path):
|
||||
os.remove(code_path)
|
||||
return {
|
||||
'name': self.node.properties.get('stepName'),
|
||||
"index": index,
|
||||
|
|
|
|||
Loading…
Reference in New Issue