refactor: ollama support rerank

This commit is contained in:
wxg0103 2025-02-24 18:35:37 +08:00
parent f02b40b830
commit 9185515660
2 changed files with 34 additions and 69 deletions

View File

@ -64,4 +64,3 @@ class OllamaReRankModelCredential(BaseForm, BaseModelCredential):
return self
api_base = forms.TextInputField('API URL', required=True)
api_key = forms.TextInputField('API Key', required=True)

View File

@ -1,82 +1,48 @@
from typing import Sequence, Optional, Any, Dict
from langchain_core.callbacks import Callbacks
from langchain_core.documents import BaseDocumentCompressor, Document
import requests
from langchain_core.documents import Document
from langchain_community.embeddings import OllamaEmbeddings
from setting.models_provider.base_model_provider import MaxKBBaseModel
from sklearn.metrics.pairwise import cosine_similarity
from pydantic.v1 import BaseModel, Field
class OllamaReranker(MaxKBBaseModel, BaseDocumentCompressor):
api_base: Optional[str]
"""URL of the Ollama server"""
model_name: Optional[str]
"""The model name to use for reranking"""
api_key: Optional[str]
class OllamaReranker(MaxKBBaseModel, OllamaEmbeddings, BaseModel):
top_n: Optional[int] = Field(3, description="Number of top documents to return")
@staticmethod
def new_instance(model_name, model_credential: Dict[str, object], **model_kwargs):
return OllamaReranker(api_base=model_credential.get('api_base'), model_name=model_name,
api_key=model_credential.get('api_key'), top_n=model_kwargs.get('top_n', 3))
top_n: Optional[int] = 3
def __init__(
self, api_base: Optional[str] = None, model_name: Optional[str] = None, top_n=3,
api_key: Optional[str] = None
):
super().__init__()
if api_base is None:
raise ValueError("Please provide server URL")
if model_name is None:
raise ValueError("Please provide the model name")
self.api_base = api_base
self.model_name = model_name
self.api_key = api_key
self.top_n = top_n
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
return OllamaReranker(
model=model_name,
base_url=model_credential.get('api_base'),
**optional_params
)
def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \
Sequence[Document]:
"""
Given a query and a set of documents, rerank them using Ollama API.
"""
if not documents or len(documents) == 0:
return []
"""Rank documents based on their similarity to the query.
# Prepare the data to send to Ollama API
headers = {
'Authorization': f'Bearer {self.api_key}' # Use API key for authentication if required
}
Args:
query: The query text.
documents: The list of document texts to rank.
# Format the documents to be sent in a format understood by Ollama's API
documents_text = [document.page_content for document in documents]
Returns:
List of documents sorted by relevance to the query.
"""
# 获取查询和文档的嵌入
query_embedding = self.embed_query(query)
documents = [doc.page_content for doc in documents]
document_embeddings = self.embed_documents(documents)
# 计算相似度
similarities = cosine_similarity([query_embedding], document_embeddings)[0]
ranked_docs = [(doc,_) for _, doc in sorted(zip(similarities, documents), reverse=True)][:self.top_n]
return [
Document(
page_content=doc, # 第一个值是文档内容
metadata={'relevance_score': score} # 第二个值是相似度分数
)
for doc, score in ranked_docs
]
# Make a POST request to Ollama's rerank API endpoint
payload = {
'model': self.model_name, # Specify the model
'query': query,
'documents': documents_text,
'top_n': self.top_n
}
try:
response = requests.post(f'{self.api_base}/v1/rerank', headers=headers, json=payload)
response.raise_for_status()
res = response.json()
# Ensure the response contains expected results
if 'results' not in res:
raise ValueError("The API response did not contain rerank results.")
# Convert the API response into a list of Document objects with relevance scores
ranked_documents = [
Document(page_content=d['text'], metadata={'relevance_score': d['relevance_score']})
for d in res['results']
]
return ranked_documents
except requests.exceptions.RequestException as e:
print(f"Error during API request: {e}")
return [] # Return an empty list if the request failed