mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 01:33:05 +00:00
fix: xinference rerank error
--bug=1054256 --user=王孝刚 【模型】添加硅基流动的重排序模型失败 https://www.tapd.cn/57709429/s/1679612
This commit is contained in:
parent
6cf91098d6
commit
2686e76c8a
|
|
@ -16,7 +16,6 @@ from setting.models_provider.base_model_provider import MaxKBBaseModel
|
|||
|
||||
|
||||
class XInferenceReranker(MaxKBBaseModel, BaseDocumentCompressor):
|
||||
client: Any
|
||||
server_url: Optional[str]
|
||||
"""URL of the xinference server"""
|
||||
model_uid: Optional[str]
|
||||
|
|
@ -30,10 +29,13 @@ class XInferenceReranker(MaxKBBaseModel, BaseDocumentCompressor):
|
|||
|
||||
top_n: Optional[int] = 3
|
||||
|
||||
def __init__(
|
||||
self, server_url: Optional[str] = None, model_uid: Optional[str] = None, top_n=3,
|
||||
api_key: Optional[str] = None
|
||||
):
|
||||
def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \
|
||||
Sequence[Document]:
|
||||
if documents is None or len(documents) == 0:
|
||||
return []
|
||||
client: Any
|
||||
if documents is None or len(documents) == 0:
|
||||
return []
|
||||
try:
|
||||
from xinference.client import RESTfulClient
|
||||
except ImportError:
|
||||
|
|
@ -45,29 +47,8 @@ class XInferenceReranker(MaxKBBaseModel, BaseDocumentCompressor):
|
|||
" with `pip install xinference` or `pip install xinference_client`."
|
||||
) from e
|
||||
|
||||
super().__init__()
|
||||
|
||||
if server_url is None:
|
||||
raise ValueError("Please provide server URL")
|
||||
|
||||
if model_uid is None:
|
||||
raise ValueError("Please provide the model UID")
|
||||
|
||||
self.server_url = server_url
|
||||
|
||||
self.model_uid = model_uid
|
||||
|
||||
self.api_key = api_key
|
||||
|
||||
self.client = RESTfulClient(server_url, api_key)
|
||||
|
||||
self.top_n = top_n
|
||||
|
||||
def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \
|
||||
Sequence[Document]:
|
||||
if documents is None or len(documents) == 0:
|
||||
return []
|
||||
model: RESTfulRerankModelHandle = self.client.get_model(self.model_uid)
|
||||
client = RESTfulClient(self.server_url, self.api_key)
|
||||
model: RESTfulRerankModelHandle = client.get_model(self.model_uid)
|
||||
res = model.rerank([document.page_content for document in documents], query, self.top_n, return_documents=True)
|
||||
return [Document(page_content=d.get('document', {}).get('text'),
|
||||
metadata={'relevance_score': d.get('relevance_score')}) for d in res.get('results', [])]
|
||||
|
|
|
|||
Loading…
Reference in New Issue