feat: 支持xinference Rerank模型
Some checks are pending
sync2gitee / repo-sync (push) Waiting to run
Typos Check / Spell Check with Typos (push) Waiting to run

This commit is contained in:
shaohuzhang1 2024-09-10 19:00:17 +08:00 committed by shaohuzhang1
parent d48b51c3e0
commit 504e900edf
3 changed files with 126 additions and 1 deletions

View File

@ -0,0 +1,47 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file reranker.py
@date2024/9/10 9:46
@desc:
"""
from typing import Dict
from langchain_core.documents import Document
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
class XInferenceRerankerModelCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
raise_exception=True):
if not model_type == 'RERANKER':
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
for key in ['server_url']:
if key not in model_credential:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
model.compress_documents([Document(page_content='你好')], '你好')
except Exception as e:
if isinstance(e, AppApiException):
raise e
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
else:
return False
return True
def encryption_dict(self, model_info: Dict[str, object]):
return model_info
server_url = forms.TextInputField('API 域名', required=True)
api_key = forms.PasswordInputField('API Key', required=False)

View File

@ -0,0 +1,73 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file reranker.py
@date2024/9/10 9:45
@desc:
"""
from typing import Sequence, Optional, Any, Dict
from langchain_core.callbacks import Callbacks
from langchain_core.documents import BaseDocumentCompressor, Document
from xinference_client.client.restful.restful_client import RESTfulRerankModelHandle
from setting.models_provider.base_model_provider import MaxKBBaseModel
class XInferenceReranker(MaxKBBaseModel, BaseDocumentCompressor):
client: Any
server_url: Optional[str]
"""URL of the xinference server"""
model_uid: Optional[str]
"""UID of the launched model"""
api_key: Optional[str]
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
return XInferenceReranker(server_url=model_credential.get('server_url'), model_uid=model_name,
api_key=model_credential.get('api_key'))
top_n: Optional[int] = 3
def __init__(
self, server_url: Optional[str] = None, model_uid: Optional[str] = None, top_n=3,
api_key: Optional[str] = None
):
try:
from xinference.client import RESTfulClient
except ImportError:
try:
from xinference_client import RESTfulClient
except ImportError as e:
raise ImportError(
"Could not import RESTfulClient from xinference. Please install it"
" with `pip install xinference` or `pip install xinference_client`."
) from e
super().__init__()
if server_url is None:
raise ValueError("Please provide server URL")
if model_uid is None:
raise ValueError("Please provide the model UID")
self.server_url = server_url
self.model_uid = model_uid
self.api_key = api_key
self.client = RESTfulClient(server_url, api_key)
self.top_n = top_n
def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \
Sequence[Document]:
if documents is None or len(documents) == 0:
return []
model: RESTfulRerankModelHandle = self.client.get_model(self.model_uid)
res = model.rerank([document.page_content for document in documents], query, self.top_n, return_documents=True)
return [Document(page_content=d.get('document', {}).get('text'),
metadata={'relevance_score': d.get('relevance_score')}) for d in res.get('results', [])]

View File

@ -10,8 +10,10 @@ from setting.models_provider.base_model_provider import IModelProvider, ModelPro
from setting.models_provider.impl.xinference_model_provider.credential.embedding import \
XinferenceEmbeddingModelCredential
from setting.models_provider.impl.xinference_model_provider.credential.llm import XinferenceLLMModelCredential
from setting.models_provider.impl.xinference_model_provider.credential.reranker import XInferenceRerankerModelCredential
from setting.models_provider.impl.xinference_model_provider.model.embedding import XinferenceEmbedding
from setting.models_provider.impl.xinference_model_provider.model.llm import XinferenceChatModel
from setting.models_provider.impl.xinference_model_provider.model.reranker import XInferenceReranker
from smartdoc.conf import PROJECT_DIR
xinference_llm_model_credential = XinferenceLLMModelCredential()
@ -480,7 +482,9 @@ embedding_model_info = [
ModelInfo('text2vec-large-chinese', 'Text2Vec 的中文大型版本嵌入模型。', ModelTypeConst.EMBEDDING,
xinference_embedding_model_credential, XinferenceEmbedding),
]
rerank_list = [ModelInfo('bce-reranker-base_v1',
'发布新的重新排名器,建立在强大的 M3 和LLM GEMMA 和 MiniCPM实际上没那么大骨干上支持多语言处理和更大的输入大幅提高 BEIR、C-MTEB/Retrieval 的排名性能、MIRACL、LlamaIndex 评估',
ModelTypeConst.RERANKER, XInferenceRerankerModelCredential(), XInferenceReranker)]
model_info_manage = (ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info(
ModelInfo(
'phi3',
@ -492,6 +496,7 @@ model_info_manage = (ModelInfoManage.builder().append_model_info_list(model_info
'',
'',
ModelTypeConst.EMBEDDING, xinference_embedding_model_credential, XinferenceEmbedding))
.append_model_info_list(rerank_list).append_default_model_info(rerank_list[0])
.build())