diff --git a/apps/application/flow/step_node/document_extract_node/i_document_extract_node.py b/apps/application/flow/step_node/document_extract_node/i_document_extract_node.py index ceda39444..727267c33 100644 --- a/apps/application/flow/step_node/document_extract_node/i_document_extract_node.py +++ b/apps/application/flow/step_node/document_extract_node/i_document_extract_node.py @@ -23,5 +23,5 @@ class IDocumentExtractNode(INode): self.node_params_serializer.data.get('document_list')[1:]) return self.execute(document=res, **self.flow_params_serializer.data) - def execute(self, document, **kwargs) -> NodeResult: + def execute(self, document, chat_id, **kwargs) -> NodeResult: pass diff --git a/apps/application/flow/step_node/document_extract_node/impl/base_document_extract_node.py b/apps/application/flow/step_node/document_extract_node/impl/base_document_extract_node.py index 1802f68d3..2cfbf3c4b 100644 --- a/apps/application/flow/step_node/document_extract_node/impl/base_document_extract_node.py +++ b/apps/application/flow/step_node/document_extract_node/impl/base_document_extract_node.py @@ -1,16 +1,43 @@ # coding=utf-8 import io +import mimetypes +from django.core.files.uploadedfile import InMemoryUploadedFile from django.db.models import QuerySet from application.flow.i_step_node import NodeResult from application.flow.step_node.document_extract_node.i_document_extract_node import IDocumentExtractNode +from application.models import Chat from dataset.models import File from dataset.serializers.document_serializers import split_handles, parse_table_handle_list, FileBufferHandle +from dataset.serializers.file_serializers import FileSerializer + + +def bytes_to_uploaded_file(file_bytes, file_name="file.txt"): + content_type, _ = mimetypes.guess_type(file_name) + if content_type is None: + # 如果未能识别,设置为默认的二进制文件类型 + content_type = "application/octet-stream" + # 创建一个内存中的字节流对象 + file_stream = io.BytesIO(file_bytes) + + # 获取文件大小 + file_size = len(file_bytes) + + # 创建 InMemoryUploadedFile 对象 + uploaded_file = InMemoryUploadedFile( + file=file_stream, + field_name=None, + name=file_name, + content_type=content_type, + size=file_size, + charset=None, + ) + return uploaded_file class BaseDocumentExtractNode(IDocumentExtractNode): - def execute(self, document, **kwargs): + def execute(self, document, chat_id, **kwargs): get_buffer = FileBufferHandle().get_buffer self.context['document_list'] = document @@ -19,6 +46,20 @@ class BaseDocumentExtractNode(IDocumentExtractNode): if document is None or not isinstance(document, list): return NodeResult({'content': content}, {}) + application = self.workflow_manage.work_flow_post_handler.chat_info.application + + # doc文件中的图片保存 + def save_image(image_list): + for image in image_list: + meta = { + 'debug': False if application.id else True, + 'chat_id': chat_id, + 'application_id': str(application.id) if application.id else None, + 'file_id': str(image.id) + } + file = bytes_to_uploaded_file(image.image, image.image_name) + FileSerializer(data={'file': file, 'meta': meta}).upload() + for doc in document: file = QuerySet(File).filter(id=doc['file_id']).first() buffer = io.BytesIO(file.get_byte().tobytes()) @@ -28,7 +69,7 @@ class BaseDocumentExtractNode(IDocumentExtractNode): if split_handle.support(buffer, get_buffer): # 回到文件头 buffer.seek(0) - file_content = split_handle.get_content(buffer) + file_content = split_handle.get_content(buffer, save_image) content.append('## ' + doc['name'] + '\n' + file_content) break diff --git a/apps/common/handle/base_split_handle.py b/apps/common/handle/base_split_handle.py index ea48e6868..bedaad5e1 100644 --- a/apps/common/handle/base_split_handle.py +++ b/apps/common/handle/base_split_handle.py @@ -20,5 +20,5 @@ class BaseSplitHandle(ABC): pass @abstractmethod - def get_content(self, file): + def get_content(self, file, save_image): pass diff --git a/apps/common/handle/impl/doc_split_handle.py b/apps/common/handle/impl/doc_split_handle.py index 6ac6f43f9..d377abeba 100644 --- a/apps/common/handle/impl/doc_split_handle.py +++ b/apps/common/handle/impl/doc_split_handle.py @@ -190,12 +190,16 @@ class DocSplitHandle(BaseSplitHandle): return True return False - def get_content(self, file): + def get_content(self, file, save_image): try: image_list = [] buffer = file.read() doc = Document(io.BytesIO(buffer)) - return self.to_md(doc, image_list, get_image_id_func()) + content = self.to_md(doc, image_list, get_image_id_func()) + if len(image_list) > 0: + content = content.replace('/api/image/', '/api/file/') + save_image(image_list) + return content except BaseException as e: traceback.print_exception(e) return f'{e}' \ No newline at end of file diff --git a/apps/common/handle/impl/html_split_handle.py b/apps/common/handle/impl/html_split_handle.py index bb69e0af0..90e59ebcb 100644 --- a/apps/common/handle/impl/html_split_handle.py +++ b/apps/common/handle/impl/html_split_handle.py @@ -61,7 +61,7 @@ class HTMLSplitHandle(BaseSplitHandle): 'content': split_model.parse(content) } - def get_content(self, file): + def get_content(self, file, save_image): buffer = file.read() try: diff --git a/apps/common/handle/impl/pdf_split_handle.py b/apps/common/handle/impl/pdf_split_handle.py index 21d243058..8de0129e1 100644 --- a/apps/common/handle/impl/pdf_split_handle.py +++ b/apps/common/handle/impl/pdf_split_handle.py @@ -309,7 +309,7 @@ class PdfSplitHandle(BaseSplitHandle): return True return False - def get_content(self, file): + def get_content(self, file, save_image): with tempfile.NamedTemporaryFile(delete=False) as temp_file: # 将上传的文件保存到临时文件中 temp_file.write(file.read()) diff --git a/apps/common/handle/impl/text_split_handle.py b/apps/common/handle/impl/text_split_handle.py index 1ae22f95f..9d91d874d 100644 --- a/apps/common/handle/impl/text_split_handle.py +++ b/apps/common/handle/impl/text_split_handle.py @@ -51,7 +51,7 @@ class TextSplitHandle(BaseSplitHandle): 'content': split_model.parse(content) } - def get_content(self, file): + def get_content(self, file, save_image): buffer = file.read() try: return buffer.decode(detect(buffer)['encoding']) diff --git a/apps/dataset/serializers/file_serializers.py b/apps/dataset/serializers/file_serializers.py index 1bbd9feb2..2512f13e6 100644 --- a/apps/dataset/serializers/file_serializers.py +++ b/apps/dataset/serializers/file_serializers.py @@ -61,8 +61,9 @@ class FileSerializer(serializers.Serializer): def upload(self, with_valid=True): if with_valid: self.is_valid(raise_exception=True) - file_id = uuid.uuid1() - file = File(id=file_id, file_name=self.data.get('file').name, meta=self.data.get('meta')) + meta = self.data.get('meta') + file_id = meta.get('file_id', uuid.uuid1()) + file = File(id=file_id, file_name=self.data.get('file').name, meta=meta) file.save(self.data.get('file').read()) return f'/api/file/{file_id}'