feat: enhance model credential validation and support for multiple API versions

This commit is contained in:
wxg0103 2025-08-18 16:46:27 +08:00
parent 795db14c75
commit 044465fcc6
5 changed files with 98 additions and 28 deletions

View File

@ -60,7 +60,10 @@ class Reasoning:
if not self.reasoning_content_is_end:
self.reasoning_content_is_end = True
self.content += self.all_content
return {'content': self.all_content, 'reasoning_content': ''}
return {'content': self.all_content,
'reasoning_content': chunk.additional_kwargs.get('reasoning_content',
'') if chunk.additional_kwargs else ''
}
else:
if self.reasoning_content_is_start:
self.reasoning_content_chunk += chunk.content
@ -68,7 +71,9 @@ class Reasoning:
self.reasoning_content_end_tag_prefix)
if self.reasoning_content_is_end:
self.content += chunk.content
return {'content': chunk.content, 'reasoning_content': ''}
return {'content': chunk.content, 'reasoning_content': chunk.additional_kwargs.get('reasoning_content',
'') if chunk.additional_kwargs else ''
}
# 是否包含结束
if reasoning_content_end_tag_prefix_index > -1:
if len(self.reasoning_content_chunk) - reasoning_content_end_tag_prefix_index >= self.reasoning_content_end_tag_len:
@ -93,7 +98,9 @@ class Reasoning:
else:
if self.reasoning_content_is_end:
self.content += chunk.content
return {'content': chunk.content, 'reasoning_content': ''}
return {'content': chunk.content, 'reasoning_content': chunk.additional_kwargs.get('reasoning_content',
'') if chunk.additional_kwargs else ''
}
else:
# aaa
result = {'content': '', 'reasoning_content': self.reasoning_content_chunk}

View File

@ -11,7 +11,7 @@ from typing import List, Dict
from common.forms.base_field import BaseExecField, TriggerType
class Radio(BaseExecField):
class RadioButton(BaseExecField):
"""
下拉单选
"""

View File

@ -11,7 +11,7 @@ from typing import List, Dict
from common.forms.base_field import BaseExecField, TriggerType
class Radio(BaseExecField):
class RadioCard(BaseExecField):
"""
下拉单选
"""

View File

@ -40,16 +40,23 @@ class WenxinLLMModelParams(BaseForm):
class WenxinLLMModelCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value,
gettext('{model_type} Model type is not supported').format(model_type=model_type))
# 根据api_version检查必需字段
api_version = model_credential.get('api_version', 'v1')
model = provider.get_model(model_type, model_name, model_credential, **model_params)
model_info = [model.lower() for model in model.client.models()]
if not model_info.__contains__(model_name.lower()):
raise AppApiException(ValidCode.valid_error.value,
gettext('{model_name} The model does not support').format(model_name=model_name))
for key in ['api_key', 'secret_key']:
if api_version == 'v1':
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value,
gettext('{model_type} Model type is not supported').format(model_type=model_type))
model_info = [model.lower() for model in model.client.models()]
if not model_info.__contains__(model_name.lower()):
raise AppApiException(ValidCode.valid_error.value,
gettext('{model_name} The model does not support').format(model_name=model_name))
required_keys = ['api_key', 'secret_key']
if api_version == 'v2':
required_keys = ['api_base', 'api_key']
for key in required_keys:
if key not in model_credential:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, gettext('{key} is required').format(key=key))
@ -64,19 +71,47 @@ class WenxinLLMModelCredential(BaseForm, BaseModelCredential):
return True
def encryption_dict(self, model_info: Dict[str, object]):
return {**model_info, 'secret_key': super().encryption(model_info.get('secret_key', ''))}
# 根据api_version加密不同字段
api_version = model_info.get('api_version', 'v1')
if api_version == 'v1':
return {**model_info, 'secret_key': super().encryption(model_info.get('secret_key', ''))}
else: # v2
return {**model_info, 'api_key': super().encryption(model_info.get('api_key', ''))}
def build_model(self, model_info: Dict[str, object]):
for key in ['api_key', 'secret_key', 'model']:
if key not in model_info:
raise AppApiException(500, gettext('{key} is required').format(key=key))
self.api_key = model_info.get('api_key')
self.secret_key = model_info.get('secret_key')
api_version = model_info.get('api_version', 'v1')
# 根据api_version检查必需字段
if api_version == 'v1':
for key in ['api_version', 'api_key', 'secret_key', 'model']:
if key not in model_info:
raise AppApiException(500, gettext('{key} is required').format(key=key))
self.api_key = model_info.get('api_key')
self.secret_key = model_info.get('secret_key')
else: # v2
for key in ['api_version', 'api_base', 'api_key', 'model', ]:
if key not in model_info:
raise AppApiException(500, gettext('{key} is required').format(key=key))
self.api_base = model_info.get('api_base')
self.api_key = model_info.get('api_key')
return self
api_key = forms.PasswordInputField('API Key', required=True)
# 动态字段定义 - 根据api_version显示不同字段
api_version = forms.Radio('API Version', required=True, text_field='label', value_field='value',
option_list=[
{'label': 'v1', 'value': 'v1'},
{'label': 'v2', 'value': 'v2'}
],
default_value='v1',
provider='',
method='', )
secret_key = forms.PasswordInputField("Secret Key", required=True)
# v2版本字段
api_base = forms.TextInputField("API Base", required=False, relation_show_field_dict={"api_version": ["v2"]})
# v1版本字段
api_key = forms.PasswordInputField('API Key', required=False)
secret_key = forms.PasswordInputField("Secret Key", required=False,
relation_show_field_dict={"api_version": ["v1"]})
def get_model_params_setting_form(self, model_name):
return WenxinLLMModelParams()

View File

@ -17,9 +17,10 @@ from langchain_core.messages import (
from langchain_core.outputs import ChatGenerationChunk
from models_provider.base_model_provider import MaxKBBaseModel
from models_provider.impl.base_chat_open_ai import BaseChatOpenAI
class QianfanChatModel(MaxKBBaseModel, QianfanChatEndpoint):
class QianfanChatModelQianfan(MaxKBBaseModel, QianfanChatEndpoint):
@staticmethod
def is_cache_model():
return False
@ -27,11 +28,11 @@ class QianfanChatModel(MaxKBBaseModel, QianfanChatEndpoint):
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
return QianfanChatModel(model=model_name,
qianfan_ak=model_credential.get('api_key'),
qianfan_sk=model_credential.get('secret_key'),
streaming=model_kwargs.get('streaming', False),
init_kwargs=optional_params)
return QianfanChatModelQianfan(model=model_name,
qianfan_ak=model_credential.get('api_key'),
qianfan_sk=model_credential.get('secret_key'),
streaming=model_kwargs.get('streaming', False),
init_kwargs=optional_params)
usage_metadata: dict = {}
@ -74,3 +75,30 @@ class QianfanChatModel(MaxKBBaseModel, QianfanChatEndpoint):
if run_manager:
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
yield chunk
class QianfanChatModelOpenai(MaxKBBaseModel, BaseChatOpenAI):
@staticmethod
def is_cache_model():
return False
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
return QianfanChatModelOpenai(
model=model_name,
openai_api_base=model_credential.get('api_base'),
openai_api_key=model_credential.get('api_key'),
extra_body=optional_params
)
class QianfanChatModel(MaxKBBaseModel):
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
api_version = model_credential.get('api_version', 'v1')
if api_version == "v1":
return QianfanChatModelQianfan.new_instance(model_type, model_name, model_credential, **model_kwargs)
elif api_version == "v2":
return QianfanChatModelOpenai.new_instance(model_type, model_name, model_credential, **model_kwargs)