diff --git a/apps/common/handle/base_parse_qa_handle.py b/apps/common/handle/base_parse_qa_handle.py new file mode 100644 index 000000000..948b8e536 --- /dev/null +++ b/apps/common/handle/base_parse_qa_handle.py @@ -0,0 +1,19 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: base_parse_qa_handle.py + @date:2024/5/21 14:56 + @desc: +""" +from abc import ABC, abstractmethod + + +class BaseParseQAHandle(ABC): + @abstractmethod + def support(self, file, get_buffer): + pass + + @abstractmethod + def handle(self, file, get_buffer): + pass diff --git a/apps/common/handle/impl/qa/csv_parse_qa_handle.py b/apps/common/handle/impl/qa/csv_parse_qa_handle.py new file mode 100644 index 000000000..6952bf8d8 --- /dev/null +++ b/apps/common/handle/impl/qa/csv_parse_qa_handle.py @@ -0,0 +1,56 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: csv_parse_qa_handle.py + @date:2024/5/21 14:59 + @desc: +""" +import csv +import io + +from charset_normalizer import detect + +from common.handle.base_parse_qa_handle import BaseParseQAHandle + + +def read_csv_standard(file_path): + data = [] + with open(file_path, 'r') as file: + reader = csv.reader(file) + for row in reader: + data.append(row) + return data + + +class CsvParseQAHandle(BaseParseQAHandle): + def support(self, file, get_buffer): + file_name: str = file.name.lower() + if file_name.endswith(".csv"): + return True + return False + + def handle(self, file, get_buffer): + buffer = get_buffer(file) + reader = csv.reader(io.TextIOWrapper(io.BytesIO(buffer), encoding=detect(buffer)['encoding'])) + try: + title_row_list = reader.__next__() + except Exception as e: + return [] + title_row_index_dict = {'title': 0, 'content': 1, 'problem_list': 2} + for index in range(len(title_row_list)): + title_row = title_row_list[index] + if title_row.startswith('分段标题'): + title_row_index_dict['title'] = index + if title_row.startswith('分段内容'): + title_row_index_dict['content'] = index + if title_row.startswith('问题'): + title_row_index_dict['problem_list'] = index + paragraph_list = [] + for row in reader: + problem = row[title_row_index_dict.get('problem_list')] + problem_list = [{'content': p[0:255]} for p in problem.split('\n') if len(p.strip()) > 0] + paragraph_list.append({'title': row[title_row_index_dict.get('title')][0:255], + 'content': row[title_row_index_dict.get('content')][0:4096], + 'problem_list': problem_list}) + return [{'name': file.name, 'paragraphs': paragraph_list}] diff --git a/apps/common/handle/impl/qa/xls_parse_qa_handle.py b/apps/common/handle/impl/qa/xls_parse_qa_handle.py new file mode 100644 index 000000000..090745141 --- /dev/null +++ b/apps/common/handle/impl/qa/xls_parse_qa_handle.py @@ -0,0 +1,55 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: xls_parse_qa_handle.py + @date:2024/5/21 14:59 + @desc: +""" + +import xlrd + +from common.handle.base_parse_qa_handle import BaseParseQAHandle + + +def handle_sheet(file_name, sheet): + rows = iter([sheet.row_values(i) for i in range(sheet.nrows)]) + try: + title_row_list = next(rows) + except Exception as e: + return None + title_row_index_dict = {'title': 0, 'content': 1, 'problem_list': 2} + for index in range(len(title_row_list)): + title_row = str(title_row_list[index]) + if title_row.startswith('分段标题'): + title_row_index_dict['title'] = index + if title_row.startswith('分段内容'): + title_row_index_dict['content'] = index + if title_row.startswith('问题'): + title_row_index_dict['problem_list'] = index + paragraph_list = [] + for row in rows: + problem = str(row[title_row_index_dict.get('problem_list')]) + problem_list = [{'content': p[0:255]} for p in problem.split('\n') if len(p.strip()) > 0] + paragraph_list.append({'title': str(row[title_row_index_dict.get('title')])[0:255], + 'content': str(row[title_row_index_dict.get('content')])[0:4096], + 'problem_list': problem_list}) + return {'name': file_name, 'paragraphs': paragraph_list} + + +class XlsParseQAHandle(BaseParseQAHandle): + def support(self, file, get_buffer): + file_name: str = file.name.lower() + if file_name.endswith(".xls"): + return True + return False + + def handle(self, file, get_buffer): + buffer = get_buffer(file) + workbook = xlrd.open_workbook(file_contents=buffer) + worksheets = workbook.sheets() + worksheets_size = len(worksheets) + return [row for row in + [handle_sheet(file.name, sheet) if worksheets_size == 1 and sheet.name == 'Sheet1' else handle_sheet( + sheet.name, sheet) for sheet + in worksheets] if row is not None] diff --git a/apps/common/handle/impl/qa/xlsx_parse_qa_handle.py b/apps/common/handle/impl/qa/xlsx_parse_qa_handle.py new file mode 100644 index 000000000..9ed0ac06a --- /dev/null +++ b/apps/common/handle/impl/qa/xlsx_parse_qa_handle.py @@ -0,0 +1,56 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: xlsx_parse_qa_handle.py + @date:2024/5/21 14:59 + @desc: +""" +import io + +import openpyxl + +from common.handle.base_parse_qa_handle import BaseParseQAHandle + + +def handle_sheet(file_name, sheet): + rows = sheet.rows + try: + title_row_list = next(rows) + except Exception as e: + return None + title_row_index_dict = {} + for index in range(len(title_row_list)): + title_row = str(title_row_list[index].value) + if title_row.startswith('分段标题'): + title_row_index_dict['title'] = index + if title_row.startswith('分段内容'): + title_row_index_dict['content'] = index + if title_row.startswith('问题'): + title_row_index_dict['problem_list'] = index + paragraph_list = [] + for row in rows: + problem = str(row[title_row_index_dict.get('problem_list')].value) + problem_list = [{'content': p[0:255]} for p in problem.split('\n') if len(p.strip()) > 0] + paragraph_list.append({'title': str(row[title_row_index_dict.get('title')].value)[0:255], + 'content': str(row[title_row_index_dict.get('content')].value)[0:4096], + 'problem_list': problem_list}) + return {'name': file_name, 'paragraphs': paragraph_list} + + +class XlsxParseQAHandle(BaseParseQAHandle): + def support(self, file, get_buffer): + file_name: str = file.name.lower() + if file_name.endswith(".xlsx"): + return True + return False + + def handle(self, file, get_buffer): + buffer = get_buffer(file) + workbook = openpyxl.load_workbook(io.BytesIO(buffer)) + worksheets = workbook.worksheets + worksheets_size = len(worksheets) + return [row for row in + [handle_sheet(file.name, sheet) if worksheets_size == 1 and sheet.title == 'Sheet1' else handle_sheet( + sheet.title, sheet) for sheet + in worksheets] if row is not None] diff --git a/apps/common/util/common.py b/apps/common/util/common.py index 52d90ec85..9ba50c671 100644 --- a/apps/common/util/common.py +++ b/apps/common/util/common.py @@ -45,6 +45,18 @@ def get_exec_method(clazz_: str, method_: str): return getattr(getattr(package_model, clazz_name), method_) +def flat_map(array: List[List]): + """ + 将二位数组转为一维数组 + :param array: 二维数组 + :return: 一维数组 + """ + result = [] + for e in array: + result += e + return result + + def post(post_function): def inner(func): def run(*args, **kwargs): diff --git a/apps/common/util/field_message.py b/apps/common/util/field_message.py index 93b51b920..3d3945400 100644 --- a/apps/common/util/field_message.py +++ b/apps/common/util/field_message.py @@ -104,3 +104,13 @@ class ErrMessage: 'invalid_image': gettext_lazy('【%s】上载有效的图像。您上载的文件不是图像或图像已损坏。' % field), 'max_length': gettext_lazy('请确保此文件名最多包含 {max_length} 个字符(长度为 {length})。') } + + @staticmethod + def file(field: str): + return { + 'required': gettext_lazy('【%s】此字段必填。' % field), + 'empty': gettext_lazy('【%s】提交的文件为空。' % field), + 'invalid': gettext_lazy('【%s】提交的数据不是文件。请检查表单上的编码类型。' % field), + 'no_name': gettext_lazy('【%s】无法确定任何文件名。' % field), + 'max_length': gettext_lazy('请确保此文件名最多包含 {max_length} 个字符(长度为 {length})。') + } diff --git a/apps/dataset/serializers/dataset_serializers.py b/apps/dataset/serializers/dataset_serializers.py index d3f5af73a..61d4b1a3a 100644 --- a/apps/dataset/serializers/dataset_serializers.py +++ b/apps/dataset/serializers/dataset_serializers.py @@ -18,7 +18,7 @@ from urllib.parse import urlparse from django.contrib.postgres.fields import ArrayField from django.core import validators from django.db import transaction, models -from django.db.models import QuerySet, Q +from django.db.models import QuerySet from drf_yasg import openapi from rest_framework import serializers @@ -29,7 +29,7 @@ from common.db.sql_execute import select_list from common.event import ListenerManagement, SyncWebDatasetArgs from common.exception.app_exception import AppApiException from common.mixins.api_mixin import ApiMixin -from common.util.common import post +from common.util.common import post, flat_map from common.util.field_message import ErrMessage from common.util.file_util import get_file_content from common.util.fork import ChildLink, Fork @@ -210,6 +210,75 @@ class DataSetSerializers(serializers.ModelSerializer): super().is_valid(raise_exception=True) return True + class CreateQASerializers(serializers.Serializer): + """ + 创建web站点序列化对象 + """ + name = serializers.CharField(required=True, + error_messages=ErrMessage.char("知识库名称"), + max_length=64, + min_length=1) + + desc = serializers.CharField(required=True, + error_messages=ErrMessage.char("知识库描述"), + max_length=256, + min_length=1) + + file_list = serializers.ListSerializer(required=True, + error_messages=ErrMessage.list("文件列表"), + child=serializers.FileField(required=True, + error_messages=ErrMessage.file("文件"))) + + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='file', + in_=openapi.IN_FORM, + type=openapi.TYPE_ARRAY, + items=openapi.Items(type=openapi.TYPE_FILE), + required=True, + description='上传文件'), + openapi.Parameter(name='name', + in_=openapi.IN_FORM, + required=True, + type=openapi.TYPE_STRING, title="知识库名称", description="知识库名称"), + openapi.Parameter(name='desc', + in_=openapi.IN_FORM, + required=True, + type=openapi.TYPE_STRING, title="知识库描述", description="知识库描述"), + ] + + @staticmethod + def get_response_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['id', 'name', 'desc', 'user_id', 'char_length', 'document_count', + 'update_time', 'create_time', 'document_list'], + properties={ + 'id': openapi.Schema(type=openapi.TYPE_STRING, title="id", + description="id", default="xx"), + 'name': openapi.Schema(type=openapi.TYPE_STRING, title="名称", + description="名称", default="测试知识库"), + 'desc': openapi.Schema(type=openapi.TYPE_STRING, title="描述", + description="描述", default="测试知识库描述"), + 'user_id': openapi.Schema(type=openapi.TYPE_STRING, title="所属用户id", + description="所属用户id", default="user_xxxx"), + 'char_length': openapi.Schema(type=openapi.TYPE_STRING, title="字符数", + description="字符数", default=10), + 'document_count': openapi.Schema(type=openapi.TYPE_STRING, title="文档数量", + description="文档数量", default=1), + 'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间", + description="修改时间", + default="1970-01-01 00:00:00"), + 'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间", + description="创建时间", + default="1970-01-01 00:00:00" + ), + 'document_list': openapi.Schema(type=openapi.TYPE_ARRAY, title="文档列表", + description="文档列表", + items=DocumentSerializers.Operate.get_response_body_api()) + } + ) + class CreateWebSerializers(serializers.Serializer): """ 创建web站点序列化对象 @@ -288,6 +357,15 @@ class DataSetSerializers(serializers.ModelSerializer): ListenerManagement.embedding_by_dataset_signal.send(dataset_id) return document_list + def save_qa(self, instance: Dict, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + self.CreateQASerializers(data=instance).is_valid() + file_list = instance.get('file_list') + document_list = flat_map([DocumentSerializers.Create.parse_qa_file(file) for file in file_list]) + dataset_instance = {'name': instance.get('name'), 'desc': instance.get('desc'), 'documents': document_list} + return self.save(dataset_instance, with_valid=True) + @post(post_function=post_embedding_dataset) @transaction.atomic def save(self, instance: Dict, with_valid=True): diff --git a/apps/dataset/serializers/document_serializers.py b/apps/dataset/serializers/document_serializers.py index c5f88e336..5d6518ad5 100644 --- a/apps/dataset/serializers/document_serializers.py +++ b/apps/dataset/serializers/document_serializers.py @@ -17,6 +17,7 @@ from typing import List, Dict from django.core import validators from django.db import transaction from django.db.models import QuerySet +from django.http import HttpResponse from drf_yasg import openapi from rest_framework import serializers @@ -27,9 +28,12 @@ from common.exception.app_exception import AppApiException from common.handle.impl.doc_split_handle import DocSplitHandle from common.handle.impl.html_split_handle import HTMLSplitHandle from common.handle.impl.pdf_split_handle import PdfSplitHandle +from common.handle.impl.qa.csv_parse_qa_handle import CsvParseQAHandle +from common.handle.impl.qa.xls_parse_qa_handle import XlsParseQAHandle +from common.handle.impl.qa.xlsx_parse_qa_handle import XlsxParseQAHandle from common.handle.impl.text_split_handle import TextSplitHandle from common.mixins.api_mixin import ApiMixin -from common.util.common import post +from common.util.common import post, flat_map from common.util.field_message import ErrMessage from common.util.file_util import get_file_content from common.util.fork import Fork @@ -39,6 +43,17 @@ from dataset.serializers.common_serializers import BatchSerializer, MetaSerializ from dataset.serializers.paragraph_serializers import ParagraphSerializers, ParagraphInstanceSerializer from smartdoc.conf import PROJECT_DIR +parse_qa_handle_list = [XlsParseQAHandle(), CsvParseQAHandle(), XlsxParseQAHandle()] + + +class FileBufferHandle: + buffer = None + + def get_buffer(self, file): + if self.buffer is None: + self.buffer = file.read() + return self.buffer + class DocumentEditInstanceSerializer(ApiMixin, serializers.Serializer): meta = serializers.DictField(required=False) @@ -87,16 +102,19 @@ class DocumentWebInstanceSerializer(ApiMixin, serializers.Serializer): "选择器")) @staticmethod - def get_request_body_api(): - return openapi.Schema( - type=openapi.TYPE_OBJECT, - required=['source_url_list'], - properties={ - 'source_url_list': openapi.Schema(type=openapi.TYPE_ARRAY, title="段落列表", description="段落列表", - items=openapi.Schema(type=openapi.TYPE_STRING)), - 'selector': openapi.Schema(type=openapi.TYPE_STRING, title="文档名称", description="文档名称") - } - ) + def get_request_params_api(): + return [openapi.Parameter(name='file', + in_=openapi.IN_FORM, + type=openapi.TYPE_ARRAY, + items=openapi.Items(type=openapi.TYPE_FILE), + required=True, + description='上传文件'), + openapi.Parameter(name='dataset_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='知识库id'), + ] class DocumentInstanceSerializer(ApiMixin, serializers.Serializer): @@ -120,7 +138,48 @@ class DocumentInstanceSerializer(ApiMixin, serializers.Serializer): ) +class DocumentInstanceQASerializer(ApiMixin, serializers.Serializer): + file_list = serializers.ListSerializer(required=True, + error_messages=ErrMessage.list("文件列表"), + child=serializers.FileField(required=True, + error_messages=ErrMessage.file("文件"))) + + class DocumentSerializers(ApiMixin, serializers.Serializer): + class Export(ApiMixin, serializers.Serializer): + type = serializers.CharField(required=True, validators=[ + validators.RegexValidator(regex=re.compile("^csv|excel$"), + message="模版类型只支持excel|csv", + code=500) + ], error_messages=ErrMessage.char("模版类型")) + + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='type', + in_=openapi.IN_QUERY, + type=openapi.TYPE_STRING, + required=True, + description='导出模板类型csv|excel'), + + ] + + def export(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + + if self.data.get('type') == 'csv': + file = open(os.path.join(PROJECT_DIR, "apps", "dataset", 'template', 'csv_template.csv'), "rb") + content = file.read() + file.close() + return HttpResponse(content, status=200, headers={'Content-Type': 'text/cxv', + 'Content-Disposition': 'attachment; filename="csv_template.csv"'}) + elif self.data.get('type') == 'excel': + file = open(os.path.join(PROJECT_DIR, "apps", "dataset", 'template', 'excel_template.xlsx'), "rb") + content = file.read() + file.close() + return HttpResponse(content, status=200, headers={'Content-Type': 'application/vnd.ms-excel', + 'Content-Disposition': 'attachment; filename="excel_template.xlsx"'}) + class Migrate(ApiMixin, serializers.Serializer): dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char( @@ -473,6 +532,22 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): ListenerManagement.embedding_by_document_signal.send(document_id) return result + @staticmethod + def parse_qa_file(file): + for parse_qa_handle in parse_qa_handle_list: + get_buffer = FileBufferHandle().get_buffer + if parse_qa_handle.support(file, get_buffer): + return parse_qa_handle.handle(file, get_buffer) + raise AppApiException(500, '不支持的文件格式') + + def save_qa(self, instance: Dict, with_valid=True): + if with_valid: + DocumentInstanceQASerializer(data=instance).is_valid(raise_exception=True) + self.is_valid(raise_exception=True) + file_list = instance.get('file_list') + document_list = flat_map([self.parse_qa_file(file) for file in file_list]) + return DocumentSerializers.Batch(data={'dataset_id': self.data.get('dataset_id')}).batch_save(document_list) + @post(post_function=post_embedding) @transaction.atomic def save(self, instance: Dict, with_valid=False, **kwargs): @@ -714,6 +789,8 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): problem_paragraph_mapping_list) > 0 else None # 查询文档 query_set = QuerySet(model=Document) + if len(document_model_list) == 0: + return [], query_set = query_set.filter(**{'id__in': [d.id for d in document_model_list]}) return native_search(query_set, select_string=get_file_content( os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_document.sql')), with_search_one=False), diff --git a/apps/dataset/template/csv_template.csv b/apps/dataset/template/csv_template.csv new file mode 100644 index 000000000..b306a9c17 --- /dev/null +++ b/apps/dataset/template/csv_template.csv @@ -0,0 +1,8 @@ +分段标题(选填),分段内容(必填,问题答案,最长不超过4096个字符)),问题(选填,单元格内一行一个) +MaxKB产品介绍,"MaxKB 是一款基于 LLM 大语言模型的知识库问答系统。MaxKB = Max Knowledge Base,旨在成为企业的最强大脑。 +开箱即用:支持直接上传文档、自动爬取在线文档,支持文本自动拆分、向量化,智能问答交互体验好; +无缝嵌入:支持零编码快速嵌入到第三方业务系统; +多模型支持:支持对接主流的大模型,包括 Ollama 本地私有大模型(如 Llama 2、Llama 3、qwen)、通义千问、OpenAI、Azure OpenAI、Kimi、智谱 AI、讯飞星火和百度千帆大模型等。","MaxKB是什么? +MaxKB产品介绍 +MaxKB支持的大语言模型 +MaxKB优势" diff --git a/apps/dataset/template/excel_template.xlsx b/apps/dataset/template/excel_template.xlsx new file mode 100644 index 000000000..6517b154f Binary files /dev/null and b/apps/dataset/template/excel_template.xlsx differ diff --git a/apps/dataset/urls.py b/apps/dataset/urls.py index 237a81f59..b9d1bd431 100644 --- a/apps/dataset/urls.py +++ b/apps/dataset/urls.py @@ -6,13 +6,16 @@ app_name = "dataset" urlpatterns = [ path('dataset', views.Dataset.as_view(), name="dataset"), path('dataset/web', views.Dataset.CreateWebDataset.as_view()), + path('dataset/qa', views.Dataset.CreateQADataset.as_view()), path('dataset/', views.Dataset.Operate.as_view(), name="dataset_key"), path('dataset//application', views.Dataset.Application.as_view()), path('dataset//', views.Dataset.Page.as_view(), name="dataset"), path('dataset//sync_web', views.Dataset.SyncWeb.as_view()), path('dataset//hit_test', views.Dataset.HitTest.as_view()), path('dataset//document', views.Document.as_view(), name='document'), + path('dataset/document/template/export', views.Template.as_view()), path('dataset//document/web', views.WebDocument.as_view()), + path('dataset//document/qa', views.QaDocument.as_view()), path('dataset//document/_bach', views.Document.Batch.as_view()), path('dataset//document/batch_hit_handling', views.Document.BatchEditHitHandling.as_view()), path('dataset//document//', views.Document.Page.as_view()), diff --git a/apps/dataset/views/dataset.py b/apps/dataset/views/dataset.py index d3720977b..a864f089c 100644 --- a/apps/dataset/views/dataset.py +++ b/apps/dataset/views/dataset.py @@ -9,6 +9,7 @@ from drf_yasg.utils import swagger_auto_schema from rest_framework.decorators import action +from rest_framework.parsers import MultiPartParser from rest_framework.views import APIView from rest_framework.views import Request @@ -44,6 +45,27 @@ class Dataset(APIView): data={'sync_type': request.query_params.get('sync_type'), 'id': dataset_id, 'user_id': str(request.user.id)}).sync()) + class CreateQADataset(APIView): + authentication_classes = [TokenAuth] + parser_classes = [MultiPartParser] + + @action(methods=['POST'], detail=False) + @swagger_auto_schema(operation_summary="创建QA知识库", + operation_id="创建QA知识库", + + manual_parameters=DataSetSerializers.Create.CreateQASerializers.get_request_params_api(), + responses=get_api_response( + DataSetSerializers.Create.CreateQASerializers.get_response_body_api()), + tags=["知识库"] + ) + @has_permissions(PermissionConstants.DATASET_CREATE, compare=CompareConstants.AND) + def post(self, request: Request): + return result.success(DataSetSerializers.Create(data={'user_id': request.user.id}).save_qa({ + 'file_list': request.FILES.getlist('file'), + 'name': request.data.get('name'), + 'desc': request.data.get('desc') + })) + class CreateWebDataset(APIView): authentication_classes = [TokenAuth] diff --git a/apps/dataset/views/document.py b/apps/dataset/views/document.py index a727a31fa..443f650d9 100644 --- a/apps/dataset/views/document.py +++ b/apps/dataset/views/document.py @@ -22,6 +22,18 @@ from dataset.serializers.document_serializers import DocumentSerializers, Docume from dataset.swagger_api.document_api import DocumentApi +class Template(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="获取QA模版", + operation_id="获取QA模版", + manual_parameters=DocumentSerializers.Export.get_request_params_api(), + tags=["知识库/文档"]) + def get(self, request: Request): + return DocumentSerializers.Export(data={'type': request.query_params.get('type')}).export(with_valid=True) + + class WebDocument(APIView): authentication_classes = [TokenAuth] @@ -40,6 +52,26 @@ class WebDocument(APIView): DocumentSerializers.Create(data={'dataset_id': dataset_id}).save_web(request.data, with_valid=True)) +class QaDocument(APIView): + authentication_classes = [TokenAuth] + parser_classes = [MultiPartParser] + + @action(methods=['POST'], detail=False) + @swagger_auto_schema(operation_summary="导入QA并创建文档", + operation_id="导入QA并创建文档", + manual_parameters=DocumentWebInstanceSerializer.get_request_params_api(), + responses=result.get_api_response(DocumentSerializers.Create.get_response_body_api()), + tags=["知识库/文档"]) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE, + dynamic_tag=k.get('dataset_id'))) + def post(self, request: Request, dataset_id: str): + return result.success( + DocumentSerializers.Create(data={'dataset_id': dataset_id}).save_qa( + {'file_list': request.FILES.getlist('file')}, + with_valid=True)) + + class Document(APIView): authentication_classes = [TokenAuth] diff --git a/pyproject.toml b/pyproject.toml index 3eddbace0..ba1e0dedb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,8 @@ httpx = "^0.27.0" httpx-sse = "^0.4.0" websocket-client = "^1.7.0" langchain-google-genai = "^1.0.3" +openpyxl = "^3.1.2" +xlrd = "^2.0.1" [build-system] requires = ["poetry-core"] diff --git a/ui/src/api/dataset.ts b/ui/src/api/dataset.ts index 702731a26..63c93f1a4 100644 --- a/ui/src/api/dataset.ts +++ b/ui/src/api/dataset.ts @@ -93,6 +93,22 @@ const postWebDataset: (data: any, loading?: Ref) => Promise return post(`${prefix}/web`, data, undefined, loading) } +/** + * 创建QA知识库 + * @param 参数 formData + * { + "file": "file", + "name": "string", + "desc": "string", + } + */ +const postQADataset: (data: any, loading?: Ref) => Promise> = ( + data, + loading +) => { + return post(`${prefix}/qa`, data, undefined, loading) +} + /** * 知识库详情 * @param 参数 dataset_id @@ -170,5 +186,6 @@ export default { listUsableApplication, getDatasetHitTest, postWebDataset, + postQADataset, putSyncWebDataset } diff --git a/ui/src/api/document.ts b/ui/src/api/document.ts index 2d2fc1f65..647fab34a 100644 --- a/ui/src/api/document.ts +++ b/ui/src/api/document.ts @@ -1,5 +1,5 @@ import { Result } from '@/request/Result' -import { get, post, del, put } from '@/request/index' +import { get, post, del, put, exportExcel } from '@/request/index' import type { Ref } from 'vue' import type { KeyValue } from '@/api/type/common' import type { pageRequest } from '@/api/type/common' @@ -188,6 +188,20 @@ const postWebDocument: ( return post(`${prefix}/${dataset_id}/document/web`, data, undefined, loading) } +/** + * 导入QA文档 + * @param 参数 + * file +} + */ +const postQADocument: ( + dataset_id: string, + data: any, + loading?: Ref +) => Promise> = (dataset_id, data, loading) => { + return post(`${prefix}/${dataset_id}/document/qa`, data, undefined, loading) +} + /** * 批量迁移文档 * @param 参数 dataset_id,target_dataset_id, @@ -220,6 +234,19 @@ const batchEditHitHandling: ( ) => Promise> = (dataset_id, data, loading) => { return put(`${prefix}/${dataset_id}/document/batch_hit_handling`, data, undefined, loading) } + +/** + * 获得QA模版 + * @param 参数 fileName,type, + */ +const exportQATemplate: (fileName: string, type: string, loading?: Ref) => void = ( + fileName, + type, + loading +) => { + return exportExcel(fileName, `${prefix}/document/template/export`, { type }, loading) +} + export default { postSplitDocument, getDocument, @@ -234,5 +261,7 @@ export default { delMulSyncDocument, postWebDocument, putMigrateMulDocument, - batchEditHitHandling + batchEditHitHandling, + exportQATemplate, + postQADocument } diff --git a/ui/src/assets/xls-icon.svg b/ui/src/assets/xls-icon.svg new file mode 100644 index 000000000..22cb86953 --- /dev/null +++ b/ui/src/assets/xls-icon.svg @@ -0,0 +1,5 @@ + + + + + diff --git a/ui/src/assets/xlsx-icon.svg b/ui/src/assets/xlsx-icon.svg new file mode 100644 index 000000000..22cb86953 --- /dev/null +++ b/ui/src/assets/xlsx-icon.svg @@ -0,0 +1,5 @@ + + + + + diff --git a/ui/src/stores/modules/dataset.ts b/ui/src/stores/modules/dataset.ts index 22289ea5e..b185d795a 100644 --- a/ui/src/stores/modules/dataset.ts +++ b/ui/src/stores/modules/dataset.ts @@ -7,6 +7,7 @@ import { type Ref } from 'vue' export interface datasetStateTypes { baseInfo: datasetData | null webInfo: any + documentsType: string documentsFiles: UploadUserFile[] } @@ -15,6 +16,7 @@ const useDatasetStore = defineStore({ state: (): datasetStateTypes => ({ baseInfo: null, webInfo: null, + documentsType: '', documentsFiles: [] }), actions: { @@ -24,6 +26,9 @@ const useDatasetStore = defineStore({ saveWebInfo(info: any) { this.webInfo = info }, + saveDocumentsType(val: string) { + this.documentsType = val + }, saveDocumentsFile(file: UploadUserFile[]) { this.documentsFiles = file }, diff --git a/ui/src/utils/utils.ts b/ui/src/utils/utils.ts index 581a4ec5c..8f2463985 100644 --- a/ui/src/utils/utils.ts +++ b/ui/src/utils/utils.ts @@ -37,14 +37,20 @@ export function fileType(name: string) { /* 获得文件对应图片 */ +const typeList: any = { + txt: ['txt', 'pdf', 'docx', 'csv', 'md'], + QA: ['xlsx', 'csv', 'xls'] +} + export function getImgUrl(name: string) { - const type = isRightType(name) ? fileType(name) : 'unknow' + const list = Object.values(typeList).flat() + + const type = list.includes(fileType(name)) ? fileType(name) : 'unknow' return new URL(`../assets/${type}-icon.svg`, import.meta.url).href } // 是否是白名单后缀 -export function isRightType(name: string) { - const typeList = ['txt', 'pdf', 'docx', 'csv', 'md', 'html'] - return typeList.includes(fileType(name)) +export function isRightType(name: string, type: string) { + return typeList[type].includes(fileType(name)) } /* diff --git a/ui/src/views/dataset/CreateDataset.vue b/ui/src/views/dataset/CreateDataset.vue index 3d48b702b..ec0184ba6 100644 --- a/ui/src/views/dataset/CreateDataset.vue +++ b/ui/src/views/dataset/CreateDataset.vue @@ -40,6 +40,7 @@ import StepFirst from './step/StepFirst.vue' import StepSecond from './step/StepSecond.vue' import ResultSuccess from './step/ResultSuccess.vue' import datasetApi from '@/api/dataset' +import documentApi from '@/api/document' import type { datasetData } from '@/api/type/dataset' import { MsgConfirm, MsgSuccess } from '@/utils/message' @@ -48,6 +49,7 @@ const { dataset, document } = useStore() const baseInfo = computed(() => dataset.baseInfo) const webInfo = computed(() => dataset.webInfo) const documentsFiles = computed(() => dataset.documentsFiles) +const documentsType = computed(() => dataset.documentsType) const router = useRouter() const route = useRoute() @@ -80,7 +82,33 @@ const successInfo = ref(null) async function next() { disabled.value = true if (await StepFirstRef.value?.onSubmit()) { - if (active.value++ > 2) active.value = 0 + if (documentsType.value === 'QA') { + let fd = new FormData() + documentsFiles.value.forEach((item: any) => { + if (item?.raw) { + fd.append('file', item?.raw) + } + }) + if (id) { + // QA文档上传 + documentApi.postQADocument(id as string, fd, loading).then((res) => { + MsgSuccess('提交成功') + clearStore() + router.push({ path: `/dataset/${id}/document` }) + }) + } else { + // QA知识库创建 + fd.append('name', baseInfo.value?.name as string) + fd.append('desc', baseInfo.value?.desc as string) + + datasetApi.postQADataset(fd, loading).then((res) => { + successInfo.value = res.data + active.value = 2 + }) + } + } else { + if (active.value++ > 2) active.value = 0 + } } else { disabled.value = false } @@ -93,6 +121,7 @@ function clearStore() { dataset.saveBaseInfo(null) dataset.saveWebInfo(null) dataset.saveDocumentsFile([]) + dataset.saveDocumentsType('') } function submit() { loading.value = true diff --git a/ui/src/views/dataset/component/UploadComponent.vue b/ui/src/views/dataset/component/UploadComponent.vue index 54181156e..4cc3779ed 100644 --- a/ui/src/views/dataset/component/UploadComponent.vue +++ b/ui/src/views/dataset/component/UploadComponent.vue @@ -7,7 +7,48 @@ label-position="top" require-asterisk-position="right" > - + + + 文本文件 + QA 问答对 + + + + + +
+

+ 拖拽文件至此上传或 + 选择文件 + 选择文件夹 +

+
+

当前支持 XLSX / XLS / CSV 格式的文档

+

每次最多上传50个文件,每个文件不超过 100MB

+
+
+
+ + 下载 Excel 模板 + + + 下载 CSV 模板 +
+