From 563516f835be11de7b71f54d558f55427989955c Mon Sep 17 00:00:00 2001 From: CaptainB Date: Mon, 24 Mar 2025 11:05:28 +0800 Subject: [PATCH] feat: AI chat node mcp server config --- .../ai_chat_step_node/i_chat_node.py | 5 + .../ai_chat_step_node/impl/base_chat_node.py | 59 ++++++++++- .../component/McpServersDialog.vue | 98 +++++++++++++++++++ ui/src/workflow/nodes/ai-chat-node/index.vue | 32 ++++++ 4 files changed, 193 insertions(+), 1 deletion(-) create mode 100644 ui/src/views/application/component/McpServersDialog.vue diff --git a/apps/application/flow/step_node/ai_chat_step_node/i_chat_node.py b/apps/application/flow/step_node/ai_chat_step_node/i_chat_node.py index 336753450..a83d2ef57 100644 --- a/apps/application/flow/step_node/ai_chat_step_node/i_chat_node.py +++ b/apps/application/flow/step_node/ai_chat_step_node/i_chat_node.py @@ -33,6 +33,9 @@ class ChatNodeSerializer(serializers.Serializer): error_messages=ErrMessage.dict('Model settings')) dialogue_type = serializers.CharField(required=False, allow_blank=True, allow_null=True, error_messages=ErrMessage.char(_("Context Type"))) + mcp_enable = serializers.BooleanField(required=False, + error_messages=ErrMessage.boolean(_("Whether to enable MCP"))) + mcp_servers = serializers.JSONField(required=False, error_messages=ErrMessage.list(_("MCP Server"))) class IChatNode(INode): @@ -49,5 +52,7 @@ class IChatNode(INode): model_params_setting=None, dialogue_type=None, model_setting=None, + mcp_enable=False, + mcp_servers=None, **kwargs) -> NodeResult: pass diff --git a/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py b/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py index 7da004822..ae84c4bbb 100644 --- a/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py +++ b/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py @@ -6,14 +6,19 @@ @date:2024/6/4 14:30 @desc: """ +import asyncio +import json import re import time from functools import reduce +from types import AsyncGeneratorType from typing import List, Dict from django.db.models import QuerySet from langchain.schema import HumanMessage, SystemMessage -from langchain_core.messages import BaseMessage, AIMessage +from langchain_core.messages import BaseMessage, AIMessage, AIMessageChunk, ToolMessage +from langchain_mcp_adapters.client import MultiServerMCPClient +from langgraph.prebuilt import create_react_agent from application.flow.i_step_node import NodeResult, INode from application.flow.step_node.ai_chat_step_node.i_chat_node import IChatNode @@ -56,6 +61,7 @@ def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INo reasoning = Reasoning(model_setting.get('reasoning_content_start', ''), model_setting.get('reasoning_content_end', '')) response_reasoning_content = False + for chunk in response: reasoning_chunk = reasoning.get_reasoning_content(chunk) content_chunk = reasoning_chunk.get('content') @@ -84,6 +90,47 @@ def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INo _write_context(node_variable, workflow_variable, node, workflow, answer, reasoning_content) + +async def _yield_mcp_response(chat_model, message_list, mcp_servers): + async with MultiServerMCPClient(json.loads(mcp_servers)) as client: + agent = create_react_agent(chat_model, client.get_tools()) + response = agent.astream({"messages": message_list}, stream_mode='messages') + async for chunk in response: + # if isinstance(chunk[0], ToolMessage): + # print(chunk[0]) + if isinstance(chunk[0], AIMessageChunk): + yield chunk[0] + +def mcp_response_generator(chat_model, message_list, mcp_servers): + loop = asyncio.new_event_loop() + try: + async_gen = _yield_mcp_response(chat_model, message_list, mcp_servers) + while True: + try: + chunk = loop.run_until_complete(anext_async(async_gen)) + yield chunk + except StopAsyncIteration: + break + except Exception as e: + print(f'exception: {e}') + finally: + loop.close() + +async def anext_async(agen): + return await agen.__anext__() + +async def _get_mcp_response(chat_model, message_list, mcp_servers): + async with MultiServerMCPClient(json.loads(mcp_servers)) as client: + agent = create_react_agent(chat_model, client.get_tools()) + response = agent.astream({"messages": message_list}, stream_mode='messages') + result = [] + async for chunk in response: + # if isinstance(chunk[0], ToolMessage): + # print(chunk[0].content) + if isinstance(chunk[0], AIMessageChunk): + result.append(chunk[0]) + return result + def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow): """ 写入上下文数据 @@ -142,6 +189,8 @@ class BaseChatNode(IChatNode): model_params_setting=None, dialogue_type=None, model_setting=None, + mcp_enable=False, + mcp_servers=None, **kwargs) -> NodeResult: if dialogue_type is None: dialogue_type = 'WORKFLOW' @@ -163,6 +212,14 @@ class BaseChatNode(IChatNode): self.context['system'] = system message_list = self.generate_message_list(system, prompt, history_message) self.context['message_list'] = message_list + + if mcp_enable and mcp_servers is not None: + r = mcp_response_generator(chat_model, message_list, mcp_servers) + return NodeResult( + {'result': r, 'chat_model': chat_model, 'message_list': message_list, + 'history_message': history_message, 'question': question.content}, {}, + _write_context=write_context_stream) + if stream: r = chat_model.stream(message_list) return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list, diff --git a/ui/src/views/application/component/McpServersDialog.vue b/ui/src/views/application/component/McpServersDialog.vue new file mode 100644 index 000000000..ac7f9c7d1 --- /dev/null +++ b/ui/src/views/application/component/McpServersDialog.vue @@ -0,0 +1,98 @@ + + + diff --git a/ui/src/workflow/nodes/ai-chat-node/index.vue b/ui/src/workflow/nodes/ai-chat-node/index.vue index e19510712..9ab41294a 100644 --- a/ui/src/workflow/nodes/ai-chat-node/index.vue +++ b/ui/src/workflow/nodes/ai-chat-node/index.vue @@ -116,6 +116,22 @@ /> + + + +