chore: enhance tool call handling and fragment aggregation in tools.py

This commit is contained in:
CaptainB 2025-12-12 10:47:57 +08:00
parent 6d2f9279c5
commit 351898b5a0

View File

@ -290,26 +290,60 @@ async def _yield_mcp_response(chat_model, message_list, mcp_servers, mcp_output_
agent = create_react_agent(chat_model, tools)
response = agent.astream({"messages": message_list}, stream_mode='messages')
# 用于存储工具调用信息
# 用于存储工具调用信息(按 tool_id以及按 index 聚合分片
tool_calls_info = {}
_tool_fragments = {} # index -> {'id':..., 'name':..., 'arguments':...}
async for chunk in response:
if isinstance(chunk[0], AIMessageChunk):
tool_calls = chunk[0].additional_kwargs.get('tool_calls', [])
for tool_call in tool_calls:
tool_id = tool_call.get('id', '')
if tool_id:
# 保存工具调用的输入
tool_calls_info[tool_id] = {
'name': tool_call.get('function', {}).get('name', ''),
'input': tool_call.get('function', {}).get('arguments', '')
}
idx = tool_call.get('index')
if idx is None:
continue
entry = _tool_fragments.setdefault(idx, {'id': '', 'name': '', 'arguments': ''})
# 更新 id 与 name如果有
if tool_call.get('id'):
entry['id'] = tool_call.get('id')
func = tool_call.get('function', {})
# arguments 可能在 function.arguments 或顶层 arguments
part_args = ''
if isinstance(func, dict) and 'arguments' in func:
part_args = func.get('arguments') or ''
if func.get('name'):
entry['name'] = func.get('name')
else:
part_args = tool_call.get('arguments', '') or ''
# 统一为字符串
if not isinstance(part_args, str):
try:
part_args = json.dumps(part_args, ensure_ascii=False)
except Exception:
part_args = str(part_args)
entry['arguments'] += part_args
# 尝试判断 JSON 是否完整(若 arguments 是 JSON完整则提交到 tool_calls_info
try:
json.loads(entry['arguments'])
if entry['id']:
tool_calls_info[entry['id']] = {
'name': entry.get('name', ''),
'input': entry['arguments']
}
_tool_fragments.pop(idx, None)
except Exception:
# 如果不是完整 JSON继续等待后续片段
pass
yield chunk[0]
if mcp_output_enable and isinstance(chunk[0], ToolMessage):
tool_id = chunk[0].tool_call_id
if tool_id in tool_calls_info:
# 合并输入和输出
tool_info = tool_calls_info[tool_id]
content = generate_tool_message_complete(
tool_info['name'],
@ -335,6 +369,7 @@ async def _yield_mcp_response(chat_model, message_list, mcp_servers, mcp_output_
raise RuntimeError(error_msg) from None
def mcp_response_generator(chat_model, message_list, mcp_servers, mcp_output_enable=True):
"""使用全局事件循环,不创建新实例"""
result_queue = queue.Queue()