From 86330c443e5a6a69d98d8d4acaca18f5b7ce5689 Mon Sep 17 00:00:00 2001 From: CaptainB Date: Mon, 20 Jan 2025 14:42:07 +0800 Subject: [PATCH] feat: Support vllm embedding model --- .../credential/embedding.py | 48 +++++++++++++++++++ .../vllm_model_provider/model/embedding.py | 23 +++++++++ .../vllm_model_provider.py | 9 ++++ 3 files changed, 80 insertions(+) create mode 100644 apps/setting/models_provider/impl/vllm_model_provider/credential/embedding.py create mode 100644 apps/setting/models_provider/impl/vllm_model_provider/model/embedding.py diff --git a/apps/setting/models_provider/impl/vllm_model_provider/credential/embedding.py b/apps/setting/models_provider/impl/vllm_model_provider/credential/embedding.py new file mode 100644 index 000000000..ab8628431 --- /dev/null +++ b/apps/setting/models_provider/impl/vllm_model_provider/credential/embedding.py @@ -0,0 +1,48 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: embedding.py + @date:2024/7/12 16:45 + @desc: +""" +from typing import Dict + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode +from django.utils.translation import gettext_lazy as _ + + + +class VllmEmbeddingCredential(BaseForm, BaseModelCredential): + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, + raise_exception=True): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, _('{model_type} Model type is not supported').format(model_type=model_type)) + + for key in ['api_base', 'api_key']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, _('{key} is required').format(key=key)) + else: + return False + try: + model = provider.get_model(model_type, model_name, model_credential) + model.embed_query('你好') + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, _('Verification failed, please check whether the parameters are correct: {error}').format(error=str(e))) + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'api_key': super().encryption(model.get('api_key', ''))} + + api_base = forms.TextInputField('API Url', required=True) + api_key = forms.PasswordInputField('API Key', required=True) diff --git a/apps/setting/models_provider/impl/vllm_model_provider/model/embedding.py b/apps/setting/models_provider/impl/vllm_model_provider/model/embedding.py new file mode 100644 index 000000000..616d9d9ee --- /dev/null +++ b/apps/setting/models_provider/impl/vllm_model_provider/model/embedding.py @@ -0,0 +1,23 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: embedding.py + @date:2024/7/12 17:44 + @desc: +""" +from typing import Dict + +from langchain_community.embeddings import OpenAIEmbeddings + +from setting.models_provider.base_model_provider import MaxKBBaseModel + + +class VllmEmbeddingModel(MaxKBBaseModel, OpenAIEmbeddings): + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + return VllmEmbeddingModel( + model=model_name, + openai_api_key=model_credential.get('api_key'), + openai_api_base=model_credential.get('api_base'), + ) diff --git a/apps/setting/models_provider/impl/vllm_model_provider/vllm_model_provider.py b/apps/setting/models_provider/impl/vllm_model_provider/vllm_model_provider.py index 96169591a..2effe0e4b 100644 --- a/apps/setting/models_provider/impl/vllm_model_provider/vllm_model_provider.py +++ b/apps/setting/models_provider/impl/vllm_model_provider/vllm_model_provider.py @@ -7,8 +7,10 @@ import requests from common.util.file_util import get_file_content from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, \ ModelInfoManage +from setting.models_provider.impl.vllm_model_provider.credential.embedding import VllmEmbeddingCredential from setting.models_provider.impl.vllm_model_provider.credential.image import VllmImageModelCredential from setting.models_provider.impl.vllm_model_provider.credential.llm import VLLMModelCredential +from setting.models_provider.impl.vllm_model_provider.model.embedding import VllmEmbeddingModel from setting.models_provider.impl.vllm_model_provider.model.image import VllmImage from setting.models_provider.impl.vllm_model_provider.model.llm import VllmChatModel from smartdoc.conf import PROJECT_DIR @@ -16,6 +18,7 @@ from django.utils.translation import gettext_lazy as _ v_llm_model_credential = VLLMModelCredential() image_model_credential = VllmImageModelCredential() +embedding_model_credential = VllmEmbeddingCredential() model_info_list = [ ModelInfo('facebook/opt-125m', _('Facebook’s 125M parameter model'), ModelTypeConst.LLM, v_llm_model_credential, VllmChatModel), @@ -28,6 +31,10 @@ image_model_info_list = [ ModelInfo('Qwen/Qwen2-VL-2B-Instruct', '', ModelTypeConst.IMAGE, image_model_credential, VllmImage), ] +embedding_model_info_list = [ + ModelInfo('HIT-TMG/KaLM-embedding-multilingual-mini-instruct-v1.5', '', ModelTypeConst.EMBEDDING, embedding_model_credential, VllmEmbeddingModel), +] + model_info_manage = ( ModelInfoManage.builder() .append_model_info_list(model_info_list) @@ -36,6 +43,8 @@ model_info_manage = ( ModelTypeConst.LLM, v_llm_model_credential, VllmChatModel)) .append_model_info_list(image_model_info_list) .append_default_model_info(image_model_info_list[0]) + .append_model_info_list(embedding_model_info_list) + .append_default_model_info(embedding_model_info_list[0]) .build() )