From d551462c05e5e8850c301a3ade259fc57de64cdb Mon Sep 17 00:00:00 2001 From: CaptainB Date: Thu, 12 Sep 2024 11:03:31 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E6=94=AF=E6=8C=81=E4=BD=BF?= =?UTF-8?q?=E7=94=A8=E6=B5=8F=E8=A7=88=E5=99=A8=E8=AF=AD=E9=9F=B3=E6=92=AD?= =?UTF-8?q?=E6=94=BE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../migrations/0013_application_tts_type.py | 18 +++++++++++ apps/application/models/application.py | 1 + .../serializers/application_serializers.py | 5 ++- apps/application/views/application_views.py | 2 +- .../local_model_provider/credential/tts.py | 16 ---------- .../local_model_provider.py | 6 ---- .../impl/local_model_provider/model/tts.py | 31 ------------------- ui/src/api/type/application.ts | 1 + ui/src/components/ai-chat/index.vue | 21 +++---------- ui/src/views/application-workflow/index.vue | 20 ++---------- .../views/application/ApplicationSetting.vue | 9 +++++- ui/src/views/chat/pc/index.vue | 20 ------------ ui/src/workflow/nodes/base-node/index.vue | 5 +++ 13 files changed, 45 insertions(+), 110 deletions(-) create mode 100644 apps/application/migrations/0013_application_tts_type.py delete mode 100644 apps/setting/models_provider/impl/local_model_provider/credential/tts.py delete mode 100644 apps/setting/models_provider/impl/local_model_provider/model/tts.py diff --git a/apps/application/migrations/0013_application_tts_type.py b/apps/application/migrations/0013_application_tts_type.py new file mode 100644 index 000000000..c64c8e76d --- /dev/null +++ b/apps/application/migrations/0013_application_tts_type.py @@ -0,0 +1,18 @@ +# Generated by Django 4.2.15 on 2024-09-12 11:01 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('application', '0012_application_stt_model_application_stt_model_enable_and_more'), + ] + + operations = [ + migrations.AddField( + model_name='application', + name='tts_type', + field=models.CharField(default='BROWSER', max_length=20, verbose_name='语音播放类型'), + ), + ] diff --git a/apps/application/models/application.py b/apps/application/models/application.py index 54019cb07..c03bec1eb 100644 --- a/apps/application/models/application.py +++ b/apps/application/models/application.py @@ -58,6 +58,7 @@ class Application(AppModelMixin): stt_model = models.ForeignKey(Model, related_name='stt_model_id', on_delete=models.SET_NULL, db_constraint=False, blank=True, null=True) tts_model_enable = models.BooleanField(verbose_name="语音合成模型是否启用", default=False) stt_model_enable = models.BooleanField(verbose_name="语音识别模型是否启用", default=False) + tts_type = models.CharField(verbose_name="语音播放类型", max_length=20, default="BROWSER") @staticmethod def get_default_model_prompt(): diff --git a/apps/application/serializers/application_serializers.py b/apps/application/serializers/application_serializers.py index 7eab38f90..a7e7f6423 100644 --- a/apps/application/serializers/application_serializers.py +++ b/apps/application/serializers/application_serializers.py @@ -694,6 +694,7 @@ class ApplicationSerializer(serializers.Serializer): 'tts_model_id': application.tts_model_id, 'stt_model_enable': application.stt_model_enable, 'tts_model_enable': application.tts_model_enable, + 'tts_type': application.tts_type, 'work_flow': application.work_flow, 'show_source': application_access_token.show_source}) @@ -745,7 +746,7 @@ class ApplicationSerializer(serializers.Serializer): update_keys = ['name', 'desc', 'model_id', 'multiple_rounds_dialogue', 'prologue', 'status', 'dataset_setting', 'model_setting', 'problem_optimization', 'dialogue_number', - 'stt_model_id', 'tts_model_id', 'tts_model_enable', 'stt_model_enable', + 'stt_model_id', 'tts_model_id', 'tts_model_enable', 'stt_model_enable', 'tts_type', 'api_key_is_active', 'icon', 'work_flow', 'model_params_setting'] for update_key in update_keys: if update_key in instance and instance.get(update_key) is not None: @@ -865,6 +866,8 @@ class ApplicationSerializer(serializers.Serializer): instance['stt_model_enable'] = node_data['stt_model_enable'] if 'tts_model_enable' in node_data: instance['tts_model_enable'] = node_data['tts_model_enable'] + if 'tts_type' in node_data: + instance['tts_type'] = node_data['tts_type'] break def speech_to_text(self, file, with_valid=True): diff --git a/apps/application/views/application_views.py b/apps/application/views/application_views.py index 73eb9a782..0c8333566 100644 --- a/apps/application/views/application_views.py +++ b/apps/application/views/application_views.py @@ -178,7 +178,7 @@ class Application(APIView): tags=["应用"], manual_parameters=ApplicationApi.Model.get_request_params_api()) @has_permissions(ViewPermission( - [RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_ACCESS_TOKEN], + [RoleConstants.ADMIN, RoleConstants.USER], [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, dynamic_tag=keywords.get('application_id'))], compare=CompareConstants.AND)) diff --git a/apps/setting/models_provider/impl/local_model_provider/credential/tts.py b/apps/setting/models_provider/impl/local_model_provider/credential/tts.py deleted file mode 100644 index 7047908b5..000000000 --- a/apps/setting/models_provider/impl/local_model_provider/credential/tts.py +++ /dev/null @@ -1,16 +0,0 @@ -# coding=utf-8 - -from typing import Dict - -from common.forms import BaseForm -from setting.models_provider.base_model_provider import BaseModelCredential - - -class BrowserTextToSpeechCredential(BaseForm, BaseModelCredential): - - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, - raise_exception=False): - return True - - def encryption_dict(self, model: Dict[str, object]): - return model diff --git a/apps/setting/models_provider/impl/local_model_provider/local_model_provider.py b/apps/setting/models_provider/impl/local_model_provider/local_model_provider.py index e1b462bb2..2c92bbbfb 100644 --- a/apps/setting/models_provider/impl/local_model_provider/local_model_provider.py +++ b/apps/setting/models_provider/impl/local_model_provider/local_model_provider.py @@ -17,10 +17,8 @@ from setting.models_provider.base_model_provider import ModelProvideInfo, ModelT ModelInfoManage from setting.models_provider.impl.local_model_provider.credential.embedding import LocalEmbeddingCredential from setting.models_provider.impl.local_model_provider.credential.reranker import LocalRerankerCredential -from setting.models_provider.impl.local_model_provider.credential.tts import BrowserTextToSpeechCredential from setting.models_provider.impl.local_model_provider.model.embedding import LocalEmbedding from setting.models_provider.impl.local_model_provider.model.reranker import LocalReranker -from setting.models_provider.impl.local_model_provider.model.tts import BrowserTextToSpeech from smartdoc.conf import PROJECT_DIR embedding_text2vec_base_chinese = ModelInfo('shibing624/text2vec-base-chinese', '', ModelTypeConst.EMBEDDING, @@ -28,14 +26,10 @@ embedding_text2vec_base_chinese = ModelInfo('shibing624/text2vec-base-chinese', bge_reranker_v2_m3 = ModelInfo('BAAI/bge-reranker-v2-m3', '', ModelTypeConst.RERANKER, LocalRerankerCredential(), LocalReranker) -browser_tts = ModelInfo('browser_tts', '', ModelTypeConst.TTS, BrowserTextToSpeechCredential(), BrowserTextToSpeech) - - model_info_manage = (ModelInfoManage.builder().append_model_info(embedding_text2vec_base_chinese) .append_default_model_info(embedding_text2vec_base_chinese) .append_model_info(bge_reranker_v2_m3) .append_default_model_info(bge_reranker_v2_m3) - .append_model_info(browser_tts) .build()) diff --git a/apps/setting/models_provider/impl/local_model_provider/model/tts.py b/apps/setting/models_provider/impl/local_model_provider/model/tts.py deleted file mode 100644 index 56e6292be..000000000 --- a/apps/setting/models_provider/impl/local_model_provider/model/tts.py +++ /dev/null @@ -1,31 +0,0 @@ -from typing import Dict - -from setting.models_provider.base_model_provider import MaxKBBaseModel -from setting.models_provider.impl.base_tts import BaseTextToSpeech - - - -class BrowserTextToSpeech(MaxKBBaseModel, BaseTextToSpeech): - model: str - - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.model = kwargs.get('model') - - @staticmethod - def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): - optional_params = {} - if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None: - optional_params['max_tokens'] = model_kwargs['max_tokens'] - if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None: - optional_params['temperature'] = model_kwargs['temperature'] - return BrowserTextToSpeech( - model=model_name, - **optional_params, - ) - - def check_auth(self): - pass - - def text_to_speech(self, text): - pass diff --git a/ui/src/api/type/application.ts b/ui/src/api/type/application.ts index 458414908..24215a5fb 100644 --- a/ui/src/api/type/application.ts +++ b/ui/src/api/type/application.ts @@ -18,6 +18,7 @@ interface ApplicationFormType { tts_model_id?: string stt_model_enable?: boolean tts_model_enable?: boolean + tts_type?: string } interface chatType { id: string diff --git a/ui/src/components/ai-chat/index.vue b/ui/src/components/ai-chat/index.vue index 01a4b01e6..df3ea2b62 100644 --- a/ui/src/components/ai-chat/index.vue +++ b/ui/src/components/ai-chat/index.vue @@ -239,13 +239,7 @@ const props = defineProps({ chatId: { type: String, default: '' - }, // 历史记录Id - ttsModelOptions: { - type: Object, - default: () => { - return {} - } - } + } // 历史记录Id }) const emit = defineEmits(['refresh', 'scroll']) @@ -771,20 +765,14 @@ const uploadRecording = async (audioBlob: Blob) => { } const playAnswerText = (text: string) => { - if ( - props.ttsModelOptions?.model_local_provider?.filter( - (v: any) => v.id === props.data.tts_model_id - ).length > 0 - ) { + if (props.data.tts_type === 'BROWSER') { // 创建一个新的 SpeechSynthesisUtterance 实例 const utterance = new SpeechSynthesisUtterance(text) // 调用浏览器的朗读功能 window.speechSynthesis.speak(utterance) - return } - - applicationApi - .postTextToSpeech(props.data.id as string, { text: text }, loading) + if (props.data.tts_type === 'TTS') { + applicationApi.postTextToSpeech(props.data.id as string, { 'text': text }, loading) .then((res: any) => { // 假设我们有一个 MP3 文件的字节数组 // 创建 Blob 对象 @@ -810,6 +798,7 @@ const playAnswerText = (text: string) => { .catch((err) => { console.log('err: ', err) }) + } } onMounted(() => { diff --git a/ui/src/views/application-workflow/index.vue b/ui/src/views/application-workflow/index.vue index 43a1b4da5..c3f3f6377 100644 --- a/ui/src/views/application-workflow/index.vue +++ b/ui/src/views/application-workflow/index.vue @@ -138,7 +138,7 @@
- +
@@ -157,7 +157,6 @@ import { datetimeFormat } from '@/utils/time' import useStore from '@/stores' import { WorkFlowInstance } from '@/workflow/common/validate' import { hasPermission } from '@/utils/permission' -import { groupBy } from 'lodash' const { user, application } = useStore() const router = useRouter() @@ -182,7 +181,6 @@ const enlarge = ref(false) const saveTime = ref('') const activeName = ref('base') const functionLibList = ref([]) -const ttsModelOptions = ref(null) function publicHandle() { workflowRef.value @@ -290,6 +288,7 @@ function getDetail() { detail.value = res.data detail.value.stt_model_id = res.data.stt_model detail.value.tts_model_id = res.data.tts_model + detail.value.tts_type = res.data.tts_type saveTime.value = res.data?.update_time }) } @@ -312,20 +311,6 @@ function getList() { }) } -function getTTSModel() { - loading.value = true - applicationApi - .getApplicationTTSModel(id) - .then((res: any) => { - ttsModelOptions.value = groupBy(res?.data, 'provider') - loading.value = false - }) - .catch(() => { - loading.value = false - }) -} - - /** * 定时保存 */ @@ -345,7 +330,6 @@ const closeInterval = () => { } onMounted(() => { - getTTSModel() getDetail() getList() // 初始化定时任务 diff --git a/ui/src/views/application/ApplicationSetting.vue b/ui/src/views/application/ApplicationSetting.vue index 6fac3e062..91f2c153f 100644 --- a/ui/src/views/application/ApplicationSetting.vue +++ b/ui/src/views/application/ApplicationSetting.vue @@ -369,7 +369,12 @@ + + + +
- +
@@ -556,6 +561,7 @@ const applicationForm = ref({ tts_model_id: '', stt_model_enable: false, tts_model_enable: false, + tts_type: 'BROWSER', type: 'SIMPLE' }) @@ -657,6 +663,7 @@ function getDetail() { applicationForm.value.model_id = res.data.model applicationForm.value.stt_model_id = res.data.stt_model applicationForm.value.tts_model_id = res.data.tts_model + applicationForm.value.tts_type = res.data.tts_type }) } diff --git a/ui/src/views/chat/pc/index.vue b/ui/src/views/chat/pc/index.vue index e89363371..3ed0c7a30 100644 --- a/ui/src/views/chat/pc/index.vue +++ b/ui/src/views/chat/pc/index.vue @@ -108,7 +108,6 @@ :appId="applicationDetail?.id" :record="currentRecordList" :chatId="currentChatId" - :tts-model-options="ttsModelOptions" @refresh="refresh" @scroll="handleScroll" > @@ -131,8 +130,6 @@ import { marked } from 'marked' import { saveAs } from 'file-saver' import { isAppIcon } from '@/utils/application' import useStore from '@/stores' -import applicationApi from '@/api/application' -import { groupBy } from 'lodash' import useResize from '@/layout/hooks/useResize' useResize() @@ -170,8 +167,6 @@ const left_loading = ref(false) const applicationDetail = ref({}) const applicationAvailable = ref(true) const chatLogeData = ref([]) -const ttsModelOptions = ref(null) - const paginationConfig = ref({ current_page: 1, @@ -233,7 +228,6 @@ function getAppProfile() { if (res.data?.show_history || !user.isEnterprise()) { getChatLog(applicationDetail.value.id) } - getTTSModel() }) .catch(() => { applicationAvailable.value = false @@ -342,20 +336,6 @@ async function exportHTML(): Promise { saveAs(blob, suggestedName) } -function getTTSModel() { - loading.value = true - applicationApi - .getApplicationTTSModel(applicationDetail.value.id) - .then((res: any) => { - ttsModelOptions.value = groupBy(res?.data, 'provider') - loading.value = false - }) - .catch(() => { - loading.value = false - }) -} - - onMounted(() => { user.changeUserType(2) getAccessToken(accessToken) diff --git a/ui/src/workflow/nodes/base-node/index.vue b/ui/src/workflow/nodes/base-node/index.vue index c63d60986..072b9e857 100644 --- a/ui/src/workflow/nodes/base-node/index.vue +++ b/ui/src/workflow/nodes/base-node/index.vue @@ -135,7 +135,12 @@ + + + +