feat: support siliconCloud rerank

This commit is contained in:
wxg0103 2025-02-06 11:53:19 +08:00 committed by wxg
parent 8957b77d55
commit 1234096678
17 changed files with 100 additions and 29 deletions

View File

@ -28,7 +28,7 @@ from setting.models_provider.impl.aliyun_bai_lian_model_provider.model.stt impor
from setting.models_provider.impl.aliyun_bai_lian_model_provider.model.tti import QwenTextToImageModel
from setting.models_provider.impl.aliyun_bai_lian_model_provider.model.tts import AliyunBaiLianTextToSpeech
from smartdoc.conf import PROJECT_DIR
from django.utils.translation import gettext_lazy as _, gettext
from django.utils.translation import gettext as _, gettext
aliyun_bai_lian_model_credential = AliyunBaiLianRerankerCredential()
aliyun_bai_lian_tts_model_credential = AliyunBaiLianTTSModelCredential()

View File

@ -11,7 +11,7 @@ from setting.models_provider.impl.aws_bedrock_model_provider.credential.llm impo
from setting.models_provider.impl.aws_bedrock_model_provider.model.embedding import BedrockEmbeddingModel
from setting.models_provider.impl.aws_bedrock_model_provider.model.llm import BedrockModel
from smartdoc.conf import PROJECT_DIR
from django.utils.translation import gettext_lazy as _
from django.utils.translation import gettext as _
def _create_model_info(model_name, description, model_type, credential_class, model_class):

View File

@ -14,7 +14,7 @@ from setting.models_provider.base_model_provider import IModelProvider, ModelPro
from setting.models_provider.impl.deepseek_model_provider.credential.llm import DeepSeekLLMModelCredential
from setting.models_provider.impl.deepseek_model_provider.model.llm import DeepSeekChatModel
from smartdoc.conf import PROJECT_DIR
from django.utils.translation import gettext_lazy as _
from django.utils.translation import gettext as _
deepseek_llm_model_credential = DeepSeekLLMModelCredential()
deepseek_chat = ModelInfo('deepseek-chat', _('Good at common conversational tasks, supports 32K contexts'), ModelTypeConst.LLM,

View File

@ -20,7 +20,7 @@ from setting.models_provider.impl.gemini_model_provider.model.image import Gemin
from setting.models_provider.impl.gemini_model_provider.model.llm import GeminiChatModel
from setting.models_provider.impl.gemini_model_provider.model.stt import GeminiSpeechToText
from smartdoc.conf import PROJECT_DIR
from django.utils.translation import gettext_lazy as _
from django.utils.translation import gettext as _
gemini_llm_model_credential = GeminiLLMModelCredential()

View File

@ -16,7 +16,7 @@ from setting.models_provider.impl.local_model_provider.credential.reranker impor
from setting.models_provider.impl.local_model_provider.model.embedding import LocalEmbedding
from setting.models_provider.impl.local_model_provider.model.reranker import LocalReranker
from smartdoc.conf import PROJECT_DIR
from django.utils.translation import gettext_lazy as _
from django.utils.translation import gettext as _
embedding_text2vec_base_chinese = ModelInfo('shibing624/text2vec-base-chinese', '', ModelTypeConst.EMBEDDING,
LocalEmbeddingCredential(), LocalEmbedding)

View File

@ -22,7 +22,7 @@ from setting.models_provider.impl.ollama_model_provider.model.embedding import O
from setting.models_provider.impl.ollama_model_provider.model.image import OllamaImage
from setting.models_provider.impl.ollama_model_provider.model.llm import OllamaChatModel
from smartdoc.conf import PROJECT_DIR
from django.utils.translation import gettext_lazy as _
from django.utils.translation import gettext as _
""

View File

@ -19,7 +19,7 @@ from setting.models_provider.impl.qwen_model_provider.model.image import QwenVLC
from setting.models_provider.impl.qwen_model_provider.model.llm import QwenChatModel
from setting.models_provider.impl.qwen_model_provider.model.tti import QwenTextToImageModel
from smartdoc.conf import PROJECT_DIR
from django.utils.translation import gettext_lazy as _
from django.utils.translation import gettext as _
qwen_model_credential = OpenAILLMModelCredential()
qwenvl_model_credential = QwenVLModelCredential()

View File

@ -15,7 +15,6 @@ from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
from setting.models_provider.impl.aliyun_bai_lian_model_provider.model.reranker import AliyunBaiLianReranker
from setting.models_provider.impl.siliconCloud_model_provider.model.reranker import SiliconCloudReranker
@ -26,7 +25,7 @@ class SiliconCloudRerankerCredential(BaseForm, BaseModelCredential):
if not model_type == 'RERANKER':
raise AppApiException(ValidCode.valid_error.value,
_('{model_type} Model type is not supported').format(model_type=model_type))
for key in ['dashscope_api_key']:
for key in ['api_base', 'api_key']:
if key not in model_credential:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, _('{key} is required').format(key=key))
@ -47,6 +46,6 @@ class SiliconCloudRerankerCredential(BaseForm, BaseModelCredential):
return True
def encryption_dict(self, model: Dict[str, object]):
return {**model, 'dashscope_api_key': super().encryption(model.get('dashscope_api_key', ''))}
dashscope_api_key = forms.PasswordInputField('API Key', required=True)
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
api_base = forms.TextInputField('API URL', required=True)
api_key = forms.PasswordInputField('API Key', required=True)

View File

@ -2,19 +2,92 @@
"""
@project: MaxKB
@Author
@file reranker.py.py
@date2024/9/2 16:42
@desc:
@file siliconcloud_reranker.py
@date2024/9/10 9:45
@desc: SiliconCloud 文档重排封装
"""
from typing import Dict
from langchain_community.document_compressors import DashScopeRerank
from typing import Sequence, Optional, Any, Dict
import requests
from langchain_core.callbacks import Callbacks
from langchain_core.documents import BaseDocumentCompressor, Document
from setting.models_provider.base_model_provider import MaxKBBaseModel
from django.utils.translation import gettext as _
class SiliconCloudReranker(MaxKBBaseModel, DashScopeRerank):
class SiliconCloudReranker(MaxKBBaseModel, BaseDocumentCompressor):
api_base: Optional[str]
"""SiliconCloud API URL"""
model: Optional[str]
"""SiliconCloud 重排模型 ID"""
api_key: Optional[str]
"""API Key"""
top_n: Optional[int] = 3 # 取前 N 个最相关的结果
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
return SiliconCloudReranker(model=model_name, dashscope_api_key=model_credential.get('dashscope_api_key'),
top_n=model_kwargs.get('top_n', 3))
return SiliconCloudReranker(
api_base=model_credential.get('api_base'),
model=model_name,
api_key=model_credential.get('api_key'),
top_n=model_kwargs.get('top_n', 3)
)
def __init__(
self, api_base: Optional[str] = None, model: Optional[str] = None, top_n=3,
api_key: Optional[str] = None
):
super().__init__()
if not api_base:
raise ValueError(_('Please provide server URL'))
if not model:
raise ValueError(_('Please provide the model'))
if not api_key:
raise ValueError(_('Please provide the API Key'))
self.api_base = api_base
self.model = model
self.api_key = api_key
self.top_n = top_n
def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \
Sequence[Document]:
if not documents:
return []
# 预处理文本
texts = [doc.page_content for doc in documents]
# 发送请求到 SiliconCloud API
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
payload = {
"model": self.model,
"query": query,
"documents": texts,
"top_n": self.top_n
}
response = requests.post(f"{self.api_base}/rerank", json=payload, headers=headers)
if response.status_code != 200:
raise RuntimeError(f"SiliconCloud API 请求失败: {response.text}")
res = response.json()
# 解析返回结果
return [
Document(
page_content=item.get('document', {}).get('text', ''),
metadata={'relevance_score': item.get('relevance_score')}
)
for item in res.get('results', [])
]

View File

@ -24,7 +24,7 @@ from setting.models_provider.impl.siliconCloud_model_provider.model.reranker imp
from setting.models_provider.impl.siliconCloud_model_provider.model.stt import SiliconCloudSpeechToText
from setting.models_provider.impl.siliconCloud_model_provider.model.tti import SiliconCloudTextToImage
from smartdoc.conf import PROJECT_DIR
from django.utils.translation import gettext_lazy as _
from django.utils.translation import gettext as _
openai_llm_model_credential = SiliconCloudLLMModelCredential()
openai_stt_model_credential = SiliconCloudSTTModelCredential()

View File

@ -15,7 +15,7 @@ from setting.models_provider.impl.tencent_model_provider.model.image import Tenc
from setting.models_provider.impl.tencent_model_provider.model.llm import TencentModel
from setting.models_provider.impl.tencent_model_provider.model.tti import TencentTextToImageModel
from smartdoc.conf import PROJECT_DIR
from django.utils.translation import gettext_lazy as _
from django.utils.translation import gettext as _
def _create_model_info(model_name, description, model_type, credential_class, model_class):
return ModelInfo(

View File

@ -14,7 +14,7 @@ from setting.models_provider.impl.vllm_model_provider.model.embedding import Vll
from setting.models_provider.impl.vllm_model_provider.model.image import VllmImage
from setting.models_provider.impl.vllm_model_provider.model.llm import VllmChatModel
from smartdoc.conf import PROJECT_DIR
from django.utils.translation import gettext_lazy as _
from django.utils.translation import gettext as _
v_llm_model_credential = VLLMModelCredential()
image_model_credential = VllmImageModelCredential()

View File

@ -28,7 +28,7 @@ from setting.models_provider.impl.volcanic_engine_model_provider.model.tti impor
from setting.models_provider.impl.volcanic_engine_model_provider.model.tts import VolcanicEngineTextToSpeech
from smartdoc.conf import PROJECT_DIR
from django.utils.translation import gettext_lazy as _
from django.utils.translation import gettext as _
volcanic_engine_llm_model_credential = OpenAILLMModelCredential()
volcanic_engine_stt_model_credential = VolcanicEngineSTTModelCredential()

View File

@ -16,7 +16,7 @@ from setting.models_provider.impl.wenxin_model_provider.credential.llm import We
from setting.models_provider.impl.wenxin_model_provider.model.embedding import QianfanEmbeddings
from setting.models_provider.impl.wenxin_model_provider.model.llm import QianfanChatModel
from smartdoc.conf import PROJECT_DIR
from django.utils.translation import gettext_lazy as _
from django.utils.translation import gettext as _
win_xin_llm_model_credential = WenxinLLMModelCredential()
qianfan_embedding_credential = QianfanEmbeddingCredential()

View File

@ -23,7 +23,7 @@ from setting.models_provider.impl.xf_model_provider.model.llm import XFChatSpark
from setting.models_provider.impl.xf_model_provider.model.stt import XFSparkSpeechToText
from setting.models_provider.impl.xf_model_provider.model.tts import XFSparkTextToSpeech
from smartdoc.conf import PROJECT_DIR
from django.utils.translation import gettext_lazy as _
from django.utils.translation import gettext as _
ssl._create_default_https_context = ssl.create_default_context()

View File

@ -23,8 +23,7 @@ from setting.models_provider.impl.xinference_model_provider.model.stt import XIn
from setting.models_provider.impl.xinference_model_provider.model.tti import XinferenceTextToImage
from setting.models_provider.impl.xinference_model_provider.model.tts import XInferenceTextToSpeech
from smartdoc.conf import PROJECT_DIR
from django.utils.translation import gettext_lazy as _
from django.utils.translation import gettext as _
xinference_llm_model_credential = XinferenceLLMModelCredential()
xinference_stt_model_credential = XInferenceSTTModelCredential()

View File

@ -18,7 +18,7 @@ from setting.models_provider.impl.zhipu_model_provider.model.image import ZhiPuI
from setting.models_provider.impl.zhipu_model_provider.model.llm import ZhipuChatModel
from setting.models_provider.impl.zhipu_model_provider.model.tti import ZhiPuTextToImage
from smartdoc.conf import PROJECT_DIR
from django.utils.translation import gettext_lazy as _
from django.utils.translation import gettext as _
qwen_model_credential = ZhiPuLLMModelCredential()
zhipu_image_model_credential = ZhiPuImageModelCredential()