From 16088975faee258734f94e996c155bb7997a1a80 Mon Sep 17 00:00:00 2001 From: wxg0103 <727495428@qq.com> Date: Wed, 9 Jul 2025 18:08:41 +0800 Subject: [PATCH] refactor: enhance text-to-speech processing by splitting content into chunks and merging audio segments --- .../impl/base_text_to_speech_node.py | 66 +++++++++++++++---- .../model/tts.py | 2 + 2 files changed, 56 insertions(+), 12 deletions(-) diff --git a/apps/application/flow/step_node/text_to_speech_step_node/impl/base_text_to_speech_node.py b/apps/application/flow/step_node/text_to_speech_step_node/impl/base_text_to_speech_node.py index e673d7aaa..31ff6ba14 100644 --- a/apps/application/flow/step_node/text_to_speech_step_node/impl/base_text_to_speech_node.py +++ b/apps/application/flow/step_node/text_to_speech_step_node/impl/base_text_to_speech_node.py @@ -6,9 +6,11 @@ from django.core.files.uploadedfile import InMemoryUploadedFile from application.flow.i_step_node import NodeResult from application.flow.step_node.text_to_speech_step_node.i_text_to_speech_node import ITextToSpeechNode +from common.utils.common import _remove_empty_lines from knowledge.models import FileSourceType from models_provider.tools import get_model_instance_by_model_workspace_id from oss.serializers.file import FileSerializer +from pydub import AudioSegment def bytes_to_uploaded_file(file_bytes, file_name="generated_audio.mp3"): @@ -41,32 +43,72 @@ class BaseTextToSpeechNode(ITextToSpeechNode): def execute(self, tts_model_id, chat_id, content, model_params_setting=None, - **kwargs) -> NodeResult: - self.context['content'] = content - workspace_id = self.workflow_manage.get_body().get('workspace_id') - model = get_model_instance_by_model_workspace_id(tts_model_id, workspace_id, - **model_params_setting) - audio_byte = model.text_to_speech(content) - # 需要把这个音频文件存储到数据库中 - file_name = 'generated_audio.mp3' - file = bytes_to_uploaded_file(audio_byte, file_name) + max_length=1024, **kwargs) -> NodeResult: + # 分割文本为合理片段 + content = _remove_empty_lines(content) + content_chunks = [content[i:i + max_length] + for i in range(0, len(content), max_length)] + + # 生成并收集所有音频片段 + audio_segments = [] + temp_files = [] + + for i, chunk in enumerate(content_chunks): + self.context['content'] = chunk + workspace_id = self.workflow_manage.get_body().get('workspace_id') + model = get_model_instance_by_model_workspace_id( + tts_model_id, workspace_id, **model_params_setting) + + audio_byte = model.text_to_speech(chunk) + + # 保存为临时音频文件用于合并 + temp_file = io.BytesIO(audio_byte) + audio_segment = AudioSegment.from_file(temp_file) + audio_segments.append(audio_segment) + temp_files.append(temp_file) + + # 合并所有音频片段 + combined_audio = AudioSegment.empty() + for segment in audio_segments: + combined_audio += segment + + # 将合并后的音频转为字节流 + output_buffer = io.BytesIO() + combined_audio.export(output_buffer, format="mp3") + combined_bytes = output_buffer.getvalue() + + # 存储合并后的音频文件 + file_name = 'combined_audio.mp3' + file = bytes_to_uploaded_file(combined_bytes, file_name) + application = self.workflow_manage.work_flow_post_handler.chat_info.application meta = { 'debug': False if application.id else True, 'chat_id': chat_id, 'application_id': str(application.id) if application.id else None, } + file_url = FileSerializer(data={ 'file': file, 'meta': meta, 'source_id': meta['application_id'], 'source_type': FileSourceType.APPLICATION.value }).upload() - # 拼接一个audio标签的src属性 - audio_label = f'' + + # 生成音频标签 + audio_label = f'' file_id = file_url.split('/')[-1] audio_list = [{'file_id': file_id, 'file_name': file_name, 'url': file_url}] - return NodeResult({'answer': audio_label, 'result': audio_list}, {}) + + # 关闭所有临时文件 + for temp_file in temp_files: + temp_file.close() + output_buffer.close() + + return NodeResult({ + 'answer': audio_label, + 'result': audio_list + }, {}) def get_details(self, index: int, **kwargs): return { diff --git a/apps/models_provider/impl/volcanic_engine_model_provider/model/tts.py b/apps/models_provider/impl/volcanic_engine_model_provider/model/tts.py index 2006f1ca4..5dd02f0b2 100644 --- a/apps/models_provider/impl/volcanic_engine_model_provider/model/tts.py +++ b/apps/models_provider/impl/volcanic_engine_model_provider/model/tts.py @@ -14,6 +14,8 @@ import gzip import json import re import ssl + +import requests import uuid_utils.compat as uuid from typing import Dict