diff --git a/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/stt.py b/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/stt.py index 2715d9bd5..94e656162 100644 --- a/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/stt.py +++ b/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/stt.py @@ -9,7 +9,7 @@ from setting.models_provider.base_model_provider import BaseModelCredential, Val class VolcanicEngineSTTModelCredential(BaseForm, BaseModelCredential): - volcanic_api_url = forms.TextInputField('API 域名', required=True) + volcanic_api_url = forms.TextInputField('API 域名', required=True, default_value='wss://openspeech.bytedance.com/api/v2/asr') volcanic_app_id = forms.TextInputField('App ID', required=True) volcanic_token = forms.PasswordInputField('Token', required=True) volcanic_cluster = forms.TextInputField('Cluster', required=True) diff --git a/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/tts.py b/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/tts.py new file mode 100644 index 000000000..353740a42 --- /dev/null +++ b/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/tts.py @@ -0,0 +1,45 @@ +# coding=utf-8 + +from typing import Dict + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class VolcanicEngineTTSModelCredential(BaseForm, BaseModelCredential): + volcanic_api_url = forms.TextInputField('API 域名', required=True, default_value='wss://openspeech.bytedance.com/api/v1/tts/ws_binary') + volcanic_app_id = forms.TextInputField('App ID', required=True) + volcanic_token = forms.PasswordInputField('Token', required=True) + volcanic_cluster = forms.TextInputField('Cluster', required=True) + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + + for key in ['volcanic_api_url', 'volcanic_app_id', 'volcanic_token', 'volcanic_cluster']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') + else: + return False + try: + model = provider.get_model(model_type, model_name, model_credential) + model.check_auth() + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'volcanic_token': super().encryption(model.get('volcanic_token', ''))} + + def get_model_params_setting_form(self, model_name): + pass diff --git a/apps/setting/models_provider/impl/volcanic_engine_model_provider/volcanic_engine_model_provider.py b/apps/setting/models_provider/impl/volcanic_engine_model_provider/volcanic_engine_model_provider.py index 0238e5694..1a0e17d8b 100644 --- a/apps/setting/models_provider/impl/volcanic_engine_model_provider/volcanic_engine_model_provider.py +++ b/apps/setting/models_provider/impl/volcanic_engine_model_provider/volcanic_engine_model_provider.py @@ -14,6 +14,7 @@ from setting.models_provider.base_model_provider import IModelProvider, ModelPro from setting.models_provider.impl.openai_model_provider.credential.embedding import OpenAIEmbeddingCredential from setting.models_provider.impl.openai_model_provider.credential.llm import OpenAILLMModelCredential from setting.models_provider.impl.openai_model_provider.model.embedding import OpenAIEmbeddingModel +from setting.models_provider.impl.volcanic_engine_model_provider.credential.tts import VolcanicEngineTTSModelCredential from setting.models_provider.impl.volcanic_engine_model_provider.model.llm import VolcanicEngineChatModel from setting.models_provider.impl.volcanic_engine_model_provider.credential.stt import VolcanicEngineSTTModelCredential from setting.models_provider.impl.volcanic_engine_model_provider.model.stt import VolcanicEngineSpeechToText @@ -23,6 +24,7 @@ from smartdoc.conf import PROJECT_DIR volcanic_engine_llm_model_credential = OpenAILLMModelCredential() volcanic_engine_stt_model_credential = VolcanicEngineSTTModelCredential() +volcanic_engine_tts_model_credential = VolcanicEngineTTSModelCredential() model_info_list = [ ModelInfo('ep-xxxxxxxxxx-yyyy', @@ -38,7 +40,7 @@ model_info_list = [ ModelInfo('tts', '', ModelTypeConst.TTS, - volcanic_engine_stt_model_credential, VolcanicEngineTextToSpeech + volcanic_engine_tts_model_credential, VolcanicEngineTextToSpeech ), ] diff --git a/apps/setting/models_provider/impl/xf_model_provider/credential/stt.py b/apps/setting/models_provider/impl/xf_model_provider/credential/stt.py index c32de6cbe..bf051c18a 100644 --- a/apps/setting/models_provider/impl/xf_model_provider/credential/stt.py +++ b/apps/setting/models_provider/impl/xf_model_provider/credential/stt.py @@ -9,7 +9,7 @@ from setting.models_provider.base_model_provider import BaseModelCredential, Val class XunFeiSTTModelCredential(BaseForm, BaseModelCredential): - spark_api_url = forms.TextInputField('API 域名', required=True) + spark_api_url = forms.TextInputField('API 域名', required=True, default_value='wss://iat-api.xfyun.cn/v2/iat') spark_app_id = forms.TextInputField('APP ID', required=True) spark_api_key = forms.PasswordInputField("API Key", required=True) spark_api_secret = forms.PasswordInputField('API Secret', required=True) diff --git a/apps/setting/models_provider/impl/xf_model_provider/credential/tts.py b/apps/setting/models_provider/impl/xf_model_provider/credential/tts.py new file mode 100644 index 000000000..309f882b9 --- /dev/null +++ b/apps/setting/models_provider/impl/xf_model_provider/credential/tts.py @@ -0,0 +1,46 @@ +# coding=utf-8 + +from typing import Dict + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class XunFeiTTSModelCredential(BaseForm, BaseModelCredential): + spark_api_url = forms.TextInputField('API 域名', required=True, default_value='wss://tts-api.xfyun.cn/v2/tts') + spark_app_id = forms.TextInputField('APP ID', required=True) + spark_api_key = forms.PasswordInputField("API Key", required=True) + spark_api_secret = forms.PasswordInputField('API Secret', required=True) + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + + for key in ['spark_api_url', 'spark_app_id', 'spark_api_key', 'spark_api_secret']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') + else: + return False + try: + model = provider.get_model(model_type, model_name, model_credential) + model.check_auth() + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'spark_api_secret': super().encryption(model.get('spark_api_secret', ''))} + + + def get_model_params_setting_form(self, model_name): + pass diff --git a/apps/setting/models_provider/impl/xf_model_provider/xf_model_provider.py b/apps/setting/models_provider/impl/xf_model_provider/xf_model_provider.py index 61bd5e0ac..8c60083cf 100644 --- a/apps/setting/models_provider/impl/xf_model_provider/xf_model_provider.py +++ b/apps/setting/models_provider/impl/xf_model_provider/xf_model_provider.py @@ -14,6 +14,7 @@ from setting.models_provider.base_model_provider import ModelProvideInfo, ModelT ModelInfoManage from setting.models_provider.impl.xf_model_provider.credential.llm import XunFeiLLMModelCredential from setting.models_provider.impl.xf_model_provider.credential.stt import XunFeiSTTModelCredential +from setting.models_provider.impl.xf_model_provider.credential.tts import XunFeiTTSModelCredential from setting.models_provider.impl.xf_model_provider.model.llm import XFChatSparkLLM from setting.models_provider.impl.xf_model_provider.model.stt import XFSparkSpeechToText from setting.models_provider.impl.xf_model_provider.model.tts import XFSparkTextToSpeech @@ -23,13 +24,13 @@ ssl._create_default_https_context = ssl.create_default_context() qwen_model_credential = XunFeiLLMModelCredential() stt_model_credential = XunFeiSTTModelCredential() +tts_model_credential = XunFeiTTSModelCredential() model_info_list = [ ModelInfo('generalv3.5', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM), ModelInfo('generalv3', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM), ModelInfo('generalv2', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM), ModelInfo('iat', '中英文识别', ModelTypeConst.STT, stt_model_credential, XFSparkSpeechToText), - ModelInfo('iat-niche', '小语种识别', ModelTypeConst.STT, stt_model_credential, XFSparkSpeechToText), - ModelInfo('tts', '', ModelTypeConst.TTS, stt_model_credential, XFSparkTextToSpeech), + ModelInfo('tts', '', ModelTypeConst.TTS, tts_model_credential, XFSparkTextToSpeech), ] model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info(