diff --git a/apps/common/util/lock.py b/apps/common/util/lock.py index e9dff31af..4276f1c65 100644 --- a/apps/common/util/lock.py +++ b/apps/common/util/lock.py @@ -13,13 +13,14 @@ from django.core.cache import caches memory_cache = caches['default'] -def try_lock(key: str): +def try_lock(key: str, timeout=None): """ 获取锁 - :param key: 获取锁 key + :param key: 获取锁 key + :param timeout 超时时间 :return: 是否获取到锁 """ - return memory_cache.add(key, 'lock', timeout=timedelta(hours=1).total_seconds()) + return memory_cache.add(key, 'lock', timeout=timedelta(hours=1).total_seconds() if timeout is not None else timeout) def un_lock(key: str): diff --git a/apps/setting/migrations/0003_model_meta_model_status.py b/apps/setting/migrations/0003_model_meta_model_status.py new file mode 100644 index 000000000..f4956e880 --- /dev/null +++ b/apps/setting/migrations/0003_model_meta_model_status.py @@ -0,0 +1,23 @@ +# Generated by Django 4.1.13 on 2024-03-22 17:51 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('setting', '0002_systemsetting'), + ] + + operations = [ + migrations.AddField( + model_name='model', + name='meta', + field=models.JSONField(default=dict, verbose_name='模型元数据,用于存储下载,或者错误信息'), + ), + migrations.AddField( + model_name='model', + name='status', + field=models.CharField(choices=[('SUCCESS', '成功'), ('ERROR', '失败'), ('DOWNLOAD', '下载中')], default='SUCCESS', max_length=20, verbose_name='设置类型'), + ), + ] diff --git a/apps/setting/models/model_management.py b/apps/setting/models/model_management.py index 59d427e41..d97100815 100644 --- a/apps/setting/models/model_management.py +++ b/apps/setting/models/model_management.py @@ -14,6 +14,15 @@ from common.mixins.app_model_mixin import AppModelMixin from users.models import User +class Status(models.TextChoices): + """系统设置类型""" + SUCCESS = "SUCCESS", '成功' + + ERROR = "ERROR", "失败" + + DOWNLOAD = "DOWNLOAD", '下载中' + + class Model(AppModelMixin): """ 模型数据 @@ -22,6 +31,9 @@ class Model(AppModelMixin): name = models.CharField(max_length=128, verbose_name="名称") + status = models.CharField(max_length=20, verbose_name='设置类型', choices=Status.choices, + default=Status.SUCCESS) + model_type = models.CharField(max_length=128, verbose_name="模型类型") model_name = models.CharField(max_length=128, verbose_name="模型名称") @@ -32,6 +44,8 @@ class Model(AppModelMixin): credential = models.CharField(max_length=5120, verbose_name="模型认证信息") + meta = models.JSONField(verbose_name="模型元数据,用于存储下载,或者错误信息", default=dict) + class Meta: db_table = "model" unique_together = ['name', 'user_id'] diff --git a/apps/setting/models_provider/base_model_provider.py b/apps/setting/models_provider/base_model_provider.py index 04c88b401..3796b5bbe 100644 --- a/apps/setting/models_provider/base_model_provider.py +++ b/apps/setting/models_provider/base_model_provider.py @@ -9,10 +9,42 @@ from abc import ABC, abstractmethod from enum import Enum from functools import reduce -from typing import Dict +from typing import Dict, Iterator from langchain.chat_models.base import BaseChatModel +from common.exception.app_exception import AppApiException + + +class DownModelChunkStatus(Enum): + success = "success" + error = "error" + pulling = "pulling" + unknown = 'unknown' + + +class ValidCode(Enum): + valid_error = 500 + model_not_fount = 404 + + +class DownModelChunk: + def __init__(self, status: DownModelChunkStatus, digest: str, progress: int, details: str, index: int): + self.details = details + self.status = status + self.digest = digest + self.progress = progress + self.index = index + + def to_dict(self): + return { + "details": self.details, + "status": self.status.value, + "digest": self.digest, + "progress": self.progress, + "index": self.index + } + class IModelProvider(ABC): @@ -40,6 +72,9 @@ class IModelProvider(ABC): def get_dialogue_number(self): pass + def down_model(self, model_type: str, model_name, model_credential: Dict[str, object]) -> Iterator[DownModelChunk]: + raise AppApiException(500, "当前平台不支持下载模型") + class BaseModelCredential(ABC): diff --git a/apps/setting/models_provider/impl/azure_model_provider/azure_model_provider.py b/apps/setting/models_provider/impl/azure_model_provider/azure_model_provider.py index 53668e4a9..d3a818add 100644 --- a/apps/setting/models_provider/impl/azure_model_provider/azure_model_provider.py +++ b/apps/setting/models_provider/impl/azure_model_provider/azure_model_provider.py @@ -9,8 +9,8 @@ import os from typing import Dict -from langchain_community.chat_models import AzureChatOpenAI from langchain.schema import HumanMessage +from langchain_community.chat_models import AzureChatOpenAI from common import froms from common.exception.app_exception import AppApiException @@ -18,7 +18,7 @@ from common.froms import BaseForm from common.util.file_util import get_file_content from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, BaseModelCredential, \ ModelInfo, \ - ModelTypeConst + ModelTypeConst, ValidCode from smartdoc.conf import PROJECT_DIR @@ -27,15 +27,15 @@ class AzureLLMModelCredential(BaseForm, BaseModelCredential): def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False): model_type_list = AzureModelProvider().get_model_type_list() if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): - raise AppApiException(500, f'{model_type} 模型类型不支持') + raise AppApiException(ValidCode.valid_error, f'{model_type} 模型类型不支持') if model_name not in model_dict: - raise AppApiException(500, f'{model_name} 模型名称不支持') + raise AppApiException(ValidCode.valid_error, f'{model_name} 模型名称不支持') for key in ['api_base', 'api_key', 'deployment_name']: if key not in model_credential: if raise_exception: - raise AppApiException(500, f'{key} 字段为必填字段') + raise AppApiException(ValidCode.valid_error, f'{key} 字段为必填字段') else: return False try: @@ -45,7 +45,7 @@ class AzureLLMModelCredential(BaseForm, BaseModelCredential): if isinstance(e, AppApiException): raise e if raise_exception: - raise AppApiException(500, '校验失败,请检查参数是否正确') + raise AppApiException(ValidCode.valid_error, '校验失败,请检查参数是否正确') else: return False diff --git a/apps/setting/models_provider/impl/ollama_model_provider/ollama_model_provider.py b/apps/setting/models_provider/impl/ollama_model_provider/ollama_model_provider.py index f5fc46f6e..f965727ce 100644 --- a/apps/setting/models_provider/impl/ollama_model_provider/ollama_model_provider.py +++ b/apps/setting/models_provider/impl/ollama_model_provider/ollama_model_provider.py @@ -6,9 +6,14 @@ @date:2024/3/5 17:23 @desc: """ +import json import os -from typing import Dict +from typing import Dict, Iterator +from urllib.parse import urlparse, ParseResult +import aiohttp +import requests +from django.http import StreamingHttpResponse from langchain.chat_models.base import BaseChatModel from langchain.schema import HumanMessage @@ -17,29 +22,26 @@ from common.exception.app_exception import AppApiException from common.froms import BaseForm from common.util.file_util import get_file_content from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, \ - BaseModelCredential + BaseModelCredential, DownModelChunk, DownModelChunkStatus, ValidCode from setting.models_provider.impl.ollama_model_provider.model.ollama_chat_model import OllamaChatModel from smartdoc.conf import PROJECT_DIR +"" + class OllamaLLMModelCredential(BaseForm, BaseModelCredential): def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False): model_type_list = OllamaModelProvider().get_model_type_list() if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): - raise AppApiException(500, f'{model_type} 模型类型不支持') - - for key in ['api_key']: - if key not in model_credential: - if raise_exception: - raise AppApiException(500, f'{key} 字段为必填字段') - else: - return False + raise AppApiException(ValidCode.valid_error, f'{model_type} 模型类型不支持') try: - OllamaModelProvider().get_model(model_type, model_name, model_credential).invoke( - [HumanMessage(content='valid')]) + model_list = OllamaModelProvider.get_base_model_list(model_credential.get('api_base')) except Exception as e: - if raise_exception: - raise AppApiException(500, "校验失败,请检查 api_key secret_key 是否正确") + raise AppApiException(ValidCode.valid_error, "API 域名无效") + exist = [model for model in model_list.get('models') if + model.get('model') == model_name or model.get('model').replace(":latest", "") == model_name] + if len(exist) == 0: + raise AppApiException(ValidCode.model_not_fount, "模型不存在,请先下载模型") return True def encryption_dict(self, model_info: Dict[str, object]): @@ -86,6 +88,52 @@ model_dict = { } +def get_base_url(url: str): + parse = urlparse(url) + return ParseResult(scheme=parse.scheme, netloc=parse.netloc, path='', params='', + query='', + fragment='').geturl() + + +def convert_to_down_model_chunk(row_str: str, chunk_index: int): + row = json.loads(row_str) + status = DownModelChunkStatus.unknown + digest = "" + progress = 100 + if 'status' in row: + digest = row.get('status') + if row.get('status') == 'success': + status = DownModelChunkStatus.success + if row.get('status').__contains__("pulling"): + status = DownModelChunkStatus.pulling + if 'total' in row and 'completed' in row: + progress = (row.get('completed') / row.get('total') * 100) + elif 'error' in row: + status = DownModelChunkStatus.error + digest = row.get('error') + return DownModelChunk(status=status, digest=digest, progress=progress, details=row_str, index=chunk_index) + + +def convert(response_stream) -> Iterator[DownModelChunk]: + temp = "" + index = 0 + for c in response_stream: + index += 1 + row_content = c.decode() + temp += row_content + if row_content.endswith('}') or row_content.endswith('\n'): + rows = [t for t in temp.split("\n") if len(t) > 0] + for row in rows: + yield convert_to_down_model_chunk(row, index) + temp = "" + + if len(temp) > 0: + print(temp) + rows = [t for t in temp.split("\n") if len(t) > 0] + for row in rows: + yield convert_to_down_model_chunk(row, index) + + class OllamaModelProvider(IModelProvider): def get_model_provide_info(self): return ModelProvideInfo(provider='model_ollama_provider', name='Ollama', icon=get_file_content( @@ -113,3 +161,21 @@ class OllamaModelProvider(IModelProvider): def get_dialogue_number(self): return 2 + + @staticmethod + def get_base_model_list(api_base): + base_url = get_base_url(api_base) + r = requests.request(method="GET", url=f"{base_url}/api/tags") + r.raise_for_status() + return r.json() + + def down_model(self, model_type: str, model_name, model_credential: Dict[str, object]) -> Iterator[DownModelChunk]: + api_base = model_credential.get('api_base') + base_url = get_base_url(api_base) + r = requests.request( + method="POST", + url=f"{base_url}/api/pull", + data=json.dumps({"name": model_name}).encode(), + stream=True, + ) + return convert(r) diff --git a/apps/setting/serializers/provider_serializers.py b/apps/setting/serializers/provider_serializers.py index 9a405e28d..e744b0c6b 100644 --- a/apps/setting/serializers/provider_serializers.py +++ b/apps/setting/serializers/provider_serializers.py @@ -7,6 +7,8 @@ @desc: """ import json +import threading +import time import uuid from typing import Dict @@ -17,10 +19,36 @@ from application.models import Application from common.exception.app_exception import AppApiException from common.util.field_message import ErrMessage from common.util.rsa_util import encrypt, decrypt -from setting.models.model_management import Model +from setting.models.model_management import Model, Status +from setting.models_provider.base_model_provider import ValidCode, DownModelChunkStatus from setting.models_provider.constants.model_provider_constants import ModelProvideConstants +class ModelPullManage: + + @staticmethod + def pull(model: Model, credential: Dict): + response = ModelProvideConstants[model.provider].value.down_model(model.model_type, model.model_name, + credential) + down_model_chunk = {} + timestamp = time.time() + for chunk in response: + down_model_chunk[chunk.digest] = chunk.to_dict() + if time.time() - timestamp > 5: + QuerySet(Model).filter(id=model.id).update(meta={"down_model_chunk": list(down_model_chunk.values())}) + timestamp = time.time() + status = Status.ERROR + message = "" + down_model_chunk_list = list(down_model_chunk.values()) + for chunk in down_model_chunk_list: + if chunk.get('status') == DownModelChunkStatus.success.value: + status = Status.SUCCESS + if chunk.get('status') == DownModelChunkStatus.error.value: + message = chunk.get("digest") + QuerySet(Model).filter(id=model.id).update(meta={"down_model_chunk": down_model_chunk_list, "message": message}, + status=status) + + class ModelSerializer(serializers.Serializer): class Query(serializers.Serializer): user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id")) @@ -50,7 +78,10 @@ class ModelSerializer(serializers.Serializer): if self.data.get('provider') is not None: query_params['provider'] = self.data.get('provider') - return [ModelSerializer.model_to_dict(model) for model in model_query_set.filter(**query_params)] + return [ + {'id': str(model.id), 'provider': model.provider, 'name': model.name, 'model_type': model.model_type, + 'model_name': model.model_name, 'status': model.status, 'meta': model.meta} for model in + model_query_set.filter(**query_params)] class Edit(serializers.Serializer): user_id = serializers.CharField(required=False, error_messages=ErrMessage.uuid("用户id")) @@ -88,13 +119,7 @@ class ModelSerializer(serializers.Serializer): for k in source_encryption_model_credential.keys(): if credential[k] == source_encryption_model_credential[k]: credential[k] = source_model_credential[k] - # 校验模型认证数据 - model_credential.is_valid( - model_type, - model_name, - credential, - raise_exception=True) - return credential + return credential, model_credential class Create(serializers.Serializer): user_id = serializers.CharField(required=True, error_messages=ErrMessage.uuid("用户id")) @@ -124,18 +149,28 @@ class ModelSerializer(serializers.Serializer): raise_exception=True) def insert(self, user_id, with_valid=False): + status = Status.SUCCESS if with_valid: - self.is_valid(raise_exception=True) + try: + self.is_valid(raise_exception=True) + except AppApiException as e: + if e.code == ValidCode.model_not_fount: + status = Status.DOWNLOAD + else: + raise e credential = self.data.get('credential') name = self.data.get('name') provider = self.data.get('provider') model_type = self.data.get('model_type') model_name = self.data.get('model_name') model_credential_str = json.dumps(credential) - model = Model(id=uuid.uuid1(), user_id=user_id, name=name, + model = Model(id=uuid.uuid1(), status=status, user_id=user_id, name=name, credential=encrypt(model_credential_str), provider=provider, model_type=model_type, model_name=model_name) model.save() + if status == Status.DOWNLOAD: + thread = threading.Thread(target=ModelPullManage.pull, args=(model, credential)) + thread.start() return ModelSerializer.Operate(data={'id': model.id, 'user_id': user_id}).one(with_valid=True) @staticmethod @@ -143,6 +178,8 @@ class ModelSerializer(serializers.Serializer): credential = json.loads(decrypt(model.credential)) return {'id': str(model.id), 'provider': model.provider, 'name': model.name, 'model_type': model.model_type, 'model_name': model.model_name, + 'status': model.status, + 'meta': model.meta, 'credential': ModelProvideConstants[model.provider].value.get_model_credential(model.model_type, model.model_name).encryption_dict( credential)} @@ -164,6 +201,15 @@ class ModelSerializer(serializers.Serializer): model = QuerySet(Model).get(id=self.data.get('id'), user_id=self.data.get('user_id')) return ModelSerializer.model_to_dict(model) + def one_meta(self, with_valid=False): + if with_valid: + self.is_valid(raise_exception=True) + model = QuerySet(Model).get(id=self.data.get('id'), user_id=self.data.get('user_id')) + return {'id': str(model.id), 'provider': model.provider, 'name': model.name, 'model_type': model.model_type, + 'model_name': model.model_name, + 'status': model.status, + 'meta': model.meta, } + def delete(self, with_valid=True): if with_valid: self.is_valid(raise_exception=True) @@ -181,7 +227,20 @@ class ModelSerializer(serializers.Serializer): if model is None: raise AppApiException(500, '不存在的id') else: - credential = ModelSerializer.Edit(data={**instance, 'user_id': user_id}).is_valid(model=model) + credential, model_credential = ModelSerializer.Edit(data={**instance, 'user_id': user_id}).is_valid( + model=model) + try: + # 校验模型认证数据 + model_credential.is_valid( + model.model_type, + instance.get("model_name"), + credential, + raise_exception=True) + except AppApiException as e: + if e.code == ValidCode.model_not_fount: + model.status = Status.DOWNLOAD + else: + raise e update_keys = ['credential', 'name', 'model_type', 'model_name'] for update_key in update_keys: if update_key in instance and instance.get(update_key) is not None: @@ -191,6 +250,9 @@ class ModelSerializer(serializers.Serializer): else: model.__setattr__(update_key, instance.get(update_key)) model.save() + if model.status == Status.DOWNLOAD: + thread = threading.Thread(target=ModelPullManage.pull, args=(model, credential)) + thread.start() return self.one(with_valid=False) diff --git a/apps/setting/urls.py b/apps/setting/urls.py index 460facb34..42fea74ec 100644 --- a/apps/setting/urls.py +++ b/apps/setting/urls.py @@ -16,6 +16,7 @@ urlpatterns = [ name="provider/model_form"), path('model', views.Model.as_view(), name='model'), path('model/', views.Model.Operate.as_view(), name='model/operate'), + path('model//meta', views.Model.ModelMeta.as_view(), name='model/operate/meta'), path('email_setting', views.SystemSetting.Email.as_view(), name='email_setting') ] diff --git a/apps/setting/views/model.py b/apps/setting/views/model.py index 330834efd..7ba0304fc 100644 --- a/apps/setting/views/model.py +++ b/apps/setting/views/model.py @@ -34,6 +34,17 @@ class Model(APIView): ModelSerializer.Create(data={**request.data, 'user_id': str(request.user.id)}).insert(request.user.id, with_valid=True)) + @action(methods=['PUT'], detail=False) + @swagger_auto_schema(operation_summary="下载模型,只试用与Ollama平台", + operation_id="下载模型,只试用与Ollama平台", + request_body=ModelCreateApi.get_request_body_api() + , tags=["模型"]) + @has_permissions(PermissionConstants.MODEL_CREATE) + def put(self, request: Request): + return result.success( + ModelSerializer.Create(data={**request.data, 'user_id': str(request.user.id)}).insert(request.user.id, + with_valid=True)) + @action(methods=['GET'], detail=False) @swagger_auto_schema(operation_summary="获取模型列表", operation_id="获取模型列表", @@ -46,6 +57,18 @@ class Model(APIView): data={**query_params_to_single_dict(request.query_params), 'user_id': request.user.id}).list( with_valid=True)) + class ModelMeta(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="查询模型meta信息,该接口不携带认证信息", + operation_id="查询模型meta信息,该接口不携带认证信息", + tags=["模型"]) + @has_permissions(PermissionConstants.MODEL_READ) + def get(self, request: Request, model_id: str): + return result.success( + ModelSerializer.Operate(data={'id': model_id, 'user_id': request.user.id}).one_meta(with_valid=True)) + class Operate(APIView): authentication_classes = [TokenAuth] diff --git a/ui/src/api/model.ts b/ui/src/api/model.ts index 265582117..bb98984f8 100644 --- a/ui/src/api/model.ts +++ b/ui/src/api/model.ts @@ -106,6 +106,31 @@ const updateModel: ( return put(`${prefix}/${model_id}`, request, {}, loading) } +/** + * 获取模型详情根据模型id 包括认证信息 + * @param model_id 模型id + * @param loading 加载器 + * @returns + */ +const getModelById: (model_id: string, loading?: Ref) => Promise> = ( + model_id, + loading +) => { + return get(`${prefix}/${model_id}`, {}, loading) +} +/** + * 获取模型信息不包括认证信息根据模型id + * @param model_id 模型id + * @param loading 加载器 + * @returns + */ +const getModelMetaById: (model_id: string, loading?: Ref) => Promise> = ( + model_id, + loading +) => { + return get(`${prefix}/${model_id}/meta`, {}, loading) +} + const deleteModel: (model_id: string, loading?: Ref) => Promise> = ( model_id, loading @@ -120,5 +145,7 @@ export default { listBaseModel, createModel, updateModel, - deleteModel + deleteModel, + getModelById, + getModelMetaById } diff --git a/ui/src/api/type/model.ts b/ui/src/api/type/model.ts index 167154ada..e9a5203ed 100644 --- a/ui/src/api/type/model.ts +++ b/ui/src/api/type/model.ts @@ -1,4 +1,5 @@ import { store } from '@/stores' +import { Dict } from './common' interface modelRequest { name: string model_type: string @@ -64,6 +65,14 @@ interface Model { * 供应商 */ provider: string + /** + * 状态 + */ + status: 'SUCCESS' | 'DOWNLOAD' | 'ERROR' + /** + * 元数据 + */ + meta: Dict } interface CreateModelRequest { /** diff --git a/ui/src/views/template/component/EditModel.vue b/ui/src/views/template/component/EditModel.vue index 02597da1f..a1d670e4c 100644 --- a/ui/src/views/template/component/EditModel.vue +++ b/ui/src/views/template/component/EditModel.vue @@ -18,6 +18,7 @@ () const dynamicsFormRef = ref>() const emit = defineEmits(['change', 'submit']) const loading = ref(false) +const formLoading = ref(false) const model_type_loading = ref(false) const base_model_loading = ref(false) const model_type_list = ref>>([]) @@ -152,21 +155,22 @@ const list_base_model = (model_type: any) => { } } const open = (provider: Provider, model: Model) => { - modelValue.value = model - ModelApi.listModelType(model.provider, model_type_loading).then((ok) => { - model_type_list.value = ok.data - list_base_model(model.model_type) + ModelApi.getModelById(model.id, formLoading).then((ok) => { + modelValue.value = ok.data + ModelApi.listModelType(model.provider, model_type_loading).then((ok) => { + model_type_list.value = ok.data + list_base_model(model.model_type) + }) + providerValue.value = provider + + base_form_data.value = { + name: model.name, + model_type: model.model_type, + model_name: model.model_name + } + form_data.value = model.credential + getModelForm(model.model_name) }) - - providerValue.value = provider - - base_form_data.value = { - name: model.name, - model_type: model.model_type, - model_name: model.model_name - } - form_data.value = model.credential - getModelForm(model.model_name) dialogVisible.value = true } diff --git a/ui/src/views/template/component/ModelCard.vue b/ui/src/views/template/component/ModelCard.vue index ea46dc928..8694c7894 100644 --- a/ui/src/views/template/component/ModelCard.vue +++ b/ui/src/views/template/component/ModelCard.vue @@ -37,15 +37,32 @@