diff --git a/apps/setting/models_provider/impl/xinference_model_provider/model/reranker.py b/apps/setting/models_provider/impl/xinference_model_provider/model/reranker.py index ed2db0f91..8820a1986 100644 --- a/apps/setting/models_provider/impl/xinference_model_provider/model/reranker.py +++ b/apps/setting/models_provider/impl/xinference_model_provider/model/reranker.py @@ -16,7 +16,6 @@ from setting.models_provider.base_model_provider import MaxKBBaseModel class XInferenceReranker(MaxKBBaseModel, BaseDocumentCompressor): - client: Any server_url: Optional[str] """URL of the xinference server""" model_uid: Optional[str] @@ -30,10 +29,13 @@ class XInferenceReranker(MaxKBBaseModel, BaseDocumentCompressor): top_n: Optional[int] = 3 - def __init__( - self, server_url: Optional[str] = None, model_uid: Optional[str] = None, top_n=3, - api_key: Optional[str] = None - ): + def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \ + Sequence[Document]: + if documents is None or len(documents) == 0: + return [] + client: Any + if documents is None or len(documents) == 0: + return [] try: from xinference.client import RESTfulClient except ImportError: @@ -45,29 +47,8 @@ class XInferenceReranker(MaxKBBaseModel, BaseDocumentCompressor): " with `pip install xinference` or `pip install xinference_client`." ) from e - super().__init__() - - if server_url is None: - raise ValueError("Please provide server URL") - - if model_uid is None: - raise ValueError("Please provide the model UID") - - self.server_url = server_url - - self.model_uid = model_uid - - self.api_key = api_key - - self.client = RESTfulClient(server_url, api_key) - - self.top_n = top_n - - def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \ - Sequence[Document]: - if documents is None or len(documents) == 0: - return [] - model: RESTfulRerankModelHandle = self.client.get_model(self.model_uid) + client = RESTfulClient(self.server_url, self.api_key) + model: RESTfulRerankModelHandle = client.get_model(self.model_uid) res = model.rerank([document.page_content for document in documents], query, self.top_n, return_documents=True) return [Document(page_content=d.get('document', {}).get('text'), metadata={'relevance_score': d.get('relevance_score')}) for d in res.get('results', [])]