diff --git a/apps/application/flow/step_node/image_generate_step_node/i_image_generate_node.py b/apps/application/flow/step_node/image_generate_step_node/i_image_generate_node.py index 48a0840e7..1ce1af46c 100644 --- a/apps/application/flow/step_node/image_generate_step_node/i_image_generate_node.py +++ b/apps/application/flow/step_node/image_generate_step_node/i_image_generate_node.py @@ -21,6 +21,8 @@ class ImageGenerateNodeSerializer(serializers.Serializer): is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容')) + model_params_setting = serializers.JSONField(required=False, default=dict, error_messages=ErrMessage.json("模型参数设置")) + class IImageGenerateNode(INode): type = 'image-generate-node' @@ -32,6 +34,7 @@ class IImageGenerateNode(INode): return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data) def execute(self, model_id, prompt, negative_prompt, dialogue_number, dialogue_type, history_chat_record, chat_id, + model_params_setting, chat_record_id, **kwargs) -> NodeResult: pass diff --git a/apps/application/flow/step_node/image_generate_step_node/impl/base_image_generate_node.py b/apps/application/flow/step_node/image_generate_step_node/impl/base_image_generate_node.py index 2933c46ec..77231b70d 100644 --- a/apps/application/flow/step_node/image_generate_step_node/impl/base_image_generate_node.py +++ b/apps/application/flow/step_node/image_generate_step_node/impl/base_image_generate_node.py @@ -2,10 +2,13 @@ from functools import reduce from typing import List +import requests from langchain_core.messages import BaseMessage, HumanMessage, AIMessage from application.flow.i_step_node import NodeResult from application.flow.step_node.image_generate_step_node.i_image_generate_node import IImageGenerateNode +from common.util.common import bytes_to_uploaded_file +from dataset.serializers.file_serializers import FileSerializer from setting.models_provider.tools import get_model_instance_by_model_user_id @@ -16,10 +19,12 @@ class BaseImageGenerateNode(IImageGenerateNode): self.answer_text = details.get('answer') def execute(self, model_id, prompt, negative_prompt, dialogue_number, dialogue_type, history_chat_record, chat_id, + model_params_setting, chat_record_id, **kwargs) -> NodeResult: - - tti_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id')) + print(model_params_setting) + application = self.workflow_manage.work_flow_post_handler.chat_info.application + tti_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'), **model_params_setting) history_message = self.get_history_message(history_chat_record, dialogue_number) self.context['history_message'] = history_message question = self.generate_prompt_question(prompt) @@ -28,10 +33,21 @@ class BaseImageGenerateNode(IImageGenerateNode): self.context['message_list'] = message_list self.context['dialogue_type'] = dialogue_type print(message_list) - print(negative_prompt) image_urls = tti_model.generate_image(question, negative_prompt) - self.context['image_list'] = image_urls - answer = '\n'.join([f"![Image]({path})" for path in image_urls]) + # 保存图片 + file_urls = [] + for image_url in image_urls: + file_name = 'generated_image.png' + file = bytes_to_uploaded_file(requests.get(image_url).content, file_name) + meta = { + 'debug': False if application.id else True, + 'chat_id': chat_id, + 'application_id': str(application.id) if application.id else None, + } + file_url = FileSerializer(data={'file': file, 'meta': meta}).upload() + file_urls.append(file_url) + self.context['image_list'] = file_urls + answer = '\n'.join([f"![Image]({path})" for path in file_urls]) return NodeResult({'answer': answer, 'chat_model': tti_model, 'message_list': message_list, 'image': [{'file_id': path.split('/')[-1], 'file_url': path} for path in file_urls], 'history_message': history_message, 'question': question}, {}) diff --git a/apps/setting/models_provider/impl/openai_model_provider/credential/tti.py b/apps/setting/models_provider/impl/openai_model_provider/credential/tti.py index 7a9a70006..668ebca8a 100644 --- a/apps/setting/models_provider/impl/openai_model_provider/credential/tti.py +++ b/apps/setting/models_provider/impl/openai_model_provider/credential/tti.py @@ -10,14 +10,32 @@ from common.exception.app_exception import AppApiException from common.forms import BaseForm, TooltipLabel from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode -class OpenAITTIModelParams(BaseForm): - size = forms.TextInputField( - TooltipLabel('图片尺寸', '指定生成图片的尺寸, 如: 1024x1024'), - required=True, default_value='1024x1024') - quality = forms.TextInputField( +class OpenAITTIModelParams(BaseForm): + size = forms.SingleSelect( + TooltipLabel('图片尺寸', '指定生成图片的尺寸, 如: 1024x1024'), + required=True, + default_value='1024x1024', + option_list=[ + {'value': '1024x1024', 'label': '1024x1024'}, + {'value': '1024x1792', 'label': '1024x1792'}, + {'value': '1792x1024', 'label': '1792x1024'}, + ], + text_field='label', + value_field='value' + ) + + quality = forms.SingleSelect( TooltipLabel('图片质量', ''), - required=True, default_value='standard') + required=True, + default_value='standard', + option_list=[ + {'value': 'standard', 'label': 'standard'}, + {'value': 'hd', 'label': 'hd'}, + ], + text_field='label', + value_field='value' + ) n = forms.SliderField( TooltipLabel('图片数量', '指定生成图片的数量'), diff --git a/apps/setting/models_provider/impl/openai_model_provider/model/tti.py b/apps/setting/models_provider/impl/openai_model_provider/model/tti.py index 08ceede02..942afcf9f 100644 --- a/apps/setting/models_provider/impl/openai_model_provider/model/tti.py +++ b/apps/setting/models_provider/impl/openai_model_provider/model/tti.py @@ -1,13 +1,8 @@ from typing import Dict -import requests -from langchain_core.messages import HumanMessage -from langchain_openai import ChatOpenAI from openai import OpenAI from common.config.tokenizer_manage_config import TokenizerManage -from common.util.common import bytes_to_uploaded_file -from dataset.serializers.file_serializers import FileSerializer from setting.models_provider.base_model_provider import MaxKBBaseModel from setting.models_provider.impl.base_tti import BaseTextToImage @@ -32,7 +27,7 @@ class OpenAITextToImage(MaxKBBaseModel, BaseTextToImage): @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): - optional_params = {'params': {}} + optional_params = {'params': {'size': '1024x1024', 'quality': 'standard', 'n': 1}} for key, value in model_kwargs.items(): if key not in ['model_id', 'use_local', 'streaming']: optional_params['params'][key] = value @@ -43,6 +38,9 @@ class OpenAITextToImage(MaxKBBaseModel, BaseTextToImage): **optional_params, ) + def is_cache_model(self): + return False + def check_auth(self): chat = OpenAI(api_key=self.api_key, base_url=self.api_base) response_list = chat.models.with_raw_response.list() @@ -50,18 +48,11 @@ class OpenAITextToImage(MaxKBBaseModel, BaseTextToImage): # self.generate_image('生成一个小猫图片') def generate_image(self, prompt: str, negative_prompt: str = None): - chat = OpenAI(api_key=self.api_key, base_url=self.api_base) res = chat.images.generate(model=self.model, prompt=prompt, **self.params) - file_urls = [] for content in res.data: url = content.url - print(url) - file_name = 'generated_image.png' - file = bytes_to_uploaded_file(requests.get(url).content, file_name) - meta = {'debug': True} - file_url = FileSerializer(data={'file': file, 'meta': meta}).upload() - file_urls.append(file_url) + file_urls.append(url) return file_urls diff --git a/apps/setting/models_provider/impl/qwen_model_provider/credential/tti.py b/apps/setting/models_provider/impl/qwen_model_provider/credential/tti.py index 805751b4d..395d94db9 100644 --- a/apps/setting/models_provider/impl/qwen_model_provider/credential/tti.py +++ b/apps/setting/models_provider/impl/qwen_model_provider/credential/tti.py @@ -19,9 +19,18 @@ from setting.models_provider.base_model_provider import BaseModelCredential, Val class QwenModelParams(BaseForm): - size = forms.TextInputField( + size = forms.SingleSelect( TooltipLabel('图片尺寸', '指定生成图片的尺寸, 如: 1024x1024'), - required=True, default_value='1024x1024') + required=True, + default_value='1024*1024', + option_list=[ + {'value': '1024*1024', 'label': '1024*1024'}, + {'value': '720*1280', 'label': '720*1280'}, + {'value': '768*1152', 'label': '768*1152'}, + {'value': '1280*720', 'label': '1280*720'}, + ], + text_field='label', + value_field='value') n = forms.SliderField( TooltipLabel('图片数量', '指定生成图片的数量'), required=True, default_value=1, @@ -29,9 +38,25 @@ class QwenModelParams(BaseForm): _max=4, _step=1, precision=0) - style = forms.TextInputField( + style = forms.SingleSelect( TooltipLabel('风格', '指定生成图片的风格'), - required=True, default_value='') + required=True, + default_value='', + option_list=[ + {'value': '', 'label': '默认值,由模型随机输出图像风格'}, + {'value': '', 'label': '摄影'}, + {'value': '', 'label': '人像写真'}, + {'value': '<3d cartoon>', 'label': '3D卡通'}, + {'value': '', 'label': '动画'}, + {'value': '', 'label': '油画'}, + {'value': '', 'label': '水彩'}, + {'value': '', 'label': '素描'}, + {'value': '', 'label': '中国画'}, + {'value': '', 'label': '扁平插画'}, + ], + text_field='label', + value_field='value' + ) class QwenTextToImageModelCredential(BaseForm, BaseModelCredential): diff --git a/apps/setting/models_provider/impl/qwen_model_provider/model/tti.py b/apps/setting/models_provider/impl/qwen_model_provider/model/tti.py index 593171dfe..c2fd32877 100644 --- a/apps/setting/models_provider/impl/qwen_model_provider/model/tti.py +++ b/apps/setting/models_provider/impl/qwen_model_provider/model/tti.py @@ -1,16 +1,11 @@ # coding=utf-8 from http import HTTPStatus -from pathlib import PurePosixPath from typing import Dict -from urllib.parse import unquote, urlparse -import requests from dashscope import ImageSynthesis from langchain_community.chat_models import ChatTongyi from langchain_core.messages import HumanMessage -from common.util.common import bytes_to_uploaded_file -from dataset.serializers.file_serializers import FileSerializer from setting.models_provider.base_model_provider import MaxKBBaseModel from setting.models_provider.impl.base_tti import BaseTextToImage @@ -28,7 +23,7 @@ class QwenTextToImageModel(MaxKBBaseModel, BaseTextToImage): @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): - optional_params = {'params': {}} + optional_params = {'params': {'size': '1024*1024', 'style': '', 'n': 1}} for key, value in model_kwargs.items(): if key not in ['model_id', 'use_local', 'streaming']: optional_params['params'][key] = value @@ -39,6 +34,9 @@ class QwenTextToImageModel(MaxKBBaseModel, BaseTextToImage): ) return chat_tong_yi + def is_cache_model(self): + return False + def check_auth(self): chat = ChatTongyi(api_key=self.api_key, model_name='qwen-max') chat.invoke([HumanMessage([{"type": "text", "text": "你好"}])]) @@ -53,11 +51,7 @@ class QwenTextToImageModel(MaxKBBaseModel, BaseTextToImage): file_urls = [] if rsp.status_code == HTTPStatus.OK: for result in rsp.output.results: - file_name = PurePosixPath(unquote(urlparse(result.url).path)).parts[-1] - file = bytes_to_uploaded_file(requests.get(result.url).content, file_name) - meta = {'debug': True} - file_url = FileSerializer(data={'file': file, 'meta': meta}).upload() - file_urls.append(file_url) + file_urls.append(result.url) else: print('sync_call Failed, status_code: %s, code: %s, message: %s' % (rsp.status_code, rsp.code, rsp.message)) diff --git a/apps/setting/models_provider/impl/tencent_model_provider/model/tti.py b/apps/setting/models_provider/impl/tencent_model_provider/model/tti.py index e2a9976f0..e8d57dc13 100644 --- a/apps/setting/models_provider/impl/tencent_model_provider/model/tti.py +++ b/apps/setting/models_provider/impl/tencent_model_provider/model/tti.py @@ -3,15 +3,12 @@ import json from typing import Dict -import requests from tencentcloud.common import credential from tencentcloud.common.exception.tencent_cloud_sdk_exception import TencentCloudSDKException from tencentcloud.common.profile.client_profile import ClientProfile from tencentcloud.common.profile.http_profile import HttpProfile from tencentcloud.hunyuan.v20230901 import hunyuan_client, models -from common.util.common import bytes_to_uploaded_file -from dataset.serializers.file_serializers import FileSerializer from setting.models_provider.base_model_provider import MaxKBBaseModel from setting.models_provider.impl.base_tti import BaseTextToImage from setting.models_provider.impl.tencent_model_provider.model.hunyuan import ChatHunyuan @@ -87,12 +84,8 @@ class TencentTextToImageModel(MaxKBBaseModel, BaseTextToImage): # 输出json格式的字符串回包 print(resp.to_json_string()) file_urls = [] - file_name = 'generated_image.png' - file = bytes_to_uploaded_file(requests.get(resp.ResultImage).content, file_name) - meta = {'debug': True} - file_url = FileSerializer(data={'file': file, 'meta': meta}).upload() - file_urls.append(file_url) + + file_urls.append(resp.ResultImage) return file_urls except TencentCloudSDKException as err: print(err) - diff --git a/apps/setting/models_provider/impl/zhipu_model_provider/credential/tti.py b/apps/setting/models_provider/impl/zhipu_model_provider/credential/tti.py index b0a9d7a61..f951efd9e 100644 --- a/apps/setting/models_provider/impl/zhipu_model_provider/credential/tti.py +++ b/apps/setting/models_provider/impl/zhipu_model_provider/credential/tti.py @@ -8,10 +8,22 @@ from setting.models_provider.base_model_provider import BaseModelCredential, Val class ZhiPuTTIModelParams(BaseForm): - size = forms.TextInputField( + size = forms.SingleSelect( TooltipLabel('图片尺寸', '图片尺寸,仅 cogview-3-plus 支持该参数。可选范围:[1024x1024,768x1344,864x1152,1344x768,1152x864,1440x720,720x1440],默认是1024x1024。'), - required=True, default_value='1024x1024') + required=True, + default_value='1024x1024', + option_list=[ + {'value': '1024x1024', 'label': '1024x1024'}, + {'value': '768x1344', 'label': '768x1344'}, + {'value': '864x1152', 'label': '864x1152'}, + {'value': '1344x768', 'label': '1344x768'}, + {'value': '1152x864', 'label': '1152x864'}, + {'value': '1440x720', 'label': '1440x720'}, + {'value': '720x1440', 'label': '720x1440'}, + ], + text_field='label', + value_field='value') class ZhiPuTextToImageModelCredential(BaseForm, BaseModelCredential): diff --git a/apps/setting/models_provider/impl/zhipu_model_provider/model/tti.py b/apps/setting/models_provider/impl/zhipu_model_provider/model/tti.py index c537669ce..e2c59b85a 100644 --- a/apps/setting/models_provider/impl/zhipu_model_provider/model/tti.py +++ b/apps/setting/models_provider/impl/zhipu_model_provider/model/tti.py @@ -1,13 +1,10 @@ from typing import Dict -import requests from langchain_community.chat_models import ChatZhipuAI from langchain_core.messages import HumanMessage from zhipuai import ZhipuAI from common.config.tokenizer_manage_config import TokenizerManage -from common.util.common import bytes_to_uploaded_file -from dataset.serializers.file_serializers import FileSerializer from setting.models_provider.base_model_provider import MaxKBBaseModel from setting.models_provider.impl.base_tti import BaseTextToImage @@ -30,7 +27,7 @@ class ZhiPuTextToImage(MaxKBBaseModel, BaseTextToImage): @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): - optional_params = {'params': {}} + optional_params = {'params': {'size': '1024x1024'}} for key, value in model_kwargs.items(): if key not in ['model_id', 'use_local', 'streaming']: optional_params['params'][key] = value @@ -40,6 +37,9 @@ class ZhiPuTextToImage(MaxKBBaseModel, BaseTextToImage): **optional_params, ) + def is_cache_model(self): + return False + def check_auth(self): chat = ChatZhipuAI( zhipuai_api_key=self.api_key, @@ -58,16 +58,11 @@ class ZhiPuTextToImage(MaxKBBaseModel, BaseTextToImage): response = chat.images.generations( model=self.model, # 填写需要调用的模型编码 prompt=prompt, # 填写需要生成图片的文本 - **self.params # 填写额外参数 + **self.params # 填写额外参数 ) file_urls = [] for content in response.data: - url = content['url'] - print(url) - file_name = url.split('/')[-1] - file = bytes_to_uploaded_file(requests.get(url).content, file_name) - meta = {'debug': True} - file_url = FileSerializer(data={'file': file, 'meta': meta}).upload() - file_urls.append(file_url) + url = content.url + file_urls.append(url) return file_urls diff --git a/apps/setting/serializers/provider_serializers.py b/apps/setting/serializers/provider_serializers.py index 2f7d6dd82..b8f314127 100644 --- a/apps/setting/serializers/provider_serializers.py +++ b/apps/setting/serializers/provider_serializers.py @@ -30,8 +30,10 @@ from setting.models_provider.constants.model_provider_constants import ModelProv def get_default_model_params_setting(provider, model_type, model_name): credential = get_model_credential(provider, model_type, model_name) - model_params_setting = credential.get_model_params_setting_form(model_name).to_form_list() - return model_params_setting + setting_form = credential.get_model_params_setting_form(model_name) + if setting_form is not None: + return setting_form.to_form_list() + return [] class ModelPullManage: @@ -178,6 +180,8 @@ class ModelSerializer(serializers.Serializer): model_name = serializers.CharField(required=True, error_messages=ErrMessage.char("基础模型")) + model_params_form = serializers.ListField(required=False, default=list, error_messages=ErrMessage.char("参数配置")) + credential = serializers.DictField(required=True, error_messages=ErrMessage.dict("认证信息")) def is_valid(self, *, raise_exception=False): @@ -207,11 +211,12 @@ class ModelSerializer(serializers.Serializer): model_type = self.data.get('model_type') model_name = self.data.get('model_name') permission_type = self.data.get('permission_type') + model_params_form = self.data.get('model_params_form') model_credential_str = json.dumps(credential) model = Model(id=uuid.uuid1(), status=status, user_id=user_id, name=name, credential=rsa_long_encrypt(model_credential_str), provider=provider, model_type=model_type, model_name=model_name, - model_params_form=get_default_model_params_setting(provider, model_type, model_name), + model_params_form=model_params_form, permission_type=permission_type) model.save() if status == Status.DOWNLOAD: diff --git a/apps/setting/urls.py b/apps/setting/urls.py index 42e80592c..73fe9ba12 100644 --- a/apps/setting/urls.py +++ b/apps/setting/urls.py @@ -14,6 +14,8 @@ urlpatterns = [ path('provider/model_type_list', views.Provide.ModelTypeList.as_view(), name="provider/model_type_list"), path('provider/model_list', views.Provide.ModelList.as_view(), name="provider/model_name_list"), + path('provider/model_params_form', views.Provide.ModelParamsForm.as_view(), + name="provider/model_params_form"), path('provider/model_form', views.Provide.ModelForm.as_view(), name="provider/model_form"), path('model', views.Model.as_view(), name='model'), diff --git a/apps/setting/views/model.py b/apps/setting/views/model.py index b5abf9196..965f68b1b 100644 --- a/apps/setting/views/model.py +++ b/apps/setting/views/model.py @@ -16,7 +16,7 @@ from common.constants.permission_constants import PermissionConstants from common.response import result from common.util.common import query_params_to_single_dict from setting.models_provider.constants.model_provider_constants import ModelProvideConstants -from setting.serializers.provider_serializers import ProviderSerializer, ModelSerializer +from setting.serializers.provider_serializers import ProviderSerializer, ModelSerializer, get_default_model_params_setting from setting.swagger_api.provide_api import ProvideApi, ModelCreateApi, ModelQueryApi, ModelEditApi @@ -207,6 +207,24 @@ class Provide(APIView): ModelProvideConstants[provider].value.get_model_list( model_type)) + class ModelParamsForm(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="获取模型默认参数", + operation_id="获取模型创建表单", + manual_parameters=ProvideApi.ModelList.get_request_params_api(), + responses=result.get_api_array_response(ProvideApi.ModelList.get_response_body_api()) + , tags=["模型"] + ) + @has_permissions(PermissionConstants.MODEL_READ) + def get(self, request: Request): + provider = request.query_params.get('provider') + model_type = request.query_params.get('model_type') + model_name = request.query_params.get('model_name') + + return result.success(get_default_model_params_setting(provider, model_type, model_name)) + class ModelForm(APIView): authentication_classes = [TokenAuth] diff --git a/ui/src/api/model.ts b/ui/src/api/model.ts index 6519f1bc0..5129dd055 100644 --- a/ui/src/api/model.ts +++ b/ui/src/api/model.ts @@ -98,6 +98,15 @@ const listBaseModel: ( return get(`${prefix_provider}/model_list`, { provider, model_type }, loading) } +const listBaseModelParamsForm: ( + provider: string, + model_type: string, + model_name: string, + loading?: Ref +) => Promise>> = (provider, model_type, model_name, loading) => { + return get(`${prefix_provider}/model_params_form`, { provider, model_type, model_name}, loading) +} + /** * 创建模型 * @param request 请求对象 @@ -187,6 +196,7 @@ export default { getModelCreateForm, listModelType, listBaseModel, + listBaseModelParamsForm, createModel, updateModel, deleteModel, diff --git a/ui/src/views/template/component/CreateModelDialog.vue b/ui/src/views/template/component/CreateModelDialog.vue index 8f4bf727d..de8537e60 100644 --- a/ui/src/views/template/component/CreateModelDialog.vue +++ b/ui/src/views/template/component/CreateModelDialog.vue @@ -22,136 +22,192 @@ > - - -