feat: 支持阿里百炼Rerank模型
Some checks are pending
sync2gitee / repo-sync (push) Waiting to run
Typos Check / Spell Check with Typos (push) Waiting to run

This commit is contained in:
shaohuzhang1 2024-09-09 18:45:44 +08:00 committed by shaohuzhang1
parent 7e01c600ae
commit 277ed17f93
7 changed files with 118 additions and 1 deletions

View File

@ -213,7 +213,8 @@ class ChatMessageSerializer(serializers.Serializer):
self.is_valid_intraday_access_num()
def is_valid_chat_id(self, chat_info: ChatInfo):
if self.data.get('application_id') != str(chat_info.application.id):
if self.data.get('application_id') is not None and self.data.get('application_id') != str(
chat_info.application.id):
raise ChatException(500, "会话不存在")
def is_valid_intraday_access_num(self):

View File

@ -8,6 +8,8 @@
"""
from enum import Enum
from setting.models_provider.impl.aliyun_bai_lian_model_provider.aliyun_bai_lian_model_provider import \
AliyunBaiLianModelProvider
from setting.models_provider.impl.aws_bedrock_model_provider.aws_bedrock_model_provider import BedrockModelProvider
from setting.models_provider.impl.azure_model_provider.azure_model_provider import AzureModelProvider
from setting.models_provider.impl.deepseek_model_provider.deepseek_model_provider import DeepSeekModelProvider
@ -44,3 +46,4 @@ class ModelProvideConstants(Enum):
model_local_provider = LocalModelProvider()
model_xinference_provider = XinferenceModelProvider()
model_vllm_provider = VllmModelProvider()
aliyun_bai_lian_model_provider = AliyunBaiLianModelProvider()

View File

@ -0,0 +1,8 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file __init__.py
@date2024/9/9 17:42
@desc:
"""

View File

@ -0,0 +1,37 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file aliyun_bai_lian_model_provider.py
@date2024/9/9 17:43
@desc:
"""
import os
from common.util.file_util import get_file_content
from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, ModelInfo, IModelProvider, \
ModelInfoManage
from setting.models_provider.impl.aliyun_bai_lian_model_provider.credential.reranker import \
AliyunBaiLianRerankerCredential
from setting.models_provider.impl.aliyun_bai_lian_model_provider.model.reranker import AliyunBaiLianReranker
from smartdoc.conf import PROJECT_DIR
aliyun_bai_lian_model_credential = AliyunBaiLianRerankerCredential()
model_info_list = [ModelInfo('gte-rerank',
'阿里巴巴通义实验室开发的GTE-Rerank文本排序系列模型开发者可以通过LlamaIndex框架进行集成高质量文本检索、排序。',
ModelTypeConst.RERANKER, aliyun_bai_lian_model_credential, AliyunBaiLianReranker)
]
model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).build()
class AliyunBaiLianModelProvider(IModelProvider):
def get_model_info_manage(self):
return model_info_manage
def get_model_provide_info(self):
return ModelProvideInfo(provider='aliyun_bai_lian_model_provider', name='阿里百炼大模型', icon=get_file_content(
os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'aliyun_bai_lian_model_provider',
'icon',
'aliyun_bai_lian_icon_svg')))

View File

@ -0,0 +1,47 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file reranker.py
@date2024/9/9 17:51
@desc:
"""
from typing import Dict
from langchain_core.documents import Document
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
from setting.models_provider.impl.aliyun_bai_lian_model_provider.model.reranker import AliyunBaiLianReranker
class AliyunBaiLianRerankerCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
raise_exception=False):
if not model_type == 'RERANKER':
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
for key in ['dashscope_api_key']:
if key not in model_credential:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
else:
return False
try:
model: AliyunBaiLianReranker = provider.get_model(model_type, model_name, model_credential)
model.compress_documents([Document(page_content='你好')], '你好')
except Exception as e:
if isinstance(e, AppApiException):
raise e
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
else:
return False
return True
def encryption_dict(self, model: Dict[str, object]):
return model
dashscope_api_key = forms.PasswordInputField('API Key', required=True)

View File

@ -0,0 +1 @@
<svg t="1725878508761" class="icon" viewBox="0 0 1024 1024" version="1.1" xmlns="http://www.w3.org/2000/svg" p-id="9190" width="100%" height="100%"><path d="M908.68 766.88L515.55 995.34 119.86 766.88l0.59-510.09L514.21 28.98l393.88 228.49 0.59 509.41z" fill="#FFFFFF" p-id="9191"></path><path d="M908.68 766.88L515.55 995.34 119.86 766.88l393.12-228.46 395.7 228.46z" fill="#A1C2F9" p-id="9192"></path><path d="M214.29 99.17l-47.06 27.35-47.37-27.35 47.06-27.34 47.37 27.34z" fill="#A1C2F9" opacity=".8" p-id="9193"></path><path d="M214.29 99.17v54.7l-47.06 27.09v-54.44l47.06-27.35z" fill="#A1C2F9" p-id="9194"></path><path d="M167.23 126.53v54.43l-47.37-27.26V99.18l47.37 27.35z" fill="#A1C2F9" opacity=".6" p-id="9195"></path><path d="M742.65 510.46v-17.57C742.65 338.31 640.77 254 514.27 254s-228.38 84.31-228.38 238.89v24.63c0 7.06 0 10.51 3.53 14 14.13 144.07 108.94 284.52 224.85 284.52S725.08 672 739.12 531.56c3.53-7.07 3.53-10.51 3.53-14z m-210.82 59.76v-105.4c108.94 3.53 168.62 31.6 175.68 49.17v3.53c-7 21.1-66.74 49.17-175.68 52.7z m-207.19-49.17V514c7.07-17.57 66.74-45.64 172.15-49.17v105.4c-102-3.53-165.17-31.6-172.15-49.17zM704 468.35c-38.66-21.1-105.4-35.14-172.15-35.14V289.14C630.27 296.2 697 366.47 704 468.35zM496.7 289.14v140.54c-66.74 3.53-133.48 14-172.14 35.14 10.59-98.35 73.8-165.17 172.14-175.68zM331.62 570.22c38.66 21.11 101.87 31.61 165.08 35.14V781C423 766.91 356.25 675.63 331.62 570.22zM531.83 781V605.27c63.21 0 126.51-14 165.09-35.13-21 108.93-91.28 196.77-165.09 210.81z m0 0" fill="#366FD9" p-id="9196"></path><path d="M515.61 485.95v508.7L119.87 766.92l-0.01-509.45 395.75 228.48z" fill="#D2E3FE" opacity=".6" p-id="9197"></path><path d="M909.96 257.47L516.78 485.95 121.04 257.47 514.21 28.98l395.75 228.49z" fill="#366FD9" opacity=".8" p-id="9198"></path><path d="M909.96 258.28v510.8L516.78 995.34l0.01-509.32 393.17-227.74z" fill="#548FEF" opacity=".8" p-id="9199"></path><path d="M519.96 722.89l-2.61-4.84 386.86-207.8 2.61 4.84-386.86 207.8zM777.18 440.27l45-29c3.17-2 5.64-0.05 5.3 4.76l0.19 66.88a15.86 15.86 0 0 1-1.61 6.84 11.69 11.69 0 0 1-4 4.89c-1.5 1-3 1-4 0.16s-1.65-2.59-1.64-4.77l-0.16-58.63-39 25c-3.12 2-5.9 0.23-5.56-4.55a16.06 16.06 0 0 1 1.59-6.79 11.6 11.6 0 0 1 3.89-4.79zM777.74 703l39.55-24-0.16-56.39a16 16 0 0 1 1.61-6.83 11.42 11.42 0 0 1 4-4.83c1.5-1 3-1 4-0.09s1.66 2.64 1.65 4.84l0.18 64.69a15.8 15.8 0 0 1-1.61 6.83 11.28 11.28 0 0 1-4 4.8l-45.2 27.23c-2 1.2-3.83 0.75-4.83-1.17s-1-5 0-8.16 2.82-5.77 4.81-6.92zM635 531.72l47.19-30.37c1.44-1 2.82-1 3.84-0.16s1.58 2.55 1.57 4.69a15.89 15.89 0 0 1-1.56 6.7 11.24 11.24 0 0 1-3.83 4.77L640 544.41v57.22a15.87 15.87 0 0 1-1.54 6.64 11 11 0 0 1-3.79 4.69c-1.41 0.92-2.78 1-3.78 0.09s-1.56-2.56-1.54-4.69v-65.07c0.04-4.47 2.37-9.43 5.65-11.57zM629.56 737.65c0-4.41 2.38-9.45 5.33-11.26s5.34 0.3 5.34 4.72v55l42-25.43c1.94-1.17 3.73-0.72 4.7 1.19s1 5 0 8.07-2.75 5.69-4.68 6.86l-47.72 28.75c-3 1.81-5.33-0.29-5-5z" fill="#FFFFFF" p-id="9200"></path><path d="M974.19 758.32l-86.41 50.21-86.98-50.21 86.41-50.22 86.98 50.22z" fill="#A1C2F9" opacity=".8" p-id="9201"></path><path d="M974.19 758.32v100.43l-86.41 49.73v-99.95l86.41-50.21z" fill="#548FEF" opacity=".6" p-id="9202"></path><path d="M887.78 808.54v99.94l-86.98-50.06v-100.1l86.98 50.22z" fill="#A1C2F9" opacity=".6" p-id="9203"></path><path d="M515.07 724.05L117.73 503.36l2.68-4.81 397.34 220.7-2.68 4.8zM258 448.33l-44.26-25.69c-3.09-1.79-5.49 0.29-5.15 5l-0.06 65a15 15 0 0 0 1.58 6.57 10.5 10.5 0 0 0 3.9 4.56 3.18 3.18 0 0 0 3.91 0c1-0.9 1.61-2.6 1.59-4.73v-57.1l38.43 22.3c3.09 1.8 5.83-0.09 5.49-4.75a15.06 15.06 0 0 0-1.58-6.57 10.5 10.5 0 0 0-3.85-4.59zM257.77 705.12L219 682.61V627.9a15.1 15.1 0 0 0-1.58-6.58 10.61 10.61 0 0 0-3.91-4.57 3.19 3.19 0 0 0-3.91 0c-1 0.9-1.61 2.6-1.59 4.73L208 684.2a15.13 15.13 0 0 0 1.58 6.58 10.67 10.67 0 0 0 3.91 4.57L257.75 721c2 1.14 3.78 0.67 4.77-1.22s1-4.93 0-8-2.79-5.53-4.75-6.66zM401.56 531.68l-48.35-28.06a3.16 3.16 0 0 0-3.9 0c-1 0.9-1.6 2.6-1.59 4.73a15.17 15.17 0 0 0 1.59 6.56 10.5 10.5 0 0 0 3.9 4.56l43.2 25.08v57.1a15 15 0 0 0 1.59 6.57 10.53 10.53 0 0 0 3.9 4.57 3.18 3.18 0 0 0 3.9 0c1-0.89 1.6-2.6 1.59-4.73V543c0-4.46-2.39-9.33-5.83-11.32zM407.22 737.05c0-4.39-2.46-9.38-5.5-11.14s-5.49 0.38-5.49 4.77v54.72l-42.91-24.88c-2-1.14-3.78-0.67-4.76 1.22s-1 4.93 0 8 2.79 5.61 4.75 6.74l48.75 28.27c3.1 1.79 5.5-0.3 5.16-5z" fill="#366FD9" p-id="9204"></path></svg>

After

Width:  |  Height:  |  Size: 4.3 KiB

View File

@ -0,0 +1,20 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file reranker.py.py
@date2024/9/2 16:42
@desc:
"""
from typing import Dict
from langchain_community.document_compressors import DashScopeRerank
from setting.models_provider.base_model_provider import MaxKBBaseModel
class AliyunBaiLianReranker(MaxKBBaseModel, DashScopeRerank):
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
return AliyunBaiLianReranker(model=model_name, dashscope_api_key=model_credential.get('dashscope_api_key'),
top_n=model_kwargs.get('top_n', 3))