From 3450ef78a46c9b5fcae219c966cdcba9e20b3415 Mon Sep 17 00:00:00 2001 From: shaohuzhang1 <80892890+shaohuzhang1@users.noreply.github.com> Date: Wed, 9 Jul 2025 21:34:41 +0800 Subject: [PATCH] fix: Application import and export (#3538) --- apps/application/serializers/application.py | 50 ++++++++++++++++----- apps/common/utils/common.py | 5 +++ apps/tools/serializers/tool.py | 14 +++++- 3 files changed, 56 insertions(+), 13 deletions(-) diff --git a/apps/application/serializers/application.py b/apps/application/serializers/application.py index aa4c181c1..80eecae8a 100644 --- a/apps/application/serializers/application.py +++ b/apps/application/serializers/application.py @@ -13,6 +13,7 @@ import json import os import pickle import re +from functools import reduce from typing import Dict, List import uuid_utils.compat as uuid @@ -33,7 +34,7 @@ from common.database_model_manage.database_model_manage import DatabaseModelMana from common.db.search import native_search, native_page_search from common.exception.app_exception import AppApiException from common.field.common import UploadedFileField -from common.utils.common import get_file_content, valid_license, restricted_loads +from common.utils.common import get_file_content, valid_license, restricted_loads, generate_uuid from knowledge.models import Knowledge, KnowledgeScope from knowledge.serializers.knowledge import KnowledgeSerializer, KnowledgeModelSerializer from maxkb.conf import PROJECT_DIR @@ -493,18 +494,37 @@ class ApplicationSerializer(serializers.Serializer): except Exception as e: raise AppApiException(1001, _("Unsupported file format")) application = mk_instance.application - tool_list = mk_instance.get_tool_list() + update_tool_map = {} if len(tool_list) > 0: - tool_id_list = [tool.get('id') for tool in tool_list] + tool_id_list = reduce(lambda x, y: [*x, *y], + [[tool.get('id'), generate_uuid((tool.get('id') + tool.get('workspace_id') or ''))] + for tool + in + tool_list], []) + # 存在的工具列表 exits_tool_id_list = [str(tool.id) for tool in - QuerySet(Tool).filter(id__in=tool_id_list)] - # 获取到需要插入的函数 - tool_list = [tool for tool in tool_id_list if - not exits_tool_id_list.__contains__(tool.get('id'))] - application_model = self.to_application(application, workspace_id, user_id) + QuerySet(Tool).filter(id__in=tool_id_list, workspace_id=workspace_id)] + # 需要更新的工具集合 + update_tool_map = {tool.get('id'): generate_uuid((tool.get('id') + tool.get('workspace_id') or '')) for tool + in + tool_list if + not exits_tool_id_list.__contains__( + tool.get('id'))} + + tool_list = [{**tool, 'id': update_tool_map.get(tool.get('id'))} for tool in tool_list if + not exits_tool_id_list.__contains__( + tool.get('id')) and not exits_tool_id_list.__contains__( + generate_uuid((tool.get('id') + tool.get('workspace_id') or '')))] + application_model = self.to_application(application, workspace_id, user_id, update_tool_map) tool_model_list = [self.to_tool(f, workspace_id, user_id) for f in tool_list] application_model.save() + # 插入授权数据 + UserResourcePermissionSerializer(data={ + 'workspace_id': self.data.get('workspace_id'), + 'user_id': self.data.get('user_id'), + 'auth_target_type': AuthTargetType.APPLICATION.value + }).auth_resource(str(application_model.id)) # 插入认证信息 ApplicationAccessToken(application_id=application_model.id, access_token=hashlib.md5(str(uuid.uuid7()).encode()).hexdigest()[8:24]).save() @@ -526,18 +546,24 @@ class ApplicationSerializer(serializers.Serializer): input_field_list=tool.get('input_field_list'), is_active=tool.get('is_active'), scope=ToolScope.WORKSPACE, + folder_id=workspace_id, workspace_id=workspace_id) @staticmethod - def to_application(application, workspace_id, user_id): + def to_application(application, workspace_id, user_id, update_tool_map): work_flow = application.get('work_flow') for node in work_flow.get('nodes', []): + if node.get('type') == 'tool-lib-node': + tool_lib_id = (node.get('properties', {}).get('node_data', {}).get('tool_lib_id') or '') + node.get('properties', {}).get('node_data', {})['tool_lib_id'] = update_tool_map.get(tool_lib_id, + tool_lib_id) if node.get('type') == 'search-knowledge-node': node.get('properties', {}).get('node_data', {})['knowledge_id_list'] = [] return Application(id=uuid.uuid7(), user_id=user_id, name=application.get('name'), workspace_id=workspace_id, + folder_id=workspace_id, desc=application.get('desc'), prologue=application.get('prologue'), dialogue_number=application.get('dialogue_number'), knowledge_setting=application.get('knowledge_setting'), @@ -624,13 +650,13 @@ class ApplicationOperateSerializer(serializers.Serializer): self.is_valid() application_id = self.data.get('application_id') application = QuerySet(Application).filter(id=application_id).first() - tool_id_list = [node.get('properties', {}).get('node_data', {}).get('tool_id') for node + tool_id_list = [node.get('properties', {}).get('node_data', {}).get('tool_lib_id') for node in application.work_flow.get('nodes', []) if - node.get('type') == 'tool-node'] + node.get('type') == 'tool-lib-node'] tool_list = [] if len(tool_id_list) > 0: - tool_list = QuerySet(Tool).filter(id__in=tool_id_list) + tool_list = QuerySet(Tool).filter(id__in=tool_id_list).exclude(scope=ToolScope.SHARED) application_dict = ApplicationSerializerModel(application).data mk_instance = MKInstance(application_dict, diff --git a/apps/common/utils/common.py b/apps/common/utils/common.py index 6d12e803a..128f32da4 100644 --- a/apps/common/utils/common.py +++ b/apps/common/utils/common.py @@ -13,6 +13,7 @@ import pickle import random import re import shutil +import uuid from functools import reduce from typing import List, Dict @@ -329,3 +330,7 @@ def parse_image(content: str): matches = re.finditer("!\[.*?\]\(\/oss\/(image|file)\/.*?\)", content) image_list = [match.group() for match in matches] return image_list + + +def generate_uuid(tag: str): + return str(uuid.uuid5(uuid.NAMESPACE_DNS, tag)) diff --git a/apps/tools/serializers/tool.py b/apps/tools/serializers/tool.py index 2c8208094..6f6bcf898 100644 --- a/apps/tools/serializers/tool.py +++ b/apps/tools/serializers/tool.py @@ -287,6 +287,18 @@ class ToolSerializer(serializers.Serializer): id = serializers.UUIDField(required=True, label=_('tool id')) workspace_id = serializers.CharField(required=True, label=_('workspace id')) + def is_one_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + workspace_id = self.data.get('workspace_id') + query_set = QuerySet(Tool).filter(id=self.data.get('id')) + if workspace_id: + query_set = query_set.filter(workspace_id=workspace_id) + if not query_set.exists(): + get_authorized_tool = DatabaseModelManage.get_model('get_authorized_tool') + if get_authorized_tool: + return get_authorized_tool(QuerySet(Tool).filter(id=self.data.get('id')), workspace_id).exists() + raise AppApiException(500, _('Tool id does not exist')) + def is_valid(self, *, raise_exception=False): super().is_valid(raise_exception=True) workspace_id = self.data.get('workspace_id') @@ -337,7 +349,7 @@ class ToolSerializer(serializers.Serializer): QuerySet(Tool).filter(id=self.data.get('id')).delete() def one(self): - self.is_valid(raise_exception=True) + self.is_one_valid(raise_exception=True) tool = QuerySet(Tool).filter(id=self.data.get('id')).first() if tool.init_params: tool.init_params = json.loads(rsa_long_decrypt(tool.init_params))