fix: 修复语音模型传入不正确参数报错的问题
Some checks are pending
sync2gitee / repo-sync (push) Waiting to run
Typos Check / Spell Check with Typos (push) Waiting to run

This commit is contained in:
CaptainB 2024-10-24 19:02:05 +08:00 committed by 刘瑞斌
parent a0ad4c911c
commit a46cf1c18b
9 changed files with 76 additions and 50 deletions

View File

@ -1028,7 +1028,11 @@ class ApplicationSerializer(serializers.Serializer):
application_id = self.data.get('application_id')
application = QuerySet(Application).filter(id=application_id).first()
if application.tts_model_enable:
model = get_model_instance_by_model_user_id(application.tts_model_id, application.user_id, **form_data)
tts_model_id = application.tts_model_id
if 'tts_model_id' in form_data:
tts_model_id = form_data.get('tts_model_id')
del form_data['tts_model_id']
model = get_model_instance_by_model_user_id(tts_model_id, application.user_id, **form_data)
return model.text_to_speech(text)
class ApplicationKeySerializerModel(serializers.ModelSerializer):

View File

@ -10,23 +10,21 @@ from setting.models_provider.impl.base_tts import BaseTextToSpeech
class AliyunBaiLianTextToSpeech(MaxKBBaseModel, BaseTextToSpeech):
api_key: str
model: str
voice: str
speech_rate: float
params: dict
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.api_key = kwargs.get('api_key')
self.model = kwargs.get('model')
self.voice = kwargs.get('voice')
self.speech_rate = kwargs.get('speech_rate')
self.params = kwargs.get('params')
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {'voice': 'longxiaochun', 'speech_rate': 1.0}
if 'voice' in model_kwargs and model_kwargs['voice'] is not None:
optional_params['voice'] = model_kwargs['voice']
if 'speech_rate' in model_kwargs and model_kwargs['speech_rate'] is not None:
optional_params['speech_rate'] = model_kwargs['speech_rate']
optional_params = {'params': {'voice': 'longxiaochun', 'speech_rate': 1.0}}
for key, value in model_kwargs.items():
if key not in ['model_id', 'use_local', 'streaming']:
optional_params['params'][key] = value
return AliyunBaiLianTextToSpeech(
model=model_name,
api_key=model_credential.get('api_key'),
@ -38,7 +36,7 @@ class AliyunBaiLianTextToSpeech(MaxKBBaseModel, BaseTextToSpeech):
def text_to_speech(self, text):
dashscope.api_key = self.api_key
synthesizer = SpeechSynthesizer(model=self.model, voice=self.voice, speech_rate=self.speech_rate)
synthesizer = SpeechSynthesizer(model=self.model, **self.params)
audio = synthesizer.call(text)
if type(audio) == str:
print(audio)

View File

@ -16,20 +16,21 @@ class OpenAITextToSpeech(MaxKBBaseModel, BaseTextToSpeech):
api_base: str
api_key: str
model: str
voice: str
params: dict
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.api_key = kwargs.get('api_key')
self.api_base = kwargs.get('api_base')
self.model = kwargs.get('model')
self.voice = kwargs.get('voice', 'alloy')
self.params = kwargs.get('params')
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {'voice': 'alloy'}
if 'voice' in model_kwargs and model_kwargs['voice'] is not None:
optional_params['voice'] = model_kwargs['voice']
optional_params = {'params': {'voice': 'alloy'}}
for key, value in model_kwargs.items():
if key not in ['model_id', 'use_local', 'streaming']:
optional_params['params'][key] = value
return OpenAITextToSpeech(
model=model_name,
api_base=model_credential.get('api_base'),
@ -52,10 +53,10 @@ class OpenAITextToSpeech(MaxKBBaseModel, BaseTextToSpeech):
)
with client.audio.speech.with_streaming_response.create(
model=self.model,
voice=self.voice,
input=text,
**self.params
) as response:
return response.read()
def is_cache_model(self):
return False
return False

View File

@ -45,8 +45,7 @@ class VolcanicEngineTextToSpeech(MaxKBBaseModel, BaseTextToSpeech):
volcanic_cluster: str
volcanic_api_url: str
volcanic_token: str
speed_ratio: float
voice_type: str
params: dict
def __init__(self, **kwargs):
super().__init__(**kwargs)
@ -54,16 +53,14 @@ class VolcanicEngineTextToSpeech(MaxKBBaseModel, BaseTextToSpeech):
self.volcanic_token = kwargs.get('volcanic_token')
self.volcanic_app_id = kwargs.get('volcanic_app_id')
self.volcanic_cluster = kwargs.get('volcanic_cluster')
self.voice_type = kwargs.get('voice_type')
self.speed_ratio = kwargs.get('speed_ratio')
self.params = kwargs.get('params')
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {'voice_type': 'BV002_streaming', 'speed_ratio': 1.0}
if 'voice_type' in model_kwargs and model_kwargs['voice_type'] is not None:
optional_params['voice_type'] = model_kwargs['voice_type']
if 'speed_ratio' in model_kwargs and model_kwargs['speed_ratio'] is not None:
optional_params['speed_ratio'] = model_kwargs['speed_ratio']
optional_params = {'params': {'voice_type': 'BV002_streaming', 'speed_ratio': 1.0}}
for key, value in model_kwargs.items():
if key not in ['model_id', 'use_local', 'streaming']:
optional_params['params'][key] = value
return VolcanicEngineTextToSpeech(
volcanic_api_url=model_credential.get('volcanic_api_url'),
volcanic_token=model_credential.get('volcanic_token'),
@ -86,12 +83,10 @@ class VolcanicEngineTextToSpeech(MaxKBBaseModel, BaseTextToSpeech):
"uid": "uid"
},
"audio": {
"voice_type": self.voice_type,
"encoding": "mp3",
"speed_ratio": self.speed_ratio,
"volume_ratio": 1.0,
"pitch_ratio": 1.0,
},
} | self.params,
"request": {
"reqid": str(uuid.uuid4()),
"text": '',

View File

@ -37,8 +37,7 @@ class XFSparkTextToSpeech(MaxKBBaseModel, BaseTextToSpeech):
spark_api_key: str
spark_api_secret: str
spark_api_url: str
speed: int
vcn: str
params: dict
def __init__(self, **kwargs):
super().__init__(**kwargs)
@ -46,16 +45,14 @@ class XFSparkTextToSpeech(MaxKBBaseModel, BaseTextToSpeech):
self.spark_app_id = kwargs.get('spark_app_id')
self.spark_api_key = kwargs.get('spark_api_key')
self.spark_api_secret = kwargs.get('spark_api_secret')
self.vcn = kwargs.get('vcn')
self.speed = kwargs.get('speed')
self.params = kwargs.get('params')
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {'vcn': 'xiaoyan', 'speed': 50}
if 'vcn' in model_kwargs and model_kwargs['vcn'] is not None:
optional_params['vcn'] = model_kwargs['vcn']
if 'speed' in model_kwargs and model_kwargs['speed'] is not None:
optional_params['speed'] = model_kwargs['speed']
optional_params = {'params': {'vcn': 'xiaoyan', 'speed': 50}}
for key, value in model_kwargs.items():
if key not in ['model_id', 'use_local', 'streaming']:
optional_params['params'][key] = value
return XFSparkTextToSpeech(
spark_app_id=model_credential.get('spark_app_id'),
spark_api_key=model_credential.get('spark_api_key'),
@ -139,9 +136,10 @@ class XFSparkTextToSpeech(MaxKBBaseModel, BaseTextToSpeech):
return audio_bytes
async def send(self, ws, text):
business = {"aue": "lame", "sfl": 1, "auf": "audio/L16;rate=16000", "tte": "utf8"}
d = {
"common": {"app_id": self.spark_app_id},
"business": {"aue": "lame", "sfl": 1, "auf": "audio/L16;rate=16000", "vcn": self.vcn, "speed": self.speed, "tte": "utf8"},
"business": business | self.params,
"data": {"status": 2, "text": str(base64.b64encode(text.encode('utf-8')), "UTF8")},
}
d = json.dumps(d)

View File

@ -15,21 +15,20 @@ def custom_get_token_ids(text: str):
class XInferenceTextToSpeech(MaxKBBaseModel, BaseTextToSpeech):
api_base: str
api_key: str
model: str
voice: str
params: dict
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.api_key = kwargs.get('api_key')
self.api_base = kwargs.get('api_base')
self.model = kwargs.get('model')
self.voice = kwargs.get('voice', '中文女')
self.params = kwargs.get('params')
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {'voice': '中文女'}
if 'voice' in model_kwargs and model_kwargs['voice'] is not None:
optional_params['voice'] = model_kwargs['voice']
optional_params = {'params': {'voice': '中文女'}}
for key, value in model_kwargs.items():
if key not in ['model_id', 'use_local', 'streaming']:
optional_params['params'][key] = value
return XInferenceTextToSpeech(
model=model_name,
api_base=model_credential.get('api_base'),
@ -54,8 +53,8 @@ class XInferenceTextToSpeech(MaxKBBaseModel, BaseTextToSpeech):
with client.audio.speech.with_streaming_response.create(
model=self.model,
voice=self.voice,
input=text,
**self.params
) as response:
return response.read()

View File

@ -422,6 +422,7 @@
v-model="applicationForm.tts_model_id"
class="w-full"
popper-class="select-model"
@change="ttsModelChange()"
placeholder="请选择语音合成模型"
>
<el-option-group
@ -807,6 +808,14 @@ function getTTSModel() {
})
}
function ttsModelChange() {
if (applicationForm.value.tts_model_id) {
TTSModeParamSettingDialogRef.value?.reset_default(applicationForm.value.tts_model_id, id)
} else {
refreshTTSForm({})
}
}
function getProvider() {
loading.value = true
model

View File

@ -50,11 +50,13 @@ import applicationApi from '@/api/application'
import DynamicsForm from '@/components/dynamics-form/index.vue'
import { keys } from 'lodash'
import { app } from '@/main'
import { MsgError } from '@/utils/message'
const {
params: { id }
} = app.config.globalProperties.$route as any
const tts_model_id = ref('')
const model_form_field = ref<Array<FormField>>([])
const emit = defineEmits(['refresh'])
const dynamicsFormRef = ref<InstanceType<typeof DynamicsForm>>()
@ -69,6 +71,7 @@ const getApi = (model_id: string, application_id?: string) => {
}
const open = (model_id: string, application_id?: string, model_setting_data?: any) => {
form_data.value = {}
tts_model_id.value = model_id
const api = getApi(model_id, application_id)
api.then((ok) => {
model_form_field.value = ok.data
@ -104,9 +107,18 @@ const submit = async () => {
const audioPlayer = ref<HTMLAudioElement | null>(null)
const testPlay = () => {
const data = {
...form_data.value,
tts_model_id: tts_model_id.value
}
applicationApi
.playDemoText(id as string, form_data.value, playLoading)
.then((res: any) => {
.playDemoText(id as string, data, playLoading)
.then(async (res: any) => {
if (res.type === 'application/json') {
const text = await res.text();
MsgError(text)
return
}
// Blob
const blob = new Blob([res], { type: 'audio/mp3' })

View File

@ -153,6 +153,7 @@
v-model="form_data.tts_model_id"
class="w-full"
popper-class="select-model"
@change="ttsModelChange()"
placeholder="请选择语音合成模型"
>
<el-option-group
@ -312,6 +313,15 @@ function getTTSModel() {
})
}
function ttsModelChange() {
if (form_data.value.tts_model_id) {
TTSModeParamSettingDialogRef.value?.reset_default(form_data.value.tts_model_id, id)
} else {
refreshTTSForm({})
}
}
const openTTSParamSettingDialog = () => {
const model_id = form_data.value.tts_model_id
if (!model_id) {