mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 01:33:05 +00:00
536 lines
25 KiB
Python
536 lines
25 KiB
Python
# -*- coding: utf-8 -*-
|
|
import json
|
|
import os
|
|
import threading
|
|
import time
|
|
from typing import Dict
|
|
|
|
import uuid_utils.compat as uuid
|
|
from django.core.cache import cache
|
|
from django.db import transaction
|
|
from django.db.models import QuerySet
|
|
from django.utils.translation import gettext_lazy as _
|
|
from rest_framework import serializers
|
|
|
|
from common.config.embedding_config import ModelManage
|
|
from common.constants.cache_version import Cache_Version
|
|
from common.constants.permission_constants import ResourcePermission, ResourceAuthType
|
|
from common.database_model_manage.database_model_manage import DatabaseModelManage
|
|
from common.db.search import native_search
|
|
from common.exception.app_exception import AppApiException
|
|
from common.utils.common import get_file_content
|
|
from common.utils.rsa_util import rsa_long_encrypt, rsa_long_decrypt
|
|
from maxkb.conf import PROJECT_DIR
|
|
from models_provider.base_model_provider import ValidCode, DownModelChunkStatus
|
|
from models_provider.constants.model_provider_constants import ModelProvideConstants
|
|
from models_provider.models import Model, Status
|
|
from models_provider.tools import get_model_credential
|
|
from system_manage.models import WorkspaceUserResourcePermission, AuthTargetType
|
|
from system_manage.serializers.user_resource_permission import UserResourcePermissionSerializer
|
|
from users.serializers.user import is_workspace_manage
|
|
|
|
|
|
def get_default_model_params_setting(provider, model_type, model_name):
|
|
credential = get_model_credential(provider, model_type, model_name)
|
|
setting_form = credential.get_model_params_setting_form(model_name)
|
|
if setting_form is not None:
|
|
return setting_form.to_form_list()
|
|
return []
|
|
|
|
|
|
class ModelModelSerializer(serializers.ModelSerializer):
|
|
class Meta:
|
|
model = Model
|
|
fields = [
|
|
'id', 'name', 'status', 'model_type', 'model_name',
|
|
'user', 'provider', 'credential', 'meta',
|
|
'model_params_form', 'workspace_id', 'create_time', 'update_time'
|
|
]
|
|
|
|
|
|
class ModelCreateRequest(serializers.Serializer):
|
|
name = serializers.CharField(required=True, max_length=64, label=_("model name"))
|
|
provider = serializers.CharField(required=True, label=_("provider"))
|
|
model_type = serializers.CharField(required=True, label=_("model type"))
|
|
model_name = serializers.CharField(required=True, label=_("base model"))
|
|
model_params_form = serializers.ListField(required=False, default=list, label=_("parameter configuration"))
|
|
credential = serializers.DictField(required=True, label=_("certification information"))
|
|
|
|
|
|
class ModelPullManage:
|
|
@staticmethod
|
|
def pull(model: Model, credential: Dict):
|
|
try:
|
|
response = ModelProvideConstants[model.provider].value.down_model(
|
|
model.model_type, model.model_name, credential
|
|
)
|
|
down_model_chunk = {}
|
|
last_update_time = time.time()
|
|
|
|
for chunk in response:
|
|
down_model_chunk[chunk.digest] = chunk.to_dict()
|
|
if time.time() - last_update_time > 5:
|
|
current_model = QuerySet(Model).filter(id=model.id).first()
|
|
if current_model and current_model.status == Status.PAUSE_DOWNLOAD:
|
|
return
|
|
QuerySet(Model).filter(id=model.id).update(
|
|
meta={"down_model_chunk": list(down_model_chunk.values())}
|
|
)
|
|
last_update_time = time.time()
|
|
|
|
status = Status.ERROR
|
|
message = ""
|
|
for chunk in down_model_chunk.values():
|
|
if chunk.get('status') == DownModelChunkStatus.success.value:
|
|
status = Status.SUCCESS
|
|
elif chunk.get('status') == DownModelChunkStatus.error.value:
|
|
message = chunk.get("digest")
|
|
|
|
QuerySet(Model).filter(id=model.id).update(
|
|
meta={"down_model_chunk": [], "message": message},
|
|
status=status
|
|
)
|
|
except Exception as e:
|
|
QuerySet(Model).filter(id=model.id).update(
|
|
meta={"down_model_chunk": [], "message": str(e)},
|
|
status=Status.ERROR
|
|
)
|
|
|
|
|
|
class ModelSerializer(serializers.Serializer):
|
|
@staticmethod
|
|
def model_to_dict(model: Model):
|
|
credential = json.loads(rsa_long_decrypt(model.credential))
|
|
return {
|
|
'id': str(model.id),
|
|
'provider': model.provider,
|
|
'name': model.name,
|
|
'model_type': model.model_type,
|
|
'model_name': model.model_name,
|
|
'status': model.status,
|
|
'meta': model.meta,
|
|
'credential': ModelProvideConstants[model.provider].value.get_model_credential(
|
|
model.model_type, model.model_name
|
|
).encryption_dict(credential),
|
|
'workspace_id': model.workspace_id,
|
|
'nick_name': model.user.nick_name if model.user else '',
|
|
'username': model.user.username if model.user else ''
|
|
}
|
|
|
|
class Operate(serializers.Serializer):
|
|
id = serializers.UUIDField(required=True, label=_("model id"))
|
|
user_id = serializers.UUIDField(required=False, label=_("user id"))
|
|
workspace_id = serializers.CharField(required=False, label=_("workspace id"))
|
|
|
|
def is_valid(self, *, raise_exception=False):
|
|
super().is_valid(raise_exception=True)
|
|
workspace_id = self.data.get("workspace_id")
|
|
model_query = QuerySet(Model).filter(id=self.data.get("id"))
|
|
if workspace_id is not None:
|
|
model_query = model_query.filter(workspace_id=workspace_id)
|
|
model = model_query.first()
|
|
if model is None:
|
|
raise AppApiException(500, _('Model does not exist'))
|
|
if model.workspace_id == 'None':
|
|
raise AppApiException(500, _('Shared models cannot be deleted or modified'))
|
|
|
|
def one(self, with_valid=False):
|
|
if with_valid:
|
|
super().is_valid(raise_exception=True)
|
|
model = QuerySet(Model).get(
|
|
id=self.data.get('id'), workspace_id=self.data.get('workspace_id', 'None')
|
|
)
|
|
return ModelSerializer.model_to_dict(model)
|
|
|
|
def one_meta(self, with_valid=False):
|
|
model = None
|
|
if with_valid:
|
|
super().is_valid(raise_exception=True)
|
|
model = QuerySet(Model).filter(id=self.data.get("id"),
|
|
workspace_id=self.data.get('workspace_id', 'None')).first()
|
|
if model is None:
|
|
raise AppApiException(500, _('Model does not exist'))
|
|
return {'id': str(model.id), 'provider': model.provider, 'name': model.name, 'model_type': model.model_type,
|
|
'model_name': model.model_name,
|
|
'status': model.status,
|
|
'meta': model.meta,
|
|
'workspace_id': model.workspace_id,
|
|
}
|
|
|
|
def pause_download(self, with_valid=True):
|
|
if with_valid:
|
|
self.is_valid(raise_exception=True)
|
|
QuerySet(Model).filter(id=self.data.get('id')).update(status=Status.PAUSE_DOWNLOAD)
|
|
return True
|
|
|
|
@transaction.atomic
|
|
def delete(self, with_valid=True):
|
|
if with_valid:
|
|
super().is_valid(raise_exception=True)
|
|
model_id = self.data.get('id')
|
|
model = Model.objects.filter(id=model_id).first()
|
|
if model is None:
|
|
return True
|
|
QuerySet(WorkspaceUserResourcePermission).filter(target=model_id).delete()
|
|
# TODO : 这里可以添加模型删除的逻辑,需要注意删除模型时的权限和关联关系
|
|
# if model.model_type == 'LLM':
|
|
# application_count = Application.objects.filter(model_id=model_id).count()
|
|
# if application_count > 0:
|
|
# raise AppApiException(500, f"该模型关联了{application_count} 个应用,无法删除该模型。")
|
|
# elif model.model_type == 'EMBEDDING':
|
|
# dataset_count = DataSet.objects.filter(embedding_model_id=model_id).count()
|
|
# if dataset_count > 0:
|
|
# raise AppApiException(500, f"该模型关联了{dataset_count} 个知识库,无法删除该模型。")
|
|
# elif model.model_type == 'TTS':
|
|
# dataset_count = Application.objects.filter(tts_model_id=model_id).count()
|
|
# if dataset_count > 0:
|
|
# raise AppApiException(500, f"该模型关联了{dataset_count} 个应用,无法删除该模型。")
|
|
# elif model.model_type == 'STT':
|
|
# dataset_count = Application.objects.filter(stt_model_id=model_id).count()
|
|
# if dataset_count > 0:
|
|
# raise AppApiException(500, f"该模型关联了{dataset_count} 个应用,无法删除该模型。")
|
|
model.delete()
|
|
return True
|
|
|
|
def edit(self, instance: Dict, user_id: str, with_valid=True):
|
|
if with_valid:
|
|
super().is_valid(raise_exception=True)
|
|
model = QuerySet(Model).filter(id=self.data.get('id')).first()
|
|
|
|
credential, model_credential, provider_handler = ModelSerializer.Edit(
|
|
data={**instance}).is_valid(
|
|
model=model)
|
|
try:
|
|
model.status = Status.SUCCESS
|
|
default_params = {item['field']: item['default_value'] for item in model.model_params_form}
|
|
# 校验模型认证数据
|
|
provider_handler.is_valid_credential(model.model_type,
|
|
instance.get("model_name"),
|
|
credential,
|
|
default_params,
|
|
raise_exception=True)
|
|
|
|
except AppApiException as e:
|
|
if e.code == ValidCode.model_not_fount:
|
|
model.status = Status.DOWNLOAD
|
|
else:
|
|
raise e
|
|
update_keys = ['credential', 'name', 'model_type', 'model_name']
|
|
for update_key in update_keys:
|
|
if update_key in instance and instance.get(update_key) is not None:
|
|
if update_key == 'credential':
|
|
model_credential_str = json.dumps(credential)
|
|
model.__setattr__(update_key, rsa_long_encrypt(model_credential_str))
|
|
else:
|
|
model.__setattr__(update_key, instance.get(update_key))
|
|
|
|
ModelManage.delete_key(str(model.id))
|
|
model.save()
|
|
if model.status == Status.DOWNLOAD:
|
|
thread = threading.Thread(target=ModelPullManage.pull, args=(model, credential))
|
|
thread.start()
|
|
return self.one(with_valid=False)
|
|
|
|
class Edit(serializers.Serializer):
|
|
user_id = serializers.CharField(required=False, label=(_('user id')))
|
|
|
|
name = serializers.CharField(required=False, max_length=64,
|
|
label=(_("model name")))
|
|
|
|
model_type = serializers.CharField(required=False, label=(_("model type")))
|
|
|
|
model_name = serializers.CharField(required=False, label=(_("base model")))
|
|
|
|
credential = serializers.DictField(required=False,
|
|
label=(_("certification information")))
|
|
workspace_id = serializers.CharField(required=False, label=(_("workspace id")))
|
|
|
|
def is_valid(self, model=None, raise_exception=False):
|
|
super().is_valid(raise_exception=True)
|
|
filter_params = {'workspace_id': model.workspace_id}
|
|
if 'name' in self.data and self.data.get('name') is not None:
|
|
filter_params['name'] = self.data.get('name')
|
|
if QuerySet(Model).exclude(id=model.id).filter(**filter_params).exists():
|
|
raise AppApiException(500, _('base model【{model_name}】already exists').format(
|
|
model_name=self.data.get("name")))
|
|
|
|
ModelSerializer.model_to_dict(model)
|
|
|
|
provider = model.provider
|
|
model_type = self.data.get('model_type')
|
|
model_name = self.data.get(
|
|
'model_name')
|
|
credential = self.data.get('credential')
|
|
provider_handler = ModelProvideConstants[provider].value
|
|
model_credential = ModelProvideConstants[provider].value.get_model_credential(model_type,
|
|
model_name)
|
|
source_model_credential = json.loads(rsa_long_decrypt(model.credential))
|
|
source_encryption_model_credential = model_credential.encryption_dict(source_model_credential)
|
|
if credential is not None:
|
|
for k in source_encryption_model_credential.keys():
|
|
if k in credential and credential[k] == source_encryption_model_credential[k]:
|
|
credential[k] = source_model_credential[k]
|
|
return credential, model_credential, provider_handler
|
|
|
|
class Create(serializers.Serializer):
|
|
user_id = serializers.UUIDField(required=True, label=_('user id'))
|
|
name = serializers.CharField(required=True, max_length=64, label=_("model name"))
|
|
provider = serializers.CharField(required=True, label=_("provider"))
|
|
model_type = serializers.CharField(required=True, label=_("model type"))
|
|
model_name = serializers.CharField(required=True, label=_("base model"))
|
|
model_params_form = serializers.ListField(required=False, default=list, label=_("parameter configuration"))
|
|
credential = serializers.DictField(required=True, label=_("certification information"))
|
|
workspace_id = serializers.CharField(required=False, label=_("workspace id"), max_length=128)
|
|
|
|
def is_valid(self, *, raise_exception=False):
|
|
super().is_valid(raise_exception=True)
|
|
if QuerySet(Model).filter(
|
|
name=self.data.get('name'),
|
|
workspace_id=self.data.get('workspace_id', 'None')
|
|
).exists():
|
|
raise AppApiException(
|
|
500,
|
|
_('base model【{model_name}】already exists').format(model_name=self.data.get("name"))
|
|
)
|
|
default_params = {item['field']: item['default_value'] for item in self.data.get('model_params_form')}
|
|
ModelProvideConstants[self.data.get('provider')].value.is_valid_credential(
|
|
self.data.get('model_type'),
|
|
self.data.get('model_name'),
|
|
self.data.get('credential'),
|
|
default_params,
|
|
raise_exception=True
|
|
)
|
|
|
|
def insert(self, workspace_id, with_valid=True):
|
|
status = Status.SUCCESS
|
|
if with_valid:
|
|
try:
|
|
self.is_valid(raise_exception=True)
|
|
except AppApiException as e:
|
|
if e.code == ValidCode.model_not_fount:
|
|
status = Status.DOWNLOAD
|
|
else:
|
|
raise e
|
|
|
|
credential = self.data.get('credential')
|
|
model_data = {
|
|
'id': uuid.uuid7(),
|
|
'status': status,
|
|
'user_id': self.data.get('user_id'),
|
|
'name': self.data.get('name'),
|
|
'credential': rsa_long_encrypt(json.dumps(credential)),
|
|
'provider': self.data.get('provider'),
|
|
'model_type': self.data.get('model_type'),
|
|
'model_name': self.data.get('model_name'),
|
|
'model_params_form': self.data.get('model_params_form'),
|
|
'workspace_id': workspace_id
|
|
}
|
|
model = Model(**model_data)
|
|
try:
|
|
model.save()
|
|
if workspace_id != 'None':
|
|
UserResourcePermissionSerializer(data={
|
|
'workspace_id': workspace_id,
|
|
'user_id': self.data.get('user_id'),
|
|
'auth_target_type': AuthTargetType.MODEL.value
|
|
}).auth_resource(str(model.id))
|
|
except Exception as save_error:
|
|
# 可添加日志记录
|
|
raise AppApiException(500, _("Model saving failed")) from save_error
|
|
|
|
if status == Status.DOWNLOAD:
|
|
thread = threading.Thread(target=ModelPullManage.pull, args=(model, credential))
|
|
thread.start()
|
|
|
|
return ModelModelSerializer(model).data
|
|
|
|
class Query(serializers.Serializer):
|
|
user_id = serializers.CharField(required=True, label=_("User ID"))
|
|
name = serializers.CharField(required=False, max_length=64, label=_('model name'))
|
|
model_type = serializers.CharField(required=False, label=_('model type'))
|
|
model_name = serializers.CharField(required=False, label=_('base model'))
|
|
provider = serializers.CharField(required=False, label=_('provider'))
|
|
create_user = serializers.CharField(required=False, label=_('create user'))
|
|
workspace_id = serializers.CharField(required=False, label=_('workspace id'))
|
|
|
|
@staticmethod
|
|
def is_x_pack_ee():
|
|
workspace_user_role_mapping_model = DatabaseModelManage.get_model("workspace_user_role_mapping")
|
|
role_permission_mapping_model = DatabaseModelManage.get_model("role_permission_mapping_model")
|
|
return workspace_user_role_mapping_model is not None and role_permission_mapping_model is not None
|
|
|
|
def list(self, workspace_id, with_valid):
|
|
if with_valid:
|
|
self.is_valid(raise_exception=True)
|
|
user_id = self.data.get("user_id")
|
|
workspace_manage = is_workspace_manage(user_id, workspace_id)
|
|
query_params = self._build_query_params(workspace_id, workspace_manage, user_id)
|
|
is_x_pack_ee = self.is_x_pack_ee()
|
|
return native_search(query_params,
|
|
select_string=get_file_content(
|
|
os.path.join(PROJECT_DIR, "apps", "models_provider", 'sql',
|
|
'list_model.sql' if workspace_manage else (
|
|
'list_model_user_ee.sql' if is_x_pack_ee else 'list_model_user.sql')
|
|
)))
|
|
|
|
def share_list(self, workspace_id, with_valid=True):
|
|
if with_valid:
|
|
self.is_valid(raise_exception=True)
|
|
user_id = self.data.get("user_id")
|
|
query_params = self._build_query_params(workspace_id, False, user_id)
|
|
return [
|
|
self._build_model_data(
|
|
model
|
|
) for model in query_params.get('model_query_set')
|
|
]
|
|
|
|
def model_list(self, workspace_id, with_valid=True):
|
|
if with_valid:
|
|
self.is_valid(raise_exception=True)
|
|
user_id = self.data.get("user_id")
|
|
workspace_manage = is_workspace_manage(user_id, workspace_id)
|
|
queryset = self._build_query_params(workspace_id, workspace_manage, user_id)
|
|
get_authorized_model = DatabaseModelManage.get_model("get_authorized_model")
|
|
|
|
shared_queryset = QuerySet(Model).none()
|
|
if get_authorized_model is not None:
|
|
shared_queryset = self._build_query_params('None', False, user_id)['model_query_set']
|
|
shared_queryset = get_authorized_model(shared_queryset, workspace_id)
|
|
|
|
# 构建共享模型和普通模型列表
|
|
shared_model = [self._build_model_data(model) for model in shared_queryset]
|
|
|
|
is_x_pack_ee = self.is_x_pack_ee()
|
|
normal_model = native_search(
|
|
queryset,
|
|
select_string=get_file_content(
|
|
os.path.join(
|
|
PROJECT_DIR, "apps", "models_provider", 'sql',
|
|
'list_model.sql' if workspace_manage else (
|
|
'list_model_user_ee.sql' if is_x_pack_ee else 'list_model_user.sql')
|
|
)
|
|
)
|
|
)
|
|
return {
|
|
"shared_model": shared_model,
|
|
"model": normal_model
|
|
}
|
|
|
|
def _build_query_params(self, workspace_id, workspace_manage: bool, user_id):
|
|
queryset = QuerySet(Model)
|
|
if workspace_id:
|
|
queryset = queryset.filter(workspace_id=workspace_id)
|
|
for field in ['name', 'model_type', 'model_name', 'provider', 'create_user']:
|
|
value = self.data.get(field)
|
|
if value is not None:
|
|
if field == 'name':
|
|
queryset = queryset.filter(**{f'{field}__icontains': value})
|
|
elif field == 'create_user':
|
|
queryset = queryset.filter(user_id=value)
|
|
else:
|
|
queryset = queryset.filter(**{field: value})
|
|
queryset = queryset.order_by("-create_time")
|
|
return {
|
|
'model_query_set': queryset,
|
|
'workspace_user_resource_permission_query_set': QuerySet(WorkspaceUserResourcePermission).filter(
|
|
auth_target_type="MODEL",
|
|
workspace_id=workspace_id,
|
|
user_id=user_id)} if (
|
|
not workspace_manage) else {
|
|
'model_query_set': queryset,
|
|
}
|
|
|
|
def _build_model_data(self, model):
|
|
return {
|
|
'id': str(model.id),
|
|
'provider': model.provider,
|
|
'name': model.name,
|
|
'model_type': model.model_type,
|
|
'model_name': model.model_name,
|
|
'status': model.status,
|
|
'meta': model.meta,
|
|
'user_id': model.user_id,
|
|
'username': model.user.username,
|
|
'nick_name': model.user.nick_name,
|
|
}
|
|
|
|
def page(self, current_page, page_size):
|
|
pass
|
|
|
|
class ModelParams(serializers.Serializer):
|
|
id = serializers.UUIDField(required=True, label=_('model id'))
|
|
|
|
def is_valid(self, *, raise_exception=False):
|
|
super().is_valid(raise_exception=True)
|
|
model = QuerySet(Model).filter(id=self.data.get("id")).first()
|
|
if model is None:
|
|
raise AppApiException(500, _("Model does not exist"))
|
|
|
|
def get_model_params(self, with_valid=True):
|
|
if with_valid:
|
|
self.is_valid(raise_exception=True)
|
|
model_id = self.data.get('id')
|
|
model = QuerySet(Model).filter(id=model_id).first()
|
|
return model.model_params_form
|
|
|
|
def save_model_params_form(self, model_params_form, with_valid=True):
|
|
if with_valid:
|
|
self.is_valid(raise_exception=True)
|
|
if model_params_form is None:
|
|
model_params_form = []
|
|
model_id = self.data.get('id')
|
|
model = QuerySet(Model).filter(id=model_id).first()
|
|
model.model_params_form = model_params_form
|
|
model.save()
|
|
return True
|
|
|
|
|
|
class WorkspaceSharedModelSerializer(serializers.Serializer):
|
|
workspace_id = serializers.CharField(required=True, label=_('workspace id'))
|
|
name = serializers.CharField(required=False, max_length=64, label=_('model name'))
|
|
model_type = serializers.CharField(required=False, label=_('model type'))
|
|
model_name = serializers.CharField(required=False, label=_('base model'))
|
|
provider = serializers.CharField(required=False, label=_('provider'))
|
|
create_user = serializers.CharField(required=False, label=_('create user'))
|
|
|
|
def get_share_model_list(self):
|
|
self.is_valid(raise_exception=True)
|
|
workspace_id = self.data.get('workspace_id')
|
|
|
|
queryset = self._build_queryset(workspace_id)
|
|
|
|
return [
|
|
{
|
|
'id': str(model.id),
|
|
'provider': model.provider,
|
|
'name': model.name,
|
|
'model_type': model.model_type,
|
|
'model_name': model.model_name,
|
|
'status': model.status,
|
|
'meta': model.meta,
|
|
'user_id': model.user_id,
|
|
'nick_name': model.user.nick_name,
|
|
'username': model.user.username
|
|
}
|
|
for model in queryset.order_by("-create_time")
|
|
]
|
|
|
|
def _build_queryset(self, workspace_id):
|
|
queryset = QuerySet(Model)
|
|
if workspace_id:
|
|
get_authorized_model = DatabaseModelManage.get_model("get_authorized_model")
|
|
if get_authorized_model is not None:
|
|
queryset = get_authorized_model(queryset, workspace_id)
|
|
|
|
for field in ['name', 'model_type', 'model_name', 'provider', 'create_user']:
|
|
value = self.data.get(field)
|
|
if value is not None:
|
|
if field == 'name':
|
|
queryset = queryset.filter(**{f'{field}__icontains': value})
|
|
elif field == 'create_user':
|
|
queryset = queryset.filter(user_id=value)
|
|
else:
|
|
queryset = queryset.filter(**{field: value})
|
|
|
|
return queryset
|