mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 01:33:05 +00:00
feat: enhance model credential validation and support for multiple API versions
This commit is contained in:
parent
795db14c75
commit
044465fcc6
|
|
@ -60,7 +60,10 @@ class Reasoning:
|
|||
if not self.reasoning_content_is_end:
|
||||
self.reasoning_content_is_end = True
|
||||
self.content += self.all_content
|
||||
return {'content': self.all_content, 'reasoning_content': ''}
|
||||
return {'content': self.all_content,
|
||||
'reasoning_content': chunk.additional_kwargs.get('reasoning_content',
|
||||
'') if chunk.additional_kwargs else ''
|
||||
}
|
||||
else:
|
||||
if self.reasoning_content_is_start:
|
||||
self.reasoning_content_chunk += chunk.content
|
||||
|
|
@ -68,7 +71,9 @@ class Reasoning:
|
|||
self.reasoning_content_end_tag_prefix)
|
||||
if self.reasoning_content_is_end:
|
||||
self.content += chunk.content
|
||||
return {'content': chunk.content, 'reasoning_content': ''}
|
||||
return {'content': chunk.content, 'reasoning_content': chunk.additional_kwargs.get('reasoning_content',
|
||||
'') if chunk.additional_kwargs else ''
|
||||
}
|
||||
# 是否包含结束
|
||||
if reasoning_content_end_tag_prefix_index > -1:
|
||||
if len(self.reasoning_content_chunk) - reasoning_content_end_tag_prefix_index >= self.reasoning_content_end_tag_len:
|
||||
|
|
@ -93,7 +98,9 @@ class Reasoning:
|
|||
else:
|
||||
if self.reasoning_content_is_end:
|
||||
self.content += chunk.content
|
||||
return {'content': chunk.content, 'reasoning_content': ''}
|
||||
return {'content': chunk.content, 'reasoning_content': chunk.additional_kwargs.get('reasoning_content',
|
||||
'') if chunk.additional_kwargs else ''
|
||||
}
|
||||
else:
|
||||
# aaa
|
||||
result = {'content': '', 'reasoning_content': self.reasoning_content_chunk}
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from typing import List, Dict
|
|||
from common.forms.base_field import BaseExecField, TriggerType
|
||||
|
||||
|
||||
class Radio(BaseExecField):
|
||||
class RadioButton(BaseExecField):
|
||||
"""
|
||||
下拉单选
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from typing import List, Dict
|
|||
from common.forms.base_field import BaseExecField, TriggerType
|
||||
|
||||
|
||||
class Radio(BaseExecField):
|
||||
class RadioCard(BaseExecField):
|
||||
"""
|
||||
下拉单选
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -40,16 +40,23 @@ class WenxinLLMModelParams(BaseForm):
|
|||
class WenxinLLMModelCredential(BaseForm, BaseModelCredential):
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||
raise_exception=False):
|
||||
model_type_list = provider.get_model_type_list()
|
||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||
raise AppApiException(ValidCode.valid_error.value,
|
||||
gettext('{model_type} Model type is not supported').format(model_type=model_type))
|
||||
# 根据api_version检查必需字段
|
||||
api_version = model_credential.get('api_version', 'v1')
|
||||
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||
model_info = [model.lower() for model in model.client.models()]
|
||||
if not model_info.__contains__(model_name.lower()):
|
||||
raise AppApiException(ValidCode.valid_error.value,
|
||||
gettext('{model_name} The model does not support').format(model_name=model_name))
|
||||
for key in ['api_key', 'secret_key']:
|
||||
if api_version == 'v1':
|
||||
model_type_list = provider.get_model_type_list()
|
||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||
raise AppApiException(ValidCode.valid_error.value,
|
||||
gettext('{model_type} Model type is not supported').format(model_type=model_type))
|
||||
model_info = [model.lower() for model in model.client.models()]
|
||||
if not model_info.__contains__(model_name.lower()):
|
||||
raise AppApiException(ValidCode.valid_error.value,
|
||||
gettext('{model_name} The model does not support').format(model_name=model_name))
|
||||
required_keys = ['api_key', 'secret_key']
|
||||
if api_version == 'v2':
|
||||
required_keys = ['api_base', 'api_key']
|
||||
|
||||
for key in required_keys:
|
||||
if key not in model_credential:
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, gettext('{key} is required').format(key=key))
|
||||
|
|
@ -64,19 +71,47 @@ class WenxinLLMModelCredential(BaseForm, BaseModelCredential):
|
|||
return True
|
||||
|
||||
def encryption_dict(self, model_info: Dict[str, object]):
|
||||
return {**model_info, 'secret_key': super().encryption(model_info.get('secret_key', ''))}
|
||||
# 根据api_version加密不同字段
|
||||
api_version = model_info.get('api_version', 'v1')
|
||||
if api_version == 'v1':
|
||||
return {**model_info, 'secret_key': super().encryption(model_info.get('secret_key', ''))}
|
||||
else: # v2
|
||||
return {**model_info, 'api_key': super().encryption(model_info.get('api_key', ''))}
|
||||
|
||||
def build_model(self, model_info: Dict[str, object]):
|
||||
for key in ['api_key', 'secret_key', 'model']:
|
||||
if key not in model_info:
|
||||
raise AppApiException(500, gettext('{key} is required').format(key=key))
|
||||
self.api_key = model_info.get('api_key')
|
||||
self.secret_key = model_info.get('secret_key')
|
||||
api_version = model_info.get('api_version', 'v1')
|
||||
# 根据api_version检查必需字段
|
||||
if api_version == 'v1':
|
||||
for key in ['api_version', 'api_key', 'secret_key', 'model']:
|
||||
if key not in model_info:
|
||||
raise AppApiException(500, gettext('{key} is required').format(key=key))
|
||||
self.api_key = model_info.get('api_key')
|
||||
self.secret_key = model_info.get('secret_key')
|
||||
else: # v2
|
||||
for key in ['api_version', 'api_base', 'api_key', 'model', ]:
|
||||
if key not in model_info:
|
||||
raise AppApiException(500, gettext('{key} is required').format(key=key))
|
||||
self.api_base = model_info.get('api_base')
|
||||
self.api_key = model_info.get('api_key')
|
||||
return self
|
||||
|
||||
api_key = forms.PasswordInputField('API Key', required=True)
|
||||
# 动态字段定义 - 根据api_version显示不同字段
|
||||
api_version = forms.Radio('API Version', required=True, text_field='label', value_field='value',
|
||||
option_list=[
|
||||
{'label': 'v1', 'value': 'v1'},
|
||||
{'label': 'v2', 'value': 'v2'}
|
||||
],
|
||||
default_value='v1',
|
||||
provider='',
|
||||
method='', )
|
||||
|
||||
secret_key = forms.PasswordInputField("Secret Key", required=True)
|
||||
# v2版本字段
|
||||
api_base = forms.TextInputField("API Base", required=False, relation_show_field_dict={"api_version": ["v2"]})
|
||||
|
||||
# v1版本字段
|
||||
api_key = forms.PasswordInputField('API Key', required=False)
|
||||
secret_key = forms.PasswordInputField("Secret Key", required=False,
|
||||
relation_show_field_dict={"api_version": ["v1"]})
|
||||
|
||||
def get_model_params_setting_form(self, model_name):
|
||||
return WenxinLLMModelParams()
|
||||
|
|
|
|||
|
|
@ -17,9 +17,10 @@ from langchain_core.messages import (
|
|||
from langchain_core.outputs import ChatGenerationChunk
|
||||
|
||||
from models_provider.base_model_provider import MaxKBBaseModel
|
||||
from models_provider.impl.base_chat_open_ai import BaseChatOpenAI
|
||||
|
||||
|
||||
class QianfanChatModel(MaxKBBaseModel, QianfanChatEndpoint):
|
||||
class QianfanChatModelQianfan(MaxKBBaseModel, QianfanChatEndpoint):
|
||||
@staticmethod
|
||||
def is_cache_model():
|
||||
return False
|
||||
|
|
@ -27,11 +28,11 @@ class QianfanChatModel(MaxKBBaseModel, QianfanChatEndpoint):
|
|||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
|
||||
return QianfanChatModel(model=model_name,
|
||||
qianfan_ak=model_credential.get('api_key'),
|
||||
qianfan_sk=model_credential.get('secret_key'),
|
||||
streaming=model_kwargs.get('streaming', False),
|
||||
init_kwargs=optional_params)
|
||||
return QianfanChatModelQianfan(model=model_name,
|
||||
qianfan_ak=model_credential.get('api_key'),
|
||||
qianfan_sk=model_credential.get('secret_key'),
|
||||
streaming=model_kwargs.get('streaming', False),
|
||||
init_kwargs=optional_params)
|
||||
|
||||
usage_metadata: dict = {}
|
||||
|
||||
|
|
@ -74,3 +75,30 @@ class QianfanChatModel(MaxKBBaseModel, QianfanChatEndpoint):
|
|||
if run_manager:
|
||||
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
|
||||
yield chunk
|
||||
|
||||
|
||||
class QianfanChatModelOpenai(MaxKBBaseModel, BaseChatOpenAI):
|
||||
@staticmethod
|
||||
def is_cache_model():
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
|
||||
return QianfanChatModelOpenai(
|
||||
model=model_name,
|
||||
openai_api_base=model_credential.get('api_base'),
|
||||
openai_api_key=model_credential.get('api_key'),
|
||||
extra_body=optional_params
|
||||
)
|
||||
|
||||
|
||||
class QianfanChatModel(MaxKBBaseModel):
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
api_version = model_credential.get('api_version', 'v1')
|
||||
|
||||
if api_version == "v1":
|
||||
return QianfanChatModelQianfan.new_instance(model_type, model_name, model_credential, **model_kwargs)
|
||||
elif api_version == "v2":
|
||||
return QianfanChatModelOpenai.new_instance(model_type, model_name, model_credential, **model_kwargs)
|
||||
|
|
|
|||
Loading…
Reference in New Issue