mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 01:33:05 +00:00
refactor: ollama support rerank
This commit is contained in:
parent
f02b40b830
commit
9185515660
|
|
@ -64,4 +64,3 @@ class OllamaReRankModelCredential(BaseForm, BaseModelCredential):
|
|||
return self
|
||||
|
||||
api_base = forms.TextInputField('API URL', required=True)
|
||||
api_key = forms.TextInputField('API Key', required=True)
|
||||
|
|
|
|||
|
|
@ -1,82 +1,48 @@
|
|||
from typing import Sequence, Optional, Any, Dict
|
||||
from langchain_core.callbacks import Callbacks
|
||||
from langchain_core.documents import BaseDocumentCompressor, Document
|
||||
import requests
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_community.embeddings import OllamaEmbeddings
|
||||
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
from pydantic.v1 import BaseModel, Field
|
||||
|
||||
|
||||
class OllamaReranker(MaxKBBaseModel, BaseDocumentCompressor):
|
||||
api_base: Optional[str]
|
||||
"""URL of the Ollama server"""
|
||||
model_name: Optional[str]
|
||||
"""The model name to use for reranking"""
|
||||
api_key: Optional[str]
|
||||
class OllamaReranker(MaxKBBaseModel, OllamaEmbeddings, BaseModel):
|
||||
top_n: Optional[int] = Field(3, description="Number of top documents to return")
|
||||
|
||||
@staticmethod
|
||||
def new_instance(model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
return OllamaReranker(api_base=model_credential.get('api_base'), model_name=model_name,
|
||||
api_key=model_credential.get('api_key'), top_n=model_kwargs.get('top_n', 3))
|
||||
|
||||
top_n: Optional[int] = 3
|
||||
|
||||
def __init__(
|
||||
self, api_base: Optional[str] = None, model_name: Optional[str] = None, top_n=3,
|
||||
api_key: Optional[str] = None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if api_base is None:
|
||||
raise ValueError("Please provide server URL")
|
||||
|
||||
if model_name is None:
|
||||
raise ValueError("Please provide the model name")
|
||||
|
||||
self.api_base = api_base
|
||||
self.model_name = model_name
|
||||
self.api_key = api_key
|
||||
self.top_n = top_n
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
|
||||
return OllamaReranker(
|
||||
model=model_name,
|
||||
base_url=model_credential.get('api_base'),
|
||||
**optional_params
|
||||
)
|
||||
|
||||
def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \
|
||||
Sequence[Document]:
|
||||
"""
|
||||
Given a query and a set of documents, rerank them using Ollama API.
|
||||
"""
|
||||
if not documents or len(documents) == 0:
|
||||
return []
|
||||
"""Rank documents based on their similarity to the query.
|
||||
|
||||
# Prepare the data to send to Ollama API
|
||||
headers = {
|
||||
'Authorization': f'Bearer {self.api_key}' # Use API key for authentication if required
|
||||
}
|
||||
Args:
|
||||
query: The query text.
|
||||
documents: The list of document texts to rank.
|
||||
|
||||
# Format the documents to be sent in a format understood by Ollama's API
|
||||
documents_text = [document.page_content for document in documents]
|
||||
Returns:
|
||||
List of documents sorted by relevance to the query.
|
||||
"""
|
||||
# 获取查询和文档的嵌入
|
||||
query_embedding = self.embed_query(query)
|
||||
documents = [doc.page_content for doc in documents]
|
||||
document_embeddings = self.embed_documents(documents)
|
||||
# 计算相似度
|
||||
similarities = cosine_similarity([query_embedding], document_embeddings)[0]
|
||||
ranked_docs = [(doc,_) for _, doc in sorted(zip(similarities, documents), reverse=True)][:self.top_n]
|
||||
return [
|
||||
Document(
|
||||
page_content=doc, # 第一个值是文档内容
|
||||
metadata={'relevance_score': score} # 第二个值是相似度分数
|
||||
)
|
||||
for doc, score in ranked_docs
|
||||
]
|
||||
|
||||
# Make a POST request to Ollama's rerank API endpoint
|
||||
payload = {
|
||||
'model': self.model_name, # Specify the model
|
||||
'query': query,
|
||||
'documents': documents_text,
|
||||
'top_n': self.top_n
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(f'{self.api_base}/v1/rerank', headers=headers, json=payload)
|
||||
response.raise_for_status()
|
||||
res = response.json()
|
||||
|
||||
# Ensure the response contains expected results
|
||||
if 'results' not in res:
|
||||
raise ValueError("The API response did not contain rerank results.")
|
||||
|
||||
# Convert the API response into a list of Document objects with relevance scores
|
||||
ranked_documents = [
|
||||
Document(page_content=d['text'], metadata={'relevance_score': d['relevance_score']})
|
||||
for d in res['results']
|
||||
]
|
||||
return ranked_documents
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
print(f"Error during API request: {e}")
|
||||
return [] # Return an empty list if the request failed
|
||||
|
|
|
|||
Loading…
Reference in New Issue