diff --git a/apps/setting/models_provider/impl/azure_model_provider/credential/llm.py b/apps/setting/models_provider/impl/azure_model_provider/credential/llm.py index 60fac8105..ac1727924 100644 --- a/apps/setting/models_provider/impl/azure_model_provider/credential/llm.py +++ b/apps/setting/models_provider/impl/azure_model_provider/credential/llm.py @@ -10,6 +10,7 @@ import traceback from typing import Dict from langchain_core.messages import HumanMessage +from openai import BadRequestError from common import forms from common.exception.app_exception import AppApiException @@ -37,6 +38,17 @@ class AzureLLMModelParams(BaseForm): precision=0) +class o3MiniLLMModelParams(BaseForm): + max_completion_tokens = forms.SliderField( + TooltipLabel(_('Output the maximum Tokens'), + _('Specify the maximum number of tokens that the model can generate')), + required=True, default_value=800, + _min=1, + _max=5000, + _step=1, + precision=0) + + class AzureLLMModelCredential(BaseForm, BaseModelCredential): def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, @@ -57,7 +69,7 @@ class AzureLLMModelCredential(BaseForm, BaseModelCredential): model.invoke([HumanMessage(content=gettext('Hello'))]) except Exception as e: traceback.print_exc() - if isinstance(e, AppApiException): + if isinstance(e, AppApiException) or isinstance(e, BadRequestError): raise e if raise_exception: raise AppApiException(ValidCode.valid_error.value, @@ -79,4 +91,6 @@ class AzureLLMModelCredential(BaseForm, BaseModelCredential): deployment_name = forms.TextInputField("Deployment name", required=True) def get_model_params_setting_form(self, model_name): + if 'o3' in model_name or 'o1' in model_name: + return o3MiniLLMModelParams() return AzureLLMModelParams()