feat: add model setting

This commit is contained in:
wxg0103 2025-04-18 17:45:15 +08:00
parent 9b0f9b04b7
commit 6f6b163416
14 changed files with 7031 additions and 390 deletions

View File

@ -47,20 +47,20 @@ class ModelManage:
ModelManage.cache.delete(_id)
class VectorStore:
from embedding.vector.pg_vector import PGVector
from embedding.vector.base_vector import BaseVectorStore
instance_map = {
'pg_vector': PGVector,
}
instance = None
@staticmethod
def get_embedding_vector() -> BaseVectorStore:
from embedding.vector.pg_vector import PGVector
if VectorStore.instance is None:
from maxkb.const import CONFIG
vector_store_class = VectorStore.instance_map.get(CONFIG.get("VECTOR_STORE_NAME"),
PGVector)
VectorStore.instance = vector_store_class()
return VectorStore.instance
# class VectorStore:
# from embedding.vector.pg_vector import PGVector
# from embedding.vector.base_vector import BaseVectorStore
# instance_map = {
# 'pg_vector': PGVector,
# }
# instance = None
#
# @staticmethod
# def get_embedding_vector() -> BaseVectorStore:
# from embedding.vector.pg_vector import PGVector
# if VectorStore.instance is None:
# from maxkb.const import CONFIG
# vector_store_class = VectorStore.instance_map.get(CONFIG.get("VECTOR_STORE_NAME"),
# PGVector)
# VectorStore.instance = vector_store_class()
# return VectorStore.instance

View File

@ -13,7 +13,8 @@ import io
import mimetypes
import re
import shutil
from typing import List
from functools import reduce
from typing import List, Dict
from django.core.files.uploadedfile import InMemoryUploadedFile
from django.utils.translation import gettext as _
@ -50,13 +51,13 @@ def group_by(list_source: List, key):
return result
CHAR_SET = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
def get_random_chars(number=6):
return "".join([CHAR_SET[random.randint(0, len(CHAR_SET) - 1)] for index in range(number)])
def encryption(message: str):
"""
加密敏感字段数据 加密方式是 如果密码是 1234567890 那么给前端则是 123******890
@ -122,7 +123,6 @@ def get_file_content(path):
return content
def bytes_to_uploaded_file(file_bytes, file_name="file.txt"):
content_type, _ = mimetypes.guess_type(file_name)
if content_type is None:
@ -205,3 +205,9 @@ def split_and_transcribe(file_path, model, max_segment_length_ms=59000, audio_fo
full_text.append(text)
return ' '.join(full_text)
def query_params_to_single_dict(query_params: Dict):
return reduce(lambda x, y: {**x, **y}, list(
filter(lambda item: item is not None, [({key: value} if value is not None and len(value) > 0 else None) for
key, value in
query_params.items()])), {})

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,8 +1,12 @@
# coding=utf-8
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import OpenApiParameter
from rest_framework import serializers
from common.mixins.api_mixin import APIMixin
from common.result import ResultSerializer
from models_provider.serializers.model import ModelCreateRequest, ModelModelSerializer
from models_provider.serializers.model_serializer import ModelModelSerializer, ModelCreateRequest
from django.utils.translation import gettext_lazy as _
class ModelCreateResponse(ResultSerializer):
@ -10,6 +14,12 @@ class ModelCreateResponse(ResultSerializer):
return ModelModelSerializer()
class ModelListResponse(APIMixin):
@staticmethod
def get_response():
return serializers.ListSerializer(child=ModelModelSerializer())
class ModelCreateAPI(APIMixin):
@staticmethod
def get_request():
@ -18,3 +28,47 @@ class ModelCreateAPI(APIMixin):
@staticmethod
def get_response():
return ModelCreateResponse
@classmethod
def get_query_params_api(cls):
return [OpenApiParameter(
name="workspace_id",
description=_("workspace id"),
type=OpenApiTypes.STR,
location=OpenApiParameter.PATH,
required=True,
)]
class GetModelApi(APIMixin):
@staticmethod
def get_query_params_api():
return [OpenApiParameter(
name="workspace_id",
description=_("workspace id"),
type=OpenApiTypes.STR,
location=OpenApiParameter.PATH,
required=True,
), OpenApiParameter(
name="model_id",
description=_("model id"),
type=OpenApiTypes.STR,
location=OpenApiParameter.PATH,
required=True,
)
]
@staticmethod
def get_response():
return ModelModelSerializer
class ModelEditApi(APIMixin):
@staticmethod
def get_request():
return ModelCreateRequest
@staticmethod
def get_response():
return ModelModelSerializer

View File

@ -30,29 +30,65 @@ class ModelListSerializer(serializers.Serializer):
desc = serializers.CharField(required=True, label=_("model name"))
class ModelParamsFormSerializer(serializers.Serializer):
input_type = serializers.CharField(required=False, label=_("input type"))
label = serializers.CharField(required=False, label=_("label"))
text_field = serializers.CharField(required=False, label=_("text field"))
value_field = serializers.CharField(required=False, label=_("value field"))
provider = serializers.CharField(required=False, label=_("provider"))
method = serializers.CharField(required=False, label=_("method"))
required = serializers.BooleanField(required=False, label=_("required"))
default_value = serializers.CharField(required=False, label=_("default value"))
relation_show_field_dict = serializers.DictField(required=False, label=_("relation show field dict"))
relation_trigger_field_dict = serializers.DictField(required=False, label=_("relation trigger field dict"))
trigger_type = serializers.CharField(required=False, label=_("trigger type"))
attrs = serializers.DictField(required=False, label=_("attrs"))
props_info = serializers.DictField(required=False, label=_("props info"))
class ProvideApi(APIMixin):
class ModelParamsForm(APIMixin):
@staticmethod
def get_query_params_api():
return [OpenApiParameter(
name="model_type",
description=_("model type"),
type=OpenApiTypes.STR,
location=OpenApiParameter.QUERY,
required=True,
), OpenApiParameter(
name="provider",
description=_("provider"),
type=OpenApiTypes.STR,
location=OpenApiParameter.QUERY,
required=True,
), OpenApiParameter(
name="model_name",
description=_("model name"),
type=OpenApiTypes.STR,
location=OpenApiParameter.QUERY,
required=True,
)
]
@staticmethod
def get_response():
return serializers.ListSerializer(child=ModelParamsFormSerializer())
class ModelList(APIMixin):
@staticmethod
def get_query_params_api():
return [OpenApiParameter(
# 参数的名称是done
name="model_type",
# 对参数的备注
description="model_type",
# 指定参数的类型
description=_("model type"),
type=OpenApiTypes.STR,
location=OpenApiParameter.QUERY,
# 指定必须给
required=False,
required=True,
), OpenApiParameter(
# 参数的名称是done
name="provider",
# 对参数的备注
description="provider",
# 指定参数的类型
description=_("provider"),
type=OpenApiTypes.STR,
location=OpenApiParameter.QUERY,
# 指定必须给
required=True,
)
]
@ -72,7 +108,7 @@ class ProvideApi(APIMixin):
# 参数的名称是done
name="provider",
# 对参数的备注
description="provider",
description=_("provider"),
# 指定参数的类型
type=OpenApiTypes.STR,
location=OpenApiParameter.QUERY,

View File

@ -15,7 +15,7 @@ class BaiLianLLMModelParams(BaseForm):
temperature = forms.SliderField(
TooltipLabel(
_('Temperature'),
_('Higher values make the output more random, while lower values make it more focused and deterministic.')
_('Higher values make the output more random, while lower values make it more focused and deterministic')
),
required=True,
default_value=0.7,

View File

@ -1,181 +0,0 @@
# -*- coding: utf-8 -*-
import json
import threading
import time
from typing import Dict
import uuid_utils.compat as uuid
from django.db.models import QuerySet
from django.utils.translation import gettext_lazy as _
from rest_framework import serializers
from common.exception.app_exception import AppApiException
from common.utils.rsa_util import rsa_long_encrypt, rsa_long_decrypt
from models_provider.base_model_provider import ValidCode, DownModelChunkStatus
from models_provider.constants.model_provider_constants import ModelProvideConstants
from models_provider.models import Model, Status
class ModelModelSerializer(serializers.ModelSerializer):
class Meta:
model = Model
fields = [
'id', 'name', 'status', 'model_type', 'model_name',
'user', 'provider', 'credential', 'meta',
'model_params_form', 'workspace_id'
]
class ModelCreateRequest(serializers.Serializer):
name = serializers.CharField(required=True, max_length=64, label=_("model name"))
provider = serializers.CharField(required=True, label=_("provider"))
model_type = serializers.CharField(required=True, label=_("model type"))
model_name = serializers.CharField(required=True, label=_("model name"))
model_params_form = serializers.ListField(required=False, default=list, label=_("parameter configuration"))
credential = serializers.DictField(required=True, label=_("certification information"))
class ModelPullManage:
@staticmethod
def pull(model: Model, credential: Dict):
try:
response = ModelProvideConstants[model.provider].value.down_model(
model.model_type, model.model_name, credential
)
down_model_chunk = {}
last_update_time = time.time()
for chunk in response:
down_model_chunk[chunk.digest] = chunk.to_dict()
if time.time() - last_update_time > 5:
current_model = QuerySet(Model).filter(id=model.id).first()
if current_model and current_model.status == Status.PAUSE_DOWNLOAD:
return
QuerySet(Model).filter(id=model.id).update(
meta={"down_model_chunk": list(down_model_chunk.values())}
)
last_update_time = time.time()
status = Status.ERROR
message = ""
for chunk in down_model_chunk.values():
if chunk.get('status') == DownModelChunkStatus.success.value:
status = Status.SUCCESS
elif chunk.get('status') == DownModelChunkStatus.error.value:
message = chunk.get("digest")
QuerySet(Model).filter(id=model.id).update(
meta={"down_model_chunk": [], "message": message},
status=status
)
except Exception as e:
QuerySet(Model).filter(id=model.id).update(
meta={"down_model_chunk": [], "message": str(e)},
status=Status.ERROR
)
class ModelSerializer(serializers.Serializer):
@staticmethod
def model_to_dict(model: Model):
credential = json.loads(rsa_long_decrypt(model.credential))
return {
'id': str(model.id),
'provider': model.provider,
'name': model.name,
'model_type': model.model_type,
'model_name': model.model_name,
'status': model.status,
'meta': model.meta,
'credential': ModelProvideConstants[model.provider].value.get_model_credential(
model.model_type, model.model_name
).encryption_dict(credential),
'workspace_id': model.workspace_id
}
class Operate(serializers.Serializer):
id = serializers.UUIDField(required=True, label=_("模型id"))
user_id = serializers.UUIDField(required=True, label=_("user id"))
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
model = QuerySet(Model).filter(
id=self.data.get("id"), user_id=self.data.get("user_id")
).first()
if model is None:
raise AppApiException(500, _('模型不存在'))
def one(self, with_valid=False):
if with_valid:
self.is_valid(raise_exception=True)
model = QuerySet(Model).get(
id=self.data.get('id'), user_id=self.data.get('user_id')
)
return ModelSerializer.model_to_dict(model)
class Create(serializers.Serializer):
user_id = serializers.UUIDField(required=True, label=_('user id'))
name = serializers.CharField(required=True, max_length=64, label=_("model name"))
provider = serializers.CharField(required=True, label=_("provider"))
model_type = serializers.CharField(required=True, label=_("model type"))
model_name = serializers.CharField(required=True, label=_("model name"))
model_params_form = serializers.ListField(required=False, default=list, label=_("parameter configuration"))
credential = serializers.DictField(required=True, label=_("certification information"))
workspace_id = serializers.CharField(required=False, label=_("workspace id"), max_length=128)
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
if QuerySet(Model).filter(
user_id=self.data.get('user_id'),
name=self.data.get('name'),
workspace_id=self.data.get('workspace_id')
).exists():
raise AppApiException(
500,
_('Model name【{model_name}】already exists').format(model_name=self.data.get("name"))
)
default_params = {item['field']: item['default_value'] for item in self.data.get('model_params_form')}
ModelProvideConstants[self.data.get('provider')].value.is_valid_credential(
self.data.get('model_type'),
self.data.get('model_name'),
self.data.get('credential'),
default_params,
raise_exception=True
)
def insert(self, workspace_id, with_valid=True):
status = Status.SUCCESS
if with_valid:
try:
self.is_valid(raise_exception=True)
except AppApiException as e:
if e.code == ValidCode.model_not_fount:
status = Status.DOWNLOAD
else:
raise e
credential = self.data.get('credential')
model_data = {
'id': uuid.uuid1(),
'status': status,
'user_id': self.data.get('user_id'),
'name': self.data.get('name'),
'credential': rsa_long_encrypt(json.dumps(credential)),
'provider': self.data.get('provider'),
'model_type': self.data.get('model_type'),
'model_name': self.data.get('model_name'),
'model_params_form': self.data.get('model_params_form'),
'workspace_id': workspace_id
}
model = Model(**model_data)
try:
model.save()
except Exception as save_error:
# 可添加日志记录
raise AppApiException(500, _('模型保存失败')) from save_error
if status == Status.DOWNLOAD:
thread = threading.Thread(target=ModelPullManage.pull, args=(model, credential))
thread.start()
return ModelModelSerializer(model).data

View File

@ -0,0 +1,389 @@
# -*- coding: utf-8 -*-
import json
import threading
import time
from typing import Dict
import uuid_utils.compat as uuid
from django.db.models import QuerySet
from django.utils.translation import gettext_lazy as _
from rest_framework import serializers
from common.config.embedding_config import ModelManage
from common.exception.app_exception import AppApiException
from common.utils.rsa_util import rsa_long_encrypt, rsa_long_decrypt
from models_provider.base_model_provider import ValidCode, DownModelChunkStatus
from models_provider.constants.model_provider_constants import ModelProvideConstants
from models_provider.models import Model, Status
from models_provider.tools import get_model_credential
def get_default_model_params_setting(provider, model_type, model_name):
credential = get_model_credential(provider, model_type, model_name)
setting_form = credential.get_model_params_setting_form(model_name)
if setting_form is not None:
return setting_form.to_form_list()
return []
class ModelModelSerializer(serializers.ModelSerializer):
class Meta:
model = Model
fields = [
'id', 'name', 'status', 'model_type', 'model_name',
'user', 'provider', 'credential', 'meta',
'model_params_form', 'workspace_id'
]
class ModelCreateRequest(serializers.Serializer):
name = serializers.CharField(required=True, max_length=64, label=_("model name"))
provider = serializers.CharField(required=True, label=_("provider"))
model_type = serializers.CharField(required=True, label=_("model type"))
model_name = serializers.CharField(required=True, label=_("base model"))
model_params_form = serializers.ListField(required=False, default=list, label=_("parameter configuration"))
credential = serializers.DictField(required=True, label=_("certification information"))
class ModelPullManage:
@staticmethod
def pull(model: Model, credential: Dict):
try:
response = ModelProvideConstants[model.provider].value.down_model(
model.model_type, model.model_name, credential
)
down_model_chunk = {}
last_update_time = time.time()
for chunk in response:
down_model_chunk[chunk.digest] = chunk.to_dict()
if time.time() - last_update_time > 5:
current_model = QuerySet(Model).filter(id=model.id).first()
if current_model and current_model.status == Status.PAUSE_DOWNLOAD:
return
QuerySet(Model).filter(id=model.id).update(
meta={"down_model_chunk": list(down_model_chunk.values())}
)
last_update_time = time.time()
status = Status.ERROR
message = ""
for chunk in down_model_chunk.values():
if chunk.get('status') == DownModelChunkStatus.success.value:
status = Status.SUCCESS
elif chunk.get('status') == DownModelChunkStatus.error.value:
message = chunk.get("digest")
QuerySet(Model).filter(id=model.id).update(
meta={"down_model_chunk": [], "message": message},
status=status
)
except Exception as e:
QuerySet(Model).filter(id=model.id).update(
meta={"down_model_chunk": [], "message": str(e)},
status=Status.ERROR
)
class ModelSerializer(serializers.Serializer):
@staticmethod
def model_to_dict(model: Model):
credential = json.loads(rsa_long_decrypt(model.credential))
return {
'id': str(model.id),
'provider': model.provider,
'name': model.name,
'model_type': model.model_type,
'model_name': model.model_name,
'status': model.status,
'meta': model.meta,
'credential': ModelProvideConstants[model.provider].value.get_model_credential(
model.model_type, model.model_name
).encryption_dict(credential),
'workspace_id': model.workspace_id
}
class Operate(serializers.Serializer):
id = serializers.UUIDField(required=True, label=_("model id"))
user_id = serializers.UUIDField(required=True, label=_("user id"))
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
model = QuerySet(Model).filter(
id=self.data.get("id")
).first()
if model is None:
raise AppApiException(500, _('Model does not exist'))
def one(self, with_valid=False):
if with_valid:
self.is_valid(raise_exception=True)
model = QuerySet(Model).get(
id=self.data.get('id')
)
return ModelSerializer.model_to_dict(model)
def one_meta(self, with_valid=False):
model = None
if with_valid:
super().is_valid(raise_exception=True)
model = QuerySet(Model).filter(id=self.data.get("id")).first()
if model is None:
raise AppApiException(500, _('Model does not exist'))
return {'id': str(model.id), 'provider': model.provider, 'name': model.name, 'model_type': model.model_type,
'model_name': model.model_name,
'status': model.status,
'meta': model.meta
}
def pause_download(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
QuerySet(Model).filter(id=self.data.get('id')).update(status=Status.PAUSE_DOWNLOAD)
return True
def delete(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
model_id = self.data.get('id')
model = Model.objects.filter(id=model_id).first()
if not model:
raise AppApiException(500, _("Model does not exist"))
# TODO : 这里可以添加模型删除的逻辑,需要注意删除模型时的权限和关联关系
# if model.model_type == 'LLM':
# application_count = Application.objects.filter(model_id=model_id).count()
# if application_count > 0:
# raise AppApiException(500, f"该模型关联了{application_count} 个应用,无法删除该模型。")
# elif model.model_type == 'EMBEDDING':
# dataset_count = DataSet.objects.filter(embedding_mode_id=model_id).count()
# if dataset_count > 0:
# raise AppApiException(500, f"该模型关联了{dataset_count} 个知识库,无法删除该模型。")
# elif model.model_type == 'TTS':
# dataset_count = Application.objects.filter(tts_model_id=model_id).count()
# if dataset_count > 0:
# raise AppApiException(500, f"该模型关联了{dataset_count} 个应用,无法删除该模型。")
# elif model.model_type == 'STT':
# dataset_count = Application.objects.filter(stt_model_id=model_id).count()
# if dataset_count > 0:
# raise AppApiException(500, f"该模型关联了{dataset_count} 个应用,无法删除该模型。")
model.delete()
return True
def edit(self, instance: Dict, user_id: str, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
model = QuerySet(Model).filter(id=self.data.get('id')).first()
if model is None:
raise AppApiException(500, _('Model does not exist'))
else:
credential, model_credential, provider_handler = ModelSerializer.Edit(
data={**instance}).is_valid(
model=model)
try:
model.status = Status.SUCCESS
default_params = {item['field']: item['default_value'] for item in model.model_params_form}
# 校验模型认证数据
provider_handler.is_valid_credential(model.model_type,
instance.get("model_name"),
credential,
default_params,
raise_exception=True)
except AppApiException as e:
if e.code == ValidCode.model_not_fount:
model.status = Status.DOWNLOAD
else:
raise e
update_keys = ['credential', 'name', 'model_type', 'model_name']
for update_key in update_keys:
if update_key in instance and instance.get(update_key) is not None:
if update_key == 'credential':
model_credential_str = json.dumps(credential)
model.__setattr__(update_key, rsa_long_encrypt(model_credential_str))
else:
model.__setattr__(update_key, instance.get(update_key))
ModelManage.delete_key(str(model.id))
model.save()
if model.status == Status.DOWNLOAD:
thread = threading.Thread(target=ModelPullManage.pull, args=(model, credential))
thread.start()
return self.one(with_valid=False)
class Edit(serializers.Serializer):
user_id = serializers.CharField(required=False, label=(_('user id')))
name = serializers.CharField(required=False, max_length=64,
label=(_("model name")))
model_type = serializers.CharField(required=False, label=(_("model type")))
model_name = serializers.CharField(required=False, label=(_("base model")))
credential = serializers.DictField(required=False,
label=(_("certification information")))
def is_valid(self, model=None, raise_exception=False):
super().is_valid(raise_exception=True)
filter_params = {'workspace_id': self.data.get('workspace_id')}
if 'name' in self.data and self.data.get('name') is not None:
filter_params['name'] = self.data.get('name')
if QuerySet(Model).exclude(id=model.id).filter(**filter_params).exists():
raise AppApiException(500, _('base model【{model_name}】already exists').format(
model_name=self.data.get("name")))
ModelSerializer.model_to_dict(model)
provider = model.provider
model_type = self.data.get('model_type')
model_name = self.data.get(
'model_name')
credential = self.data.get('credential')
provider_handler = ModelProvideConstants[provider].value
model_credential = ModelProvideConstants[provider].value.get_model_credential(model_type,
model_name)
source_model_credential = json.loads(rsa_long_decrypt(model.credential))
source_encryption_model_credential = model_credential.encryption_dict(source_model_credential)
if credential is not None:
for k in source_encryption_model_credential.keys():
if k in credential and credential[k] == source_encryption_model_credential[k]:
credential[k] = source_model_credential[k]
return credential, model_credential, provider_handler
class Create(serializers.Serializer):
user_id = serializers.UUIDField(required=True, label=_('user id'))
name = serializers.CharField(required=True, max_length=64, label=_("model name"))
provider = serializers.CharField(required=True, label=_("provider"))
model_type = serializers.CharField(required=True, label=_("model type"))
model_name = serializers.CharField(required=True, label=_("base model"))
model_params_form = serializers.ListField(required=False, default=list, label=_("parameter configuration"))
credential = serializers.DictField(required=True, label=_("certification information"))
workspace_id = serializers.CharField(required=False, label=_("workspace id"), max_length=128)
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
if QuerySet(Model).filter(
name=self.data.get('name'),
workspace_id=self.data.get('workspace_id')
).exists():
raise AppApiException(
500,
_('base model【{model_name}】already exists').format(model_name=self.data.get("name"))
)
default_params = {item['field']: item['default_value'] for item in self.data.get('model_params_form')}
ModelProvideConstants[self.data.get('provider')].value.is_valid_credential(
self.data.get('model_type'),
self.data.get('model_name'),
self.data.get('credential'),
default_params,
raise_exception=True
)
def insert(self, workspace_id, with_valid=True):
status = Status.SUCCESS
if with_valid:
try:
self.is_valid(raise_exception=True)
except AppApiException as e:
if e.code == ValidCode.model_not_fount:
status = Status.DOWNLOAD
else:
raise e
credential = self.data.get('credential')
model_data = {
'id': uuid.uuid1(),
'status': status,
'user_id': self.data.get('user_id'),
'name': self.data.get('name'),
'credential': rsa_long_encrypt(json.dumps(credential)),
'provider': self.data.get('provider'),
'model_type': self.data.get('model_type'),
'model_name': self.data.get('model_name'),
'model_params_form': self.data.get('model_params_form'),
'workspace_id': workspace_id
}
model = Model(**model_data)
try:
model.save()
except Exception as save_error:
# 可添加日志记录
raise AppApiException(500, _("Model saving failed")) from save_error
if status == Status.DOWNLOAD:
thread = threading.Thread(target=ModelPullManage.pull, args=(model, credential))
thread.start()
return ModelModelSerializer(model).data
class Query(serializers.Serializer):
name = serializers.CharField(required=False, max_length=64, label=_('model name'))
model_type = serializers.CharField(required=False, label=_('model type'))
model_name = serializers.CharField(required=False, label=_('base model'))
provider = serializers.CharField(required=False, label=_('provider'))
create_user = serializers.CharField(required=False, label=_('create user'))
workspace_id = serializers.CharField(required=False, label=_('workspace id'))
def list(self, with_valid):
if with_valid:
self.is_valid(raise_exception=True)
query_params = self._build_query_params()
return self._fetch_models(query_params)
def _build_query_params(self):
query_params = {}
for field in ['name', 'model_type', 'model_name', 'provider', 'create_user', 'workspace_id']:
value = self.data.get(field)
if value is not None:
if field == 'name':
query_params[f'{field}__icontains'] = value
elif field == 'create_user':
query_params['user_id'] = value
else:
query_params[field] = value
return query_params
def _fetch_models(self, query_params):
return [
{
'id': str(model.id),
'provider': model.provider,
'name': model.name,
'model_type': model.model_type,
'model_name': model.model_name,
'status': model.status,
'meta': model.meta,
'user_id': model.user_id,
'username': model.user.username
}
for model in Model.objects.filter(**query_params).order_by("-create_time")
]
class ModelParams(serializers.Serializer):
id = serializers.UUIDField(required=True, label=_('model id'))
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
model = QuerySet(Model).filter(id=self.data.get("id")).first()
if model is None:
raise AppApiException(500, _("Model does not exist"))
def get_model_params(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
model_id = self.data.get('id')
model = QuerySet(Model).filter(id=model_id).first()
return model.model_params_form
def save_model_params_form(self, model_params_form, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
if model_params_form is None:
model_params_form = []
model_id = self.data.get('id')
model = QuerySet(Model).filter(id=model_id).first()
model.model_params_form = model_params_form
model.save()
return True

View File

@ -109,8 +109,6 @@ def get_model_by_id(_id, user_id):
connection.close()
if model is None:
raise Exception(_('Model does not exist'))
if model.permission_type == 'PRIVATE' and str(model.user_id) != str(user_id):
raise Exception(_('No permission to use this model') + f"{model.name}")
return model

View File

@ -4,18 +4,19 @@ from . import views
app_name = "models_provider"
urlpatterns = [
# path('provider/<str:provider>/<str:method>', views.Provide.Exec.as_view(), name='provide_exec'),
path('provider', views.Provide.as_view(), name='provide'),
path('provider/model_type_list', views.Provide.ModelTypeList.as_view(), name="provider/model_type_list"),
path('provider/model_list', views.Provide.ModelList.as_view(), name="provider/model_name_list"),
# path('provider/model_params_form', views.Provide.ModelParamsForm.as_view(),
# name="provider/model_params_form"),
# path('provider/model_form', views.Provide.ModelForm.as_view(),
# name="provider/model_form"),
path('provider/model_params_form', views.Provide.ModelParamsForm.as_view(),
name="provider/model_params_form"),
path('provider/model_form', views.Provide.ModelForm.as_view(),
name="provider/model_form"),
path('workspace/<str:workspace_id>/model', views.Model.as_view(), name='model'),
# path('workspace/<str:workspace_id>/model/<str:model_id>/model_params_form', views.Model.ModelParamsForm.as_view(),
# name='model/model_params_form'),
# path('workspace/<str:workspace_id>/model/<str:model_id>', views.Model.Operate.as_view(), name='model/operate'),
# path('workspace/<str:workspace_id>/model/<str:model_id>/pause_download', views.Model.PauseDownload.as_view(), name='model/operate'),
# path('workspace/<str:workspace_id>/model/<str:model_id>/meta', views.Model.ModelMeta.as_view(), name='model/operate/meta'),
path('workspace/<str:workspace_id>/model/<str:model_id>/model_params_form', views.Model.ModelParamsForm.as_view(),
name='model/model_params_form'),
path('workspace/<str:workspace_id>/model/<str:model_id>', views.Model.Operate.as_view(), name='model/operate'),
path('workspace/<str:workspace_id>/model/<str:model_id>/pause_download', views.Model.PauseDownload.as_view(),
name='model/operate'),
path('workspace/<str:workspace_id>/model/<str:model_id>/meta', views.Model.ModelMeta.as_view(),
name='model/operate/meta'),
]

View File

@ -15,8 +15,10 @@ from common.auth import TokenAuth
from common.auth.authentication import has_permissions
from common.constants.permission_constants import PermissionConstants
from common.result import result
from models_provider.api.model import ModelCreateAPI
from models_provider.serializers.model import ModelSerializer
from common.utils.common import query_params_to_single_dict
from models_provider.api.model import ModelCreateAPI, GetModelApi, ModelEditApi, ModelListResponse
from models_provider.api.provide import ProvideApi
from models_provider.serializers.model_serializer import ModelSerializer
class Model(APIView):
@ -26,10 +28,127 @@ class Model(APIView):
description=_("Create model"),
operation_id=_("Create model"),
tags=[_("Model")],
parameters=ModelCreateAPI.get_query_params_api(),
request=ModelCreateAPI.get_request(),
responses=ModelCreateAPI.get_response())
@has_permissions(PermissionConstants.MODEL_CREATE)
@has_permissions(PermissionConstants.MODEL_CREATE.get_workspace_permission())
def post(self, request: Request, workspace_id: str):
return result.success(
ModelSerializer.Create(data={**request.data, 'user_id': request.user.id}).insert(workspace_id,
with_valid=True))
# @extend_schema(methods=['PUT'],
# description=_('Update model'),
# operation_id=_('Update model'),
# request=ModelEditApi.get_request(),
# responses=ModelCreateApi.get_response(),
# tags=[_('Model')])
# @has_permissions(PermissionConstants.MODEL_CREATE)
# def put(self, request: Request):
# return result.success(
# ModelSerializer.Create(data={**request.data, 'user_id': str(request.user.id)}).insert(request.user.id,
# with_valid=True))
@extend_schema(methods=['GET'],
description=_('Query model list'),
operation_id=_('Query model list'),
parameters=ModelCreateAPI.get_query_params_api(),
responses=ModelListResponse.get_response(),
tags=[_('Model')])
@has_permissions(PermissionConstants.MODEL_READ.get_workspace_permission())
def get(self, request: Request):
return result.success(
ModelSerializer.Query(
data={**query_params_to_single_dict(request.query_params)}).list(
with_valid=True))
class Operate(APIView):
authentication_classes = [TokenAuth]
@extend_schema(methods=['PUT'],
description=_('Update model'),
operation_id=_('Update model'),
request=ModelEditApi.get_request(),
parameters=GetModelApi.get_query_params_api(),
responses=ModelEditApi.get_response(),
tags=[_('Model')])
@has_permissions(PermissionConstants.MODEL_EDIT.get_workspace_permission())
def put(self, request: Request, workspace_id, model_id: str):
return result.success(
ModelSerializer.Operate(data={'id': model_id, 'user_id': request.user.id}).edit(request.data,
str(request.user.id)))
@extend_schema(methods=['DELETE'],
description=_('Delete model'),
operation_id=_('Delete model'),
parameters=GetModelApi.get_query_params_api(),
tags=[_('Model')])
@has_permissions(PermissionConstants.MODEL_DELETE.get_workspace_permission())
def delete(self, request: Request, workspace_id: str, model_id: str):
return result.success(
ModelSerializer.Operate(data={'id': model_id, 'user_id': request.user.id}).delete())
@extend_schema(methods=['GET'],
description=_('Query model details'),
operation_id=_('Query model details'),
parameters=GetModelApi.get_query_params_api(),
responses=GetModelApi.get_response(),
tags=[_('Model')])
@has_permissions(PermissionConstants.MODEL_READ.get_workspace_permission())
def get(self, request: Request, workspace_id: str, model_id: str):
return result.success(
ModelSerializer.Operate(data={'id': model_id, 'user_id': request.user.id}).one(with_valid=True))
class ModelParamsForm(APIView):
authentication_classes = [TokenAuth]
@extend_schema(methods=['GET'],
description=_('Get model parameter form'),
operation_id=_('Get model parameter form'),
parameters=GetModelApi.get_query_params_api(),
responses=ProvideApi.ModelParamsForm.get_response(),
tags=[_('Model')])
@has_permissions(PermissionConstants.MODEL_READ.get_workspace_permission())
def get(self, request: Request, workspace_id: str, model_id: str):
return result.success(
ModelSerializer.ModelParams(data={'id': model_id}).get_model_params())
@extend_schema(methods=['PUT'],
description=_('Save model parameter form'),
operation_id=_('Save model parameter form'),
parameters=GetModelApi.get_query_params_api(),
responses=ProvideApi.ModelParamsForm.get_response(),
tags=[_('Model')])
@has_permissions(PermissionConstants.MODEL_READ.get_workspace_permission())
def put(self, request: Request, workspace_id: str, model_id: str):
return result.success(
ModelSerializer.ModelParams(data={'id': model_id}).save_model_params_form(request.data))
class ModelMeta(APIView):
authentication_classes = [TokenAuth]
@extend_schema(methods=['GET'],
description=_(
'Query model meta information, this interface does not carry authentication information'),
operation_id=_(
'Query model meta information, this interface does not carry authentication information'),
parameters=GetModelApi.get_query_params_api(),
responses=GetModelApi.get_response(),
tags=[_('Model')])
@has_permissions(PermissionConstants.MODEL_READ.get_workspace_permission())
def get(self, request: Request, workspace_id: str, model_id: str):
return result.success(
ModelSerializer.Operate(data={'id': model_id}).one_meta(with_valid=True))
class PauseDownload(APIView):
authentication_classes = [TokenAuth]
@extend_schema(methods=['PUT'],
description=_('Pause model download'),
operation_id=_('Pause model download'),
parameters=GetModelApi.get_query_params_api(),
tags=[_('Model')])
@has_permissions(PermissionConstants.MODEL_CREATE.get_workspace_permission())
def put(self, request: Request, workspace_id: str, model_id: str):
return result.success(
ModelSerializer.Operate(data={'id': model_id}).pause_download())

View File

@ -11,6 +11,7 @@ from common.auth.authentication import has_permissions
from common.constants.permission_constants import PermissionConstants
from models_provider.api.provide import ProvideApi
from models_provider.constants.model_provider_constants import ModelProvideConstants
from models_provider.serializers.model_serializer import get_default_model_params_setting
class Provide(APIView):
@ -66,3 +67,37 @@ class Provide(APIView):
return result.success(
ModelProvideConstants[provider].value.get_model_list(
model_type))
class ModelParamsForm(APIView):
authentication_classes = [TokenAuth]
@extend_schema(methods=['GET'],
description=_('Get model default parameters'),
operation_id=_('Get the model creation form'),
parameters=ProvideApi.ModelParamsForm.get_query_params_api(),
responses=ProvideApi.ModelParamsForm.get_response(),
tags=[_('Model')])
@has_permissions(PermissionConstants.MODEL_READ)
def get(self, request: Request):
provider = request.query_params.get('provider')
model_type = request.query_params.get('model_type')
model_name = request.query_params.get('model_name')
return result.success(get_default_model_params_setting(provider, model_type, model_name))
class ModelForm(APIView):
authentication_classes = [TokenAuth]
@extend_schema(methods=['GET'],
description=_('Get the model creation form'),
operation_id=_('Get the model creation form'),
parameters=ProvideApi.ModelParamsForm.get_query_params_api(),
responses=ProvideApi.ModelParamsForm.get_response(),
tags=[_('Model')])
@has_permissions(PermissionConstants.MODEL_READ)
def get(self, request: Request):
provider = request.query_params.get('provider')
model_type = request.query_params.get('model_type')
model_name = request.query_params.get('model_name')
return result.success(
ModelProvideConstants[provider].value.get_model_credential(model_type, model_name).to_form_list())