mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 01:33:05 +00:00
chore: enhance tool call handling and fragment aggregation in tools.py
This commit is contained in:
parent
6d2f9279c5
commit
351898b5a0
|
|
@ -290,26 +290,60 @@ async def _yield_mcp_response(chat_model, message_list, mcp_servers, mcp_output_
|
|||
agent = create_react_agent(chat_model, tools)
|
||||
response = agent.astream({"messages": message_list}, stream_mode='messages')
|
||||
|
||||
# 用于存储工具调用信息
|
||||
# 用于存储工具调用信息(按 tool_id)以及按 index 聚合分片
|
||||
tool_calls_info = {}
|
||||
_tool_fragments = {} # index -> {'id':..., 'name':..., 'arguments':...}
|
||||
|
||||
async for chunk in response:
|
||||
if isinstance(chunk[0], AIMessageChunk):
|
||||
tool_calls = chunk[0].additional_kwargs.get('tool_calls', [])
|
||||
for tool_call in tool_calls:
|
||||
tool_id = tool_call.get('id', '')
|
||||
if tool_id:
|
||||
# 保存工具调用的输入
|
||||
tool_calls_info[tool_id] = {
|
||||
'name': tool_call.get('function', {}).get('name', ''),
|
||||
'input': tool_call.get('function', {}).get('arguments', '')
|
||||
}
|
||||
idx = tool_call.get('index')
|
||||
if idx is None:
|
||||
continue
|
||||
entry = _tool_fragments.setdefault(idx, {'id': '', 'name': '', 'arguments': ''})
|
||||
|
||||
# 更新 id 与 name(如果有)
|
||||
if tool_call.get('id'):
|
||||
entry['id'] = tool_call.get('id')
|
||||
|
||||
func = tool_call.get('function', {})
|
||||
# arguments 可能在 function.arguments 或顶层 arguments
|
||||
part_args = ''
|
||||
if isinstance(func, dict) and 'arguments' in func:
|
||||
part_args = func.get('arguments') or ''
|
||||
if func.get('name'):
|
||||
entry['name'] = func.get('name')
|
||||
else:
|
||||
part_args = tool_call.get('arguments', '') or ''
|
||||
|
||||
# 统一为字符串
|
||||
if not isinstance(part_args, str):
|
||||
try:
|
||||
part_args = json.dumps(part_args, ensure_ascii=False)
|
||||
except Exception:
|
||||
part_args = str(part_args)
|
||||
|
||||
entry['arguments'] += part_args
|
||||
|
||||
# 尝试判断 JSON 是否完整(若 arguments 是 JSON),完整则提交到 tool_calls_info
|
||||
try:
|
||||
json.loads(entry['arguments'])
|
||||
if entry['id']:
|
||||
tool_calls_info[entry['id']] = {
|
||||
'name': entry.get('name', ''),
|
||||
'input': entry['arguments']
|
||||
}
|
||||
_tool_fragments.pop(idx, None)
|
||||
except Exception:
|
||||
# 如果不是完整 JSON,继续等待后续片段
|
||||
pass
|
||||
|
||||
yield chunk[0]
|
||||
|
||||
if mcp_output_enable and isinstance(chunk[0], ToolMessage):
|
||||
tool_id = chunk[0].tool_call_id
|
||||
if tool_id in tool_calls_info:
|
||||
# 合并输入和输出
|
||||
tool_info = tool_calls_info[tool_id]
|
||||
content = generate_tool_message_complete(
|
||||
tool_info['name'],
|
||||
|
|
@ -335,6 +369,7 @@ async def _yield_mcp_response(chat_model, message_list, mcp_servers, mcp_output_
|
|||
raise RuntimeError(error_msg) from None
|
||||
|
||||
|
||||
|
||||
def mcp_response_generator(chat_model, message_list, mcp_servers, mcp_output_enable=True):
|
||||
"""使用全局事件循环,不创建新实例"""
|
||||
result_queue = queue.Queue()
|
||||
|
|
|
|||
Loading…
Reference in New Issue