feat: 模型接口添加权限参数
Some checks failed
sync2gitee / repo-sync (push) Has been cancelled
Typos Check / Spell Check with Typos (push) Has been cancelled

This commit is contained in:
shaohuzhang1 2024-07-16 16:08:31 +08:00
parent 2f3d282c0d
commit 9b81b89975
2 changed files with 22 additions and 5 deletions

View File

@ -7,12 +7,14 @@
@desc:
"""
import json
import re
import threading
import time
import uuid
from typing import Dict
from django.db.models import QuerySet
from django.core import validators
from django.db.models import QuerySet, Q
from rest_framework import serializers
from application.models import Application
@ -72,7 +74,7 @@ class ModelSerializer(serializers.Serializer):
self.is_valid(raise_exception=True)
user_id = self.data.get('user_id')
name = self.data.get('name')
model_query_set = QuerySet(Model).filter(user_id=user_id)
model_query_set = QuerySet(Model).filter((Q(user_id=user_id) | Q(permission_type='PUBLIC')))
query_params = {}
if name is not None:
query_params['name__contains'] = name
@ -96,6 +98,11 @@ class ModelSerializer(serializers.Serializer):
model_type = serializers.CharField(required=False, error_messages=ErrMessage.char("模型类型"))
permission_type = serializers.CharField(required=False, error_messages=ErrMessage.char("权限"), validators=[
validators.RegexValidator(regex=re.compile("^PUBLIC|PRIVATE$"),
message="权限只支持PUBLIC|PRIVATE", code=500)
])
model_name = serializers.CharField(required=False, error_messages=ErrMessage.char("模型类型"))
credential = serializers.DictField(required=False, error_messages=ErrMessage.dict("认证信息"))
@ -135,6 +142,11 @@ class ModelSerializer(serializers.Serializer):
model_type = serializers.CharField(required=True, error_messages=ErrMessage.char("模型类型"))
permission_type = serializers.CharField(required=True, error_messages=ErrMessage.char("权限"), validators=[
validators.RegexValidator(regex=re.compile("^PUBLIC|PRIVATE$"),
message="权限只支持PUBLIC|PRIVATE", code=500)
])
model_name = serializers.CharField(required=True, error_messages=ErrMessage.char("基础模型"))
credential = serializers.DictField(required=True, error_messages=ErrMessage.dict("认证信息"))
@ -165,10 +177,12 @@ class ModelSerializer(serializers.Serializer):
provider = self.data.get('provider')
model_type = self.data.get('model_type')
model_name = self.data.get('model_name')
permission_type = self.data.get('permission_type')
model_credential_str = json.dumps(credential)
model = Model(id=uuid.uuid1(), status=status, user_id=user_id, name=name,
credential=rsa_long_encrypt(model_credential_str),
provider=provider, model_type=model_type, model_name=model_name)
provider=provider, model_type=model_type, model_name=model_name,
permission_type=permission_type)
model.save()
if status == Status.DOWNLOAD:
thread = threading.Thread(target=ModelPullManage.pull, args=(model, credential))
@ -245,7 +259,7 @@ class ModelSerializer(serializers.Serializer):
model.status = Status.DOWNLOAD
else:
raise e
update_keys = ['credential', 'name', 'model_type', 'model_name']
update_keys = ['credential', 'name', 'model_type', 'model_name', 'permission_type']
for update_key in update_keys:
if update_key in instance and instance.get(update_key) is not None:
if update_key == 'credential':

View File

@ -74,6 +74,8 @@ class ModelCreateApi(ApiMixin):
'provider': openapi.Schema(type=openapi.TYPE_STRING,
title="供应商",
description="供应商"),
'permission_type': openapi.Schema(type=openapi.TYPE_STRING, title="权限",
description="PUBLIC|PRIVATE"),
'model_type': openapi.Schema(type=openapi.TYPE_STRING,
title="供应商",
description="供应商"),
@ -82,7 +84,8 @@ class ModelCreateApi(ApiMixin):
description="供应商"),
'credential': openapi.Schema(type=openapi.TYPE_OBJECT,
title="模型证书信息",
description="模型证书信息")
description="模型证书信息"),
}
)