diff --git a/apps/application/chat_pipeline/step/chat_step/i_chat_step.py b/apps/application/chat_pipeline/step/chat_step/i_chat_step.py
index 6c3012f8e..25cff9d54 100644
--- a/apps/application/chat_pipeline/step/chat_step/i_chat_step.py
+++ b/apps/application/chat_pipeline/step/chat_step/i_chat_step.py
@@ -82,6 +82,12 @@ class IChatStep(IBaseChatPipelineStep):
model_params_setting = serializers.DictField(required=False, allow_null=True,
label=_("Model parameter settings"))
+ mcp_enable = serializers.BooleanField(label="MCP否启用", required=False, default=False)
+ mcp_tool_ids = serializers.JSONField(label="MCP工具ID列表", required=False, default=list)
+ mcp_servers = serializers.JSONField(label="MCP服务列表", required=False, default=dict)
+ mcp_source = serializers.CharField(label="MCP Source", required=False, default="referencing")
+ tool_enable = serializers.BooleanField(label="工具是否启用", required=False, default=False)
+ tool_ids = serializers.JSONField(label="工具ID列表", required=False, default=list)
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
@@ -106,5 +112,8 @@ class IChatStep(IBaseChatPipelineStep):
paragraph_list=None,
manage: PipelineManage = None,
padding_problem_text: str = None, stream: bool = True, chat_user_id=None, chat_user_type=None,
- no_references_setting=None, model_params_setting=None, model_setting=None, **kwargs):
+ no_references_setting=None, model_params_setting=None, model_setting=None,
+ mcp_enable=False, mcp_tool_ids=None, mcp_servers='', mcp_source="referencing",
+ tool_enable=False, tool_ids=None,
+ **kwargs):
pass
diff --git a/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py b/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py
index 8e4c410f0..ef5e31704 100644
--- a/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py
+++ b/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py
@@ -6,7 +6,8 @@
@date:2024/1/9 18:25
@desc: 对话step Base实现
"""
-import logging
+import json
+import os
import time
import traceback
import uuid_utils.compat as uuid
@@ -24,10 +25,14 @@ from rest_framework import status
from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel
from application.chat_pipeline.pipeline_manage import PipelineManage
from application.chat_pipeline.step.chat_step.i_chat_step import IChatStep, PostResponseHandler
-from application.flow.tools import Reasoning
+from application.flow.tools import Reasoning, mcp_response_generator
from application.models import ApplicationChatUserStats, ChatUserType
from common.utils.logger import maxkb_logger
+from common.utils.rsa_util import rsa_long_decrypt
+from common.utils.tool_code import ToolExecutor
+from maxkb.const import CONFIG
from models_provider.tools import get_model_instance_by_model_workspace_id
+from tools.models import Tool
def add_access_num(chat_user_id=None, chat_user_type=None, application_id=None):
@@ -54,6 +59,7 @@ def write_context(step, manage, request_token, response_token, all_text):
manage.context['answer_tokens'] = manage.context['answer_tokens'] + response_token
+
def event_content(response,
chat_id,
chat_record_id,
@@ -169,6 +175,12 @@ class BaseChatStep(IChatStep):
no_references_setting=None,
model_params_setting=None,
model_setting=None,
+ mcp_enable=False,
+ mcp_tool_ids=None,
+ mcp_servers='',
+ mcp_source="referencing",
+ tool_enable=False,
+ tool_ids=None,
**kwargs):
chat_model = get_model_instance_by_model_workspace_id(model_id, workspace_id,
**model_params_setting) if model_id is not None else None
@@ -177,14 +189,24 @@ class BaseChatStep(IChatStep):
paragraph_list,
manage, padding_problem_text, chat_user_id, chat_user_type,
no_references_setting,
- model_setting)
+ model_setting,
+ mcp_enable, mcp_tool_ids, mcp_servers, mcp_source, tool_enable, tool_ids)
else:
return self.execute_block(message_list, chat_id, problem_text, post_response_handler, chat_model,
paragraph_list,
manage, padding_problem_text, chat_user_id, chat_user_type, no_references_setting,
- model_setting)
+ model_setting,
+ mcp_enable, mcp_tool_ids, mcp_servers, mcp_source, tool_enable, tool_ids)
def get_details(self, manage, **kwargs):
+ # 删除临时生成的MCP代码文件
+ if self.context.get('execute_ids'):
+ executor = ToolExecutor(CONFIG.get('SANDBOX'))
+ # 清理工具代码文件,延时删除,避免文件被占用
+ for tool_id in self.context.get('execute_ids'):
+ code_path = f'{executor.sandbox_path}/execute/{tool_id}.py'
+ if os.path.exists(code_path):
+ os.remove(code_path)
return {
'step_type': 'chat_step',
'run_time': self.context['run_time'],
@@ -206,12 +228,63 @@ class BaseChatStep(IChatStep):
result.append({'role': 'ai', 'content': answer_text})
return result
- @staticmethod
- def get_stream_result(message_list: List[BaseMessage],
+ def _handle_mcp_request(self, mcp_enable, tool_enable, mcp_source, mcp_servers, mcp_tool_ids, tool_ids,
+ chat_model, message_list):
+ if not mcp_enable and not tool_enable:
+ return None
+
+ mcp_servers_config = {}
+
+ # 迁移过来mcp_source是None
+ if mcp_source is None:
+ mcp_source = 'custom'
+ if mcp_enable:
+ # 兼容老数据
+ if not mcp_tool_ids:
+ mcp_tool_ids = []
+ if mcp_source == 'custom' and mcp_servers is not None and '"stdio"' not in mcp_servers:
+ mcp_servers_config = json.loads(mcp_servers)
+ elif mcp_tool_ids:
+ mcp_tools = QuerySet(Tool).filter(id__in=mcp_tool_ids).values()
+ for mcp_tool in mcp_tools:
+ if mcp_tool and mcp_tool['is_active']:
+ mcp_servers_config = {**mcp_servers_config, **json.loads(mcp_tool['code'])}
+
+ if tool_enable:
+ if tool_ids and len(tool_ids) > 0: # 如果有工具ID,则将其转换为MCP
+ self.context['tool_ids'] = tool_ids
+ self.context['execute_ids'] = []
+ for tool_id in tool_ids:
+ tool = QuerySet(Tool).filter(id=tool_id).first()
+ if not tool.is_active:
+ continue
+ executor = ToolExecutor(CONFIG.get('SANDBOX'))
+ if tool.init_params is not None:
+ params = json.loads(rsa_long_decrypt(tool.init_params))
+ else:
+ params = {}
+ _id, tool_config = executor.get_tool_mcp_config(tool.code, params)
+
+ self.context['execute_ids'].append(_id)
+ mcp_servers_config[str(tool.id)] = tool_config
+
+ if len(mcp_servers_config) > 0:
+ return mcp_response_generator(chat_model, message_list, json.dumps(mcp_servers_config))
+
+ return None
+
+
+ def get_stream_result(self, message_list: List[BaseMessage],
chat_model: BaseChatModel = None,
paragraph_list=None,
no_references_setting=None,
- problem_text=None):
+ problem_text=None,
+ mcp_enable=False,
+ mcp_tool_ids=None,
+ mcp_servers='',
+ mcp_source="referencing",
+ tool_enable=False,
+ tool_ids=None):
if paragraph_list is None:
paragraph_list = []
directly_return_chunk_list = [AIMessageChunk(content=paragraph.content)
@@ -227,6 +300,12 @@ class BaseChatStep(IChatStep):
return iter([AIMessageChunk(
_('Sorry, the AI model is not configured. Please go to the application to set up the AI model first.'))]), False
else:
+ # 处理 MCP 请求
+ mcp_result = self._handle_mcp_request(
+ mcp_enable, tool_enable, mcp_source, mcp_servers, mcp_tool_ids, tool_ids, chat_model, message_list,
+ )
+ if mcp_result:
+ return mcp_result, True
return chat_model.stream(message_list), True
def execute_stream(self, message_list: List[BaseMessage],
@@ -239,9 +318,15 @@ class BaseChatStep(IChatStep):
padding_problem_text: str = None,
chat_user_id=None, chat_user_type=None,
no_references_setting=None,
- model_setting=None):
+ model_setting=None,
+ mcp_enable=False,
+ mcp_tool_ids=None,
+ mcp_servers='',
+ mcp_source="referencing",
+ tool_enable=False,
+ tool_ids=None):
chat_result, is_ai_chat = self.get_stream_result(message_list, chat_model, paragraph_list,
- no_references_setting, problem_text)
+ no_references_setting, problem_text, mcp_enable, mcp_tool_ids, mcp_servers, mcp_source, tool_enable, tool_ids)
chat_record_id = uuid.uuid7()
r = StreamingHttpResponse(
streaming_content=event_content(chat_result, chat_id, chat_record_id, paragraph_list,
@@ -253,12 +338,17 @@ class BaseChatStep(IChatStep):
r['Cache-Control'] = 'no-cache'
return r
- @staticmethod
- def get_block_result(message_list: List[BaseMessage],
+ def get_block_result(self, message_list: List[BaseMessage],
chat_model: BaseChatModel = None,
paragraph_list=None,
no_references_setting=None,
- problem_text=None):
+ problem_text=None,
+ mcp_enable=False,
+ mcp_tool_ids=None,
+ mcp_servers='',
+ mcp_source="referencing",
+ tool_enable=False,
+ tool_ids=None):
if paragraph_list is None:
paragraph_list = []
directly_return_chunk_list = [AIMessageChunk(content=paragraph.content)
@@ -273,6 +363,12 @@ class BaseChatStep(IChatStep):
return AIMessage(
_('Sorry, the AI model is not configured. Please go to the application to set up the AI model first.')), False
else:
+ # 处理 MCP 请求
+ mcp_result = self._handle_mcp_request(
+ mcp_enable, tool_enable, mcp_source, mcp_servers, mcp_tool_ids, tool_ids, chat_model, message_list,
+ )
+ if mcp_result:
+ return mcp_result, True
return chat_model.invoke(message_list), True
def execute_block(self, message_list: List[BaseMessage],
@@ -284,7 +380,13 @@ class BaseChatStep(IChatStep):
manage: PipelineManage = None,
padding_problem_text: str = None,
chat_user_id=None, chat_user_type=None, no_references_setting=None,
- model_setting=None):
+ model_setting=None,
+ mcp_enable=False,
+ mcp_tool_ids=None,
+ mcp_servers='',
+ mcp_source="referencing",
+ tool_enable=False,
+ tool_ids=None):
reasoning_content_enable = model_setting.get('reasoning_content_enable', False)
reasoning_content_start = model_setting.get('reasoning_content_start', '
- Called MCP Tool: %s
-
-
-%s
-
-
+ Called MCP Tool: %s
+
+
+%s
+
+