mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 01:33:05 +00:00
perf: Memory optimization
This commit is contained in:
parent
fbc13a0bc2
commit
3ce69d64fb
|
|
@ -8,7 +8,8 @@
|
|||
"""
|
||||
import asyncio
|
||||
import json
|
||||
import traceback
|
||||
import queue
|
||||
import threading
|
||||
from typing import Iterator
|
||||
|
||||
from django.http import StreamingHttpResponse
|
||||
|
|
@ -227,6 +228,30 @@ def generate_tool_message_template(name, context):
|
|||
return tool_message_template % (name, tool_message_json_template % (context))
|
||||
|
||||
|
||||
# 全局单例事件循环
|
||||
_global_loop = None
|
||||
_loop_thread = None
|
||||
_loop_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_global_loop():
|
||||
"""获取全局共享的事件循环"""
|
||||
global _global_loop, _loop_thread
|
||||
|
||||
with _loop_lock:
|
||||
if _global_loop is None:
|
||||
_global_loop = asyncio.new_event_loop()
|
||||
|
||||
def run_forever():
|
||||
asyncio.set_event_loop(_global_loop)
|
||||
_global_loop.run_forever()
|
||||
|
||||
_loop_thread = threading.Thread(target=run_forever, daemon=True, name="GlobalAsyncLoop")
|
||||
_loop_thread.start()
|
||||
|
||||
return _global_loop
|
||||
|
||||
|
||||
async def _yield_mcp_response(chat_model, message_list, mcp_servers, mcp_output_enable=True):
|
||||
client = MultiServerMCPClient(json.loads(mcp_servers))
|
||||
tools = await client.get_tools()
|
||||
|
|
@ -242,19 +267,31 @@ async def _yield_mcp_response(chat_model, message_list, mcp_servers, mcp_output_
|
|||
|
||||
|
||||
def mcp_response_generator(chat_model, message_list, mcp_servers, mcp_output_enable=True):
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
async_gen = _yield_mcp_response(chat_model, message_list, mcp_servers, mcp_output_enable)
|
||||
while True:
|
||||
try:
|
||||
chunk = loop.run_until_complete(anext_async(async_gen))
|
||||
yield chunk
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
except Exception as e:
|
||||
maxkb_logger.error(f'Exception: {e}', exc_info=True)
|
||||
finally:
|
||||
loop.close()
|
||||
"""使用全局事件循环,不创建新实例"""
|
||||
result_queue = queue.Queue()
|
||||
loop = get_global_loop() # 使用共享循环
|
||||
|
||||
async def _run():
|
||||
try:
|
||||
async_gen = _yield_mcp_response(chat_model, message_list, mcp_servers, mcp_output_enable)
|
||||
async for chunk in async_gen:
|
||||
result_queue.put(('data', chunk))
|
||||
except Exception as e:
|
||||
maxkb_logger.error(f'Exception: {e}', exc_info=True)
|
||||
result_queue.put(('error', e))
|
||||
finally:
|
||||
result_queue.put(('done', None))
|
||||
|
||||
# 在全局循环中调度任务
|
||||
asyncio.run_coroutine_threadsafe(_run(), loop)
|
||||
|
||||
while True:
|
||||
msg_type, data = result_queue.get()
|
||||
if msg_type == 'done':
|
||||
break
|
||||
if msg_type == 'error':
|
||||
raise data
|
||||
yield data
|
||||
|
||||
|
||||
async def anext_async(agen):
|
||||
|
|
|
|||
|
|
@ -234,25 +234,62 @@ class WorkflowManage:
|
|||
非流式响应
|
||||
@return: 结果
|
||||
"""
|
||||
self.run_chain_async(None, None, language)
|
||||
while self.is_run():
|
||||
pass
|
||||
details = self.get_runtime_details()
|
||||
message_tokens = sum([row.get('message_tokens') for row in details.values() if
|
||||
'message_tokens' in row and row.get('message_tokens') is not None])
|
||||
answer_tokens = sum([row.get('answer_tokens') for row in details.values() if
|
||||
'answer_tokens' in row and row.get('answer_tokens') is not None])
|
||||
answer_text_list = self.get_answer_text_list()
|
||||
answer_text = '\n\n'.join(
|
||||
'\n\n'.join([a.get('content') for a in answer]) for answer in
|
||||
answer_text_list)
|
||||
answer_list = reduce(lambda pre, _n: [*pre, *_n], answer_text_list, [])
|
||||
self.work_flow_post_handler.handler(self)
|
||||
return self.base_to_response.to_block_response(self.params['chat_id'],
|
||||
self.params['chat_record_id'], answer_text, True
|
||||
, message_tokens, answer_tokens,
|
||||
_status=status.HTTP_200_OK if self.status == 200 else status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
other_params={'answer_list': answer_list})
|
||||
try:
|
||||
self.run_chain_async(None, None, language)
|
||||
while self.is_run():
|
||||
pass
|
||||
details = self.get_runtime_details()
|
||||
message_tokens = sum([row.get('message_tokens') for row in details.values() if
|
||||
'message_tokens' in row and row.get('message_tokens') is not None])
|
||||
answer_tokens = sum([row.get('answer_tokens') for row in details.values() if
|
||||
'answer_tokens' in row and row.get('answer_tokens') is not None])
|
||||
answer_text_list = self.get_answer_text_list()
|
||||
answer_text = '\n\n'.join(
|
||||
'\n\n'.join([a.get('content') for a in answer]) for answer in
|
||||
answer_text_list)
|
||||
answer_list = reduce(lambda pre, _n: [*pre, *_n], answer_text_list, [])
|
||||
self.work_flow_post_handler.handler(self)
|
||||
|
||||
res = self.base_to_response.to_block_response(self.params['chat_id'],
|
||||
self.params['chat_record_id'], answer_text, True
|
||||
, message_tokens, answer_tokens,
|
||||
_status=status.HTTP_200_OK if self.status == 200 else status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
other_params={'answer_list': answer_list})
|
||||
finally:
|
||||
self._cleanup()
|
||||
return res
|
||||
|
||||
def _cleanup(self):
|
||||
"""清理所有对象引用"""
|
||||
# 清理列表
|
||||
self.future_list.clear()
|
||||
self.field_list.clear()
|
||||
self.global_field_list.clear()
|
||||
self.chat_field_list.clear()
|
||||
self.image_list.clear()
|
||||
self.video_list.clear()
|
||||
self.document_list.clear()
|
||||
self.audio_list.clear()
|
||||
self.other_list.clear()
|
||||
if hasattr(self, 'node_context'):
|
||||
self.node_context.clear()
|
||||
|
||||
# 清理字典
|
||||
self.context.clear()
|
||||
self.chat_context.clear()
|
||||
self.form_data.clear()
|
||||
|
||||
# 清理对象引用
|
||||
self.node_chunk_manage = None
|
||||
self.work_flow_post_handler = None
|
||||
self.flow = None
|
||||
self.start_node = None
|
||||
self.current_node = None
|
||||
self.current_result = None
|
||||
self.chat_record = None
|
||||
self.base_to_response = None
|
||||
self.params = None
|
||||
self.lock = None
|
||||
|
||||
def run_stream(self, current_node, node_result_future, language='zh'):
|
||||
"""
|
||||
|
|
@ -307,6 +344,7 @@ class WorkflowManage:
|
|||
'',
|
||||
[],
|
||||
'', True, message_tokens, answer_tokens, {})
|
||||
self._cleanup()
|
||||
|
||||
def run_chain_async(self, current_node, node_result_future, language='zh'):
|
||||
future = executor.submit(self.run_chain_manage, current_node, node_result_future, language)
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@
|
|||
@date:2025/6/9 13:42
|
||||
@desc:
|
||||
"""
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
|
||||
from django.core.cache import cache
|
||||
|
|
@ -226,41 +225,66 @@ class ChatInfo:
|
|||
chat_record.save()
|
||||
ChatCountSerializer(data={'chat_id': self.chat_id}).update_chat()
|
||||
|
||||
def to_dict(self):
|
||||
|
||||
return {
|
||||
'chat_id': self.chat_id,
|
||||
'chat_user_id': self.chat_user_id,
|
||||
'chat_user_type': self.chat_user_type,
|
||||
'knowledge_id_list': self.knowledge_id_list,
|
||||
'exclude_document_id_list': self.exclude_document_id_list,
|
||||
'application_id': self.application_id,
|
||||
'chat_record_list': [self.chat_record_to_map(c) for c in self.chat_record_list],
|
||||
'debug': self.debug
|
||||
}
|
||||
|
||||
def chat_record_to_map(self, chat_record):
|
||||
return {'id': chat_record.id,
|
||||
'chat_id': chat_record.chat_id,
|
||||
'vote_status': chat_record.vote_status,
|
||||
'problem_text': chat_record.problem_text,
|
||||
'answer_text': chat_record.answer_text,
|
||||
'answer_text_list': chat_record.answer_text_list,
|
||||
'message_tokens': chat_record.message_tokens,
|
||||
'answer_tokens': chat_record.answer_tokens,
|
||||
'const': chat_record.const,
|
||||
'details': chat_record.details,
|
||||
'improve_paragraph_id_list': chat_record.improve_paragraph_id_list,
|
||||
'run_time': chat_record.run_time,
|
||||
'index': chat_record.index}
|
||||
|
||||
@staticmethod
|
||||
def map_to_chat_record(chat_record_dict):
|
||||
ChatRecord(id=chat_record_dict.get('id'),
|
||||
chat_id=chat_record_dict.get('chat_id'),
|
||||
vote_status=chat_record_dict.get('vote_status'),
|
||||
problem_text=chat_record_dict.get('problem_text'),
|
||||
answer_text=chat_record_dict.get('answer_text'),
|
||||
answer_text_list=chat_record_dict.get('answer_text_list'),
|
||||
message_tokens=chat_record_dict.get('message_tokens'),
|
||||
answer_tokens=chat_record_dict.get('answer_tokens'),
|
||||
const=chat_record_dict.get('const'),
|
||||
details=chat_record_dict.get('details'),
|
||||
improve_paragraph_id_list=chat_record_dict.get('improve_paragraph_id_list'),
|
||||
run_time=chat_record_dict.get('run_time'),
|
||||
index=chat_record_dict.get('index'), )
|
||||
|
||||
def set_cache(self):
|
||||
cache.set(Cache_Version.CHAT.get_key(key=self.chat_id), self, version=Cache_Version.CHAT.get_version(),
|
||||
cache.set(Cache_Version.CHAT.get_key(key=self.chat_id), self.to_dict(),
|
||||
version=Cache_Version.CHAT_INFO.get_version(),
|
||||
timeout=60 * 30)
|
||||
|
||||
@staticmethod
|
||||
def map_to_chat_info(chat_info_dict):
|
||||
return ChatInfo(chat_info_dict.get('chat_id'), chat_info_dict.get('chat_user_id'),
|
||||
chat_info_dict.get('chat_user_type'), chat_info_dict.get('knowledge_id_list'),
|
||||
chat_info_dict.get('exclude_document_id_list'),
|
||||
chat_info_dict.get('application_id'),
|
||||
[ChatInfo.map_to_chat_record(c_r) for c_r in chat_info_dict.get('chat_record_list')])
|
||||
|
||||
@staticmethod
|
||||
def get_cache(chat_id):
|
||||
return cache.get(Cache_Version.CHAT.get_key(key=chat_id), version=Cache_Version.CHAT.get_version())
|
||||
|
||||
def __getstate__(self):
|
||||
state = self.__dict__.copy()
|
||||
state['application'] = None
|
||||
state['chat_user'] = None
|
||||
|
||||
# 将 ChatRecord ORM 对象转为轻量字典
|
||||
if not self.debug and len(self.chat_record_list) > 0:
|
||||
state['chat_record_list'] = [
|
||||
{
|
||||
'id': str(record.id),
|
||||
}
|
||||
for record in self.chat_record_list
|
||||
]
|
||||
return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.__dict__.update(state)
|
||||
|
||||
# 恢复 application
|
||||
if self.application is None and self.application_id:
|
||||
self.get_application()
|
||||
|
||||
# 如果需要完整的 ChatRecord 对象,从数据库重新加载
|
||||
if not self.debug and len(self.chat_record_list) > 0:
|
||||
record_ids = [record['id'] for record in self.chat_record_list if isinstance(record, dict)]
|
||||
if record_ids:
|
||||
self.chat_record_list = list(
|
||||
QuerySet(ChatRecord).filter(id__in=record_ids).order_by('create_time')
|
||||
)
|
||||
|
||||
chat_info_dict = cache.get(Cache_Version.CHAT.get_key(key=chat_id), version=Cache_Version.CHAT_INFO.get_version())
|
||||
if chat_info_dict:
|
||||
return ChatInfo.map_to_chat_info(chat_info_dict)
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ class ChatAuthentication:
|
|||
value = json.dumps(self.to_dict())
|
||||
authentication = encrypt(value)
|
||||
cache_key = hashlib.sha256(authentication.encode()).hexdigest()
|
||||
authentication_cache.set(cache_key, value, version=Cache_Version.CHAT.value, timeout=60 * 60 * 2)
|
||||
authentication_cache.set(cache_key, value, version=Cache_Version.CHAT.get_version(), timeout=60 * 60 * 2)
|
||||
return authentication
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -30,6 +30,8 @@ class Cache_Version(Enum):
|
|||
# 对话
|
||||
CHAT = "CHAT", lambda key: key
|
||||
|
||||
CHAT_INFO = "CHAT_INFO", lambda key: key
|
||||
|
||||
CHAT_VARIABLE = "CHAT_VARIABLE", lambda key: key
|
||||
|
||||
# 应用API KEY
|
||||
|
|
|
|||
Loading…
Reference in New Issue