mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 01:33:05 +00:00
feat: ollama支持下载模型
This commit is contained in:
parent
bdf5edc203
commit
d074424398
|
|
@ -13,13 +13,14 @@ from django.core.cache import caches
|
|||
memory_cache = caches['default']
|
||||
|
||||
|
||||
def try_lock(key: str):
|
||||
def try_lock(key: str, timeout=None):
|
||||
"""
|
||||
获取锁
|
||||
:param key: 获取锁 key
|
||||
:param key: 获取锁 key
|
||||
:param timeout 超时时间
|
||||
:return: 是否获取到锁
|
||||
"""
|
||||
return memory_cache.add(key, 'lock', timeout=timedelta(hours=1).total_seconds())
|
||||
return memory_cache.add(key, 'lock', timeout=timedelta(hours=1).total_seconds() if timeout is not None else timeout)
|
||||
|
||||
|
||||
def un_lock(key: str):
|
||||
|
|
|
|||
|
|
@ -0,0 +1,23 @@
|
|||
# Generated by Django 4.1.13 on 2024-03-22 17:51
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('setting', '0002_systemsetting'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name='model',
|
||||
name='meta',
|
||||
field=models.JSONField(default=dict, verbose_name='模型元数据,用于存储下载,或者错误信息'),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name='model',
|
||||
name='status',
|
||||
field=models.CharField(choices=[('SUCCESS', '成功'), ('ERROR', '失败'), ('DOWNLOAD', '下载中')], default='SUCCESS', max_length=20, verbose_name='设置类型'),
|
||||
),
|
||||
]
|
||||
|
|
@ -14,6 +14,15 @@ from common.mixins.app_model_mixin import AppModelMixin
|
|||
from users.models import User
|
||||
|
||||
|
||||
class Status(models.TextChoices):
|
||||
"""系统设置类型"""
|
||||
SUCCESS = "SUCCESS", '成功'
|
||||
|
||||
ERROR = "ERROR", "失败"
|
||||
|
||||
DOWNLOAD = "DOWNLOAD", '下载中'
|
||||
|
||||
|
||||
class Model(AppModelMixin):
|
||||
"""
|
||||
模型数据
|
||||
|
|
@ -22,6 +31,9 @@ class Model(AppModelMixin):
|
|||
|
||||
name = models.CharField(max_length=128, verbose_name="名称")
|
||||
|
||||
status = models.CharField(max_length=20, verbose_name='设置类型', choices=Status.choices,
|
||||
default=Status.SUCCESS)
|
||||
|
||||
model_type = models.CharField(max_length=128, verbose_name="模型类型")
|
||||
|
||||
model_name = models.CharField(max_length=128, verbose_name="模型名称")
|
||||
|
|
@ -32,6 +44,8 @@ class Model(AppModelMixin):
|
|||
|
||||
credential = models.CharField(max_length=5120, verbose_name="模型认证信息")
|
||||
|
||||
meta = models.JSONField(verbose_name="模型元数据,用于存储下载,或者错误信息", default=dict)
|
||||
|
||||
class Meta:
|
||||
db_table = "model"
|
||||
unique_together = ['name', 'user_id']
|
||||
|
|
|
|||
|
|
@ -9,10 +9,42 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from functools import reduce
|
||||
from typing import Dict
|
||||
from typing import Dict, Iterator
|
||||
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
|
||||
from common.exception.app_exception import AppApiException
|
||||
|
||||
|
||||
class DownModelChunkStatus(Enum):
|
||||
success = "success"
|
||||
error = "error"
|
||||
pulling = "pulling"
|
||||
unknown = 'unknown'
|
||||
|
||||
|
||||
class ValidCode(Enum):
|
||||
valid_error = 500
|
||||
model_not_fount = 404
|
||||
|
||||
|
||||
class DownModelChunk:
|
||||
def __init__(self, status: DownModelChunkStatus, digest: str, progress: int, details: str, index: int):
|
||||
self.details = details
|
||||
self.status = status
|
||||
self.digest = digest
|
||||
self.progress = progress
|
||||
self.index = index
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"details": self.details,
|
||||
"status": self.status.value,
|
||||
"digest": self.digest,
|
||||
"progress": self.progress,
|
||||
"index": self.index
|
||||
}
|
||||
|
||||
|
||||
class IModelProvider(ABC):
|
||||
|
||||
|
|
@ -40,6 +72,9 @@ class IModelProvider(ABC):
|
|||
def get_dialogue_number(self):
|
||||
pass
|
||||
|
||||
def down_model(self, model_type: str, model_name, model_credential: Dict[str, object]) -> Iterator[DownModelChunk]:
|
||||
raise AppApiException(500, "当前平台不支持下载模型")
|
||||
|
||||
|
||||
class BaseModelCredential(ABC):
|
||||
|
||||
|
|
|
|||
|
|
@ -9,8 +9,8 @@
|
|||
import os
|
||||
from typing import Dict
|
||||
|
||||
from langchain_community.chat_models import AzureChatOpenAI
|
||||
from langchain.schema import HumanMessage
|
||||
from langchain_community.chat_models import AzureChatOpenAI
|
||||
|
||||
from common import froms
|
||||
from common.exception.app_exception import AppApiException
|
||||
|
|
@ -18,7 +18,7 @@ from common.froms import BaseForm
|
|||
from common.util.file_util import get_file_content
|
||||
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, BaseModelCredential, \
|
||||
ModelInfo, \
|
||||
ModelTypeConst
|
||||
ModelTypeConst, ValidCode
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
|
||||
|
|
@ -27,15 +27,15 @@ class AzureLLMModelCredential(BaseForm, BaseModelCredential):
|
|||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False):
|
||||
model_type_list = AzureModelProvider().get_model_type_list()
|
||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||
raise AppApiException(500, f'{model_type} 模型类型不支持')
|
||||
raise AppApiException(ValidCode.valid_error, f'{model_type} 模型类型不支持')
|
||||
|
||||
if model_name not in model_dict:
|
||||
raise AppApiException(500, f'{model_name} 模型名称不支持')
|
||||
raise AppApiException(ValidCode.valid_error, f'{model_name} 模型名称不支持')
|
||||
|
||||
for key in ['api_base', 'api_key', 'deployment_name']:
|
||||
if key not in model_credential:
|
||||
if raise_exception:
|
||||
raise AppApiException(500, f'{key} 字段为必填字段')
|
||||
raise AppApiException(ValidCode.valid_error, f'{key} 字段为必填字段')
|
||||
else:
|
||||
return False
|
||||
try:
|
||||
|
|
@ -45,7 +45,7 @@ class AzureLLMModelCredential(BaseForm, BaseModelCredential):
|
|||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
raise AppApiException(500, '校验失败,请检查参数是否正确')
|
||||
raise AppApiException(ValidCode.valid_error, '校验失败,请检查参数是否正确')
|
||||
else:
|
||||
return False
|
||||
|
||||
|
|
|
|||
|
|
@ -6,9 +6,14 @@
|
|||
@date:2024/3/5 17:23
|
||||
@desc:
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
from typing import Dict
|
||||
from typing import Dict, Iterator
|
||||
from urllib.parse import urlparse, ParseResult
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
from django.http import StreamingHttpResponse
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.schema import HumanMessage
|
||||
|
||||
|
|
@ -17,29 +22,26 @@ from common.exception.app_exception import AppApiException
|
|||
from common.froms import BaseForm
|
||||
from common.util.file_util import get_file_content
|
||||
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, \
|
||||
BaseModelCredential
|
||||
BaseModelCredential, DownModelChunk, DownModelChunkStatus, ValidCode
|
||||
from setting.models_provider.impl.ollama_model_provider.model.ollama_chat_model import OllamaChatModel
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
""
|
||||
|
||||
|
||||
class OllamaLLMModelCredential(BaseForm, BaseModelCredential):
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False):
|
||||
model_type_list = OllamaModelProvider().get_model_type_list()
|
||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||
raise AppApiException(500, f'{model_type} 模型类型不支持')
|
||||
|
||||
for key in ['api_key']:
|
||||
if key not in model_credential:
|
||||
if raise_exception:
|
||||
raise AppApiException(500, f'{key} 字段为必填字段')
|
||||
else:
|
||||
return False
|
||||
raise AppApiException(ValidCode.valid_error, f'{model_type} 模型类型不支持')
|
||||
try:
|
||||
OllamaModelProvider().get_model(model_type, model_name, model_credential).invoke(
|
||||
[HumanMessage(content='valid')])
|
||||
model_list = OllamaModelProvider.get_base_model_list(model_credential.get('api_base'))
|
||||
except Exception as e:
|
||||
if raise_exception:
|
||||
raise AppApiException(500, "校验失败,请检查 api_key secret_key 是否正确")
|
||||
raise AppApiException(ValidCode.valid_error, "API 域名无效")
|
||||
exist = [model for model in model_list.get('models') if
|
||||
model.get('model') == model_name or model.get('model').replace(":latest", "") == model_name]
|
||||
if len(exist) == 0:
|
||||
raise AppApiException(ValidCode.model_not_fount, "模型不存在,请先下载模型")
|
||||
return True
|
||||
|
||||
def encryption_dict(self, model_info: Dict[str, object]):
|
||||
|
|
@ -86,6 +88,52 @@ model_dict = {
|
|||
}
|
||||
|
||||
|
||||
def get_base_url(url: str):
|
||||
parse = urlparse(url)
|
||||
return ParseResult(scheme=parse.scheme, netloc=parse.netloc, path='', params='',
|
||||
query='',
|
||||
fragment='').geturl()
|
||||
|
||||
|
||||
def convert_to_down_model_chunk(row_str: str, chunk_index: int):
|
||||
row = json.loads(row_str)
|
||||
status = DownModelChunkStatus.unknown
|
||||
digest = ""
|
||||
progress = 100
|
||||
if 'status' in row:
|
||||
digest = row.get('status')
|
||||
if row.get('status') == 'success':
|
||||
status = DownModelChunkStatus.success
|
||||
if row.get('status').__contains__("pulling"):
|
||||
status = DownModelChunkStatus.pulling
|
||||
if 'total' in row and 'completed' in row:
|
||||
progress = (row.get('completed') / row.get('total') * 100)
|
||||
elif 'error' in row:
|
||||
status = DownModelChunkStatus.error
|
||||
digest = row.get('error')
|
||||
return DownModelChunk(status=status, digest=digest, progress=progress, details=row_str, index=chunk_index)
|
||||
|
||||
|
||||
def convert(response_stream) -> Iterator[DownModelChunk]:
|
||||
temp = ""
|
||||
index = 0
|
||||
for c in response_stream:
|
||||
index += 1
|
||||
row_content = c.decode()
|
||||
temp += row_content
|
||||
if row_content.endswith('}') or row_content.endswith('\n'):
|
||||
rows = [t for t in temp.split("\n") if len(t) > 0]
|
||||
for row in rows:
|
||||
yield convert_to_down_model_chunk(row, index)
|
||||
temp = ""
|
||||
|
||||
if len(temp) > 0:
|
||||
print(temp)
|
||||
rows = [t for t in temp.split("\n") if len(t) > 0]
|
||||
for row in rows:
|
||||
yield convert_to_down_model_chunk(row, index)
|
||||
|
||||
|
||||
class OllamaModelProvider(IModelProvider):
|
||||
def get_model_provide_info(self):
|
||||
return ModelProvideInfo(provider='model_ollama_provider', name='Ollama', icon=get_file_content(
|
||||
|
|
@ -113,3 +161,21 @@ class OllamaModelProvider(IModelProvider):
|
|||
|
||||
def get_dialogue_number(self):
|
||||
return 2
|
||||
|
||||
@staticmethod
|
||||
def get_base_model_list(api_base):
|
||||
base_url = get_base_url(api_base)
|
||||
r = requests.request(method="GET", url=f"{base_url}/api/tags")
|
||||
r.raise_for_status()
|
||||
return r.json()
|
||||
|
||||
def down_model(self, model_type: str, model_name, model_credential: Dict[str, object]) -> Iterator[DownModelChunk]:
|
||||
api_base = model_credential.get('api_base')
|
||||
base_url = get_base_url(api_base)
|
||||
r = requests.request(
|
||||
method="POST",
|
||||
url=f"{base_url}/api/pull",
|
||||
data=json.dumps({"name": model_name}).encode(),
|
||||
stream=True,
|
||||
)
|
||||
return convert(r)
|
||||
|
|
|
|||
|
|
@ -7,6 +7,8 @@
|
|||
@desc:
|
||||
"""
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from typing import Dict
|
||||
|
||||
|
|
@ -17,10 +19,36 @@ from application.models import Application
|
|||
from common.exception.app_exception import AppApiException
|
||||
from common.util.field_message import ErrMessage
|
||||
from common.util.rsa_util import encrypt, decrypt
|
||||
from setting.models.model_management import Model
|
||||
from setting.models.model_management import Model, Status
|
||||
from setting.models_provider.base_model_provider import ValidCode, DownModelChunkStatus
|
||||
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
|
||||
|
||||
|
||||
class ModelPullManage:
|
||||
|
||||
@staticmethod
|
||||
def pull(model: Model, credential: Dict):
|
||||
response = ModelProvideConstants[model.provider].value.down_model(model.model_type, model.model_name,
|
||||
credential)
|
||||
down_model_chunk = {}
|
||||
timestamp = time.time()
|
||||
for chunk in response:
|
||||
down_model_chunk[chunk.digest] = chunk.to_dict()
|
||||
if time.time() - timestamp > 5:
|
||||
QuerySet(Model).filter(id=model.id).update(meta={"down_model_chunk": list(down_model_chunk.values())})
|
||||
timestamp = time.time()
|
||||
status = Status.ERROR
|
||||
message = ""
|
||||
down_model_chunk_list = list(down_model_chunk.values())
|
||||
for chunk in down_model_chunk_list:
|
||||
if chunk.get('status') == DownModelChunkStatus.success.value:
|
||||
status = Status.SUCCESS
|
||||
if chunk.get('status') == DownModelChunkStatus.error.value:
|
||||
message = chunk.get("digest")
|
||||
QuerySet(Model).filter(id=model.id).update(meta={"down_model_chunk": down_model_chunk_list, "message": message},
|
||||
status=status)
|
||||
|
||||
|
||||
class ModelSerializer(serializers.Serializer):
|
||||
class Query(serializers.Serializer):
|
||||
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id"))
|
||||
|
|
@ -50,7 +78,10 @@ class ModelSerializer(serializers.Serializer):
|
|||
if self.data.get('provider') is not None:
|
||||
query_params['provider'] = self.data.get('provider')
|
||||
|
||||
return [ModelSerializer.model_to_dict(model) for model in model_query_set.filter(**query_params)]
|
||||
return [
|
||||
{'id': str(model.id), 'provider': model.provider, 'name': model.name, 'model_type': model.model_type,
|
||||
'model_name': model.model_name, 'status': model.status, 'meta': model.meta} for model in
|
||||
model_query_set.filter(**query_params)]
|
||||
|
||||
class Edit(serializers.Serializer):
|
||||
user_id = serializers.CharField(required=False, error_messages=ErrMessage.uuid("用户id"))
|
||||
|
|
@ -88,13 +119,7 @@ class ModelSerializer(serializers.Serializer):
|
|||
for k in source_encryption_model_credential.keys():
|
||||
if credential[k] == source_encryption_model_credential[k]:
|
||||
credential[k] = source_model_credential[k]
|
||||
# 校验模型认证数据
|
||||
model_credential.is_valid(
|
||||
model_type,
|
||||
model_name,
|
||||
credential,
|
||||
raise_exception=True)
|
||||
return credential
|
||||
return credential, model_credential
|
||||
|
||||
class Create(serializers.Serializer):
|
||||
user_id = serializers.CharField(required=True, error_messages=ErrMessage.uuid("用户id"))
|
||||
|
|
@ -124,18 +149,28 @@ class ModelSerializer(serializers.Serializer):
|
|||
raise_exception=True)
|
||||
|
||||
def insert(self, user_id, with_valid=False):
|
||||
status = Status.SUCCESS
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
try:
|
||||
self.is_valid(raise_exception=True)
|
||||
except AppApiException as e:
|
||||
if e.code == ValidCode.model_not_fount:
|
||||
status = Status.DOWNLOAD
|
||||
else:
|
||||
raise e
|
||||
credential = self.data.get('credential')
|
||||
name = self.data.get('name')
|
||||
provider = self.data.get('provider')
|
||||
model_type = self.data.get('model_type')
|
||||
model_name = self.data.get('model_name')
|
||||
model_credential_str = json.dumps(credential)
|
||||
model = Model(id=uuid.uuid1(), user_id=user_id, name=name,
|
||||
model = Model(id=uuid.uuid1(), status=status, user_id=user_id, name=name,
|
||||
credential=encrypt(model_credential_str),
|
||||
provider=provider, model_type=model_type, model_name=model_name)
|
||||
model.save()
|
||||
if status == Status.DOWNLOAD:
|
||||
thread = threading.Thread(target=ModelPullManage.pull, args=(model, credential))
|
||||
thread.start()
|
||||
return ModelSerializer.Operate(data={'id': model.id, 'user_id': user_id}).one(with_valid=True)
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -143,6 +178,8 @@ class ModelSerializer(serializers.Serializer):
|
|||
credential = json.loads(decrypt(model.credential))
|
||||
return {'id': str(model.id), 'provider': model.provider, 'name': model.name, 'model_type': model.model_type,
|
||||
'model_name': model.model_name,
|
||||
'status': model.status,
|
||||
'meta': model.meta,
|
||||
'credential': ModelProvideConstants[model.provider].value.get_model_credential(model.model_type,
|
||||
model.model_name).encryption_dict(
|
||||
credential)}
|
||||
|
|
@ -164,6 +201,15 @@ class ModelSerializer(serializers.Serializer):
|
|||
model = QuerySet(Model).get(id=self.data.get('id'), user_id=self.data.get('user_id'))
|
||||
return ModelSerializer.model_to_dict(model)
|
||||
|
||||
def one_meta(self, with_valid=False):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
model = QuerySet(Model).get(id=self.data.get('id'), user_id=self.data.get('user_id'))
|
||||
return {'id': str(model.id), 'provider': model.provider, 'name': model.name, 'model_type': model.model_type,
|
||||
'model_name': model.model_name,
|
||||
'status': model.status,
|
||||
'meta': model.meta, }
|
||||
|
||||
def delete(self, with_valid=True):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
|
|
@ -181,7 +227,20 @@ class ModelSerializer(serializers.Serializer):
|
|||
if model is None:
|
||||
raise AppApiException(500, '不存在的id')
|
||||
else:
|
||||
credential = ModelSerializer.Edit(data={**instance, 'user_id': user_id}).is_valid(model=model)
|
||||
credential, model_credential = ModelSerializer.Edit(data={**instance, 'user_id': user_id}).is_valid(
|
||||
model=model)
|
||||
try:
|
||||
# 校验模型认证数据
|
||||
model_credential.is_valid(
|
||||
model.model_type,
|
||||
instance.get("model_name"),
|
||||
credential,
|
||||
raise_exception=True)
|
||||
except AppApiException as e:
|
||||
if e.code == ValidCode.model_not_fount:
|
||||
model.status = Status.DOWNLOAD
|
||||
else:
|
||||
raise e
|
||||
update_keys = ['credential', 'name', 'model_type', 'model_name']
|
||||
for update_key in update_keys:
|
||||
if update_key in instance and instance.get(update_key) is not None:
|
||||
|
|
@ -191,6 +250,9 @@ class ModelSerializer(serializers.Serializer):
|
|||
else:
|
||||
model.__setattr__(update_key, instance.get(update_key))
|
||||
model.save()
|
||||
if model.status == Status.DOWNLOAD:
|
||||
thread = threading.Thread(target=ModelPullManage.pull, args=(model, credential))
|
||||
thread.start()
|
||||
return self.one(with_valid=False)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ urlpatterns = [
|
|||
name="provider/model_form"),
|
||||
path('model', views.Model.as_view(), name='model'),
|
||||
path('model/<str:model_id>', views.Model.Operate.as_view(), name='model/operate'),
|
||||
path('model/<str:model_id>/meta', views.Model.ModelMeta.as_view(), name='model/operate/meta'),
|
||||
path('email_setting', views.SystemSetting.Email.as_view(), name='email_setting')
|
||||
|
||||
]
|
||||
|
|
|
|||
|
|
@ -34,6 +34,17 @@ class Model(APIView):
|
|||
ModelSerializer.Create(data={**request.data, 'user_id': str(request.user.id)}).insert(request.user.id,
|
||||
with_valid=True))
|
||||
|
||||
@action(methods=['PUT'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="下载模型,只试用与Ollama平台",
|
||||
operation_id="下载模型,只试用与Ollama平台",
|
||||
request_body=ModelCreateApi.get_request_body_api()
|
||||
, tags=["模型"])
|
||||
@has_permissions(PermissionConstants.MODEL_CREATE)
|
||||
def put(self, request: Request):
|
||||
return result.success(
|
||||
ModelSerializer.Create(data={**request.data, 'user_id': str(request.user.id)}).insert(request.user.id,
|
||||
with_valid=True))
|
||||
|
||||
@action(methods=['GET'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="获取模型列表",
|
||||
operation_id="获取模型列表",
|
||||
|
|
@ -46,6 +57,18 @@ class Model(APIView):
|
|||
data={**query_params_to_single_dict(request.query_params), 'user_id': request.user.id}).list(
|
||||
with_valid=True))
|
||||
|
||||
class ModelMeta(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
@action(methods=['GET'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="查询模型meta信息,该接口不携带认证信息",
|
||||
operation_id="查询模型meta信息,该接口不携带认证信息",
|
||||
tags=["模型"])
|
||||
@has_permissions(PermissionConstants.MODEL_READ)
|
||||
def get(self, request: Request, model_id: str):
|
||||
return result.success(
|
||||
ModelSerializer.Operate(data={'id': model_id, 'user_id': request.user.id}).one_meta(with_valid=True))
|
||||
|
||||
class Operate(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
|
|
|
|||
|
|
@ -106,6 +106,31 @@ const updateModel: (
|
|||
return put(`${prefix}/${model_id}`, request, {}, loading)
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取模型详情根据模型id 包括认证信息
|
||||
* @param model_id 模型id
|
||||
* @param loading 加载器
|
||||
* @returns
|
||||
*/
|
||||
const getModelById: (model_id: string, loading?: Ref<boolean>) => Promise<Result<Model>> = (
|
||||
model_id,
|
||||
loading
|
||||
) => {
|
||||
return get(`${prefix}/${model_id}`, {}, loading)
|
||||
}
|
||||
/**
|
||||
* 获取模型信息不包括认证信息根据模型id
|
||||
* @param model_id 模型id
|
||||
* @param loading 加载器
|
||||
* @returns
|
||||
*/
|
||||
const getModelMetaById: (model_id: string, loading?: Ref<boolean>) => Promise<Result<Model>> = (
|
||||
model_id,
|
||||
loading
|
||||
) => {
|
||||
return get(`${prefix}/${model_id}/meta`, {}, loading)
|
||||
}
|
||||
|
||||
const deleteModel: (model_id: string, loading?: Ref<boolean>) => Promise<Result<boolean>> = (
|
||||
model_id,
|
||||
loading
|
||||
|
|
@ -120,5 +145,7 @@ export default {
|
|||
listBaseModel,
|
||||
createModel,
|
||||
updateModel,
|
||||
deleteModel
|
||||
deleteModel,
|
||||
getModelById,
|
||||
getModelMetaById
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import { store } from '@/stores'
|
||||
import { Dict } from './common'
|
||||
interface modelRequest {
|
||||
name: string
|
||||
model_type: string
|
||||
|
|
@ -64,6 +65,14 @@ interface Model {
|
|||
* 供应商
|
||||
*/
|
||||
provider: string
|
||||
/**
|
||||
* 状态
|
||||
*/
|
||||
status: 'SUCCESS' | 'DOWNLOAD' | 'ERROR'
|
||||
/**
|
||||
* 元数据
|
||||
*/
|
||||
meta: Dict<any>
|
||||
}
|
||||
interface CreateModelRequest {
|
||||
/**
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@
|
|||
</template>
|
||||
|
||||
<DynamicsForm
|
||||
v-loading="formLoading"
|
||||
v-model="form_data"
|
||||
:render_data="model_form_field"
|
||||
:model="form_data"
|
||||
|
|
@ -56,7 +57,7 @@
|
|||
@change="getModelForm($event)"
|
||||
v-loading="base_model_loading"
|
||||
style="width: 100%"
|
||||
v-model="form_data.model_name"
|
||||
v-model="base_form_data.model_name"
|
||||
class="m-2"
|
||||
placeholder="请选择基础模型"
|
||||
filterable
|
||||
|
|
@ -90,10 +91,12 @@ import type { FormField } from '@/components/dynamics-form/type'
|
|||
import DynamicsForm from '@/components/dynamics-form/index.vue'
|
||||
import type { FormRules } from 'element-plus'
|
||||
import { MsgSuccess } from '@/utils/message'
|
||||
|
||||
const providerValue = ref<Provider>()
|
||||
const dynamicsFormRef = ref<InstanceType<typeof DynamicsForm>>()
|
||||
const emit = defineEmits(['change', 'submit'])
|
||||
const loading = ref<boolean>(false)
|
||||
const formLoading = ref<boolean>(false)
|
||||
const model_type_loading = ref<boolean>(false)
|
||||
const base_model_loading = ref<boolean>(false)
|
||||
const model_type_list = ref<Array<KeyValue<string, string>>>([])
|
||||
|
|
@ -152,21 +155,22 @@ const list_base_model = (model_type: any) => {
|
|||
}
|
||||
}
|
||||
const open = (provider: Provider, model: Model) => {
|
||||
modelValue.value = model
|
||||
ModelApi.listModelType(model.provider, model_type_loading).then((ok) => {
|
||||
model_type_list.value = ok.data
|
||||
list_base_model(model.model_type)
|
||||
ModelApi.getModelById(model.id, formLoading).then((ok) => {
|
||||
modelValue.value = ok.data
|
||||
ModelApi.listModelType(model.provider, model_type_loading).then((ok) => {
|
||||
model_type_list.value = ok.data
|
||||
list_base_model(model.model_type)
|
||||
})
|
||||
providerValue.value = provider
|
||||
|
||||
base_form_data.value = {
|
||||
name: model.name,
|
||||
model_type: model.model_type,
|
||||
model_name: model.model_name
|
||||
}
|
||||
form_data.value = model.credential
|
||||
getModelForm(model.model_name)
|
||||
})
|
||||
|
||||
providerValue.value = provider
|
||||
|
||||
base_form_data.value = {
|
||||
name: model.name,
|
||||
model_type: model.model_type,
|
||||
model_name: model.model_name
|
||||
}
|
||||
form_data.value = model.credential
|
||||
getModelForm(model.model_name)
|
||||
dialogVisible.value = true
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -37,15 +37,32 @@
|
|||
<script setup lang="ts">
|
||||
import type { Provider, Model } from '@/api/type/model'
|
||||
import ModelApi from '@/api/model'
|
||||
import { computed, ref } from 'vue'
|
||||
import { computed, ref, onMounted, onBeforeUnmount } from 'vue'
|
||||
import EditModel from '@/views/template/component/EditModel.vue'
|
||||
import { MsgSuccess, MsgConfirm } from '@/utils/message'
|
||||
|
||||
const props = defineProps<{
|
||||
model: Model
|
||||
provider_list: Array<Provider>
|
||||
}>()
|
||||
const downModel = ref<Model>()
|
||||
|
||||
const progress = computed(() => {
|
||||
if (downModel.value) {
|
||||
const down_model_chunk = downModel.value.meta['down_model_chunk']
|
||||
if (down_model_chunk) {
|
||||
const maxObj = down_model_chunk.reduce((prev: any, current: any) => {
|
||||
return (prev.index || 0) > (current.index || 0) ? prev : current
|
||||
})
|
||||
return maxObj.progress
|
||||
}
|
||||
return 0
|
||||
}
|
||||
return 0
|
||||
})
|
||||
const emit = defineEmits(['change'])
|
||||
const eidtModelRef = ref<InstanceType<typeof EditModel>>()
|
||||
let interval: any
|
||||
const deleteModel = () => {
|
||||
MsgConfirm(`删除模型 `, `是否删除模型:${props.model.name} ?`, {
|
||||
confirmButtonText: '删除',
|
||||
|
|
@ -67,6 +84,34 @@ const openEditModel = () => {
|
|||
const icon = computed(() => {
|
||||
return props.provider_list.find((p) => p.provider === props.model.provider)?.icon
|
||||
})
|
||||
|
||||
/**
|
||||
* 初始化轮询
|
||||
*/
|
||||
const initInterval = () => {
|
||||
interval = setInterval(() => {
|
||||
if (props.model.status === 'DOWNLOAD') {
|
||||
ModelApi.getModelMetaById(props.model.id).then((ok) => {
|
||||
downModel.value = ok.data
|
||||
})
|
||||
}
|
||||
}, 6000)
|
||||
}
|
||||
/**
|
||||
* 关闭轮询
|
||||
*/
|
||||
const closeInterval = () => {
|
||||
if (interval) {
|
||||
clearInterval(interval)
|
||||
}
|
||||
}
|
||||
onMounted(() => {
|
||||
initInterval()
|
||||
})
|
||||
onBeforeUnmount(() => {
|
||||
// 清除定时任务
|
||||
closeInterval()
|
||||
})
|
||||
</script>
|
||||
<style lang="scss" scoped>
|
||||
.model-card {
|
||||
|
|
|
|||
Loading…
Reference in New Issue