From 9185515660ad13af85af1b43ff8d6b7b6d6daf23 Mon Sep 17 00:00:00 2001 From: wxg0103 <727495428@qq.com> Date: Mon, 24 Feb 2025 18:35:37 +0800 Subject: [PATCH] refactor: ollama support rerank --- .../credential/reranker.py | 1 - .../ollama_model_provider/model/reranker.py | 102 ++++++------------ 2 files changed, 34 insertions(+), 69 deletions(-) diff --git a/apps/setting/models_provider/impl/ollama_model_provider/credential/reranker.py b/apps/setting/models_provider/impl/ollama_model_provider/credential/reranker.py index 7f7feff15..c2825aacb 100644 --- a/apps/setting/models_provider/impl/ollama_model_provider/credential/reranker.py +++ b/apps/setting/models_provider/impl/ollama_model_provider/credential/reranker.py @@ -64,4 +64,3 @@ class OllamaReRankModelCredential(BaseForm, BaseModelCredential): return self api_base = forms.TextInputField('API URL', required=True) - api_key = forms.TextInputField('API Key', required=True) diff --git a/apps/setting/models_provider/impl/ollama_model_provider/model/reranker.py b/apps/setting/models_provider/impl/ollama_model_provider/model/reranker.py index fd004ea01..f82c9a21a 100644 --- a/apps/setting/models_provider/impl/ollama_model_provider/model/reranker.py +++ b/apps/setting/models_provider/impl/ollama_model_provider/model/reranker.py @@ -1,82 +1,48 @@ from typing import Sequence, Optional, Any, Dict from langchain_core.callbacks import Callbacks -from langchain_core.documents import BaseDocumentCompressor, Document -import requests - +from langchain_core.documents import Document +from langchain_community.embeddings import OllamaEmbeddings from setting.models_provider.base_model_provider import MaxKBBaseModel +from sklearn.metrics.pairwise import cosine_similarity +from pydantic.v1 import BaseModel, Field -class OllamaReranker(MaxKBBaseModel, BaseDocumentCompressor): - api_base: Optional[str] - """URL of the Ollama server""" - model_name: Optional[str] - """The model name to use for reranking""" - api_key: Optional[str] +class OllamaReranker(MaxKBBaseModel, OllamaEmbeddings, BaseModel): + top_n: Optional[int] = Field(3, description="Number of top documents to return") @staticmethod - def new_instance(model_name, model_credential: Dict[str, object], **model_kwargs): - return OllamaReranker(api_base=model_credential.get('api_base'), model_name=model_name, - api_key=model_credential.get('api_key'), top_n=model_kwargs.get('top_n', 3)) - - top_n: Optional[int] = 3 - - def __init__( - self, api_base: Optional[str] = None, model_name: Optional[str] = None, top_n=3, - api_key: Optional[str] = None - ): - super().__init__() - - if api_base is None: - raise ValueError("Please provide server URL") - - if model_name is None: - raise ValueError("Please provide the model name") - - self.api_base = api_base - self.model_name = model_name - self.api_key = api_key - self.top_n = top_n + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs) + return OllamaReranker( + model=model_name, + base_url=model_credential.get('api_base'), + **optional_params + ) def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \ Sequence[Document]: - """ - Given a query and a set of documents, rerank them using Ollama API. - """ - if not documents or len(documents) == 0: - return [] + """Rank documents based on their similarity to the query. - # Prepare the data to send to Ollama API - headers = { - 'Authorization': f'Bearer {self.api_key}' # Use API key for authentication if required - } + Args: + query: The query text. + documents: The list of document texts to rank. - # Format the documents to be sent in a format understood by Ollama's API - documents_text = [document.page_content for document in documents] + Returns: + List of documents sorted by relevance to the query. + """ + # 获取查询和文档的嵌入 + query_embedding = self.embed_query(query) + documents = [doc.page_content for doc in documents] + document_embeddings = self.embed_documents(documents) + # 计算相似度 + similarities = cosine_similarity([query_embedding], document_embeddings)[0] + ranked_docs = [(doc,_) for _, doc in sorted(zip(similarities, documents), reverse=True)][:self.top_n] + return [ + Document( + page_content=doc, # 第一个值是文档内容 + metadata={'relevance_score': score} # 第二个值是相似度分数 + ) + for doc, score in ranked_docs + ] - # Make a POST request to Ollama's rerank API endpoint - payload = { - 'model': self.model_name, # Specify the model - 'query': query, - 'documents': documents_text, - 'top_n': self.top_n - } - try: - response = requests.post(f'{self.api_base}/v1/rerank', headers=headers, json=payload) - response.raise_for_status() - res = response.json() - - # Ensure the response contains expected results - if 'results' not in res: - raise ValueError("The API response did not contain rerank results.") - - # Convert the API response into a list of Document objects with relevance scores - ranked_documents = [ - Document(page_content=d['text'], metadata={'relevance_score': d['relevance_score']}) - for d in res['results'] - ] - return ranked_documents - - except requests.exceptions.RequestException as e: - print(f"Error during API request: {e}") - return [] # Return an empty list if the request failed