refactor: model list

This commit is contained in:
wxg0103 2025-06-26 19:07:59 +08:00
parent 8fc074fecb
commit 75df321783
6 changed files with 126 additions and 18 deletions

View File

@ -12,6 +12,7 @@ from rest_framework import serializers
from django.db.models.query_utils import Q
from common.config.embedding_config import ModelManage
from common.database_model_manage.database_model_manage import DatabaseModelManage
from common.db.search import native_search
from common.db.sql_execute import select_list
from common.exception.app_exception import AppApiException
from common.utils.common import get_file_content
@ -21,6 +22,8 @@ from models_provider.base_model_provider import ValidCode, DownModelChunkStatus
from models_provider.constants.model_provider_constants import ModelProvideConstants
from models_provider.models import Model, Status
from models_provider.tools import get_model_credential
from system_manage.models import WorkspaceUserResourcePermission
from users.serializers.user import is_workspace_manage
def get_default_model_params_setting(provider, model_type, model_name):
@ -326,6 +329,7 @@ class ModelSerializer(serializers.Serializer):
return ModelModelSerializer(model).data
class Query(serializers.Serializer):
user_id = serializers.CharField(required=True, label=_("User ID"))
name = serializers.CharField(required=False, max_length=64, label=_('model name'))
model_type = serializers.CharField(required=False, label=_('model type'))
model_name = serializers.CharField(required=False, label=_('base model'))
@ -333,17 +337,40 @@ class ModelSerializer(serializers.Serializer):
create_user = serializers.CharField(required=False, label=_('create user'))
workspace_id = serializers.CharField(required=False, label=_('workspace id'))
@staticmethod
def is_x_pack_ee():
workspace_user_role_mapping_model = DatabaseModelManage.get_model("workspace_user_role_mapping")
role_permission_mapping_model = DatabaseModelManage.get_model("role_permission_mapping_model")
return workspace_user_role_mapping_model is not None and role_permission_mapping_model is not None
def list(self, workspace_id, with_valid):
if with_valid:
self.is_valid(raise_exception=True)
user_id = self.data.get("user_id")
workspace_manage = is_workspace_manage(user_id, workspace_id)
query_params = self._build_query_params(workspace_id, workspace_manage, user_id)
is_x_pack_ee = self.is_x_pack_ee()
return native_search(query_params,
select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "models_provider", 'sql',
'list_model.sql' if workspace_manage else (
'list_model_user_ee.sql' if is_x_pack_ee else 'list_model_user.sql')
)))
query_params = self._build_query_params(workspace_id)
return [self._build_model_data(model) for model in query_params.order_by("-create_time")]
def share_list(self, workspace_id, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
user_id = self.data.get("user_id")
query_params = self._build_query_params(workspace_id, False, user_id)
return [self._build_model_data(model) for model in
query_params.get('model_query_set').order_by("-create_time")]
def model_list(self, workspace_id, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
queryset = self._build_query_params(workspace_id)
user_id = self.data.get("user_id")
workspace_manage = is_workspace_manage(user_id, workspace_id)
queryset = self._build_query_params(workspace_id, workspace_manage, user_id)
get_authorized_model = DatabaseModelManage.get_model("get_authorized_model")
shared_queryset = QuerySet(Model).filter(workspace_id='None')
@ -352,14 +379,20 @@ class ModelSerializer(serializers.Serializer):
# 构建共享模型和普通模型列表
shared_model = [self._build_model_data(model) for model in shared_queryset.order_by("-create_time")]
normal_model = [self._build_model_data(model) for model in queryset.order_by("-create_time")]
is_x_pack_ee = self.is_x_pack_ee()
normal_model = native_search(queryset,
select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "models_provider", 'sql',
'list_model.sql' if workspace_manage else (
'list_model_user_ee.sql' if is_x_pack_ee else 'list_model_user.sql')
)))
return {
"shared_model": shared_model,
"model": normal_model
}
def _build_query_params(self, workspace_id):
def _build_query_params(self, workspace_id, workspace_manage: bool, user_id):
queryset = QuerySet(Model)
if workspace_id:
queryset = queryset.filter(workspace_id=workspace_id)
@ -372,7 +405,15 @@ class ModelSerializer(serializers.Serializer):
queryset = queryset.filter(user_id=value)
else:
queryset = queryset.filter(**{field: value})
return queryset
return {
'model_query_set': queryset,
'workspace_user_resource_permission_query_set': QuerySet(WorkspaceUserResourcePermission).filter(
auth_target_type="MODEL",
workspace_id=workspace_id,
user_id=user_id)} if (
not workspace_manage) else {
'model_query_set': queryset,
}
def _build_model_data(self, model):
return {

View File

@ -0,0 +1,16 @@
SELECT model."id"::text, model."name",
model.model_name,
model.meta,
model.credential,
model.model_params_form,
model.model_type,
model.provider,
model.status,
model.create_time,
model.update_time,
model.user_id,
"user"."nick_name" as "nick_name",
model.workspace_id
from model
left join "user" on user_id = "user".id
${model_query_set}

View File

@ -0,0 +1,19 @@
SELECT *
FROM (SELECT model."id"::text, model."name",
model.model_name,
model.meta,
model.credential,
model.model_params_form,
model.model_type,
model.provider,
model.status,
model.create_time,
model.update_time,
model.user_id,
"user"."nick_name" as "nick_name",
model.workspace_id
from model
left join "user" on user_id = "user".id
where model."id" in (select target
from workspace_user_resource_permission ${workspace_user_resource_permission_query_set}
and 'VIEW' = any (permission_list)) ) temp ${model_query_set}

View File

@ -0,0 +1,38 @@
SELECT *
FROM (SELECT model."id"::text, model."name",
model.model_name,
model.meta,
model.credential,
model.model_params_form,
model.model_type,
model.provider,
model.status,
model.create_time,
model.update_time,
model.user_id,
"user"."nick_name" as "nick_name",
model.workspace_id
from model
left join "user" on user_id = "user".id
where model."id" in (select target
from workspace_user_resource_permission ${workspace_user_resource_permission_query_set}
and case
when auth_type = 'ROLE' then
'ROLE' = any (permission_list)
and
'MODEL:READ' in (select (case
when user_role_relation.role_id = any (array['USER'])
THEN 'MODEL:READ'
else role_permission.permission_id END)
from role_permission role_permission
right join user_role_relation user_role_relation
on user_role_relation.role_id = role_permission.role_id
where user_role_relation.user_id = workspace_user_resource_permission.user_id
and user_role_relation.workspace_id =
workspace_user_resource_permission.workspace_id)
else
'VIEW' = any (permission_list)
end) ) temp ${model_query_set}

View File

@ -1,8 +0,0 @@
select model_id
from model_workspace_authorization
where case
when authentication_type = 'WHITE_LIST' then
%s = any (workspace_id_list)
else
not %s = any(workspace_id_list)
end

View File

@ -101,8 +101,9 @@ class ModelSetting(APIView):
def get(self, request: Request, workspace_id: str):
return result.success(
ModelSerializer.Query(
data={**query_params_to_single_dict(request.query_params)}).list(workspace_id=workspace_id,
with_valid=True))
data={**query_params_to_single_dict(request.query_params), 'user_id': str(request.user.id)}).list(
workspace_id=workspace_id,
with_valid=True))
class Operate(APIView):
authentication_classes = [TokenAuth]
@ -266,5 +267,6 @@ class ModelList(APIView):
def get(self, request: Request, workspace_id: str):
return result.success(
ModelSerializer.Query(
data={**query_params_to_single_dict(request.query_params)}).model_list(workspace_id=workspace_id,
with_valid=True))
data={**query_params_to_single_dict(request.query_params), 'user_id': str(request.user.id)}).model_list(
workspace_id=workspace_id,
with_valid=True))