diff --git a/apps/common/handle/impl/csv_split_handle.py b/apps/common/handle/impl/csv_split_handle.py
new file mode 100644
index 000000000..d4691a3d6
--- /dev/null
+++ b/apps/common/handle/impl/csv_split_handle.py
@@ -0,0 +1,70 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: csv_parse_qa_handle.py
+ @date:2024/5/21 14:59
+ @desc:
+"""
+import csv
+import io
+from typing import List
+
+from charset_normalizer import detect
+
+from common.handle.base_split_handle import BaseSplitHandle
+
+
+def post_cell(cell_value):
+ return cell_value.replace('\n', '
').replace('|', '|')
+
+
+def row_to_md(row):
+ return '| ' + ' | '.join(
+ [post_cell(cell) if cell is not None else '' for cell in row]) + ' |\n'
+
+
+class CsvSplitHandle(BaseSplitHandle):
+ def handle(self, file, pattern_list: List, with_filter: bool, limit: int, get_buffer, save_image):
+ buffer = get_buffer(file)
+ paragraphs = []
+ result = {'name': file.name, 'content': paragraphs}
+ try:
+ reader = csv.reader(io.TextIOWrapper(io.BytesIO(buffer), encoding=detect(buffer)['encoding']))
+ try:
+ title_row_list = reader.__next__()
+ title_md_content = row_to_md(title_row_list)
+ title_md_content += '| ' + ' | '.join(
+ ['---' if cell is not None else '' for cell in title_row_list]) + ' |\n'
+ except Exception as e:
+ return result
+ if len(title_row_list) == 0:
+ return result
+ result_item_content = ''
+ for row in reader:
+ next_md_content = row_to_md(row)
+ next_md_content_len = len(next_md_content)
+ result_item_content_len = len(result_item_content)
+ if len(result_item_content) == 0:
+ result_item_content += title_md_content
+ result_item_content += next_md_content
+ else:
+ if result_item_content_len + next_md_content_len < limit:
+ result_item_content += next_md_content
+ else:
+ paragraphs.append({'content': result_item_content, 'title': ''})
+ result_item_content = ''
+ if len(result_item_content) > 0:
+ paragraphs.append({'content': result_item_content, 'title': ''})
+ return result
+ except Exception as e:
+ return result
+
+ def get_content(self, file, save_image):
+ pass
+
+ def support(self, file, get_buffer):
+ file_name: str = file.name.lower()
+ if file_name.endswith(".csv"):
+ return True
+ return False
diff --git a/apps/common/handle/impl/qa/zip_parse_qa_handle.py b/apps/common/handle/impl/qa/zip_parse_qa_handle.py
new file mode 100644
index 000000000..cb9418723
--- /dev/null
+++ b/apps/common/handle/impl/qa/zip_parse_qa_handle.py
@@ -0,0 +1,162 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: text_split_handle.py
+ @date:2024/3/27 18:19
+ @desc:
+"""
+import io
+import os
+import re
+import uuid
+import zipfile
+from typing import List
+from urllib.parse import urljoin
+
+from django.db.models import QuerySet
+
+from common.handle.base_parse_qa_handle import BaseParseQAHandle
+from common.handle.impl.qa.csv_parse_qa_handle import CsvParseQAHandle
+from common.handle.impl.qa.xls_parse_qa_handle import XlsParseQAHandle
+from common.handle.impl.qa.xlsx_parse_qa_handle import XlsxParseQAHandle
+from common.util.common import parse_md_image
+from dataset.models import Image
+
+
+class FileBufferHandle:
+ buffer = None
+
+ def get_buffer(self, file):
+ if self.buffer is None:
+ self.buffer = file.read()
+ return self.buffer
+
+
+split_handles = [XlsParseQAHandle(), XlsxParseQAHandle(), CsvParseQAHandle()]
+
+
+def save_inner_image(image_list):
+ """
+ 子模块插入图片逻辑
+ @param image_list:
+ @return:
+ """
+ if image_list is not None and len(image_list) > 0:
+ QuerySet(Image).bulk_create(image_list)
+
+
+def file_to_paragraph(file):
+ """
+ 文件转换为段落列表
+ @param file: 文件
+ @return: {
+ name:文件名
+ paragraphs:段落列表
+ }
+ """
+ get_buffer = FileBufferHandle().get_buffer
+ for split_handle in split_handles:
+ if split_handle.support(file, get_buffer):
+ return split_handle.handle(file, get_buffer, save_inner_image)
+ raise Exception("不支持的文件格式")
+
+
+def is_valid_uuid(uuid_str: str):
+ """
+ 校验字符串是否是uuid
+ @param uuid_str: 需要校验的字符串
+ @return: bool
+ """
+ try:
+ uuid.UUID(uuid_str)
+ except ValueError:
+ return False
+ return True
+
+
+def get_image_list(result_list: list, zip_files: List[str]):
+ """
+ 获取图片文件列表
+ @param result_list:
+ @param zip_files:
+ @return:
+ """
+ image_file_list = []
+ for result in result_list:
+ for p in result.get('paragraphs', []):
+ content: str = p.get('content', '')
+ image_list = parse_md_image(content)
+ for image in image_list:
+ search = re.search("\(.*\)", image)
+ if search:
+ new_image_id = str(uuid.uuid1())
+ source_image_path = search.group().replace('(', '').replace(')', '')
+ image_path = urljoin(result.get('name'), '.' + source_image_path if source_image_path.startswith(
+ '/') else source_image_path)
+ if not zip_files.__contains__(image_path):
+ continue
+ if image_path.startswith('api/file/') or image_path.startswith('api/image/'):
+ image_id = image_path.replace('api/file/', '').replace('api/image/', '')
+ if is_valid_uuid(image_id):
+ image_file_list.append({'source_file': image_path,
+ 'image_id': image_id})
+ else:
+ image_file_list.append({'source_file': image_path,
+ 'image_id': new_image_id})
+ content = content.replace(source_image_path, f'/api/image/{new_image_id}')
+ p['content'] = content
+ else:
+ image_file_list.append({'source_file': image_path,
+ 'image_id': new_image_id})
+ content = content.replace(source_image_path, f'/api/image/{new_image_id}')
+ p['content'] = content
+
+ return image_file_list
+
+
+def filter_image_file(result_list: list, image_list):
+ image_source_file_list = [image.get('source_file') for image in image_list]
+ return [r for r in result_list if not image_source_file_list.__contains__(r.get('name', ''))]
+
+
+class ZipParseQAHandle(BaseParseQAHandle):
+
+ def handle(self, file, get_buffer, save_image):
+ buffer = get_buffer(file)
+ bytes_io = io.BytesIO(buffer)
+ result = []
+ # 打开zip文件
+ with zipfile.ZipFile(bytes_io, 'r') as zip_ref:
+ # 获取压缩包中的文件名列表
+ files = zip_ref.namelist()
+ # 读取压缩包中的文件内容
+ for file in files:
+ if file.endswith('/'):
+ continue
+ with zip_ref.open(file) as f:
+ # 对文件内容进行处理
+ try:
+ value = file_to_paragraph(f)
+ if isinstance(value, list):
+ result = [*result, *value]
+ else:
+ result.append(value)
+ except Exception:
+ pass
+ image_list = get_image_list(result, files)
+ result = filter_image_file(result, image_list)
+ image_mode_list = []
+ for image in image_list:
+ with zip_ref.open(image.get('source_file')) as f:
+ i = Image(id=image.get('image_id'), image=f.read(),
+ image_name=os.path.basename(image.get('source_file')))
+ image_mode_list.append(i)
+ save_image(image_mode_list)
+ return result
+
+ def support(self, file, get_buffer):
+ file_name: str = file.name.lower()
+ if file_name.endswith(".zip") or file_name.endswith(".ZIP"):
+ return True
+ return False
diff --git a/apps/common/handle/impl/xls_split_handle.py b/apps/common/handle/impl/xls_split_handle.py
new file mode 100644
index 000000000..332e5b56d
--- /dev/null
+++ b/apps/common/handle/impl/xls_split_handle.py
@@ -0,0 +1,80 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: xls_parse_qa_handle.py
+ @date:2024/5/21 14:59
+ @desc:
+"""
+from typing import List
+
+import xlrd
+
+from common.handle.base_split_handle import BaseSplitHandle
+
+
+def post_cell(cell_value):
+ return cell_value.replace('\n', '
').replace('|', '|')
+
+
+def row_to_md(row):
+ return '| ' + ' | '.join(
+ [post_cell(str(cell)) if cell is not None else '' for cell in row]) + ' |\n'
+
+
+def handle_sheet(file_name, sheet, limit: int):
+ rows = iter([sheet.row_values(i) for i in range(sheet.nrows)])
+ paragraphs = []
+ result = {'name': file_name, 'content': paragraphs}
+ try:
+ title_row_list = next(rows)
+ title_md_content = row_to_md(title_row_list)
+ title_md_content += '| ' + ' | '.join(
+ ['---' if cell is not None else '' for cell in title_row_list]) + ' |\n'
+ except Exception as e:
+ return result
+ if len(title_row_list) == 0:
+ return result
+ result_item_content = ''
+ for row in rows:
+ next_md_content = row_to_md(row)
+ next_md_content_len = len(next_md_content)
+ result_item_content_len = len(result_item_content)
+ if len(result_item_content) == 0:
+ result_item_content += title_md_content
+ result_item_content += next_md_content
+ else:
+ if result_item_content_len + next_md_content_len < limit:
+ result_item_content += next_md_content
+ else:
+ paragraphs.append({'content': result_item_content, 'title': ''})
+ result_item_content = ''
+ if len(result_item_content) > 0:
+ paragraphs.append({'content': result_item_content, 'title': ''})
+ return result
+
+
+class XlsSplitHandle(BaseSplitHandle):
+ def handle(self, file, pattern_list: List, with_filter: bool, limit: int, get_buffer, save_image):
+ buffer = get_buffer(file)
+ try:
+ workbook = xlrd.open_workbook(file_contents=buffer)
+ worksheets = workbook.sheets()
+ worksheets_size = len(worksheets)
+ return [row for row in
+ [handle_sheet(file.name,
+ sheet, limit) if worksheets_size == 1 and sheet.name == 'Sheet1' else handle_sheet(
+ sheet.name, sheet, limit) for sheet
+ in worksheets] if row is not None]
+ except Exception as e:
+ return [{'name': file.name, 'content': []}]
+
+ def get_content(self, file, save_image):
+ pass
+
+ def support(self, file, get_buffer):
+ file_name: str = file.name.lower()
+ buffer = get_buffer(file)
+ if file_name.endswith(".xls") and xlrd.inspect_format(content=buffer):
+ return True
+ return False
diff --git a/apps/common/handle/impl/xlsx_split_handle.py b/apps/common/handle/impl/xlsx_split_handle.py
new file mode 100644
index 000000000..c8763b1fb
--- /dev/null
+++ b/apps/common/handle/impl/xlsx_split_handle.py
@@ -0,0 +1,92 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: xlsx_parse_qa_handle.py
+ @date:2024/5/21 14:59
+ @desc:
+"""
+import io
+from typing import List
+
+import openpyxl
+
+from common.handle.base_split_handle import BaseSplitHandle
+from common.handle.impl.tools import xlsx_embed_cells_images
+
+
+def post_cell(image_dict, cell_value):
+ image = image_dict.get(cell_value, None)
+ if image is not None:
+ return f''
+ return cell_value.replace('\n', '
').replace('|', '|')
+
+
+def row_to_md(row, image_dict):
+ return '| ' + ' | '.join(
+ [post_cell(image_dict, str(cell.value if cell.value is not None else '')) if cell is not None else '' for cell
+ in row]) + ' |\n'
+
+
+def handle_sheet(file_name, sheet, image_dict, limit: int):
+ rows = sheet.rows
+ paragraphs = []
+ result = {'name': file_name, 'content': paragraphs}
+ try:
+ title_row_list = next(rows)
+ title_md_content = row_to_md(title_row_list, image_dict)
+ title_md_content += '| ' + ' | '.join(
+ ['---' if cell is not None else '' for cell in title_row_list]) + ' |\n'
+ except Exception as e:
+ return result
+ if len(title_row_list) == 0:
+ return result
+ result_item_content = ''
+ for row in rows:
+ next_md_content = row_to_md(row, image_dict)
+ next_md_content_len = len(next_md_content)
+ result_item_content_len = len(result_item_content)
+ if len(result_item_content) == 0:
+ result_item_content += title_md_content
+ result_item_content += next_md_content
+ else:
+ if result_item_content_len + next_md_content_len < limit:
+ result_item_content += next_md_content
+ else:
+ paragraphs.append({'content': result_item_content, 'title': ''})
+ result_item_content = ''
+ if len(result_item_content) > 0:
+ paragraphs.append({'content': result_item_content, 'title': ''})
+ return result
+
+
+class XlsxSplitHandle(BaseSplitHandle):
+ def handle(self, file, pattern_list: List, with_filter: bool, limit: int, get_buffer, save_image):
+ buffer = get_buffer(file)
+ try:
+ workbook = openpyxl.load_workbook(io.BytesIO(buffer))
+ try:
+ image_dict: dict = xlsx_embed_cells_images(io.BytesIO(buffer))
+ save_image([item for item in image_dict.values()])
+ except Exception as e:
+ image_dict = {}
+ worksheets = workbook.worksheets
+ worksheets_size = len(worksheets)
+ return [row for row in
+ [handle_sheet(file.name,
+ sheet,
+ image_dict,
+ limit) if worksheets_size == 1 and sheet.title == 'Sheet1' else handle_sheet(
+ sheet.title, sheet, image_dict, limit) for sheet
+ in worksheets] if row is not None]
+ except Exception as e:
+ return [{'name': file.name, 'content': []}]
+
+ def get_content(self, file, save_image):
+ pass
+
+ def support(self, file, get_buffer):
+ file_name: str = file.name.lower()
+ if file_name.endswith(".xlsx"):
+ return True
+ return False
diff --git a/apps/common/handle/impl/zip_split_handle.py b/apps/common/handle/impl/zip_split_handle.py
new file mode 100644
index 000000000..f85d8699d
--- /dev/null
+++ b/apps/common/handle/impl/zip_split_handle.py
@@ -0,0 +1,147 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: text_split_handle.py
+ @date:2024/3/27 18:19
+ @desc:
+"""
+import io
+import os
+import re
+import uuid
+import zipfile
+from typing import List
+from urllib.parse import urljoin
+
+from django.db.models import QuerySet
+
+from common.handle.base_split_handle import BaseSplitHandle
+from common.handle.impl.csv_split_handle import CsvSplitHandle
+from common.handle.impl.doc_split_handle import DocSplitHandle
+from common.handle.impl.html_split_handle import HTMLSplitHandle
+from common.handle.impl.pdf_split_handle import PdfSplitHandle
+from common.handle.impl.text_split_handle import TextSplitHandle
+from common.handle.impl.xls_split_handle import XlsSplitHandle
+from common.handle.impl.xlsx_split_handle import XlsxSplitHandle
+from common.util.common import parse_md_image
+from dataset.models import Image
+
+
+class FileBufferHandle:
+ buffer = None
+
+ def get_buffer(self, file):
+ if self.buffer is None:
+ self.buffer = file.read()
+ return self.buffer
+
+
+default_split_handle = TextSplitHandle()
+split_handles = [HTMLSplitHandle(), DocSplitHandle(), PdfSplitHandle(), XlsxSplitHandle(), XlsSplitHandle(),
+ CsvSplitHandle(),
+ default_split_handle]
+
+
+def save_inner_image(image_list):
+ if image_list is not None and len(image_list) > 0:
+ QuerySet(Image).bulk_create(image_list)
+
+
+def file_to_paragraph(file, pattern_list: List, with_filter: bool, limit: int):
+ get_buffer = FileBufferHandle().get_buffer
+ for split_handle in split_handles:
+ if split_handle.support(file, get_buffer):
+ return split_handle.handle(file, pattern_list, with_filter, limit, get_buffer, save_inner_image)
+ raise Exception("不支持的文件格式")
+
+
+def is_valid_uuid(uuid_str: str):
+ try:
+ uuid.UUID(uuid_str)
+ except ValueError:
+ return False
+ return True
+
+
+def get_image_list(result_list: list, zip_files: List[str]):
+ image_file_list = []
+ for result in result_list:
+ for p in result.get('content', []):
+ content: str = p.get('content', '')
+ image_list = parse_md_image(content)
+ for image in image_list:
+ search = re.search("\(.*\)", image)
+ if search:
+ new_image_id = str(uuid.uuid1())
+ source_image_path = search.group().replace('(', '').replace(')', '')
+ image_path = urljoin(result.get('name'), '.' + source_image_path if source_image_path.startswith(
+ '/') else source_image_path)
+ if not zip_files.__contains__(image_path):
+ continue
+ if image_path.startswith('api/file/') or image_path.startswith('api/image/'):
+ image_id = image_path.replace('api/file/', '').replace('api/image/', '')
+ if is_valid_uuid(image_id):
+ image_file_list.append({'source_file': image_path,
+ 'image_id': image_id})
+ else:
+ image_file_list.append({'source_file': image_path,
+ 'image_id': new_image_id})
+ content = content.replace(source_image_path, f'/api/image/{new_image_id}')
+ p['content'] = content
+ else:
+ image_file_list.append({'source_file': image_path,
+ 'image_id': new_image_id})
+ content = content.replace(source_image_path, f'/api/image/{new_image_id}')
+ p['content'] = content
+
+ return image_file_list
+
+
+def filter_image_file(result_list: list, image_list):
+ image_source_file_list = [image.get('source_file') for image in image_list]
+ return [r for r in result_list if not image_source_file_list.__contains__(r.get('name', ''))]
+
+
+class ZipSplitHandle(BaseSplitHandle):
+ def handle(self, file, pattern_list: List, with_filter: bool, limit: int, get_buffer, save_image):
+ buffer = get_buffer(file)
+ bytes_io = io.BytesIO(buffer)
+ result = []
+ # 打开zip文件
+ with zipfile.ZipFile(bytes_io, 'r') as zip_ref:
+ # 获取压缩包中的文件名列表
+ files = zip_ref.namelist()
+ # 读取压缩包中的文件内容
+ for file in files:
+ if file.endswith('/'):
+ continue
+ with zip_ref.open(file) as f:
+ # 对文件内容进行处理
+ try:
+ value = file_to_paragraph(f, pattern_list, with_filter, limit)
+ if isinstance(value, list):
+ result = [*result, *value]
+ else:
+ result.append(value)
+ except Exception:
+ pass
+ image_list = get_image_list(result, files)
+ result = filter_image_file(result, image_list)
+ image_mode_list = []
+ for image in image_list:
+ with zip_ref.open(image.get('source_file')) as f:
+ i = Image(id=image.get('image_id'), image=f.read(),
+ image_name=os.path.basename(image.get('source_file')))
+ image_mode_list.append(i)
+ save_image(image_mode_list)
+ return result
+
+ def support(self, file, get_buffer):
+ file_name: str = file.name.lower()
+ if file_name.endswith(".zip") or file_name.endswith(".ZIP"):
+ return True
+ return False
+
+ def get_content(self, file, save_image):
+ return ""
diff --git a/apps/common/util/common.py b/apps/common/util/common.py
index 2feb3f57d..c654862c3 100644
--- a/apps/common/util/common.py
+++ b/apps/common/util/common.py
@@ -9,6 +9,7 @@
import hashlib
import importlib
import io
+import re
import shutil
import mimetypes
from functools import reduce
@@ -109,6 +110,18 @@ def valid_license(model=None, count=None, message=None):
return inner
+def parse_image(content: str):
+ matches = re.finditer("!\[.*?\]\(\/api\/(image|file)\/.*?\)", content)
+ image_list = [match.group() for match in matches]
+ return image_list
+
+
+def parse_md_image(content: str):
+ matches = re.finditer("!\[.*?\]\(.*?\)", content)
+ image_list = [match.group() for match in matches]
+ return image_list
+
+
def bulk_create_in_batches(model, data, batch_size=1000):
if len(data) == 0:
return
@@ -139,6 +152,7 @@ def bytes_to_uploaded_file(file_bytes, file_name="file.txt"):
)
return uploaded_file
+
def any_to_amr(any_path, amr_path):
"""
把任意格式转成amr文件
diff --git a/apps/dataset/serializers/common_serializers.py b/apps/dataset/serializers/common_serializers.py
index 8f08a2693..1cb33f88f 100644
--- a/apps/dataset/serializers/common_serializers.py
+++ b/apps/dataset/serializers/common_serializers.py
@@ -7,7 +7,9 @@
@desc:
"""
import os
+import re
import uuid
+import zipfile
from typing import List
from django.db.models import QuerySet
@@ -22,11 +24,46 @@ from common.mixins.api_mixin import ApiMixin
from common.util.field_message import ErrMessage
from common.util.file_util import get_file_content
from common.util.fork import Fork
-from dataset.models import Paragraph, Problem, ProblemParagraphMapping, DataSet
+from dataset.models import Paragraph, Problem, ProblemParagraphMapping, DataSet, File, Image
from setting.models_provider import get_model
from smartdoc.conf import PROJECT_DIR
+def zip_dir(zip_path, output=None):
+ output = output or os.path.basename(zip_path) + '.zip'
+ zip = zipfile.ZipFile(output, 'w', zipfile.ZIP_DEFLATED)
+ for root, dirs, files in os.walk(zip_path):
+ relative_root = '' if root == zip_path else root.replace(zip_path, '') + os.sep
+ for filename in files:
+ zip.write(os.path.join(root, filename), relative_root + filename)
+ zip.close()
+
+
+def write_image(zip_path: str, image_list: List[str]):
+ for image in image_list:
+ search = re.search("\(.*\)", image)
+ if search:
+ text = search.group()
+ if text.startswith('(/api/file/'):
+ r = text.replace('(/api/file/', '').replace(')', '')
+ file = QuerySet(File).filter(id=r).first()
+ zip_inner_path = os.path.join('api', 'file', r)
+ file_path = os.path.join(zip_path, zip_inner_path)
+ if not os.path.exists(os.path.dirname(file_path)):
+ os.makedirs(os.path.dirname(file_path))
+ with open(os.path.join(zip_path, file_path), 'wb') as f:
+ f.write(file.get_byte())
+ else:
+ r = text.replace('(/api/image/', '').replace(')', '')
+ image_model = QuerySet(Image).filter(id=r).first()
+ zip_inner_path = os.path.join('api', 'image', r)
+ file_path = os.path.join(zip_path, zip_inner_path)
+ if not os.path.exists(os.path.dirname(file_path)):
+ os.makedirs(os.path.dirname(file_path))
+ with open(file_path, 'wb') as f:
+ f.write(image_model.image)
+
+
def update_document_char_length(document_id: str):
update_execute(get_file_content(
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'update_document_char_length.sql')),
diff --git a/apps/dataset/serializers/dataset_serializers.py b/apps/dataset/serializers/dataset_serializers.py
index 000737649..adb6d3d4f 100644
--- a/apps/dataset/serializers/dataset_serializers.py
+++ b/apps/dataset/serializers/dataset_serializers.py
@@ -6,16 +6,19 @@
@date:2023/9/21 16:14
@desc:
"""
+import io
import logging
import os.path
import re
import traceback
import uuid
+import zipfile
from functools import reduce
+from tempfile import TemporaryDirectory
from typing import Dict, List
from urllib.parse import urlparse
-from celery_once import AlreadyQueued, QueueOnce
+from celery_once import AlreadyQueued
from django.contrib.postgres.fields import ArrayField
from django.core import validators
from django.db import transaction, models
@@ -31,15 +34,15 @@ from common.db.sql_execute import select_list
from common.event import ListenerManagement
from common.exception.app_exception import AppApiException
from common.mixins.api_mixin import ApiMixin
-from common.util.common import post, flat_map, valid_license
+from common.util.common import post, flat_map, valid_license, parse_image
from common.util.field_message import ErrMessage
from common.util.file_util import get_file_content
from common.util.fork import ChildLink, Fork
from common.util.split_model import get_split_model
-from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, ProblemParagraphMapping, Status, \
- TaskType, State
+from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, ProblemParagraphMapping, TaskType, \
+ State, File, Image
from dataset.serializers.common_serializers import list_paragraph, MetaSerializer, ProblemParagraphManage, \
- get_embedding_model_by_dataset_id, get_embedding_model_id_by_dataset_id
+ get_embedding_model_by_dataset_id, get_embedding_model_id_by_dataset_id, write_image, zip_dir
from dataset.serializers.document_serializers import DocumentSerializers, DocumentInstanceSerializer
from dataset.task import sync_web_dataset, sync_replace_web_dataset
from embedding.models import SearchMode
@@ -695,6 +698,33 @@ class DataSetSerializers(serializers.ModelSerializer):
workbook.save(response)
return response
+ def export_zip(self, with_valid=True):
+ if with_valid:
+ self.is_valid(raise_exception=True)
+ document_list = QuerySet(Document).filter(dataset_id=self.data.get('id'))
+ paragraph_list = native_search(QuerySet(Paragraph).filter(dataset_id=self.data.get("id")), get_file_content(
+ os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_paragraph_document_name.sql')))
+ problem_mapping_list = native_search(
+ QuerySet(ProblemParagraphMapping).filter(dataset_id=self.data.get("id")), get_file_content(
+ os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_problem_mapping.sql')),
+ with_table_name=True)
+ data_dict, document_dict = DocumentSerializers.Operate.merge_problem(paragraph_list, problem_mapping_list,
+ document_list)
+ res = [parse_image(paragraph.get('content')) for paragraph in paragraph_list]
+
+ workbook = DocumentSerializers.Operate.get_workbook(data_dict, document_dict)
+ response = HttpResponse(content_type='application/zip')
+ response['Content-Disposition'] = 'attachment; filename="archive.zip"'
+ zip_buffer = io.BytesIO()
+ with TemporaryDirectory() as tempdir:
+ dataset_file = os.path.join(tempdir, 'dataset.xlsx')
+ workbook.save(dataset_file)
+ for r in res:
+ write_image(tempdir, r)
+ zip_dir(tempdir, zip_buffer)
+ response.write(zip_buffer.getvalue())
+ return response
+
@staticmethod
def merge_problem(paragraph_list: List[Dict], problem_mapping_list: List[Dict]):
result = {}
diff --git a/apps/dataset/serializers/document_serializers.py b/apps/dataset/serializers/document_serializers.py
index e37af025b..1e778fc3b 100644
--- a/apps/dataset/serializers/document_serializers.py
+++ b/apps/dataset/serializers/document_serializers.py
@@ -6,12 +6,14 @@
@date:2023/9/22 13:43
@desc:
"""
+import io
import logging
import os
import re
import traceback
import uuid
from functools import reduce
+from tempfile import TemporaryDirectory
from typing import List, Dict
import openpyxl
@@ -30,18 +32,23 @@ from common.db.search import native_search, native_page_search
from common.event import ListenerManagement
from common.event.common import work_thread_pool
from common.exception.app_exception import AppApiException
+from common.handle.impl.csv_split_handle import CsvSplitHandle
from common.handle.impl.doc_split_handle import DocSplitHandle
from common.handle.impl.html_split_handle import HTMLSplitHandle
from common.handle.impl.pdf_split_handle import PdfSplitHandle
from common.handle.impl.qa.csv_parse_qa_handle import CsvParseQAHandle
from common.handle.impl.qa.xls_parse_qa_handle import XlsParseQAHandle
from common.handle.impl.qa.xlsx_parse_qa_handle import XlsxParseQAHandle
-from common.handle.impl.table.csv_parse_table_handle import CsvSplitHandle
-from common.handle.impl.table.xls_parse_table_handle import XlsSplitHandle
-from common.handle.impl.table.xlsx_parse_table_handle import XlsxSplitHandle
+from common.handle.impl.qa.zip_parse_qa_handle import ZipParseQAHandle
+from common.handle.impl.table.csv_parse_table_handle import CsvSplitHandle as CsvSplitTableHandle
+from common.handle.impl.table.xls_parse_table_handle import XlsSplitHandle as XlsSplitTableHandle
+from common.handle.impl.table.xlsx_parse_table_handle import XlsxSplitHandle as XlsxSplitTableHandle
from common.handle.impl.text_split_handle import TextSplitHandle
+from common.handle.impl.xls_split_handle import XlsSplitHandle
+from common.handle.impl.xlsx_split_handle import XlsxSplitHandle
+from common.handle.impl.zip_split_handle import ZipSplitHandle
from common.mixins.api_mixin import ApiMixin
-from common.util.common import post, flat_map, bulk_create_in_batches
+from common.util.common import post, flat_map, bulk_create_in_batches, parse_image
from common.util.field_message import ErrMessage
from common.util.file_util import get_file_content
from common.util.fork import Fork
@@ -49,7 +56,7 @@ from common.util.split_model import get_split_model
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, ProblemParagraphMapping, Image, \
TaskType, State
from dataset.serializers.common_serializers import BatchSerializer, MetaSerializer, ProblemParagraphManage, \
- get_embedding_model_id_by_dataset_id
+ get_embedding_model_id_by_dataset_id, write_image, zip_dir
from dataset.serializers.paragraph_serializers import ParagraphSerializers, ParagraphInstanceSerializer
from dataset.task import sync_web_document, generate_related_by_document_id
from embedding.task.embedding import embedding_by_document, delete_embedding_by_document_list, \
@@ -57,8 +64,8 @@ from embedding.task.embedding import embedding_by_document, delete_embedding_by_
embedding_by_document_list
from smartdoc.conf import PROJECT_DIR
-parse_qa_handle_list = [XlsParseQAHandle(), CsvParseQAHandle(), XlsxParseQAHandle()]
-parse_table_handle_list = [CsvSplitHandle(), XlsSplitHandle(), XlsxSplitHandle()]
+parse_qa_handle_list = [XlsParseQAHandle(), CsvParseQAHandle(), XlsxParseQAHandle(), ZipParseQAHandle()]
+parse_table_handle_list = [CsvSplitTableHandle(), XlsSplitTableHandle(), XlsxSplitTableHandle()]
class FileBufferHandle:
@@ -563,6 +570,34 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
workbook.save(response)
return response
+ def export_zip(self, with_valid=True):
+ if with_valid:
+ self.is_valid(raise_exception=True)
+ document = QuerySet(Document).filter(id=self.data.get("document_id")).first()
+ paragraph_list = native_search(QuerySet(Paragraph).filter(document_id=self.data.get("document_id")),
+ get_file_content(
+ os.path.join(PROJECT_DIR, "apps", "dataset", 'sql',
+ 'list_paragraph_document_name.sql')))
+ problem_mapping_list = native_search(
+ QuerySet(ProblemParagraphMapping).filter(document_id=self.data.get("document_id")), get_file_content(
+ os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_problem_mapping.sql')),
+ with_table_name=True)
+ data_dict, document_dict = self.merge_problem(paragraph_list, problem_mapping_list, [document])
+ res = [parse_image(paragraph.get('content')) for paragraph in paragraph_list]
+
+ workbook = DocumentSerializers.Operate.get_workbook(data_dict, document_dict)
+ response = HttpResponse(content_type='application/zip')
+ response['Content-Disposition'] = 'attachment; filename="archive.zip"'
+ zip_buffer = io.BytesIO()
+ with TemporaryDirectory() as tempdir:
+ dataset_file = os.path.join(tempdir, 'dataset.xlsx')
+ workbook.save(dataset_file)
+ for r in res:
+ write_image(tempdir, r)
+ zip_dir(tempdir, zip_buffer)
+ response.write(zip_buffer.getvalue())
+ return response
+
@staticmethod
def get_workbook(data_dict, document_dict):
# 创建工作簿对象
@@ -929,9 +964,9 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
def parse(self):
file_list = self.data.get("file")
- return list(
- map(lambda f: file_to_paragraph(f, self.data.get("patterns", None), self.data.get("with_filter", None),
- self.data.get("limit", 4096)), file_list))
+ return reduce(lambda x, y: [*x, *y],
+ [file_to_paragraph(f, self.data.get("patterns", None), self.data.get("with_filter", None),
+ self.data.get("limit", 4096)) for f in file_list], [])
class SplitPattern(ApiMixin, serializers.Serializer):
@staticmethod
@@ -1109,20 +1144,33 @@ class FileBufferHandle:
default_split_handle = TextSplitHandle()
-split_handles = [HTMLSplitHandle(), DocSplitHandle(), PdfSplitHandle(), default_split_handle]
+split_handles = [HTMLSplitHandle(), DocSplitHandle(), PdfSplitHandle(), XlsxSplitHandle(), XlsSplitHandle(),
+ CsvSplitHandle(),
+ ZipSplitHandle(),
+ default_split_handle]
def save_image(image_list):
if image_list is not None and len(image_list) > 0:
- QuerySet(Image).bulk_create(image_list)
+ exist_image_list = [str(i.get('id')) for i in
+ QuerySet(Image).filter(id__in=[i.id for i in image_list]).values('id')]
+ save_image_list = [image for image in image_list if not exist_image_list.__contains__(str(image.id))]
+ if len(save_image_list) > 0:
+ QuerySet(Image).bulk_create(save_image_list)
def file_to_paragraph(file, pattern_list: List, with_filter: bool, limit: int):
get_buffer = FileBufferHandle().get_buffer
for split_handle in split_handles:
if split_handle.support(file, get_buffer):
- return split_handle.handle(file, pattern_list, with_filter, limit, get_buffer, save_image)
- return default_split_handle.handle(file, pattern_list, with_filter, limit, get_buffer, save_image)
+ result = split_handle.handle(file, pattern_list, with_filter, limit, get_buffer, save_image)
+ if isinstance(result, list):
+ return result
+ return [result]
+ result = default_split_handle.handle(file, pattern_list, with_filter, limit, get_buffer, save_image)
+ if isinstance(result, list):
+ return result
+ return [result]
def delete_problems_and_mappings(document_ids):
diff --git a/apps/dataset/urls.py b/apps/dataset/urls.py
index 9e5835318..f7240a667 100644
--- a/apps/dataset/urls.py
+++ b/apps/dataset/urls.py
@@ -9,6 +9,7 @@ urlpatterns = [
path('dataset/qa', views.Dataset.CreateQADataset.as_view()),
path('dataset/
支持格式:TXT、Markdown、PDF、DOCX、HTML
+支持格式:TXT、Markdown、PDF、DOCX、HTML、ZIP、XLSX、XLS、CSV