refactor: azure llm params

This commit is contained in:
wxg0103 2025-03-31 15:31:52 +08:00
parent b0a4e9e78f
commit a2b6620b10

View File

@ -10,6 +10,7 @@ import traceback
from typing import Dict
from langchain_core.messages import HumanMessage
from openai import BadRequestError
from common import forms
from common.exception.app_exception import AppApiException
@ -37,6 +38,17 @@ class AzureLLMModelParams(BaseForm):
precision=0)
class o3MiniLLMModelParams(BaseForm):
max_completion_tokens = forms.SliderField(
TooltipLabel(_('Output the maximum Tokens'),
_('Specify the maximum number of tokens that the model can generate')),
required=True, default_value=800,
_min=1,
_max=5000,
_step=1,
precision=0)
class AzureLLMModelCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
@ -57,7 +69,7 @@ class AzureLLMModelCredential(BaseForm, BaseModelCredential):
model.invoke([HumanMessage(content=gettext('Hello'))])
except Exception as e:
traceback.print_exc()
if isinstance(e, AppApiException):
if isinstance(e, AppApiException) or isinstance(e, BadRequestError):
raise e
if raise_exception:
raise AppApiException(ValidCode.valid_error.value,
@ -79,4 +91,6 @@ class AzureLLMModelCredential(BaseForm, BaseModelCredential):
deployment_name = forms.TextInputField("Deployment name", required=True)
def get_model_params_setting_form(self, model_name):
if 'o3' in model_name or 'o1' in model_name:
return o3MiniLLMModelParams()
return AzureLLMModelParams()