mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 01:33:05 +00:00
feat: enhance TTS model parameters and API version handling (#4480)
Co-authored-by: xiaoc <648844981@qq.com>
This commit is contained in:
parent
4e58d079bf
commit
532eee571f
|
|
@ -0,0 +1,133 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
讯飞 TTS 工厂类 Credential,根据 api_version 路由到具体 Credential
|
||||
"""
|
||||
from typing import Dict
|
||||
|
||||
from django.utils.translation import gettext_lazy as _, gettext
|
||||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm, TooltipLabel
|
||||
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
from common.utils.logger import maxkb_logger
|
||||
|
||||
|
||||
class XunFeiDefaultTTSModelCredential(BaseForm, BaseModelCredential):
|
||||
"""讯飞 TTS 工厂类 Credential,根据 api_version 参数路由到具体实现"""
|
||||
|
||||
api_version = forms.SingleSelect(
|
||||
_("API Version"), required=True,
|
||||
text_field='label',
|
||||
value_field='value',
|
||||
default_value='online',
|
||||
option_list=[
|
||||
{'label': _('Online TTS'), 'value': 'online'},
|
||||
{'label': _('Super Humanoid TTS'), 'value': 'super_humanoid'}
|
||||
])
|
||||
|
||||
spark_api_url = forms.TextInputField('API URL', required=True,
|
||||
default_value='wss://tts-api.xfyun.cn/v2/tts',
|
||||
relation_show_field_dict={"api_version": ["online"]})
|
||||
spark_api_url_super = forms.TextInputField('API URL', required=True,
|
||||
default_value='wss://cbm01.cn-huabei-1.xf-yun.com/v1/private/mcd9m97e6',
|
||||
relation_show_field_dict={"api_version": ["super_humanoid"]})
|
||||
|
||||
# vcn 选择放在 credential 中,根据 api_version 联动显示
|
||||
vcn_online = forms.SingleSelect(
|
||||
TooltipLabel(_('Speaker'), _('Speaker selection for standard TTS service')),
|
||||
required=True, default_value='xiaoyan',
|
||||
text_field='value',
|
||||
value_field='value',
|
||||
option_list=[
|
||||
{'text': _('iFlytek Xiaoyan'), 'value': 'xiaoyan'},
|
||||
{'text': _('iFlytek Xujiu'), 'value': 'aisjiuxu'},
|
||||
{'text': _('iFlytek Xiaoping'), 'value': 'aisxping'},
|
||||
{'text': _('iFlytek Xiaojing'), 'value': 'aisjinger'},
|
||||
{'text': _('iFlytek Xuxiaobao'), 'value': 'aisbabyxu'},
|
||||
],
|
||||
relation_show_field_dict={"api_version": ["online"]})
|
||||
|
||||
vcn_super = forms.SingleSelect(
|
||||
TooltipLabel(_('Speaker'), _('Speaker selection for super-humanoid TTS service')),
|
||||
required=True, default_value='x5_lingxiaoxuan_flow',
|
||||
text_field='value',
|
||||
value_field='value',
|
||||
option_list=[
|
||||
{'text': _('Super-humanoid: Lingxiaoxuan Flow'), 'value': 'x5_lingxiaoxuan_flow'},
|
||||
{'text': _('Super-humanoid: Lingyuyan Flow'), 'value': 'x5_lingyuyan_flow'},
|
||||
{'text': _('Super-humanoid: Lingfeiyi Flow'), 'value': 'x5_lingfeiyi_flow'},
|
||||
{'text': _('Super-humanoid: Lingxiaoyue Flow'), 'value': 'x5_lingxiaoyue_flow'},
|
||||
{'text': _('Super-humanoid: Sun Dasheng Flow'), 'value': 'x5_sundasheng_flow'},
|
||||
{'text': _('Super-humanoid: Lingyuzhao Flow'), 'value': 'x5_lingyuzhao_flow'},
|
||||
{'text': _('Super-humanoid: Lingxiaotang Flow'), 'value': 'x5_lingxiaotang_flow'},
|
||||
{'text': _('Super-humanoid: Lingxiaorong Flow'), 'value': 'x5_lingxiaorong_flow'},
|
||||
{'text': _('Super-humanoid: Xinyun Flow'), 'value': 'x5_xinyun_flow'},
|
||||
{'text': _('Super-humanoid: Grant (EN)'), 'value': 'x5_EnUs_Grant_flow'},
|
||||
{'text': _('Super-humanoid: Lila (EN)'), 'value': 'x5_EnUs_Lila_flow'},
|
||||
{'text': _('Super-humanoid: Lingwanwan Pro'), 'value': 'x6_lingwanwan_pro'},
|
||||
{'text': _('Super-humanoid: Yiyi Pro'), 'value': 'x6_yiyi_pro'},
|
||||
{'text': _('Super-humanoid: Huifangnv Pro'), 'value': 'x6_huifangnv_pro'},
|
||||
{'text': _('Super-humanoid: Lingxiaoying Pro'), 'value': 'x6_lingxiaoying_pro'},
|
||||
{'text': _('Super-humanoid: Lingfeibo Pro'), 'value': 'x6_lingfeibo_pro'},
|
||||
{'text': _('Super-humanoid: Lingyuyan Pro'), 'value': 'x6_lingyuyan_pro'},
|
||||
],
|
||||
relation_show_field_dict={"api_version": ["super_humanoid"]})
|
||||
|
||||
spark_app_id = forms.TextInputField('APP ID', required=True)
|
||||
spark_api_key = forms.PasswordInputField("API Key", required=True)
|
||||
spark_api_secret = forms.PasswordInputField('API Secret', required=True)
|
||||
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||
raise_exception=False):
|
||||
model_type_list = provider.get_model_type_list()
|
||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||
raise AppApiException(ValidCode.valid_error.value,
|
||||
gettext('{model_type} Model type is not supported').format(model_type=model_type))
|
||||
|
||||
api_version = model_credential.get('api_version', 'online')
|
||||
if api_version == 'super_humanoid':
|
||||
required_keys = ['spark_api_url_super', 'spark_app_id', 'spark_api_key', 'spark_api_secret']
|
||||
else:
|
||||
required_keys = ['spark_api_url', 'spark_app_id', 'spark_api_key', 'spark_api_secret']
|
||||
|
||||
for key in required_keys:
|
||||
if key not in model_credential:
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, gettext('{key} is required').format(key=key))
|
||||
else:
|
||||
return False
|
||||
try:
|
||||
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||
model.check_auth()
|
||||
except Exception as e:
|
||||
maxkb_logger.error(f'Exception: {e}', exc_info=True)
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value,
|
||||
gettext(
|
||||
'Verification failed, please check whether the parameters are correct: {error}').format(
|
||||
error=str(e)))
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
def encryption_dict(self, model: Dict[str, object]):
|
||||
return {**model, 'spark_api_secret': super().encryption(model.get('spark_api_secret', ''))}
|
||||
|
||||
def get_model_params_setting_form(self, model_name):
|
||||
# params 只包含通用参数,vcn 已在 credential 中
|
||||
return XunFeiDefaultTTSModelParams()
|
||||
|
||||
|
||||
class XunFeiDefaultTTSModelParams(BaseForm):
|
||||
"""工厂类的参数表单,只包含通用参数"""
|
||||
|
||||
speed = forms.SliderField(
|
||||
TooltipLabel(_('speaking speed'), _('Speech speed, optional value: [0-100], default is 50')),
|
||||
required=True, default_value=50,
|
||||
_min=1,
|
||||
_max=100,
|
||||
_step=5,
|
||||
precision=1)
|
||||
|
|
@ -0,0 +1,93 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
讯飞超拟人语音合成 (Super Humanoid TTS) Credential
|
||||
"""
|
||||
from typing import Dict
|
||||
|
||||
from django.utils.translation import gettext_lazy as _, gettext
|
||||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm, TooltipLabel
|
||||
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
|
||||
class XunFeiSuperHumanoidTTSModelParams(BaseForm):
|
||||
"""超拟人语音合成参数"""
|
||||
vcn = forms.SingleSelect(
|
||||
TooltipLabel(_('Speaker'), _('Speaker selection for super-humanoid TTS service')),
|
||||
required=True, default_value='x5_lingxiaoxuan_flow',
|
||||
text_field='value',
|
||||
value_field='value',
|
||||
option_list=[
|
||||
{'text': _('Super-humanoid: Lingxiaoxuan Flow'), 'value': 'x5_lingxiaoxuan_flow'},
|
||||
{'text': _('Super-humanoid: Lingyuyan Flow'), 'value': 'x5_lingyuyan_flow'},
|
||||
{'text': _('Super-humanoid: Lingfeiyi Flow'), 'value': 'x5_lingfeiyi_flow'},
|
||||
{'text': _('Super-humanoid: Lingxiaoyue Flow'), 'value': 'x5_lingxiaoyue_flow'},
|
||||
{'text': _('Super-humanoid: Sun Dasheng Flow'), 'value': 'x5_sundasheng_flow'},
|
||||
{'text': _('Super-humanoid: Lingyuzhao Flow'), 'value': 'x5_lingyuzhao_flow'},
|
||||
{'text': _('Super-humanoid: Lingxiaotang Flow'), 'value': 'x5_lingxiaotang_flow'},
|
||||
{'text': _('Super-humanoid: Lingxiaorong Flow'), 'value': 'x5_lingxiaorong_flow'},
|
||||
{'text': _('Super-humanoid: Xinyun Flow'), 'value': 'x5_xinyun_flow'},
|
||||
{'text': _('Super-humanoid: Grant (EN)'), 'value': 'x5_EnUs_Grant_flow'},
|
||||
{'text': _('Super-humanoid: Lila (EN)'), 'value': 'x5_EnUs_Lila_flow'},
|
||||
{'text': _('Super-humanoid: Lingwanwan Pro'), 'value': 'x6_lingwanwan_pro'},
|
||||
{'text': _('Super-humanoid: Yiyi Pro'), 'value': 'x6_yiyi_pro'},
|
||||
{'text': _('Super-humanoid: Huifangnv Pro'), 'value': 'x6_huifangnv_pro'},
|
||||
{'text': _('Super-humanoid: Lingxiaoying Pro'), 'value': 'x6_lingxiaoying_pro'},
|
||||
{'text': _('Super-humanoid: Lingfeibo Pro'), 'value': 'x6_lingfeibo_pro'},
|
||||
{'text': _('Super-humanoid: Lingyuyan Pro'), 'value': 'x6_lingyuyan_pro'},
|
||||
])
|
||||
|
||||
speed = forms.SliderField(
|
||||
TooltipLabel(_('speaking speed'), _('Speech speed, optional value: [0-100], default is 50')),
|
||||
required=True, default_value=50,
|
||||
_min=1,
|
||||
_max=100,
|
||||
_step=5,
|
||||
precision=1)
|
||||
|
||||
|
||||
class XunFeiSuperHumanoidTTSModelCredential(BaseForm, BaseModelCredential):
|
||||
"""讯飞超拟人语音合成 Credential"""
|
||||
spark_api_url = forms.TextInputField('API URL', required=True,
|
||||
default_value='wss://cbm01.cn-huabei-1.xf-yun.com/v1/private/mcd9m97e6')
|
||||
spark_app_id = forms.TextInputField('APP ID', required=True)
|
||||
spark_api_key = forms.PasswordInputField("API Key", required=True)
|
||||
spark_api_secret = forms.PasswordInputField('API Secret', required=True)
|
||||
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||
raise_exception=False):
|
||||
model_type_list = provider.get_model_type_list()
|
||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||
raise AppApiException(ValidCode.valid_error.value,
|
||||
gettext('{model_type} Model type is not supported').format(model_type=model_type))
|
||||
|
||||
required_keys = ['spark_api_url', 'spark_app_id', 'spark_api_key', 'spark_api_secret']
|
||||
|
||||
for key in required_keys:
|
||||
if key not in model_credential:
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, gettext('{key} is required').format(key=key))
|
||||
else:
|
||||
return False
|
||||
try:
|
||||
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||
model.check_auth()
|
||||
except Exception as e:
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value,
|
||||
gettext(
|
||||
'Verification failed, please check whether the parameters are correct: {error}').format(
|
||||
error=str(e)))
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
def encryption_dict(self, model: Dict[str, object]):
|
||||
return {**model, 'spark_api_secret': super().encryption(model.get('spark_api_secret', ''))}
|
||||
|
||||
def get_model_params_setting_form(self, model_name):
|
||||
return XunFeiSuperHumanoidTTSModelParams()
|
||||
|
|
@ -0,0 +1,58 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:
|
||||
@file: default_tts.py
|
||||
@date:2025/12/9
|
||||
@desc: 讯飞 TTS 工厂类,根据 api_version 路由到具体实现
|
||||
"""
|
||||
from typing import Dict
|
||||
|
||||
from models_provider.base_model_provider import MaxKBBaseModel
|
||||
from models_provider.impl.base_tts import BaseTextToSpeech
|
||||
|
||||
|
||||
class XFSparkDefaultTextToSpeech(MaxKBBaseModel, BaseTextToSpeech):
|
||||
"""讯飞 TTS 工厂类,根据 api_version 参数路由到具体实现"""
|
||||
|
||||
def check_auth(self):
|
||||
pass
|
||||
|
||||
def text_to_speech(self, text):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def is_cache_model():
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
from models_provider.impl.xf_model_provider.model.tts import XFSparkTextToSpeech
|
||||
from models_provider.impl.xf_model_provider.model.super_humanoid_tts import XFSparkSuperHumanoidTextToSpeech
|
||||
|
||||
api_version = model_credential.get('api_version', 'online')
|
||||
|
||||
if api_version == 'super_humanoid':
|
||||
# 超拟人:从 credential 获取 vcn_super,构造统一的 credential 格式
|
||||
vcn = model_credential.get('vcn_super', 'x5_lingxiaoxuan_flow')
|
||||
unified_credential = {
|
||||
'spark_app_id': model_credential.get('spark_app_id'),
|
||||
'spark_api_key': model_credential.get('spark_api_key'),
|
||||
'spark_api_secret': model_credential.get('spark_api_secret'),
|
||||
'spark_api_url': model_credential.get('spark_api_url_super'),
|
||||
}
|
||||
return XFSparkSuperHumanoidTextToSpeech.new_instance(
|
||||
model_type, model_name, unified_credential, vcn=vcn, **model_kwargs
|
||||
)
|
||||
else:
|
||||
# 在线语音:从 credential 获取 vcn_online
|
||||
vcn = model_credential.get('vcn_online', 'xiaoyan')
|
||||
unified_credential = {
|
||||
'spark_app_id': model_credential.get('spark_app_id'),
|
||||
'spark_api_key': model_credential.get('spark_api_key'),
|
||||
'spark_api_secret': model_credential.get('spark_api_secret'),
|
||||
'spark_api_url': model_credential.get('spark_api_url'),
|
||||
}
|
||||
return XFSparkTextToSpeech.new_instance(
|
||||
model_type, model_name, unified_credential, vcn=vcn, **model_kwargs
|
||||
)
|
||||
|
|
@ -0,0 +1,195 @@
|
|||
# -*- coding:utf-8 -*-
|
||||
#
|
||||
# author: iflytek
|
||||
#
|
||||
# 错误码链接:https://www.xfyun.cn/document/error-code (code返回错误码时必看)
|
||||
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
|
||||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import ssl
|
||||
from datetime import datetime, UTC
|
||||
from typing import Dict
|
||||
from urllib.parse import urlencode, urlparse
|
||||
|
||||
import websockets
|
||||
from django.utils.translation import gettext as _
|
||||
|
||||
from common.utils.common import _remove_empty_lines
|
||||
from models_provider.base_model_provider import MaxKBBaseModel
|
||||
from models_provider.impl.base_tts import BaseTextToSpeech
|
||||
|
||||
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
||||
ssl_context.check_hostname = False
|
||||
ssl_context.verify_mode = ssl.CERT_NONE
|
||||
|
||||
|
||||
class XFSparkSuperHumanoidTextToSpeech(MaxKBBaseModel, BaseTextToSpeech):
|
||||
"""讯飞超拟人语音合成 (Super Humanoid TTS)"""
|
||||
spark_app_id: str
|
||||
spark_api_key: str
|
||||
spark_api_secret: str
|
||||
spark_api_url: str
|
||||
params: dict
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.spark_api_url = kwargs.get('spark_api_url')
|
||||
self.spark_app_id = kwargs.get('spark_app_id')
|
||||
self.spark_api_key = kwargs.get('spark_api_key')
|
||||
self.spark_api_secret = kwargs.get('spark_api_secret')
|
||||
self.params = kwargs.get('params') or {}
|
||||
|
||||
@staticmethod
|
||||
def is_cache_model():
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
spark_api_url = model_credential.get('spark_api_url')
|
||||
vcn = model_kwargs.get('vcn', 'x5_lingxiaoxuan_flow')
|
||||
|
||||
params = {'vcn': vcn}
|
||||
for k, v in model_kwargs.items():
|
||||
if k not in ['model_id', 'use_local', 'streaming', 'vcn']:
|
||||
params[k] = v
|
||||
|
||||
return XFSparkSuperHumanoidTextToSpeech(
|
||||
spark_app_id=model_credential.get('spark_app_id'),
|
||||
spark_api_key=model_credential.get('spark_api_key'),
|
||||
spark_api_secret=model_credential.get('spark_api_secret'),
|
||||
spark_api_url=spark_api_url,
|
||||
params=params
|
||||
)
|
||||
|
||||
def create_url(self):
|
||||
url = self.spark_api_url
|
||||
host = urlparse(url).hostname
|
||||
|
||||
gmt_format = '%a, %d %b %Y %H:%M:%S GMT'
|
||||
date = datetime.now(UTC).strftime(gmt_format)
|
||||
|
||||
signature_origin = f"host: {host}\n"
|
||||
signature_origin += f"date: {date}\n"
|
||||
signature_origin += f"GET {urlparse(url).path} HTTP/1.1"
|
||||
|
||||
signature_sha = hmac.new(
|
||||
self.spark_api_secret.encode('utf-8'),
|
||||
signature_origin.encode('utf-8'),
|
||||
digestmod=hashlib.sha256
|
||||
).digest()
|
||||
|
||||
signature_sha = base64.b64encode(signature_sha).decode('utf-8')
|
||||
|
||||
authorization_origin = \
|
||||
f'api_key="{self.spark_api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha}"'
|
||||
|
||||
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode('utf-8')
|
||||
|
||||
v = {
|
||||
"authorization": authorization,
|
||||
"date": date,
|
||||
"host": host
|
||||
}
|
||||
|
||||
url = url + '?' + urlencode(v)
|
||||
return url
|
||||
|
||||
def check_auth(self):
|
||||
self.text_to_speech(_('Hello'))
|
||||
|
||||
def text_to_speech(self, text):
|
||||
text = _remove_empty_lines(text)
|
||||
|
||||
async def handle():
|
||||
try:
|
||||
async with websockets.connect(self.create_url(), max_size=1000000000, ssl=ssl_context) as ws:
|
||||
await self.send(ws, text)
|
||||
return await self.handle_message(ws)
|
||||
except websockets.exceptions.InvalidStatus as e:
|
||||
if e.response.status_code == 401:
|
||||
raise Exception(
|
||||
_("Authentication failed (HTTP 401). Please check: "
|
||||
"1) API URL is correct for TTS service; "
|
||||
"2) APP ID, API Key, and API Secret are correct; "
|
||||
"3) Your iFlytek account has TTS service enabled.")
|
||||
)
|
||||
else:
|
||||
raise Exception(f"WebSocket connection failed: HTTP {e.response.status_code}")
|
||||
except Exception as e:
|
||||
if "Authentication failed" in str(e):
|
||||
raise
|
||||
raise Exception(f"iFlytek TTS service error: {str(e)}")
|
||||
|
||||
return asyncio.run(handle())
|
||||
|
||||
@staticmethod
|
||||
async def handle_message(ws):
|
||||
audio_bytes: bytes = b''
|
||||
while True:
|
||||
res = await ws.recv()
|
||||
message = json.loads(res)
|
||||
|
||||
if "header" in message and "code" in message["header"]:
|
||||
code = message["header"]["code"]
|
||||
sid = message["header"].get("sid", "unknown")
|
||||
|
||||
if code != 0:
|
||||
errMsg = message["header"].get("message", "Unknown error")
|
||||
raise Exception(f"sid: {sid} call error: {errMsg} code is: {code}")
|
||||
|
||||
if "payload" in message and "audio" in message["payload"]:
|
||||
audio = base64.b64decode(message["payload"]["audio"]["audio"])
|
||||
audio_bytes += audio
|
||||
|
||||
if message["payload"]["audio"].get("status") == 2:
|
||||
break
|
||||
else:
|
||||
raise Exception(
|
||||
f"Unexpected response from iFlytek API. Response: {json.dumps(message, ensure_ascii=False)}"
|
||||
)
|
||||
|
||||
return audio_bytes
|
||||
|
||||
async def send(self, ws, text):
|
||||
vcn_value = self.params.get("vcn", "x5_lingxiaoxuan_flow")
|
||||
|
||||
# 确保 vcn 值符合超拟人格式
|
||||
if not vcn_value or not (str(vcn_value).startswith('x5_') or str(vcn_value).startswith('x6_')):
|
||||
vcn_value = 'x5_lingxiaoxuan_flow'
|
||||
|
||||
audio_params = {
|
||||
"encoding": self.params.get("encoding", "lame"),
|
||||
"sample_rate": self.params.get("sample_rate", 24000),
|
||||
"channels": self.params.get("channels", 1),
|
||||
"bit_depth": self.params.get("bit_depth", 16),
|
||||
"frame_size": self.params.get("frame_size", 0)
|
||||
}
|
||||
|
||||
tts_params = {
|
||||
"vcn": vcn_value,
|
||||
"audio": audio_params,
|
||||
"volume": self.params.get("volume", 50),
|
||||
"speed": self.params.get("speed", 50),
|
||||
"pitch": self.params.get("pitch", 50)
|
||||
}
|
||||
|
||||
encoded_text = base64.b64encode(text.encode('utf-8')).decode('utf-8')
|
||||
payload_text_obj = {
|
||||
"encoding": "utf8",
|
||||
"compress": "raw",
|
||||
"format": "plain",
|
||||
"status": 2,
|
||||
"seq": 0,
|
||||
"text": encoded_text
|
||||
}
|
||||
|
||||
d = {
|
||||
"header": {"app_id": self.spark_app_id, "status": 2},
|
||||
"parameter": {"tts": tts_params},
|
||||
"payload": {"text": payload_text_obj}
|
||||
}
|
||||
|
||||
await ws.send(json.dumps(d))
|
||||
|
|
@ -17,12 +17,16 @@ from models_provider.impl.xf_model_provider.credential.image import XunFeiImageM
|
|||
from models_provider.impl.xf_model_provider.credential.llm import XunFeiLLMModelCredential
|
||||
from models_provider.impl.xf_model_provider.credential.stt import XunFeiSTTModelCredential
|
||||
from models_provider.impl.xf_model_provider.credential.tts import XunFeiTTSModelCredential
|
||||
from models_provider.impl.xf_model_provider.credential.super_humanoid_tts import XunFeiSuperHumanoidTTSModelCredential
|
||||
from models_provider.impl.xf_model_provider.credential.default_tts import XunFeiDefaultTTSModelCredential
|
||||
from models_provider.impl.xf_model_provider.credential.zh_en_stt import ZhEnXunFeiSTTModelCredential
|
||||
from models_provider.impl.xf_model_provider.model.embedding import XFEmbedding
|
||||
from models_provider.impl.xf_model_provider.model.image import XFSparkImage
|
||||
from models_provider.impl.xf_model_provider.model.llm import XFChatSparkLLM
|
||||
from models_provider.impl.xf_model_provider.model.stt import XFSparkSpeechToText
|
||||
from models_provider.impl.xf_model_provider.model.tts import XFSparkTextToSpeech
|
||||
from models_provider.impl.xf_model_provider.model.super_humanoid_tts import XFSparkSuperHumanoidTextToSpeech
|
||||
from models_provider.impl.xf_model_provider.model.default_tts import XFSparkDefaultTextToSpeech
|
||||
from maxkb.conf import PROJECT_DIR
|
||||
from django.utils.translation import gettext as _
|
||||
|
||||
|
|
@ -34,8 +38,12 @@ xunfei_model_credential = XunFeiLLMModelCredential()
|
|||
stt_model_credential = XunFeiSTTModelCredential()
|
||||
zh_en_stt_credential = ZhEnXunFeiSTTModelCredential()
|
||||
image_model_credential = XunFeiImageModelCredential()
|
||||
# TTS credentials
|
||||
tts_model_credential = XunFeiTTSModelCredential()
|
||||
super_humanoid_tts_credential = XunFeiSuperHumanoidTTSModelCredential()
|
||||
default_tts_credential = XunFeiDefaultTTSModelCredential()
|
||||
embedding_model_credential = XFEmbeddingCredential()
|
||||
|
||||
model_info_list = [
|
||||
ModelInfo('generalv3.5', '', ModelTypeConst.LLM, xunfei_model_credential, XFChatSparkLLM),
|
||||
ModelInfo('generalv3', '', ModelTypeConst.LLM, xunfei_model_credential, XFChatSparkLLM),
|
||||
|
|
@ -44,7 +52,10 @@ model_info_list = [
|
|||
XFSparkSpeechToText),
|
||||
ModelInfo('slm', _('Chinese and English recognition'), ModelTypeConst.STT, zh_en_stt_credential,
|
||||
XFZhEnSparkSpeechToText),
|
||||
ModelInfo('tts', '', ModelTypeConst.TTS, tts_model_credential, XFSparkTextToSpeech),
|
||||
# 具体 TTS 模型
|
||||
ModelInfo('tts', _('Online TTS'), ModelTypeConst.TTS, tts_model_credential, XFSparkTextToSpeech),
|
||||
ModelInfo('tts-super-humanoid', _('Super Humanoid TTS'), ModelTypeConst.TTS, super_humanoid_tts_credential,
|
||||
XFSparkSuperHumanoidTextToSpeech),
|
||||
ModelInfo('embedding', '', ModelTypeConst.EMBEDDING, embedding_model_credential, XFEmbedding)
|
||||
]
|
||||
|
||||
|
|
@ -57,8 +68,9 @@ model_info_manage = (
|
|||
ModelInfo('iat', _('Chinese and English recognition'), ModelTypeConst.STT, stt_model_credential,
|
||||
XFSparkSpeechToText),
|
||||
)
|
||||
# default TTS 工厂入口
|
||||
.append_default_model_info(
|
||||
ModelInfo('tts', '', ModelTypeConst.TTS, tts_model_credential, XFSparkTextToSpeech))
|
||||
ModelInfo('default', _('default'), ModelTypeConst.TTS, default_tts_credential, XFSparkDefaultTextToSpeech))
|
||||
.append_default_model_info(
|
||||
ModelInfo('embedding', '', ModelTypeConst.EMBEDDING, embedding_model_credential, XFEmbedding))
|
||||
.build()
|
||||
|
|
@ -73,4 +85,4 @@ class XunFeiModelProvider(IModelProvider):
|
|||
def get_model_provide_info(self):
|
||||
return ModelProvideInfo(provider='model_xf_provider', name=_('iFlytek Spark'), icon=get_file_content(
|
||||
os.path.join(PROJECT_DIR, "apps", 'models_provider', 'impl', 'xf_model_provider', 'icon',
|
||||
'xf_icon_svg')))
|
||||
'xf_icon_svg')))
|
||||
Loading…
Reference in New Issue