feat: Support image generate model

This commit is contained in:
CaptainB 2024-12-09 11:00:51 +08:00 committed by 刘瑞斌
parent 9e859be5ff
commit add99fabc6
32 changed files with 1265 additions and 16 deletions

View File

@ -18,6 +18,7 @@ from .reranker_node import *
from .document_extract_node import *
from .image_understand_step_node import *
from .image_generate_step_node import *
from .search_dataset_node import *
from .start_node import *
@ -25,7 +26,7 @@ from .start_node import *
node_list = [BaseStartStepNode, BaseChatNode, BaseSearchDatasetNode, BaseQuestionNode, BaseConditionNode, BaseReplyNode,
BaseFunctionNodeNode, BaseFunctionLibNodeNode, BaseRerankerNode, BaseApplicationNode,
BaseDocumentExtractNode,
BaseImageUnderstandNode, BaseFormNode]
BaseImageUnderstandNode, BaseImageGenerateNode, BaseFormNode]
def get_node(node_type):

View File

@ -0,0 +1,3 @@
# coding=utf-8
from .impl import *

View File

@ -0,0 +1,37 @@
# coding=utf-8
from typing import Type
from rest_framework import serializers
from application.flow.i_step_node import INode, NodeResult
from common.util.field_message import ErrMessage
class ImageGenerateNodeSerializer(serializers.Serializer):
model_id = serializers.CharField(required=True, error_messages=ErrMessage.char("模型id"))
prompt = serializers.CharField(required=True, error_messages=ErrMessage.char("提示词(正向)"))
negative_prompt = serializers.CharField(required=False, default='', error_messages=ErrMessage.char("提示词(负向)"))
# 多轮对话数量
dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer("多轮对话数量"))
dialogue_type = serializers.CharField(required=True, error_messages=ErrMessage.char("对话存储类型"))
is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容'))
class IImageGenerateNode(INode):
type = 'image-generate-node'
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
return ImageGenerateNodeSerializer
def _run(self):
return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data)
def execute(self, model_id, prompt, negative_prompt, dialogue_number, dialogue_type, history_chat_record, chat_id,
chat_record_id,
**kwargs) -> NodeResult:
pass

View File

@ -0,0 +1,3 @@
# coding=utf-8
from .base_image_generate_node import BaseImageGenerateNode

View File

@ -0,0 +1,101 @@
# coding=utf-8
from functools import reduce
from typing import List
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
from application.flow.i_step_node import NodeResult
from application.flow.step_node.image_generate_step_node.i_image_generate_node import IImageGenerateNode
from setting.models_provider.tools import get_model_instance_by_model_user_id
class BaseImageGenerateNode(IImageGenerateNode):
def save_context(self, details, workflow_manage):
self.context['answer'] = details.get('answer')
self.context['question'] = details.get('question')
self.answer_text = details.get('answer')
def execute(self, model_id, prompt, negative_prompt, dialogue_number, dialogue_type, history_chat_record, chat_id,
chat_record_id,
**kwargs) -> NodeResult:
tti_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'))
history_message = self.get_history_message(history_chat_record, dialogue_number)
self.context['history_message'] = history_message
question = self.generate_prompt_question(prompt)
self.context['question'] = question
message_list = self.generate_message_list(question, history_message)
self.context['message_list'] = message_list
self.context['dialogue_type'] = dialogue_type
print(message_list)
print(negative_prompt)
image_urls = tti_model.generate_image(question, negative_prompt)
self.context['image_list'] = image_urls
answer = '\n'.join([f"![Image]({path})" for path in image_urls])
return NodeResult({'answer': answer, 'chat_model': tti_model, 'message_list': message_list,
'image': [{'file_id': path.split('/')[-1], 'file_url': path} for path in file_urls],
'history_message': history_message, 'question': question}, {})
def generate_history_ai_message(self, chat_record):
for val in chat_record.details.values():
if self.node.id == val['node_id'] and 'image_list' in val:
if val['dialogue_type'] == 'WORKFLOW':
return chat_record.get_ai_message()
return AIMessage(content=val['answer'])
return chat_record.get_ai_message()
def get_history_message(self, history_chat_record, dialogue_number):
start_index = len(history_chat_record) - dialogue_number
history_message = reduce(lambda x, y: [*x, *y], [
[self.generate_history_human_message(history_chat_record[index]),
self.generate_history_ai_message(history_chat_record[index])]
for index in
range(start_index if start_index > 0 else 0, len(history_chat_record))], [])
return history_message
def generate_history_human_message(self, chat_record):
for data in chat_record.details.values():
if self.node.id == data['node_id'] and 'image_list' in data:
image_list = data['image_list']
if len(image_list) == 0 or data['dialogue_type'] == 'WORKFLOW':
return HumanMessage(content=chat_record.problem_text)
return HumanMessage(content=data['question'])
return HumanMessage(content=chat_record.problem_text)
def generate_prompt_question(self, prompt):
return self.workflow_manage.generate_prompt(prompt)
def generate_message_list(self, question: str, history_message):
return [
*history_message,
question
]
@staticmethod
def reset_message_list(message_list: List[BaseMessage], answer_text):
result = [{'role': 'user' if isinstance(message, HumanMessage) else 'ai', 'content': message.content} for
message
in
message_list]
result.append({'role': 'ai', 'content': answer_text})
return result
def get_details(self, index: int, **kwargs):
return {
'name': self.node.properties.get('stepName'),
"index": index,
'run_time': self.context.get('run_time'),
'history_message': [{'content': message.content, 'role': message.type} for message in
(self.context.get('history_message') if self.context.get(
'history_message') is not None else [])],
'question': self.context.get('question'),
'answer': self.context.get('answer'),
'type': self.node.type,
'message_tokens': self.context.get('message_tokens'),
'answer_tokens': self.context.get('answer_tokens'),
'status': self.status,
'err_message': self.err_message,
'image_list': self.context.get('image_list'),
'dialogue_type': self.context.get('dialogue_type')
}

View File

@ -54,7 +54,7 @@ class Node:
end_nodes = ['ai-chat-node', 'reply-node', 'function-node', 'function-lib-node', 'application-node',
'image-understand-node']
'image-understand-node', 'image-generate-node']
class Flow:

View File

@ -8,9 +8,12 @@
"""
import hashlib
import importlib
import mimetypes
import io
from functools import reduce
from typing import Dict, List
from django.core.files.uploadedfile import InMemoryUploadedFile
from django.db.models import QuerySet
from ..exception.app_exception import AppApiException
@ -111,3 +114,25 @@ def bulk_create_in_batches(model, data, batch_size=1000):
batch = data[i:i + batch_size]
model.objects.bulk_create(batch)
def bytes_to_uploaded_file(file_bytes, file_name="file.txt"):
content_type, _ = mimetypes.guess_type(file_name)
if content_type is None:
# 如果未能识别,设置为默认的二进制文件类型
content_type = "application/octet-stream"
# 创建一个内存中的字节流对象
file_stream = io.BytesIO(file_bytes)
# 获取文件大小
file_size = len(file_bytes)
# 创建 InMemoryUploadedFile 对象
uploaded_file = InMemoryUploadedFile(
file=file_stream,
field_name=None,
name=file_name,
content_type=content_type,
size=file_size,
charset=None,
)
return uploaded_file

View File

@ -150,6 +150,7 @@ class ModelTypeConst(Enum):
STT = {'code': 'STT', 'message': '语音识别'}
TTS = {'code': 'TTS', 'message': '语音合成'}
IMAGE = {'code': 'IMAGE', 'message': '图片理解'}
TTI = {'code': 'TTI', 'message': '图片生成'}
RERANKER = {'code': 'RERANKER', 'message': '重排模型'}

View File

@ -0,0 +1,14 @@
# coding=utf-8
from abc import abstractmethod
from pydantic import BaseModel
class BaseTextToImage(BaseModel):
@abstractmethod
def check_auth(self):
pass
@abstractmethod
def generate_image(self, prompt: str, negative_prompt: str = None):
pass

View File

@ -0,0 +1,47 @@
# coding=utf-8
import base64
import os
from typing import Dict
from langchain_core.messages import HumanMessage
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
class OpenAITextToImageModelCredential(BaseForm, BaseModelCredential):
api_base = forms.TextInputField('API 域名', required=True)
api_key = forms.PasswordInputField('API Key', required=True)
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
for key in ['api_base', 'api_key']:
if key not in model_credential:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
res = model.check_auth()
print(res)
except Exception as e:
if isinstance(e, AppApiException):
raise e
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
else:
return False
return True
def encryption_dict(self, model: Dict[str, object]):
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
def get_model_params_setting_form(self, model_name):
pass

View File

@ -0,0 +1,67 @@
from typing import Dict
import requests
from langchain_core.messages import HumanMessage
from langchain_openai import ChatOpenAI
from openai import OpenAI
from common.config.tokenizer_manage_config import TokenizerManage
from common.util.common import bytes_to_uploaded_file
from dataset.serializers.file_serializers import FileSerializer
from setting.models_provider.base_model_provider import MaxKBBaseModel
from setting.models_provider.impl.base_tti import BaseTextToImage
def custom_get_token_ids(text: str):
tokenizer = TokenizerManage.get_tokenizer()
return tokenizer.encode(text)
class OpenAITextToImage(MaxKBBaseModel, BaseTextToImage):
api_base: str
api_key: str
model: str
params: dict
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.api_key = kwargs.get('api_key')
self.api_base = kwargs.get('api_base')
self.model = kwargs.get('model')
self.params = kwargs.get('params')
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {'params': {}}
for key, value in model_kwargs.items():
if key not in ['model_id', 'use_local', 'streaming']:
optional_params['params'][key] = value
return OpenAITextToImage(
model=model_name,
api_base=model_credential.get('api_base'),
api_key=model_credential.get('api_key'),
**optional_params,
)
def check_auth(self):
chat = OpenAI(api_key=self.api_key, base_url=self.api_base)
response_list = chat.models.with_raw_response.list()
# self.generate_image('生成一个小猫图片')
def generate_image(self, prompt: str, negative_prompt: str = None):
chat = OpenAI(api_key=self.api_key, base_url=self.api_base)
res = chat.images.generate(model=self.model, prompt=prompt, **self.params)
file_urls = []
for content in res.data:
url = content.url
print(url)
file_name = 'generated_image.png'
file = bytes_to_uploaded_file(requests.get(url).content, file_name)
meta = {'debug': True}
file_url = FileSerializer(data={'file': file, 'meta': meta}).upload()
file_urls.append(file_url)
return file_urls

View File

@ -15,11 +15,13 @@ from setting.models_provider.impl.openai_model_provider.credential.embedding imp
from setting.models_provider.impl.openai_model_provider.credential.image import OpenAIImageModelCredential
from setting.models_provider.impl.openai_model_provider.credential.llm import OpenAILLMModelCredential
from setting.models_provider.impl.openai_model_provider.credential.stt import OpenAISTTModelCredential
from setting.models_provider.impl.openai_model_provider.credential.tti import OpenAITextToImageModelCredential
from setting.models_provider.impl.openai_model_provider.credential.tts import OpenAITTSModelCredential
from setting.models_provider.impl.openai_model_provider.model.embedding import OpenAIEmbeddingModel
from setting.models_provider.impl.openai_model_provider.model.image import OpenAIImage
from setting.models_provider.impl.openai_model_provider.model.llm import OpenAIChatModel
from setting.models_provider.impl.openai_model_provider.model.stt import OpenAISpeechToText
from setting.models_provider.impl.openai_model_provider.model.tti import OpenAITextToImage
from setting.models_provider.impl.openai_model_provider.model.tts import OpenAITextToSpeech
from smartdoc.conf import PROJECT_DIR
@ -27,6 +29,7 @@ openai_llm_model_credential = OpenAILLMModelCredential()
openai_stt_model_credential = OpenAISTTModelCredential()
openai_tts_model_credential = OpenAITTSModelCredential()
openai_image_model_credential = OpenAIImageModelCredential()
openai_tti_model_credential = OpenAITextToImageModelCredential()
model_info_list = [
ModelInfo('gpt-3.5-turbo', '最新的gpt-3.5-turbo随OpenAI调整而更新', ModelTypeConst.LLM,
openai_llm_model_credential, OpenAIChatModel
@ -37,8 +40,8 @@ model_info_list = [
ModelTypeConst.LLM, openai_llm_model_credential,
OpenAIChatModel),
ModelInfo('gpt-4o-mini', '最新的gpt-4o-mini比gpt-4o更便宜、更快随OpenAI调整而更新',
ModelTypeConst.LLM, openai_llm_model_credential,
OpenAIChatModel),
ModelTypeConst.LLM, openai_llm_model_credential,
OpenAIChatModel),
ModelInfo('gpt-4-turbo', '最新的gpt-4-turbo随OpenAI调整而更新', ModelTypeConst.LLM,
openai_llm_model_credential,
OpenAIChatModel),
@ -100,11 +103,27 @@ model_info_image_list = [
OpenAIImage),
]
model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info(
ModelInfo('gpt-3.5-turbo', '最新的gpt-3.5-turbo随OpenAI调整而更新', ModelTypeConst.LLM,
openai_llm_model_credential, OpenAIChatModel
)).append_model_info_list(model_info_embedding_list).append_default_model_info(
model_info_embedding_list[0]).append_model_info_list(model_info_image_list).build()
model_info_tti_list = [
ModelInfo('dall-e-2', '',
ModelTypeConst.TTI, openai_tti_model_credential,
OpenAITextToImage),
ModelInfo('dall-e-3', '',
ModelTypeConst.TTI, openai_tti_model_credential,
OpenAITextToImage),
]
model_info_manage = (
ModelInfoManage.builder()
.append_model_info_list(model_info_list)
.append_default_model_info(ModelInfo('gpt-3.5-turbo', '最新的gpt-3.5-turbo随OpenAI调整而更新', ModelTypeConst.LLM,
openai_llm_model_credential, OpenAIChatModel
))
.append_model_info_list(model_info_embedding_list)
.append_default_model_info(model_info_embedding_list[0])
.append_model_info_list(model_info_image_list)
.append_model_info_list(model_info_tti_list)
.build()
)
class OpenAIModelProvider(IModelProvider):

View File

@ -0,0 +1,70 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file llm.py
@date2024/7/11 18:41
@desc:
"""
import base64
import os
from typing import Dict
from langchain_core.messages import HumanMessage
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm, TooltipLabel
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
class QwenModelParams(BaseForm):
temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
required=True, default_value=1.0,
_min=0.1,
_max=1.9,
_step=0.01,
precision=2)
max_tokens = forms.SliderField(
TooltipLabel('输出最大Tokens', '指定模型可生成的最大token个数'),
required=True, default_value=800,
_min=1,
_max=100000,
_step=1,
precision=0)
class QwenTextToImageModelCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
for key in ['api_key']:
if key not in model_credential:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
res = model.check_auth()
print(res)
except Exception as e:
if isinstance(e, AppApiException):
raise e
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
else:
return False
return True
def encryption_dict(self, model: Dict[str, object]):
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
api_key = forms.PasswordInputField('API Key', required=True)
def get_model_params_setting_form(self, model_name):
return QwenModelParams()

View File

@ -0,0 +1,64 @@
# coding=utf-8
from http import HTTPStatus
from pathlib import PurePosixPath
from typing import Dict
from urllib.parse import unquote, urlparse
import requests
from dashscope import ImageSynthesis
from langchain_community.chat_models import ChatTongyi
from langchain_core.messages import HumanMessage
from common.util.common import bytes_to_uploaded_file
from dataset.serializers.file_serializers import FileSerializer
from setting.models_provider.base_model_provider import MaxKBBaseModel
from setting.models_provider.impl.base_tti import BaseTextToImage
class QwenTextToImageModel(MaxKBBaseModel, BaseTextToImage):
api_key: str
model_name: str
params: dict
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.api_key = kwargs.get('api_key')
self.model_name = kwargs.get('model_name')
self.params = kwargs.get('params')
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {'params': {}}
for key, value in model_kwargs.items():
if key not in ['model_id', 'use_local', 'streaming']:
optional_params['params'][key] = value
chat_tong_yi = QwenTextToImageModel(
model_name=model_name,
api_key=model_credential.get('api_key'),
**optional_params,
)
return chat_tong_yi
def check_auth(self):
chat = ChatTongyi(api_key=self.api_key, model_name='qwen-max')
chat.invoke([HumanMessage([{"type": "text", "text": "你好"}])])
def generate_image(self, prompt: str, negative_prompt: str = None):
# api_base='https://dashscope.aliyuncs.com/compatible-mode/v1',
rsp = ImageSynthesis.call(api_key=self.api_key,
model=self.model_name,
prompt=prompt,
negative_prompt=negative_prompt,
**self.params)
file_urls = []
if rsp.status_code == HTTPStatus.OK:
for result in rsp.output.results:
file_name = PurePosixPath(unquote(urlparse(result.url).path)).parts[-1]
file = bytes_to_uploaded_file(requests.get(result.url).content, file_name)
meta = {'debug': True}
file_url = FileSerializer(data={'file': file, 'meta': meta}).upload()
file_urls.append(file_url)
else:
print('sync_call Failed, status_code: %s, code: %s, message: %s' %
(rsp.status_code, rsp.code, rsp.message))
return file_urls

View File

@ -13,13 +13,16 @@ from setting.models_provider.base_model_provider import ModelProvideInfo, ModelT
ModelInfoManage
from setting.models_provider.impl.qwen_model_provider.credential.image import QwenVLModelCredential
from setting.models_provider.impl.qwen_model_provider.credential.llm import OpenAILLMModelCredential
from setting.models_provider.impl.qwen_model_provider.credential.tti import QwenTextToImageModelCredential
from setting.models_provider.impl.qwen_model_provider.model.image import QwenVLChatModel
from setting.models_provider.impl.qwen_model_provider.model.llm import QwenChatModel
from setting.models_provider.impl.qwen_model_provider.model.tti import QwenTextToImageModel
from smartdoc.conf import PROJECT_DIR
qwen_model_credential = OpenAILLMModelCredential()
qwenvl_model_credential = QwenVLModelCredential()
qwentti_model_credential = QwenTextToImageModelCredential()
module_info_list = [
ModelInfo('qwen-turbo', '', ModelTypeConst.LLM, qwen_model_credential, QwenChatModel),
@ -31,13 +34,21 @@ module_info_vl_list = [
ModelInfo('qwen-vl-max-0809', '', ModelTypeConst.IMAGE, qwenvl_model_credential, QwenVLChatModel),
ModelInfo('qwen-vl-plus-0809', '', ModelTypeConst.IMAGE, qwenvl_model_credential, QwenVLChatModel),
]
module_info_tti_list = [
ModelInfo('wanx-v1',
'通义万相-文本生成图像大模型支持中英文双语输入支持输入参考图片进行参考内容或者参考风格迁移重点风格包括但不限于水彩、油画、中国画、素描、扁平插画、二次元、3D卡通。',
ModelTypeConst.TTI, qwentti_model_credential, QwenTextToImageModel),
]
model_info_manage = (ModelInfoManage.builder()
.append_model_info_list(module_info_list)
.append_default_model_info(
ModelInfo('qwen-turbo', '', ModelTypeConst.LLM, qwen_model_credential, QwenChatModel))
.append_model_info_list(module_info_vl_list)
.build())
model_info_manage = (
ModelInfoManage.builder()
.append_model_info_list(module_info_list)
.append_default_model_info(
ModelInfo('qwen-turbo', '', ModelTypeConst.LLM, qwen_model_credential, QwenChatModel))
.append_model_info_list(module_info_vl_list)
.append_model_info_list(module_info_tti_list)
.build()
)
class QwenModelProvider(IModelProvider):

View File

@ -0,0 +1,108 @@
# coding=utf-8
from langchain_core.messages import HumanMessage
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm, TooltipLabel
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
class TencentTTIModelParams(BaseForm):
Style = forms.SingleSelect(
TooltipLabel('绘画风格', '不传默认使用201日系动漫风格'),
required=True,
default_value='201',
option_list=[
{'value': '000', 'label': '不限定风格'},
{'value': '101', 'label': '水墨画'},
{'value': '102', 'label': '概念艺术'},
{'value': '103', 'label': '油画1'},
{'value': '118', 'label': '油画2梵高'},
{'value': '104', 'label': '水彩画'},
{'value': '105', 'label': '像素画'},
{'value': '106', 'label': '厚涂风格'},
{'value': '107', 'label': '插图'},
{'value': '108', 'label': '剪纸风格'},
{'value': '109', 'label': '印象派1莫奈'},
{'value': '119', 'label': '印象派2'},
{'value': '110', 'label': '2.5D'},
{'value': '111', 'label': '古典肖像画'},
{'value': '112', 'label': '黑白素描画'},
{'value': '113', 'label': '赛博朋克'},
{'value': '114', 'label': '科幻风格'},
{'value': '115', 'label': '暗黑风格'},
{'value': '116', 'label': '3D'},
{'value': '117', 'label': '蒸汽波'},
{'value': '201', 'label': '日系动漫'},
{'value': '202', 'label': '怪兽风格'},
{'value': '203', 'label': '唯美古风'},
{'value': '204', 'label': '复古动漫'},
{'value': '301', 'label': '游戏卡通手绘'},
{'value': '401', 'label': '通用写实风格'},
],
value_field='value',
text_field='label'
)
Resolution = forms.SingleSelect(
TooltipLabel('生成图分辨率', '不传默认使用768:768。'),
required=True,
default_value='768:768',
option_list=[
{'value': '768:768', 'label': '768:7681:1'},
{'value': '768:1024', 'label': '768:10243:4'},
{'value': '1024:768', 'label': '1024:7684:3'},
{'value': '1024:1024', 'label': '1024:10241:1'},
{'value': '720:1280', 'label': '720:12809:16'},
{'value': '1280:720', 'label': '1280:72016:9'},
{'value': '768:1280', 'label': '768:12803:5'},
{'value': '1280:768', 'label': '1280:7685:3'},
{'value': '1080:1920', 'label': '1080:19209:16'},
{'value': '1920:1080', 'label': '1920:108016:9'},
],
value_field='value',
text_field='label'
)
class TencentTTIModelCredential(BaseForm, BaseModelCredential):
REQUIRED_FIELDS = ['hunyuan_secret_id', 'hunyuan_secret_key']
@classmethod
def _validate_model_type(cls, model_type, provider, raise_exception=False):
if not any(mt['value'] == model_type for mt in provider.get_model_type_list()):
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
return False
return True
@classmethod
def _validate_credential_fields(cls, model_credential, raise_exception=False):
missing_keys = [key for key in cls.REQUIRED_FIELDS if key not in model_credential]
if missing_keys:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'{", ".join(missing_keys)} 字段为必填字段')
return False
return True
def is_valid(self, model_type, model_name, model_credential, provider, raise_exception=False):
if not (self._validate_model_type(model_type, provider, raise_exception) and
self._validate_credential_fields(model_credential, raise_exception)):
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
model.check_auth()
except Exception as e:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
return False
return True
def encryption_dict(self, model):
return {**model, 'hunyuan_secret_key': super().encryption(model.get('hunyuan_secret_key', ''))}
hunyuan_secret_id = forms.PasswordInputField('SecretId', required=True)
hunyuan_secret_key = forms.PasswordInputField('SecretKey', required=True)
def get_model_params_setting_form(self, model_name):
return TencentTTIModelParams()

View File

@ -0,0 +1,98 @@
# coding=utf-8
import json
from typing import Dict
import requests
from tencentcloud.common import credential
from tencentcloud.common.exception.tencent_cloud_sdk_exception import TencentCloudSDKException
from tencentcloud.common.profile.client_profile import ClientProfile
from tencentcloud.common.profile.http_profile import HttpProfile
from tencentcloud.hunyuan.v20230901 import hunyuan_client, models
from common.util.common import bytes_to_uploaded_file
from dataset.serializers.file_serializers import FileSerializer
from setting.models_provider.base_model_provider import MaxKBBaseModel
from setting.models_provider.impl.base_tti import BaseTextToImage
from setting.models_provider.impl.tencent_model_provider.model.hunyuan import ChatHunyuan
class TencentTextToImageModel(MaxKBBaseModel, BaseTextToImage):
hunyuan_secret_id: str
hunyuan_secret_key: str
model: str
params: dict
@staticmethod
def is_cache_model():
return False
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.hunyuan_secret_id = kwargs.get('hunyuan_secret_id')
self.hunyuan_secret_key = kwargs.get('hunyuan_secret_key')
self.model = kwargs.get('model_name')
self.params = kwargs.get('params')
@staticmethod
def new_instance(model_type: str, model_name: str, model_credential: Dict[str, object],
**model_kwargs) -> 'TencentTextToImageModel':
optional_params = {'params': {'Style': '201', 'Resolution': '768:768'}}
for key, value in model_kwargs.items():
if key not in ['model_id', 'use_local', 'streaming']:
optional_params['params'][key] = value
return TencentTextToImageModel(
model=model_name,
hunyuan_secret_id=model_credential.get('hunyuan_secret_id'),
hunyuan_secret_key=model_credential.get('hunyuan_secret_key'),
**optional_params
)
def check_auth(self):
chat = ChatHunyuan(hunyuan_app_id='111111',
hunyuan_secret_id=self.hunyuan_secret_id,
hunyuan_secret_key=self.hunyuan_secret_key,
model="hunyuan-standard")
res = chat.invoke('你好')
# print(res)
def generate_image(self, prompt: str, negative_prompt: str = None):
try:
# 实例化一个认证对象,入参需要传入腾讯云账户 SecretId 和 SecretKey此处还需注意密钥对的保密
# 代码泄露可能会导致 SecretId 和 SecretKey 泄露并威胁账号下所有资源的安全性。以下代码示例仅供参考建议采用更安全的方式来使用密钥请参见https://cloud.tencent.com/document/product/1278/85305
# 密钥可前往官网控制台 https://console.cloud.tencent.com/cam/capi 进行获取
cred = credential.Credential(self.hunyuan_secret_id, self.hunyuan_secret_key)
# 实例化一个http选项可选的没有特殊需求可以跳过
httpProfile = HttpProfile()
httpProfile.endpoint = "hunyuan.tencentcloudapi.com"
# 实例化一个client选项可选的没有特殊需求可以跳过
clientProfile = ClientProfile()
clientProfile.httpProfile = httpProfile
# 实例化要请求产品的client对象,clientProfile是可选的
client = hunyuan_client.HunyuanClient(cred, "ap-guangzhou", clientProfile)
# 实例化一个请求对象,每个接口都会对应一个request对象
req = models.TextToImageLiteRequest()
params = {
"Prompt": prompt,
"NegativePrompt": negative_prompt,
"RspImgType": "url",
**self.params
}
req.from_json_string(json.dumps(params))
# 返回的resp是一个TextToImageLiteResponse的实例与请求对象对应
resp = client.TextToImageLite(req)
# 输出json格式的字符串回包
print(resp.to_json_string())
file_urls = []
file_name = 'generated_image.png'
file = bytes_to_uploaded_file(requests.get(resp.ResultImage).content, file_name)
meta = {'debug': True}
file_url = FileSerializer(data={'file': file, 'meta': meta}).upload()
file_urls.append(file_url)
return file_urls
except TencentCloudSDKException as err:
print(err)

View File

@ -9,9 +9,11 @@ from setting.models_provider.base_model_provider import (
from setting.models_provider.impl.tencent_model_provider.credential.embedding import TencentEmbeddingCredential
from setting.models_provider.impl.tencent_model_provider.credential.image import TencentVisionModelCredential
from setting.models_provider.impl.tencent_model_provider.credential.llm import TencentLLMModelCredential
from setting.models_provider.impl.tencent_model_provider.credential.tti import TencentTTIModelCredential
from setting.models_provider.impl.tencent_model_provider.model.embedding import TencentEmbeddingModel
from setting.models_provider.impl.tencent_model_provider.model.image import TencentVision
from setting.models_provider.impl.tencent_model_provider.model.llm import TencentModel
from setting.models_provider.impl.tencent_model_provider.model.tti import TencentTextToImageModel
from smartdoc.conf import PROJECT_DIR
@ -87,11 +89,19 @@ def _initialize_model_info():
TencentVisionModelCredential,
TencentVision)]
model_info_tti_list = [_create_model_info(
'hunyuan-dit',
'混元生图模型',
ModelTypeConst.TTI,
TencentTTIModelCredential,
TencentTextToImageModel)]
model_info_manage = ModelInfoManage.builder() \
.append_model_info_list(model_info_list) \
.append_model_info_list(model_info_embedding_list) \
.append_model_info_list(model_info_vision_list) \
.append_model_info_list(model_info_tti_list) \
.append_default_model_info(model_info_list[0]) \
.build()

View File

@ -0,0 +1,44 @@
# coding=utf-8
from typing import Dict
from langchain_core.messages import HumanMessage
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
class ZhiPuTextToImageModelCredential(BaseForm, BaseModelCredential):
api_key = forms.PasswordInputField('API Key', required=True)
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
for key in ['api_key']:
if key not in model_credential:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
res = model.check_auth()
print(res)
except Exception as e:
if isinstance(e, AppApiException):
raise e
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
else:
return False
return True
def encryption_dict(self, model: Dict[str, object]):
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
def get_model_params_setting_form(self, model_name):
pass

View File

@ -0,0 +1,73 @@
from typing import Dict
import requests
from langchain_community.chat_models import ChatZhipuAI
from langchain_core.messages import HumanMessage
from zhipuai import ZhipuAI
from common.config.tokenizer_manage_config import TokenizerManage
from common.util.common import bytes_to_uploaded_file
from dataset.serializers.file_serializers import FileSerializer
from setting.models_provider.base_model_provider import MaxKBBaseModel
from setting.models_provider.impl.base_tti import BaseTextToImage
def custom_get_token_ids(text: str):
tokenizer = TokenizerManage.get_tokenizer()
return tokenizer.encode(text)
class ZhiPuTextToImage(MaxKBBaseModel, BaseTextToImage):
api_key: str
model: str
params: dict
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.api_key = kwargs.get('api_key')
self.model = kwargs.get('model')
self.params = kwargs.get('params')
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {'params': {}}
for key, value in model_kwargs.items():
if key not in ['model_id', 'use_local', 'streaming']:
optional_params['params'][key] = value
return ZhiPuTextToImage(
model=model_name,
api_key=model_credential.get('api_key'),
**optional_params,
)
def check_auth(self):
chat = ChatZhipuAI(
zhipuai_api_key=self.api_key,
model_name=self.model,
)
chat.invoke([HumanMessage([{"type": "text", "text": "你好"}])])
# self.generate_image('生成一个小猫图片')
def generate_image(self, prompt: str, negative_prompt: str = None):
# chat = ChatZhipuAI(
# zhipuai_api_key=self.api_key,
# model_name=self.model,
# )
chat = ZhipuAI(api_key=self.api_key)
response = chat.images.generations(
model=self.model, # 填写需要调用的模型编码
prompt=prompt, # 填写需要生成图片的文本
**self.params # 填写额外参数
)
file_urls = []
for content in response.data:
url = content['url']
print(url)
file_name = url.split('/')[-1]
file = bytes_to_uploaded_file(requests.get(url).content, file_name)
meta = {'debug': True}
file_url = FileSerializer(data={'file': file, 'meta': meta}).upload()
file_urls.append(file_url)
return file_urls

View File

@ -13,12 +13,15 @@ from setting.models_provider.base_model_provider import ModelProvideInfo, ModelT
ModelInfoManage
from setting.models_provider.impl.zhipu_model_provider.credential.image import ZhiPuImageModelCredential
from setting.models_provider.impl.zhipu_model_provider.credential.llm import ZhiPuLLMModelCredential
from setting.models_provider.impl.zhipu_model_provider.credential.tti import ZhiPuTextToImageModelCredential
from setting.models_provider.impl.zhipu_model_provider.model.image import ZhiPuImage
from setting.models_provider.impl.zhipu_model_provider.model.llm import ZhipuChatModel
from setting.models_provider.impl.zhipu_model_provider.model.tti import ZhiPuTextToImage
from smartdoc.conf import PROJECT_DIR
qwen_model_credential = ZhiPuLLMModelCredential()
zhipu_image_model_credential = ZhiPuImageModelCredential()
zhipu_tti_model_credential = ZhiPuTextToImageModelCredential()
model_info_list = [
ModelInfo('glm-4', '', ModelTypeConst.LLM, qwen_model_credential, ZhipuChatModel),
@ -38,11 +41,21 @@ model_info_image_list = [
ZhiPuImage),
]
model_info_tti_list = [
ModelInfo('cogview-3', '根据用户文字描述快速、精准生成图像。分辨率支持1024x1024',
ModelTypeConst.TTI, zhipu_tti_model_credential,
ZhiPuTextToImage),
ModelInfo('cogview-3-plus', '根据用户文字描述生成高质量图像,支持多图片尺寸',
ModelTypeConst.TTI, zhipu_tti_model_credential,
ZhiPuTextToImage),
]
model_info_manage = (
ModelInfoManage.builder()
.append_model_info_list(model_info_list)
.append_default_model_info(ModelInfo('glm-4', '', ModelTypeConst.LLM, qwen_model_credential, ZhipuChatModel))
.append_model_info_list(model_info_image_list)
.append_model_info_list(model_info_tti_list)
.build()
)

View File

@ -293,6 +293,13 @@ const getApplicationImageModel: (
return get(`${prefix}/${application_id}/model`, { model_type: 'IMAGE' }, loading)
}
const getApplicationTTIModel: (
application_id: string,
loading?: Ref<boolean>
) => Promise<Result<Array<any>>> = (application_id, loading) => {
return get(`${prefix}/${application_id}/model`, { model_type: 'TTI' }, loading)
}
/**
*
@ -523,6 +530,7 @@ export default {
getApplicationSTTModel,
getApplicationTTSModel,
getApplicationImageModel,
getApplicationTTIModel,
postSpeechToText,
postTextToSpeech,
getPlatformStatus,

View File

@ -32,6 +32,7 @@
item.type === WorkflowType.Question ||
item.type === WorkflowType.AiChat ||
item.type === WorkflowType.ImageUnderstandNode ||
item.type === WorkflowType.ImageGenerateNode ||
item.type === WorkflowType.Application
"
>{{ item?.message_tokens + item?.answer_tokens }} tokens</span
@ -444,6 +445,65 @@
</div>
</div>
</template>
<!-- 图片生成 -->
<template v-if="item.type == WorkflowType.ImageGenerateNode">
<div
class="card-never border-r-4 mt-8"
v-if="item.type !== WorkflowType.Application"
>
<h5 class="p-8-12">历史记录</h5>
<div class="p-8-12 border-t-dashed lighter">
<template v-if="item.history_message?.length > 0">
<p
class="mt-4 mb-4"
v-for="(history, historyIndex) in item.history_message"
:key="historyIndex"
>
<span class="color-secondary mr-4">{{ history.role }}:</span>
<span v-if="Array.isArray(history.content)">
<template v-for="(h, i) in history.content" :key="i">
<el-image
v-if="h.type === 'image_url'"
:src="h.image_url.url"
alt=""
fit="cover"
style="width: 40px; height: 40px; display: inline-block"
class="border-r-4 mr-8"
/>
<span v-else>{{ h.text }}<br /></span>
</template>
</span>
<span v-else>{{ history.content }}</span>
</p>
</template>
<template v-else> - </template>
</div>
</div>
<div class="card-never border-r-4 mt-8">
<h5 class="p-8-12">本次对话</h5>
<div class="p-8-12 border-t-dashed lighter pre-wrap">
{{ item.question || '-' }}
</div>
</div>
<div class="card-never border-r-4 mt-8">
<h5 class="p-8-12">
{{ item.type == WorkflowType.Application ? '参数输出' : 'AI 回答' }}
</h5>
<div class="p-8-12 border-t-dashed lighter">
<MdPreview
v-if="item.answer"
ref="editorRef"
editorId="preview-only"
:modelValue="item.answer"
style="background: none"
/>
<template v-else> - </template>
</div>
</div>
</template>
</template>
<template v-else>
<div class="card-never border-r-4">

View File

@ -13,5 +13,6 @@ export enum modelType {
STT = '语音识别',
TTS = '语音合成',
IMAGE = '图片理解',
TTI = '图片生成',
RERANKER = '重排模型'
}

View File

@ -12,5 +12,6 @@ export enum WorkflowType {
Application = 'application-node',
DocumentExtractNode = 'document-extract-node',
ImageUnderstandNode = 'image-understand-node',
ImageGenerateNode = 'image-generate-node',
FormNode = 'form-node'
}

View File

@ -63,6 +63,7 @@ const modelTypeOptions = ref([
{ text: '语音识别', value: 'STT' },
{ text: '语音合成', value: 'TTS' },
{ text: '图片理解', value: 'IMAGE' },
{ text: '图片生成', value: 'TTI' },
])
const open = () => {

View File

@ -133,6 +133,7 @@
<el-option label="语音识别" value="STT" />
<el-option label="语音合成" value="TTS" />
<el-option label="图片理解" value="IMAGE" />
<el-option label="图片生成" value="TTI" />
</el-select>
</div>
</div>

View File

@ -227,6 +227,28 @@ export const imageUnderstandNode = {
}
}
}
export const imageGenerateNode = {
type: WorkflowType.ImageGenerateNode,
text: '根据提供的文本内容生成图片',
label: '图片生成',
height: 252,
properties: {
stepName: '图片生成',
config: {
fields: [
{
label: 'AI 回答内容',
value: 'answer'
},
{
label: '图片',
value: 'image'
}
]
}
}
}
export const menuNodes = [
aiChatNode,
searchDatasetNode,
@ -236,6 +258,7 @@ export const menuNodes = [
rerankerNode,
documentExtractNode,
imageUnderstandNode,
imageGenerateNode,
formNode
]
@ -326,7 +349,8 @@ export const nodeDict: any = {
[WorkflowType.FormNode]: formNode,
[WorkflowType.Application]: applicationNode,
[WorkflowType.DocumentExtractNode]: documentExtractNode,
[WorkflowType.ImageUnderstandNode]: imageUnderstandNode
[WorkflowType.ImageUnderstandNode]: imageUnderstandNode,
[WorkflowType.ImageGenerateNode]: imageGenerateNode
}
export function isWorkFlow(type: string | undefined) {
return type === 'WORK_FLOW'

View File

@ -6,6 +6,7 @@ const end_nodes: Array<string> = [
WorkflowType.FunctionLib,
WorkflowType.FunctionLibCustom,
WorkflowType.ImageUnderstandNode,
WorkflowType.ImageGenerateNode,
WorkflowType.Application
]
export class WorkFlowInstance {

View File

@ -0,0 +1,6 @@
<template>
<AppAvatar shape="square" style="background: #14C0FF;">
<img src="@/assets/icon_image.svg" style="width: 65%" alt="" />
</AppAvatar>
</template>
<script setup lang="ts"></script>

View File

@ -0,0 +1,14 @@
import ImageGenerateNodeVue from './index.vue'
import { AppNode, AppNodeModel } from '@/workflow/common/app-node'
class RerankerNode extends AppNode {
constructor(props: any) {
super(props, ImageGenerateNodeVue)
}
}
export default {
type: 'image-generate-node',
model: AppNodeModel,
view: RerankerNode
}

View File

@ -0,0 +1,323 @@
<template>
<NodeContainer :node-model="nodeModel">
<h5 class="title-decoration-1 mb-8">节点设置</h5>
<el-card shadow="never" class="card-never">
<el-form
@submit.prevent
:model="form_data"
label-position="top"
require-asterisk-position="right"
label-width="auto"
ref="aiChatNodeFormRef"
hide-required-asterisk
>
<el-form-item
label="图片生成模型"
prop="model_id"
:rules="{
required: true,
message: '请选择图片生成模型',
trigger: 'change'
}"
>
<template #label>
<div class="flex-between w-full">
<div>
<span>图片生成模型<span class="danger">*</span></span>
</div>
<el-button
:disabled="!form_data.model_id"
type="primary"
link
@click="openAIParamSettingDialog(form_data.model_id)"
@refreshForm="refreshParam"
>
{{ $t('views.application.applicationForm.form.paramSetting') }}
</el-button>
</div>
</template>
<el-select
@change="model_change"
@wheel="wheel"
:teleported="false"
v-model="form_data.model_id"
placeholder="请选择图片生成模型"
class="w-full"
popper-class="select-model"
:clearable="true"
>
<el-option-group
v-for="(value, label) in modelOptions"
:key="value"
:label="relatedObject(providerOptions, label, 'provider')?.name"
>
<el-option
v-for="item in value.filter((v: any) => v.status === 'SUCCESS')"
:key="item.id"
:label="item.name"
:value="item.id"
class="flex-between"
>
<div class="flex align-center">
<span
v-html="relatedObject(providerOptions, label, 'provider')?.icon"
class="model-icon mr-8"
></span>
<span>{{ item.name }}</span>
<el-tag v-if="item.permission_type === 'PUBLIC'" type="info" class="info-tag ml-8"
>公用
</el-tag>
</div>
<el-icon class="check-icon" v-if="item.id === form_data.model_id">
<Check />
</el-icon>
</el-option>
<!-- 不可用 -->
<el-option
v-for="item in value.filter((v: any) => v.status !== 'SUCCESS')"
:key="item.id"
:label="item.name"
:value="item.id"
class="flex-between"
disabled
>
<div class="flex">
<span
v-html="relatedObject(providerOptions, label, 'provider')?.icon"
class="model-icon mr-8"
></span>
<span>{{ item.name }}</span>
<span class="danger">不可用</span>
</div>
<el-icon class="check-icon" v-if="item.id === form_data.model_id">
<Check />
</el-icon>
</el-option>
</el-option-group>
</el-select>
</el-form-item>
<el-form-item
label="提示词(正向)"
prop="prompt"
:rules="{
required: true,
message: '请输入提示词',
trigger: 'blur'
}"
>
<template #label>
<div class="flex align-center">
<div class="mr-4">
<span>提示词(正向)<span class="danger">*</span></span>
</div>
<el-tooltip effect="dark" placement="right" popper-class="max-w-200">
<template #content
>正向提示词用来描述生成图像中期望包含的元素和视觉特点
</template>
<AppIcon iconName="app-warning" class="app-warning-icon"></AppIcon>
</el-tooltip>
</div>
</template>
<MdEditorMagnify
@wheel="wheel"
title="提示词(正向)"
v-model="form_data.prompt"
style="height: 150px"
@submitDialog="submitDialog"
/>
</el-form-item>
<el-form-item
label="提示词(负向)"
prop="prompt"
:rules="{
required: false,
message: '请输入提示词',
trigger: 'blur'
}"
>
<template #label>
<div class="flex align-center">
<div class="mr-4">
<span>提示词(负向)</span>
</div>
<el-tooltip effect="dark" placement="right" popper-class="max-w-200">
<template #content
>反向提示词用来描述不希望在画面中看到的内容可以对画面进行限制
</template>
<AppIcon iconName="app-warning" class="app-warning-icon"></AppIcon>
</el-tooltip>
</div>
</template>
<MdEditorMagnify
@wheel="wheel"
title="提示词(负向)"
v-model="form_data.negative_prompt"
style="height: 150px"
@submitDialog="submitDialog"
/>
</el-form-item>
<el-form-item>
<template #label>
<div class="flex-between">
<div>历史聊天记录</div>
<el-select v-model="form_data.dialogue_type" type="small" style="width: 100px">
<el-option label="节点" value="NODE" />
<el-option label="工作流" value="WORKFLOW" />
</el-select>
</div>
</template>
<el-input-number
v-model="form_data.dialogue_number"
:min="0"
:value-on-clear="0"
controls-position="right"
class="w-full"
:step="1"
:step-strictly="true"
/>
</el-form-item>
<el-form-item label="返回内容" @click.prevent>
<template #label>
<div class="flex align-center">
<div class="mr-4">
<span>返回内容<span class="danger">*</span></span>
</div>
<el-tooltip effect="dark" placement="right" popper-class="max-w-200">
<template #content>
关闭后该节点的内容则不输出给用户
如果你想让用户看到该节点的输出内容请打开开关
</template>
<AppIcon iconName="app-warning" class="app-warning-icon"></AppIcon>
</el-tooltip>
</div>
</template>
<el-switch size="small" v-model="form_data.is_result" />
</el-form-item>
</el-form>
</el-card>
<AIModeParamSettingDialog ref="AIModeParamSettingDialogRef" @refresh="refreshParam" />
</NodeContainer>
</template>
<script setup lang="ts">
import NodeContainer from '@/workflow/common/NodeContainer.vue'
import { computed, onMounted, ref } from 'vue'
import { groupBy, set } from 'lodash'
import { relatedObject } from '@/utils/utils'
import type { Provider } from '@/api/type/model'
import applicationApi from '@/api/application'
import { app } from '@/main'
import useStore from '@/stores'
import NodeCascader from '@/workflow/common/NodeCascader.vue'
import type { FormInstance } from 'element-plus'
import AIModeParamSettingDialog from '@/views/application/component/AIModeParamSettingDialog.vue'
const { model } = useStore()
const {
params: { id }
} = app.config.globalProperties.$route as any
const props = defineProps<{ nodeModel: any }>()
const modelOptions = ref<any>(null)
const providerOptions = ref<Array<Provider>>([])
const AIModeParamSettingDialogRef = ref<InstanceType<typeof AIModeParamSettingDialog>>()
const aiChatNodeFormRef = ref<FormInstance>()
const validate = () => {
return aiChatNodeFormRef.value?.validate().catch((err) => {
return Promise.reject({ node: props.nodeModel, errMessage: err })
})
}
const wheel = (e: any) => {
if (e.ctrlKey === true) {
e.preventDefault()
return true
} else {
e.stopPropagation()
return true
}
}
const defaultPrompt = `{{开始.question}}`
const form = {
model_id: '',
system: '',
prompt: defaultPrompt,
negative_prompt: '',
dialogue_number: 0,
dialogue_type: 'NODE',
is_result: true,
temperature: null,
max_tokens: null,
image_list: ['start-node', 'image']
}
const form_data = computed({
get: () => {
if (props.nodeModel.properties.node_data) {
return props.nodeModel.properties.node_data
} else {
set(props.nodeModel.properties, 'node_data', form)
}
return props.nodeModel.properties.node_data
},
set: (value) => {
set(props.nodeModel.properties, 'node_data', value)
}
})
function getModel() {
if (id) {
applicationApi.getApplicationTTIModel(id).then((res: any) => {
modelOptions.value = groupBy(res?.data, 'provider')
})
} else {
model.asyncGetModel().then((res: any) => {
modelOptions.value = groupBy(res?.data, 'provider')
})
}
}
function getProvider() {
model.asyncGetProvider().then((res: any) => {
providerOptions.value = res?.data
})
}
const model_change = () => {
if (form_data.value.model_id) {
AIModeParamSettingDialogRef.value?.reset_default(form_data.value.model_id, id)
} else {
refreshParam({})
}
}
const openAIParamSettingDialog = (modelId: string) => {
if (modelId) {
AIModeParamSettingDialogRef.value?.open(modelId, id, form_data.value.model_params_setting)
}
}
function refreshParam(data: any) {
set(props.nodeModel.properties.node_data, 'model_params_setting', data)
}
function submitDialog(val: string) {
set(props.nodeModel.properties.node_data, 'prompt', val)
}
onMounted(() => {
getModel()
getProvider()
set(props.nodeModel, 'validate', validate)
})
</script>
<style scoped lang="scss"></style>