feat: Image understanding and image generation models support configuring model parameters

This commit is contained in:
CaptainB 2024-12-10 14:56:52 +08:00 committed by 刘瑞斌
parent add99fabc6
commit 8b33c99235
10 changed files with 89 additions and 24 deletions

View File

@ -22,6 +22,9 @@ class ImageUnderstandNodeSerializer(serializers.Serializer):
image_list = serializers.ListField(required=False, error_messages=ErrMessage.list("图片"))
model_params_setting = serializers.JSONField(required=False, default=dict, error_messages=ErrMessage.json("模型参数设置"))
class IImageUnderstandNode(INode):
type = 'image-understand-node'
@ -35,6 +38,7 @@ class IImageUnderstandNode(INode):
return self.execute(image=res, **self.node_params_serializer.data, **self.flow_params_serializer.data)
def execute(self, model_id, system, prompt, dialogue_number, dialogue_type, history_chat_record, stream, chat_id,
model_params_setting,
chat_record_id,
image,
**kwargs) -> NodeResult:

View File

@ -70,6 +70,7 @@ class BaseImageUnderstandNode(IImageUnderstandNode):
self.answer_text = details.get('answer')
def execute(self, model_id, system, prompt, dialogue_number, dialogue_type, history_chat_record, stream, chat_id,
model_params_setting,
chat_record_id,
image,
**kwargs) -> NodeResult:
@ -77,7 +78,7 @@ class BaseImageUnderstandNode(IImageUnderstandNode):
if image is None or not isinstance(image, list):
image = []
image_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'))
image_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'), **model_params_setting)
# 执行详情中的历史消息不需要图片内容
history_message = self.get_history_message_for_details(history_chat_record, dialogue_number)
self.context['history_message'] = history_message

View File

@ -8,6 +8,7 @@
"""
from typing import Dict
from common.forms import BaseLabel
from common.forms.base_field import BaseField, TriggerType
@ -16,7 +17,7 @@ class TextInputField(BaseField):
文本输入框
"""
def __init__(self, label: str,
def __init__(self, label: str or BaseLabel,
required: bool = False,
default_value=None,
relation_show_field_dict: Dict = None,

View File

@ -7,9 +7,26 @@ from langchain_core.messages import HumanMessage
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from common.forms import BaseForm, TooltipLabel
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
class OpenAIImageModelParams(BaseForm):
temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
required=True, default_value=0.7,
_min=0.1,
_max=1.0,
_step=0.01,
precision=2)
max_tokens = forms.SliderField(
TooltipLabel('输出最大Tokens', '指定模型可生成的最大token个数'),
required=True, default_value=800,
_min=1,
_max=100000,
_step=1,
precision=0)
class OpenAIImageModelCredential(BaseForm, BaseModelCredential):
api_base = forms.TextInputField('API 域名', required=True)
@ -45,4 +62,4 @@ class OpenAIImageModelCredential(BaseForm, BaseModelCredential):
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
def get_model_params_setting_form(self, model_name):
pass
return OpenAIImageModelParams()

View File

@ -7,9 +7,26 @@ from langchain_core.messages import HumanMessage
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from common.forms import BaseForm, TooltipLabel
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
class OpenAITTIModelParams(BaseForm):
size = forms.TextInputField(
TooltipLabel('图片尺寸', '指定生成图片的尺寸, 如: 1024x1024'),
required=True, default_value='1024x1024')
quality = forms.TextInputField(
TooltipLabel('图片质量', ''),
required=True, default_value='standard')
n = forms.SliderField(
TooltipLabel('图片数量', '指定生成图片的数量'),
required=True, default_value=1,
_min=1,
_max=10,
_step=1,
precision=0)
class OpenAITextToImageModelCredential(BaseForm, BaseModelCredential):
api_base = forms.TextInputField('API 域名', required=True)
@ -44,4 +61,4 @@ class OpenAITextToImageModelCredential(BaseForm, BaseModelCredential):
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
def get_model_params_setting_form(self, model_name):
pass
return OpenAITTIModelParams()

View File

@ -19,20 +19,19 @@ from setting.models_provider.base_model_provider import BaseModelCredential, Val
class QwenModelParams(BaseForm):
temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
required=True, default_value=1.0,
_min=0.1,
_max=1.9,
_step=0.01,
precision=2)
max_tokens = forms.SliderField(
TooltipLabel('输出最大Tokens', '指定模型可生成的最大token个数'),
required=True, default_value=800,
size = forms.TextInputField(
TooltipLabel('图片尺寸', '指定生成图片的尺寸, 如: 1024x1024'),
required=True, default_value='1024x1024')
n = forms.SliderField(
TooltipLabel('图片数量', '指定生成图片的数量'),
required=True, default_value=1,
_min=1,
_max=100000,
_max=4,
_step=1,
precision=0)
style = forms.TextInputField(
TooltipLabel('风格', '指定生成图片的风格'),
required=True, default_value='<auto>')
class QwenTextToImageModelCredential(BaseForm, BaseModelCredential):

View File

@ -7,9 +7,24 @@ from langchain_core.messages import HumanMessage
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from common.forms import BaseForm, TooltipLabel
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
class ZhiPuImageModelParams(BaseForm):
temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
required=True, default_value=0.95,
_min=0.1,
_max=1.0,
_step=0.01,
precision=2)
max_tokens = forms.SliderField(
TooltipLabel('输出最大Tokens', '指定模型可生成的最大token个数'),
required=True, default_value=1024,
_min=1,
_max=100000,
_step=1,
precision=0)
class ZhiPuImageModelCredential(BaseForm, BaseModelCredential):
api_key = forms.PasswordInputField('API Key', required=True)
@ -44,4 +59,4 @@ class ZhiPuImageModelCredential(BaseForm, BaseModelCredential):
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
def get_model_params_setting_form(self, model_name):
pass
return ZhiPuImageModelParams()

View File

@ -1,14 +1,19 @@
# coding=utf-8
from typing import Dict
from langchain_core.messages import HumanMessage
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from common.forms import BaseForm, TooltipLabel
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
class ZhiPuTTIModelParams(BaseForm):
size = forms.TextInputField(
TooltipLabel('图片尺寸',
'图片尺寸,仅 cogview-3-plus 支持该参数。可选范围:[1024x1024,768x1344,864x1152,1344x768,1152x864,1440x720,720x1440]默认是1024x1024。'),
required=True, default_value='1024x1024')
class ZhiPuTextToImageModelCredential(BaseForm, BaseModelCredential):
api_key = forms.PasswordInputField('API Key', required=True)
@ -41,4 +46,4 @@ class ZhiPuTextToImageModelCredential(BaseForm, BaseModelCredential):
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
def get_model_params_setting_form(self, model_name):
pass
return ZhiPuTTIModelParams()

View File

@ -28,6 +28,11 @@ from setting.models_provider import get_model, get_model_credential
from setting.models_provider.base_model_provider import ValidCode, DownModelChunkStatus
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
def get_default_model_params_setting(provider, model_type, model_name):
credential = get_model_credential(provider, model_type, model_name)
model_params_setting = credential.get_model_params_setting_form(model_name).to_form_list()
return model_params_setting
class ModelPullManage:
@ -206,6 +211,7 @@ class ModelSerializer(serializers.Serializer):
model = Model(id=uuid.uuid1(), status=status, user_id=user_id, name=name,
credential=rsa_long_encrypt(model_credential_str),
provider=provider, model_type=model_type, model_name=model_name,
model_params_form=get_default_model_params_setting(provider, model_type, model_name),
permission_type=permission_type)
model.save()
if status == Status.DOWNLOAD:

View File

@ -94,7 +94,7 @@
<template #dropdown>
<el-dropdown-menu>
<el-dropdown-item
v-if="currentModel.model_type === 'TTS' || currentModel.model_type === 'LLM'"
v-if="currentModel.model_type === 'TTS' || currentModel.model_type === 'LLM' || currentModel.model_type === 'IMAGE' || currentModel.model_type === 'TTI'"
:disabled="!is_permisstion"
icon="Setting" @click.stop="openParamSetting"
>