From 16088975faee258734f94e996c155bb7997a1a80 Mon Sep 17 00:00:00 2001
From: wxg0103 <727495428@qq.com>
Date: Wed, 9 Jul 2025 18:08:41 +0800
Subject: [PATCH] refactor: enhance text-to-speech processing by splitting
content into chunks and merging audio segments
---
.../impl/base_text_to_speech_node.py | 66 +++++++++++++++----
.../model/tts.py | 2 +
2 files changed, 56 insertions(+), 12 deletions(-)
diff --git a/apps/application/flow/step_node/text_to_speech_step_node/impl/base_text_to_speech_node.py b/apps/application/flow/step_node/text_to_speech_step_node/impl/base_text_to_speech_node.py
index e673d7aaa..31ff6ba14 100644
--- a/apps/application/flow/step_node/text_to_speech_step_node/impl/base_text_to_speech_node.py
+++ b/apps/application/flow/step_node/text_to_speech_step_node/impl/base_text_to_speech_node.py
@@ -6,9 +6,11 @@ from django.core.files.uploadedfile import InMemoryUploadedFile
from application.flow.i_step_node import NodeResult
from application.flow.step_node.text_to_speech_step_node.i_text_to_speech_node import ITextToSpeechNode
+from common.utils.common import _remove_empty_lines
from knowledge.models import FileSourceType
from models_provider.tools import get_model_instance_by_model_workspace_id
from oss.serializers.file import FileSerializer
+from pydub import AudioSegment
def bytes_to_uploaded_file(file_bytes, file_name="generated_audio.mp3"):
@@ -41,32 +43,72 @@ class BaseTextToSpeechNode(ITextToSpeechNode):
def execute(self, tts_model_id, chat_id,
content, model_params_setting=None,
- **kwargs) -> NodeResult:
- self.context['content'] = content
- workspace_id = self.workflow_manage.get_body().get('workspace_id')
- model = get_model_instance_by_model_workspace_id(tts_model_id, workspace_id,
- **model_params_setting)
- audio_byte = model.text_to_speech(content)
- # 需要把这个音频文件存储到数据库中
- file_name = 'generated_audio.mp3'
- file = bytes_to_uploaded_file(audio_byte, file_name)
+ max_length=1024, **kwargs) -> NodeResult:
+ # 分割文本为合理片段
+ content = _remove_empty_lines(content)
+ content_chunks = [content[i:i + max_length]
+ for i in range(0, len(content), max_length)]
+
+ # 生成并收集所有音频片段
+ audio_segments = []
+ temp_files = []
+
+ for i, chunk in enumerate(content_chunks):
+ self.context['content'] = chunk
+ workspace_id = self.workflow_manage.get_body().get('workspace_id')
+ model = get_model_instance_by_model_workspace_id(
+ tts_model_id, workspace_id, **model_params_setting)
+
+ audio_byte = model.text_to_speech(chunk)
+
+ # 保存为临时音频文件用于合并
+ temp_file = io.BytesIO(audio_byte)
+ audio_segment = AudioSegment.from_file(temp_file)
+ audio_segments.append(audio_segment)
+ temp_files.append(temp_file)
+
+ # 合并所有音频片段
+ combined_audio = AudioSegment.empty()
+ for segment in audio_segments:
+ combined_audio += segment
+
+ # 将合并后的音频转为字节流
+ output_buffer = io.BytesIO()
+ combined_audio.export(output_buffer, format="mp3")
+ combined_bytes = output_buffer.getvalue()
+
+ # 存储合并后的音频文件
+ file_name = 'combined_audio.mp3'
+ file = bytes_to_uploaded_file(combined_bytes, file_name)
+
application = self.workflow_manage.work_flow_post_handler.chat_info.application
meta = {
'debug': False if application.id else True,
'chat_id': chat_id,
'application_id': str(application.id) if application.id else None,
}
+
file_url = FileSerializer(data={
'file': file,
'meta': meta,
'source_id': meta['application_id'],
'source_type': FileSourceType.APPLICATION.value
}).upload()
- # 拼接一个audio标签的src属性
- audio_label = f''
+
+ # 生成音频标签
+ audio_label = f''
file_id = file_url.split('/')[-1]
audio_list = [{'file_id': file_id, 'file_name': file_name, 'url': file_url}]
- return NodeResult({'answer': audio_label, 'result': audio_list}, {})
+
+ # 关闭所有临时文件
+ for temp_file in temp_files:
+ temp_file.close()
+ output_buffer.close()
+
+ return NodeResult({
+ 'answer': audio_label,
+ 'result': audio_list
+ }, {})
def get_details(self, index: int, **kwargs):
return {
diff --git a/apps/models_provider/impl/volcanic_engine_model_provider/model/tts.py b/apps/models_provider/impl/volcanic_engine_model_provider/model/tts.py
index 2006f1ca4..5dd02f0b2 100644
--- a/apps/models_provider/impl/volcanic_engine_model_provider/model/tts.py
+++ b/apps/models_provider/impl/volcanic_engine_model_provider/model/tts.py
@@ -14,6 +14,8 @@ import gzip
import json
import re
import ssl
+
+import requests
import uuid_utils.compat as uuid
from typing import Dict