From a46cf1c18b163cf2e68725b8ffe64106eff327e1 Mon Sep 17 00:00:00 2001 From: CaptainB Date: Thu, 24 Oct 2024 19:02:05 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E8=AF=AD=E9=9F=B3?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E4=BC=A0=E5=85=A5=E4=B8=8D=E6=AD=A3=E7=A1=AE?= =?UTF-8?q?=E5=8F=82=E6=95=B0=E6=8A=A5=E9=94=99=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../serializers/application_serializers.py | 6 +++++- .../model/tts.py | 18 ++++++++---------- .../impl/openai_model_provider/model/tts.py | 15 ++++++++------- .../model/tts.py | 19 +++++++------------ .../impl/xf_model_provider/model/tts.py | 18 ++++++++---------- .../xinference_model_provider/model/tts.py | 15 +++++++-------- .../views/application/ApplicationSetting.vue | 9 +++++++++ .../component/TTSModeParamSettingDialog.vue | 16 ++++++++++++++-- ui/src/workflow/nodes/base-node/index.vue | 10 ++++++++++ 9 files changed, 76 insertions(+), 50 deletions(-) diff --git a/apps/application/serializers/application_serializers.py b/apps/application/serializers/application_serializers.py index 4b54e3a4a..0be967621 100644 --- a/apps/application/serializers/application_serializers.py +++ b/apps/application/serializers/application_serializers.py @@ -1028,7 +1028,11 @@ class ApplicationSerializer(serializers.Serializer): application_id = self.data.get('application_id') application = QuerySet(Application).filter(id=application_id).first() if application.tts_model_enable: - model = get_model_instance_by_model_user_id(application.tts_model_id, application.user_id, **form_data) + tts_model_id = application.tts_model_id + if 'tts_model_id' in form_data: + tts_model_id = form_data.get('tts_model_id') + del form_data['tts_model_id'] + model = get_model_instance_by_model_user_id(tts_model_id, application.user_id, **form_data) return model.text_to_speech(text) class ApplicationKeySerializerModel(serializers.ModelSerializer): diff --git a/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/model/tts.py b/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/model/tts.py index 356f9f907..1dbee97d3 100644 --- a/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/model/tts.py +++ b/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/model/tts.py @@ -10,23 +10,21 @@ from setting.models_provider.impl.base_tts import BaseTextToSpeech class AliyunBaiLianTextToSpeech(MaxKBBaseModel, BaseTextToSpeech): api_key: str model: str - voice: str - speech_rate: float + params: dict def __init__(self, **kwargs): super().__init__(**kwargs) self.api_key = kwargs.get('api_key') self.model = kwargs.get('model') - self.voice = kwargs.get('voice') - self.speech_rate = kwargs.get('speech_rate') + self.params = kwargs.get('params') @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): - optional_params = {'voice': 'longxiaochun', 'speech_rate': 1.0} - if 'voice' in model_kwargs and model_kwargs['voice'] is not None: - optional_params['voice'] = model_kwargs['voice'] - if 'speech_rate' in model_kwargs and model_kwargs['speech_rate'] is not None: - optional_params['speech_rate'] = model_kwargs['speech_rate'] + optional_params = {'params': {'voice': 'longxiaochun', 'speech_rate': 1.0}} + for key, value in model_kwargs.items(): + if key not in ['model_id', 'use_local', 'streaming']: + optional_params['params'][key] = value + return AliyunBaiLianTextToSpeech( model=model_name, api_key=model_credential.get('api_key'), @@ -38,7 +36,7 @@ class AliyunBaiLianTextToSpeech(MaxKBBaseModel, BaseTextToSpeech): def text_to_speech(self, text): dashscope.api_key = self.api_key - synthesizer = SpeechSynthesizer(model=self.model, voice=self.voice, speech_rate=self.speech_rate) + synthesizer = SpeechSynthesizer(model=self.model, **self.params) audio = synthesizer.call(text) if type(audio) == str: print(audio) diff --git a/apps/setting/models_provider/impl/openai_model_provider/model/tts.py b/apps/setting/models_provider/impl/openai_model_provider/model/tts.py index 1247eabb5..6e9aa2c6e 100644 --- a/apps/setting/models_provider/impl/openai_model_provider/model/tts.py +++ b/apps/setting/models_provider/impl/openai_model_provider/model/tts.py @@ -16,20 +16,21 @@ class OpenAITextToSpeech(MaxKBBaseModel, BaseTextToSpeech): api_base: str api_key: str model: str - voice: str + params: dict def __init__(self, **kwargs): super().__init__(**kwargs) self.api_key = kwargs.get('api_key') self.api_base = kwargs.get('api_base') self.model = kwargs.get('model') - self.voice = kwargs.get('voice', 'alloy') + self.params = kwargs.get('params') @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): - optional_params = {'voice': 'alloy'} - if 'voice' in model_kwargs and model_kwargs['voice'] is not None: - optional_params['voice'] = model_kwargs['voice'] + optional_params = {'params': {'voice': 'alloy'}} + for key, value in model_kwargs.items(): + if key not in ['model_id', 'use_local', 'streaming']: + optional_params['params'][key] = value return OpenAITextToSpeech( model=model_name, api_base=model_credential.get('api_base'), @@ -52,10 +53,10 @@ class OpenAITextToSpeech(MaxKBBaseModel, BaseTextToSpeech): ) with client.audio.speech.with_streaming_response.create( model=self.model, - voice=self.voice, input=text, + **self.params ) as response: return response.read() def is_cache_model(self): - return False \ No newline at end of file + return False diff --git a/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/tts.py b/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/tts.py index 6b6ace2cf..e2cc4e286 100644 --- a/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/tts.py +++ b/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/tts.py @@ -45,8 +45,7 @@ class VolcanicEngineTextToSpeech(MaxKBBaseModel, BaseTextToSpeech): volcanic_cluster: str volcanic_api_url: str volcanic_token: str - speed_ratio: float - voice_type: str + params: dict def __init__(self, **kwargs): super().__init__(**kwargs) @@ -54,16 +53,14 @@ class VolcanicEngineTextToSpeech(MaxKBBaseModel, BaseTextToSpeech): self.volcanic_token = kwargs.get('volcanic_token') self.volcanic_app_id = kwargs.get('volcanic_app_id') self.volcanic_cluster = kwargs.get('volcanic_cluster') - self.voice_type = kwargs.get('voice_type') - self.speed_ratio = kwargs.get('speed_ratio') + self.params = kwargs.get('params') @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): - optional_params = {'voice_type': 'BV002_streaming', 'speed_ratio': 1.0} - if 'voice_type' in model_kwargs and model_kwargs['voice_type'] is not None: - optional_params['voice_type'] = model_kwargs['voice_type'] - if 'speed_ratio' in model_kwargs and model_kwargs['speed_ratio'] is not None: - optional_params['speed_ratio'] = model_kwargs['speed_ratio'] + optional_params = {'params': {'voice_type': 'BV002_streaming', 'speed_ratio': 1.0}} + for key, value in model_kwargs.items(): + if key not in ['model_id', 'use_local', 'streaming']: + optional_params['params'][key] = value return VolcanicEngineTextToSpeech( volcanic_api_url=model_credential.get('volcanic_api_url'), volcanic_token=model_credential.get('volcanic_token'), @@ -86,12 +83,10 @@ class VolcanicEngineTextToSpeech(MaxKBBaseModel, BaseTextToSpeech): "uid": "uid" }, "audio": { - "voice_type": self.voice_type, "encoding": "mp3", - "speed_ratio": self.speed_ratio, "volume_ratio": 1.0, "pitch_ratio": 1.0, - }, + } | self.params, "request": { "reqid": str(uuid.uuid4()), "text": '', diff --git a/apps/setting/models_provider/impl/xf_model_provider/model/tts.py b/apps/setting/models_provider/impl/xf_model_provider/model/tts.py index b0bc0b017..3a575ed28 100644 --- a/apps/setting/models_provider/impl/xf_model_provider/model/tts.py +++ b/apps/setting/models_provider/impl/xf_model_provider/model/tts.py @@ -37,8 +37,7 @@ class XFSparkTextToSpeech(MaxKBBaseModel, BaseTextToSpeech): spark_api_key: str spark_api_secret: str spark_api_url: str - speed: int - vcn: str + params: dict def __init__(self, **kwargs): super().__init__(**kwargs) @@ -46,16 +45,14 @@ class XFSparkTextToSpeech(MaxKBBaseModel, BaseTextToSpeech): self.spark_app_id = kwargs.get('spark_app_id') self.spark_api_key = kwargs.get('spark_api_key') self.spark_api_secret = kwargs.get('spark_api_secret') - self.vcn = kwargs.get('vcn') - self.speed = kwargs.get('speed') + self.params = kwargs.get('params') @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): - optional_params = {'vcn': 'xiaoyan', 'speed': 50} - if 'vcn' in model_kwargs and model_kwargs['vcn'] is not None: - optional_params['vcn'] = model_kwargs['vcn'] - if 'speed' in model_kwargs and model_kwargs['speed'] is not None: - optional_params['speed'] = model_kwargs['speed'] + optional_params = {'params': {'vcn': 'xiaoyan', 'speed': 50}} + for key, value in model_kwargs.items(): + if key not in ['model_id', 'use_local', 'streaming']: + optional_params['params'][key] = value return XFSparkTextToSpeech( spark_app_id=model_credential.get('spark_app_id'), spark_api_key=model_credential.get('spark_api_key'), @@ -139,9 +136,10 @@ class XFSparkTextToSpeech(MaxKBBaseModel, BaseTextToSpeech): return audio_bytes async def send(self, ws, text): + business = {"aue": "lame", "sfl": 1, "auf": "audio/L16;rate=16000", "tte": "utf8"} d = { "common": {"app_id": self.spark_app_id}, - "business": {"aue": "lame", "sfl": 1, "auf": "audio/L16;rate=16000", "vcn": self.vcn, "speed": self.speed, "tte": "utf8"}, + "business": business | self.params, "data": {"status": 2, "text": str(base64.b64encode(text.encode('utf-8')), "UTF8")}, } d = json.dumps(d) diff --git a/apps/setting/models_provider/impl/xinference_model_provider/model/tts.py b/apps/setting/models_provider/impl/xinference_model_provider/model/tts.py index 22accfc6f..1ba221eed 100644 --- a/apps/setting/models_provider/impl/xinference_model_provider/model/tts.py +++ b/apps/setting/models_provider/impl/xinference_model_provider/model/tts.py @@ -15,21 +15,20 @@ def custom_get_token_ids(text: str): class XInferenceTextToSpeech(MaxKBBaseModel, BaseTextToSpeech): api_base: str api_key: str - model: str - voice: str + params: dict def __init__(self, **kwargs): super().__init__(**kwargs) self.api_key = kwargs.get('api_key') self.api_base = kwargs.get('api_base') - self.model = kwargs.get('model') - self.voice = kwargs.get('voice', '中文女') + self.params = kwargs.get('params') @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): - optional_params = {'voice': '中文女'} - if 'voice' in model_kwargs and model_kwargs['voice'] is not None: - optional_params['voice'] = model_kwargs['voice'] + optional_params = {'params': {'voice': '中文女'}} + for key, value in model_kwargs.items(): + if key not in ['model_id', 'use_local', 'streaming']: + optional_params['params'][key] = value return XInferenceTextToSpeech( model=model_name, api_base=model_credential.get('api_base'), @@ -54,8 +53,8 @@ class XInferenceTextToSpeech(MaxKBBaseModel, BaseTextToSpeech): with client.audio.speech.with_streaming_response.create( model=self.model, - voice=self.voice, input=text, + **self.params ) as response: return response.read() diff --git a/ui/src/views/application/ApplicationSetting.vue b/ui/src/views/application/ApplicationSetting.vue index fbb028ef0..9256d2d2e 100644 --- a/ui/src/views/application/ApplicationSetting.vue +++ b/ui/src/views/application/ApplicationSetting.vue @@ -422,6 +422,7 @@ v-model="applicationForm.tts_model_id" class="w-full" popper-class="select-model" + @change="ttsModelChange()" placeholder="请选择语音合成模型" > >([]) const emit = defineEmits(['refresh']) const dynamicsFormRef = ref>() @@ -69,6 +71,7 @@ const getApi = (model_id: string, application_id?: string) => { } const open = (model_id: string, application_id?: string, model_setting_data?: any) => { form_data.value = {} + tts_model_id.value = model_id const api = getApi(model_id, application_id) api.then((ok) => { model_form_field.value = ok.data @@ -104,9 +107,18 @@ const submit = async () => { const audioPlayer = ref(null) const testPlay = () => { + const data = { + ...form_data.value, + tts_model_id: tts_model_id.value + } applicationApi - .playDemoText(id as string, form_data.value, playLoading) - .then((res: any) => { + .playDemoText(id as string, data, playLoading) + .then(async (res: any) => { + if (res.type === 'application/json') { + const text = await res.text(); + MsgError(text) + return + } // 创建 Blob 对象 const blob = new Blob([res], { type: 'audio/mp3' }) diff --git a/ui/src/workflow/nodes/base-node/index.vue b/ui/src/workflow/nodes/base-node/index.vue index 90b67fe04..987316242 100644 --- a/ui/src/workflow/nodes/base-node/index.vue +++ b/ui/src/workflow/nodes/base-node/index.vue @@ -153,6 +153,7 @@ v-model="form_data.tts_model_id" class="w-full" popper-class="select-model" + @change="ttsModelChange()" placeholder="请选择语音合成模型" > { const model_id = form_data.value.tts_model_id if (!model_id) {