mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 01:33:05 +00:00
perf: Memory optimization
This commit is contained in:
parent
fbc13a0bc2
commit
3ce69d64fb
|
|
@ -8,7 +8,8 @@
|
||||||
"""
|
"""
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import traceback
|
import queue
|
||||||
|
import threading
|
||||||
from typing import Iterator
|
from typing import Iterator
|
||||||
|
|
||||||
from django.http import StreamingHttpResponse
|
from django.http import StreamingHttpResponse
|
||||||
|
|
@ -227,6 +228,30 @@ def generate_tool_message_template(name, context):
|
||||||
return tool_message_template % (name, tool_message_json_template % (context))
|
return tool_message_template % (name, tool_message_json_template % (context))
|
||||||
|
|
||||||
|
|
||||||
|
# 全局单例事件循环
|
||||||
|
_global_loop = None
|
||||||
|
_loop_thread = None
|
||||||
|
_loop_lock = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
def get_global_loop():
|
||||||
|
"""获取全局共享的事件循环"""
|
||||||
|
global _global_loop, _loop_thread
|
||||||
|
|
||||||
|
with _loop_lock:
|
||||||
|
if _global_loop is None:
|
||||||
|
_global_loop = asyncio.new_event_loop()
|
||||||
|
|
||||||
|
def run_forever():
|
||||||
|
asyncio.set_event_loop(_global_loop)
|
||||||
|
_global_loop.run_forever()
|
||||||
|
|
||||||
|
_loop_thread = threading.Thread(target=run_forever, daemon=True, name="GlobalAsyncLoop")
|
||||||
|
_loop_thread.start()
|
||||||
|
|
||||||
|
return _global_loop
|
||||||
|
|
||||||
|
|
||||||
async def _yield_mcp_response(chat_model, message_list, mcp_servers, mcp_output_enable=True):
|
async def _yield_mcp_response(chat_model, message_list, mcp_servers, mcp_output_enable=True):
|
||||||
client = MultiServerMCPClient(json.loads(mcp_servers))
|
client = MultiServerMCPClient(json.loads(mcp_servers))
|
||||||
tools = await client.get_tools()
|
tools = await client.get_tools()
|
||||||
|
|
@ -242,19 +267,31 @@ async def _yield_mcp_response(chat_model, message_list, mcp_servers, mcp_output_
|
||||||
|
|
||||||
|
|
||||||
def mcp_response_generator(chat_model, message_list, mcp_servers, mcp_output_enable=True):
|
def mcp_response_generator(chat_model, message_list, mcp_servers, mcp_output_enable=True):
|
||||||
loop = asyncio.new_event_loop()
|
"""使用全局事件循环,不创建新实例"""
|
||||||
try:
|
result_queue = queue.Queue()
|
||||||
async_gen = _yield_mcp_response(chat_model, message_list, mcp_servers, mcp_output_enable)
|
loop = get_global_loop() # 使用共享循环
|
||||||
while True:
|
|
||||||
try:
|
async def _run():
|
||||||
chunk = loop.run_until_complete(anext_async(async_gen))
|
try:
|
||||||
yield chunk
|
async_gen = _yield_mcp_response(chat_model, message_list, mcp_servers, mcp_output_enable)
|
||||||
except StopAsyncIteration:
|
async for chunk in async_gen:
|
||||||
break
|
result_queue.put(('data', chunk))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
maxkb_logger.error(f'Exception: {e}', exc_info=True)
|
maxkb_logger.error(f'Exception: {e}', exc_info=True)
|
||||||
finally:
|
result_queue.put(('error', e))
|
||||||
loop.close()
|
finally:
|
||||||
|
result_queue.put(('done', None))
|
||||||
|
|
||||||
|
# 在全局循环中调度任务
|
||||||
|
asyncio.run_coroutine_threadsafe(_run(), loop)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
msg_type, data = result_queue.get()
|
||||||
|
if msg_type == 'done':
|
||||||
|
break
|
||||||
|
if msg_type == 'error':
|
||||||
|
raise data
|
||||||
|
yield data
|
||||||
|
|
||||||
|
|
||||||
async def anext_async(agen):
|
async def anext_async(agen):
|
||||||
|
|
|
||||||
|
|
@ -234,25 +234,62 @@ class WorkflowManage:
|
||||||
非流式响应
|
非流式响应
|
||||||
@return: 结果
|
@return: 结果
|
||||||
"""
|
"""
|
||||||
self.run_chain_async(None, None, language)
|
try:
|
||||||
while self.is_run():
|
self.run_chain_async(None, None, language)
|
||||||
pass
|
while self.is_run():
|
||||||
details = self.get_runtime_details()
|
pass
|
||||||
message_tokens = sum([row.get('message_tokens') for row in details.values() if
|
details = self.get_runtime_details()
|
||||||
'message_tokens' in row and row.get('message_tokens') is not None])
|
message_tokens = sum([row.get('message_tokens') for row in details.values() if
|
||||||
answer_tokens = sum([row.get('answer_tokens') for row in details.values() if
|
'message_tokens' in row and row.get('message_tokens') is not None])
|
||||||
'answer_tokens' in row and row.get('answer_tokens') is not None])
|
answer_tokens = sum([row.get('answer_tokens') for row in details.values() if
|
||||||
answer_text_list = self.get_answer_text_list()
|
'answer_tokens' in row and row.get('answer_tokens') is not None])
|
||||||
answer_text = '\n\n'.join(
|
answer_text_list = self.get_answer_text_list()
|
||||||
'\n\n'.join([a.get('content') for a in answer]) for answer in
|
answer_text = '\n\n'.join(
|
||||||
answer_text_list)
|
'\n\n'.join([a.get('content') for a in answer]) for answer in
|
||||||
answer_list = reduce(lambda pre, _n: [*pre, *_n], answer_text_list, [])
|
answer_text_list)
|
||||||
self.work_flow_post_handler.handler(self)
|
answer_list = reduce(lambda pre, _n: [*pre, *_n], answer_text_list, [])
|
||||||
return self.base_to_response.to_block_response(self.params['chat_id'],
|
self.work_flow_post_handler.handler(self)
|
||||||
self.params['chat_record_id'], answer_text, True
|
|
||||||
, message_tokens, answer_tokens,
|
res = self.base_to_response.to_block_response(self.params['chat_id'],
|
||||||
_status=status.HTTP_200_OK if self.status == 200 else status.HTTP_500_INTERNAL_SERVER_ERROR,
|
self.params['chat_record_id'], answer_text, True
|
||||||
other_params={'answer_list': answer_list})
|
, message_tokens, answer_tokens,
|
||||||
|
_status=status.HTTP_200_OK if self.status == 200 else status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
other_params={'answer_list': answer_list})
|
||||||
|
finally:
|
||||||
|
self._cleanup()
|
||||||
|
return res
|
||||||
|
|
||||||
|
def _cleanup(self):
|
||||||
|
"""清理所有对象引用"""
|
||||||
|
# 清理列表
|
||||||
|
self.future_list.clear()
|
||||||
|
self.field_list.clear()
|
||||||
|
self.global_field_list.clear()
|
||||||
|
self.chat_field_list.clear()
|
||||||
|
self.image_list.clear()
|
||||||
|
self.video_list.clear()
|
||||||
|
self.document_list.clear()
|
||||||
|
self.audio_list.clear()
|
||||||
|
self.other_list.clear()
|
||||||
|
if hasattr(self, 'node_context'):
|
||||||
|
self.node_context.clear()
|
||||||
|
|
||||||
|
# 清理字典
|
||||||
|
self.context.clear()
|
||||||
|
self.chat_context.clear()
|
||||||
|
self.form_data.clear()
|
||||||
|
|
||||||
|
# 清理对象引用
|
||||||
|
self.node_chunk_manage = None
|
||||||
|
self.work_flow_post_handler = None
|
||||||
|
self.flow = None
|
||||||
|
self.start_node = None
|
||||||
|
self.current_node = None
|
||||||
|
self.current_result = None
|
||||||
|
self.chat_record = None
|
||||||
|
self.base_to_response = None
|
||||||
|
self.params = None
|
||||||
|
self.lock = None
|
||||||
|
|
||||||
def run_stream(self, current_node, node_result_future, language='zh'):
|
def run_stream(self, current_node, node_result_future, language='zh'):
|
||||||
"""
|
"""
|
||||||
|
|
@ -307,6 +344,7 @@ class WorkflowManage:
|
||||||
'',
|
'',
|
||||||
[],
|
[],
|
||||||
'', True, message_tokens, answer_tokens, {})
|
'', True, message_tokens, answer_tokens, {})
|
||||||
|
self._cleanup()
|
||||||
|
|
||||||
def run_chain_async(self, current_node, node_result_future, language='zh'):
|
def run_chain_async(self, current_node, node_result_future, language='zh'):
|
||||||
future = executor.submit(self.run_chain_manage, current_node, node_result_future, language)
|
future = executor.submit(self.run_chain_manage, current_node, node_result_future, language)
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,6 @@
|
||||||
@date:2025/6/9 13:42
|
@date:2025/6/9 13:42
|
||||||
@desc:
|
@desc:
|
||||||
"""
|
"""
|
||||||
from datetime import datetime
|
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from django.core.cache import cache
|
from django.core.cache import cache
|
||||||
|
|
@ -226,41 +225,66 @@ class ChatInfo:
|
||||||
chat_record.save()
|
chat_record.save()
|
||||||
ChatCountSerializer(data={'chat_id': self.chat_id}).update_chat()
|
ChatCountSerializer(data={'chat_id': self.chat_id}).update_chat()
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
|
||||||
|
return {
|
||||||
|
'chat_id': self.chat_id,
|
||||||
|
'chat_user_id': self.chat_user_id,
|
||||||
|
'chat_user_type': self.chat_user_type,
|
||||||
|
'knowledge_id_list': self.knowledge_id_list,
|
||||||
|
'exclude_document_id_list': self.exclude_document_id_list,
|
||||||
|
'application_id': self.application_id,
|
||||||
|
'chat_record_list': [self.chat_record_to_map(c) for c in self.chat_record_list],
|
||||||
|
'debug': self.debug
|
||||||
|
}
|
||||||
|
|
||||||
|
def chat_record_to_map(self, chat_record):
|
||||||
|
return {'id': chat_record.id,
|
||||||
|
'chat_id': chat_record.chat_id,
|
||||||
|
'vote_status': chat_record.vote_status,
|
||||||
|
'problem_text': chat_record.problem_text,
|
||||||
|
'answer_text': chat_record.answer_text,
|
||||||
|
'answer_text_list': chat_record.answer_text_list,
|
||||||
|
'message_tokens': chat_record.message_tokens,
|
||||||
|
'answer_tokens': chat_record.answer_tokens,
|
||||||
|
'const': chat_record.const,
|
||||||
|
'details': chat_record.details,
|
||||||
|
'improve_paragraph_id_list': chat_record.improve_paragraph_id_list,
|
||||||
|
'run_time': chat_record.run_time,
|
||||||
|
'index': chat_record.index}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def map_to_chat_record(chat_record_dict):
|
||||||
|
ChatRecord(id=chat_record_dict.get('id'),
|
||||||
|
chat_id=chat_record_dict.get('chat_id'),
|
||||||
|
vote_status=chat_record_dict.get('vote_status'),
|
||||||
|
problem_text=chat_record_dict.get('problem_text'),
|
||||||
|
answer_text=chat_record_dict.get('answer_text'),
|
||||||
|
answer_text_list=chat_record_dict.get('answer_text_list'),
|
||||||
|
message_tokens=chat_record_dict.get('message_tokens'),
|
||||||
|
answer_tokens=chat_record_dict.get('answer_tokens'),
|
||||||
|
const=chat_record_dict.get('const'),
|
||||||
|
details=chat_record_dict.get('details'),
|
||||||
|
improve_paragraph_id_list=chat_record_dict.get('improve_paragraph_id_list'),
|
||||||
|
run_time=chat_record_dict.get('run_time'),
|
||||||
|
index=chat_record_dict.get('index'), )
|
||||||
|
|
||||||
def set_cache(self):
|
def set_cache(self):
|
||||||
cache.set(Cache_Version.CHAT.get_key(key=self.chat_id), self, version=Cache_Version.CHAT.get_version(),
|
cache.set(Cache_Version.CHAT.get_key(key=self.chat_id), self.to_dict(),
|
||||||
|
version=Cache_Version.CHAT_INFO.get_version(),
|
||||||
timeout=60 * 30)
|
timeout=60 * 30)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def map_to_chat_info(chat_info_dict):
|
||||||
|
return ChatInfo(chat_info_dict.get('chat_id'), chat_info_dict.get('chat_user_id'),
|
||||||
|
chat_info_dict.get('chat_user_type'), chat_info_dict.get('knowledge_id_list'),
|
||||||
|
chat_info_dict.get('exclude_document_id_list'),
|
||||||
|
chat_info_dict.get('application_id'),
|
||||||
|
[ChatInfo.map_to_chat_record(c_r) for c_r in chat_info_dict.get('chat_record_list')])
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_cache(chat_id):
|
def get_cache(chat_id):
|
||||||
return cache.get(Cache_Version.CHAT.get_key(key=chat_id), version=Cache_Version.CHAT.get_version())
|
chat_info_dict = cache.get(Cache_Version.CHAT.get_key(key=chat_id), version=Cache_Version.CHAT_INFO.get_version())
|
||||||
|
if chat_info_dict:
|
||||||
def __getstate__(self):
|
return ChatInfo.map_to_chat_info(chat_info_dict)
|
||||||
state = self.__dict__.copy()
|
return None
|
||||||
state['application'] = None
|
|
||||||
state['chat_user'] = None
|
|
||||||
|
|
||||||
# 将 ChatRecord ORM 对象转为轻量字典
|
|
||||||
if not self.debug and len(self.chat_record_list) > 0:
|
|
||||||
state['chat_record_list'] = [
|
|
||||||
{
|
|
||||||
'id': str(record.id),
|
|
||||||
}
|
|
||||||
for record in self.chat_record_list
|
|
||||||
]
|
|
||||||
return state
|
|
||||||
|
|
||||||
def __setstate__(self, state):
|
|
||||||
self.__dict__.update(state)
|
|
||||||
|
|
||||||
# 恢复 application
|
|
||||||
if self.application is None and self.application_id:
|
|
||||||
self.get_application()
|
|
||||||
|
|
||||||
# 如果需要完整的 ChatRecord 对象,从数据库重新加载
|
|
||||||
if not self.debug and len(self.chat_record_list) > 0:
|
|
||||||
record_ids = [record['id'] for record in self.chat_record_list if isinstance(record, dict)]
|
|
||||||
if record_ids:
|
|
||||||
self.chat_record_list = list(
|
|
||||||
QuerySet(ChatRecord).filter(id__in=record_ids).order_by('create_time')
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -45,7 +45,7 @@ class ChatAuthentication:
|
||||||
value = json.dumps(self.to_dict())
|
value = json.dumps(self.to_dict())
|
||||||
authentication = encrypt(value)
|
authentication = encrypt(value)
|
||||||
cache_key = hashlib.sha256(authentication.encode()).hexdigest()
|
cache_key = hashlib.sha256(authentication.encode()).hexdigest()
|
||||||
authentication_cache.set(cache_key, value, version=Cache_Version.CHAT.value, timeout=60 * 60 * 2)
|
authentication_cache.set(cache_key, value, version=Cache_Version.CHAT.get_version(), timeout=60 * 60 * 2)
|
||||||
return authentication
|
return authentication
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
||||||
|
|
@ -30,6 +30,8 @@ class Cache_Version(Enum):
|
||||||
# 对话
|
# 对话
|
||||||
CHAT = "CHAT", lambda key: key
|
CHAT = "CHAT", lambda key: key
|
||||||
|
|
||||||
|
CHAT_INFO = "CHAT_INFO", lambda key: key
|
||||||
|
|
||||||
CHAT_VARIABLE = "CHAT_VARIABLE", lambda key: key
|
CHAT_VARIABLE = "CHAT_VARIABLE", lambda key: key
|
||||||
|
|
||||||
# 应用API KEY
|
# 应用API KEY
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue