diff --git a/apps/setting/serializers/provider_serializers.py b/apps/setting/serializers/provider_serializers.py index b5d85e8a3..c155ca1cd 100644 --- a/apps/setting/serializers/provider_serializers.py +++ b/apps/setting/serializers/provider_serializers.py @@ -7,12 +7,14 @@ @desc: """ import json +import re import threading import time import uuid from typing import Dict -from django.db.models import QuerySet +from django.core import validators +from django.db.models import QuerySet, Q from rest_framework import serializers from application.models import Application @@ -72,7 +74,7 @@ class ModelSerializer(serializers.Serializer): self.is_valid(raise_exception=True) user_id = self.data.get('user_id') name = self.data.get('name') - model_query_set = QuerySet(Model).filter(user_id=user_id) + model_query_set = QuerySet(Model).filter((Q(user_id=user_id) | Q(permission_type='PUBLIC'))) query_params = {} if name is not None: query_params['name__contains'] = name @@ -96,6 +98,11 @@ class ModelSerializer(serializers.Serializer): model_type = serializers.CharField(required=False, error_messages=ErrMessage.char("模型类型")) + permission_type = serializers.CharField(required=False, error_messages=ErrMessage.char("权限"), validators=[ + validators.RegexValidator(regex=re.compile("^PUBLIC|PRIVATE$"), + message="权限只支持PUBLIC|PRIVATE", code=500) + ]) + model_name = serializers.CharField(required=False, error_messages=ErrMessage.char("模型类型")) credential = serializers.DictField(required=False, error_messages=ErrMessage.dict("认证信息")) @@ -135,6 +142,11 @@ class ModelSerializer(serializers.Serializer): model_type = serializers.CharField(required=True, error_messages=ErrMessage.char("模型类型")) + permission_type = serializers.CharField(required=True, error_messages=ErrMessage.char("权限"), validators=[ + validators.RegexValidator(regex=re.compile("^PUBLIC|PRIVATE$"), + message="权限只支持PUBLIC|PRIVATE", code=500) + ]) + model_name = serializers.CharField(required=True, error_messages=ErrMessage.char("基础模型")) credential = serializers.DictField(required=True, error_messages=ErrMessage.dict("认证信息")) @@ -165,10 +177,12 @@ class ModelSerializer(serializers.Serializer): provider = self.data.get('provider') model_type = self.data.get('model_type') model_name = self.data.get('model_name') + permission_type = self.data.get('permission_type') model_credential_str = json.dumps(credential) model = Model(id=uuid.uuid1(), status=status, user_id=user_id, name=name, credential=rsa_long_encrypt(model_credential_str), - provider=provider, model_type=model_type, model_name=model_name) + provider=provider, model_type=model_type, model_name=model_name, + permission_type=permission_type) model.save() if status == Status.DOWNLOAD: thread = threading.Thread(target=ModelPullManage.pull, args=(model, credential)) @@ -245,7 +259,7 @@ class ModelSerializer(serializers.Serializer): model.status = Status.DOWNLOAD else: raise e - update_keys = ['credential', 'name', 'model_type', 'model_name'] + update_keys = ['credential', 'name', 'model_type', 'model_name', 'permission_type'] for update_key in update_keys: if update_key in instance and instance.get(update_key) is not None: if update_key == 'credential': diff --git a/apps/setting/swagger_api/provide_api.py b/apps/setting/swagger_api/provide_api.py index f68ac5be4..7544fdf25 100644 --- a/apps/setting/swagger_api/provide_api.py +++ b/apps/setting/swagger_api/provide_api.py @@ -74,6 +74,8 @@ class ModelCreateApi(ApiMixin): 'provider': openapi.Schema(type=openapi.TYPE_STRING, title="供应商", description="供应商"), + 'permission_type': openapi.Schema(type=openapi.TYPE_STRING, title="权限", + description="PUBLIC|PRIVATE"), 'model_type': openapi.Schema(type=openapi.TYPE_STRING, title="供应商", description="供应商"), @@ -82,7 +84,8 @@ class ModelCreateApi(ApiMixin): description="供应商"), 'credential': openapi.Schema(type=openapi.TYPE_OBJECT, title="模型证书信息", - description="模型证书信息") + description="模型证书信息"), + } )