From 4786970689fcbf7571a2b3eaded3d41fba366bac Mon Sep 17 00:00:00 2001 From: zhangzhanwei Date: Thu, 28 Aug 2025 10:54:15 +0800 Subject: [PATCH] feat: Support iFLYTEK large model for Chinese-English speech recognition --- .../xf_model_provider/credential/zh_en_stt.py | 56 +++++ .../impl/xf_model_provider/model/zh_en_stt.py | 192 ++++++++++++++++++ .../xf_model_provider/xf_model_provider.py | 12 +- 3 files changed, 258 insertions(+), 2 deletions(-) create mode 100644 apps/models_provider/impl/xf_model_provider/credential/zh_en_stt.py create mode 100644 apps/models_provider/impl/xf_model_provider/model/zh_en_stt.py diff --git a/apps/models_provider/impl/xf_model_provider/credential/zh_en_stt.py b/apps/models_provider/impl/xf_model_provider/credential/zh_en_stt.py new file mode 100644 index 000000000..05c0e4e17 --- /dev/null +++ b/apps/models_provider/impl/xf_model_provider/credential/zh_en_stt.py @@ -0,0 +1,56 @@ +# coding=utf-8 +import traceback +from typing import Dict + +from django.utils.translation import gettext as _ + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm +from models_provider.base_model_provider import BaseModelCredential, ValidCode + + + +class ZhEnXunFeiSTTModelCredential(BaseForm, BaseModelCredential): + spark_api_url = forms.TextInputField('API URL', required=True, default_value='wss://iat.xf-yun.com/v1') + spark_app_id = forms.TextInputField('APP ID', required=True) + spark_api_key = forms.PasswordInputField("API Key", required=True) + spark_api_secret = forms.PasswordInputField('API Secret', required=True) + + def is_valid(self, + model_type: str, + model_name, + model_credential: Dict[str, object], + model_params, provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, + _('{model_type} Model type is not supported').format(model_type=model_type)) + + for key in ['spark_api_url', 'spark_app_id', 'spark_api_key', 'spark_api_secret']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, _('{key} is required').format(key=key)) + else: + return False + try: + model = provider.get_model(model_type, model_name, model_credential) + model.check_auth() + except Exception as e: + traceback.print_exc() + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, + _('Verification failed, please check whether the parameters are correct: {error}').format( + error=str(e))) + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'spark_api_secret': super().encryption(model.get('spark_api_secret', ''))} + + def get_model_params_setting_form(self, model_name): + pass \ No newline at end of file diff --git a/apps/models_provider/impl/xf_model_provider/model/zh_en_stt.py b/apps/models_provider/impl/xf_model_provider/model/zh_en_stt.py new file mode 100644 index 000000000..fb1498fca --- /dev/null +++ b/apps/models_provider/impl/xf_model_provider/model/zh_en_stt.py @@ -0,0 +1,192 @@ +import asyncio +import json +import base64 +import hmac +import hashlib +import ssl +import traceback +from typing import Dict +from urllib.parse import urlencode +from datetime import datetime, timezone, UTC +import websockets +import os + +from future.backports.urllib.parse import urlparse + +from common.utils.logger import maxkb_logger +from models_provider.base_model_provider import MaxKBBaseModel +from models_provider.impl.base_stt import BaseSpeechToText + +ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) +ssl_context.check_hostname = False +ssl_context.verify_mode = ssl.CERT_NONE + + +class XFZhEnSparkSpeechToText(MaxKBBaseModel, BaseSpeechToText): + spark_app_id: str + spark_api_key: str + spark_api_secret: str + spark_api_url: str + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.spark_api_url = kwargs.get('spark_api_url') + self.spark_app_id = kwargs.get('spark_app_id') + self.spark_api_key = kwargs.get('spark_api_key') + self.spark_api_secret = kwargs.get('spark_api_secret') + + @staticmethod + def is_cache_model(): + return False + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + optional_params = {} + if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None: + optional_params['max_tokens'] = model_kwargs['max_tokens'] + if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None: + optional_params['temperature'] = model_kwargs['temperature'] + return XFZhEnSparkSpeechToText( + spark_app_id=model_credential.get('spark_app_id'), + spark_api_key=model_credential.get('spark_api_key'), + spark_api_secret=model_credential.get('spark_api_secret'), + spark_api_url=model_credential.get('spark_api_url'), + **optional_params + ) + + # 生成url + def create_url(self): + url = self.spark_api_url + host = urlparse(url).hostname + + gmt_format = '%a, %d %b %Y %H:%M:%S GMT' + date = datetime.now(UTC).strftime(gmt_format) + # 拼接字符串 + signature_origin = "host: " + host + "\n" + signature_origin += "date: " + date + "\n" + signature_origin += "GET " + "/v1 HTTP/1.1" + # 进行hmac-sha256进行加密 + signature_sha = hmac.new( + self.spark_api_secret.encode('utf-8'), + signature_origin.encode('utf-8'), + hashlib.sha256 + ).digest() + signature = base64.b64encode(signature_sha).decode(encoding='utf-8') + + authorization_origin = ( + f'api_key="{self.spark_api_key}", algorithm="hmac-sha256", ' + f'headers="host date request-line", signature="{signature}"' + ) + authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8') + + params = { + 'authorization': authorization, + 'date': date, + 'host': host + } + auth_url = url + '?' + urlencode(params) + return auth_url + + def check_auth(self): + cwd = os.path.dirname(os.path.abspath(__file__)) + with open(f'{cwd}/iat_mp3_16k.mp3', 'rb') as f: + self.speech_to_text(f) + + def speech_to_text(self, audio_file_path): + async def handle(): + async with websockets.connect(self.create_url(), max_size=1000000000, ssl=ssl_context) as ws: + # print("连接成功") + # 发送音频数据 + await self.send_audio(ws, audio_file_path) + # 接收识别结果 + return await self.handle_message(ws) + try: + return asyncio.run(handle()) + except Exception as err: + maxkb_logger.error(f"语音识别错误: {str(err)}: {traceback.format_exc()}") + return "" + + async def send_audio(self, ws, audio_file): + """发送音频数据""" + chunk_size = 4000 + seq = 1 + max_chunks = 10000 + while True: + chunk = audio_file.read(chunk_size) + if not chunk or seq > max_chunks: + break + + chunk_base64 = base64.b64encode(chunk).decode('utf-8') + # 第一帧 + if seq == 1: + frame = { + "header": {"app_id": self.spark_app_id, "status": 0}, + "parameter": { + "iat": { + "domain": "slm", "language": "zh_cn", "accent": "mandarin", + "eos": 10000, "vinfo": 1, + "result": {"encoding": "utf8", "compress": "raw", "format": "json"} + } + }, + "payload": { + "audio": { + "encoding": "lame", "sample_rate": 16000, "channels": 1, + "bit_depth": 16, "seq": seq, "status": 0, "audio": chunk_base64 + } + } + } + # 中间帧 + else: + frame = { + "header": {"app_id": self.spark_app_id, "status": 1}, + "payload": { + "audio": { + "encoding": "lame", "sample_rate": 16000, "channels": 1, + "bit_depth": 16, "seq": seq, "status": 1, "audio": chunk_base64 + } + } + } + + await ws.send(json.dumps(frame)) + seq += 1 + + # 发送结束帧 + end_frame = { + "header": {"app_id": self.spark_app_id, "status": 2}, + "payload": { + "audio": { + "encoding": "lame", "sample_rate": 16000, "channels": 1, + "bit_depth": 16, "seq": seq, "status": 2, "audio": "" + } + } + } + await ws.send(json.dumps(end_frame)) + + +# 接受信息处理器 + async def handle_message(self, ws): + result_text = "" + while True: + try: + message = await asyncio.wait_for(ws.recv(), timeout=30.0) + data = json.loads(message) + + if data['header']['code'] != 0: + raise Exception("") + + if 'payload' in data and 'result' in data['payload']: + result = data['payload']['result'] + text = result.get('text', '') + if text: + text_data = json.loads(base64.b64decode(text).decode('utf-8')) + for ws_item in text_data.get('ws', []): + for cw in ws_item.get('cw', []): + for sw in cw.get('sw', []): + result_text += sw['w'] + + if data['header'].get('status') == 2: + break + except asyncio.TimeoutError: + break + + return result_text diff --git a/apps/models_provider/impl/xf_model_provider/xf_model_provider.py b/apps/models_provider/impl/xf_model_provider/xf_model_provider.py index 7bcf4fcfb..4a576e9aa 100644 --- a/apps/models_provider/impl/xf_model_provider/xf_model_provider.py +++ b/apps/models_provider/impl/xf_model_provider/xf_model_provider.py @@ -17,6 +17,7 @@ from models_provider.impl.xf_model_provider.credential.image import XunFeiImageM from models_provider.impl.xf_model_provider.credential.llm import XunFeiLLMModelCredential from models_provider.impl.xf_model_provider.credential.stt import XunFeiSTTModelCredential from models_provider.impl.xf_model_provider.credential.tts import XunFeiTTSModelCredential +from models_provider.impl.xf_model_provider.credential.zh_en_stt import ZhEnXunFeiSTTModelCredential from models_provider.impl.xf_model_provider.model.embedding import XFEmbedding from models_provider.impl.xf_model_provider.model.image import XFSparkImage from models_provider.impl.xf_model_provider.model.llm import XFChatSparkLLM @@ -25,10 +26,13 @@ from models_provider.impl.xf_model_provider.model.tts import XFSparkTextToSpeech from maxkb.conf import PROJECT_DIR from django.utils.translation import gettext as _ +from models_provider.impl.xf_model_provider.model.zh_en_stt import XFZhEnSparkSpeechToText + ssl._create_default_https_context = ssl.create_default_context() xunfei_model_credential = XunFeiLLMModelCredential() stt_model_credential = XunFeiSTTModelCredential() +zh_en_stt_credential = ZhEnXunFeiSTTModelCredential() image_model_credential = XunFeiImageModelCredential() tts_model_credential = XunFeiTTSModelCredential() embedding_model_credential = XFEmbeddingCredential() @@ -36,7 +40,10 @@ model_info_list = [ ModelInfo('generalv3.5', '', ModelTypeConst.LLM, xunfei_model_credential, XFChatSparkLLM), ModelInfo('generalv3', '', ModelTypeConst.LLM, xunfei_model_credential, XFChatSparkLLM), ModelInfo('generalv2', '', ModelTypeConst.LLM, xunfei_model_credential, XFChatSparkLLM), - ModelInfo('iat', _('Chinese and English recognition'), ModelTypeConst.STT, stt_model_credential, XFSparkSpeechToText), + ModelInfo('iat', _('Chinese and English recognition'), ModelTypeConst.STT, stt_model_credential, + XFSparkSpeechToText), + ModelInfo('slm', _('Chinese and English recognition'), ModelTypeConst.STT, zh_en_stt_credential, + XFZhEnSparkSpeechToText), ModelInfo('tts', '', ModelTypeConst.TTS, tts_model_credential, XFSparkTextToSpeech), ModelInfo('embedding', '', ModelTypeConst.EMBEDDING, embedding_model_credential, XFEmbedding) ] @@ -47,7 +54,8 @@ model_info_manage = ( .append_default_model_info( ModelInfo('generalv3.5', '', ModelTypeConst.LLM, xunfei_model_credential, XFChatSparkLLM)) .append_default_model_info( - ModelInfo('iat', _('Chinese and English recognition'), ModelTypeConst.STT, stt_model_credential, XFSparkSpeechToText), + ModelInfo('iat', _('Chinese and English recognition'), ModelTypeConst.STT, stt_model_credential, + XFSparkSpeechToText), ) .append_default_model_info( ModelInfo('tts', '', ModelTypeConst.TTS, tts_model_credential, XFSparkTextToSpeech))