diff --git a/apps/application/flow/step_node/image_understand_step_node/i_image_understand_node.py b/apps/application/flow/step_node/image_understand_step_node/i_image_understand_node.py index 31aaa0205..8fce0b176 100644 --- a/apps/application/flow/step_node/image_understand_step_node/i_image_understand_node.py +++ b/apps/application/flow/step_node/image_understand_step_node/i_image_understand_node.py @@ -22,6 +22,9 @@ class ImageUnderstandNodeSerializer(serializers.Serializer): image_list = serializers.ListField(required=False, error_messages=ErrMessage.list("图片")) + model_params_setting = serializers.JSONField(required=False, default=dict, error_messages=ErrMessage.json("模型参数设置")) + + class IImageUnderstandNode(INode): type = 'image-understand-node' @@ -35,6 +38,7 @@ class IImageUnderstandNode(INode): return self.execute(image=res, **self.node_params_serializer.data, **self.flow_params_serializer.data) def execute(self, model_id, system, prompt, dialogue_number, dialogue_type, history_chat_record, stream, chat_id, + model_params_setting, chat_record_id, image, **kwargs) -> NodeResult: diff --git a/apps/application/flow/step_node/image_understand_step_node/impl/base_image_understand_node.py b/apps/application/flow/step_node/image_understand_step_node/impl/base_image_understand_node.py index 1c2536e0c..a26c6f52f 100644 --- a/apps/application/flow/step_node/image_understand_step_node/impl/base_image_understand_node.py +++ b/apps/application/flow/step_node/image_understand_step_node/impl/base_image_understand_node.py @@ -70,6 +70,7 @@ class BaseImageUnderstandNode(IImageUnderstandNode): self.answer_text = details.get('answer') def execute(self, model_id, system, prompt, dialogue_number, dialogue_type, history_chat_record, stream, chat_id, + model_params_setting, chat_record_id, image, **kwargs) -> NodeResult: @@ -77,7 +78,7 @@ class BaseImageUnderstandNode(IImageUnderstandNode): if image is None or not isinstance(image, list): image = [] - image_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id')) + image_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'), **model_params_setting) # 执行详情中的历史消息不需要图片内容 history_message = self.get_history_message_for_details(history_chat_record, dialogue_number) self.context['history_message'] = history_message diff --git a/apps/common/forms/text_input_field.py b/apps/common/forms/text_input_field.py index 28a821e15..2b8b2ce04 100644 --- a/apps/common/forms/text_input_field.py +++ b/apps/common/forms/text_input_field.py @@ -8,6 +8,7 @@ """ from typing import Dict +from common.forms import BaseLabel from common.forms.base_field import BaseField, TriggerType @@ -16,7 +17,7 @@ class TextInputField(BaseField): 文本输入框 """ - def __init__(self, label: str, + def __init__(self, label: str or BaseLabel, required: bool = False, default_value=None, relation_show_field_dict: Dict = None, diff --git a/apps/setting/models_provider/impl/openai_model_provider/credential/image.py b/apps/setting/models_provider/impl/openai_model_provider/credential/image.py index e6063695a..83c1e70d1 100644 --- a/apps/setting/models_provider/impl/openai_model_provider/credential/image.py +++ b/apps/setting/models_provider/impl/openai_model_provider/credential/image.py @@ -7,9 +7,26 @@ from langchain_core.messages import HumanMessage from common import forms from common.exception.app_exception import AppApiException -from common.forms import BaseForm +from common.forms import BaseForm, TooltipLabel from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode +class OpenAIImageModelParams(BaseForm): + temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'), + required=True, default_value=0.7, + _min=0.1, + _max=1.0, + _step=0.01, + precision=2) + + max_tokens = forms.SliderField( + TooltipLabel('输出最大Tokens', '指定模型可生成的最大token个数'), + required=True, default_value=800, + _min=1, + _max=100000, + _step=1, + precision=0) + + class OpenAIImageModelCredential(BaseForm, BaseModelCredential): api_base = forms.TextInputField('API 域名', required=True) @@ -45,4 +62,4 @@ class OpenAIImageModelCredential(BaseForm, BaseModelCredential): return {**model, 'api_key': super().encryption(model.get('api_key', ''))} def get_model_params_setting_form(self, model_name): - pass + return OpenAIImageModelParams() diff --git a/apps/setting/models_provider/impl/openai_model_provider/credential/tti.py b/apps/setting/models_provider/impl/openai_model_provider/credential/tti.py index 480c20a8c..7a9a70006 100644 --- a/apps/setting/models_provider/impl/openai_model_provider/credential/tti.py +++ b/apps/setting/models_provider/impl/openai_model_provider/credential/tti.py @@ -7,9 +7,26 @@ from langchain_core.messages import HumanMessage from common import forms from common.exception.app_exception import AppApiException -from common.forms import BaseForm +from common.forms import BaseForm, TooltipLabel from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode +class OpenAITTIModelParams(BaseForm): + size = forms.TextInputField( + TooltipLabel('图片尺寸', '指定生成图片的尺寸, 如: 1024x1024'), + required=True, default_value='1024x1024') + + quality = forms.TextInputField( + TooltipLabel('图片质量', ''), + required=True, default_value='standard') + + n = forms.SliderField( + TooltipLabel('图片数量', '指定生成图片的数量'), + required=True, default_value=1, + _min=1, + _max=10, + _step=1, + precision=0) + class OpenAITextToImageModelCredential(BaseForm, BaseModelCredential): api_base = forms.TextInputField('API 域名', required=True) @@ -44,4 +61,4 @@ class OpenAITextToImageModelCredential(BaseForm, BaseModelCredential): return {**model, 'api_key': super().encryption(model.get('api_key', ''))} def get_model_params_setting_form(self, model_name): - pass + return OpenAITTIModelParams() diff --git a/apps/setting/models_provider/impl/qwen_model_provider/credential/tti.py b/apps/setting/models_provider/impl/qwen_model_provider/credential/tti.py index dc4779f6b..805751b4d 100644 --- a/apps/setting/models_provider/impl/qwen_model_provider/credential/tti.py +++ b/apps/setting/models_provider/impl/qwen_model_provider/credential/tti.py @@ -19,20 +19,19 @@ from setting.models_provider.base_model_provider import BaseModelCredential, Val class QwenModelParams(BaseForm): - temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'), - required=True, default_value=1.0, - _min=0.1, - _max=1.9, - _step=0.01, - precision=2) - - max_tokens = forms.SliderField( - TooltipLabel('输出最大Tokens', '指定模型可生成的最大token个数'), - required=True, default_value=800, + size = forms.TextInputField( + TooltipLabel('图片尺寸', '指定生成图片的尺寸, 如: 1024x1024'), + required=True, default_value='1024x1024') + n = forms.SliderField( + TooltipLabel('图片数量', '指定生成图片的数量'), + required=True, default_value=1, _min=1, - _max=100000, + _max=4, _step=1, precision=0) + style = forms.TextInputField( + TooltipLabel('风格', '指定生成图片的风格'), + required=True, default_value='') class QwenTextToImageModelCredential(BaseForm, BaseModelCredential): diff --git a/apps/setting/models_provider/impl/zhipu_model_provider/credential/image.py b/apps/setting/models_provider/impl/zhipu_model_provider/credential/image.py index 0eb05bb91..54bd19e14 100644 --- a/apps/setting/models_provider/impl/zhipu_model_provider/credential/image.py +++ b/apps/setting/models_provider/impl/zhipu_model_provider/credential/image.py @@ -7,9 +7,24 @@ from langchain_core.messages import HumanMessage from common import forms from common.exception.app_exception import AppApiException -from common.forms import BaseForm +from common.forms import BaseForm, TooltipLabel from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode +class ZhiPuImageModelParams(BaseForm): + temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'), + required=True, default_value=0.95, + _min=0.1, + _max=1.0, + _step=0.01, + precision=2) + + max_tokens = forms.SliderField( + TooltipLabel('输出最大Tokens', '指定模型可生成的最大token个数'), + required=True, default_value=1024, + _min=1, + _max=100000, + _step=1, + precision=0) class ZhiPuImageModelCredential(BaseForm, BaseModelCredential): api_key = forms.PasswordInputField('API Key', required=True) @@ -44,4 +59,4 @@ class ZhiPuImageModelCredential(BaseForm, BaseModelCredential): return {**model, 'api_key': super().encryption(model.get('api_key', ''))} def get_model_params_setting_form(self, model_name): - pass + return ZhiPuImageModelParams() diff --git a/apps/setting/models_provider/impl/zhipu_model_provider/credential/tti.py b/apps/setting/models_provider/impl/zhipu_model_provider/credential/tti.py index 4ac28c1b5..b0a9d7a61 100644 --- a/apps/setting/models_provider/impl/zhipu_model_provider/credential/tti.py +++ b/apps/setting/models_provider/impl/zhipu_model_provider/credential/tti.py @@ -1,14 +1,19 @@ # coding=utf-8 from typing import Dict -from langchain_core.messages import HumanMessage - from common import forms from common.exception.app_exception import AppApiException -from common.forms import BaseForm +from common.forms import BaseForm, TooltipLabel from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode +class ZhiPuTTIModelParams(BaseForm): + size = forms.TextInputField( + TooltipLabel('图片尺寸', + '图片尺寸,仅 cogview-3-plus 支持该参数。可选范围:[1024x1024,768x1344,864x1152,1344x768,1152x864,1440x720,720x1440],默认是1024x1024。'), + required=True, default_value='1024x1024') + + class ZhiPuTextToImageModelCredential(BaseForm, BaseModelCredential): api_key = forms.PasswordInputField('API Key', required=True) @@ -41,4 +46,4 @@ class ZhiPuTextToImageModelCredential(BaseForm, BaseModelCredential): return {**model, 'api_key': super().encryption(model.get('api_key', ''))} def get_model_params_setting_form(self, model_name): - pass + return ZhiPuTTIModelParams() diff --git a/apps/setting/serializers/provider_serializers.py b/apps/setting/serializers/provider_serializers.py index e732087b0..2f7d6dd82 100644 --- a/apps/setting/serializers/provider_serializers.py +++ b/apps/setting/serializers/provider_serializers.py @@ -28,6 +28,11 @@ from setting.models_provider import get_model, get_model_credential from setting.models_provider.base_model_provider import ValidCode, DownModelChunkStatus from setting.models_provider.constants.model_provider_constants import ModelProvideConstants +def get_default_model_params_setting(provider, model_type, model_name): + credential = get_model_credential(provider, model_type, model_name) + model_params_setting = credential.get_model_params_setting_form(model_name).to_form_list() + return model_params_setting + class ModelPullManage: @@ -206,6 +211,7 @@ class ModelSerializer(serializers.Serializer): model = Model(id=uuid.uuid1(), status=status, user_id=user_id, name=name, credential=rsa_long_encrypt(model_credential_str), provider=provider, model_type=model_type, model_name=model_name, + model_params_form=get_default_model_params_setting(provider, model_type, model_name), permission_type=permission_type) model.save() if status == Status.DOWNLOAD: diff --git a/ui/src/views/template/component/ModelCard.vue b/ui/src/views/template/component/ModelCard.vue index 1ab1118fb..4420cd5b4 100644 --- a/ui/src/views/template/component/ModelCard.vue +++ b/ui/src/views/template/component/ModelCard.vue @@ -94,7 +94,7 @@