diff --git a/apps/common/handle/base_parse_table_handle.py b/apps/common/handle/base_parse_table_handle.py new file mode 100644 index 000000000..e5331e19f --- /dev/null +++ b/apps/common/handle/base_parse_table_handle.py @@ -0,0 +1,19 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: base_parse_qa_handle.py + @date:2024/5/21 14:56 + @desc: +""" +from abc import ABC, abstractmethod + + +class BaseParseTableHandle(ABC): + @abstractmethod + def support(self, file, get_buffer): + pass + + @abstractmethod + def handle(self, file, get_buffer): + pass diff --git a/apps/common/handle/impl/table/csv_parse_table_handle.py b/apps/common/handle/impl/table/csv_parse_table_handle.py new file mode 100644 index 000000000..1104dd899 --- /dev/null +++ b/apps/common/handle/impl/table/csv_parse_table_handle.py @@ -0,0 +1,34 @@ +# coding=utf-8 +import logging + +from charset_normalizer import detect + +from common.handle.base_parse_table_handle import BaseParseTableHandle + +max_kb = logging.getLogger("max_kb") + + +class CsvSplitHandle(BaseParseTableHandle): + def support(self, file, get_buffer): + file_name: str = file.name.lower() + if file_name.endswith(".csv"): + return True + return False + + def handle(self, file, get_buffer): + buffer = get_buffer(file) + try: + content = buffer.decode(detect(buffer)['encoding']) + except BaseException as e: + max_kb.error(f'csv split handle error: {e}') + return [{'name': file.name, 'paragraphs': []}] + + csv_model = content.split('\n') + paragraphs = [] + # 第一行为标题 + title = csv_model[0].split(',') + for row in csv_model[1:]: + line = '; '.join([f'{key}:{value}' for key, value in zip(title, row.split(','))]) + paragraphs.append({'title': '', 'content': line}) + + return [{'name': file.name, 'paragraphs': paragraphs}] diff --git a/apps/common/handle/impl/table/excel_parse_table_handle.py b/apps/common/handle/impl/table/excel_parse_table_handle.py new file mode 100644 index 000000000..665e70ebc --- /dev/null +++ b/apps/common/handle/impl/table/excel_parse_table_handle.py @@ -0,0 +1,49 @@ +# coding=utf-8 +import io +import logging + +from openpyxl import load_workbook + +from common.handle.base_parse_table_handle import BaseParseTableHandle + +max_kb = logging.getLogger("max_kb") + + +class ExcelSplitHandle(BaseParseTableHandle): + def support(self, file, get_buffer): + file_name: str = file.name.lower() + if file_name.endswith('.xls') or file_name.endswith('.xlsx'): + return True + return False + + def handle(self, file, get_buffer): + buffer = get_buffer(file) + try: + wb = load_workbook(io.BytesIO(buffer)) + result = [] + for sheetname in wb.sheetnames: + paragraphs = [] + ws = wb[sheetname] + rows = list(ws.rows) + if not rows: continue + ti = list(rows[0]) + for r in list(rows[1:]): + title = [] + l = [] + for i, c in enumerate(r): + if not c.value: + continue + t = str(ti[i].value) if i < len(ti) else "" + title.append(t) + t += (": " if t else "") + str(c.value) + l.append(t) + l = "; ".join(l) + if sheetname.lower().find("sheet") < 0: + l += " ——" + sheetname + paragraphs.append({'title': '', 'content': l}) + result.append({'name': sheetname, 'paragraphs': paragraphs}) + + except BaseException as e: + max_kb.error(f'excel split handle error: {e}') + return [{'name': file.name, 'paragraphs': []}] + return result diff --git a/apps/dataset/serializers/document_serializers.py b/apps/dataset/serializers/document_serializers.py index 43d401428..0c29d348a 100644 --- a/apps/dataset/serializers/document_serializers.py +++ b/apps/dataset/serializers/document_serializers.py @@ -33,6 +33,8 @@ from common.handle.impl.pdf_split_handle import PdfSplitHandle from common.handle.impl.qa.csv_parse_qa_handle import CsvParseQAHandle from common.handle.impl.qa.xls_parse_qa_handle import XlsParseQAHandle from common.handle.impl.qa.xlsx_parse_qa_handle import XlsxParseQAHandle +from common.handle.impl.table.csv_parse_table_handle import CsvSplitHandle +from common.handle.impl.table.excel_parse_table_handle import ExcelSplitHandle from common.handle.impl.text_split_handle import TextSplitHandle from common.mixins.api_mixin import ApiMixin from common.util.common import post, flat_map @@ -51,6 +53,7 @@ from embedding.task.embedding import embedding_by_document, delete_embedding_by_ from smartdoc.conf import PROJECT_DIR parse_qa_handle_list = [XlsParseQAHandle(), CsvParseQAHandle(), XlsxParseQAHandle()] +parse_table_handle_list = [CsvSplitHandle(), ExcelSplitHandle()] class FileBufferHandle: @@ -152,6 +155,13 @@ class DocumentInstanceQASerializer(ApiMixin, serializers.Serializer): error_messages=ErrMessage.file("文件"))) +class DocumentInstanceTableSerializer(ApiMixin, serializers.Serializer): + file_list = serializers.ListSerializer(required=True, + error_messages=ErrMessage.list("文件列表"), + child=serializers.FileField(required=True, + error_messages=ErrMessage.file("文件"))) + + class DocumentSerializers(ApiMixin, serializers.Serializer): class Export(ApiMixin, serializers.Serializer): type = serializers.CharField(required=True, validators=[ @@ -187,6 +197,23 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): return HttpResponse(content, status=200, headers={'Content-Type': 'application/vnd.ms-excel', 'Content-Disposition': 'attachment; filename="excel_template.xlsx"'}) + def table_export(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + + if self.data.get('type') == 'csv': + file = open(os.path.join(PROJECT_DIR, "apps", "dataset", 'template', 'MaxKB表格模板.csv'), "rb") + content = file.read() + file.close() + return HttpResponse(content, status=200, headers={'Content-Type': 'text/cxv', + 'Content-Disposition': 'attachment; filename="csv_template.csv"'}) + elif self.data.get('type') == 'excel': + file = open(os.path.join(PROJECT_DIR, "apps", "dataset", 'template', 'MaxKB表格模板.xlsx'), "rb") + content = file.read() + file.close() + return HttpResponse(content, status=200, headers={'Content-Type': 'application/vnd.ms-excel', + 'Content-Disposition': 'attachment; filename="excel_template.xlsx"'}) + class Migrate(ApiMixin, serializers.Serializer): dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char( @@ -633,6 +660,14 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): return parse_qa_handle.handle(file, get_buffer) raise AppApiException(500, '不支持的文件格式') + @staticmethod + def parse_table_file(file): + get_buffer = FileBufferHandle().get_buffer + for parse_table_handle in parse_table_handle_list: + if parse_table_handle.support(file, get_buffer): + return parse_table_handle.handle(file, get_buffer) + raise AppApiException(500, '不支持的文件格式') + def save_qa(self, instance: Dict, with_valid=True): if with_valid: DocumentInstanceQASerializer(data=instance).is_valid(raise_exception=True) @@ -641,6 +676,14 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): document_list = flat_map([self.parse_qa_file(file) for file in file_list]) return DocumentSerializers.Batch(data={'dataset_id': self.data.get('dataset_id')}).batch_save(document_list) + def save_table(self, instance: Dict, with_valid=True): + if with_valid: + DocumentInstanceTableSerializer(data=instance).is_valid(raise_exception=True) + self.is_valid(raise_exception=True) + file_list = instance.get('file_list') + document_list = flat_map([self.parse_table_file(file) for file in file_list]) + return DocumentSerializers.Batch(data={'dataset_id': self.data.get('dataset_id')}).batch_save(document_list) + @post(post_function=post_embedding) @transaction.atomic def save(self, instance: Dict, with_valid=False, **kwargs): diff --git a/apps/dataset/template/MaxKB表格模板.csv b/apps/dataset/template/MaxKB表格模板.csv new file mode 100644 index 000000000..7cf0f6306 --- /dev/null +++ b/apps/dataset/template/MaxKB表格模板.csv @@ -0,0 +1,13 @@ +职务,报销类型,一线城市报销标准(元),二线城市报销标准(元),三线城市报销标准(元) +普通员工,住宿费,500,400,300 +部门主管,住宿费,600,500,400 +部门总监,住宿费,700,600,500 +区域总经理,住宿费,800,700,600 +普通员工,伙食费,50,40,30 +部门主管,伙食费,50,40,30 +部门总监,伙食费,50,40,30 +区域总经理,伙食费,50,40,30 +普通员工,交通费,50,40,30 +部门主管,交通费,50,40,30 +部门总监,交通费,50,40,30 +区域总经理,交通费,50,40,30 diff --git a/apps/dataset/template/MaxKB表格模板.xlsx b/apps/dataset/template/MaxKB表格模板.xlsx new file mode 100644 index 000000000..2bc94a5b8 Binary files /dev/null and b/apps/dataset/template/MaxKB表格模板.xlsx differ diff --git a/apps/dataset/urls.py b/apps/dataset/urls.py index 8162a13ab..026492d18 100644 --- a/apps/dataset/urls.py +++ b/apps/dataset/urls.py @@ -16,8 +16,10 @@ urlpatterns = [ path('dataset//hit_test', views.Dataset.HitTest.as_view()), path('dataset//document', views.Document.as_view(), name='document'), path('dataset/document/template/export', views.Template.as_view()), + path('dataset/document/table_template/export', views.TableTemplate.as_view()), path('dataset//document/web', views.WebDocument.as_view()), path('dataset//document/qa', views.QaDocument.as_view()), + path('dataset//document/table', views.TableDocument.as_view()), path('dataset//document/_bach', views.Document.Batch.as_view()), path('dataset//document/batch_hit_handling', views.Document.BatchEditHitHandling.as_view()), path('dataset//document//', views.Document.Page.as_view()), diff --git a/apps/dataset/views/document.py b/apps/dataset/views/document.py index f522d01ce..1988ca75a 100644 --- a/apps/dataset/views/document.py +++ b/apps/dataset/views/document.py @@ -33,6 +33,17 @@ class Template(APIView): def get(self, request: Request): return DocumentSerializers.Export(data={'type': request.query_params.get('type')}).export(with_valid=True) +class TableTemplate(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="获取表格模版", + operation_id="获取表格模版", + manual_parameters=DocumentSerializers.Export.get_request_params_api(), + tags=["知识库/文档"]) + def get(self, request: Request): + return DocumentSerializers.Export(data={'type': request.query_params.get('type')}).table_export(with_valid=True) + class WebDocument(APIView): authentication_classes = [TokenAuth] @@ -71,6 +82,24 @@ class QaDocument(APIView): {'file_list': request.FILES.getlist('file')}, with_valid=True)) +class TableDocument(APIView): + authentication_classes = [TokenAuth] + parser_classes = [MultiPartParser] + + @action(methods=['POST'], detail=False) + @swagger_auto_schema(operation_summary="导入表格并创建文档", + operation_id="导入表格并创建文档", + manual_parameters=DocumentWebInstanceSerializer.get_request_params_api(), + responses=result.get_api_response(DocumentSerializers.Create.get_response_body_api()), + tags=["知识库/文档"]) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE, + dynamic_tag=k.get('dataset_id'))) + def post(self, request: Request, dataset_id: str): + return result.success( + DocumentSerializers.Create(data={'dataset_id': dataset_id}).save_table( + {'file_list': request.FILES.getlist('file')}, + with_valid=True)) class Document(APIView): authentication_classes = [TokenAuth] diff --git a/ui/src/api/document.ts b/ui/src/api/document.ts index 5bf294100..0653f2d40 100644 --- a/ui/src/api/document.ts +++ b/ui/src/api/document.ts @@ -211,6 +211,19 @@ const postQADocument: ( return post(`${prefix}/${dataset_id}/document/qa`, data, undefined, loading) } +/** + * 导入表格 + * @param 参数 + * file + */ +const postTableDocument: ( + dataset_id: string, + data: any, + loading?: Ref +) => Promise> = (dataset_id, data, loading) => { + return post(`${prefix}/${dataset_id}/document/table`, data, undefined, loading) +} + /** * 批量迁移文档 * @param 参数 dataset_id,target_dataset_id, @@ -256,6 +269,18 @@ const exportQATemplate: (fileName: string, type: string, loading?: Ref) return exportExcel(fileName, `${prefix}/document/template/export`, { type }, loading) } +/** + * 获得table模版 + * @param 参数 fileName,type, + */ +const exportTableTemplate: (fileName: string, type: string, loading?: Ref) => void = ( + fileName, + type, + loading +) => { + return exportExcel(fileName, `${prefix}/document/table_template/export`, { type }, loading) +} + /** * 导出文档 * @param document_name 文档名称 @@ -295,6 +320,8 @@ export default { putMigrateMulDocument, batchEditHitHandling, exportQATemplate, + exportTableTemplate, postQADocument, + postTableDocument, exportDocument } diff --git a/ui/src/utils/utils.ts b/ui/src/utils/utils.ts index 9b30135fb..b2d77d834 100644 --- a/ui/src/utils/utils.ts +++ b/ui/src/utils/utils.ts @@ -39,6 +39,7 @@ export function fileType(name: string) { */ const typeList: any = { txt: ['txt', 'pdf', 'docx', 'csv', 'md', 'html', 'PDF'], + table: ['xlsx', 'xls', 'csv'], QA: ['xlsx', 'csv', 'xls'] } diff --git a/ui/src/views/dataset/UploadDocumentDataset.vue b/ui/src/views/dataset/UploadDocumentDataset.vue index 370451a54..c434ea2fc 100644 --- a/ui/src/views/dataset/UploadDocumentDataset.vue +++ b/ui/src/views/dataset/UploadDocumentDataset.vue @@ -78,6 +78,21 @@ async function next() { router.push({ path: `/dataset/${id}/document` }) }) } + } else if (documentsType.value === 'table') { + let fd = new FormData() + documentsFiles.value.forEach((item: any) => { + if (item?.raw) { + fd.append('file', item?.raw) + } + }) + if (id) { + // table文档上传 + documentApi.postTableDocument(id as string, fd, loading).then((res) => { + MsgSuccess('提交成功') + clearStore() + router.push({ path: `/dataset/${id}/document` }) + }) + } } else { if (active.value++ > 2) active.value = 0 } diff --git a/ui/src/views/dataset/component/UploadComponent.vue b/ui/src/views/dataset/component/UploadComponent.vue index d51aff263..305ae1ca4 100644 --- a/ui/src/views/dataset/component/UploadComponent.vue +++ b/ui/src/views/dataset/component/UploadComponent.vue @@ -10,6 +10,7 @@ 文本文件 + 表格 QA 问答对 @@ -48,6 +49,42 @@ 下载 CSV 模板 + + + +
+

+ 拖拽文件至此上传或 + 选择文件 + 选择文件夹 +

+
+

当前支持 EXCEL和CSV 格式文件。

+

第一行必须是列标题,且列标题必须是有意义的术语,表中每条记录将作为一个分段。

+

每次最多上传50个文档,每个文档最大不能超过100MB。

+
+
+
+ + 下载 Excel 模板 + + + 下载 CSV 模板 +

- 支持格式:TXT、Markdown、PDF、DOCX、HTML 每次最多上传50个文件,每个文件不超过 100MB + 支持格式:TXT、Markdown、PDF、DOCX、HTML、Excel、CSV 每次最多上传50个文件,每个文件不超过 100MB

若使用【高级分段】建议上传前规范文件的分段标识

@@ -133,6 +170,10 @@ function downloadTemplate(type: string) { documentApi.exportQATemplate(`${type}模版.${type == 'csv' ? type : 'xlsx'}`, type) } +function downloadTableTemplate(type: string) { + documentApi.exportTableTemplate(`${type}模版.${type == 'csv' ? type : 'xlsx'}`, type) +} + function radioChange() { form.value.fileList = [] }