diff --git a/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/asr_stt.py b/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/asr_stt.py index e9d4f2308..93374e47d 100644 --- a/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/asr_stt.py +++ b/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/asr_stt.py @@ -8,8 +8,6 @@ from common.forms import BaseForm from models_provider.base_model_provider import BaseModelCredential, ValidCode from django.utils.translation import gettext as _ -from models_provider.impl.aliyun_bai_lian_model_provider.credential.omni_stt import AliyunBaiLianOmiSTTModelParams - class AliyunBaiLianAsrSTTModelCredential(BaseForm, BaseModelCredential): api_url = forms.TextInputField(_('API URL'), required=True) @@ -64,4 +62,4 @@ class AliyunBaiLianAsrSTTModelCredential(BaseForm, BaseModelCredential): def get_model_params_setting_form(self, model_name): - return AliyunBaiLianOmiSTTModelParams() + pass diff --git a/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/omni_stt.py b/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/omni_stt.py index 82dcb4b55..274e56678 100644 --- a/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/omni_stt.py +++ b/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/omni_stt.py @@ -8,12 +8,12 @@ from common.forms import BaseForm, PasswordInputField, TooltipLabel from models_provider.base_model_provider import BaseModelCredential, ValidCode from django.utils.translation import gettext as _ -class AliyunBaiLianOmiSTTModelParams(BaseForm): - CueWord = forms.TextInputField( - TooltipLabel(_('CueWord'), _('If not passed, the default value is What is this audio saying? Only answer the audio content')), - required=True, - default_value='这段音频在说什么,只回答音频的内容', - ) +# class AliyunBaiLianOmiSTTModelParams(BaseForm): +# CueWord = forms.TextInputField( +# TooltipLabel(_('CueWord'), _('If not passed, the default value is What is this audio saying? Only answer the audio content')), +# required=True, +# default_value='这段音频在说什么,只回答音频的内容', +# ) class AliyunBaiLianOmiSTTModelCredential(BaseForm, BaseModelCredential): @@ -70,4 +70,4 @@ class AliyunBaiLianOmiSTTModelCredential(BaseForm, BaseModelCredential): def get_model_params_setting_form(self, model_name): - return AliyunBaiLianOmiSTTModelParams() + pass \ No newline at end of file diff --git a/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/asr_stt.py b/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/asr_stt.py index c326e452e..e843180f5 100644 --- a/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/asr_stt.py +++ b/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/asr_stt.py @@ -61,10 +61,11 @@ class AliyunBaiLianAsrSpeechToText(MaxKBBaseModel, BaseSpeechToText): messages=messages, result_format="message", ) - - text = response["output"]["choices"][0]["message"].content[0]["text"] - - return text + if response.status_code == 200: + text = response["output"]["choices"][0]["message"].content[0]["text"] + return text + else: + raise Exception('Error: ', response.message) except Exception as err: maxkb_logger.error(f":Error: {str(err)}: {traceback.format_exc()}") diff --git a/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/omni_stt.py b/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/omni_stt.py index 56e060f6b..23a060288 100644 --- a/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/omni_stt.py +++ b/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/omni_stt.py @@ -68,7 +68,7 @@ class AliyunBaiLianOmiSpeechToText(MaxKBBaseModel, BaseSpeechToText): "format": "mp3", }, }, - {"type": "text", "text": self.params.get('CueWord')}, + {"type": "text", "text": '这段音频在说什么,只回答音频的内容'}, ], }, ], diff --git a/apps/models_provider/impl/tencent_model_provider/credential/stt.py b/apps/models_provider/impl/tencent_model_provider/credential/stt.py index 3eea500f2..a055ccb1d 100644 --- a/apps/models_provider/impl/tencent_model_provider/credential/stt.py +++ b/apps/models_provider/impl/tencent_model_provider/credential/stt.py @@ -8,38 +8,38 @@ from django.utils.translation import gettext_lazy as _, gettext from models_provider.base_model_provider import BaseModelCredential, ValidCode -class TencentSSTModelParams(BaseForm): - EngSerViceType = forms.SingleSelect( - TooltipLabel(_('Engine model type'), _('If not passed, the default value is 16k_zh (Chinese universal)')), - required=True, - default_value='16k_zh', - option_list=[ - {"value": "8k_zh", "label": _("Chinese telephone universal")}, - {"value": "8k_en", "label": _("English telephone universal")}, - {"value": "16k_zh", "label": _("Commonly used in Chinese")}, - {"value": "16k_zh-PY", "label": _("Chinese, English, and Guangdong")}, - {"value": "16k_zh_medical", "label": _("Chinese medical")}, - {"value": "16k_en", "label": _("English")}, - {"value": "16k_yue", "label": _("Cantonese")}, - {"value": "16k_ja", "label": _("Japanese")}, - {"value": "16k_ko", "label": _("Korean")}, - {"value": "16k_vi", "label": _("Vietnamese")}, - {"value": "16k_ms", "label": _("Malay language")}, - {"value": "16k_id", "label": _("Indonesian language")}, - {"value": "16k_fil", "label": _("Filipino language")}, - {"value": "16k_th", "label": _("Thai")}, - {"value": "16k_pt", "label": _("Portuguese")}, - {"value": "16k_tr", "label": _("Turkish")}, - {"value": "16k_ar", "label": _("Arabic")}, - {"value": "16k_es", "label": _("Spanish")}, - {"value": "16k_hi", "label": _("Hindi")}, - {"value": "16k_fr", "label": _("French")}, - {"value": "16k_de", "label": _("German")}, - {"value": "16k_zh_dialect", "label": _("Multiple dialects, supporting 23 dialects")} - ], - value_field='value', - text_field='label' - ) +# class TencentSSTModelParams(BaseForm): +# EngSerViceType = forms.SingleSelect( +# TooltipLabel(_('Engine model type'), _('If not passed, the default value is 16k_zh (Chinese universal)')), +# required=True, +# default_value='16k_zh', +# option_list=[ +# {"value": "8k_zh", "label": _("Chinese telephone universal")}, +# {"value": "8k_en", "label": _("English telephone universal")}, +# {"value": "16k_zh", "label": _("Commonly used in Chinese")}, +# {"value": "16k_zh-PY", "label": _("Chinese, English, and Guangdong")}, +# {"value": "16k_zh_medical", "label": _("Chinese medical")}, +# {"value": "16k_en", "label": _("English")}, +# {"value": "16k_yue", "label": _("Cantonese")}, +# {"value": "16k_ja", "label": _("Japanese")}, +# {"value": "16k_ko", "label": _("Korean")}, +# {"value": "16k_vi", "label": _("Vietnamese")}, +# {"value": "16k_ms", "label": _("Malay language")}, +# {"value": "16k_id", "label": _("Indonesian language")}, +# {"value": "16k_fil", "label": _("Filipino language")}, +# {"value": "16k_th", "label": _("Thai")}, +# {"value": "16k_pt", "label": _("Portuguese")}, +# {"value": "16k_tr", "label": _("Turkish")}, +# {"value": "16k_ar", "label": _("Arabic")}, +# {"value": "16k_es", "label": _("Spanish")}, +# {"value": "16k_hi", "label": _("Hindi")}, +# {"value": "16k_fr", "label": _("French")}, +# {"value": "16k_de", "label": _("German")}, +# {"value": "16k_zh_dialect", "label": _("Multiple dialects, supporting 23 dialects")} +# ], +# value_field='value', +# text_field='label' +# ) class TencentSTTModelCredential(BaseForm, BaseModelCredential): REQUIRED_FIELDS = ["SecretId", "SecretKey"] @@ -87,4 +87,4 @@ class TencentSTTModelCredential(BaseForm, BaseModelCredential): SecretKey = forms.PasswordInputField('SecretKey', required=True) def get_model_params_setting_form(self, model_name): - return TencentSSTModelParams() + pass diff --git a/apps/models_provider/impl/tencent_model_provider/model/stt.py b/apps/models_provider/impl/tencent_model_provider/model/stt.py index a501fed19..87a1a05e4 100644 --- a/apps/models_provider/impl/tencent_model_provider/model/stt.py +++ b/apps/models_provider/impl/tencent_model_provider/model/stt.py @@ -65,7 +65,7 @@ class TencentSpeechToText(MaxKBBaseModel, BaseSpeechToText): # 实例化一个请求对象,每个接口都会对应一个request对象 req = models.SentenceRecognitionRequest() params = { - "EngSerViceType": self.params.get('EngSerViceType'), + "EngSerViceType": '16k_zh', "SourceType": 1, "VoiceFormat": "mp3", "Data": _v.decode(), diff --git a/apps/models_provider/impl/vllm_model_provider/credential/reranker.py b/apps/models_provider/impl/vllm_model_provider/credential/reranker.py index 6d391d14d..a2ae71e73 100644 --- a/apps/models_provider/impl/vllm_model_provider/credential/reranker.py +++ b/apps/models_provider/impl/vllm_model_provider/credential/reranker.py @@ -31,7 +31,8 @@ class VllmRerankerCredential(BaseForm, BaseModelCredential): return False try: model: VllmBgeReranker = provider.get_model(model_type, model_name, model_credential) - model.compress_documents([Document(page_content=_('Hello'))], _('Hello')) + test_text = str(_('Hello')) + model.compress_documents([Document(page_content=test_text)], test_text) except Exception as e: traceback.print_exc() if isinstance(e, AppApiException): diff --git a/apps/models_provider/impl/vllm_model_provider/vllm_model_provider.py b/apps/models_provider/impl/vllm_model_provider/vllm_model_provider.py index 009fe88b8..ffe544081 100644 --- a/apps/models_provider/impl/vllm_model_provider/vllm_model_provider.py +++ b/apps/models_provider/impl/vllm_model_provider/vllm_model_provider.py @@ -54,7 +54,7 @@ whisper_model_info_list = [ ] reranker_model_info_list = [ - ModelInfo('bge-reranker-v2-m3', '', ModelTypeConst.RERANKER, rerank_model_credential, VllmBgeReranker), + ModelInfo('BAAI/bge-reranker-v2-m3', '', ModelTypeConst.RERANKER, rerank_model_credential, VllmBgeReranker), ] model_info_manage = (