diff --git a/apps/setting/models_provider/base_model_provider.py b/apps/setting/models_provider/base_model_provider.py index 6b3313983..4de2f0067 100644 --- a/apps/setting/models_provider/base_model_provider.py +++ b/apps/setting/models_provider/base_model_provider.py @@ -139,6 +139,8 @@ class BaseModelCredential(ABC): class ModelTypeConst(Enum): LLM = {'code': 'LLM', 'message': '大语言模型'} EMBEDDING = {'code': 'EMBEDDING', 'message': '向量模型'} + STT = {'code': 'STT', 'message': '语音识别'} + TTS = {'code': 'TTS', 'message': '语音合成'} class ModelInfo: diff --git a/apps/setting/models_provider/impl/base_stt.py b/apps/setting/models_provider/impl/base_stt.py new file mode 100644 index 000000000..aae72a559 --- /dev/null +++ b/apps/setting/models_provider/impl/base_stt.py @@ -0,0 +1,14 @@ +# coding=utf-8 +from abc import abstractmethod + +from pydantic import BaseModel + + +class BaseSpeechToText(BaseModel): + @abstractmethod + def check_auth(self): + pass + + @abstractmethod + def speech_to_text(self, audio_file): + pass diff --git a/apps/setting/models_provider/impl/base_tts.py b/apps/setting/models_provider/impl/base_tts.py new file mode 100644 index 000000000..6311f2686 --- /dev/null +++ b/apps/setting/models_provider/impl/base_tts.py @@ -0,0 +1,14 @@ +# coding=utf-8 +from abc import abstractmethod + +from pydantic import BaseModel + + +class BaseTextToSpeech(BaseModel): + @abstractmethod + def check_auth(self): + pass + + @abstractmethod + def text_to_speech(self, text): + pass diff --git a/apps/setting/models_provider/impl/openai_model_provider/credential/stt.py b/apps/setting/models_provider/impl/openai_model_provider/credential/stt.py new file mode 100644 index 000000000..59506311c --- /dev/null +++ b/apps/setting/models_provider/impl/openai_model_provider/credential/stt.py @@ -0,0 +1,42 @@ +# coding=utf-8 +from typing import Dict + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class OpenAISTTModelCredential(BaseForm, BaseModelCredential): + api_base = forms.TextInputField('API 域名', required=True) + api_key = forms.PasswordInputField('API Key', required=True) + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + + for key in ['api_base', 'api_key']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') + else: + return False + try: + model = provider.get_model(model_type, model_name, model_credential) + model.check_auth() + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'api_key': super().encryption(model.get('api_key', ''))} + + def get_model_params_setting_form(self, model_name): + pass diff --git a/apps/setting/models_provider/impl/openai_model_provider/model/stt.py b/apps/setting/models_provider/impl/openai_model_provider/model/stt.py new file mode 100644 index 000000000..66b2daeda --- /dev/null +++ b/apps/setting/models_provider/impl/openai_model_provider/model/stt.py @@ -0,0 +1,53 @@ +from typing import Dict + +from openai import OpenAI + +from common.config.tokenizer_manage_config import TokenizerManage +from setting.models_provider.base_model_provider import MaxKBBaseModel +from setting.models_provider.impl.base_stt import BaseSpeechToText + + +def custom_get_token_ids(text: str): + tokenizer = TokenizerManage.get_tokenizer() + return tokenizer.encode(text) + + +class OpenAISpeechToText(MaxKBBaseModel, BaseSpeechToText): + api_base: str + api_key: str + model: str + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.api_key = kwargs.get('api_key') + self.api_base = kwargs.get('api_base') + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + optional_params = {} + if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None: + optional_params['max_tokens'] = model_kwargs['max_tokens'] + if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None: + optional_params['temperature'] = model_kwargs['temperature'] + return OpenAISpeechToText( + model=model_name, + api_base=model_credential.get('api_base'), + api_key=model_credential.get('api_key'), + **optional_params, + ) + + def check_auth(self): + client = OpenAI( + base_url=self.api_base, + api_key=self.api_key + ) + response_list = client.models.with_raw_response.list() + # print(response_list) + + def speech_to_text(self, audio_file): + client = OpenAI( + base_url=self.api_base, + api_key=self.api_key + ) + res = client.audio.transcriptions.create(model=self.model, language="zh", file=audio_file) + return res.text diff --git a/apps/setting/models_provider/impl/openai_model_provider/model/tts.py b/apps/setting/models_provider/impl/openai_model_provider/model/tts.py new file mode 100644 index 000000000..c09754840 --- /dev/null +++ b/apps/setting/models_provider/impl/openai_model_provider/model/tts.py @@ -0,0 +1,58 @@ +from typing import Dict + +from openai import OpenAI + +from common.config.tokenizer_manage_config import TokenizerManage +from setting.models_provider.base_model_provider import MaxKBBaseModel +from setting.models_provider.impl.base_tts import BaseTextToSpeech + + +def custom_get_token_ids(text: str): + tokenizer = TokenizerManage.get_tokenizer() + return tokenizer.encode(text) + + +class OpenAITextToSpeech(MaxKBBaseModel, BaseTextToSpeech): + api_base: str + api_key: str + model: str + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.api_key = kwargs.get('api_key') + self.api_base = kwargs.get('api_base') + self.model = kwargs.get('model') + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + optional_params = {} + if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None: + optional_params['max_tokens'] = model_kwargs['max_tokens'] + if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None: + optional_params['temperature'] = model_kwargs['temperature'] + return OpenAITextToSpeech( + model=model_name, + api_base=model_credential.get('api_base'), + api_key=model_credential.get('api_key'), + **optional_params, + ) + + def check_auth(self): + client = OpenAI( + base_url=self.api_base, + api_key=self.api_key + ) + response_list = client.models.with_raw_response.list() + # print(response_list) + + def text_to_speech(self, text): + client = OpenAI( + base_url=self.api_base, + api_key=self.api_key + ) + with client.audio.speech.with_streaming_response.create( + model=self.model, + voice="alloy", + input=text, + ) as response: + return response.read() diff --git a/apps/setting/models_provider/impl/openai_model_provider/openai_model_provider.py b/apps/setting/models_provider/impl/openai_model_provider/openai_model_provider.py index fb4c89d7b..cbec2d504 100644 --- a/apps/setting/models_provider/impl/openai_model_provider/openai_model_provider.py +++ b/apps/setting/models_provider/impl/openai_model_provider/openai_model_provider.py @@ -13,11 +13,15 @@ from setting.models_provider.base_model_provider import IModelProvider, ModelPro ModelTypeConst, ModelInfoManage from setting.models_provider.impl.openai_model_provider.credential.embedding import OpenAIEmbeddingCredential from setting.models_provider.impl.openai_model_provider.credential.llm import OpenAILLMModelCredential +from setting.models_provider.impl.openai_model_provider.credential.stt import OpenAISTTModelCredential from setting.models_provider.impl.openai_model_provider.model.embedding import OpenAIEmbeddingModel from setting.models_provider.impl.openai_model_provider.model.llm import OpenAIChatModel +from setting.models_provider.impl.openai_model_provider.model.stt import OpenAISpeechToText +from setting.models_provider.impl.openai_model_provider.model.tts import OpenAITextToSpeech from smartdoc.conf import PROJECT_DIR openai_llm_model_credential = OpenAILLMModelCredential() +openai_stt_model_credential = OpenAISTTModelCredential() model_info_list = [ ModelInfo('gpt-3.5-turbo', '最新的gpt-3.5-turbo,随OpenAI调整而更新', ModelTypeConst.LLM, openai_llm_model_credential, OpenAIChatModel @@ -58,7 +62,13 @@ model_info_list = [ OpenAIChatModel), ModelInfo('gpt-4-1106-preview', '2023年11月6日的gpt-4-turbo快照,支持上下文长度128,000 tokens', ModelTypeConst.LLM, openai_llm_model_credential, - OpenAIChatModel) + OpenAIChatModel), + ModelInfo('whisper-1', '', + ModelTypeConst.STT, openai_stt_model_credential, + OpenAISpeechToText), + ModelInfo('tts-1', '', + ModelTypeConst.TTS, openai_stt_model_credential, + OpenAITextToSpeech) ] open_ai_embedding_credential = OpenAIEmbeddingCredential() model_info_embedding_list = [ diff --git a/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/stt.py b/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/stt.py new file mode 100644 index 000000000..2715d9bd5 --- /dev/null +++ b/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/stt.py @@ -0,0 +1,45 @@ +# coding=utf-8 + +from typing import Dict + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class VolcanicEngineSTTModelCredential(BaseForm, BaseModelCredential): + volcanic_api_url = forms.TextInputField('API 域名', required=True) + volcanic_app_id = forms.TextInputField('App ID', required=True) + volcanic_token = forms.PasswordInputField('Token', required=True) + volcanic_cluster = forms.TextInputField('Cluster', required=True) + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + + for key in ['volcanic_api_url', 'volcanic_app_id', 'volcanic_token', 'volcanic_cluster']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') + else: + return False + try: + model = provider.get_model(model_type, model_name, model_credential) + model.check_auth() + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'volcanic_token': super().encryption(model.get('volcanic_token', ''))} + + def get_model_params_setting_form(self, model_name): + pass diff --git a/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/stt.py b/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/stt.py new file mode 100644 index 000000000..8c0c0e116 --- /dev/null +++ b/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/stt.py @@ -0,0 +1,339 @@ +# coding=utf-8 + +""" +requires Python 3.6 or later + +pip install asyncio +pip install websockets +""" +import asyncio +import base64 +import gzip +import hmac +import json +import uuid +import wave +from enum import Enum +from hashlib import sha256 +from io import BytesIO +from typing import Dict +from urllib.parse import urlparse + +import websockets + +from setting.models_provider.base_model_provider import MaxKBBaseModel +from setting.models_provider.impl.base_stt import BaseSpeechToText + +audio_format = "mp3" # wav 或者 mp3,根据实际音频格式设置 + +PROTOCOL_VERSION = 0b0001 +DEFAULT_HEADER_SIZE = 0b0001 + +PROTOCOL_VERSION_BITS = 4 +HEADER_BITS = 4 +MESSAGE_TYPE_BITS = 4 +MESSAGE_TYPE_SPECIFIC_FLAGS_BITS = 4 +MESSAGE_SERIALIZATION_BITS = 4 +MESSAGE_COMPRESSION_BITS = 4 +RESERVED_BITS = 8 + +# Message Type: +CLIENT_FULL_REQUEST = 0b0001 +CLIENT_AUDIO_ONLY_REQUEST = 0b0010 +SERVER_FULL_RESPONSE = 0b1001 +SERVER_ACK = 0b1011 +SERVER_ERROR_RESPONSE = 0b1111 + +# Message Type Specific Flags +NO_SEQUENCE = 0b0000 # no check sequence +POS_SEQUENCE = 0b0001 +NEG_SEQUENCE = 0b0010 +NEG_SEQUENCE_1 = 0b0011 + +# Message Serialization +NO_SERIALIZATION = 0b0000 +JSON = 0b0001 +THRIFT = 0b0011 +CUSTOM_TYPE = 0b1111 + +# Message Compression +NO_COMPRESSION = 0b0000 +GZIP = 0b0001 +CUSTOM_COMPRESSION = 0b1111 + + +def generate_header( + version=PROTOCOL_VERSION, + message_type=CLIENT_FULL_REQUEST, + message_type_specific_flags=NO_SEQUENCE, + serial_method=JSON, + compression_type=GZIP, + reserved_data=0x00, + extension_header=bytes() +): + """ + protocol_version(4 bits), header_size(4 bits), + message_type(4 bits), message_type_specific_flags(4 bits) + serialization_method(4 bits) message_compression(4 bits) + reserved (8bits) 保留字段 + header_extensions 扩展头(大小等于 8 * 4 * (header_size - 1) ) + """ + header = bytearray() + header_size = int(len(extension_header) / 4) + 1 + header.append((version << 4) | header_size) + header.append((message_type << 4) | message_type_specific_flags) + header.append((serial_method << 4) | compression_type) + header.append(reserved_data) + header.extend(extension_header) + return header + + +def generate_full_default_header(): + return generate_header() + + +def generate_audio_default_header(): + return generate_header( + message_type=CLIENT_AUDIO_ONLY_REQUEST + ) + + +def generate_last_audio_default_header(): + return generate_header( + message_type=CLIENT_AUDIO_ONLY_REQUEST, + message_type_specific_flags=NEG_SEQUENCE + ) + + +def parse_response(res): + """ + protocol_version(4 bits), header_size(4 bits), + message_type(4 bits), message_type_specific_flags(4 bits) + serialization_method(4 bits) message_compression(4 bits) + reserved (8bits) 保留字段 + header_extensions 扩展头(大小等于 8 * 4 * (header_size - 1) ) + payload 类似与http 请求体 + """ + protocol_version = res[0] >> 4 + header_size = res[0] & 0x0f + message_type = res[1] >> 4 + message_type_specific_flags = res[1] & 0x0f + serialization_method = res[2] >> 4 + message_compression = res[2] & 0x0f + reserved = res[3] + header_extensions = res[4:header_size * 4] + payload = res[header_size * 4:] + result = {} + payload_msg = None + payload_size = 0 + if message_type == SERVER_FULL_RESPONSE: + payload_size = int.from_bytes(payload[:4], "big", signed=True) + payload_msg = payload[4:] + elif message_type == SERVER_ACK: + seq = int.from_bytes(payload[:4], "big", signed=True) + result['seq'] = seq + if len(payload) >= 8: + payload_size = int.from_bytes(payload[4:8], "big", signed=False) + payload_msg = payload[8:] + elif message_type == SERVER_ERROR_RESPONSE: + code = int.from_bytes(payload[:4], "big", signed=False) + result['code'] = code + payload_size = int.from_bytes(payload[4:8], "big", signed=False) + payload_msg = payload[8:] + if payload_msg is None: + return result + if message_compression == GZIP: + payload_msg = gzip.decompress(payload_msg) + if serialization_method == JSON: + payload_msg = json.loads(str(payload_msg, "utf-8")) + elif serialization_method != NO_SERIALIZATION: + payload_msg = str(payload_msg, "utf-8") + result['payload_msg'] = payload_msg + result['payload_size'] = payload_size + return result + + +def read_wav_info(data: bytes = None) -> (int, int, int, int, int): + with BytesIO(data) as _f: + wave_fp = wave.open(_f, 'rb') + nchannels, sampwidth, framerate, nframes = wave_fp.getparams()[:4] + wave_bytes = wave_fp.readframes(nframes) + return nchannels, sampwidth, framerate, nframes, len(wave_bytes) + + +class VolcanicEngineSpeechToText(MaxKBBaseModel, BaseSpeechToText): + workflow: str = "audio_in,resample,partition,vad,fe,decode,itn,nlu_punctuate" + show_language: bool = False + show_utterances: bool = False + result_type: str = "full" + format: str = "mp3" + rate: int = 16000 + language: str = "zh-CN" + bits: int = 16 + channel: int = 1 + codec: str = "raw" + audio_type: int = 1 + secret: str = "access_secret" + auth_method: str = "token" + mp3_seg_size: int = 10000 + success_code: int = 1000 # success code, default is 1000 + seg_duration: int = 15000 + nbest: int = 1 + + volcanic_app_id: str + volcanic_cluster: str + volcanic_api_url: str + volcanic_token: str + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.volcanic_api_url = kwargs.get('volcanic_api_url') + self.volcanic_token = kwargs.get('volcanic_token') + self.volcanic_app_id = kwargs.get('volcanic_app_id') + self.volcanic_cluster = kwargs.get('volcanic_cluster') + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + optional_params = {} + if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None: + optional_params['max_tokens'] = model_kwargs['max_tokens'] + if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None: + optional_params['temperature'] = model_kwargs['temperature'] + return VolcanicEngineSpeechToText( + volcanic_api_url=model_credential.get('volcanic_api_url'), + volcanic_token=model_credential.get('volcanic_token'), + volcanic_app_id=model_credential.get('volcanic_app_id'), + volcanic_cluster=model_credential.get('volcanic_cluster'), + **optional_params + ) + + def construct_request(self, reqid): + req = { + 'app': { + 'appid': self.volcanic_app_id, + 'cluster': self.volcanic_cluster, + 'token': self.volcanic_token, + }, + 'user': { + 'uid': 'uid' + }, + 'request': { + 'reqid': reqid, + 'nbest': self.nbest, + 'workflow': self.workflow, + 'show_language': self.show_language, + 'show_utterances': self.show_utterances, + 'result_type': self.result_type, + "sequence": 1 + }, + 'audio': { + 'format': self.format, + 'rate': self.rate, + 'language': self.language, + 'bits': self.bits, + 'channel': self.channel, + 'codec': self.codec + } + } + return req + + @staticmethod + def slice_data(data: bytes, chunk_size: int) -> (list, bool): + """ + slice data + :param data: wav data + :param chunk_size: the segment size in one request + :return: segment data, last flag + """ + data_len = len(data) + offset = 0 + while offset + chunk_size < data_len: + yield data[offset: offset + chunk_size], False + offset += chunk_size + else: + yield data[offset: data_len], True + + def _real_processor(self, request_params: dict) -> dict: + pass + + def token_auth(self): + return {'Authorization': 'Bearer; {}'.format(self.volcanic_token)} + + def signature_auth(self, data): + header_dicts = { + 'Custom': 'auth_custom', + } + + url_parse = urlparse(self.volcanic_api_url) + input_str = 'GET {} HTTP/1.1\n'.format(url_parse.path) + auth_headers = 'Custom' + for header in auth_headers.split(','): + input_str += '{}\n'.format(header_dicts[header]) + input_data = bytearray(input_str, 'utf-8') + input_data += data + mac = base64.urlsafe_b64encode( + hmac.new(self.secret.encode('utf-8'), input_data, digestmod=sha256).digest()) + header_dicts['Authorization'] = 'HMAC256; access_token="{}"; mac="{}"; h="{}"'.format(self.volcanic_token, + str(mac, 'utf-8'), + auth_headers) + return header_dicts + + async def segment_data_processor(self, wav_data: bytes, segment_size: int): + reqid = str(uuid.uuid4()) + # 构建 full client request,并序列化压缩 + request_params = self.construct_request(reqid) + payload_bytes = str.encode(json.dumps(request_params)) + payload_bytes = gzip.compress(payload_bytes) + full_client_request = bytearray(generate_full_default_header()) + full_client_request.extend((len(payload_bytes)).to_bytes(4, 'big')) # payload size(4 bytes) + full_client_request.extend(payload_bytes) # payload + header = None + if self.auth_method == "token": + header = self.token_auth() + elif self.auth_method == "signature": + header = self.signature_auth(full_client_request) + async with websockets.connect(self.volcanic_api_url, extra_headers=header, max_size=1000000000) as ws: + # 发送 full client request + await ws.send(full_client_request) + res = await ws.recv() + result = parse_response(res) + if 'payload_msg' in result and result['payload_msg']['code'] != self.success_code: + return result + for seq, (chunk, last) in enumerate(VolcanicEngineSpeechToText.slice_data(wav_data, segment_size), 1): + # if no compression, comment this line + payload_bytes = gzip.compress(chunk) + audio_only_request = bytearray(generate_audio_default_header()) + if last: + audio_only_request = bytearray(generate_last_audio_default_header()) + audio_only_request.extend((len(payload_bytes)).to_bytes(4, 'big')) # payload size(4 bytes) + audio_only_request.extend(payload_bytes) # payload + # 发送 audio-only client request + await ws.send(audio_only_request) + res = await ws.recv() + result = parse_response(res) + if 'payload_msg' in result and result['payload_msg']['code'] != self.success_code: + return result + return result['payload_msg']['result'][0]['text'] + + def check_auth(self): + header = self.token_auth() + + async def check(): + async with websockets.connect(self.volcanic_api_url, extra_headers=header, max_size=1000000000) as ws: + pass + + asyncio.run(check()) + + def speech_to_text(self, file): + data = file.read() + audio_data = bytes(data) + if self.format == "mp3": + segment_size = self.mp3_seg_size + return asyncio.run(self.segment_data_processor(audio_data, segment_size)) + if self.format != "wav": + raise Exception("format should in wav or mp3") + nchannels, sampwidth, framerate, nframes, wav_len = read_wav_info( + audio_data) + size_per_sec = nchannels * sampwidth * framerate + segment_size = int(size_per_sec * self.seg_duration / 1000) + return asyncio.run(self.segment_data_processor(audio_data, segment_size)) diff --git a/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/tts.py b/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/tts.py new file mode 100644 index 000000000..4cd6adfa8 --- /dev/null +++ b/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/tts.py @@ -0,0 +1,164 @@ +# coding=utf-8 + +''' +requires Python 3.6 or later + +pip install asyncio +pip install websockets + +''' + +import asyncio +import copy +import gzip +import json +import uuid +from typing import Dict + +import websockets + +from setting.models_provider.base_model_provider import MaxKBBaseModel +from setting.models_provider.impl.base_tts import BaseTextToSpeech + +MESSAGE_TYPES = {11: "audio-only server response", 12: "frontend server response", 15: "error message from server"} +MESSAGE_TYPE_SPECIFIC_FLAGS = {0: "no sequence number", 1: "sequence number > 0", + 2: "last message from server (seq < 0)", 3: "sequence number < 0"} +MESSAGE_SERIALIZATION_METHODS = {0: "no serialization", 1: "JSON", 15: "custom type"} +MESSAGE_COMPRESSIONS = {0: "no compression", 1: "gzip", 15: "custom compression method"} + +# version: b0001 (4 bits) +# header size: b0001 (4 bits) +# message type: b0001 (Full client request) (4bits) +# message type specific flags: b0000 (none) (4bits) +# message serialization method: b0001 (JSON) (4 bits) +# message compression: b0001 (gzip) (4bits) +# reserved data: 0x00 (1 byte) +default_header = bytearray(b'\x11\x10\x11\x00') + + +class VolcanicEngineTextToSpeech(MaxKBBaseModel, BaseTextToSpeech): + volcanic_app_id: str + volcanic_cluster: str + volcanic_api_url: str + volcanic_token: str + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.volcanic_api_url = kwargs.get('volcanic_api_url') + self.volcanic_token = kwargs.get('volcanic_token') + self.volcanic_app_id = kwargs.get('volcanic_app_id') + self.volcanic_cluster = kwargs.get('volcanic_cluster') + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + optional_params = {} + if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None: + optional_params['max_tokens'] = model_kwargs['max_tokens'] + if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None: + optional_params['temperature'] = model_kwargs['temperature'] + return VolcanicEngineTextToSpeech( + volcanic_api_url=model_credential.get('volcanic_api_url'), + volcanic_token=model_credential.get('volcanic_token'), + volcanic_app_id=model_credential.get('volcanic_app_id'), + volcanic_cluster=model_credential.get('volcanic_cluster'), + **optional_params + ) + + def check_auth(self): + header = self.token_auth() + + async def check(): + async with websockets.connect(self.volcanic_api_url, extra_headers=header, ping_interval=None) as ws: + pass + + asyncio.run(check()) + + def text_to_speech(self, text): + request_json = { + "app": { + "appid": self.volcanic_app_id, + "token": "access_token", + "cluster": self.volcanic_cluster + }, + "user": { + "uid": "uid" + }, + "audio": { + "voice_type": "BV002_streaming", + "encoding": "mp3", + "speed_ratio": 1.0, + "volume_ratio": 1.0, + "pitch_ratio": 1.0, + }, + "request": { + "reqid": str(uuid.uuid4()), + "text": text, + "text_type": "plain", + "operation": "xxx" + } + } + + return asyncio.run(self.submit(request_json)) + + def token_auth(self): + return {'Authorization': 'Bearer; {}'.format(self.volcanic_token)} + + async def submit(self, request_json): + submit_request_json = copy.deepcopy(request_json) + submit_request_json["request"]["reqid"] = str(uuid.uuid4()) + submit_request_json["request"]["operation"] = "submit" + payload_bytes = str.encode(json.dumps(submit_request_json)) + payload_bytes = gzip.compress(payload_bytes) # if no compression, comment this line + full_client_request = bytearray(default_header) + full_client_request.extend((len(payload_bytes)).to_bytes(4, 'big')) # payload size(4 bytes) + full_client_request.extend(payload_bytes) # payload + header = {"Authorization": f"Bearer; {self.volcanic_token}"} + async with websockets.connect(self.volcanic_api_url, extra_headers=header, ping_interval=None) as ws: + await ws.send(full_client_request) + return await self.parse_response(ws) + + @staticmethod + async def parse_response(ws): + result = b'' + while True: + res = await ws.recv() + protocol_version = res[0] >> 4 + header_size = res[0] & 0x0f + message_type = res[1] >> 4 + message_type_specific_flags = res[1] & 0x0f + serialization_method = res[2] >> 4 + message_compression = res[2] & 0x0f + reserved = res[3] + header_extensions = res[4:header_size * 4] + payload = res[header_size * 4:] + if header_size != 1: + # print(f" Header extensions: {header_extensions}") + pass + if message_type == 0xb: # audio-only server response + if message_type_specific_flags == 0: # no sequence number as ACK + continue + else: + sequence_number = int.from_bytes(payload[:4], "big", signed=True) + payload_size = int.from_bytes(payload[4:8], "big", signed=False) + payload = payload[8:] + result += payload + if sequence_number < 0: + break + else: + continue + elif message_type == 0xf: + code = int.from_bytes(payload[:4], "big", signed=False) + msg_size = int.from_bytes(payload[4:8], "big", signed=False) + error_msg = payload[8:] + if message_compression == 1: + error_msg = gzip.decompress(error_msg) + error_msg = str(error_msg, "utf-8") + break + elif message_type == 0xc: + msg_size = int.from_bytes(payload[:4], "big", signed=False) + payload = payload[4:] + if message_compression == 1: + payload = gzip.decompress(payload) + else: + break + return result \ No newline at end of file diff --git a/apps/setting/models_provider/impl/volcanic_engine_model_provider/volcanic_engine_model_provider.py b/apps/setting/models_provider/impl/volcanic_engine_model_provider/volcanic_engine_model_provider.py index 48802f6b8..0238e5694 100644 --- a/apps/setting/models_provider/impl/volcanic_engine_model_provider/volcanic_engine_model_provider.py +++ b/apps/setting/models_provider/impl/volcanic_engine_model_provider/volcanic_engine_model_provider.py @@ -15,17 +15,31 @@ from setting.models_provider.impl.openai_model_provider.credential.embedding imp from setting.models_provider.impl.openai_model_provider.credential.llm import OpenAILLMModelCredential from setting.models_provider.impl.openai_model_provider.model.embedding import OpenAIEmbeddingModel from setting.models_provider.impl.volcanic_engine_model_provider.model.llm import VolcanicEngineChatModel +from setting.models_provider.impl.volcanic_engine_model_provider.credential.stt import VolcanicEngineSTTModelCredential +from setting.models_provider.impl.volcanic_engine_model_provider.model.stt import VolcanicEngineSpeechToText +from setting.models_provider.impl.volcanic_engine_model_provider.model.tts import VolcanicEngineTextToSpeech from smartdoc.conf import PROJECT_DIR volcanic_engine_llm_model_credential = OpenAILLMModelCredential() +volcanic_engine_stt_model_credential = VolcanicEngineSTTModelCredential() model_info_list = [ ModelInfo('ep-xxxxxxxxxx-yyyy', '用户前往火山方舟的模型推理页面创建推理接入点,这里需要输入ep-xxxxxxxxxx-yyyy进行调用', ModelTypeConst.LLM, volcanic_engine_llm_model_credential, VolcanicEngineChatModel - ) + ), + ModelInfo('asr', + '', + ModelTypeConst.STT, + volcanic_engine_stt_model_credential, VolcanicEngineSpeechToText + ), + ModelInfo('tts', + '', + ModelTypeConst.TTS, + volcanic_engine_stt_model_credential, VolcanicEngineTextToSpeech + ), ] open_ai_embedding_credential = OpenAIEmbeddingCredential() diff --git a/apps/setting/models_provider/impl/xf_model_provider/credential/stt.py b/apps/setting/models_provider/impl/xf_model_provider/credential/stt.py new file mode 100644 index 000000000..c32de6cbe --- /dev/null +++ b/apps/setting/models_provider/impl/xf_model_provider/credential/stt.py @@ -0,0 +1,46 @@ +# coding=utf-8 + +from typing import Dict + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class XunFeiSTTModelCredential(BaseForm, BaseModelCredential): + spark_api_url = forms.TextInputField('API 域名', required=True) + spark_app_id = forms.TextInputField('APP ID', required=True) + spark_api_key = forms.PasswordInputField("API Key", required=True) + spark_api_secret = forms.PasswordInputField('API Secret', required=True) + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + + for key in ['spark_api_url', 'spark_app_id', 'spark_api_key', 'spark_api_secret']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') + else: + return False + try: + model = provider.get_model(model_type, model_name, model_credential) + model.check_auth() + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'spark_api_secret': super().encryption(model.get('spark_api_secret', ''))} + + + def get_model_params_setting_form(self, model_name): + pass diff --git a/apps/setting/models_provider/impl/xf_model_provider/model/stt.py b/apps/setting/models_provider/impl/xf_model_provider/model/stt.py new file mode 100644 index 000000000..fd439df72 --- /dev/null +++ b/apps/setting/models_provider/impl/xf_model_provider/model/stt.py @@ -0,0 +1,165 @@ +# -*- coding:utf-8 -*- +# +# 错误码链接:https://www.xfyun.cn/document/error-code (code返回错误码时必看) +# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # +import asyncio +import base64 +import datetime +import hashlib +import hmac +import json +from datetime import datetime +from typing import Dict +from urllib.parse import urlencode, urlparse + +import websockets + +from setting.models_provider.base_model_provider import MaxKBBaseModel +from setting.models_provider.impl.base_stt import BaseSpeechToText + +STATUS_FIRST_FRAME = 0 # 第一帧的标识 +STATUS_CONTINUE_FRAME = 1 # 中间帧标识 +STATUS_LAST_FRAME = 2 # 最后一帧的标识 + + +class XFSparkSpeechToText(MaxKBBaseModel, BaseSpeechToText): + spark_app_id: str + spark_api_key: str + spark_api_secret: str + spark_api_url: str + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.spark_api_url = kwargs.get('spark_api_url') + self.spark_app_id = kwargs.get('spark_app_id') + self.spark_api_key = kwargs.get('spark_api_key') + self.spark_api_secret = kwargs.get('spark_api_secret') + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + optional_params = {} + if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None: + optional_params['max_tokens'] = model_kwargs['max_tokens'] + if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None: + optional_params['temperature'] = model_kwargs['temperature'] + return XFSparkSpeechToText( + spark_app_id=model_credential.get('spark_app_id'), + spark_api_key=model_credential.get('spark_api_key'), + spark_api_secret=model_credential.get('spark_api_secret'), + spark_api_url=model_credential.get('spark_api_url'), + **optional_params + ) + + # 生成url + def create_url(self): + url = self.spark_api_url + host = urlparse(url).hostname + # 生成RFC1123格式的时间戳 + gmt_format = '%a, %d %b %Y %H:%M:%S GMT' + date = datetime.utcnow().strftime(gmt_format) + + # 拼接字符串 + signature_origin = "host: " + host + "\n" + signature_origin += "date: " + date + "\n" + signature_origin += "GET " + "/v2/iat " + "HTTP/1.1" + # 进行hmac-sha256进行加密 + signature_sha = hmac.new(self.spark_api_secret.encode('utf-8'), signature_origin.encode('utf-8'), + digestmod=hashlib.sha256).digest() + signature_sha = base64.b64encode(signature_sha).decode(encoding='utf-8') + + authorization_origin = "api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"" % ( + self.spark_api_key, "hmac-sha256", "host date request-line", signature_sha) + authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8') + # 将请求的鉴权参数组合为字典 + v = { + "authorization": authorization, + "date": date, + "host": host + } + # 拼接鉴权参数,生成url + url = url + '?' + urlencode(v) + # print("date: ",date) + # print("v: ",v) + # 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致 + # print('websocket url :', url) + return url + + def check_auth(self): + async def check(): + async with websockets.connect(self.create_url()) as ws: + pass + + asyncio.run(check()) + + def speech_to_text(self, file): + async def handle(): + async with websockets.connect(self.create_url(), max_size=1000000000) as ws: + # 发送 full client request + await self.send(ws, file) + return await self.handle_message(ws) + + return asyncio.run(handle()) + + @staticmethod + async def handle_message(ws): + res = await ws.recv() + message = json.loads(res) + code = message["code"] + sid = message["sid"] + if code != 0: + errMsg = message["message"] + print("sid:%s call error:%s code is:%s" % (sid, errMsg, code)) + return errMsg + else: + data = message["data"]["result"]["ws"] + result = "" + for i in data: + for w in i["cw"]: + result += w["w"] + # print("sid:%s call success!,data is:%s" % (sid, json.dumps(data, ensure_ascii=False))) + return result + + # 收到websocket连接建立的处理 + async def send(self, ws, file): + frameSize = 8000 # 每一帧的音频大小 + status = STATUS_FIRST_FRAME # 音频的状态信息,标识音频是第一帧,还是中间帧、最后一帧 + + while True: + buf = file.read(frameSize) + # 文件结束 + if not buf: + status = STATUS_LAST_FRAME + # 第一帧处理 + # 发送第一帧音频,带business 参数 + # appid 必须带上,只需第一帧发送 + if status == STATUS_FIRST_FRAME: + d = { + "common": {"app_id": self.spark_app_id}, + "business": { + "domain": "iat", + "language": "zh_cn", + "accent": "mandarin", + "vinfo": 1, + "vad_eos": 10000 + }, + "data": { + "status": 0, "format": "audio/L16;rate=16000", + "audio": str(base64.b64encode(buf), 'utf-8'), + "encoding": "lame"} + } + d = json.dumps(d) + await ws.send(d) + status = STATUS_CONTINUE_FRAME + # 中间帧处理 + elif status == STATUS_CONTINUE_FRAME: + d = {"data": {"status": 1, "format": "audio/L16;rate=16000", + "audio": str(base64.b64encode(buf), 'utf-8'), + "encoding": "lame"}} + await ws.send(json.dumps(d)) + # 最后一帧处理 + elif status == STATUS_LAST_FRAME: + d = {"data": {"status": 2, "format": "audio/L16;rate=16000", + "audio": str(base64.b64encode(buf), 'utf-8'), + "encoding": "lame"}} + await ws.send(json.dumps(d)) + break diff --git a/apps/setting/models_provider/impl/xf_model_provider/model/tts.py b/apps/setting/models_provider/impl/xf_model_provider/model/tts.py new file mode 100644 index 000000000..b33537b6b --- /dev/null +++ b/apps/setting/models_provider/impl/xf_model_provider/model/tts.py @@ -0,0 +1,138 @@ +# -*- coding:utf-8 -*- +# +# author: iflytek +# +# 错误码链接:https://www.xfyun.cn/document/error-code (code返回错误码时必看) +# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # +import asyncio +import base64 +import datetime +import hashlib +import hmac +import json +import os +from datetime import datetime +from typing import Dict +from urllib.parse import urlencode, urlparse + +import websockets + +from setting.models_provider.base_model_provider import MaxKBBaseModel +from setting.models_provider.impl.base_tts import BaseTextToSpeech + +STATUS_FIRST_FRAME = 0 # 第一帧的标识 +STATUS_CONTINUE_FRAME = 1 # 中间帧标识 +STATUS_LAST_FRAME = 2 # 最后一帧的标识 + + +class XFSparkTextToSpeech(MaxKBBaseModel, BaseTextToSpeech): + spark_app_id: str + spark_api_key: str + spark_api_secret: str + spark_api_url: str + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.spark_api_url = kwargs.get('spark_api_url') + self.spark_app_id = kwargs.get('spark_app_id') + self.spark_api_key = kwargs.get('spark_api_key') + self.spark_api_secret = kwargs.get('spark_api_secret') + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + optional_params = {} + if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None: + optional_params['max_tokens'] = model_kwargs['max_tokens'] + if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None: + optional_params['temperature'] = model_kwargs['temperature'] + return XFSparkTextToSpeech( + spark_app_id=model_credential.get('spark_app_id'), + spark_api_key=model_credential.get('spark_api_key'), + spark_api_secret=model_credential.get('spark_api_secret'), + spark_api_url=model_credential.get('spark_api_url'), + **optional_params + ) + + # 生成url + def create_url(self): + url = self.spark_api_url + host = urlparse(url).hostname + # 生成RFC1123格式的时间戳 + gmt_format = '%a, %d %b %Y %H:%M:%S GMT' + date = datetime.utcnow().strftime(gmt_format) + + # 拼接字符串 + signature_origin = "host: " + host + "\n" + signature_origin += "date: " + date + "\n" + signature_origin += "GET " + "/v2/tts " + "HTTP/1.1" + # 进行hmac-sha256进行加密 + signature_sha = hmac.new(self.spark_api_secret.encode('utf-8'), signature_origin.encode('utf-8'), + digestmod=hashlib.sha256).digest() + signature_sha = base64.b64encode(signature_sha).decode(encoding='utf-8') + + authorization_origin = "api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"" % ( + self.spark_api_key, "hmac-sha256", "host date request-line", signature_sha) + authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8') + # 将请求的鉴权参数组合为字典 + v = { + "authorization": authorization, + "date": date, + "host": host + } + # 拼接鉴权参数,生成url + url = url + '?' + urlencode(v) + # print("date: ",date) + # print("v: ",v) + # 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致 + # print('websocket url :', url) + return url + + def check_auth(self): + async def check(): + async with websockets.connect(self.create_url(), max_size=1000000000) as ws: + pass + + asyncio.run(check()) + + def text_to_speech(self, text): + + # 使用小语种须使用以下方式,此处的unicode指的是 utf16小端的编码方式,即"UTF-16LE"” + # self.Data = {"status": 2, "text": str(base64.b64encode(self.Text.encode('utf-16')), "UTF8")} + async def handle(): + async with websockets.connect(self.create_url(), max_size=1000000000) as ws: + # 发送 full client request + await self.send(ws, text) + return await self.handle_message(ws) + + return asyncio.run(handle()) + + @staticmethod + async def handle_message(ws): + audio_bytes: bytes = b'' + while True: + res = await ws.recv() + message = json.loads(res) + # print(message) + code = message["code"] + sid = message["sid"] + audio = message["data"]["audio"] + audio = base64.b64decode(audio) + + if code != 0: + errMsg = message["message"] + print("sid:%s call error:%s code is:%s" % (sid, errMsg, code)) + else: + audio_bytes += audio + # 退出 + if message["data"]["status"] == 2: + break + return audio_bytes + + async def send(self, ws, text): + d = { + "common": {"app_id": self.spark_app_id}, + "business": {"aue": "lame", "sfl": 1, "auf": "audio/L16;rate=16000", "vcn": "xiaoyan", "tte": "utf8"}, + "data": {"status": 2, "text": str(base64.b64encode(text.encode('utf-8')), "UTF8")}, + } + d = json.dumps(d) + await ws.send(d) diff --git a/apps/setting/models_provider/impl/xf_model_provider/xf_model_provider.py b/apps/setting/models_provider/impl/xf_model_provider/xf_model_provider.py index d33b944e3..61bd5e0ac 100644 --- a/apps/setting/models_provider/impl/xf_model_provider/xf_model_provider.py +++ b/apps/setting/models_provider/impl/xf_model_provider/xf_model_provider.py @@ -13,16 +13,24 @@ from common.util.file_util import get_file_content from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, ModelInfo, IModelProvider, \ ModelInfoManage from setting.models_provider.impl.xf_model_provider.credential.llm import XunFeiLLMModelCredential +from setting.models_provider.impl.xf_model_provider.credential.stt import XunFeiSTTModelCredential from setting.models_provider.impl.xf_model_provider.model.llm import XFChatSparkLLM +from setting.models_provider.impl.xf_model_provider.model.stt import XFSparkSpeechToText +from setting.models_provider.impl.xf_model_provider.model.tts import XFSparkTextToSpeech from smartdoc.conf import PROJECT_DIR ssl._create_default_https_context = ssl.create_default_context() qwen_model_credential = XunFeiLLMModelCredential() -model_info_list = [ModelInfo('generalv3.5', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM), - ModelInfo('generalv3', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM), - ModelInfo('generalv2', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM) - ] +stt_model_credential = XunFeiSTTModelCredential() +model_info_list = [ + ModelInfo('generalv3.5', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM), + ModelInfo('generalv3', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM), + ModelInfo('generalv2', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM), + ModelInfo('iat', '中英文识别', ModelTypeConst.STT, stt_model_credential, XFSparkSpeechToText), + ModelInfo('iat-niche', '小语种识别', ModelTypeConst.STT, stt_model_credential, XFSparkSpeechToText), + ModelInfo('tts', '', ModelTypeConst.TTS, stt_model_credential, XFSparkTextToSpeech), +] model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info( ModelInfo('generalv3.5', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM)).build() diff --git a/pyproject.toml b/pyproject.toml index d37963caa..37d1f61df 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,7 @@ dashscope = "^1.17.0" zhipuai = "^2.0.1" httpx = "^0.27.0" httpx-sse = "^0.4.0" -websocket-client = "^1.7.0" +websockets = "^13.0" langchain-google-genai = "^1.0.3" openpyxl = "^3.1.2" xlrd = "^2.0.1" diff --git a/ui/src/enums/model.ts b/ui/src/enums/model.ts index c27fc61fd..52ff935d8 100644 --- a/ui/src/enums/model.ts +++ b/ui/src/enums/model.ts @@ -9,7 +9,9 @@ export enum PermissionDesc { export enum modelType { EMBEDDING = '向量模型', - LLM = '大语言模型' + LLM = '大语言模型', + STT = '语音识别', + TTS = '语音合成', }