MaxKB/apps/setting/serializers/provider_serializers.py
2023-12-01 17:30:06 +08:00

202 lines
8.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# coding=utf-8
"""
@project: maxkb
@Author
@file provider_serializers.py
@date2023/11/2 14:01
@desc:
"""
import json
import uuid
from typing import Dict
from django.db.models import QuerySet
from rest_framework import serializers
from common.exception.app_exception import AppApiException
from common.util.rsa_util import encrypt, decrypt
from setting.models.model_management import Model
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
class ModelSerializer(serializers.Serializer):
class Query(serializers.Serializer):
user_id = serializers.UUIDField(required=True)
name = serializers.CharField(required=False)
model_type = serializers.CharField(required=False)
model_name = serializers.CharField(required=False)
provider = serializers.CharField(required=False)
def list(self, with_valid):
if with_valid:
self.is_valid(raise_exception=True)
user_id = self.data.get('user_id')
name = self.data.get('name')
model_query_set = QuerySet(Model).filter(user_id=user_id)
query_params = {}
if name is not None:
query_params['name__contains'] = name
if self.data.get('model_type') is not None:
query_params['model_type'] = self.data.get('model_type')
if self.data.get('model_name') is not None:
query_params['model_name'] = self.data.get('model_name')
if self.data.get('provider') is not None:
query_params['provider'] = self.data.get('provider')
return [ModelSerializer.model_to_dict(model) for model in model_query_set.filter(**query_params)]
class Edit(serializers.Serializer):
user_id = serializers.CharField(required=False)
name = serializers.CharField(required=False)
model_type = serializers.CharField(required=False)
model_name = serializers.CharField(required=False)
credential = serializers.DictField(required=False)
def is_valid(self, model=None, raise_exception=False):
super().is_valid(raise_exception=True)
filter_params = {'user_id': self.data.get('user_id')}
if 'name' in self.data and self.data.get('name') is not None:
filter_params['name'] = self.data.get('name')
if QuerySet(Model).exclude(id=model.id).filter(**filter_params).exists():
raise AppApiException(500, f'模型名称【{self.data.get("name")}】已存在')
ModelSerializer.model_to_dict(model)
provider = model.provider
model_type = self.data.get('model_type')
model_name = self.data.get(
'model_name')
credential = self.data.get('credential')
model_credential = ModelProvideConstants[provider].value.get_model_credential(model_type,
model_name)
source_model_credential = json.loads(decrypt(model.credential))
source_encryption_model_credential = model_credential.encryption_dict(source_model_credential)
if credential is not None:
for k in source_encryption_model_credential.keys():
if credential[k] == source_encryption_model_credential[k]:
credential[k] = source_model_credential[k]
# 校验模型认证数据
model_credential.is_valid(
model_type,
model_name,
credential,
raise_exception=True)
return credential
class Create(serializers.Serializer):
user_id = serializers.CharField(required=True)
name = serializers.CharField(required=True)
provider = serializers.CharField(required=True)
model_type = serializers.CharField(required=True)
model_name = serializers.CharField(required=True)
credential = serializers.DictField(required=True)
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
if QuerySet(Model).filter(user_id=self.data.get('user_id'),
name=self.data.get('name')).exists():
raise AppApiException(500, f'模型名称【{self.data.get("name")}】已存在')
# 校验模型认证数据
ModelProvideConstants[self.data.get('provider')].value.get_model_credential(self.data.get('model_type'),
self.data.get(
'model_name')).is_valid(
self.data.get('model_type'),
self.data.get('model_name'),
self.data.get('credential'),
raise_exception=True)
def insert(self, user_id, with_valid=False):
if with_valid:
self.is_valid(raise_exception=True)
credential = self.data.get('credential')
name = self.data.get('name')
provider = self.data.get('provider')
model_type = self.data.get('model_type')
model_name = self.data.get('model_name')
model_credential_str = json.dumps(credential)
model = Model(id=uuid.uuid1(), user_id=user_id, name=name,
credential=encrypt(model_credential_str),
provider=provider, model_type=model_type, model_name=model_name)
model.save()
return ModelSerializer.Operate(data={'id': model.id, 'user_id': user_id}).one(with_valid=True)
@staticmethod
def model_to_dict(model: Model):
credential = json.loads(decrypt(model.credential))
return {'id': str(model.id), 'provider': model.provider, 'name': model.name, 'model_type': model.model_type,
'model_name': model.model_name,
'credential': ModelProvideConstants[model.provider].value.get_model_credential(model.model_type,
model.model_name).encryption_dict(
credential)}
class Operate(serializers.Serializer):
id = serializers.UUIDField(required=True)
user_id = serializers.UUIDField(required=True)
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
model = QuerySet(Model).filter(id=self.data.get("id"), user_id=self.data.get("user_id")).first()
if model is None:
raise AppApiException(500, '模型不存在')
def one(self, with_valid=False):
if with_valid:
self.is_valid(raise_exception=True)
model = QuerySet(Model).get(id=self.data.get('id'), user_id=self.data.get('user_id'))
return ModelSerializer.model_to_dict(model)
def delete(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
QuerySet(Model).filter(id=self.data.get('id')).delete()
return True
def edit(self, instance: Dict, user_id: str, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
model = QuerySet(Model).filter(id=self.data.get('id')).first()
if model is None:
raise AppApiException(500, '不存在的id')
else:
credential = ModelSerializer.Edit(data={**instance, 'user_id': user_id}).is_valid(model=model)
update_keys = ['credential', 'name', 'model_type', 'model_name']
for update_key in update_keys:
if update_key in instance and instance.get(update_key) is not None:
if update_key == 'credential':
model_credential_str = json.dumps(credential)
model.__setattr__(update_key, encrypt(model_credential_str))
else:
model.__setattr__(update_key, instance.get(update_key))
model.save()
return self.one(with_valid=False)
class ProviderSerializer(serializers.Serializer):
provider = serializers.CharField(required=True)
method = serializers.CharField(required=True)
def exec(self, exec_params: Dict[str, object], with_valid=False):
if with_valid:
self.is_valid(raise_exception=True)
provider = self.data.get('provider')
method = self.data.get('method')
return getattr(ModelProvideConstants[provider].value, method)(exec_params)