diff --git a/apps/application/flow/step_node/document_extract_node/i_document_extract_node.py b/apps/application/flow/step_node/document_extract_node/i_document_extract_node.py index ceda39444..727267c33 100644 --- a/apps/application/flow/step_node/document_extract_node/i_document_extract_node.py +++ b/apps/application/flow/step_node/document_extract_node/i_document_extract_node.py @@ -23,5 +23,5 @@ class IDocumentExtractNode(INode): self.node_params_serializer.data.get('document_list')[1:]) return self.execute(document=res, **self.flow_params_serializer.data) - def execute(self, document, **kwargs) -> NodeResult: + def execute(self, document, chat_id, **kwargs) -> NodeResult: pass diff --git a/apps/application/flow/step_node/document_extract_node/impl/base_document_extract_node.py b/apps/application/flow/step_node/document_extract_node/impl/base_document_extract_node.py index 1802f68d3..620763aef 100644 --- a/apps/application/flow/step_node/document_extract_node/impl/base_document_extract_node.py +++ b/apps/application/flow/step_node/document_extract_node/impl/base_document_extract_node.py @@ -1,23 +1,64 @@ # coding=utf-8 import io +import mimetypes +from django.core.files.uploadedfile import InMemoryUploadedFile from django.db.models import QuerySet from application.flow.i_step_node import NodeResult from application.flow.step_node.document_extract_node.i_document_extract_node import IDocumentExtractNode from dataset.models import File from dataset.serializers.document_serializers import split_handles, parse_table_handle_list, FileBufferHandle +from dataset.serializers.file_serializers import FileSerializer +def bytes_to_uploaded_file(file_bytes, file_name="file.txt"): + content_type, _ = mimetypes.guess_type(file_name) + if content_type is None: + # 如果未能识别,设置为默认的二进制文件类型 + content_type = "application/octet-stream" + # 创建一个内存中的字节流对象 + file_stream = io.BytesIO(file_bytes) + + # 获取文件大小 + file_size = len(file_bytes) + + # 创建 InMemoryUploadedFile 对象 + uploaded_file = InMemoryUploadedFile( + file=file_stream, + field_name=None, + name=file_name, + content_type=content_type, + size=file_size, + charset=None, + ) + return uploaded_file + + +splitter = '\n`-----------------------------------`\n' + class BaseDocumentExtractNode(IDocumentExtractNode): - def execute(self, document, **kwargs): + def execute(self, document, chat_id, **kwargs): get_buffer = FileBufferHandle().get_buffer self.context['document_list'] = document content = [] - splitter = '\n`-----------------------------------`\n' if document is None or not isinstance(document, list): - return NodeResult({'content': content}, {}) + return NodeResult({'content': ''}, {}) + + application = self.workflow_manage.work_flow_post_handler.chat_info.application + + # doc文件中的图片保存 + def save_image(image_list): + for image in image_list: + meta = { + 'debug': False if application.id else True, + 'chat_id': chat_id, + 'application_id': str(application.id) if application.id else None, + 'file_id': str(image.id) + } + file = bytes_to_uploaded_file(image.image, image.image_name) + FileSerializer(data={'file': file, 'meta': meta}).upload() for doc in document: file = QuerySet(File).filter(id=doc['file_id']).first() @@ -28,21 +69,21 @@ class BaseDocumentExtractNode(IDocumentExtractNode): if split_handle.support(buffer, get_buffer): # 回到文件头 buffer.seek(0) - file_content = split_handle.get_content(buffer) + file_content = split_handle.get_content(buffer, save_image) content.append('## ' + doc['name'] + '\n' + file_content) break return NodeResult({'content': splitter.join(content)}, {}) def get_details(self, index: int, **kwargs): + content = self.context.get('content', '').split(splitter) # 不保存content全部内容,因为content内容可能会很大 - content = (self.context.get('content')[:500] + '...') if len(self.context.get('content')) > 0 else '' return { 'name': self.node.properties.get('stepName'), "index": index, 'run_time': self.context.get('run_time'), 'type': self.node.type, - 'content': content, + 'content': [file_content[:500] for file_content in content], 'status': self.status, 'err_message': self.err_message, 'document_list': self.context.get('document_list') diff --git a/apps/application/serializers/application_serializers.py b/apps/application/serializers/application_serializers.py index 23b058040..7c43ada4b 100644 --- a/apps/application/serializers/application_serializers.py +++ b/apps/application/serializers/application_serializers.py @@ -147,9 +147,9 @@ class ApplicationWorkflowSerializer(serializers.Serializer): default_workflow = json.loads(default_workflow_json) for node in default_workflow.get('nodes'): if node.get('id') == 'base-node': - node.get('properties')['node_data'] = {"desc": application.get('desc'), - "name": application.get('name'), - "prologue": application.get('prologue')} + node.get('properties')['node_data']['desc'] = application.get('desc') + node.get('properties')['node_data']['name'] = application.get('name') + node.get('properties')['node_data']['prologue'] = application.get('prologue') return Application(id=uuid.uuid1(), name=application.get('name'), desc=application.get('desc'), @@ -160,6 +160,14 @@ class ApplicationWorkflowSerializer(serializers.Serializer): model_setting={}, problem_optimization=False, type=ApplicationTypeChoices.WORK_FLOW, + stt_model_enable=application.get('stt_model_enable', False), + stt_model_id=application.get('stt_model', None), + tts_model_id=application.get('tts_model', None), + tts_model_enable=application.get('tts_model_enable', False), + tts_model_params_setting=application.get('tts_model_params_setting', {}), + tts_type=application.get('tts_type', None), + file_upload_enable=application.get('file_upload_enable', False), + file_upload_setting=application.get('file_upload_setting', {}), work_flow=default_workflow ) @@ -502,6 +510,14 @@ class ApplicationSerializer(serializers.Serializer): type=ApplicationTypeChoices.SIMPLE, model_params_setting=application.get('model_params_setting', {}), problem_optimization_prompt=application.get('problem_optimization_prompt', None), + stt_model_enable=application.get('stt_model_enable', False), + stt_model_id=application.get('stt_model', None), + tts_model_id=application.get('tts_model', None), + tts_model_enable=application.get('tts_model_enable', False), + tts_model_params_setting=application.get('tts_model_params_setting', {}), + tts_type=application.get('tts_type', None), + file_upload_enable=application.get('file_upload_enable', False), + file_upload_setting=application.get('file_upload_setting', {}), work_flow={} ) diff --git a/apps/common/handle/base_parse_table_handle.py b/apps/common/handle/base_parse_table_handle.py index b84690859..65eaf897f 100644 --- a/apps/common/handle/base_parse_table_handle.py +++ b/apps/common/handle/base_parse_table_handle.py @@ -19,5 +19,5 @@ class BaseParseTableHandle(ABC): pass @abstractmethod - def get_content(self, file): + def get_content(self, file, save_image): pass \ No newline at end of file diff --git a/apps/common/handle/base_split_handle.py b/apps/common/handle/base_split_handle.py index ea48e6868..bedaad5e1 100644 --- a/apps/common/handle/base_split_handle.py +++ b/apps/common/handle/base_split_handle.py @@ -20,5 +20,5 @@ class BaseSplitHandle(ABC): pass @abstractmethod - def get_content(self, file): + def get_content(self, file, save_image): pass diff --git a/apps/common/handle/impl/doc_split_handle.py b/apps/common/handle/impl/doc_split_handle.py index 6ac6f43f9..d377abeba 100644 --- a/apps/common/handle/impl/doc_split_handle.py +++ b/apps/common/handle/impl/doc_split_handle.py @@ -190,12 +190,16 @@ class DocSplitHandle(BaseSplitHandle): return True return False - def get_content(self, file): + def get_content(self, file, save_image): try: image_list = [] buffer = file.read() doc = Document(io.BytesIO(buffer)) - return self.to_md(doc, image_list, get_image_id_func()) + content = self.to_md(doc, image_list, get_image_id_func()) + if len(image_list) > 0: + content = content.replace('/api/image/', '/api/file/') + save_image(image_list) + return content except BaseException as e: traceback.print_exception(e) return f'{e}' \ No newline at end of file diff --git a/apps/common/handle/impl/html_split_handle.py b/apps/common/handle/impl/html_split_handle.py index bb69e0af0..90e59ebcb 100644 --- a/apps/common/handle/impl/html_split_handle.py +++ b/apps/common/handle/impl/html_split_handle.py @@ -61,7 +61,7 @@ class HTMLSplitHandle(BaseSplitHandle): 'content': split_model.parse(content) } - def get_content(self, file): + def get_content(self, file, save_image): buffer = file.read() try: diff --git a/apps/common/handle/impl/pdf_split_handle.py b/apps/common/handle/impl/pdf_split_handle.py index 21d243058..8de0129e1 100644 --- a/apps/common/handle/impl/pdf_split_handle.py +++ b/apps/common/handle/impl/pdf_split_handle.py @@ -309,7 +309,7 @@ class PdfSplitHandle(BaseSplitHandle): return True return False - def get_content(self, file): + def get_content(self, file, save_image): with tempfile.NamedTemporaryFile(delete=False) as temp_file: # 将上传的文件保存到临时文件中 temp_file.write(file.read()) diff --git a/apps/common/handle/impl/table/csv_parse_table_handle.py b/apps/common/handle/impl/table/csv_parse_table_handle.py index dcd971839..e2fc7ce86 100644 --- a/apps/common/handle/impl/table/csv_parse_table_handle.py +++ b/apps/common/handle/impl/table/csv_parse_table_handle.py @@ -35,7 +35,7 @@ class CsvSplitHandle(BaseParseTableHandle): return [{'name': file.name, 'paragraphs': paragraphs}] - def get_content(self, file): + def get_content(self, file, save_image): buffer = file.read() try: return buffer.decode(detect(buffer)['encoding']) diff --git a/apps/common/handle/impl/table/xls_parse_table_handle.py b/apps/common/handle/impl/table/xls_parse_table_handle.py index 0fee4e35b..5b7f594a1 100644 --- a/apps/common/handle/impl/table/xls_parse_table_handle.py +++ b/apps/common/handle/impl/table/xls_parse_table_handle.py @@ -61,7 +61,7 @@ class XlsSplitHandle(BaseParseTableHandle): return [{'name': file.name, 'paragraphs': []}] return result - def get_content(self, file): + def get_content(self, file, save_image): # 打开 .xls 文件 try: workbook = xlrd.open_workbook(file_contents=file.read(), formatting_info=True) diff --git a/apps/common/handle/impl/table/xlsx_parse_table_handle.py b/apps/common/handle/impl/table/xlsx_parse_table_handle.py index 3fd40b2d1..2ae22d019 100644 --- a/apps/common/handle/impl/table/xlsx_parse_table_handle.py +++ b/apps/common/handle/impl/table/xlsx_parse_table_handle.py @@ -74,7 +74,7 @@ class XlsxSplitHandle(BaseParseTableHandle): return result - def get_content(self, file): + def get_content(self, file, save_image): try: # 加载 Excel 文件 workbook = load_workbook(file) diff --git a/apps/common/handle/impl/text_split_handle.py b/apps/common/handle/impl/text_split_handle.py index 1ae22f95f..9d91d874d 100644 --- a/apps/common/handle/impl/text_split_handle.py +++ b/apps/common/handle/impl/text_split_handle.py @@ -51,7 +51,7 @@ class TextSplitHandle(BaseSplitHandle): 'content': split_model.parse(content) } - def get_content(self, file): + def get_content(self, file, save_image): buffer = file.read() try: return buffer.decode(detect(buffer)['encoding']) diff --git a/apps/common/job/clean_chat_job.py b/apps/common/job/clean_chat_job.py index 332beee18..d42c39982 100644 --- a/apps/common/job/clean_chat_job.py +++ b/apps/common/job/clean_chat_job.py @@ -4,6 +4,7 @@ import logging import datetime from django.db import transaction +from django.db.models.fields.json import KeyTextTransform from django.utils import timezone from apscheduler.schedulers.background import BackgroundScheduler from django_apscheduler.jobstores import DjangoJobStore @@ -11,6 +12,8 @@ from application.models import Application, Chat from django.db.models import Q from common.lock.impl.file_lock import FileLock from dataset.models import File +from django.db.models.functions import Cast +from django.db import models scheduler = BackgroundScheduler() scheduler.add_jobstore(DjangoJobStore(), "default") @@ -40,7 +43,7 @@ def clean_chat_log_job(): break deleted_count, _ = Chat.objects.filter(id__in=logs_to_delete).delete() # 删除对应的文件 - File.objects.filter(~Q(meta__chat_id__in=logs_to_delete)).delete() + File.objects.filter(meta__chat_id__in=[str(uuid) for uuid in logs_to_delete]).delete() if deleted_count < batch_size: break diff --git a/apps/dataset/serializers/file_serializers.py b/apps/dataset/serializers/file_serializers.py index 1bbd9feb2..2512f13e6 100644 --- a/apps/dataset/serializers/file_serializers.py +++ b/apps/dataset/serializers/file_serializers.py @@ -61,8 +61,9 @@ class FileSerializer(serializers.Serializer): def upload(self, with_valid=True): if with_valid: self.is_valid(raise_exception=True) - file_id = uuid.uuid1() - file = File(id=file_id, file_name=self.data.get('file').name, meta=self.data.get('meta')) + meta = self.data.get('meta') + file_id = meta.get('file_id', uuid.uuid1()) + file = File(id=file_id, file_name=self.data.get('file').name, meta=meta) file.save(self.data.get('file').read()) return f'/api/file/{file_id}' diff --git a/ui/src/api/dataset.ts b/ui/src/api/dataset.ts index c50e64513..d6ee4d84c 100644 --- a/ui/src/api/dataset.ts +++ b/ui/src/api/dataset.ts @@ -178,7 +178,7 @@ const putSyncWebDataset: ( } /** - * 重新向量化知识库 + * 向量化知识库 * @param 参数 dataset_id */ const putReEmbeddingDataset: ( diff --git a/ui/src/components/ai-chat/ExecutionDetailDialog.vue b/ui/src/components/ai-chat/ExecutionDetailDialog.vue index 673de74d7..c39e95c60 100644 --- a/ui/src/components/ai-chat/ExecutionDetailDialog.vue +++ b/ui/src/components/ai-chat/ExecutionDetailDialog.vue @@ -218,26 +218,30 @@