diff --git a/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py b/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py index 390b0765a..ceb7ecc95 100644 --- a/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py +++ b/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py @@ -220,11 +220,13 @@ class BaseChatNode(IChatNode): mcp_tool_ids = list(set(mcp_tool_ids + [mcp_tool_id])) if mcp_source == 'custom' and mcp_servers is not None and '"stdio"' not in mcp_servers: mcp_servers_config = json.loads(mcp_servers) + mcp_servers_config = self.handle_variables(mcp_servers_config) elif mcp_tool_ids: mcp_tools = QuerySet(Tool).filter(id__in=mcp_tool_ids).values() for mcp_tool in mcp_tools: if mcp_tool and mcp_tool['is_active']: mcp_servers_config = {**mcp_servers_config, **json.loads(mcp_tool['code'])} + mcp_servers_config = self.handle_variables(mcp_servers_config) if tool_enable: if tool_ids and len(tool_ids) > 0: # 如果有工具ID,则将其转换为MCP @@ -253,6 +255,22 @@ class BaseChatNode(IChatNode): return None + def handle_variables(self, tool_params): + # 处理参数中的变量 + for k, v in tool_params.items(): + if type(v) == str: + tool_params[k] = self.workflow_manage.generate_prompt(tool_params[k]) + if type(v) == dict: + self.handle_variables(v) + if (type(v) == list) and (type(v[0]) == str): + tool_params[k] = self.get_reference_content(v) + return tool_params + + def get_reference_content(self, fields: List[str]): + return str(self.workflow_manage.get_reference_field( + fields[0], + fields[1:])) + @staticmethod def get_history_message(history_chat_record, dialogue_number, dialogue_type, runtime_node_id): start_index = len(history_chat_record) - dialogue_number diff --git a/apps/application/flow/step_node/mcp_node/impl/base_mcp_node.py b/apps/application/flow/step_node/mcp_node/impl/base_mcp_node.py index e7de53efb..d4a046cc9 100644 --- a/apps/application/flow/step_node/mcp_node/impl/base_mcp_node.py +++ b/apps/application/flow/step_node/mcp_node/impl/base_mcp_node.py @@ -27,10 +27,12 @@ class BaseMcpNode(IMcpNode): if not tool.is_active: raise ValueError(f"Tool with ID {mcp_tool_id} is inactive.") servers = json.loads(tool.code) + servers = self.handle_variables(servers) # 处理servers中的变量 params = json.loads(json.dumps(tool_params)) params = self.handle_variables(params) else: servers = json.loads(mcp_servers) + servers = self.handle_variables(servers) # 处理servers中的变量 params = json.loads(json.dumps(tool_params)) params = self.handle_variables(params)