diff --git a/apps/application/migrations/0010_alter_chatrecord_details.py b/apps/application/migrations/0010_alter_chatrecord_details.py new file mode 100644 index 000000000..e46278009 --- /dev/null +++ b/apps/application/migrations/0010_alter_chatrecord_details.py @@ -0,0 +1,19 @@ +# Generated by Django 4.2.13 on 2024-07-15 15:52 + +import application.models.application +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('application', '0009_application_type_application_work_flow_and_more'), + ] + + operations = [ + migrations.AlterField( + model_name='chatrecord', + name='details', + field=models.JSONField(default=dict, encoder=application.models.application.DateEncoder, verbose_name='对话详情'), + ), + ] diff --git a/apps/application/serializers/application_serializers.py b/apps/application/serializers/application_serializers.py index 8ac962c95..390456305 100644 --- a/apps/application/serializers/application_serializers.py +++ b/apps/application/serializers/application_serializers.py @@ -522,12 +522,14 @@ class ApplicationSerializer(serializers.Serializer): if not QuerySet(Application).filter(id=self.data.get('application_id')).exists(): raise AppApiException(500, '不存在的应用id') - def list_model(self, with_valid=True): + def list_model(self, model_type=None, with_valid=True): if with_valid: self.is_valid() + if model_type is None: + model_type = "LLM" application = QuerySet(Application).filter(id=self.data.get("application_id")).first() return ModelSerializer.Query( - data={'user_id': application.user_id}).list( + data={'user_id': application.user_id, 'model_type': model_type}).list( with_valid=True) def delete(self, with_valid=True): diff --git a/apps/application/swagger_api/application_api.py b/apps/application/swagger_api/application_api.py index 6e46931a6..e153f6279 100644 --- a/apps/application/swagger_api/application_api.py +++ b/apps/application/swagger_api/application_api.py @@ -67,6 +67,20 @@ class ApplicationApi(ApiMixin): } ) + class Model(ApiMixin): + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='application_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='应用id'), + openapi.Parameter(name='model_type', in_=openapi.IN_QUERY, + type=openapi.TYPE_STRING, + required=False, + description='模型类型'), + ] + class ApiKey(ApiMixin): @staticmethod def get_request_params_api(): diff --git a/apps/application/views/application_views.py b/apps/application/views/application_views.py index c5fc165d3..28fe412ad 100644 --- a/apps/application/views/application_views.py +++ b/apps/application/views/application_views.py @@ -175,7 +175,7 @@ class Application(APIView): @swagger_auto_schema(operation_summary="获取模型列表", operation_id="获取模型列表", tags=["应用"], - manual_parameters=ApplicationApi.ApiKey.get_request_params_api()) + manual_parameters=ApplicationApi.Model.get_request_params_api()) @has_permissions(ViewPermission( [RoleConstants.ADMIN, RoleConstants.USER], [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, @@ -185,7 +185,7 @@ class Application(APIView): return result.success( ApplicationSerializer.Operate( data={'application_id': application_id, - 'user_id': request.user.id}).list_model()) + 'user_id': request.user.id}).list_model(request.query_params.get('model_type'))) class Profile(APIView): authentication_classes = [TokenAuth] diff --git a/apps/dataset/migrations/0006_dataset_embedding_mode_id.py b/apps/dataset/migrations/0006_dataset_embedding_mode_id.py new file mode 100644 index 000000000..77f546c60 --- /dev/null +++ b/apps/dataset/migrations/0006_dataset_embedding_mode_id.py @@ -0,0 +1,21 @@ +# Generated by Django 4.2.13 on 2024-07-15 15:56 + +import dataset.models.data_set +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + ('setting', '0005_model_permission_type'), + ('dataset', '0005_file'), + ] + + operations = [ + migrations.AddField( + model_name='dataset', + name='embedding_mode_id', + field=models.ForeignKey(default=dataset.models.data_set.default_model, on_delete=django.db.models.deletion.DO_NOTHING, to='setting.model', verbose_name='向量模型'), + ), + ] diff --git a/apps/dataset/models/data_set.py b/apps/dataset/models/data_set.py index ca3b05e0c..69fed2a09 100644 --- a/apps/dataset/models/data_set.py +++ b/apps/dataset/models/data_set.py @@ -9,9 +9,11 @@ import uuid from django.db import models +from django.db.models import QuerySet from common.db.sql_execute import select_one from common.mixins.app_model_mixin import AppModelMixin +from setting.models import Model from users.models import User @@ -33,6 +35,10 @@ class HitHandlingMethod(models.TextChoices): directly_return = 'directly_return', '直接返回' +def default_model(): + return uuid.UUID('42f63a3d-427e-11ef-b3ec-a8a1595801ab') + + class DataSet(AppModelMixin): """ 数据集表 @@ -43,7 +49,8 @@ class DataSet(AppModelMixin): user = models.ForeignKey(User, on_delete=models.DO_NOTHING, verbose_name="所属用户") type = models.CharField(verbose_name='类型', max_length=1, choices=Type.choices, default=Type.base) - + embedding_mode_id = models.ForeignKey(Model, on_delete=models.DO_NOTHING, verbose_name="向量模型", + default=default_model) meta = models.JSONField(verbose_name="元数据", default=dict) class Meta: diff --git a/apps/dataset/serializers/dataset_serializers.py b/apps/dataset/serializers/dataset_serializers.py index 50a6a9b99..0d87744c1 100644 --- a/apps/dataset/serializers/dataset_serializers.py +++ b/apps/dataset/serializers/dataset_serializers.py @@ -206,6 +206,8 @@ class DataSetSerializers(serializers.ModelSerializer): max_length=256, min_length=1) + embedding_mode_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("向量模型")) + documents = DocumentInstanceSerializer(required=False, many=True) def is_valid(self, *, raise_exception=False): @@ -226,6 +228,8 @@ class DataSetSerializers(serializers.ModelSerializer): max_length=256, min_length=1) + embedding_mode_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("向量模型")) + file_list = serializers.ListSerializer(required=True, error_messages=ErrMessage.list("文件列表"), child=serializers.FileField(required=True, @@ -365,7 +369,8 @@ class DataSetSerializers(serializers.ModelSerializer): self.CreateQASerializers(data=instance).is_valid() file_list = instance.get('file_list') document_list = flat_map([DocumentSerializers.Create.parse_qa_file(file) for file in file_list]) - dataset_instance = {'name': instance.get('name'), 'desc': instance.get('desc'), 'documents': document_list} + dataset_instance = {'name': instance.get('name'), 'desc': instance.get('desc'), 'documents': document_list, + 'embedding_mode_id': instance.get('embedding_mode_id')} return self.save(dataset_instance, with_valid=True) @valid_license(model=DataSet, count=50, @@ -381,7 +386,8 @@ class DataSetSerializers(serializers.ModelSerializer): if QuerySet(DataSet).filter(user_id=user_id, name=instance.get('name')).exists(): raise AppApiException(500, "知识库名称重复!") dataset = DataSet( - **{'id': dataset_id, 'name': instance.get("name"), 'desc': instance.get('desc'), 'user_id': user_id}) + **{'id': dataset_id, 'name': instance.get("name"), 'desc': instance.get('desc'), 'user_id': user_id, + 'embedding_mode_id': instance.get('embedding_mode_id')}) document_model_list = [] paragraph_model_list = [] @@ -500,6 +506,8 @@ class DataSetSerializers(serializers.ModelSerializer): properties={ 'name': openapi.Schema(type=openapi.TYPE_STRING, title="知识库名称", description="知识库名称"), 'desc': openapi.Schema(type=openapi.TYPE_STRING, title="知识库描述", description="知识库描述"), + 'embedding_mode_id': openapi.Schema(type=openapi.TYPE_STRING, title='向量模型', + description='向量模型'), 'documents': openapi.Schema(type=openapi.TYPE_ARRAY, title="文档数据", description="文档数据", items=DocumentSerializers().Create.get_request_body_api() ) @@ -782,6 +790,8 @@ class DataSetSerializers(serializers.ModelSerializer): raise AppApiException(500, "知识库名称重复!") _dataset = QuerySet(DataSet).get(id=self.data.get("id")) DataSetSerializers.Edit(data=dataset).is_valid(dataset=_dataset) + if 'embedding_mode_id' in dataset: + _dataset.embedding_mode_id = dataset.get('embedding_mode_id') if "name" in dataset: _dataset.name = dataset.get("name") if 'desc' in dataset: diff --git a/apps/dataset/views/dataset.py b/apps/dataset/views/dataset.py index e2bd10e09..adf10e3ff 100644 --- a/apps/dataset/views/dataset.py +++ b/apps/dataset/views/dataset.py @@ -52,7 +52,6 @@ class Dataset(APIView): @action(methods=['POST'], detail=False) @swagger_auto_schema(operation_summary="创建QA知识库", operation_id="创建QA知识库", - manual_parameters=DataSetSerializers.Create.CreateQASerializers.get_request_params_api(), responses=get_api_response( DataSetSerializers.Create.CreateQASerializers.get_response_body_api()), diff --git a/apps/setting/migrations/0005_model_permission_type.py b/apps/setting/migrations/0005_model_permission_type.py new file mode 100644 index 000000000..dba081a19 --- /dev/null +++ b/apps/setting/migrations/0005_model_permission_type.py @@ -0,0 +1,46 @@ +# Generated by Django 4.2.13 on 2024-07-15 15:23 +import json + +from django.db import migrations, models +from django.db.models import QuerySet + +from common.util.rsa_util import rsa_long_encrypt +from setting.models import Status, PermissionType +from smartdoc.const import CONFIG + +default_embedding_model_id = '42f63a3d-427e-11ef-b3ec-a8a1595801ab' + + +def save_default_embedding_model(apps, schema_editor): + ModelModel = apps.get_model('setting', 'Model') + cache_folder = CONFIG.get('EMBEDDING_MODEL_PATH') + model_name = CONFIG.get('EMBEDDING_MODEL_NAME') + credential = {'cache_folder': cache_folder} + model_credential_str = json.dumps(credential) + model = ModelModel(id=default_embedding_model_id, name='maxkb-embedding', status=Status.SUCCESS, + model_type="EMBEDDING", model_name=model_name, user_id='f0dd8f71-e4ee-11ee-8c84-a8a1595801ab', + provider='model_local_provider', + credential=rsa_long_encrypt(model_credential_str), meta={}, + permission_type=PermissionType.PUBLIC) + model.save() + + +def reverse_code_embedding_model(apps, schema_editor): + ModelModel = apps.get_model('setting', 'Model') + QuerySet(ModelModel).filter(id=default_embedding_model_id).delete() + + +class Migration(migrations.Migration): + dependencies = [ + ('setting', '0004_alter_model_credential'), + ] + + operations = [ + migrations.AddField( + model_name='model', + name='permission_type', + field=models.CharField(choices=[('PUBLIC', '公开'), ('PRIVATE', '私有')], default='PRIVATE', max_length=20, + verbose_name='权限类型'), + ), + migrations.RunPython(save_default_embedding_model, reverse_code_embedding_model) + ] diff --git a/apps/setting/models/model_management.py b/apps/setting/models/model_management.py index 5bdd1b296..42136a48e 100644 --- a/apps/setting/models/model_management.py +++ b/apps/setting/models/model_management.py @@ -23,6 +23,11 @@ class Status(models.TextChoices): DOWNLOAD = "DOWNLOAD", '下载中' +class PermissionType(models.TextChoices): + PUBLIC = "PUBLIC", '公开' + PRIVATE = "PRIVATE", "私有" + + class Model(AppModelMixin): """ 模型数据 @@ -46,6 +51,9 @@ class Model(AppModelMixin): meta = models.JSONField(verbose_name="模型元数据,用于存储下载,或者错误信息", default=dict) + permission_type = models.CharField(max_length=20, verbose_name='权限类型', choices=PermissionType.choices, + default=PermissionType.PRIVATE) + class Meta: db_table = "model" unique_together = ['name', 'user_id'] diff --git a/apps/setting/models_provider/base_model_provider.py b/apps/setting/models_provider/base_model_provider.py index aa6c525e0..8a9ab5e2b 100644 --- a/apps/setting/models_provider/base_model_provider.py +++ b/apps/setting/models_provider/base_model_provider.py @@ -61,7 +61,7 @@ class IModelProvider(ABC): def get_model_list(self, model_type): if model_type is None: raise AppApiException(500, '模型类型不能为空') - return self.get_model_info_manage().get_model_list() + return self.get_model_info_manage().get_model_list_by_model_type(model_type) def get_model_credential(self, model_type, model_name): model_info = self.get_model_info_manage().get_model_info(model_type, model_name) @@ -191,6 +191,9 @@ class ModelInfoManage: def get_model_list(self): return [model.to_dict() for model in self.model_list] + def get_model_list_by_model_type(self, model_type): + return [model.to_dict() for model in self.model_list if model.model_type == model_type] + def get_model_type_list(self): return [{'key': _type.value.get('message'), 'value': _type.value.get('code')} for _type in ModelTypeConst if len([model for model in self.model_list if model.model_type == _type.name]) > 0] diff --git a/apps/setting/models_provider/constants/model_provider_constants.py b/apps/setting/models_provider/constants/model_provider_constants.py index 1db811f9e..c9e9659c3 100644 --- a/apps/setting/models_provider/constants/model_provider_constants.py +++ b/apps/setting/models_provider/constants/model_provider_constants.py @@ -18,6 +18,7 @@ from setting.models_provider.impl.qwen_model_provider.qwen_model_provider import from setting.models_provider.impl.wenxin_model_provider.wenxin_model_provider import WenxinModelProvider from setting.models_provider.impl.xf_model_provider.xf_model_provider import XunFeiModelProvider from setting.models_provider.impl.zhipu_model_provider.zhipu_model_provider import ZhiPuModelProvider +from setting.models_provider.impl.local_model_provider.local_model_provider import LocalModelProvider class ModelProvideConstants(Enum): @@ -31,3 +32,4 @@ class ModelProvideConstants(Enum): model_xf_provider = XunFeiModelProvider() model_deepseek_provider = DeepSeekModelProvider() model_gemini_provider = GeminiModelProvider() + model_local_provider = LocalModelProvider() diff --git a/apps/setting/models_provider/impl/local_model_provider/__init__.py b/apps/setting/models_provider/impl/local_model_provider/__init__.py new file mode 100644 index 000000000..90a8d72c3 --- /dev/null +++ b/apps/setting/models_provider/impl/local_model_provider/__init__.py @@ -0,0 +1,8 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: __init__.py + @date:2024/7/10 17:48 + @desc: +""" diff --git a/apps/setting/models_provider/impl/local_model_provider/credential/embedding.py b/apps/setting/models_provider/impl/local_model_provider/credential/embedding.py new file mode 100644 index 000000000..a631196eb --- /dev/null +++ b/apps/setting/models_provider/impl/local_model_provider/credential/embedding.py @@ -0,0 +1,45 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: embedding.py + @date:2024/7/11 11:06 + @desc: +""" +from typing import Dict + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode +from setting.models_provider.impl.local_model_provider.model.embedding import LocalEmbedding + + +class LocalEmbeddingCredential(BaseForm, BaseModelCredential): + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + if not model_type == 'EMBEDDING': + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + for key in ['cache_folder']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') + else: + return False + try: + model: LocalEmbedding = provider.get_model(model_type, model_name, model_credential) + model.embed_query('你好') + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return model + + cache_folder = forms.TextInputField('模型目录', required=True) diff --git a/apps/setting/models_provider/impl/local_model_provider/icon/local_icon_svg b/apps/setting/models_provider/impl/local_model_provider/icon/local_icon_svg new file mode 100644 index 000000000..62930faab --- /dev/null +++ b/apps/setting/models_provider/impl/local_model_provider/icon/local_icon_svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/apps/setting/models_provider/impl/local_model_provider/local_model_provider.py b/apps/setting/models_provider/impl/local_model_provider/local_model_provider.py new file mode 100644 index 000000000..65cc57322 --- /dev/null +++ b/apps/setting/models_provider/impl/local_model_provider/local_model_provider.py @@ -0,0 +1,38 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: zhipu_model_provider.py + @date:2024/04/19 13:5 + @desc: +""" +import os +from typing import Dict + +from pydantic import BaseModel + +from common.exception.app_exception import AppApiException +from common.util.file_util import get_file_content +from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, ModelInfo, IModelProvider, \ + ModelInfoManage +from setting.models_provider.impl.local_model_provider.credential.embedding import LocalEmbeddingCredential +from setting.models_provider.impl.local_model_provider.model.embedding import LocalEmbedding +from smartdoc.conf import PROJECT_DIR + +embedding_text2vec_base_chinese = ModelInfo('shibing624/text2vec-base-chinese', '', ModelTypeConst.EMBEDDING, + LocalEmbeddingCredential(), LocalEmbedding) + +model_info_manage = (ModelInfoManage.builder().append_model_info(embedding_text2vec_base_chinese) + .append_default_model_info(embedding_text2vec_base_chinese) + .build()) + + +class LocalModelProvider(IModelProvider): + + def get_model_info_manage(self): + return model_info_manage + + def get_model_provide_info(self): + return ModelProvideInfo(provider='model_local_provider', name='本地模型', icon=get_file_content( + os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'local_model_provider', 'icon', + 'local_icon_svg'))) diff --git a/apps/setting/models_provider/impl/local_model_provider/model/embedding.py b/apps/setting/models_provider/impl/local_model_provider/model/embedding.py new file mode 100644 index 000000000..92cdd7390 --- /dev/null +++ b/apps/setting/models_provider/impl/local_model_provider/model/embedding.py @@ -0,0 +1,22 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: embedding.py + @date:2024/7/11 14:06 + @desc: +""" +from typing import Dict + +from langchain_huggingface import HuggingFaceEmbeddings + +from setting.models_provider.base_model_provider import MaxKBBaseModel + + +class LocalEmbedding(MaxKBBaseModel, HuggingFaceEmbeddings): + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + return LocalEmbedding(model_name=model_name, cache_folder=model_credential.get('cache_folder'), + model_kwargs={'device': model_credential.get('device')}, + encode_kwargs={'normalize_embeddings': True}, + ) diff --git a/apps/setting/models_provider/impl/ollama_model_provider/credential/embedding.py b/apps/setting/models_provider/impl/ollama_model_provider/credential/embedding.py new file mode 100644 index 000000000..e0eeabe59 --- /dev/null +++ b/apps/setting/models_provider/impl/ollama_model_provider/credential/embedding.py @@ -0,0 +1,45 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: embedding.py + @date:2024/7/12 15:10 + @desc: +""" +from typing import Dict + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode +from setting.models_provider.impl.local_model_provider.model.embedding import LocalEmbedding + + +class OllamaEmbeddingModelCredential(BaseForm, BaseModelCredential): + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + try: + model_list = provider.get_base_model_list(model_credential.get('api_base')) + except Exception as e: + raise AppApiException(ValidCode.valid_error.value, "API 域名无效") + exist = [model for model in (model_list.get('models') if model_list.get('models') is not None else []) if + model.get('model') == model_name or model.get('model').replace(":latest", "") == model_name] + if len(exist) == 0: + raise AppApiException(ValidCode.model_not_fount, "模型不存在,请先下载模型") + model: LocalEmbedding = provider.get_model(model_type, model_name, model_credential) + model.embed_query('你好') + return True + + def encryption_dict(self, model_info: Dict[str, object]): + return model_info + + def build_model(self, model_info: Dict[str, object]): + for key in ['model']: + if key not in model_info: + raise AppApiException(500, f'{key} 字段为必填字段') + return self + + api_base = forms.TextInputField('API 域名', required=True) diff --git a/apps/setting/models_provider/impl/ollama_model_provider/model/embedding.py b/apps/setting/models_provider/impl/ollama_model_provider/model/embedding.py new file mode 100644 index 000000000..2bb3bb659 --- /dev/null +++ b/apps/setting/models_provider/impl/ollama_model_provider/model/embedding.py @@ -0,0 +1,22 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: embedding.py + @date:2024/7/12 15:02 + @desc: +""" +from typing import Dict + +from langchain_community.embeddings import OllamaEmbeddings + +from setting.models_provider.base_model_provider import MaxKBBaseModel + + +class OllamaEmbedding(MaxKBBaseModel, OllamaEmbeddings): + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + return OllamaEmbeddings( + model=model_name, + base_url=model_credential.get('api_base'), + ) diff --git a/apps/setting/models_provider/impl/ollama_model_provider/ollama_model_provider.py b/apps/setting/models_provider/impl/ollama_model_provider/ollama_model_provider.py index 3839d5fcb..eb01d38d7 100644 --- a/apps/setting/models_provider/impl/ollama_model_provider/ollama_model_provider.py +++ b/apps/setting/models_provider/impl/ollama_model_provider/ollama_model_provider.py @@ -20,7 +20,9 @@ from common.forms import BaseForm from common.util.file_util import get_file_content from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, \ BaseModelCredential, DownModelChunk, DownModelChunkStatus, ValidCode, ModelInfoManage +from setting.models_provider.impl.ollama_model_provider.credential.embedding import OllamaEmbeddingModelCredential from setting.models_provider.impl.ollama_model_provider.credential.llm import OllamaLLMModelCredential +from setting.models_provider.impl.ollama_model_provider.model.embedding import OllamaEmbedding from setting.models_provider.impl.ollama_model_provider.model.llm import OllamaChatModel from smartdoc.conf import PROJECT_DIR @@ -88,14 +90,25 @@ model_info_list = [ ModelInfo( 'phi3', 'Phi-3 Mini是Microsoft的3.8B参数,轻量级,最先进的开放模型。', - ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel) + ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel), +] +ollama_embedding_model_credential = OllamaEmbeddingModelCredential() +embedding_model_info = [ + ModelInfo( + 'nomic-embed-text', + '一个具有大令牌上下文窗口的高性能开放嵌入模型。', + ModelTypeConst.EMBEDDING, ollama_embedding_model_credential, OllamaEmbedding), ] -model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info( +model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_model_info_list( + embedding_model_info).append_default_model_info( ModelInfo( 'phi3', 'Phi-3 Mini是Microsoft的3.8B参数,轻量级,最先进的开放模型。', - ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel)).build() + ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel)).append_default_model_info(ModelInfo( + 'nomic-embed-text', + '一个具有大令牌上下文窗口的高性能开放嵌入模型。', + ModelTypeConst.EMBEDDING, ollama_embedding_model_credential, OllamaEmbedding), ).build() def get_base_url(url: str): @@ -139,7 +152,6 @@ def convert(response_stream) -> Iterator[DownModelChunk]: temp = "" if len(temp) > 0: - print(temp) rows = [t for t in temp.split("\n") if len(t) > 0] for row in rows: yield convert_to_down_model_chunk(row, index) @@ -154,9 +166,6 @@ class OllamaModelProvider(IModelProvider): os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'ollama_model_provider', 'icon', 'ollama_icon_svg'))) - def get_dialogue_number(self): - return 2 - @staticmethod def get_base_model_list(api_base): base_url = get_base_url(api_base) @@ -165,7 +174,7 @@ class OllamaModelProvider(IModelProvider): return r.json() def down_model(self, model_type: str, model_name, model_credential: Dict[str, object]) -> Iterator[DownModelChunk]: - api_base = model_credential.get('api_base') + api_base = model_credential.get('api_base', '') base_url = get_base_url(api_base) r = requests.request( method="POST", diff --git a/apps/setting/models_provider/impl/openai_model_provider/credential/embedding.py b/apps/setting/models_provider/impl/openai_model_provider/credential/embedding.py new file mode 100644 index 000000000..d49d22e22 --- /dev/null +++ b/apps/setting/models_provider/impl/openai_model_provider/credential/embedding.py @@ -0,0 +1,46 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: embedding.py + @date:2024/7/12 16:45 + @desc: +""" +from typing import Dict + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class OpenAIEmbeddingCredential(BaseForm, BaseModelCredential): + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=True): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + + for key in ['api_base', 'api_key']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') + else: + return False + try: + model = provider.get_model(model_type, model_name, model_credential) + model.embed_query('你好') + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'api_key': super().encryption(model.get('api_key', ''))} + + api_base = forms.TextInputField('API 域名', required=True) + api_key = forms.PasswordInputField('API Key', required=True) diff --git a/apps/setting/models_provider/impl/openai_model_provider/model/embedding.py b/apps/setting/models_provider/impl/openai_model_provider/model/embedding.py new file mode 100644 index 000000000..5ac1f8e6f --- /dev/null +++ b/apps/setting/models_provider/impl/openai_model_provider/model/embedding.py @@ -0,0 +1,23 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: embedding.py + @date:2024/7/12 17:44 + @desc: +""" +from typing import Dict + +from langchain_community.embeddings import OpenAIEmbeddings + +from setting.models_provider.base_model_provider import MaxKBBaseModel + + +class OpenAIEmbeddingModel(MaxKBBaseModel, OpenAIEmbeddings): + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + return OpenAIEmbeddingModel( + api_key=model_credential.get('api_key'), + model=model_name, + openai_api_base=model_credential.get('api_base'), + ) diff --git a/apps/setting/models_provider/impl/openai_model_provider/openai_model_provider.py b/apps/setting/models_provider/impl/openai_model_provider/openai_model_provider.py index 6d12869ce..fb4c89d7b 100644 --- a/apps/setting/models_provider/impl/openai_model_provider/openai_model_provider.py +++ b/apps/setting/models_provider/impl/openai_model_provider/openai_model_provider.py @@ -11,7 +11,9 @@ import os from common.util.file_util import get_file_content from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, \ ModelTypeConst, ModelInfoManage +from setting.models_provider.impl.openai_model_provider.credential.embedding import OpenAIEmbeddingCredential from setting.models_provider.impl.openai_model_provider.credential.llm import OpenAILLMModelCredential +from setting.models_provider.impl.openai_model_provider.model.embedding import OpenAIEmbeddingModel from setting.models_provider.impl.openai_model_provider.model.llm import OpenAIChatModel from smartdoc.conf import PROJECT_DIR @@ -58,11 +60,17 @@ model_info_list = [ ModelTypeConst.LLM, openai_llm_model_credential, OpenAIChatModel) ] +open_ai_embedding_credential = OpenAIEmbeddingCredential() +model_info_embedding_list = [ + ModelInfo('text-embedding-ada-002', '', + ModelTypeConst.EMBEDDING, open_ai_embedding_credential, + OpenAIEmbeddingModel)] model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info( ModelInfo('gpt-3.5-turbo', '最新的gpt-3.5-turbo,随OpenAI调整而更新', ModelTypeConst.LLM, openai_llm_model_credential, OpenAIChatModel - )).build() + )).append_model_info_list(model_info_embedding_list).append_default_model_info( + model_info_embedding_list[0]).build() class OpenAIModelProvider(IModelProvider):