From 71cec2fca48389548acd7afc2a0adc6b16d86c4d Mon Sep 17 00:00:00 2001 From: zhangzhanwei Date: Fri, 19 Sep 2025 10:12:53 +0800 Subject: [PATCH] feat: Support stt model params setting --- ...ation_stt_model_params_setting_and_more.py | 23 ++++ apps/application/models/application.py | 2 + apps/application/serializers/application.py | 5 +- .../model/asr_stt.py | 1 + .../model/omni_stt.py | 1 + .../impl/tencent_model_provider/model/stt.py | 1 + .../impl/xf_model_provider/model/zh_en_stt.py | 49 +++++-- ui/src/api/type/application.ts | 1 + .../views/application/ApplicationSetting.vue | 47 ++++++- .../component/STTModelParamSettingDialog.vue | 122 ++++++++++++++++++ ui/src/views/model/component/ModelCard.vue | 1 + 11 files changed, 238 insertions(+), 15 deletions(-) create mode 100644 apps/application/migrations/0003_application_stt_model_params_setting_and_more.py create mode 100644 ui/src/views/application/component/STTModelParamSettingDialog.vue diff --git a/apps/application/migrations/0003_application_stt_model_params_setting_and_more.py b/apps/application/migrations/0003_application_stt_model_params_setting_and_more.py new file mode 100644 index 000000000..86170b4d4 --- /dev/null +++ b/apps/application/migrations/0003_application_stt_model_params_setting_and_more.py @@ -0,0 +1,23 @@ +# Generated by Django 5.2.4 on 2025-09-16 08:10 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('application', '0002_application_simple_mcp'), + ] + + operations = [ + migrations.AddField( + model_name='application', + name='stt_model_params_setting', + field=models.JSONField(default=dict, verbose_name='STT模型参数相关设置'), + ), + migrations.AddField( + model_name='applicationversion', + name='stt_model_params_setting', + field=models.JSONField(default=dict, verbose_name='STT模型参数相关设置'), + ), + ] diff --git a/apps/application/models/application.py b/apps/application/models/application.py index a8524b863..0a4b73765 100644 --- a/apps/application/models/application.py +++ b/apps/application/models/application.py @@ -72,6 +72,7 @@ class Application(AppModelMixin): model_setting = models.JSONField(verbose_name="模型参数相关设置", default=get_model_setting_dict) model_params_setting = models.JSONField(verbose_name="模型参数相关设置", default=dict) tts_model_params_setting = models.JSONField(verbose_name="模型参数相关设置", default=dict) + stt_model_params_setting = models.JSONField(verbose_name="STT模型参数相关设置", default=dict) problem_optimization = models.BooleanField(verbose_name="问题优化", default=False) icon = models.CharField(max_length=256, verbose_name="应用icon", default="./favicon.ico") work_flow = models.JSONField(verbose_name="工作流数据", default=dict) @@ -145,6 +146,7 @@ class ApplicationVersion(AppModelMixin): model_setting = models.JSONField(verbose_name="模型参数相关设置", default=get_model_setting_dict) model_params_setting = models.JSONField(verbose_name="模型参数相关设置", default=dict) tts_model_params_setting = models.JSONField(verbose_name="模型参数相关设置", default=dict) + stt_model_params_setting = models.JSONField(verbose_name="STT模型参数相关设置", default=dict) problem_optimization = models.BooleanField(verbose_name="问题优化", default=False) icon = models.CharField(max_length=256, verbose_name="应用icon", default="./favicon.ico") work_flow = models.JSONField(verbose_name="工作流数据", default=dict) diff --git a/apps/application/serializers/application.py b/apps/application/serializers/application.py index f7a875cf8..b3db3e0a2 100644 --- a/apps/application/serializers/application.py +++ b/apps/application/serializers/application.py @@ -700,6 +700,7 @@ class ApplicationOperateSerializer(serializers.Serializer): 'user_id': 'user_id', 'model_id': 'model_id', 'knowledge_setting': 'knowledge_setting', 'model_setting': 'model_setting', 'model_params_setting': 'model_params_setting', 'tts_model_params_setting': 'tts_model_params_setting', + 'stt_model_params_setting': 'stt_model_params_setting', 'problem_optimization': 'problem_optimization', 'icon': 'icon', 'work_flow': 'work_flow', 'problem_optimization_prompt': 'problem_optimization_prompt', 'tts_model_id': 'tts_model_id', 'stt_model_id': 'stt_model_id', 'tts_model_enable': 'tts_model_enable', @@ -785,6 +786,8 @@ class ApplicationOperateSerializer(serializers.Serializer): instance['stt_autosend'] = node_data['stt_autosend'] if 'tts_model_params_setting' in node_data: instance['tts_model_params_setting'] = node_data['tts_model_params_setting'] + if 'stt_model_params_setting' in node_data: + instance['stt_model_params_setting'] = node_data['stt_model_params_setting'] if 'file_upload_enable' in node_data: instance['file_upload_enable'] = node_data['file_upload_enable'] if 'file_upload_setting' in node_data: @@ -830,7 +833,7 @@ class ApplicationOperateSerializer(serializers.Serializer): 'knowledge_setting', 'model_setting', 'problem_optimization', 'dialogue_number', 'stt_model_id', 'tts_model_id', 'tts_model_enable', 'stt_model_enable', 'tts_type', 'tts_autoplay', 'stt_autosend', 'file_upload_enable', 'file_upload_setting', - 'api_key_is_active', 'icon', 'work_flow', 'model_params_setting', 'tts_model_params_setting', + 'api_key_is_active', 'icon', 'work_flow', 'model_params_setting', 'tts_model_params_setting', 'stt_model_params_setting', 'mcp_enable', 'mcp_tool_ids', 'mcp_servers', 'mcp_source', 'tool_enable', 'tool_ids', 'mcp_output_enable', 'problem_optimization_prompt', 'clean_time', 'folder_id'] for update_key in update_keys: diff --git a/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/asr_stt.py b/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/asr_stt.py index e843180f5..ecf8a83c5 100644 --- a/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/asr_stt.py +++ b/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/asr_stt.py @@ -60,6 +60,7 @@ class AliyunBaiLianAsrSpeechToText(MaxKBBaseModel, BaseSpeechToText): model=self.model, messages=messages, result_format="message", + **self.params ) if response.status_code == 200: text = response["output"]["choices"][0]["message"].content[0]["text"] diff --git a/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/omni_stt.py b/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/omni_stt.py index 56e060f6b..1951953e2 100644 --- a/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/omni_stt.py +++ b/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/omni_stt.py @@ -77,6 +77,7 @@ class AliyunBaiLianOmiSpeechToText(MaxKBBaseModel, BaseSpeechToText): # stream 必须设置为 True,否则会报错 stream=True, stream_options={"include_usage": True}, + extra_body=self.params ) result = [] for chunk in completion: diff --git a/apps/models_provider/impl/tencent_model_provider/model/stt.py b/apps/models_provider/impl/tencent_model_provider/model/stt.py index a501fed19..d2aed7370 100644 --- a/apps/models_provider/impl/tencent_model_provider/model/stt.py +++ b/apps/models_provider/impl/tencent_model_provider/model/stt.py @@ -69,6 +69,7 @@ class TencentSpeechToText(MaxKBBaseModel, BaseSpeechToText): "SourceType": 1, "VoiceFormat": "mp3", "Data": _v.decode(), + **self.params } req.from_json_string(json.dumps(params)) diff --git a/apps/models_provider/impl/xf_model_provider/model/zh_en_stt.py b/apps/models_provider/impl/xf_model_provider/model/zh_en_stt.py index fb1498fca..53d985d07 100644 --- a/apps/models_provider/impl/xf_model_provider/model/zh_en_stt.py +++ b/apps/models_provider/impl/xf_model_provider/model/zh_en_stt.py @@ -22,11 +22,25 @@ ssl_context.check_hostname = False ssl_context.verify_mode = ssl.CERT_NONE +def deep_merge_dict(target_dict, source_dict): + + if not isinstance(source_dict, dict): + return source_dict + result = target_dict.copy() if isinstance(target_dict, dict) else {} + for key, value in source_dict.items(): + if key in result and isinstance(result[key], dict) and isinstance(value, dict): + result[key] = deep_merge_dict(result[key], value) + else: + result[key] = value + return result + + class XFZhEnSparkSpeechToText(MaxKBBaseModel, BaseSpeechToText): spark_app_id: str spark_api_key: str spark_api_secret: str spark_api_url: str + params: dict def __init__(self, **kwargs): super().__init__(**kwargs) @@ -34,6 +48,7 @@ class XFZhEnSparkSpeechToText(MaxKBBaseModel, BaseSpeechToText): self.spark_app_id = kwargs.get('spark_app_id') self.spark_api_key = kwargs.get('spark_api_key') self.spark_api_secret = kwargs.get('spark_api_secret') + self.params = kwargs.get('params') @staticmethod def is_cache_model(): @@ -41,17 +56,14 @@ class XFZhEnSparkSpeechToText(MaxKBBaseModel, BaseSpeechToText): @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): - optional_params = {} - if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None: - optional_params['max_tokens'] = model_kwargs['max_tokens'] - if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None: - optional_params['temperature'] = model_kwargs['temperature'] + return XFZhEnSparkSpeechToText( spark_app_id=model_credential.get('spark_app_id'), spark_api_key=model_credential.get('spark_api_key'), spark_api_secret=model_credential.get('spark_api_secret'), spark_api_url=model_credential.get('spark_api_url'), - **optional_params + params=model_kwargs, + **model_kwargs ) # 生成url @@ -106,6 +118,10 @@ class XFZhEnSparkSpeechToText(MaxKBBaseModel, BaseSpeechToText): maxkb_logger.error(f"语音识别错误: {str(err)}: {traceback.format_exc()}") return "" + def merge_params_to_frame(self, frame,params): + + return deep_merge_dict(frame, params) + async def send_audio(self, ws, audio_file): """发送音频数据""" chunk_size = 4000 @@ -123,8 +139,11 @@ class XFZhEnSparkSpeechToText(MaxKBBaseModel, BaseSpeechToText): "header": {"app_id": self.spark_app_id, "status": 0}, "parameter": { "iat": { - "domain": "slm", "language": "zh_cn", "accent": "mandarin", - "eos": 10000, "vinfo": 1, + "domain": "slm", + "language": "zh_cn", + "accent": "mandarin", + "eos": 10000, + "vinfo": 1, "result": {"encoding": "utf8", "compress": "raw", "format": "json"} } }, @@ -135,6 +154,9 @@ class XFZhEnSparkSpeechToText(MaxKBBaseModel, BaseSpeechToText): } } } + frame = self.merge_params_to_frame(frame,{key: value for key, value in self.params.items() if + not ['model_id', 'use_local', 'streaming'].__contains__(key)}) + # 中间帧 else: frame = { @@ -147,6 +169,9 @@ class XFZhEnSparkSpeechToText(MaxKBBaseModel, BaseSpeechToText): } } + frame = self.merge_params_to_frame(frame,{key: value for key, value in self.params.items() if + not ['model_id', 'use_local', 'streaming','parameter'].__contains__(key)}) + await ws.send(json.dumps(frame)) seq += 1 @@ -160,17 +185,19 @@ class XFZhEnSparkSpeechToText(MaxKBBaseModel, BaseSpeechToText): } } } + + end_frame = self.merge_params_to_frame(end_frame,{key: value for key, value in self.params.items() if + not ['model_id', 'use_local', 'streaming','parameter'].__contains__(key)}) + await ws.send(json.dumps(end_frame)) - -# 接受信息处理器 + # 接受信息处理器 async def handle_message(self, ws): result_text = "" while True: try: message = await asyncio.wait_for(ws.recv(), timeout=30.0) data = json.loads(message) - if data['header']['code'] != 0: raise Exception("") diff --git a/ui/src/api/type/application.ts b/ui/src/api/type/application.ts index b7b4a85c6..86f735ab6 100644 --- a/ui/src/api/type/application.ts +++ b/ui/src/api/type/application.ts @@ -17,6 +17,7 @@ interface ApplicationFormType { work_flow?: any model_params_setting?: any tts_model_params_setting?: any + stt_model_params_setting?: any stt_model_id?: string tts_model_id?: string stt_model_enable?: boolean diff --git a/ui/src/views/application/ApplicationSetting.vue b/ui/src/views/application/ApplicationSetting.vue index 62284aa14..c5264855b 100644 --- a/ui/src/views/application/ApplicationSetting.vue +++ b/ui/src/views/application/ApplicationSetting.vue @@ -482,14 +482,28 @@ - + - + > + + + + + + + + + >() const ReasoningParamSettingDialogRef = ref>() const TTSModeParamSettingDialogRef = ref>() +const STTModeParamSettingDialogRef = ref>() const ParamSettingDialogRef = ref>() const GeneratePromptDialogRef = ref>() @@ -756,6 +773,7 @@ const submit = async (formEl: FormInstance | undefined) => { if (!formEl) return await formEl.validate((valid, fields) => { if (valid) { + console.log(applicationForm.value) loadSharedApi({ type: 'application', systemType: apiType.value }) .putApplication(id, applicationForm.value, loading) .then(() => { @@ -806,6 +824,17 @@ const openTTSParamSettingDialog = () => { } } +const openSTTParamSettingDialog = () => { + if (applicationForm.value.stt_model_id) { + STTModeParamSettingDialogRef.value?.open( + applicationForm.value.stt_model_id, + id, + applicationForm.value.stt_model_params_setting, + ) + } +} + + const openParamSettingDialog = () => { ParamSettingDialogRef.value?.open(applicationForm.value) } @@ -905,6 +934,10 @@ function refreshTTSForm(data: any) { applicationForm.value.tts_model_params_setting = data } +function refreshSTTForm(data: any) { + applicationForm.value.stt_model_params_setting = data +} + function removeKnowledge(id: any) { if (applicationForm.value.knowledge_id_list) { applicationForm.value.knowledge_id_list.splice( @@ -1022,6 +1055,14 @@ function ttsModelChange() { } } +function sttModelChange() { + if (applicationForm.value.stt_model_id) { + STTModeParamSettingDialogRef.value?.reset_default(applicationForm.value.stt_model_id, id) + } else { + refreshSTTForm({}) + } +} + function ttsModelEnableChange() { if (!applicationForm.value.tts_model_enable) { applicationForm.value.tts_model_id = undefined diff --git a/ui/src/views/application/component/STTModelParamSettingDialog.vue b/ui/src/views/application/component/STTModelParamSettingDialog.vue new file mode 100644 index 000000000..c4e4723f5 --- /dev/null +++ b/ui/src/views/application/component/STTModelParamSettingDialog.vue @@ -0,0 +1,122 @@ + + + + + + + \ No newline at end of file diff --git a/ui/src/views/model/component/ModelCard.vue b/ui/src/views/model/component/ModelCard.vue index 87a443145..dffca98c7 100644 --- a/ui/src/views/model/component/ModelCard.vue +++ b/ui/src/views/model/component/ModelCard.vue @@ -90,6 +90,7 @@