From e8f80094ceeda298cfaa9fe9cae05c5312ea9f4e Mon Sep 17 00:00:00 2001 From: wxg0103 <727495428@qq.com> Date: Thu, 26 Jun 2025 17:26:50 +0800 Subject: [PATCH] refactor: model --- .../serializers/model_serializer.py | 33 +++++++------------ .../sql/list_share_authorized_model.sql | 8 +++++ ui/src/locales/lang/zh-CN/views/model.ts | 2 +- 3 files changed, 20 insertions(+), 23 deletions(-) create mode 100644 apps/models_provider/sql/list_share_authorized_model.sql diff --git a/apps/models_provider/serializers/model_serializer.py b/apps/models_provider/serializers/model_serializer.py index 6270420e7..37b99075f 100644 --- a/apps/models_provider/serializers/model_serializer.py +++ b/apps/models_provider/serializers/model_serializer.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- import json +import os import threading import time from typing import Dict @@ -11,8 +12,11 @@ from rest_framework import serializers from django.db.models.query_utils import Q from common.config.embedding_config import ModelManage from common.database_model_manage.database_model_manage import DatabaseModelManage +from common.db.sql_execute import select_list from common.exception.app_exception import AppApiException +from common.utils.common import get_file_content from common.utils.rsa_util import rsa_long_encrypt, rsa_long_decrypt +from maxkb.conf import PROJECT_DIR from models_provider.base_model_provider import ValidCode, DownModelChunkStatus from models_provider.constants.model_provider_constants import ModelProvideConstants from models_provider.models import Model, Status @@ -412,27 +416,13 @@ class ModelSerializer(serializers.Serializer): return True -def get_authorized_tool(tool_query_set, workspace_id, model_workspace_authorization): - # 对所有工作空间拉黑的工具 - non_auths = QuerySet(model_workspace_authorization).filter( - Q(workspace_id='None') & Q(authentication_type='WHITE_LIST') - ).values_list('model_id', flat=True) - # 授权给所有工作空间的工具 - all_auths = QuerySet(model_workspace_authorization).filter( - Q(workspace_id='None') & Q(authentication_type='BLACK_LIST') - ).values_list('model_id', flat=True) - # 查询白名单授权的工具 - white_authorized_tool_ids = QuerySet(model_workspace_authorization).filter( - workspace_id=workspace_id, authentication_type='WHITE_LIST' - ).values_list('model_id', flat=True) - # 查询黑名单授权的工具 - black_authorized_tool_ids = QuerySet(model_workspace_authorization).filter( - workspace_id=workspace_id, authentication_type='BLACK_LIST' - ).values_list('model_id', flat=True) +def get_authorized_tool(tool_query_set, workspace_id): + model_id_list = select_list(get_file_content( + os.path.join(PROJECT_DIR, "apps", "models_provider", 'sql', + 'list_share_authorized_model.sql' + )), [workspace_id, workspace_id]) tool_query_set = tool_query_set.filter( - id__in=list(white_authorized_tool_ids) + list(all_auths) - ).exclude( - id__in=list(black_authorized_tool_ids) + list(non_auths) + id__in=[k.get('model_id') for k in model_id_list] ) return tool_query_set @@ -471,8 +461,7 @@ class WorkspaceSharedModelSerializer(serializers.Serializer): if workspace_id: model_workspace_authorization = DatabaseModelManage.get_model("model_workspace_authorization") if model_workspace_authorization is not None: - queryset = get_authorized_tool(queryset, workspace_id, - model_workspace_authorization=model_workspace_authorization) + queryset = get_authorized_tool(queryset, workspace_id) for field in ['name', 'model_type', 'model_name', 'provider', 'create_user']: value = self.data.get(field) diff --git a/apps/models_provider/sql/list_share_authorized_model.sql b/apps/models_provider/sql/list_share_authorized_model.sql new file mode 100644 index 000000000..7fca84973 --- /dev/null +++ b/apps/models_provider/sql/list_share_authorized_model.sql @@ -0,0 +1,8 @@ +select model_id +from model_workspace_authorization +where case + when authentication_type = 'WHITE_LIST' then + %s = any (workspace_id_list) + else + not %s = any(workspace_id_list) + end \ No newline at end of file diff --git a/ui/src/locales/lang/zh-CN/views/model.ts b/ui/src/locales/lang/zh-CN/views/model.ts index 0c1375712..16de7793a 100644 --- a/ui/src/locales/lang/zh-CN/views/model.ts +++ b/ui/src/locales/lang/zh-CN/views/model.ts @@ -9,7 +9,7 @@ export default { }, tip: { createSuccessMessage: '创建模型成功', - createErrorMessage: '基础信息有填写错误', + createErrorMessage: '基础信息填写错误', errorMessage: '变量已存在: ', emptyMessage1: '请先选择基础信息的模型类型和基础模型', emptyMessage2: '所选模型不支持参数设置',