diff --git a/apps/application/flow/tools.py b/apps/application/flow/tools.py index 527b11ce8..feb8b62dc 100644 --- a/apps/application/flow/tools.py +++ b/apps/application/flow/tools.py @@ -8,7 +8,8 @@ """ import asyncio import json -import traceback +import queue +import threading from typing import Iterator from django.http import StreamingHttpResponse @@ -227,6 +228,30 @@ def generate_tool_message_template(name, context): return tool_message_template % (name, tool_message_json_template % (context)) +# 全局单例事件循环 +_global_loop = None +_loop_thread = None +_loop_lock = threading.Lock() + + +def get_global_loop(): + """获取全局共享的事件循环""" + global _global_loop, _loop_thread + + with _loop_lock: + if _global_loop is None: + _global_loop = asyncio.new_event_loop() + + def run_forever(): + asyncio.set_event_loop(_global_loop) + _global_loop.run_forever() + + _loop_thread = threading.Thread(target=run_forever, daemon=True, name="GlobalAsyncLoop") + _loop_thread.start() + + return _global_loop + + async def _yield_mcp_response(chat_model, message_list, mcp_servers, mcp_output_enable=True): client = MultiServerMCPClient(json.loads(mcp_servers)) tools = await client.get_tools() @@ -242,19 +267,31 @@ async def _yield_mcp_response(chat_model, message_list, mcp_servers, mcp_output_ def mcp_response_generator(chat_model, message_list, mcp_servers, mcp_output_enable=True): - loop = asyncio.new_event_loop() - try: - async_gen = _yield_mcp_response(chat_model, message_list, mcp_servers, mcp_output_enable) - while True: - try: - chunk = loop.run_until_complete(anext_async(async_gen)) - yield chunk - except StopAsyncIteration: - break - except Exception as e: - maxkb_logger.error(f'Exception: {e}', exc_info=True) - finally: - loop.close() + """使用全局事件循环,不创建新实例""" + result_queue = queue.Queue() + loop = get_global_loop() # 使用共享循环 + + async def _run(): + try: + async_gen = _yield_mcp_response(chat_model, message_list, mcp_servers, mcp_output_enable) + async for chunk in async_gen: + result_queue.put(('data', chunk)) + except Exception as e: + maxkb_logger.error(f'Exception: {e}', exc_info=True) + result_queue.put(('error', e)) + finally: + result_queue.put(('done', None)) + + # 在全局循环中调度任务 + asyncio.run_coroutine_threadsafe(_run(), loop) + + while True: + msg_type, data = result_queue.get() + if msg_type == 'done': + break + if msg_type == 'error': + raise data + yield data async def anext_async(agen): diff --git a/apps/application/flow/workflow_manage.py b/apps/application/flow/workflow_manage.py index d35173ac9..343d6f9aa 100644 --- a/apps/application/flow/workflow_manage.py +++ b/apps/application/flow/workflow_manage.py @@ -234,25 +234,62 @@ class WorkflowManage: 非流式响应 @return: 结果 """ - self.run_chain_async(None, None, language) - while self.is_run(): - pass - details = self.get_runtime_details() - message_tokens = sum([row.get('message_tokens') for row in details.values() if - 'message_tokens' in row and row.get('message_tokens') is not None]) - answer_tokens = sum([row.get('answer_tokens') for row in details.values() if - 'answer_tokens' in row and row.get('answer_tokens') is not None]) - answer_text_list = self.get_answer_text_list() - answer_text = '\n\n'.join( - '\n\n'.join([a.get('content') for a in answer]) for answer in - answer_text_list) - answer_list = reduce(lambda pre, _n: [*pre, *_n], answer_text_list, []) - self.work_flow_post_handler.handler(self) - return self.base_to_response.to_block_response(self.params['chat_id'], - self.params['chat_record_id'], answer_text, True - , message_tokens, answer_tokens, - _status=status.HTTP_200_OK if self.status == 200 else status.HTTP_500_INTERNAL_SERVER_ERROR, - other_params={'answer_list': answer_list}) + try: + self.run_chain_async(None, None, language) + while self.is_run(): + pass + details = self.get_runtime_details() + message_tokens = sum([row.get('message_tokens') for row in details.values() if + 'message_tokens' in row and row.get('message_tokens') is not None]) + answer_tokens = sum([row.get('answer_tokens') for row in details.values() if + 'answer_tokens' in row and row.get('answer_tokens') is not None]) + answer_text_list = self.get_answer_text_list() + answer_text = '\n\n'.join( + '\n\n'.join([a.get('content') for a in answer]) for answer in + answer_text_list) + answer_list = reduce(lambda pre, _n: [*pre, *_n], answer_text_list, []) + self.work_flow_post_handler.handler(self) + + res = self.base_to_response.to_block_response(self.params['chat_id'], + self.params['chat_record_id'], answer_text, True + , message_tokens, answer_tokens, + _status=status.HTTP_200_OK if self.status == 200 else status.HTTP_500_INTERNAL_SERVER_ERROR, + other_params={'answer_list': answer_list}) + finally: + self._cleanup() + return res + + def _cleanup(self): + """清理所有对象引用""" + # 清理列表 + self.future_list.clear() + self.field_list.clear() + self.global_field_list.clear() + self.chat_field_list.clear() + self.image_list.clear() + self.video_list.clear() + self.document_list.clear() + self.audio_list.clear() + self.other_list.clear() + if hasattr(self, 'node_context'): + self.node_context.clear() + + # 清理字典 + self.context.clear() + self.chat_context.clear() + self.form_data.clear() + + # 清理对象引用 + self.node_chunk_manage = None + self.work_flow_post_handler = None + self.flow = None + self.start_node = None + self.current_node = None + self.current_result = None + self.chat_record = None + self.base_to_response = None + self.params = None + self.lock = None def run_stream(self, current_node, node_result_future, language='zh'): """ @@ -307,6 +344,7 @@ class WorkflowManage: '', [], '', True, message_tokens, answer_tokens, {}) + self._cleanup() def run_chain_async(self, current_node, node_result_future, language='zh'): future = executor.submit(self.run_chain_manage, current_node, node_result_future, language) diff --git a/apps/application/serializers/common.py b/apps/application/serializers/common.py index f8cfcaf12..e8ac10809 100644 --- a/apps/application/serializers/common.py +++ b/apps/application/serializers/common.py @@ -6,7 +6,6 @@ @date:2025/6/9 13:42 @desc: """ -from datetime import datetime from typing import List from django.core.cache import cache @@ -226,41 +225,66 @@ class ChatInfo: chat_record.save() ChatCountSerializer(data={'chat_id': self.chat_id}).update_chat() + def to_dict(self): + + return { + 'chat_id': self.chat_id, + 'chat_user_id': self.chat_user_id, + 'chat_user_type': self.chat_user_type, + 'knowledge_id_list': self.knowledge_id_list, + 'exclude_document_id_list': self.exclude_document_id_list, + 'application_id': self.application_id, + 'chat_record_list': [self.chat_record_to_map(c) for c in self.chat_record_list], + 'debug': self.debug + } + + def chat_record_to_map(self, chat_record): + return {'id': chat_record.id, + 'chat_id': chat_record.chat_id, + 'vote_status': chat_record.vote_status, + 'problem_text': chat_record.problem_text, + 'answer_text': chat_record.answer_text, + 'answer_text_list': chat_record.answer_text_list, + 'message_tokens': chat_record.message_tokens, + 'answer_tokens': chat_record.answer_tokens, + 'const': chat_record.const, + 'details': chat_record.details, + 'improve_paragraph_id_list': chat_record.improve_paragraph_id_list, + 'run_time': chat_record.run_time, + 'index': chat_record.index} + + @staticmethod + def map_to_chat_record(chat_record_dict): + ChatRecord(id=chat_record_dict.get('id'), + chat_id=chat_record_dict.get('chat_id'), + vote_status=chat_record_dict.get('vote_status'), + problem_text=chat_record_dict.get('problem_text'), + answer_text=chat_record_dict.get('answer_text'), + answer_text_list=chat_record_dict.get('answer_text_list'), + message_tokens=chat_record_dict.get('message_tokens'), + answer_tokens=chat_record_dict.get('answer_tokens'), + const=chat_record_dict.get('const'), + details=chat_record_dict.get('details'), + improve_paragraph_id_list=chat_record_dict.get('improve_paragraph_id_list'), + run_time=chat_record_dict.get('run_time'), + index=chat_record_dict.get('index'), ) + def set_cache(self): - cache.set(Cache_Version.CHAT.get_key(key=self.chat_id), self, version=Cache_Version.CHAT.get_version(), + cache.set(Cache_Version.CHAT.get_key(key=self.chat_id), self.to_dict(), + version=Cache_Version.CHAT_INFO.get_version(), timeout=60 * 30) + @staticmethod + def map_to_chat_info(chat_info_dict): + return ChatInfo(chat_info_dict.get('chat_id'), chat_info_dict.get('chat_user_id'), + chat_info_dict.get('chat_user_type'), chat_info_dict.get('knowledge_id_list'), + chat_info_dict.get('exclude_document_id_list'), + chat_info_dict.get('application_id'), + [ChatInfo.map_to_chat_record(c_r) for c_r in chat_info_dict.get('chat_record_list')]) + @staticmethod def get_cache(chat_id): - return cache.get(Cache_Version.CHAT.get_key(key=chat_id), version=Cache_Version.CHAT.get_version()) - - def __getstate__(self): - state = self.__dict__.copy() - state['application'] = None - state['chat_user'] = None - - # 将 ChatRecord ORM 对象转为轻量字典 - if not self.debug and len(self.chat_record_list) > 0: - state['chat_record_list'] = [ - { - 'id': str(record.id), - } - for record in self.chat_record_list - ] - return state - - def __setstate__(self, state): - self.__dict__.update(state) - - # 恢复 application - if self.application is None and self.application_id: - self.get_application() - - # 如果需要完整的 ChatRecord 对象,从数据库重新加载 - if not self.debug and len(self.chat_record_list) > 0: - record_ids = [record['id'] for record in self.chat_record_list if isinstance(record, dict)] - if record_ids: - self.chat_record_list = list( - QuerySet(ChatRecord).filter(id__in=record_ids).order_by('create_time') - ) - + chat_info_dict = cache.get(Cache_Version.CHAT.get_key(key=chat_id), version=Cache_Version.CHAT_INFO.get_version()) + if chat_info_dict: + return ChatInfo.map_to_chat_info(chat_info_dict) + return None diff --git a/apps/common/auth/common.py b/apps/common/auth/common.py index 40158f7d2..ad8e0e50a 100644 --- a/apps/common/auth/common.py +++ b/apps/common/auth/common.py @@ -45,7 +45,7 @@ class ChatAuthentication: value = json.dumps(self.to_dict()) authentication = encrypt(value) cache_key = hashlib.sha256(authentication.encode()).hexdigest() - authentication_cache.set(cache_key, value, version=Cache_Version.CHAT.value, timeout=60 * 60 * 2) + authentication_cache.set(cache_key, value, version=Cache_Version.CHAT.get_version(), timeout=60 * 60 * 2) return authentication @staticmethod diff --git a/apps/common/constants/cache_version.py b/apps/common/constants/cache_version.py index 2cf889c17..6664acb56 100644 --- a/apps/common/constants/cache_version.py +++ b/apps/common/constants/cache_version.py @@ -30,6 +30,8 @@ class Cache_Version(Enum): # 对话 CHAT = "CHAT", lambda key: key + CHAT_INFO = "CHAT_INFO", lambda key: key + CHAT_VARIABLE = "CHAT_VARIABLE", lambda key: key # 应用API KEY