From 2ae813f9ff2ccfe00dc3527a510f3c530ab19b47 Mon Sep 17 00:00:00 2001 From: shaohuzhang1 Date: Mon, 8 Dec 2025 14:08:28 +0800 Subject: [PATCH] perf: Optimization of data source file download logic --- .../tool_lib_node/impl/base_tool_lib_node.py | 61 ++++++++++++++++++- apps/common/utils/tool_code.py | 43 ++++++++++++- 2 files changed, 101 insertions(+), 3 deletions(-) diff --git a/apps/application/flow/step_node/tool_lib_node/impl/base_tool_lib_node.py b/apps/application/flow/step_node/tool_lib_node/impl/base_tool_lib_node.py index 7a0690e93..da22e30c1 100644 --- a/apps/application/flow/step_node/tool_lib_node/impl/base_tool_lib_node.py +++ b/apps/application/flow/step_node/tool_lib_node/impl/base_tool_lib_node.py @@ -6,10 +6,14 @@ @date:2024/8/8 17:49 @desc: """ +import ast +import io import json +import mimetypes import time from typing import Dict +from django.core.files.uploadedfile import InMemoryUploadedFile from django.db.models import QuerySet from django.utils.translation import gettext as _ @@ -19,7 +23,8 @@ from common.database_model_manage.database_model_manage import DatabaseModelMana from common.exception.app_exception import AppApiException from common.utils.rsa_util import rsa_long_decrypt from common.utils.tool_code import ToolExecutor -from maxkb.const import CONFIG +from knowledge.models import FileSourceType +from oss.serializers.file import FileSerializer from tools.models import Tool function_executor = ToolExecutor() @@ -126,6 +131,7 @@ def valid_function(tool_lib, workspace_id): if not tool_lib.is_active: raise Exception(_("Tool is not active")) + def _filter_file_bytes(data): """递归过滤掉所有层级的 file_bytes""" if isinstance(data, dict): @@ -136,6 +142,27 @@ def _filter_file_bytes(data): return data +def bytes_to_uploaded_file(file_bytes, file_name="unknown"): + content_type, _ = mimetypes.guess_type(file_name) + if content_type is None: + # 如果未能识别,设置为默认的二进制文件类型 + content_type = "application/octet-stream" + # 创建一个内存中的字节流对象 + file_stream = io.BytesIO(file_bytes) + + # 获取文件大小 + file_size = len(file_bytes) + + uploaded_file = InMemoryUploadedFile( + file=file_stream, + field_name=None, + name=file_name, + content_type=content_type, + size=file_size, + charset=None, + ) + return uploaded_file + class BaseToolLibNodeNode(IToolLibNode): def save_context(self, details, workflow_manage): @@ -168,12 +195,42 @@ class BaseToolLibNodeNode(IToolLibNode): else: all_params = init_params_default_value | params if self.node.properties.get('kind') == 'data-source': - all_params = {**all_params, **self.workflow_params.get('data_source')} + exist = function_executor.exist_function(tool_lib.code, 'get_download_file_list') + if exist: + download_file_list = [] + download_list = function_executor.exec_code(tool_lib.code, + {**all_params, **self.workflow_params.get('data_source')}, + function_name='get_download_file_list') + for item in download_list: + result = function_executor.exec_code(tool_lib.code, + {**all_params, **self.workflow_params.get('data_source'), + 'download_item': item}, + function_name='download') + file = bytes_to_uploaded_file(ast.literal_eval(result.get('file_bytes')), result.get('name')) + file_url = self.upload_knowledge_file(file) + download_file_list.append({'file_id': file_url, 'name': result.get('name')}) + all_params = {**all_params, **self.workflow_params.get('data_source'), + 'download_file_list': download_file_list} result = function_executor.exec_code(tool_lib.code, all_params) return NodeResult({'result': result}, (self.workflow_manage.params.get('knowledge_base') or {}) if self.node.properties.get( 'kind') == 'data-source' else {}, _write_context=write_context) + def upload_knowledge_file(self, file): + knowledge_id = self.workflow_params.get('knowledge_id') + meta = { + 'debug': False, + 'knowledge_id': knowledge_id, + } + file_url = FileSerializer(data={ + 'file': file, + 'meta': meta, + 'source_id': knowledge_id, + 'source_type': FileSourceType.KNOWLEDGE.value + }).upload().replace("./oss/file/", '') + file.close() + return file_url + def get_details(self, index: int, **kwargs): result = _filter_file_bytes(self.context.get('result')) diff --git a/apps/common/utils/tool_code.py b/apps/common/utils/tool_code.py index 6e94542aa..ff77b9467 100644 --- a/apps/common/utils/tool_code.py +++ b/apps/common/utils/tool_code.py @@ -74,6 +74,47 @@ class ToolExecutor: except Exception as e: maxkb_logger.error(f'Exception: {e}', exc_info=True) + def exist_function(self, code_str, name): + _id = str(uuid.uuid7()) + python_paths = CONFIG.get_sandbox_python_package_paths().split(',') + set_run_user = f'os.setgid({pwd.getpwnam(_run_user).pw_gid});os.setuid({pwd.getpwnam(_run_user).pw_uid});' if _enable_sandbox else '' + _exec_code = f""" +try: + import os, sys, json + path_to_exclude = ['/opt/py3/lib/python3.11/site-packages', '/opt/maxkb-app/apps'] + sys.path = [p for p in sys.path if p not in path_to_exclude] + sys.path += {python_paths} + locals_v={{}} + globals_v={{}} + {set_run_user} + os.environ.clear() + exec({dedent(code_str)!a}, globals_v, locals_v) + exec_result=locals_v.__contains__('{name}') + sys.stdout.write("\\n{_id}:") + json.dump({{'code':200,'msg':'success','data':exec_result}}, sys.stdout, default=str) +except Exception as e: + if isinstance(e, MemoryError): e = Exception("Cannot allocate more memory: exceeded the limit of {_process_limit_mem_mb} MB.") + sys.stdout.write("\\n{_id}:") + json.dump({{'code':500,'msg':str(e),'data':False}}, sys.stdout, default=str) +sys.stdout.flush() + """ + maxkb_logger.debug(f"Sandbox execute code: {_exec_code}") + with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=True) as f: + f.write(_exec_code) + f.flush() + subprocess_result = self._exec(f.name) + if subprocess_result.returncode != 0: + raise Exception(subprocess_result.stderr or subprocess_result.stdout or "Unknown exception occurred") + lines = subprocess_result.stdout.splitlines() + result_line = [line for line in lines if line.startswith(_id)] + if not result_line: + maxkb_logger.error("\n".join(lines)) + raise Exception("No result found.") + result = json.loads(result_line[-1].split(":", 1)[1]) + if result.get('code') == 200: + return result.get('data') + raise Exception(result.get('msg')) + def exec_code(self, code_str, keywords, function_name=None): _id = str(uuid.uuid7()) action_function = f'({function_name !a}, locals_v.get({function_name !a}))' if function_name else 'locals_v.popitem()' @@ -212,7 +253,7 @@ exec({dedent(code)!a}) ], 'cwd': _sandbox_path, 'env': { - 'LD_PRELOAD': f'{_sandbox_path}/lib/sandbox.so', + 'LD_PRELOAD': f'{_sandbox_path}/lib/sandbox.so', }, 'transport': 'stdio', }