diff --git a/apps/application/flow/step_node/speech_to_text_step_node/impl/base_speech_to_text_node.py b/apps/application/flow/step_node/speech_to_text_step_node/impl/base_speech_to_text_node.py index ed2ca9a68..b4806c537 100644 --- a/apps/application/flow/step_node/speech_to_text_step_node/impl/base_speech_to_text_node.py +++ b/apps/application/flow/step_node/speech_to_text_step_node/impl/base_speech_to_text_node.py @@ -10,7 +10,7 @@ from pydub import AudioSegment from concurrent.futures import ThreadPoolExecutor from application.flow.i_step_node import NodeResult, INode from application.flow.step_node.speech_to_text_step_node.i_speech_to_text_node import ISpeechToTextNode -from common.util.common import split_and_transcribe +from common.util.common import split_and_transcribe, any_to_mp3 from dataset.models import File from setting.models_provider.tools import get_model_instance_by_model_user_id @@ -26,16 +26,21 @@ class BaseSpeechToTextNode(ISpeechToTextNode): audio_list = audio self.context['audio_list'] = audio - def process_audio_item(audio_item, model): file = QuerySet(File).filter(id=audio_item['file_id']).first() - with tempfile.NamedTemporaryFile(delete=False, suffix='.mp3') as temp_file: + # 根据file_name 吧文件转成mp3格式 + file_format = file.file_name.split('.')[-1] + with tempfile.NamedTemporaryFile(delete=False, suffix=f'.{file_format}') as temp_file: temp_file.write(file.get_byte().tobytes()) temp_file_path = temp_file.name + with tempfile.NamedTemporaryFile(delete=False, suffix='.mp3') as temp_amr_file: + temp_mp3_path = temp_amr_file.name + any_to_mp3(temp_file_path, temp_mp3_path) try: return split_and_transcribe(temp_file_path, model) finally: os.remove(temp_file_path) + os.remove(temp_mp3_path) def process_audio_items(audio_list, model): with ThreadPoolExecutor(max_workers=5) as executor: diff --git a/apps/common/util/common.py b/apps/common/util/common.py index 73107d62f..2feb3f57d 100644 --- a/apps/common/util/common.py +++ b/apps/common/util/common.py @@ -181,18 +181,18 @@ def sil_to_wav(silk_path, wav_path, rate: int = 24000): f.write(wav_data) -def split_and_transcribe(file_path, model, max_segment_length_ms=59000, format="mp3"): - audio_data = AudioSegment.from_file(file_path, format=format) +def split_and_transcribe(file_path, model, max_segment_length_ms=59000, audio_format="mp3"): + audio_data = AudioSegment.from_file(file_path, format=audio_format) audio_length_ms = len(audio_data) if audio_length_ms <= max_segment_length_ms: - return model.speech_to_text(io.BytesIO(audio_data.export(format=format).read())) + return model.speech_to_text(io.BytesIO(audio_data.export(format=audio_format).read())) full_text = [] for start_ms in range(0, audio_length_ms, max_segment_length_ms): end_ms = min(audio_length_ms, start_ms + max_segment_length_ms) segment = audio_data[start_ms:end_ms] - text = model.speech_to_text(io.BytesIO(segment.export(format=format).read())) + text = model.speech_to_text(io.BytesIO(segment.export(format=audio_format).read())) if isinstance(text, str): full_text.append(text) return ' '.join(full_text) diff --git a/apps/dataset/serializers/file_serializers.py b/apps/dataset/serializers/file_serializers.py index 28a806338..4fe66a9a6 100644 --- a/apps/dataset/serializers/file_serializers.py +++ b/apps/dataset/serializers/file_serializers.py @@ -77,9 +77,11 @@ class FileSerializer(serializers.Serializer): file = QuerySet(File).filter(id=file_id).first() if file is None: raise NotFound404(404, "不存在的文件") - # 如果是mp3文件,直接返回文件流 - if file.file_name.split(".")[-1] == 'mp3': - return HttpResponse(file.get_byte(), status=200, headers={'Content-Type': 'audio/mp3', - 'Content-Disposition': 'attachment; filename="abc.mp3"'}) + # 如果是音频文件,直接返回文件流 + file_type = file.file_name.split(".")[-1] + if file_type in ['mp3', 'wav', 'ogg', 'aac']: + return HttpResponse(file.get_byte(), status=200, headers={'Content-Type': f'audio/{file_type}', + 'Content-Disposition': 'attachment; filename="{}"'.format( + file.file_name)}) return HttpResponse(file.get_byte(), status=200, headers={'Content-Type': mime_types.get(file.file_name.split(".")[-1], 'text/plain')}) diff --git a/apps/setting/models_provider/impl/vllm_model_provider/credential/llm.py b/apps/setting/models_provider/impl/vllm_model_provider/credential/llm.py index 0e100d742..6c14a45fa 100644 --- a/apps/setting/models_provider/impl/vllm_model_provider/credential/llm.py +++ b/apps/setting/models_provider/impl/vllm_model_provider/credential/llm.py @@ -34,7 +34,7 @@ class VLLMModelCredential(BaseForm, BaseModelCredential): if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') try: - model_list = provider.get_base_model_list(model_credential.get('api_base')) + model_list = provider.get_base_model_list(model_credential.get('api_base'), model_credential.get('api_key')) except Exception as e: raise AppApiException(ValidCode.valid_error.value, "API 域名无效") exist = provider.get_model_info_by_name(model_list, model_name) diff --git a/apps/setting/models_provider/impl/vllm_model_provider/vllm_model_provider.py b/apps/setting/models_provider/impl/vllm_model_provider/vllm_model_provider.py index 42ba361fc..7912fff96 100644 --- a/apps/setting/models_provider/impl/vllm_model_provider/vllm_model_provider.py +++ b/apps/setting/models_provider/impl/vllm_model_provider/vllm_model_provider.py @@ -45,10 +45,13 @@ class VllmModelProvider(IModelProvider): 'vllm_icon_svg'))) @staticmethod - def get_base_model_list(api_base): + def get_base_model_list(api_base, api_key): base_url = get_base_url(api_base) base_url = base_url if base_url.endswith('/v1') else (base_url + '/v1') - r = requests.request(method="GET", url=f"{base_url}/models", timeout=5) + headers = {} + if api_key: + headers['Authorization'] = f"Bearer {api_key}" + r = requests.request(method="GET", url=f"{base_url}/models", headers=headers, timeout=5) r.raise_for_status() return r.json().get('data') diff --git a/ui/src/components/ai-chat/component/chat-input-operate/index.vue b/ui/src/components/ai-chat/component/chat-input-operate/index.vue index b8c33b93d..7c90dbd7f 100644 --- a/ui/src/components/ai-chat/component/chat-input-operate/index.vue +++ b/ui/src/components/ai-chat/component/chat-input-operate/index.vue @@ -27,7 +27,9 @@ class="delete-icon color-secondary" v-if="showDelete === item.url" > - + + +
@@ -48,7 +50,9 @@ class="delete-icon color-secondary" v-if="showDelete === item.url" > - + + +
- + + +
@@ -180,6 +186,7 @@ import 'recorder-core/src/engine/mp3' import 'recorder-core/src/engine/mp3-engine' import { MsgWarning } from '@/utils/message' + const route = useRoute() const { query: { mode } @@ -227,7 +234,7 @@ const localLoading = computed({ const imageExtensions = ['jpg', 'jpeg', 'png', 'gif', 'bmp'] const documentExtensions = ['pdf', 'docx', 'txt', 'xls', 'xlsx', 'md', 'html', 'csv'] const videoExtensions = ['mp4', 'avi', 'mov', 'mkv', 'flv'] -const audioExtensions = ['mp3'] +const audioExtensions = ['mp3', 'wav', 'ogg', 'aac'] const getAcceptList = () => { const { image, document, audio, video } = props.applicationDetails.file_upload_setting @@ -513,9 +520,11 @@ function deleteFile(index: number, val: string) { uploadAudioList.value.splice(index, 1) } } + function mouseenter(row: any) { showDelete.value = row.url } + function mouseleave() { showDelete.value = '' } @@ -530,9 +539,11 @@ onMounted(() => {