diff --git a/apps/dataset/serializers/dataset_serializers.py b/apps/dataset/serializers/dataset_serializers.py index 028b3e62a..6894d9470 100644 --- a/apps/dataset/serializers/dataset_serializers.py +++ b/apps/dataset/serializers/dataset_serializers.py @@ -12,13 +12,15 @@ import re import traceback import uuid from functools import reduce -from typing import Dict +from typing import Dict, List from urllib.parse import urlparse +import xlwt from django.contrib.postgres.fields import ArrayField from django.core import validators from django.db import transaction, models from django.db.models import QuerySet +from django.http import HttpResponse from drf_yasg import openapi from rest_framework import serializers @@ -668,6 +670,50 @@ class DataSetSerializers(serializers.ModelSerializer): if not QuerySet(DataSet).filter(id=self.data.get("id")).exists(): raise AppApiException(300, "id不存在") + def export_excel(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + document_list = QuerySet(Document).filter(dataset_id=self.data.get('id')) + paragraph_list = native_search(QuerySet(Paragraph).filter(dataset_id=self.data.get("id")), get_file_content( + os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_paragraph_document_name.sql'))) + problem_mapping_list = native_search( + QuerySet(ProblemParagraphMapping).filter(dataset_id=self.data.get("id")), get_file_content( + os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_problem_mapping.sql')), + with_table_name=True) + data_dict, document_dict = DocumentSerializers.Operate.merge_problem(paragraph_list, problem_mapping_list, + document_list) + workbook = DocumentSerializers.Operate.get_workbook(data_dict, document_dict) + response = HttpResponse(content_type='application/vnd.ms-excel') + response['Content-Disposition'] = 'attachment; filename="dataset.xls"' + workbook.save(response) + return response + + @staticmethod + def merge_problem(paragraph_list: List[Dict], problem_mapping_list: List[Dict]): + result = {} + document_dict = {} + + for paragraph in paragraph_list: + problem_list = [problem_mapping.get('content') for problem_mapping in problem_mapping_list if + problem_mapping.get('paragraph_id') == paragraph.get('id')] + document_sheet = result.get(paragraph.get('document_id')) + d = document_dict.get(paragraph.get('document_name')) + if d is None: + document_dict[paragraph.get('document_name')] = {paragraph.get('document_id')} + else: + d.add(paragraph.get('document_id')) + + if document_sheet is None: + result[paragraph.get('document_id')] = [[paragraph.get('title'), paragraph.get('content'), + '\n'.join(problem_list)]] + else: + document_sheet.append([paragraph.get('title'), paragraph.get('content'), '\n'.join(problem_list)]) + result_document_dict = {} + for d_name in document_dict: + for index, d_id in enumerate(document_dict.get(d_name)): + result_document_dict[d_id] = d_name if index == 0 else d_name + str(index) + return result, result_document_dict + @transaction.atomic def delete(self): self.is_valid() diff --git a/apps/dataset/serializers/document_serializers.py b/apps/dataset/serializers/document_serializers.py index b596e7108..90b2701b4 100644 --- a/apps/dataset/serializers/document_serializers.py +++ b/apps/dataset/serializers/document_serializers.py @@ -14,12 +14,14 @@ import uuid from functools import reduce from typing import List, Dict +import xlwt from django.core import validators from django.db import transaction from django.db.models import QuerySet from django.http import HttpResponse from drf_yasg import openapi from rest_framework import serializers +from xlwt import Utils from common.db.search import native_search, native_page_search from common.event.common import work_thread_pool @@ -423,6 +425,85 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): if not QuerySet(Document).filter(id=document_id).exists(): raise AppApiException(500, "文档id不存在") + def export(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + document = QuerySet(Document).filter(id=self.data.get("document_id")).first() + paragraph_list = native_search(QuerySet(Paragraph).filter(document_id=self.data.get("document_id")), + get_file_content( + os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', + 'list_paragraph_document_name.sql'))) + problem_mapping_list = native_search( + QuerySet(ProblemParagraphMapping).filter(document_id=self.data.get("document_id")), get_file_content( + os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_problem_mapping.sql')), + with_table_name=True) + data_dict, document_dict = self.merge_problem(paragraph_list, problem_mapping_list, [document]) + workbook = self.get_workbook(data_dict, document_dict) + response = HttpResponse(content_type='application/vnd.ms-excel') + response['Content-Disposition'] = f'attachment; filename="data.xls"' + workbook.save(response) + return response + + @staticmethod + def get_workbook(data_dict, document_dict): + # 创建工作簿对象 + workbook = xlwt.Workbook(encoding='utf-8') + for sheet_id in data_dict: + # 添加工作表 + worksheet = workbook.add_sheet(document_dict.get(sheet_id)) + data = [ + ['分段标题(选填)', '分段内容(必填,问题答案,最长不超过4096个字符)', '问题(选填,单元格内一行一个)'], + *data_dict.get(sheet_id) + ] + # 写入数据到工作表 + for row_idx, row in enumerate(data): + for col_idx, col in enumerate(row): + worksheet.write(row_idx, col_idx, col) + # 创建HttpResponse对象返回Excel文件 + return workbook + + @staticmethod + def merge_problem(paragraph_list: List[Dict], problem_mapping_list: List[Dict], document_list): + result = {} + document_dict = {} + + for paragraph in paragraph_list: + problem_list = [problem_mapping.get('content') for problem_mapping in problem_mapping_list if + problem_mapping.get('paragraph_id') == paragraph.get('id')] + document_sheet = result.get(paragraph.get('document_id')) + document_name = DocumentSerializers.Operate.reset_document_name(paragraph.get('document_name')) + d = document_dict.get(document_name) + if d is None: + document_dict[document_name] = {paragraph.get('document_id')} + else: + d.add(paragraph.get('document_id')) + + if document_sheet is None: + result[paragraph.get('document_id')] = [[paragraph.get('title'), paragraph.get('content'), + '\n'.join(problem_list)]] + else: + document_sheet.append([paragraph.get('title'), paragraph.get('content'), '\n'.join(problem_list)]) + for document in document_list: + if document.id not in result: + document_name = DocumentSerializers.Operate.reset_document_name(document.name) + result[document.id] = [[]] + d = document_dict.get(document_name) + if d is None: + document_dict[document_name] = {document.id} + else: + d.add(document.id) + result_document_dict = {} + for d_name in document_dict: + for index, d_id in enumerate(document_dict.get(d_name)): + result_document_dict[d_id] = d_name if index == 0 else d_name + str(index) + return result, result_document_dict + + @staticmethod + def reset_document_name(document_name): + if document_name is None or not Utils.valid_sheet_name(document_name): + return "Sheet" + return document_name.strip() + def one(self, with_valid=False): if with_valid: self.is_valid(raise_exception=True) diff --git a/apps/dataset/sql/list_paragraph_document_name.sql b/apps/dataset/sql/list_paragraph_document_name.sql new file mode 100644 index 000000000..a95209bf5 --- /dev/null +++ b/apps/dataset/sql/list_paragraph_document_name.sql @@ -0,0 +1,5 @@ +SELECT + (SELECT "name" FROM "document" WHERE "id"=document_id) as document_name, + * +FROM + "paragraph" diff --git a/apps/dataset/sql/list_problem_mapping.sql b/apps/dataset/sql/list_problem_mapping.sql new file mode 100644 index 000000000..8c8ac3c30 --- /dev/null +++ b/apps/dataset/sql/list_problem_mapping.sql @@ -0,0 +1,2 @@ +SELECT "problem"."content",problem_paragraph_mapping.paragraph_id FROM problem problem +LEFT JOIN problem_paragraph_mapping problem_paragraph_mapping ON problem_paragraph_mapping.problem_id=problem."id" \ No newline at end of file diff --git a/apps/dataset/urls.py b/apps/dataset/urls.py index dd06ba254..4731bc830 100644 --- a/apps/dataset/urls.py +++ b/apps/dataset/urls.py @@ -8,6 +8,7 @@ urlpatterns = [ path('dataset/web', views.Dataset.CreateWebDataset.as_view()), path('dataset/qa', views.Dataset.CreateQADataset.as_view()), path('dataset/', views.Dataset.Operate.as_view(), name="dataset_key"), + path('dataset//export', views.Dataset.Export.as_view(), name="export"), path('dataset//re_embedding', views.Dataset.Embedding.as_view(), name="dataset_key"), path('dataset//application', views.Dataset.Application.as_view()), path('dataset//', views.Dataset.Page.as_view(), name="dataset"), @@ -27,6 +28,8 @@ urlpatterns = [ path('dataset/document/split_pattern', views.Document.SplitPattern.as_view(), name="document_operate"), path('dataset//document/migrate/', views.Document.Migrate.as_view()), + path('dataset//document//export', views.Document.Export.as_view(), + name="document_export"), path('dataset//document//sync', views.Document.SyncWeb.as_view()), path('dataset//document//refresh', views.Document.Refresh.as_view()), path('dataset//document//paragraph', views.Paragraph.as_view()), diff --git a/apps/dataset/views/dataset.py b/apps/dataset/views/dataset.py index 92e3c1451..e2bd10e09 100644 --- a/apps/dataset/views/dataset.py +++ b/apps/dataset/views/dataset.py @@ -152,6 +152,19 @@ class Dataset(APIView): return result.success( DataSetSerializers.Operate(data={'id': dataset_id, 'user_id': request.user.id}).re_embedding()) + class Export(APIView): + authentication_classes = [TokenAuth] + + @action(methods="GET", detail=False) + @swagger_auto_schema(operation_summary="导出知识库", operation_id="导出知识库", + manual_parameters=DataSetSerializers.Operate.get_request_params_api(), + tags=["知识库"] + ) + @has_permissions(lambda r, keywords: Permission(group=Group.DATASET, operate=Operate.MANAGE, + dynamic_tag=keywords.get('dataset_id'))) + def get(self, request: Request, dataset_id: str): + return DataSetSerializers.Operate(data={'id': dataset_id, 'user_id': request.user.id}).export_excel() + class Operate(APIView): authentication_classes = [TokenAuth] diff --git a/apps/dataset/views/document.py b/apps/dataset/views/document.py index 90849ce7e..dc76655d4 100644 --- a/apps/dataset/views/document.py +++ b/apps/dataset/views/document.py @@ -230,6 +230,20 @@ class Document(APIView): )) + class Export(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="导出文档", + operation_id="导出文档", + manual_parameters=DocumentSerializers.Operate.get_request_params_api(), + tags=["知识库/文档"]) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.USE, + dynamic_tag=k.get('dataset_id'))) + def get(self, request: Request, dataset_id: str, document_id: str): + return DocumentSerializers.Operate(data={'document_id': document_id, 'dataset_id': dataset_id}).export() + class Operate(APIView): authentication_classes = [TokenAuth] diff --git a/ui/src/api/dataset.ts b/ui/src/api/dataset.ts index 998370eae..df06d2768 100644 --- a/ui/src/api/dataset.ts +++ b/ui/src/api/dataset.ts @@ -1,5 +1,5 @@ import { Result } from '@/request/Result' -import { get, post, del, put } from '@/request/index' +import { get, post, del, put, exportExcel } from '@/request/index' import type { datasetData } from '@/api/type/dataset' import type { pageRequest } from '@/api/type/common' import type { ApplicationFormType } from '@/api/type/application' @@ -187,6 +187,20 @@ const putReEmbeddingDataset: ( return put(`${prefix}/${dataset_id}/re_embedding`, undefined, undefined, loading) } +/** + * 导出知识库 + * @param dataset_name 知识库名称 + * @param dataset_id 知识库id + * @returns + */ +const exportDataset: ( + dataset_name: string, + dataset_id: string, + loading?: Ref +) => Promise = (dataset_name, dataset_id, loading) => { + return exportExcel(dataset_name + '.xls', `dataset/${dataset_id}/export`, undefined, loading) +} + export default { getDataset, getAllDataset, @@ -199,5 +213,6 @@ export default { postWebDataset, putSyncWebDataset, putReEmbeddingDataset, - postQADataset + postQADataset, + exportDataset } diff --git a/ui/src/api/document.ts b/ui/src/api/document.ts index 413c1f6c9..5bf294100 100644 --- a/ui/src/api/document.ts +++ b/ui/src/api/document.ts @@ -256,6 +256,28 @@ const exportQATemplate: (fileName: string, type: string, loading?: Ref) return exportExcel(fileName, `${prefix}/document/template/export`, { type }, loading) } +/** + * 导出文档 + * @param document_name 文档名称 + * @param dataset_id 数据集id + * @param document_id 文档id + * @param loading 加载器 + * @returns + */ +const exportDocument: ( + document_name: string, + dataset_id: string, + document_id: string, + loading?: Ref +) => Promise = (document_name, dataset_id, document_id, loading) => { + return exportExcel( + document_name + '.xls', + `${prefix}/${dataset_id}/document/${document_id}/export`, + {}, + loading + ) +} + export default { postSplitDocument, getDocument, @@ -273,5 +295,6 @@ export default { putMigrateMulDocument, batchEditHitHandling, exportQATemplate, - postQADocument + postQADocument, + exportDocument } diff --git a/ui/src/components/icons/index.ts b/ui/src/components/icons/index.ts index 1965fe3ba..6b83adf9a 100644 --- a/ui/src/components/icons/index.ts +++ b/ui/src/components/icons/index.ts @@ -910,5 +910,30 @@ export const iconMap: any = { ) ]) } + }, + 'app-export': { + iconReader: () => { + return h('i', [ + h( + 'svg', + { + style: { height: '100%', width: '100%' }, + viewBox: '0 0 1024 1024', + version: '1.1', + xmlns: 'http://www.w3.org/2000/svg' + }, + [ + h('path', { + d: 'M142.859375 854.80357107c0 14.52455392 11.8125 26.33705392 26.39732108 26.33705393h263.67187499a26.33705392 26.33705392 0 1 1 1e-8 52.734375H116.52232107A26.33705392 26.33705392 0 0 1 90.125 907.47767893V116.52232107C90.125 101.9375 101.9375 90.125 116.52232107 90.125h790.95535786c14.58482107 0 26.39732108 11.8125 26.39732107 26.39732108v316.40624999a26.33705392 26.33705392 0 1 1-52.734375 1e-8V169.25669607a26.33705392 26.33705392 0 0 0-26.39732107-26.39732107H169.25669607a26.33705392 26.33705392 0 0 0-26.39732107 26.39732108v685.48660785z', + fill: 'currentColor' + }), + h('path', { + d: 'M797.97098203 650.85714298H274.72544608a26.33705392 26.33705392 0 1 0-1e-8 52.734375h523.90848285L638.20089298 863.96428595a26.33705392 26.33705392 0 0 0 37.30580309 37.24553513l205.09151799-205.09151715a26.39732108 26.39732108 0 0 0 0-37.24553596L675.50669608 453.78125a26.39732108 26.39732108 0 1 0-37.30580311 37.36607108l159.77008905 159.7098219z', + fill: 'currentColor' + }) + ] + ) + ]) + } } } diff --git a/ui/src/request/index.ts b/ui/src/request/index.ts index 7d748ffd8..354bf639f 100644 --- a/ui/src/request/index.ts +++ b/ui/src/request/index.ts @@ -211,8 +211,13 @@ export const exportExcel: ( url: string, params: any, loading?: NProgress | Ref -) => void = (fileName: string, url: string, params: any, loading?: NProgress | Ref) => { - promise(request({ url: url, method: 'get', params, responseType: 'blob' }), loading) +) => Promise = ( + fileName: string, + url: string, + params: any, + loading?: NProgress | Ref +) => { + return promise(request({ url: url, method: 'get', params, responseType: 'blob' }), loading) .then((res: any) => { if (res) { const blob = new Blob([res], { @@ -225,6 +230,7 @@ export const exportExcel: ( //释放内存 window.URL.revokeObjectURL(link.href) } + return true }) .catch((e) => {}) } diff --git a/ui/src/views/dataset/index.vue b/ui/src/views/dataset/index.vue index 478c5a3ee..c4bec6a79 100644 --- a/ui/src/views/dataset/index.vue +++ b/ui/src/views/dataset/index.vue @@ -85,7 +85,12 @@ 设置 + 设置 + + 导出 删除 { + datasetApi.exportDataset(item.name, item.id, loading).then((ok) => { + MsgSuccess('导出成功') + }) +} function deleteDataset(row: any) { MsgConfirm( diff --git a/ui/src/views/document/index.vue b/ui/src/views/document/index.vue index 6b1505f9c..7d8be9711 100644 --- a/ui/src/views/document/index.vue +++ b/ui/src/views/document/index.vue @@ -165,6 +165,10 @@ 迁移 + + + 导出 + 删除 @@ -200,6 +204,10 @@ 迁移 + + + 导出 + 删除 @@ -281,7 +289,11 @@ const multipleSelection = ref([]) const title = ref('') const SelectDatasetDialogRef = ref() - +const exportDocument = (document: any) => { + documentApi.exportDocument(document.name, document.dataset_id, document.id, loading).then(() => { + MsgSuccess('导出成功') + }) +} function openDatasetDialog(row?: any) { const arr: string[] = [] if (row) {