mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 01:33:05 +00:00
feat: 支持向量模型
This commit is contained in:
parent
7660c8f7f0
commit
2f3d282c0d
|
|
@ -0,0 +1,19 @@
|
|||
# Generated by Django 4.2.13 on 2024-07-15 15:52
|
||||
|
||||
import application.models.application
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('application', '0009_application_type_application_work_flow_and_more'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name='chatrecord',
|
||||
name='details',
|
||||
field=models.JSONField(default=dict, encoder=application.models.application.DateEncoder, verbose_name='对话详情'),
|
||||
),
|
||||
]
|
||||
|
|
@ -522,12 +522,14 @@ class ApplicationSerializer(serializers.Serializer):
|
|||
if not QuerySet(Application).filter(id=self.data.get('application_id')).exists():
|
||||
raise AppApiException(500, '不存在的应用id')
|
||||
|
||||
def list_model(self, with_valid=True):
|
||||
def list_model(self, model_type=None, with_valid=True):
|
||||
if with_valid:
|
||||
self.is_valid()
|
||||
if model_type is None:
|
||||
model_type = "LLM"
|
||||
application = QuerySet(Application).filter(id=self.data.get("application_id")).first()
|
||||
return ModelSerializer.Query(
|
||||
data={'user_id': application.user_id}).list(
|
||||
data={'user_id': application.user_id, 'model_type': model_type}).list(
|
||||
with_valid=True)
|
||||
|
||||
def delete(self, with_valid=True):
|
||||
|
|
|
|||
|
|
@ -67,6 +67,20 @@ class ApplicationApi(ApiMixin):
|
|||
}
|
||||
)
|
||||
|
||||
class Model(ApiMixin):
|
||||
@staticmethod
|
||||
def get_request_params_api():
|
||||
return [openapi.Parameter(name='application_id',
|
||||
in_=openapi.IN_PATH,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=True,
|
||||
description='应用id'),
|
||||
openapi.Parameter(name='model_type', in_=openapi.IN_QUERY,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=False,
|
||||
description='模型类型'),
|
||||
]
|
||||
|
||||
class ApiKey(ApiMixin):
|
||||
@staticmethod
|
||||
def get_request_params_api():
|
||||
|
|
|
|||
|
|
@ -175,7 +175,7 @@ class Application(APIView):
|
|||
@swagger_auto_schema(operation_summary="获取模型列表",
|
||||
operation_id="获取模型列表",
|
||||
tags=["应用"],
|
||||
manual_parameters=ApplicationApi.ApiKey.get_request_params_api())
|
||||
manual_parameters=ApplicationApi.Model.get_request_params_api())
|
||||
@has_permissions(ViewPermission(
|
||||
[RoleConstants.ADMIN, RoleConstants.USER],
|
||||
[lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE,
|
||||
|
|
@ -185,7 +185,7 @@ class Application(APIView):
|
|||
return result.success(
|
||||
ApplicationSerializer.Operate(
|
||||
data={'application_id': application_id,
|
||||
'user_id': request.user.id}).list_model())
|
||||
'user_id': request.user.id}).list_model(request.query_params.get('model_type')))
|
||||
|
||||
class Profile(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,21 @@
|
|||
# Generated by Django 4.2.13 on 2024-07-15 15:56
|
||||
|
||||
import dataset.models.data_set
|
||||
from django.db import migrations, models
|
||||
import django.db.models.deletion
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('setting', '0005_model_permission_type'),
|
||||
('dataset', '0005_file'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name='dataset',
|
||||
name='embedding_mode_id',
|
||||
field=models.ForeignKey(default=dataset.models.data_set.default_model, on_delete=django.db.models.deletion.DO_NOTHING, to='setting.model', verbose_name='向量模型'),
|
||||
),
|
||||
]
|
||||
|
|
@ -9,9 +9,11 @@
|
|||
import uuid
|
||||
|
||||
from django.db import models
|
||||
from django.db.models import QuerySet
|
||||
|
||||
from common.db.sql_execute import select_one
|
||||
from common.mixins.app_model_mixin import AppModelMixin
|
||||
from setting.models import Model
|
||||
from users.models import User
|
||||
|
||||
|
||||
|
|
@ -33,6 +35,10 @@ class HitHandlingMethod(models.TextChoices):
|
|||
directly_return = 'directly_return', '直接返回'
|
||||
|
||||
|
||||
def default_model():
|
||||
return uuid.UUID('42f63a3d-427e-11ef-b3ec-a8a1595801ab')
|
||||
|
||||
|
||||
class DataSet(AppModelMixin):
|
||||
"""
|
||||
数据集表
|
||||
|
|
@ -43,7 +49,8 @@ class DataSet(AppModelMixin):
|
|||
user = models.ForeignKey(User, on_delete=models.DO_NOTHING, verbose_name="所属用户")
|
||||
type = models.CharField(verbose_name='类型', max_length=1, choices=Type.choices,
|
||||
default=Type.base)
|
||||
|
||||
embedding_mode_id = models.ForeignKey(Model, on_delete=models.DO_NOTHING, verbose_name="向量模型",
|
||||
default=default_model)
|
||||
meta = models.JSONField(verbose_name="元数据", default=dict)
|
||||
|
||||
class Meta:
|
||||
|
|
|
|||
|
|
@ -206,6 +206,8 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||
max_length=256,
|
||||
min_length=1)
|
||||
|
||||
embedding_mode_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("向量模型"))
|
||||
|
||||
documents = DocumentInstanceSerializer(required=False, many=True)
|
||||
|
||||
def is_valid(self, *, raise_exception=False):
|
||||
|
|
@ -226,6 +228,8 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||
max_length=256,
|
||||
min_length=1)
|
||||
|
||||
embedding_mode_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("向量模型"))
|
||||
|
||||
file_list = serializers.ListSerializer(required=True,
|
||||
error_messages=ErrMessage.list("文件列表"),
|
||||
child=serializers.FileField(required=True,
|
||||
|
|
@ -365,7 +369,8 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||
self.CreateQASerializers(data=instance).is_valid()
|
||||
file_list = instance.get('file_list')
|
||||
document_list = flat_map([DocumentSerializers.Create.parse_qa_file(file) for file in file_list])
|
||||
dataset_instance = {'name': instance.get('name'), 'desc': instance.get('desc'), 'documents': document_list}
|
||||
dataset_instance = {'name': instance.get('name'), 'desc': instance.get('desc'), 'documents': document_list,
|
||||
'embedding_mode_id': instance.get('embedding_mode_id')}
|
||||
return self.save(dataset_instance, with_valid=True)
|
||||
|
||||
@valid_license(model=DataSet, count=50,
|
||||
|
|
@ -381,7 +386,8 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||
if QuerySet(DataSet).filter(user_id=user_id, name=instance.get('name')).exists():
|
||||
raise AppApiException(500, "知识库名称重复!")
|
||||
dataset = DataSet(
|
||||
**{'id': dataset_id, 'name': instance.get("name"), 'desc': instance.get('desc'), 'user_id': user_id})
|
||||
**{'id': dataset_id, 'name': instance.get("name"), 'desc': instance.get('desc'), 'user_id': user_id,
|
||||
'embedding_mode_id': instance.get('embedding_mode_id')})
|
||||
|
||||
document_model_list = []
|
||||
paragraph_model_list = []
|
||||
|
|
@ -500,6 +506,8 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||
properties={
|
||||
'name': openapi.Schema(type=openapi.TYPE_STRING, title="知识库名称", description="知识库名称"),
|
||||
'desc': openapi.Schema(type=openapi.TYPE_STRING, title="知识库描述", description="知识库描述"),
|
||||
'embedding_mode_id': openapi.Schema(type=openapi.TYPE_STRING, title='向量模型',
|
||||
description='向量模型'),
|
||||
'documents': openapi.Schema(type=openapi.TYPE_ARRAY, title="文档数据", description="文档数据",
|
||||
items=DocumentSerializers().Create.get_request_body_api()
|
||||
)
|
||||
|
|
@ -782,6 +790,8 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||
raise AppApiException(500, "知识库名称重复!")
|
||||
_dataset = QuerySet(DataSet).get(id=self.data.get("id"))
|
||||
DataSetSerializers.Edit(data=dataset).is_valid(dataset=_dataset)
|
||||
if 'embedding_mode_id' in dataset:
|
||||
_dataset.embedding_mode_id = dataset.get('embedding_mode_id')
|
||||
if "name" in dataset:
|
||||
_dataset.name = dataset.get("name")
|
||||
if 'desc' in dataset:
|
||||
|
|
|
|||
|
|
@ -52,7 +52,6 @@ class Dataset(APIView):
|
|||
@action(methods=['POST'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="创建QA知识库",
|
||||
operation_id="创建QA知识库",
|
||||
|
||||
manual_parameters=DataSetSerializers.Create.CreateQASerializers.get_request_params_api(),
|
||||
responses=get_api_response(
|
||||
DataSetSerializers.Create.CreateQASerializers.get_response_body_api()),
|
||||
|
|
|
|||
|
|
@ -0,0 +1,46 @@
|
|||
# Generated by Django 4.2.13 on 2024-07-15 15:23
|
||||
import json
|
||||
|
||||
from django.db import migrations, models
|
||||
from django.db.models import QuerySet
|
||||
|
||||
from common.util.rsa_util import rsa_long_encrypt
|
||||
from setting.models import Status, PermissionType
|
||||
from smartdoc.const import CONFIG
|
||||
|
||||
default_embedding_model_id = '42f63a3d-427e-11ef-b3ec-a8a1595801ab'
|
||||
|
||||
|
||||
def save_default_embedding_model(apps, schema_editor):
|
||||
ModelModel = apps.get_model('setting', 'Model')
|
||||
cache_folder = CONFIG.get('EMBEDDING_MODEL_PATH')
|
||||
model_name = CONFIG.get('EMBEDDING_MODEL_NAME')
|
||||
credential = {'cache_folder': cache_folder}
|
||||
model_credential_str = json.dumps(credential)
|
||||
model = ModelModel(id=default_embedding_model_id, name='maxkb-embedding', status=Status.SUCCESS,
|
||||
model_type="EMBEDDING", model_name=model_name, user_id='f0dd8f71-e4ee-11ee-8c84-a8a1595801ab',
|
||||
provider='model_local_provider',
|
||||
credential=rsa_long_encrypt(model_credential_str), meta={},
|
||||
permission_type=PermissionType.PUBLIC)
|
||||
model.save()
|
||||
|
||||
|
||||
def reverse_code_embedding_model(apps, schema_editor):
|
||||
ModelModel = apps.get_model('setting', 'Model')
|
||||
QuerySet(ModelModel).filter(id=default_embedding_model_id).delete()
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
('setting', '0004_alter_model_credential'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name='model',
|
||||
name='permission_type',
|
||||
field=models.CharField(choices=[('PUBLIC', '公开'), ('PRIVATE', '私有')], default='PRIVATE', max_length=20,
|
||||
verbose_name='权限类型'),
|
||||
),
|
||||
migrations.RunPython(save_default_embedding_model, reverse_code_embedding_model)
|
||||
]
|
||||
|
|
@ -23,6 +23,11 @@ class Status(models.TextChoices):
|
|||
DOWNLOAD = "DOWNLOAD", '下载中'
|
||||
|
||||
|
||||
class PermissionType(models.TextChoices):
|
||||
PUBLIC = "PUBLIC", '公开'
|
||||
PRIVATE = "PRIVATE", "私有"
|
||||
|
||||
|
||||
class Model(AppModelMixin):
|
||||
"""
|
||||
模型数据
|
||||
|
|
@ -46,6 +51,9 @@ class Model(AppModelMixin):
|
|||
|
||||
meta = models.JSONField(verbose_name="模型元数据,用于存储下载,或者错误信息", default=dict)
|
||||
|
||||
permission_type = models.CharField(max_length=20, verbose_name='权限类型', choices=PermissionType.choices,
|
||||
default=PermissionType.PRIVATE)
|
||||
|
||||
class Meta:
|
||||
db_table = "model"
|
||||
unique_together = ['name', 'user_id']
|
||||
|
|
|
|||
|
|
@ -61,7 +61,7 @@ class IModelProvider(ABC):
|
|||
def get_model_list(self, model_type):
|
||||
if model_type is None:
|
||||
raise AppApiException(500, '模型类型不能为空')
|
||||
return self.get_model_info_manage().get_model_list()
|
||||
return self.get_model_info_manage().get_model_list_by_model_type(model_type)
|
||||
|
||||
def get_model_credential(self, model_type, model_name):
|
||||
model_info = self.get_model_info_manage().get_model_info(model_type, model_name)
|
||||
|
|
@ -191,6 +191,9 @@ class ModelInfoManage:
|
|||
def get_model_list(self):
|
||||
return [model.to_dict() for model in self.model_list]
|
||||
|
||||
def get_model_list_by_model_type(self, model_type):
|
||||
return [model.to_dict() for model in self.model_list if model.model_type == model_type]
|
||||
|
||||
def get_model_type_list(self):
|
||||
return [{'key': _type.value.get('message'), 'value': _type.value.get('code')} for _type in ModelTypeConst if
|
||||
len([model for model in self.model_list if model.model_type == _type.name]) > 0]
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ from setting.models_provider.impl.qwen_model_provider.qwen_model_provider import
|
|||
from setting.models_provider.impl.wenxin_model_provider.wenxin_model_provider import WenxinModelProvider
|
||||
from setting.models_provider.impl.xf_model_provider.xf_model_provider import XunFeiModelProvider
|
||||
from setting.models_provider.impl.zhipu_model_provider.zhipu_model_provider import ZhiPuModelProvider
|
||||
from setting.models_provider.impl.local_model_provider.local_model_provider import LocalModelProvider
|
||||
|
||||
|
||||
class ModelProvideConstants(Enum):
|
||||
|
|
@ -31,3 +32,4 @@ class ModelProvideConstants(Enum):
|
|||
model_xf_provider = XunFeiModelProvider()
|
||||
model_deepseek_provider = DeepSeekModelProvider()
|
||||
model_gemini_provider = GeminiModelProvider()
|
||||
model_local_provider = LocalModelProvider()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,8 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: __init__.py
|
||||
@date:2024/7/10 17:48
|
||||
@desc:
|
||||
"""
|
||||
|
|
@ -0,0 +1,45 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: embedding.py
|
||||
@date:2024/7/11 11:06
|
||||
@desc:
|
||||
"""
|
||||
from typing import Dict
|
||||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm
|
||||
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
from setting.models_provider.impl.local_model_provider.model.embedding import LocalEmbedding
|
||||
|
||||
|
||||
class LocalEmbeddingCredential(BaseForm, BaseModelCredential):
|
||||
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
||||
raise_exception=False):
|
||||
if not model_type == 'EMBEDDING':
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||
for key in ['cache_folder']:
|
||||
if key not in model_credential:
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
|
||||
else:
|
||||
return False
|
||||
try:
|
||||
model: LocalEmbedding = provider.get_model(model_type, model_name, model_credential)
|
||||
model.embed_query('你好')
|
||||
except Exception as e:
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
def encryption_dict(self, model: Dict[str, object]):
|
||||
return model
|
||||
|
||||
cache_folder = forms.TextInputField('模型目录', required=True)
|
||||
|
|
@ -0,0 +1 @@
|
|||
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" t="1720668342208" class="icon" viewBox="0 0 1024 1024" version="1.1" p-id="9052" width="100%" height="100%"><path d="M512.2 475.7c-8.2-0.3-16.1-2.4-23.4-6.1L192.6 330.2c-24.9-11.1-36.1-40.3-25-65.2 5-11.2 13.9-20.1 25-25l281.5-133.2a89.43 89.43 0 0 1 76 0L831.7 240c24.9 11.1 36.1 40.3 25 65.2-5 11.2-13.9 20.1-25 25L535.5 469.5c-7.2 3.8-15.2 5.9-23.3 6.2z m-76.5 452.5c-7.6 0-15.1-1.9-21.8-5.5L146.3 797.2c-17-8.9-27.5-26.5-27.3-45.6v-320c0.1-18 9.7-34.5 25.1-43.7 14.3-8.1 31.8-8.1 46.1 0l267.1 125.4c16.1 8.7 26.4 25.4 27.1 43.7v320.4c-0.2 17.9-9.6 34.4-24.9 43.7-7.2 4.3-15.4 6.8-23.8 7.1z m152.9 0c-8.3 0-16.5-2.2-23.8-6.3-15.3-9.3-24.7-25.8-24.9-43.7V556.9c0.4-18.2 10.4-34.8 26.2-43.7L835 387c14.2-7.5 31.4-7.1 45.2 1.1 15.5 9.1 25 25.7 25.1 43.7v319.8c0.4 18.9-9.7 36.5-26.2 45.6L610.5 922.8c-6.8 3.6-14.3 5.5-21.9 5.4z" p-id="9053"/></svg>
|
||||
|
After Width: | Height: | Size: 931 B |
|
|
@ -0,0 +1,38 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: zhipu_model_provider.py
|
||||
@date:2024/04/19 13:5
|
||||
@desc:
|
||||
"""
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.util.file_util import get_file_content
|
||||
from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, ModelInfo, IModelProvider, \
|
||||
ModelInfoManage
|
||||
from setting.models_provider.impl.local_model_provider.credential.embedding import LocalEmbeddingCredential
|
||||
from setting.models_provider.impl.local_model_provider.model.embedding import LocalEmbedding
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
embedding_text2vec_base_chinese = ModelInfo('shibing624/text2vec-base-chinese', '', ModelTypeConst.EMBEDDING,
|
||||
LocalEmbeddingCredential(), LocalEmbedding)
|
||||
|
||||
model_info_manage = (ModelInfoManage.builder().append_model_info(embedding_text2vec_base_chinese)
|
||||
.append_default_model_info(embedding_text2vec_base_chinese)
|
||||
.build())
|
||||
|
||||
|
||||
class LocalModelProvider(IModelProvider):
|
||||
|
||||
def get_model_info_manage(self):
|
||||
return model_info_manage
|
||||
|
||||
def get_model_provide_info(self):
|
||||
return ModelProvideInfo(provider='model_local_provider', name='本地模型', icon=get_file_content(
|
||||
os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'local_model_provider', 'icon',
|
||||
'local_icon_svg')))
|
||||
|
|
@ -0,0 +1,22 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: embedding.py
|
||||
@date:2024/7/11 14:06
|
||||
@desc:
|
||||
"""
|
||||
from typing import Dict
|
||||
|
||||
from langchain_huggingface import HuggingFaceEmbeddings
|
||||
|
||||
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||
|
||||
|
||||
class LocalEmbedding(MaxKBBaseModel, HuggingFaceEmbeddings):
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
return LocalEmbedding(model_name=model_name, cache_folder=model_credential.get('cache_folder'),
|
||||
model_kwargs={'device': model_credential.get('device')},
|
||||
encode_kwargs={'normalize_embeddings': True},
|
||||
)
|
||||
|
|
@ -0,0 +1,45 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: embedding.py
|
||||
@date:2024/7/12 15:10
|
||||
@desc:
|
||||
"""
|
||||
from typing import Dict
|
||||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm
|
||||
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
from setting.models_provider.impl.local_model_provider.model.embedding import LocalEmbedding
|
||||
|
||||
|
||||
class OllamaEmbeddingModelCredential(BaseForm, BaseModelCredential):
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
||||
raise_exception=False):
|
||||
model_type_list = provider.get_model_type_list()
|
||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||
try:
|
||||
model_list = provider.get_base_model_list(model_credential.get('api_base'))
|
||||
except Exception as e:
|
||||
raise AppApiException(ValidCode.valid_error.value, "API 域名无效")
|
||||
exist = [model for model in (model_list.get('models') if model_list.get('models') is not None else []) if
|
||||
model.get('model') == model_name or model.get('model').replace(":latest", "") == model_name]
|
||||
if len(exist) == 0:
|
||||
raise AppApiException(ValidCode.model_not_fount, "模型不存在,请先下载模型")
|
||||
model: LocalEmbedding = provider.get_model(model_type, model_name, model_credential)
|
||||
model.embed_query('你好')
|
||||
return True
|
||||
|
||||
def encryption_dict(self, model_info: Dict[str, object]):
|
||||
return model_info
|
||||
|
||||
def build_model(self, model_info: Dict[str, object]):
|
||||
for key in ['model']:
|
||||
if key not in model_info:
|
||||
raise AppApiException(500, f'{key} 字段为必填字段')
|
||||
return self
|
||||
|
||||
api_base = forms.TextInputField('API 域名', required=True)
|
||||
|
|
@ -0,0 +1,22 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: embedding.py
|
||||
@date:2024/7/12 15:02
|
||||
@desc:
|
||||
"""
|
||||
from typing import Dict
|
||||
|
||||
from langchain_community.embeddings import OllamaEmbeddings
|
||||
|
||||
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||
|
||||
|
||||
class OllamaEmbedding(MaxKBBaseModel, OllamaEmbeddings):
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
return OllamaEmbeddings(
|
||||
model=model_name,
|
||||
base_url=model_credential.get('api_base'),
|
||||
)
|
||||
|
|
@ -20,7 +20,9 @@ from common.forms import BaseForm
|
|||
from common.util.file_util import get_file_content
|
||||
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, \
|
||||
BaseModelCredential, DownModelChunk, DownModelChunkStatus, ValidCode, ModelInfoManage
|
||||
from setting.models_provider.impl.ollama_model_provider.credential.embedding import OllamaEmbeddingModelCredential
|
||||
from setting.models_provider.impl.ollama_model_provider.credential.llm import OllamaLLMModelCredential
|
||||
from setting.models_provider.impl.ollama_model_provider.model.embedding import OllamaEmbedding
|
||||
from setting.models_provider.impl.ollama_model_provider.model.llm import OllamaChatModel
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
|
|
@ -88,14 +90,25 @@ model_info_list = [
|
|||
ModelInfo(
|
||||
'phi3',
|
||||
'Phi-3 Mini是Microsoft的3.8B参数,轻量级,最先进的开放模型。',
|
||||
ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel)
|
||||
ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel),
|
||||
]
|
||||
ollama_embedding_model_credential = OllamaEmbeddingModelCredential()
|
||||
embedding_model_info = [
|
||||
ModelInfo(
|
||||
'nomic-embed-text',
|
||||
'一个具有大令牌上下文窗口的高性能开放嵌入模型。',
|
||||
ModelTypeConst.EMBEDDING, ollama_embedding_model_credential, OllamaEmbedding),
|
||||
]
|
||||
|
||||
model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info(
|
||||
model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_model_info_list(
|
||||
embedding_model_info).append_default_model_info(
|
||||
ModelInfo(
|
||||
'phi3',
|
||||
'Phi-3 Mini是Microsoft的3.8B参数,轻量级,最先进的开放模型。',
|
||||
ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel)).build()
|
||||
ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel)).append_default_model_info(ModelInfo(
|
||||
'nomic-embed-text',
|
||||
'一个具有大令牌上下文窗口的高性能开放嵌入模型。',
|
||||
ModelTypeConst.EMBEDDING, ollama_embedding_model_credential, OllamaEmbedding), ).build()
|
||||
|
||||
|
||||
def get_base_url(url: str):
|
||||
|
|
@ -139,7 +152,6 @@ def convert(response_stream) -> Iterator[DownModelChunk]:
|
|||
temp = ""
|
||||
|
||||
if len(temp) > 0:
|
||||
print(temp)
|
||||
rows = [t for t in temp.split("\n") if len(t) > 0]
|
||||
for row in rows:
|
||||
yield convert_to_down_model_chunk(row, index)
|
||||
|
|
@ -154,9 +166,6 @@ class OllamaModelProvider(IModelProvider):
|
|||
os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'ollama_model_provider', 'icon',
|
||||
'ollama_icon_svg')))
|
||||
|
||||
def get_dialogue_number(self):
|
||||
return 2
|
||||
|
||||
@staticmethod
|
||||
def get_base_model_list(api_base):
|
||||
base_url = get_base_url(api_base)
|
||||
|
|
@ -165,7 +174,7 @@ class OllamaModelProvider(IModelProvider):
|
|||
return r.json()
|
||||
|
||||
def down_model(self, model_type: str, model_name, model_credential: Dict[str, object]) -> Iterator[DownModelChunk]:
|
||||
api_base = model_credential.get('api_base')
|
||||
api_base = model_credential.get('api_base', '')
|
||||
base_url = get_base_url(api_base)
|
||||
r = requests.request(
|
||||
method="POST",
|
||||
|
|
|
|||
|
|
@ -0,0 +1,46 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: embedding.py
|
||||
@date:2024/7/12 16:45
|
||||
@desc:
|
||||
"""
|
||||
from typing import Dict
|
||||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm
|
||||
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
|
||||
class OpenAIEmbeddingCredential(BaseForm, BaseModelCredential):
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
||||
raise_exception=True):
|
||||
model_type_list = provider.get_model_type_list()
|
||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||
|
||||
for key in ['api_base', 'api_key']:
|
||||
if key not in model_credential:
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
|
||||
else:
|
||||
return False
|
||||
try:
|
||||
model = provider.get_model(model_type, model_name, model_credential)
|
||||
model.embed_query('你好')
|
||||
except Exception as e:
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
def encryption_dict(self, model: Dict[str, object]):
|
||||
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||
|
||||
api_base = forms.TextInputField('API 域名', required=True)
|
||||
api_key = forms.PasswordInputField('API Key', required=True)
|
||||
|
|
@ -0,0 +1,23 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: embedding.py
|
||||
@date:2024/7/12 17:44
|
||||
@desc:
|
||||
"""
|
||||
from typing import Dict
|
||||
|
||||
from langchain_community.embeddings import OpenAIEmbeddings
|
||||
|
||||
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||
|
||||
|
||||
class OpenAIEmbeddingModel(MaxKBBaseModel, OpenAIEmbeddings):
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
return OpenAIEmbeddingModel(
|
||||
api_key=model_credential.get('api_key'),
|
||||
model=model_name,
|
||||
openai_api_base=model_credential.get('api_base'),
|
||||
)
|
||||
|
|
@ -11,7 +11,9 @@ import os
|
|||
from common.util.file_util import get_file_content
|
||||
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, \
|
||||
ModelTypeConst, ModelInfoManage
|
||||
from setting.models_provider.impl.openai_model_provider.credential.embedding import OpenAIEmbeddingCredential
|
||||
from setting.models_provider.impl.openai_model_provider.credential.llm import OpenAILLMModelCredential
|
||||
from setting.models_provider.impl.openai_model_provider.model.embedding import OpenAIEmbeddingModel
|
||||
from setting.models_provider.impl.openai_model_provider.model.llm import OpenAIChatModel
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
|
|
@ -58,11 +60,17 @@ model_info_list = [
|
|||
ModelTypeConst.LLM, openai_llm_model_credential,
|
||||
OpenAIChatModel)
|
||||
]
|
||||
open_ai_embedding_credential = OpenAIEmbeddingCredential()
|
||||
model_info_embedding_list = [
|
||||
ModelInfo('text-embedding-ada-002', '',
|
||||
ModelTypeConst.EMBEDDING, open_ai_embedding_credential,
|
||||
OpenAIEmbeddingModel)]
|
||||
|
||||
model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info(
|
||||
ModelInfo('gpt-3.5-turbo', '最新的gpt-3.5-turbo,随OpenAI调整而更新', ModelTypeConst.LLM,
|
||||
openai_llm_model_credential, OpenAIChatModel
|
||||
)).build()
|
||||
)).append_model_info_list(model_info_embedding_list).append_default_model_info(
|
||||
model_info_embedding_list[0]).build()
|
||||
|
||||
|
||||
class OpenAIModelProvider(IModelProvider):
|
||||
|
|
|
|||
Loading…
Reference in New Issue