mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 01:33:05 +00:00
feat: 语音识别和语音合成
This commit is contained in:
parent
88277c7aea
commit
b500404a41
|
|
@ -139,6 +139,8 @@ class BaseModelCredential(ABC):
|
||||||
class ModelTypeConst(Enum):
|
class ModelTypeConst(Enum):
|
||||||
LLM = {'code': 'LLM', 'message': '大语言模型'}
|
LLM = {'code': 'LLM', 'message': '大语言模型'}
|
||||||
EMBEDDING = {'code': 'EMBEDDING', 'message': '向量模型'}
|
EMBEDDING = {'code': 'EMBEDDING', 'message': '向量模型'}
|
||||||
|
STT = {'code': 'STT', 'message': '语音识别'}
|
||||||
|
TTS = {'code': 'TTS', 'message': '语音合成'}
|
||||||
|
|
||||||
|
|
||||||
class ModelInfo:
|
class ModelInfo:
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,14 @@
|
||||||
|
# coding=utf-8
|
||||||
|
from abc import abstractmethod
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class BaseSpeechToText(BaseModel):
|
||||||
|
@abstractmethod
|
||||||
|
def check_auth(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def speech_to_text(self, audio_file):
|
||||||
|
pass
|
||||||
|
|
@ -0,0 +1,14 @@
|
||||||
|
# coding=utf-8
|
||||||
|
from abc import abstractmethod
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class BaseTextToSpeech(BaseModel):
|
||||||
|
@abstractmethod
|
||||||
|
def check_auth(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def text_to_speech(self, text):
|
||||||
|
pass
|
||||||
|
|
@ -0,0 +1,42 @@
|
||||||
|
# coding=utf-8
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from common import forms
|
||||||
|
from common.exception.app_exception import AppApiException
|
||||||
|
from common.forms import BaseForm
|
||||||
|
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAISTTModelCredential(BaseForm, BaseModelCredential):
|
||||||
|
api_base = forms.TextInputField('API 域名', required=True)
|
||||||
|
api_key = forms.PasswordInputField('API Key', required=True)
|
||||||
|
|
||||||
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
||||||
|
raise_exception=False):
|
||||||
|
model_type_list = provider.get_model_type_list()
|
||||||
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||||
|
|
||||||
|
for key in ['api_base', 'api_key']:
|
||||||
|
if key not in model_credential:
|
||||||
|
if raise_exception:
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
model = provider.get_model(model_type, model_name, model_credential)
|
||||||
|
model.check_auth()
|
||||||
|
except Exception as e:
|
||||||
|
if isinstance(e, AppApiException):
|
||||||
|
raise e
|
||||||
|
if raise_exception:
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def encryption_dict(self, model: Dict[str, object]):
|
||||||
|
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||||
|
|
||||||
|
def get_model_params_setting_form(self, model_name):
|
||||||
|
pass
|
||||||
|
|
@ -0,0 +1,53 @@
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
from common.config.tokenizer_manage_config import TokenizerManage
|
||||||
|
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||||
|
from setting.models_provider.impl.base_stt import BaseSpeechToText
|
||||||
|
|
||||||
|
|
||||||
|
def custom_get_token_ids(text: str):
|
||||||
|
tokenizer = TokenizerManage.get_tokenizer()
|
||||||
|
return tokenizer.encode(text)
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAISpeechToText(MaxKBBaseModel, BaseSpeechToText):
|
||||||
|
api_base: str
|
||||||
|
api_key: str
|
||||||
|
model: str
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.api_key = kwargs.get('api_key')
|
||||||
|
self.api_base = kwargs.get('api_base')
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||||
|
optional_params = {}
|
||||||
|
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
|
||||||
|
optional_params['max_tokens'] = model_kwargs['max_tokens']
|
||||||
|
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
|
||||||
|
optional_params['temperature'] = model_kwargs['temperature']
|
||||||
|
return OpenAISpeechToText(
|
||||||
|
model=model_name,
|
||||||
|
api_base=model_credential.get('api_base'),
|
||||||
|
api_key=model_credential.get('api_key'),
|
||||||
|
**optional_params,
|
||||||
|
)
|
||||||
|
|
||||||
|
def check_auth(self):
|
||||||
|
client = OpenAI(
|
||||||
|
base_url=self.api_base,
|
||||||
|
api_key=self.api_key
|
||||||
|
)
|
||||||
|
response_list = client.models.with_raw_response.list()
|
||||||
|
# print(response_list)
|
||||||
|
|
||||||
|
def speech_to_text(self, audio_file):
|
||||||
|
client = OpenAI(
|
||||||
|
base_url=self.api_base,
|
||||||
|
api_key=self.api_key
|
||||||
|
)
|
||||||
|
res = client.audio.transcriptions.create(model=self.model, language="zh", file=audio_file)
|
||||||
|
return res.text
|
||||||
|
|
@ -0,0 +1,58 @@
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
from common.config.tokenizer_manage_config import TokenizerManage
|
||||||
|
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||||
|
from setting.models_provider.impl.base_tts import BaseTextToSpeech
|
||||||
|
|
||||||
|
|
||||||
|
def custom_get_token_ids(text: str):
|
||||||
|
tokenizer = TokenizerManage.get_tokenizer()
|
||||||
|
return tokenizer.encode(text)
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAITextToSpeech(MaxKBBaseModel, BaseTextToSpeech):
|
||||||
|
api_base: str
|
||||||
|
api_key: str
|
||||||
|
model: str
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.api_key = kwargs.get('api_key')
|
||||||
|
self.api_base = kwargs.get('api_base')
|
||||||
|
self.model = kwargs.get('model')
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||||
|
optional_params = {}
|
||||||
|
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
|
||||||
|
optional_params['max_tokens'] = model_kwargs['max_tokens']
|
||||||
|
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
|
||||||
|
optional_params['temperature'] = model_kwargs['temperature']
|
||||||
|
return OpenAITextToSpeech(
|
||||||
|
model=model_name,
|
||||||
|
api_base=model_credential.get('api_base'),
|
||||||
|
api_key=model_credential.get('api_key'),
|
||||||
|
**optional_params,
|
||||||
|
)
|
||||||
|
|
||||||
|
def check_auth(self):
|
||||||
|
client = OpenAI(
|
||||||
|
base_url=self.api_base,
|
||||||
|
api_key=self.api_key
|
||||||
|
)
|
||||||
|
response_list = client.models.with_raw_response.list()
|
||||||
|
# print(response_list)
|
||||||
|
|
||||||
|
def text_to_speech(self, text):
|
||||||
|
client = OpenAI(
|
||||||
|
base_url=self.api_base,
|
||||||
|
api_key=self.api_key
|
||||||
|
)
|
||||||
|
with client.audio.speech.with_streaming_response.create(
|
||||||
|
model=self.model,
|
||||||
|
voice="alloy",
|
||||||
|
input=text,
|
||||||
|
) as response:
|
||||||
|
return response.read()
|
||||||
|
|
@ -13,11 +13,15 @@ from setting.models_provider.base_model_provider import IModelProvider, ModelPro
|
||||||
ModelTypeConst, ModelInfoManage
|
ModelTypeConst, ModelInfoManage
|
||||||
from setting.models_provider.impl.openai_model_provider.credential.embedding import OpenAIEmbeddingCredential
|
from setting.models_provider.impl.openai_model_provider.credential.embedding import OpenAIEmbeddingCredential
|
||||||
from setting.models_provider.impl.openai_model_provider.credential.llm import OpenAILLMModelCredential
|
from setting.models_provider.impl.openai_model_provider.credential.llm import OpenAILLMModelCredential
|
||||||
|
from setting.models_provider.impl.openai_model_provider.credential.stt import OpenAISTTModelCredential
|
||||||
from setting.models_provider.impl.openai_model_provider.model.embedding import OpenAIEmbeddingModel
|
from setting.models_provider.impl.openai_model_provider.model.embedding import OpenAIEmbeddingModel
|
||||||
from setting.models_provider.impl.openai_model_provider.model.llm import OpenAIChatModel
|
from setting.models_provider.impl.openai_model_provider.model.llm import OpenAIChatModel
|
||||||
|
from setting.models_provider.impl.openai_model_provider.model.stt import OpenAISpeechToText
|
||||||
|
from setting.models_provider.impl.openai_model_provider.model.tts import OpenAITextToSpeech
|
||||||
from smartdoc.conf import PROJECT_DIR
|
from smartdoc.conf import PROJECT_DIR
|
||||||
|
|
||||||
openai_llm_model_credential = OpenAILLMModelCredential()
|
openai_llm_model_credential = OpenAILLMModelCredential()
|
||||||
|
openai_stt_model_credential = OpenAISTTModelCredential()
|
||||||
model_info_list = [
|
model_info_list = [
|
||||||
ModelInfo('gpt-3.5-turbo', '最新的gpt-3.5-turbo,随OpenAI调整而更新', ModelTypeConst.LLM,
|
ModelInfo('gpt-3.5-turbo', '最新的gpt-3.5-turbo,随OpenAI调整而更新', ModelTypeConst.LLM,
|
||||||
openai_llm_model_credential, OpenAIChatModel
|
openai_llm_model_credential, OpenAIChatModel
|
||||||
|
|
@ -58,7 +62,13 @@ model_info_list = [
|
||||||
OpenAIChatModel),
|
OpenAIChatModel),
|
||||||
ModelInfo('gpt-4-1106-preview', '2023年11月6日的gpt-4-turbo快照,支持上下文长度128,000 tokens',
|
ModelInfo('gpt-4-1106-preview', '2023年11月6日的gpt-4-turbo快照,支持上下文长度128,000 tokens',
|
||||||
ModelTypeConst.LLM, openai_llm_model_credential,
|
ModelTypeConst.LLM, openai_llm_model_credential,
|
||||||
OpenAIChatModel)
|
OpenAIChatModel),
|
||||||
|
ModelInfo('whisper-1', '',
|
||||||
|
ModelTypeConst.STT, openai_stt_model_credential,
|
||||||
|
OpenAISpeechToText),
|
||||||
|
ModelInfo('tts-1', '',
|
||||||
|
ModelTypeConst.TTS, openai_stt_model_credential,
|
||||||
|
OpenAITextToSpeech)
|
||||||
]
|
]
|
||||||
open_ai_embedding_credential = OpenAIEmbeddingCredential()
|
open_ai_embedding_credential = OpenAIEmbeddingCredential()
|
||||||
model_info_embedding_list = [
|
model_info_embedding_list = [
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,45 @@
|
||||||
|
# coding=utf-8
|
||||||
|
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from common import forms
|
||||||
|
from common.exception.app_exception import AppApiException
|
||||||
|
from common.forms import BaseForm
|
||||||
|
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||||
|
|
||||||
|
|
||||||
|
class VolcanicEngineSTTModelCredential(BaseForm, BaseModelCredential):
|
||||||
|
volcanic_api_url = forms.TextInputField('API 域名', required=True)
|
||||||
|
volcanic_app_id = forms.TextInputField('App ID', required=True)
|
||||||
|
volcanic_token = forms.PasswordInputField('Token', required=True)
|
||||||
|
volcanic_cluster = forms.TextInputField('Cluster', required=True)
|
||||||
|
|
||||||
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
||||||
|
raise_exception=False):
|
||||||
|
model_type_list = provider.get_model_type_list()
|
||||||
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||||
|
|
||||||
|
for key in ['volcanic_api_url', 'volcanic_app_id', 'volcanic_token', 'volcanic_cluster']:
|
||||||
|
if key not in model_credential:
|
||||||
|
if raise_exception:
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
model = provider.get_model(model_type, model_name, model_credential)
|
||||||
|
model.check_auth()
|
||||||
|
except Exception as e:
|
||||||
|
if isinstance(e, AppApiException):
|
||||||
|
raise e
|
||||||
|
if raise_exception:
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def encryption_dict(self, model: Dict[str, object]):
|
||||||
|
return {**model, 'volcanic_token': super().encryption(model.get('volcanic_token', ''))}
|
||||||
|
|
||||||
|
def get_model_params_setting_form(self, model_name):
|
||||||
|
pass
|
||||||
|
|
@ -0,0 +1,339 @@
|
||||||
|
# coding=utf-8
|
||||||
|
|
||||||
|
"""
|
||||||
|
requires Python 3.6 or later
|
||||||
|
|
||||||
|
pip install asyncio
|
||||||
|
pip install websockets
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
import base64
|
||||||
|
import gzip
|
||||||
|
import hmac
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
import wave
|
||||||
|
from enum import Enum
|
||||||
|
from hashlib import sha256
|
||||||
|
from io import BytesIO
|
||||||
|
from typing import Dict
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
import websockets
|
||||||
|
|
||||||
|
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||||
|
from setting.models_provider.impl.base_stt import BaseSpeechToText
|
||||||
|
|
||||||
|
audio_format = "mp3" # wav 或者 mp3,根据实际音频格式设置
|
||||||
|
|
||||||
|
PROTOCOL_VERSION = 0b0001
|
||||||
|
DEFAULT_HEADER_SIZE = 0b0001
|
||||||
|
|
||||||
|
PROTOCOL_VERSION_BITS = 4
|
||||||
|
HEADER_BITS = 4
|
||||||
|
MESSAGE_TYPE_BITS = 4
|
||||||
|
MESSAGE_TYPE_SPECIFIC_FLAGS_BITS = 4
|
||||||
|
MESSAGE_SERIALIZATION_BITS = 4
|
||||||
|
MESSAGE_COMPRESSION_BITS = 4
|
||||||
|
RESERVED_BITS = 8
|
||||||
|
|
||||||
|
# Message Type:
|
||||||
|
CLIENT_FULL_REQUEST = 0b0001
|
||||||
|
CLIENT_AUDIO_ONLY_REQUEST = 0b0010
|
||||||
|
SERVER_FULL_RESPONSE = 0b1001
|
||||||
|
SERVER_ACK = 0b1011
|
||||||
|
SERVER_ERROR_RESPONSE = 0b1111
|
||||||
|
|
||||||
|
# Message Type Specific Flags
|
||||||
|
NO_SEQUENCE = 0b0000 # no check sequence
|
||||||
|
POS_SEQUENCE = 0b0001
|
||||||
|
NEG_SEQUENCE = 0b0010
|
||||||
|
NEG_SEQUENCE_1 = 0b0011
|
||||||
|
|
||||||
|
# Message Serialization
|
||||||
|
NO_SERIALIZATION = 0b0000
|
||||||
|
JSON = 0b0001
|
||||||
|
THRIFT = 0b0011
|
||||||
|
CUSTOM_TYPE = 0b1111
|
||||||
|
|
||||||
|
# Message Compression
|
||||||
|
NO_COMPRESSION = 0b0000
|
||||||
|
GZIP = 0b0001
|
||||||
|
CUSTOM_COMPRESSION = 0b1111
|
||||||
|
|
||||||
|
|
||||||
|
def generate_header(
|
||||||
|
version=PROTOCOL_VERSION,
|
||||||
|
message_type=CLIENT_FULL_REQUEST,
|
||||||
|
message_type_specific_flags=NO_SEQUENCE,
|
||||||
|
serial_method=JSON,
|
||||||
|
compression_type=GZIP,
|
||||||
|
reserved_data=0x00,
|
||||||
|
extension_header=bytes()
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
protocol_version(4 bits), header_size(4 bits),
|
||||||
|
message_type(4 bits), message_type_specific_flags(4 bits)
|
||||||
|
serialization_method(4 bits) message_compression(4 bits)
|
||||||
|
reserved (8bits) 保留字段
|
||||||
|
header_extensions 扩展头(大小等于 8 * 4 * (header_size - 1) )
|
||||||
|
"""
|
||||||
|
header = bytearray()
|
||||||
|
header_size = int(len(extension_header) / 4) + 1
|
||||||
|
header.append((version << 4) | header_size)
|
||||||
|
header.append((message_type << 4) | message_type_specific_flags)
|
||||||
|
header.append((serial_method << 4) | compression_type)
|
||||||
|
header.append(reserved_data)
|
||||||
|
header.extend(extension_header)
|
||||||
|
return header
|
||||||
|
|
||||||
|
|
||||||
|
def generate_full_default_header():
|
||||||
|
return generate_header()
|
||||||
|
|
||||||
|
|
||||||
|
def generate_audio_default_header():
|
||||||
|
return generate_header(
|
||||||
|
message_type=CLIENT_AUDIO_ONLY_REQUEST
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_last_audio_default_header():
|
||||||
|
return generate_header(
|
||||||
|
message_type=CLIENT_AUDIO_ONLY_REQUEST,
|
||||||
|
message_type_specific_flags=NEG_SEQUENCE
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_response(res):
|
||||||
|
"""
|
||||||
|
protocol_version(4 bits), header_size(4 bits),
|
||||||
|
message_type(4 bits), message_type_specific_flags(4 bits)
|
||||||
|
serialization_method(4 bits) message_compression(4 bits)
|
||||||
|
reserved (8bits) 保留字段
|
||||||
|
header_extensions 扩展头(大小等于 8 * 4 * (header_size - 1) )
|
||||||
|
payload 类似与http 请求体
|
||||||
|
"""
|
||||||
|
protocol_version = res[0] >> 4
|
||||||
|
header_size = res[0] & 0x0f
|
||||||
|
message_type = res[1] >> 4
|
||||||
|
message_type_specific_flags = res[1] & 0x0f
|
||||||
|
serialization_method = res[2] >> 4
|
||||||
|
message_compression = res[2] & 0x0f
|
||||||
|
reserved = res[3]
|
||||||
|
header_extensions = res[4:header_size * 4]
|
||||||
|
payload = res[header_size * 4:]
|
||||||
|
result = {}
|
||||||
|
payload_msg = None
|
||||||
|
payload_size = 0
|
||||||
|
if message_type == SERVER_FULL_RESPONSE:
|
||||||
|
payload_size = int.from_bytes(payload[:4], "big", signed=True)
|
||||||
|
payload_msg = payload[4:]
|
||||||
|
elif message_type == SERVER_ACK:
|
||||||
|
seq = int.from_bytes(payload[:4], "big", signed=True)
|
||||||
|
result['seq'] = seq
|
||||||
|
if len(payload) >= 8:
|
||||||
|
payload_size = int.from_bytes(payload[4:8], "big", signed=False)
|
||||||
|
payload_msg = payload[8:]
|
||||||
|
elif message_type == SERVER_ERROR_RESPONSE:
|
||||||
|
code = int.from_bytes(payload[:4], "big", signed=False)
|
||||||
|
result['code'] = code
|
||||||
|
payload_size = int.from_bytes(payload[4:8], "big", signed=False)
|
||||||
|
payload_msg = payload[8:]
|
||||||
|
if payload_msg is None:
|
||||||
|
return result
|
||||||
|
if message_compression == GZIP:
|
||||||
|
payload_msg = gzip.decompress(payload_msg)
|
||||||
|
if serialization_method == JSON:
|
||||||
|
payload_msg = json.loads(str(payload_msg, "utf-8"))
|
||||||
|
elif serialization_method != NO_SERIALIZATION:
|
||||||
|
payload_msg = str(payload_msg, "utf-8")
|
||||||
|
result['payload_msg'] = payload_msg
|
||||||
|
result['payload_size'] = payload_size
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def read_wav_info(data: bytes = None) -> (int, int, int, int, int):
|
||||||
|
with BytesIO(data) as _f:
|
||||||
|
wave_fp = wave.open(_f, 'rb')
|
||||||
|
nchannels, sampwidth, framerate, nframes = wave_fp.getparams()[:4]
|
||||||
|
wave_bytes = wave_fp.readframes(nframes)
|
||||||
|
return nchannels, sampwidth, framerate, nframes, len(wave_bytes)
|
||||||
|
|
||||||
|
|
||||||
|
class VolcanicEngineSpeechToText(MaxKBBaseModel, BaseSpeechToText):
|
||||||
|
workflow: str = "audio_in,resample,partition,vad,fe,decode,itn,nlu_punctuate"
|
||||||
|
show_language: bool = False
|
||||||
|
show_utterances: bool = False
|
||||||
|
result_type: str = "full"
|
||||||
|
format: str = "mp3"
|
||||||
|
rate: int = 16000
|
||||||
|
language: str = "zh-CN"
|
||||||
|
bits: int = 16
|
||||||
|
channel: int = 1
|
||||||
|
codec: str = "raw"
|
||||||
|
audio_type: int = 1
|
||||||
|
secret: str = "access_secret"
|
||||||
|
auth_method: str = "token"
|
||||||
|
mp3_seg_size: int = 10000
|
||||||
|
success_code: int = 1000 # success code, default is 1000
|
||||||
|
seg_duration: int = 15000
|
||||||
|
nbest: int = 1
|
||||||
|
|
||||||
|
volcanic_app_id: str
|
||||||
|
volcanic_cluster: str
|
||||||
|
volcanic_api_url: str
|
||||||
|
volcanic_token: str
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.volcanic_api_url = kwargs.get('volcanic_api_url')
|
||||||
|
self.volcanic_token = kwargs.get('volcanic_token')
|
||||||
|
self.volcanic_app_id = kwargs.get('volcanic_app_id')
|
||||||
|
self.volcanic_cluster = kwargs.get('volcanic_cluster')
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||||
|
optional_params = {}
|
||||||
|
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
|
||||||
|
optional_params['max_tokens'] = model_kwargs['max_tokens']
|
||||||
|
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
|
||||||
|
optional_params['temperature'] = model_kwargs['temperature']
|
||||||
|
return VolcanicEngineSpeechToText(
|
||||||
|
volcanic_api_url=model_credential.get('volcanic_api_url'),
|
||||||
|
volcanic_token=model_credential.get('volcanic_token'),
|
||||||
|
volcanic_app_id=model_credential.get('volcanic_app_id'),
|
||||||
|
volcanic_cluster=model_credential.get('volcanic_cluster'),
|
||||||
|
**optional_params
|
||||||
|
)
|
||||||
|
|
||||||
|
def construct_request(self, reqid):
|
||||||
|
req = {
|
||||||
|
'app': {
|
||||||
|
'appid': self.volcanic_app_id,
|
||||||
|
'cluster': self.volcanic_cluster,
|
||||||
|
'token': self.volcanic_token,
|
||||||
|
},
|
||||||
|
'user': {
|
||||||
|
'uid': 'uid'
|
||||||
|
},
|
||||||
|
'request': {
|
||||||
|
'reqid': reqid,
|
||||||
|
'nbest': self.nbest,
|
||||||
|
'workflow': self.workflow,
|
||||||
|
'show_language': self.show_language,
|
||||||
|
'show_utterances': self.show_utterances,
|
||||||
|
'result_type': self.result_type,
|
||||||
|
"sequence": 1
|
||||||
|
},
|
||||||
|
'audio': {
|
||||||
|
'format': self.format,
|
||||||
|
'rate': self.rate,
|
||||||
|
'language': self.language,
|
||||||
|
'bits': self.bits,
|
||||||
|
'channel': self.channel,
|
||||||
|
'codec': self.codec
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return req
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def slice_data(data: bytes, chunk_size: int) -> (list, bool):
|
||||||
|
"""
|
||||||
|
slice data
|
||||||
|
:param data: wav data
|
||||||
|
:param chunk_size: the segment size in one request
|
||||||
|
:return: segment data, last flag
|
||||||
|
"""
|
||||||
|
data_len = len(data)
|
||||||
|
offset = 0
|
||||||
|
while offset + chunk_size < data_len:
|
||||||
|
yield data[offset: offset + chunk_size], False
|
||||||
|
offset += chunk_size
|
||||||
|
else:
|
||||||
|
yield data[offset: data_len], True
|
||||||
|
|
||||||
|
def _real_processor(self, request_params: dict) -> dict:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def token_auth(self):
|
||||||
|
return {'Authorization': 'Bearer; {}'.format(self.volcanic_token)}
|
||||||
|
|
||||||
|
def signature_auth(self, data):
|
||||||
|
header_dicts = {
|
||||||
|
'Custom': 'auth_custom',
|
||||||
|
}
|
||||||
|
|
||||||
|
url_parse = urlparse(self.volcanic_api_url)
|
||||||
|
input_str = 'GET {} HTTP/1.1\n'.format(url_parse.path)
|
||||||
|
auth_headers = 'Custom'
|
||||||
|
for header in auth_headers.split(','):
|
||||||
|
input_str += '{}\n'.format(header_dicts[header])
|
||||||
|
input_data = bytearray(input_str, 'utf-8')
|
||||||
|
input_data += data
|
||||||
|
mac = base64.urlsafe_b64encode(
|
||||||
|
hmac.new(self.secret.encode('utf-8'), input_data, digestmod=sha256).digest())
|
||||||
|
header_dicts['Authorization'] = 'HMAC256; access_token="{}"; mac="{}"; h="{}"'.format(self.volcanic_token,
|
||||||
|
str(mac, 'utf-8'),
|
||||||
|
auth_headers)
|
||||||
|
return header_dicts
|
||||||
|
|
||||||
|
async def segment_data_processor(self, wav_data: bytes, segment_size: int):
|
||||||
|
reqid = str(uuid.uuid4())
|
||||||
|
# 构建 full client request,并序列化压缩
|
||||||
|
request_params = self.construct_request(reqid)
|
||||||
|
payload_bytes = str.encode(json.dumps(request_params))
|
||||||
|
payload_bytes = gzip.compress(payload_bytes)
|
||||||
|
full_client_request = bytearray(generate_full_default_header())
|
||||||
|
full_client_request.extend((len(payload_bytes)).to_bytes(4, 'big')) # payload size(4 bytes)
|
||||||
|
full_client_request.extend(payload_bytes) # payload
|
||||||
|
header = None
|
||||||
|
if self.auth_method == "token":
|
||||||
|
header = self.token_auth()
|
||||||
|
elif self.auth_method == "signature":
|
||||||
|
header = self.signature_auth(full_client_request)
|
||||||
|
async with websockets.connect(self.volcanic_api_url, extra_headers=header, max_size=1000000000) as ws:
|
||||||
|
# 发送 full client request
|
||||||
|
await ws.send(full_client_request)
|
||||||
|
res = await ws.recv()
|
||||||
|
result = parse_response(res)
|
||||||
|
if 'payload_msg' in result and result['payload_msg']['code'] != self.success_code:
|
||||||
|
return result
|
||||||
|
for seq, (chunk, last) in enumerate(VolcanicEngineSpeechToText.slice_data(wav_data, segment_size), 1):
|
||||||
|
# if no compression, comment this line
|
||||||
|
payload_bytes = gzip.compress(chunk)
|
||||||
|
audio_only_request = bytearray(generate_audio_default_header())
|
||||||
|
if last:
|
||||||
|
audio_only_request = bytearray(generate_last_audio_default_header())
|
||||||
|
audio_only_request.extend((len(payload_bytes)).to_bytes(4, 'big')) # payload size(4 bytes)
|
||||||
|
audio_only_request.extend(payload_bytes) # payload
|
||||||
|
# 发送 audio-only client request
|
||||||
|
await ws.send(audio_only_request)
|
||||||
|
res = await ws.recv()
|
||||||
|
result = parse_response(res)
|
||||||
|
if 'payload_msg' in result and result['payload_msg']['code'] != self.success_code:
|
||||||
|
return result
|
||||||
|
return result['payload_msg']['result'][0]['text']
|
||||||
|
|
||||||
|
def check_auth(self):
|
||||||
|
header = self.token_auth()
|
||||||
|
|
||||||
|
async def check():
|
||||||
|
async with websockets.connect(self.volcanic_api_url, extra_headers=header, max_size=1000000000) as ws:
|
||||||
|
pass
|
||||||
|
|
||||||
|
asyncio.run(check())
|
||||||
|
|
||||||
|
def speech_to_text(self, file):
|
||||||
|
data = file.read()
|
||||||
|
audio_data = bytes(data)
|
||||||
|
if self.format == "mp3":
|
||||||
|
segment_size = self.mp3_seg_size
|
||||||
|
return asyncio.run(self.segment_data_processor(audio_data, segment_size))
|
||||||
|
if self.format != "wav":
|
||||||
|
raise Exception("format should in wav or mp3")
|
||||||
|
nchannels, sampwidth, framerate, nframes, wav_len = read_wav_info(
|
||||||
|
audio_data)
|
||||||
|
size_per_sec = nchannels * sampwidth * framerate
|
||||||
|
segment_size = int(size_per_sec * self.seg_duration / 1000)
|
||||||
|
return asyncio.run(self.segment_data_processor(audio_data, segment_size))
|
||||||
|
|
@ -0,0 +1,164 @@
|
||||||
|
# coding=utf-8
|
||||||
|
|
||||||
|
'''
|
||||||
|
requires Python 3.6 or later
|
||||||
|
|
||||||
|
pip install asyncio
|
||||||
|
pip install websockets
|
||||||
|
|
||||||
|
'''
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import copy
|
||||||
|
import gzip
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import websockets
|
||||||
|
|
||||||
|
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||||
|
from setting.models_provider.impl.base_tts import BaseTextToSpeech
|
||||||
|
|
||||||
|
MESSAGE_TYPES = {11: "audio-only server response", 12: "frontend server response", 15: "error message from server"}
|
||||||
|
MESSAGE_TYPE_SPECIFIC_FLAGS = {0: "no sequence number", 1: "sequence number > 0",
|
||||||
|
2: "last message from server (seq < 0)", 3: "sequence number < 0"}
|
||||||
|
MESSAGE_SERIALIZATION_METHODS = {0: "no serialization", 1: "JSON", 15: "custom type"}
|
||||||
|
MESSAGE_COMPRESSIONS = {0: "no compression", 1: "gzip", 15: "custom compression method"}
|
||||||
|
|
||||||
|
# version: b0001 (4 bits)
|
||||||
|
# header size: b0001 (4 bits)
|
||||||
|
# message type: b0001 (Full client request) (4bits)
|
||||||
|
# message type specific flags: b0000 (none) (4bits)
|
||||||
|
# message serialization method: b0001 (JSON) (4 bits)
|
||||||
|
# message compression: b0001 (gzip) (4bits)
|
||||||
|
# reserved data: 0x00 (1 byte)
|
||||||
|
default_header = bytearray(b'\x11\x10\x11\x00')
|
||||||
|
|
||||||
|
|
||||||
|
class VolcanicEngineTextToSpeech(MaxKBBaseModel, BaseTextToSpeech):
|
||||||
|
volcanic_app_id: str
|
||||||
|
volcanic_cluster: str
|
||||||
|
volcanic_api_url: str
|
||||||
|
volcanic_token: str
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.volcanic_api_url = kwargs.get('volcanic_api_url')
|
||||||
|
self.volcanic_token = kwargs.get('volcanic_token')
|
||||||
|
self.volcanic_app_id = kwargs.get('volcanic_app_id')
|
||||||
|
self.volcanic_cluster = kwargs.get('volcanic_cluster')
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||||
|
optional_params = {}
|
||||||
|
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
|
||||||
|
optional_params['max_tokens'] = model_kwargs['max_tokens']
|
||||||
|
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
|
||||||
|
optional_params['temperature'] = model_kwargs['temperature']
|
||||||
|
return VolcanicEngineTextToSpeech(
|
||||||
|
volcanic_api_url=model_credential.get('volcanic_api_url'),
|
||||||
|
volcanic_token=model_credential.get('volcanic_token'),
|
||||||
|
volcanic_app_id=model_credential.get('volcanic_app_id'),
|
||||||
|
volcanic_cluster=model_credential.get('volcanic_cluster'),
|
||||||
|
**optional_params
|
||||||
|
)
|
||||||
|
|
||||||
|
def check_auth(self):
|
||||||
|
header = self.token_auth()
|
||||||
|
|
||||||
|
async def check():
|
||||||
|
async with websockets.connect(self.volcanic_api_url, extra_headers=header, ping_interval=None) as ws:
|
||||||
|
pass
|
||||||
|
|
||||||
|
asyncio.run(check())
|
||||||
|
|
||||||
|
def text_to_speech(self, text):
|
||||||
|
request_json = {
|
||||||
|
"app": {
|
||||||
|
"appid": self.volcanic_app_id,
|
||||||
|
"token": "access_token",
|
||||||
|
"cluster": self.volcanic_cluster
|
||||||
|
},
|
||||||
|
"user": {
|
||||||
|
"uid": "uid"
|
||||||
|
},
|
||||||
|
"audio": {
|
||||||
|
"voice_type": "BV002_streaming",
|
||||||
|
"encoding": "mp3",
|
||||||
|
"speed_ratio": 1.0,
|
||||||
|
"volume_ratio": 1.0,
|
||||||
|
"pitch_ratio": 1.0,
|
||||||
|
},
|
||||||
|
"request": {
|
||||||
|
"reqid": str(uuid.uuid4()),
|
||||||
|
"text": text,
|
||||||
|
"text_type": "plain",
|
||||||
|
"operation": "xxx"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return asyncio.run(self.submit(request_json))
|
||||||
|
|
||||||
|
def token_auth(self):
|
||||||
|
return {'Authorization': 'Bearer; {}'.format(self.volcanic_token)}
|
||||||
|
|
||||||
|
async def submit(self, request_json):
|
||||||
|
submit_request_json = copy.deepcopy(request_json)
|
||||||
|
submit_request_json["request"]["reqid"] = str(uuid.uuid4())
|
||||||
|
submit_request_json["request"]["operation"] = "submit"
|
||||||
|
payload_bytes = str.encode(json.dumps(submit_request_json))
|
||||||
|
payload_bytes = gzip.compress(payload_bytes) # if no compression, comment this line
|
||||||
|
full_client_request = bytearray(default_header)
|
||||||
|
full_client_request.extend((len(payload_bytes)).to_bytes(4, 'big')) # payload size(4 bytes)
|
||||||
|
full_client_request.extend(payload_bytes) # payload
|
||||||
|
header = {"Authorization": f"Bearer; {self.volcanic_token}"}
|
||||||
|
async with websockets.connect(self.volcanic_api_url, extra_headers=header, ping_interval=None) as ws:
|
||||||
|
await ws.send(full_client_request)
|
||||||
|
return await self.parse_response(ws)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def parse_response(ws):
|
||||||
|
result = b''
|
||||||
|
while True:
|
||||||
|
res = await ws.recv()
|
||||||
|
protocol_version = res[0] >> 4
|
||||||
|
header_size = res[0] & 0x0f
|
||||||
|
message_type = res[1] >> 4
|
||||||
|
message_type_specific_flags = res[1] & 0x0f
|
||||||
|
serialization_method = res[2] >> 4
|
||||||
|
message_compression = res[2] & 0x0f
|
||||||
|
reserved = res[3]
|
||||||
|
header_extensions = res[4:header_size * 4]
|
||||||
|
payload = res[header_size * 4:]
|
||||||
|
if header_size != 1:
|
||||||
|
# print(f" Header extensions: {header_extensions}")
|
||||||
|
pass
|
||||||
|
if message_type == 0xb: # audio-only server response
|
||||||
|
if message_type_specific_flags == 0: # no sequence number as ACK
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
sequence_number = int.from_bytes(payload[:4], "big", signed=True)
|
||||||
|
payload_size = int.from_bytes(payload[4:8], "big", signed=False)
|
||||||
|
payload = payload[8:]
|
||||||
|
result += payload
|
||||||
|
if sequence_number < 0:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
elif message_type == 0xf:
|
||||||
|
code = int.from_bytes(payload[:4], "big", signed=False)
|
||||||
|
msg_size = int.from_bytes(payload[4:8], "big", signed=False)
|
||||||
|
error_msg = payload[8:]
|
||||||
|
if message_compression == 1:
|
||||||
|
error_msg = gzip.decompress(error_msg)
|
||||||
|
error_msg = str(error_msg, "utf-8")
|
||||||
|
break
|
||||||
|
elif message_type == 0xc:
|
||||||
|
msg_size = int.from_bytes(payload[:4], "big", signed=False)
|
||||||
|
payload = payload[4:]
|
||||||
|
if message_compression == 1:
|
||||||
|
payload = gzip.decompress(payload)
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
return result
|
||||||
|
|
@ -15,17 +15,31 @@ from setting.models_provider.impl.openai_model_provider.credential.embedding imp
|
||||||
from setting.models_provider.impl.openai_model_provider.credential.llm import OpenAILLMModelCredential
|
from setting.models_provider.impl.openai_model_provider.credential.llm import OpenAILLMModelCredential
|
||||||
from setting.models_provider.impl.openai_model_provider.model.embedding import OpenAIEmbeddingModel
|
from setting.models_provider.impl.openai_model_provider.model.embedding import OpenAIEmbeddingModel
|
||||||
from setting.models_provider.impl.volcanic_engine_model_provider.model.llm import VolcanicEngineChatModel
|
from setting.models_provider.impl.volcanic_engine_model_provider.model.llm import VolcanicEngineChatModel
|
||||||
|
from setting.models_provider.impl.volcanic_engine_model_provider.credential.stt import VolcanicEngineSTTModelCredential
|
||||||
|
from setting.models_provider.impl.volcanic_engine_model_provider.model.stt import VolcanicEngineSpeechToText
|
||||||
|
from setting.models_provider.impl.volcanic_engine_model_provider.model.tts import VolcanicEngineTextToSpeech
|
||||||
|
|
||||||
from smartdoc.conf import PROJECT_DIR
|
from smartdoc.conf import PROJECT_DIR
|
||||||
|
|
||||||
volcanic_engine_llm_model_credential = OpenAILLMModelCredential()
|
volcanic_engine_llm_model_credential = OpenAILLMModelCredential()
|
||||||
|
volcanic_engine_stt_model_credential = VolcanicEngineSTTModelCredential()
|
||||||
|
|
||||||
model_info_list = [
|
model_info_list = [
|
||||||
ModelInfo('ep-xxxxxxxxxx-yyyy',
|
ModelInfo('ep-xxxxxxxxxx-yyyy',
|
||||||
'用户前往火山方舟的模型推理页面创建推理接入点,这里需要输入ep-xxxxxxxxxx-yyyy进行调用',
|
'用户前往火山方舟的模型推理页面创建推理接入点,这里需要输入ep-xxxxxxxxxx-yyyy进行调用',
|
||||||
ModelTypeConst.LLM,
|
ModelTypeConst.LLM,
|
||||||
volcanic_engine_llm_model_credential, VolcanicEngineChatModel
|
volcanic_engine_llm_model_credential, VolcanicEngineChatModel
|
||||||
)
|
),
|
||||||
|
ModelInfo('asr',
|
||||||
|
'',
|
||||||
|
ModelTypeConst.STT,
|
||||||
|
volcanic_engine_stt_model_credential, VolcanicEngineSpeechToText
|
||||||
|
),
|
||||||
|
ModelInfo('tts',
|
||||||
|
'',
|
||||||
|
ModelTypeConst.TTS,
|
||||||
|
volcanic_engine_stt_model_credential, VolcanicEngineTextToSpeech
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
open_ai_embedding_credential = OpenAIEmbeddingCredential()
|
open_ai_embedding_credential = OpenAIEmbeddingCredential()
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,46 @@
|
||||||
|
# coding=utf-8
|
||||||
|
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from common import forms
|
||||||
|
from common.exception.app_exception import AppApiException
|
||||||
|
from common.forms import BaseForm
|
||||||
|
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||||
|
|
||||||
|
|
||||||
|
class XunFeiSTTModelCredential(BaseForm, BaseModelCredential):
|
||||||
|
spark_api_url = forms.TextInputField('API 域名', required=True)
|
||||||
|
spark_app_id = forms.TextInputField('APP ID', required=True)
|
||||||
|
spark_api_key = forms.PasswordInputField("API Key", required=True)
|
||||||
|
spark_api_secret = forms.PasswordInputField('API Secret', required=True)
|
||||||
|
|
||||||
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
||||||
|
raise_exception=False):
|
||||||
|
model_type_list = provider.get_model_type_list()
|
||||||
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||||
|
|
||||||
|
for key in ['spark_api_url', 'spark_app_id', 'spark_api_key', 'spark_api_secret']:
|
||||||
|
if key not in model_credential:
|
||||||
|
if raise_exception:
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
model = provider.get_model(model_type, model_name, model_credential)
|
||||||
|
model.check_auth()
|
||||||
|
except Exception as e:
|
||||||
|
if isinstance(e, AppApiException):
|
||||||
|
raise e
|
||||||
|
if raise_exception:
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def encryption_dict(self, model: Dict[str, object]):
|
||||||
|
return {**model, 'spark_api_secret': super().encryption(model.get('spark_api_secret', ''))}
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_params_setting_form(self, model_name):
|
||||||
|
pass
|
||||||
|
|
@ -0,0 +1,165 @@
|
||||||
|
# -*- coding:utf-8 -*-
|
||||||
|
#
|
||||||
|
# 错误码链接:https://www.xfyun.cn/document/error-code (code返回错误码时必看)
|
||||||
|
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
|
||||||
|
import asyncio
|
||||||
|
import base64
|
||||||
|
import datetime
|
||||||
|
import hashlib
|
||||||
|
import hmac
|
||||||
|
import json
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Dict
|
||||||
|
from urllib.parse import urlencode, urlparse
|
||||||
|
|
||||||
|
import websockets
|
||||||
|
|
||||||
|
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||||
|
from setting.models_provider.impl.base_stt import BaseSpeechToText
|
||||||
|
|
||||||
|
STATUS_FIRST_FRAME = 0 # 第一帧的标识
|
||||||
|
STATUS_CONTINUE_FRAME = 1 # 中间帧标识
|
||||||
|
STATUS_LAST_FRAME = 2 # 最后一帧的标识
|
||||||
|
|
||||||
|
|
||||||
|
class XFSparkSpeechToText(MaxKBBaseModel, BaseSpeechToText):
|
||||||
|
spark_app_id: str
|
||||||
|
spark_api_key: str
|
||||||
|
spark_api_secret: str
|
||||||
|
spark_api_url: str
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.spark_api_url = kwargs.get('spark_api_url')
|
||||||
|
self.spark_app_id = kwargs.get('spark_app_id')
|
||||||
|
self.spark_api_key = kwargs.get('spark_api_key')
|
||||||
|
self.spark_api_secret = kwargs.get('spark_api_secret')
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||||
|
optional_params = {}
|
||||||
|
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
|
||||||
|
optional_params['max_tokens'] = model_kwargs['max_tokens']
|
||||||
|
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
|
||||||
|
optional_params['temperature'] = model_kwargs['temperature']
|
||||||
|
return XFSparkSpeechToText(
|
||||||
|
spark_app_id=model_credential.get('spark_app_id'),
|
||||||
|
spark_api_key=model_credential.get('spark_api_key'),
|
||||||
|
spark_api_secret=model_credential.get('spark_api_secret'),
|
||||||
|
spark_api_url=model_credential.get('spark_api_url'),
|
||||||
|
**optional_params
|
||||||
|
)
|
||||||
|
|
||||||
|
# 生成url
|
||||||
|
def create_url(self):
|
||||||
|
url = self.spark_api_url
|
||||||
|
host = urlparse(url).hostname
|
||||||
|
# 生成RFC1123格式的时间戳
|
||||||
|
gmt_format = '%a, %d %b %Y %H:%M:%S GMT'
|
||||||
|
date = datetime.utcnow().strftime(gmt_format)
|
||||||
|
|
||||||
|
# 拼接字符串
|
||||||
|
signature_origin = "host: " + host + "\n"
|
||||||
|
signature_origin += "date: " + date + "\n"
|
||||||
|
signature_origin += "GET " + "/v2/iat " + "HTTP/1.1"
|
||||||
|
# 进行hmac-sha256进行加密
|
||||||
|
signature_sha = hmac.new(self.spark_api_secret.encode('utf-8'), signature_origin.encode('utf-8'),
|
||||||
|
digestmod=hashlib.sha256).digest()
|
||||||
|
signature_sha = base64.b64encode(signature_sha).decode(encoding='utf-8')
|
||||||
|
|
||||||
|
authorization_origin = "api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"" % (
|
||||||
|
self.spark_api_key, "hmac-sha256", "host date request-line", signature_sha)
|
||||||
|
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
|
||||||
|
# 将请求的鉴权参数组合为字典
|
||||||
|
v = {
|
||||||
|
"authorization": authorization,
|
||||||
|
"date": date,
|
||||||
|
"host": host
|
||||||
|
}
|
||||||
|
# 拼接鉴权参数,生成url
|
||||||
|
url = url + '?' + urlencode(v)
|
||||||
|
# print("date: ",date)
|
||||||
|
# print("v: ",v)
|
||||||
|
# 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致
|
||||||
|
# print('websocket url :', url)
|
||||||
|
return url
|
||||||
|
|
||||||
|
def check_auth(self):
|
||||||
|
async def check():
|
||||||
|
async with websockets.connect(self.create_url()) as ws:
|
||||||
|
pass
|
||||||
|
|
||||||
|
asyncio.run(check())
|
||||||
|
|
||||||
|
def speech_to_text(self, file):
|
||||||
|
async def handle():
|
||||||
|
async with websockets.connect(self.create_url(), max_size=1000000000) as ws:
|
||||||
|
# 发送 full client request
|
||||||
|
await self.send(ws, file)
|
||||||
|
return await self.handle_message(ws)
|
||||||
|
|
||||||
|
return asyncio.run(handle())
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def handle_message(ws):
|
||||||
|
res = await ws.recv()
|
||||||
|
message = json.loads(res)
|
||||||
|
code = message["code"]
|
||||||
|
sid = message["sid"]
|
||||||
|
if code != 0:
|
||||||
|
errMsg = message["message"]
|
||||||
|
print("sid:%s call error:%s code is:%s" % (sid, errMsg, code))
|
||||||
|
return errMsg
|
||||||
|
else:
|
||||||
|
data = message["data"]["result"]["ws"]
|
||||||
|
result = ""
|
||||||
|
for i in data:
|
||||||
|
for w in i["cw"]:
|
||||||
|
result += w["w"]
|
||||||
|
# print("sid:%s call success!,data is:%s" % (sid, json.dumps(data, ensure_ascii=False)))
|
||||||
|
return result
|
||||||
|
|
||||||
|
# 收到websocket连接建立的处理
|
||||||
|
async def send(self, ws, file):
|
||||||
|
frameSize = 8000 # 每一帧的音频大小
|
||||||
|
status = STATUS_FIRST_FRAME # 音频的状态信息,标识音频是第一帧,还是中间帧、最后一帧
|
||||||
|
|
||||||
|
while True:
|
||||||
|
buf = file.read(frameSize)
|
||||||
|
# 文件结束
|
||||||
|
if not buf:
|
||||||
|
status = STATUS_LAST_FRAME
|
||||||
|
# 第一帧处理
|
||||||
|
# 发送第一帧音频,带business 参数
|
||||||
|
# appid 必须带上,只需第一帧发送
|
||||||
|
if status == STATUS_FIRST_FRAME:
|
||||||
|
d = {
|
||||||
|
"common": {"app_id": self.spark_app_id},
|
||||||
|
"business": {
|
||||||
|
"domain": "iat",
|
||||||
|
"language": "zh_cn",
|
||||||
|
"accent": "mandarin",
|
||||||
|
"vinfo": 1,
|
||||||
|
"vad_eos": 10000
|
||||||
|
},
|
||||||
|
"data": {
|
||||||
|
"status": 0, "format": "audio/L16;rate=16000",
|
||||||
|
"audio": str(base64.b64encode(buf), 'utf-8'),
|
||||||
|
"encoding": "lame"}
|
||||||
|
}
|
||||||
|
d = json.dumps(d)
|
||||||
|
await ws.send(d)
|
||||||
|
status = STATUS_CONTINUE_FRAME
|
||||||
|
# 中间帧处理
|
||||||
|
elif status == STATUS_CONTINUE_FRAME:
|
||||||
|
d = {"data": {"status": 1, "format": "audio/L16;rate=16000",
|
||||||
|
"audio": str(base64.b64encode(buf), 'utf-8'),
|
||||||
|
"encoding": "lame"}}
|
||||||
|
await ws.send(json.dumps(d))
|
||||||
|
# 最后一帧处理
|
||||||
|
elif status == STATUS_LAST_FRAME:
|
||||||
|
d = {"data": {"status": 2, "format": "audio/L16;rate=16000",
|
||||||
|
"audio": str(base64.b64encode(buf), 'utf-8'),
|
||||||
|
"encoding": "lame"}}
|
||||||
|
await ws.send(json.dumps(d))
|
||||||
|
break
|
||||||
|
|
@ -0,0 +1,138 @@
|
||||||
|
# -*- coding:utf-8 -*-
|
||||||
|
#
|
||||||
|
# author: iflytek
|
||||||
|
#
|
||||||
|
# 错误码链接:https://www.xfyun.cn/document/error-code (code返回错误码时必看)
|
||||||
|
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
|
||||||
|
import asyncio
|
||||||
|
import base64
|
||||||
|
import datetime
|
||||||
|
import hashlib
|
||||||
|
import hmac
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Dict
|
||||||
|
from urllib.parse import urlencode, urlparse
|
||||||
|
|
||||||
|
import websockets
|
||||||
|
|
||||||
|
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||||
|
from setting.models_provider.impl.base_tts import BaseTextToSpeech
|
||||||
|
|
||||||
|
STATUS_FIRST_FRAME = 0 # 第一帧的标识
|
||||||
|
STATUS_CONTINUE_FRAME = 1 # 中间帧标识
|
||||||
|
STATUS_LAST_FRAME = 2 # 最后一帧的标识
|
||||||
|
|
||||||
|
|
||||||
|
class XFSparkTextToSpeech(MaxKBBaseModel, BaseTextToSpeech):
|
||||||
|
spark_app_id: str
|
||||||
|
spark_api_key: str
|
||||||
|
spark_api_secret: str
|
||||||
|
spark_api_url: str
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.spark_api_url = kwargs.get('spark_api_url')
|
||||||
|
self.spark_app_id = kwargs.get('spark_app_id')
|
||||||
|
self.spark_api_key = kwargs.get('spark_api_key')
|
||||||
|
self.spark_api_secret = kwargs.get('spark_api_secret')
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||||
|
optional_params = {}
|
||||||
|
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
|
||||||
|
optional_params['max_tokens'] = model_kwargs['max_tokens']
|
||||||
|
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
|
||||||
|
optional_params['temperature'] = model_kwargs['temperature']
|
||||||
|
return XFSparkTextToSpeech(
|
||||||
|
spark_app_id=model_credential.get('spark_app_id'),
|
||||||
|
spark_api_key=model_credential.get('spark_api_key'),
|
||||||
|
spark_api_secret=model_credential.get('spark_api_secret'),
|
||||||
|
spark_api_url=model_credential.get('spark_api_url'),
|
||||||
|
**optional_params
|
||||||
|
)
|
||||||
|
|
||||||
|
# 生成url
|
||||||
|
def create_url(self):
|
||||||
|
url = self.spark_api_url
|
||||||
|
host = urlparse(url).hostname
|
||||||
|
# 生成RFC1123格式的时间戳
|
||||||
|
gmt_format = '%a, %d %b %Y %H:%M:%S GMT'
|
||||||
|
date = datetime.utcnow().strftime(gmt_format)
|
||||||
|
|
||||||
|
# 拼接字符串
|
||||||
|
signature_origin = "host: " + host + "\n"
|
||||||
|
signature_origin += "date: " + date + "\n"
|
||||||
|
signature_origin += "GET " + "/v2/tts " + "HTTP/1.1"
|
||||||
|
# 进行hmac-sha256进行加密
|
||||||
|
signature_sha = hmac.new(self.spark_api_secret.encode('utf-8'), signature_origin.encode('utf-8'),
|
||||||
|
digestmod=hashlib.sha256).digest()
|
||||||
|
signature_sha = base64.b64encode(signature_sha).decode(encoding='utf-8')
|
||||||
|
|
||||||
|
authorization_origin = "api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"" % (
|
||||||
|
self.spark_api_key, "hmac-sha256", "host date request-line", signature_sha)
|
||||||
|
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
|
||||||
|
# 将请求的鉴权参数组合为字典
|
||||||
|
v = {
|
||||||
|
"authorization": authorization,
|
||||||
|
"date": date,
|
||||||
|
"host": host
|
||||||
|
}
|
||||||
|
# 拼接鉴权参数,生成url
|
||||||
|
url = url + '?' + urlencode(v)
|
||||||
|
# print("date: ",date)
|
||||||
|
# print("v: ",v)
|
||||||
|
# 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致
|
||||||
|
# print('websocket url :', url)
|
||||||
|
return url
|
||||||
|
|
||||||
|
def check_auth(self):
|
||||||
|
async def check():
|
||||||
|
async with websockets.connect(self.create_url(), max_size=1000000000) as ws:
|
||||||
|
pass
|
||||||
|
|
||||||
|
asyncio.run(check())
|
||||||
|
|
||||||
|
def text_to_speech(self, text):
|
||||||
|
|
||||||
|
# 使用小语种须使用以下方式,此处的unicode指的是 utf16小端的编码方式,即"UTF-16LE"”
|
||||||
|
# self.Data = {"status": 2, "text": str(base64.b64encode(self.Text.encode('utf-16')), "UTF8")}
|
||||||
|
async def handle():
|
||||||
|
async with websockets.connect(self.create_url(), max_size=1000000000) as ws:
|
||||||
|
# 发送 full client request
|
||||||
|
await self.send(ws, text)
|
||||||
|
return await self.handle_message(ws)
|
||||||
|
|
||||||
|
return asyncio.run(handle())
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def handle_message(ws):
|
||||||
|
audio_bytes: bytes = b''
|
||||||
|
while True:
|
||||||
|
res = await ws.recv()
|
||||||
|
message = json.loads(res)
|
||||||
|
# print(message)
|
||||||
|
code = message["code"]
|
||||||
|
sid = message["sid"]
|
||||||
|
audio = message["data"]["audio"]
|
||||||
|
audio = base64.b64decode(audio)
|
||||||
|
|
||||||
|
if code != 0:
|
||||||
|
errMsg = message["message"]
|
||||||
|
print("sid:%s call error:%s code is:%s" % (sid, errMsg, code))
|
||||||
|
else:
|
||||||
|
audio_bytes += audio
|
||||||
|
# 退出
|
||||||
|
if message["data"]["status"] == 2:
|
||||||
|
break
|
||||||
|
return audio_bytes
|
||||||
|
|
||||||
|
async def send(self, ws, text):
|
||||||
|
d = {
|
||||||
|
"common": {"app_id": self.spark_app_id},
|
||||||
|
"business": {"aue": "lame", "sfl": 1, "auf": "audio/L16;rate=16000", "vcn": "xiaoyan", "tte": "utf8"},
|
||||||
|
"data": {"status": 2, "text": str(base64.b64encode(text.encode('utf-8')), "UTF8")},
|
||||||
|
}
|
||||||
|
d = json.dumps(d)
|
||||||
|
await ws.send(d)
|
||||||
|
|
@ -13,16 +13,24 @@ from common.util.file_util import get_file_content
|
||||||
from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, ModelInfo, IModelProvider, \
|
from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, ModelInfo, IModelProvider, \
|
||||||
ModelInfoManage
|
ModelInfoManage
|
||||||
from setting.models_provider.impl.xf_model_provider.credential.llm import XunFeiLLMModelCredential
|
from setting.models_provider.impl.xf_model_provider.credential.llm import XunFeiLLMModelCredential
|
||||||
|
from setting.models_provider.impl.xf_model_provider.credential.stt import XunFeiSTTModelCredential
|
||||||
from setting.models_provider.impl.xf_model_provider.model.llm import XFChatSparkLLM
|
from setting.models_provider.impl.xf_model_provider.model.llm import XFChatSparkLLM
|
||||||
|
from setting.models_provider.impl.xf_model_provider.model.stt import XFSparkSpeechToText
|
||||||
|
from setting.models_provider.impl.xf_model_provider.model.tts import XFSparkTextToSpeech
|
||||||
from smartdoc.conf import PROJECT_DIR
|
from smartdoc.conf import PROJECT_DIR
|
||||||
|
|
||||||
ssl._create_default_https_context = ssl.create_default_context()
|
ssl._create_default_https_context = ssl.create_default_context()
|
||||||
|
|
||||||
qwen_model_credential = XunFeiLLMModelCredential()
|
qwen_model_credential = XunFeiLLMModelCredential()
|
||||||
model_info_list = [ModelInfo('generalv3.5', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM),
|
stt_model_credential = XunFeiSTTModelCredential()
|
||||||
ModelInfo('generalv3', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM),
|
model_info_list = [
|
||||||
ModelInfo('generalv2', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM)
|
ModelInfo('generalv3.5', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM),
|
||||||
]
|
ModelInfo('generalv3', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM),
|
||||||
|
ModelInfo('generalv2', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM),
|
||||||
|
ModelInfo('iat', '中英文识别', ModelTypeConst.STT, stt_model_credential, XFSparkSpeechToText),
|
||||||
|
ModelInfo('iat-niche', '小语种识别', ModelTypeConst.STT, stt_model_credential, XFSparkSpeechToText),
|
||||||
|
ModelInfo('tts', '', ModelTypeConst.TTS, stt_model_credential, XFSparkTextToSpeech),
|
||||||
|
]
|
||||||
|
|
||||||
model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info(
|
model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info(
|
||||||
ModelInfo('generalv3.5', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM)).build()
|
ModelInfo('generalv3.5', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM)).build()
|
||||||
|
|
|
||||||
|
|
@ -40,7 +40,7 @@ dashscope = "^1.17.0"
|
||||||
zhipuai = "^2.0.1"
|
zhipuai = "^2.0.1"
|
||||||
httpx = "^0.27.0"
|
httpx = "^0.27.0"
|
||||||
httpx-sse = "^0.4.0"
|
httpx-sse = "^0.4.0"
|
||||||
websocket-client = "^1.7.0"
|
websockets = "^13.0"
|
||||||
langchain-google-genai = "^1.0.3"
|
langchain-google-genai = "^1.0.3"
|
||||||
openpyxl = "^3.1.2"
|
openpyxl = "^3.1.2"
|
||||||
xlrd = "^2.0.1"
|
xlrd = "^2.0.1"
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,9 @@ export enum PermissionDesc {
|
||||||
|
|
||||||
export enum modelType {
|
export enum modelType {
|
||||||
EMBEDDING = '向量模型',
|
EMBEDDING = '向量模型',
|
||||||
LLM = '大语言模型'
|
LLM = '大语言模型',
|
||||||
|
STT = '语音识别',
|
||||||
|
TTS = '语音合成',
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue