diff --git a/apps/application/chat_pipeline/step/chat_step/i_chat_step.py b/apps/application/chat_pipeline/step/chat_step/i_chat_step.py index 6c3012f8e..25cff9d54 100644 --- a/apps/application/chat_pipeline/step/chat_step/i_chat_step.py +++ b/apps/application/chat_pipeline/step/chat_step/i_chat_step.py @@ -82,6 +82,12 @@ class IChatStep(IBaseChatPipelineStep): model_params_setting = serializers.DictField(required=False, allow_null=True, label=_("Model parameter settings")) + mcp_enable = serializers.BooleanField(label="MCP否启用", required=False, default=False) + mcp_tool_ids = serializers.JSONField(label="MCP工具ID列表", required=False, default=list) + mcp_servers = serializers.JSONField(label="MCP服务列表", required=False, default=dict) + mcp_source = serializers.CharField(label="MCP Source", required=False, default="referencing") + tool_enable = serializers.BooleanField(label="工具是否启用", required=False, default=False) + tool_ids = serializers.JSONField(label="工具ID列表", required=False, default=list) def is_valid(self, *, raise_exception=False): super().is_valid(raise_exception=True) @@ -106,5 +112,8 @@ class IChatStep(IBaseChatPipelineStep): paragraph_list=None, manage: PipelineManage = None, padding_problem_text: str = None, stream: bool = True, chat_user_id=None, chat_user_type=None, - no_references_setting=None, model_params_setting=None, model_setting=None, **kwargs): + no_references_setting=None, model_params_setting=None, model_setting=None, + mcp_enable=False, mcp_tool_ids=None, mcp_servers='', mcp_source="referencing", + tool_enable=False, tool_ids=None, + **kwargs): pass diff --git a/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py b/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py index 8e4c410f0..ef5e31704 100644 --- a/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py +++ b/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py @@ -6,7 +6,8 @@ @date:2024/1/9 18:25 @desc: 对话step Base实现 """ -import logging +import json +import os import time import traceback import uuid_utils.compat as uuid @@ -24,10 +25,14 @@ from rest_framework import status from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel from application.chat_pipeline.pipeline_manage import PipelineManage from application.chat_pipeline.step.chat_step.i_chat_step import IChatStep, PostResponseHandler -from application.flow.tools import Reasoning +from application.flow.tools import Reasoning, mcp_response_generator from application.models import ApplicationChatUserStats, ChatUserType from common.utils.logger import maxkb_logger +from common.utils.rsa_util import rsa_long_decrypt +from common.utils.tool_code import ToolExecutor +from maxkb.const import CONFIG from models_provider.tools import get_model_instance_by_model_workspace_id +from tools.models import Tool def add_access_num(chat_user_id=None, chat_user_type=None, application_id=None): @@ -54,6 +59,7 @@ def write_context(step, manage, request_token, response_token, all_text): manage.context['answer_tokens'] = manage.context['answer_tokens'] + response_token + def event_content(response, chat_id, chat_record_id, @@ -169,6 +175,12 @@ class BaseChatStep(IChatStep): no_references_setting=None, model_params_setting=None, model_setting=None, + mcp_enable=False, + mcp_tool_ids=None, + mcp_servers='', + mcp_source="referencing", + tool_enable=False, + tool_ids=None, **kwargs): chat_model = get_model_instance_by_model_workspace_id(model_id, workspace_id, **model_params_setting) if model_id is not None else None @@ -177,14 +189,24 @@ class BaseChatStep(IChatStep): paragraph_list, manage, padding_problem_text, chat_user_id, chat_user_type, no_references_setting, - model_setting) + model_setting, + mcp_enable, mcp_tool_ids, mcp_servers, mcp_source, tool_enable, tool_ids) else: return self.execute_block(message_list, chat_id, problem_text, post_response_handler, chat_model, paragraph_list, manage, padding_problem_text, chat_user_id, chat_user_type, no_references_setting, - model_setting) + model_setting, + mcp_enable, mcp_tool_ids, mcp_servers, mcp_source, tool_enable, tool_ids) def get_details(self, manage, **kwargs): + # 删除临时生成的MCP代码文件 + if self.context.get('execute_ids'): + executor = ToolExecutor(CONFIG.get('SANDBOX')) + # 清理工具代码文件,延时删除,避免文件被占用 + for tool_id in self.context.get('execute_ids'): + code_path = f'{executor.sandbox_path}/execute/{tool_id}.py' + if os.path.exists(code_path): + os.remove(code_path) return { 'step_type': 'chat_step', 'run_time': self.context['run_time'], @@ -206,12 +228,63 @@ class BaseChatStep(IChatStep): result.append({'role': 'ai', 'content': answer_text}) return result - @staticmethod - def get_stream_result(message_list: List[BaseMessage], + def _handle_mcp_request(self, mcp_enable, tool_enable, mcp_source, mcp_servers, mcp_tool_ids, tool_ids, + chat_model, message_list): + if not mcp_enable and not tool_enable: + return None + + mcp_servers_config = {} + + # 迁移过来mcp_source是None + if mcp_source is None: + mcp_source = 'custom' + if mcp_enable: + # 兼容老数据 + if not mcp_tool_ids: + mcp_tool_ids = [] + if mcp_source == 'custom' and mcp_servers is not None and '"stdio"' not in mcp_servers: + mcp_servers_config = json.loads(mcp_servers) + elif mcp_tool_ids: + mcp_tools = QuerySet(Tool).filter(id__in=mcp_tool_ids).values() + for mcp_tool in mcp_tools: + if mcp_tool and mcp_tool['is_active']: + mcp_servers_config = {**mcp_servers_config, **json.loads(mcp_tool['code'])} + + if tool_enable: + if tool_ids and len(tool_ids) > 0: # 如果有工具ID,则将其转换为MCP + self.context['tool_ids'] = tool_ids + self.context['execute_ids'] = [] + for tool_id in tool_ids: + tool = QuerySet(Tool).filter(id=tool_id).first() + if not tool.is_active: + continue + executor = ToolExecutor(CONFIG.get('SANDBOX')) + if tool.init_params is not None: + params = json.loads(rsa_long_decrypt(tool.init_params)) + else: + params = {} + _id, tool_config = executor.get_tool_mcp_config(tool.code, params) + + self.context['execute_ids'].append(_id) + mcp_servers_config[str(tool.id)] = tool_config + + if len(mcp_servers_config) > 0: + return mcp_response_generator(chat_model, message_list, json.dumps(mcp_servers_config)) + + return None + + + def get_stream_result(self, message_list: List[BaseMessage], chat_model: BaseChatModel = None, paragraph_list=None, no_references_setting=None, - problem_text=None): + problem_text=None, + mcp_enable=False, + mcp_tool_ids=None, + mcp_servers='', + mcp_source="referencing", + tool_enable=False, + tool_ids=None): if paragraph_list is None: paragraph_list = [] directly_return_chunk_list = [AIMessageChunk(content=paragraph.content) @@ -227,6 +300,12 @@ class BaseChatStep(IChatStep): return iter([AIMessageChunk( _('Sorry, the AI model is not configured. Please go to the application to set up the AI model first.'))]), False else: + # 处理 MCP 请求 + mcp_result = self._handle_mcp_request( + mcp_enable, tool_enable, mcp_source, mcp_servers, mcp_tool_ids, tool_ids, chat_model, message_list, + ) + if mcp_result: + return mcp_result, True return chat_model.stream(message_list), True def execute_stream(self, message_list: List[BaseMessage], @@ -239,9 +318,15 @@ class BaseChatStep(IChatStep): padding_problem_text: str = None, chat_user_id=None, chat_user_type=None, no_references_setting=None, - model_setting=None): + model_setting=None, + mcp_enable=False, + mcp_tool_ids=None, + mcp_servers='', + mcp_source="referencing", + tool_enable=False, + tool_ids=None): chat_result, is_ai_chat = self.get_stream_result(message_list, chat_model, paragraph_list, - no_references_setting, problem_text) + no_references_setting, problem_text, mcp_enable, mcp_tool_ids, mcp_servers, mcp_source, tool_enable, tool_ids) chat_record_id = uuid.uuid7() r = StreamingHttpResponse( streaming_content=event_content(chat_result, chat_id, chat_record_id, paragraph_list, @@ -253,12 +338,17 @@ class BaseChatStep(IChatStep): r['Cache-Control'] = 'no-cache' return r - @staticmethod - def get_block_result(message_list: List[BaseMessage], + def get_block_result(self, message_list: List[BaseMessage], chat_model: BaseChatModel = None, paragraph_list=None, no_references_setting=None, - problem_text=None): + problem_text=None, + mcp_enable=False, + mcp_tool_ids=None, + mcp_servers='', + mcp_source="referencing", + tool_enable=False, + tool_ids=None): if paragraph_list is None: paragraph_list = [] directly_return_chunk_list = [AIMessageChunk(content=paragraph.content) @@ -273,6 +363,12 @@ class BaseChatStep(IChatStep): return AIMessage( _('Sorry, the AI model is not configured. Please go to the application to set up the AI model first.')), False else: + # 处理 MCP 请求 + mcp_result = self._handle_mcp_request( + mcp_enable, tool_enable, mcp_source, mcp_servers, mcp_tool_ids, tool_ids, chat_model, message_list, + ) + if mcp_result: + return mcp_result, True return chat_model.invoke(message_list), True def execute_block(self, message_list: List[BaseMessage], @@ -284,7 +380,13 @@ class BaseChatStep(IChatStep): manage: PipelineManage = None, padding_problem_text: str = None, chat_user_id=None, chat_user_type=None, no_references_setting=None, - model_setting=None): + model_setting=None, + mcp_enable=False, + mcp_tool_ids=None, + mcp_servers='', + mcp_source="referencing", + tool_enable=False, + tool_ids=None): reasoning_content_enable = model_setting.get('reasoning_content_enable', False) reasoning_content_start = model_setting.get('reasoning_content_start', '') reasoning_content_end = model_setting.get('reasoning_content_end', '') @@ -294,7 +396,7 @@ class BaseChatStep(IChatStep): # 调用模型 try: chat_result, is_ai_chat = self.get_block_result(message_list, chat_model, paragraph_list, - no_references_setting, problem_text) + no_references_setting, problem_text, mcp_enable, mcp_tool_ids, mcp_servers, mcp_source, tool_enable, tool_ids) if is_ai_chat: request_token = chat_model.get_num_tokens_from_messages(message_list) response_token = chat_model.get_num_tokens(chat_result.content) diff --git a/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py b/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py index a2839cd90..31e0cb834 100644 --- a/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py +++ b/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py @@ -6,27 +6,21 @@ @date:2024/6/4 14:30 @desc: """ -import asyncio import json import os import re -import sys import time -import traceback from functools import reduce from typing import List, Dict -import uuid_utils.compat as uuid from django.db.models import QuerySet from langchain.schema import HumanMessage, SystemMessage -from langchain_core.messages import BaseMessage, AIMessage, AIMessageChunk, ToolMessage -from langchain_mcp_adapters.client import MultiServerMCPClient -from langgraph.prebuilt import create_react_agent +from langchain_core.messages import BaseMessage, AIMessage + from application.flow.i_step_node import NodeResult, INode from application.flow.step_node.ai_chat_step_node.i_chat_node import IChatNode -from application.flow.tools import Reasoning -from common.utils.logger import maxkb_logger +from application.flow.tools import Reasoning, mcp_response_generator from common.utils.rsa_util import rsa_long_decrypt from common.utils.tool_code import ToolExecutor from maxkb.const import CONFIG @@ -34,31 +28,6 @@ from models_provider.models import Model from models_provider.tools import get_model_credential, get_model_instance_by_model_workspace_id from tools.models import Tool -tool_message_template = """ -
- - Called MCP Tool: %s - - -%s - -
- -""" - -tool_message_json_template = """ -```json -%s -``` -""" - - -def generate_tool_message_template(name, context): - if '```' in context: - return tool_message_template % (name, context) - else: - return tool_message_template % (name, tool_message_json_template % (context)) - def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str, reasoning_content: str): @@ -122,39 +91,6 @@ def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INo _write_context(node_variable, workflow_variable, node, workflow, answer, reasoning_content) -async def _yield_mcp_response(chat_model, message_list, mcp_servers): - client = MultiServerMCPClient(json.loads(mcp_servers)) - tools = await client.get_tools() - agent = create_react_agent(chat_model, tools) - response = agent.astream({"messages": message_list}, stream_mode='messages') - async for chunk in response: - if isinstance(chunk[0], ToolMessage): - content = generate_tool_message_template(chunk[0].name, chunk[0].content) - chunk[0].content = content - yield chunk[0] - if isinstance(chunk[0], AIMessageChunk): - yield chunk[0] - - -def mcp_response_generator(chat_model, message_list, mcp_servers): - loop = asyncio.new_event_loop() - try: - async_gen = _yield_mcp_response(chat_model, message_list, mcp_servers) - while True: - try: - chunk = loop.run_until_complete(anext_async(async_gen)) - yield chunk - except StopAsyncIteration: - break - except Exception as e: - maxkb_logger.error(f'Exception: {e}', traceback.format_exc()) - finally: - loop.close() - - -async def anext_async(agen): - return await agen.__anext__() - def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow): """ diff --git a/apps/application/flow/tools.py b/apps/application/flow/tools.py index 60cdbb408..a76d35308 100644 --- a/apps/application/flow/tools.py +++ b/apps/application/flow/tools.py @@ -6,14 +6,18 @@ @date:2024/6/6 15:15 @desc: """ +import asyncio import json +import traceback from typing import Iterator from django.http import StreamingHttpResponse -from langchain_core.messages import BaseMessageChunk, BaseMessage - +from langchain_core.messages import BaseMessageChunk, BaseMessage, ToolMessage, AIMessageChunk +from langchain_mcp_adapters.client import MultiServerMCPClient +from langgraph.prebuilt import create_react_agent from application.flow.i_step_node import WorkFlowPostHandler from common.result import result +from common.utils.logger import maxkb_logger class Reasoning: @@ -196,3 +200,62 @@ def to_stream_response_simple(stream_event): r['Cache-Control'] = 'no-cache' return r + +tool_message_template = """ +
+ + Called MCP Tool: %s + + +%s + +
+ +""" + +tool_message_json_template = """ +```json +%s +``` +""" + + +def generate_tool_message_template(name, context): + if '```' in context: + return tool_message_template % (name, context) + else: + return tool_message_template % (name, tool_message_json_template % (context)) + + +async def _yield_mcp_response(chat_model, message_list, mcp_servers): + client = MultiServerMCPClient(json.loads(mcp_servers)) + tools = await client.get_tools() + agent = create_react_agent(chat_model, tools) + response = agent.astream({"messages": message_list}, stream_mode='messages') + async for chunk in response: + if isinstance(chunk[0], ToolMessage): + content = generate_tool_message_template(chunk[0].name, chunk[0].content) + chunk[0].content = content + yield chunk[0] + if isinstance(chunk[0], AIMessageChunk): + yield chunk[0] + + +def mcp_response_generator(chat_model, message_list, mcp_servers): + loop = asyncio.new_event_loop() + try: + async_gen = _yield_mcp_response(chat_model, message_list, mcp_servers) + while True: + try: + chunk = loop.run_until_complete(anext_async(async_gen)) + yield chunk + except StopAsyncIteration: + break + except Exception as e: + maxkb_logger.error(f'Exception: {e}', traceback.format_exc()) + finally: + loop.close() + + +async def anext_async(agen): + return await agen.__anext__() diff --git a/apps/application/migrations/0002_application_mcp_enable_application_mcp_servers_and_more.py b/apps/application/migrations/0002_application_mcp_enable_application_mcp_servers_and_more.py new file mode 100644 index 000000000..6e27dab8e --- /dev/null +++ b/apps/application/migrations/0002_application_mcp_enable_application_mcp_servers_and_more.py @@ -0,0 +1,73 @@ +# Generated by Django 5.2.4 on 2025-09-08 03:16 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('application', '0001_initial'), + ] + + operations = [ + migrations.AddField( + model_name='application', + name='mcp_enable', + field=models.BooleanField(default=False, verbose_name='MCP否启用'), + ), + migrations.AddField( + model_name='application', + name='mcp_servers', + field=models.JSONField(default=dict, verbose_name='MCP服务列表'), + ), + migrations.AddField( + model_name='application', + name='mcp_source', + field=models.CharField(default='referencing', max_length=20, verbose_name='MCP Source'), + ), + migrations.AddField( + model_name='application', + name='mcp_tool_ids', + field=models.JSONField(default=list, verbose_name='MCP工具ID列表'), + ), + migrations.AddField( + model_name='application', + name='tool_enable', + field=models.BooleanField(default=False, verbose_name='工具是否启用'), + ), + migrations.AddField( + model_name='application', + name='tool_ids', + field=models.JSONField(default=list, verbose_name='工具ID列表'), + ), + migrations.AddField( + model_name='applicationversion', + name='mcp_enable', + field=models.BooleanField(default=False, verbose_name='MCP否启用'), + ), + migrations.AddField( + model_name='applicationversion', + name='mcp_servers', + field=models.JSONField(default=dict, verbose_name='MCP服务列表'), + ), + migrations.AddField( + model_name='applicationversion', + name='mcp_source', + field=models.CharField(default='referencing', max_length=20, verbose_name='MCP Source'), + ), + migrations.AddField( + model_name='applicationversion', + name='mcp_tool_ids', + field=models.JSONField(default=list, verbose_name='MCP工具ID列表'), + ), + migrations.AddField( + model_name='applicationversion', + name='tool_enable', + field=models.BooleanField(default=False, verbose_name='工具是否启用'), + ), + migrations.AddField( + model_name='applicationversion', + name='tool_ids', + field=models.JSONField(default=list, verbose_name='工具ID列表'), + ), + ] diff --git a/apps/application/models/application.py b/apps/application/models/application.py index a8a38b960..548b02627 100644 --- a/apps/application/models/application.py +++ b/apps/application/models/application.py @@ -93,6 +93,12 @@ class Application(AppModelMixin): publish_time = models.DateTimeField(verbose_name="发布时间", default=None, null=True, blank=True) file_upload_enable = models.BooleanField(verbose_name="文件上传是否启用", default=False) file_upload_setting = models.JSONField(verbose_name="文件上传相关设置", default=dict) + mcp_enable = models.BooleanField(verbose_name="MCP否启用", default=False) + mcp_tool_ids = models.JSONField(verbose_name="MCP工具ID列表", default=list) + mcp_servers = models.JSONField(verbose_name="MCP服务列表", default=dict) + mcp_source = models.CharField(verbose_name="MCP Source", max_length=20, default="referencing") + tool_enable = models.BooleanField(verbose_name="工具是否启用", default=False) + tool_ids = models.JSONField(verbose_name="工具ID列表", default=list) @staticmethod def get_default_model_prompt(): @@ -158,6 +164,12 @@ class ApplicationVersion(AppModelMixin): clean_time = models.IntegerField(verbose_name="清理时间", default=180) file_upload_enable = models.BooleanField(verbose_name="文件上传是否启用", default=False) file_upload_setting = models.JSONField(verbose_name="文件上传相关设置", default=dict) + mcp_enable = models.BooleanField(verbose_name="MCP否启用", default=False) + mcp_tool_ids = models.JSONField(verbose_name="MCP工具ID列表", default=list) + mcp_servers = models.JSONField(verbose_name="MCP服务列表", default=dict) + mcp_source = models.CharField(verbose_name="MCP Source", max_length=20, default="referencing") + tool_enable = models.BooleanField(verbose_name="工具是否启用", default=False) + tool_ids = models.JSONField(verbose_name="工具ID列表", default=list) class Meta: db_table = "application_version" diff --git a/apps/application/serializers/application.py b/apps/application/serializers/application.py index a8ed999d4..b5c10836d 100644 --- a/apps/application/serializers/application.py +++ b/apps/application/serializers/application.py @@ -706,6 +706,8 @@ class ApplicationOperateSerializer(serializers.Serializer): 'stt_model_enable': 'stt_model_enable', 'tts_type': 'tts_type', 'tts_autoplay': 'tts_autoplay', 'stt_autosend': 'stt_autosend', 'file_upload_enable': 'file_upload_enable', 'file_upload_setting': 'file_upload_setting', + 'mcp_enable': 'mcp_enable', 'mcp_tool_ids': 'mcp_tool_ids', 'mcp_servers': 'mcp_servers', + 'mcp_source': 'mcp_source', 'tool_enable': 'tool_enable', 'tool_ids': 'tool_ids', 'type': 'type' } @@ -829,6 +831,7 @@ class ApplicationOperateSerializer(serializers.Serializer): 'stt_model_id', 'tts_model_id', 'tts_model_enable', 'stt_model_enable', 'tts_type', 'tts_autoplay', 'stt_autosend', 'file_upload_enable', 'file_upload_setting', 'api_key_is_active', 'icon', 'work_flow', 'model_params_setting', 'tts_model_params_setting', + 'mcp_enable', 'mcp_tool_ids', 'mcp_servers', 'mcp_source', 'tool_enable', 'tool_ids', 'problem_optimization_prompt', 'clean_time', 'folder_id'] for update_key in update_keys: if update_key in instance and instance.get(update_key) is not None: diff --git a/apps/application/serializers/common.py b/apps/application/serializers/common.py index 86263fd18..1a5db5ef3 100644 --- a/apps/application/serializers/common.py +++ b/apps/application/serializers/common.py @@ -154,7 +154,13 @@ class ChatInfo: 'search_mode': self.application.knowledge_setting.get('search_mode') or 'embedding', 'no_references_setting': self.get_no_references_setting(self.application.knowledge_setting, model_setting), 'workspace_id': self.application.workspace_id, - 'application_id': self.application.id + 'application_id': self.application.id, + 'mcp_enable': self.application.mcp_enable, + 'mcp_tool_ids': self.application.mcp_tool_ids, + 'mcp_servers': self.application.mcp_servers, + 'mcp_source': self.application.mcp_source, + 'tool_enable': self.application.tool_enable, + 'tool_ids': self.application.tool_ids, } def to_pipeline_manage_params(self, problem_text: str, post_response_handler: PostResponseHandler, diff --git a/ui/src/api/type/application.ts b/ui/src/api/type/application.ts index a52cc0b6f..6fbb603cc 100644 --- a/ui/src/api/type/application.ts +++ b/ui/src/api/type/application.ts @@ -26,6 +26,12 @@ interface ApplicationFormType { stt_autosend?: boolean folder_id?: string workspace_id?: string + mcp_enable?: boolean + mcp_servers?: string + mcp_tool_ids?: string[] + mcp_source?: string + tool_enable?: boolean + tool_ids?: string[] } interface Chunk { real_node_id: string diff --git a/ui/src/views/application/ApplicationSetting.vue b/ui/src/views/application/ApplicationSetting.vue index ca28665b8..c13375bd7 100644 --- a/ui/src/views/application/ApplicationSetting.vue +++ b/ui/src/views/application/ApplicationSetting.vue @@ -273,6 +273,111 @@ @submitDialog="submitPrologueDialog" /> + + + + +
+ +
+ + + + +
+ +