diff --git a/apps/common/handle/impl/csv_split_handle.py b/apps/common/handle/impl/csv_split_handle.py new file mode 100644 index 000000000..d4691a3d6 --- /dev/null +++ b/apps/common/handle/impl/csv_split_handle.py @@ -0,0 +1,70 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: csv_parse_qa_handle.py + @date:2024/5/21 14:59 + @desc: +""" +import csv +import io +from typing import List + +from charset_normalizer import detect + +from common.handle.base_split_handle import BaseSplitHandle + + +def post_cell(cell_value): + return cell_value.replace('\n', '
').replace('|', '|') + + +def row_to_md(row): + return '| ' + ' | '.join( + [post_cell(cell) if cell is not None else '' for cell in row]) + ' |\n' + + +class CsvSplitHandle(BaseSplitHandle): + def handle(self, file, pattern_list: List, with_filter: bool, limit: int, get_buffer, save_image): + buffer = get_buffer(file) + paragraphs = [] + result = {'name': file.name, 'content': paragraphs} + try: + reader = csv.reader(io.TextIOWrapper(io.BytesIO(buffer), encoding=detect(buffer)['encoding'])) + try: + title_row_list = reader.__next__() + title_md_content = row_to_md(title_row_list) + title_md_content += '| ' + ' | '.join( + ['---' if cell is not None else '' for cell in title_row_list]) + ' |\n' + except Exception as e: + return result + if len(title_row_list) == 0: + return result + result_item_content = '' + for row in reader: + next_md_content = row_to_md(row) + next_md_content_len = len(next_md_content) + result_item_content_len = len(result_item_content) + if len(result_item_content) == 0: + result_item_content += title_md_content + result_item_content += next_md_content + else: + if result_item_content_len + next_md_content_len < limit: + result_item_content += next_md_content + else: + paragraphs.append({'content': result_item_content, 'title': ''}) + result_item_content = '' + if len(result_item_content) > 0: + paragraphs.append({'content': result_item_content, 'title': ''}) + return result + except Exception as e: + return result + + def get_content(self, file, save_image): + pass + + def support(self, file, get_buffer): + file_name: str = file.name.lower() + if file_name.endswith(".csv"): + return True + return False diff --git a/apps/common/handle/impl/qa/zip_parse_qa_handle.py b/apps/common/handle/impl/qa/zip_parse_qa_handle.py new file mode 100644 index 000000000..cb9418723 --- /dev/null +++ b/apps/common/handle/impl/qa/zip_parse_qa_handle.py @@ -0,0 +1,162 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: text_split_handle.py + @date:2024/3/27 18:19 + @desc: +""" +import io +import os +import re +import uuid +import zipfile +from typing import List +from urllib.parse import urljoin + +from django.db.models import QuerySet + +from common.handle.base_parse_qa_handle import BaseParseQAHandle +from common.handle.impl.qa.csv_parse_qa_handle import CsvParseQAHandle +from common.handle.impl.qa.xls_parse_qa_handle import XlsParseQAHandle +from common.handle.impl.qa.xlsx_parse_qa_handle import XlsxParseQAHandle +from common.util.common import parse_md_image +from dataset.models import Image + + +class FileBufferHandle: + buffer = None + + def get_buffer(self, file): + if self.buffer is None: + self.buffer = file.read() + return self.buffer + + +split_handles = [XlsParseQAHandle(), XlsxParseQAHandle(), CsvParseQAHandle()] + + +def save_inner_image(image_list): + """ + 子模块插入图片逻辑 + @param image_list: + @return: + """ + if image_list is not None and len(image_list) > 0: + QuerySet(Image).bulk_create(image_list) + + +def file_to_paragraph(file): + """ + 文件转换为段落列表 + @param file: 文件 + @return: { + name:文件名 + paragraphs:段落列表 + } + """ + get_buffer = FileBufferHandle().get_buffer + for split_handle in split_handles: + if split_handle.support(file, get_buffer): + return split_handle.handle(file, get_buffer, save_inner_image) + raise Exception("不支持的文件格式") + + +def is_valid_uuid(uuid_str: str): + """ + 校验字符串是否是uuid + @param uuid_str: 需要校验的字符串 + @return: bool + """ + try: + uuid.UUID(uuid_str) + except ValueError: + return False + return True + + +def get_image_list(result_list: list, zip_files: List[str]): + """ + 获取图片文件列表 + @param result_list: + @param zip_files: + @return: + """ + image_file_list = [] + for result in result_list: + for p in result.get('paragraphs', []): + content: str = p.get('content', '') + image_list = parse_md_image(content) + for image in image_list: + search = re.search("\(.*\)", image) + if search: + new_image_id = str(uuid.uuid1()) + source_image_path = search.group().replace('(', '').replace(')', '') + image_path = urljoin(result.get('name'), '.' + source_image_path if source_image_path.startswith( + '/') else source_image_path) + if not zip_files.__contains__(image_path): + continue + if image_path.startswith('api/file/') or image_path.startswith('api/image/'): + image_id = image_path.replace('api/file/', '').replace('api/image/', '') + if is_valid_uuid(image_id): + image_file_list.append({'source_file': image_path, + 'image_id': image_id}) + else: + image_file_list.append({'source_file': image_path, + 'image_id': new_image_id}) + content = content.replace(source_image_path, f'/api/image/{new_image_id}') + p['content'] = content + else: + image_file_list.append({'source_file': image_path, + 'image_id': new_image_id}) + content = content.replace(source_image_path, f'/api/image/{new_image_id}') + p['content'] = content + + return image_file_list + + +def filter_image_file(result_list: list, image_list): + image_source_file_list = [image.get('source_file') for image in image_list] + return [r for r in result_list if not image_source_file_list.__contains__(r.get('name', ''))] + + +class ZipParseQAHandle(BaseParseQAHandle): + + def handle(self, file, get_buffer, save_image): + buffer = get_buffer(file) + bytes_io = io.BytesIO(buffer) + result = [] + # 打开zip文件 + with zipfile.ZipFile(bytes_io, 'r') as zip_ref: + # 获取压缩包中的文件名列表 + files = zip_ref.namelist() + # 读取压缩包中的文件内容 + for file in files: + if file.endswith('/'): + continue + with zip_ref.open(file) as f: + # 对文件内容进行处理 + try: + value = file_to_paragraph(f) + if isinstance(value, list): + result = [*result, *value] + else: + result.append(value) + except Exception: + pass + image_list = get_image_list(result, files) + result = filter_image_file(result, image_list) + image_mode_list = [] + for image in image_list: + with zip_ref.open(image.get('source_file')) as f: + i = Image(id=image.get('image_id'), image=f.read(), + image_name=os.path.basename(image.get('source_file'))) + image_mode_list.append(i) + save_image(image_mode_list) + return result + + def support(self, file, get_buffer): + file_name: str = file.name.lower() + if file_name.endswith(".zip") or file_name.endswith(".ZIP"): + return True + return False diff --git a/apps/common/handle/impl/xls_split_handle.py b/apps/common/handle/impl/xls_split_handle.py new file mode 100644 index 000000000..332e5b56d --- /dev/null +++ b/apps/common/handle/impl/xls_split_handle.py @@ -0,0 +1,80 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: xls_parse_qa_handle.py + @date:2024/5/21 14:59 + @desc: +""" +from typing import List + +import xlrd + +from common.handle.base_split_handle import BaseSplitHandle + + +def post_cell(cell_value): + return cell_value.replace('\n', '
').replace('|', '|') + + +def row_to_md(row): + return '| ' + ' | '.join( + [post_cell(str(cell)) if cell is not None else '' for cell in row]) + ' |\n' + + +def handle_sheet(file_name, sheet, limit: int): + rows = iter([sheet.row_values(i) for i in range(sheet.nrows)]) + paragraphs = [] + result = {'name': file_name, 'content': paragraphs} + try: + title_row_list = next(rows) + title_md_content = row_to_md(title_row_list) + title_md_content += '| ' + ' | '.join( + ['---' if cell is not None else '' for cell in title_row_list]) + ' |\n' + except Exception as e: + return result + if len(title_row_list) == 0: + return result + result_item_content = '' + for row in rows: + next_md_content = row_to_md(row) + next_md_content_len = len(next_md_content) + result_item_content_len = len(result_item_content) + if len(result_item_content) == 0: + result_item_content += title_md_content + result_item_content += next_md_content + else: + if result_item_content_len + next_md_content_len < limit: + result_item_content += next_md_content + else: + paragraphs.append({'content': result_item_content, 'title': ''}) + result_item_content = '' + if len(result_item_content) > 0: + paragraphs.append({'content': result_item_content, 'title': ''}) + return result + + +class XlsSplitHandle(BaseSplitHandle): + def handle(self, file, pattern_list: List, with_filter: bool, limit: int, get_buffer, save_image): + buffer = get_buffer(file) + try: + workbook = xlrd.open_workbook(file_contents=buffer) + worksheets = workbook.sheets() + worksheets_size = len(worksheets) + return [row for row in + [handle_sheet(file.name, + sheet, limit) if worksheets_size == 1 and sheet.name == 'Sheet1' else handle_sheet( + sheet.name, sheet, limit) for sheet + in worksheets] if row is not None] + except Exception as e: + return [{'name': file.name, 'content': []}] + + def get_content(self, file, save_image): + pass + + def support(self, file, get_buffer): + file_name: str = file.name.lower() + buffer = get_buffer(file) + if file_name.endswith(".xls") and xlrd.inspect_format(content=buffer): + return True + return False diff --git a/apps/common/handle/impl/xlsx_split_handle.py b/apps/common/handle/impl/xlsx_split_handle.py new file mode 100644 index 000000000..c8763b1fb --- /dev/null +++ b/apps/common/handle/impl/xlsx_split_handle.py @@ -0,0 +1,92 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: xlsx_parse_qa_handle.py + @date:2024/5/21 14:59 + @desc: +""" +import io +from typing import List + +import openpyxl + +from common.handle.base_split_handle import BaseSplitHandle +from common.handle.impl.tools import xlsx_embed_cells_images + + +def post_cell(image_dict, cell_value): + image = image_dict.get(cell_value, None) + if image is not None: + return f'![](/api/image/{image.id})' + return cell_value.replace('\n', '
').replace('|', '|') + + +def row_to_md(row, image_dict): + return '| ' + ' | '.join( + [post_cell(image_dict, str(cell.value if cell.value is not None else '')) if cell is not None else '' for cell + in row]) + ' |\n' + + +def handle_sheet(file_name, sheet, image_dict, limit: int): + rows = sheet.rows + paragraphs = [] + result = {'name': file_name, 'content': paragraphs} + try: + title_row_list = next(rows) + title_md_content = row_to_md(title_row_list, image_dict) + title_md_content += '| ' + ' | '.join( + ['---' if cell is not None else '' for cell in title_row_list]) + ' |\n' + except Exception as e: + return result + if len(title_row_list) == 0: + return result + result_item_content = '' + for row in rows: + next_md_content = row_to_md(row, image_dict) + next_md_content_len = len(next_md_content) + result_item_content_len = len(result_item_content) + if len(result_item_content) == 0: + result_item_content += title_md_content + result_item_content += next_md_content + else: + if result_item_content_len + next_md_content_len < limit: + result_item_content += next_md_content + else: + paragraphs.append({'content': result_item_content, 'title': ''}) + result_item_content = '' + if len(result_item_content) > 0: + paragraphs.append({'content': result_item_content, 'title': ''}) + return result + + +class XlsxSplitHandle(BaseSplitHandle): + def handle(self, file, pattern_list: List, with_filter: bool, limit: int, get_buffer, save_image): + buffer = get_buffer(file) + try: + workbook = openpyxl.load_workbook(io.BytesIO(buffer)) + try: + image_dict: dict = xlsx_embed_cells_images(io.BytesIO(buffer)) + save_image([item for item in image_dict.values()]) + except Exception as e: + image_dict = {} + worksheets = workbook.worksheets + worksheets_size = len(worksheets) + return [row for row in + [handle_sheet(file.name, + sheet, + image_dict, + limit) if worksheets_size == 1 and sheet.title == 'Sheet1' else handle_sheet( + sheet.title, sheet, image_dict, limit) for sheet + in worksheets] if row is not None] + except Exception as e: + return [{'name': file.name, 'content': []}] + + def get_content(self, file, save_image): + pass + + def support(self, file, get_buffer): + file_name: str = file.name.lower() + if file_name.endswith(".xlsx"): + return True + return False diff --git a/apps/common/handle/impl/zip_split_handle.py b/apps/common/handle/impl/zip_split_handle.py new file mode 100644 index 000000000..f85d8699d --- /dev/null +++ b/apps/common/handle/impl/zip_split_handle.py @@ -0,0 +1,147 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: text_split_handle.py + @date:2024/3/27 18:19 + @desc: +""" +import io +import os +import re +import uuid +import zipfile +from typing import List +from urllib.parse import urljoin + +from django.db.models import QuerySet + +from common.handle.base_split_handle import BaseSplitHandle +from common.handle.impl.csv_split_handle import CsvSplitHandle +from common.handle.impl.doc_split_handle import DocSplitHandle +from common.handle.impl.html_split_handle import HTMLSplitHandle +from common.handle.impl.pdf_split_handle import PdfSplitHandle +from common.handle.impl.text_split_handle import TextSplitHandle +from common.handle.impl.xls_split_handle import XlsSplitHandle +from common.handle.impl.xlsx_split_handle import XlsxSplitHandle +from common.util.common import parse_md_image +from dataset.models import Image + + +class FileBufferHandle: + buffer = None + + def get_buffer(self, file): + if self.buffer is None: + self.buffer = file.read() + return self.buffer + + +default_split_handle = TextSplitHandle() +split_handles = [HTMLSplitHandle(), DocSplitHandle(), PdfSplitHandle(), XlsxSplitHandle(), XlsSplitHandle(), + CsvSplitHandle(), + default_split_handle] + + +def save_inner_image(image_list): + if image_list is not None and len(image_list) > 0: + QuerySet(Image).bulk_create(image_list) + + +def file_to_paragraph(file, pattern_list: List, with_filter: bool, limit: int): + get_buffer = FileBufferHandle().get_buffer + for split_handle in split_handles: + if split_handle.support(file, get_buffer): + return split_handle.handle(file, pattern_list, with_filter, limit, get_buffer, save_inner_image) + raise Exception("不支持的文件格式") + + +def is_valid_uuid(uuid_str: str): + try: + uuid.UUID(uuid_str) + except ValueError: + return False + return True + + +def get_image_list(result_list: list, zip_files: List[str]): + image_file_list = [] + for result in result_list: + for p in result.get('content', []): + content: str = p.get('content', '') + image_list = parse_md_image(content) + for image in image_list: + search = re.search("\(.*\)", image) + if search: + new_image_id = str(uuid.uuid1()) + source_image_path = search.group().replace('(', '').replace(')', '') + image_path = urljoin(result.get('name'), '.' + source_image_path if source_image_path.startswith( + '/') else source_image_path) + if not zip_files.__contains__(image_path): + continue + if image_path.startswith('api/file/') or image_path.startswith('api/image/'): + image_id = image_path.replace('api/file/', '').replace('api/image/', '') + if is_valid_uuid(image_id): + image_file_list.append({'source_file': image_path, + 'image_id': image_id}) + else: + image_file_list.append({'source_file': image_path, + 'image_id': new_image_id}) + content = content.replace(source_image_path, f'/api/image/{new_image_id}') + p['content'] = content + else: + image_file_list.append({'source_file': image_path, + 'image_id': new_image_id}) + content = content.replace(source_image_path, f'/api/image/{new_image_id}') + p['content'] = content + + return image_file_list + + +def filter_image_file(result_list: list, image_list): + image_source_file_list = [image.get('source_file') for image in image_list] + return [r for r in result_list if not image_source_file_list.__contains__(r.get('name', ''))] + + +class ZipSplitHandle(BaseSplitHandle): + def handle(self, file, pattern_list: List, with_filter: bool, limit: int, get_buffer, save_image): + buffer = get_buffer(file) + bytes_io = io.BytesIO(buffer) + result = [] + # 打开zip文件 + with zipfile.ZipFile(bytes_io, 'r') as zip_ref: + # 获取压缩包中的文件名列表 + files = zip_ref.namelist() + # 读取压缩包中的文件内容 + for file in files: + if file.endswith('/'): + continue + with zip_ref.open(file) as f: + # 对文件内容进行处理 + try: + value = file_to_paragraph(f, pattern_list, with_filter, limit) + if isinstance(value, list): + result = [*result, *value] + else: + result.append(value) + except Exception: + pass + image_list = get_image_list(result, files) + result = filter_image_file(result, image_list) + image_mode_list = [] + for image in image_list: + with zip_ref.open(image.get('source_file')) as f: + i = Image(id=image.get('image_id'), image=f.read(), + image_name=os.path.basename(image.get('source_file'))) + image_mode_list.append(i) + save_image(image_mode_list) + return result + + def support(self, file, get_buffer): + file_name: str = file.name.lower() + if file_name.endswith(".zip") or file_name.endswith(".ZIP"): + return True + return False + + def get_content(self, file, save_image): + return "" diff --git a/apps/common/util/common.py b/apps/common/util/common.py index 2feb3f57d..c654862c3 100644 --- a/apps/common/util/common.py +++ b/apps/common/util/common.py @@ -9,6 +9,7 @@ import hashlib import importlib import io +import re import shutil import mimetypes from functools import reduce @@ -109,6 +110,18 @@ def valid_license(model=None, count=None, message=None): return inner +def parse_image(content: str): + matches = re.finditer("!\[.*?\]\(\/api\/(image|file)\/.*?\)", content) + image_list = [match.group() for match in matches] + return image_list + + +def parse_md_image(content: str): + matches = re.finditer("!\[.*?\]\(.*?\)", content) + image_list = [match.group() for match in matches] + return image_list + + def bulk_create_in_batches(model, data, batch_size=1000): if len(data) == 0: return @@ -139,6 +152,7 @@ def bytes_to_uploaded_file(file_bytes, file_name="file.txt"): ) return uploaded_file + def any_to_amr(any_path, amr_path): """ 把任意格式转成amr文件 diff --git a/apps/dataset/serializers/common_serializers.py b/apps/dataset/serializers/common_serializers.py index 8f08a2693..1cb33f88f 100644 --- a/apps/dataset/serializers/common_serializers.py +++ b/apps/dataset/serializers/common_serializers.py @@ -7,7 +7,9 @@ @desc: """ import os +import re import uuid +import zipfile from typing import List from django.db.models import QuerySet @@ -22,11 +24,46 @@ from common.mixins.api_mixin import ApiMixin from common.util.field_message import ErrMessage from common.util.file_util import get_file_content from common.util.fork import Fork -from dataset.models import Paragraph, Problem, ProblemParagraphMapping, DataSet +from dataset.models import Paragraph, Problem, ProblemParagraphMapping, DataSet, File, Image from setting.models_provider import get_model from smartdoc.conf import PROJECT_DIR +def zip_dir(zip_path, output=None): + output = output or os.path.basename(zip_path) + '.zip' + zip = zipfile.ZipFile(output, 'w', zipfile.ZIP_DEFLATED) + for root, dirs, files in os.walk(zip_path): + relative_root = '' if root == zip_path else root.replace(zip_path, '') + os.sep + for filename in files: + zip.write(os.path.join(root, filename), relative_root + filename) + zip.close() + + +def write_image(zip_path: str, image_list: List[str]): + for image in image_list: + search = re.search("\(.*\)", image) + if search: + text = search.group() + if text.startswith('(/api/file/'): + r = text.replace('(/api/file/', '').replace(')', '') + file = QuerySet(File).filter(id=r).first() + zip_inner_path = os.path.join('api', 'file', r) + file_path = os.path.join(zip_path, zip_inner_path) + if not os.path.exists(os.path.dirname(file_path)): + os.makedirs(os.path.dirname(file_path)) + with open(os.path.join(zip_path, file_path), 'wb') as f: + f.write(file.get_byte()) + else: + r = text.replace('(/api/image/', '').replace(')', '') + image_model = QuerySet(Image).filter(id=r).first() + zip_inner_path = os.path.join('api', 'image', r) + file_path = os.path.join(zip_path, zip_inner_path) + if not os.path.exists(os.path.dirname(file_path)): + os.makedirs(os.path.dirname(file_path)) + with open(file_path, 'wb') as f: + f.write(image_model.image) + + def update_document_char_length(document_id: str): update_execute(get_file_content( os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'update_document_char_length.sql')), diff --git a/apps/dataset/serializers/dataset_serializers.py b/apps/dataset/serializers/dataset_serializers.py index 000737649..adb6d3d4f 100644 --- a/apps/dataset/serializers/dataset_serializers.py +++ b/apps/dataset/serializers/dataset_serializers.py @@ -6,16 +6,19 @@ @date:2023/9/21 16:14 @desc: """ +import io import logging import os.path import re import traceback import uuid +import zipfile from functools import reduce +from tempfile import TemporaryDirectory from typing import Dict, List from urllib.parse import urlparse -from celery_once import AlreadyQueued, QueueOnce +from celery_once import AlreadyQueued from django.contrib.postgres.fields import ArrayField from django.core import validators from django.db import transaction, models @@ -31,15 +34,15 @@ from common.db.sql_execute import select_list from common.event import ListenerManagement from common.exception.app_exception import AppApiException from common.mixins.api_mixin import ApiMixin -from common.util.common import post, flat_map, valid_license +from common.util.common import post, flat_map, valid_license, parse_image from common.util.field_message import ErrMessage from common.util.file_util import get_file_content from common.util.fork import ChildLink, Fork from common.util.split_model import get_split_model -from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, ProblemParagraphMapping, Status, \ - TaskType, State +from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, ProblemParagraphMapping, TaskType, \ + State, File, Image from dataset.serializers.common_serializers import list_paragraph, MetaSerializer, ProblemParagraphManage, \ - get_embedding_model_by_dataset_id, get_embedding_model_id_by_dataset_id + get_embedding_model_by_dataset_id, get_embedding_model_id_by_dataset_id, write_image, zip_dir from dataset.serializers.document_serializers import DocumentSerializers, DocumentInstanceSerializer from dataset.task import sync_web_dataset, sync_replace_web_dataset from embedding.models import SearchMode @@ -695,6 +698,33 @@ class DataSetSerializers(serializers.ModelSerializer): workbook.save(response) return response + def export_zip(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + document_list = QuerySet(Document).filter(dataset_id=self.data.get('id')) + paragraph_list = native_search(QuerySet(Paragraph).filter(dataset_id=self.data.get("id")), get_file_content( + os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_paragraph_document_name.sql'))) + problem_mapping_list = native_search( + QuerySet(ProblemParagraphMapping).filter(dataset_id=self.data.get("id")), get_file_content( + os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_problem_mapping.sql')), + with_table_name=True) + data_dict, document_dict = DocumentSerializers.Operate.merge_problem(paragraph_list, problem_mapping_list, + document_list) + res = [parse_image(paragraph.get('content')) for paragraph in paragraph_list] + + workbook = DocumentSerializers.Operate.get_workbook(data_dict, document_dict) + response = HttpResponse(content_type='application/zip') + response['Content-Disposition'] = 'attachment; filename="archive.zip"' + zip_buffer = io.BytesIO() + with TemporaryDirectory() as tempdir: + dataset_file = os.path.join(tempdir, 'dataset.xlsx') + workbook.save(dataset_file) + for r in res: + write_image(tempdir, r) + zip_dir(tempdir, zip_buffer) + response.write(zip_buffer.getvalue()) + return response + @staticmethod def merge_problem(paragraph_list: List[Dict], problem_mapping_list: List[Dict]): result = {} diff --git a/apps/dataset/serializers/document_serializers.py b/apps/dataset/serializers/document_serializers.py index e37af025b..1e778fc3b 100644 --- a/apps/dataset/serializers/document_serializers.py +++ b/apps/dataset/serializers/document_serializers.py @@ -6,12 +6,14 @@ @date:2023/9/22 13:43 @desc: """ +import io import logging import os import re import traceback import uuid from functools import reduce +from tempfile import TemporaryDirectory from typing import List, Dict import openpyxl @@ -30,18 +32,23 @@ from common.db.search import native_search, native_page_search from common.event import ListenerManagement from common.event.common import work_thread_pool from common.exception.app_exception import AppApiException +from common.handle.impl.csv_split_handle import CsvSplitHandle from common.handle.impl.doc_split_handle import DocSplitHandle from common.handle.impl.html_split_handle import HTMLSplitHandle from common.handle.impl.pdf_split_handle import PdfSplitHandle from common.handle.impl.qa.csv_parse_qa_handle import CsvParseQAHandle from common.handle.impl.qa.xls_parse_qa_handle import XlsParseQAHandle from common.handle.impl.qa.xlsx_parse_qa_handle import XlsxParseQAHandle -from common.handle.impl.table.csv_parse_table_handle import CsvSplitHandle -from common.handle.impl.table.xls_parse_table_handle import XlsSplitHandle -from common.handle.impl.table.xlsx_parse_table_handle import XlsxSplitHandle +from common.handle.impl.qa.zip_parse_qa_handle import ZipParseQAHandle +from common.handle.impl.table.csv_parse_table_handle import CsvSplitHandle as CsvSplitTableHandle +from common.handle.impl.table.xls_parse_table_handle import XlsSplitHandle as XlsSplitTableHandle +from common.handle.impl.table.xlsx_parse_table_handle import XlsxSplitHandle as XlsxSplitTableHandle from common.handle.impl.text_split_handle import TextSplitHandle +from common.handle.impl.xls_split_handle import XlsSplitHandle +from common.handle.impl.xlsx_split_handle import XlsxSplitHandle +from common.handle.impl.zip_split_handle import ZipSplitHandle from common.mixins.api_mixin import ApiMixin -from common.util.common import post, flat_map, bulk_create_in_batches +from common.util.common import post, flat_map, bulk_create_in_batches, parse_image from common.util.field_message import ErrMessage from common.util.file_util import get_file_content from common.util.fork import Fork @@ -49,7 +56,7 @@ from common.util.split_model import get_split_model from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, ProblemParagraphMapping, Image, \ TaskType, State from dataset.serializers.common_serializers import BatchSerializer, MetaSerializer, ProblemParagraphManage, \ - get_embedding_model_id_by_dataset_id + get_embedding_model_id_by_dataset_id, write_image, zip_dir from dataset.serializers.paragraph_serializers import ParagraphSerializers, ParagraphInstanceSerializer from dataset.task import sync_web_document, generate_related_by_document_id from embedding.task.embedding import embedding_by_document, delete_embedding_by_document_list, \ @@ -57,8 +64,8 @@ from embedding.task.embedding import embedding_by_document, delete_embedding_by_ embedding_by_document_list from smartdoc.conf import PROJECT_DIR -parse_qa_handle_list = [XlsParseQAHandle(), CsvParseQAHandle(), XlsxParseQAHandle()] -parse_table_handle_list = [CsvSplitHandle(), XlsSplitHandle(), XlsxSplitHandle()] +parse_qa_handle_list = [XlsParseQAHandle(), CsvParseQAHandle(), XlsxParseQAHandle(), ZipParseQAHandle()] +parse_table_handle_list = [CsvSplitTableHandle(), XlsSplitTableHandle(), XlsxSplitTableHandle()] class FileBufferHandle: @@ -563,6 +570,34 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): workbook.save(response) return response + def export_zip(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + document = QuerySet(Document).filter(id=self.data.get("document_id")).first() + paragraph_list = native_search(QuerySet(Paragraph).filter(document_id=self.data.get("document_id")), + get_file_content( + os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', + 'list_paragraph_document_name.sql'))) + problem_mapping_list = native_search( + QuerySet(ProblemParagraphMapping).filter(document_id=self.data.get("document_id")), get_file_content( + os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_problem_mapping.sql')), + with_table_name=True) + data_dict, document_dict = self.merge_problem(paragraph_list, problem_mapping_list, [document]) + res = [parse_image(paragraph.get('content')) for paragraph in paragraph_list] + + workbook = DocumentSerializers.Operate.get_workbook(data_dict, document_dict) + response = HttpResponse(content_type='application/zip') + response['Content-Disposition'] = 'attachment; filename="archive.zip"' + zip_buffer = io.BytesIO() + with TemporaryDirectory() as tempdir: + dataset_file = os.path.join(tempdir, 'dataset.xlsx') + workbook.save(dataset_file) + for r in res: + write_image(tempdir, r) + zip_dir(tempdir, zip_buffer) + response.write(zip_buffer.getvalue()) + return response + @staticmethod def get_workbook(data_dict, document_dict): # 创建工作簿对象 @@ -929,9 +964,9 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): def parse(self): file_list = self.data.get("file") - return list( - map(lambda f: file_to_paragraph(f, self.data.get("patterns", None), self.data.get("with_filter", None), - self.data.get("limit", 4096)), file_list)) + return reduce(lambda x, y: [*x, *y], + [file_to_paragraph(f, self.data.get("patterns", None), self.data.get("with_filter", None), + self.data.get("limit", 4096)) for f in file_list], []) class SplitPattern(ApiMixin, serializers.Serializer): @staticmethod @@ -1109,20 +1144,33 @@ class FileBufferHandle: default_split_handle = TextSplitHandle() -split_handles = [HTMLSplitHandle(), DocSplitHandle(), PdfSplitHandle(), default_split_handle] +split_handles = [HTMLSplitHandle(), DocSplitHandle(), PdfSplitHandle(), XlsxSplitHandle(), XlsSplitHandle(), + CsvSplitHandle(), + ZipSplitHandle(), + default_split_handle] def save_image(image_list): if image_list is not None and len(image_list) > 0: - QuerySet(Image).bulk_create(image_list) + exist_image_list = [str(i.get('id')) for i in + QuerySet(Image).filter(id__in=[i.id for i in image_list]).values('id')] + save_image_list = [image for image in image_list if not exist_image_list.__contains__(str(image.id))] + if len(save_image_list) > 0: + QuerySet(Image).bulk_create(save_image_list) def file_to_paragraph(file, pattern_list: List, with_filter: bool, limit: int): get_buffer = FileBufferHandle().get_buffer for split_handle in split_handles: if split_handle.support(file, get_buffer): - return split_handle.handle(file, pattern_list, with_filter, limit, get_buffer, save_image) - return default_split_handle.handle(file, pattern_list, with_filter, limit, get_buffer, save_image) + result = split_handle.handle(file, pattern_list, with_filter, limit, get_buffer, save_image) + if isinstance(result, list): + return result + return [result] + result = default_split_handle.handle(file, pattern_list, with_filter, limit, get_buffer, save_image) + if isinstance(result, list): + return result + return [result] def delete_problems_and_mappings(document_ids): diff --git a/apps/dataset/urls.py b/apps/dataset/urls.py index 9e5835318..f7240a667 100644 --- a/apps/dataset/urls.py +++ b/apps/dataset/urls.py @@ -9,6 +9,7 @@ urlpatterns = [ path('dataset/qa', views.Dataset.CreateQADataset.as_view()), path('dataset/', views.Dataset.Operate.as_view(), name="dataset_key"), path('dataset//export', views.Dataset.Export.as_view(), name="export"), + path('dataset//export_zip', views.Dataset.ExportZip.as_view(), name="export_zip"), path('dataset//re_embedding', views.Dataset.Embedding.as_view(), name="dataset_key"), path('dataset//application', views.Dataset.Application.as_view()), path('dataset//', views.Dataset.Page.as_view(), name="dataset"), @@ -35,6 +36,8 @@ urlpatterns = [ path('dataset//document/migrate/', views.Document.Migrate.as_view()), path('dataset//document//export', views.Document.Export.as_view(), name="document_export"), + path('dataset//document//export_zip', views.Document.ExportZip.as_view(), + name="document_export"), path('dataset//document//sync', views.Document.SyncWeb.as_view()), path('dataset//document//refresh', views.Document.Refresh.as_view()), path('dataset//document//cancel_task', views.Document.CancelTask.as_view()), diff --git a/apps/dataset/views/dataset.py b/apps/dataset/views/dataset.py index f16fbd45b..c831aad50 100644 --- a/apps/dataset/views/dataset.py +++ b/apps/dataset/views/dataset.py @@ -165,6 +165,19 @@ class Dataset(APIView): def get(self, request: Request, dataset_id: str): return DataSetSerializers.Operate(data={'id': dataset_id, 'user_id': request.user.id}).export_excel() + class ExportZip(APIView): + authentication_classes = [TokenAuth] + + @action(methods="GET", detail=False) + @swagger_auto_schema(operation_summary="导出知识库包含图片", operation_id="导出知识库包含图片", + manual_parameters=DataSetSerializers.Operate.get_request_params_api(), + tags=["知识库"] + ) + @has_permissions(lambda r, keywords: Permission(group=Group.DATASET, operate=Operate.MANAGE, + dynamic_tag=keywords.get('dataset_id'))) + def get(self, request: Request, dataset_id: str): + return DataSetSerializers.Operate(data={'id': dataset_id, 'user_id': request.user.id}).export_zip() + class Operate(APIView): authentication_classes = [TokenAuth] @@ -221,7 +234,8 @@ class Dataset(APIView): def get(self, request: Request, current_page, page_size): d = DataSetSerializers.Query( data={'name': request.query_params.get('name', None), 'desc': request.query_params.get("desc", None), - 'user_id': str(request.user.id), 'select_user_id': request.query_params.get('select_user_id', None)}) + 'user_id': str(request.user.id), + 'select_user_id': request.query_params.get('select_user_id', None)}) d.is_valid() return result.success(d.page(current_page, page_size)) diff --git a/apps/dataset/views/document.py b/apps/dataset/views/document.py index 87e0b886a..2ed8fc5d7 100644 --- a/apps/dataset/views/document.py +++ b/apps/dataset/views/document.py @@ -314,6 +314,20 @@ class Document(APIView): def get(self, request: Request, dataset_id: str, document_id: str): return DocumentSerializers.Operate(data={'document_id': document_id, 'dataset_id': dataset_id}).export() + class ExportZip(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="导出Zip文档", + operation_id="导出Zip文档", + manual_parameters=DocumentSerializers.Operate.get_request_params_api(), + tags=["知识库/文档"]) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE, + dynamic_tag=k.get('dataset_id'))) + def get(self, request: Request, dataset_id: str, document_id: str): + return DocumentSerializers.Operate(data={'document_id': document_id, 'dataset_id': dataset_id}).export_zip() + class Operate(APIView): authentication_classes = [TokenAuth] diff --git a/ui/src/api/dataset.ts b/ui/src/api/dataset.ts index d6ee4d84c..d2599041f 100644 --- a/ui/src/api/dataset.ts +++ b/ui/src/api/dataset.ts @@ -1,5 +1,5 @@ import { Result } from '@/request/Result' -import { get, post, del, put, exportExcel } from '@/request/index' +import { get, post, del, put, exportExcel, exportFile } from '@/request/index' import type { datasetData } from '@/api/type/dataset' import type { pageRequest } from '@/api/type/common' import type { ApplicationFormType } from '@/api/type/application' @@ -201,7 +201,20 @@ const exportDataset: ( ) => Promise = (dataset_name, dataset_id, loading) => { return exportExcel(dataset_name + '.xlsx', `dataset/${dataset_id}/export`, undefined, loading) } - +/** + *导出Zip知识库 + * @param dataset_name 知识库名称 + * @param dataset_id 知识库id + * @param loading 加载器 + * @returns + */ +const exportZipDataset: ( + dataset_name: string, + dataset_id: string, + loading?: Ref +) => Promise = (dataset_name, dataset_id, loading) => { + return exportFile(dataset_name + '.zip', `dataset/${dataset_id}/export_zip`, undefined, loading) +} /** * 获取当前用户可使用的模型列表 @@ -217,7 +230,6 @@ const getDatasetModel: ( return get(`${prefix}/${dataset_id}/model`, loading) } - export default { getDataset, getAllDataset, @@ -232,5 +244,6 @@ export default { putReEmbeddingDataset, postQADataset, exportDataset, - getDatasetModel + getDatasetModel, + exportZipDataset } diff --git a/ui/src/api/document.ts b/ui/src/api/document.ts index 7bd42546c..6356a722e 100644 --- a/ui/src/api/document.ts +++ b/ui/src/api/document.ts @@ -1,5 +1,5 @@ import { Result } from '@/request/Result' -import { get, post, del, put, exportExcel } from '@/request/index' +import { get, post, del, put, exportExcel, exportFile } from '@/request/index' import type { Ref } from 'vue' import type { KeyValue } from '@/api/type/common' import type { pageRequest } from '@/api/type/common' @@ -316,7 +316,27 @@ const exportDocument: ( loading ) } - +/** + * 导出文档 + * @param document_name 文档名称 + * @param dataset_id 数据集id + * @param document_id 文档id + * @param loading 加载器 + * @returns + */ +const exportDocumentZip: ( + document_name: string, + dataset_id: string, + document_id: string, + loading?: Ref +) => Promise = (document_name, dataset_id, document_id, loading) => { + return exportFile( + document_name + '.zip', + `${prefix}/${dataset_id}/document/${document_id}/export_zip`, + {}, + loading + ) +} const batchGenerateRelated: ( dataset_id: string, data: any, @@ -362,5 +382,6 @@ export default { exportDocument, batchRefresh, batchGenerateRelated, - cancelTask + cancelTask, + exportDocumentZip } diff --git a/ui/src/utils/utils.ts b/ui/src/utils/utils.ts index a3c03ef28..a926a4074 100644 --- a/ui/src/utils/utils.ts +++ b/ui/src/utils/utils.ts @@ -38,9 +38,9 @@ export function fileType(name: string) { 获得文件对应图片 */ const typeList: any = { - txt: ['txt', 'pdf', 'docx', 'md', 'html'], + txt: ['txt', 'pdf', 'docx', 'md', 'html', 'zip', 'xlsx', 'xls', 'csv'], table: ['xlsx', 'xls', 'csv'], - QA: ['xlsx', 'csv', 'xls'] + QA: ['xlsx', 'csv', 'xls', 'zip'] } export function getImgUrl(name: string) { @@ -51,6 +51,7 @@ export function getImgUrl(name: string) { } // 是否是白名单后缀 export function isRightType(name: string, type: string) { + console.log(name, type) return typeList[type].includes(fileType(name).toLowerCase()) } diff --git a/ui/src/views/dataset/component/UploadComponent.vue b/ui/src/views/dataset/component/UploadComponent.vue index c6f0ad768..01303f2f0 100644 --- a/ui/src/views/dataset/component/UploadComponent.vue +++ b/ui/src/views/dataset/component/UploadComponent.vue @@ -43,7 +43,7 @@ action="#" :auto-upload="false" :show-file-list="false" - accept=".xlsx, .xls, .csv" + accept=".xlsx, .xls, .csv,.zip" :limit="50" :on-exceed="onExceed" :on-change="fileHandleChange" @@ -129,7 +129,7 @@ action="#" :auto-upload="false" :show-file-list="false" - accept=".txt, .md, .log, .docx, .pdf, .html" + accept=".txt, .md, .log, .docx, .pdf, .html,.zip,.xlsx,.xls,.csv" :limit="50" :on-exceed="onExceed" :on-change="fileHandleChange" @@ -143,7 +143,7 @@ 选择文件夹

-

支持格式:TXT、Markdown、PDF、DOCX、HTML

+

支持格式:TXT、Markdown、PDF、DOCX、HTML、ZIP、XLSX、XLS、CSV

diff --git a/ui/src/views/dataset/index.vue b/ui/src/views/dataset/index.vue index 6f062af3c..7214b532d 100644 --- a/ui/src/views/dataset/index.vue +++ b/ui/src/views/dataset/index.vue @@ -118,7 +118,10 @@ 设置 - 导出导出Excel + + 导出ZIP 删除 { MsgSuccess('导出成功') }) } +const export_zip_dataset = (item: any) => { + datasetApi.exportZipDataset(item.name, item.id, loading).then((ok) => { + MsgSuccess('导出成功') + }) +} function deleteDataset(row: any) { MsgConfirm( diff --git a/ui/src/views/document/index.vue b/ui/src/views/document/index.vue index 5a3cd2f2c..116949aef 100644 --- a/ui/src/views/document/index.vue +++ b/ui/src/views/document/index.vue @@ -305,7 +305,11 @@ - 导出 + 导出Excel + + + + 导出Zip 删除 - 导出 + 导出Excel + + + + 导出Zip 删除 { MsgSuccess('导出成功') }) } +const exportDocumentZip = (document: any) => { + documentApi + .exportDocumentZip(document.name, document.dataset_id, document.id, loading) + .then(() => { + MsgSuccess('导出成功') + }) +} function openDatasetDialog(row?: any) { const arr: string[] = [] if (row) {