diff --git a/apps/application/flow/step_node/__init__.py b/apps/application/flow/step_node/__init__.py index c02d7c0f6..ca6b9233f 100644 --- a/apps/application/flow/step_node/__init__.py +++ b/apps/application/flow/step_node/__init__.py @@ -13,6 +13,7 @@ from .direct_reply_node import * from .document_extract_node import * from .form_node import * from .image_generate_step_node import * +from .image_to_video_step_node import BaseImageToVideoNode from .image_understand_step_node import * from .mcp_node import BaseMcpNode from .question_node import * @@ -21,6 +22,7 @@ from .search_knowledge_node import * from .speech_to_text_step_node import BaseSpeechToTextNode from .start_node import * from .text_to_speech_step_node.impl.base_text_to_speech_node import BaseTextToSpeechNode +from .text_to_video_step_node.impl.base_text_to_video_node import BaseTextToVideoNode from .tool_lib_node import * from .tool_node import * from .variable_assign_node import BaseVariableAssignNode @@ -31,7 +33,8 @@ node_list = [BaseStartStepNode, BaseChatNode, BaseSearchKnowledgeNode, BaseQuest BaseToolNodeNode, BaseToolLibNodeNode, BaseRerankerNode, BaseApplicationNode, BaseDocumentExtractNode, BaseImageUnderstandNode, BaseFormNode, BaseSpeechToTextNode, BaseTextToSpeechNode, - BaseImageGenerateNode, BaseVariableAssignNode, BaseMcpNode,BaseIntentNode] + BaseImageGenerateNode, BaseVariableAssignNode, BaseMcpNode, BaseTextToVideoNode, BaseImageToVideoNode, + BaseIntentNode] def get_node(node_type): diff --git a/apps/application/flow/step_node/image_to_video_step_node/__init__.py b/apps/application/flow/step_node/image_to_video_step_node/__init__.py new file mode 100644 index 000000000..f3feecc9c --- /dev/null +++ b/apps/application/flow/step_node/image_to_video_step_node/__init__.py @@ -0,0 +1,3 @@ +# coding=utf-8 + +from .impl import * diff --git a/apps/application/flow/step_node/image_to_video_step_node/i_image_to_video_node.py b/apps/application/flow/step_node/image_to_video_step_node/i_image_to_video_node.py new file mode 100644 index 000000000..5c408f6d4 --- /dev/null +++ b/apps/application/flow/step_node/image_to_video_step_node/i_image_to_video_node.py @@ -0,0 +1,64 @@ +# coding=utf-8 + +from typing import Type + +from django.utils.translation import gettext_lazy as _ +from rest_framework import serializers + +from application.flow.i_step_node import INode, NodeResult + + +class ImageToVideoNodeSerializer(serializers.Serializer): + model_id = serializers.CharField(required=True, label=_("Model id")) + + prompt = serializers.CharField(required=True, label=_("Prompt word (positive)")) + + negative_prompt = serializers.CharField(required=False, label=_("Prompt word (negative)"), + allow_null=True, allow_blank=True, ) + # 多轮对话数量 + dialogue_number = serializers.IntegerField(required=False, default=0, + label=_("Number of multi-round conversations")) + + dialogue_type = serializers.CharField(required=False, default='NODE', + label=_("Conversation storage type")) + + is_result = serializers.BooleanField(required=False, + label=_('Whether to return content')) + + model_params_setting = serializers.JSONField(required=False, default=dict, + label=_("Model parameter settings")) + + first_frame_url = serializers.ListField(required=True, label=_("First frame url")) + last_frame_url = serializers.ListField(required=False, label=_("Last frame url")) + + +class IImageToVideoNode(INode): + type = 'image-to-video-node' + + def get_node_params_serializer_class(self) -> Type[serializers.Serializer]: + return ImageToVideoNodeSerializer + + def _run(self): + first_frame_url = self.workflow_manage.get_reference_field( + self.node_params_serializer.data.get('first_frame_url')[0], + self.node_params_serializer.data.get('first_frame_url')[1:]) + if first_frame_url is []: + raise ValueError( + _("First frame url cannot be empty")) + last_frame_url = None + if self.node_params_serializer.data.get('last_frame_url') is not None and self.node_params_serializer.data.get( + 'last_frame_url') != []: + last_frame_url = self.workflow_manage.get_reference_field( + self.node_params_serializer.data.get('last_frame_url')[0], + self.node_params_serializer.data.get('last_frame_url')[1:]) + node_params_data = {k: v for k, v in self.node_params_serializer.data.items() + if k not in ['first_frame_url', 'last_frame_url']} + return self.execute(first_frame_url=first_frame_url, last_frame_url=last_frame_url, + **node_params_data, **self.flow_params_serializer.data) + + def execute(self, model_id, prompt, negative_prompt, dialogue_number, dialogue_type, history_chat_record, chat_id, + model_params_setting, + chat_record_id, + first_frame_url, last_frame_url, + **kwargs) -> NodeResult: + pass diff --git a/apps/application/flow/step_node/image_to_video_step_node/impl/__init__.py b/apps/application/flow/step_node/image_to_video_step_node/impl/__init__.py new file mode 100644 index 000000000..95be14851 --- /dev/null +++ b/apps/application/flow/step_node/image_to_video_step_node/impl/__init__.py @@ -0,0 +1,3 @@ +# coding=utf-8 + +from .base_image_to_video_node import BaseImageToVideoNode diff --git a/apps/application/flow/step_node/image_to_video_step_node/impl/base_image_to_video_node.py b/apps/application/flow/step_node/image_to_video_step_node/impl/base_image_to_video_node.py new file mode 100644 index 000000000..e99808520 --- /dev/null +++ b/apps/application/flow/step_node/image_to_video_step_node/impl/base_image_to_video_node.py @@ -0,0 +1,153 @@ +# coding=utf-8 +import base64 +from functools import reduce +from typing import List + +import requests +from django.db.models import QuerySet +from langchain_core.messages import BaseMessage, HumanMessage, AIMessage + +from application.flow.i_step_node import NodeResult +from application.flow.step_node.image_to_video_step_node.i_image_to_video_node import IImageToVideoNode +from common.utils.common import bytes_to_uploaded_file +from knowledge.models import FileSourceType, File +from oss.serializers.file import FileSerializer, mime_types +from models_provider.tools import get_model_instance_by_model_workspace_id + + +class BaseImageToVideoNode(IImageToVideoNode): + def save_context(self, details, workflow_manage): + self.context['answer'] = details.get('answer') + self.context['question'] = details.get('question') + if self.node_params.get('is_result', False): + self.answer_text = details.get('answer') + + def execute(self, model_id, prompt, negative_prompt, dialogue_number, dialogue_type, history_chat_record, chat_id, + model_params_setting, + chat_record_id, + first_frame_url, last_frame_url=None, + **kwargs) -> NodeResult: + application = self.workflow_manage.work_flow_post_handler.chat_info.application + workspace_id = self.workflow_manage.get_body().get('workspace_id') + ttv_model = get_model_instance_by_model_workspace_id(model_id, workspace_id, + **model_params_setting) + history_message = self.get_history_message(history_chat_record, dialogue_number) + self.context['history_message'] = history_message + question = self.generate_prompt_question(prompt) + self.context['question'] = question + message_list = self.generate_message_list(question, history_message) + self.context['message_list'] = message_list + self.context['dialogue_type'] = dialogue_type + self.context['negative_prompt'] = negative_prompt + self.context['first_frame_url'] = first_frame_url + self.context['last_frame_url'] = last_frame_url + # 处理首尾帧图片 这块可以是url 也可以是file_id 如果是url 可以直接传递给模型 如果是file_id 需要传base64 + # 判断是不是 url + first_frame_url = self.get_file_base64(first_frame_url) + last_frame_url = self.get_file_base64(last_frame_url) + video_urls = ttv_model.generate_video(question, negative_prompt, first_frame_url, last_frame_url) + # 保存图片 + if video_urls is None: + return NodeResult({'answer': '生成视频失败'}, {}) + file_name = 'generated_video.mp4' + if isinstance(video_urls, str) and video_urls.startswith('http'): + video_urls = requests.get(video_urls).content + file = bytes_to_uploaded_file(video_urls, file_name) + meta = { + 'debug': False if application.id else True, + 'chat_id': chat_id, + 'application_id': str(application.id) if application.id else None, + } + file_url = FileSerializer(data={ + 'file': file, + 'meta': meta, + 'source_id': meta['application_id'], + 'source_type': FileSourceType.APPLICATION.value + }).upload() + video_label = f'' + video_list = [{'file_id': file_url.split('/')[-1], 'file_name': file_name, 'url': file_url}] + return NodeResult({'answer': video_label, 'chat_model': ttv_model, 'message_list': message_list, + 'video': video_list, + 'history_message': history_message, 'question': question}, {}) + + def get_file_base64(self, image_url): + if isinstance(image_url, list): + image_url = image_url[0].get('file_id') + if isinstance(image_url, str) and not image_url.startswith('http'): + file = QuerySet(File).filter(id=image_url).first() + file_bytes = file.get_bytes() + # 如果我不知道content_type 可以用 magic 库去检测 + file_type = file.file_name.split(".")[-1].lower() + content_type = mime_types.get(file_type, 'application/octet-stream') + encoded_bytes = base64.b64encode(file_bytes) + return f'data:{content_type};base64,{encoded_bytes.decode()}' + return image_url + + def generate_history_ai_message(self, chat_record): + for val in chat_record.details.values(): + if self.node.id == val['node_id'] and 'image_list' in val: + if val['dialogue_type'] == 'WORKFLOW': + return chat_record.get_ai_message() + image_list = val['image_list'] + return AIMessage(content=[ + *[{'type': 'image_url', 'image_url': {'url': f'{file_url}'}} for file_url in image_list] + ]) + return chat_record.get_ai_message() + + def get_history_message(self, history_chat_record, dialogue_number): + start_index = len(history_chat_record) - dialogue_number + history_message = reduce(lambda x, y: [*x, *y], [ + [self.generate_history_human_message(history_chat_record[index]), + self.generate_history_ai_message(history_chat_record[index])] + for index in + range(start_index if start_index > 0 else 0, len(history_chat_record))], []) + return history_message + + def generate_history_human_message(self, chat_record): + + for data in chat_record.details.values(): + if self.node.id == data['node_id'] and 'image_list' in data: + image_list = data['image_list'] + if len(image_list) == 0 or data['dialogue_type'] == 'WORKFLOW': + return HumanMessage(content=chat_record.problem_text) + return HumanMessage(content=data['question']) + return HumanMessage(content=chat_record.problem_text) + + def generate_prompt_question(self, prompt): + return self.workflow_manage.generate_prompt(prompt) + + def generate_message_list(self, question: str, history_message): + return [ + *history_message, + question + ] + + @staticmethod + def reset_message_list(message_list: List[BaseMessage], answer_text): + result = [{'role': 'user' if isinstance(message, HumanMessage) else 'ai', 'content': message.content} for + message + in + message_list] + result.append({'role': 'ai', 'content': answer_text}) + return result + + def get_details(self, index: int, **kwargs): + return { + 'name': self.node.properties.get('stepName'), + "index": index, + 'run_time': self.context.get('run_time'), + 'history_message': [{'content': message.content, 'role': message.type} for message in + (self.context.get('history_message') if self.context.get( + 'history_message') is not None else [])], + 'question': self.context.get('question'), + 'answer': self.context.get('answer'), + 'type': self.node.type, + 'message_tokens': self.context.get('message_tokens'), + 'answer_tokens': self.context.get('answer_tokens'), + 'status': self.status, + 'err_message': self.err_message, + 'first_frame_url': self.context.get('first_frame_url'), + 'last_frame_url': self.context.get('last_frame_url'), + 'dialogue_type': self.context.get('dialogue_type'), + 'negative_prompt': self.context.get('negative_prompt'), + } diff --git a/apps/application/flow/step_node/text_to_video_step_node/__init__.py b/apps/application/flow/step_node/text_to_video_step_node/__init__.py new file mode 100644 index 000000000..f3feecc9c --- /dev/null +++ b/apps/application/flow/step_node/text_to_video_step_node/__init__.py @@ -0,0 +1,3 @@ +# coding=utf-8 + +from .impl import * diff --git a/apps/application/flow/step_node/text_to_video_step_node/i_text_to_video_node.py b/apps/application/flow/step_node/text_to_video_step_node/i_text_to_video_node.py new file mode 100644 index 000000000..c91d70f59 --- /dev/null +++ b/apps/application/flow/step_node/text_to_video_step_node/i_text_to_video_node.py @@ -0,0 +1,45 @@ +# coding=utf-8 + +from typing import Type + +from django.utils.translation import gettext_lazy as _ +from rest_framework import serializers + +from application.flow.i_step_node import INode, NodeResult + + +class TextToVideoNodeSerializer(serializers.Serializer): + model_id = serializers.CharField(required=True, label=_("Model id")) + + prompt = serializers.CharField(required=True, label=_("Prompt word (positive)")) + + negative_prompt = serializers.CharField(required=False, label=_("Prompt word (negative)"), + allow_null=True, allow_blank=True, ) + # 多轮对话数量 + dialogue_number = serializers.IntegerField(required=False, default=0, + label=_("Number of multi-round conversations")) + + dialogue_type = serializers.CharField(required=False, default='NODE', + label=_("Conversation storage type")) + + is_result = serializers.BooleanField(required=False, + label=_('Whether to return content')) + + model_params_setting = serializers.JSONField(required=False, default=dict, + label=_("Model parameter settings")) + + +class ITextToVideoNode(INode): + type = 'text-to-video-node' + + def get_node_params_serializer_class(self) -> Type[serializers.Serializer]: + return TextToVideoNodeSerializer + + def _run(self): + return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data) + + def execute(self, model_id, prompt, negative_prompt, dialogue_number, dialogue_type, history_chat_record, chat_id, + model_params_setting, + chat_record_id, + **kwargs) -> NodeResult: + pass diff --git a/apps/application/flow/step_node/text_to_video_step_node/impl/__init__.py b/apps/application/flow/step_node/text_to_video_step_node/impl/__init__.py new file mode 100644 index 000000000..be03d57a2 --- /dev/null +++ b/apps/application/flow/step_node/text_to_video_step_node/impl/__init__.py @@ -0,0 +1,3 @@ +# coding=utf-8 + +from .base_text_to_video_node import BaseTextToVideoNode diff --git a/apps/application/flow/step_node/text_to_video_step_node/impl/base_text_to_video_node.py b/apps/application/flow/step_node/text_to_video_step_node/impl/base_text_to_video_node.py new file mode 100644 index 000000000..2beb69ce1 --- /dev/null +++ b/apps/application/flow/step_node/text_to_video_step_node/impl/base_text_to_video_node.py @@ -0,0 +1,132 @@ +# coding=utf-8 +from functools import reduce +from typing import List + +import requests +from langchain_core.messages import BaseMessage, HumanMessage, AIMessage + +from application.flow.i_step_node import NodeResult +from application.flow.step_node.text_to_video_step_node.i_text_to_video_node import ITextToVideoNode +from common.utils.common import bytes_to_uploaded_file +from knowledge.models import FileSourceType +from oss.serializers.file import FileSerializer +from models_provider.tools import get_model_instance_by_model_workspace_id + + +class BaseTextToVideoNode(ITextToVideoNode): + def save_context(self, details, workflow_manage): + self.context['answer'] = details.get('answer') + self.context['question'] = details.get('question') + if self.node_params.get('is_result', False): + self.answer_text = details.get('answer') + + def execute(self, model_id, prompt, negative_prompt, dialogue_number, dialogue_type, history_chat_record, chat_id, + model_params_setting, + chat_record_id, + **kwargs) -> NodeResult: + application = self.workflow_manage.work_flow_post_handler.chat_info.application + workspace_id = self.workflow_manage.get_body().get('workspace_id') + ttv_model = get_model_instance_by_model_workspace_id(model_id, workspace_id, + **model_params_setting) + history_message = self.get_history_message(history_chat_record, dialogue_number) + self.context['history_message'] = history_message + question = self.generate_prompt_question(prompt) + self.context['question'] = question + message_list = self.generate_message_list(question, history_message) + self.context['message_list'] = message_list + self.context['dialogue_type'] = dialogue_type + self.context['negative_prompt'] = negative_prompt + video_urls = ttv_model.generate_video(question, negative_prompt) + print('video_urls', video_urls) + # 保存图片 + if video_urls is None: + return NodeResult({'answer': '生成视频失败'}, {}) + file_name = 'generated_video.mp4' + if isinstance(video_urls, str) and video_urls.startswith('http'): + video_urls = requests.get(video_urls).content + file = bytes_to_uploaded_file(video_urls, file_name) + meta = { + 'debug': False if application.id else True, + 'chat_id': chat_id, + 'application_id': str(application.id) if application.id else None, + } + file_url = FileSerializer(data={ + 'file': file, + 'meta': meta, + 'source_id': meta['application_id'], + 'source_type': FileSourceType.APPLICATION.value + }).upload() + print('file_url', file_url) + video_label = f'' + video_list = [{'file_id': file_url.split('/')[-1], 'file_name': file_name, 'url': file_url}] + return NodeResult({'answer': video_label, 'chat_model': ttv_model, 'message_list': message_list, + 'video': video_list, + 'history_message': history_message, 'question': question}, {}) + + def generate_history_ai_message(self, chat_record): + for val in chat_record.details.values(): + if self.node.id == val['node_id'] and 'image_list' in val: + if val['dialogue_type'] == 'WORKFLOW': + return chat_record.get_ai_message() + image_list = val['image_list'] + return AIMessage(content=[ + *[{'type': 'image_url', 'image_url': {'url': f'{file_url}'}} for file_url in image_list] + ]) + return chat_record.get_ai_message() + + def get_history_message(self, history_chat_record, dialogue_number): + start_index = len(history_chat_record) - dialogue_number + history_message = reduce(lambda x, y: [*x, *y], [ + [self.generate_history_human_message(history_chat_record[index]), + self.generate_history_ai_message(history_chat_record[index])] + for index in + range(start_index if start_index > 0 else 0, len(history_chat_record))], []) + return history_message + + def generate_history_human_message(self, chat_record): + + for data in chat_record.details.values(): + if self.node.id == data['node_id'] and 'image_list' in data: + image_list = data['image_list'] + if len(image_list) == 0 or data['dialogue_type'] == 'WORKFLOW': + return HumanMessage(content=chat_record.problem_text) + return HumanMessage(content=data['question']) + return HumanMessage(content=chat_record.problem_text) + + def generate_prompt_question(self, prompt): + return self.workflow_manage.generate_prompt(prompt) + + def generate_message_list(self, question: str, history_message): + return [ + *history_message, + question + ] + + @staticmethod + def reset_message_list(message_list: List[BaseMessage], answer_text): + result = [{'role': 'user' if isinstance(message, HumanMessage) else 'ai', 'content': message.content} for + message + in + message_list] + result.append({'role': 'ai', 'content': answer_text}) + return result + + def get_details(self, index: int, **kwargs): + return { + 'name': self.node.properties.get('stepName'), + "index": index, + 'run_time': self.context.get('run_time'), + 'history_message': [{'content': message.content, 'role': message.type} for message in + (self.context.get('history_message') if self.context.get( + 'history_message') is not None else [])], + 'question': self.context.get('question'), + 'answer': self.context.get('answer'), + 'type': self.node.type, + 'message_tokens': self.context.get('message_tokens'), + 'answer_tokens': self.context.get('answer_tokens'), + 'status': self.status, + 'err_message': self.err_message, + 'image_list': self.context.get('image_list'), + 'dialogue_type': self.context.get('dialogue_type'), + 'negative_prompt': self.context.get('negative_prompt'), + } diff --git a/apps/locales/en_US/LC_MESSAGES/django.po b/apps/locales/en_US/LC_MESSAGES/django.po index 2d8c7ee74..37be5a3f7 100644 --- a/apps/locales/en_US/LC_MESSAGES/django.po +++ b/apps/locales/en_US/LC_MESSAGES/django.po @@ -8674,6 +8674,12 @@ msgstr "" msgid "This folder contains resources that you dont have permission" msgstr "" +msgid "Text to Video" +msgstr "" + +msgid "Image to Video" +msgstr "" + msgid "Authentication failed. Please verify that the parameters are correct" msgstr "" @@ -8684,4 +8690,14 @@ msgid "Prompt template" msgstr "" msgid "generate prompt" +msgstr "" + + +msgid "Watermark" +msgstr "" + +msgid "Whether to add watermark" +msgstr "" + +msgid "Resolution" msgstr "" \ No newline at end of file diff --git a/apps/locales/zh_CN/LC_MESSAGES/django.po b/apps/locales/zh_CN/LC_MESSAGES/django.po index a01eb7ca6..0497e828a 100644 --- a/apps/locales/zh_CN/LC_MESSAGES/django.po +++ b/apps/locales/zh_CN/LC_MESSAGES/django.po @@ -8800,6 +8800,13 @@ msgstr "系统资源授权" msgid "This folder contains resources that you dont have permission" msgstr "此文件夹包含您没有权限的资源" + +msgid "Text to Video" +msgstr "文生视频" + +msgid "Image to Video" +msgstr "图生视频" + msgid "Authentication failed. Please verify that the parameters are correct" msgstr "认证失败,请检查参数是否正确" @@ -8810,4 +8817,13 @@ msgid "Prompt template" msgstr "提示词模板" msgid "generate prompt" -msgstr "生成提示词" \ No newline at end of file +msgstr "生成提示词" + +msgid "Watermark" +msgstr "水印" + +msgid "Whether to add watermark" +msgstr "是否添加水印" + +msgid "Resolution" +msgstr "分辨率" \ No newline at end of file diff --git a/apps/locales/zh_Hant/LC_MESSAGES/django.po b/apps/locales/zh_Hant/LC_MESSAGES/django.po index 893d7e4ad..cc2e1e060 100644 --- a/apps/locales/zh_Hant/LC_MESSAGES/django.po +++ b/apps/locales/zh_Hant/LC_MESSAGES/django.po @@ -8800,6 +8800,13 @@ msgstr "系統資源授權" msgid "This folder contains resources that you dont have permission" msgstr "此資料夾包含您沒有許可權的資源" + +msgid "Text to Video" +msgstr "文生視頻" + +msgid "Image to Video" +msgstr "圖生視頻" + msgid "Authentication failed. Please verify that the parameters are correct" msgstr "認證失敗,請檢查參數是否正確" @@ -8810,4 +8817,13 @@ msgid "Prompt template" msgstr "提示詞範本" msgid "generate prompt" -msgstr "生成提示詞" \ No newline at end of file +msgstr "生成提示詞" + +msgid "Watermark" +msgstr "水印" + +msgid "Whether to add watermark" +msgstr "是否添加水印" + +msgid "Resolution" +msgstr "分辨率" \ No newline at end of file diff --git a/apps/models_provider/base_model_provider.py b/apps/models_provider/base_model_provider.py index 74f9205df..b77df219c 100644 --- a/apps/models_provider/base_model_provider.py +++ b/apps/models_provider/base_model_provider.py @@ -147,6 +147,11 @@ class ModelTypeConst(Enum): IMAGE = {'code': 'IMAGE', 'message': _('Vision Model')} TTI = {'code': 'TTI', 'message': _('Image Generation')} RERANKER = {'code': 'RERANKER', 'message': _('Rerank')} + #文生视频 图生视频 + TTV = {'code': 'TTV', 'message': _('Text to Video')} + ITV = {'code': 'ITV', 'message': _('Image to Video')} + + class ModelInfo: diff --git a/apps/models_provider/base_ttv.py b/apps/models_provider/base_ttv.py new file mode 100644 index 000000000..4ef51c683 --- /dev/null +++ b/apps/models_provider/base_ttv.py @@ -0,0 +1,14 @@ +# coding=utf-8 +from abc import abstractmethod + +from pydantic import BaseModel + + +class BaseGenerationVideo(BaseModel): + @abstractmethod + def check_auth(self): + pass + + @abstractmethod + def generate_video(self, prompt: str, negative_prompt: str = None, first_frame_url=None, last_frame_url=None): + pass diff --git a/apps/models_provider/impl/aliyun_bai_lian_model_provider/aliyun_bai_lian_model_provider.py b/apps/models_provider/impl/aliyun_bai_lian_model_provider/aliyun_bai_lian_model_provider.py index d402ae72d..c92c23ed9 100644 --- a/apps/models_provider/impl/aliyun_bai_lian_model_provider/aliyun_bai_lian_model_provider.py +++ b/apps/models_provider/impl/aliyun_bai_lian_model_provider/aliyun_bai_lian_model_provider.py @@ -15,6 +15,7 @@ from models_provider.impl.aliyun_bai_lian_model_provider.credential.asr_stt impo from models_provider.impl.aliyun_bai_lian_model_provider.credential.embedding import \ AliyunBaiLianEmbeddingCredential from models_provider.impl.aliyun_bai_lian_model_provider.credential.image import QwenVLModelCredential +from models_provider.impl.aliyun_bai_lian_model_provider.credential.itv import ImageToVideoModelCredential from models_provider.impl.aliyun_bai_lian_model_provider.credential.llm import BaiLianLLMModelCredential from models_provider.impl.aliyun_bai_lian_model_provider.credential.omni_stt import AliyunBaiLianOmiSTTModelCredential from models_provider.impl.aliyun_bai_lian_model_provider.credential.reranker import \ @@ -22,6 +23,7 @@ from models_provider.impl.aliyun_bai_lian_model_provider.credential.reranker imp from models_provider.impl.aliyun_bai_lian_model_provider.credential.stt import AliyunBaiLianSTTModelCredential from models_provider.impl.aliyun_bai_lian_model_provider.credential.tti import QwenTextToImageModelCredential from models_provider.impl.aliyun_bai_lian_model_provider.credential.tts import AliyunBaiLianTTSModelCredential +from models_provider.impl.aliyun_bai_lian_model_provider.credential.ttv import TextToVideoModelCredential from models_provider.impl.aliyun_bai_lian_model_provider.model.asr_stt import AliyunBaiLianAsrSpeechToText from models_provider.impl.aliyun_bai_lian_model_provider.model.embedding import AliyunBaiLianEmbedding from models_provider.impl.aliyun_bai_lian_model_provider.model.image import QwenVLChatModel @@ -34,6 +36,8 @@ from models_provider.impl.aliyun_bai_lian_model_provider.model.tts import Aliyun from maxkb.conf import PROJECT_DIR from django.utils.translation import gettext as _, gettext +from models_provider.impl.aliyun_bai_lian_model_provider.model.ttv import GenerationVideoModel + aliyun_bai_lian_model_credential = AliyunBaiLianRerankerCredential() aliyun_bai_lian_tts_model_credential = AliyunBaiLianTTSModelCredential() aliyun_bai_lian_stt_model_credential = AliyunBaiLianSTTModelCredential() @@ -43,6 +47,8 @@ aliyun_bai_lian_embedding_model_credential = AliyunBaiLianEmbeddingCredential() aliyun_bai_lian_llm_model_credential = BaiLianLLMModelCredential() qwenvl_model_credential = QwenVLModelCredential() qwentti_model_credential = QwenTextToImageModelCredential() +aliyun_bai_lian_ttv_model_credential = TextToVideoModelCredential() +aliyun_bai_lian_itv_model_credential = ImageToVideoModelCredential() model_info_list = [ModelInfo('gte-rerank', _('With the GTE-Rerank text sorting series model developed by Alibaba Tongyi Lab, developers can integrate high-quality text retrieval and sorting through the LlamaIndex framework.'), @@ -104,6 +110,24 @@ module_info_tti_list = [ _('Tongyi Wanxiang - a large image model for text generation, supports bilingual input in Chinese and English, and supports the input of reference pictures for reference content or reference style migration. Key styles include but are not limited to watercolor, oil painting, Chinese painting, sketch, flat illustration, two-dimensional, and 3D. Cartoon.'), ModelTypeConst.TTI, qwentti_model_credential, QwenTextToImageModel), ] +model_info_ttv_list = [ + ModelInfo('wan2.2-t2v-plus', '', ModelTypeConst.TTV, aliyun_bai_lian_ttv_model_credential, + GenerationVideoModel), + ModelInfo('wanx2.1-t2v-turbo', '', ModelTypeConst.TTV, aliyun_bai_lian_ttv_model_credential, + GenerationVideoModel), + ModelInfo('wanx2.1-t2v-plus', '', ModelTypeConst.TTV, aliyun_bai_lian_ttv_model_credential, + GenerationVideoModel), +] +module_info_itv_list = [ + ModelInfo('wan2.2-i2v-flash', '', ModelTypeConst.ITV, aliyun_bai_lian_itv_model_credential, + GenerationVideoModel), + ModelInfo('wan2.2-i2v-plus', '', ModelTypeConst.ITV, aliyun_bai_lian_itv_model_credential, + GenerationVideoModel), + ModelInfo('wanx2.1-i2v-plus', '', ModelTypeConst.ITV, aliyun_bai_lian_itv_model_credential, + GenerationVideoModel), + ModelInfo('wanx2.1-i2v-turbo', '', ModelTypeConst.ITV, aliyun_bai_lian_itv_model_credential, + GenerationVideoModel), +] model_info_manage = ( ModelInfoManage.builder() @@ -117,6 +141,10 @@ model_info_manage = ( .append_default_model_info(model_info_list[3]) .append_default_model_info(model_info_list[4]) .append_default_model_info(model_info_list[0]) + .append_model_info_list(model_info_ttv_list) + .append_default_model_info(model_info_ttv_list[0]) + .append_model_info_list(module_info_itv_list) + .append_default_model_info(module_info_itv_list[0]) .build() ) diff --git a/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/itv.py b/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/itv.py new file mode 100644 index 000000000..c4afd2cbb --- /dev/null +++ b/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/itv.py @@ -0,0 +1,120 @@ +# coding=utf-8 + +import traceback +from typing import Dict, Any + +from django.utils.translation import gettext_lazy as _, gettext + +from common.exception.app_exception import AppApiException +from common.forms import BaseForm, PasswordInputField, SingleSelect, SliderField, TooltipLabel +from common.forms.switch_field import SwitchField +from models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class QwenModelParams(BaseForm): + """ + Parameters class for the Qwen Image-to-Video model. + Defines fields such as Video size, number of Videos, and style. + """ + resolution = SingleSelect( + TooltipLabel(_('Resolution'), ''), + required=True, + default_value='480P', + option_list=[ + {'value': '480P', 'label': '480P'}, + {'value': '720P', 'label': '720P'}, + {'value': '1080P', 'label': '1080P'}, + ], + text_field='label', + value_field='value' + ) + + watermark = SwitchField( + TooltipLabel(_('Watermark'), _('Whether to add watermark')), + default_value=False, + ) + + +class ImageToVideoModelCredential(BaseForm, BaseModelCredential): + """ + Credential class for the Qwen Image-to-Video model. + Provides validation and encryption for the model credentials. + """ + + api_key = PasswordInputField('API Key', required=True) + + def is_valid( + self, + model_type: str, + model_name: str, + model_credential: Dict[str, Any], + model_params: Dict[str, Any], + provider, + raise_exception: bool = False + ) -> bool: + """ + Validate the model credentials. + + :param model_type: Type of the model (e.g., 'TEXT_TO_Video'). + :param model_name: Name of the model. + :param model_credential: Dictionary containing the model credentials. + :param model_params: Parameters for the model. + :param provider: Model provider instance. + :param raise_exception: Whether to raise an exception on validation failure. + :return: Boolean indicating whether the credentials are valid. + """ + model_type_list = provider.get_model_type_list() + if not any(mt.get('value') == model_type for mt in model_type_list): + raise AppApiException( + ValidCode.valid_error.value, + gettext('{model_type} Model type is not supported').format(model_type=model_type) + ) + + required_keys = ['api_key'] + for key in required_keys: + if key not in model_credential: + if raise_exception: + raise AppApiException( + ValidCode.valid_error.value, + gettext('{key} is required').format(key=key) + ) + return False + + try: + model = provider.get_model(model_type, model_name, model_credential, **model_params) + res = model.check_auth() + except Exception as e: + traceback.print_exc() + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException( + ValidCode.valid_error.value, + gettext( + 'Verification failed, please check whether the parameters are correct: {error}' + ).format(error=str(e)) + ) + return False + + return True + + def encryption_dict(self, model: Dict[str, object]) -> Dict[str, object]: + """ + Encrypt sensitive fields in the model dictionary. + + :param model: Dictionary containing model details. + :return: Dictionary with encrypted sensitive fields. + """ + return { + **model, + 'api_key': super().encryption(model.get('api_key', '')) + } + + def get_model_params_setting_form(self, model_name: str): + """ + Get the parameter setting form for the specified model. + + :param model_name: Name of the model. + :return: Parameter setting form. + """ + return QwenModelParams() diff --git a/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/ttv.py b/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/ttv.py new file mode 100644 index 000000000..a2eefbac5 --- /dev/null +++ b/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/ttv.py @@ -0,0 +1,122 @@ +# coding=utf-8 + +import traceback +from typing import Dict, Any + +from django.utils.translation import gettext_lazy as _, gettext + +from common.exception.app_exception import AppApiException +from common.forms import BaseForm, PasswordInputField, SingleSelect, SliderField, TooltipLabel +from common.forms.switch_field import SwitchField +from models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class QwenModelParams(BaseForm): + """ + Parameters class for the Qwen Text-to-Video model. + Defines fields such as Video size, number of Videos, and style. + """ + + size = SingleSelect( + TooltipLabel(_('Video size'), _('Specify the size of the generated Video, such as: 1024x1024')), + required=True, + default_value='1280*720', + option_list=[ + {'value': '832*480', 'label': '832*480'}, + {'value': '480*832', 'label': '480*832'}, + {'value': '1280*720', 'label': '1280*720'}, + {'value': '720*1280', 'label': '720*1280'}, + ], + text_field='label', + value_field='value' + ) + + watermark = SwitchField( + TooltipLabel(_('Watermark'), _('Whether to add watermark')), + default_value=False, + ) + + +class TextToVideoModelCredential(BaseForm, BaseModelCredential): + """ + Credential class for the Qwen Text-to-Video model. + Provides validation and encryption for the model credentials. + """ + + api_key = PasswordInputField('API Key', required=True) + + def is_valid( + self, + model_type: str, + model_name: str, + model_credential: Dict[str, Any], + model_params: Dict[str, Any], + provider, + raise_exception: bool = False + ) -> bool: + """ + Validate the model credentials. + + :param model_type: Type of the model (e.g., 'TEXT_TO_Video'). + :param model_name: Name of the model. + :param model_credential: Dictionary containing the model credentials. + :param model_params: Parameters for the model. + :param provider: Model provider instance. + :param raise_exception: Whether to raise an exception on validation failure. + :return: Boolean indicating whether the credentials are valid. + """ + model_type_list = provider.get_model_type_list() + if not any(mt.get('value') == model_type for mt in model_type_list): + raise AppApiException( + ValidCode.valid_error.value, + gettext('{model_type} Model type is not supported').format(model_type=model_type) + ) + + required_keys = ['api_key'] + for key in required_keys: + if key not in model_credential: + if raise_exception: + raise AppApiException( + ValidCode.valid_error.value, + gettext('{key} is required').format(key=key) + ) + return False + + try: + model = provider.get_model(model_type, model_name, model_credential, **model_params) + res = model.check_auth() + except Exception as e: + traceback.print_exc() + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException( + ValidCode.valid_error.value, + gettext( + 'Verification failed, please check whether the parameters are correct: {error}' + ).format(error=str(e)) + ) + return False + + return True + + def encryption_dict(self, model: Dict[str, object]) -> Dict[str, object]: + """ + Encrypt sensitive fields in the model dictionary. + + :param model: Dictionary containing model details. + :return: Dictionary with encrypted sensitive fields. + """ + return { + **model, + 'api_key': super().encryption(model.get('api_key', '')) + } + + def get_model_params_setting_form(self, model_name: str): + """ + Get the parameter setting form for the specified model. + + :param model_name: Name of the model. + :return: Parameter setting form. + """ + return QwenModelParams() diff --git a/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/ttv.py b/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/ttv.py new file mode 100644 index 000000000..cbbb7bdff --- /dev/null +++ b/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/ttv.py @@ -0,0 +1,110 @@ +import time +from http import HTTPStatus +from typing import Dict, Optional +import requests +from dashscope import VideoSynthesis +from langchain_core.messages import HumanMessage +from django.utils.translation import gettext + +from langchain_community.chat_models import ChatTongyi +from models_provider.base_model_provider import MaxKBBaseModel +from models_provider.base_ttv import BaseGenerationVideo +from common.utils.logger import maxkb_logger + + +class GenerationVideoModel(MaxKBBaseModel, BaseGenerationVideo): + api_key: str + model_name: str + params: dict + max_retries: int = 3 + retry_delay: int = 5 # seconds + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.api_key = kwargs.get('api_key') + self.model_name = kwargs.get('model_name') + self.params = kwargs.get('params', {}) + self.max_retries = kwargs.get('max_retries', 3) + self.retry_delay = 5 + + @staticmethod + def is_cache_model(): + return False + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + optional_params = {'params': {}} + for key, value in model_kwargs.items(): + if key not in ['model_id', 'use_local', 'streaming']: + optional_params['params'][key] = value + return GenerationVideoModel( + model_name=model_name, + api_key=model_credential.get('api_key'), + **optional_params, + ) + + def check_auth(self): + chat = ChatTongyi(api_key=self.api_key, model_name='qwen-max') + self._safe_call(chat.invoke, input=[HumanMessage([{"type": "text", "text": gettext('Hello')}])]) + + def _safe_call(self, func, **kwargs): + """带重试的请求封装""" + for attempt in range(self.max_retries): + try: + rsp = func(**kwargs) + return rsp + except (requests.exceptions.ProxyError, + requests.exceptions.ConnectionError, + requests.exceptions.Timeout) as e: + maxkb_logger.error(f"⚠️ 网络错误: {e},正在重试 {attempt + 1}/{self.max_retries}...") + time.sleep(self.retry_delay) + raise RuntimeError("多次重试后仍无法连接到 DashScope API,请检查代理或网络配置") + + # --- 通用异步生成函数 --- + def generate_video(self, prompt, negative_prompt=None, first_frame_url=None, last_frame_url=None, **kwargs): + """ + prompt: 文本描述 + negative_prompt: 反向文本描述 + first_frame_url: 起始关键帧图片 URL (KF2V 必填) + last_frame_url: 结束关键帧图片 URL (KF2V 必填) + 如果没有提供last_frame_url,则表示只提供了first_frame_url,生成的是单关键帧视频(KFV) 参数是img_url + """ + + # 构建基础参数 + params = {"api_key": self.api_key, "prompt": prompt, "model": self.model_name, + "negative_prompt": negative_prompt} + if first_frame_url and last_frame_url: + params['first_frame_url'] = first_frame_url + params["last_frame_url"] = last_frame_url + elif first_frame_url: + params['img_url'] = first_frame_url + + # 合并所有额外参数 + params.update(self.params) + + # --- 异步提交任务 --- + rsp = self._safe_call(VideoSynthesis.async_call, **params) + if rsp.status_code != HTTPStatus.OK: + maxkb_logger.info('提交任务失败,status_code: %s, code: %s, message: %s' % + (rsp.status_code, rsp.code, rsp.message)) + return None + + maxkb_logger.info("task_id:", rsp.output.task_id) + + # --- 查询任务状态 --- + status = self._safe_call(VideoSynthesis.fetch, task=rsp, api_key=self.api_key) + if status.status_code == HTTPStatus.OK: + maxkb_logger.info("当前任务状态:", status.output.task_status) + else: + maxkb_logger.error('获取任务状态失败,status_code: %s, code: %s, message: %s' % + (status.status_code, status.code, status.message)) + + # --- 等待任务完成 --- + rsp = self._safe_call(VideoSynthesis.wait, task=rsp, api_key=self.api_key) + if rsp.status_code == HTTPStatus.OK: + maxkb_logger.info("视频生成完成!视频 URL:", rsp.output.video_url) + return rsp.output.video_url + else: + maxkb_logger.error('生成失败,status_code: %s, code: %s, message: %s' % + (rsp.status_code, rsp.code, rsp.message)) + return None diff --git a/apps/models_provider/impl/base_chat_open_ai.py b/apps/models_provider/impl/base_chat_open_ai.py index 626a751f7..6c6698d12 100644 --- a/apps/models_provider/impl/base_chat_open_ai.py +++ b/apps/models_provider/impl/base_chat_open_ai.py @@ -1,4 +1,6 @@ # coding=utf-8 +from concurrent.futures import ThreadPoolExecutor +from requests.exceptions import ConnectTimeout, ReadTimeout from typing import Dict, Optional, Any, Iterator, cast, Union, Sequence, Callable, Mapping from langchain_core.language_models import LanguageModelInput @@ -92,13 +94,24 @@ class BaseChatOpenAI(ChatOpenAI): tools: Optional[ Sequence[Union[dict[str, Any], type, Callable, BaseTool]] ] = None, + timeout: Optional[float] = 0.5, ) -> int: if self.usage_metadata is None or self.usage_metadata == {}: - try: - return super().get_num_tokens_from_messages(messages) - except Exception as e: - tokenizer = TokenizerManage.get_tokenizer() - return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages]) + + with ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit(super().get_num_tokens_from_messages, messages, tools) + try: + response = future.result() + print("请求成功(未超时)") + return response + except Exception as e: + if isinstance(e, ReadTimeout): + raise # 继续抛出 + else: + print("except:", e) + tokenizer = TokenizerManage.get_tokenizer() + return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages]) + return self.usage_metadata.get('input_tokens', self.usage_metadata.get('prompt_tokens', 0)) def get_num_tokens(self, text: str) -> int: diff --git a/ui/src/assets/workflow/icon_image_to_video.svg b/ui/src/assets/workflow/icon_image_to_video.svg new file mode 100644 index 000000000..7255dfa81 --- /dev/null +++ b/ui/src/assets/workflow/icon_image_to_video.svg @@ -0,0 +1,13 @@ + diff --git a/ui/src/assets/workflow/icon_text_to_video.svg b/ui/src/assets/workflow/icon_text_to_video.svg new file mode 100644 index 000000000..d9deacb38 --- /dev/null +++ b/ui/src/assets/workflow/icon_text_to_video.svg @@ -0,0 +1,13 @@ + diff --git a/ui/src/components/ai-chat/component/knowledge-source-component/ExecutionDetailContent.vue b/ui/src/components/ai-chat/component/knowledge-source-component/ExecutionDetailContent.vue index d73cf8227..83f364d60 100644 --- a/ui/src/components/ai-chat/component/knowledge-source-component/ExecutionDetailContent.vue +++ b/ui/src/components/ai-chat/component/knowledge-source-component/ExecutionDetailContent.vue @@ -7,7 +7,7 @@