diff --git a/apps/application/flow/step_node/speech_to_text_step_node/impl/base_speech_to_text_node.py b/apps/application/flow/step_node/speech_to_text_step_node/impl/base_speech_to_text_node.py
index ed2ca9a68..b4806c537 100644
--- a/apps/application/flow/step_node/speech_to_text_step_node/impl/base_speech_to_text_node.py
+++ b/apps/application/flow/step_node/speech_to_text_step_node/impl/base_speech_to_text_node.py
@@ -10,7 +10,7 @@ from pydub import AudioSegment
from concurrent.futures import ThreadPoolExecutor
from application.flow.i_step_node import NodeResult, INode
from application.flow.step_node.speech_to_text_step_node.i_speech_to_text_node import ISpeechToTextNode
-from common.util.common import split_and_transcribe
+from common.util.common import split_and_transcribe, any_to_mp3
from dataset.models import File
from setting.models_provider.tools import get_model_instance_by_model_user_id
@@ -26,16 +26,21 @@ class BaseSpeechToTextNode(ISpeechToTextNode):
audio_list = audio
self.context['audio_list'] = audio
-
def process_audio_item(audio_item, model):
file = QuerySet(File).filter(id=audio_item['file_id']).first()
- with tempfile.NamedTemporaryFile(delete=False, suffix='.mp3') as temp_file:
+ # 根据file_name 吧文件转成mp3格式
+ file_format = file.file_name.split('.')[-1]
+ with tempfile.NamedTemporaryFile(delete=False, suffix=f'.{file_format}') as temp_file:
temp_file.write(file.get_byte().tobytes())
temp_file_path = temp_file.name
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.mp3') as temp_amr_file:
+ temp_mp3_path = temp_amr_file.name
+ any_to_mp3(temp_file_path, temp_mp3_path)
try:
return split_and_transcribe(temp_file_path, model)
finally:
os.remove(temp_file_path)
+ os.remove(temp_mp3_path)
def process_audio_items(audio_list, model):
with ThreadPoolExecutor(max_workers=5) as executor:
diff --git a/apps/common/util/common.py b/apps/common/util/common.py
index 73107d62f..2feb3f57d 100644
--- a/apps/common/util/common.py
+++ b/apps/common/util/common.py
@@ -181,18 +181,18 @@ def sil_to_wav(silk_path, wav_path, rate: int = 24000):
f.write(wav_data)
-def split_and_transcribe(file_path, model, max_segment_length_ms=59000, format="mp3"):
- audio_data = AudioSegment.from_file(file_path, format=format)
+def split_and_transcribe(file_path, model, max_segment_length_ms=59000, audio_format="mp3"):
+ audio_data = AudioSegment.from_file(file_path, format=audio_format)
audio_length_ms = len(audio_data)
if audio_length_ms <= max_segment_length_ms:
- return model.speech_to_text(io.BytesIO(audio_data.export(format=format).read()))
+ return model.speech_to_text(io.BytesIO(audio_data.export(format=audio_format).read()))
full_text = []
for start_ms in range(0, audio_length_ms, max_segment_length_ms):
end_ms = min(audio_length_ms, start_ms + max_segment_length_ms)
segment = audio_data[start_ms:end_ms]
- text = model.speech_to_text(io.BytesIO(segment.export(format=format).read()))
+ text = model.speech_to_text(io.BytesIO(segment.export(format=audio_format).read()))
if isinstance(text, str):
full_text.append(text)
return ' '.join(full_text)
diff --git a/apps/dataset/serializers/file_serializers.py b/apps/dataset/serializers/file_serializers.py
index 28a806338..4fe66a9a6 100644
--- a/apps/dataset/serializers/file_serializers.py
+++ b/apps/dataset/serializers/file_serializers.py
@@ -77,9 +77,11 @@ class FileSerializer(serializers.Serializer):
file = QuerySet(File).filter(id=file_id).first()
if file is None:
raise NotFound404(404, "不存在的文件")
- # 如果是mp3文件,直接返回文件流
- if file.file_name.split(".")[-1] == 'mp3':
- return HttpResponse(file.get_byte(), status=200, headers={'Content-Type': 'audio/mp3',
- 'Content-Disposition': 'attachment; filename="abc.mp3"'})
+ # 如果是音频文件,直接返回文件流
+ file_type = file.file_name.split(".")[-1]
+ if file_type in ['mp3', 'wav', 'ogg', 'aac']:
+ return HttpResponse(file.get_byte(), status=200, headers={'Content-Type': f'audio/{file_type}',
+ 'Content-Disposition': 'attachment; filename="{}"'.format(
+ file.file_name)})
return HttpResponse(file.get_byte(), status=200,
headers={'Content-Type': mime_types.get(file.file_name.split(".")[-1], 'text/plain')})
diff --git a/apps/setting/models_provider/impl/vllm_model_provider/credential/llm.py b/apps/setting/models_provider/impl/vllm_model_provider/credential/llm.py
index 0e100d742..6c14a45fa 100644
--- a/apps/setting/models_provider/impl/vllm_model_provider/credential/llm.py
+++ b/apps/setting/models_provider/impl/vllm_model_provider/credential/llm.py
@@ -34,7 +34,7 @@ class VLLMModelCredential(BaseForm, BaseModelCredential):
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
try:
- model_list = provider.get_base_model_list(model_credential.get('api_base'))
+ model_list = provider.get_base_model_list(model_credential.get('api_base'), model_credential.get('api_key'))
except Exception as e:
raise AppApiException(ValidCode.valid_error.value, "API 域名无效")
exist = provider.get_model_info_by_name(model_list, model_name)
diff --git a/apps/setting/models_provider/impl/vllm_model_provider/vllm_model_provider.py b/apps/setting/models_provider/impl/vllm_model_provider/vllm_model_provider.py
index 42ba361fc..7912fff96 100644
--- a/apps/setting/models_provider/impl/vllm_model_provider/vllm_model_provider.py
+++ b/apps/setting/models_provider/impl/vllm_model_provider/vllm_model_provider.py
@@ -45,10 +45,13 @@ class VllmModelProvider(IModelProvider):
'vllm_icon_svg')))
@staticmethod
- def get_base_model_list(api_base):
+ def get_base_model_list(api_base, api_key):
base_url = get_base_url(api_base)
base_url = base_url if base_url.endswith('/v1') else (base_url + '/v1')
- r = requests.request(method="GET", url=f"{base_url}/models", timeout=5)
+ headers = {}
+ if api_key:
+ headers['Authorization'] = f"Bearer {api_key}"
+ r = requests.request(method="GET", url=f"{base_url}/models", headers=headers, timeout=5)
r.raise_for_status()
return r.json().get('data')
diff --git a/ui/src/components/ai-chat/component/chat-input-operate/index.vue b/ui/src/components/ai-chat/component/chat-input-operate/index.vue
index b8c33b93d..7c90dbd7f 100644
--- a/ui/src/components/ai-chat/component/chat-input-operate/index.vue
+++ b/ui/src/components/ai-chat/component/chat-input-operate/index.vue
@@ -27,7 +27,9 @@
class="delete-icon color-secondary"
v-if="showDelete === item.url"
>
-