diff --git a/apps/application/flow/tools.py b/apps/application/flow/tools.py index de9d4add0..60cdbb408 100644 --- a/apps/application/flow/tools.py +++ b/apps/application/flow/tools.py @@ -60,7 +60,10 @@ class Reasoning: if not self.reasoning_content_is_end: self.reasoning_content_is_end = True self.content += self.all_content - return {'content': self.all_content, 'reasoning_content': ''} + return {'content': self.all_content, + 'reasoning_content': chunk.additional_kwargs.get('reasoning_content', + '') if chunk.additional_kwargs else '' + } else: if self.reasoning_content_is_start: self.reasoning_content_chunk += chunk.content @@ -68,7 +71,9 @@ class Reasoning: self.reasoning_content_end_tag_prefix) if self.reasoning_content_is_end: self.content += chunk.content - return {'content': chunk.content, 'reasoning_content': ''} + return {'content': chunk.content, 'reasoning_content': chunk.additional_kwargs.get('reasoning_content', + '') if chunk.additional_kwargs else '' + } # 是否包含结束 if reasoning_content_end_tag_prefix_index > -1: if len(self.reasoning_content_chunk) - reasoning_content_end_tag_prefix_index >= self.reasoning_content_end_tag_len: @@ -93,7 +98,9 @@ class Reasoning: else: if self.reasoning_content_is_end: self.content += chunk.content - return {'content': chunk.content, 'reasoning_content': ''} + return {'content': chunk.content, 'reasoning_content': chunk.additional_kwargs.get('reasoning_content', + '') if chunk.additional_kwargs else '' + } else: # aaa result = {'content': '', 'reasoning_content': self.reasoning_content_chunk} diff --git a/apps/common/forms/radio_button_field.py b/apps/common/forms/radio_button_field.py index aa6952303..b31572d6c 100644 --- a/apps/common/forms/radio_button_field.py +++ b/apps/common/forms/radio_button_field.py @@ -11,7 +11,7 @@ from typing import List, Dict from common.forms.base_field import BaseExecField, TriggerType -class Radio(BaseExecField): +class RadioButton(BaseExecField): """ 下拉单选 """ diff --git a/apps/common/forms/radio_card_field.py b/apps/common/forms/radio_card_field.py index b3579b84d..31c66d678 100644 --- a/apps/common/forms/radio_card_field.py +++ b/apps/common/forms/radio_card_field.py @@ -11,7 +11,7 @@ from typing import List, Dict from common.forms.base_field import BaseExecField, TriggerType -class Radio(BaseExecField): +class RadioCard(BaseExecField): """ 下拉单选 """ diff --git a/apps/models_provider/impl/wenxin_model_provider/credential/llm.py b/apps/models_provider/impl/wenxin_model_provider/credential/llm.py index bb458fb6f..4c06ee3b4 100644 --- a/apps/models_provider/impl/wenxin_model_provider/credential/llm.py +++ b/apps/models_provider/impl/wenxin_model_provider/credential/llm.py @@ -40,16 +40,23 @@ class WenxinLLMModelParams(BaseForm): class WenxinLLMModelCredential(BaseForm, BaseModelCredential): def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, raise_exception=False): - model_type_list = provider.get_model_type_list() - if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): - raise AppApiException(ValidCode.valid_error.value, - gettext('{model_type} Model type is not supported').format(model_type=model_type)) + # 根据api_version检查必需字段 + api_version = model_credential.get('api_version', 'v1') model = provider.get_model(model_type, model_name, model_credential, **model_params) - model_info = [model.lower() for model in model.client.models()] - if not model_info.__contains__(model_name.lower()): - raise AppApiException(ValidCode.valid_error.value, - gettext('{model_name} The model does not support').format(model_name=model_name)) - for key in ['api_key', 'secret_key']: + if api_version == 'v1': + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, + gettext('{model_type} Model type is not supported').format(model_type=model_type)) + model_info = [model.lower() for model in model.client.models()] + if not model_info.__contains__(model_name.lower()): + raise AppApiException(ValidCode.valid_error.value, + gettext('{model_name} The model does not support').format(model_name=model_name)) + required_keys = ['api_key', 'secret_key'] + if api_version == 'v2': + required_keys = ['api_base', 'api_key'] + + for key in required_keys: if key not in model_credential: if raise_exception: raise AppApiException(ValidCode.valid_error.value, gettext('{key} is required').format(key=key)) @@ -64,19 +71,47 @@ class WenxinLLMModelCredential(BaseForm, BaseModelCredential): return True def encryption_dict(self, model_info: Dict[str, object]): - return {**model_info, 'secret_key': super().encryption(model_info.get('secret_key', ''))} + # 根据api_version加密不同字段 + api_version = model_info.get('api_version', 'v1') + if api_version == 'v1': + return {**model_info, 'secret_key': super().encryption(model_info.get('secret_key', ''))} + else: # v2 + return {**model_info, 'api_key': super().encryption(model_info.get('api_key', ''))} def build_model(self, model_info: Dict[str, object]): - for key in ['api_key', 'secret_key', 'model']: - if key not in model_info: - raise AppApiException(500, gettext('{key} is required').format(key=key)) - self.api_key = model_info.get('api_key') - self.secret_key = model_info.get('secret_key') + api_version = model_info.get('api_version', 'v1') + # 根据api_version检查必需字段 + if api_version == 'v1': + for key in ['api_version', 'api_key', 'secret_key', 'model']: + if key not in model_info: + raise AppApiException(500, gettext('{key} is required').format(key=key)) + self.api_key = model_info.get('api_key') + self.secret_key = model_info.get('secret_key') + else: # v2 + for key in ['api_version', 'api_base', 'api_key', 'model', ]: + if key not in model_info: + raise AppApiException(500, gettext('{key} is required').format(key=key)) + self.api_base = model_info.get('api_base') + self.api_key = model_info.get('api_key') return self - api_key = forms.PasswordInputField('API Key', required=True) + # 动态字段定义 - 根据api_version显示不同字段 + api_version = forms.Radio('API Version', required=True, text_field='label', value_field='value', + option_list=[ + {'label': 'v1', 'value': 'v1'}, + {'label': 'v2', 'value': 'v2'} + ], + default_value='v1', + provider='', + method='', ) - secret_key = forms.PasswordInputField("Secret Key", required=True) + # v2版本字段 + api_base = forms.TextInputField("API Base", required=False, relation_show_field_dict={"api_version": ["v2"]}) + + # v1版本字段 + api_key = forms.PasswordInputField('API Key', required=False) + secret_key = forms.PasswordInputField("Secret Key", required=False, + relation_show_field_dict={"api_version": ["v1"]}) def get_model_params_setting_form(self, model_name): return WenxinLLMModelParams() diff --git a/apps/models_provider/impl/wenxin_model_provider/model/llm.py b/apps/models_provider/impl/wenxin_model_provider/model/llm.py index aa79692b2..688530136 100644 --- a/apps/models_provider/impl/wenxin_model_provider/model/llm.py +++ b/apps/models_provider/impl/wenxin_model_provider/model/llm.py @@ -17,9 +17,10 @@ from langchain_core.messages import ( from langchain_core.outputs import ChatGenerationChunk from models_provider.base_model_provider import MaxKBBaseModel +from models_provider.impl.base_chat_open_ai import BaseChatOpenAI -class QianfanChatModel(MaxKBBaseModel, QianfanChatEndpoint): +class QianfanChatModelQianfan(MaxKBBaseModel, QianfanChatEndpoint): @staticmethod def is_cache_model(): return False @@ -27,11 +28,11 @@ class QianfanChatModel(MaxKBBaseModel, QianfanChatEndpoint): @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs) - return QianfanChatModel(model=model_name, - qianfan_ak=model_credential.get('api_key'), - qianfan_sk=model_credential.get('secret_key'), - streaming=model_kwargs.get('streaming', False), - init_kwargs=optional_params) + return QianfanChatModelQianfan(model=model_name, + qianfan_ak=model_credential.get('api_key'), + qianfan_sk=model_credential.get('secret_key'), + streaming=model_kwargs.get('streaming', False), + init_kwargs=optional_params) usage_metadata: dict = {} @@ -74,3 +75,30 @@ class QianfanChatModel(MaxKBBaseModel, QianfanChatEndpoint): if run_manager: run_manager.on_llm_new_token(chunk.text, chunk=chunk) yield chunk + + +class QianfanChatModelOpenai(MaxKBBaseModel, BaseChatOpenAI): + @staticmethod + def is_cache_model(): + return False + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs) + return QianfanChatModelOpenai( + model=model_name, + openai_api_base=model_credential.get('api_base'), + openai_api_key=model_credential.get('api_key'), + extra_body=optional_params + ) + + +class QianfanChatModel(MaxKBBaseModel): + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + api_version = model_credential.get('api_version', 'v1') + + if api_version == "v1": + return QianfanChatModelQianfan.new_instance(model_type, model_name, model_credential, **model_kwargs) + elif api_version == "v2": + return QianfanChatModelOpenai.new_instance(model_type, model_name, model_credential, **model_kwargs)