fix: Application import and export (#3538)

This commit is contained in:
shaohuzhang1 2025-07-09 21:34:41 +08:00 committed by GitHub
parent 857f988992
commit 3450ef78a4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 56 additions and 13 deletions

View File

@ -13,6 +13,7 @@ import json
import os
import pickle
import re
from functools import reduce
from typing import Dict, List
import uuid_utils.compat as uuid
@ -33,7 +34,7 @@ from common.database_model_manage.database_model_manage import DatabaseModelMana
from common.db.search import native_search, native_page_search
from common.exception.app_exception import AppApiException
from common.field.common import UploadedFileField
from common.utils.common import get_file_content, valid_license, restricted_loads
from common.utils.common import get_file_content, valid_license, restricted_loads, generate_uuid
from knowledge.models import Knowledge, KnowledgeScope
from knowledge.serializers.knowledge import KnowledgeSerializer, KnowledgeModelSerializer
from maxkb.conf import PROJECT_DIR
@ -493,18 +494,37 @@ class ApplicationSerializer(serializers.Serializer):
except Exception as e:
raise AppApiException(1001, _("Unsupported file format"))
application = mk_instance.application
tool_list = mk_instance.get_tool_list()
update_tool_map = {}
if len(tool_list) > 0:
tool_id_list = [tool.get('id') for tool in tool_list]
tool_id_list = reduce(lambda x, y: [*x, *y],
[[tool.get('id'), generate_uuid((tool.get('id') + tool.get('workspace_id') or ''))]
for tool
in
tool_list], [])
# 存在的工具列表
exits_tool_id_list = [str(tool.id) for tool in
QuerySet(Tool).filter(id__in=tool_id_list)]
# 获取到需要插入的函数
tool_list = [tool for tool in tool_id_list if
not exits_tool_id_list.__contains__(tool.get('id'))]
application_model = self.to_application(application, workspace_id, user_id)
QuerySet(Tool).filter(id__in=tool_id_list, workspace_id=workspace_id)]
# 需要更新的工具集合
update_tool_map = {tool.get('id'): generate_uuid((tool.get('id') + tool.get('workspace_id') or '')) for tool
in
tool_list if
not exits_tool_id_list.__contains__(
tool.get('id'))}
tool_list = [{**tool, 'id': update_tool_map.get(tool.get('id'))} for tool in tool_list if
not exits_tool_id_list.__contains__(
tool.get('id')) and not exits_tool_id_list.__contains__(
generate_uuid((tool.get('id') + tool.get('workspace_id') or '')))]
application_model = self.to_application(application, workspace_id, user_id, update_tool_map)
tool_model_list = [self.to_tool(f, workspace_id, user_id) for f in tool_list]
application_model.save()
# 插入授权数据
UserResourcePermissionSerializer(data={
'workspace_id': self.data.get('workspace_id'),
'user_id': self.data.get('user_id'),
'auth_target_type': AuthTargetType.APPLICATION.value
}).auth_resource(str(application_model.id))
# 插入认证信息
ApplicationAccessToken(application_id=application_model.id,
access_token=hashlib.md5(str(uuid.uuid7()).encode()).hexdigest()[8:24]).save()
@ -526,18 +546,24 @@ class ApplicationSerializer(serializers.Serializer):
input_field_list=tool.get('input_field_list'),
is_active=tool.get('is_active'),
scope=ToolScope.WORKSPACE,
folder_id=workspace_id,
workspace_id=workspace_id)
@staticmethod
def to_application(application, workspace_id, user_id):
def to_application(application, workspace_id, user_id, update_tool_map):
work_flow = application.get('work_flow')
for node in work_flow.get('nodes', []):
if node.get('type') == 'tool-lib-node':
tool_lib_id = (node.get('properties', {}).get('node_data', {}).get('tool_lib_id') or '')
node.get('properties', {}).get('node_data', {})['tool_lib_id'] = update_tool_map.get(tool_lib_id,
tool_lib_id)
if node.get('type') == 'search-knowledge-node':
node.get('properties', {}).get('node_data', {})['knowledge_id_list'] = []
return Application(id=uuid.uuid7(),
user_id=user_id,
name=application.get('name'),
workspace_id=workspace_id,
folder_id=workspace_id,
desc=application.get('desc'),
prologue=application.get('prologue'), dialogue_number=application.get('dialogue_number'),
knowledge_setting=application.get('knowledge_setting'),
@ -624,13 +650,13 @@ class ApplicationOperateSerializer(serializers.Serializer):
self.is_valid()
application_id = self.data.get('application_id')
application = QuerySet(Application).filter(id=application_id).first()
tool_id_list = [node.get('properties', {}).get('node_data', {}).get('tool_id') for node
tool_id_list = [node.get('properties', {}).get('node_data', {}).get('tool_lib_id') for node
in
application.work_flow.get('nodes', []) if
node.get('type') == 'tool-node']
node.get('type') == 'tool-lib-node']
tool_list = []
if len(tool_id_list) > 0:
tool_list = QuerySet(Tool).filter(id__in=tool_id_list)
tool_list = QuerySet(Tool).filter(id__in=tool_id_list).exclude(scope=ToolScope.SHARED)
application_dict = ApplicationSerializerModel(application).data
mk_instance = MKInstance(application_dict,

View File

@ -13,6 +13,7 @@ import pickle
import random
import re
import shutil
import uuid
from functools import reduce
from typing import List, Dict
@ -329,3 +330,7 @@ def parse_image(content: str):
matches = re.finditer("!\[.*?\]\(\/oss\/(image|file)\/.*?\)", content)
image_list = [match.group() for match in matches]
return image_list
def generate_uuid(tag: str):
return str(uuid.uuid5(uuid.NAMESPACE_DNS, tag))

View File

@ -287,6 +287,18 @@ class ToolSerializer(serializers.Serializer):
id = serializers.UUIDField(required=True, label=_('tool id'))
workspace_id = serializers.CharField(required=True, label=_('workspace id'))
def is_one_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
workspace_id = self.data.get('workspace_id')
query_set = QuerySet(Tool).filter(id=self.data.get('id'))
if workspace_id:
query_set = query_set.filter(workspace_id=workspace_id)
if not query_set.exists():
get_authorized_tool = DatabaseModelManage.get_model('get_authorized_tool')
if get_authorized_tool:
return get_authorized_tool(QuerySet(Tool).filter(id=self.data.get('id')), workspace_id).exists()
raise AppApiException(500, _('Tool id does not exist'))
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
workspace_id = self.data.get('workspace_id')
@ -337,7 +349,7 @@ class ToolSerializer(serializers.Serializer):
QuerySet(Tool).filter(id=self.data.get('id')).delete()
def one(self):
self.is_valid(raise_exception=True)
self.is_one_valid(raise_exception=True)
tool = QuerySet(Tool).filter(id=self.data.get('id')).first()
if tool.init_params:
tool.init_params = json.loads(rsa_long_decrypt(tool.init_params))