fix: 修复文档提取doc图片没有保存和展示的问题

This commit is contained in:
CaptainB 2024-11-28 14:34:51 +08:00 committed by 刘瑞斌
parent be1ddeb252
commit f638abdea2
8 changed files with 57 additions and 11 deletions

View File

@ -23,5 +23,5 @@ class IDocumentExtractNode(INode):
self.node_params_serializer.data.get('document_list')[1:])
return self.execute(document=res, **self.flow_params_serializer.data)
def execute(self, document, **kwargs) -> NodeResult:
def execute(self, document, chat_id, **kwargs) -> NodeResult:
pass

View File

@ -1,16 +1,43 @@
# coding=utf-8
import io
import mimetypes
from django.core.files.uploadedfile import InMemoryUploadedFile
from django.db.models import QuerySet
from application.flow.i_step_node import NodeResult
from application.flow.step_node.document_extract_node.i_document_extract_node import IDocumentExtractNode
from application.models import Chat
from dataset.models import File
from dataset.serializers.document_serializers import split_handles, parse_table_handle_list, FileBufferHandle
from dataset.serializers.file_serializers import FileSerializer
def bytes_to_uploaded_file(file_bytes, file_name="file.txt"):
content_type, _ = mimetypes.guess_type(file_name)
if content_type is None:
# 如果未能识别,设置为默认的二进制文件类型
content_type = "application/octet-stream"
# 创建一个内存中的字节流对象
file_stream = io.BytesIO(file_bytes)
# 获取文件大小
file_size = len(file_bytes)
# 创建 InMemoryUploadedFile 对象
uploaded_file = InMemoryUploadedFile(
file=file_stream,
field_name=None,
name=file_name,
content_type=content_type,
size=file_size,
charset=None,
)
return uploaded_file
class BaseDocumentExtractNode(IDocumentExtractNode):
def execute(self, document, **kwargs):
def execute(self, document, chat_id, **kwargs):
get_buffer = FileBufferHandle().get_buffer
self.context['document_list'] = document
@ -19,6 +46,20 @@ class BaseDocumentExtractNode(IDocumentExtractNode):
if document is None or not isinstance(document, list):
return NodeResult({'content': content}, {})
application = self.workflow_manage.work_flow_post_handler.chat_info.application
# doc文件中的图片保存
def save_image(image_list):
for image in image_list:
meta = {
'debug': False if application.id else True,
'chat_id': chat_id,
'application_id': str(application.id) if application.id else None,
'file_id': str(image.id)
}
file = bytes_to_uploaded_file(image.image, image.image_name)
FileSerializer(data={'file': file, 'meta': meta}).upload()
for doc in document:
file = QuerySet(File).filter(id=doc['file_id']).first()
buffer = io.BytesIO(file.get_byte().tobytes())
@ -28,7 +69,7 @@ class BaseDocumentExtractNode(IDocumentExtractNode):
if split_handle.support(buffer, get_buffer):
# 回到文件头
buffer.seek(0)
file_content = split_handle.get_content(buffer)
file_content = split_handle.get_content(buffer, save_image)
content.append('## ' + doc['name'] + '\n' + file_content)
break

View File

@ -20,5 +20,5 @@ class BaseSplitHandle(ABC):
pass
@abstractmethod
def get_content(self, file):
def get_content(self, file, save_image):
pass

View File

@ -190,12 +190,16 @@ class DocSplitHandle(BaseSplitHandle):
return True
return False
def get_content(self, file):
def get_content(self, file, save_image):
try:
image_list = []
buffer = file.read()
doc = Document(io.BytesIO(buffer))
return self.to_md(doc, image_list, get_image_id_func())
content = self.to_md(doc, image_list, get_image_id_func())
if len(image_list) > 0:
content = content.replace('/api/image/', '/api/file/')
save_image(image_list)
return content
except BaseException as e:
traceback.print_exception(e)
return f'{e}'

View File

@ -61,7 +61,7 @@ class HTMLSplitHandle(BaseSplitHandle):
'content': split_model.parse(content)
}
def get_content(self, file):
def get_content(self, file, save_image):
buffer = file.read()
try:

View File

@ -309,7 +309,7 @@ class PdfSplitHandle(BaseSplitHandle):
return True
return False
def get_content(self, file):
def get_content(self, file, save_image):
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
# 将上传的文件保存到临时文件中
temp_file.write(file.read())

View File

@ -51,7 +51,7 @@ class TextSplitHandle(BaseSplitHandle):
'content': split_model.parse(content)
}
def get_content(self, file):
def get_content(self, file, save_image):
buffer = file.read()
try:
return buffer.decode(detect(buffer)['encoding'])

View File

@ -61,8 +61,9 @@ class FileSerializer(serializers.Serializer):
def upload(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
file_id = uuid.uuid1()
file = File(id=file_id, file_name=self.data.get('file').name, meta=self.data.get('meta'))
meta = self.data.get('meta')
file_id = meta.get('file_id', uuid.uuid1())
file = File(id=file_id, file_name=self.data.get('file').name, meta=meta)
file.save(self.data.get('file').read())
return f'/api/file/{file_id}'