diff --git a/apps/application/swagger_api/chat_api.py b/apps/application/swagger_api/chat_api.py index a5c56ca16..3f9c1bf70 100644 --- a/apps/application/swagger_api/chat_api.py +++ b/apps/application/swagger_api/chat_api.py @@ -165,7 +165,15 @@ class ChatApi(ApiMixin): openapi.Parameter(name='min_trample', in_=openapi.IN_QUERY, type=openapi.TYPE_INTEGER, required=False, description=_("Minimum number of clicks")), openapi.Parameter(name='comparer', in_=openapi.IN_QUERY, type=openapi.TYPE_STRING, required=False, - description=_("or|and comparator")) + description=_("or|and comparator")), + openapi.Parameter(name='start_time', in_=openapi.IN_QUERY, + type=openapi.TYPE_STRING, + required=True, + description=_('start time')), + openapi.Parameter(name='end_time', in_=openapi.IN_QUERY, + type=openapi.TYPE_STRING, + required=True, + description=_('End time')), ] diff --git a/apps/setting/models_provider/impl/ollama_model_provider/credential/reranker.py b/apps/setting/models_provider/impl/ollama_model_provider/credential/reranker.py new file mode 100644 index 000000000..7f7feff15 --- /dev/null +++ b/apps/setting/models_provider/impl/ollama_model_provider/credential/reranker.py @@ -0,0 +1,67 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: embedding.py + @date:2024/7/12 15:10 + @desc: +""" +from typing import Dict + +from django.utils.translation import gettext as _ + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode +from setting.models_provider.impl.ollama_model_provider.model.reranker import OllamaReranker +from langchain_core.documents import BaseDocumentCompressor, Document +from django.utils.translation import gettext_lazy as _, gettext + + +class OllamaReRankModelCredential(BaseForm, BaseModelCredential): + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, + raise_exception=False): + if not model_type == 'RERANKER': + raise AppApiException(ValidCode.valid_error.value, + gettext('{model_type} Model type is not supported').format(model_type=model_type)) + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, + _('{model_type} Model type is not supported').format(model_type=model_type)) + try: + model_list = provider.get_base_model_list(model_credential.get('api_base')) + except Exception as e: + raise AppApiException(ValidCode.valid_error.value, _('API domain name is invalid')) + exist = [model for model in (model_list.get('models') if model_list.get('models') is not None else []) if + model.get('model') == model_name or model.get('model').replace(":latest", "") == model_name] + if len(exist) == 0: + raise AppApiException(ValidCode.model_not_fount, + _('The model does not exist, please download the model first')) + + try: + model: OllamaReranker = provider.get_model(model_type, model_name, model_credential) + model.compress_documents([Document(page_content=gettext('Hello'))], gettext('Hello')) + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, + gettext( + 'Verification failed, please check whether the parameters are correct: {error}').format( + error=str(e))) + else: + return False + return True + + def encryption_dict(self, model_info: Dict[str, object]): + return model_info + + def build_model(self, model_info: Dict[str, object]): + for key in ['model']: + if key not in model_info: + raise AppApiException(500, _('{key} is required').format(key=key)) + return self + + api_base = forms.TextInputField('API URL', required=True) + api_key = forms.TextInputField('API Key', required=True) diff --git a/apps/setting/models_provider/impl/ollama_model_provider/model/reranker.py b/apps/setting/models_provider/impl/ollama_model_provider/model/reranker.py new file mode 100644 index 000000000..fd004ea01 --- /dev/null +++ b/apps/setting/models_provider/impl/ollama_model_provider/model/reranker.py @@ -0,0 +1,82 @@ +from typing import Sequence, Optional, Any, Dict +from langchain_core.callbacks import Callbacks +from langchain_core.documents import BaseDocumentCompressor, Document +import requests + +from setting.models_provider.base_model_provider import MaxKBBaseModel + + +class OllamaReranker(MaxKBBaseModel, BaseDocumentCompressor): + api_base: Optional[str] + """URL of the Ollama server""" + model_name: Optional[str] + """The model name to use for reranking""" + api_key: Optional[str] + + @staticmethod + def new_instance(model_name, model_credential: Dict[str, object], **model_kwargs): + return OllamaReranker(api_base=model_credential.get('api_base'), model_name=model_name, + api_key=model_credential.get('api_key'), top_n=model_kwargs.get('top_n', 3)) + + top_n: Optional[int] = 3 + + def __init__( + self, api_base: Optional[str] = None, model_name: Optional[str] = None, top_n=3, + api_key: Optional[str] = None + ): + super().__init__() + + if api_base is None: + raise ValueError("Please provide server URL") + + if model_name is None: + raise ValueError("Please provide the model name") + + self.api_base = api_base + self.model_name = model_name + self.api_key = api_key + self.top_n = top_n + + def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \ + Sequence[Document]: + """ + Given a query and a set of documents, rerank them using Ollama API. + """ + if not documents or len(documents) == 0: + return [] + + # Prepare the data to send to Ollama API + headers = { + 'Authorization': f'Bearer {self.api_key}' # Use API key for authentication if required + } + + # Format the documents to be sent in a format understood by Ollama's API + documents_text = [document.page_content for document in documents] + + # Make a POST request to Ollama's rerank API endpoint + payload = { + 'model': self.model_name, # Specify the model + 'query': query, + 'documents': documents_text, + 'top_n': self.top_n + } + + try: + response = requests.post(f'{self.api_base}/v1/rerank', headers=headers, json=payload) + response.raise_for_status() + res = response.json() + + # Ensure the response contains expected results + if 'results' not in res: + raise ValueError("The API response did not contain rerank results.") + + # Convert the API response into a list of Document objects with relevance scores + ranked_documents = [ + Document(page_content=d['text'], metadata={'relevance_score': d['relevance_score']}) + for d in res['results'] + ] + return ranked_documents + + except requests.exceptions.RequestException as e: + print(f"Error during API request: {e}") + return [] # Return an empty list if the request failed diff --git a/apps/setting/models_provider/impl/ollama_model_provider/ollama_model_provider.py b/apps/setting/models_provider/impl/ollama_model_provider/ollama_model_provider.py index 5da9a7fe0..8a970e6d3 100644 --- a/apps/setting/models_provider/impl/ollama_model_provider/ollama_model_provider.py +++ b/apps/setting/models_provider/impl/ollama_model_provider/ollama_model_provider.py @@ -18,9 +18,11 @@ from setting.models_provider.base_model_provider import IModelProvider, ModelPro from setting.models_provider.impl.ollama_model_provider.credential.embedding import OllamaEmbeddingModelCredential from setting.models_provider.impl.ollama_model_provider.credential.image import OllamaImageModelCredential from setting.models_provider.impl.ollama_model_provider.credential.llm import OllamaLLMModelCredential +from setting.models_provider.impl.ollama_model_provider.credential.reranker import OllamaReRankModelCredential from setting.models_provider.impl.ollama_model_provider.model.embedding import OllamaEmbedding from setting.models_provider.impl.ollama_model_provider.model.image import OllamaImage from setting.models_provider.impl.ollama_model_provider.model.llm import OllamaChatModel +from setting.models_provider.impl.ollama_model_provider.model.reranker import OllamaReranker from smartdoc.conf import PROJECT_DIR from django.utils.translation import gettext as _ @@ -153,12 +155,19 @@ model_info_list = [ ] ollama_embedding_model_credential = OllamaEmbeddingModelCredential() ollama_image_model_credential = OllamaImageModelCredential() +ollama_reranker_model_credential = OllamaReRankModelCredential() embedding_model_info = [ ModelInfo( 'nomic-embed-text', _('A high-performance open embedding model with a large token context window.'), ModelTypeConst.EMBEDDING, ollama_embedding_model_credential, OllamaEmbedding), ] +reranker_model_info = [ + ModelInfo( + 'ollama:reranker', + '', + ModelTypeConst.RERANKER, ollama_reranker_model_credential, OllamaReranker), +] image_model_info = [ ModelInfo( @@ -189,6 +198,8 @@ model_info_manage = ( ModelTypeConst.EMBEDDING, ollama_embedding_model_credential, OllamaEmbedding), ) .append_model_info_list(image_model_info) .append_default_model_info(image_model_info[0]) + .append_model_info_list(reranker_model_info) + .append_default_model_info(reranker_model_info[0]) .build() )