mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-29 16:12:55 +00:00
182 lines
7.7 KiB
Python
182 lines
7.7 KiB
Python
# -*- coding: utf-8 -*-
|
|
import json
|
|
import threading
|
|
import time
|
|
from typing import Dict
|
|
|
|
import uuid_utils.compat as uuid
|
|
from django.db.models import QuerySet
|
|
from django.utils.translation import gettext_lazy as _
|
|
from rest_framework import serializers
|
|
|
|
from common.exception.app_exception import AppApiException
|
|
from common.utils.rsa_util import rsa_long_encrypt, rsa_long_decrypt
|
|
from models_provider.base_model_provider import ValidCode, DownModelChunkStatus
|
|
from models_provider.constants.model_provider_constants import ModelProvideConstants
|
|
from models_provider.models import Model, Status
|
|
|
|
|
|
class ModelModelSerializer(serializers.ModelSerializer):
|
|
class Meta:
|
|
model = Model
|
|
fields = [
|
|
'id', 'name', 'status', 'model_type', 'model_name',
|
|
'user', 'provider', 'credential', 'meta',
|
|
'model_params_form', 'workspace_id'
|
|
]
|
|
|
|
|
|
class ModelCreateRequest(serializers.Serializer):
|
|
name = serializers.CharField(required=True, max_length=64, label=_("model name"))
|
|
provider = serializers.CharField(required=True, label=_("provider"))
|
|
model_type = serializers.CharField(required=True, label=_("model type"))
|
|
model_name = serializers.CharField(required=True, label=_("model name"))
|
|
model_params_form = serializers.ListField(required=False, default=list, label=_("parameter configuration"))
|
|
credential = serializers.DictField(required=True, label=_("certification information"))
|
|
|
|
|
|
class ModelPullManage:
|
|
@staticmethod
|
|
def pull(model: Model, credential: Dict):
|
|
try:
|
|
response = ModelProvideConstants[model.provider].value.down_model(
|
|
model.model_type, model.model_name, credential
|
|
)
|
|
down_model_chunk = {}
|
|
last_update_time = time.time()
|
|
|
|
for chunk in response:
|
|
down_model_chunk[chunk.digest] = chunk.to_dict()
|
|
if time.time() - last_update_time > 5:
|
|
current_model = QuerySet(Model).filter(id=model.id).first()
|
|
if current_model and current_model.status == Status.PAUSE_DOWNLOAD:
|
|
return
|
|
QuerySet(Model).filter(id=model.id).update(
|
|
meta={"down_model_chunk": list(down_model_chunk.values())}
|
|
)
|
|
last_update_time = time.time()
|
|
|
|
status = Status.ERROR
|
|
message = ""
|
|
for chunk in down_model_chunk.values():
|
|
if chunk.get('status') == DownModelChunkStatus.success.value:
|
|
status = Status.SUCCESS
|
|
elif chunk.get('status') == DownModelChunkStatus.error.value:
|
|
message = chunk.get("digest")
|
|
|
|
QuerySet(Model).filter(id=model.id).update(
|
|
meta={"down_model_chunk": [], "message": message},
|
|
status=status
|
|
)
|
|
except Exception as e:
|
|
QuerySet(Model).filter(id=model.id).update(
|
|
meta={"down_model_chunk": [], "message": str(e)},
|
|
status=Status.ERROR
|
|
)
|
|
|
|
|
|
class ModelSerializer(serializers.Serializer):
|
|
@staticmethod
|
|
def model_to_dict(model: Model):
|
|
credential = json.loads(rsa_long_decrypt(model.credential))
|
|
return {
|
|
'id': str(model.id),
|
|
'provider': model.provider,
|
|
'name': model.name,
|
|
'model_type': model.model_type,
|
|
'model_name': model.model_name,
|
|
'status': model.status,
|
|
'meta': model.meta,
|
|
'credential': ModelProvideConstants[model.provider].value.get_model_credential(
|
|
model.model_type, model.model_name
|
|
).encryption_dict(credential),
|
|
'workspace_id': model.workspace_id
|
|
}
|
|
|
|
class Operate(serializers.Serializer):
|
|
id = serializers.UUIDField(required=True, label=_("模型id"))
|
|
user_id = serializers.UUIDField(required=True, label=_("user id"))
|
|
|
|
def is_valid(self, *, raise_exception=False):
|
|
super().is_valid(raise_exception=True)
|
|
model = QuerySet(Model).filter(
|
|
id=self.data.get("id"), user_id=self.data.get("user_id")
|
|
).first()
|
|
if model is None:
|
|
raise AppApiException(500, _('模型不存在'))
|
|
|
|
def one(self, with_valid=False):
|
|
if with_valid:
|
|
self.is_valid(raise_exception=True)
|
|
model = QuerySet(Model).get(
|
|
id=self.data.get('id'), user_id=self.data.get('user_id')
|
|
)
|
|
return ModelSerializer.model_to_dict(model)
|
|
|
|
class Create(serializers.Serializer):
|
|
user_id = serializers.UUIDField(required=True, label=_('user id'))
|
|
name = serializers.CharField(required=True, max_length=64, label=_("model name"))
|
|
provider = serializers.CharField(required=True, label=_("provider"))
|
|
model_type = serializers.CharField(required=True, label=_("model type"))
|
|
model_name = serializers.CharField(required=True, label=_("model name"))
|
|
model_params_form = serializers.ListField(required=False, default=list, label=_("parameter configuration"))
|
|
credential = serializers.DictField(required=True, label=_("certification information"))
|
|
workspace_id = serializers.CharField(required=False, label=_("workspace id"), max_length=128)
|
|
|
|
def is_valid(self, *, raise_exception=False):
|
|
super().is_valid(raise_exception=True)
|
|
if QuerySet(Model).filter(
|
|
user_id=self.data.get('user_id'),
|
|
name=self.data.get('name'),
|
|
workspace_id=self.data.get('workspace_id')
|
|
).exists():
|
|
raise AppApiException(
|
|
500,
|
|
_('Model name【{model_name}】already exists').format(model_name=self.data.get("name"))
|
|
)
|
|
default_params = {item['field']: item['default_value'] for item in self.data.get('model_params_form')}
|
|
ModelProvideConstants[self.data.get('provider')].value.is_valid_credential(
|
|
self.data.get('model_type'),
|
|
self.data.get('model_name'),
|
|
self.data.get('credential'),
|
|
default_params,
|
|
raise_exception=True
|
|
)
|
|
|
|
def insert(self, workspace_id, with_valid=True):
|
|
status = Status.SUCCESS
|
|
if with_valid:
|
|
try:
|
|
self.is_valid(raise_exception=True)
|
|
except AppApiException as e:
|
|
if e.code == ValidCode.model_not_fount:
|
|
status = Status.DOWNLOAD
|
|
else:
|
|
raise e
|
|
|
|
credential = self.data.get('credential')
|
|
model_data = {
|
|
'id': uuid.uuid1(),
|
|
'status': status,
|
|
'user_id': self.data.get('user_id'),
|
|
'name': self.data.get('name'),
|
|
'credential': rsa_long_encrypt(json.dumps(credential)),
|
|
'provider': self.data.get('provider'),
|
|
'model_type': self.data.get('model_type'),
|
|
'model_name': self.data.get('model_name'),
|
|
'model_params_form': self.data.get('model_params_form'),
|
|
'workspace_id': workspace_id
|
|
}
|
|
model = Model(**model_data)
|
|
try:
|
|
model.save()
|
|
except Exception as save_error:
|
|
# 可添加日志记录
|
|
raise AppApiException(500, _('模型保存失败')) from save_error
|
|
|
|
if status == Status.DOWNLOAD:
|
|
thread = threading.Thread(target=ModelPullManage.pull, args=(model, credential))
|
|
thread.start()
|
|
|
|
return ModelModelSerializer(model).data
|