refactor: model

This commit is contained in:
wxg0103 2025-06-26 17:26:50 +08:00
parent d49f448a5f
commit e8f80094ce
3 changed files with 20 additions and 23 deletions

View File

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
import json
import os
import threading
import time
from typing import Dict
@ -11,8 +12,11 @@ from rest_framework import serializers
from django.db.models.query_utils import Q
from common.config.embedding_config import ModelManage
from common.database_model_manage.database_model_manage import DatabaseModelManage
from common.db.sql_execute import select_list
from common.exception.app_exception import AppApiException
from common.utils.common import get_file_content
from common.utils.rsa_util import rsa_long_encrypt, rsa_long_decrypt
from maxkb.conf import PROJECT_DIR
from models_provider.base_model_provider import ValidCode, DownModelChunkStatus
from models_provider.constants.model_provider_constants import ModelProvideConstants
from models_provider.models import Model, Status
@ -412,27 +416,13 @@ class ModelSerializer(serializers.Serializer):
return True
def get_authorized_tool(tool_query_set, workspace_id, model_workspace_authorization):
# 对所有工作空间拉黑的工具
non_auths = QuerySet(model_workspace_authorization).filter(
Q(workspace_id='None') & Q(authentication_type='WHITE_LIST')
).values_list('model_id', flat=True)
# 授权给所有工作空间的工具
all_auths = QuerySet(model_workspace_authorization).filter(
Q(workspace_id='None') & Q(authentication_type='BLACK_LIST')
).values_list('model_id', flat=True)
# 查询白名单授权的工具
white_authorized_tool_ids = QuerySet(model_workspace_authorization).filter(
workspace_id=workspace_id, authentication_type='WHITE_LIST'
).values_list('model_id', flat=True)
# 查询黑名单授权的工具
black_authorized_tool_ids = QuerySet(model_workspace_authorization).filter(
workspace_id=workspace_id, authentication_type='BLACK_LIST'
).values_list('model_id', flat=True)
def get_authorized_tool(tool_query_set, workspace_id):
model_id_list = select_list(get_file_content(
os.path.join(PROJECT_DIR, "apps", "models_provider", 'sql',
'list_share_authorized_model.sql'
)), [workspace_id, workspace_id])
tool_query_set = tool_query_set.filter(
id__in=list(white_authorized_tool_ids) + list(all_auths)
).exclude(
id__in=list(black_authorized_tool_ids) + list(non_auths)
id__in=[k.get('model_id') for k in model_id_list]
)
return tool_query_set
@ -471,8 +461,7 @@ class WorkspaceSharedModelSerializer(serializers.Serializer):
if workspace_id:
model_workspace_authorization = DatabaseModelManage.get_model("model_workspace_authorization")
if model_workspace_authorization is not None:
queryset = get_authorized_tool(queryset, workspace_id,
model_workspace_authorization=model_workspace_authorization)
queryset = get_authorized_tool(queryset, workspace_id)
for field in ['name', 'model_type', 'model_name', 'provider', 'create_user']:
value = self.data.get(field)

View File

@ -0,0 +1,8 @@
select model_id
from model_workspace_authorization
where case
when authentication_type = 'WHITE_LIST' then
%s = any (workspace_id_list)
else
not %s = any(workspace_id_list)
end

View File

@ -9,7 +9,7 @@ export default {
},
tip: {
createSuccessMessage: '创建模型成功',
createErrorMessage: '基础信息填写错误',
createErrorMessage: '基础信息填写错误',
errorMessage: '变量已存在: ',
emptyMessage1: '请先选择基础信息的模型类型和基础模型',
emptyMessage2: '所选模型不支持参数设置',