diff --git a/apps/application/views/application_views.py b/apps/application/views/application_views.py index 0c8333566..73eb9a782 100644 --- a/apps/application/views/application_views.py +++ b/apps/application/views/application_views.py @@ -178,7 +178,7 @@ class Application(APIView): tags=["应用"], manual_parameters=ApplicationApi.Model.get_request_params_api()) @has_permissions(ViewPermission( - [RoleConstants.ADMIN, RoleConstants.USER], + [RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_ACCESS_TOKEN], [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, dynamic_tag=keywords.get('application_id'))], compare=CompareConstants.AND)) diff --git a/apps/setting/models_provider/impl/local_model_provider/credential/tts.py b/apps/setting/models_provider/impl/local_model_provider/credential/tts.py new file mode 100644 index 000000000..7047908b5 --- /dev/null +++ b/apps/setting/models_provider/impl/local_model_provider/credential/tts.py @@ -0,0 +1,16 @@ +# coding=utf-8 + +from typing import Dict + +from common.forms import BaseForm +from setting.models_provider.base_model_provider import BaseModelCredential + + +class BrowserTextToSpeechCredential(BaseForm, BaseModelCredential): + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + return True + + def encryption_dict(self, model: Dict[str, object]): + return model diff --git a/apps/setting/models_provider/impl/local_model_provider/local_model_provider.py b/apps/setting/models_provider/impl/local_model_provider/local_model_provider.py index 2c92bbbfb..e1b462bb2 100644 --- a/apps/setting/models_provider/impl/local_model_provider/local_model_provider.py +++ b/apps/setting/models_provider/impl/local_model_provider/local_model_provider.py @@ -17,8 +17,10 @@ from setting.models_provider.base_model_provider import ModelProvideInfo, ModelT ModelInfoManage from setting.models_provider.impl.local_model_provider.credential.embedding import LocalEmbeddingCredential from setting.models_provider.impl.local_model_provider.credential.reranker import LocalRerankerCredential +from setting.models_provider.impl.local_model_provider.credential.tts import BrowserTextToSpeechCredential from setting.models_provider.impl.local_model_provider.model.embedding import LocalEmbedding from setting.models_provider.impl.local_model_provider.model.reranker import LocalReranker +from setting.models_provider.impl.local_model_provider.model.tts import BrowserTextToSpeech from smartdoc.conf import PROJECT_DIR embedding_text2vec_base_chinese = ModelInfo('shibing624/text2vec-base-chinese', '', ModelTypeConst.EMBEDDING, @@ -26,10 +28,14 @@ embedding_text2vec_base_chinese = ModelInfo('shibing624/text2vec-base-chinese', bge_reranker_v2_m3 = ModelInfo('BAAI/bge-reranker-v2-m3', '', ModelTypeConst.RERANKER, LocalRerankerCredential(), LocalReranker) +browser_tts = ModelInfo('browser_tts', '', ModelTypeConst.TTS, BrowserTextToSpeechCredential(), BrowserTextToSpeech) + + model_info_manage = (ModelInfoManage.builder().append_model_info(embedding_text2vec_base_chinese) .append_default_model_info(embedding_text2vec_base_chinese) .append_model_info(bge_reranker_v2_m3) .append_default_model_info(bge_reranker_v2_m3) + .append_model_info(browser_tts) .build()) diff --git a/apps/setting/models_provider/impl/local_model_provider/model/tts.py b/apps/setting/models_provider/impl/local_model_provider/model/tts.py new file mode 100644 index 000000000..56e6292be --- /dev/null +++ b/apps/setting/models_provider/impl/local_model_provider/model/tts.py @@ -0,0 +1,31 @@ +from typing import Dict + +from setting.models_provider.base_model_provider import MaxKBBaseModel +from setting.models_provider.impl.base_tts import BaseTextToSpeech + + + +class BrowserTextToSpeech(MaxKBBaseModel, BaseTextToSpeech): + model: str + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.model = kwargs.get('model') + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + optional_params = {} + if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None: + optional_params['max_tokens'] = model_kwargs['max_tokens'] + if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None: + optional_params['temperature'] = model_kwargs['temperature'] + return BrowserTextToSpeech( + model=model_name, + **optional_params, + ) + + def check_auth(self): + pass + + def text_to_speech(self, text): + pass diff --git a/ui/src/components/ai-chat/index.vue b/ui/src/components/ai-chat/index.vue index 7a75f994f..185dafd29 100644 --- a/ui/src/components/ai-chat/index.vue +++ b/ui/src/components/ai-chat/index.vue @@ -136,6 +136,7 @@ @regeneration="regenerationChart(item)" /> +
@@ -247,7 +248,13 @@ const props = defineProps({ chatId: { type: String, default: '' - } // 历史记录Id + }, // 历史记录Id + ttsModelOptions: { + type: Object, + default: () => { + return {} + } + } }) const emit = defineEmits(['refresh', 'scroll']) @@ -321,8 +328,7 @@ watch( ) function handleInputFieldList() { - props.data.work_flow?.nodes - .filter((v: any) => v.id === 'base-node') + props.data.work_flow?.nodes?.filter((v: any) => v.id === 'base-node') .map((v: any) => { inputFieldList.value = v.properties.input_field_list.map((v: any) => { switch (v.type) { @@ -763,6 +769,14 @@ const uploadRecording = async (audioBlob: Blob) => { } const playAnswerText = (text: string) => { + if (props.ttsModelOptions?.model_local_provider?.filter((v: any) => v.id === props.data.tts_model_id).length > 0) { + // 创建一个新的 SpeechSynthesisUtterance 实例 + const utterance = new SpeechSynthesisUtterance(text); + // 调用浏览器的朗读功能 + window.speechSynthesis.speak(utterance); + return + } + applicationApi.postTextToSpeech(props.data.id as string, { 'text': text }, loading) .then((res: any) => { diff --git a/ui/src/views/application-workflow/index.vue b/ui/src/views/application-workflow/index.vue index bd5d77a04..43a1b4da5 100644 --- a/ui/src/views/application-workflow/index.vue +++ b/ui/src/views/application-workflow/index.vue @@ -138,7 +138,7 @@
- +
@@ -157,6 +157,7 @@ import { datetimeFormat } from '@/utils/time' import useStore from '@/stores' import { WorkFlowInstance } from '@/workflow/common/validate' import { hasPermission } from '@/utils/permission' +import { groupBy } from 'lodash' const { user, application } = useStore() const router = useRouter() @@ -181,6 +182,7 @@ const enlarge = ref(false) const saveTime = ref('') const activeName = ref('base') const functionLibList = ref([]) +const ttsModelOptions = ref(null) function publicHandle() { workflowRef.value @@ -310,6 +312,20 @@ function getList() { }) } +function getTTSModel() { + loading.value = true + applicationApi + .getApplicationTTSModel(id) + .then((res: any) => { + ttsModelOptions.value = groupBy(res?.data, 'provider') + loading.value = false + }) + .catch(() => { + loading.value = false + }) +} + + /** * 定时保存 */ @@ -329,6 +345,7 @@ const closeInterval = () => { } onMounted(() => { + getTTSModel() getDetail() getList() // 初始化定时任务 diff --git a/ui/src/views/application/ApplicationSetting.vue b/ui/src/views/application/ApplicationSetting.vue index e52b7d734..6fac3e062 100644 --- a/ui/src/views/application/ApplicationSetting.vue +++ b/ui/src/views/application/ApplicationSetting.vue @@ -464,7 +464,7 @@
- +
diff --git a/ui/src/views/chat/pc/index.vue b/ui/src/views/chat/pc/index.vue index 3ed0c7a30..e89363371 100644 --- a/ui/src/views/chat/pc/index.vue +++ b/ui/src/views/chat/pc/index.vue @@ -108,6 +108,7 @@ :appId="applicationDetail?.id" :record="currentRecordList" :chatId="currentChatId" + :tts-model-options="ttsModelOptions" @refresh="refresh" @scroll="handleScroll" > @@ -130,6 +131,8 @@ import { marked } from 'marked' import { saveAs } from 'file-saver' import { isAppIcon } from '@/utils/application' import useStore from '@/stores' +import applicationApi from '@/api/application' +import { groupBy } from 'lodash' import useResize from '@/layout/hooks/useResize' useResize() @@ -167,6 +170,8 @@ const left_loading = ref(false) const applicationDetail = ref({}) const applicationAvailable = ref(true) const chatLogeData = ref([]) +const ttsModelOptions = ref(null) + const paginationConfig = ref({ current_page: 1, @@ -228,6 +233,7 @@ function getAppProfile() { if (res.data?.show_history || !user.isEnterprise()) { getChatLog(applicationDetail.value.id) } + getTTSModel() }) .catch(() => { applicationAvailable.value = false @@ -336,6 +342,20 @@ async function exportHTML(): Promise { saveAs(blob, suggestedName) } +function getTTSModel() { + loading.value = true + applicationApi + .getApplicationTTSModel(applicationDetail.value.id) + .then((res: any) => { + ttsModelOptions.value = groupBy(res?.data, 'provider') + loading.value = false + }) + .catch(() => { + loading.value = false + }) +} + + onMounted(() => { user.changeUserType(2) getAccessToken(accessToken)