diff --git a/apps/application/serializers/chat_message_serializers.py b/apps/application/serializers/chat_message_serializers.py index a994481ab..a570f1ff7 100644 --- a/apps/application/serializers/chat_message_serializers.py +++ b/apps/application/serializers/chat_message_serializers.py @@ -213,7 +213,8 @@ class ChatMessageSerializer(serializers.Serializer): self.is_valid_intraday_access_num() def is_valid_chat_id(self, chat_info: ChatInfo): - if self.data.get('application_id') != str(chat_info.application.id): + if self.data.get('application_id') is not None and self.data.get('application_id') != str( + chat_info.application.id): raise ChatException(500, "会话不存在") def is_valid_intraday_access_num(self): diff --git a/apps/setting/models_provider/constants/model_provider_constants.py b/apps/setting/models_provider/constants/model_provider_constants.py index 13e43a278..c471cead4 100644 --- a/apps/setting/models_provider/constants/model_provider_constants.py +++ b/apps/setting/models_provider/constants/model_provider_constants.py @@ -8,6 +8,8 @@ """ from enum import Enum +from setting.models_provider.impl.aliyun_bai_lian_model_provider.aliyun_bai_lian_model_provider import \ + AliyunBaiLianModelProvider from setting.models_provider.impl.aws_bedrock_model_provider.aws_bedrock_model_provider import BedrockModelProvider from setting.models_provider.impl.azure_model_provider.azure_model_provider import AzureModelProvider from setting.models_provider.impl.deepseek_model_provider.deepseek_model_provider import DeepSeekModelProvider @@ -44,3 +46,4 @@ class ModelProvideConstants(Enum): model_local_provider = LocalModelProvider() model_xinference_provider = XinferenceModelProvider() model_vllm_provider = VllmModelProvider() + aliyun_bai_lian_model_provider = AliyunBaiLianModelProvider() diff --git a/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/__init__.py b/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/__init__.py new file mode 100644 index 000000000..3c10c5535 --- /dev/null +++ b/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/__init__.py @@ -0,0 +1,8 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: __init__.py + @date:2024/9/9 17:42 + @desc: +""" diff --git a/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/aliyun_bai_lian_model_provider.py b/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/aliyun_bai_lian_model_provider.py new file mode 100644 index 000000000..bfc4bf2aa --- /dev/null +++ b/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/aliyun_bai_lian_model_provider.py @@ -0,0 +1,37 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: aliyun_bai_lian_model_provider.py + @date:2024/9/9 17:43 + @desc: +""" +import os + +from common.util.file_util import get_file_content +from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, ModelInfo, IModelProvider, \ + ModelInfoManage +from setting.models_provider.impl.aliyun_bai_lian_model_provider.credential.reranker import \ + AliyunBaiLianRerankerCredential +from setting.models_provider.impl.aliyun_bai_lian_model_provider.model.reranker import AliyunBaiLianReranker +from smartdoc.conf import PROJECT_DIR + +aliyun_bai_lian_model_credential = AliyunBaiLianRerankerCredential() +model_info_list = [ModelInfo('gte-rerank', + '阿里巴巴通义实验室开发的GTE-Rerank文本排序系列模型,开发者可以通过LlamaIndex框架进行集成高质量文本检索、排序。', + ModelTypeConst.RERANKER, aliyun_bai_lian_model_credential, AliyunBaiLianReranker) + ] + +model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).build() + + +class AliyunBaiLianModelProvider(IModelProvider): + + def get_model_info_manage(self): + return model_info_manage + + def get_model_provide_info(self): + return ModelProvideInfo(provider='aliyun_bai_lian_model_provider', name='阿里百炼大模型', icon=get_file_content( + os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'aliyun_bai_lian_model_provider', + 'icon', + 'aliyun_bai_lian_icon_svg'))) diff --git a/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/credential/reranker.py b/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/credential/reranker.py new file mode 100644 index 000000000..4e079dccc --- /dev/null +++ b/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/credential/reranker.py @@ -0,0 +1,47 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: reranker.py + @date:2024/9/9 17:51 + @desc: +""" +from typing import Dict + +from langchain_core.documents import Document + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode +from setting.models_provider.impl.aliyun_bai_lian_model_provider.model.reranker import AliyunBaiLianReranker + + +class AliyunBaiLianRerankerCredential(BaseForm, BaseModelCredential): + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + if not model_type == 'RERANKER': + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + for key in ['dashscope_api_key']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') + else: + return False + try: + model: AliyunBaiLianReranker = provider.get_model(model_type, model_name, model_credential) + model.compress_documents([Document(page_content='你好')], '你好') + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return model + + dashscope_api_key = forms.PasswordInputField('API Key', required=True) diff --git a/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/icon/aliyun_bai_lian_icon_svg b/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/icon/aliyun_bai_lian_icon_svg new file mode 100644 index 000000000..d77747f14 --- /dev/null +++ b/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/icon/aliyun_bai_lian_icon_svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/model/reranker.py b/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/model/reranker.py new file mode 100644 index 000000000..5c9bea4af --- /dev/null +++ b/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/model/reranker.py @@ -0,0 +1,20 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: reranker.py.py + @date:2024/9/2 16:42 + @desc: +""" +from typing import Dict + +from langchain_community.document_compressors import DashScopeRerank + +from setting.models_provider.base_model_provider import MaxKBBaseModel + + +class AliyunBaiLianReranker(MaxKBBaseModel, DashScopeRerank): + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + return AliyunBaiLianReranker(model=model_name, dashscope_api_key=model_credential.get('dashscope_api_key'), + top_n=model_kwargs.get('top_n', 3))