mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 01:33:05 +00:00
feat: Support iFLYTEK large model for Chinese-English speech recognition
This commit is contained in:
parent
f9f96fd2cd
commit
4786970689
|
|
@ -0,0 +1,56 @@
|
|||
# coding=utf-8
|
||||
import traceback
|
||||
from typing import Dict
|
||||
|
||||
from django.utils.translation import gettext as _
|
||||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm
|
||||
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
|
||||
|
||||
class ZhEnXunFeiSTTModelCredential(BaseForm, BaseModelCredential):
|
||||
spark_api_url = forms.TextInputField('API URL', required=True, default_value='wss://iat.xf-yun.com/v1')
|
||||
spark_app_id = forms.TextInputField('APP ID', required=True)
|
||||
spark_api_key = forms.PasswordInputField("API Key", required=True)
|
||||
spark_api_secret = forms.PasswordInputField('API Secret', required=True)
|
||||
|
||||
def is_valid(self,
|
||||
model_type: str,
|
||||
model_name,
|
||||
model_credential: Dict[str, object],
|
||||
model_params, provider,
|
||||
raise_exception=False):
|
||||
model_type_list = provider.get_model_type_list()
|
||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||
raise AppApiException(ValidCode.valid_error.value,
|
||||
_('{model_type} Model type is not supported').format(model_type=model_type))
|
||||
|
||||
for key in ['spark_api_url', 'spark_app_id', 'spark_api_key', 'spark_api_secret']:
|
||||
if key not in model_credential:
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, _('{key} is required').format(key=key))
|
||||
else:
|
||||
return False
|
||||
try:
|
||||
model = provider.get_model(model_type, model_name, model_credential)
|
||||
model.check_auth()
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value,
|
||||
_('Verification failed, please check whether the parameters are correct: {error}').format(
|
||||
error=str(e)))
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
def encryption_dict(self, model: Dict[str, object]):
|
||||
return {**model, 'spark_api_secret': super().encryption(model.get('spark_api_secret', ''))}
|
||||
|
||||
def get_model_params_setting_form(self, model_name):
|
||||
pass
|
||||
|
|
@ -0,0 +1,192 @@
|
|||
import asyncio
|
||||
import json
|
||||
import base64
|
||||
import hmac
|
||||
import hashlib
|
||||
import ssl
|
||||
import traceback
|
||||
from typing import Dict
|
||||
from urllib.parse import urlencode
|
||||
from datetime import datetime, timezone, UTC
|
||||
import websockets
|
||||
import os
|
||||
|
||||
from future.backports.urllib.parse import urlparse
|
||||
|
||||
from common.utils.logger import maxkb_logger
|
||||
from models_provider.base_model_provider import MaxKBBaseModel
|
||||
from models_provider.impl.base_stt import BaseSpeechToText
|
||||
|
||||
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
||||
ssl_context.check_hostname = False
|
||||
ssl_context.verify_mode = ssl.CERT_NONE
|
||||
|
||||
|
||||
class XFZhEnSparkSpeechToText(MaxKBBaseModel, BaseSpeechToText):
|
||||
spark_app_id: str
|
||||
spark_api_key: str
|
||||
spark_api_secret: str
|
||||
spark_api_url: str
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.spark_api_url = kwargs.get('spark_api_url')
|
||||
self.spark_app_id = kwargs.get('spark_app_id')
|
||||
self.spark_api_key = kwargs.get('spark_api_key')
|
||||
self.spark_api_secret = kwargs.get('spark_api_secret')
|
||||
|
||||
@staticmethod
|
||||
def is_cache_model():
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
optional_params = {}
|
||||
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
|
||||
optional_params['max_tokens'] = model_kwargs['max_tokens']
|
||||
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
|
||||
optional_params['temperature'] = model_kwargs['temperature']
|
||||
return XFZhEnSparkSpeechToText(
|
||||
spark_app_id=model_credential.get('spark_app_id'),
|
||||
spark_api_key=model_credential.get('spark_api_key'),
|
||||
spark_api_secret=model_credential.get('spark_api_secret'),
|
||||
spark_api_url=model_credential.get('spark_api_url'),
|
||||
**optional_params
|
||||
)
|
||||
|
||||
# 生成url
|
||||
def create_url(self):
|
||||
url = self.spark_api_url
|
||||
host = urlparse(url).hostname
|
||||
|
||||
gmt_format = '%a, %d %b %Y %H:%M:%S GMT'
|
||||
date = datetime.now(UTC).strftime(gmt_format)
|
||||
# 拼接字符串
|
||||
signature_origin = "host: " + host + "\n"
|
||||
signature_origin += "date: " + date + "\n"
|
||||
signature_origin += "GET " + "/v1 HTTP/1.1"
|
||||
# 进行hmac-sha256进行加密
|
||||
signature_sha = hmac.new(
|
||||
self.spark_api_secret.encode('utf-8'),
|
||||
signature_origin.encode('utf-8'),
|
||||
hashlib.sha256
|
||||
).digest()
|
||||
signature = base64.b64encode(signature_sha).decode(encoding='utf-8')
|
||||
|
||||
authorization_origin = (
|
||||
f'api_key="{self.spark_api_key}", algorithm="hmac-sha256", '
|
||||
f'headers="host date request-line", signature="{signature}"'
|
||||
)
|
||||
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
|
||||
|
||||
params = {
|
||||
'authorization': authorization,
|
||||
'date': date,
|
||||
'host': host
|
||||
}
|
||||
auth_url = url + '?' + urlencode(params)
|
||||
return auth_url
|
||||
|
||||
def check_auth(self):
|
||||
cwd = os.path.dirname(os.path.abspath(__file__))
|
||||
with open(f'{cwd}/iat_mp3_16k.mp3', 'rb') as f:
|
||||
self.speech_to_text(f)
|
||||
|
||||
def speech_to_text(self, audio_file_path):
|
||||
async def handle():
|
||||
async with websockets.connect(self.create_url(), max_size=1000000000, ssl=ssl_context) as ws:
|
||||
# print("连接成功")
|
||||
# 发送音频数据
|
||||
await self.send_audio(ws, audio_file_path)
|
||||
# 接收识别结果
|
||||
return await self.handle_message(ws)
|
||||
try:
|
||||
return asyncio.run(handle())
|
||||
except Exception as err:
|
||||
maxkb_logger.error(f"语音识别错误: {str(err)}: {traceback.format_exc()}")
|
||||
return ""
|
||||
|
||||
async def send_audio(self, ws, audio_file):
|
||||
"""发送音频数据"""
|
||||
chunk_size = 4000
|
||||
seq = 1
|
||||
max_chunks = 10000
|
||||
while True:
|
||||
chunk = audio_file.read(chunk_size)
|
||||
if not chunk or seq > max_chunks:
|
||||
break
|
||||
|
||||
chunk_base64 = base64.b64encode(chunk).decode('utf-8')
|
||||
# 第一帧
|
||||
if seq == 1:
|
||||
frame = {
|
||||
"header": {"app_id": self.spark_app_id, "status": 0},
|
||||
"parameter": {
|
||||
"iat": {
|
||||
"domain": "slm", "language": "zh_cn", "accent": "mandarin",
|
||||
"eos": 10000, "vinfo": 1,
|
||||
"result": {"encoding": "utf8", "compress": "raw", "format": "json"}
|
||||
}
|
||||
},
|
||||
"payload": {
|
||||
"audio": {
|
||||
"encoding": "lame", "sample_rate": 16000, "channels": 1,
|
||||
"bit_depth": 16, "seq": seq, "status": 0, "audio": chunk_base64
|
||||
}
|
||||
}
|
||||
}
|
||||
# 中间帧
|
||||
else:
|
||||
frame = {
|
||||
"header": {"app_id": self.spark_app_id, "status": 1},
|
||||
"payload": {
|
||||
"audio": {
|
||||
"encoding": "lame", "sample_rate": 16000, "channels": 1,
|
||||
"bit_depth": 16, "seq": seq, "status": 1, "audio": chunk_base64
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
await ws.send(json.dumps(frame))
|
||||
seq += 1
|
||||
|
||||
# 发送结束帧
|
||||
end_frame = {
|
||||
"header": {"app_id": self.spark_app_id, "status": 2},
|
||||
"payload": {
|
||||
"audio": {
|
||||
"encoding": "lame", "sample_rate": 16000, "channels": 1,
|
||||
"bit_depth": 16, "seq": seq, "status": 2, "audio": ""
|
||||
}
|
||||
}
|
||||
}
|
||||
await ws.send(json.dumps(end_frame))
|
||||
|
||||
|
||||
# 接受信息处理器
|
||||
async def handle_message(self, ws):
|
||||
result_text = ""
|
||||
while True:
|
||||
try:
|
||||
message = await asyncio.wait_for(ws.recv(), timeout=30.0)
|
||||
data = json.loads(message)
|
||||
|
||||
if data['header']['code'] != 0:
|
||||
raise Exception("")
|
||||
|
||||
if 'payload' in data and 'result' in data['payload']:
|
||||
result = data['payload']['result']
|
||||
text = result.get('text', '')
|
||||
if text:
|
||||
text_data = json.loads(base64.b64decode(text).decode('utf-8'))
|
||||
for ws_item in text_data.get('ws', []):
|
||||
for cw in ws_item.get('cw', []):
|
||||
for sw in cw.get('sw', []):
|
||||
result_text += sw['w']
|
||||
|
||||
if data['header'].get('status') == 2:
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
break
|
||||
|
||||
return result_text
|
||||
|
|
@ -17,6 +17,7 @@ from models_provider.impl.xf_model_provider.credential.image import XunFeiImageM
|
|||
from models_provider.impl.xf_model_provider.credential.llm import XunFeiLLMModelCredential
|
||||
from models_provider.impl.xf_model_provider.credential.stt import XunFeiSTTModelCredential
|
||||
from models_provider.impl.xf_model_provider.credential.tts import XunFeiTTSModelCredential
|
||||
from models_provider.impl.xf_model_provider.credential.zh_en_stt import ZhEnXunFeiSTTModelCredential
|
||||
from models_provider.impl.xf_model_provider.model.embedding import XFEmbedding
|
||||
from models_provider.impl.xf_model_provider.model.image import XFSparkImage
|
||||
from models_provider.impl.xf_model_provider.model.llm import XFChatSparkLLM
|
||||
|
|
@ -25,10 +26,13 @@ from models_provider.impl.xf_model_provider.model.tts import XFSparkTextToSpeech
|
|||
from maxkb.conf import PROJECT_DIR
|
||||
from django.utils.translation import gettext as _
|
||||
|
||||
from models_provider.impl.xf_model_provider.model.zh_en_stt import XFZhEnSparkSpeechToText
|
||||
|
||||
ssl._create_default_https_context = ssl.create_default_context()
|
||||
|
||||
xunfei_model_credential = XunFeiLLMModelCredential()
|
||||
stt_model_credential = XunFeiSTTModelCredential()
|
||||
zh_en_stt_credential = ZhEnXunFeiSTTModelCredential()
|
||||
image_model_credential = XunFeiImageModelCredential()
|
||||
tts_model_credential = XunFeiTTSModelCredential()
|
||||
embedding_model_credential = XFEmbeddingCredential()
|
||||
|
|
@ -36,7 +40,10 @@ model_info_list = [
|
|||
ModelInfo('generalv3.5', '', ModelTypeConst.LLM, xunfei_model_credential, XFChatSparkLLM),
|
||||
ModelInfo('generalv3', '', ModelTypeConst.LLM, xunfei_model_credential, XFChatSparkLLM),
|
||||
ModelInfo('generalv2', '', ModelTypeConst.LLM, xunfei_model_credential, XFChatSparkLLM),
|
||||
ModelInfo('iat', _('Chinese and English recognition'), ModelTypeConst.STT, stt_model_credential, XFSparkSpeechToText),
|
||||
ModelInfo('iat', _('Chinese and English recognition'), ModelTypeConst.STT, stt_model_credential,
|
||||
XFSparkSpeechToText),
|
||||
ModelInfo('slm', _('Chinese and English recognition'), ModelTypeConst.STT, zh_en_stt_credential,
|
||||
XFZhEnSparkSpeechToText),
|
||||
ModelInfo('tts', '', ModelTypeConst.TTS, tts_model_credential, XFSparkTextToSpeech),
|
||||
ModelInfo('embedding', '', ModelTypeConst.EMBEDDING, embedding_model_credential, XFEmbedding)
|
||||
]
|
||||
|
|
@ -47,7 +54,8 @@ model_info_manage = (
|
|||
.append_default_model_info(
|
||||
ModelInfo('generalv3.5', '', ModelTypeConst.LLM, xunfei_model_credential, XFChatSparkLLM))
|
||||
.append_default_model_info(
|
||||
ModelInfo('iat', _('Chinese and English recognition'), ModelTypeConst.STT, stt_model_credential, XFSparkSpeechToText),
|
||||
ModelInfo('iat', _('Chinese and English recognition'), ModelTypeConst.STT, stt_model_credential,
|
||||
XFSparkSpeechToText),
|
||||
)
|
||||
.append_default_model_info(
|
||||
ModelInfo('tts', '', ModelTypeConst.TTS, tts_model_credential, XFSparkTextToSpeech))
|
||||
|
|
|
|||
Loading…
Reference in New Issue