diff --git a/apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py b/apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py index 040a2be86..6adb2b4c4 100644 --- a/apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py +++ b/apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py @@ -24,17 +24,7 @@ from knowledge.models import Paragraph, Knowledge from knowledge.models import SearchMode from maxkb.conf import PROJECT_DIR from models_provider.models import Model -from models_provider.tools import get_model - - -def get_model_by_id(_id, workspace_id): - model = QuerySet(Model).filter(id=_id, model_type="EMBEDDING") - get_authorized_model = DatabaseModelManage.get_model("get_authorized_model") - if get_authorized_model is not None: - model = get_authorized_model(model, workspace_id) - if model is None: - raise Exception(_("Model does not exist")) - return model +from models_provider.tools import get_model, get_model_by_id def get_embedding_id(knowledge_id_list): @@ -65,6 +55,8 @@ class BaseSearchDatasetStep(ISearchDatasetStep): exec_problem_text = padding_problem_text if padding_problem_text is not None else problem_text model_id = get_embedding_id(knowledge_id_list) model = get_model_by_id(model_id, workspace_id) + if model.model_type != "EMBEDDING": + raise Exception(_("Model does not exist")) self.context['model_name'] = model.name embedding_model = ModelManage.get_model(model_id, lambda _id: get_model(model)) embedding_value = embedding_model.embed_query(exec_problem_text) diff --git a/apps/application/flow/step_node/document_extract_node/impl/base_document_extract_node.py b/apps/application/flow/step_node/document_extract_node/impl/base_document_extract_node.py index 1657dacaf..9a8b7509a 100644 --- a/apps/application/flow/step_node/document_extract_node/impl/base_document_extract_node.py +++ b/apps/application/flow/step_node/document_extract_node/impl/base_document_extract_node.py @@ -71,7 +71,7 @@ class BaseDocumentExtractNode(IDocumentExtractNode): for doc in document: file = QuerySet(File).filter(id=doc['file_id']).first() - buffer = io.BytesIO(file.get_bytes().tobytes()) + buffer = io.BytesIO(file.get_bytes()) buffer.name = doc['name'] # this is the important line for split_handle in (parse_table_handle_list + split_handles): diff --git a/apps/chat/views/chat.py b/apps/chat/views/chat.py index bfd55dbc0..8d2ca52f1 100644 --- a/apps/chat/views/chat.py +++ b/apps/chat/views/chat.py @@ -9,6 +9,7 @@ from django.http import HttpResponse from django.utils.translation import gettext_lazy as _ from drf_spectacular.utils import extend_schema +from rest_framework.parsers import MultiPartParser from rest_framework.request import Request from rest_framework.views import APIView @@ -24,6 +25,8 @@ from common.auth.authentication import has_permissions from common.constants.permission_constants import ChatAuth from common.exception.app_exception import AppAuthenticationFailed from common.result import result +from knowledge.models import FileSourceType +from oss.serializers.file import FileSerializer from users.api import CaptchaAPI from users.serializers.login import CaptchaSerializer @@ -134,7 +137,7 @@ class CaptchaView(APIView): summary=_("Get Chat captcha"), description=_("Get Chat captcha"), operation_id=_("Get Chat captcha"), # type: ignore - tags=[_("User Management")], # type: ignore + tags=[_("Chat")], # type: ignore responses=CaptchaAPI.get_response()) def get(self, request: Request): return result.success(CaptchaSerializer().generate()) @@ -150,7 +153,7 @@ class SpeechToText(APIView): operation_id=_("speech to text"), # type: ignore request=SpeechToTextAPI.get_request(), responses=SpeechToTextAPI.get_response(), - tags=[_('Application')] # type: ignore + tags=[_('Chat')] # type: ignore ) def post(self, request: Request): return result.success( @@ -169,10 +172,34 @@ class TextToSpeech(APIView): operation_id=_("text to speech"), # type: ignore request=TextToSpeechAPI.get_request(), responses=TextToSpeechAPI.get_response(), - tags=[_('Application')] # type: ignore + tags=[_('Chat')] # type: ignore ) def post(self, request: Request): byte_data = TextToSpeechSerializers( data={'application_id': request.auth.application_id}).text_to_speech(request.data) return HttpResponse(byte_data, status=200, headers={'Content-Type': 'audio/mp3', 'Content-Disposition': 'attachment; filename="abc.mp3"'}) + + +class UploadFile(APIView): + authentication_classes = [TokenAuth] + parser_classes = [MultiPartParser] + + @extend_schema( + methods=['POST'], + description=_("Upload files"), + summary=_("Upload files"), + operation_id=_("Upload files"), # type: ignore + request=TextToSpeechAPI.get_request(), + responses=TextToSpeechAPI.get_response(), + tags=[_('Application')] # type: ignore + ) + def post(self, request: Request, chat_id: str): + files = request.FILES.getlist('file') + file_ids = [] + meta = {} + for file in files: + file_url = FileSerializer( + data={'file': file, 'meta': meta, 'source_id': chat_id, 'source_type': FileSourceType.CHAT, }).upload() + file_ids.append({'name': file.name, 'url': file_url, 'file_id': file_url.split('/')[-1]}) + return result.success(file_ids) diff --git a/apps/knowledge/models/knowledge.py b/apps/knowledge/models/knowledge.py index cafaafdc0..26d6bea74 100644 --- a/apps/knowledge/models/knowledge.py +++ b/apps/knowledge/models/knowledge.py @@ -95,8 +95,6 @@ def default_status_meta(): return {"state_time": {}} - - class KnowledgeFolder(MPTTModel, AppModelMixin): id = models.CharField(primary_key=True, max_length=64, editable=False, verbose_name="主键id") name = models.CharField(max_length=64, verbose_name="文件夹名称") @@ -133,9 +131,11 @@ class Knowledge(AppModelMixin): class Meta: db_table = "knowledge" + def get_default_status(): return Status('').__str__() + class Document(AppModelMixin): """ 文档表 @@ -224,10 +224,12 @@ class FileSourceType(models.TextChoices): TOOL = "TOOL" # 文档 DOCUMENT = "DOCUMENT" + # 对话 + CHAT = "CHAT" # 临时30分钟 数据30分钟后被清理 source_id 为TEMPORARY_30_MINUTE TEMPORARY_30_MINUTE = "TEMPORARY_30_MINUTE" # 临时120分钟 数据120分钟后被清理 source_id为TEMPORARY_100_MINUTE - TEMPORARY_120_MINUTE = "TEMPORARY_100_MINUTE" + TEMPORARY_120_MINUTE = "TEMPORARY_120_MINUTE" # 临时1天 数据1天后被清理 source_id为TEMPORARY_1_DAY TEMPORARY_1_DAY = "TEMPORARY_1_DAY" diff --git a/apps/maxkb/urls.py b/apps/maxkb/urls.py index 7963db8f4..f2e8b68fa 100644 --- a/apps/maxkb/urls.py +++ b/apps/maxkb/urls.py @@ -42,6 +42,7 @@ urlpatterns = [ path(admin_api_prefix, include("system_manage.urls")), path(admin_api_prefix, include("application.urls")), path(admin_api_prefix, include("oss.urls")), + path(chat_api_prefix, include("oss.urls")), path(chat_api_prefix, include("chat.urls")), path(f'{admin_ui_prefix[1:]}/', include('oss.retrieval_urls')), path(f'{chat_ui_prefix[1:]}/', include('oss.retrieval_urls')), diff --git a/ui/src/api/application/application.ts b/ui/src/api/application/application.ts index 6fdb944b0..bf416cc2b 100644 --- a/ui/src/api/application/application.ts +++ b/ui/src/api/application/application.ts @@ -296,6 +296,30 @@ const getMcpTools: (application_id: String, loading?: Ref) => Promise Promise> = (file, sourceId, resourceType) => { + const fd = new FormData() + fd.append('file', file) + fd.append('source_id', sourceId) + fd.append('source_type', resourceType) + return post(`/oss/file`, fd) +} + export default { getAllApplication, getApplication, @@ -321,4 +345,5 @@ export default { postTextToSpeech, speechToText, getMcpTools, + uploadFile, } diff --git a/ui/src/api/chat/chat.ts b/ui/src/api/chat/chat.ts index 6dc0d04b1..c75eb8f9d 100644 --- a/ui/src/api/chat/chat.ts +++ b/ui/src/api/chat/chat.ts @@ -265,30 +265,57 @@ const speechToText: (data: any, loading?: Ref) => Promise> return post(`speech_to_text`, data, undefined, loading) } /** - * + * * @param chat_id 对话ID - * @param loading - * @returns + * @param loading + * @returns */ -const deleteChat: (chat_id:string, loading?: Ref) => Promise> = ( +const deleteChat: (chat_id: string, loading?: Ref) => Promise> = ( chat_id, loading, ) => { return del(`historical_conversation/${chat_id}`, loading) } /** - * + * * @param chat_id 对话id * @param data 对话简介 - * @param loading - * @returns + * @param loading + * @returns */ -const modifyChat: (chat_id:string, data:any, loading?: Ref ) => Promise> = ( - chat_id,data,loading +const modifyChat: (chat_id: string, data: any, loading?: Ref) => Promise> = ( + chat_id, + data, + loading, ) => { return put(`historical_conversation/${chat_id}`, data, undefined, loading) } - +/** + * 上传文件 + * @param file 文件 + * @param sourceId 资源id + * @param resourceType 资源类型 + * @returns + */ +const uploadFile: ( + file: any, + sourceId: string, + resourceType: + | 'KNOWLEDGE' + | 'APPLICATION' + | 'TOOL' + | 'DOCUMENT' + | 'CHAT' + | 'TEMPORARY_30_MINUTE' + | 'TEMPORARY_120_MINUTE' + | 'TEMPORARY_1_DAY', +) => Promise> = (file, sourceId, sourceType) => { + const fd = new FormData() + fd.append('file', file) + fd.append('source_id', sourceId) + fd.append('source_type', sourceType) + return post(`/oss/file`, fd) +} export default { open, chat, @@ -316,5 +343,6 @@ export default { textToSpeech, speechToText, deleteChat, - modifyChat + modifyChat, + uploadFile, } diff --git a/ui/src/components/ai-chat/component/chat-input-operate/index.vue b/ui/src/components/ai-chat/component/chat-input-operate/index.vue index 5df07308d..f85ab6c7c 100644 --- a/ui/src/components/ai-chat/component/chat-input-operate/index.vue +++ b/ui/src/components/ai-chat/component/chat-input-operate/index.vue @@ -42,7 +42,7 @@
@@ -80,7 +80,7 @@
@@ -115,7 +115,7 @@
@@ -136,7 +136,7 @@ @mouseleave.stop="mouseleave()" >
@@ -297,7 +297,6 @@ import { ref, computed, onMounted, nextTick, watch, type Ref } from 'vue' import Recorder from 'recorder-core' import TouchChat from './TouchChat.vue' import applicationApi from '@/api/application/application' -import UserForm from '@/components/ai-chat/component/user-form/index.vue' import { MsgAlert } from '@/utils/message' import { type chatType } from '@/api/type/application' import { useRoute, useRouter } from 'vue-router' @@ -354,7 +353,6 @@ const localLoading = computed({ }, }) - const upload = ref() const imageExtensions = ['JPG', 'JPEG', 'PNG', 'GIF', 'BMP'] @@ -421,92 +419,24 @@ const uploadFile = async (file: any, fileList: any) => { fileList.splice(0, fileList.length) return } - - const formData = new FormData() - formData.append('file', file.raw, file.name) - // - const extension = file.name.split('.').pop().toUpperCase() // 获取文件后缀名并转为小写 - if (imageExtensions.includes(extension)) { - uploadImageList.value.push(file) - } else if (documentExtensions.includes(extension)) { - uploadDocumentList.value.push(file) - } else if (videoExtensions.includes(extension)) { - uploadVideoList.value.push(file) - } else if (audioExtensions.includes(extension)) { - uploadAudioList.value.push(file) - } else if (otherExtensions.includes(extension)) { - uploadOtherList.value.push(file) - } + fileAllList.value = fileList if (!chatId_context.value) { const res = await props.openChatId() chatId_context.value = res } - - if (props.type === 'debug-ai-chat') { - formData.append('debug', 'true') - } else { - formData.append('debug', 'false') + const api = + props.type === 'debug-ai-chat' + ? applicationApi.uploadFile(file.raw, 'TEMPORARY_120_MINUTE', 'TEMPORARY_120_MINUTE') + : chatAPI.uploadFile(file.raw, chatId_context.value, 'CHAT') + api.then((ok) => { + file.url = ok.data + const split_path = ok.data.split('/') + file.file_id = split_path[split_path.length - 1] + }) + if (!inputValue.value && uploadImageList.value.length > 0) { + inputValue.value = t('chat.uploadFile.imageMessage') } - - applicationApi - .uploadFile( - props.applicationDetails.id as string, - chatId_context.value as string, - formData, - localLoading, - ) - .then((response: any) => { - fileList.splice(0, fileList.length) - uploadImageList.value.forEach((file: any) => { - const f = response.data.filter( - (f: any) => f.name.replaceAll(' ', '') === file.name.replaceAll(' ', ''), - ) - if (f.length > 0) { - file.url = f[0].url - file.file_id = f[0].file_id - } - }) - uploadDocumentList.value.forEach((file: any) => { - const f = response.data.filter( - (f: any) => f.name.replaceAll(' ', '') == file.name.replaceAll(' ', ''), - ) - if (f.length > 0) { - file.url = f[0].url - file.file_id = f[0].file_id - } - }) - uploadAudioList.value.forEach((file: any) => { - const f = response.data.filter( - (f: any) => f.name.replaceAll(' ', '') === file.name.replaceAll(' ', ''), - ) - if (f.length > 0) { - file.url = f[0].url - file.file_id = f[0].file_id - } - }) - uploadVideoList.value.forEach((file: any) => { - const f = response.data.filter( - (f: any) => f.name.replaceAll(' ', '') === file.name.replaceAll(' ', ''), - ) - if (f.length > 0) { - file.url = f[0].url - file.file_id = f[0].file_id - } - }) - uploadOtherList.value.forEach((file: any) => { - const f = response.data.filter( - (f: any) => f.name.replaceAll(' ', '') === file.name.replaceAll(' ', ''), - ) - if (f.length > 0) { - file.url = f[0].url - file.file_id = f[0].file_id - } - }) - if (!inputValue.value && uploadImageList.value.length > 0) { - inputValue.value = t('chat.uploadFile.imageMessage') - } - }) } // 粘贴处理 const handlePaste = (event: ClipboardEvent) => { @@ -564,11 +494,19 @@ const recorderTime = ref(0) const recorderStatus = ref<'START' | 'TRANSCRIBING' | 'STOP'>('STOP') const inputValue = ref('') -const uploadImageList = ref>([]) -const uploadDocumentList = ref>([]) -const uploadVideoList = ref>([]) -const uploadAudioList = ref>([]) -const uploadOtherList = ref>([]) + +const fileAllList = ref>([]) + +const fileFilter = (fileList: Array, extensionList: Array) => { + return fileList.filter((f) => { + return extensionList.includes(f.name.split('.').pop().toUpperCase()) + }) +} +const uploadImageList = computed(() => fileFilter(fileAllList.value, imageExtensions)) +const uploadDocumentList = computed(() => fileFilter(fileAllList.value, documentExtensions)) +const uploadVideoList = computed(() => fileFilter(fileAllList.value, videoExtensions)) +const uploadAudioList = computed(() => fileFilter(fileAllList.value, audioExtensions)) +const uploadOtherList = computed(() => fileFilter(fileAllList.value, otherExtensions)) const showDelete = ref('') @@ -794,11 +732,11 @@ function autoSendMessage() { other_list: uploadOtherList.value, }) inputValue.value = '' - uploadImageList.value = [] - uploadDocumentList.value = [] - uploadAudioList.value = [] - uploadVideoList.value = [] - uploadOtherList.value = [] + fileAllList.value = [] + if (upload.value) { + upload.value.clearFiles() + } + if (quickInputRef.value) { quickInputRef.value.textarea.style.height = '45px' } @@ -845,18 +783,8 @@ const insertNewlineAtCursor = (event?: any) => { }) } -function deleteFile(index: number, val: string) { - if (val === 'image') { - uploadImageList.value.splice(index, 1) - } else if (val === 'document') { - uploadDocumentList.value.splice(index, 1) - } else if (val === 'video') { - uploadVideoList.value.splice(index, 1) - } else if (val === 'audio') { - uploadAudioList.value.splice(index, 1) - } else if (val === 'other') { - uploadOtherList.value.splice(index, 1) - } +function deleteFile(item: any) { + fileAllList.value = fileAllList.value.filter((i) => i != item) } function mouseenter(row: any) { diff --git a/ui/src/views/application/component/CreateApplicationDialog.vue b/ui/src/views/application/component/CreateApplicationDialog.vue index 79155af01..1d29bd11e 100644 --- a/ui/src/views/application/component/CreateApplicationDialog.vue +++ b/ui/src/views/application/component/CreateApplicationDialog.vue @@ -218,7 +218,6 @@ const open = (folder: string, type?: string) => { const submitHandle = async (formEl: FormInstance | undefined) => { if (!formEl) return - console.log(applicationForm.value.type) await formEl.validate((valid) => { if (valid) { if (isWorkFlow(applicationForm.value.type) && appTemplate.value === 'blank') { @@ -226,7 +225,6 @@ const submitHandle = async (formEl: FormInstance | undefined) => { workflowDefault.value.nodes[0].properties.node_data.name = applicationForm.value.name applicationForm.value['work_flow'] = workflowDefault.value } - console.log(applicationForm.value.type) applicationApi .postApplication({ ...applicationForm.value, folder_id: currentFolder.value }, loading) .then((res) => {