From df442272e947a728b7673dba13fb646c6a73f1d4 Mon Sep 17 00:00:00 2001 From: CaptainB Date: Sun, 28 Sep 2025 15:33:25 +0800 Subject: [PATCH] feat: add MCP transport validation to ToolExecutor --- apps/common/utils/tool_code.py | 9 ++++++++- apps/tools/serializers/tool.py | 5 ++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/apps/common/utils/tool_code.py b/apps/common/utils/tool_code.py index 1463cca01..414ea23a3 100644 --- a/apps/common/utils/tool_code.py +++ b/apps/common/utils/tool_code.py @@ -1,12 +1,13 @@ # coding=utf-8 import ast -import os import json +import os import subprocess import sys from textwrap import dedent import uuid_utils.compat as uuid +from django.utils.translation import gettext_lazy as _ from maxkb.const import BASE_DIR, CONFIG from maxkb.const import PROJECT_DIR @@ -210,6 +211,12 @@ exec({dedent(code)!a}) if matched: raise Exception(f"keyword '{matched}' is banned in the tool.") + def validate_mcp_transport(self, code_str): + servers = json.loads(code_str) + for server, config in servers.items(): + if config.get('transport') not in ['sse', 'streamable_http']: + raise Exception(_('Only support transport=sse or transport=streamable_http')) + @staticmethod def _exec(_code): return subprocess.run([python_directory, '-c', _code], text=True, capture_output=True) diff --git a/apps/tools/serializers/tool.py b/apps/tools/serializers/tool.py index d50e72206..0cb55e569 100644 --- a/apps/tools/serializers/tool.py +++ b/apps/tools/serializers/tool.py @@ -356,6 +356,7 @@ class ToolSerializer(serializers.Serializer): ToolCreateRequest(data=instance).is_valid(raise_exception=True) # 校验代码是否包括禁止的关键字 ToolExecutor().validate_banned_keywords(instance.get('code', '')) + ToolExecutor().validate_mcp_transport(instance.get('code', '')) tool_id = uuid.uuid7() Tool( @@ -391,6 +392,8 @@ class ToolSerializer(serializers.Serializer): self.is_valid(raise_exception=True) # 校验代码是否包括禁止的关键字 ToolExecutor().validate_banned_keywords(self.data.get('code', '')) + ToolExecutor().validate_mcp_transport(self.data.get('code', '')) + # 校验mcp json validate_mcp_config(json.loads(self.data.get('code'))) return True @@ -484,7 +487,7 @@ class ToolSerializer(serializers.Serializer): ToolEditRequest(data=instance).is_valid(raise_exception=True) # 校验代码是否包括禁止的关键字 ToolExecutor().validate_banned_keywords(instance.get('code', '')) - + ToolExecutor().validate_mcp_transport(instance.get('code', '')) if not QuerySet(Tool).filter(id=self.data.get('id')).exists(): raise serializers.ValidationError(_('Tool not found'))