feat: ollama支持下载模型

This commit is contained in:
shaohuzhang1 2024-03-22 17:56:56 +08:00
parent bdf5edc203
commit d074424398
13 changed files with 363 additions and 53 deletions

View File

@ -13,13 +13,14 @@ from django.core.cache import caches
memory_cache = caches['default']
def try_lock(key: str):
def try_lock(key: str, timeout=None):
"""
获取锁
:param key: 获取锁 key
:param key: 获取锁 key
:param timeout 超时时间
:return: 是否获取到锁
"""
return memory_cache.add(key, 'lock', timeout=timedelta(hours=1).total_seconds())
return memory_cache.add(key, 'lock', timeout=timedelta(hours=1).total_seconds() if timeout is not None else timeout)
def un_lock(key: str):

View File

@ -0,0 +1,23 @@
# Generated by Django 4.1.13 on 2024-03-22 17:51
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('setting', '0002_systemsetting'),
]
operations = [
migrations.AddField(
model_name='model',
name='meta',
field=models.JSONField(default=dict, verbose_name='模型元数据,用于存储下载,或者错误信息'),
),
migrations.AddField(
model_name='model',
name='status',
field=models.CharField(choices=[('SUCCESS', '成功'), ('ERROR', '失败'), ('DOWNLOAD', '下载中')], default='SUCCESS', max_length=20, verbose_name='设置类型'),
),
]

View File

@ -14,6 +14,15 @@ from common.mixins.app_model_mixin import AppModelMixin
from users.models import User
class Status(models.TextChoices):
"""系统设置类型"""
SUCCESS = "SUCCESS", '成功'
ERROR = "ERROR", "失败"
DOWNLOAD = "DOWNLOAD", '下载中'
class Model(AppModelMixin):
"""
模型数据
@ -22,6 +31,9 @@ class Model(AppModelMixin):
name = models.CharField(max_length=128, verbose_name="名称")
status = models.CharField(max_length=20, verbose_name='设置类型', choices=Status.choices,
default=Status.SUCCESS)
model_type = models.CharField(max_length=128, verbose_name="模型类型")
model_name = models.CharField(max_length=128, verbose_name="模型名称")
@ -32,6 +44,8 @@ class Model(AppModelMixin):
credential = models.CharField(max_length=5120, verbose_name="模型认证信息")
meta = models.JSONField(verbose_name="模型元数据,用于存储下载,或者错误信息", default=dict)
class Meta:
db_table = "model"
unique_together = ['name', 'user_id']

View File

@ -9,10 +9,42 @@
from abc import ABC, abstractmethod
from enum import Enum
from functools import reduce
from typing import Dict
from typing import Dict, Iterator
from langchain.chat_models.base import BaseChatModel
from common.exception.app_exception import AppApiException
class DownModelChunkStatus(Enum):
success = "success"
error = "error"
pulling = "pulling"
unknown = 'unknown'
class ValidCode(Enum):
valid_error = 500
model_not_fount = 404
class DownModelChunk:
def __init__(self, status: DownModelChunkStatus, digest: str, progress: int, details: str, index: int):
self.details = details
self.status = status
self.digest = digest
self.progress = progress
self.index = index
def to_dict(self):
return {
"details": self.details,
"status": self.status.value,
"digest": self.digest,
"progress": self.progress,
"index": self.index
}
class IModelProvider(ABC):
@ -40,6 +72,9 @@ class IModelProvider(ABC):
def get_dialogue_number(self):
pass
def down_model(self, model_type: str, model_name, model_credential: Dict[str, object]) -> Iterator[DownModelChunk]:
raise AppApiException(500, "当前平台不支持下载模型")
class BaseModelCredential(ABC):

View File

@ -9,8 +9,8 @@
import os
from typing import Dict
from langchain_community.chat_models import AzureChatOpenAI
from langchain.schema import HumanMessage
from langchain_community.chat_models import AzureChatOpenAI
from common import froms
from common.exception.app_exception import AppApiException
@ -18,7 +18,7 @@ from common.froms import BaseForm
from common.util.file_util import get_file_content
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, BaseModelCredential, \
ModelInfo, \
ModelTypeConst
ModelTypeConst, ValidCode
from smartdoc.conf import PROJECT_DIR
@ -27,15 +27,15 @@ class AzureLLMModelCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False):
model_type_list = AzureModelProvider().get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(500, f'{model_type} 模型类型不支持')
raise AppApiException(ValidCode.valid_error, f'{model_type} 模型类型不支持')
if model_name not in model_dict:
raise AppApiException(500, f'{model_name} 模型名称不支持')
raise AppApiException(ValidCode.valid_error, f'{model_name} 模型名称不支持')
for key in ['api_base', 'api_key', 'deployment_name']:
if key not in model_credential:
if raise_exception:
raise AppApiException(500, f'{key} 字段为必填字段')
raise AppApiException(ValidCode.valid_error, f'{key} 字段为必填字段')
else:
return False
try:
@ -45,7 +45,7 @@ class AzureLLMModelCredential(BaseForm, BaseModelCredential):
if isinstance(e, AppApiException):
raise e
if raise_exception:
raise AppApiException(500, '校验失败,请检查参数是否正确')
raise AppApiException(ValidCode.valid_error, '校验失败,请检查参数是否正确')
else:
return False

View File

@ -6,9 +6,14 @@
@date2024/3/5 17:23
@desc:
"""
import json
import os
from typing import Dict
from typing import Dict, Iterator
from urllib.parse import urlparse, ParseResult
import aiohttp
import requests
from django.http import StreamingHttpResponse
from langchain.chat_models.base import BaseChatModel
from langchain.schema import HumanMessage
@ -17,29 +22,26 @@ from common.exception.app_exception import AppApiException
from common.froms import BaseForm
from common.util.file_util import get_file_content
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, \
BaseModelCredential
BaseModelCredential, DownModelChunk, DownModelChunkStatus, ValidCode
from setting.models_provider.impl.ollama_model_provider.model.ollama_chat_model import OllamaChatModel
from smartdoc.conf import PROJECT_DIR
""
class OllamaLLMModelCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False):
model_type_list = OllamaModelProvider().get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(500, f'{model_type} 模型类型不支持')
for key in ['api_key']:
if key not in model_credential:
if raise_exception:
raise AppApiException(500, f'{key} 字段为必填字段')
else:
return False
raise AppApiException(ValidCode.valid_error, f'{model_type} 模型类型不支持')
try:
OllamaModelProvider().get_model(model_type, model_name, model_credential).invoke(
[HumanMessage(content='valid')])
model_list = OllamaModelProvider.get_base_model_list(model_credential.get('api_base'))
except Exception as e:
if raise_exception:
raise AppApiException(500, "校验失败,请检查 api_key secret_key 是否正确")
raise AppApiException(ValidCode.valid_error, "API 域名无效")
exist = [model for model in model_list.get('models') if
model.get('model') == model_name or model.get('model').replace(":latest", "") == model_name]
if len(exist) == 0:
raise AppApiException(ValidCode.model_not_fount, "模型不存在,请先下载模型")
return True
def encryption_dict(self, model_info: Dict[str, object]):
@ -86,6 +88,52 @@ model_dict = {
}
def get_base_url(url: str):
parse = urlparse(url)
return ParseResult(scheme=parse.scheme, netloc=parse.netloc, path='', params='',
query='',
fragment='').geturl()
def convert_to_down_model_chunk(row_str: str, chunk_index: int):
row = json.loads(row_str)
status = DownModelChunkStatus.unknown
digest = ""
progress = 100
if 'status' in row:
digest = row.get('status')
if row.get('status') == 'success':
status = DownModelChunkStatus.success
if row.get('status').__contains__("pulling"):
status = DownModelChunkStatus.pulling
if 'total' in row and 'completed' in row:
progress = (row.get('completed') / row.get('total') * 100)
elif 'error' in row:
status = DownModelChunkStatus.error
digest = row.get('error')
return DownModelChunk(status=status, digest=digest, progress=progress, details=row_str, index=chunk_index)
def convert(response_stream) -> Iterator[DownModelChunk]:
temp = ""
index = 0
for c in response_stream:
index += 1
row_content = c.decode()
temp += row_content
if row_content.endswith('}') or row_content.endswith('\n'):
rows = [t for t in temp.split("\n") if len(t) > 0]
for row in rows:
yield convert_to_down_model_chunk(row, index)
temp = ""
if len(temp) > 0:
print(temp)
rows = [t for t in temp.split("\n") if len(t) > 0]
for row in rows:
yield convert_to_down_model_chunk(row, index)
class OllamaModelProvider(IModelProvider):
def get_model_provide_info(self):
return ModelProvideInfo(provider='model_ollama_provider', name='Ollama', icon=get_file_content(
@ -113,3 +161,21 @@ class OllamaModelProvider(IModelProvider):
def get_dialogue_number(self):
return 2
@staticmethod
def get_base_model_list(api_base):
base_url = get_base_url(api_base)
r = requests.request(method="GET", url=f"{base_url}/api/tags")
r.raise_for_status()
return r.json()
def down_model(self, model_type: str, model_name, model_credential: Dict[str, object]) -> Iterator[DownModelChunk]:
api_base = model_credential.get('api_base')
base_url = get_base_url(api_base)
r = requests.request(
method="POST",
url=f"{base_url}/api/pull",
data=json.dumps({"name": model_name}).encode(),
stream=True,
)
return convert(r)

View File

@ -7,6 +7,8 @@
@desc:
"""
import json
import threading
import time
import uuid
from typing import Dict
@ -17,10 +19,36 @@ from application.models import Application
from common.exception.app_exception import AppApiException
from common.util.field_message import ErrMessage
from common.util.rsa_util import encrypt, decrypt
from setting.models.model_management import Model
from setting.models.model_management import Model, Status
from setting.models_provider.base_model_provider import ValidCode, DownModelChunkStatus
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
class ModelPullManage:
@staticmethod
def pull(model: Model, credential: Dict):
response = ModelProvideConstants[model.provider].value.down_model(model.model_type, model.model_name,
credential)
down_model_chunk = {}
timestamp = time.time()
for chunk in response:
down_model_chunk[chunk.digest] = chunk.to_dict()
if time.time() - timestamp > 5:
QuerySet(Model).filter(id=model.id).update(meta={"down_model_chunk": list(down_model_chunk.values())})
timestamp = time.time()
status = Status.ERROR
message = ""
down_model_chunk_list = list(down_model_chunk.values())
for chunk in down_model_chunk_list:
if chunk.get('status') == DownModelChunkStatus.success.value:
status = Status.SUCCESS
if chunk.get('status') == DownModelChunkStatus.error.value:
message = chunk.get("digest")
QuerySet(Model).filter(id=model.id).update(meta={"down_model_chunk": down_model_chunk_list, "message": message},
status=status)
class ModelSerializer(serializers.Serializer):
class Query(serializers.Serializer):
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id"))
@ -50,7 +78,10 @@ class ModelSerializer(serializers.Serializer):
if self.data.get('provider') is not None:
query_params['provider'] = self.data.get('provider')
return [ModelSerializer.model_to_dict(model) for model in model_query_set.filter(**query_params)]
return [
{'id': str(model.id), 'provider': model.provider, 'name': model.name, 'model_type': model.model_type,
'model_name': model.model_name, 'status': model.status, 'meta': model.meta} for model in
model_query_set.filter(**query_params)]
class Edit(serializers.Serializer):
user_id = serializers.CharField(required=False, error_messages=ErrMessage.uuid("用户id"))
@ -88,13 +119,7 @@ class ModelSerializer(serializers.Serializer):
for k in source_encryption_model_credential.keys():
if credential[k] == source_encryption_model_credential[k]:
credential[k] = source_model_credential[k]
# 校验模型认证数据
model_credential.is_valid(
model_type,
model_name,
credential,
raise_exception=True)
return credential
return credential, model_credential
class Create(serializers.Serializer):
user_id = serializers.CharField(required=True, error_messages=ErrMessage.uuid("用户id"))
@ -124,18 +149,28 @@ class ModelSerializer(serializers.Serializer):
raise_exception=True)
def insert(self, user_id, with_valid=False):
status = Status.SUCCESS
if with_valid:
self.is_valid(raise_exception=True)
try:
self.is_valid(raise_exception=True)
except AppApiException as e:
if e.code == ValidCode.model_not_fount:
status = Status.DOWNLOAD
else:
raise e
credential = self.data.get('credential')
name = self.data.get('name')
provider = self.data.get('provider')
model_type = self.data.get('model_type')
model_name = self.data.get('model_name')
model_credential_str = json.dumps(credential)
model = Model(id=uuid.uuid1(), user_id=user_id, name=name,
model = Model(id=uuid.uuid1(), status=status, user_id=user_id, name=name,
credential=encrypt(model_credential_str),
provider=provider, model_type=model_type, model_name=model_name)
model.save()
if status == Status.DOWNLOAD:
thread = threading.Thread(target=ModelPullManage.pull, args=(model, credential))
thread.start()
return ModelSerializer.Operate(data={'id': model.id, 'user_id': user_id}).one(with_valid=True)
@staticmethod
@ -143,6 +178,8 @@ class ModelSerializer(serializers.Serializer):
credential = json.loads(decrypt(model.credential))
return {'id': str(model.id), 'provider': model.provider, 'name': model.name, 'model_type': model.model_type,
'model_name': model.model_name,
'status': model.status,
'meta': model.meta,
'credential': ModelProvideConstants[model.provider].value.get_model_credential(model.model_type,
model.model_name).encryption_dict(
credential)}
@ -164,6 +201,15 @@ class ModelSerializer(serializers.Serializer):
model = QuerySet(Model).get(id=self.data.get('id'), user_id=self.data.get('user_id'))
return ModelSerializer.model_to_dict(model)
def one_meta(self, with_valid=False):
if with_valid:
self.is_valid(raise_exception=True)
model = QuerySet(Model).get(id=self.data.get('id'), user_id=self.data.get('user_id'))
return {'id': str(model.id), 'provider': model.provider, 'name': model.name, 'model_type': model.model_type,
'model_name': model.model_name,
'status': model.status,
'meta': model.meta, }
def delete(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
@ -181,7 +227,20 @@ class ModelSerializer(serializers.Serializer):
if model is None:
raise AppApiException(500, '不存在的id')
else:
credential = ModelSerializer.Edit(data={**instance, 'user_id': user_id}).is_valid(model=model)
credential, model_credential = ModelSerializer.Edit(data={**instance, 'user_id': user_id}).is_valid(
model=model)
try:
# 校验模型认证数据
model_credential.is_valid(
model.model_type,
instance.get("model_name"),
credential,
raise_exception=True)
except AppApiException as e:
if e.code == ValidCode.model_not_fount:
model.status = Status.DOWNLOAD
else:
raise e
update_keys = ['credential', 'name', 'model_type', 'model_name']
for update_key in update_keys:
if update_key in instance and instance.get(update_key) is not None:
@ -191,6 +250,9 @@ class ModelSerializer(serializers.Serializer):
else:
model.__setattr__(update_key, instance.get(update_key))
model.save()
if model.status == Status.DOWNLOAD:
thread = threading.Thread(target=ModelPullManage.pull, args=(model, credential))
thread.start()
return self.one(with_valid=False)

View File

@ -16,6 +16,7 @@ urlpatterns = [
name="provider/model_form"),
path('model', views.Model.as_view(), name='model'),
path('model/<str:model_id>', views.Model.Operate.as_view(), name='model/operate'),
path('model/<str:model_id>/meta', views.Model.ModelMeta.as_view(), name='model/operate/meta'),
path('email_setting', views.SystemSetting.Email.as_view(), name='email_setting')
]

View File

@ -34,6 +34,17 @@ class Model(APIView):
ModelSerializer.Create(data={**request.data, 'user_id': str(request.user.id)}).insert(request.user.id,
with_valid=True))
@action(methods=['PUT'], detail=False)
@swagger_auto_schema(operation_summary="下载模型,只试用与Ollama平台",
operation_id="下载模型,只试用与Ollama平台",
request_body=ModelCreateApi.get_request_body_api()
, tags=["模型"])
@has_permissions(PermissionConstants.MODEL_CREATE)
def put(self, request: Request):
return result.success(
ModelSerializer.Create(data={**request.data, 'user_id': str(request.user.id)}).insert(request.user.id,
with_valid=True))
@action(methods=['GET'], detail=False)
@swagger_auto_schema(operation_summary="获取模型列表",
operation_id="获取模型列表",
@ -46,6 +57,18 @@ class Model(APIView):
data={**query_params_to_single_dict(request.query_params), 'user_id': request.user.id}).list(
with_valid=True))
class ModelMeta(APIView):
authentication_classes = [TokenAuth]
@action(methods=['GET'], detail=False)
@swagger_auto_schema(operation_summary="查询模型meta信息,该接口不携带认证信息",
operation_id="查询模型meta信息,该接口不携带认证信息",
tags=["模型"])
@has_permissions(PermissionConstants.MODEL_READ)
def get(self, request: Request, model_id: str):
return result.success(
ModelSerializer.Operate(data={'id': model_id, 'user_id': request.user.id}).one_meta(with_valid=True))
class Operate(APIView):
authentication_classes = [TokenAuth]

View File

@ -106,6 +106,31 @@ const updateModel: (
return put(`${prefix}/${model_id}`, request, {}, loading)
}
/**
* id
* @param model_id id
* @param loading
* @returns
*/
const getModelById: (model_id: string, loading?: Ref<boolean>) => Promise<Result<Model>> = (
model_id,
loading
) => {
return get(`${prefix}/${model_id}`, {}, loading)
}
/**
* id
* @param model_id id
* @param loading
* @returns
*/
const getModelMetaById: (model_id: string, loading?: Ref<boolean>) => Promise<Result<Model>> = (
model_id,
loading
) => {
return get(`${prefix}/${model_id}/meta`, {}, loading)
}
const deleteModel: (model_id: string, loading?: Ref<boolean>) => Promise<Result<boolean>> = (
model_id,
loading
@ -120,5 +145,7 @@ export default {
listBaseModel,
createModel,
updateModel,
deleteModel
deleteModel,
getModelById,
getModelMetaById
}

View File

@ -1,4 +1,5 @@
import { store } from '@/stores'
import { Dict } from './common'
interface modelRequest {
name: string
model_type: string
@ -64,6 +65,14 @@ interface Model {
*
*/
provider: string
/**
*
*/
status: 'SUCCESS' | 'DOWNLOAD' | 'ERROR'
/**
*
*/
meta: Dict<any>
}
interface CreateModelRequest {
/**

View File

@ -18,6 +18,7 @@
</template>
<DynamicsForm
v-loading="formLoading"
v-model="form_data"
:render_data="model_form_field"
:model="form_data"
@ -56,7 +57,7 @@
@change="getModelForm($event)"
v-loading="base_model_loading"
style="width: 100%"
v-model="form_data.model_name"
v-model="base_form_data.model_name"
class="m-2"
placeholder="请选择基础模型"
filterable
@ -90,10 +91,12 @@ import type { FormField } from '@/components/dynamics-form/type'
import DynamicsForm from '@/components/dynamics-form/index.vue'
import type { FormRules } from 'element-plus'
import { MsgSuccess } from '@/utils/message'
const providerValue = ref<Provider>()
const dynamicsFormRef = ref<InstanceType<typeof DynamicsForm>>()
const emit = defineEmits(['change', 'submit'])
const loading = ref<boolean>(false)
const formLoading = ref<boolean>(false)
const model_type_loading = ref<boolean>(false)
const base_model_loading = ref<boolean>(false)
const model_type_list = ref<Array<KeyValue<string, string>>>([])
@ -152,21 +155,22 @@ const list_base_model = (model_type: any) => {
}
}
const open = (provider: Provider, model: Model) => {
modelValue.value = model
ModelApi.listModelType(model.provider, model_type_loading).then((ok) => {
model_type_list.value = ok.data
list_base_model(model.model_type)
ModelApi.getModelById(model.id, formLoading).then((ok) => {
modelValue.value = ok.data
ModelApi.listModelType(model.provider, model_type_loading).then((ok) => {
model_type_list.value = ok.data
list_base_model(model.model_type)
})
providerValue.value = provider
base_form_data.value = {
name: model.name,
model_type: model.model_type,
model_name: model.model_name
}
form_data.value = model.credential
getModelForm(model.model_name)
})
providerValue.value = provider
base_form_data.value = {
name: model.name,
model_type: model.model_type,
model_name: model.model_name
}
form_data.value = model.credential
getModelForm(model.model_name)
dialogVisible.value = true
}

View File

@ -37,15 +37,32 @@
<script setup lang="ts">
import type { Provider, Model } from '@/api/type/model'
import ModelApi from '@/api/model'
import { computed, ref } from 'vue'
import { computed, ref, onMounted, onBeforeUnmount } from 'vue'
import EditModel from '@/views/template/component/EditModel.vue'
import { MsgSuccess, MsgConfirm } from '@/utils/message'
const props = defineProps<{
model: Model
provider_list: Array<Provider>
}>()
const downModel = ref<Model>()
const progress = computed(() => {
if (downModel.value) {
const down_model_chunk = downModel.value.meta['down_model_chunk']
if (down_model_chunk) {
const maxObj = down_model_chunk.reduce((prev: any, current: any) => {
return (prev.index || 0) > (current.index || 0) ? prev : current
})
return maxObj.progress
}
return 0
}
return 0
})
const emit = defineEmits(['change'])
const eidtModelRef = ref<InstanceType<typeof EditModel>>()
let interval: any
const deleteModel = () => {
MsgConfirm(`删除模型 `, `是否删除模型:${props.model.name} ?`, {
confirmButtonText: '删除',
@ -67,6 +84,34 @@ const openEditModel = () => {
const icon = computed(() => {
return props.provider_list.find((p) => p.provider === props.model.provider)?.icon
})
/**
* 初始化轮询
*/
const initInterval = () => {
interval = setInterval(() => {
if (props.model.status === 'DOWNLOAD') {
ModelApi.getModelMetaById(props.model.id).then((ok) => {
downModel.value = ok.data
})
}
}, 6000)
}
/**
* 关闭轮询
*/
const closeInterval = () => {
if (interval) {
clearInterval(interval)
}
}
onMounted(() => {
initInterval()
})
onBeforeUnmount(() => {
//
closeInterval()
})
</script>
<style lang="scss" scoped>
.model-card {