feat: 支持向量模型

This commit is contained in:
shaohuzhang1 2024-07-15 16:26:54 +08:00
parent 7660c8f7f0
commit 2f3d282c0d
23 changed files with 416 additions and 18 deletions

View File

@ -0,0 +1,19 @@
# Generated by Django 4.2.13 on 2024-07-15 15:52
import application.models.application
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('application', '0009_application_type_application_work_flow_and_more'),
]
operations = [
migrations.AlterField(
model_name='chatrecord',
name='details',
field=models.JSONField(default=dict, encoder=application.models.application.DateEncoder, verbose_name='对话详情'),
),
]

View File

@ -522,12 +522,14 @@ class ApplicationSerializer(serializers.Serializer):
if not QuerySet(Application).filter(id=self.data.get('application_id')).exists():
raise AppApiException(500, '不存在的应用id')
def list_model(self, with_valid=True):
def list_model(self, model_type=None, with_valid=True):
if with_valid:
self.is_valid()
if model_type is None:
model_type = "LLM"
application = QuerySet(Application).filter(id=self.data.get("application_id")).first()
return ModelSerializer.Query(
data={'user_id': application.user_id}).list(
data={'user_id': application.user_id, 'model_type': model_type}).list(
with_valid=True)
def delete(self, with_valid=True):

View File

@ -67,6 +67,20 @@ class ApplicationApi(ApiMixin):
}
)
class Model(ApiMixin):
@staticmethod
def get_request_params_api():
return [openapi.Parameter(name='application_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='应用id'),
openapi.Parameter(name='model_type', in_=openapi.IN_QUERY,
type=openapi.TYPE_STRING,
required=False,
description='模型类型'),
]
class ApiKey(ApiMixin):
@staticmethod
def get_request_params_api():

View File

@ -175,7 +175,7 @@ class Application(APIView):
@swagger_auto_schema(operation_summary="获取模型列表",
operation_id="获取模型列表",
tags=["应用"],
manual_parameters=ApplicationApi.ApiKey.get_request_params_api())
manual_parameters=ApplicationApi.Model.get_request_params_api())
@has_permissions(ViewPermission(
[RoleConstants.ADMIN, RoleConstants.USER],
[lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE,
@ -185,7 +185,7 @@ class Application(APIView):
return result.success(
ApplicationSerializer.Operate(
data={'application_id': application_id,
'user_id': request.user.id}).list_model())
'user_id': request.user.id}).list_model(request.query_params.get('model_type')))
class Profile(APIView):
authentication_classes = [TokenAuth]

View File

@ -0,0 +1,21 @@
# Generated by Django 4.2.13 on 2024-07-15 15:56
import dataset.models.data_set
from django.db import migrations, models
import django.db.models.deletion
class Migration(migrations.Migration):
dependencies = [
('setting', '0005_model_permission_type'),
('dataset', '0005_file'),
]
operations = [
migrations.AddField(
model_name='dataset',
name='embedding_mode_id',
field=models.ForeignKey(default=dataset.models.data_set.default_model, on_delete=django.db.models.deletion.DO_NOTHING, to='setting.model', verbose_name='向量模型'),
),
]

View File

@ -9,9 +9,11 @@
import uuid
from django.db import models
from django.db.models import QuerySet
from common.db.sql_execute import select_one
from common.mixins.app_model_mixin import AppModelMixin
from setting.models import Model
from users.models import User
@ -33,6 +35,10 @@ class HitHandlingMethod(models.TextChoices):
directly_return = 'directly_return', '直接返回'
def default_model():
return uuid.UUID('42f63a3d-427e-11ef-b3ec-a8a1595801ab')
class DataSet(AppModelMixin):
"""
数据集表
@ -43,7 +49,8 @@ class DataSet(AppModelMixin):
user = models.ForeignKey(User, on_delete=models.DO_NOTHING, verbose_name="所属用户")
type = models.CharField(verbose_name='类型', max_length=1, choices=Type.choices,
default=Type.base)
embedding_mode_id = models.ForeignKey(Model, on_delete=models.DO_NOTHING, verbose_name="向量模型",
default=default_model)
meta = models.JSONField(verbose_name="元数据", default=dict)
class Meta:

View File

@ -206,6 +206,8 @@ class DataSetSerializers(serializers.ModelSerializer):
max_length=256,
min_length=1)
embedding_mode_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("向量模型"))
documents = DocumentInstanceSerializer(required=False, many=True)
def is_valid(self, *, raise_exception=False):
@ -226,6 +228,8 @@ class DataSetSerializers(serializers.ModelSerializer):
max_length=256,
min_length=1)
embedding_mode_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("向量模型"))
file_list = serializers.ListSerializer(required=True,
error_messages=ErrMessage.list("文件列表"),
child=serializers.FileField(required=True,
@ -365,7 +369,8 @@ class DataSetSerializers(serializers.ModelSerializer):
self.CreateQASerializers(data=instance).is_valid()
file_list = instance.get('file_list')
document_list = flat_map([DocumentSerializers.Create.parse_qa_file(file) for file in file_list])
dataset_instance = {'name': instance.get('name'), 'desc': instance.get('desc'), 'documents': document_list}
dataset_instance = {'name': instance.get('name'), 'desc': instance.get('desc'), 'documents': document_list,
'embedding_mode_id': instance.get('embedding_mode_id')}
return self.save(dataset_instance, with_valid=True)
@valid_license(model=DataSet, count=50,
@ -381,7 +386,8 @@ class DataSetSerializers(serializers.ModelSerializer):
if QuerySet(DataSet).filter(user_id=user_id, name=instance.get('name')).exists():
raise AppApiException(500, "知识库名称重复!")
dataset = DataSet(
**{'id': dataset_id, 'name': instance.get("name"), 'desc': instance.get('desc'), 'user_id': user_id})
**{'id': dataset_id, 'name': instance.get("name"), 'desc': instance.get('desc'), 'user_id': user_id,
'embedding_mode_id': instance.get('embedding_mode_id')})
document_model_list = []
paragraph_model_list = []
@ -500,6 +506,8 @@ class DataSetSerializers(serializers.ModelSerializer):
properties={
'name': openapi.Schema(type=openapi.TYPE_STRING, title="知识库名称", description="知识库名称"),
'desc': openapi.Schema(type=openapi.TYPE_STRING, title="知识库描述", description="知识库描述"),
'embedding_mode_id': openapi.Schema(type=openapi.TYPE_STRING, title='向量模型',
description='向量模型'),
'documents': openapi.Schema(type=openapi.TYPE_ARRAY, title="文档数据", description="文档数据",
items=DocumentSerializers().Create.get_request_body_api()
)
@ -782,6 +790,8 @@ class DataSetSerializers(serializers.ModelSerializer):
raise AppApiException(500, "知识库名称重复!")
_dataset = QuerySet(DataSet).get(id=self.data.get("id"))
DataSetSerializers.Edit(data=dataset).is_valid(dataset=_dataset)
if 'embedding_mode_id' in dataset:
_dataset.embedding_mode_id = dataset.get('embedding_mode_id')
if "name" in dataset:
_dataset.name = dataset.get("name")
if 'desc' in dataset:

View File

@ -52,7 +52,6 @@ class Dataset(APIView):
@action(methods=['POST'], detail=False)
@swagger_auto_schema(operation_summary="创建QA知识库",
operation_id="创建QA知识库",
manual_parameters=DataSetSerializers.Create.CreateQASerializers.get_request_params_api(),
responses=get_api_response(
DataSetSerializers.Create.CreateQASerializers.get_response_body_api()),

View File

@ -0,0 +1,46 @@
# Generated by Django 4.2.13 on 2024-07-15 15:23
import json
from django.db import migrations, models
from django.db.models import QuerySet
from common.util.rsa_util import rsa_long_encrypt
from setting.models import Status, PermissionType
from smartdoc.const import CONFIG
default_embedding_model_id = '42f63a3d-427e-11ef-b3ec-a8a1595801ab'
def save_default_embedding_model(apps, schema_editor):
ModelModel = apps.get_model('setting', 'Model')
cache_folder = CONFIG.get('EMBEDDING_MODEL_PATH')
model_name = CONFIG.get('EMBEDDING_MODEL_NAME')
credential = {'cache_folder': cache_folder}
model_credential_str = json.dumps(credential)
model = ModelModel(id=default_embedding_model_id, name='maxkb-embedding', status=Status.SUCCESS,
model_type="EMBEDDING", model_name=model_name, user_id='f0dd8f71-e4ee-11ee-8c84-a8a1595801ab',
provider='model_local_provider',
credential=rsa_long_encrypt(model_credential_str), meta={},
permission_type=PermissionType.PUBLIC)
model.save()
def reverse_code_embedding_model(apps, schema_editor):
ModelModel = apps.get_model('setting', 'Model')
QuerySet(ModelModel).filter(id=default_embedding_model_id).delete()
class Migration(migrations.Migration):
dependencies = [
('setting', '0004_alter_model_credential'),
]
operations = [
migrations.AddField(
model_name='model',
name='permission_type',
field=models.CharField(choices=[('PUBLIC', '公开'), ('PRIVATE', '私有')], default='PRIVATE', max_length=20,
verbose_name='权限类型'),
),
migrations.RunPython(save_default_embedding_model, reverse_code_embedding_model)
]

View File

@ -23,6 +23,11 @@ class Status(models.TextChoices):
DOWNLOAD = "DOWNLOAD", '下载中'
class PermissionType(models.TextChoices):
PUBLIC = "PUBLIC", '公开'
PRIVATE = "PRIVATE", "私有"
class Model(AppModelMixin):
"""
模型数据
@ -46,6 +51,9 @@ class Model(AppModelMixin):
meta = models.JSONField(verbose_name="模型元数据,用于存储下载,或者错误信息", default=dict)
permission_type = models.CharField(max_length=20, verbose_name='权限类型', choices=PermissionType.choices,
default=PermissionType.PRIVATE)
class Meta:
db_table = "model"
unique_together = ['name', 'user_id']

View File

@ -61,7 +61,7 @@ class IModelProvider(ABC):
def get_model_list(self, model_type):
if model_type is None:
raise AppApiException(500, '模型类型不能为空')
return self.get_model_info_manage().get_model_list()
return self.get_model_info_manage().get_model_list_by_model_type(model_type)
def get_model_credential(self, model_type, model_name):
model_info = self.get_model_info_manage().get_model_info(model_type, model_name)
@ -191,6 +191,9 @@ class ModelInfoManage:
def get_model_list(self):
return [model.to_dict() for model in self.model_list]
def get_model_list_by_model_type(self, model_type):
return [model.to_dict() for model in self.model_list if model.model_type == model_type]
def get_model_type_list(self):
return [{'key': _type.value.get('message'), 'value': _type.value.get('code')} for _type in ModelTypeConst if
len([model for model in self.model_list if model.model_type == _type.name]) > 0]

View File

@ -18,6 +18,7 @@ from setting.models_provider.impl.qwen_model_provider.qwen_model_provider import
from setting.models_provider.impl.wenxin_model_provider.wenxin_model_provider import WenxinModelProvider
from setting.models_provider.impl.xf_model_provider.xf_model_provider import XunFeiModelProvider
from setting.models_provider.impl.zhipu_model_provider.zhipu_model_provider import ZhiPuModelProvider
from setting.models_provider.impl.local_model_provider.local_model_provider import LocalModelProvider
class ModelProvideConstants(Enum):
@ -31,3 +32,4 @@ class ModelProvideConstants(Enum):
model_xf_provider = XunFeiModelProvider()
model_deepseek_provider = DeepSeekModelProvider()
model_gemini_provider = GeminiModelProvider()
model_local_provider = LocalModelProvider()

View File

@ -0,0 +1,8 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file __init__.py
@date2024/7/10 17:48
@desc:
"""

View File

@ -0,0 +1,45 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file embedding.py
@date2024/7/11 11:06
@desc:
"""
from typing import Dict
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
from setting.models_provider.impl.local_model_provider.model.embedding import LocalEmbedding
class LocalEmbeddingCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
raise_exception=False):
if not model_type == 'EMBEDDING':
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
for key in ['cache_folder']:
if key not in model_credential:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
else:
return False
try:
model: LocalEmbedding = provider.get_model(model_type, model_name, model_credential)
model.embed_query('你好')
except Exception as e:
if isinstance(e, AppApiException):
raise e
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
else:
return False
return True
def encryption_dict(self, model: Dict[str, object]):
return model
cache_folder = forms.TextInputField('模型目录', required=True)

View File

@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" t="1720668342208" class="icon" viewBox="0 0 1024 1024" version="1.1" p-id="9052" width="100%" height="100%"><path d="M512.2 475.7c-8.2-0.3-16.1-2.4-23.4-6.1L192.6 330.2c-24.9-11.1-36.1-40.3-25-65.2 5-11.2 13.9-20.1 25-25l281.5-133.2a89.43 89.43 0 0 1 76 0L831.7 240c24.9 11.1 36.1 40.3 25 65.2-5 11.2-13.9 20.1-25 25L535.5 469.5c-7.2 3.8-15.2 5.9-23.3 6.2z m-76.5 452.5c-7.6 0-15.1-1.9-21.8-5.5L146.3 797.2c-17-8.9-27.5-26.5-27.3-45.6v-320c0.1-18 9.7-34.5 25.1-43.7 14.3-8.1 31.8-8.1 46.1 0l267.1 125.4c16.1 8.7 26.4 25.4 27.1 43.7v320.4c-0.2 17.9-9.6 34.4-24.9 43.7-7.2 4.3-15.4 6.8-23.8 7.1z m152.9 0c-8.3 0-16.5-2.2-23.8-6.3-15.3-9.3-24.7-25.8-24.9-43.7V556.9c0.4-18.2 10.4-34.8 26.2-43.7L835 387c14.2-7.5 31.4-7.1 45.2 1.1 15.5 9.1 25 25.7 25.1 43.7v319.8c0.4 18.9-9.7 36.5-26.2 45.6L610.5 922.8c-6.8 3.6-14.3 5.5-21.9 5.4z" p-id="9053"/></svg>

After

Width:  |  Height:  |  Size: 931 B

View File

@ -0,0 +1,38 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file zhipu_model_provider.py
@date2024/04/19 13:5
@desc:
"""
import os
from typing import Dict
from pydantic import BaseModel
from common.exception.app_exception import AppApiException
from common.util.file_util import get_file_content
from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, ModelInfo, IModelProvider, \
ModelInfoManage
from setting.models_provider.impl.local_model_provider.credential.embedding import LocalEmbeddingCredential
from setting.models_provider.impl.local_model_provider.model.embedding import LocalEmbedding
from smartdoc.conf import PROJECT_DIR
embedding_text2vec_base_chinese = ModelInfo('shibing624/text2vec-base-chinese', '', ModelTypeConst.EMBEDDING,
LocalEmbeddingCredential(), LocalEmbedding)
model_info_manage = (ModelInfoManage.builder().append_model_info(embedding_text2vec_base_chinese)
.append_default_model_info(embedding_text2vec_base_chinese)
.build())
class LocalModelProvider(IModelProvider):
def get_model_info_manage(self):
return model_info_manage
def get_model_provide_info(self):
return ModelProvideInfo(provider='model_local_provider', name='本地模型', icon=get_file_content(
os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'local_model_provider', 'icon',
'local_icon_svg')))

View File

@ -0,0 +1,22 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file embedding.py
@date2024/7/11 14:06
@desc:
"""
from typing import Dict
from langchain_huggingface import HuggingFaceEmbeddings
from setting.models_provider.base_model_provider import MaxKBBaseModel
class LocalEmbedding(MaxKBBaseModel, HuggingFaceEmbeddings):
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
return LocalEmbedding(model_name=model_name, cache_folder=model_credential.get('cache_folder'),
model_kwargs={'device': model_credential.get('device')},
encode_kwargs={'normalize_embeddings': True},
)

View File

@ -0,0 +1,45 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file embedding.py
@date2024/7/12 15:10
@desc:
"""
from typing import Dict
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
from setting.models_provider.impl.local_model_provider.model.embedding import LocalEmbedding
class OllamaEmbeddingModelCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
try:
model_list = provider.get_base_model_list(model_credential.get('api_base'))
except Exception as e:
raise AppApiException(ValidCode.valid_error.value, "API 域名无效")
exist = [model for model in (model_list.get('models') if model_list.get('models') is not None else []) if
model.get('model') == model_name or model.get('model').replace(":latest", "") == model_name]
if len(exist) == 0:
raise AppApiException(ValidCode.model_not_fount, "模型不存在,请先下载模型")
model: LocalEmbedding = provider.get_model(model_type, model_name, model_credential)
model.embed_query('你好')
return True
def encryption_dict(self, model_info: Dict[str, object]):
return model_info
def build_model(self, model_info: Dict[str, object]):
for key in ['model']:
if key not in model_info:
raise AppApiException(500, f'{key} 字段为必填字段')
return self
api_base = forms.TextInputField('API 域名', required=True)

View File

@ -0,0 +1,22 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file embedding.py
@date2024/7/12 15:02
@desc:
"""
from typing import Dict
from langchain_community.embeddings import OllamaEmbeddings
from setting.models_provider.base_model_provider import MaxKBBaseModel
class OllamaEmbedding(MaxKBBaseModel, OllamaEmbeddings):
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
return OllamaEmbeddings(
model=model_name,
base_url=model_credential.get('api_base'),
)

View File

@ -20,7 +20,9 @@ from common.forms import BaseForm
from common.util.file_util import get_file_content
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, \
BaseModelCredential, DownModelChunk, DownModelChunkStatus, ValidCode, ModelInfoManage
from setting.models_provider.impl.ollama_model_provider.credential.embedding import OllamaEmbeddingModelCredential
from setting.models_provider.impl.ollama_model_provider.credential.llm import OllamaLLMModelCredential
from setting.models_provider.impl.ollama_model_provider.model.embedding import OllamaEmbedding
from setting.models_provider.impl.ollama_model_provider.model.llm import OllamaChatModel
from smartdoc.conf import PROJECT_DIR
@ -88,14 +90,25 @@ model_info_list = [
ModelInfo(
'phi3',
'Phi-3 Mini是Microsoft的3.8B参数,轻量级,最先进的开放模型。',
ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel)
ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel),
]
ollama_embedding_model_credential = OllamaEmbeddingModelCredential()
embedding_model_info = [
ModelInfo(
'nomic-embed-text',
'一个具有大令牌上下文窗口的高性能开放嵌入模型。',
ModelTypeConst.EMBEDDING, ollama_embedding_model_credential, OllamaEmbedding),
]
model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info(
model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_model_info_list(
embedding_model_info).append_default_model_info(
ModelInfo(
'phi3',
'Phi-3 Mini是Microsoft的3.8B参数,轻量级,最先进的开放模型。',
ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel)).build()
ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel)).append_default_model_info(ModelInfo(
'nomic-embed-text',
'一个具有大令牌上下文窗口的高性能开放嵌入模型。',
ModelTypeConst.EMBEDDING, ollama_embedding_model_credential, OllamaEmbedding), ).build()
def get_base_url(url: str):
@ -139,7 +152,6 @@ def convert(response_stream) -> Iterator[DownModelChunk]:
temp = ""
if len(temp) > 0:
print(temp)
rows = [t for t in temp.split("\n") if len(t) > 0]
for row in rows:
yield convert_to_down_model_chunk(row, index)
@ -154,9 +166,6 @@ class OllamaModelProvider(IModelProvider):
os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'ollama_model_provider', 'icon',
'ollama_icon_svg')))
def get_dialogue_number(self):
return 2
@staticmethod
def get_base_model_list(api_base):
base_url = get_base_url(api_base)
@ -165,7 +174,7 @@ class OllamaModelProvider(IModelProvider):
return r.json()
def down_model(self, model_type: str, model_name, model_credential: Dict[str, object]) -> Iterator[DownModelChunk]:
api_base = model_credential.get('api_base')
api_base = model_credential.get('api_base', '')
base_url = get_base_url(api_base)
r = requests.request(
method="POST",

View File

@ -0,0 +1,46 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file embedding.py
@date2024/7/12 16:45
@desc:
"""
from typing import Dict
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
class OpenAIEmbeddingCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
raise_exception=True):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
for key in ['api_base', 'api_key']:
if key not in model_credential:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
model.embed_query('你好')
except Exception as e:
if isinstance(e, AppApiException):
raise e
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
else:
return False
return True
def encryption_dict(self, model: Dict[str, object]):
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
api_base = forms.TextInputField('API 域名', required=True)
api_key = forms.PasswordInputField('API Key', required=True)

View File

@ -0,0 +1,23 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file embedding.py
@date2024/7/12 17:44
@desc:
"""
from typing import Dict
from langchain_community.embeddings import OpenAIEmbeddings
from setting.models_provider.base_model_provider import MaxKBBaseModel
class OpenAIEmbeddingModel(MaxKBBaseModel, OpenAIEmbeddings):
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
return OpenAIEmbeddingModel(
api_key=model_credential.get('api_key'),
model=model_name,
openai_api_base=model_credential.get('api_base'),
)

View File

@ -11,7 +11,9 @@ import os
from common.util.file_util import get_file_content
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, \
ModelTypeConst, ModelInfoManage
from setting.models_provider.impl.openai_model_provider.credential.embedding import OpenAIEmbeddingCredential
from setting.models_provider.impl.openai_model_provider.credential.llm import OpenAILLMModelCredential
from setting.models_provider.impl.openai_model_provider.model.embedding import OpenAIEmbeddingModel
from setting.models_provider.impl.openai_model_provider.model.llm import OpenAIChatModel
from smartdoc.conf import PROJECT_DIR
@ -58,11 +60,17 @@ model_info_list = [
ModelTypeConst.LLM, openai_llm_model_credential,
OpenAIChatModel)
]
open_ai_embedding_credential = OpenAIEmbeddingCredential()
model_info_embedding_list = [
ModelInfo('text-embedding-ada-002', '',
ModelTypeConst.EMBEDDING, open_ai_embedding_credential,
OpenAIEmbeddingModel)]
model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info(
ModelInfo('gpt-3.5-turbo', '最新的gpt-3.5-turbo随OpenAI调整而更新', ModelTypeConst.LLM,
openai_llm_model_credential, OpenAIChatModel
)).build()
)).append_model_info_list(model_info_embedding_list).append_default_model_info(
model_info_embedding_list[0]).build()
class OpenAIModelProvider(IModelProvider):