diff --git a/apps/models_provider/impl/aliyun_bai_lian_model_provider/aliyun_bai_lian_model_provider.py b/apps/models_provider/impl/aliyun_bai_lian_model_provider/aliyun_bai_lian_model_provider.py index c92c23ed9..af4677d3d 100644 --- a/apps/models_provider/impl/aliyun_bai_lian_model_provider/aliyun_bai_lian_model_provider.py +++ b/apps/models_provider/impl/aliyun_bai_lian_model_provider/aliyun_bai_lian_model_provider.py @@ -50,7 +50,7 @@ qwentti_model_credential = QwenTextToImageModelCredential() aliyun_bai_lian_ttv_model_credential = TextToVideoModelCredential() aliyun_bai_lian_itv_model_credential = ImageToVideoModelCredential() -model_info_list = [ModelInfo('gte-rerank', +model_info_list = [ModelInfo('gte-rerank-v2', _('With the GTE-Rerank text sorting series model developed by Alibaba Tongyi Lab, developers can integrate high-quality text retrieval and sorting through the LlamaIndex framework.'), ModelTypeConst.RERANKER, aliyun_bai_lian_model_credential, AliyunBaiLianReranker), ModelInfo('paraformer-realtime-v2', diff --git a/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/reranker.py b/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/reranker.py index c0d7ceedf..9b5b121a0 100644 --- a/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/reranker.py +++ b/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/reranker.py @@ -6,15 +6,54 @@ @date:2024/9/2 16:42 @desc: """ -from typing import Dict +from http import HTTPStatus +from typing import Sequence, Optional, Any, Dict -from langchain_community.document_compressors import DashScopeRerank +import dashscope +from langchain_core.callbacks import Callbacks +from langchain_core.documents import BaseDocumentCompressor, Document +from langchain_core.documents import BaseDocumentCompressor from models_provider.base_model_provider import MaxKBBaseModel -class AliyunBaiLianReranker(MaxKBBaseModel, DashScopeRerank): +class AliyunBaiLianReranker(MaxKBBaseModel, BaseDocumentCompressor): + model: Optional[str] + api_key: Optional[str] + + top_n: Optional[int] = 3 # 取前 N 个最相关的结果 + + @staticmethod + def is_cache_model(): + return False + @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): - return AliyunBaiLianReranker(model=model_name, dashscope_api_key=model_credential.get('dashscope_api_key'), + return AliyunBaiLianReranker(model=model_name, + api_key=model_credential.get('dashscope_api_key'), top_n=model_kwargs.get('top_n', 3)) + + def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \ + Sequence[Document]: + if not documents: + return [] + + texts = [doc.page_content for doc in documents] + resp = dashscope.TextReRank.call( + model=self.model, + query=query, + documents=texts, + top_n=self.top_n, + api_key=self.api_key, + return_documents=True + ) + if resp.status_code == HTTPStatus.OK: + return [ + Document( + page_content=item.get('document', {}).get('text', ''), + metadata={'relevance_score': item.get('relevance_score')} + ) + for item in resp.output.get('results', []) + ] + else: + return []