diff --git a/apps/application/flow/step_node/image_to_video_step_node/impl/base_image_to_video_node.py b/apps/application/flow/step_node/image_to_video_step_node/impl/base_image_to_video_node.py index e99808520..d294058c1 100644 --- a/apps/application/flow/step_node/image_to_video_step_node/impl/base_image_to_video_node.py +++ b/apps/application/flow/step_node/image_to_video_step_node/impl/base_image_to_video_node.py @@ -13,7 +13,7 @@ from common.utils.common import bytes_to_uploaded_file from knowledge.models import FileSourceType, File from oss.serializers.file import FileSerializer, mime_types from models_provider.tools import get_model_instance_by_model_workspace_id - +from django.utils.translation import gettext class BaseImageToVideoNode(IImageToVideoNode): def save_context(self, details, workflow_manage): @@ -48,7 +48,7 @@ class BaseImageToVideoNode(IImageToVideoNode): video_urls = ttv_model.generate_video(question, negative_prompt, first_frame_url, last_frame_url) # 保存图片 if video_urls is None: - return NodeResult({'answer': '生成视频失败'}, {}) + return NodeResult({'answer': gettext('Failed to generate video')}, {}) file_name = 'generated_video.mp4' if isinstance(video_urls, str) and video_urls.startswith('http'): video_urls = requests.get(video_urls).content diff --git a/apps/application/flow/step_node/text_to_video_step_node/impl/base_text_to_video_node.py b/apps/application/flow/step_node/text_to_video_step_node/impl/base_text_to_video_node.py index 2beb69ce1..0fa97a07c 100644 --- a/apps/application/flow/step_node/text_to_video_step_node/impl/base_text_to_video_node.py +++ b/apps/application/flow/step_node/text_to_video_step_node/impl/base_text_to_video_node.py @@ -11,7 +11,7 @@ from common.utils.common import bytes_to_uploaded_file from knowledge.models import FileSourceType from oss.serializers.file import FileSerializer from models_provider.tools import get_model_instance_by_model_workspace_id - +from django.utils.translation import gettext class BaseTextToVideoNode(ITextToVideoNode): def save_context(self, details, workflow_manage): @@ -40,7 +40,7 @@ class BaseTextToVideoNode(ITextToVideoNode): print('video_urls', video_urls) # 保存图片 if video_urls is None: - return NodeResult({'answer': '生成视频失败'}, {}) + return NodeResult({'answer': gettext('Failed to generate video')}, {}) file_name = 'generated_video.mp4' if isinstance(video_urls, str) and video_urls.startswith('http'): video_urls = requests.get(video_urls).content @@ -56,7 +56,6 @@ class BaseTextToVideoNode(ITextToVideoNode): 'source_id': meta['application_id'], 'source_type': FileSourceType.APPLICATION.value }).upload() - print('file_url', file_url) video_label = f'' video_list = [{'file_id': file_url.split('/')[-1], 'file_name': file_name, 'url': file_url}] return NodeResult({'answer': video_label, 'chat_model': ttv_model, 'message_list': message_list, diff --git a/apps/locales/en_US/LC_MESSAGES/django.po b/apps/locales/en_US/LC_MESSAGES/django.po index 37be5a3f7..364419723 100644 --- a/apps/locales/en_US/LC_MESSAGES/django.po +++ b/apps/locales/en_US/LC_MESSAGES/django.po @@ -8700,4 +8700,13 @@ msgid "Whether to add watermark" msgstr "" msgid "Resolution" +msgstr "" + +msgid "Ratio" +msgstr "" + +msgid "Duration" +msgstr "" + +msgid "Failed to generate video" msgstr "" \ No newline at end of file diff --git a/apps/locales/zh_CN/LC_MESSAGES/django.po b/apps/locales/zh_CN/LC_MESSAGES/django.po index 0497e828a..dc1046dcb 100644 --- a/apps/locales/zh_CN/LC_MESSAGES/django.po +++ b/apps/locales/zh_CN/LC_MESSAGES/django.po @@ -8826,4 +8826,13 @@ msgid "Whether to add watermark" msgstr "是否添加水印" msgid "Resolution" -msgstr "分辨率" \ No newline at end of file +msgstr "分辨率" + +msgid "Ratio" +msgstr "比例" + +msgid "Duration" +msgstr "时长" + +msgid "Failed to generate video" +msgstr "生成视频失败" \ No newline at end of file diff --git a/apps/locales/zh_Hant/LC_MESSAGES/django.po b/apps/locales/zh_Hant/LC_MESSAGES/django.po index cc2e1e060..c7c0d9c98 100644 --- a/apps/locales/zh_Hant/LC_MESSAGES/django.po +++ b/apps/locales/zh_Hant/LC_MESSAGES/django.po @@ -8826,4 +8826,13 @@ msgid "Whether to add watermark" msgstr "是否添加水印" msgid "Resolution" -msgstr "分辨率" \ No newline at end of file +msgstr "分辨率" + +msgid "Ratio" +msgstr "比例" + +msgid "Duration" +msgstr "時長" + +msgid "Failed to generate video" +msgstr "生成視頻失敗" \ No newline at end of file diff --git a/apps/models_provider/impl/volcanic_engine_model_provider/credential/ttv.py b/apps/models_provider/impl/volcanic_engine_model_provider/credential/ttv.py new file mode 100644 index 000000000..dec802a94 --- /dev/null +++ b/apps/models_provider/impl/volcanic_engine_model_provider/credential/ttv.py @@ -0,0 +1,89 @@ +# coding=utf-8 +import traceback +from typing import Dict + +from django.utils.translation import gettext_lazy as _, gettext + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm, TooltipLabel, SingleSelect, TextInputField +from common.forms.switch_field import SwitchField +from models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class VolcanicEngineTTVModelGeneralParams(BaseForm): + resolution = SingleSelect( + TooltipLabel(_('Resolution'), _('Resolution')), + required=True, + default_value='480P', + option_list=[ + {'value': '480P', 'label': '480P'}, + {'value': '720P', 'label': '720P'}, + {'value': '1080P', 'label': '1080P'}, + ], + text_field='label', + value_field='value' + ) + ratio = SingleSelect( + TooltipLabel(_('Ratio'), _('Ratio')), + required=True, + default_value='16:9', + option_list=[ + {'value': '16:9', 'label': '16:9'}, + {'value': '9:16', 'label': '9:16'}, + {'value': '1:1', 'label': '1:1'}, + {'value': '4:3', 'label': '4:3'}, + {'value': '3:4', 'label': '3:4'}, + {'value': '21:9', 'label': '21:9'}, + ], + text_field='label', + value_field='value' + ) + duration = TextInputField( + TooltipLabel(_('Duration'), _('Duration')), + required=True, + default_value=5, + ) + + watermark = SwitchField( + TooltipLabel(_('Watermark'), _('Whether to add watermark')), + default_value=False, + ) + + +class VolcanicEngineTTVModelCredential(BaseForm, BaseModelCredential): + api_key = forms.PasswordInputField('Api key', required=True) + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, + gettext('{model_type} Model type is not supported').format(model_type=model_type)) + + for key in ['api_key']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, gettext('{key} is required').format(key=key)) + else: + return False + try: + model = provider.get_model(model_type, model_name, model_credential, **model_params) + model.check_auth() + except Exception as e: + traceback.print_exc() + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, gettext( + 'Verification failed, please check whether the parameters are correct: {error}').format( + error=str(e))) + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'api_key': super().encryption(model.get('api_key', ''))} + + def get_model_params_setting_form(self, model_name): + return VolcanicEngineTTVModelGeneralParams() diff --git a/apps/models_provider/impl/volcanic_engine_model_provider/model/ttv.py b/apps/models_provider/impl/volcanic_engine_model_provider/model/ttv.py new file mode 100644 index 000000000..92200afe0 --- /dev/null +++ b/apps/models_provider/impl/volcanic_engine_model_provider/model/ttv.py @@ -0,0 +1,119 @@ +import base64 +import time +from typing import Dict, Optional +from models_provider.base_model_provider import MaxKBBaseModel +from models_provider.base_ttv import BaseGenerationVideo +from common.utils.logger import maxkb_logger +from volcenginesdkarkruntime import Ark + + +class GenerationVideoModel(MaxKBBaseModel, BaseGenerationVideo): + api_key: str + model_name: str + params: dict + max_retries: int = 3 + retry_delay: int = 5 # seconds + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.api_key = kwargs.get('api_key') + self.model_name = kwargs.get('model_name') + self.params = kwargs.get('params', {}) + self.retry_delay = 5 + + @staticmethod + def is_cache_model(): + return False + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + optional_params = {'params': {}} + for key, value in model_kwargs.items(): + if key not in ['model_id', 'use_local', 'streaming']: + optional_params['params'][key] = value + return GenerationVideoModel( + model_name=model_name, + api_key=model_credential.get('api_key'), + **optional_params, + ) + + def check_auth(self): + return True + + def _build_prompt(self, prompt: str) -> str: + """拼接参数到 prompt 文本""" + param_map = { + "ratio": "rt", + "duration": "dur", + "framespersecond": "fps", + "resolution": "rs", + "watermark": "wm", + "camerafixed": "cf", + } + for key, value in self.params.items(): + if key in param_map: + prompt += f" --{param_map[key]} {value}" + return prompt + + def _poll_task(self, client: Ark, task_id: str, max_wait: int = 60, interval: int = 5): + """轮询任务状态,直到完成或超时""" + elapsed = 0 + while elapsed < max_wait: + result = client.content_generation.tasks.get(task_id=task_id) + status = getattr(result, "status", None) + maxkb_logger.info(f"[ArkVideo] Task {task_id} status={status}") + + if status in ("succeeded", "failed", "cancelled"): + return result + + time.sleep(interval) + elapsed += interval + maxkb_logger.warning(f"[ArkVideo] Task {task_id} wait timeout") + return None + + # --- 通用异步生成函数 --- + def generate_video(self, prompt, negative_prompt=None, first_frame_url=None, last_frame_url=None, **kwargs): + client = Ark(api_key=self.api_key) + # 根据params设置其他参数 豆包的参数和别的不一样 需要拼接在text里 + # --rt 16:9 --dur 5 --fps 24 --rs 720p --wm true --cf false + prompt = self._build_prompt(prompt) + content = [{"type": "text", "text": prompt}] + + if first_frame_url: + content.append({ + "type": "image_url", + "image_url": { + "url": first_frame_url + }, + "role": "first_frame" + }) + if last_frame_url: + content.append({ + "type": "image_url", + "image_url": { + "url": last_frame_url + }, + "role": "last_frame" + }) + create_result = client.content_generation.tasks.create( + model=self.model_name, + content=content + ) + + task = client.content_generation.tasks.create(model=self.model_name, content=content) + task_id = task.id + maxkb_logger.info(f"[ArkVideo] Created task {task_id}") + + # 轮询获取结果 + result = self._poll_task(client, task_id) + if not result: + return {"status": "timeout", "task_id": task_id} + + try: + if getattr(result, "status", None) in ("succeeded", "failed", "cancelled"): + client.content_generation.tasks.delete(task_id=task_id) + maxkb_logger.info(f"[ArkVideo] Deleted task {task_id}") + except Exception as e: + maxkb_logger.error(f"[ArkVideo] Failed to delete task {task_id}: {e}") + maxkb_logger.info("视频地址", result.content.video_url) + return result.content.video_url diff --git a/apps/models_provider/impl/volcanic_engine_model_provider/volcanic_engine_model_provider.py b/apps/models_provider/impl/volcanic_engine_model_provider/volcanic_engine_model_provider.py index 5f9460d76..96937b3e6 100644 --- a/apps/models_provider/impl/volcanic_engine_model_provider/volcanic_engine_model_provider.py +++ b/apps/models_provider/impl/volcanic_engine_model_provider/volcanic_engine_model_provider.py @@ -17,6 +17,7 @@ from models_provider.impl.volcanic_engine_model_provider.credential.image import VolcanicEngineImageModelCredential from models_provider.impl.volcanic_engine_model_provider.credential.tti import VolcanicEngineTTIModelCredential from models_provider.impl.volcanic_engine_model_provider.credential.tts import VolcanicEngineTTSModelCredential +from models_provider.impl.volcanic_engine_model_provider.credential.ttv import VolcanicEngineTTVModelCredential from models_provider.impl.volcanic_engine_model_provider.model.embedding import VolcanicEngineEmbeddingModel from models_provider.impl.volcanic_engine_model_provider.model.image import VolcanicEngineImage from models_provider.impl.volcanic_engine_model_provider.model.llm import VolcanicEngineChatModel @@ -28,6 +29,8 @@ from models_provider.impl.volcanic_engine_model_provider.model.tts import Volcan from maxkb.conf import PROJECT_DIR from django.utils.translation import gettext as _ +from models_provider.impl.volcanic_engine_model_provider.model.ttv import GenerationVideoModel + volcanic_engine_llm_model_credential = OpenAILLMModelCredential() volcanic_engine_stt_model_credential = VolcanicEngineSTTModelCredential() volcanic_engine_tts_model_credential = VolcanicEngineTTSModelCredential() @@ -69,6 +72,45 @@ model_info_embedding_list = [ ModelTypeConst.EMBEDDING, open_ai_embedding_credential, VolcanicEngineEmbeddingModel) ] +ttv_credential = VolcanicEngineTTVModelCredential() +model_info_ttv_list = [ + ModelInfo('doubao-seedance-1-0-pro-250528', + _(''), + ModelTypeConst.TTV, + ttv_credential, GenerationVideoModel) + , + ModelInfo('doubao-seedance-1-0-lite-t2v-250428', + _(''), + ModelTypeConst.TTV, + ttv_credential, GenerationVideoModel) + , + ModelInfo('wan2-1-14b-t2v-250225', + _(''), + ModelTypeConst.TTV, + ttv_credential, GenerationVideoModel) +] +model_info_itv_list = [ + ModelInfo('doubao-seedance-1-0-pro-250528', + _(''), + ModelTypeConst.ITV, + ttv_credential, + GenerationVideoModel), + ModelInfo('doubao-seedance-1-0-lite-i2v-250428', + _(''), + ModelTypeConst.ITV, + ttv_credential, + GenerationVideoModel), + ModelInfo('wan2-1-14b-i2v-250225', + _(''), + ModelTypeConst.ITV, + ttv_credential, + GenerationVideoModel), + ModelInfo('wan2-1-14b-flf2v-250417', + _(''), + ModelTypeConst.ITV, + ttv_credential, + GenerationVideoModel), +] model_info_manage = ( ModelInfoManage.builder() @@ -80,6 +122,10 @@ model_info_manage = ( .append_default_model_info(model_info_list[4]) .append_model_info_list(model_info_embedding_list) .append_default_model_info(model_info_embedding_list[0]) + .append_model_info_list(model_info_ttv_list) + .append_default_model_info(model_info_ttv_list[0]) + .append_model_info_list(model_info_itv_list) + .append_default_model_info(model_info_itv_list[0]) .build() )