feat: Knowledge base import supports zip, xls, xlsx, and csv formats, while knowledge base export supports zip format (#1869)

This commit is contained in:
shaohuzhang1 2024-12-18 18:00:19 +08:00 committed by GitHub
parent 854d74bbe5
commit 832b0dbd63
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 805 additions and 36 deletions

View File

@ -0,0 +1,70 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file csv_parse_qa_handle.py
@date2024/5/21 14:59
@desc:
"""
import csv
import io
from typing import List
from charset_normalizer import detect
from common.handle.base_split_handle import BaseSplitHandle
def post_cell(cell_value):
return cell_value.replace('\n', '<br>').replace('|', '&#124;')
def row_to_md(row):
return '| ' + ' | '.join(
[post_cell(cell) if cell is not None else '' for cell in row]) + ' |\n'
class CsvSplitHandle(BaseSplitHandle):
def handle(self, file, pattern_list: List, with_filter: bool, limit: int, get_buffer, save_image):
buffer = get_buffer(file)
paragraphs = []
result = {'name': file.name, 'content': paragraphs}
try:
reader = csv.reader(io.TextIOWrapper(io.BytesIO(buffer), encoding=detect(buffer)['encoding']))
try:
title_row_list = reader.__next__()
title_md_content = row_to_md(title_row_list)
title_md_content += '| ' + ' | '.join(
['---' if cell is not None else '' for cell in title_row_list]) + ' |\n'
except Exception as e:
return result
if len(title_row_list) == 0:
return result
result_item_content = ''
for row in reader:
next_md_content = row_to_md(row)
next_md_content_len = len(next_md_content)
result_item_content_len = len(result_item_content)
if len(result_item_content) == 0:
result_item_content += title_md_content
result_item_content += next_md_content
else:
if result_item_content_len + next_md_content_len < limit:
result_item_content += next_md_content
else:
paragraphs.append({'content': result_item_content, 'title': ''})
result_item_content = ''
if len(result_item_content) > 0:
paragraphs.append({'content': result_item_content, 'title': ''})
return result
except Exception as e:
return result
def get_content(self, file, save_image):
pass
def support(self, file, get_buffer):
file_name: str = file.name.lower()
if file_name.endswith(".csv"):
return True
return False

View File

@ -0,0 +1,162 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file text_split_handle.py
@date2024/3/27 18:19
@desc:
"""
import io
import os
import re
import uuid
import zipfile
from typing import List
from urllib.parse import urljoin
from django.db.models import QuerySet
from common.handle.base_parse_qa_handle import BaseParseQAHandle
from common.handle.impl.qa.csv_parse_qa_handle import CsvParseQAHandle
from common.handle.impl.qa.xls_parse_qa_handle import XlsParseQAHandle
from common.handle.impl.qa.xlsx_parse_qa_handle import XlsxParseQAHandle
from common.util.common import parse_md_image
from dataset.models import Image
class FileBufferHandle:
buffer = None
def get_buffer(self, file):
if self.buffer is None:
self.buffer = file.read()
return self.buffer
split_handles = [XlsParseQAHandle(), XlsxParseQAHandle(), CsvParseQAHandle()]
def save_inner_image(image_list):
"""
子模块插入图片逻辑
@param image_list:
@return:
"""
if image_list is not None and len(image_list) > 0:
QuerySet(Image).bulk_create(image_list)
def file_to_paragraph(file):
"""
文件转换为段落列表
@param file: 文件
@return: {
name:文件名
paragraphs:段落列表
}
"""
get_buffer = FileBufferHandle().get_buffer
for split_handle in split_handles:
if split_handle.support(file, get_buffer):
return split_handle.handle(file, get_buffer, save_inner_image)
raise Exception("不支持的文件格式")
def is_valid_uuid(uuid_str: str):
"""
校验字符串是否是uuid
@param uuid_str: 需要校验的字符串
@return: bool
"""
try:
uuid.UUID(uuid_str)
except ValueError:
return False
return True
def get_image_list(result_list: list, zip_files: List[str]):
"""
获取图片文件列表
@param result_list:
@param zip_files:
@return:
"""
image_file_list = []
for result in result_list:
for p in result.get('paragraphs', []):
content: str = p.get('content', '')
image_list = parse_md_image(content)
for image in image_list:
search = re.search("\(.*\)", image)
if search:
new_image_id = str(uuid.uuid1())
source_image_path = search.group().replace('(', '').replace(')', '')
image_path = urljoin(result.get('name'), '.' + source_image_path if source_image_path.startswith(
'/') else source_image_path)
if not zip_files.__contains__(image_path):
continue
if image_path.startswith('api/file/') or image_path.startswith('api/image/'):
image_id = image_path.replace('api/file/', '').replace('api/image/', '')
if is_valid_uuid(image_id):
image_file_list.append({'source_file': image_path,
'image_id': image_id})
else:
image_file_list.append({'source_file': image_path,
'image_id': new_image_id})
content = content.replace(source_image_path, f'/api/image/{new_image_id}')
p['content'] = content
else:
image_file_list.append({'source_file': image_path,
'image_id': new_image_id})
content = content.replace(source_image_path, f'/api/image/{new_image_id}')
p['content'] = content
return image_file_list
def filter_image_file(result_list: list, image_list):
image_source_file_list = [image.get('source_file') for image in image_list]
return [r for r in result_list if not image_source_file_list.__contains__(r.get('name', ''))]
class ZipParseQAHandle(BaseParseQAHandle):
def handle(self, file, get_buffer, save_image):
buffer = get_buffer(file)
bytes_io = io.BytesIO(buffer)
result = []
# 打开zip文件
with zipfile.ZipFile(bytes_io, 'r') as zip_ref:
# 获取压缩包中的文件名列表
files = zip_ref.namelist()
# 读取压缩包中的文件内容
for file in files:
if file.endswith('/'):
continue
with zip_ref.open(file) as f:
# 对文件内容进行处理
try:
value = file_to_paragraph(f)
if isinstance(value, list):
result = [*result, *value]
else:
result.append(value)
except Exception:
pass
image_list = get_image_list(result, files)
result = filter_image_file(result, image_list)
image_mode_list = []
for image in image_list:
with zip_ref.open(image.get('source_file')) as f:
i = Image(id=image.get('image_id'), image=f.read(),
image_name=os.path.basename(image.get('source_file')))
image_mode_list.append(i)
save_image(image_mode_list)
return result
def support(self, file, get_buffer):
file_name: str = file.name.lower()
if file_name.endswith(".zip") or file_name.endswith(".ZIP"):
return True
return False

View File

@ -0,0 +1,80 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file xls_parse_qa_handle.py
@date2024/5/21 14:59
@desc:
"""
from typing import List
import xlrd
from common.handle.base_split_handle import BaseSplitHandle
def post_cell(cell_value):
return cell_value.replace('\n', '<br>').replace('|', '&#124;')
def row_to_md(row):
return '| ' + ' | '.join(
[post_cell(str(cell)) if cell is not None else '' for cell in row]) + ' |\n'
def handle_sheet(file_name, sheet, limit: int):
rows = iter([sheet.row_values(i) for i in range(sheet.nrows)])
paragraphs = []
result = {'name': file_name, 'content': paragraphs}
try:
title_row_list = next(rows)
title_md_content = row_to_md(title_row_list)
title_md_content += '| ' + ' | '.join(
['---' if cell is not None else '' for cell in title_row_list]) + ' |\n'
except Exception as e:
return result
if len(title_row_list) == 0:
return result
result_item_content = ''
for row in rows:
next_md_content = row_to_md(row)
next_md_content_len = len(next_md_content)
result_item_content_len = len(result_item_content)
if len(result_item_content) == 0:
result_item_content += title_md_content
result_item_content += next_md_content
else:
if result_item_content_len + next_md_content_len < limit:
result_item_content += next_md_content
else:
paragraphs.append({'content': result_item_content, 'title': ''})
result_item_content = ''
if len(result_item_content) > 0:
paragraphs.append({'content': result_item_content, 'title': ''})
return result
class XlsSplitHandle(BaseSplitHandle):
def handle(self, file, pattern_list: List, with_filter: bool, limit: int, get_buffer, save_image):
buffer = get_buffer(file)
try:
workbook = xlrd.open_workbook(file_contents=buffer)
worksheets = workbook.sheets()
worksheets_size = len(worksheets)
return [row for row in
[handle_sheet(file.name,
sheet, limit) if worksheets_size == 1 and sheet.name == 'Sheet1' else handle_sheet(
sheet.name, sheet, limit) for sheet
in worksheets] if row is not None]
except Exception as e:
return [{'name': file.name, 'content': []}]
def get_content(self, file, save_image):
pass
def support(self, file, get_buffer):
file_name: str = file.name.lower()
buffer = get_buffer(file)
if file_name.endswith(".xls") and xlrd.inspect_format(content=buffer):
return True
return False

View File

@ -0,0 +1,92 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file xlsx_parse_qa_handle.py
@date2024/5/21 14:59
@desc:
"""
import io
from typing import List
import openpyxl
from common.handle.base_split_handle import BaseSplitHandle
from common.handle.impl.tools import xlsx_embed_cells_images
def post_cell(image_dict, cell_value):
image = image_dict.get(cell_value, None)
if image is not None:
return f'![](/api/image/{image.id})'
return cell_value.replace('\n', '<br>').replace('|', '&#124;')
def row_to_md(row, image_dict):
return '| ' + ' | '.join(
[post_cell(image_dict, str(cell.value if cell.value is not None else '')) if cell is not None else '' for cell
in row]) + ' |\n'
def handle_sheet(file_name, sheet, image_dict, limit: int):
rows = sheet.rows
paragraphs = []
result = {'name': file_name, 'content': paragraphs}
try:
title_row_list = next(rows)
title_md_content = row_to_md(title_row_list, image_dict)
title_md_content += '| ' + ' | '.join(
['---' if cell is not None else '' for cell in title_row_list]) + ' |\n'
except Exception as e:
return result
if len(title_row_list) == 0:
return result
result_item_content = ''
for row in rows:
next_md_content = row_to_md(row, image_dict)
next_md_content_len = len(next_md_content)
result_item_content_len = len(result_item_content)
if len(result_item_content) == 0:
result_item_content += title_md_content
result_item_content += next_md_content
else:
if result_item_content_len + next_md_content_len < limit:
result_item_content += next_md_content
else:
paragraphs.append({'content': result_item_content, 'title': ''})
result_item_content = ''
if len(result_item_content) > 0:
paragraphs.append({'content': result_item_content, 'title': ''})
return result
class XlsxSplitHandle(BaseSplitHandle):
def handle(self, file, pattern_list: List, with_filter: bool, limit: int, get_buffer, save_image):
buffer = get_buffer(file)
try:
workbook = openpyxl.load_workbook(io.BytesIO(buffer))
try:
image_dict: dict = xlsx_embed_cells_images(io.BytesIO(buffer))
save_image([item for item in image_dict.values()])
except Exception as e:
image_dict = {}
worksheets = workbook.worksheets
worksheets_size = len(worksheets)
return [row for row in
[handle_sheet(file.name,
sheet,
image_dict,
limit) if worksheets_size == 1 and sheet.title == 'Sheet1' else handle_sheet(
sheet.title, sheet, image_dict, limit) for sheet
in worksheets] if row is not None]
except Exception as e:
return [{'name': file.name, 'content': []}]
def get_content(self, file, save_image):
pass
def support(self, file, get_buffer):
file_name: str = file.name.lower()
if file_name.endswith(".xlsx"):
return True
return False

View File

@ -0,0 +1,147 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file text_split_handle.py
@date2024/3/27 18:19
@desc:
"""
import io
import os
import re
import uuid
import zipfile
from typing import List
from urllib.parse import urljoin
from django.db.models import QuerySet
from common.handle.base_split_handle import BaseSplitHandle
from common.handle.impl.csv_split_handle import CsvSplitHandle
from common.handle.impl.doc_split_handle import DocSplitHandle
from common.handle.impl.html_split_handle import HTMLSplitHandle
from common.handle.impl.pdf_split_handle import PdfSplitHandle
from common.handle.impl.text_split_handle import TextSplitHandle
from common.handle.impl.xls_split_handle import XlsSplitHandle
from common.handle.impl.xlsx_split_handle import XlsxSplitHandle
from common.util.common import parse_md_image
from dataset.models import Image
class FileBufferHandle:
buffer = None
def get_buffer(self, file):
if self.buffer is None:
self.buffer = file.read()
return self.buffer
default_split_handle = TextSplitHandle()
split_handles = [HTMLSplitHandle(), DocSplitHandle(), PdfSplitHandle(), XlsxSplitHandle(), XlsSplitHandle(),
CsvSplitHandle(),
default_split_handle]
def save_inner_image(image_list):
if image_list is not None and len(image_list) > 0:
QuerySet(Image).bulk_create(image_list)
def file_to_paragraph(file, pattern_list: List, with_filter: bool, limit: int):
get_buffer = FileBufferHandle().get_buffer
for split_handle in split_handles:
if split_handle.support(file, get_buffer):
return split_handle.handle(file, pattern_list, with_filter, limit, get_buffer, save_inner_image)
raise Exception("不支持的文件格式")
def is_valid_uuid(uuid_str: str):
try:
uuid.UUID(uuid_str)
except ValueError:
return False
return True
def get_image_list(result_list: list, zip_files: List[str]):
image_file_list = []
for result in result_list:
for p in result.get('content', []):
content: str = p.get('content', '')
image_list = parse_md_image(content)
for image in image_list:
search = re.search("\(.*\)", image)
if search:
new_image_id = str(uuid.uuid1())
source_image_path = search.group().replace('(', '').replace(')', '')
image_path = urljoin(result.get('name'), '.' + source_image_path if source_image_path.startswith(
'/') else source_image_path)
if not zip_files.__contains__(image_path):
continue
if image_path.startswith('api/file/') or image_path.startswith('api/image/'):
image_id = image_path.replace('api/file/', '').replace('api/image/', '')
if is_valid_uuid(image_id):
image_file_list.append({'source_file': image_path,
'image_id': image_id})
else:
image_file_list.append({'source_file': image_path,
'image_id': new_image_id})
content = content.replace(source_image_path, f'/api/image/{new_image_id}')
p['content'] = content
else:
image_file_list.append({'source_file': image_path,
'image_id': new_image_id})
content = content.replace(source_image_path, f'/api/image/{new_image_id}')
p['content'] = content
return image_file_list
def filter_image_file(result_list: list, image_list):
image_source_file_list = [image.get('source_file') for image in image_list]
return [r for r in result_list if not image_source_file_list.__contains__(r.get('name', ''))]
class ZipSplitHandle(BaseSplitHandle):
def handle(self, file, pattern_list: List, with_filter: bool, limit: int, get_buffer, save_image):
buffer = get_buffer(file)
bytes_io = io.BytesIO(buffer)
result = []
# 打开zip文件
with zipfile.ZipFile(bytes_io, 'r') as zip_ref:
# 获取压缩包中的文件名列表
files = zip_ref.namelist()
# 读取压缩包中的文件内容
for file in files:
if file.endswith('/'):
continue
with zip_ref.open(file) as f:
# 对文件内容进行处理
try:
value = file_to_paragraph(f, pattern_list, with_filter, limit)
if isinstance(value, list):
result = [*result, *value]
else:
result.append(value)
except Exception:
pass
image_list = get_image_list(result, files)
result = filter_image_file(result, image_list)
image_mode_list = []
for image in image_list:
with zip_ref.open(image.get('source_file')) as f:
i = Image(id=image.get('image_id'), image=f.read(),
image_name=os.path.basename(image.get('source_file')))
image_mode_list.append(i)
save_image(image_mode_list)
return result
def support(self, file, get_buffer):
file_name: str = file.name.lower()
if file_name.endswith(".zip") or file_name.endswith(".ZIP"):
return True
return False
def get_content(self, file, save_image):
return ""

View File

@ -9,6 +9,7 @@
import hashlib
import importlib
import io
import re
import shutil
import mimetypes
from functools import reduce
@ -109,6 +110,18 @@ def valid_license(model=None, count=None, message=None):
return inner
def parse_image(content: str):
matches = re.finditer("!\[.*?\]\(\/api\/(image|file)\/.*?\)", content)
image_list = [match.group() for match in matches]
return image_list
def parse_md_image(content: str):
matches = re.finditer("!\[.*?\]\(.*?\)", content)
image_list = [match.group() for match in matches]
return image_list
def bulk_create_in_batches(model, data, batch_size=1000):
if len(data) == 0:
return
@ -139,6 +152,7 @@ def bytes_to_uploaded_file(file_bytes, file_name="file.txt"):
)
return uploaded_file
def any_to_amr(any_path, amr_path):
"""
把任意格式转成amr文件

View File

@ -7,7 +7,9 @@
@desc:
"""
import os
import re
import uuid
import zipfile
from typing import List
from django.db.models import QuerySet
@ -22,11 +24,46 @@ from common.mixins.api_mixin import ApiMixin
from common.util.field_message import ErrMessage
from common.util.file_util import get_file_content
from common.util.fork import Fork
from dataset.models import Paragraph, Problem, ProblemParagraphMapping, DataSet
from dataset.models import Paragraph, Problem, ProblemParagraphMapping, DataSet, File, Image
from setting.models_provider import get_model
from smartdoc.conf import PROJECT_DIR
def zip_dir(zip_path, output=None):
output = output or os.path.basename(zip_path) + '.zip'
zip = zipfile.ZipFile(output, 'w', zipfile.ZIP_DEFLATED)
for root, dirs, files in os.walk(zip_path):
relative_root = '' if root == zip_path else root.replace(zip_path, '') + os.sep
for filename in files:
zip.write(os.path.join(root, filename), relative_root + filename)
zip.close()
def write_image(zip_path: str, image_list: List[str]):
for image in image_list:
search = re.search("\(.*\)", image)
if search:
text = search.group()
if text.startswith('(/api/file/'):
r = text.replace('(/api/file/', '').replace(')', '')
file = QuerySet(File).filter(id=r).first()
zip_inner_path = os.path.join('api', 'file', r)
file_path = os.path.join(zip_path, zip_inner_path)
if not os.path.exists(os.path.dirname(file_path)):
os.makedirs(os.path.dirname(file_path))
with open(os.path.join(zip_path, file_path), 'wb') as f:
f.write(file.get_byte())
else:
r = text.replace('(/api/image/', '').replace(')', '')
image_model = QuerySet(Image).filter(id=r).first()
zip_inner_path = os.path.join('api', 'image', r)
file_path = os.path.join(zip_path, zip_inner_path)
if not os.path.exists(os.path.dirname(file_path)):
os.makedirs(os.path.dirname(file_path))
with open(file_path, 'wb') as f:
f.write(image_model.image)
def update_document_char_length(document_id: str):
update_execute(get_file_content(
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'update_document_char_length.sql')),

View File

@ -6,16 +6,19 @@
@date2023/9/21 16:14
@desc:
"""
import io
import logging
import os.path
import re
import traceback
import uuid
import zipfile
from functools import reduce
from tempfile import TemporaryDirectory
from typing import Dict, List
from urllib.parse import urlparse
from celery_once import AlreadyQueued, QueueOnce
from celery_once import AlreadyQueued
from django.contrib.postgres.fields import ArrayField
from django.core import validators
from django.db import transaction, models
@ -31,15 +34,15 @@ from common.db.sql_execute import select_list
from common.event import ListenerManagement
from common.exception.app_exception import AppApiException
from common.mixins.api_mixin import ApiMixin
from common.util.common import post, flat_map, valid_license
from common.util.common import post, flat_map, valid_license, parse_image
from common.util.field_message import ErrMessage
from common.util.file_util import get_file_content
from common.util.fork import ChildLink, Fork
from common.util.split_model import get_split_model
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, ProblemParagraphMapping, Status, \
TaskType, State
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, ProblemParagraphMapping, TaskType, \
State, File, Image
from dataset.serializers.common_serializers import list_paragraph, MetaSerializer, ProblemParagraphManage, \
get_embedding_model_by_dataset_id, get_embedding_model_id_by_dataset_id
get_embedding_model_by_dataset_id, get_embedding_model_id_by_dataset_id, write_image, zip_dir
from dataset.serializers.document_serializers import DocumentSerializers, DocumentInstanceSerializer
from dataset.task import sync_web_dataset, sync_replace_web_dataset
from embedding.models import SearchMode
@ -695,6 +698,33 @@ class DataSetSerializers(serializers.ModelSerializer):
workbook.save(response)
return response
def export_zip(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
document_list = QuerySet(Document).filter(dataset_id=self.data.get('id'))
paragraph_list = native_search(QuerySet(Paragraph).filter(dataset_id=self.data.get("id")), get_file_content(
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_paragraph_document_name.sql')))
problem_mapping_list = native_search(
QuerySet(ProblemParagraphMapping).filter(dataset_id=self.data.get("id")), get_file_content(
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_problem_mapping.sql')),
with_table_name=True)
data_dict, document_dict = DocumentSerializers.Operate.merge_problem(paragraph_list, problem_mapping_list,
document_list)
res = [parse_image(paragraph.get('content')) for paragraph in paragraph_list]
workbook = DocumentSerializers.Operate.get_workbook(data_dict, document_dict)
response = HttpResponse(content_type='application/zip')
response['Content-Disposition'] = 'attachment; filename="archive.zip"'
zip_buffer = io.BytesIO()
with TemporaryDirectory() as tempdir:
dataset_file = os.path.join(tempdir, 'dataset.xlsx')
workbook.save(dataset_file)
for r in res:
write_image(tempdir, r)
zip_dir(tempdir, zip_buffer)
response.write(zip_buffer.getvalue())
return response
@staticmethod
def merge_problem(paragraph_list: List[Dict], problem_mapping_list: List[Dict]):
result = {}

View File

@ -6,12 +6,14 @@
@date2023/9/22 13:43
@desc:
"""
import io
import logging
import os
import re
import traceback
import uuid
from functools import reduce
from tempfile import TemporaryDirectory
from typing import List, Dict
import openpyxl
@ -30,18 +32,23 @@ from common.db.search import native_search, native_page_search
from common.event import ListenerManagement
from common.event.common import work_thread_pool
from common.exception.app_exception import AppApiException
from common.handle.impl.csv_split_handle import CsvSplitHandle
from common.handle.impl.doc_split_handle import DocSplitHandle
from common.handle.impl.html_split_handle import HTMLSplitHandle
from common.handle.impl.pdf_split_handle import PdfSplitHandle
from common.handle.impl.qa.csv_parse_qa_handle import CsvParseQAHandle
from common.handle.impl.qa.xls_parse_qa_handle import XlsParseQAHandle
from common.handle.impl.qa.xlsx_parse_qa_handle import XlsxParseQAHandle
from common.handle.impl.table.csv_parse_table_handle import CsvSplitHandle
from common.handle.impl.table.xls_parse_table_handle import XlsSplitHandle
from common.handle.impl.table.xlsx_parse_table_handle import XlsxSplitHandle
from common.handle.impl.qa.zip_parse_qa_handle import ZipParseQAHandle
from common.handle.impl.table.csv_parse_table_handle import CsvSplitHandle as CsvSplitTableHandle
from common.handle.impl.table.xls_parse_table_handle import XlsSplitHandle as XlsSplitTableHandle
from common.handle.impl.table.xlsx_parse_table_handle import XlsxSplitHandle as XlsxSplitTableHandle
from common.handle.impl.text_split_handle import TextSplitHandle
from common.handle.impl.xls_split_handle import XlsSplitHandle
from common.handle.impl.xlsx_split_handle import XlsxSplitHandle
from common.handle.impl.zip_split_handle import ZipSplitHandle
from common.mixins.api_mixin import ApiMixin
from common.util.common import post, flat_map, bulk_create_in_batches
from common.util.common import post, flat_map, bulk_create_in_batches, parse_image
from common.util.field_message import ErrMessage
from common.util.file_util import get_file_content
from common.util.fork import Fork
@ -49,7 +56,7 @@ from common.util.split_model import get_split_model
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, ProblemParagraphMapping, Image, \
TaskType, State
from dataset.serializers.common_serializers import BatchSerializer, MetaSerializer, ProblemParagraphManage, \
get_embedding_model_id_by_dataset_id
get_embedding_model_id_by_dataset_id, write_image, zip_dir
from dataset.serializers.paragraph_serializers import ParagraphSerializers, ParagraphInstanceSerializer
from dataset.task import sync_web_document, generate_related_by_document_id
from embedding.task.embedding import embedding_by_document, delete_embedding_by_document_list, \
@ -57,8 +64,8 @@ from embedding.task.embedding import embedding_by_document, delete_embedding_by_
embedding_by_document_list
from smartdoc.conf import PROJECT_DIR
parse_qa_handle_list = [XlsParseQAHandle(), CsvParseQAHandle(), XlsxParseQAHandle()]
parse_table_handle_list = [CsvSplitHandle(), XlsSplitHandle(), XlsxSplitHandle()]
parse_qa_handle_list = [XlsParseQAHandle(), CsvParseQAHandle(), XlsxParseQAHandle(), ZipParseQAHandle()]
parse_table_handle_list = [CsvSplitTableHandle(), XlsSplitTableHandle(), XlsxSplitTableHandle()]
class FileBufferHandle:
@ -563,6 +570,34 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
workbook.save(response)
return response
def export_zip(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
document = QuerySet(Document).filter(id=self.data.get("document_id")).first()
paragraph_list = native_search(QuerySet(Paragraph).filter(document_id=self.data.get("document_id")),
get_file_content(
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql',
'list_paragraph_document_name.sql')))
problem_mapping_list = native_search(
QuerySet(ProblemParagraphMapping).filter(document_id=self.data.get("document_id")), get_file_content(
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_problem_mapping.sql')),
with_table_name=True)
data_dict, document_dict = self.merge_problem(paragraph_list, problem_mapping_list, [document])
res = [parse_image(paragraph.get('content')) for paragraph in paragraph_list]
workbook = DocumentSerializers.Operate.get_workbook(data_dict, document_dict)
response = HttpResponse(content_type='application/zip')
response['Content-Disposition'] = 'attachment; filename="archive.zip"'
zip_buffer = io.BytesIO()
with TemporaryDirectory() as tempdir:
dataset_file = os.path.join(tempdir, 'dataset.xlsx')
workbook.save(dataset_file)
for r in res:
write_image(tempdir, r)
zip_dir(tempdir, zip_buffer)
response.write(zip_buffer.getvalue())
return response
@staticmethod
def get_workbook(data_dict, document_dict):
# 创建工作簿对象
@ -929,9 +964,9 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
def parse(self):
file_list = self.data.get("file")
return list(
map(lambda f: file_to_paragraph(f, self.data.get("patterns", None), self.data.get("with_filter", None),
self.data.get("limit", 4096)), file_list))
return reduce(lambda x, y: [*x, *y],
[file_to_paragraph(f, self.data.get("patterns", None), self.data.get("with_filter", None),
self.data.get("limit", 4096)) for f in file_list], [])
class SplitPattern(ApiMixin, serializers.Serializer):
@staticmethod
@ -1109,20 +1144,33 @@ class FileBufferHandle:
default_split_handle = TextSplitHandle()
split_handles = [HTMLSplitHandle(), DocSplitHandle(), PdfSplitHandle(), default_split_handle]
split_handles = [HTMLSplitHandle(), DocSplitHandle(), PdfSplitHandle(), XlsxSplitHandle(), XlsSplitHandle(),
CsvSplitHandle(),
ZipSplitHandle(),
default_split_handle]
def save_image(image_list):
if image_list is not None and len(image_list) > 0:
QuerySet(Image).bulk_create(image_list)
exist_image_list = [str(i.get('id')) for i in
QuerySet(Image).filter(id__in=[i.id for i in image_list]).values('id')]
save_image_list = [image for image in image_list if not exist_image_list.__contains__(str(image.id))]
if len(save_image_list) > 0:
QuerySet(Image).bulk_create(save_image_list)
def file_to_paragraph(file, pattern_list: List, with_filter: bool, limit: int):
get_buffer = FileBufferHandle().get_buffer
for split_handle in split_handles:
if split_handle.support(file, get_buffer):
return split_handle.handle(file, pattern_list, with_filter, limit, get_buffer, save_image)
return default_split_handle.handle(file, pattern_list, with_filter, limit, get_buffer, save_image)
result = split_handle.handle(file, pattern_list, with_filter, limit, get_buffer, save_image)
if isinstance(result, list):
return result
return [result]
result = default_split_handle.handle(file, pattern_list, with_filter, limit, get_buffer, save_image)
if isinstance(result, list):
return result
return [result]
def delete_problems_and_mappings(document_ids):

View File

@ -9,6 +9,7 @@ urlpatterns = [
path('dataset/qa', views.Dataset.CreateQADataset.as_view()),
path('dataset/<str:dataset_id>', views.Dataset.Operate.as_view(), name="dataset_key"),
path('dataset/<str:dataset_id>/export', views.Dataset.Export.as_view(), name="export"),
path('dataset/<str:dataset_id>/export_zip', views.Dataset.ExportZip.as_view(), name="export_zip"),
path('dataset/<str:dataset_id>/re_embedding', views.Dataset.Embedding.as_view(), name="dataset_key"),
path('dataset/<str:dataset_id>/application', views.Dataset.Application.as_view()),
path('dataset/<int:current_page>/<int:page_size>', views.Dataset.Page.as_view(), name="dataset"),
@ -35,6 +36,8 @@ urlpatterns = [
path('dataset/<str:dataset_id>/document/migrate/<str:target_dataset_id>', views.Document.Migrate.as_view()),
path('dataset/<str:dataset_id>/document/<str:document_id>/export', views.Document.Export.as_view(),
name="document_export"),
path('dataset/<str:dataset_id>/document/<str:document_id>/export_zip', views.Document.ExportZip.as_view(),
name="document_export"),
path('dataset/<str:dataset_id>/document/<str:document_id>/sync', views.Document.SyncWeb.as_view()),
path('dataset/<str:dataset_id>/document/<str:document_id>/refresh', views.Document.Refresh.as_view()),
path('dataset/<str:dataset_id>/document/<str:document_id>/cancel_task', views.Document.CancelTask.as_view()),

View File

@ -165,6 +165,19 @@ class Dataset(APIView):
def get(self, request: Request, dataset_id: str):
return DataSetSerializers.Operate(data={'id': dataset_id, 'user_id': request.user.id}).export_excel()
class ExportZip(APIView):
authentication_classes = [TokenAuth]
@action(methods="GET", detail=False)
@swagger_auto_schema(operation_summary="导出知识库包含图片", operation_id="导出知识库包含图片",
manual_parameters=DataSetSerializers.Operate.get_request_params_api(),
tags=["知识库"]
)
@has_permissions(lambda r, keywords: Permission(group=Group.DATASET, operate=Operate.MANAGE,
dynamic_tag=keywords.get('dataset_id')))
def get(self, request: Request, dataset_id: str):
return DataSetSerializers.Operate(data={'id': dataset_id, 'user_id': request.user.id}).export_zip()
class Operate(APIView):
authentication_classes = [TokenAuth]
@ -221,7 +234,8 @@ class Dataset(APIView):
def get(self, request: Request, current_page, page_size):
d = DataSetSerializers.Query(
data={'name': request.query_params.get('name', None), 'desc': request.query_params.get("desc", None),
'user_id': str(request.user.id), 'select_user_id': request.query_params.get('select_user_id', None)})
'user_id': str(request.user.id),
'select_user_id': request.query_params.get('select_user_id', None)})
d.is_valid()
return result.success(d.page(current_page, page_size))

View File

@ -314,6 +314,20 @@ class Document(APIView):
def get(self, request: Request, dataset_id: str, document_id: str):
return DocumentSerializers.Operate(data={'document_id': document_id, 'dataset_id': dataset_id}).export()
class ExportZip(APIView):
authentication_classes = [TokenAuth]
@action(methods=['GET'], detail=False)
@swagger_auto_schema(operation_summary="导出Zip文档",
operation_id="导出Zip文档",
manual_parameters=DocumentSerializers.Operate.get_request_params_api(),
tags=["知识库/文档"])
@has_permissions(
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
dynamic_tag=k.get('dataset_id')))
def get(self, request: Request, dataset_id: str, document_id: str):
return DocumentSerializers.Operate(data={'document_id': document_id, 'dataset_id': dataset_id}).export_zip()
class Operate(APIView):
authentication_classes = [TokenAuth]

View File

@ -1,5 +1,5 @@
import { Result } from '@/request/Result'
import { get, post, del, put, exportExcel } from '@/request/index'
import { get, post, del, put, exportExcel, exportFile } from '@/request/index'
import type { datasetData } from '@/api/type/dataset'
import type { pageRequest } from '@/api/type/common'
import type { ApplicationFormType } from '@/api/type/application'
@ -201,7 +201,20 @@ const exportDataset: (
) => Promise<any> = (dataset_name, dataset_id, loading) => {
return exportExcel(dataset_name + '.xlsx', `dataset/${dataset_id}/export`, undefined, loading)
}
/**
*Zip知识库
* @param dataset_name
* @param dataset_id id
* @param loading
* @returns
*/
const exportZipDataset: (
dataset_name: string,
dataset_id: string,
loading?: Ref<boolean>
) => Promise<any> = (dataset_name, dataset_id, loading) => {
return exportFile(dataset_name + '.zip', `dataset/${dataset_id}/export_zip`, undefined, loading)
}
/**
* 使
@ -217,7 +230,6 @@ const getDatasetModel: (
return get(`${prefix}/${dataset_id}/model`, loading)
}
export default {
getDataset,
getAllDataset,
@ -232,5 +244,6 @@ export default {
putReEmbeddingDataset,
postQADataset,
exportDataset,
getDatasetModel
getDatasetModel,
exportZipDataset
}

View File

@ -1,5 +1,5 @@
import { Result } from '@/request/Result'
import { get, post, del, put, exportExcel } from '@/request/index'
import { get, post, del, put, exportExcel, exportFile } from '@/request/index'
import type { Ref } from 'vue'
import type { KeyValue } from '@/api/type/common'
import type { pageRequest } from '@/api/type/common'
@ -316,7 +316,27 @@ const exportDocument: (
loading
)
}
/**
*
* @param document_name
* @param dataset_id id
* @param document_id id
* @param loading
* @returns
*/
const exportDocumentZip: (
document_name: string,
dataset_id: string,
document_id: string,
loading?: Ref<boolean>
) => Promise<any> = (document_name, dataset_id, document_id, loading) => {
return exportFile(
document_name + '.zip',
`${prefix}/${dataset_id}/document/${document_id}/export_zip`,
{},
loading
)
}
const batchGenerateRelated: (
dataset_id: string,
data: any,
@ -362,5 +382,6 @@ export default {
exportDocument,
batchRefresh,
batchGenerateRelated,
cancelTask
cancelTask,
exportDocumentZip
}

View File

@ -38,9 +38,9 @@ export function fileType(name: string) {
*/
const typeList: any = {
txt: ['txt', 'pdf', 'docx', 'md', 'html'],
txt: ['txt', 'pdf', 'docx', 'md', 'html', 'zip', 'xlsx', 'xls', 'csv'],
table: ['xlsx', 'xls', 'csv'],
QA: ['xlsx', 'csv', 'xls']
QA: ['xlsx', 'csv', 'xls', 'zip']
}
export function getImgUrl(name: string) {
@ -51,6 +51,7 @@ export function getImgUrl(name: string) {
}
// 是否是白名单后缀
export function isRightType(name: string, type: string) {
console.log(name, type)
return typeList[type].includes(fileType(name).toLowerCase())
}

View File

@ -43,7 +43,7 @@
action="#"
:auto-upload="false"
:show-file-list="false"
accept=".xlsx, .xls, .csv"
accept=".xlsx, .xls, .csv,.zip"
:limit="50"
:on-exceed="onExceed"
:on-change="fileHandleChange"
@ -129,7 +129,7 @@
action="#"
:auto-upload="false"
:show-file-list="false"
accept=".txt, .md, .log, .docx, .pdf, .html"
accept=".txt, .md, .log, .docx, .pdf, .html,.zip,.xlsx,.xls,.csv"
:limit="50"
:on-exceed="onExceed"
:on-change="fileHandleChange"
@ -143,7 +143,7 @@
<em class="hover" @click.prevent="handlePreview(true)"> 选择文件夹 </em>
</p>
<div class="upload__decoration">
<p>支持格式TXTMarkdownPDFDOCXHTML</p>
<p>支持格式TXTMarkdownPDFDOCXHTMLZIPXLSXXLSCSV</p>
</div>
</div>
</el-upload>

View File

@ -118,7 +118,10 @@
设置</el-dropdown-item
>
<el-dropdown-item @click.stop="export_dataset(item)">
<AppIcon iconName="app-export"></AppIcon>导出</el-dropdown-item
<AppIcon iconName="app-export"></AppIcon>导出Excel</el-dropdown-item
>
<el-dropdown-item @click.stop="export_zip_dataset(item)">
<AppIcon iconName="app-export"></AppIcon>导出ZIP</el-dropdown-item
>
<el-dropdown-item icon="Delete" @click.stop="deleteDataset(item)"
>删除</el-dropdown-item
@ -225,6 +228,11 @@ const export_dataset = (item: any) => {
MsgSuccess('导出成功')
})
}
const export_zip_dataset = (item: any) => {
datasetApi.exportZipDataset(item.name, item.id, loading).then((ok) => {
MsgSuccess('导出成功')
})
}
function deleteDataset(row: any) {
MsgConfirm(

View File

@ -305,7 +305,11 @@
</el-dropdown-item>
<el-dropdown-item @click="exportDocument(row)">
<AppIcon iconName="app-export"></AppIcon>
导出
导出Excel
</el-dropdown-item>
<el-dropdown-item @click="exportDocumentZip(row)">
<AppIcon iconName="app-export"></AppIcon>
导出Zip
</el-dropdown-item>
<el-dropdown-item icon="Delete" @click.stop="deleteDocument(row)"
>删除</el-dropdown-item
@ -381,7 +385,11 @@
>
<el-dropdown-item @click="exportDocument(row)">
<AppIcon iconName="app-export"></AppIcon>
导出
导出Excel
</el-dropdown-item>
<el-dropdown-item @click="exportDocumentZip(row)">
<AppIcon iconName="app-export"></AppIcon>
导出Zip
</el-dropdown-item>
<el-dropdown-item icon="Delete" @click.stop="deleteDocument(row)"
>删除</el-dropdown-item
@ -475,6 +483,13 @@ const exportDocument = (document: any) => {
MsgSuccess('导出成功')
})
}
const exportDocumentZip = (document: any) => {
documentApi
.exportDocumentZip(document.name, document.dataset_id, document.id, loading)
.then(() => {
MsgSuccess('导出成功')
})
}
function openDatasetDialog(row?: any) {
const arr: string[] = []
if (row) {