MaxKB/apps/models_provider/impl/xinference_model_provider/model/reranker.py
2025-04-17 18:01:33 +08:00

55 lines
2.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# coding=utf-8
"""
@project: MaxKB
@Author
@file reranker.py
@date2024/9/10 9:45
@desc:
"""
from typing import Sequence, Optional, Any, Dict
from langchain_core.callbacks import Callbacks
from langchain_core.documents import BaseDocumentCompressor, Document
from xinference_client.client.restful.restful_client import RESTfulRerankModelHandle
from models_provider.base_model_provider import MaxKBBaseModel
class XInferenceReranker(MaxKBBaseModel, BaseDocumentCompressor):
server_url: Optional[str]
"""URL of the xinference server"""
model_uid: Optional[str]
"""UID of the launched model"""
api_key: Optional[str]
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
return XInferenceReranker(server_url=model_credential.get('server_url'), model_uid=model_name,
api_key=model_credential.get('api_key'), top_n=model_kwargs.get('top_n', 3))
top_n: Optional[int] = 3
def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \
Sequence[Document]:
if documents is None or len(documents) == 0:
return []
client: Any
if documents is None or len(documents) == 0:
return []
try:
from xinference.client import RESTfulClient
except ImportError:
try:
from xinference_client import RESTfulClient
except ImportError as e:
raise ImportError(
"Could not import RESTfulClient from xinference. Please install it"
" with `pip install xinference` or `pip install xinference_client`."
) from e
client = RESTfulClient(self.server_url, self.api_key)
model: RESTfulRerankModelHandle = client.get_model(self.model_uid)
res = model.rerank([document.page_content for document in documents], query, self.top_n, return_documents=True)
return [Document(page_content=d.get('document', {}).get('text'),
metadata={'relevance_score': d.get('relevance_score')}) for d in res.get('results', [])]