feat: support ollama rerank

--story=1017862 --user=王孝刚 希望支持在 Ollama 中添加 rerank 模型 issue#2243 https://www.tapd.cn/57709429/s/1655139
This commit is contained in:
wxg0103 2025-02-17 18:38:49 +08:00
parent a071d7c89b
commit 7f597b6409
4 changed files with 169 additions and 1 deletions

View File

@ -165,7 +165,15 @@ class ChatApi(ApiMixin):
openapi.Parameter(name='min_trample', in_=openapi.IN_QUERY, type=openapi.TYPE_INTEGER, required=False,
description=_("Minimum number of clicks")),
openapi.Parameter(name='comparer', in_=openapi.IN_QUERY, type=openapi.TYPE_STRING, required=False,
description=_("or|and comparator"))
description=_("or|and comparator")),
openapi.Parameter(name='start_time', in_=openapi.IN_QUERY,
type=openapi.TYPE_STRING,
required=True,
description=_('start time')),
openapi.Parameter(name='end_time', in_=openapi.IN_QUERY,
type=openapi.TYPE_STRING,
required=True,
description=_('End time')),
]

View File

@ -0,0 +1,67 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file embedding.py
@date2024/7/12 15:10
@desc:
"""
from typing import Dict
from django.utils.translation import gettext as _
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
from setting.models_provider.impl.ollama_model_provider.model.reranker import OllamaReranker
from langchain_core.documents import BaseDocumentCompressor, Document
from django.utils.translation import gettext_lazy as _, gettext
class OllamaReRankModelCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
raise_exception=False):
if not model_type == 'RERANKER':
raise AppApiException(ValidCode.valid_error.value,
gettext('{model_type} Model type is not supported').format(model_type=model_type))
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value,
_('{model_type} Model type is not supported').format(model_type=model_type))
try:
model_list = provider.get_base_model_list(model_credential.get('api_base'))
except Exception as e:
raise AppApiException(ValidCode.valid_error.value, _('API domain name is invalid'))
exist = [model for model in (model_list.get('models') if model_list.get('models') is not None else []) if
model.get('model') == model_name or model.get('model').replace(":latest", "") == model_name]
if len(exist) == 0:
raise AppApiException(ValidCode.model_not_fount,
_('The model does not exist, please download the model first'))
try:
model: OllamaReranker = provider.get_model(model_type, model_name, model_credential)
model.compress_documents([Document(page_content=gettext('Hello'))], gettext('Hello'))
except Exception as e:
if isinstance(e, AppApiException):
raise e
if raise_exception:
raise AppApiException(ValidCode.valid_error.value,
gettext(
'Verification failed, please check whether the parameters are correct: {error}').format(
error=str(e)))
else:
return False
return True
def encryption_dict(self, model_info: Dict[str, object]):
return model_info
def build_model(self, model_info: Dict[str, object]):
for key in ['model']:
if key not in model_info:
raise AppApiException(500, _('{key} is required').format(key=key))
return self
api_base = forms.TextInputField('API URL', required=True)
api_key = forms.TextInputField('API Key', required=True)

View File

@ -0,0 +1,82 @@
from typing import Sequence, Optional, Any, Dict
from langchain_core.callbacks import Callbacks
from langchain_core.documents import BaseDocumentCompressor, Document
import requests
from setting.models_provider.base_model_provider import MaxKBBaseModel
class OllamaReranker(MaxKBBaseModel, BaseDocumentCompressor):
api_base: Optional[str]
"""URL of the Ollama server"""
model_name: Optional[str]
"""The model name to use for reranking"""
api_key: Optional[str]
@staticmethod
def new_instance(model_name, model_credential: Dict[str, object], **model_kwargs):
return OllamaReranker(api_base=model_credential.get('api_base'), model_name=model_name,
api_key=model_credential.get('api_key'), top_n=model_kwargs.get('top_n', 3))
top_n: Optional[int] = 3
def __init__(
self, api_base: Optional[str] = None, model_name: Optional[str] = None, top_n=3,
api_key: Optional[str] = None
):
super().__init__()
if api_base is None:
raise ValueError("Please provide server URL")
if model_name is None:
raise ValueError("Please provide the model name")
self.api_base = api_base
self.model_name = model_name
self.api_key = api_key
self.top_n = top_n
def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \
Sequence[Document]:
"""
Given a query and a set of documents, rerank them using Ollama API.
"""
if not documents or len(documents) == 0:
return []
# Prepare the data to send to Ollama API
headers = {
'Authorization': f'Bearer {self.api_key}' # Use API key for authentication if required
}
# Format the documents to be sent in a format understood by Ollama's API
documents_text = [document.page_content for document in documents]
# Make a POST request to Ollama's rerank API endpoint
payload = {
'model': self.model_name, # Specify the model
'query': query,
'documents': documents_text,
'top_n': self.top_n
}
try:
response = requests.post(f'{self.api_base}/v1/rerank', headers=headers, json=payload)
response.raise_for_status()
res = response.json()
# Ensure the response contains expected results
if 'results' not in res:
raise ValueError("The API response did not contain rerank results.")
# Convert the API response into a list of Document objects with relevance scores
ranked_documents = [
Document(page_content=d['text'], metadata={'relevance_score': d['relevance_score']})
for d in res['results']
]
return ranked_documents
except requests.exceptions.RequestException as e:
print(f"Error during API request: {e}")
return [] # Return an empty list if the request failed

View File

@ -18,9 +18,11 @@ from setting.models_provider.base_model_provider import IModelProvider, ModelPro
from setting.models_provider.impl.ollama_model_provider.credential.embedding import OllamaEmbeddingModelCredential
from setting.models_provider.impl.ollama_model_provider.credential.image import OllamaImageModelCredential
from setting.models_provider.impl.ollama_model_provider.credential.llm import OllamaLLMModelCredential
from setting.models_provider.impl.ollama_model_provider.credential.reranker import OllamaReRankModelCredential
from setting.models_provider.impl.ollama_model_provider.model.embedding import OllamaEmbedding
from setting.models_provider.impl.ollama_model_provider.model.image import OllamaImage
from setting.models_provider.impl.ollama_model_provider.model.llm import OllamaChatModel
from setting.models_provider.impl.ollama_model_provider.model.reranker import OllamaReranker
from smartdoc.conf import PROJECT_DIR
from django.utils.translation import gettext as _
@ -153,12 +155,19 @@ model_info_list = [
]
ollama_embedding_model_credential = OllamaEmbeddingModelCredential()
ollama_image_model_credential = OllamaImageModelCredential()
ollama_reranker_model_credential = OllamaReRankModelCredential()
embedding_model_info = [
ModelInfo(
'nomic-embed-text',
_('A high-performance open embedding model with a large token context window.'),
ModelTypeConst.EMBEDDING, ollama_embedding_model_credential, OllamaEmbedding),
]
reranker_model_info = [
ModelInfo(
'ollama:reranker',
'',
ModelTypeConst.RERANKER, ollama_reranker_model_credential, OllamaReranker),
]
image_model_info = [
ModelInfo(
@ -189,6 +198,8 @@ model_info_manage = (
ModelTypeConst.EMBEDDING, ollama_embedding_model_credential, OllamaEmbedding), )
.append_model_info_list(image_model_info)
.append_default_model_info(image_model_info[0])
.append_model_info_list(reranker_model_info)
.append_default_model_info(reranker_model_info[0])
.build()
)