diff --git a/apps/setting/models_provider/impl/azure_model_provider/azure_model_provider.py b/apps/setting/models_provider/impl/azure_model_provider/azure_model_provider.py index f1cd269f1..bfaf91308 100644 --- a/apps/setting/models_provider/impl/azure_model_provider/azure_model_provider.py +++ b/apps/setting/models_provider/impl/azure_model_provider/azure_model_provider.py @@ -10,7 +10,7 @@ import os from typing import Dict from langchain.schema import HumanMessage -from langchain_community.chat_models import AzureChatOpenAI +from langchain_community.chat_models.azure_openai import AzureChatOpenAI from common import froms from common.exception.app_exception import AppApiException @@ -29,9 +29,6 @@ class AzureLLMModelCredential(BaseForm, BaseModelCredential): if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') - if model_name not in model_dict: - raise AppApiException(ValidCode.valid_error.value, f'{model_name} 模型名称不支持') - for key in ['api_base', 'api_key', 'deployment_name']: if key not in model_credential: if raise_exception: @@ -40,7 +37,7 @@ class AzureLLMModelCredential(BaseForm, BaseModelCredential): return False try: model = AzureModelProvider().get_model(model_type, model_name, model_credential) - model.invoke([HumanMessage(content='valid')]) + model.invoke([HumanMessage(content='你好')]) except Exception as e: if isinstance(e, AppApiException): raise e @@ -61,8 +58,48 @@ class AzureLLMModelCredential(BaseForm, BaseModelCredential): deployment_name = froms.TextInputField("部署名", required=True) +class DefaultAzureLLMModelCredential(BaseForm, BaseModelCredential): + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False): + model_type_list = AzureModelProvider().get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + + for key in ['api_base', 'api_key', 'deployment_name', 'api_version']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') + else: + return False + try: + model = AzureModelProvider().get_model(model_type, model_name, model_credential) + model.invoke([HumanMessage(content='你好')]) + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, '校验失败,请检查参数是否正确') + else: + return False + + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'api_key': super().encryption(model.get('api_key', ''))} + + api_version = froms.TextInputField("api_version", required=True) + + api_base = froms.TextInputField('API 域名', required=True) + + api_key = froms.PasswordInputField("API Key", required=True) + + deployment_name = froms.TextInputField("部署名", required=True) + + azure_llm_model_credential = AzureLLMModelCredential() +base_azure_llm_model_credential = DefaultAzureLLMModelCredential() + model_dict = { 'gpt-3.5-turbo-0613': ModelInfo('gpt-3.5-turbo-0613', '', ModelTypeConst.LLM, azure_llm_model_credential, api_version='2023-07-01-preview'), @@ -84,18 +121,18 @@ class AzureModelProvider(IModelProvider): model_info: ModelInfo = model_dict.get(model_name) azure_chat_open_ai = AzureChatOpenAI( openai_api_base=model_credential.get('api_base'), - openai_api_version=model_info.api_version, + openai_api_version=model_credential.get( + 'api_version') if 'api_version' in model_credential else model_info.api_version, deployment_name=model_credential.get('deployment_name'), openai_api_key=model_credential.get('api_key'), - openai_api_type="azure", - tiktoken_model_name=model_name + openai_api_type="azure" ) return azure_chat_open_ai def get_model_credential(self, model_type, model_name): if model_name in model_dict: return model_dict.get(model_name).model_credential - raise AppApiException(500, f'不支持的模型:{model_name}') + return base_azure_llm_model_credential def get_model_provide_info(self): return ModelProvideInfo(provider='model_azure_provider', name='Azure OpenAI', icon=get_file_content( diff --git a/apps/setting/models_provider/impl/wenxin_model_provider/wenxin_model_provider.py b/apps/setting/models_provider/impl/wenxin_model_provider/wenxin_model_provider.py index 3c05777fa..2805f4b1d 100644 --- a/apps/setting/models_provider/impl/wenxin_model_provider/wenxin_model_provider.py +++ b/apps/setting/models_provider/impl/wenxin_model_provider/wenxin_model_provider.py @@ -9,8 +9,9 @@ import os from typing import Dict -from langchain_community.chat_models import QianfanChatEndpoint from langchain.schema import HumanMessage +from langchain_community.chat_models import QianfanChatEndpoint +from qianfan import ChatCompletion from common import froms from common.exception.app_exception import AppApiException @@ -27,10 +28,9 @@ class WenxinLLMModelCredential(BaseForm, BaseModelCredential): model_type_list = WenxinModelProvider().get_model_type_list() if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') - - if model_name not in model_dict: - raise AppApiException(ValidCode.valid_error.value, f'{model_name} 模型名称不支持') - + model_info = [model.lower() for model in ChatCompletion.models()] + if not model_info.__contains__(model_name.lower()): + raise AppApiException(ValidCode.valid_error.value, f'{model_name} 模型不支持') for key in ['api_key', 'secret_key']: if key not in model_credential: if raise_exception: @@ -39,10 +39,9 @@ class WenxinLLMModelCredential(BaseForm, BaseModelCredential): return False try: WenxinModelProvider().get_model(model_type, model_name, model_credential).invoke( - [HumanMessage(content='valid')]) + [HumanMessage(content='你好')]) except Exception as e: - if raise_exception: - raise AppApiException(ValidCode.valid_error.value, "校验失败,请检查 api_key secret_key 是否正确") + raise e return True def encryption_dict(self, model_info: Dict[str, object]): @@ -121,7 +120,7 @@ class WenxinModelProvider(IModelProvider): def get_model_credential(self, model_type, model_name): if model_name in model_dict: return model_dict.get(model_name).model_credential - raise AppApiException(500, f'不支持的模型:{model_name}') + return win_xin_llm_model_credential def get_model_provide_info(self): return ModelProvideInfo(provider='model_wenxin_provider', name='千帆大模型', icon=get_file_content( diff --git a/pyproject.toml b/pyproject.toml index d9a2f62b1..bee1b9cf5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,7 @@ sentence-transformers = "^2.2.2" blinker = "^1.6.3" openai = "^1.13.3" tiktoken = "^0.5.1" -qianfan = "^0.1.1" +qianfan = "^0.3.6.1" pycryptodome = "^3.19.0" beautifulsoup4 = "^4.12.2" html2text = "^2024.2.26"