diff --git a/apps/models_provider/serializers/model_serializer.py b/apps/models_provider/serializers/model_serializer.py index 8675870dd..28524e54a 100644 --- a/apps/models_provider/serializers/model_serializer.py +++ b/apps/models_provider/serializers/model_serializer.py @@ -12,6 +12,7 @@ from rest_framework import serializers from django.db.models.query_utils import Q from common.config.embedding_config import ModelManage from common.database_model_manage.database_model_manage import DatabaseModelManage +from common.db.search import native_search from common.db.sql_execute import select_list from common.exception.app_exception import AppApiException from common.utils.common import get_file_content @@ -21,6 +22,8 @@ from models_provider.base_model_provider import ValidCode, DownModelChunkStatus from models_provider.constants.model_provider_constants import ModelProvideConstants from models_provider.models import Model, Status from models_provider.tools import get_model_credential +from system_manage.models import WorkspaceUserResourcePermission +from users.serializers.user import is_workspace_manage def get_default_model_params_setting(provider, model_type, model_name): @@ -326,6 +329,7 @@ class ModelSerializer(serializers.Serializer): return ModelModelSerializer(model).data class Query(serializers.Serializer): + user_id = serializers.CharField(required=True, label=_("User ID")) name = serializers.CharField(required=False, max_length=64, label=_('model name')) model_type = serializers.CharField(required=False, label=_('model type')) model_name = serializers.CharField(required=False, label=_('base model')) @@ -333,17 +337,40 @@ class ModelSerializer(serializers.Serializer): create_user = serializers.CharField(required=False, label=_('create user')) workspace_id = serializers.CharField(required=False, label=_('workspace id')) + @staticmethod + def is_x_pack_ee(): + workspace_user_role_mapping_model = DatabaseModelManage.get_model("workspace_user_role_mapping") + role_permission_mapping_model = DatabaseModelManage.get_model("role_permission_mapping_model") + return workspace_user_role_mapping_model is not None and role_permission_mapping_model is not None + def list(self, workspace_id, with_valid): if with_valid: self.is_valid(raise_exception=True) + user_id = self.data.get("user_id") + workspace_manage = is_workspace_manage(user_id, workspace_id) + query_params = self._build_query_params(workspace_id, workspace_manage, user_id) + is_x_pack_ee = self.is_x_pack_ee() + return native_search(query_params, + select_string=get_file_content( + os.path.join(PROJECT_DIR, "apps", "models_provider", 'sql', + 'list_model.sql' if workspace_manage else ( + 'list_model_user_ee.sql' if is_x_pack_ee else 'list_model_user.sql') + ))) - query_params = self._build_query_params(workspace_id) - return [self._build_model_data(model) for model in query_params.order_by("-create_time")] + def share_list(self, workspace_id, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + user_id = self.data.get("user_id") + query_params = self._build_query_params(workspace_id, False, user_id) + return [self._build_model_data(model) for model in + query_params.get('model_query_set').order_by("-create_time")] def model_list(self, workspace_id, with_valid=True): if with_valid: self.is_valid(raise_exception=True) - queryset = self._build_query_params(workspace_id) + user_id = self.data.get("user_id") + workspace_manage = is_workspace_manage(user_id, workspace_id) + queryset = self._build_query_params(workspace_id, workspace_manage, user_id) get_authorized_model = DatabaseModelManage.get_model("get_authorized_model") shared_queryset = QuerySet(Model).filter(workspace_id='None') @@ -352,14 +379,20 @@ class ModelSerializer(serializers.Serializer): # 构建共享模型和普通模型列表 shared_model = [self._build_model_data(model) for model in shared_queryset.order_by("-create_time")] - normal_model = [self._build_model_data(model) for model in queryset.order_by("-create_time")] + is_x_pack_ee = self.is_x_pack_ee() + normal_model = native_search(queryset, + select_string=get_file_content( + os.path.join(PROJECT_DIR, "apps", "models_provider", 'sql', + 'list_model.sql' if workspace_manage else ( + 'list_model_user_ee.sql' if is_x_pack_ee else 'list_model_user.sql') + ))) return { "shared_model": shared_model, "model": normal_model } - def _build_query_params(self, workspace_id): + def _build_query_params(self, workspace_id, workspace_manage: bool, user_id): queryset = QuerySet(Model) if workspace_id: queryset = queryset.filter(workspace_id=workspace_id) @@ -372,7 +405,15 @@ class ModelSerializer(serializers.Serializer): queryset = queryset.filter(user_id=value) else: queryset = queryset.filter(**{field: value}) - return queryset + return { + 'model_query_set': queryset, + 'workspace_user_resource_permission_query_set': QuerySet(WorkspaceUserResourcePermission).filter( + auth_target_type="MODEL", + workspace_id=workspace_id, + user_id=user_id)} if ( + not workspace_manage) else { + 'model_query_set': queryset, + } def _build_model_data(self, model): return { diff --git a/apps/models_provider/sql/list_model.sql b/apps/models_provider/sql/list_model.sql new file mode 100644 index 000000000..83dcbe1bf --- /dev/null +++ b/apps/models_provider/sql/list_model.sql @@ -0,0 +1,16 @@ +SELECT model."id"::text, model."name", + model.model_name, + model.meta, + model.credential, + model.model_params_form, + model.model_type, + model.provider, + model.status, + model.create_time, + model.update_time, + model.user_id, + "user"."nick_name" as "nick_name", + model.workspace_id +from model + left join "user" on user_id = "user".id + ${model_query_set} \ No newline at end of file diff --git a/apps/models_provider/sql/list_model_user.sql b/apps/models_provider/sql/list_model_user.sql new file mode 100644 index 000000000..71c6cca7d --- /dev/null +++ b/apps/models_provider/sql/list_model_user.sql @@ -0,0 +1,19 @@ +SELECT * +FROM (SELECT model."id"::text, model."name", + model.model_name, + model.meta, + model.credential, + model.model_params_form, + model.model_type, + model.provider, + model.status, + model.create_time, + model.update_time, + model.user_id, + "user"."nick_name" as "nick_name", + model.workspace_id + from model + left join "user" on user_id = "user".id + where model."id" in (select target + from workspace_user_resource_permission ${workspace_user_resource_permission_query_set} + and 'VIEW' = any (permission_list)) ) temp ${model_query_set} diff --git a/apps/models_provider/sql/list_model_user_ee.sql b/apps/models_provider/sql/list_model_user_ee.sql new file mode 100644 index 000000000..1846e6470 --- /dev/null +++ b/apps/models_provider/sql/list_model_user_ee.sql @@ -0,0 +1,38 @@ +SELECT * +FROM (SELECT model."id"::text, model."name", + model.model_name, + model.meta, + model.credential, + model.model_params_form, + model.model_type, + model.provider, + model.status, + model.create_time, + model.update_time, + model.user_id, + "user"."nick_name" as "nick_name", + model.workspace_id + from model + left join "user" on user_id = "user".id + where model."id" in (select target + from workspace_user_resource_permission ${workspace_user_resource_permission_query_set} + and case + when auth_type = 'ROLE' then + 'ROLE' = any (permission_list) + and + 'MODEL:READ' in (select (case + when user_role_relation.role_id = any (array['USER']) + THEN 'MODEL:READ' + else role_permission.permission_id END) + from role_permission role_permission + right join user_role_relation user_role_relation + on user_role_relation.role_id = role_permission.role_id + where user_role_relation.user_id = workspace_user_resource_permission.user_id + and user_role_relation.workspace_id = + workspace_user_resource_permission.workspace_id) + + else + 'VIEW' = any (permission_list) + end) ) temp ${model_query_set} + + diff --git a/apps/models_provider/sql/list_share_authorized_model.sql b/apps/models_provider/sql/list_share_authorized_model.sql deleted file mode 100644 index 7fca84973..000000000 --- a/apps/models_provider/sql/list_share_authorized_model.sql +++ /dev/null @@ -1,8 +0,0 @@ -select model_id -from model_workspace_authorization -where case - when authentication_type = 'WHITE_LIST' then - %s = any (workspace_id_list) - else - not %s = any(workspace_id_list) - end \ No newline at end of file diff --git a/apps/models_provider/views/model.py b/apps/models_provider/views/model.py index dbfd95c69..8d746c484 100644 --- a/apps/models_provider/views/model.py +++ b/apps/models_provider/views/model.py @@ -101,8 +101,9 @@ class ModelSetting(APIView): def get(self, request: Request, workspace_id: str): return result.success( ModelSerializer.Query( - data={**query_params_to_single_dict(request.query_params)}).list(workspace_id=workspace_id, - with_valid=True)) + data={**query_params_to_single_dict(request.query_params), 'user_id': str(request.user.id)}).list( + workspace_id=workspace_id, + with_valid=True)) class Operate(APIView): authentication_classes = [TokenAuth] @@ -266,5 +267,6 @@ class ModelList(APIView): def get(self, request: Request, workspace_id: str): return result.success( ModelSerializer.Query( - data={**query_params_to_single_dict(request.query_params)}).model_list(workspace_id=workspace_id, - with_valid=True)) + data={**query_params_to_single_dict(request.query_params), 'user_id': str(request.user.id)}).model_list( + workspace_id=workspace_id, + with_valid=True))