feat: 知识库支持上传csv和excel

--story=1016154 --user=刘瑞斌 【知识库】-支持上传表格类型文档(Excel/CSV)按行分段 https://www.tapd.cn/57709429/s/1567910
This commit is contained in:
CaptainB 2024-08-22 16:52:32 +08:00 committed by 刘瑞斌
parent 3d1b3ea8d5
commit 57b15a8a7f
12 changed files with 274 additions and 1 deletions

View File

@ -0,0 +1,19 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file base_parse_qa_handle.py
@date2024/5/21 14:56
@desc:
"""
from abc import ABC, abstractmethod
class BaseParseTableHandle(ABC):
@abstractmethod
def support(self, file, get_buffer):
pass
@abstractmethod
def handle(self, file, get_buffer):
pass

View File

@ -0,0 +1,34 @@
# coding=utf-8
import logging
from charset_normalizer import detect
from common.handle.base_parse_table_handle import BaseParseTableHandle
max_kb = logging.getLogger("max_kb")
class CsvSplitHandle(BaseParseTableHandle):
def support(self, file, get_buffer):
file_name: str = file.name.lower()
if file_name.endswith(".csv"):
return True
return False
def handle(self, file, get_buffer):
buffer = get_buffer(file)
try:
content = buffer.decode(detect(buffer)['encoding'])
except BaseException as e:
max_kb.error(f'csv split handle error: {e}')
return [{'name': file.name, 'paragraphs': []}]
csv_model = content.split('\n')
paragraphs = []
# 第一行为标题
title = csv_model[0].split(',')
for row in csv_model[1:]:
line = '; '.join([f'{key}:{value}' for key, value in zip(title, row.split(','))])
paragraphs.append({'title': '', 'content': line})
return [{'name': file.name, 'paragraphs': paragraphs}]

View File

@ -0,0 +1,49 @@
# coding=utf-8
import io
import logging
from openpyxl import load_workbook
from common.handle.base_parse_table_handle import BaseParseTableHandle
max_kb = logging.getLogger("max_kb")
class ExcelSplitHandle(BaseParseTableHandle):
def support(self, file, get_buffer):
file_name: str = file.name.lower()
if file_name.endswith('.xls') or file_name.endswith('.xlsx'):
return True
return False
def handle(self, file, get_buffer):
buffer = get_buffer(file)
try:
wb = load_workbook(io.BytesIO(buffer))
result = []
for sheetname in wb.sheetnames:
paragraphs = []
ws = wb[sheetname]
rows = list(ws.rows)
if not rows: continue
ti = list(rows[0])
for r in list(rows[1:]):
title = []
l = []
for i, c in enumerate(r):
if not c.value:
continue
t = str(ti[i].value) if i < len(ti) else ""
title.append(t)
t += (": " if t else "") + str(c.value)
l.append(t)
l = "; ".join(l)
if sheetname.lower().find("sheet") < 0:
l += " ——" + sheetname
paragraphs.append({'title': '', 'content': l})
result.append({'name': sheetname, 'paragraphs': paragraphs})
except BaseException as e:
max_kb.error(f'excel split handle error: {e}')
return [{'name': file.name, 'paragraphs': []}]
return result

View File

@ -33,6 +33,8 @@ from common.handle.impl.pdf_split_handle import PdfSplitHandle
from common.handle.impl.qa.csv_parse_qa_handle import CsvParseQAHandle
from common.handle.impl.qa.xls_parse_qa_handle import XlsParseQAHandle
from common.handle.impl.qa.xlsx_parse_qa_handle import XlsxParseQAHandle
from common.handle.impl.table.csv_parse_table_handle import CsvSplitHandle
from common.handle.impl.table.excel_parse_table_handle import ExcelSplitHandle
from common.handle.impl.text_split_handle import TextSplitHandle
from common.mixins.api_mixin import ApiMixin
from common.util.common import post, flat_map
@ -51,6 +53,7 @@ from embedding.task.embedding import embedding_by_document, delete_embedding_by_
from smartdoc.conf import PROJECT_DIR
parse_qa_handle_list = [XlsParseQAHandle(), CsvParseQAHandle(), XlsxParseQAHandle()]
parse_table_handle_list = [CsvSplitHandle(), ExcelSplitHandle()]
class FileBufferHandle:
@ -152,6 +155,13 @@ class DocumentInstanceQASerializer(ApiMixin, serializers.Serializer):
error_messages=ErrMessage.file("文件")))
class DocumentInstanceTableSerializer(ApiMixin, serializers.Serializer):
file_list = serializers.ListSerializer(required=True,
error_messages=ErrMessage.list("文件列表"),
child=serializers.FileField(required=True,
error_messages=ErrMessage.file("文件")))
class DocumentSerializers(ApiMixin, serializers.Serializer):
class Export(ApiMixin, serializers.Serializer):
type = serializers.CharField(required=True, validators=[
@ -187,6 +197,23 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
return HttpResponse(content, status=200, headers={'Content-Type': 'application/vnd.ms-excel',
'Content-Disposition': 'attachment; filename="excel_template.xlsx"'})
def table_export(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
if self.data.get('type') == 'csv':
file = open(os.path.join(PROJECT_DIR, "apps", "dataset", 'template', 'MaxKB表格模板.csv'), "rb")
content = file.read()
file.close()
return HttpResponse(content, status=200, headers={'Content-Type': 'text/cxv',
'Content-Disposition': 'attachment; filename="csv_template.csv"'})
elif self.data.get('type') == 'excel':
file = open(os.path.join(PROJECT_DIR, "apps", "dataset", 'template', 'MaxKB表格模板.xlsx'), "rb")
content = file.read()
file.close()
return HttpResponse(content, status=200, headers={'Content-Type': 'application/vnd.ms-excel',
'Content-Disposition': 'attachment; filename="excel_template.xlsx"'})
class Migrate(ApiMixin, serializers.Serializer):
dataset_id = serializers.UUIDField(required=True,
error_messages=ErrMessage.char(
@ -633,6 +660,14 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
return parse_qa_handle.handle(file, get_buffer)
raise AppApiException(500, '不支持的文件格式')
@staticmethod
def parse_table_file(file):
get_buffer = FileBufferHandle().get_buffer
for parse_table_handle in parse_table_handle_list:
if parse_table_handle.support(file, get_buffer):
return parse_table_handle.handle(file, get_buffer)
raise AppApiException(500, '不支持的文件格式')
def save_qa(self, instance: Dict, with_valid=True):
if with_valid:
DocumentInstanceQASerializer(data=instance).is_valid(raise_exception=True)
@ -641,6 +676,14 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
document_list = flat_map([self.parse_qa_file(file) for file in file_list])
return DocumentSerializers.Batch(data={'dataset_id': self.data.get('dataset_id')}).batch_save(document_list)
def save_table(self, instance: Dict, with_valid=True):
if with_valid:
DocumentInstanceTableSerializer(data=instance).is_valid(raise_exception=True)
self.is_valid(raise_exception=True)
file_list = instance.get('file_list')
document_list = flat_map([self.parse_table_file(file) for file in file_list])
return DocumentSerializers.Batch(data={'dataset_id': self.data.get('dataset_id')}).batch_save(document_list)
@post(post_function=post_embedding)
@transaction.atomic
def save(self, instance: Dict, with_valid=False, **kwargs):

View File

@ -0,0 +1,13 @@
职务,报销类型,一线城市报销标准(元),二线城市报销标准(元),三线城市报销标准(元)
普通员工,住宿费,500,400,300
部门主管,住宿费,600,500,400
部门总监,住宿费,700,600,500
区域总经理,住宿费,800,700,600
普通员工,伙食费,50,40,30
部门主管,伙食费,50,40,30
部门总监,伙食费,50,40,30
区域总经理,伙食费,50,40,30
普通员工,交通费,50,40,30
部门主管,交通费,50,40,30
部门总监,交通费,50,40,30
区域总经理,交通费,50,40,30
1 职务 报销类型 一线城市报销标准(元) 二线城市报销标准(元) 三线城市报销标准(元)
2 普通员工 住宿费 500 400 300
3 部门主管 住宿费 600 500 400
4 部门总监 住宿费 700 600 500
5 区域总经理 住宿费 800 700 600
6 普通员工 伙食费 50 40 30
7 部门主管 伙食费 50 40 30
8 部门总监 伙食费 50 40 30
9 区域总经理 伙食费 50 40 30
10 普通员工 交通费 50 40 30
11 部门主管 交通费 50 40 30
12 部门总监 交通费 50 40 30
13 区域总经理 交通费 50 40 30

Binary file not shown.

View File

@ -16,8 +16,10 @@ urlpatterns = [
path('dataset/<str:dataset_id>/hit_test', views.Dataset.HitTest.as_view()),
path('dataset/<str:dataset_id>/document', views.Document.as_view(), name='document'),
path('dataset/document/template/export', views.Template.as_view()),
path('dataset/document/table_template/export', views.TableTemplate.as_view()),
path('dataset/<str:dataset_id>/document/web', views.WebDocument.as_view()),
path('dataset/<str:dataset_id>/document/qa', views.QaDocument.as_view()),
path('dataset/<str:dataset_id>/document/table', views.TableDocument.as_view()),
path('dataset/<str:dataset_id>/document/_bach', views.Document.Batch.as_view()),
path('dataset/<str:dataset_id>/document/batch_hit_handling', views.Document.BatchEditHitHandling.as_view()),
path('dataset/<str:dataset_id>/document/<int:current_page>/<int:page_size>', views.Document.Page.as_view()),

View File

@ -33,6 +33,17 @@ class Template(APIView):
def get(self, request: Request):
return DocumentSerializers.Export(data={'type': request.query_params.get('type')}).export(with_valid=True)
class TableTemplate(APIView):
authentication_classes = [TokenAuth]
@action(methods=['GET'], detail=False)
@swagger_auto_schema(operation_summary="获取表格模版",
operation_id="获取表格模版",
manual_parameters=DocumentSerializers.Export.get_request_params_api(),
tags=["知识库/文档"])
def get(self, request: Request):
return DocumentSerializers.Export(data={'type': request.query_params.get('type')}).table_export(with_valid=True)
class WebDocument(APIView):
authentication_classes = [TokenAuth]
@ -71,6 +82,24 @@ class QaDocument(APIView):
{'file_list': request.FILES.getlist('file')},
with_valid=True))
class TableDocument(APIView):
authentication_classes = [TokenAuth]
parser_classes = [MultiPartParser]
@action(methods=['POST'], detail=False)
@swagger_auto_schema(operation_summary="导入表格并创建文档",
operation_id="导入表格并创建文档",
manual_parameters=DocumentWebInstanceSerializer.get_request_params_api(),
responses=result.get_api_response(DocumentSerializers.Create.get_response_body_api()),
tags=["知识库/文档"])
@has_permissions(
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
dynamic_tag=k.get('dataset_id')))
def post(self, request: Request, dataset_id: str):
return result.success(
DocumentSerializers.Create(data={'dataset_id': dataset_id}).save_table(
{'file_list': request.FILES.getlist('file')},
with_valid=True))
class Document(APIView):
authentication_classes = [TokenAuth]

View File

@ -211,6 +211,19 @@ const postQADocument: (
return post(`${prefix}/${dataset_id}/document/qa`, data, undefined, loading)
}
/**
*
* @param
* file
*/
const postTableDocument: (
dataset_id: string,
data: any,
loading?: Ref<boolean>
) => Promise<Result<any>> = (dataset_id, data, loading) => {
return post(`${prefix}/${dataset_id}/document/table`, data, undefined, loading)
}
/**
*
* @param dataset_id,target_dataset_id,
@ -256,6 +269,18 @@ const exportQATemplate: (fileName: string, type: string, loading?: Ref<boolean>)
return exportExcel(fileName, `${prefix}/document/template/export`, { type }, loading)
}
/**
* table模版
* @param fileName,type,
*/
const exportTableTemplate: (fileName: string, type: string, loading?: Ref<boolean>) => void = (
fileName,
type,
loading
) => {
return exportExcel(fileName, `${prefix}/document/table_template/export`, { type }, loading)
}
/**
*
* @param document_name
@ -295,6 +320,8 @@ export default {
putMigrateMulDocument,
batchEditHitHandling,
exportQATemplate,
exportTableTemplate,
postQADocument,
postTableDocument,
exportDocument
}

View File

@ -39,6 +39,7 @@ export function fileType(name: string) {
*/
const typeList: any = {
txt: ['txt', 'pdf', 'docx', 'csv', 'md', 'html', 'PDF'],
table: ['xlsx', 'xls', 'csv'],
QA: ['xlsx', 'csv', 'xls']
}

View File

@ -78,6 +78,21 @@ async function next() {
router.push({ path: `/dataset/${id}/document` })
})
}
} else if (documentsType.value === 'table') {
let fd = new FormData()
documentsFiles.value.forEach((item: any) => {
if (item?.raw) {
fd.append('file', item?.raw)
}
})
if (id) {
// table
documentApi.postTableDocument(id as string, fd, loading).then((res) => {
MsgSuccess('提交成功')
clearStore()
router.push({ path: `/dataset/${id}/document` })
})
}
} else {
if (active.value++ > 2) active.value = 0
}

View File

@ -10,6 +10,7 @@
<el-form-item>
<el-radio-group v-model="form.fileType" @change="radioChange">
<el-radio value="txt">文本文件</el-radio>
<el-radio value="table">表格</el-radio>
<el-radio value="QA">QA 问答对</el-radio>
</el-radio-group>
</el-form-item>
@ -48,6 +49,42 @@
<el-divider direction="vertical" />
<el-button type="primary" link @click="downloadTemplate('csv')"> 下载 CSV 模板 </el-button>
</el-form-item>
<el-form-item prop="fileList" v-else-if="form.fileType === 'table'">
<el-upload
:webkitdirectory="false"
class="w-full mb-4"
drag
multiple
v-model:file-list="form.fileList"
action="#"
:auto-upload="false"
:show-file-list="false"
accept=".xlsx, .xls, .csv"
:limit="50"
:on-exceed="onExceed"
:on-change="fileHandleChange"
@click.prevent="handlePreview(false)"
>
<img src="@/assets/upload-icon.svg" alt="" />
<div class="el-upload__text">
<p>
拖拽文件至此上传或
<em class="hover" @click.prevent="handlePreview(false)"> 选择文件 </em>
<em class="hover" @click.prevent="handlePreview(true)"> 选择文件夹 </em>
</p>
<div class="upload__decoration">
<p>当前支持 EXCEL和CSV 格式文件</p>
<p>第一行必须是列标题且列标题必须是有意义的术语表中每条记录将作为一个分段</p>
<p>每次最多上传50个文档每个文档最大不能超过100MB</p>
</div>
</div>
</el-upload>
<el-button type="primary" link @click="downloadTableTemplate('excel')">
下载 Excel 模板
</el-button>
<el-divider direction="vertical" />
<el-button type="primary" link @click="downloadTableTemplate('csv')"> 下载 CSV 模板 </el-button>
</el-form-item>
<el-form-item prop="fileList" v-else>
<el-upload
:webkitdirectory="false"
@ -73,7 +110,7 @@
</p>
<div class="upload__decoration">
<p>
支持格式TXTMarkdownPDFDOCXHTML 每次最多上传50个文件每个文件不超过 100MB
支持格式TXTMarkdownPDFDOCXHTMLExcelCSV 每次最多上传50个文件每个文件不超过 100MB
</p>
<p>若使用高级分段建议上传前规范文件的分段标识</p>
</div>
@ -133,6 +170,10 @@ function downloadTemplate(type: string) {
documentApi.exportQATemplate(`${type}模版.${type == 'csv' ? type : 'xlsx'}`, type)
}
function downloadTableTemplate(type: string) {
documentApi.exportTableTemplate(`${type}模版.${type == 'csv' ? type : 'xlsx'}`, type)
}
function radioChange() {
form.value.fileList = []
}