feat: add translation for video generation error messages and implement TTV model parameters

This commit is contained in:
wxg0103 2025-09-15 11:08:54 +08:00
parent b9b91925f0
commit f1cd675caa
8 changed files with 287 additions and 7 deletions

View File

@ -13,7 +13,7 @@ from common.utils.common import bytes_to_uploaded_file
from knowledge.models import FileSourceType, File
from oss.serializers.file import FileSerializer, mime_types
from models_provider.tools import get_model_instance_by_model_workspace_id
from django.utils.translation import gettext
class BaseImageToVideoNode(IImageToVideoNode):
def save_context(self, details, workflow_manage):
@ -48,7 +48,7 @@ class BaseImageToVideoNode(IImageToVideoNode):
video_urls = ttv_model.generate_video(question, negative_prompt, first_frame_url, last_frame_url)
# 保存图片
if video_urls is None:
return NodeResult({'answer': '生成视频失败'}, {})
return NodeResult({'answer': gettext('Failed to generate video')}, {})
file_name = 'generated_video.mp4'
if isinstance(video_urls, str) and video_urls.startswith('http'):
video_urls = requests.get(video_urls).content

View File

@ -11,7 +11,7 @@ from common.utils.common import bytes_to_uploaded_file
from knowledge.models import FileSourceType
from oss.serializers.file import FileSerializer
from models_provider.tools import get_model_instance_by_model_workspace_id
from django.utils.translation import gettext
class BaseTextToVideoNode(ITextToVideoNode):
def save_context(self, details, workflow_manage):
@ -40,7 +40,7 @@ class BaseTextToVideoNode(ITextToVideoNode):
print('video_urls', video_urls)
# 保存图片
if video_urls is None:
return NodeResult({'answer': '生成视频失败'}, {})
return NodeResult({'answer': gettext('Failed to generate video')}, {})
file_name = 'generated_video.mp4'
if isinstance(video_urls, str) and video_urls.startswith('http'):
video_urls = requests.get(video_urls).content
@ -56,7 +56,6 @@ class BaseTextToVideoNode(ITextToVideoNode):
'source_id': meta['application_id'],
'source_type': FileSourceType.APPLICATION.value
}).upload()
print('file_url', file_url)
video_label = f'<video src="{file_url}" controls style="max-width: 100%; width: 100%; height: auto;"></video>'
video_list = [{'file_id': file_url.split('/')[-1], 'file_name': file_name, 'url': file_url}]
return NodeResult({'answer': video_label, 'chat_model': ttv_model, 'message_list': message_list,

View File

@ -8700,4 +8700,13 @@ msgid "Whether to add watermark"
msgstr ""
msgid "Resolution"
msgstr ""
msgid "Ratio"
msgstr ""
msgid "Duration"
msgstr ""
msgid "Failed to generate video"
msgstr ""

View File

@ -8826,4 +8826,13 @@ msgid "Whether to add watermark"
msgstr "是否添加水印"
msgid "Resolution"
msgstr "分辨率"
msgstr "分辨率"
msgid "Ratio"
msgstr "比例"
msgid "Duration"
msgstr "时长"
msgid "Failed to generate video"
msgstr "生成视频失败"

View File

@ -8826,4 +8826,13 @@ msgid "Whether to add watermark"
msgstr "是否添加水印"
msgid "Resolution"
msgstr "分辨率"
msgstr "分辨率"
msgid "Ratio"
msgstr "比例"
msgid "Duration"
msgstr "時長"
msgid "Failed to generate video"
msgstr "生成視頻失敗"

View File

@ -0,0 +1,89 @@
# coding=utf-8
import traceback
from typing import Dict
from django.utils.translation import gettext_lazy as _, gettext
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm, TooltipLabel, SingleSelect, TextInputField
from common.forms.switch_field import SwitchField
from models_provider.base_model_provider import BaseModelCredential, ValidCode
class VolcanicEngineTTVModelGeneralParams(BaseForm):
resolution = SingleSelect(
TooltipLabel(_('Resolution'), _('Resolution')),
required=True,
default_value='480P',
option_list=[
{'value': '480P', 'label': '480P'},
{'value': '720P', 'label': '720P'},
{'value': '1080P', 'label': '1080P'},
],
text_field='label',
value_field='value'
)
ratio = SingleSelect(
TooltipLabel(_('Ratio'), _('Ratio')),
required=True,
default_value='16:9',
option_list=[
{'value': '16:9', 'label': '16:9'},
{'value': '9:16', 'label': '9:16'},
{'value': '1:1', 'label': '1:1'},
{'value': '4:3', 'label': '4:3'},
{'value': '3:4', 'label': '3:4'},
{'value': '21:9', 'label': '21:9'},
],
text_field='label',
value_field='value'
)
duration = TextInputField(
TooltipLabel(_('Duration'), _('Duration')),
required=True,
default_value=5,
)
watermark = SwitchField(
TooltipLabel(_('Watermark'), _('Whether to add watermark')),
default_value=False,
)
class VolcanicEngineTTVModelCredential(BaseForm, BaseModelCredential):
api_key = forms.PasswordInputField('Api key', required=True)
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value,
gettext('{model_type} Model type is not supported').format(model_type=model_type))
for key in ['api_key']:
if key not in model_credential:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, gettext('{key} is required').format(key=key))
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential, **model_params)
model.check_auth()
except Exception as e:
traceback.print_exc()
if isinstance(e, AppApiException):
raise e
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, gettext(
'Verification failed, please check whether the parameters are correct: {error}').format(
error=str(e)))
else:
return False
return True
def encryption_dict(self, model: Dict[str, object]):
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
def get_model_params_setting_form(self, model_name):
return VolcanicEngineTTVModelGeneralParams()

View File

@ -0,0 +1,119 @@
import base64
import time
from typing import Dict, Optional
from models_provider.base_model_provider import MaxKBBaseModel
from models_provider.base_ttv import BaseGenerationVideo
from common.utils.logger import maxkb_logger
from volcenginesdkarkruntime import Ark
class GenerationVideoModel(MaxKBBaseModel, BaseGenerationVideo):
api_key: str
model_name: str
params: dict
max_retries: int = 3
retry_delay: int = 5 # seconds
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.api_key = kwargs.get('api_key')
self.model_name = kwargs.get('model_name')
self.params = kwargs.get('params', {})
self.retry_delay = 5
@staticmethod
def is_cache_model():
return False
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {'params': {}}
for key, value in model_kwargs.items():
if key not in ['model_id', 'use_local', 'streaming']:
optional_params['params'][key] = value
return GenerationVideoModel(
model_name=model_name,
api_key=model_credential.get('api_key'),
**optional_params,
)
def check_auth(self):
return True
def _build_prompt(self, prompt: str) -> str:
"""拼接参数到 prompt 文本"""
param_map = {
"ratio": "rt",
"duration": "dur",
"framespersecond": "fps",
"resolution": "rs",
"watermark": "wm",
"camerafixed": "cf",
}
for key, value in self.params.items():
if key in param_map:
prompt += f" --{param_map[key]} {value}"
return prompt
def _poll_task(self, client: Ark, task_id: str, max_wait: int = 60, interval: int = 5):
"""轮询任务状态,直到完成或超时"""
elapsed = 0
while elapsed < max_wait:
result = client.content_generation.tasks.get(task_id=task_id)
status = getattr(result, "status", None)
maxkb_logger.info(f"[ArkVideo] Task {task_id} status={status}")
if status in ("succeeded", "failed", "cancelled"):
return result
time.sleep(interval)
elapsed += interval
maxkb_logger.warning(f"[ArkVideo] Task {task_id} wait timeout")
return None
# --- 通用异步生成函数 ---
def generate_video(self, prompt, negative_prompt=None, first_frame_url=None, last_frame_url=None, **kwargs):
client = Ark(api_key=self.api_key)
# 根据params设置其他参数 豆包的参数和别的不一样 需要拼接在text里
# --rt 16:9 --dur 5 --fps 24 --rs 720p --wm true --cf false
prompt = self._build_prompt(prompt)
content = [{"type": "text", "text": prompt}]
if first_frame_url:
content.append({
"type": "image_url",
"image_url": {
"url": first_frame_url
},
"role": "first_frame"
})
if last_frame_url:
content.append({
"type": "image_url",
"image_url": {
"url": last_frame_url
},
"role": "last_frame"
})
create_result = client.content_generation.tasks.create(
model=self.model_name,
content=content
)
task = client.content_generation.tasks.create(model=self.model_name, content=content)
task_id = task.id
maxkb_logger.info(f"[ArkVideo] Created task {task_id}")
# 轮询获取结果
result = self._poll_task(client, task_id)
if not result:
return {"status": "timeout", "task_id": task_id}
try:
if getattr(result, "status", None) in ("succeeded", "failed", "cancelled"):
client.content_generation.tasks.delete(task_id=task_id)
maxkb_logger.info(f"[ArkVideo] Deleted task {task_id}")
except Exception as e:
maxkb_logger.error(f"[ArkVideo] Failed to delete task {task_id}: {e}")
maxkb_logger.info("视频地址", result.content.video_url)
return result.content.video_url

View File

@ -17,6 +17,7 @@ from models_provider.impl.volcanic_engine_model_provider.credential.image import
VolcanicEngineImageModelCredential
from models_provider.impl.volcanic_engine_model_provider.credential.tti import VolcanicEngineTTIModelCredential
from models_provider.impl.volcanic_engine_model_provider.credential.tts import VolcanicEngineTTSModelCredential
from models_provider.impl.volcanic_engine_model_provider.credential.ttv import VolcanicEngineTTVModelCredential
from models_provider.impl.volcanic_engine_model_provider.model.embedding import VolcanicEngineEmbeddingModel
from models_provider.impl.volcanic_engine_model_provider.model.image import VolcanicEngineImage
from models_provider.impl.volcanic_engine_model_provider.model.llm import VolcanicEngineChatModel
@ -28,6 +29,8 @@ from models_provider.impl.volcanic_engine_model_provider.model.tts import Volcan
from maxkb.conf import PROJECT_DIR
from django.utils.translation import gettext as _
from models_provider.impl.volcanic_engine_model_provider.model.ttv import GenerationVideoModel
volcanic_engine_llm_model_credential = OpenAILLMModelCredential()
volcanic_engine_stt_model_credential = VolcanicEngineSTTModelCredential()
volcanic_engine_tts_model_credential = VolcanicEngineTTSModelCredential()
@ -69,6 +72,45 @@ model_info_embedding_list = [
ModelTypeConst.EMBEDDING, open_ai_embedding_credential,
VolcanicEngineEmbeddingModel)
]
ttv_credential = VolcanicEngineTTVModelCredential()
model_info_ttv_list = [
ModelInfo('doubao-seedance-1-0-pro-250528',
_(''),
ModelTypeConst.TTV,
ttv_credential, GenerationVideoModel)
,
ModelInfo('doubao-seedance-1-0-lite-t2v-250428',
_(''),
ModelTypeConst.TTV,
ttv_credential, GenerationVideoModel)
,
ModelInfo('wan2-1-14b-t2v-250225',
_(''),
ModelTypeConst.TTV,
ttv_credential, GenerationVideoModel)
]
model_info_itv_list = [
ModelInfo('doubao-seedance-1-0-pro-250528',
_(''),
ModelTypeConst.ITV,
ttv_credential,
GenerationVideoModel),
ModelInfo('doubao-seedance-1-0-lite-i2v-250428',
_(''),
ModelTypeConst.ITV,
ttv_credential,
GenerationVideoModel),
ModelInfo('wan2-1-14b-i2v-250225',
_(''),
ModelTypeConst.ITV,
ttv_credential,
GenerationVideoModel),
ModelInfo('wan2-1-14b-flf2v-250417',
_(''),
ModelTypeConst.ITV,
ttv_credential,
GenerationVideoModel),
]
model_info_manage = (
ModelInfoManage.builder()
@ -80,6 +122,10 @@ model_info_manage = (
.append_default_model_info(model_info_list[4])
.append_model_info_list(model_info_embedding_list)
.append_default_model_info(model_info_embedding_list[0])
.append_model_info_list(model_info_ttv_list)
.append_default_model_info(model_info_ttv_list[0])
.append_model_info_list(model_info_itv_list)
.append_default_model_info(model_info_itv_list[0])
.build()
)