mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 10:12:51 +00:00
feat: Support image generate model
This commit is contained in:
parent
9e859be5ff
commit
add99fabc6
|
|
@ -18,6 +18,7 @@ from .reranker_node import *
|
|||
|
||||
from .document_extract_node import *
|
||||
from .image_understand_step_node import *
|
||||
from .image_generate_step_node import *
|
||||
|
||||
from .search_dataset_node import *
|
||||
from .start_node import *
|
||||
|
|
@ -25,7 +26,7 @@ from .start_node import *
|
|||
node_list = [BaseStartStepNode, BaseChatNode, BaseSearchDatasetNode, BaseQuestionNode, BaseConditionNode, BaseReplyNode,
|
||||
BaseFunctionNodeNode, BaseFunctionLibNodeNode, BaseRerankerNode, BaseApplicationNode,
|
||||
BaseDocumentExtractNode,
|
||||
BaseImageUnderstandNode, BaseFormNode]
|
||||
BaseImageUnderstandNode, BaseImageGenerateNode, BaseFormNode]
|
||||
|
||||
|
||||
def get_node(node_type):
|
||||
|
|
|
|||
|
|
@ -0,0 +1,3 @@
|
|||
# coding=utf-8
|
||||
|
||||
from .impl import *
|
||||
|
|
@ -0,0 +1,37 @@
|
|||
# coding=utf-8
|
||||
|
||||
from typing import Type
|
||||
|
||||
from rest_framework import serializers
|
||||
|
||||
from application.flow.i_step_node import INode, NodeResult
|
||||
from common.util.field_message import ErrMessage
|
||||
|
||||
|
||||
class ImageGenerateNodeSerializer(serializers.Serializer):
|
||||
model_id = serializers.CharField(required=True, error_messages=ErrMessage.char("模型id"))
|
||||
|
||||
prompt = serializers.CharField(required=True, error_messages=ErrMessage.char("提示词(正向)"))
|
||||
|
||||
negative_prompt = serializers.CharField(required=False, default='', error_messages=ErrMessage.char("提示词(负向)"))
|
||||
# 多轮对话数量
|
||||
dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer("多轮对话数量"))
|
||||
|
||||
dialogue_type = serializers.CharField(required=True, error_messages=ErrMessage.char("对话存储类型"))
|
||||
|
||||
is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容'))
|
||||
|
||||
|
||||
class IImageGenerateNode(INode):
|
||||
type = 'image-generate-node'
|
||||
|
||||
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
|
||||
return ImageGenerateNodeSerializer
|
||||
|
||||
def _run(self):
|
||||
return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data)
|
||||
|
||||
def execute(self, model_id, prompt, negative_prompt, dialogue_number, dialogue_type, history_chat_record, chat_id,
|
||||
chat_record_id,
|
||||
**kwargs) -> NodeResult:
|
||||
pass
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
# coding=utf-8
|
||||
|
||||
from .base_image_generate_node import BaseImageGenerateNode
|
||||
|
|
@ -0,0 +1,101 @@
|
|||
# coding=utf-8
|
||||
from functools import reduce
|
||||
from typing import List
|
||||
|
||||
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
|
||||
|
||||
from application.flow.i_step_node import NodeResult
|
||||
from application.flow.step_node.image_generate_step_node.i_image_generate_node import IImageGenerateNode
|
||||
from setting.models_provider.tools import get_model_instance_by_model_user_id
|
||||
|
||||
|
||||
class BaseImageGenerateNode(IImageGenerateNode):
|
||||
def save_context(self, details, workflow_manage):
|
||||
self.context['answer'] = details.get('answer')
|
||||
self.context['question'] = details.get('question')
|
||||
self.answer_text = details.get('answer')
|
||||
|
||||
def execute(self, model_id, prompt, negative_prompt, dialogue_number, dialogue_type, history_chat_record, chat_id,
|
||||
chat_record_id,
|
||||
**kwargs) -> NodeResult:
|
||||
|
||||
tti_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'))
|
||||
history_message = self.get_history_message(history_chat_record, dialogue_number)
|
||||
self.context['history_message'] = history_message
|
||||
question = self.generate_prompt_question(prompt)
|
||||
self.context['question'] = question
|
||||
message_list = self.generate_message_list(question, history_message)
|
||||
self.context['message_list'] = message_list
|
||||
self.context['dialogue_type'] = dialogue_type
|
||||
print(message_list)
|
||||
print(negative_prompt)
|
||||
image_urls = tti_model.generate_image(question, negative_prompt)
|
||||
self.context['image_list'] = image_urls
|
||||
answer = '\n'.join([f"" for path in image_urls])
|
||||
return NodeResult({'answer': answer, 'chat_model': tti_model, 'message_list': message_list,
|
||||
'image': [{'file_id': path.split('/')[-1], 'file_url': path} for path in file_urls],
|
||||
'history_message': history_message, 'question': question}, {})
|
||||
|
||||
def generate_history_ai_message(self, chat_record):
|
||||
for val in chat_record.details.values():
|
||||
if self.node.id == val['node_id'] and 'image_list' in val:
|
||||
if val['dialogue_type'] == 'WORKFLOW':
|
||||
return chat_record.get_ai_message()
|
||||
return AIMessage(content=val['answer'])
|
||||
return chat_record.get_ai_message()
|
||||
|
||||
def get_history_message(self, history_chat_record, dialogue_number):
|
||||
start_index = len(history_chat_record) - dialogue_number
|
||||
history_message = reduce(lambda x, y: [*x, *y], [
|
||||
[self.generate_history_human_message(history_chat_record[index]),
|
||||
self.generate_history_ai_message(history_chat_record[index])]
|
||||
for index in
|
||||
range(start_index if start_index > 0 else 0, len(history_chat_record))], [])
|
||||
return history_message
|
||||
|
||||
def generate_history_human_message(self, chat_record):
|
||||
|
||||
for data in chat_record.details.values():
|
||||
if self.node.id == data['node_id'] and 'image_list' in data:
|
||||
image_list = data['image_list']
|
||||
if len(image_list) == 0 or data['dialogue_type'] == 'WORKFLOW':
|
||||
return HumanMessage(content=chat_record.problem_text)
|
||||
return HumanMessage(content=data['question'])
|
||||
return HumanMessage(content=chat_record.problem_text)
|
||||
|
||||
def generate_prompt_question(self, prompt):
|
||||
return self.workflow_manage.generate_prompt(prompt)
|
||||
|
||||
def generate_message_list(self, question: str, history_message):
|
||||
return [
|
||||
*history_message,
|
||||
question
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def reset_message_list(message_list: List[BaseMessage], answer_text):
|
||||
result = [{'role': 'user' if isinstance(message, HumanMessage) else 'ai', 'content': message.content} for
|
||||
message
|
||||
in
|
||||
message_list]
|
||||
result.append({'role': 'ai', 'content': answer_text})
|
||||
return result
|
||||
|
||||
def get_details(self, index: int, **kwargs):
|
||||
return {
|
||||
'name': self.node.properties.get('stepName'),
|
||||
"index": index,
|
||||
'run_time': self.context.get('run_time'),
|
||||
'history_message': [{'content': message.content, 'role': message.type} for message in
|
||||
(self.context.get('history_message') if self.context.get(
|
||||
'history_message') is not None else [])],
|
||||
'question': self.context.get('question'),
|
||||
'answer': self.context.get('answer'),
|
||||
'type': self.node.type,
|
||||
'message_tokens': self.context.get('message_tokens'),
|
||||
'answer_tokens': self.context.get('answer_tokens'),
|
||||
'status': self.status,
|
||||
'err_message': self.err_message,
|
||||
'image_list': self.context.get('image_list'),
|
||||
'dialogue_type': self.context.get('dialogue_type')
|
||||
}
|
||||
|
|
@ -54,7 +54,7 @@ class Node:
|
|||
|
||||
|
||||
end_nodes = ['ai-chat-node', 'reply-node', 'function-node', 'function-lib-node', 'application-node',
|
||||
'image-understand-node']
|
||||
'image-understand-node', 'image-generate-node']
|
||||
|
||||
|
||||
class Flow:
|
||||
|
|
|
|||
|
|
@ -8,9 +8,12 @@
|
|||
"""
|
||||
import hashlib
|
||||
import importlib
|
||||
import mimetypes
|
||||
import io
|
||||
from functools import reduce
|
||||
from typing import Dict, List
|
||||
|
||||
from django.core.files.uploadedfile import InMemoryUploadedFile
|
||||
from django.db.models import QuerySet
|
||||
|
||||
from ..exception.app_exception import AppApiException
|
||||
|
|
@ -111,3 +114,25 @@ def bulk_create_in_batches(model, data, batch_size=1000):
|
|||
batch = data[i:i + batch_size]
|
||||
model.objects.bulk_create(batch)
|
||||
|
||||
|
||||
def bytes_to_uploaded_file(file_bytes, file_name="file.txt"):
|
||||
content_type, _ = mimetypes.guess_type(file_name)
|
||||
if content_type is None:
|
||||
# 如果未能识别,设置为默认的二进制文件类型
|
||||
content_type = "application/octet-stream"
|
||||
# 创建一个内存中的字节流对象
|
||||
file_stream = io.BytesIO(file_bytes)
|
||||
|
||||
# 获取文件大小
|
||||
file_size = len(file_bytes)
|
||||
|
||||
# 创建 InMemoryUploadedFile 对象
|
||||
uploaded_file = InMemoryUploadedFile(
|
||||
file=file_stream,
|
||||
field_name=None,
|
||||
name=file_name,
|
||||
content_type=content_type,
|
||||
size=file_size,
|
||||
charset=None,
|
||||
)
|
||||
return uploaded_file
|
||||
|
|
|
|||
|
|
@ -150,6 +150,7 @@ class ModelTypeConst(Enum):
|
|||
STT = {'code': 'STT', 'message': '语音识别'}
|
||||
TTS = {'code': 'TTS', 'message': '语音合成'}
|
||||
IMAGE = {'code': 'IMAGE', 'message': '图片理解'}
|
||||
TTI = {'code': 'TTI', 'message': '图片生成'}
|
||||
RERANKER = {'code': 'RERANKER', 'message': '重排模型'}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,14 @@
|
|||
# coding=utf-8
|
||||
from abc import abstractmethod
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class BaseTextToImage(BaseModel):
|
||||
@abstractmethod
|
||||
def check_auth(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def generate_image(self, prompt: str, negative_prompt: str = None):
|
||||
pass
|
||||
|
|
@ -0,0 +1,47 @@
|
|||
# coding=utf-8
|
||||
import base64
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm
|
||||
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
|
||||
class OpenAITextToImageModelCredential(BaseForm, BaseModelCredential):
|
||||
api_base = forms.TextInputField('API 域名', required=True)
|
||||
api_key = forms.PasswordInputField('API Key', required=True)
|
||||
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
||||
raise_exception=False):
|
||||
model_type_list = provider.get_model_type_list()
|
||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||
|
||||
for key in ['api_base', 'api_key']:
|
||||
if key not in model_credential:
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
|
||||
else:
|
||||
return False
|
||||
try:
|
||||
model = provider.get_model(model_type, model_name, model_credential)
|
||||
res = model.check_auth()
|
||||
print(res)
|
||||
except Exception as e:
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
def encryption_dict(self, model: Dict[str, object]):
|
||||
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||
|
||||
def get_model_params_setting_form(self, model_name):
|
||||
pass
|
||||
|
|
@ -0,0 +1,67 @@
|
|||
from typing import Dict
|
||||
|
||||
import requests
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_openai import ChatOpenAI
|
||||
from openai import OpenAI
|
||||
|
||||
from common.config.tokenizer_manage_config import TokenizerManage
|
||||
from common.util.common import bytes_to_uploaded_file
|
||||
from dataset.serializers.file_serializers import FileSerializer
|
||||
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||
from setting.models_provider.impl.base_tti import BaseTextToImage
|
||||
|
||||
|
||||
def custom_get_token_ids(text: str):
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return tokenizer.encode(text)
|
||||
|
||||
|
||||
class OpenAITextToImage(MaxKBBaseModel, BaseTextToImage):
|
||||
api_base: str
|
||||
api_key: str
|
||||
model: str
|
||||
params: dict
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.api_key = kwargs.get('api_key')
|
||||
self.api_base = kwargs.get('api_base')
|
||||
self.model = kwargs.get('model')
|
||||
self.params = kwargs.get('params')
|
||||
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
optional_params = {'params': {}}
|
||||
for key, value in model_kwargs.items():
|
||||
if key not in ['model_id', 'use_local', 'streaming']:
|
||||
optional_params['params'][key] = value
|
||||
return OpenAITextToImage(
|
||||
model=model_name,
|
||||
api_base=model_credential.get('api_base'),
|
||||
api_key=model_credential.get('api_key'),
|
||||
**optional_params,
|
||||
)
|
||||
|
||||
def check_auth(self):
|
||||
chat = OpenAI(api_key=self.api_key, base_url=self.api_base)
|
||||
response_list = chat.models.with_raw_response.list()
|
||||
|
||||
# self.generate_image('生成一个小猫图片')
|
||||
|
||||
def generate_image(self, prompt: str, negative_prompt: str = None):
|
||||
|
||||
chat = OpenAI(api_key=self.api_key, base_url=self.api_base)
|
||||
res = chat.images.generate(model=self.model, prompt=prompt, **self.params)
|
||||
|
||||
file_urls = []
|
||||
for content in res.data:
|
||||
url = content.url
|
||||
print(url)
|
||||
file_name = 'generated_image.png'
|
||||
file = bytes_to_uploaded_file(requests.get(url).content, file_name)
|
||||
meta = {'debug': True}
|
||||
file_url = FileSerializer(data={'file': file, 'meta': meta}).upload()
|
||||
file_urls.append(file_url)
|
||||
|
||||
return file_urls
|
||||
|
|
@ -15,11 +15,13 @@ from setting.models_provider.impl.openai_model_provider.credential.embedding imp
|
|||
from setting.models_provider.impl.openai_model_provider.credential.image import OpenAIImageModelCredential
|
||||
from setting.models_provider.impl.openai_model_provider.credential.llm import OpenAILLMModelCredential
|
||||
from setting.models_provider.impl.openai_model_provider.credential.stt import OpenAISTTModelCredential
|
||||
from setting.models_provider.impl.openai_model_provider.credential.tti import OpenAITextToImageModelCredential
|
||||
from setting.models_provider.impl.openai_model_provider.credential.tts import OpenAITTSModelCredential
|
||||
from setting.models_provider.impl.openai_model_provider.model.embedding import OpenAIEmbeddingModel
|
||||
from setting.models_provider.impl.openai_model_provider.model.image import OpenAIImage
|
||||
from setting.models_provider.impl.openai_model_provider.model.llm import OpenAIChatModel
|
||||
from setting.models_provider.impl.openai_model_provider.model.stt import OpenAISpeechToText
|
||||
from setting.models_provider.impl.openai_model_provider.model.tti import OpenAITextToImage
|
||||
from setting.models_provider.impl.openai_model_provider.model.tts import OpenAITextToSpeech
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
|
|
@ -27,6 +29,7 @@ openai_llm_model_credential = OpenAILLMModelCredential()
|
|||
openai_stt_model_credential = OpenAISTTModelCredential()
|
||||
openai_tts_model_credential = OpenAITTSModelCredential()
|
||||
openai_image_model_credential = OpenAIImageModelCredential()
|
||||
openai_tti_model_credential = OpenAITextToImageModelCredential()
|
||||
model_info_list = [
|
||||
ModelInfo('gpt-3.5-turbo', '最新的gpt-3.5-turbo,随OpenAI调整而更新', ModelTypeConst.LLM,
|
||||
openai_llm_model_credential, OpenAIChatModel
|
||||
|
|
@ -37,8 +40,8 @@ model_info_list = [
|
|||
ModelTypeConst.LLM, openai_llm_model_credential,
|
||||
OpenAIChatModel),
|
||||
ModelInfo('gpt-4o-mini', '最新的gpt-4o-mini,比gpt-4o更便宜、更快,随OpenAI调整而更新',
|
||||
ModelTypeConst.LLM, openai_llm_model_credential,
|
||||
OpenAIChatModel),
|
||||
ModelTypeConst.LLM, openai_llm_model_credential,
|
||||
OpenAIChatModel),
|
||||
ModelInfo('gpt-4-turbo', '最新的gpt-4-turbo,随OpenAI调整而更新', ModelTypeConst.LLM,
|
||||
openai_llm_model_credential,
|
||||
OpenAIChatModel),
|
||||
|
|
@ -100,11 +103,27 @@ model_info_image_list = [
|
|||
OpenAIImage),
|
||||
]
|
||||
|
||||
model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info(
|
||||
ModelInfo('gpt-3.5-turbo', '最新的gpt-3.5-turbo,随OpenAI调整而更新', ModelTypeConst.LLM,
|
||||
openai_llm_model_credential, OpenAIChatModel
|
||||
)).append_model_info_list(model_info_embedding_list).append_default_model_info(
|
||||
model_info_embedding_list[0]).append_model_info_list(model_info_image_list).build()
|
||||
model_info_tti_list = [
|
||||
ModelInfo('dall-e-2', '',
|
||||
ModelTypeConst.TTI, openai_tti_model_credential,
|
||||
OpenAITextToImage),
|
||||
ModelInfo('dall-e-3', '',
|
||||
ModelTypeConst.TTI, openai_tti_model_credential,
|
||||
OpenAITextToImage),
|
||||
]
|
||||
|
||||
model_info_manage = (
|
||||
ModelInfoManage.builder()
|
||||
.append_model_info_list(model_info_list)
|
||||
.append_default_model_info(ModelInfo('gpt-3.5-turbo', '最新的gpt-3.5-turbo,随OpenAI调整而更新', ModelTypeConst.LLM,
|
||||
openai_llm_model_credential, OpenAIChatModel
|
||||
))
|
||||
.append_model_info_list(model_info_embedding_list)
|
||||
.append_default_model_info(model_info_embedding_list[0])
|
||||
.append_model_info_list(model_info_image_list)
|
||||
.append_model_info_list(model_info_tti_list)
|
||||
.build()
|
||||
)
|
||||
|
||||
|
||||
class OpenAIModelProvider(IModelProvider):
|
||||
|
|
|
|||
|
|
@ -0,0 +1,70 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: llm.py
|
||||
@date:2024/7/11 18:41
|
||||
@desc:
|
||||
"""
|
||||
import base64
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm, TooltipLabel
|
||||
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
|
||||
class QwenModelParams(BaseForm):
|
||||
temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
|
||||
required=True, default_value=1.0,
|
||||
_min=0.1,
|
||||
_max=1.9,
|
||||
_step=0.01,
|
||||
precision=2)
|
||||
|
||||
max_tokens = forms.SliderField(
|
||||
TooltipLabel('输出最大Tokens', '指定模型可生成的最大token个数'),
|
||||
required=True, default_value=800,
|
||||
_min=1,
|
||||
_max=100000,
|
||||
_step=1,
|
||||
precision=0)
|
||||
|
||||
|
||||
class QwenTextToImageModelCredential(BaseForm, BaseModelCredential):
|
||||
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
||||
raise_exception=False):
|
||||
model_type_list = provider.get_model_type_list()
|
||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||
for key in ['api_key']:
|
||||
if key not in model_credential:
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
|
||||
else:
|
||||
return False
|
||||
try:
|
||||
model = provider.get_model(model_type, model_name, model_credential)
|
||||
res = model.check_auth()
|
||||
print(res)
|
||||
except Exception as e:
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
def encryption_dict(self, model: Dict[str, object]):
|
||||
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||
|
||||
api_key = forms.PasswordInputField('API Key', required=True)
|
||||
|
||||
def get_model_params_setting_form(self, model_name):
|
||||
return QwenModelParams()
|
||||
|
|
@ -0,0 +1,64 @@
|
|||
# coding=utf-8
|
||||
from http import HTTPStatus
|
||||
from pathlib import PurePosixPath
|
||||
from typing import Dict
|
||||
from urllib.parse import unquote, urlparse
|
||||
|
||||
import requests
|
||||
from dashscope import ImageSynthesis
|
||||
from langchain_community.chat_models import ChatTongyi
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from common.util.common import bytes_to_uploaded_file
|
||||
from dataset.serializers.file_serializers import FileSerializer
|
||||
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||
from setting.models_provider.impl.base_tti import BaseTextToImage
|
||||
|
||||
|
||||
class QwenTextToImageModel(MaxKBBaseModel, BaseTextToImage):
|
||||
api_key: str
|
||||
model_name: str
|
||||
params: dict
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.api_key = kwargs.get('api_key')
|
||||
self.model_name = kwargs.get('model_name')
|
||||
self.params = kwargs.get('params')
|
||||
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
optional_params = {'params': {}}
|
||||
for key, value in model_kwargs.items():
|
||||
if key not in ['model_id', 'use_local', 'streaming']:
|
||||
optional_params['params'][key] = value
|
||||
chat_tong_yi = QwenTextToImageModel(
|
||||
model_name=model_name,
|
||||
api_key=model_credential.get('api_key'),
|
||||
**optional_params,
|
||||
)
|
||||
return chat_tong_yi
|
||||
|
||||
def check_auth(self):
|
||||
chat = ChatTongyi(api_key=self.api_key, model_name='qwen-max')
|
||||
chat.invoke([HumanMessage([{"type": "text", "text": "你好"}])])
|
||||
|
||||
def generate_image(self, prompt: str, negative_prompt: str = None):
|
||||
# api_base='https://dashscope.aliyuncs.com/compatible-mode/v1',
|
||||
rsp = ImageSynthesis.call(api_key=self.api_key,
|
||||
model=self.model_name,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
**self.params)
|
||||
file_urls = []
|
||||
if rsp.status_code == HTTPStatus.OK:
|
||||
for result in rsp.output.results:
|
||||
file_name = PurePosixPath(unquote(urlparse(result.url).path)).parts[-1]
|
||||
file = bytes_to_uploaded_file(requests.get(result.url).content, file_name)
|
||||
meta = {'debug': True}
|
||||
file_url = FileSerializer(data={'file': file, 'meta': meta}).upload()
|
||||
file_urls.append(file_url)
|
||||
else:
|
||||
print('sync_call Failed, status_code: %s, code: %s, message: %s' %
|
||||
(rsp.status_code, rsp.code, rsp.message))
|
||||
return file_urls
|
||||
|
|
@ -13,13 +13,16 @@ from setting.models_provider.base_model_provider import ModelProvideInfo, ModelT
|
|||
ModelInfoManage
|
||||
from setting.models_provider.impl.qwen_model_provider.credential.image import QwenVLModelCredential
|
||||
from setting.models_provider.impl.qwen_model_provider.credential.llm import OpenAILLMModelCredential
|
||||
from setting.models_provider.impl.qwen_model_provider.credential.tti import QwenTextToImageModelCredential
|
||||
from setting.models_provider.impl.qwen_model_provider.model.image import QwenVLChatModel
|
||||
|
||||
from setting.models_provider.impl.qwen_model_provider.model.llm import QwenChatModel
|
||||
from setting.models_provider.impl.qwen_model_provider.model.tti import QwenTextToImageModel
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
qwen_model_credential = OpenAILLMModelCredential()
|
||||
qwenvl_model_credential = QwenVLModelCredential()
|
||||
qwentti_model_credential = QwenTextToImageModelCredential()
|
||||
|
||||
module_info_list = [
|
||||
ModelInfo('qwen-turbo', '', ModelTypeConst.LLM, qwen_model_credential, QwenChatModel),
|
||||
|
|
@ -31,13 +34,21 @@ module_info_vl_list = [
|
|||
ModelInfo('qwen-vl-max-0809', '', ModelTypeConst.IMAGE, qwenvl_model_credential, QwenVLChatModel),
|
||||
ModelInfo('qwen-vl-plus-0809', '', ModelTypeConst.IMAGE, qwenvl_model_credential, QwenVLChatModel),
|
||||
]
|
||||
module_info_tti_list = [
|
||||
ModelInfo('wanx-v1',
|
||||
'通义万相-文本生成图像大模型,支持中英文双语输入,支持输入参考图片进行参考内容或者参考风格迁移,重点风格包括但不限于水彩、油画、中国画、素描、扁平插画、二次元、3D卡通。',
|
||||
ModelTypeConst.TTI, qwentti_model_credential, QwenTextToImageModel),
|
||||
]
|
||||
|
||||
model_info_manage = (ModelInfoManage.builder()
|
||||
.append_model_info_list(module_info_list)
|
||||
.append_default_model_info(
|
||||
ModelInfo('qwen-turbo', '', ModelTypeConst.LLM, qwen_model_credential, QwenChatModel))
|
||||
.append_model_info_list(module_info_vl_list)
|
||||
.build())
|
||||
model_info_manage = (
|
||||
ModelInfoManage.builder()
|
||||
.append_model_info_list(module_info_list)
|
||||
.append_default_model_info(
|
||||
ModelInfo('qwen-turbo', '', ModelTypeConst.LLM, qwen_model_credential, QwenChatModel))
|
||||
.append_model_info_list(module_info_vl_list)
|
||||
.append_model_info_list(module_info_tti_list)
|
||||
.build()
|
||||
)
|
||||
|
||||
|
||||
class QwenModelProvider(IModelProvider):
|
||||
|
|
|
|||
|
|
@ -0,0 +1,108 @@
|
|||
# coding=utf-8
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm, TooltipLabel
|
||||
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
|
||||
class TencentTTIModelParams(BaseForm):
|
||||
Style = forms.SingleSelect(
|
||||
TooltipLabel('绘画风格', '不传默认使用201(日系动漫风格)'),
|
||||
required=True,
|
||||
default_value='201',
|
||||
option_list=[
|
||||
{'value': '000', 'label': '不限定风格'},
|
||||
{'value': '101', 'label': '水墨画'},
|
||||
{'value': '102', 'label': '概念艺术'},
|
||||
{'value': '103', 'label': '油画1'},
|
||||
{'value': '118', 'label': '油画2(梵高)'},
|
||||
{'value': '104', 'label': '水彩画'},
|
||||
{'value': '105', 'label': '像素画'},
|
||||
{'value': '106', 'label': '厚涂风格'},
|
||||
{'value': '107', 'label': '插图'},
|
||||
{'value': '108', 'label': '剪纸风格'},
|
||||
{'value': '109', 'label': '印象派1(莫奈)'},
|
||||
{'value': '119', 'label': '印象派2'},
|
||||
{'value': '110', 'label': '2.5D'},
|
||||
{'value': '111', 'label': '古典肖像画'},
|
||||
{'value': '112', 'label': '黑白素描画'},
|
||||
{'value': '113', 'label': '赛博朋克'},
|
||||
{'value': '114', 'label': '科幻风格'},
|
||||
{'value': '115', 'label': '暗黑风格'},
|
||||
{'value': '116', 'label': '3D'},
|
||||
{'value': '117', 'label': '蒸汽波'},
|
||||
{'value': '201', 'label': '日系动漫'},
|
||||
{'value': '202', 'label': '怪兽风格'},
|
||||
{'value': '203', 'label': '唯美古风'},
|
||||
{'value': '204', 'label': '复古动漫'},
|
||||
{'value': '301', 'label': '游戏卡通手绘'},
|
||||
{'value': '401', 'label': '通用写实风格'},
|
||||
],
|
||||
value_field='value',
|
||||
text_field='label'
|
||||
)
|
||||
|
||||
Resolution = forms.SingleSelect(
|
||||
TooltipLabel('生成图分辨率', '不传默认使用768:768。'),
|
||||
required=True,
|
||||
default_value='768:768',
|
||||
option_list=[
|
||||
{'value': '768:768', 'label': '768:768(1:1)'},
|
||||
{'value': '768:1024', 'label': '768:1024(3:4)'},
|
||||
{'value': '1024:768', 'label': '1024:768(4:3)'},
|
||||
{'value': '1024:1024', 'label': '1024:1024(1:1)'},
|
||||
{'value': '720:1280', 'label': '720:1280(9:16)'},
|
||||
{'value': '1280:720', 'label': '1280:720(16:9)'},
|
||||
{'value': '768:1280', 'label': '768:1280(3:5)'},
|
||||
{'value': '1280:768', 'label': '1280:768(5:3)'},
|
||||
{'value': '1080:1920', 'label': '1080:1920(9:16)'},
|
||||
{'value': '1920:1080', 'label': '1920:1080(16:9)'},
|
||||
],
|
||||
value_field='value',
|
||||
text_field='label'
|
||||
)
|
||||
|
||||
|
||||
class TencentTTIModelCredential(BaseForm, BaseModelCredential):
|
||||
REQUIRED_FIELDS = ['hunyuan_secret_id', 'hunyuan_secret_key']
|
||||
|
||||
@classmethod
|
||||
def _validate_model_type(cls, model_type, provider, raise_exception=False):
|
||||
if not any(mt['value'] == model_type for mt in provider.get_model_type_list()):
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||
return False
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def _validate_credential_fields(cls, model_credential, raise_exception=False):
|
||||
missing_keys = [key for key in cls.REQUIRED_FIELDS if key not in model_credential]
|
||||
if missing_keys:
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{", ".join(missing_keys)} 字段为必填字段')
|
||||
return False
|
||||
return True
|
||||
|
||||
def is_valid(self, model_type, model_name, model_credential, provider, raise_exception=False):
|
||||
if not (self._validate_model_type(model_type, provider, raise_exception) and
|
||||
self._validate_credential_fields(model_credential, raise_exception)):
|
||||
return False
|
||||
try:
|
||||
model = provider.get_model(model_type, model_name, model_credential)
|
||||
model.check_auth()
|
||||
except Exception as e:
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
|
||||
return False
|
||||
return True
|
||||
|
||||
def encryption_dict(self, model):
|
||||
return {**model, 'hunyuan_secret_key': super().encryption(model.get('hunyuan_secret_key', ''))}
|
||||
|
||||
hunyuan_secret_id = forms.PasswordInputField('SecretId', required=True)
|
||||
hunyuan_secret_key = forms.PasswordInputField('SecretKey', required=True)
|
||||
|
||||
def get_model_params_setting_form(self, model_name):
|
||||
return TencentTTIModelParams()
|
||||
|
|
@ -0,0 +1,98 @@
|
|||
# coding=utf-8
|
||||
|
||||
import json
|
||||
from typing import Dict
|
||||
|
||||
import requests
|
||||
from tencentcloud.common import credential
|
||||
from tencentcloud.common.exception.tencent_cloud_sdk_exception import TencentCloudSDKException
|
||||
from tencentcloud.common.profile.client_profile import ClientProfile
|
||||
from tencentcloud.common.profile.http_profile import HttpProfile
|
||||
from tencentcloud.hunyuan.v20230901 import hunyuan_client, models
|
||||
|
||||
from common.util.common import bytes_to_uploaded_file
|
||||
from dataset.serializers.file_serializers import FileSerializer
|
||||
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||
from setting.models_provider.impl.base_tti import BaseTextToImage
|
||||
from setting.models_provider.impl.tencent_model_provider.model.hunyuan import ChatHunyuan
|
||||
|
||||
|
||||
class TencentTextToImageModel(MaxKBBaseModel, BaseTextToImage):
|
||||
hunyuan_secret_id: str
|
||||
hunyuan_secret_key: str
|
||||
model: str
|
||||
params: dict
|
||||
|
||||
@staticmethod
|
||||
def is_cache_model():
|
||||
return False
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.hunyuan_secret_id = kwargs.get('hunyuan_secret_id')
|
||||
self.hunyuan_secret_key = kwargs.get('hunyuan_secret_key')
|
||||
self.model = kwargs.get('model_name')
|
||||
self.params = kwargs.get('params')
|
||||
|
||||
@staticmethod
|
||||
def new_instance(model_type: str, model_name: str, model_credential: Dict[str, object],
|
||||
**model_kwargs) -> 'TencentTextToImageModel':
|
||||
optional_params = {'params': {'Style': '201', 'Resolution': '768:768'}}
|
||||
for key, value in model_kwargs.items():
|
||||
if key not in ['model_id', 'use_local', 'streaming']:
|
||||
optional_params['params'][key] = value
|
||||
return TencentTextToImageModel(
|
||||
model=model_name,
|
||||
hunyuan_secret_id=model_credential.get('hunyuan_secret_id'),
|
||||
hunyuan_secret_key=model_credential.get('hunyuan_secret_key'),
|
||||
**optional_params
|
||||
)
|
||||
|
||||
def check_auth(self):
|
||||
chat = ChatHunyuan(hunyuan_app_id='111111',
|
||||
hunyuan_secret_id=self.hunyuan_secret_id,
|
||||
hunyuan_secret_key=self.hunyuan_secret_key,
|
||||
model="hunyuan-standard")
|
||||
res = chat.invoke('你好')
|
||||
# print(res)
|
||||
|
||||
def generate_image(self, prompt: str, negative_prompt: str = None):
|
||||
try:
|
||||
# 实例化一个认证对象,入参需要传入腾讯云账户 SecretId 和 SecretKey,此处还需注意密钥对的保密
|
||||
# 代码泄露可能会导致 SecretId 和 SecretKey 泄露,并威胁账号下所有资源的安全性。以下代码示例仅供参考,建议采用更安全的方式来使用密钥,请参见:https://cloud.tencent.com/document/product/1278/85305
|
||||
# 密钥可前往官网控制台 https://console.cloud.tencent.com/cam/capi 进行获取
|
||||
cred = credential.Credential(self.hunyuan_secret_id, self.hunyuan_secret_key)
|
||||
# 实例化一个http选项,可选的,没有特殊需求可以跳过
|
||||
httpProfile = HttpProfile()
|
||||
httpProfile.endpoint = "hunyuan.tencentcloudapi.com"
|
||||
|
||||
# 实例化一个client选项,可选的,没有特殊需求可以跳过
|
||||
clientProfile = ClientProfile()
|
||||
clientProfile.httpProfile = httpProfile
|
||||
# 实例化要请求产品的client对象,clientProfile是可选的
|
||||
client = hunyuan_client.HunyuanClient(cred, "ap-guangzhou", clientProfile)
|
||||
|
||||
# 实例化一个请求对象,每个接口都会对应一个request对象
|
||||
req = models.TextToImageLiteRequest()
|
||||
params = {
|
||||
"Prompt": prompt,
|
||||
"NegativePrompt": negative_prompt,
|
||||
"RspImgType": "url",
|
||||
**self.params
|
||||
}
|
||||
req.from_json_string(json.dumps(params))
|
||||
|
||||
# 返回的resp是一个TextToImageLiteResponse的实例,与请求对象对应
|
||||
resp = client.TextToImageLite(req)
|
||||
# 输出json格式的字符串回包
|
||||
print(resp.to_json_string())
|
||||
file_urls = []
|
||||
file_name = 'generated_image.png'
|
||||
file = bytes_to_uploaded_file(requests.get(resp.ResultImage).content, file_name)
|
||||
meta = {'debug': True}
|
||||
file_url = FileSerializer(data={'file': file, 'meta': meta}).upload()
|
||||
file_urls.append(file_url)
|
||||
return file_urls
|
||||
except TencentCloudSDKException as err:
|
||||
print(err)
|
||||
|
||||
|
|
@ -9,9 +9,11 @@ from setting.models_provider.base_model_provider import (
|
|||
from setting.models_provider.impl.tencent_model_provider.credential.embedding import TencentEmbeddingCredential
|
||||
from setting.models_provider.impl.tencent_model_provider.credential.image import TencentVisionModelCredential
|
||||
from setting.models_provider.impl.tencent_model_provider.credential.llm import TencentLLMModelCredential
|
||||
from setting.models_provider.impl.tencent_model_provider.credential.tti import TencentTTIModelCredential
|
||||
from setting.models_provider.impl.tencent_model_provider.model.embedding import TencentEmbeddingModel
|
||||
from setting.models_provider.impl.tencent_model_provider.model.image import TencentVision
|
||||
from setting.models_provider.impl.tencent_model_provider.model.llm import TencentModel
|
||||
from setting.models_provider.impl.tencent_model_provider.model.tti import TencentTextToImageModel
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
|
||||
|
|
@ -87,11 +89,19 @@ def _initialize_model_info():
|
|||
TencentVisionModelCredential,
|
||||
TencentVision)]
|
||||
|
||||
model_info_tti_list = [_create_model_info(
|
||||
'hunyuan-dit',
|
||||
'混元生图模型',
|
||||
ModelTypeConst.TTI,
|
||||
TencentTTIModelCredential,
|
||||
TencentTextToImageModel)]
|
||||
|
||||
|
||||
model_info_manage = ModelInfoManage.builder() \
|
||||
.append_model_info_list(model_info_list) \
|
||||
.append_model_info_list(model_info_embedding_list) \
|
||||
.append_model_info_list(model_info_vision_list) \
|
||||
.append_model_info_list(model_info_tti_list) \
|
||||
.append_default_model_info(model_info_list[0]) \
|
||||
.build()
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,44 @@
|
|||
# coding=utf-8
|
||||
from typing import Dict
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm
|
||||
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
|
||||
class ZhiPuTextToImageModelCredential(BaseForm, BaseModelCredential):
|
||||
api_key = forms.PasswordInputField('API Key', required=True)
|
||||
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
||||
raise_exception=False):
|
||||
model_type_list = provider.get_model_type_list()
|
||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||
|
||||
for key in ['api_key']:
|
||||
if key not in model_credential:
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
|
||||
else:
|
||||
return False
|
||||
try:
|
||||
model = provider.get_model(model_type, model_name, model_credential)
|
||||
res = model.check_auth()
|
||||
print(res)
|
||||
except Exception as e:
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
def encryption_dict(self, model: Dict[str, object]):
|
||||
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||
|
||||
def get_model_params_setting_form(self, model_name):
|
||||
pass
|
||||
|
|
@ -0,0 +1,73 @@
|
|||
from typing import Dict
|
||||
|
||||
import requests
|
||||
from langchain_community.chat_models import ChatZhipuAI
|
||||
from langchain_core.messages import HumanMessage
|
||||
from zhipuai import ZhipuAI
|
||||
|
||||
from common.config.tokenizer_manage_config import TokenizerManage
|
||||
from common.util.common import bytes_to_uploaded_file
|
||||
from dataset.serializers.file_serializers import FileSerializer
|
||||
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||
from setting.models_provider.impl.base_tti import BaseTextToImage
|
||||
|
||||
|
||||
def custom_get_token_ids(text: str):
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return tokenizer.encode(text)
|
||||
|
||||
|
||||
class ZhiPuTextToImage(MaxKBBaseModel, BaseTextToImage):
|
||||
api_key: str
|
||||
model: str
|
||||
params: dict
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.api_key = kwargs.get('api_key')
|
||||
self.model = kwargs.get('model')
|
||||
self.params = kwargs.get('params')
|
||||
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
optional_params = {'params': {}}
|
||||
for key, value in model_kwargs.items():
|
||||
if key not in ['model_id', 'use_local', 'streaming']:
|
||||
optional_params['params'][key] = value
|
||||
return ZhiPuTextToImage(
|
||||
model=model_name,
|
||||
api_key=model_credential.get('api_key'),
|
||||
**optional_params,
|
||||
)
|
||||
|
||||
def check_auth(self):
|
||||
chat = ChatZhipuAI(
|
||||
zhipuai_api_key=self.api_key,
|
||||
model_name=self.model,
|
||||
)
|
||||
chat.invoke([HumanMessage([{"type": "text", "text": "你好"}])])
|
||||
|
||||
# self.generate_image('生成一个小猫图片')
|
||||
|
||||
def generate_image(self, prompt: str, negative_prompt: str = None):
|
||||
# chat = ChatZhipuAI(
|
||||
# zhipuai_api_key=self.api_key,
|
||||
# model_name=self.model,
|
||||
# )
|
||||
chat = ZhipuAI(api_key=self.api_key)
|
||||
response = chat.images.generations(
|
||||
model=self.model, # 填写需要调用的模型编码
|
||||
prompt=prompt, # 填写需要生成图片的文本
|
||||
**self.params # 填写额外参数
|
||||
)
|
||||
file_urls = []
|
||||
for content in response.data:
|
||||
url = content['url']
|
||||
print(url)
|
||||
file_name = url.split('/')[-1]
|
||||
file = bytes_to_uploaded_file(requests.get(url).content, file_name)
|
||||
meta = {'debug': True}
|
||||
file_url = FileSerializer(data={'file': file, 'meta': meta}).upload()
|
||||
file_urls.append(file_url)
|
||||
|
||||
return file_urls
|
||||
|
|
@ -13,12 +13,15 @@ from setting.models_provider.base_model_provider import ModelProvideInfo, ModelT
|
|||
ModelInfoManage
|
||||
from setting.models_provider.impl.zhipu_model_provider.credential.image import ZhiPuImageModelCredential
|
||||
from setting.models_provider.impl.zhipu_model_provider.credential.llm import ZhiPuLLMModelCredential
|
||||
from setting.models_provider.impl.zhipu_model_provider.credential.tti import ZhiPuTextToImageModelCredential
|
||||
from setting.models_provider.impl.zhipu_model_provider.model.image import ZhiPuImage
|
||||
from setting.models_provider.impl.zhipu_model_provider.model.llm import ZhipuChatModel
|
||||
from setting.models_provider.impl.zhipu_model_provider.model.tti import ZhiPuTextToImage
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
qwen_model_credential = ZhiPuLLMModelCredential()
|
||||
zhipu_image_model_credential = ZhiPuImageModelCredential()
|
||||
zhipu_tti_model_credential = ZhiPuTextToImageModelCredential()
|
||||
|
||||
model_info_list = [
|
||||
ModelInfo('glm-4', '', ModelTypeConst.LLM, qwen_model_credential, ZhipuChatModel),
|
||||
|
|
@ -38,11 +41,21 @@ model_info_image_list = [
|
|||
ZhiPuImage),
|
||||
]
|
||||
|
||||
model_info_tti_list = [
|
||||
ModelInfo('cogview-3', '根据用户文字描述快速、精准生成图像。分辨率支持1024x1024',
|
||||
ModelTypeConst.TTI, zhipu_tti_model_credential,
|
||||
ZhiPuTextToImage),
|
||||
ModelInfo('cogview-3-plus', '根据用户文字描述生成高质量图像,支持多图片尺寸',
|
||||
ModelTypeConst.TTI, zhipu_tti_model_credential,
|
||||
ZhiPuTextToImage),
|
||||
]
|
||||
|
||||
model_info_manage = (
|
||||
ModelInfoManage.builder()
|
||||
.append_model_info_list(model_info_list)
|
||||
.append_default_model_info(ModelInfo('glm-4', '', ModelTypeConst.LLM, qwen_model_credential, ZhipuChatModel))
|
||||
.append_model_info_list(model_info_image_list)
|
||||
.append_model_info_list(model_info_tti_list)
|
||||
.build()
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -293,6 +293,13 @@ const getApplicationImageModel: (
|
|||
return get(`${prefix}/${application_id}/model`, { model_type: 'IMAGE' }, loading)
|
||||
}
|
||||
|
||||
const getApplicationTTIModel: (
|
||||
application_id: string,
|
||||
loading?: Ref<boolean>
|
||||
) => Promise<Result<Array<any>>> = (application_id, loading) => {
|
||||
return get(`${prefix}/${application_id}/model`, { model_type: 'TTI' }, loading)
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* 发布应用
|
||||
|
|
@ -523,6 +530,7 @@ export default {
|
|||
getApplicationSTTModel,
|
||||
getApplicationTTSModel,
|
||||
getApplicationImageModel,
|
||||
getApplicationTTIModel,
|
||||
postSpeechToText,
|
||||
postTextToSpeech,
|
||||
getPlatformStatus,
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@
|
|||
item.type === WorkflowType.Question ||
|
||||
item.type === WorkflowType.AiChat ||
|
||||
item.type === WorkflowType.ImageUnderstandNode ||
|
||||
item.type === WorkflowType.ImageGenerateNode ||
|
||||
item.type === WorkflowType.Application
|
||||
"
|
||||
>{{ item?.message_tokens + item?.answer_tokens }} tokens</span
|
||||
|
|
@ -444,6 +445,65 @@
|
|||
</div>
|
||||
</div>
|
||||
</template>
|
||||
<!-- 图片生成 -->
|
||||
<template v-if="item.type == WorkflowType.ImageGenerateNode">
|
||||
<div
|
||||
class="card-never border-r-4 mt-8"
|
||||
v-if="item.type !== WorkflowType.Application"
|
||||
>
|
||||
<h5 class="p-8-12">历史记录</h5>
|
||||
<div class="p-8-12 border-t-dashed lighter">
|
||||
<template v-if="item.history_message?.length > 0">
|
||||
<p
|
||||
class="mt-4 mb-4"
|
||||
v-for="(history, historyIndex) in item.history_message"
|
||||
:key="historyIndex"
|
||||
>
|
||||
<span class="color-secondary mr-4">{{ history.role }}:</span>
|
||||
|
||||
<span v-if="Array.isArray(history.content)">
|
||||
<template v-for="(h, i) in history.content" :key="i">
|
||||
<el-image
|
||||
v-if="h.type === 'image_url'"
|
||||
:src="h.image_url.url"
|
||||
alt=""
|
||||
fit="cover"
|
||||
style="width: 40px; height: 40px; display: inline-block"
|
||||
class="border-r-4 mr-8"
|
||||
/>
|
||||
|
||||
<span v-else>{{ h.text }}<br /></span>
|
||||
</template>
|
||||
</span>
|
||||
|
||||
<span v-else>{{ history.content }}</span>
|
||||
</p>
|
||||
</template>
|
||||
<template v-else> - </template>
|
||||
</div>
|
||||
</div>
|
||||
<div class="card-never border-r-4 mt-8">
|
||||
<h5 class="p-8-12">本次对话</h5>
|
||||
<div class="p-8-12 border-t-dashed lighter pre-wrap">
|
||||
{{ item.question || '-' }}
|
||||
</div>
|
||||
</div>
|
||||
<div class="card-never border-r-4 mt-8">
|
||||
<h5 class="p-8-12">
|
||||
{{ item.type == WorkflowType.Application ? '参数输出' : 'AI 回答' }}
|
||||
</h5>
|
||||
<div class="p-8-12 border-t-dashed lighter">
|
||||
<MdPreview
|
||||
v-if="item.answer"
|
||||
ref="editorRef"
|
||||
editorId="preview-only"
|
||||
:modelValue="item.answer"
|
||||
style="background: none"
|
||||
/>
|
||||
<template v-else> - </template>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
</template>
|
||||
<template v-else>
|
||||
<div class="card-never border-r-4">
|
||||
|
|
|
|||
|
|
@ -13,5 +13,6 @@ export enum modelType {
|
|||
STT = '语音识别',
|
||||
TTS = '语音合成',
|
||||
IMAGE = '图片理解',
|
||||
TTI = '图片生成',
|
||||
RERANKER = '重排模型'
|
||||
}
|
||||
|
|
|
|||
|
|
@ -12,5 +12,6 @@ export enum WorkflowType {
|
|||
Application = 'application-node',
|
||||
DocumentExtractNode = 'document-extract-node',
|
||||
ImageUnderstandNode = 'image-understand-node',
|
||||
ImageGenerateNode = 'image-generate-node',
|
||||
FormNode = 'form-node'
|
||||
}
|
||||
|
|
|
|||
|
|
@ -63,6 +63,7 @@ const modelTypeOptions = ref([
|
|||
{ text: '语音识别', value: 'STT' },
|
||||
{ text: '语音合成', value: 'TTS' },
|
||||
{ text: '图片理解', value: 'IMAGE' },
|
||||
{ text: '图片生成', value: 'TTI' },
|
||||
])
|
||||
|
||||
const open = () => {
|
||||
|
|
|
|||
|
|
@ -133,6 +133,7 @@
|
|||
<el-option label="语音识别" value="STT" />
|
||||
<el-option label="语音合成" value="TTS" />
|
||||
<el-option label="图片理解" value="IMAGE" />
|
||||
<el-option label="图片生成" value="TTI" />
|
||||
</el-select>
|
||||
</div>
|
||||
</div>
|
||||
|
|
|
|||
|
|
@ -227,6 +227,28 @@ export const imageUnderstandNode = {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
export const imageGenerateNode = {
|
||||
type: WorkflowType.ImageGenerateNode,
|
||||
text: '根据提供的文本内容生成图片',
|
||||
label: '图片生成',
|
||||
height: 252,
|
||||
properties: {
|
||||
stepName: '图片生成',
|
||||
config: {
|
||||
fields: [
|
||||
{
|
||||
label: 'AI 回答内容',
|
||||
value: 'answer'
|
||||
},
|
||||
{
|
||||
label: '图片',
|
||||
value: 'image'
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
export const menuNodes = [
|
||||
aiChatNode,
|
||||
searchDatasetNode,
|
||||
|
|
@ -236,6 +258,7 @@ export const menuNodes = [
|
|||
rerankerNode,
|
||||
documentExtractNode,
|
||||
imageUnderstandNode,
|
||||
imageGenerateNode,
|
||||
formNode
|
||||
]
|
||||
|
||||
|
|
@ -326,7 +349,8 @@ export const nodeDict: any = {
|
|||
[WorkflowType.FormNode]: formNode,
|
||||
[WorkflowType.Application]: applicationNode,
|
||||
[WorkflowType.DocumentExtractNode]: documentExtractNode,
|
||||
[WorkflowType.ImageUnderstandNode]: imageUnderstandNode
|
||||
[WorkflowType.ImageUnderstandNode]: imageUnderstandNode,
|
||||
[WorkflowType.ImageGenerateNode]: imageGenerateNode
|
||||
}
|
||||
export function isWorkFlow(type: string | undefined) {
|
||||
return type === 'WORK_FLOW'
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ const end_nodes: Array<string> = [
|
|||
WorkflowType.FunctionLib,
|
||||
WorkflowType.FunctionLibCustom,
|
||||
WorkflowType.ImageUnderstandNode,
|
||||
WorkflowType.ImageGenerateNode,
|
||||
WorkflowType.Application
|
||||
]
|
||||
export class WorkFlowInstance {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,6 @@
|
|||
<template>
|
||||
<AppAvatar shape="square" style="background: #14C0FF;">
|
||||
<img src="@/assets/icon_image.svg" style="width: 65%" alt="" />
|
||||
</AppAvatar>
|
||||
</template>
|
||||
<script setup lang="ts"></script>
|
||||
|
|
@ -0,0 +1,14 @@
|
|||
import ImageGenerateNodeVue from './index.vue'
|
||||
import { AppNode, AppNodeModel } from '@/workflow/common/app-node'
|
||||
|
||||
class RerankerNode extends AppNode {
|
||||
constructor(props: any) {
|
||||
super(props, ImageGenerateNodeVue)
|
||||
}
|
||||
}
|
||||
|
||||
export default {
|
||||
type: 'image-generate-node',
|
||||
model: AppNodeModel,
|
||||
view: RerankerNode
|
||||
}
|
||||
|
|
@ -0,0 +1,323 @@
|
|||
<template>
|
||||
<NodeContainer :node-model="nodeModel">
|
||||
<h5 class="title-decoration-1 mb-8">节点设置</h5>
|
||||
<el-card shadow="never" class="card-never">
|
||||
<el-form
|
||||
@submit.prevent
|
||||
:model="form_data"
|
||||
label-position="top"
|
||||
require-asterisk-position="right"
|
||||
label-width="auto"
|
||||
ref="aiChatNodeFormRef"
|
||||
hide-required-asterisk
|
||||
>
|
||||
<el-form-item
|
||||
label="图片生成模型"
|
||||
prop="model_id"
|
||||
:rules="{
|
||||
required: true,
|
||||
message: '请选择图片生成模型',
|
||||
trigger: 'change'
|
||||
}"
|
||||
>
|
||||
<template #label>
|
||||
<div class="flex-between w-full">
|
||||
<div>
|
||||
<span>图片生成模型<span class="danger">*</span></span>
|
||||
</div>
|
||||
<el-button
|
||||
:disabled="!form_data.model_id"
|
||||
type="primary"
|
||||
link
|
||||
@click="openAIParamSettingDialog(form_data.model_id)"
|
||||
@refreshForm="refreshParam"
|
||||
>
|
||||
{{ $t('views.application.applicationForm.form.paramSetting') }}
|
||||
</el-button>
|
||||
</div>
|
||||
</template>
|
||||
<el-select
|
||||
@change="model_change"
|
||||
@wheel="wheel"
|
||||
:teleported="false"
|
||||
v-model="form_data.model_id"
|
||||
placeholder="请选择图片生成模型"
|
||||
class="w-full"
|
||||
popper-class="select-model"
|
||||
:clearable="true"
|
||||
>
|
||||
<el-option-group
|
||||
v-for="(value, label) in modelOptions"
|
||||
:key="value"
|
||||
:label="relatedObject(providerOptions, label, 'provider')?.name"
|
||||
>
|
||||
<el-option
|
||||
v-for="item in value.filter((v: any) => v.status === 'SUCCESS')"
|
||||
:key="item.id"
|
||||
:label="item.name"
|
||||
:value="item.id"
|
||||
class="flex-between"
|
||||
>
|
||||
<div class="flex align-center">
|
||||
<span
|
||||
v-html="relatedObject(providerOptions, label, 'provider')?.icon"
|
||||
class="model-icon mr-8"
|
||||
></span>
|
||||
<span>{{ item.name }}</span>
|
||||
<el-tag v-if="item.permission_type === 'PUBLIC'" type="info" class="info-tag ml-8"
|
||||
>公用
|
||||
</el-tag>
|
||||
</div>
|
||||
<el-icon class="check-icon" v-if="item.id === form_data.model_id">
|
||||
<Check />
|
||||
</el-icon>
|
||||
</el-option>
|
||||
<!-- 不可用 -->
|
||||
<el-option
|
||||
v-for="item in value.filter((v: any) => v.status !== 'SUCCESS')"
|
||||
:key="item.id"
|
||||
:label="item.name"
|
||||
:value="item.id"
|
||||
class="flex-between"
|
||||
disabled
|
||||
>
|
||||
<div class="flex">
|
||||
<span
|
||||
v-html="relatedObject(providerOptions, label, 'provider')?.icon"
|
||||
class="model-icon mr-8"
|
||||
></span>
|
||||
<span>{{ item.name }}</span>
|
||||
<span class="danger">(不可用)</span>
|
||||
</div>
|
||||
<el-icon class="check-icon" v-if="item.id === form_data.model_id">
|
||||
<Check />
|
||||
</el-icon>
|
||||
</el-option>
|
||||
</el-option-group>
|
||||
</el-select>
|
||||
</el-form-item>
|
||||
|
||||
|
||||
<el-form-item
|
||||
label="提示词(正向)"
|
||||
prop="prompt"
|
||||
:rules="{
|
||||
required: true,
|
||||
message: '请输入提示词',
|
||||
trigger: 'blur'
|
||||
}"
|
||||
>
|
||||
<template #label>
|
||||
<div class="flex align-center">
|
||||
<div class="mr-4">
|
||||
<span>提示词(正向)<span class="danger">*</span></span>
|
||||
</div>
|
||||
<el-tooltip effect="dark" placement="right" popper-class="max-w-200">
|
||||
<template #content
|
||||
>正向提示词,用来描述生成图像中期望包含的元素和视觉特点
|
||||
</template>
|
||||
<AppIcon iconName="app-warning" class="app-warning-icon"></AppIcon>
|
||||
</el-tooltip>
|
||||
</div>
|
||||
</template>
|
||||
<MdEditorMagnify
|
||||
@wheel="wheel"
|
||||
title="提示词(正向)"
|
||||
v-model="form_data.prompt"
|
||||
style="height: 150px"
|
||||
@submitDialog="submitDialog"
|
||||
/>
|
||||
</el-form-item>
|
||||
<el-form-item
|
||||
label="提示词(负向)"
|
||||
prop="prompt"
|
||||
:rules="{
|
||||
required: false,
|
||||
message: '请输入提示词',
|
||||
trigger: 'blur'
|
||||
}"
|
||||
>
|
||||
<template #label>
|
||||
<div class="flex align-center">
|
||||
<div class="mr-4">
|
||||
<span>提示词(负向)</span>
|
||||
</div>
|
||||
<el-tooltip effect="dark" placement="right" popper-class="max-w-200">
|
||||
<template #content
|
||||
>反向提示词,用来描述不希望在画面中看到的内容,可以对画面进行限制。
|
||||
</template>
|
||||
<AppIcon iconName="app-warning" class="app-warning-icon"></AppIcon>
|
||||
</el-tooltip>
|
||||
</div>
|
||||
</template>
|
||||
<MdEditorMagnify
|
||||
@wheel="wheel"
|
||||
title="提示词(负向)"
|
||||
v-model="form_data.negative_prompt"
|
||||
style="height: 150px"
|
||||
@submitDialog="submitDialog"
|
||||
/>
|
||||
</el-form-item>
|
||||
<el-form-item>
|
||||
<template #label>
|
||||
<div class="flex-between">
|
||||
<div>历史聊天记录</div>
|
||||
<el-select v-model="form_data.dialogue_type" type="small" style="width: 100px">
|
||||
<el-option label="节点" value="NODE" />
|
||||
<el-option label="工作流" value="WORKFLOW" />
|
||||
</el-select>
|
||||
</div>
|
||||
</template>
|
||||
<el-input-number
|
||||
v-model="form_data.dialogue_number"
|
||||
:min="0"
|
||||
:value-on-clear="0"
|
||||
controls-position="right"
|
||||
class="w-full"
|
||||
:step="1"
|
||||
:step-strictly="true"
|
||||
/>
|
||||
</el-form-item>
|
||||
<el-form-item label="返回内容" @click.prevent>
|
||||
<template #label>
|
||||
<div class="flex align-center">
|
||||
<div class="mr-4">
|
||||
<span>返回内容<span class="danger">*</span></span>
|
||||
</div>
|
||||
<el-tooltip effect="dark" placement="right" popper-class="max-w-200">
|
||||
<template #content>
|
||||
关闭后该节点的内容则不输出给用户。
|
||||
如果你想让用户看到该节点的输出内容,请打开开关。
|
||||
</template>
|
||||
<AppIcon iconName="app-warning" class="app-warning-icon"></AppIcon>
|
||||
</el-tooltip>
|
||||
</div>
|
||||
</template>
|
||||
<el-switch size="small" v-model="form_data.is_result" />
|
||||
</el-form-item>
|
||||
</el-form>
|
||||
</el-card>
|
||||
<AIModeParamSettingDialog ref="AIModeParamSettingDialogRef" @refresh="refreshParam" />
|
||||
</NodeContainer>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import NodeContainer from '@/workflow/common/NodeContainer.vue'
|
||||
import { computed, onMounted, ref } from 'vue'
|
||||
import { groupBy, set } from 'lodash'
|
||||
import { relatedObject } from '@/utils/utils'
|
||||
import type { Provider } from '@/api/type/model'
|
||||
import applicationApi from '@/api/application'
|
||||
import { app } from '@/main'
|
||||
import useStore from '@/stores'
|
||||
import NodeCascader from '@/workflow/common/NodeCascader.vue'
|
||||
import type { FormInstance } from 'element-plus'
|
||||
import AIModeParamSettingDialog from '@/views/application/component/AIModeParamSettingDialog.vue'
|
||||
|
||||
const { model } = useStore()
|
||||
|
||||
const {
|
||||
params: { id }
|
||||
} = app.config.globalProperties.$route as any
|
||||
|
||||
const props = defineProps<{ nodeModel: any }>()
|
||||
const modelOptions = ref<any>(null)
|
||||
const providerOptions = ref<Array<Provider>>([])
|
||||
const AIModeParamSettingDialogRef = ref<InstanceType<typeof AIModeParamSettingDialog>>()
|
||||
|
||||
const aiChatNodeFormRef = ref<FormInstance>()
|
||||
const validate = () => {
|
||||
return aiChatNodeFormRef.value?.validate().catch((err) => {
|
||||
return Promise.reject({ node: props.nodeModel, errMessage: err })
|
||||
})
|
||||
}
|
||||
|
||||
const wheel = (e: any) => {
|
||||
if (e.ctrlKey === true) {
|
||||
e.preventDefault()
|
||||
return true
|
||||
} else {
|
||||
e.stopPropagation()
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
const defaultPrompt = `{{开始.question}}`
|
||||
|
||||
const form = {
|
||||
model_id: '',
|
||||
system: '',
|
||||
prompt: defaultPrompt,
|
||||
negative_prompt: '',
|
||||
dialogue_number: 0,
|
||||
dialogue_type: 'NODE',
|
||||
is_result: true,
|
||||
temperature: null,
|
||||
max_tokens: null,
|
||||
image_list: ['start-node', 'image']
|
||||
}
|
||||
|
||||
const form_data = computed({
|
||||
get: () => {
|
||||
if (props.nodeModel.properties.node_data) {
|
||||
return props.nodeModel.properties.node_data
|
||||
} else {
|
||||
set(props.nodeModel.properties, 'node_data', form)
|
||||
}
|
||||
return props.nodeModel.properties.node_data
|
||||
},
|
||||
set: (value) => {
|
||||
set(props.nodeModel.properties, 'node_data', value)
|
||||
}
|
||||
})
|
||||
|
||||
function getModel() {
|
||||
if (id) {
|
||||
applicationApi.getApplicationTTIModel(id).then((res: any) => {
|
||||
modelOptions.value = groupBy(res?.data, 'provider')
|
||||
})
|
||||
} else {
|
||||
model.asyncGetModel().then((res: any) => {
|
||||
modelOptions.value = groupBy(res?.data, 'provider')
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
function getProvider() {
|
||||
model.asyncGetProvider().then((res: any) => {
|
||||
providerOptions.value = res?.data
|
||||
})
|
||||
}
|
||||
|
||||
const model_change = () => {
|
||||
if (form_data.value.model_id) {
|
||||
AIModeParamSettingDialogRef.value?.reset_default(form_data.value.model_id, id)
|
||||
} else {
|
||||
refreshParam({})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
const openAIParamSettingDialog = (modelId: string) => {
|
||||
if (modelId) {
|
||||
AIModeParamSettingDialogRef.value?.open(modelId, id, form_data.value.model_params_setting)
|
||||
}
|
||||
}
|
||||
|
||||
function refreshParam(data: any) {
|
||||
set(props.nodeModel.properties.node_data, 'model_params_setting', data)
|
||||
}
|
||||
|
||||
function submitDialog(val: string) {
|
||||
set(props.nodeModel.properties.node_data, 'prompt', val)
|
||||
}
|
||||
|
||||
onMounted(() => {
|
||||
getModel()
|
||||
getProvider()
|
||||
|
||||
set(props.nodeModel, 'validate', validate)
|
||||
})
|
||||
</script>
|
||||
|
||||
<style scoped lang="scss"></style>
|
||||
Loading…
Reference in New Issue