diff --git a/apps/setting/models_provider/impl/xf_model_provider/model/iat_mp3_16k.mp3 b/apps/setting/models_provider/impl/xf_model_provider/model/iat_mp3_16k.mp3 new file mode 100644 index 000000000..75e744c8f Binary files /dev/null and b/apps/setting/models_provider/impl/xf_model_provider/model/iat_mp3_16k.mp3 differ diff --git a/apps/setting/models_provider/impl/xf_model_provider/model/stt.py b/apps/setting/models_provider/impl/xf_model_provider/model/stt.py index f57e6bf13..f400473ed 100644 --- a/apps/setting/models_provider/impl/xf_model_provider/model/stt.py +++ b/apps/setting/models_provider/impl/xf_model_provider/model/stt.py @@ -8,6 +8,8 @@ import datetime import hashlib import hmac import json +import logging +import os from datetime import datetime from typing import Dict from urllib.parse import urlencode, urlparse @@ -25,6 +27,7 @@ ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ssl_context.check_hostname = False ssl_context.verify_mode = ssl.CERT_NONE +max_kb = logging.getLogger("max_kb") class XFSparkSpeechToText(MaxKBBaseModel, BaseSpeechToText): spark_app_id: str @@ -89,11 +92,9 @@ class XFSparkSpeechToText(MaxKBBaseModel, BaseSpeechToText): return url def check_auth(self): - async def check(): - async with websockets.connect(self.create_url(), ssl=ssl_context) as ws: - pass - - asyncio.run(check()) + cwd = os.path.dirname(os.path.abspath(__file__)) + with open(f'{cwd}/iat_mp3_16k.mp3', 'rb') as f: + self.speech_to_text(f) def speech_to_text(self, file): async def handle(): @@ -112,8 +113,7 @@ class XFSparkSpeechToText(MaxKBBaseModel, BaseSpeechToText): sid = message["sid"] if code != 0: errMsg = message["message"] - print("sid:%s call error:%s code is:%s" % (sid, errMsg, code)) - return errMsg + raise Exception(f"sid: {sid} call error: {errMsg} code is: {code}") else: data = message["data"]["result"]["ws"] result = "" diff --git a/apps/setting/models_provider/impl/xf_model_provider/model/tts.py b/apps/setting/models_provider/impl/xf_model_provider/model/tts.py index 8d7755832..004b78858 100644 --- a/apps/setting/models_provider/impl/xf_model_provider/model/tts.py +++ b/apps/setting/models_provider/impl/xf_model_provider/model/tts.py @@ -10,6 +10,7 @@ import datetime import hashlib import hmac import json +import logging import os from datetime import datetime from typing import Dict @@ -20,6 +21,8 @@ import websockets from setting.models_provider.base_model_provider import MaxKBBaseModel from setting.models_provider.impl.base_tts import BaseTextToSpeech +max_kb = logging.getLogger("max_kb") + STATUS_FIRST_FRAME = 0 # 第一帧的标识 STATUS_CONTINUE_FRAME = 1 # 中间帧标识 STATUS_LAST_FRAME = 2 # 最后一帧的标识 @@ -92,11 +95,7 @@ class XFSparkTextToSpeech(MaxKBBaseModel, BaseTextToSpeech): return url def check_auth(self): - async def check(): - async with websockets.connect(self.create_url(), max_size=1000000000, ssl=ssl_context) as ws: - pass - - asyncio.run(check()) + self.text_to_speech("你好") def text_to_speech(self, text): @@ -119,13 +118,13 @@ class XFSparkTextToSpeech(MaxKBBaseModel, BaseTextToSpeech): # print(message) code = message["code"] sid = message["sid"] - audio = message["data"]["audio"] - audio = base64.b64decode(audio) if code != 0: errMsg = message["message"] - print("sid:%s call error:%s code is:%s" % (sid, errMsg, code)) + raise Exception(f"sid: {sid} call error: {errMsg} code is: {code}") else: + audio = message["data"]["audio"] + audio = base64.b64decode(audio) audio_bytes += audio # 退出 if message["data"]["status"] == 2: