diff --git a/apps/models_provider/impl/vllm_model_provider/credential/reranker.py b/apps/models_provider/impl/vllm_model_provider/credential/reranker.py new file mode 100644 index 000000000..6d391d14d --- /dev/null +++ b/apps/models_provider/impl/vllm_model_provider/credential/reranker.py @@ -0,0 +1,50 @@ +import traceback +from typing import Dict + +from langchain_core.documents import Document + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm +from models_provider.base_model_provider import BaseModelCredential, ValidCode +from django.utils.translation import gettext_lazy as _ + +from models_provider.impl.vllm_model_provider.model.reranker import VllmBgeReranker + + +class VllmRerankerCredential(BaseForm, BaseModelCredential): + api_url = forms.TextInputField('API URL', required=True) + api_key = forms.PasswordInputField('API Key', required=True) + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, + raise_exception=True): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, + _('{model_type} Model type is not supported').format(model_type=model_type)) + + for key in ['api_url', 'api_key']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, _('{key} is required').format(key=key)) + else: + return False + try: + model: VllmBgeReranker = provider.get_model(model_type, model_name, model_credential) + model.compress_documents([Document(page_content=_('Hello'))], _('Hello')) + except Exception as e: + traceback.print_exc() + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException( + ValidCode.valid_error.value, + _('Verification failed, please check whether the parameters are correct: {error}').format( + error=str(e)) + ) + return False + + return True + + def encryption_dict(self, model_info: Dict[str, object]): + return {**model_info, 'api_key': super().encryption(model_info.get('api_key', ''))} \ No newline at end of file diff --git a/apps/models_provider/impl/vllm_model_provider/model/reranker.py b/apps/models_provider/impl/vllm_model_provider/model/reranker.py new file mode 100644 index 000000000..e11b3aa9b --- /dev/null +++ b/apps/models_provider/impl/vllm_model_provider/model/reranker.py @@ -0,0 +1,47 @@ +from typing import Sequence, Optional, Dict, Any + +import cohere +from langchain_core.callbacks import Callbacks +from langchain_core.documents import BaseDocumentCompressor, Document + +from models_provider.base_model_provider import MaxKBBaseModel + + +class VllmBgeReranker(MaxKBBaseModel, BaseDocumentCompressor): + api_key: str + api_url: str + model: str + params: dict + client: Any = None + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.api_key = kwargs.get('api_key') + self.model = kwargs.get('model') + self.params = kwargs.get('params') + self.api_url = kwargs.get('api_url') + self.client = cohere.ClientV2(kwargs.get('api_key'), base_url=kwargs.get('api_url')) + + @staticmethod + def is_cache_model(): + return False + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + return VllmBgeReranker( + model=model_name, + api_key=model_credential.get('api_key'), + api_url=model_credential.get('api_url'), + params=model_kwargs, + **model_kwargs + ) + + def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \ + Sequence[Document]: + if documents is None or len(documents) == 0: + return [] + + ds = [d.page_content for d in documents] + result = self.client.rerank(model=self.model, query=query, documents=ds) + return [Document(page_content=d.document.get('text'), metadata={'relevance_score': d.relevance_score}) for d in + result.results] diff --git a/apps/models_provider/impl/vllm_model_provider/vllm_model_provider.py b/apps/models_provider/impl/vllm_model_provider/vllm_model_provider.py index 12da7ce66..009fe88b8 100644 --- a/apps/models_provider/impl/vllm_model_provider/vllm_model_provider.py +++ b/apps/models_provider/impl/vllm_model_provider/vllm_model_provider.py @@ -10,6 +10,7 @@ from models_provider.base_model_provider import IModelProvider, ModelProvideInfo from models_provider.impl.vllm_model_provider.credential.embedding import VllmEmbeddingCredential from models_provider.impl.vllm_model_provider.credential.image import VllmImageModelCredential from models_provider.impl.vllm_model_provider.credential.llm import VLLMModelCredential +from models_provider.impl.vllm_model_provider.credential.reranker import VllmRerankerCredential from models_provider.impl.vllm_model_provider.credential.whisper_stt import VLLMWhisperModelCredential from models_provider.impl.vllm_model_provider.model.embedding import VllmEmbeddingModel from models_provider.impl.vllm_model_provider.model.image import VllmImage @@ -17,12 +18,14 @@ from models_provider.impl.vllm_model_provider.model.llm import VllmChatModel from maxkb.conf import PROJECT_DIR from django.utils.translation import gettext as _ +from models_provider.impl.vllm_model_provider.model.reranker import VllmBgeReranker from models_provider.impl.vllm_model_provider.model.whisper_sst import VllmWhisperSpeechToText v_llm_model_credential = VLLMModelCredential() image_model_credential = VllmImageModelCredential() embedding_model_credential = VllmEmbeddingCredential() whisper_model_credential = VLLMWhisperModelCredential() +rerank_model_credential = VllmRerankerCredential() model_info_list = [ ModelInfo('facebook/opt-125m', _('Facebook’s 125M parameter model'), ModelTypeConst.LLM, v_llm_model_credential, @@ -50,6 +53,10 @@ whisper_model_info_list = [ ModelInfo('whisper-large-v3', '', ModelTypeConst.STT, whisper_model_credential, VllmWhisperSpeechToText), ] +reranker_model_info_list = [ + ModelInfo('bge-reranker-v2-m3', '', ModelTypeConst.RERANKER, rerank_model_credential, VllmBgeReranker), +] + model_info_manage = ( ModelInfoManage.builder() .append_model_info_list(model_info_list) @@ -62,6 +69,8 @@ model_info_manage = ( .append_default_model_info(embedding_model_info_list[0]) .append_model_info_list(whisper_model_info_list) .append_default_model_info(whisper_model_info_list[0]) + .append_model_info_list(reranker_model_info_list) + .append_default_model_info(reranker_model_info_list[0]) .build() ) diff --git a/pyproject.toml b/pyproject.toml index 5c60d2142..70458993a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,6 +57,7 @@ dependencies = [ "python-daemon==3.1.2", "websockets==15.0.1", "pylint==3.3.7", + "cohere>=5.17.0", ] [tool.uv]