mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-30 01:32:49 +00:00
feat: support ollama rerank
--story=1017862 --user=王孝刚 希望支持在 Ollama 中添加 rerank 模型 issue#2243 https://www.tapd.cn/57709429/s/1655139
This commit is contained in:
parent
a071d7c89b
commit
7f597b6409
|
|
@ -165,7 +165,15 @@ class ChatApi(ApiMixin):
|
|||
openapi.Parameter(name='min_trample', in_=openapi.IN_QUERY, type=openapi.TYPE_INTEGER, required=False,
|
||||
description=_("Minimum number of clicks")),
|
||||
openapi.Parameter(name='comparer', in_=openapi.IN_QUERY, type=openapi.TYPE_STRING, required=False,
|
||||
description=_("or|and comparator"))
|
||||
description=_("or|and comparator")),
|
||||
openapi.Parameter(name='start_time', in_=openapi.IN_QUERY,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=True,
|
||||
description=_('start time')),
|
||||
openapi.Parameter(name='end_time', in_=openapi.IN_QUERY,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=True,
|
||||
description=_('End time')),
|
||||
]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,67 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: embedding.py
|
||||
@date:2024/7/12 15:10
|
||||
@desc:
|
||||
"""
|
||||
from typing import Dict
|
||||
|
||||
from django.utils.translation import gettext as _
|
||||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm
|
||||
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
from setting.models_provider.impl.ollama_model_provider.model.reranker import OllamaReranker
|
||||
from langchain_core.documents import BaseDocumentCompressor, Document
|
||||
from django.utils.translation import gettext_lazy as _, gettext
|
||||
|
||||
|
||||
class OllamaReRankModelCredential(BaseForm, BaseModelCredential):
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||
raise_exception=False):
|
||||
if not model_type == 'RERANKER':
|
||||
raise AppApiException(ValidCode.valid_error.value,
|
||||
gettext('{model_type} Model type is not supported').format(model_type=model_type))
|
||||
model_type_list = provider.get_model_type_list()
|
||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||
raise AppApiException(ValidCode.valid_error.value,
|
||||
_('{model_type} Model type is not supported').format(model_type=model_type))
|
||||
try:
|
||||
model_list = provider.get_base_model_list(model_credential.get('api_base'))
|
||||
except Exception as e:
|
||||
raise AppApiException(ValidCode.valid_error.value, _('API domain name is invalid'))
|
||||
exist = [model for model in (model_list.get('models') if model_list.get('models') is not None else []) if
|
||||
model.get('model') == model_name or model.get('model').replace(":latest", "") == model_name]
|
||||
if len(exist) == 0:
|
||||
raise AppApiException(ValidCode.model_not_fount,
|
||||
_('The model does not exist, please download the model first'))
|
||||
|
||||
try:
|
||||
model: OllamaReranker = provider.get_model(model_type, model_name, model_credential)
|
||||
model.compress_documents([Document(page_content=gettext('Hello'))], gettext('Hello'))
|
||||
except Exception as e:
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value,
|
||||
gettext(
|
||||
'Verification failed, please check whether the parameters are correct: {error}').format(
|
||||
error=str(e)))
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
def encryption_dict(self, model_info: Dict[str, object]):
|
||||
return model_info
|
||||
|
||||
def build_model(self, model_info: Dict[str, object]):
|
||||
for key in ['model']:
|
||||
if key not in model_info:
|
||||
raise AppApiException(500, _('{key} is required').format(key=key))
|
||||
return self
|
||||
|
||||
api_base = forms.TextInputField('API URL', required=True)
|
||||
api_key = forms.TextInputField('API Key', required=True)
|
||||
|
|
@ -0,0 +1,82 @@
|
|||
from typing import Sequence, Optional, Any, Dict
|
||||
from langchain_core.callbacks import Callbacks
|
||||
from langchain_core.documents import BaseDocumentCompressor, Document
|
||||
import requests
|
||||
|
||||
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||
|
||||
|
||||
class OllamaReranker(MaxKBBaseModel, BaseDocumentCompressor):
|
||||
api_base: Optional[str]
|
||||
"""URL of the Ollama server"""
|
||||
model_name: Optional[str]
|
||||
"""The model name to use for reranking"""
|
||||
api_key: Optional[str]
|
||||
|
||||
@staticmethod
|
||||
def new_instance(model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
return OllamaReranker(api_base=model_credential.get('api_base'), model_name=model_name,
|
||||
api_key=model_credential.get('api_key'), top_n=model_kwargs.get('top_n', 3))
|
||||
|
||||
top_n: Optional[int] = 3
|
||||
|
||||
def __init__(
|
||||
self, api_base: Optional[str] = None, model_name: Optional[str] = None, top_n=3,
|
||||
api_key: Optional[str] = None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if api_base is None:
|
||||
raise ValueError("Please provide server URL")
|
||||
|
||||
if model_name is None:
|
||||
raise ValueError("Please provide the model name")
|
||||
|
||||
self.api_base = api_base
|
||||
self.model_name = model_name
|
||||
self.api_key = api_key
|
||||
self.top_n = top_n
|
||||
|
||||
def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \
|
||||
Sequence[Document]:
|
||||
"""
|
||||
Given a query and a set of documents, rerank them using Ollama API.
|
||||
"""
|
||||
if not documents or len(documents) == 0:
|
||||
return []
|
||||
|
||||
# Prepare the data to send to Ollama API
|
||||
headers = {
|
||||
'Authorization': f'Bearer {self.api_key}' # Use API key for authentication if required
|
||||
}
|
||||
|
||||
# Format the documents to be sent in a format understood by Ollama's API
|
||||
documents_text = [document.page_content for document in documents]
|
||||
|
||||
# Make a POST request to Ollama's rerank API endpoint
|
||||
payload = {
|
||||
'model': self.model_name, # Specify the model
|
||||
'query': query,
|
||||
'documents': documents_text,
|
||||
'top_n': self.top_n
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(f'{self.api_base}/v1/rerank', headers=headers, json=payload)
|
||||
response.raise_for_status()
|
||||
res = response.json()
|
||||
|
||||
# Ensure the response contains expected results
|
||||
if 'results' not in res:
|
||||
raise ValueError("The API response did not contain rerank results.")
|
||||
|
||||
# Convert the API response into a list of Document objects with relevance scores
|
||||
ranked_documents = [
|
||||
Document(page_content=d['text'], metadata={'relevance_score': d['relevance_score']})
|
||||
for d in res['results']
|
||||
]
|
||||
return ranked_documents
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
print(f"Error during API request: {e}")
|
||||
return [] # Return an empty list if the request failed
|
||||
|
|
@ -18,9 +18,11 @@ from setting.models_provider.base_model_provider import IModelProvider, ModelPro
|
|||
from setting.models_provider.impl.ollama_model_provider.credential.embedding import OllamaEmbeddingModelCredential
|
||||
from setting.models_provider.impl.ollama_model_provider.credential.image import OllamaImageModelCredential
|
||||
from setting.models_provider.impl.ollama_model_provider.credential.llm import OllamaLLMModelCredential
|
||||
from setting.models_provider.impl.ollama_model_provider.credential.reranker import OllamaReRankModelCredential
|
||||
from setting.models_provider.impl.ollama_model_provider.model.embedding import OllamaEmbedding
|
||||
from setting.models_provider.impl.ollama_model_provider.model.image import OllamaImage
|
||||
from setting.models_provider.impl.ollama_model_provider.model.llm import OllamaChatModel
|
||||
from setting.models_provider.impl.ollama_model_provider.model.reranker import OllamaReranker
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
from django.utils.translation import gettext as _
|
||||
|
||||
|
|
@ -153,12 +155,19 @@ model_info_list = [
|
|||
]
|
||||
ollama_embedding_model_credential = OllamaEmbeddingModelCredential()
|
||||
ollama_image_model_credential = OllamaImageModelCredential()
|
||||
ollama_reranker_model_credential = OllamaReRankModelCredential()
|
||||
embedding_model_info = [
|
||||
ModelInfo(
|
||||
'nomic-embed-text',
|
||||
_('A high-performance open embedding model with a large token context window.'),
|
||||
ModelTypeConst.EMBEDDING, ollama_embedding_model_credential, OllamaEmbedding),
|
||||
]
|
||||
reranker_model_info = [
|
||||
ModelInfo(
|
||||
'ollama:reranker',
|
||||
'',
|
||||
ModelTypeConst.RERANKER, ollama_reranker_model_credential, OllamaReranker),
|
||||
]
|
||||
|
||||
image_model_info = [
|
||||
ModelInfo(
|
||||
|
|
@ -189,6 +198,8 @@ model_info_manage = (
|
|||
ModelTypeConst.EMBEDDING, ollama_embedding_model_credential, OllamaEmbedding), )
|
||||
.append_model_info_list(image_model_info)
|
||||
.append_default_model_info(image_model_info[0])
|
||||
.append_model_info_list(reranker_model_info)
|
||||
.append_default_model_info(reranker_model_info[0])
|
||||
.build()
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue