From 3b1de042270fa1dce128f09731d2a9b98995d94e Mon Sep 17 00:00:00 2001 From: wxg0103 <727495428@qq.com> Date: Mon, 18 Nov 2024 16:23:32 +0800 Subject: [PATCH] =?UTF-8?q?feat(=E6=A8=A1=E5=9E=8B=E7=AE=A1=E7=90=86):=20?= =?UTF-8?q?=E6=94=AF=E6=8C=81=E9=98=BF=E9=87=8C=E4=BA=91=E7=99=BE=E7=82=BC?= =?UTF-8?q?=E5=A4=A7=E6=A8=A1=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../aliyun_bai_lian_model_provider.py | 11 +++- .../credential/llm.py | 62 +++++++++++++++++++ .../model/llm.py | 22 +++++++ ui/src/views/login/index.vue | 7 ++- 4 files changed, 100 insertions(+), 2 deletions(-) create mode 100644 apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/credential/llm.py create mode 100644 apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/model/llm.py diff --git a/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/aliyun_bai_lian_model_provider.py b/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/aliyun_bai_lian_model_provider.py index f3fd75a80..2333505b6 100644 --- a/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/aliyun_bai_lian_model_provider.py +++ b/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/aliyun_bai_lian_model_provider.py @@ -13,11 +13,13 @@ from setting.models_provider.base_model_provider import ModelProvideInfo, ModelT ModelInfoManage from setting.models_provider.impl.aliyun_bai_lian_model_provider.credential.embedding import \ AliyunBaiLianEmbeddingCredential +from setting.models_provider.impl.aliyun_bai_lian_model_provider.credential.llm import BaiLianLLMModelCredential from setting.models_provider.impl.aliyun_bai_lian_model_provider.credential.reranker import \ AliyunBaiLianRerankerCredential from setting.models_provider.impl.aliyun_bai_lian_model_provider.credential.stt import AliyunBaiLianSTTModelCredential from setting.models_provider.impl.aliyun_bai_lian_model_provider.credential.tts import AliyunBaiLianTTSModelCredential from setting.models_provider.impl.aliyun_bai_lian_model_provider.model.embedding import AliyunBaiLianEmbedding +from setting.models_provider.impl.aliyun_bai_lian_model_provider.model.llm import BaiLianChatModel from setting.models_provider.impl.aliyun_bai_lian_model_provider.model.reranker import AliyunBaiLianReranker from setting.models_provider.impl.aliyun_bai_lian_model_provider.model.stt import AliyunBaiLianSpeechToText from setting.models_provider.impl.aliyun_bai_lian_model_provider.model.tts import AliyunBaiLianTextToSpeech @@ -27,6 +29,7 @@ aliyun_bai_lian_model_credential = AliyunBaiLianRerankerCredential() aliyun_bai_lian_tts_model_credential = AliyunBaiLianTTSModelCredential() aliyun_bai_lian_stt_model_credential = AliyunBaiLianSTTModelCredential() aliyun_bai_lian_embedding_model_credential = AliyunBaiLianEmbeddingCredential() +aliyun_bai_lian_llm_model_credential = BaiLianLLMModelCredential() model_info_list = [ModelInfo('gte-rerank', '阿里巴巴通义实验室开发的GTE-Rerank文本排序系列模型,开发者可以通过LlamaIndex框架进行集成高质量文本检索、排序。', @@ -41,11 +44,17 @@ model_info_list = [ModelInfo('gte-rerank', '通用文本向量,是通义实验室基于LLM底座的多语言文本统一向量模型,面向全球多个主流语种,提供高水准的向量服务,帮助开发者将文本数据快速转换为高质量的向量数据。', ModelTypeConst.EMBEDDING, aliyun_bai_lian_embedding_model_credential, AliyunBaiLianEmbedding), + ModelInfo('qwen-turbo', '', ModelTypeConst.LLM, aliyun_bai_lian_llm_model_credential, + BaiLianChatModel), + ModelInfo('qwen-plus', '', ModelTypeConst.LLM, aliyun_bai_lian_llm_model_credential, + BaiLianChatModel), + ModelInfo('qwen-max', '', ModelTypeConst.LLM, aliyun_bai_lian_llm_model_credential, + BaiLianChatModel) ] model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info( model_info_list[1]).append_default_model_info(model_info_list[2]).append_default_model_info( - model_info_list[3]).build() + model_info_list[3]).append_default_model_info(model_info_list[4]).build() class AliyunBaiLianModelProvider(IModelProvider): diff --git a/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/credential/llm.py b/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/credential/llm.py new file mode 100644 index 000000000..1c822adef --- /dev/null +++ b/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/credential/llm.py @@ -0,0 +1,62 @@ +# coding=utf-8 +from typing import Dict + +from langchain_core.messages import HumanMessage + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm, TooltipLabel +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class BaiLianLLMModelParams(BaseForm): + temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'), + required=True, default_value=0.7, + _min=0.1, + _max=1.0, + _step=0.01, + precision=2) + + max_tokens = forms.SliderField( + TooltipLabel('输出最大Tokens', '指定模型可生成的最大token个数'), + required=True, default_value=800, + _min=1, + _max=100000, + _step=1, + precision=0) + + +class BaiLianLLMModelCredential(BaseForm, BaseModelCredential): + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + + for key in ['api_base', 'api_key']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') + else: + return False + try: + model = provider.get_model(model_type, model_name, model_credential) + model.invoke([HumanMessage(content='你好')]) + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'api_key': super().encryption(model.get('api_key', ''))} + + api_base = forms.TextInputField('API 域名', required=True) + api_key = forms.PasswordInputField('API Key', required=True) + + def get_model_params_setting_form(self, model_name): + return BaiLianLLMModelParams() diff --git a/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/model/llm.py b/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/model/llm.py new file mode 100644 index 000000000..7e88102f0 --- /dev/null +++ b/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/model/llm.py @@ -0,0 +1,22 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- +from typing import List, Dict + +from setting.models_provider.base_model_provider import MaxKBBaseModel +from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI + + +class BaiLianChatModel(MaxKBBaseModel, BaseChatOpenAI): + @staticmethod + def is_cache_model(): + return False + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs) + return BaiLianChatModel( + model=model_name, + openai_api_base=model_credential.get('api_base'), + openai_api_key=model_credential.get('api_key'), + **optional_params + ) diff --git a/ui/src/views/login/index.vue b/ui/src/views/login/index.vue index dbc2ab691..13018061e 100644 --- a/ui/src/views/login/index.vue +++ b/ui/src/views/login/index.vue @@ -158,7 +158,12 @@ function redirectAuth(authType: string) { const redirectUrl = eval(`\`${config.redirectUrl}\``) let url if (authType === 'CAS') { - url = `${config.ldpUri}?service=${encodeURIComponent(redirectUrl)}` + url = config.ldpUri + if (url.indexOf('?') !== -1) { + url = `${config.ldpUri}&service=${encodeURIComponent(redirectUrl)}` + } else { + url = `${config.ldpUri}?service=${encodeURIComponent(redirectUrl)}` + } } if (authType === 'OIDC') { url = `${config.authEndpoint}?client_id=${config.clientId}&redirect_uri=${redirectUrl}&response_type=code&scope=openid+profile+email`