From 28938104c03c3773864796007ccf0b2261c8c123 Mon Sep 17 00:00:00 2001 From: shaohuzhang1 <80892890+shaohuzhang1@users.noreply.github.com> Date: Thu, 23 May 2024 18:57:49 +0800 Subject: [PATCH] =?UTF-8?q?*=20feat:=20=E6=94=AF=E6=8C=81=E4=B8=8A?= =?UTF-8?q?=E4=BC=A0=20Excel/CSV=20=E7=B1=BB=E5=9E=8B=E7=9A=84=E9=97=AE?= =?UTF-8?q?=E7=AD=94=E5=AF=B9=20(#430)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/common/handle/base_parse_qa_handle.py | 19 ++++ .../handle/impl/qa/csv_parse_qa_handle.py | 56 ++++++++++ .../handle/impl/qa/xls_parse_qa_handle.py | 55 ++++++++++ .../handle/impl/qa/xlsx_parse_qa_handle.py | 56 ++++++++++ apps/common/util/common.py | 12 +++ apps/common/util/field_message.py | 10 ++ .../serializers/dataset_serializers.py | 82 ++++++++++++++- .../serializers/document_serializers.py | 99 ++++++++++++++++-- apps/dataset/template/csv_template.csv | 8 ++ apps/dataset/template/excel_template.xlsx | Bin 0 -> 10639 bytes apps/dataset/urls.py | 3 + apps/dataset/views/dataset.py | 22 ++++ apps/dataset/views/document.py | 32 ++++++ pyproject.toml | 2 + ui/src/api/dataset.ts | 17 +++ ui/src/api/document.ts | 33 +++++- ui/src/assets/xls-icon.svg | 5 + ui/src/assets/xlsx-icon.svg | 5 + ui/src/stores/modules/dataset.ts | 5 + ui/src/utils/utils.ts | 14 ++- ui/src/views/dataset/CreateDataset.vue | 31 +++++- .../dataset/component/UploadComponent.vue | 66 +++++++++++- ui/src/views/dataset/step/StepFirst.vue | 3 + 23 files changed, 611 insertions(+), 24 deletions(-) create mode 100644 apps/common/handle/base_parse_qa_handle.py create mode 100644 apps/common/handle/impl/qa/csv_parse_qa_handle.py create mode 100644 apps/common/handle/impl/qa/xls_parse_qa_handle.py create mode 100644 apps/common/handle/impl/qa/xlsx_parse_qa_handle.py create mode 100644 apps/dataset/template/csv_template.csv create mode 100644 apps/dataset/template/excel_template.xlsx create mode 100644 ui/src/assets/xls-icon.svg create mode 100644 ui/src/assets/xlsx-icon.svg diff --git a/apps/common/handle/base_parse_qa_handle.py b/apps/common/handle/base_parse_qa_handle.py new file mode 100644 index 000000000..948b8e536 --- /dev/null +++ b/apps/common/handle/base_parse_qa_handle.py @@ -0,0 +1,19 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: base_parse_qa_handle.py + @date:2024/5/21 14:56 + @desc: +""" +from abc import ABC, abstractmethod + + +class BaseParseQAHandle(ABC): + @abstractmethod + def support(self, file, get_buffer): + pass + + @abstractmethod + def handle(self, file, get_buffer): + pass diff --git a/apps/common/handle/impl/qa/csv_parse_qa_handle.py b/apps/common/handle/impl/qa/csv_parse_qa_handle.py new file mode 100644 index 000000000..6952bf8d8 --- /dev/null +++ b/apps/common/handle/impl/qa/csv_parse_qa_handle.py @@ -0,0 +1,56 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: csv_parse_qa_handle.py + @date:2024/5/21 14:59 + @desc: +""" +import csv +import io + +from charset_normalizer import detect + +from common.handle.base_parse_qa_handle import BaseParseQAHandle + + +def read_csv_standard(file_path): + data = [] + with open(file_path, 'r') as file: + reader = csv.reader(file) + for row in reader: + data.append(row) + return data + + +class CsvParseQAHandle(BaseParseQAHandle): + def support(self, file, get_buffer): + file_name: str = file.name.lower() + if file_name.endswith(".csv"): + return True + return False + + def handle(self, file, get_buffer): + buffer = get_buffer(file) + reader = csv.reader(io.TextIOWrapper(io.BytesIO(buffer), encoding=detect(buffer)['encoding'])) + try: + title_row_list = reader.__next__() + except Exception as e: + return [] + title_row_index_dict = {'title': 0, 'content': 1, 'problem_list': 2} + for index in range(len(title_row_list)): + title_row = title_row_list[index] + if title_row.startswith('分段标题'): + title_row_index_dict['title'] = index + if title_row.startswith('分段内容'): + title_row_index_dict['content'] = index + if title_row.startswith('问题'): + title_row_index_dict['problem_list'] = index + paragraph_list = [] + for row in reader: + problem = row[title_row_index_dict.get('problem_list')] + problem_list = [{'content': p[0:255]} for p in problem.split('\n') if len(p.strip()) > 0] + paragraph_list.append({'title': row[title_row_index_dict.get('title')][0:255], + 'content': row[title_row_index_dict.get('content')][0:4096], + 'problem_list': problem_list}) + return [{'name': file.name, 'paragraphs': paragraph_list}] diff --git a/apps/common/handle/impl/qa/xls_parse_qa_handle.py b/apps/common/handle/impl/qa/xls_parse_qa_handle.py new file mode 100644 index 000000000..090745141 --- /dev/null +++ b/apps/common/handle/impl/qa/xls_parse_qa_handle.py @@ -0,0 +1,55 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: xls_parse_qa_handle.py + @date:2024/5/21 14:59 + @desc: +""" + +import xlrd + +from common.handle.base_parse_qa_handle import BaseParseQAHandle + + +def handle_sheet(file_name, sheet): + rows = iter([sheet.row_values(i) for i in range(sheet.nrows)]) + try: + title_row_list = next(rows) + except Exception as e: + return None + title_row_index_dict = {'title': 0, 'content': 1, 'problem_list': 2} + for index in range(len(title_row_list)): + title_row = str(title_row_list[index]) + if title_row.startswith('分段标题'): + title_row_index_dict['title'] = index + if title_row.startswith('分段内容'): + title_row_index_dict['content'] = index + if title_row.startswith('问题'): + title_row_index_dict['problem_list'] = index + paragraph_list = [] + for row in rows: + problem = str(row[title_row_index_dict.get('problem_list')]) + problem_list = [{'content': p[0:255]} for p in problem.split('\n') if len(p.strip()) > 0] + paragraph_list.append({'title': str(row[title_row_index_dict.get('title')])[0:255], + 'content': str(row[title_row_index_dict.get('content')])[0:4096], + 'problem_list': problem_list}) + return {'name': file_name, 'paragraphs': paragraph_list} + + +class XlsParseQAHandle(BaseParseQAHandle): + def support(self, file, get_buffer): + file_name: str = file.name.lower() + if file_name.endswith(".xls"): + return True + return False + + def handle(self, file, get_buffer): + buffer = get_buffer(file) + workbook = xlrd.open_workbook(file_contents=buffer) + worksheets = workbook.sheets() + worksheets_size = len(worksheets) + return [row for row in + [handle_sheet(file.name, sheet) if worksheets_size == 1 and sheet.name == 'Sheet1' else handle_sheet( + sheet.name, sheet) for sheet + in worksheets] if row is not None] diff --git a/apps/common/handle/impl/qa/xlsx_parse_qa_handle.py b/apps/common/handle/impl/qa/xlsx_parse_qa_handle.py new file mode 100644 index 000000000..9ed0ac06a --- /dev/null +++ b/apps/common/handle/impl/qa/xlsx_parse_qa_handle.py @@ -0,0 +1,56 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: xlsx_parse_qa_handle.py + @date:2024/5/21 14:59 + @desc: +""" +import io + +import openpyxl + +from common.handle.base_parse_qa_handle import BaseParseQAHandle + + +def handle_sheet(file_name, sheet): + rows = sheet.rows + try: + title_row_list = next(rows) + except Exception as e: + return None + title_row_index_dict = {} + for index in range(len(title_row_list)): + title_row = str(title_row_list[index].value) + if title_row.startswith('分段标题'): + title_row_index_dict['title'] = index + if title_row.startswith('分段内容'): + title_row_index_dict['content'] = index + if title_row.startswith('问题'): + title_row_index_dict['problem_list'] = index + paragraph_list = [] + for row in rows: + problem = str(row[title_row_index_dict.get('problem_list')].value) + problem_list = [{'content': p[0:255]} for p in problem.split('\n') if len(p.strip()) > 0] + paragraph_list.append({'title': str(row[title_row_index_dict.get('title')].value)[0:255], + 'content': str(row[title_row_index_dict.get('content')].value)[0:4096], + 'problem_list': problem_list}) + return {'name': file_name, 'paragraphs': paragraph_list} + + +class XlsxParseQAHandle(BaseParseQAHandle): + def support(self, file, get_buffer): + file_name: str = file.name.lower() + if file_name.endswith(".xlsx"): + return True + return False + + def handle(self, file, get_buffer): + buffer = get_buffer(file) + workbook = openpyxl.load_workbook(io.BytesIO(buffer)) + worksheets = workbook.worksheets + worksheets_size = len(worksheets) + return [row for row in + [handle_sheet(file.name, sheet) if worksheets_size == 1 and sheet.title == 'Sheet1' else handle_sheet( + sheet.title, sheet) for sheet + in worksheets] if row is not None] diff --git a/apps/common/util/common.py b/apps/common/util/common.py index 52d90ec85..9ba50c671 100644 --- a/apps/common/util/common.py +++ b/apps/common/util/common.py @@ -45,6 +45,18 @@ def get_exec_method(clazz_: str, method_: str): return getattr(getattr(package_model, clazz_name), method_) +def flat_map(array: List[List]): + """ + 将二位数组转为一维数组 + :param array: 二维数组 + :return: 一维数组 + """ + result = [] + for e in array: + result += e + return result + + def post(post_function): def inner(func): def run(*args, **kwargs): diff --git a/apps/common/util/field_message.py b/apps/common/util/field_message.py index 93b51b920..3d3945400 100644 --- a/apps/common/util/field_message.py +++ b/apps/common/util/field_message.py @@ -104,3 +104,13 @@ class ErrMessage: 'invalid_image': gettext_lazy('【%s】上载有效的图像。您上载的文件不是图像或图像已损坏。' % field), 'max_length': gettext_lazy('请确保此文件名最多包含 {max_length} 个字符(长度为 {length})。') } + + @staticmethod + def file(field: str): + return { + 'required': gettext_lazy('【%s】此字段必填。' % field), + 'empty': gettext_lazy('【%s】提交的文件为空。' % field), + 'invalid': gettext_lazy('【%s】提交的数据不是文件。请检查表单上的编码类型。' % field), + 'no_name': gettext_lazy('【%s】无法确定任何文件名。' % field), + 'max_length': gettext_lazy('请确保此文件名最多包含 {max_length} 个字符(长度为 {length})。') + } diff --git a/apps/dataset/serializers/dataset_serializers.py b/apps/dataset/serializers/dataset_serializers.py index d3f5af73a..61d4b1a3a 100644 --- a/apps/dataset/serializers/dataset_serializers.py +++ b/apps/dataset/serializers/dataset_serializers.py @@ -18,7 +18,7 @@ from urllib.parse import urlparse from django.contrib.postgres.fields import ArrayField from django.core import validators from django.db import transaction, models -from django.db.models import QuerySet, Q +from django.db.models import QuerySet from drf_yasg import openapi from rest_framework import serializers @@ -29,7 +29,7 @@ from common.db.sql_execute import select_list from common.event import ListenerManagement, SyncWebDatasetArgs from common.exception.app_exception import AppApiException from common.mixins.api_mixin import ApiMixin -from common.util.common import post +from common.util.common import post, flat_map from common.util.field_message import ErrMessage from common.util.file_util import get_file_content from common.util.fork import ChildLink, Fork @@ -210,6 +210,75 @@ class DataSetSerializers(serializers.ModelSerializer): super().is_valid(raise_exception=True) return True + class CreateQASerializers(serializers.Serializer): + """ + 创建web站点序列化对象 + """ + name = serializers.CharField(required=True, + error_messages=ErrMessage.char("知识库名称"), + max_length=64, + min_length=1) + + desc = serializers.CharField(required=True, + error_messages=ErrMessage.char("知识库描述"), + max_length=256, + min_length=1) + + file_list = serializers.ListSerializer(required=True, + error_messages=ErrMessage.list("文件列表"), + child=serializers.FileField(required=True, + error_messages=ErrMessage.file("文件"))) + + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='file', + in_=openapi.IN_FORM, + type=openapi.TYPE_ARRAY, + items=openapi.Items(type=openapi.TYPE_FILE), + required=True, + description='上传文件'), + openapi.Parameter(name='name', + in_=openapi.IN_FORM, + required=True, + type=openapi.TYPE_STRING, title="知识库名称", description="知识库名称"), + openapi.Parameter(name='desc', + in_=openapi.IN_FORM, + required=True, + type=openapi.TYPE_STRING, title="知识库描述", description="知识库描述"), + ] + + @staticmethod + def get_response_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['id', 'name', 'desc', 'user_id', 'char_length', 'document_count', + 'update_time', 'create_time', 'document_list'], + properties={ + 'id': openapi.Schema(type=openapi.TYPE_STRING, title="id", + description="id", default="xx"), + 'name': openapi.Schema(type=openapi.TYPE_STRING, title="名称", + description="名称", default="测试知识库"), + 'desc': openapi.Schema(type=openapi.TYPE_STRING, title="描述", + description="描述", default="测试知识库描述"), + 'user_id': openapi.Schema(type=openapi.TYPE_STRING, title="所属用户id", + description="所属用户id", default="user_xxxx"), + 'char_length': openapi.Schema(type=openapi.TYPE_STRING, title="字符数", + description="字符数", default=10), + 'document_count': openapi.Schema(type=openapi.TYPE_STRING, title="文档数量", + description="文档数量", default=1), + 'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间", + description="修改时间", + default="1970-01-01 00:00:00"), + 'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间", + description="创建时间", + default="1970-01-01 00:00:00" + ), + 'document_list': openapi.Schema(type=openapi.TYPE_ARRAY, title="文档列表", + description="文档列表", + items=DocumentSerializers.Operate.get_response_body_api()) + } + ) + class CreateWebSerializers(serializers.Serializer): """ 创建web站点序列化对象 @@ -288,6 +357,15 @@ class DataSetSerializers(serializers.ModelSerializer): ListenerManagement.embedding_by_dataset_signal.send(dataset_id) return document_list + def save_qa(self, instance: Dict, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + self.CreateQASerializers(data=instance).is_valid() + file_list = instance.get('file_list') + document_list = flat_map([DocumentSerializers.Create.parse_qa_file(file) for file in file_list]) + dataset_instance = {'name': instance.get('name'), 'desc': instance.get('desc'), 'documents': document_list} + return self.save(dataset_instance, with_valid=True) + @post(post_function=post_embedding_dataset) @transaction.atomic def save(self, instance: Dict, with_valid=True): diff --git a/apps/dataset/serializers/document_serializers.py b/apps/dataset/serializers/document_serializers.py index c5f88e336..5d6518ad5 100644 --- a/apps/dataset/serializers/document_serializers.py +++ b/apps/dataset/serializers/document_serializers.py @@ -17,6 +17,7 @@ from typing import List, Dict from django.core import validators from django.db import transaction from django.db.models import QuerySet +from django.http import HttpResponse from drf_yasg import openapi from rest_framework import serializers @@ -27,9 +28,12 @@ from common.exception.app_exception import AppApiException from common.handle.impl.doc_split_handle import DocSplitHandle from common.handle.impl.html_split_handle import HTMLSplitHandle from common.handle.impl.pdf_split_handle import PdfSplitHandle +from common.handle.impl.qa.csv_parse_qa_handle import CsvParseQAHandle +from common.handle.impl.qa.xls_parse_qa_handle import XlsParseQAHandle +from common.handle.impl.qa.xlsx_parse_qa_handle import XlsxParseQAHandle from common.handle.impl.text_split_handle import TextSplitHandle from common.mixins.api_mixin import ApiMixin -from common.util.common import post +from common.util.common import post, flat_map from common.util.field_message import ErrMessage from common.util.file_util import get_file_content from common.util.fork import Fork @@ -39,6 +43,17 @@ from dataset.serializers.common_serializers import BatchSerializer, MetaSerializ from dataset.serializers.paragraph_serializers import ParagraphSerializers, ParagraphInstanceSerializer from smartdoc.conf import PROJECT_DIR +parse_qa_handle_list = [XlsParseQAHandle(), CsvParseQAHandle(), XlsxParseQAHandle()] + + +class FileBufferHandle: + buffer = None + + def get_buffer(self, file): + if self.buffer is None: + self.buffer = file.read() + return self.buffer + class DocumentEditInstanceSerializer(ApiMixin, serializers.Serializer): meta = serializers.DictField(required=False) @@ -87,16 +102,19 @@ class DocumentWebInstanceSerializer(ApiMixin, serializers.Serializer): "选择器")) @staticmethod - def get_request_body_api(): - return openapi.Schema( - type=openapi.TYPE_OBJECT, - required=['source_url_list'], - properties={ - 'source_url_list': openapi.Schema(type=openapi.TYPE_ARRAY, title="段落列表", description="段落列表", - items=openapi.Schema(type=openapi.TYPE_STRING)), - 'selector': openapi.Schema(type=openapi.TYPE_STRING, title="文档名称", description="文档名称") - } - ) + def get_request_params_api(): + return [openapi.Parameter(name='file', + in_=openapi.IN_FORM, + type=openapi.TYPE_ARRAY, + items=openapi.Items(type=openapi.TYPE_FILE), + required=True, + description='上传文件'), + openapi.Parameter(name='dataset_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='知识库id'), + ] class DocumentInstanceSerializer(ApiMixin, serializers.Serializer): @@ -120,7 +138,48 @@ class DocumentInstanceSerializer(ApiMixin, serializers.Serializer): ) +class DocumentInstanceQASerializer(ApiMixin, serializers.Serializer): + file_list = serializers.ListSerializer(required=True, + error_messages=ErrMessage.list("文件列表"), + child=serializers.FileField(required=True, + error_messages=ErrMessage.file("文件"))) + + class DocumentSerializers(ApiMixin, serializers.Serializer): + class Export(ApiMixin, serializers.Serializer): + type = serializers.CharField(required=True, validators=[ + validators.RegexValidator(regex=re.compile("^csv|excel$"), + message="模版类型只支持excel|csv", + code=500) + ], error_messages=ErrMessage.char("模版类型")) + + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='type', + in_=openapi.IN_QUERY, + type=openapi.TYPE_STRING, + required=True, + description='导出模板类型csv|excel'), + + ] + + def export(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + + if self.data.get('type') == 'csv': + file = open(os.path.join(PROJECT_DIR, "apps", "dataset", 'template', 'csv_template.csv'), "rb") + content = file.read() + file.close() + return HttpResponse(content, status=200, headers={'Content-Type': 'text/cxv', + 'Content-Disposition': 'attachment; filename="csv_template.csv"'}) + elif self.data.get('type') == 'excel': + file = open(os.path.join(PROJECT_DIR, "apps", "dataset", 'template', 'excel_template.xlsx'), "rb") + content = file.read() + file.close() + return HttpResponse(content, status=200, headers={'Content-Type': 'application/vnd.ms-excel', + 'Content-Disposition': 'attachment; filename="excel_template.xlsx"'}) + class Migrate(ApiMixin, serializers.Serializer): dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char( @@ -473,6 +532,22 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): ListenerManagement.embedding_by_document_signal.send(document_id) return result + @staticmethod + def parse_qa_file(file): + for parse_qa_handle in parse_qa_handle_list: + get_buffer = FileBufferHandle().get_buffer + if parse_qa_handle.support(file, get_buffer): + return parse_qa_handle.handle(file, get_buffer) + raise AppApiException(500, '不支持的文件格式') + + def save_qa(self, instance: Dict, with_valid=True): + if with_valid: + DocumentInstanceQASerializer(data=instance).is_valid(raise_exception=True) + self.is_valid(raise_exception=True) + file_list = instance.get('file_list') + document_list = flat_map([self.parse_qa_file(file) for file in file_list]) + return DocumentSerializers.Batch(data={'dataset_id': self.data.get('dataset_id')}).batch_save(document_list) + @post(post_function=post_embedding) @transaction.atomic def save(self, instance: Dict, with_valid=False, **kwargs): @@ -714,6 +789,8 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): problem_paragraph_mapping_list) > 0 else None # 查询文档 query_set = QuerySet(model=Document) + if len(document_model_list) == 0: + return [], query_set = query_set.filter(**{'id__in': [d.id for d in document_model_list]}) return native_search(query_set, select_string=get_file_content( os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_document.sql')), with_search_one=False), diff --git a/apps/dataset/template/csv_template.csv b/apps/dataset/template/csv_template.csv new file mode 100644 index 000000000..b306a9c17 --- /dev/null +++ b/apps/dataset/template/csv_template.csv @@ -0,0 +1,8 @@ +分段标题(选填),分段内容(必填,问题答案,最长不超过4096个字符)),问题(选填,单元格内一行一个) +MaxKB产品介绍,"MaxKB 是一款基于 LLM 大语言模型的知识库问答系统。MaxKB = Max Knowledge Base,旨在成为企业的最强大脑。 +开箱即用:支持直接上传文档、自动爬取在线文档,支持文本自动拆分、向量化,智能问答交互体验好; +无缝嵌入:支持零编码快速嵌入到第三方业务系统; +多模型支持:支持对接主流的大模型,包括 Ollama 本地私有大模型(如 Llama 2、Llama 3、qwen)、通义千问、OpenAI、Azure OpenAI、Kimi、智谱 AI、讯飞星火和百度千帆大模型等。","MaxKB是什么? +MaxKB产品介绍 +MaxKB支持的大语言模型 +MaxKB优势" diff --git a/apps/dataset/template/excel_template.xlsx b/apps/dataset/template/excel_template.xlsx new file mode 100644 index 0000000000000000000000000000000000000000..6517b154ffeea0a2136d98192026cf0c8655bcd4 GIT binary patch literal 10639 zcma)iWmp}{wk5D+ym5D>uM#dNK$>6|Ss(xN*h!5Ps+u6>>WpMldD$dpNxAwfhNUW#9AcBF9= zB?Qm#)SSwDm6#4T*f4$dM){^lb@vP``+ zan@IRgVq(bBEdpgC{d=6plLoizSbULVoJD`6rIT0SsMKPLHXl}pZX)=zFgIv$rM}C zO++~M*e4HRIXbvU`3lY*+OuK511Nw=4Aa1wF3;1SwC4qfOYu5!rs5#somuKLVZ&W0 z2N+K>-ncqHHh6B9HtyaL4KuNMH@0maKkgE)kKA3BM~|9(7!IpchHOzmh=*hiWLJU>L}9iTGbN>Ygw$Z1XsS|fstt<9q#(W2r+3hZ_{IM&u2$&92zc^d>VLSQ|S8J zOT3wTc)R$8iD7v9c&X{DoA1p0he8d`6KOVRN(K*xN1O5)Zlq_pQU3wAgPlFl@&)j`C{3wu2DE@<$S0uz zkK_nCO|FITj|e?*kPh|o3-;7C5)xDUfTw0`yAaHkz3p`#Ztml^r?(Dx9ReoP)JgmY z$S{5u3X8*OAL6n-7^YIT>Hx#DYbSPe226&r44Z?Y`u;jLjzQ5xzU)pRYATd~YH z8&#TeKe(2b;a(z^lpo^C6&n%w*CZmNtQoN>g_+!Tbut4TDm_K&ir;=sI@+|cZ8E)w z<@2$s@kr^YkH$0{gcU0dc_j=xDz_yjB`heYG8Q5ot9T4}S6~R`T`H#hkU!5k*I6F0 z->Z@)PY_K3M|sIC_{k6hDn%~}+FRQ&^>d-C)&!PkMgNcN)Az(rukG)wcU|l3^v@d! z7s@91?y=V4Ay>LcvA<`d_j?RTzJ64uwUlO(KAYw&@e-sn)8kc{Yb)O7TLRW1d;X zI79@hf<+7e#?`TtSK=E<#Y@P=?j&a8F|P#n@i0>83DLs3T$gA$Fk0m}lX|j;hr{)Q zWE6c>b;cq!ihVdN*XrAQ)34MI296-{L_+40P1SYkxkc+-CW#UX(6rxYB-uwc3Rd#- zr#q-JW(H~oKASEuOip!FP7tokI?!?~QMUmzYrTbi4KOZK()LC=t~5M?DyB;NwV-*( zQI5N1wubvLZ*e+nnTqRRm8EfVH67o+-!=ANyZ6~csK&|9xLQU>ahhKe71z~vZ}ugs ztl;)oCHF8!Ss_qN#flx5uRLt6DsD86pPZd~zb02uVn{Y2!tR=#MYnI+rV>HX$FmXP zwUXL`I?To=09G%%E|=AV?;R%p9`OgdP}1fVNdG_sI<$&!ELm&HI_2%I`5HBbxW51N zb=HC}A4^2Cf^2nbjXBv<$k2C5&i5g5Y$;=bafp$G#ZO>*&DDh#$!l4PZqy=ht7{-A z>{4o}{gKnXvT({MaGR1i>)%1h5Zn+E0{6=XpJi*e7)iWJ|Ufax%gMZY!@iLoZa zx4UN34z;#g$^s&Pe;un6-npk6M>`E39z=-jln``}SJftrh8&e$T8MyuwmFoYb^J+) zJEWU61Xo0l{)vugqA3U~gJdotNJ2O@SO^8Ug)$VxYeMGR(N$nrqsvYc>iGTHieF5d6l8Ge$-69_ zk-!fZVk>*k_By5^V0Bol9_N$k#%c-5lxNwatXjuakwxNzd{` zpbI`1j)we{4mumGVo%%i48;7*6i%&t+UoG-6r+x2niX7K&Dc5hR) zjWEqO?BtWad~6f@P#^auun7JAleO~(4?dykQYC?Pg_q>blh_#d=?^c$wy>IZMVJdG zgB!^-E;|oqah7?WBK_aw_7XrnlbgMXp{3y;0hHhE9sG}H%4&O_{g>0_FR`D(KN;&t zbuQ^?sgy6wK zq;#L*3r;LYoP9`J&IZSVa+oJlCem!tW*t$4rx?WLz|4-*FS;K5#}FZVi_pw;#9iK- z_E2|B*5vYjYBZ4}NPWEkBXAv*c6kbp4TS~^ONgVJ>$yJm7>8ez04LWdFc*DzpiP$m z7n*O#L!JR2B0(KzmggwD_XULNi%ptq!CVpgEei-j;97oF?_{HTMNRVCIj|(Li0@-o zjY|$E@sCpGY)-?YzJ=e`SO@khO4ko(jhZ($SO?Zh%zsG9QJX2@J>VFe)Uvu4sF^Bq zA!_;RNlxwKaE1)k%CL%IMl$9EYB{%lI3TF8`+)*!K~b2Ok5#SU8XTW_7wWWLA}B{H zfivI|&~E}$V@WItl?owdfFi{Tz0f^yYZVZnXDKd?8AME;WZ_3>1oKL_=yLY_Sc}IK zr^WSg>wdlZJdiJ)&(;3%1UhFJnOk_j2gChlw7}E-Qtsq#xaO&7V|tj+FF9t*Y_5Q4DBGzb ze;=O;_O#`#wy6hjUXCsc7R4-#(fF22xKpV!I*xE>QN)?}e4qip43lQOa(0RUkeBxx z3?%93V=s{ipdtgsZw=M}rj+*LCp}V#NJwG_4dGJrZ~n#Zp&A7(^+H21)+zcCCM;3~-6)!oS~Vze+qWaeXCfjmueJ zITcNGFdXV;>w>@>*xvdUpQN-I-rk`=o1Aocp=A{Qc)^i`d#y*EJCJ(R74^I4eD5$*Lcro}&o+8`bw+#BT5Og@%hns2d$(i8tQzuNG(ylZS%vrSII9*b=a?{h z^G0|1xi$s!7SSohj|AK0Xz6E1gG#~q>W~0BuCkEh4@C)Npo2D|m>3f(A1X7nrwg0| z;4y zP}>YN2G~%l2AlGuRq;_?Ga^w);H;MqMoGX+vnsAoZ9d9OazK=re5B>jnEBQ&%TD%w zxkb*(JAn@A)HyV8$|#>TyFo|mT&dfLKShh16AsG04@jJ8MoBe`7APj%aoUapnl=Nc zp{z*la!DVIXZ_|F1^Zerwll6fKHT2~gG5L(QBTMa^UXes%A8M$mU2yndAbdEJ^-fS ztsiw4_;ye-;5(Aoi7llkx16PtTkT1GX{yGh@(Av_e~ZJ8BDa^_eDQJq#>)Ig*Qp_0 z@uhsNxX8Y(ftD6I`qgY&2na`kprJ9Ck6f&m_u&BMj>WKVNQYA-rBz#C zat#-3XIL%QiL5&jD$Ez^j`}4@&ERUOZSD>&LFZ+2i{~5BTFuYLAh(oM;e^eDG2rw2UOhqlGg9bxygCnp2LV~r1_43%bx_#ZyI2_7y@UtL z8kP}@;`p8sJx-PKop8czABTLe+0-4Y&G&O(Db^5?^2fz+G-Jr5+NC;PPe}M{QQ}dd zP|0$n_1F~1YfV5R_9*t(UK=LcU!!Y#KaI?%Cpmv>Sd`w-P&UIeejIPV89%qs0)vP)5APOj)y3W|ybZqYq z+`9BZ&`~^fyR7D(h?jUin7Noc8%65ga-{LM9D0uL3G~LeMdN4hDL%Dg+PuHVSxCBd3t)<4Xj)X~ z0rLBn76S|~jNHsG>g=a&vyv$&1o6i&-6=t8gkn!qyl3krf(HjHw7A{8r0~@qVW?n< zXY4N!m7^*RRzW+m6VPA#-mi-0#lN8yLh+)r2FwkgiJ}*|*sz6vu(&{47>za^IvHoN z3yzg5x{*mEsRRYXO`O%hwR?qj!&(zQ?-VsxtLnrCFh;RjnU%gs zkTQ_bw%vI?$fvkva2%JXav8_q?S>%(>A>TXi;^xEy1M!jF*YDzXD|X|a{28F+Q_v{ z$aXVd15C0#G54oJf6$IMDVtr{DxJc2XO$;ddpDWkXbFfzQkt* z)W>)d^Ma7)Yl-3_yZYU56y!i5{rkDc@+uR&09O1>S8wZxUcYXDOO6$ms8A6r4ua5^ z*zMs8A9o>hELQZHNKAy;+y|VX3gwcpV!gtq&wFTqA?@g-S(BQ^YzF2tWfUXPN6D^G zE}l;nqqXRPNm1?5Kf8BseHL37OlOqwehQ2XXiFrXl1^-xGnWD)DZ9|gEMN3~b|=h0 zdiBP}1^i1_n|Z`J61IDxZr4e7Tl$r9%mf&!OKt1+LMupLRGXc35HC0|NHk6h&>=s2 zrSzo=vBi74OP?$|h!>fxbiQ*#qZ$pbYeS6%2@g@;N%0~O$Lkdp&d`8__hsk77AZ(C zNvoYafMjn zy)6$A20O%^*Z{^hATp}~26w;5S_}8Nhm8s1t)&5l)@)xndw|G3Rz&bVFe=W@G}>{D z3|WOU;Y~%EU2(2TxM^1tZaQ0_f@8WXWhteGM+T5tnOiE1VOj9(`+T_@Lf+To>katw z9JG%!`WQ#LT##xLB6naDB<~agB1ko3=oi)F5Aon{uR!91E%~fDO`;mXwJkq0xzD8N zmTVgSaB(Ly=dh2Ej*Q5Q0cL_d)O}JXy7eHR@@;E6aPwY|-RO3+WUsReCRLuXJeM?qeO8U4;aJ^Uu+-Q}qDNh_O*9h$^OfF04dnX7H zI^hR1-&aJ{+poc?&flfy{%YFBeCI==sv2{cZnVKDOTcNHtUibh#&fJ$sj`$E@GcfE zBqys&;iGeI93cu&AT_c0^+m5zceG?6gLo zwrn+AZ)ZM4&MNN~)rbO6oC4%`P^QzO=MDt&=|V!@F+(p*O&F2K zEcZr?&tG)bSY!eP;tX)uC+ZU=iw8<{$G|<{ib}7>=tc#{-?yyW|82{29AlweQCLO zALMLwa6{c$T|pP!Pz;%?qr;=dr}HVf9+CS&kkHaO>RlBR<6K7O za*uxxcDUT<9chdNTP*H5oQ%5owE1WJ)%|Uw&ixzB*QzCjjb_kOE-^ zCq46Ke7PEWyrBIf+Mqz$+dSgHQR5|sn~Rv~82T}S^cMH0nEzJtGBs>qBBfz??RgjUq8`5C?1s3hBoFYtc*wGh*aSO41ERFkd5YA;H%aI4Gwl zvVJyj!I>uFJ)6FG9LWfC`{&Rx&5WEA?%bPDSgWwHLZsKcpYl$EygK89wzkIU-xKta za=+2pA|Xop=>vDGT zEeA=)^9I(RHZ@_}ho-cr_@1?h#_;oc71#r#0NUk^h4VdsphQPLTLwY|<=!OBJj_V& zVcrlp+!z%cR<0|Vk#BPz8pD=+_*cn<1N5PQaj|58Q;}Ad#aeLDcrsSFL}^bj^il^S zV1s#9RXXUaLWL|Fw=FBN?T1#uuO|bu{sZVbWMic`5vmRB2a1)?eTn`Lj$mvr zW}LUW5I=yXC{*-VbRbw1I_Ns0IYa6RF5udGV1Tn}J{v22HfH2NZ1k5g>C0D!;E%dt zD6f^*(7jal&~&Gh%W1BiRPzEnK}$hPu9)KW0R!nOo(khmt8b~2kq2mvNOSt zM<6bP-crHG3lgU-7kazb#Hx}v-XKnjE=9!0sXDk%+O$rRrAj{fH40oF$;RgMPTK0m zEAsIWYq3uapMAhl_x*D=Q3BJCTM%HXmcF54+t8w=W|zpm8JT38^W243rm-hho{ z(`%kcc*8_ksEyjCjgyh9M2^A1;glDzAqVl{QoqWwSFNi4Uw|}#xgw z4v#SRX;sO|9(SnNTFCBb75FZVr#$PGqd7vYV);RU*4*{=)i4)pqg?!Ue3dq<;k0b6 zlManQ*pkyaqBe&+_S*2K6~1eESe`Gk!4H422-9t(bs6);r*J1~y72x|(_V0-CUFZ> zZebbfWxIR1iL)ge!tz1H+D@T!7O*dpvHhFHHD@Xyg?6+jX<7Kg9YRg-aI$CMvG>(J zaLx_uYD^0d>2k7_n=qwADV5EJspZU4e0gVAQCeT4=29>#Kh*aauJ@fXOtiK(27TTj zQzsusjlGx-$xk>pcBCWUg$d|xtWjhTfDPBoW<~BW5_&c8-PyN1(&}RmvSgp-mPiAB z`fJF3?$JQu%tVy;db!AmW#EhU7WHnuW{ck`Fair?c2GRVZTqc1cqr(f)JXfwXLahQ z_7bgjE^NdhhX$>Gz_~*tje1yas;Yl}qYY?Tq&Q60q{W?G<{6+YTW;l)}(qf~6Y6K%hbf&)D8St0ddtf*XfW5%+_(H>|U7p!2U_C*V$)nWyTr3sT-=oR_KC++k zrIrQ*r$PdpP0f3HdO%J^JUgJmCvundoKCN5AYFZFF*|+o7L*!UWYt3T7LN2hLygD2 zdK_q5J1S=}wtO2)Js)LfULCX8>r8PDZON8kO_uGN8qLZSE@Vnd456CV-M^{sXa6oP zJZoSm8BT_V6Tc#_=n80bCCe?w=~NLxkYY||bE5~yk{g?7C$H)!C_Y@q89FI zyk>cL1#tym_eh9i1tsi3O66hCy9El3m*3S}nf6`pLX#%*NibRh#QOnYpJ3M~ZrTn&+8oBd|A+-?ekkLJ9-0kzXR-ii{1JC#$#BQbv2Gm?r_yPz-m%5@hmxX7 z#@dY6#fr<@v9b5M=pw9;tt4UW?m3IhHzs0-5K5UE)X3n0v z9-d93nfaLgU?@cjWzn{ZET~GvOMtxs+eWpC+}OScp@P0UH&_ya*Y|95fGbO=DDtWH ziqA<0#VSd;(@K`L`5|4aGDRHrP-}w`I%s;37K>HE|xUp(t>NwN7{Lw{Se% z-SDVF^`C%smRhPDxIBu$g5Hw}KZ5*oqd?JHp?rAG0$@CE7NlPrGI>8uf0FUA!~qHg*+_WbRL!hG!2LN@^wDaH72=`Wa_S(wINH-1teY?X}{R!(EN4` zqm&QGkzTT!b~u)uDI%m3{s*5v86%5Mo_xnn4O4Hp1n?h%RTMB#rEBuj@eH4zbQ5V} zC{vT{6}_%#dwezswwpEL!{2Paqe`sDMb?UEZ{c4y7YW`N%YfS#=&AWsRHUQ5_d!p2 z+nx;F@qM#^u^vpGF^AXkEHqerWZ1X(S^Yz`W4geeO=*hOWTrjN!Jfh7X8ob35L zXbvNosx)bnoXhAJ<>{j@yVw1Pj^|g^E@J2A?M8)^!rmmo&|Ijx2wW}B%J0{EtT;In zt(0`7;Z~&@C4sa>LcT|rw{q>s*VlQ2|0T}|^^Cu^t)a!A`9(Jl$Cjt(|5l&p2<|T# zx);f^=yAy&20-7j&oN-9P1*?^`Kvws>u=Hq9$ty>o+jR)MynpRwF=15%rrqqB;D+! zthyUC<1Q0vG3JN2_&CV<@B;@N+vHB%omG~7#7kRe!Ce%ryC9km&h+m5e0R8G7wAZ% zG~MMVm;$*dV@xOa6=eA{xL8FB-ty^___DQH$yyDaiS}g>$Vzkfu22QJHAC5O__y2u zfyIYoYdlFgLgUQTTMAKW4i=HyL)oX^uIsM4851&Yx`g^wGU3giw#cj`MOPv0n zd5`}DCp4g~`0^P>vFF4W&h!2GZ1^Si_l5rx-PUMnbHC@+f9(Gl$Z(BQOBnmvhCZ+I zi{l-*Pq?TO>fk%R`)hpeQgjTrC@rvKc$Zaj=C4QXcQPJ&xopnhs?k1z+5yFvVcyr9 z!=oe=dWY1DkP%R#r1O|xl$({m(F~K>9`Q?RaARAKbjQpoaDJmD)8(g64k%-MxQPEHL|X|;*njlx=- zC&E>UMJ0yvz}~U?fJMNSwDXQ}zxKx>?WU@upWQ9#rM;0k{7pQ(HgWm@@!5zCU6h3s z=MB!gt8s?i!M)yV`j3C`L+3RKywx*5#62g9UXn!WLO?5fLo0i21s7{WJIxo4D2pAo z>S6$hT*uuaMb#_n>BuRfdy5bbV-#J0Yq$y0Yy@%#bgep}WTM1^$;}JferwzvyVU2< zN}^H~nVLnQN zMi>^|5&TiSp%MaeKmx@XHJ=pD>m{oKCj7zBZJ{~GmgZ_fOheh)0X%FE+csAUvgPKA z3wmN874%~>!;`Jm3btb+TV6VL))C~r(Wgkq46`PuR$%V;3S_H&ed}%h#~*8>Ih-$b z44`0WKVzC-Wl(=k(DNLDfPxS`U&BcL<8=Kp@SiIkUKr?Cd#R{8BKplo|8)8%BfaSS za*9m+Z=Ihz=)XC8NlpH0FV0@t-zvTS>F>{4uNUE8{>cBX_OHsXKOy?FYU#Ij%Ja(p z%jv&snEq7%vkK?8`m1LK`K|s}P0oLN=x2?8wVyR>&*AAa{roqDf1>pE_22mF&w>6) zY`+Jp^LMrXMRR}J`7?O_ZATdIKZnwPTKF?#@Y_PvbCB^8jQ^EC`2T-)foCY5RsWsJ z{^tt$GnxI{qsiaZey6to9P*!m&~F2lgg^cM9k%~A=$C4tU+rav{eR8yi}GK!L;oE4 zpVdPzy1&+T)&J7{T}$-m@c*pl_-(})`mf>tYXm=UKWjT=r67N9c@z*75b0-cG-NM& F{{zz', views.Dataset.Operate.as_view(), name="dataset_key"), path('dataset//application', views.Dataset.Application.as_view()), path('dataset//', views.Dataset.Page.as_view(), name="dataset"), path('dataset//sync_web', views.Dataset.SyncWeb.as_view()), path('dataset//hit_test', views.Dataset.HitTest.as_view()), path('dataset//document', views.Document.as_view(), name='document'), + path('dataset/document/template/export', views.Template.as_view()), path('dataset//document/web', views.WebDocument.as_view()), + path('dataset//document/qa', views.QaDocument.as_view()), path('dataset//document/_bach', views.Document.Batch.as_view()), path('dataset//document/batch_hit_handling', views.Document.BatchEditHitHandling.as_view()), path('dataset//document//', views.Document.Page.as_view()), diff --git a/apps/dataset/views/dataset.py b/apps/dataset/views/dataset.py index d3720977b..a864f089c 100644 --- a/apps/dataset/views/dataset.py +++ b/apps/dataset/views/dataset.py @@ -9,6 +9,7 @@ from drf_yasg.utils import swagger_auto_schema from rest_framework.decorators import action +from rest_framework.parsers import MultiPartParser from rest_framework.views import APIView from rest_framework.views import Request @@ -44,6 +45,27 @@ class Dataset(APIView): data={'sync_type': request.query_params.get('sync_type'), 'id': dataset_id, 'user_id': str(request.user.id)}).sync()) + class CreateQADataset(APIView): + authentication_classes = [TokenAuth] + parser_classes = [MultiPartParser] + + @action(methods=['POST'], detail=False) + @swagger_auto_schema(operation_summary="创建QA知识库", + operation_id="创建QA知识库", + + manual_parameters=DataSetSerializers.Create.CreateQASerializers.get_request_params_api(), + responses=get_api_response( + DataSetSerializers.Create.CreateQASerializers.get_response_body_api()), + tags=["知识库"] + ) + @has_permissions(PermissionConstants.DATASET_CREATE, compare=CompareConstants.AND) + def post(self, request: Request): + return result.success(DataSetSerializers.Create(data={'user_id': request.user.id}).save_qa({ + 'file_list': request.FILES.getlist('file'), + 'name': request.data.get('name'), + 'desc': request.data.get('desc') + })) + class CreateWebDataset(APIView): authentication_classes = [TokenAuth] diff --git a/apps/dataset/views/document.py b/apps/dataset/views/document.py index a727a31fa..443f650d9 100644 --- a/apps/dataset/views/document.py +++ b/apps/dataset/views/document.py @@ -22,6 +22,18 @@ from dataset.serializers.document_serializers import DocumentSerializers, Docume from dataset.swagger_api.document_api import DocumentApi +class Template(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="获取QA模版", + operation_id="获取QA模版", + manual_parameters=DocumentSerializers.Export.get_request_params_api(), + tags=["知识库/文档"]) + def get(self, request: Request): + return DocumentSerializers.Export(data={'type': request.query_params.get('type')}).export(with_valid=True) + + class WebDocument(APIView): authentication_classes = [TokenAuth] @@ -40,6 +52,26 @@ class WebDocument(APIView): DocumentSerializers.Create(data={'dataset_id': dataset_id}).save_web(request.data, with_valid=True)) +class QaDocument(APIView): + authentication_classes = [TokenAuth] + parser_classes = [MultiPartParser] + + @action(methods=['POST'], detail=False) + @swagger_auto_schema(operation_summary="导入QA并创建文档", + operation_id="导入QA并创建文档", + manual_parameters=DocumentWebInstanceSerializer.get_request_params_api(), + responses=result.get_api_response(DocumentSerializers.Create.get_response_body_api()), + tags=["知识库/文档"]) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE, + dynamic_tag=k.get('dataset_id'))) + def post(self, request: Request, dataset_id: str): + return result.success( + DocumentSerializers.Create(data={'dataset_id': dataset_id}).save_qa( + {'file_list': request.FILES.getlist('file')}, + with_valid=True)) + + class Document(APIView): authentication_classes = [TokenAuth] diff --git a/pyproject.toml b/pyproject.toml index 3eddbace0..ba1e0dedb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,8 @@ httpx = "^0.27.0" httpx-sse = "^0.4.0" websocket-client = "^1.7.0" langchain-google-genai = "^1.0.3" +openpyxl = "^3.1.2" +xlrd = "^2.0.1" [build-system] requires = ["poetry-core"] diff --git a/ui/src/api/dataset.ts b/ui/src/api/dataset.ts index 702731a26..63c93f1a4 100644 --- a/ui/src/api/dataset.ts +++ b/ui/src/api/dataset.ts @@ -93,6 +93,22 @@ const postWebDataset: (data: any, loading?: Ref) => Promise return post(`${prefix}/web`, data, undefined, loading) } +/** + * 创建QA知识库 + * @param 参数 formData + * { + "file": "file", + "name": "string", + "desc": "string", + } + */ +const postQADataset: (data: any, loading?: Ref) => Promise> = ( + data, + loading +) => { + return post(`${prefix}/qa`, data, undefined, loading) +} + /** * 知识库详情 * @param 参数 dataset_id @@ -170,5 +186,6 @@ export default { listUsableApplication, getDatasetHitTest, postWebDataset, + postQADataset, putSyncWebDataset } diff --git a/ui/src/api/document.ts b/ui/src/api/document.ts index 2d2fc1f65..647fab34a 100644 --- a/ui/src/api/document.ts +++ b/ui/src/api/document.ts @@ -1,5 +1,5 @@ import { Result } from '@/request/Result' -import { get, post, del, put } from '@/request/index' +import { get, post, del, put, exportExcel } from '@/request/index' import type { Ref } from 'vue' import type { KeyValue } from '@/api/type/common' import type { pageRequest } from '@/api/type/common' @@ -188,6 +188,20 @@ const postWebDocument: ( return post(`${prefix}/${dataset_id}/document/web`, data, undefined, loading) } +/** + * 导入QA文档 + * @param 参数 + * file +} + */ +const postQADocument: ( + dataset_id: string, + data: any, + loading?: Ref +) => Promise> = (dataset_id, data, loading) => { + return post(`${prefix}/${dataset_id}/document/qa`, data, undefined, loading) +} + /** * 批量迁移文档 * @param 参数 dataset_id,target_dataset_id, @@ -220,6 +234,19 @@ const batchEditHitHandling: ( ) => Promise> = (dataset_id, data, loading) => { return put(`${prefix}/${dataset_id}/document/batch_hit_handling`, data, undefined, loading) } + +/** + * 获得QA模版 + * @param 参数 fileName,type, + */ +const exportQATemplate: (fileName: string, type: string, loading?: Ref) => void = ( + fileName, + type, + loading +) => { + return exportExcel(fileName, `${prefix}/document/template/export`, { type }, loading) +} + export default { postSplitDocument, getDocument, @@ -234,5 +261,7 @@ export default { delMulSyncDocument, postWebDocument, putMigrateMulDocument, - batchEditHitHandling + batchEditHitHandling, + exportQATemplate, + postQADocument } diff --git a/ui/src/assets/xls-icon.svg b/ui/src/assets/xls-icon.svg new file mode 100644 index 000000000..22cb86953 --- /dev/null +++ b/ui/src/assets/xls-icon.svg @@ -0,0 +1,5 @@ + + + + + diff --git a/ui/src/assets/xlsx-icon.svg b/ui/src/assets/xlsx-icon.svg new file mode 100644 index 000000000..22cb86953 --- /dev/null +++ b/ui/src/assets/xlsx-icon.svg @@ -0,0 +1,5 @@ + + + + + diff --git a/ui/src/stores/modules/dataset.ts b/ui/src/stores/modules/dataset.ts index 22289ea5e..b185d795a 100644 --- a/ui/src/stores/modules/dataset.ts +++ b/ui/src/stores/modules/dataset.ts @@ -7,6 +7,7 @@ import { type Ref } from 'vue' export interface datasetStateTypes { baseInfo: datasetData | null webInfo: any + documentsType: string documentsFiles: UploadUserFile[] } @@ -15,6 +16,7 @@ const useDatasetStore = defineStore({ state: (): datasetStateTypes => ({ baseInfo: null, webInfo: null, + documentsType: '', documentsFiles: [] }), actions: { @@ -24,6 +26,9 @@ const useDatasetStore = defineStore({ saveWebInfo(info: any) { this.webInfo = info }, + saveDocumentsType(val: string) { + this.documentsType = val + }, saveDocumentsFile(file: UploadUserFile[]) { this.documentsFiles = file }, diff --git a/ui/src/utils/utils.ts b/ui/src/utils/utils.ts index 581a4ec5c..8f2463985 100644 --- a/ui/src/utils/utils.ts +++ b/ui/src/utils/utils.ts @@ -37,14 +37,20 @@ export function fileType(name: string) { /* 获得文件对应图片 */ +const typeList: any = { + txt: ['txt', 'pdf', 'docx', 'csv', 'md'], + QA: ['xlsx', 'csv', 'xls'] +} + export function getImgUrl(name: string) { - const type = isRightType(name) ? fileType(name) : 'unknow' + const list = Object.values(typeList).flat() + + const type = list.includes(fileType(name)) ? fileType(name) : 'unknow' return new URL(`../assets/${type}-icon.svg`, import.meta.url).href } // 是否是白名单后缀 -export function isRightType(name: string) { - const typeList = ['txt', 'pdf', 'docx', 'csv', 'md', 'html'] - return typeList.includes(fileType(name)) +export function isRightType(name: string, type: string) { + return typeList[type].includes(fileType(name)) } /* diff --git a/ui/src/views/dataset/CreateDataset.vue b/ui/src/views/dataset/CreateDataset.vue index 3d48b702b..ec0184ba6 100644 --- a/ui/src/views/dataset/CreateDataset.vue +++ b/ui/src/views/dataset/CreateDataset.vue @@ -40,6 +40,7 @@ import StepFirst from './step/StepFirst.vue' import StepSecond from './step/StepSecond.vue' import ResultSuccess from './step/ResultSuccess.vue' import datasetApi from '@/api/dataset' +import documentApi from '@/api/document' import type { datasetData } from '@/api/type/dataset' import { MsgConfirm, MsgSuccess } from '@/utils/message' @@ -48,6 +49,7 @@ const { dataset, document } = useStore() const baseInfo = computed(() => dataset.baseInfo) const webInfo = computed(() => dataset.webInfo) const documentsFiles = computed(() => dataset.documentsFiles) +const documentsType = computed(() => dataset.documentsType) const router = useRouter() const route = useRoute() @@ -80,7 +82,33 @@ const successInfo = ref(null) async function next() { disabled.value = true if (await StepFirstRef.value?.onSubmit()) { - if (active.value++ > 2) active.value = 0 + if (documentsType.value === 'QA') { + let fd = new FormData() + documentsFiles.value.forEach((item: any) => { + if (item?.raw) { + fd.append('file', item?.raw) + } + }) + if (id) { + // QA文档上传 + documentApi.postQADocument(id as string, fd, loading).then((res) => { + MsgSuccess('提交成功') + clearStore() + router.push({ path: `/dataset/${id}/document` }) + }) + } else { + // QA知识库创建 + fd.append('name', baseInfo.value?.name as string) + fd.append('desc', baseInfo.value?.desc as string) + + datasetApi.postQADataset(fd, loading).then((res) => { + successInfo.value = res.data + active.value = 2 + }) + } + } else { + if (active.value++ > 2) active.value = 0 + } } else { disabled.value = false } @@ -93,6 +121,7 @@ function clearStore() { dataset.saveBaseInfo(null) dataset.saveWebInfo(null) dataset.saveDocumentsFile([]) + dataset.saveDocumentsType('') } function submit() { loading.value = true diff --git a/ui/src/views/dataset/component/UploadComponent.vue b/ui/src/views/dataset/component/UploadComponent.vue index 54181156e..4cc3779ed 100644 --- a/ui/src/views/dataset/component/UploadComponent.vue +++ b/ui/src/views/dataset/component/UploadComponent.vue @@ -7,7 +7,48 @@ label-position="top" require-asterisk-position="right" > - + + + 文本文件 + QA 问答对 + + + + + +
+

+ 拖拽文件至此上传或 + 选择文件 + 选择文件夹 +

+
+

当前支持 XLSX / XLS / CSV 格式的文档

+

每次最多上传50个文件,每个文件不超过 100MB

+
+
+
+ + 下载 Excel 模板 + + + 下载 CSV 模板 +
+