From 9b81b89975cad700395401c235596cd06bfaeaf6 Mon Sep 17 00:00:00 2001 From: shaohuzhang1 Date: Tue, 16 Jul 2024 16:08:31 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=A8=A1=E5=9E=8B=E6=8E=A5=E5=8F=A3?= =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E6=9D=83=E9=99=90=E5=8F=82=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../serializers/provider_serializers.py | 22 +++++++++++++++---- apps/setting/swagger_api/provide_api.py | 5 ++++- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/apps/setting/serializers/provider_serializers.py b/apps/setting/serializers/provider_serializers.py index b5d85e8a3..c155ca1cd 100644 --- a/apps/setting/serializers/provider_serializers.py +++ b/apps/setting/serializers/provider_serializers.py @@ -7,12 +7,14 @@ @desc: """ import json +import re import threading import time import uuid from typing import Dict -from django.db.models import QuerySet +from django.core import validators +from django.db.models import QuerySet, Q from rest_framework import serializers from application.models import Application @@ -72,7 +74,7 @@ class ModelSerializer(serializers.Serializer): self.is_valid(raise_exception=True) user_id = self.data.get('user_id') name = self.data.get('name') - model_query_set = QuerySet(Model).filter(user_id=user_id) + model_query_set = QuerySet(Model).filter((Q(user_id=user_id) | Q(permission_type='PUBLIC'))) query_params = {} if name is not None: query_params['name__contains'] = name @@ -96,6 +98,11 @@ class ModelSerializer(serializers.Serializer): model_type = serializers.CharField(required=False, error_messages=ErrMessage.char("模型类型")) + permission_type = serializers.CharField(required=False, error_messages=ErrMessage.char("权限"), validators=[ + validators.RegexValidator(regex=re.compile("^PUBLIC|PRIVATE$"), + message="权限只支持PUBLIC|PRIVATE", code=500) + ]) + model_name = serializers.CharField(required=False, error_messages=ErrMessage.char("模型类型")) credential = serializers.DictField(required=False, error_messages=ErrMessage.dict("认证信息")) @@ -135,6 +142,11 @@ class ModelSerializer(serializers.Serializer): model_type = serializers.CharField(required=True, error_messages=ErrMessage.char("模型类型")) + permission_type = serializers.CharField(required=True, error_messages=ErrMessage.char("权限"), validators=[ + validators.RegexValidator(regex=re.compile("^PUBLIC|PRIVATE$"), + message="权限只支持PUBLIC|PRIVATE", code=500) + ]) + model_name = serializers.CharField(required=True, error_messages=ErrMessage.char("基础模型")) credential = serializers.DictField(required=True, error_messages=ErrMessage.dict("认证信息")) @@ -165,10 +177,12 @@ class ModelSerializer(serializers.Serializer): provider = self.data.get('provider') model_type = self.data.get('model_type') model_name = self.data.get('model_name') + permission_type = self.data.get('permission_type') model_credential_str = json.dumps(credential) model = Model(id=uuid.uuid1(), status=status, user_id=user_id, name=name, credential=rsa_long_encrypt(model_credential_str), - provider=provider, model_type=model_type, model_name=model_name) + provider=provider, model_type=model_type, model_name=model_name, + permission_type=permission_type) model.save() if status == Status.DOWNLOAD: thread = threading.Thread(target=ModelPullManage.pull, args=(model, credential)) @@ -245,7 +259,7 @@ class ModelSerializer(serializers.Serializer): model.status = Status.DOWNLOAD else: raise e - update_keys = ['credential', 'name', 'model_type', 'model_name'] + update_keys = ['credential', 'name', 'model_type', 'model_name', 'permission_type'] for update_key in update_keys: if update_key in instance and instance.get(update_key) is not None: if update_key == 'credential': diff --git a/apps/setting/swagger_api/provide_api.py b/apps/setting/swagger_api/provide_api.py index f68ac5be4..7544fdf25 100644 --- a/apps/setting/swagger_api/provide_api.py +++ b/apps/setting/swagger_api/provide_api.py @@ -74,6 +74,8 @@ class ModelCreateApi(ApiMixin): 'provider': openapi.Schema(type=openapi.TYPE_STRING, title="供应商", description="供应商"), + 'permission_type': openapi.Schema(type=openapi.TYPE_STRING, title="权限", + description="PUBLIC|PRIVATE"), 'model_type': openapi.Schema(type=openapi.TYPE_STRING, title="供应商", description="供应商"), @@ -82,7 +84,8 @@ class ModelCreateApi(ApiMixin): description="供应商"), 'credential': openapi.Schema(type=openapi.TYPE_OBJECT, title="模型证书信息", - description="模型证书信息") + description="模型证书信息"), + } )