mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 01:33:05 +00:00
feat: support siliconCloud rerank
This commit is contained in:
parent
8957b77d55
commit
1234096678
|
|
@ -28,7 +28,7 @@ from setting.models_provider.impl.aliyun_bai_lian_model_provider.model.stt impor
|
|||
from setting.models_provider.impl.aliyun_bai_lian_model_provider.model.tti import QwenTextToImageModel
|
||||
from setting.models_provider.impl.aliyun_bai_lian_model_provider.model.tts import AliyunBaiLianTextToSpeech
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
from django.utils.translation import gettext_lazy as _, gettext
|
||||
from django.utils.translation import gettext as _, gettext
|
||||
|
||||
aliyun_bai_lian_model_credential = AliyunBaiLianRerankerCredential()
|
||||
aliyun_bai_lian_tts_model_credential = AliyunBaiLianTTSModelCredential()
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from setting.models_provider.impl.aws_bedrock_model_provider.credential.llm impo
|
|||
from setting.models_provider.impl.aws_bedrock_model_provider.model.embedding import BedrockEmbeddingModel
|
||||
from setting.models_provider.impl.aws_bedrock_model_provider.model.llm import BedrockModel
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from django.utils.translation import gettext as _
|
||||
|
||||
|
||||
def _create_model_info(model_name, description, model_type, credential_class, model_class):
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from setting.models_provider.base_model_provider import IModelProvider, ModelPro
|
|||
from setting.models_provider.impl.deepseek_model_provider.credential.llm import DeepSeekLLMModelCredential
|
||||
from setting.models_provider.impl.deepseek_model_provider.model.llm import DeepSeekChatModel
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from django.utils.translation import gettext as _
|
||||
deepseek_llm_model_credential = DeepSeekLLMModelCredential()
|
||||
|
||||
deepseek_chat = ModelInfo('deepseek-chat', _('Good at common conversational tasks, supports 32K contexts'), ModelTypeConst.LLM,
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ from setting.models_provider.impl.gemini_model_provider.model.image import Gemin
|
|||
from setting.models_provider.impl.gemini_model_provider.model.llm import GeminiChatModel
|
||||
from setting.models_provider.impl.gemini_model_provider.model.stt import GeminiSpeechToText
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from django.utils.translation import gettext as _
|
||||
|
||||
|
||||
gemini_llm_model_credential = GeminiLLMModelCredential()
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ from setting.models_provider.impl.local_model_provider.credential.reranker impor
|
|||
from setting.models_provider.impl.local_model_provider.model.embedding import LocalEmbedding
|
||||
from setting.models_provider.impl.local_model_provider.model.reranker import LocalReranker
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from django.utils.translation import gettext as _
|
||||
|
||||
embedding_text2vec_base_chinese = ModelInfo('shibing624/text2vec-base-chinese', '', ModelTypeConst.EMBEDDING,
|
||||
LocalEmbeddingCredential(), LocalEmbedding)
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ from setting.models_provider.impl.ollama_model_provider.model.embedding import O
|
|||
from setting.models_provider.impl.ollama_model_provider.model.image import OllamaImage
|
||||
from setting.models_provider.impl.ollama_model_provider.model.llm import OllamaChatModel
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from django.utils.translation import gettext as _
|
||||
|
||||
""
|
||||
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ from setting.models_provider.impl.qwen_model_provider.model.image import QwenVLC
|
|||
from setting.models_provider.impl.qwen_model_provider.model.llm import QwenChatModel
|
||||
from setting.models_provider.impl.qwen_model_provider.model.tti import QwenTextToImageModel
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from django.utils.translation import gettext as _
|
||||
|
||||
qwen_model_credential = OpenAILLMModelCredential()
|
||||
qwenvl_model_credential = QwenVLModelCredential()
|
||||
|
|
|
|||
|
|
@ -15,7 +15,6 @@ from common import forms
|
|||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm
|
||||
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
from setting.models_provider.impl.aliyun_bai_lian_model_provider.model.reranker import AliyunBaiLianReranker
|
||||
from setting.models_provider.impl.siliconCloud_model_provider.model.reranker import SiliconCloudReranker
|
||||
|
||||
|
||||
|
|
@ -26,7 +25,7 @@ class SiliconCloudRerankerCredential(BaseForm, BaseModelCredential):
|
|||
if not model_type == 'RERANKER':
|
||||
raise AppApiException(ValidCode.valid_error.value,
|
||||
_('{model_type} Model type is not supported').format(model_type=model_type))
|
||||
for key in ['dashscope_api_key']:
|
||||
for key in ['api_base', 'api_key']:
|
||||
if key not in model_credential:
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, _('{key} is required').format(key=key))
|
||||
|
|
@ -47,6 +46,6 @@ class SiliconCloudRerankerCredential(BaseForm, BaseModelCredential):
|
|||
return True
|
||||
|
||||
def encryption_dict(self, model: Dict[str, object]):
|
||||
return {**model, 'dashscope_api_key': super().encryption(model.get('dashscope_api_key', ''))}
|
||||
|
||||
dashscope_api_key = forms.PasswordInputField('API Key', required=True)
|
||||
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||
api_base = forms.TextInputField('API URL', required=True)
|
||||
api_key = forms.PasswordInputField('API Key', required=True)
|
||||
|
|
|
|||
|
|
@ -2,19 +2,92 @@
|
|||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: reranker.py.py
|
||||
@date:2024/9/2 16:42
|
||||
@desc:
|
||||
@file: siliconcloud_reranker.py
|
||||
@date:2024/9/10 9:45
|
||||
@desc: SiliconCloud 文档重排封装
|
||||
"""
|
||||
from typing import Dict
|
||||
|
||||
from langchain_community.document_compressors import DashScopeRerank
|
||||
from typing import Sequence, Optional, Any, Dict
|
||||
import requests
|
||||
|
||||
from langchain_core.callbacks import Callbacks
|
||||
from langchain_core.documents import BaseDocumentCompressor, Document
|
||||
|
||||
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||
from django.utils.translation import gettext as _
|
||||
|
||||
|
||||
class SiliconCloudReranker(MaxKBBaseModel, DashScopeRerank):
|
||||
class SiliconCloudReranker(MaxKBBaseModel, BaseDocumentCompressor):
|
||||
api_base: Optional[str]
|
||||
"""SiliconCloud API URL"""
|
||||
model: Optional[str]
|
||||
"""SiliconCloud 重排模型 ID"""
|
||||
api_key: Optional[str]
|
||||
"""API Key"""
|
||||
|
||||
top_n: Optional[int] = 3 # 取前 N 个最相关的结果
|
||||
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
return SiliconCloudReranker(model=model_name, dashscope_api_key=model_credential.get('dashscope_api_key'),
|
||||
top_n=model_kwargs.get('top_n', 3))
|
||||
return SiliconCloudReranker(
|
||||
api_base=model_credential.get('api_base'),
|
||||
model=model_name,
|
||||
api_key=model_credential.get('api_key'),
|
||||
top_n=model_kwargs.get('top_n', 3)
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self, api_base: Optional[str] = None, model: Optional[str] = None, top_n=3,
|
||||
api_key: Optional[str] = None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if not api_base:
|
||||
raise ValueError(_('Please provide server URL'))
|
||||
|
||||
if not model:
|
||||
raise ValueError(_('Please provide the model'))
|
||||
|
||||
if not api_key:
|
||||
raise ValueError(_('Please provide the API Key'))
|
||||
|
||||
self.api_base = api_base
|
||||
self.model = model
|
||||
self.api_key = api_key
|
||||
self.top_n = top_n
|
||||
|
||||
def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \
|
||||
Sequence[Document]:
|
||||
if not documents:
|
||||
return []
|
||||
|
||||
# 预处理文本
|
||||
texts = [doc.page_content for doc in documents]
|
||||
|
||||
# 发送请求到 SiliconCloud API
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"query": query,
|
||||
"documents": texts,
|
||||
"top_n": self.top_n
|
||||
}
|
||||
|
||||
response = requests.post(f"{self.api_base}/rerank", json=payload, headers=headers)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise RuntimeError(f"SiliconCloud API 请求失败: {response.text}")
|
||||
|
||||
res = response.json()
|
||||
|
||||
# 解析返回结果
|
||||
return [
|
||||
Document(
|
||||
page_content=item.get('document', {}).get('text', ''),
|
||||
metadata={'relevance_score': item.get('relevance_score')}
|
||||
)
|
||||
for item in res.get('results', [])
|
||||
]
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ from setting.models_provider.impl.siliconCloud_model_provider.model.reranker imp
|
|||
from setting.models_provider.impl.siliconCloud_model_provider.model.stt import SiliconCloudSpeechToText
|
||||
from setting.models_provider.impl.siliconCloud_model_provider.model.tti import SiliconCloudTextToImage
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from django.utils.translation import gettext as _
|
||||
|
||||
openai_llm_model_credential = SiliconCloudLLMModelCredential()
|
||||
openai_stt_model_credential = SiliconCloudSTTModelCredential()
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ from setting.models_provider.impl.tencent_model_provider.model.image import Tenc
|
|||
from setting.models_provider.impl.tencent_model_provider.model.llm import TencentModel
|
||||
from setting.models_provider.impl.tencent_model_provider.model.tti import TencentTextToImageModel
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from django.utils.translation import gettext as _
|
||||
|
||||
def _create_model_info(model_name, description, model_type, credential_class, model_class):
|
||||
return ModelInfo(
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from setting.models_provider.impl.vllm_model_provider.model.embedding import Vll
|
|||
from setting.models_provider.impl.vllm_model_provider.model.image import VllmImage
|
||||
from setting.models_provider.impl.vllm_model_provider.model.llm import VllmChatModel
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from django.utils.translation import gettext as _
|
||||
|
||||
v_llm_model_credential = VLLMModelCredential()
|
||||
image_model_credential = VllmImageModelCredential()
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ from setting.models_provider.impl.volcanic_engine_model_provider.model.tti impor
|
|||
from setting.models_provider.impl.volcanic_engine_model_provider.model.tts import VolcanicEngineTextToSpeech
|
||||
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from django.utils.translation import gettext as _
|
||||
|
||||
volcanic_engine_llm_model_credential = OpenAILLMModelCredential()
|
||||
volcanic_engine_stt_model_credential = VolcanicEngineSTTModelCredential()
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ from setting.models_provider.impl.wenxin_model_provider.credential.llm import We
|
|||
from setting.models_provider.impl.wenxin_model_provider.model.embedding import QianfanEmbeddings
|
||||
from setting.models_provider.impl.wenxin_model_provider.model.llm import QianfanChatModel
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from django.utils.translation import gettext as _
|
||||
|
||||
win_xin_llm_model_credential = WenxinLLMModelCredential()
|
||||
qianfan_embedding_credential = QianfanEmbeddingCredential()
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ from setting.models_provider.impl.xf_model_provider.model.llm import XFChatSpark
|
|||
from setting.models_provider.impl.xf_model_provider.model.stt import XFSparkSpeechToText
|
||||
from setting.models_provider.impl.xf_model_provider.model.tts import XFSparkTextToSpeech
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from django.utils.translation import gettext as _
|
||||
|
||||
ssl._create_default_https_context = ssl.create_default_context()
|
||||
|
||||
|
|
|
|||
|
|
@ -23,8 +23,7 @@ from setting.models_provider.impl.xinference_model_provider.model.stt import XIn
|
|||
from setting.models_provider.impl.xinference_model_provider.model.tti import XinferenceTextToImage
|
||||
from setting.models_provider.impl.xinference_model_provider.model.tts import XInferenceTextToSpeech
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from django.utils.translation import gettext as _
|
||||
|
||||
xinference_llm_model_credential = XinferenceLLMModelCredential()
|
||||
xinference_stt_model_credential = XInferenceSTTModelCredential()
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ from setting.models_provider.impl.zhipu_model_provider.model.image import ZhiPuI
|
|||
from setting.models_provider.impl.zhipu_model_provider.model.llm import ZhipuChatModel
|
||||
from setting.models_provider.impl.zhipu_model_provider.model.tti import ZhiPuTextToImage
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from django.utils.translation import gettext as _
|
||||
|
||||
qwen_model_credential = ZhiPuLLMModelCredential()
|
||||
zhipu_image_model_credential = ZhiPuImageModelCredential()
|
||||
|
|
|
|||
Loading…
Reference in New Issue