feat: 语音识别和语音合成

This commit is contained in:
CaptainB 2024-08-27 17:46:52 +08:00 committed by 刘瑞斌
parent 88277c7aea
commit b500404a41
17 changed files with 1122 additions and 8 deletions

View File

@ -139,6 +139,8 @@ class BaseModelCredential(ABC):
class ModelTypeConst(Enum):
LLM = {'code': 'LLM', 'message': '大语言模型'}
EMBEDDING = {'code': 'EMBEDDING', 'message': '向量模型'}
STT = {'code': 'STT', 'message': '语音识别'}
TTS = {'code': 'TTS', 'message': '语音合成'}
class ModelInfo:

View File

@ -0,0 +1,14 @@
# coding=utf-8
from abc import abstractmethod
from pydantic import BaseModel
class BaseSpeechToText(BaseModel):
@abstractmethod
def check_auth(self):
pass
@abstractmethod
def speech_to_text(self, audio_file):
pass

View File

@ -0,0 +1,14 @@
# coding=utf-8
from abc import abstractmethod
from pydantic import BaseModel
class BaseTextToSpeech(BaseModel):
@abstractmethod
def check_auth(self):
pass
@abstractmethod
def text_to_speech(self, text):
pass

View File

@ -0,0 +1,42 @@
# coding=utf-8
from typing import Dict
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
class OpenAISTTModelCredential(BaseForm, BaseModelCredential):
api_base = forms.TextInputField('API 域名', required=True)
api_key = forms.PasswordInputField('API Key', required=True)
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
for key in ['api_base', 'api_key']:
if key not in model_credential:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
model.check_auth()
except Exception as e:
if isinstance(e, AppApiException):
raise e
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
else:
return False
return True
def encryption_dict(self, model: Dict[str, object]):
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
def get_model_params_setting_form(self, model_name):
pass

View File

@ -0,0 +1,53 @@
from typing import Dict
from openai import OpenAI
from common.config.tokenizer_manage_config import TokenizerManage
from setting.models_provider.base_model_provider import MaxKBBaseModel
from setting.models_provider.impl.base_stt import BaseSpeechToText
def custom_get_token_ids(text: str):
tokenizer = TokenizerManage.get_tokenizer()
return tokenizer.encode(text)
class OpenAISpeechToText(MaxKBBaseModel, BaseSpeechToText):
api_base: str
api_key: str
model: str
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.api_key = kwargs.get('api_key')
self.api_base = kwargs.get('api_base')
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {}
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
optional_params['max_tokens'] = model_kwargs['max_tokens']
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
optional_params['temperature'] = model_kwargs['temperature']
return OpenAISpeechToText(
model=model_name,
api_base=model_credential.get('api_base'),
api_key=model_credential.get('api_key'),
**optional_params,
)
def check_auth(self):
client = OpenAI(
base_url=self.api_base,
api_key=self.api_key
)
response_list = client.models.with_raw_response.list()
# print(response_list)
def speech_to_text(self, audio_file):
client = OpenAI(
base_url=self.api_base,
api_key=self.api_key
)
res = client.audio.transcriptions.create(model=self.model, language="zh", file=audio_file)
return res.text

View File

@ -0,0 +1,58 @@
from typing import Dict
from openai import OpenAI
from common.config.tokenizer_manage_config import TokenizerManage
from setting.models_provider.base_model_provider import MaxKBBaseModel
from setting.models_provider.impl.base_tts import BaseTextToSpeech
def custom_get_token_ids(text: str):
tokenizer = TokenizerManage.get_tokenizer()
return tokenizer.encode(text)
class OpenAITextToSpeech(MaxKBBaseModel, BaseTextToSpeech):
api_base: str
api_key: str
model: str
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.api_key = kwargs.get('api_key')
self.api_base = kwargs.get('api_base')
self.model = kwargs.get('model')
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {}
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
optional_params['max_tokens'] = model_kwargs['max_tokens']
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
optional_params['temperature'] = model_kwargs['temperature']
return OpenAITextToSpeech(
model=model_name,
api_base=model_credential.get('api_base'),
api_key=model_credential.get('api_key'),
**optional_params,
)
def check_auth(self):
client = OpenAI(
base_url=self.api_base,
api_key=self.api_key
)
response_list = client.models.with_raw_response.list()
# print(response_list)
def text_to_speech(self, text):
client = OpenAI(
base_url=self.api_base,
api_key=self.api_key
)
with client.audio.speech.with_streaming_response.create(
model=self.model,
voice="alloy",
input=text,
) as response:
return response.read()

View File

@ -13,11 +13,15 @@ from setting.models_provider.base_model_provider import IModelProvider, ModelPro
ModelTypeConst, ModelInfoManage
from setting.models_provider.impl.openai_model_provider.credential.embedding import OpenAIEmbeddingCredential
from setting.models_provider.impl.openai_model_provider.credential.llm import OpenAILLMModelCredential
from setting.models_provider.impl.openai_model_provider.credential.stt import OpenAISTTModelCredential
from setting.models_provider.impl.openai_model_provider.model.embedding import OpenAIEmbeddingModel
from setting.models_provider.impl.openai_model_provider.model.llm import OpenAIChatModel
from setting.models_provider.impl.openai_model_provider.model.stt import OpenAISpeechToText
from setting.models_provider.impl.openai_model_provider.model.tts import OpenAITextToSpeech
from smartdoc.conf import PROJECT_DIR
openai_llm_model_credential = OpenAILLMModelCredential()
openai_stt_model_credential = OpenAISTTModelCredential()
model_info_list = [
ModelInfo('gpt-3.5-turbo', '最新的gpt-3.5-turbo随OpenAI调整而更新', ModelTypeConst.LLM,
openai_llm_model_credential, OpenAIChatModel
@ -58,7 +62,13 @@ model_info_list = [
OpenAIChatModel),
ModelInfo('gpt-4-1106-preview', '2023年11月6日的gpt-4-turbo快照支持上下文长度128,000 tokens',
ModelTypeConst.LLM, openai_llm_model_credential,
OpenAIChatModel)
OpenAIChatModel),
ModelInfo('whisper-1', '',
ModelTypeConst.STT, openai_stt_model_credential,
OpenAISpeechToText),
ModelInfo('tts-1', '',
ModelTypeConst.TTS, openai_stt_model_credential,
OpenAITextToSpeech)
]
open_ai_embedding_credential = OpenAIEmbeddingCredential()
model_info_embedding_list = [

View File

@ -0,0 +1,45 @@
# coding=utf-8
from typing import Dict
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
class VolcanicEngineSTTModelCredential(BaseForm, BaseModelCredential):
volcanic_api_url = forms.TextInputField('API 域名', required=True)
volcanic_app_id = forms.TextInputField('App ID', required=True)
volcanic_token = forms.PasswordInputField('Token', required=True)
volcanic_cluster = forms.TextInputField('Cluster', required=True)
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
for key in ['volcanic_api_url', 'volcanic_app_id', 'volcanic_token', 'volcanic_cluster']:
if key not in model_credential:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
model.check_auth()
except Exception as e:
if isinstance(e, AppApiException):
raise e
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
else:
return False
return True
def encryption_dict(self, model: Dict[str, object]):
return {**model, 'volcanic_token': super().encryption(model.get('volcanic_token', ''))}
def get_model_params_setting_form(self, model_name):
pass

View File

@ -0,0 +1,339 @@
# coding=utf-8
"""
requires Python 3.6 or later
pip install asyncio
pip install websockets
"""
import asyncio
import base64
import gzip
import hmac
import json
import uuid
import wave
from enum import Enum
from hashlib import sha256
from io import BytesIO
from typing import Dict
from urllib.parse import urlparse
import websockets
from setting.models_provider.base_model_provider import MaxKBBaseModel
from setting.models_provider.impl.base_stt import BaseSpeechToText
audio_format = "mp3" # wav 或者 mp3根据实际音频格式设置
PROTOCOL_VERSION = 0b0001
DEFAULT_HEADER_SIZE = 0b0001
PROTOCOL_VERSION_BITS = 4
HEADER_BITS = 4
MESSAGE_TYPE_BITS = 4
MESSAGE_TYPE_SPECIFIC_FLAGS_BITS = 4
MESSAGE_SERIALIZATION_BITS = 4
MESSAGE_COMPRESSION_BITS = 4
RESERVED_BITS = 8
# Message Type:
CLIENT_FULL_REQUEST = 0b0001
CLIENT_AUDIO_ONLY_REQUEST = 0b0010
SERVER_FULL_RESPONSE = 0b1001
SERVER_ACK = 0b1011
SERVER_ERROR_RESPONSE = 0b1111
# Message Type Specific Flags
NO_SEQUENCE = 0b0000 # no check sequence
POS_SEQUENCE = 0b0001
NEG_SEQUENCE = 0b0010
NEG_SEQUENCE_1 = 0b0011
# Message Serialization
NO_SERIALIZATION = 0b0000
JSON = 0b0001
THRIFT = 0b0011
CUSTOM_TYPE = 0b1111
# Message Compression
NO_COMPRESSION = 0b0000
GZIP = 0b0001
CUSTOM_COMPRESSION = 0b1111
def generate_header(
version=PROTOCOL_VERSION,
message_type=CLIENT_FULL_REQUEST,
message_type_specific_flags=NO_SEQUENCE,
serial_method=JSON,
compression_type=GZIP,
reserved_data=0x00,
extension_header=bytes()
):
"""
protocol_version(4 bits), header_size(4 bits),
message_type(4 bits), message_type_specific_flags(4 bits)
serialization_method(4 bits) message_compression(4 bits)
reserved 8bits) 保留字段
header_extensions 扩展头(大小等于 8 * 4 * (header_size - 1) )
"""
header = bytearray()
header_size = int(len(extension_header) / 4) + 1
header.append((version << 4) | header_size)
header.append((message_type << 4) | message_type_specific_flags)
header.append((serial_method << 4) | compression_type)
header.append(reserved_data)
header.extend(extension_header)
return header
def generate_full_default_header():
return generate_header()
def generate_audio_default_header():
return generate_header(
message_type=CLIENT_AUDIO_ONLY_REQUEST
)
def generate_last_audio_default_header():
return generate_header(
message_type=CLIENT_AUDIO_ONLY_REQUEST,
message_type_specific_flags=NEG_SEQUENCE
)
def parse_response(res):
"""
protocol_version(4 bits), header_size(4 bits),
message_type(4 bits), message_type_specific_flags(4 bits)
serialization_method(4 bits) message_compression(4 bits)
reserved 8bits) 保留字段
header_extensions 扩展头(大小等于 8 * 4 * (header_size - 1) )
payload 类似与http 请求体
"""
protocol_version = res[0] >> 4
header_size = res[0] & 0x0f
message_type = res[1] >> 4
message_type_specific_flags = res[1] & 0x0f
serialization_method = res[2] >> 4
message_compression = res[2] & 0x0f
reserved = res[3]
header_extensions = res[4:header_size * 4]
payload = res[header_size * 4:]
result = {}
payload_msg = None
payload_size = 0
if message_type == SERVER_FULL_RESPONSE:
payload_size = int.from_bytes(payload[:4], "big", signed=True)
payload_msg = payload[4:]
elif message_type == SERVER_ACK:
seq = int.from_bytes(payload[:4], "big", signed=True)
result['seq'] = seq
if len(payload) >= 8:
payload_size = int.from_bytes(payload[4:8], "big", signed=False)
payload_msg = payload[8:]
elif message_type == SERVER_ERROR_RESPONSE:
code = int.from_bytes(payload[:4], "big", signed=False)
result['code'] = code
payload_size = int.from_bytes(payload[4:8], "big", signed=False)
payload_msg = payload[8:]
if payload_msg is None:
return result
if message_compression == GZIP:
payload_msg = gzip.decompress(payload_msg)
if serialization_method == JSON:
payload_msg = json.loads(str(payload_msg, "utf-8"))
elif serialization_method != NO_SERIALIZATION:
payload_msg = str(payload_msg, "utf-8")
result['payload_msg'] = payload_msg
result['payload_size'] = payload_size
return result
def read_wav_info(data: bytes = None) -> (int, int, int, int, int):
with BytesIO(data) as _f:
wave_fp = wave.open(_f, 'rb')
nchannels, sampwidth, framerate, nframes = wave_fp.getparams()[:4]
wave_bytes = wave_fp.readframes(nframes)
return nchannels, sampwidth, framerate, nframes, len(wave_bytes)
class VolcanicEngineSpeechToText(MaxKBBaseModel, BaseSpeechToText):
workflow: str = "audio_in,resample,partition,vad,fe,decode,itn,nlu_punctuate"
show_language: bool = False
show_utterances: bool = False
result_type: str = "full"
format: str = "mp3"
rate: int = 16000
language: str = "zh-CN"
bits: int = 16
channel: int = 1
codec: str = "raw"
audio_type: int = 1
secret: str = "access_secret"
auth_method: str = "token"
mp3_seg_size: int = 10000
success_code: int = 1000 # success code, default is 1000
seg_duration: int = 15000
nbest: int = 1
volcanic_app_id: str
volcanic_cluster: str
volcanic_api_url: str
volcanic_token: str
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.volcanic_api_url = kwargs.get('volcanic_api_url')
self.volcanic_token = kwargs.get('volcanic_token')
self.volcanic_app_id = kwargs.get('volcanic_app_id')
self.volcanic_cluster = kwargs.get('volcanic_cluster')
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {}
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
optional_params['max_tokens'] = model_kwargs['max_tokens']
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
optional_params['temperature'] = model_kwargs['temperature']
return VolcanicEngineSpeechToText(
volcanic_api_url=model_credential.get('volcanic_api_url'),
volcanic_token=model_credential.get('volcanic_token'),
volcanic_app_id=model_credential.get('volcanic_app_id'),
volcanic_cluster=model_credential.get('volcanic_cluster'),
**optional_params
)
def construct_request(self, reqid):
req = {
'app': {
'appid': self.volcanic_app_id,
'cluster': self.volcanic_cluster,
'token': self.volcanic_token,
},
'user': {
'uid': 'uid'
},
'request': {
'reqid': reqid,
'nbest': self.nbest,
'workflow': self.workflow,
'show_language': self.show_language,
'show_utterances': self.show_utterances,
'result_type': self.result_type,
"sequence": 1
},
'audio': {
'format': self.format,
'rate': self.rate,
'language': self.language,
'bits': self.bits,
'channel': self.channel,
'codec': self.codec
}
}
return req
@staticmethod
def slice_data(data: bytes, chunk_size: int) -> (list, bool):
"""
slice data
:param data: wav data
:param chunk_size: the segment size in one request
:return: segment data, last flag
"""
data_len = len(data)
offset = 0
while offset + chunk_size < data_len:
yield data[offset: offset + chunk_size], False
offset += chunk_size
else:
yield data[offset: data_len], True
def _real_processor(self, request_params: dict) -> dict:
pass
def token_auth(self):
return {'Authorization': 'Bearer; {}'.format(self.volcanic_token)}
def signature_auth(self, data):
header_dicts = {
'Custom': 'auth_custom',
}
url_parse = urlparse(self.volcanic_api_url)
input_str = 'GET {} HTTP/1.1\n'.format(url_parse.path)
auth_headers = 'Custom'
for header in auth_headers.split(','):
input_str += '{}\n'.format(header_dicts[header])
input_data = bytearray(input_str, 'utf-8')
input_data += data
mac = base64.urlsafe_b64encode(
hmac.new(self.secret.encode('utf-8'), input_data, digestmod=sha256).digest())
header_dicts['Authorization'] = 'HMAC256; access_token="{}"; mac="{}"; h="{}"'.format(self.volcanic_token,
str(mac, 'utf-8'),
auth_headers)
return header_dicts
async def segment_data_processor(self, wav_data: bytes, segment_size: int):
reqid = str(uuid.uuid4())
# 构建 full client request并序列化压缩
request_params = self.construct_request(reqid)
payload_bytes = str.encode(json.dumps(request_params))
payload_bytes = gzip.compress(payload_bytes)
full_client_request = bytearray(generate_full_default_header())
full_client_request.extend((len(payload_bytes)).to_bytes(4, 'big')) # payload size(4 bytes)
full_client_request.extend(payload_bytes) # payload
header = None
if self.auth_method == "token":
header = self.token_auth()
elif self.auth_method == "signature":
header = self.signature_auth(full_client_request)
async with websockets.connect(self.volcanic_api_url, extra_headers=header, max_size=1000000000) as ws:
# 发送 full client request
await ws.send(full_client_request)
res = await ws.recv()
result = parse_response(res)
if 'payload_msg' in result and result['payload_msg']['code'] != self.success_code:
return result
for seq, (chunk, last) in enumerate(VolcanicEngineSpeechToText.slice_data(wav_data, segment_size), 1):
# if no compression, comment this line
payload_bytes = gzip.compress(chunk)
audio_only_request = bytearray(generate_audio_default_header())
if last:
audio_only_request = bytearray(generate_last_audio_default_header())
audio_only_request.extend((len(payload_bytes)).to_bytes(4, 'big')) # payload size(4 bytes)
audio_only_request.extend(payload_bytes) # payload
# 发送 audio-only client request
await ws.send(audio_only_request)
res = await ws.recv()
result = parse_response(res)
if 'payload_msg' in result and result['payload_msg']['code'] != self.success_code:
return result
return result['payload_msg']['result'][0]['text']
def check_auth(self):
header = self.token_auth()
async def check():
async with websockets.connect(self.volcanic_api_url, extra_headers=header, max_size=1000000000) as ws:
pass
asyncio.run(check())
def speech_to_text(self, file):
data = file.read()
audio_data = bytes(data)
if self.format == "mp3":
segment_size = self.mp3_seg_size
return asyncio.run(self.segment_data_processor(audio_data, segment_size))
if self.format != "wav":
raise Exception("format should in wav or mp3")
nchannels, sampwidth, framerate, nframes, wav_len = read_wav_info(
audio_data)
size_per_sec = nchannels * sampwidth * framerate
segment_size = int(size_per_sec * self.seg_duration / 1000)
return asyncio.run(self.segment_data_processor(audio_data, segment_size))

View File

@ -0,0 +1,164 @@
# coding=utf-8
'''
requires Python 3.6 or later
pip install asyncio
pip install websockets
'''
import asyncio
import copy
import gzip
import json
import uuid
from typing import Dict
import websockets
from setting.models_provider.base_model_provider import MaxKBBaseModel
from setting.models_provider.impl.base_tts import BaseTextToSpeech
MESSAGE_TYPES = {11: "audio-only server response", 12: "frontend server response", 15: "error message from server"}
MESSAGE_TYPE_SPECIFIC_FLAGS = {0: "no sequence number", 1: "sequence number > 0",
2: "last message from server (seq < 0)", 3: "sequence number < 0"}
MESSAGE_SERIALIZATION_METHODS = {0: "no serialization", 1: "JSON", 15: "custom type"}
MESSAGE_COMPRESSIONS = {0: "no compression", 1: "gzip", 15: "custom compression method"}
# version: b0001 (4 bits)
# header size: b0001 (4 bits)
# message type: b0001 (Full client request) (4bits)
# message type specific flags: b0000 (none) (4bits)
# message serialization method: b0001 (JSON) (4 bits)
# message compression: b0001 (gzip) (4bits)
# reserved data: 0x00 (1 byte)
default_header = bytearray(b'\x11\x10\x11\x00')
class VolcanicEngineTextToSpeech(MaxKBBaseModel, BaseTextToSpeech):
volcanic_app_id: str
volcanic_cluster: str
volcanic_api_url: str
volcanic_token: str
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.volcanic_api_url = kwargs.get('volcanic_api_url')
self.volcanic_token = kwargs.get('volcanic_token')
self.volcanic_app_id = kwargs.get('volcanic_app_id')
self.volcanic_cluster = kwargs.get('volcanic_cluster')
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {}
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
optional_params['max_tokens'] = model_kwargs['max_tokens']
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
optional_params['temperature'] = model_kwargs['temperature']
return VolcanicEngineTextToSpeech(
volcanic_api_url=model_credential.get('volcanic_api_url'),
volcanic_token=model_credential.get('volcanic_token'),
volcanic_app_id=model_credential.get('volcanic_app_id'),
volcanic_cluster=model_credential.get('volcanic_cluster'),
**optional_params
)
def check_auth(self):
header = self.token_auth()
async def check():
async with websockets.connect(self.volcanic_api_url, extra_headers=header, ping_interval=None) as ws:
pass
asyncio.run(check())
def text_to_speech(self, text):
request_json = {
"app": {
"appid": self.volcanic_app_id,
"token": "access_token",
"cluster": self.volcanic_cluster
},
"user": {
"uid": "uid"
},
"audio": {
"voice_type": "BV002_streaming",
"encoding": "mp3",
"speed_ratio": 1.0,
"volume_ratio": 1.0,
"pitch_ratio": 1.0,
},
"request": {
"reqid": str(uuid.uuid4()),
"text": text,
"text_type": "plain",
"operation": "xxx"
}
}
return asyncio.run(self.submit(request_json))
def token_auth(self):
return {'Authorization': 'Bearer; {}'.format(self.volcanic_token)}
async def submit(self, request_json):
submit_request_json = copy.deepcopy(request_json)
submit_request_json["request"]["reqid"] = str(uuid.uuid4())
submit_request_json["request"]["operation"] = "submit"
payload_bytes = str.encode(json.dumps(submit_request_json))
payload_bytes = gzip.compress(payload_bytes) # if no compression, comment this line
full_client_request = bytearray(default_header)
full_client_request.extend((len(payload_bytes)).to_bytes(4, 'big')) # payload size(4 bytes)
full_client_request.extend(payload_bytes) # payload
header = {"Authorization": f"Bearer; {self.volcanic_token}"}
async with websockets.connect(self.volcanic_api_url, extra_headers=header, ping_interval=None) as ws:
await ws.send(full_client_request)
return await self.parse_response(ws)
@staticmethod
async def parse_response(ws):
result = b''
while True:
res = await ws.recv()
protocol_version = res[0] >> 4
header_size = res[0] & 0x0f
message_type = res[1] >> 4
message_type_specific_flags = res[1] & 0x0f
serialization_method = res[2] >> 4
message_compression = res[2] & 0x0f
reserved = res[3]
header_extensions = res[4:header_size * 4]
payload = res[header_size * 4:]
if header_size != 1:
# print(f" Header extensions: {header_extensions}")
pass
if message_type == 0xb: # audio-only server response
if message_type_specific_flags == 0: # no sequence number as ACK
continue
else:
sequence_number = int.from_bytes(payload[:4], "big", signed=True)
payload_size = int.from_bytes(payload[4:8], "big", signed=False)
payload = payload[8:]
result += payload
if sequence_number < 0:
break
else:
continue
elif message_type == 0xf:
code = int.from_bytes(payload[:4], "big", signed=False)
msg_size = int.from_bytes(payload[4:8], "big", signed=False)
error_msg = payload[8:]
if message_compression == 1:
error_msg = gzip.decompress(error_msg)
error_msg = str(error_msg, "utf-8")
break
elif message_type == 0xc:
msg_size = int.from_bytes(payload[:4], "big", signed=False)
payload = payload[4:]
if message_compression == 1:
payload = gzip.decompress(payload)
else:
break
return result

View File

@ -15,17 +15,31 @@ from setting.models_provider.impl.openai_model_provider.credential.embedding imp
from setting.models_provider.impl.openai_model_provider.credential.llm import OpenAILLMModelCredential
from setting.models_provider.impl.openai_model_provider.model.embedding import OpenAIEmbeddingModel
from setting.models_provider.impl.volcanic_engine_model_provider.model.llm import VolcanicEngineChatModel
from setting.models_provider.impl.volcanic_engine_model_provider.credential.stt import VolcanicEngineSTTModelCredential
from setting.models_provider.impl.volcanic_engine_model_provider.model.stt import VolcanicEngineSpeechToText
from setting.models_provider.impl.volcanic_engine_model_provider.model.tts import VolcanicEngineTextToSpeech
from smartdoc.conf import PROJECT_DIR
volcanic_engine_llm_model_credential = OpenAILLMModelCredential()
volcanic_engine_stt_model_credential = VolcanicEngineSTTModelCredential()
model_info_list = [
ModelInfo('ep-xxxxxxxxxx-yyyy',
'用户前往火山方舟的模型推理页面创建推理接入点这里需要输入ep-xxxxxxxxxx-yyyy进行调用',
ModelTypeConst.LLM,
volcanic_engine_llm_model_credential, VolcanicEngineChatModel
)
),
ModelInfo('asr',
'',
ModelTypeConst.STT,
volcanic_engine_stt_model_credential, VolcanicEngineSpeechToText
),
ModelInfo('tts',
'',
ModelTypeConst.TTS,
volcanic_engine_stt_model_credential, VolcanicEngineTextToSpeech
),
]
open_ai_embedding_credential = OpenAIEmbeddingCredential()

View File

@ -0,0 +1,46 @@
# coding=utf-8
from typing import Dict
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
class XunFeiSTTModelCredential(BaseForm, BaseModelCredential):
spark_api_url = forms.TextInputField('API 域名', required=True)
spark_app_id = forms.TextInputField('APP ID', required=True)
spark_api_key = forms.PasswordInputField("API Key", required=True)
spark_api_secret = forms.PasswordInputField('API Secret', required=True)
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
for key in ['spark_api_url', 'spark_app_id', 'spark_api_key', 'spark_api_secret']:
if key not in model_credential:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
model.check_auth()
except Exception as e:
if isinstance(e, AppApiException):
raise e
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
else:
return False
return True
def encryption_dict(self, model: Dict[str, object]):
return {**model, 'spark_api_secret': super().encryption(model.get('spark_api_secret', ''))}
def get_model_params_setting_form(self, model_name):
pass

View File

@ -0,0 +1,165 @@
# -*- coding:utf-8 -*-
#
# 错误码链接https://www.xfyun.cn/document/error-code code返回错误码时必看
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
import asyncio
import base64
import datetime
import hashlib
import hmac
import json
from datetime import datetime
from typing import Dict
from urllib.parse import urlencode, urlparse
import websockets
from setting.models_provider.base_model_provider import MaxKBBaseModel
from setting.models_provider.impl.base_stt import BaseSpeechToText
STATUS_FIRST_FRAME = 0 # 第一帧的标识
STATUS_CONTINUE_FRAME = 1 # 中间帧标识
STATUS_LAST_FRAME = 2 # 最后一帧的标识
class XFSparkSpeechToText(MaxKBBaseModel, BaseSpeechToText):
spark_app_id: str
spark_api_key: str
spark_api_secret: str
spark_api_url: str
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.spark_api_url = kwargs.get('spark_api_url')
self.spark_app_id = kwargs.get('spark_app_id')
self.spark_api_key = kwargs.get('spark_api_key')
self.spark_api_secret = kwargs.get('spark_api_secret')
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {}
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
optional_params['max_tokens'] = model_kwargs['max_tokens']
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
optional_params['temperature'] = model_kwargs['temperature']
return XFSparkSpeechToText(
spark_app_id=model_credential.get('spark_app_id'),
spark_api_key=model_credential.get('spark_api_key'),
spark_api_secret=model_credential.get('spark_api_secret'),
spark_api_url=model_credential.get('spark_api_url'),
**optional_params
)
# 生成url
def create_url(self):
url = self.spark_api_url
host = urlparse(url).hostname
# 生成RFC1123格式的时间戳
gmt_format = '%a, %d %b %Y %H:%M:%S GMT'
date = datetime.utcnow().strftime(gmt_format)
# 拼接字符串
signature_origin = "host: " + host + "\n"
signature_origin += "date: " + date + "\n"
signature_origin += "GET " + "/v2/iat " + "HTTP/1.1"
# 进行hmac-sha256进行加密
signature_sha = hmac.new(self.spark_api_secret.encode('utf-8'), signature_origin.encode('utf-8'),
digestmod=hashlib.sha256).digest()
signature_sha = base64.b64encode(signature_sha).decode(encoding='utf-8')
authorization_origin = "api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"" % (
self.spark_api_key, "hmac-sha256", "host date request-line", signature_sha)
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
# 将请求的鉴权参数组合为字典
v = {
"authorization": authorization,
"date": date,
"host": host
}
# 拼接鉴权参数生成url
url = url + '?' + urlencode(v)
# print("date: ",date)
# print("v: ",v)
# 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释比对相同参数时生成的url与自己代码生成的url是否一致
# print('websocket url :', url)
return url
def check_auth(self):
async def check():
async with websockets.connect(self.create_url()) as ws:
pass
asyncio.run(check())
def speech_to_text(self, file):
async def handle():
async with websockets.connect(self.create_url(), max_size=1000000000) as ws:
# 发送 full client request
await self.send(ws, file)
return await self.handle_message(ws)
return asyncio.run(handle())
@staticmethod
async def handle_message(ws):
res = await ws.recv()
message = json.loads(res)
code = message["code"]
sid = message["sid"]
if code != 0:
errMsg = message["message"]
print("sid:%s call error:%s code is:%s" % (sid, errMsg, code))
return errMsg
else:
data = message["data"]["result"]["ws"]
result = ""
for i in data:
for w in i["cw"]:
result += w["w"]
# print("sid:%s call success!,data is:%s" % (sid, json.dumps(data, ensure_ascii=False)))
return result
# 收到websocket连接建立的处理
async def send(self, ws, file):
frameSize = 8000 # 每一帧的音频大小
status = STATUS_FIRST_FRAME # 音频的状态信息,标识音频是第一帧,还是中间帧、最后一帧
while True:
buf = file.read(frameSize)
# 文件结束
if not buf:
status = STATUS_LAST_FRAME
# 第一帧处理
# 发送第一帧音频带business 参数
# appid 必须带上,只需第一帧发送
if status == STATUS_FIRST_FRAME:
d = {
"common": {"app_id": self.spark_app_id},
"business": {
"domain": "iat",
"language": "zh_cn",
"accent": "mandarin",
"vinfo": 1,
"vad_eos": 10000
},
"data": {
"status": 0, "format": "audio/L16;rate=16000",
"audio": str(base64.b64encode(buf), 'utf-8'),
"encoding": "lame"}
}
d = json.dumps(d)
await ws.send(d)
status = STATUS_CONTINUE_FRAME
# 中间帧处理
elif status == STATUS_CONTINUE_FRAME:
d = {"data": {"status": 1, "format": "audio/L16;rate=16000",
"audio": str(base64.b64encode(buf), 'utf-8'),
"encoding": "lame"}}
await ws.send(json.dumps(d))
# 最后一帧处理
elif status == STATUS_LAST_FRAME:
d = {"data": {"status": 2, "format": "audio/L16;rate=16000",
"audio": str(base64.b64encode(buf), 'utf-8'),
"encoding": "lame"}}
await ws.send(json.dumps(d))
break

View File

@ -0,0 +1,138 @@
# -*- coding:utf-8 -*-
#
# author: iflytek
#
# 错误码链接https://www.xfyun.cn/document/error-code code返回错误码时必看
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
import asyncio
import base64
import datetime
import hashlib
import hmac
import json
import os
from datetime import datetime
from typing import Dict
from urllib.parse import urlencode, urlparse
import websockets
from setting.models_provider.base_model_provider import MaxKBBaseModel
from setting.models_provider.impl.base_tts import BaseTextToSpeech
STATUS_FIRST_FRAME = 0 # 第一帧的标识
STATUS_CONTINUE_FRAME = 1 # 中间帧标识
STATUS_LAST_FRAME = 2 # 最后一帧的标识
class XFSparkTextToSpeech(MaxKBBaseModel, BaseTextToSpeech):
spark_app_id: str
spark_api_key: str
spark_api_secret: str
spark_api_url: str
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.spark_api_url = kwargs.get('spark_api_url')
self.spark_app_id = kwargs.get('spark_app_id')
self.spark_api_key = kwargs.get('spark_api_key')
self.spark_api_secret = kwargs.get('spark_api_secret')
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {}
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
optional_params['max_tokens'] = model_kwargs['max_tokens']
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
optional_params['temperature'] = model_kwargs['temperature']
return XFSparkTextToSpeech(
spark_app_id=model_credential.get('spark_app_id'),
spark_api_key=model_credential.get('spark_api_key'),
spark_api_secret=model_credential.get('spark_api_secret'),
spark_api_url=model_credential.get('spark_api_url'),
**optional_params
)
# 生成url
def create_url(self):
url = self.spark_api_url
host = urlparse(url).hostname
# 生成RFC1123格式的时间戳
gmt_format = '%a, %d %b %Y %H:%M:%S GMT'
date = datetime.utcnow().strftime(gmt_format)
# 拼接字符串
signature_origin = "host: " + host + "\n"
signature_origin += "date: " + date + "\n"
signature_origin += "GET " + "/v2/tts " + "HTTP/1.1"
# 进行hmac-sha256进行加密
signature_sha = hmac.new(self.spark_api_secret.encode('utf-8'), signature_origin.encode('utf-8'),
digestmod=hashlib.sha256).digest()
signature_sha = base64.b64encode(signature_sha).decode(encoding='utf-8')
authorization_origin = "api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"" % (
self.spark_api_key, "hmac-sha256", "host date request-line", signature_sha)
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
# 将请求的鉴权参数组合为字典
v = {
"authorization": authorization,
"date": date,
"host": host
}
# 拼接鉴权参数生成url
url = url + '?' + urlencode(v)
# print("date: ",date)
# print("v: ",v)
# 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释比对相同参数时生成的url与自己代码生成的url是否一致
# print('websocket url :', url)
return url
def check_auth(self):
async def check():
async with websockets.connect(self.create_url(), max_size=1000000000) as ws:
pass
asyncio.run(check())
def text_to_speech(self, text):
# 使用小语种须使用以下方式此处的unicode指的是 utf16小端的编码方式即"UTF-16LE"”
# self.Data = {"status": 2, "text": str(base64.b64encode(self.Text.encode('utf-16')), "UTF8")}
async def handle():
async with websockets.connect(self.create_url(), max_size=1000000000) as ws:
# 发送 full client request
await self.send(ws, text)
return await self.handle_message(ws)
return asyncio.run(handle())
@staticmethod
async def handle_message(ws):
audio_bytes: bytes = b''
while True:
res = await ws.recv()
message = json.loads(res)
# print(message)
code = message["code"]
sid = message["sid"]
audio = message["data"]["audio"]
audio = base64.b64decode(audio)
if code != 0:
errMsg = message["message"]
print("sid:%s call error:%s code is:%s" % (sid, errMsg, code))
else:
audio_bytes += audio
# 退出
if message["data"]["status"] == 2:
break
return audio_bytes
async def send(self, ws, text):
d = {
"common": {"app_id": self.spark_app_id},
"business": {"aue": "lame", "sfl": 1, "auf": "audio/L16;rate=16000", "vcn": "xiaoyan", "tte": "utf8"},
"data": {"status": 2, "text": str(base64.b64encode(text.encode('utf-8')), "UTF8")},
}
d = json.dumps(d)
await ws.send(d)

View File

@ -13,16 +13,24 @@ from common.util.file_util import get_file_content
from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, ModelInfo, IModelProvider, \
ModelInfoManage
from setting.models_provider.impl.xf_model_provider.credential.llm import XunFeiLLMModelCredential
from setting.models_provider.impl.xf_model_provider.credential.stt import XunFeiSTTModelCredential
from setting.models_provider.impl.xf_model_provider.model.llm import XFChatSparkLLM
from setting.models_provider.impl.xf_model_provider.model.stt import XFSparkSpeechToText
from setting.models_provider.impl.xf_model_provider.model.tts import XFSparkTextToSpeech
from smartdoc.conf import PROJECT_DIR
ssl._create_default_https_context = ssl.create_default_context()
qwen_model_credential = XunFeiLLMModelCredential()
model_info_list = [ModelInfo('generalv3.5', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM),
ModelInfo('generalv3', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM),
ModelInfo('generalv2', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM)
]
stt_model_credential = XunFeiSTTModelCredential()
model_info_list = [
ModelInfo('generalv3.5', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM),
ModelInfo('generalv3', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM),
ModelInfo('generalv2', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM),
ModelInfo('iat', '中英文识别', ModelTypeConst.STT, stt_model_credential, XFSparkSpeechToText),
ModelInfo('iat-niche', '小语种识别', ModelTypeConst.STT, stt_model_credential, XFSparkSpeechToText),
ModelInfo('tts', '', ModelTypeConst.TTS, stt_model_credential, XFSparkTextToSpeech),
]
model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info(
ModelInfo('generalv3.5', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM)).build()

View File

@ -40,7 +40,7 @@ dashscope = "^1.17.0"
zhipuai = "^2.0.1"
httpx = "^0.27.0"
httpx-sse = "^0.4.0"
websocket-client = "^1.7.0"
websockets = "^13.0"
langchain-google-genai = "^1.0.3"
openpyxl = "^3.1.2"
xlrd = "^2.0.1"

View File

@ -9,7 +9,9 @@ export enum PermissionDesc {
export enum modelType {
EMBEDDING = '向量模型',
LLM = '大语言模型'
LLM = '大语言模型',
STT = '语音识别',
TTS = '语音合成',
}