mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-27 20:42:52 +00:00
153 lines
5.6 KiB
Python
153 lines
5.6 KiB
Python
# -*- coding:utf-8 -*-
|
||
#
|
||
# author: iflytek
|
||
#
|
||
# 错误码链接:https://www.xfyun.cn/document/error-code (code返回错误码时必看)
|
||
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
|
||
import asyncio
|
||
import base64
|
||
import datetime
|
||
import hashlib
|
||
import hmac
|
||
import json
|
||
import logging
|
||
import ssl
|
||
from datetime import datetime, UTC
|
||
from typing import Dict
|
||
from urllib.parse import urlencode, urlparse
|
||
|
||
import websockets
|
||
from django.utils.translation import gettext as _
|
||
|
||
from common.utils.common import _remove_empty_lines
|
||
from models_provider.base_model_provider import MaxKBBaseModel
|
||
from models_provider.impl.base_tts import BaseTextToSpeech
|
||
|
||
STATUS_FIRST_FRAME = 0 # 第一帧的标识
|
||
STATUS_CONTINUE_FRAME = 1 # 中间帧标识
|
||
STATUS_LAST_FRAME = 2 # 最后一帧的标识
|
||
|
||
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
||
ssl_context.check_hostname = False
|
||
ssl_context.verify_mode = ssl.CERT_NONE
|
||
|
||
|
||
class XFSparkTextToSpeech(MaxKBBaseModel, BaseTextToSpeech):
|
||
spark_app_id: str
|
||
spark_api_key: str
|
||
spark_api_secret: str
|
||
spark_api_url: str
|
||
params: dict
|
||
|
||
def __init__(self, **kwargs):
|
||
super().__init__(**kwargs)
|
||
self.spark_api_url = kwargs.get('spark_api_url')
|
||
self.spark_app_id = kwargs.get('spark_app_id')
|
||
self.spark_api_key = kwargs.get('spark_api_key')
|
||
self.spark_api_secret = kwargs.get('spark_api_secret')
|
||
self.params = kwargs.get('params')
|
||
|
||
@staticmethod
|
||
def is_cache_model():
|
||
return False
|
||
|
||
@staticmethod
|
||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||
optional_params = {'params': {'vcn': 'xiaoyan', 'speed': 50}}
|
||
for key, value in model_kwargs.items():
|
||
if key not in ['model_id', 'use_local', 'streaming']:
|
||
optional_params['params'][key] = value
|
||
return XFSparkTextToSpeech(
|
||
spark_app_id=model_credential.get('spark_app_id'),
|
||
spark_api_key=model_credential.get('spark_api_key'),
|
||
spark_api_secret=model_credential.get('spark_api_secret'),
|
||
spark_api_url=model_credential.get('spark_api_url'),
|
||
**optional_params
|
||
)
|
||
|
||
# 生成url
|
||
def create_url(self):
|
||
url = self.spark_api_url
|
||
host = urlparse(url).hostname
|
||
# 生成RFC1123格式的时间戳
|
||
gmt_format = '%a, %d %b %Y %H:%M:%S GMT'
|
||
date = datetime.now(UTC).strftime(gmt_format)
|
||
|
||
# 拼接字符串
|
||
signature_origin = "host: " + host + "\n"
|
||
signature_origin += "date: " + date + "\n"
|
||
signature_origin += "GET " + "/v2/tts " + "HTTP/1.1"
|
||
# 进行hmac-sha256进行加密
|
||
signature_sha = hmac.new(self.spark_api_secret.encode('utf-8'), signature_origin.encode('utf-8'),
|
||
digestmod=hashlib.sha256).digest()
|
||
signature_sha = base64.b64encode(signature_sha).decode(encoding='utf-8')
|
||
|
||
authorization_origin = "api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"" % (
|
||
self.spark_api_key, "hmac-sha256", "host date request-line", signature_sha)
|
||
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
|
||
# 将请求的鉴权参数组合为字典
|
||
v = {
|
||
"authorization": authorization,
|
||
"date": date,
|
||
"host": host
|
||
}
|
||
# 拼接鉴权参数,生成url
|
||
url = url + '?' + urlencode(v)
|
||
# print("date: ",date)
|
||
# print("v: ",v)
|
||
# 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致
|
||
# print('websocket url :', url)
|
||
return url
|
||
|
||
def check_auth(self):
|
||
self.text_to_speech(_('Hello'))
|
||
|
||
def text_to_speech(self, text):
|
||
|
||
# 使用小语种须使用以下方式,此处的unicode指的是 utf16小端的编码方式,即"UTF-16LE"”
|
||
# self.Data = {"status": 2, "text": str(base64.b64encode(self.Text.encode('utf-16')), "UTF8")}
|
||
text = _remove_empty_lines(text)
|
||
|
||
async def handle():
|
||
async with websockets.connect(self.create_url(), max_size=1000000000, ssl=ssl_context) as ws:
|
||
# 发送 full client request
|
||
await self.send(ws, text)
|
||
return await self.handle_message(ws)
|
||
|
||
return asyncio.run(handle())
|
||
|
||
def is_cache_model(self):
|
||
return False
|
||
|
||
@staticmethod
|
||
async def handle_message(ws):
|
||
audio_bytes: bytes = b''
|
||
while True:
|
||
res = await ws.recv()
|
||
message = json.loads(res)
|
||
# print(message)
|
||
code = message["code"]
|
||
sid = message["sid"]
|
||
|
||
if code != 0:
|
||
errMsg = message["message"]
|
||
raise Exception(f"sid: {sid} call error: {errMsg} code is: {code}")
|
||
else:
|
||
audio = message["data"]["audio"]
|
||
audio = base64.b64decode(audio)
|
||
audio_bytes += audio
|
||
# 退出
|
||
if message["data"]["status"] == 2:
|
||
break
|
||
return audio_bytes
|
||
|
||
async def send(self, ws, text):
|
||
business = {"aue": "lame", "sfl": 1, "auf": "audio/L16;rate=16000", "tte": "utf8"}
|
||
d = {
|
||
"common": {"app_id": self.spark_app_id},
|
||
"business": business | self.params,
|
||
"data": {"status": 2, "text": str(base64.b64encode(text.encode('utf-8')), "UTF8")},
|
||
}
|
||
d = json.dumps(d)
|
||
await ws.send(d)
|