mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 18:22:46 +00:00
refactor: model
This commit is contained in:
parent
d49f448a5f
commit
e8f80094ce
|
|
@ -1,5 +1,6 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from typing import Dict
|
||||
|
|
@ -11,8 +12,11 @@ from rest_framework import serializers
|
|||
from django.db.models.query_utils import Q
|
||||
from common.config.embedding_config import ModelManage
|
||||
from common.database_model_manage.database_model_manage import DatabaseModelManage
|
||||
from common.db.sql_execute import select_list
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.utils.common import get_file_content
|
||||
from common.utils.rsa_util import rsa_long_encrypt, rsa_long_decrypt
|
||||
from maxkb.conf import PROJECT_DIR
|
||||
from models_provider.base_model_provider import ValidCode, DownModelChunkStatus
|
||||
from models_provider.constants.model_provider_constants import ModelProvideConstants
|
||||
from models_provider.models import Model, Status
|
||||
|
|
@ -412,27 +416,13 @@ class ModelSerializer(serializers.Serializer):
|
|||
return True
|
||||
|
||||
|
||||
def get_authorized_tool(tool_query_set, workspace_id, model_workspace_authorization):
|
||||
# 对所有工作空间拉黑的工具
|
||||
non_auths = QuerySet(model_workspace_authorization).filter(
|
||||
Q(workspace_id='None') & Q(authentication_type='WHITE_LIST')
|
||||
).values_list('model_id', flat=True)
|
||||
# 授权给所有工作空间的工具
|
||||
all_auths = QuerySet(model_workspace_authorization).filter(
|
||||
Q(workspace_id='None') & Q(authentication_type='BLACK_LIST')
|
||||
).values_list('model_id', flat=True)
|
||||
# 查询白名单授权的工具
|
||||
white_authorized_tool_ids = QuerySet(model_workspace_authorization).filter(
|
||||
workspace_id=workspace_id, authentication_type='WHITE_LIST'
|
||||
).values_list('model_id', flat=True)
|
||||
# 查询黑名单授权的工具
|
||||
black_authorized_tool_ids = QuerySet(model_workspace_authorization).filter(
|
||||
workspace_id=workspace_id, authentication_type='BLACK_LIST'
|
||||
).values_list('model_id', flat=True)
|
||||
def get_authorized_tool(tool_query_set, workspace_id):
|
||||
model_id_list = select_list(get_file_content(
|
||||
os.path.join(PROJECT_DIR, "apps", "models_provider", 'sql',
|
||||
'list_share_authorized_model.sql'
|
||||
)), [workspace_id, workspace_id])
|
||||
tool_query_set = tool_query_set.filter(
|
||||
id__in=list(white_authorized_tool_ids) + list(all_auths)
|
||||
).exclude(
|
||||
id__in=list(black_authorized_tool_ids) + list(non_auths)
|
||||
id__in=[k.get('model_id') for k in model_id_list]
|
||||
)
|
||||
return tool_query_set
|
||||
|
||||
|
|
@ -471,8 +461,7 @@ class WorkspaceSharedModelSerializer(serializers.Serializer):
|
|||
if workspace_id:
|
||||
model_workspace_authorization = DatabaseModelManage.get_model("model_workspace_authorization")
|
||||
if model_workspace_authorization is not None:
|
||||
queryset = get_authorized_tool(queryset, workspace_id,
|
||||
model_workspace_authorization=model_workspace_authorization)
|
||||
queryset = get_authorized_tool(queryset, workspace_id)
|
||||
|
||||
for field in ['name', 'model_type', 'model_name', 'provider', 'create_user']:
|
||||
value = self.data.get(field)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,8 @@
|
|||
select model_id
|
||||
from model_workspace_authorization
|
||||
where case
|
||||
when authentication_type = 'WHITE_LIST' then
|
||||
%s = any (workspace_id_list)
|
||||
else
|
||||
not %s = any(workspace_id_list)
|
||||
end
|
||||
|
|
@ -9,7 +9,7 @@ export default {
|
|||
},
|
||||
tip: {
|
||||
createSuccessMessage: '创建模型成功',
|
||||
createErrorMessage: '基础信息有填写错误',
|
||||
createErrorMessage: '基础信息填写错误',
|
||||
errorMessage: '变量已存在: ',
|
||||
emptyMessage1: '请先选择基础信息的模型类型和基础模型',
|
||||
emptyMessage2: '所选模型不支持参数设置',
|
||||
|
|
|
|||
Loading…
Reference in New Issue