mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-27 20:42:52 +00:00
refactor: model list
This commit is contained in:
parent
8fc074fecb
commit
75df321783
|
|
@ -12,6 +12,7 @@ from rest_framework import serializers
|
|||
from django.db.models.query_utils import Q
|
||||
from common.config.embedding_config import ModelManage
|
||||
from common.database_model_manage.database_model_manage import DatabaseModelManage
|
||||
from common.db.search import native_search
|
||||
from common.db.sql_execute import select_list
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.utils.common import get_file_content
|
||||
|
|
@ -21,6 +22,8 @@ from models_provider.base_model_provider import ValidCode, DownModelChunkStatus
|
|||
from models_provider.constants.model_provider_constants import ModelProvideConstants
|
||||
from models_provider.models import Model, Status
|
||||
from models_provider.tools import get_model_credential
|
||||
from system_manage.models import WorkspaceUserResourcePermission
|
||||
from users.serializers.user import is_workspace_manage
|
||||
|
||||
|
||||
def get_default_model_params_setting(provider, model_type, model_name):
|
||||
|
|
@ -326,6 +329,7 @@ class ModelSerializer(serializers.Serializer):
|
|||
return ModelModelSerializer(model).data
|
||||
|
||||
class Query(serializers.Serializer):
|
||||
user_id = serializers.CharField(required=True, label=_("User ID"))
|
||||
name = serializers.CharField(required=False, max_length=64, label=_('model name'))
|
||||
model_type = serializers.CharField(required=False, label=_('model type'))
|
||||
model_name = serializers.CharField(required=False, label=_('base model'))
|
||||
|
|
@ -333,17 +337,40 @@ class ModelSerializer(serializers.Serializer):
|
|||
create_user = serializers.CharField(required=False, label=_('create user'))
|
||||
workspace_id = serializers.CharField(required=False, label=_('workspace id'))
|
||||
|
||||
@staticmethod
|
||||
def is_x_pack_ee():
|
||||
workspace_user_role_mapping_model = DatabaseModelManage.get_model("workspace_user_role_mapping")
|
||||
role_permission_mapping_model = DatabaseModelManage.get_model("role_permission_mapping_model")
|
||||
return workspace_user_role_mapping_model is not None and role_permission_mapping_model is not None
|
||||
|
||||
def list(self, workspace_id, with_valid):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
user_id = self.data.get("user_id")
|
||||
workspace_manage = is_workspace_manage(user_id, workspace_id)
|
||||
query_params = self._build_query_params(workspace_id, workspace_manage, user_id)
|
||||
is_x_pack_ee = self.is_x_pack_ee()
|
||||
return native_search(query_params,
|
||||
select_string=get_file_content(
|
||||
os.path.join(PROJECT_DIR, "apps", "models_provider", 'sql',
|
||||
'list_model.sql' if workspace_manage else (
|
||||
'list_model_user_ee.sql' if is_x_pack_ee else 'list_model_user.sql')
|
||||
)))
|
||||
|
||||
query_params = self._build_query_params(workspace_id)
|
||||
return [self._build_model_data(model) for model in query_params.order_by("-create_time")]
|
||||
def share_list(self, workspace_id, with_valid=True):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
user_id = self.data.get("user_id")
|
||||
query_params = self._build_query_params(workspace_id, False, user_id)
|
||||
return [self._build_model_data(model) for model in
|
||||
query_params.get('model_query_set').order_by("-create_time")]
|
||||
|
||||
def model_list(self, workspace_id, with_valid=True):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
queryset = self._build_query_params(workspace_id)
|
||||
user_id = self.data.get("user_id")
|
||||
workspace_manage = is_workspace_manage(user_id, workspace_id)
|
||||
queryset = self._build_query_params(workspace_id, workspace_manage, user_id)
|
||||
get_authorized_model = DatabaseModelManage.get_model("get_authorized_model")
|
||||
|
||||
shared_queryset = QuerySet(Model).filter(workspace_id='None')
|
||||
|
|
@ -352,14 +379,20 @@ class ModelSerializer(serializers.Serializer):
|
|||
|
||||
# 构建共享模型和普通模型列表
|
||||
shared_model = [self._build_model_data(model) for model in shared_queryset.order_by("-create_time")]
|
||||
normal_model = [self._build_model_data(model) for model in queryset.order_by("-create_time")]
|
||||
|
||||
is_x_pack_ee = self.is_x_pack_ee()
|
||||
normal_model = native_search(queryset,
|
||||
select_string=get_file_content(
|
||||
os.path.join(PROJECT_DIR, "apps", "models_provider", 'sql',
|
||||
'list_model.sql' if workspace_manage else (
|
||||
'list_model_user_ee.sql' if is_x_pack_ee else 'list_model_user.sql')
|
||||
)))
|
||||
return {
|
||||
"shared_model": shared_model,
|
||||
"model": normal_model
|
||||
}
|
||||
|
||||
def _build_query_params(self, workspace_id):
|
||||
def _build_query_params(self, workspace_id, workspace_manage: bool, user_id):
|
||||
queryset = QuerySet(Model)
|
||||
if workspace_id:
|
||||
queryset = queryset.filter(workspace_id=workspace_id)
|
||||
|
|
@ -372,7 +405,15 @@ class ModelSerializer(serializers.Serializer):
|
|||
queryset = queryset.filter(user_id=value)
|
||||
else:
|
||||
queryset = queryset.filter(**{field: value})
|
||||
return queryset
|
||||
return {
|
||||
'model_query_set': queryset,
|
||||
'workspace_user_resource_permission_query_set': QuerySet(WorkspaceUserResourcePermission).filter(
|
||||
auth_target_type="MODEL",
|
||||
workspace_id=workspace_id,
|
||||
user_id=user_id)} if (
|
||||
not workspace_manage) else {
|
||||
'model_query_set': queryset,
|
||||
}
|
||||
|
||||
def _build_model_data(self, model):
|
||||
return {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,16 @@
|
|||
SELECT model."id"::text, model."name",
|
||||
model.model_name,
|
||||
model.meta,
|
||||
model.credential,
|
||||
model.model_params_form,
|
||||
model.model_type,
|
||||
model.provider,
|
||||
model.status,
|
||||
model.create_time,
|
||||
model.update_time,
|
||||
model.user_id,
|
||||
"user"."nick_name" as "nick_name",
|
||||
model.workspace_id
|
||||
from model
|
||||
left join "user" on user_id = "user".id
|
||||
${model_query_set}
|
||||
|
|
@ -0,0 +1,19 @@
|
|||
SELECT *
|
||||
FROM (SELECT model."id"::text, model."name",
|
||||
model.model_name,
|
||||
model.meta,
|
||||
model.credential,
|
||||
model.model_params_form,
|
||||
model.model_type,
|
||||
model.provider,
|
||||
model.status,
|
||||
model.create_time,
|
||||
model.update_time,
|
||||
model.user_id,
|
||||
"user"."nick_name" as "nick_name",
|
||||
model.workspace_id
|
||||
from model
|
||||
left join "user" on user_id = "user".id
|
||||
where model."id" in (select target
|
||||
from workspace_user_resource_permission ${workspace_user_resource_permission_query_set}
|
||||
and 'VIEW' = any (permission_list)) ) temp ${model_query_set}
|
||||
|
|
@ -0,0 +1,38 @@
|
|||
SELECT *
|
||||
FROM (SELECT model."id"::text, model."name",
|
||||
model.model_name,
|
||||
model.meta,
|
||||
model.credential,
|
||||
model.model_params_form,
|
||||
model.model_type,
|
||||
model.provider,
|
||||
model.status,
|
||||
model.create_time,
|
||||
model.update_time,
|
||||
model.user_id,
|
||||
"user"."nick_name" as "nick_name",
|
||||
model.workspace_id
|
||||
from model
|
||||
left join "user" on user_id = "user".id
|
||||
where model."id" in (select target
|
||||
from workspace_user_resource_permission ${workspace_user_resource_permission_query_set}
|
||||
and case
|
||||
when auth_type = 'ROLE' then
|
||||
'ROLE' = any (permission_list)
|
||||
and
|
||||
'MODEL:READ' in (select (case
|
||||
when user_role_relation.role_id = any (array['USER'])
|
||||
THEN 'MODEL:READ'
|
||||
else role_permission.permission_id END)
|
||||
from role_permission role_permission
|
||||
right join user_role_relation user_role_relation
|
||||
on user_role_relation.role_id = role_permission.role_id
|
||||
where user_role_relation.user_id = workspace_user_resource_permission.user_id
|
||||
and user_role_relation.workspace_id =
|
||||
workspace_user_resource_permission.workspace_id)
|
||||
|
||||
else
|
||||
'VIEW' = any (permission_list)
|
||||
end) ) temp ${model_query_set}
|
||||
|
||||
|
||||
|
|
@ -1,8 +0,0 @@
|
|||
select model_id
|
||||
from model_workspace_authorization
|
||||
where case
|
||||
when authentication_type = 'WHITE_LIST' then
|
||||
%s = any (workspace_id_list)
|
||||
else
|
||||
not %s = any(workspace_id_list)
|
||||
end
|
||||
|
|
@ -101,8 +101,9 @@ class ModelSetting(APIView):
|
|||
def get(self, request: Request, workspace_id: str):
|
||||
return result.success(
|
||||
ModelSerializer.Query(
|
||||
data={**query_params_to_single_dict(request.query_params)}).list(workspace_id=workspace_id,
|
||||
with_valid=True))
|
||||
data={**query_params_to_single_dict(request.query_params), 'user_id': str(request.user.id)}).list(
|
||||
workspace_id=workspace_id,
|
||||
with_valid=True))
|
||||
|
||||
class Operate(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
|
@ -266,5 +267,6 @@ class ModelList(APIView):
|
|||
def get(self, request: Request, workspace_id: str):
|
||||
return result.success(
|
||||
ModelSerializer.Query(
|
||||
data={**query_params_to_single_dict(request.query_params)}).model_list(workspace_id=workspace_id,
|
||||
with_valid=True))
|
||||
data={**query_params_to_single_dict(request.query_params), 'user_id': str(request.user.id)}).model_list(
|
||||
workspace_id=workspace_id,
|
||||
with_valid=True))
|
||||
|
|
|
|||
Loading…
Reference in New Issue