feat: enhance TTS model parameters and API version handling (#4480)

Co-authored-by: xiaoc <648844981@qq.com>
This commit is contained in:
qijingyu0727 2025-12-10 12:25:43 +08:00 committed by GitHub
parent 4e58d079bf
commit 532eee571f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 494 additions and 3 deletions

View File

@ -0,0 +1,133 @@
# coding=utf-8
"""
讯飞 TTS 工厂类 Credential根据 api_version 路由到具体 Credential
"""
from typing import Dict
from django.utils.translation import gettext_lazy as _, gettext
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm, TooltipLabel
from models_provider.base_model_provider import BaseModelCredential, ValidCode
from common.utils.logger import maxkb_logger
class XunFeiDefaultTTSModelCredential(BaseForm, BaseModelCredential):
"""讯飞 TTS 工厂类 Credential根据 api_version 参数路由到具体实现"""
api_version = forms.SingleSelect(
_("API Version"), required=True,
text_field='label',
value_field='value',
default_value='online',
option_list=[
{'label': _('Online TTS'), 'value': 'online'},
{'label': _('Super Humanoid TTS'), 'value': 'super_humanoid'}
])
spark_api_url = forms.TextInputField('API URL', required=True,
default_value='wss://tts-api.xfyun.cn/v2/tts',
relation_show_field_dict={"api_version": ["online"]})
spark_api_url_super = forms.TextInputField('API URL', required=True,
default_value='wss://cbm01.cn-huabei-1.xf-yun.com/v1/private/mcd9m97e6',
relation_show_field_dict={"api_version": ["super_humanoid"]})
# vcn 选择放在 credential 中,根据 api_version 联动显示
vcn_online = forms.SingleSelect(
TooltipLabel(_('Speaker'), _('Speaker selection for standard TTS service')),
required=True, default_value='xiaoyan',
text_field='value',
value_field='value',
option_list=[
{'text': _('iFlytek Xiaoyan'), 'value': 'xiaoyan'},
{'text': _('iFlytek Xujiu'), 'value': 'aisjiuxu'},
{'text': _('iFlytek Xiaoping'), 'value': 'aisxping'},
{'text': _('iFlytek Xiaojing'), 'value': 'aisjinger'},
{'text': _('iFlytek Xuxiaobao'), 'value': 'aisbabyxu'},
],
relation_show_field_dict={"api_version": ["online"]})
vcn_super = forms.SingleSelect(
TooltipLabel(_('Speaker'), _('Speaker selection for super-humanoid TTS service')),
required=True, default_value='x5_lingxiaoxuan_flow',
text_field='value',
value_field='value',
option_list=[
{'text': _('Super-humanoid: Lingxiaoxuan Flow'), 'value': 'x5_lingxiaoxuan_flow'},
{'text': _('Super-humanoid: Lingyuyan Flow'), 'value': 'x5_lingyuyan_flow'},
{'text': _('Super-humanoid: Lingfeiyi Flow'), 'value': 'x5_lingfeiyi_flow'},
{'text': _('Super-humanoid: Lingxiaoyue Flow'), 'value': 'x5_lingxiaoyue_flow'},
{'text': _('Super-humanoid: Sun Dasheng Flow'), 'value': 'x5_sundasheng_flow'},
{'text': _('Super-humanoid: Lingyuzhao Flow'), 'value': 'x5_lingyuzhao_flow'},
{'text': _('Super-humanoid: Lingxiaotang Flow'), 'value': 'x5_lingxiaotang_flow'},
{'text': _('Super-humanoid: Lingxiaorong Flow'), 'value': 'x5_lingxiaorong_flow'},
{'text': _('Super-humanoid: Xinyun Flow'), 'value': 'x5_xinyun_flow'},
{'text': _('Super-humanoid: Grant (EN)'), 'value': 'x5_EnUs_Grant_flow'},
{'text': _('Super-humanoid: Lila (EN)'), 'value': 'x5_EnUs_Lila_flow'},
{'text': _('Super-humanoid: Lingwanwan Pro'), 'value': 'x6_lingwanwan_pro'},
{'text': _('Super-humanoid: Yiyi Pro'), 'value': 'x6_yiyi_pro'},
{'text': _('Super-humanoid: Huifangnv Pro'), 'value': 'x6_huifangnv_pro'},
{'text': _('Super-humanoid: Lingxiaoying Pro'), 'value': 'x6_lingxiaoying_pro'},
{'text': _('Super-humanoid: Lingfeibo Pro'), 'value': 'x6_lingfeibo_pro'},
{'text': _('Super-humanoid: Lingyuyan Pro'), 'value': 'x6_lingyuyan_pro'},
],
relation_show_field_dict={"api_version": ["super_humanoid"]})
spark_app_id = forms.TextInputField('APP ID', required=True)
spark_api_key = forms.PasswordInputField("API Key", required=True)
spark_api_secret = forms.PasswordInputField('API Secret', required=True)
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value,
gettext('{model_type} Model type is not supported').format(model_type=model_type))
api_version = model_credential.get('api_version', 'online')
if api_version == 'super_humanoid':
required_keys = ['spark_api_url_super', 'spark_app_id', 'spark_api_key', 'spark_api_secret']
else:
required_keys = ['spark_api_url', 'spark_app_id', 'spark_api_key', 'spark_api_secret']
for key in required_keys:
if key not in model_credential:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, gettext('{key} is required').format(key=key))
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential, **model_params)
model.check_auth()
except Exception as e:
maxkb_logger.error(f'Exception: {e}', exc_info=True)
if isinstance(e, AppApiException):
raise e
if raise_exception:
raise AppApiException(ValidCode.valid_error.value,
gettext(
'Verification failed, please check whether the parameters are correct: {error}').format(
error=str(e)))
else:
return False
return True
def encryption_dict(self, model: Dict[str, object]):
return {**model, 'spark_api_secret': super().encryption(model.get('spark_api_secret', ''))}
def get_model_params_setting_form(self, model_name):
# params 只包含通用参数vcn 已在 credential 中
return XunFeiDefaultTTSModelParams()
class XunFeiDefaultTTSModelParams(BaseForm):
"""工厂类的参数表单,只包含通用参数"""
speed = forms.SliderField(
TooltipLabel(_('speaking speed'), _('Speech speed, optional value: [0-100], default is 50')),
required=True, default_value=50,
_min=1,
_max=100,
_step=5,
precision=1)

View File

@ -0,0 +1,93 @@
# coding=utf-8
"""
讯飞超拟人语音合成 (Super Humanoid TTS) Credential
"""
from typing import Dict
from django.utils.translation import gettext_lazy as _, gettext
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm, TooltipLabel
from models_provider.base_model_provider import BaseModelCredential, ValidCode
class XunFeiSuperHumanoidTTSModelParams(BaseForm):
"""超拟人语音合成参数"""
vcn = forms.SingleSelect(
TooltipLabel(_('Speaker'), _('Speaker selection for super-humanoid TTS service')),
required=True, default_value='x5_lingxiaoxuan_flow',
text_field='value',
value_field='value',
option_list=[
{'text': _('Super-humanoid: Lingxiaoxuan Flow'), 'value': 'x5_lingxiaoxuan_flow'},
{'text': _('Super-humanoid: Lingyuyan Flow'), 'value': 'x5_lingyuyan_flow'},
{'text': _('Super-humanoid: Lingfeiyi Flow'), 'value': 'x5_lingfeiyi_flow'},
{'text': _('Super-humanoid: Lingxiaoyue Flow'), 'value': 'x5_lingxiaoyue_flow'},
{'text': _('Super-humanoid: Sun Dasheng Flow'), 'value': 'x5_sundasheng_flow'},
{'text': _('Super-humanoid: Lingyuzhao Flow'), 'value': 'x5_lingyuzhao_flow'},
{'text': _('Super-humanoid: Lingxiaotang Flow'), 'value': 'x5_lingxiaotang_flow'},
{'text': _('Super-humanoid: Lingxiaorong Flow'), 'value': 'x5_lingxiaorong_flow'},
{'text': _('Super-humanoid: Xinyun Flow'), 'value': 'x5_xinyun_flow'},
{'text': _('Super-humanoid: Grant (EN)'), 'value': 'x5_EnUs_Grant_flow'},
{'text': _('Super-humanoid: Lila (EN)'), 'value': 'x5_EnUs_Lila_flow'},
{'text': _('Super-humanoid: Lingwanwan Pro'), 'value': 'x6_lingwanwan_pro'},
{'text': _('Super-humanoid: Yiyi Pro'), 'value': 'x6_yiyi_pro'},
{'text': _('Super-humanoid: Huifangnv Pro'), 'value': 'x6_huifangnv_pro'},
{'text': _('Super-humanoid: Lingxiaoying Pro'), 'value': 'x6_lingxiaoying_pro'},
{'text': _('Super-humanoid: Lingfeibo Pro'), 'value': 'x6_lingfeibo_pro'},
{'text': _('Super-humanoid: Lingyuyan Pro'), 'value': 'x6_lingyuyan_pro'},
])
speed = forms.SliderField(
TooltipLabel(_('speaking speed'), _('Speech speed, optional value: [0-100], default is 50')),
required=True, default_value=50,
_min=1,
_max=100,
_step=5,
precision=1)
class XunFeiSuperHumanoidTTSModelCredential(BaseForm, BaseModelCredential):
"""讯飞超拟人语音合成 Credential"""
spark_api_url = forms.TextInputField('API URL', required=True,
default_value='wss://cbm01.cn-huabei-1.xf-yun.com/v1/private/mcd9m97e6')
spark_app_id = forms.TextInputField('APP ID', required=True)
spark_api_key = forms.PasswordInputField("API Key", required=True)
spark_api_secret = forms.PasswordInputField('API Secret', required=True)
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value,
gettext('{model_type} Model type is not supported').format(model_type=model_type))
required_keys = ['spark_api_url', 'spark_app_id', 'spark_api_key', 'spark_api_secret']
for key in required_keys:
if key not in model_credential:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, gettext('{key} is required').format(key=key))
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential, **model_params)
model.check_auth()
except Exception as e:
if isinstance(e, AppApiException):
raise e
if raise_exception:
raise AppApiException(ValidCode.valid_error.value,
gettext(
'Verification failed, please check whether the parameters are correct: {error}').format(
error=str(e)))
else:
return False
return True
def encryption_dict(self, model: Dict[str, object]):
return {**model, 'spark_api_secret': super().encryption(model.get('spark_api_secret', ''))}
def get_model_params_setting_form(self, model_name):
return XunFeiSuperHumanoidTTSModelParams()

View File

@ -0,0 +1,58 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file default_tts.py
@date2025/12/9
@desc: 讯飞 TTS 工厂类根据 api_version 路由到具体实现
"""
from typing import Dict
from models_provider.base_model_provider import MaxKBBaseModel
from models_provider.impl.base_tts import BaseTextToSpeech
class XFSparkDefaultTextToSpeech(MaxKBBaseModel, BaseTextToSpeech):
"""讯飞 TTS 工厂类,根据 api_version 参数路由到具体实现"""
def check_auth(self):
pass
def text_to_speech(self, text):
pass
@staticmethod
def is_cache_model():
return False
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
from models_provider.impl.xf_model_provider.model.tts import XFSparkTextToSpeech
from models_provider.impl.xf_model_provider.model.super_humanoid_tts import XFSparkSuperHumanoidTextToSpeech
api_version = model_credential.get('api_version', 'online')
if api_version == 'super_humanoid':
# 超拟人:从 credential 获取 vcn_super构造统一的 credential 格式
vcn = model_credential.get('vcn_super', 'x5_lingxiaoxuan_flow')
unified_credential = {
'spark_app_id': model_credential.get('spark_app_id'),
'spark_api_key': model_credential.get('spark_api_key'),
'spark_api_secret': model_credential.get('spark_api_secret'),
'spark_api_url': model_credential.get('spark_api_url_super'),
}
return XFSparkSuperHumanoidTextToSpeech.new_instance(
model_type, model_name, unified_credential, vcn=vcn, **model_kwargs
)
else:
# 在线语音:从 credential 获取 vcn_online
vcn = model_credential.get('vcn_online', 'xiaoyan')
unified_credential = {
'spark_app_id': model_credential.get('spark_app_id'),
'spark_api_key': model_credential.get('spark_api_key'),
'spark_api_secret': model_credential.get('spark_api_secret'),
'spark_api_url': model_credential.get('spark_api_url'),
}
return XFSparkTextToSpeech.new_instance(
model_type, model_name, unified_credential, vcn=vcn, **model_kwargs
)

View File

@ -0,0 +1,195 @@
# -*- coding:utf-8 -*-
#
# author: iflytek
#
# 错误码链接https://www.xfyun.cn/document/error-code code返回错误码时必看
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
import asyncio
import base64
import hashlib
import hmac
import json
import ssl
from datetime import datetime, UTC
from typing import Dict
from urllib.parse import urlencode, urlparse
import websockets
from django.utils.translation import gettext as _
from common.utils.common import _remove_empty_lines
from models_provider.base_model_provider import MaxKBBaseModel
from models_provider.impl.base_tts import BaseTextToSpeech
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
class XFSparkSuperHumanoidTextToSpeech(MaxKBBaseModel, BaseTextToSpeech):
"""讯飞超拟人语音合成 (Super Humanoid TTS)"""
spark_app_id: str
spark_api_key: str
spark_api_secret: str
spark_api_url: str
params: dict
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.spark_api_url = kwargs.get('spark_api_url')
self.spark_app_id = kwargs.get('spark_app_id')
self.spark_api_key = kwargs.get('spark_api_key')
self.spark_api_secret = kwargs.get('spark_api_secret')
self.params = kwargs.get('params') or {}
@staticmethod
def is_cache_model():
return False
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
spark_api_url = model_credential.get('spark_api_url')
vcn = model_kwargs.get('vcn', 'x5_lingxiaoxuan_flow')
params = {'vcn': vcn}
for k, v in model_kwargs.items():
if k not in ['model_id', 'use_local', 'streaming', 'vcn']:
params[k] = v
return XFSparkSuperHumanoidTextToSpeech(
spark_app_id=model_credential.get('spark_app_id'),
spark_api_key=model_credential.get('spark_api_key'),
spark_api_secret=model_credential.get('spark_api_secret'),
spark_api_url=spark_api_url,
params=params
)
def create_url(self):
url = self.spark_api_url
host = urlparse(url).hostname
gmt_format = '%a, %d %b %Y %H:%M:%S GMT'
date = datetime.now(UTC).strftime(gmt_format)
signature_origin = f"host: {host}\n"
signature_origin += f"date: {date}\n"
signature_origin += f"GET {urlparse(url).path} HTTP/1.1"
signature_sha = hmac.new(
self.spark_api_secret.encode('utf-8'),
signature_origin.encode('utf-8'),
digestmod=hashlib.sha256
).digest()
signature_sha = base64.b64encode(signature_sha).decode('utf-8')
authorization_origin = \
f'api_key="{self.spark_api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha}"'
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode('utf-8')
v = {
"authorization": authorization,
"date": date,
"host": host
}
url = url + '?' + urlencode(v)
return url
def check_auth(self):
self.text_to_speech(_('Hello'))
def text_to_speech(self, text):
text = _remove_empty_lines(text)
async def handle():
try:
async with websockets.connect(self.create_url(), max_size=1000000000, ssl=ssl_context) as ws:
await self.send(ws, text)
return await self.handle_message(ws)
except websockets.exceptions.InvalidStatus as e:
if e.response.status_code == 401:
raise Exception(
_("Authentication failed (HTTP 401). Please check: "
"1) API URL is correct for TTS service; "
"2) APP ID, API Key, and API Secret are correct; "
"3) Your iFlytek account has TTS service enabled.")
)
else:
raise Exception(f"WebSocket connection failed: HTTP {e.response.status_code}")
except Exception as e:
if "Authentication failed" in str(e):
raise
raise Exception(f"iFlytek TTS service error: {str(e)}")
return asyncio.run(handle())
@staticmethod
async def handle_message(ws):
audio_bytes: bytes = b''
while True:
res = await ws.recv()
message = json.loads(res)
if "header" in message and "code" in message["header"]:
code = message["header"]["code"]
sid = message["header"].get("sid", "unknown")
if code != 0:
errMsg = message["header"].get("message", "Unknown error")
raise Exception(f"sid: {sid} call error: {errMsg} code is: {code}")
if "payload" in message and "audio" in message["payload"]:
audio = base64.b64decode(message["payload"]["audio"]["audio"])
audio_bytes += audio
if message["payload"]["audio"].get("status") == 2:
break
else:
raise Exception(
f"Unexpected response from iFlytek API. Response: {json.dumps(message, ensure_ascii=False)}"
)
return audio_bytes
async def send(self, ws, text):
vcn_value = self.params.get("vcn", "x5_lingxiaoxuan_flow")
# 确保 vcn 值符合超拟人格式
if not vcn_value or not (str(vcn_value).startswith('x5_') or str(vcn_value).startswith('x6_')):
vcn_value = 'x5_lingxiaoxuan_flow'
audio_params = {
"encoding": self.params.get("encoding", "lame"),
"sample_rate": self.params.get("sample_rate", 24000),
"channels": self.params.get("channels", 1),
"bit_depth": self.params.get("bit_depth", 16),
"frame_size": self.params.get("frame_size", 0)
}
tts_params = {
"vcn": vcn_value,
"audio": audio_params,
"volume": self.params.get("volume", 50),
"speed": self.params.get("speed", 50),
"pitch": self.params.get("pitch", 50)
}
encoded_text = base64.b64encode(text.encode('utf-8')).decode('utf-8')
payload_text_obj = {
"encoding": "utf8",
"compress": "raw",
"format": "plain",
"status": 2,
"seq": 0,
"text": encoded_text
}
d = {
"header": {"app_id": self.spark_app_id, "status": 2},
"parameter": {"tts": tts_params},
"payload": {"text": payload_text_obj}
}
await ws.send(json.dumps(d))

View File

@ -17,12 +17,16 @@ from models_provider.impl.xf_model_provider.credential.image import XunFeiImageM
from models_provider.impl.xf_model_provider.credential.llm import XunFeiLLMModelCredential
from models_provider.impl.xf_model_provider.credential.stt import XunFeiSTTModelCredential
from models_provider.impl.xf_model_provider.credential.tts import XunFeiTTSModelCredential
from models_provider.impl.xf_model_provider.credential.super_humanoid_tts import XunFeiSuperHumanoidTTSModelCredential
from models_provider.impl.xf_model_provider.credential.default_tts import XunFeiDefaultTTSModelCredential
from models_provider.impl.xf_model_provider.credential.zh_en_stt import ZhEnXunFeiSTTModelCredential
from models_provider.impl.xf_model_provider.model.embedding import XFEmbedding
from models_provider.impl.xf_model_provider.model.image import XFSparkImage
from models_provider.impl.xf_model_provider.model.llm import XFChatSparkLLM
from models_provider.impl.xf_model_provider.model.stt import XFSparkSpeechToText
from models_provider.impl.xf_model_provider.model.tts import XFSparkTextToSpeech
from models_provider.impl.xf_model_provider.model.super_humanoid_tts import XFSparkSuperHumanoidTextToSpeech
from models_provider.impl.xf_model_provider.model.default_tts import XFSparkDefaultTextToSpeech
from maxkb.conf import PROJECT_DIR
from django.utils.translation import gettext as _
@ -34,8 +38,12 @@ xunfei_model_credential = XunFeiLLMModelCredential()
stt_model_credential = XunFeiSTTModelCredential()
zh_en_stt_credential = ZhEnXunFeiSTTModelCredential()
image_model_credential = XunFeiImageModelCredential()
# TTS credentials
tts_model_credential = XunFeiTTSModelCredential()
super_humanoid_tts_credential = XunFeiSuperHumanoidTTSModelCredential()
default_tts_credential = XunFeiDefaultTTSModelCredential()
embedding_model_credential = XFEmbeddingCredential()
model_info_list = [
ModelInfo('generalv3.5', '', ModelTypeConst.LLM, xunfei_model_credential, XFChatSparkLLM),
ModelInfo('generalv3', '', ModelTypeConst.LLM, xunfei_model_credential, XFChatSparkLLM),
@ -44,7 +52,10 @@ model_info_list = [
XFSparkSpeechToText),
ModelInfo('slm', _('Chinese and English recognition'), ModelTypeConst.STT, zh_en_stt_credential,
XFZhEnSparkSpeechToText),
ModelInfo('tts', '', ModelTypeConst.TTS, tts_model_credential, XFSparkTextToSpeech),
# 具体 TTS 模型
ModelInfo('tts', _('Online TTS'), ModelTypeConst.TTS, tts_model_credential, XFSparkTextToSpeech),
ModelInfo('tts-super-humanoid', _('Super Humanoid TTS'), ModelTypeConst.TTS, super_humanoid_tts_credential,
XFSparkSuperHumanoidTextToSpeech),
ModelInfo('embedding', '', ModelTypeConst.EMBEDDING, embedding_model_credential, XFEmbedding)
]
@ -57,8 +68,9 @@ model_info_manage = (
ModelInfo('iat', _('Chinese and English recognition'), ModelTypeConst.STT, stt_model_credential,
XFSparkSpeechToText),
)
# default TTS 工厂入口
.append_default_model_info(
ModelInfo('tts', '', ModelTypeConst.TTS, tts_model_credential, XFSparkTextToSpeech))
ModelInfo('default', _('default'), ModelTypeConst.TTS, default_tts_credential, XFSparkDefaultTextToSpeech))
.append_default_model_info(
ModelInfo('embedding', '', ModelTypeConst.EMBEDDING, embedding_model_credential, XFEmbedding))
.build()
@ -73,4 +85,4 @@ class XunFeiModelProvider(IModelProvider):
def get_model_provide_info(self):
return ModelProvideInfo(provider='model_xf_provider', name=_('iFlytek Spark'), icon=get_file_content(
os.path.join(PROJECT_DIR, "apps", 'models_provider', 'impl', 'xf_model_provider', 'icon',
'xf_icon_svg')))
'xf_icon_svg')))