fix: xinference rerank error

--bug=1054256 --user=王孝刚 【模型】添加硅基流动的重排序模型失败 https://www.tapd.cn/57709429/s/1679612
This commit is contained in:
wxg0103 2025-04-02 16:54:45 +08:00
parent 6cf91098d6
commit 2686e76c8a

View File

@ -16,7 +16,6 @@ from setting.models_provider.base_model_provider import MaxKBBaseModel
class XInferenceReranker(MaxKBBaseModel, BaseDocumentCompressor):
client: Any
server_url: Optional[str]
"""URL of the xinference server"""
model_uid: Optional[str]
@ -30,10 +29,13 @@ class XInferenceReranker(MaxKBBaseModel, BaseDocumentCompressor):
top_n: Optional[int] = 3
def __init__(
self, server_url: Optional[str] = None, model_uid: Optional[str] = None, top_n=3,
api_key: Optional[str] = None
):
def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \
Sequence[Document]:
if documents is None or len(documents) == 0:
return []
client: Any
if documents is None or len(documents) == 0:
return []
try:
from xinference.client import RESTfulClient
except ImportError:
@ -45,29 +47,8 @@ class XInferenceReranker(MaxKBBaseModel, BaseDocumentCompressor):
" with `pip install xinference` or `pip install xinference_client`."
) from e
super().__init__()
if server_url is None:
raise ValueError("Please provide server URL")
if model_uid is None:
raise ValueError("Please provide the model UID")
self.server_url = server_url
self.model_uid = model_uid
self.api_key = api_key
self.client = RESTfulClient(server_url, api_key)
self.top_n = top_n
def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \
Sequence[Document]:
if documents is None or len(documents) == 0:
return []
model: RESTfulRerankModelHandle = self.client.get_model(self.model_uid)
client = RESTfulClient(self.server_url, self.api_key)
model: RESTfulRerankModelHandle = client.get_model(self.model_uid)
res = model.rerank([document.page_content for document in documents], query, self.top_n, return_documents=True)
return [Document(page_content=d.get('document', {}).get('text'),
metadata={'relevance_score': d.get('relevance_score')}) for d in res.get('results', [])]