perf: Memory optimization

This commit is contained in:
zhangzhanwei 2025-11-14 18:40:05 +08:00 committed by CaptainB
parent fbc13a0bc2
commit 3ce69d64fb
5 changed files with 169 additions and 68 deletions

View File

@ -8,7 +8,8 @@
"""
import asyncio
import json
import traceback
import queue
import threading
from typing import Iterator
from django.http import StreamingHttpResponse
@ -227,6 +228,30 @@ def generate_tool_message_template(name, context):
return tool_message_template % (name, tool_message_json_template % (context))
# 全局单例事件循环
_global_loop = None
_loop_thread = None
_loop_lock = threading.Lock()
def get_global_loop():
"""获取全局共享的事件循环"""
global _global_loop, _loop_thread
with _loop_lock:
if _global_loop is None:
_global_loop = asyncio.new_event_loop()
def run_forever():
asyncio.set_event_loop(_global_loop)
_global_loop.run_forever()
_loop_thread = threading.Thread(target=run_forever, daemon=True, name="GlobalAsyncLoop")
_loop_thread.start()
return _global_loop
async def _yield_mcp_response(chat_model, message_list, mcp_servers, mcp_output_enable=True):
client = MultiServerMCPClient(json.loads(mcp_servers))
tools = await client.get_tools()
@ -242,19 +267,31 @@ async def _yield_mcp_response(chat_model, message_list, mcp_servers, mcp_output_
def mcp_response_generator(chat_model, message_list, mcp_servers, mcp_output_enable=True):
loop = asyncio.new_event_loop()
try:
async_gen = _yield_mcp_response(chat_model, message_list, mcp_servers, mcp_output_enable)
while True:
try:
chunk = loop.run_until_complete(anext_async(async_gen))
yield chunk
except StopAsyncIteration:
break
except Exception as e:
maxkb_logger.error(f'Exception: {e}', exc_info=True)
finally:
loop.close()
"""使用全局事件循环,不创建新实例"""
result_queue = queue.Queue()
loop = get_global_loop() # 使用共享循环
async def _run():
try:
async_gen = _yield_mcp_response(chat_model, message_list, mcp_servers, mcp_output_enable)
async for chunk in async_gen:
result_queue.put(('data', chunk))
except Exception as e:
maxkb_logger.error(f'Exception: {e}', exc_info=True)
result_queue.put(('error', e))
finally:
result_queue.put(('done', None))
# 在全局循环中调度任务
asyncio.run_coroutine_threadsafe(_run(), loop)
while True:
msg_type, data = result_queue.get()
if msg_type == 'done':
break
if msg_type == 'error':
raise data
yield data
async def anext_async(agen):

View File

@ -234,25 +234,62 @@ class WorkflowManage:
非流式响应
@return: 结果
"""
self.run_chain_async(None, None, language)
while self.is_run():
pass
details = self.get_runtime_details()
message_tokens = sum([row.get('message_tokens') for row in details.values() if
'message_tokens' in row and row.get('message_tokens') is not None])
answer_tokens = sum([row.get('answer_tokens') for row in details.values() if
'answer_tokens' in row and row.get('answer_tokens') is not None])
answer_text_list = self.get_answer_text_list()
answer_text = '\n\n'.join(
'\n\n'.join([a.get('content') for a in answer]) for answer in
answer_text_list)
answer_list = reduce(lambda pre, _n: [*pre, *_n], answer_text_list, [])
self.work_flow_post_handler.handler(self)
return self.base_to_response.to_block_response(self.params['chat_id'],
self.params['chat_record_id'], answer_text, True
, message_tokens, answer_tokens,
_status=status.HTTP_200_OK if self.status == 200 else status.HTTP_500_INTERNAL_SERVER_ERROR,
other_params={'answer_list': answer_list})
try:
self.run_chain_async(None, None, language)
while self.is_run():
pass
details = self.get_runtime_details()
message_tokens = sum([row.get('message_tokens') for row in details.values() if
'message_tokens' in row and row.get('message_tokens') is not None])
answer_tokens = sum([row.get('answer_tokens') for row in details.values() if
'answer_tokens' in row and row.get('answer_tokens') is not None])
answer_text_list = self.get_answer_text_list()
answer_text = '\n\n'.join(
'\n\n'.join([a.get('content') for a in answer]) for answer in
answer_text_list)
answer_list = reduce(lambda pre, _n: [*pre, *_n], answer_text_list, [])
self.work_flow_post_handler.handler(self)
res = self.base_to_response.to_block_response(self.params['chat_id'],
self.params['chat_record_id'], answer_text, True
, message_tokens, answer_tokens,
_status=status.HTTP_200_OK if self.status == 200 else status.HTTP_500_INTERNAL_SERVER_ERROR,
other_params={'answer_list': answer_list})
finally:
self._cleanup()
return res
def _cleanup(self):
"""清理所有对象引用"""
# 清理列表
self.future_list.clear()
self.field_list.clear()
self.global_field_list.clear()
self.chat_field_list.clear()
self.image_list.clear()
self.video_list.clear()
self.document_list.clear()
self.audio_list.clear()
self.other_list.clear()
if hasattr(self, 'node_context'):
self.node_context.clear()
# 清理字典
self.context.clear()
self.chat_context.clear()
self.form_data.clear()
# 清理对象引用
self.node_chunk_manage = None
self.work_flow_post_handler = None
self.flow = None
self.start_node = None
self.current_node = None
self.current_result = None
self.chat_record = None
self.base_to_response = None
self.params = None
self.lock = None
def run_stream(self, current_node, node_result_future, language='zh'):
"""
@ -307,6 +344,7 @@ class WorkflowManage:
'',
[],
'', True, message_tokens, answer_tokens, {})
self._cleanup()
def run_chain_async(self, current_node, node_result_future, language='zh'):
future = executor.submit(self.run_chain_manage, current_node, node_result_future, language)

View File

@ -6,7 +6,6 @@
@date2025/6/9 13:42
@desc:
"""
from datetime import datetime
from typing import List
from django.core.cache import cache
@ -226,41 +225,66 @@ class ChatInfo:
chat_record.save()
ChatCountSerializer(data={'chat_id': self.chat_id}).update_chat()
def to_dict(self):
return {
'chat_id': self.chat_id,
'chat_user_id': self.chat_user_id,
'chat_user_type': self.chat_user_type,
'knowledge_id_list': self.knowledge_id_list,
'exclude_document_id_list': self.exclude_document_id_list,
'application_id': self.application_id,
'chat_record_list': [self.chat_record_to_map(c) for c in self.chat_record_list],
'debug': self.debug
}
def chat_record_to_map(self, chat_record):
return {'id': chat_record.id,
'chat_id': chat_record.chat_id,
'vote_status': chat_record.vote_status,
'problem_text': chat_record.problem_text,
'answer_text': chat_record.answer_text,
'answer_text_list': chat_record.answer_text_list,
'message_tokens': chat_record.message_tokens,
'answer_tokens': chat_record.answer_tokens,
'const': chat_record.const,
'details': chat_record.details,
'improve_paragraph_id_list': chat_record.improve_paragraph_id_list,
'run_time': chat_record.run_time,
'index': chat_record.index}
@staticmethod
def map_to_chat_record(chat_record_dict):
ChatRecord(id=chat_record_dict.get('id'),
chat_id=chat_record_dict.get('chat_id'),
vote_status=chat_record_dict.get('vote_status'),
problem_text=chat_record_dict.get('problem_text'),
answer_text=chat_record_dict.get('answer_text'),
answer_text_list=chat_record_dict.get('answer_text_list'),
message_tokens=chat_record_dict.get('message_tokens'),
answer_tokens=chat_record_dict.get('answer_tokens'),
const=chat_record_dict.get('const'),
details=chat_record_dict.get('details'),
improve_paragraph_id_list=chat_record_dict.get('improve_paragraph_id_list'),
run_time=chat_record_dict.get('run_time'),
index=chat_record_dict.get('index'), )
def set_cache(self):
cache.set(Cache_Version.CHAT.get_key(key=self.chat_id), self, version=Cache_Version.CHAT.get_version(),
cache.set(Cache_Version.CHAT.get_key(key=self.chat_id), self.to_dict(),
version=Cache_Version.CHAT_INFO.get_version(),
timeout=60 * 30)
@staticmethod
def map_to_chat_info(chat_info_dict):
return ChatInfo(chat_info_dict.get('chat_id'), chat_info_dict.get('chat_user_id'),
chat_info_dict.get('chat_user_type'), chat_info_dict.get('knowledge_id_list'),
chat_info_dict.get('exclude_document_id_list'),
chat_info_dict.get('application_id'),
[ChatInfo.map_to_chat_record(c_r) for c_r in chat_info_dict.get('chat_record_list')])
@staticmethod
def get_cache(chat_id):
return cache.get(Cache_Version.CHAT.get_key(key=chat_id), version=Cache_Version.CHAT.get_version())
def __getstate__(self):
state = self.__dict__.copy()
state['application'] = None
state['chat_user'] = None
# 将 ChatRecord ORM 对象转为轻量字典
if not self.debug and len(self.chat_record_list) > 0:
state['chat_record_list'] = [
{
'id': str(record.id),
}
for record in self.chat_record_list
]
return state
def __setstate__(self, state):
self.__dict__.update(state)
# 恢复 application
if self.application is None and self.application_id:
self.get_application()
# 如果需要完整的 ChatRecord 对象,从数据库重新加载
if not self.debug and len(self.chat_record_list) > 0:
record_ids = [record['id'] for record in self.chat_record_list if isinstance(record, dict)]
if record_ids:
self.chat_record_list = list(
QuerySet(ChatRecord).filter(id__in=record_ids).order_by('create_time')
)
chat_info_dict = cache.get(Cache_Version.CHAT.get_key(key=chat_id), version=Cache_Version.CHAT_INFO.get_version())
if chat_info_dict:
return ChatInfo.map_to_chat_info(chat_info_dict)
return None

View File

@ -45,7 +45,7 @@ class ChatAuthentication:
value = json.dumps(self.to_dict())
authentication = encrypt(value)
cache_key = hashlib.sha256(authentication.encode()).hexdigest()
authentication_cache.set(cache_key, value, version=Cache_Version.CHAT.value, timeout=60 * 60 * 2)
authentication_cache.set(cache_key, value, version=Cache_Version.CHAT.get_version(), timeout=60 * 60 * 2)
return authentication
@staticmethod

View File

@ -30,6 +30,8 @@ class Cache_Version(Enum):
# 对话
CHAT = "CHAT", lambda key: key
CHAT_INFO = "CHAT_INFO", lambda key: key
CHAT_VARIABLE = "CHAT_VARIABLE", lambda key: key
# 应用API KEY