mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-25 17:22:55 +00:00
feat: 语音识别和语音合成
This commit is contained in:
parent
88277c7aea
commit
b500404a41
|
|
@ -139,6 +139,8 @@ class BaseModelCredential(ABC):
|
|||
class ModelTypeConst(Enum):
|
||||
LLM = {'code': 'LLM', 'message': '大语言模型'}
|
||||
EMBEDDING = {'code': 'EMBEDDING', 'message': '向量模型'}
|
||||
STT = {'code': 'STT', 'message': '语音识别'}
|
||||
TTS = {'code': 'TTS', 'message': '语音合成'}
|
||||
|
||||
|
||||
class ModelInfo:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,14 @@
|
|||
# coding=utf-8
|
||||
from abc import abstractmethod
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class BaseSpeechToText(BaseModel):
|
||||
@abstractmethod
|
||||
def check_auth(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def speech_to_text(self, audio_file):
|
||||
pass
|
||||
|
|
@ -0,0 +1,14 @@
|
|||
# coding=utf-8
|
||||
from abc import abstractmethod
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class BaseTextToSpeech(BaseModel):
|
||||
@abstractmethod
|
||||
def check_auth(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def text_to_speech(self, text):
|
||||
pass
|
||||
|
|
@ -0,0 +1,42 @@
|
|||
# coding=utf-8
|
||||
from typing import Dict
|
||||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm
|
||||
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
|
||||
class OpenAISTTModelCredential(BaseForm, BaseModelCredential):
|
||||
api_base = forms.TextInputField('API 域名', required=True)
|
||||
api_key = forms.PasswordInputField('API Key', required=True)
|
||||
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
||||
raise_exception=False):
|
||||
model_type_list = provider.get_model_type_list()
|
||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||
|
||||
for key in ['api_base', 'api_key']:
|
||||
if key not in model_credential:
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
|
||||
else:
|
||||
return False
|
||||
try:
|
||||
model = provider.get_model(model_type, model_name, model_credential)
|
||||
model.check_auth()
|
||||
except Exception as e:
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
def encryption_dict(self, model: Dict[str, object]):
|
||||
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||
|
||||
def get_model_params_setting_form(self, model_name):
|
||||
pass
|
||||
|
|
@ -0,0 +1,53 @@
|
|||
from typing import Dict
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from common.config.tokenizer_manage_config import TokenizerManage
|
||||
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||
from setting.models_provider.impl.base_stt import BaseSpeechToText
|
||||
|
||||
|
||||
def custom_get_token_ids(text: str):
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return tokenizer.encode(text)
|
||||
|
||||
|
||||
class OpenAISpeechToText(MaxKBBaseModel, BaseSpeechToText):
|
||||
api_base: str
|
||||
api_key: str
|
||||
model: str
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.api_key = kwargs.get('api_key')
|
||||
self.api_base = kwargs.get('api_base')
|
||||
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
optional_params = {}
|
||||
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
|
||||
optional_params['max_tokens'] = model_kwargs['max_tokens']
|
||||
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
|
||||
optional_params['temperature'] = model_kwargs['temperature']
|
||||
return OpenAISpeechToText(
|
||||
model=model_name,
|
||||
api_base=model_credential.get('api_base'),
|
||||
api_key=model_credential.get('api_key'),
|
||||
**optional_params,
|
||||
)
|
||||
|
||||
def check_auth(self):
|
||||
client = OpenAI(
|
||||
base_url=self.api_base,
|
||||
api_key=self.api_key
|
||||
)
|
||||
response_list = client.models.with_raw_response.list()
|
||||
# print(response_list)
|
||||
|
||||
def speech_to_text(self, audio_file):
|
||||
client = OpenAI(
|
||||
base_url=self.api_base,
|
||||
api_key=self.api_key
|
||||
)
|
||||
res = client.audio.transcriptions.create(model=self.model, language="zh", file=audio_file)
|
||||
return res.text
|
||||
|
|
@ -0,0 +1,58 @@
|
|||
from typing import Dict
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from common.config.tokenizer_manage_config import TokenizerManage
|
||||
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||
from setting.models_provider.impl.base_tts import BaseTextToSpeech
|
||||
|
||||
|
||||
def custom_get_token_ids(text: str):
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return tokenizer.encode(text)
|
||||
|
||||
|
||||
class OpenAITextToSpeech(MaxKBBaseModel, BaseTextToSpeech):
|
||||
api_base: str
|
||||
api_key: str
|
||||
model: str
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.api_key = kwargs.get('api_key')
|
||||
self.api_base = kwargs.get('api_base')
|
||||
self.model = kwargs.get('model')
|
||||
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
optional_params = {}
|
||||
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
|
||||
optional_params['max_tokens'] = model_kwargs['max_tokens']
|
||||
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
|
||||
optional_params['temperature'] = model_kwargs['temperature']
|
||||
return OpenAITextToSpeech(
|
||||
model=model_name,
|
||||
api_base=model_credential.get('api_base'),
|
||||
api_key=model_credential.get('api_key'),
|
||||
**optional_params,
|
||||
)
|
||||
|
||||
def check_auth(self):
|
||||
client = OpenAI(
|
||||
base_url=self.api_base,
|
||||
api_key=self.api_key
|
||||
)
|
||||
response_list = client.models.with_raw_response.list()
|
||||
# print(response_list)
|
||||
|
||||
def text_to_speech(self, text):
|
||||
client = OpenAI(
|
||||
base_url=self.api_base,
|
||||
api_key=self.api_key
|
||||
)
|
||||
with client.audio.speech.with_streaming_response.create(
|
||||
model=self.model,
|
||||
voice="alloy",
|
||||
input=text,
|
||||
) as response:
|
||||
return response.read()
|
||||
|
|
@ -13,11 +13,15 @@ from setting.models_provider.base_model_provider import IModelProvider, ModelPro
|
|||
ModelTypeConst, ModelInfoManage
|
||||
from setting.models_provider.impl.openai_model_provider.credential.embedding import OpenAIEmbeddingCredential
|
||||
from setting.models_provider.impl.openai_model_provider.credential.llm import OpenAILLMModelCredential
|
||||
from setting.models_provider.impl.openai_model_provider.credential.stt import OpenAISTTModelCredential
|
||||
from setting.models_provider.impl.openai_model_provider.model.embedding import OpenAIEmbeddingModel
|
||||
from setting.models_provider.impl.openai_model_provider.model.llm import OpenAIChatModel
|
||||
from setting.models_provider.impl.openai_model_provider.model.stt import OpenAISpeechToText
|
||||
from setting.models_provider.impl.openai_model_provider.model.tts import OpenAITextToSpeech
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
openai_llm_model_credential = OpenAILLMModelCredential()
|
||||
openai_stt_model_credential = OpenAISTTModelCredential()
|
||||
model_info_list = [
|
||||
ModelInfo('gpt-3.5-turbo', '最新的gpt-3.5-turbo,随OpenAI调整而更新', ModelTypeConst.LLM,
|
||||
openai_llm_model_credential, OpenAIChatModel
|
||||
|
|
@ -58,7 +62,13 @@ model_info_list = [
|
|||
OpenAIChatModel),
|
||||
ModelInfo('gpt-4-1106-preview', '2023年11月6日的gpt-4-turbo快照,支持上下文长度128,000 tokens',
|
||||
ModelTypeConst.LLM, openai_llm_model_credential,
|
||||
OpenAIChatModel)
|
||||
OpenAIChatModel),
|
||||
ModelInfo('whisper-1', '',
|
||||
ModelTypeConst.STT, openai_stt_model_credential,
|
||||
OpenAISpeechToText),
|
||||
ModelInfo('tts-1', '',
|
||||
ModelTypeConst.TTS, openai_stt_model_credential,
|
||||
OpenAITextToSpeech)
|
||||
]
|
||||
open_ai_embedding_credential = OpenAIEmbeddingCredential()
|
||||
model_info_embedding_list = [
|
||||
|
|
|
|||
|
|
@ -0,0 +1,45 @@
|
|||
# coding=utf-8
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm
|
||||
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
|
||||
class VolcanicEngineSTTModelCredential(BaseForm, BaseModelCredential):
|
||||
volcanic_api_url = forms.TextInputField('API 域名', required=True)
|
||||
volcanic_app_id = forms.TextInputField('App ID', required=True)
|
||||
volcanic_token = forms.PasswordInputField('Token', required=True)
|
||||
volcanic_cluster = forms.TextInputField('Cluster', required=True)
|
||||
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
||||
raise_exception=False):
|
||||
model_type_list = provider.get_model_type_list()
|
||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||
|
||||
for key in ['volcanic_api_url', 'volcanic_app_id', 'volcanic_token', 'volcanic_cluster']:
|
||||
if key not in model_credential:
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
|
||||
else:
|
||||
return False
|
||||
try:
|
||||
model = provider.get_model(model_type, model_name, model_credential)
|
||||
model.check_auth()
|
||||
except Exception as e:
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
def encryption_dict(self, model: Dict[str, object]):
|
||||
return {**model, 'volcanic_token': super().encryption(model.get('volcanic_token', ''))}
|
||||
|
||||
def get_model_params_setting_form(self, model_name):
|
||||
pass
|
||||
|
|
@ -0,0 +1,339 @@
|
|||
# coding=utf-8
|
||||
|
||||
"""
|
||||
requires Python 3.6 or later
|
||||
|
||||
pip install asyncio
|
||||
pip install websockets
|
||||
"""
|
||||
import asyncio
|
||||
import base64
|
||||
import gzip
|
||||
import hmac
|
||||
import json
|
||||
import uuid
|
||||
import wave
|
||||
from enum import Enum
|
||||
from hashlib import sha256
|
||||
from io import BytesIO
|
||||
from typing import Dict
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import websockets
|
||||
|
||||
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||
from setting.models_provider.impl.base_stt import BaseSpeechToText
|
||||
|
||||
audio_format = "mp3" # wav 或者 mp3,根据实际音频格式设置
|
||||
|
||||
PROTOCOL_VERSION = 0b0001
|
||||
DEFAULT_HEADER_SIZE = 0b0001
|
||||
|
||||
PROTOCOL_VERSION_BITS = 4
|
||||
HEADER_BITS = 4
|
||||
MESSAGE_TYPE_BITS = 4
|
||||
MESSAGE_TYPE_SPECIFIC_FLAGS_BITS = 4
|
||||
MESSAGE_SERIALIZATION_BITS = 4
|
||||
MESSAGE_COMPRESSION_BITS = 4
|
||||
RESERVED_BITS = 8
|
||||
|
||||
# Message Type:
|
||||
CLIENT_FULL_REQUEST = 0b0001
|
||||
CLIENT_AUDIO_ONLY_REQUEST = 0b0010
|
||||
SERVER_FULL_RESPONSE = 0b1001
|
||||
SERVER_ACK = 0b1011
|
||||
SERVER_ERROR_RESPONSE = 0b1111
|
||||
|
||||
# Message Type Specific Flags
|
||||
NO_SEQUENCE = 0b0000 # no check sequence
|
||||
POS_SEQUENCE = 0b0001
|
||||
NEG_SEQUENCE = 0b0010
|
||||
NEG_SEQUENCE_1 = 0b0011
|
||||
|
||||
# Message Serialization
|
||||
NO_SERIALIZATION = 0b0000
|
||||
JSON = 0b0001
|
||||
THRIFT = 0b0011
|
||||
CUSTOM_TYPE = 0b1111
|
||||
|
||||
# Message Compression
|
||||
NO_COMPRESSION = 0b0000
|
||||
GZIP = 0b0001
|
||||
CUSTOM_COMPRESSION = 0b1111
|
||||
|
||||
|
||||
def generate_header(
|
||||
version=PROTOCOL_VERSION,
|
||||
message_type=CLIENT_FULL_REQUEST,
|
||||
message_type_specific_flags=NO_SEQUENCE,
|
||||
serial_method=JSON,
|
||||
compression_type=GZIP,
|
||||
reserved_data=0x00,
|
||||
extension_header=bytes()
|
||||
):
|
||||
"""
|
||||
protocol_version(4 bits), header_size(4 bits),
|
||||
message_type(4 bits), message_type_specific_flags(4 bits)
|
||||
serialization_method(4 bits) message_compression(4 bits)
|
||||
reserved (8bits) 保留字段
|
||||
header_extensions 扩展头(大小等于 8 * 4 * (header_size - 1) )
|
||||
"""
|
||||
header = bytearray()
|
||||
header_size = int(len(extension_header) / 4) + 1
|
||||
header.append((version << 4) | header_size)
|
||||
header.append((message_type << 4) | message_type_specific_flags)
|
||||
header.append((serial_method << 4) | compression_type)
|
||||
header.append(reserved_data)
|
||||
header.extend(extension_header)
|
||||
return header
|
||||
|
||||
|
||||
def generate_full_default_header():
|
||||
return generate_header()
|
||||
|
||||
|
||||
def generate_audio_default_header():
|
||||
return generate_header(
|
||||
message_type=CLIENT_AUDIO_ONLY_REQUEST
|
||||
)
|
||||
|
||||
|
||||
def generate_last_audio_default_header():
|
||||
return generate_header(
|
||||
message_type=CLIENT_AUDIO_ONLY_REQUEST,
|
||||
message_type_specific_flags=NEG_SEQUENCE
|
||||
)
|
||||
|
||||
|
||||
def parse_response(res):
|
||||
"""
|
||||
protocol_version(4 bits), header_size(4 bits),
|
||||
message_type(4 bits), message_type_specific_flags(4 bits)
|
||||
serialization_method(4 bits) message_compression(4 bits)
|
||||
reserved (8bits) 保留字段
|
||||
header_extensions 扩展头(大小等于 8 * 4 * (header_size - 1) )
|
||||
payload 类似与http 请求体
|
||||
"""
|
||||
protocol_version = res[0] >> 4
|
||||
header_size = res[0] & 0x0f
|
||||
message_type = res[1] >> 4
|
||||
message_type_specific_flags = res[1] & 0x0f
|
||||
serialization_method = res[2] >> 4
|
||||
message_compression = res[2] & 0x0f
|
||||
reserved = res[3]
|
||||
header_extensions = res[4:header_size * 4]
|
||||
payload = res[header_size * 4:]
|
||||
result = {}
|
||||
payload_msg = None
|
||||
payload_size = 0
|
||||
if message_type == SERVER_FULL_RESPONSE:
|
||||
payload_size = int.from_bytes(payload[:4], "big", signed=True)
|
||||
payload_msg = payload[4:]
|
||||
elif message_type == SERVER_ACK:
|
||||
seq = int.from_bytes(payload[:4], "big", signed=True)
|
||||
result['seq'] = seq
|
||||
if len(payload) >= 8:
|
||||
payload_size = int.from_bytes(payload[4:8], "big", signed=False)
|
||||
payload_msg = payload[8:]
|
||||
elif message_type == SERVER_ERROR_RESPONSE:
|
||||
code = int.from_bytes(payload[:4], "big", signed=False)
|
||||
result['code'] = code
|
||||
payload_size = int.from_bytes(payload[4:8], "big", signed=False)
|
||||
payload_msg = payload[8:]
|
||||
if payload_msg is None:
|
||||
return result
|
||||
if message_compression == GZIP:
|
||||
payload_msg = gzip.decompress(payload_msg)
|
||||
if serialization_method == JSON:
|
||||
payload_msg = json.loads(str(payload_msg, "utf-8"))
|
||||
elif serialization_method != NO_SERIALIZATION:
|
||||
payload_msg = str(payload_msg, "utf-8")
|
||||
result['payload_msg'] = payload_msg
|
||||
result['payload_size'] = payload_size
|
||||
return result
|
||||
|
||||
|
||||
def read_wav_info(data: bytes = None) -> (int, int, int, int, int):
|
||||
with BytesIO(data) as _f:
|
||||
wave_fp = wave.open(_f, 'rb')
|
||||
nchannels, sampwidth, framerate, nframes = wave_fp.getparams()[:4]
|
||||
wave_bytes = wave_fp.readframes(nframes)
|
||||
return nchannels, sampwidth, framerate, nframes, len(wave_bytes)
|
||||
|
||||
|
||||
class VolcanicEngineSpeechToText(MaxKBBaseModel, BaseSpeechToText):
|
||||
workflow: str = "audio_in,resample,partition,vad,fe,decode,itn,nlu_punctuate"
|
||||
show_language: bool = False
|
||||
show_utterances: bool = False
|
||||
result_type: str = "full"
|
||||
format: str = "mp3"
|
||||
rate: int = 16000
|
||||
language: str = "zh-CN"
|
||||
bits: int = 16
|
||||
channel: int = 1
|
||||
codec: str = "raw"
|
||||
audio_type: int = 1
|
||||
secret: str = "access_secret"
|
||||
auth_method: str = "token"
|
||||
mp3_seg_size: int = 10000
|
||||
success_code: int = 1000 # success code, default is 1000
|
||||
seg_duration: int = 15000
|
||||
nbest: int = 1
|
||||
|
||||
volcanic_app_id: str
|
||||
volcanic_cluster: str
|
||||
volcanic_api_url: str
|
||||
volcanic_token: str
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.volcanic_api_url = kwargs.get('volcanic_api_url')
|
||||
self.volcanic_token = kwargs.get('volcanic_token')
|
||||
self.volcanic_app_id = kwargs.get('volcanic_app_id')
|
||||
self.volcanic_cluster = kwargs.get('volcanic_cluster')
|
||||
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
optional_params = {}
|
||||
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
|
||||
optional_params['max_tokens'] = model_kwargs['max_tokens']
|
||||
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
|
||||
optional_params['temperature'] = model_kwargs['temperature']
|
||||
return VolcanicEngineSpeechToText(
|
||||
volcanic_api_url=model_credential.get('volcanic_api_url'),
|
||||
volcanic_token=model_credential.get('volcanic_token'),
|
||||
volcanic_app_id=model_credential.get('volcanic_app_id'),
|
||||
volcanic_cluster=model_credential.get('volcanic_cluster'),
|
||||
**optional_params
|
||||
)
|
||||
|
||||
def construct_request(self, reqid):
|
||||
req = {
|
||||
'app': {
|
||||
'appid': self.volcanic_app_id,
|
||||
'cluster': self.volcanic_cluster,
|
||||
'token': self.volcanic_token,
|
||||
},
|
||||
'user': {
|
||||
'uid': 'uid'
|
||||
},
|
||||
'request': {
|
||||
'reqid': reqid,
|
||||
'nbest': self.nbest,
|
||||
'workflow': self.workflow,
|
||||
'show_language': self.show_language,
|
||||
'show_utterances': self.show_utterances,
|
||||
'result_type': self.result_type,
|
||||
"sequence": 1
|
||||
},
|
||||
'audio': {
|
||||
'format': self.format,
|
||||
'rate': self.rate,
|
||||
'language': self.language,
|
||||
'bits': self.bits,
|
||||
'channel': self.channel,
|
||||
'codec': self.codec
|
||||
}
|
||||
}
|
||||
return req
|
||||
|
||||
@staticmethod
|
||||
def slice_data(data: bytes, chunk_size: int) -> (list, bool):
|
||||
"""
|
||||
slice data
|
||||
:param data: wav data
|
||||
:param chunk_size: the segment size in one request
|
||||
:return: segment data, last flag
|
||||
"""
|
||||
data_len = len(data)
|
||||
offset = 0
|
||||
while offset + chunk_size < data_len:
|
||||
yield data[offset: offset + chunk_size], False
|
||||
offset += chunk_size
|
||||
else:
|
||||
yield data[offset: data_len], True
|
||||
|
||||
def _real_processor(self, request_params: dict) -> dict:
|
||||
pass
|
||||
|
||||
def token_auth(self):
|
||||
return {'Authorization': 'Bearer; {}'.format(self.volcanic_token)}
|
||||
|
||||
def signature_auth(self, data):
|
||||
header_dicts = {
|
||||
'Custom': 'auth_custom',
|
||||
}
|
||||
|
||||
url_parse = urlparse(self.volcanic_api_url)
|
||||
input_str = 'GET {} HTTP/1.1\n'.format(url_parse.path)
|
||||
auth_headers = 'Custom'
|
||||
for header in auth_headers.split(','):
|
||||
input_str += '{}\n'.format(header_dicts[header])
|
||||
input_data = bytearray(input_str, 'utf-8')
|
||||
input_data += data
|
||||
mac = base64.urlsafe_b64encode(
|
||||
hmac.new(self.secret.encode('utf-8'), input_data, digestmod=sha256).digest())
|
||||
header_dicts['Authorization'] = 'HMAC256; access_token="{}"; mac="{}"; h="{}"'.format(self.volcanic_token,
|
||||
str(mac, 'utf-8'),
|
||||
auth_headers)
|
||||
return header_dicts
|
||||
|
||||
async def segment_data_processor(self, wav_data: bytes, segment_size: int):
|
||||
reqid = str(uuid.uuid4())
|
||||
# 构建 full client request,并序列化压缩
|
||||
request_params = self.construct_request(reqid)
|
||||
payload_bytes = str.encode(json.dumps(request_params))
|
||||
payload_bytes = gzip.compress(payload_bytes)
|
||||
full_client_request = bytearray(generate_full_default_header())
|
||||
full_client_request.extend((len(payload_bytes)).to_bytes(4, 'big')) # payload size(4 bytes)
|
||||
full_client_request.extend(payload_bytes) # payload
|
||||
header = None
|
||||
if self.auth_method == "token":
|
||||
header = self.token_auth()
|
||||
elif self.auth_method == "signature":
|
||||
header = self.signature_auth(full_client_request)
|
||||
async with websockets.connect(self.volcanic_api_url, extra_headers=header, max_size=1000000000) as ws:
|
||||
# 发送 full client request
|
||||
await ws.send(full_client_request)
|
||||
res = await ws.recv()
|
||||
result = parse_response(res)
|
||||
if 'payload_msg' in result and result['payload_msg']['code'] != self.success_code:
|
||||
return result
|
||||
for seq, (chunk, last) in enumerate(VolcanicEngineSpeechToText.slice_data(wav_data, segment_size), 1):
|
||||
# if no compression, comment this line
|
||||
payload_bytes = gzip.compress(chunk)
|
||||
audio_only_request = bytearray(generate_audio_default_header())
|
||||
if last:
|
||||
audio_only_request = bytearray(generate_last_audio_default_header())
|
||||
audio_only_request.extend((len(payload_bytes)).to_bytes(4, 'big')) # payload size(4 bytes)
|
||||
audio_only_request.extend(payload_bytes) # payload
|
||||
# 发送 audio-only client request
|
||||
await ws.send(audio_only_request)
|
||||
res = await ws.recv()
|
||||
result = parse_response(res)
|
||||
if 'payload_msg' in result and result['payload_msg']['code'] != self.success_code:
|
||||
return result
|
||||
return result['payload_msg']['result'][0]['text']
|
||||
|
||||
def check_auth(self):
|
||||
header = self.token_auth()
|
||||
|
||||
async def check():
|
||||
async with websockets.connect(self.volcanic_api_url, extra_headers=header, max_size=1000000000) as ws:
|
||||
pass
|
||||
|
||||
asyncio.run(check())
|
||||
|
||||
def speech_to_text(self, file):
|
||||
data = file.read()
|
||||
audio_data = bytes(data)
|
||||
if self.format == "mp3":
|
||||
segment_size = self.mp3_seg_size
|
||||
return asyncio.run(self.segment_data_processor(audio_data, segment_size))
|
||||
if self.format != "wav":
|
||||
raise Exception("format should in wav or mp3")
|
||||
nchannels, sampwidth, framerate, nframes, wav_len = read_wav_info(
|
||||
audio_data)
|
||||
size_per_sec = nchannels * sampwidth * framerate
|
||||
segment_size = int(size_per_sec * self.seg_duration / 1000)
|
||||
return asyncio.run(self.segment_data_processor(audio_data, segment_size))
|
||||
|
|
@ -0,0 +1,164 @@
|
|||
# coding=utf-8
|
||||
|
||||
'''
|
||||
requires Python 3.6 or later
|
||||
|
||||
pip install asyncio
|
||||
pip install websockets
|
||||
|
||||
'''
|
||||
|
||||
import asyncio
|
||||
import copy
|
||||
import gzip
|
||||
import json
|
||||
import uuid
|
||||
from typing import Dict
|
||||
|
||||
import websockets
|
||||
|
||||
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||
from setting.models_provider.impl.base_tts import BaseTextToSpeech
|
||||
|
||||
MESSAGE_TYPES = {11: "audio-only server response", 12: "frontend server response", 15: "error message from server"}
|
||||
MESSAGE_TYPE_SPECIFIC_FLAGS = {0: "no sequence number", 1: "sequence number > 0",
|
||||
2: "last message from server (seq < 0)", 3: "sequence number < 0"}
|
||||
MESSAGE_SERIALIZATION_METHODS = {0: "no serialization", 1: "JSON", 15: "custom type"}
|
||||
MESSAGE_COMPRESSIONS = {0: "no compression", 1: "gzip", 15: "custom compression method"}
|
||||
|
||||
# version: b0001 (4 bits)
|
||||
# header size: b0001 (4 bits)
|
||||
# message type: b0001 (Full client request) (4bits)
|
||||
# message type specific flags: b0000 (none) (4bits)
|
||||
# message serialization method: b0001 (JSON) (4 bits)
|
||||
# message compression: b0001 (gzip) (4bits)
|
||||
# reserved data: 0x00 (1 byte)
|
||||
default_header = bytearray(b'\x11\x10\x11\x00')
|
||||
|
||||
|
||||
class VolcanicEngineTextToSpeech(MaxKBBaseModel, BaseTextToSpeech):
|
||||
volcanic_app_id: str
|
||||
volcanic_cluster: str
|
||||
volcanic_api_url: str
|
||||
volcanic_token: str
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.volcanic_api_url = kwargs.get('volcanic_api_url')
|
||||
self.volcanic_token = kwargs.get('volcanic_token')
|
||||
self.volcanic_app_id = kwargs.get('volcanic_app_id')
|
||||
self.volcanic_cluster = kwargs.get('volcanic_cluster')
|
||||
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
optional_params = {}
|
||||
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
|
||||
optional_params['max_tokens'] = model_kwargs['max_tokens']
|
||||
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
|
||||
optional_params['temperature'] = model_kwargs['temperature']
|
||||
return VolcanicEngineTextToSpeech(
|
||||
volcanic_api_url=model_credential.get('volcanic_api_url'),
|
||||
volcanic_token=model_credential.get('volcanic_token'),
|
||||
volcanic_app_id=model_credential.get('volcanic_app_id'),
|
||||
volcanic_cluster=model_credential.get('volcanic_cluster'),
|
||||
**optional_params
|
||||
)
|
||||
|
||||
def check_auth(self):
|
||||
header = self.token_auth()
|
||||
|
||||
async def check():
|
||||
async with websockets.connect(self.volcanic_api_url, extra_headers=header, ping_interval=None) as ws:
|
||||
pass
|
||||
|
||||
asyncio.run(check())
|
||||
|
||||
def text_to_speech(self, text):
|
||||
request_json = {
|
||||
"app": {
|
||||
"appid": self.volcanic_app_id,
|
||||
"token": "access_token",
|
||||
"cluster": self.volcanic_cluster
|
||||
},
|
||||
"user": {
|
||||
"uid": "uid"
|
||||
},
|
||||
"audio": {
|
||||
"voice_type": "BV002_streaming",
|
||||
"encoding": "mp3",
|
||||
"speed_ratio": 1.0,
|
||||
"volume_ratio": 1.0,
|
||||
"pitch_ratio": 1.0,
|
||||
},
|
||||
"request": {
|
||||
"reqid": str(uuid.uuid4()),
|
||||
"text": text,
|
||||
"text_type": "plain",
|
||||
"operation": "xxx"
|
||||
}
|
||||
}
|
||||
|
||||
return asyncio.run(self.submit(request_json))
|
||||
|
||||
def token_auth(self):
|
||||
return {'Authorization': 'Bearer; {}'.format(self.volcanic_token)}
|
||||
|
||||
async def submit(self, request_json):
|
||||
submit_request_json = copy.deepcopy(request_json)
|
||||
submit_request_json["request"]["reqid"] = str(uuid.uuid4())
|
||||
submit_request_json["request"]["operation"] = "submit"
|
||||
payload_bytes = str.encode(json.dumps(submit_request_json))
|
||||
payload_bytes = gzip.compress(payload_bytes) # if no compression, comment this line
|
||||
full_client_request = bytearray(default_header)
|
||||
full_client_request.extend((len(payload_bytes)).to_bytes(4, 'big')) # payload size(4 bytes)
|
||||
full_client_request.extend(payload_bytes) # payload
|
||||
header = {"Authorization": f"Bearer; {self.volcanic_token}"}
|
||||
async with websockets.connect(self.volcanic_api_url, extra_headers=header, ping_interval=None) as ws:
|
||||
await ws.send(full_client_request)
|
||||
return await self.parse_response(ws)
|
||||
|
||||
@staticmethod
|
||||
async def parse_response(ws):
|
||||
result = b''
|
||||
while True:
|
||||
res = await ws.recv()
|
||||
protocol_version = res[0] >> 4
|
||||
header_size = res[0] & 0x0f
|
||||
message_type = res[1] >> 4
|
||||
message_type_specific_flags = res[1] & 0x0f
|
||||
serialization_method = res[2] >> 4
|
||||
message_compression = res[2] & 0x0f
|
||||
reserved = res[3]
|
||||
header_extensions = res[4:header_size * 4]
|
||||
payload = res[header_size * 4:]
|
||||
if header_size != 1:
|
||||
# print(f" Header extensions: {header_extensions}")
|
||||
pass
|
||||
if message_type == 0xb: # audio-only server response
|
||||
if message_type_specific_flags == 0: # no sequence number as ACK
|
||||
continue
|
||||
else:
|
||||
sequence_number = int.from_bytes(payload[:4], "big", signed=True)
|
||||
payload_size = int.from_bytes(payload[4:8], "big", signed=False)
|
||||
payload = payload[8:]
|
||||
result += payload
|
||||
if sequence_number < 0:
|
||||
break
|
||||
else:
|
||||
continue
|
||||
elif message_type == 0xf:
|
||||
code = int.from_bytes(payload[:4], "big", signed=False)
|
||||
msg_size = int.from_bytes(payload[4:8], "big", signed=False)
|
||||
error_msg = payload[8:]
|
||||
if message_compression == 1:
|
||||
error_msg = gzip.decompress(error_msg)
|
||||
error_msg = str(error_msg, "utf-8")
|
||||
break
|
||||
elif message_type == 0xc:
|
||||
msg_size = int.from_bytes(payload[:4], "big", signed=False)
|
||||
payload = payload[4:]
|
||||
if message_compression == 1:
|
||||
payload = gzip.decompress(payload)
|
||||
else:
|
||||
break
|
||||
return result
|
||||
|
|
@ -15,17 +15,31 @@ from setting.models_provider.impl.openai_model_provider.credential.embedding imp
|
|||
from setting.models_provider.impl.openai_model_provider.credential.llm import OpenAILLMModelCredential
|
||||
from setting.models_provider.impl.openai_model_provider.model.embedding import OpenAIEmbeddingModel
|
||||
from setting.models_provider.impl.volcanic_engine_model_provider.model.llm import VolcanicEngineChatModel
|
||||
from setting.models_provider.impl.volcanic_engine_model_provider.credential.stt import VolcanicEngineSTTModelCredential
|
||||
from setting.models_provider.impl.volcanic_engine_model_provider.model.stt import VolcanicEngineSpeechToText
|
||||
from setting.models_provider.impl.volcanic_engine_model_provider.model.tts import VolcanicEngineTextToSpeech
|
||||
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
volcanic_engine_llm_model_credential = OpenAILLMModelCredential()
|
||||
volcanic_engine_stt_model_credential = VolcanicEngineSTTModelCredential()
|
||||
|
||||
model_info_list = [
|
||||
ModelInfo('ep-xxxxxxxxxx-yyyy',
|
||||
'用户前往火山方舟的模型推理页面创建推理接入点,这里需要输入ep-xxxxxxxxxx-yyyy进行调用',
|
||||
ModelTypeConst.LLM,
|
||||
volcanic_engine_llm_model_credential, VolcanicEngineChatModel
|
||||
)
|
||||
),
|
||||
ModelInfo('asr',
|
||||
'',
|
||||
ModelTypeConst.STT,
|
||||
volcanic_engine_stt_model_credential, VolcanicEngineSpeechToText
|
||||
),
|
||||
ModelInfo('tts',
|
||||
'',
|
||||
ModelTypeConst.TTS,
|
||||
volcanic_engine_stt_model_credential, VolcanicEngineTextToSpeech
|
||||
),
|
||||
]
|
||||
|
||||
open_ai_embedding_credential = OpenAIEmbeddingCredential()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,46 @@
|
|||
# coding=utf-8
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm
|
||||
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
|
||||
class XunFeiSTTModelCredential(BaseForm, BaseModelCredential):
|
||||
spark_api_url = forms.TextInputField('API 域名', required=True)
|
||||
spark_app_id = forms.TextInputField('APP ID', required=True)
|
||||
spark_api_key = forms.PasswordInputField("API Key", required=True)
|
||||
spark_api_secret = forms.PasswordInputField('API Secret', required=True)
|
||||
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
||||
raise_exception=False):
|
||||
model_type_list = provider.get_model_type_list()
|
||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||
|
||||
for key in ['spark_api_url', 'spark_app_id', 'spark_api_key', 'spark_api_secret']:
|
||||
if key not in model_credential:
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
|
||||
else:
|
||||
return False
|
||||
try:
|
||||
model = provider.get_model(model_type, model_name, model_credential)
|
||||
model.check_auth()
|
||||
except Exception as e:
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
def encryption_dict(self, model: Dict[str, object]):
|
||||
return {**model, 'spark_api_secret': super().encryption(model.get('spark_api_secret', ''))}
|
||||
|
||||
|
||||
def get_model_params_setting_form(self, model_name):
|
||||
pass
|
||||
|
|
@ -0,0 +1,165 @@
|
|||
# -*- coding:utf-8 -*-
|
||||
#
|
||||
# 错误码链接:https://www.xfyun.cn/document/error-code (code返回错误码时必看)
|
||||
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
|
||||
import asyncio
|
||||
import base64
|
||||
import datetime
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Dict
|
||||
from urllib.parse import urlencode, urlparse
|
||||
|
||||
import websockets
|
||||
|
||||
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||
from setting.models_provider.impl.base_stt import BaseSpeechToText
|
||||
|
||||
STATUS_FIRST_FRAME = 0 # 第一帧的标识
|
||||
STATUS_CONTINUE_FRAME = 1 # 中间帧标识
|
||||
STATUS_LAST_FRAME = 2 # 最后一帧的标识
|
||||
|
||||
|
||||
class XFSparkSpeechToText(MaxKBBaseModel, BaseSpeechToText):
|
||||
spark_app_id: str
|
||||
spark_api_key: str
|
||||
spark_api_secret: str
|
||||
spark_api_url: str
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.spark_api_url = kwargs.get('spark_api_url')
|
||||
self.spark_app_id = kwargs.get('spark_app_id')
|
||||
self.spark_api_key = kwargs.get('spark_api_key')
|
||||
self.spark_api_secret = kwargs.get('spark_api_secret')
|
||||
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
optional_params = {}
|
||||
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
|
||||
optional_params['max_tokens'] = model_kwargs['max_tokens']
|
||||
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
|
||||
optional_params['temperature'] = model_kwargs['temperature']
|
||||
return XFSparkSpeechToText(
|
||||
spark_app_id=model_credential.get('spark_app_id'),
|
||||
spark_api_key=model_credential.get('spark_api_key'),
|
||||
spark_api_secret=model_credential.get('spark_api_secret'),
|
||||
spark_api_url=model_credential.get('spark_api_url'),
|
||||
**optional_params
|
||||
)
|
||||
|
||||
# 生成url
|
||||
def create_url(self):
|
||||
url = self.spark_api_url
|
||||
host = urlparse(url).hostname
|
||||
# 生成RFC1123格式的时间戳
|
||||
gmt_format = '%a, %d %b %Y %H:%M:%S GMT'
|
||||
date = datetime.utcnow().strftime(gmt_format)
|
||||
|
||||
# 拼接字符串
|
||||
signature_origin = "host: " + host + "\n"
|
||||
signature_origin += "date: " + date + "\n"
|
||||
signature_origin += "GET " + "/v2/iat " + "HTTP/1.1"
|
||||
# 进行hmac-sha256进行加密
|
||||
signature_sha = hmac.new(self.spark_api_secret.encode('utf-8'), signature_origin.encode('utf-8'),
|
||||
digestmod=hashlib.sha256).digest()
|
||||
signature_sha = base64.b64encode(signature_sha).decode(encoding='utf-8')
|
||||
|
||||
authorization_origin = "api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"" % (
|
||||
self.spark_api_key, "hmac-sha256", "host date request-line", signature_sha)
|
||||
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
|
||||
# 将请求的鉴权参数组合为字典
|
||||
v = {
|
||||
"authorization": authorization,
|
||||
"date": date,
|
||||
"host": host
|
||||
}
|
||||
# 拼接鉴权参数,生成url
|
||||
url = url + '?' + urlencode(v)
|
||||
# print("date: ",date)
|
||||
# print("v: ",v)
|
||||
# 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致
|
||||
# print('websocket url :', url)
|
||||
return url
|
||||
|
||||
def check_auth(self):
|
||||
async def check():
|
||||
async with websockets.connect(self.create_url()) as ws:
|
||||
pass
|
||||
|
||||
asyncio.run(check())
|
||||
|
||||
def speech_to_text(self, file):
|
||||
async def handle():
|
||||
async with websockets.connect(self.create_url(), max_size=1000000000) as ws:
|
||||
# 发送 full client request
|
||||
await self.send(ws, file)
|
||||
return await self.handle_message(ws)
|
||||
|
||||
return asyncio.run(handle())
|
||||
|
||||
@staticmethod
|
||||
async def handle_message(ws):
|
||||
res = await ws.recv()
|
||||
message = json.loads(res)
|
||||
code = message["code"]
|
||||
sid = message["sid"]
|
||||
if code != 0:
|
||||
errMsg = message["message"]
|
||||
print("sid:%s call error:%s code is:%s" % (sid, errMsg, code))
|
||||
return errMsg
|
||||
else:
|
||||
data = message["data"]["result"]["ws"]
|
||||
result = ""
|
||||
for i in data:
|
||||
for w in i["cw"]:
|
||||
result += w["w"]
|
||||
# print("sid:%s call success!,data is:%s" % (sid, json.dumps(data, ensure_ascii=False)))
|
||||
return result
|
||||
|
||||
# 收到websocket连接建立的处理
|
||||
async def send(self, ws, file):
|
||||
frameSize = 8000 # 每一帧的音频大小
|
||||
status = STATUS_FIRST_FRAME # 音频的状态信息,标识音频是第一帧,还是中间帧、最后一帧
|
||||
|
||||
while True:
|
||||
buf = file.read(frameSize)
|
||||
# 文件结束
|
||||
if not buf:
|
||||
status = STATUS_LAST_FRAME
|
||||
# 第一帧处理
|
||||
# 发送第一帧音频,带business 参数
|
||||
# appid 必须带上,只需第一帧发送
|
||||
if status == STATUS_FIRST_FRAME:
|
||||
d = {
|
||||
"common": {"app_id": self.spark_app_id},
|
||||
"business": {
|
||||
"domain": "iat",
|
||||
"language": "zh_cn",
|
||||
"accent": "mandarin",
|
||||
"vinfo": 1,
|
||||
"vad_eos": 10000
|
||||
},
|
||||
"data": {
|
||||
"status": 0, "format": "audio/L16;rate=16000",
|
||||
"audio": str(base64.b64encode(buf), 'utf-8'),
|
||||
"encoding": "lame"}
|
||||
}
|
||||
d = json.dumps(d)
|
||||
await ws.send(d)
|
||||
status = STATUS_CONTINUE_FRAME
|
||||
# 中间帧处理
|
||||
elif status == STATUS_CONTINUE_FRAME:
|
||||
d = {"data": {"status": 1, "format": "audio/L16;rate=16000",
|
||||
"audio": str(base64.b64encode(buf), 'utf-8'),
|
||||
"encoding": "lame"}}
|
||||
await ws.send(json.dumps(d))
|
||||
# 最后一帧处理
|
||||
elif status == STATUS_LAST_FRAME:
|
||||
d = {"data": {"status": 2, "format": "audio/L16;rate=16000",
|
||||
"audio": str(base64.b64encode(buf), 'utf-8'),
|
||||
"encoding": "lame"}}
|
||||
await ws.send(json.dumps(d))
|
||||
break
|
||||
|
|
@ -0,0 +1,138 @@
|
|||
# -*- coding:utf-8 -*-
|
||||
#
|
||||
# author: iflytek
|
||||
#
|
||||
# 错误码链接:https://www.xfyun.cn/document/error-code (code返回错误码时必看)
|
||||
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
|
||||
import asyncio
|
||||
import base64
|
||||
import datetime
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Dict
|
||||
from urllib.parse import urlencode, urlparse
|
||||
|
||||
import websockets
|
||||
|
||||
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||
from setting.models_provider.impl.base_tts import BaseTextToSpeech
|
||||
|
||||
STATUS_FIRST_FRAME = 0 # 第一帧的标识
|
||||
STATUS_CONTINUE_FRAME = 1 # 中间帧标识
|
||||
STATUS_LAST_FRAME = 2 # 最后一帧的标识
|
||||
|
||||
|
||||
class XFSparkTextToSpeech(MaxKBBaseModel, BaseTextToSpeech):
|
||||
spark_app_id: str
|
||||
spark_api_key: str
|
||||
spark_api_secret: str
|
||||
spark_api_url: str
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.spark_api_url = kwargs.get('spark_api_url')
|
||||
self.spark_app_id = kwargs.get('spark_app_id')
|
||||
self.spark_api_key = kwargs.get('spark_api_key')
|
||||
self.spark_api_secret = kwargs.get('spark_api_secret')
|
||||
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
optional_params = {}
|
||||
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
|
||||
optional_params['max_tokens'] = model_kwargs['max_tokens']
|
||||
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
|
||||
optional_params['temperature'] = model_kwargs['temperature']
|
||||
return XFSparkTextToSpeech(
|
||||
spark_app_id=model_credential.get('spark_app_id'),
|
||||
spark_api_key=model_credential.get('spark_api_key'),
|
||||
spark_api_secret=model_credential.get('spark_api_secret'),
|
||||
spark_api_url=model_credential.get('spark_api_url'),
|
||||
**optional_params
|
||||
)
|
||||
|
||||
# 生成url
|
||||
def create_url(self):
|
||||
url = self.spark_api_url
|
||||
host = urlparse(url).hostname
|
||||
# 生成RFC1123格式的时间戳
|
||||
gmt_format = '%a, %d %b %Y %H:%M:%S GMT'
|
||||
date = datetime.utcnow().strftime(gmt_format)
|
||||
|
||||
# 拼接字符串
|
||||
signature_origin = "host: " + host + "\n"
|
||||
signature_origin += "date: " + date + "\n"
|
||||
signature_origin += "GET " + "/v2/tts " + "HTTP/1.1"
|
||||
# 进行hmac-sha256进行加密
|
||||
signature_sha = hmac.new(self.spark_api_secret.encode('utf-8'), signature_origin.encode('utf-8'),
|
||||
digestmod=hashlib.sha256).digest()
|
||||
signature_sha = base64.b64encode(signature_sha).decode(encoding='utf-8')
|
||||
|
||||
authorization_origin = "api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"" % (
|
||||
self.spark_api_key, "hmac-sha256", "host date request-line", signature_sha)
|
||||
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
|
||||
# 将请求的鉴权参数组合为字典
|
||||
v = {
|
||||
"authorization": authorization,
|
||||
"date": date,
|
||||
"host": host
|
||||
}
|
||||
# 拼接鉴权参数,生成url
|
||||
url = url + '?' + urlencode(v)
|
||||
# print("date: ",date)
|
||||
# print("v: ",v)
|
||||
# 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致
|
||||
# print('websocket url :', url)
|
||||
return url
|
||||
|
||||
def check_auth(self):
|
||||
async def check():
|
||||
async with websockets.connect(self.create_url(), max_size=1000000000) as ws:
|
||||
pass
|
||||
|
||||
asyncio.run(check())
|
||||
|
||||
def text_to_speech(self, text):
|
||||
|
||||
# 使用小语种须使用以下方式,此处的unicode指的是 utf16小端的编码方式,即"UTF-16LE"”
|
||||
# self.Data = {"status": 2, "text": str(base64.b64encode(self.Text.encode('utf-16')), "UTF8")}
|
||||
async def handle():
|
||||
async with websockets.connect(self.create_url(), max_size=1000000000) as ws:
|
||||
# 发送 full client request
|
||||
await self.send(ws, text)
|
||||
return await self.handle_message(ws)
|
||||
|
||||
return asyncio.run(handle())
|
||||
|
||||
@staticmethod
|
||||
async def handle_message(ws):
|
||||
audio_bytes: bytes = b''
|
||||
while True:
|
||||
res = await ws.recv()
|
||||
message = json.loads(res)
|
||||
# print(message)
|
||||
code = message["code"]
|
||||
sid = message["sid"]
|
||||
audio = message["data"]["audio"]
|
||||
audio = base64.b64decode(audio)
|
||||
|
||||
if code != 0:
|
||||
errMsg = message["message"]
|
||||
print("sid:%s call error:%s code is:%s" % (sid, errMsg, code))
|
||||
else:
|
||||
audio_bytes += audio
|
||||
# 退出
|
||||
if message["data"]["status"] == 2:
|
||||
break
|
||||
return audio_bytes
|
||||
|
||||
async def send(self, ws, text):
|
||||
d = {
|
||||
"common": {"app_id": self.spark_app_id},
|
||||
"business": {"aue": "lame", "sfl": 1, "auf": "audio/L16;rate=16000", "vcn": "xiaoyan", "tte": "utf8"},
|
||||
"data": {"status": 2, "text": str(base64.b64encode(text.encode('utf-8')), "UTF8")},
|
||||
}
|
||||
d = json.dumps(d)
|
||||
await ws.send(d)
|
||||
|
|
@ -13,16 +13,24 @@ from common.util.file_util import get_file_content
|
|||
from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, ModelInfo, IModelProvider, \
|
||||
ModelInfoManage
|
||||
from setting.models_provider.impl.xf_model_provider.credential.llm import XunFeiLLMModelCredential
|
||||
from setting.models_provider.impl.xf_model_provider.credential.stt import XunFeiSTTModelCredential
|
||||
from setting.models_provider.impl.xf_model_provider.model.llm import XFChatSparkLLM
|
||||
from setting.models_provider.impl.xf_model_provider.model.stt import XFSparkSpeechToText
|
||||
from setting.models_provider.impl.xf_model_provider.model.tts import XFSparkTextToSpeech
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
ssl._create_default_https_context = ssl.create_default_context()
|
||||
|
||||
qwen_model_credential = XunFeiLLMModelCredential()
|
||||
model_info_list = [ModelInfo('generalv3.5', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM),
|
||||
ModelInfo('generalv3', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM),
|
||||
ModelInfo('generalv2', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM)
|
||||
]
|
||||
stt_model_credential = XunFeiSTTModelCredential()
|
||||
model_info_list = [
|
||||
ModelInfo('generalv3.5', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM),
|
||||
ModelInfo('generalv3', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM),
|
||||
ModelInfo('generalv2', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM),
|
||||
ModelInfo('iat', '中英文识别', ModelTypeConst.STT, stt_model_credential, XFSparkSpeechToText),
|
||||
ModelInfo('iat-niche', '小语种识别', ModelTypeConst.STT, stt_model_credential, XFSparkSpeechToText),
|
||||
ModelInfo('tts', '', ModelTypeConst.TTS, stt_model_credential, XFSparkTextToSpeech),
|
||||
]
|
||||
|
||||
model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info(
|
||||
ModelInfo('generalv3.5', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM)).build()
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ dashscope = "^1.17.0"
|
|||
zhipuai = "^2.0.1"
|
||||
httpx = "^0.27.0"
|
||||
httpx-sse = "^0.4.0"
|
||||
websocket-client = "^1.7.0"
|
||||
websockets = "^13.0"
|
||||
langchain-google-genai = "^1.0.3"
|
||||
openpyxl = "^3.1.2"
|
||||
xlrd = "^2.0.1"
|
||||
|
|
|
|||
|
|
@ -9,7 +9,9 @@ export enum PermissionDesc {
|
|||
|
||||
export enum modelType {
|
||||
EMBEDDING = '向量模型',
|
||||
LLM = '大语言模型'
|
||||
LLM = '大语言模型',
|
||||
STT = '语音识别',
|
||||
TTS = '语音合成',
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue