From a2b6620b10ba6318e7069f92cc17fc7255a6dd1e Mon Sep 17 00:00:00 2001 From: wxg0103 <727495428@qq.com> Date: Mon, 31 Mar 2025 15:31:52 +0800 Subject: [PATCH] refactor: azure llm params --- .../impl/azure_model_provider/credential/llm.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/apps/setting/models_provider/impl/azure_model_provider/credential/llm.py b/apps/setting/models_provider/impl/azure_model_provider/credential/llm.py index 60fac8105..ac1727924 100644 --- a/apps/setting/models_provider/impl/azure_model_provider/credential/llm.py +++ b/apps/setting/models_provider/impl/azure_model_provider/credential/llm.py @@ -10,6 +10,7 @@ import traceback from typing import Dict from langchain_core.messages import HumanMessage +from openai import BadRequestError from common import forms from common.exception.app_exception import AppApiException @@ -37,6 +38,17 @@ class AzureLLMModelParams(BaseForm): precision=0) +class o3MiniLLMModelParams(BaseForm): + max_completion_tokens = forms.SliderField( + TooltipLabel(_('Output the maximum Tokens'), + _('Specify the maximum number of tokens that the model can generate')), + required=True, default_value=800, + _min=1, + _max=5000, + _step=1, + precision=0) + + class AzureLLMModelCredential(BaseForm, BaseModelCredential): def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, @@ -57,7 +69,7 @@ class AzureLLMModelCredential(BaseForm, BaseModelCredential): model.invoke([HumanMessage(content=gettext('Hello'))]) except Exception as e: traceback.print_exc() - if isinstance(e, AppApiException): + if isinstance(e, AppApiException) or isinstance(e, BadRequestError): raise e if raise_exception: raise AppApiException(ValidCode.valid_error.value, @@ -79,4 +91,6 @@ class AzureLLMModelCredential(BaseForm, BaseModelCredential): deployment_name = forms.TextInputField("Deployment name", required=True) def get_model_params_setting_form(self, model_name): + if 'o3' in model_name or 'o1' in model_name: + return o3MiniLLMModelParams() return AzureLLMModelParams()