mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-27 20:42:52 +00:00
178 lines
6.9 KiB
Python
178 lines
6.9 KiB
Python
# -*- coding:utf-8 -*-
|
||
#
|
||
# 错误码链接:https://www.xfyun.cn/document/error-code (code返回错误码时必看)
|
||
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
|
||
import asyncio
|
||
import base64
|
||
import datetime
|
||
import hashlib
|
||
import hmac
|
||
import json
|
||
import logging
|
||
import os
|
||
import ssl
|
||
from datetime import datetime, UTC
|
||
from typing import Dict
|
||
from urllib.parse import urlencode, urlparse
|
||
|
||
import websockets
|
||
|
||
from models_provider.base_model_provider import MaxKBBaseModel
|
||
from models_provider.impl.base_stt import BaseSpeechToText
|
||
|
||
STATUS_FIRST_FRAME = 0 # 第一帧的标识
|
||
STATUS_CONTINUE_FRAME = 1 # 中间帧标识
|
||
STATUS_LAST_FRAME = 2 # 最后一帧的标识
|
||
|
||
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
||
ssl_context.check_hostname = False
|
||
ssl_context.verify_mode = ssl.CERT_NONE
|
||
|
||
|
||
class XFSparkSpeechToText(MaxKBBaseModel, BaseSpeechToText):
|
||
spark_app_id: str
|
||
spark_api_key: str
|
||
spark_api_secret: str
|
||
spark_api_url: str
|
||
params: dict
|
||
|
||
def __init__(self, **kwargs):
|
||
super().__init__(**kwargs)
|
||
self.spark_api_url = kwargs.get('spark_api_url')
|
||
self.spark_app_id = kwargs.get('spark_app_id')
|
||
self.spark_api_key = kwargs.get('spark_api_key')
|
||
self.spark_api_secret = kwargs.get('spark_api_secret')
|
||
self.params = kwargs.get('params')
|
||
|
||
@staticmethod
|
||
def is_cache_model():
|
||
return False
|
||
|
||
@staticmethod
|
||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||
optional_params = {}
|
||
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
|
||
optional_params['max_tokens'] = model_kwargs['max_tokens']
|
||
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
|
||
optional_params['temperature'] = model_kwargs['temperature']
|
||
return XFSparkSpeechToText(
|
||
spark_app_id=model_credential.get('spark_app_id'),
|
||
spark_api_key=model_credential.get('spark_api_key'),
|
||
spark_api_secret=model_credential.get('spark_api_secret'),
|
||
spark_api_url=model_credential.get('spark_api_url'),
|
||
params=model_kwargs,
|
||
**optional_params
|
||
)
|
||
|
||
# 生成url
|
||
def create_url(self):
|
||
url = self.spark_api_url
|
||
host = urlparse(url).hostname
|
||
# 生成RFC1123格式的时间戳
|
||
gmt_format = '%a, %d %b %Y %H:%M:%S GMT'
|
||
date = datetime.now(UTC).strftime(gmt_format)
|
||
|
||
# 拼接字符串
|
||
signature_origin = "host: " + host + "\n"
|
||
signature_origin += "date: " + date + "\n"
|
||
signature_origin += "GET " + "/v2/iat " + "HTTP/1.1"
|
||
# 进行hmac-sha256进行加密
|
||
signature_sha = hmac.new(self.spark_api_secret.encode('utf-8'), signature_origin.encode('utf-8'),
|
||
digestmod=hashlib.sha256).digest()
|
||
signature_sha = base64.b64encode(signature_sha).decode(encoding='utf-8')
|
||
|
||
authorization_origin = "api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"" % (
|
||
self.spark_api_key, "hmac-sha256", "host date request-line", signature_sha)
|
||
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
|
||
# 将请求的鉴权参数组合为字典
|
||
v = {
|
||
"authorization": authorization,
|
||
"date": date,
|
||
"host": host
|
||
}
|
||
# 拼接鉴权参数,生成url
|
||
url = url + '?' + urlencode(v)
|
||
# print("date: ",date)
|
||
# print("v: ",v)
|
||
# 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致
|
||
# print('websocket url :', url)
|
||
return url
|
||
|
||
def check_auth(self):
|
||
cwd = os.path.dirname(os.path.abspath(__file__))
|
||
with open(f'{cwd}/iat_mp3_16k.mp3', 'rb') as f:
|
||
self.speech_to_text(f)
|
||
|
||
def speech_to_text(self, file):
|
||
async def handle():
|
||
async with websockets.connect(self.create_url(), max_size=1000000000, ssl=ssl_context) as ws:
|
||
# 发送 full client request
|
||
await self.send(ws, file)
|
||
return await self.handle_message(ws)
|
||
|
||
return asyncio.run(handle())
|
||
|
||
@staticmethod
|
||
async def handle_message(ws):
|
||
res = await ws.recv()
|
||
message = json.loads(res)
|
||
code = message["code"]
|
||
sid = message["sid"]
|
||
if code != 0:
|
||
errMsg = message["message"]
|
||
raise Exception(f"sid: {sid} call error: {errMsg} code is: {code}")
|
||
else:
|
||
data = message["data"]["result"]["ws"]
|
||
result = ""
|
||
for i in data:
|
||
for w in i["cw"]:
|
||
result += w["w"]
|
||
# print("sid:%s call success!,data is:%s" % (sid, json.dumps(data, ensure_ascii=False)))
|
||
return result
|
||
|
||
# 收到websocket连接建立的处理
|
||
async def send(self, ws, file):
|
||
frameSize = 8000 # 每一帧的音频大小
|
||
status = STATUS_FIRST_FRAME # 音频的状态信息,标识音频是第一帧,还是中间帧、最后一帧
|
||
|
||
allowed_params = {'language','domain','accent','vad_eos','dwa','pd','ptt',
|
||
'pcm','ltc','rlang','vinfo','nunum','speex_size','nbest','wbest'}
|
||
|
||
business_params = {k: v for k,v in self.params.items() if k in allowed_params}
|
||
|
||
while True:
|
||
buf = file.read(frameSize)
|
||
# 文件结束
|
||
if not buf:
|
||
status = STATUS_LAST_FRAME
|
||
# 第一帧处理
|
||
# 发送第一帧音频,带business 参数
|
||
# appid 必须带上,只需第一帧发送
|
||
if status == STATUS_FIRST_FRAME:
|
||
d = {
|
||
"common": {"app_id": self.spark_app_id},
|
||
"business": {
|
||
**business_params
|
||
},
|
||
"data": {
|
||
"status": 0, "format": "audio/L16;rate=16000",
|
||
"audio": str(base64.b64encode(buf), 'utf-8'),
|
||
"encoding": "lame"}
|
||
}
|
||
d = json.dumps(d)
|
||
await ws.send(d)
|
||
status = STATUS_CONTINUE_FRAME
|
||
# 中间帧处理
|
||
elif status == STATUS_CONTINUE_FRAME:
|
||
d = {"data": {"status": 1, "format": "audio/L16;rate=16000",
|
||
"audio": str(base64.b64encode(buf), 'utf-8'),
|
||
"encoding": "lame"}}
|
||
await ws.send(json.dumps(d))
|
||
# 最后一帧处理
|
||
elif status == STATUS_LAST_FRAME:
|
||
d = {"data": {"status": 2, "format": "audio/L16;rate=16000",
|
||
"audio": str(base64.b64encode(buf), 'utf-8'),
|
||
"encoding": "lame"}}
|
||
await ws.send(json.dumps(d))
|
||
break
|