feat: Support iFLYTEK large model for Chinese-English speech recognition

This commit is contained in:
zhangzhanwei 2025-08-28 10:54:15 +08:00 committed by zhanweizhang7
parent f9f96fd2cd
commit 4786970689
3 changed files with 258 additions and 2 deletions

View File

@ -0,0 +1,56 @@
# coding=utf-8
import traceback
from typing import Dict
from django.utils.translation import gettext as _
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from models_provider.base_model_provider import BaseModelCredential, ValidCode
class ZhEnXunFeiSTTModelCredential(BaseForm, BaseModelCredential):
spark_api_url = forms.TextInputField('API URL', required=True, default_value='wss://iat.xf-yun.com/v1')
spark_app_id = forms.TextInputField('APP ID', required=True)
spark_api_key = forms.PasswordInputField("API Key", required=True)
spark_api_secret = forms.PasswordInputField('API Secret', required=True)
def is_valid(self,
model_type: str,
model_name,
model_credential: Dict[str, object],
model_params, provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value,
_('{model_type} Model type is not supported').format(model_type=model_type))
for key in ['spark_api_url', 'spark_app_id', 'spark_api_key', 'spark_api_secret']:
if key not in model_credential:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, _('{key} is required').format(key=key))
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
model.check_auth()
except Exception as e:
traceback.print_exc()
if isinstance(e, AppApiException):
raise e
if raise_exception:
raise AppApiException(ValidCode.valid_error.value,
_('Verification failed, please check whether the parameters are correct: {error}').format(
error=str(e)))
else:
return False
return True
def encryption_dict(self, model: Dict[str, object]):
return {**model, 'spark_api_secret': super().encryption(model.get('spark_api_secret', ''))}
def get_model_params_setting_form(self, model_name):
pass

View File

@ -0,0 +1,192 @@
import asyncio
import json
import base64
import hmac
import hashlib
import ssl
import traceback
from typing import Dict
from urllib.parse import urlencode
from datetime import datetime, timezone, UTC
import websockets
import os
from future.backports.urllib.parse import urlparse
from common.utils.logger import maxkb_logger
from models_provider.base_model_provider import MaxKBBaseModel
from models_provider.impl.base_stt import BaseSpeechToText
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
class XFZhEnSparkSpeechToText(MaxKBBaseModel, BaseSpeechToText):
spark_app_id: str
spark_api_key: str
spark_api_secret: str
spark_api_url: str
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.spark_api_url = kwargs.get('spark_api_url')
self.spark_app_id = kwargs.get('spark_app_id')
self.spark_api_key = kwargs.get('spark_api_key')
self.spark_api_secret = kwargs.get('spark_api_secret')
@staticmethod
def is_cache_model():
return False
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {}
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
optional_params['max_tokens'] = model_kwargs['max_tokens']
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
optional_params['temperature'] = model_kwargs['temperature']
return XFZhEnSparkSpeechToText(
spark_app_id=model_credential.get('spark_app_id'),
spark_api_key=model_credential.get('spark_api_key'),
spark_api_secret=model_credential.get('spark_api_secret'),
spark_api_url=model_credential.get('spark_api_url'),
**optional_params
)
# 生成url
def create_url(self):
url = self.spark_api_url
host = urlparse(url).hostname
gmt_format = '%a, %d %b %Y %H:%M:%S GMT'
date = datetime.now(UTC).strftime(gmt_format)
# 拼接字符串
signature_origin = "host: " + host + "\n"
signature_origin += "date: " + date + "\n"
signature_origin += "GET " + "/v1 HTTP/1.1"
# 进行hmac-sha256进行加密
signature_sha = hmac.new(
self.spark_api_secret.encode('utf-8'),
signature_origin.encode('utf-8'),
hashlib.sha256
).digest()
signature = base64.b64encode(signature_sha).decode(encoding='utf-8')
authorization_origin = (
f'api_key="{self.spark_api_key}", algorithm="hmac-sha256", '
f'headers="host date request-line", signature="{signature}"'
)
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
params = {
'authorization': authorization,
'date': date,
'host': host
}
auth_url = url + '?' + urlencode(params)
return auth_url
def check_auth(self):
cwd = os.path.dirname(os.path.abspath(__file__))
with open(f'{cwd}/iat_mp3_16k.mp3', 'rb') as f:
self.speech_to_text(f)
def speech_to_text(self, audio_file_path):
async def handle():
async with websockets.connect(self.create_url(), max_size=1000000000, ssl=ssl_context) as ws:
# print("连接成功")
# 发送音频数据
await self.send_audio(ws, audio_file_path)
# 接收识别结果
return await self.handle_message(ws)
try:
return asyncio.run(handle())
except Exception as err:
maxkb_logger.error(f"语音识别错误: {str(err)}: {traceback.format_exc()}")
return ""
async def send_audio(self, ws, audio_file):
"""发送音频数据"""
chunk_size = 4000
seq = 1
max_chunks = 10000
while True:
chunk = audio_file.read(chunk_size)
if not chunk or seq > max_chunks:
break
chunk_base64 = base64.b64encode(chunk).decode('utf-8')
# 第一帧
if seq == 1:
frame = {
"header": {"app_id": self.spark_app_id, "status": 0},
"parameter": {
"iat": {
"domain": "slm", "language": "zh_cn", "accent": "mandarin",
"eos": 10000, "vinfo": 1,
"result": {"encoding": "utf8", "compress": "raw", "format": "json"}
}
},
"payload": {
"audio": {
"encoding": "lame", "sample_rate": 16000, "channels": 1,
"bit_depth": 16, "seq": seq, "status": 0, "audio": chunk_base64
}
}
}
# 中间帧
else:
frame = {
"header": {"app_id": self.spark_app_id, "status": 1},
"payload": {
"audio": {
"encoding": "lame", "sample_rate": 16000, "channels": 1,
"bit_depth": 16, "seq": seq, "status": 1, "audio": chunk_base64
}
}
}
await ws.send(json.dumps(frame))
seq += 1
# 发送结束帧
end_frame = {
"header": {"app_id": self.spark_app_id, "status": 2},
"payload": {
"audio": {
"encoding": "lame", "sample_rate": 16000, "channels": 1,
"bit_depth": 16, "seq": seq, "status": 2, "audio": ""
}
}
}
await ws.send(json.dumps(end_frame))
# 接受信息处理器
async def handle_message(self, ws):
result_text = ""
while True:
try:
message = await asyncio.wait_for(ws.recv(), timeout=30.0)
data = json.loads(message)
if data['header']['code'] != 0:
raise Exception("")
if 'payload' in data and 'result' in data['payload']:
result = data['payload']['result']
text = result.get('text', '')
if text:
text_data = json.loads(base64.b64decode(text).decode('utf-8'))
for ws_item in text_data.get('ws', []):
for cw in ws_item.get('cw', []):
for sw in cw.get('sw', []):
result_text += sw['w']
if data['header'].get('status') == 2:
break
except asyncio.TimeoutError:
break
return result_text

View File

@ -17,6 +17,7 @@ from models_provider.impl.xf_model_provider.credential.image import XunFeiImageM
from models_provider.impl.xf_model_provider.credential.llm import XunFeiLLMModelCredential
from models_provider.impl.xf_model_provider.credential.stt import XunFeiSTTModelCredential
from models_provider.impl.xf_model_provider.credential.tts import XunFeiTTSModelCredential
from models_provider.impl.xf_model_provider.credential.zh_en_stt import ZhEnXunFeiSTTModelCredential
from models_provider.impl.xf_model_provider.model.embedding import XFEmbedding
from models_provider.impl.xf_model_provider.model.image import XFSparkImage
from models_provider.impl.xf_model_provider.model.llm import XFChatSparkLLM
@ -25,10 +26,13 @@ from models_provider.impl.xf_model_provider.model.tts import XFSparkTextToSpeech
from maxkb.conf import PROJECT_DIR
from django.utils.translation import gettext as _
from models_provider.impl.xf_model_provider.model.zh_en_stt import XFZhEnSparkSpeechToText
ssl._create_default_https_context = ssl.create_default_context()
xunfei_model_credential = XunFeiLLMModelCredential()
stt_model_credential = XunFeiSTTModelCredential()
zh_en_stt_credential = ZhEnXunFeiSTTModelCredential()
image_model_credential = XunFeiImageModelCredential()
tts_model_credential = XunFeiTTSModelCredential()
embedding_model_credential = XFEmbeddingCredential()
@ -36,7 +40,10 @@ model_info_list = [
ModelInfo('generalv3.5', '', ModelTypeConst.LLM, xunfei_model_credential, XFChatSparkLLM),
ModelInfo('generalv3', '', ModelTypeConst.LLM, xunfei_model_credential, XFChatSparkLLM),
ModelInfo('generalv2', '', ModelTypeConst.LLM, xunfei_model_credential, XFChatSparkLLM),
ModelInfo('iat', _('Chinese and English recognition'), ModelTypeConst.STT, stt_model_credential, XFSparkSpeechToText),
ModelInfo('iat', _('Chinese and English recognition'), ModelTypeConst.STT, stt_model_credential,
XFSparkSpeechToText),
ModelInfo('slm', _('Chinese and English recognition'), ModelTypeConst.STT, zh_en_stt_credential,
XFZhEnSparkSpeechToText),
ModelInfo('tts', '', ModelTypeConst.TTS, tts_model_credential, XFSparkTextToSpeech),
ModelInfo('embedding', '', ModelTypeConst.EMBEDDING, embedding_model_credential, XFEmbedding)
]
@ -47,7 +54,8 @@ model_info_manage = (
.append_default_model_info(
ModelInfo('generalv3.5', '', ModelTypeConst.LLM, xunfei_model_credential, XFChatSparkLLM))
.append_default_model_info(
ModelInfo('iat', _('Chinese and English recognition'), ModelTypeConst.STT, stt_model_credential, XFSparkSpeechToText),
ModelInfo('iat', _('Chinese and English recognition'), ModelTypeConst.STT, stt_model_credential,
XFSparkSpeechToText),
)
.append_default_model_info(
ModelInfo('tts', '', ModelTypeConst.TTS, tts_model_credential, XFSparkTextToSpeech))