* feat: 支持上传 Excel/CSV 类型的问答对 (#430)
Some checks are pending
sync2gitee / repo-sync (push) Waiting to run
Typos Check / Spell Check with Typos (push) Waiting to run

This commit is contained in:
shaohuzhang1 2024-05-23 18:57:49 +08:00 committed by GitHub
parent 6a226be048
commit 28938104c0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 611 additions and 24 deletions

View File

@ -0,0 +1,19 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file base_parse_qa_handle.py
@date2024/5/21 14:56
@desc:
"""
from abc import ABC, abstractmethod
class BaseParseQAHandle(ABC):
@abstractmethod
def support(self, file, get_buffer):
pass
@abstractmethod
def handle(self, file, get_buffer):
pass

View File

@ -0,0 +1,56 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file csv_parse_qa_handle.py
@date2024/5/21 14:59
@desc:
"""
import csv
import io
from charset_normalizer import detect
from common.handle.base_parse_qa_handle import BaseParseQAHandle
def read_csv_standard(file_path):
data = []
with open(file_path, 'r') as file:
reader = csv.reader(file)
for row in reader:
data.append(row)
return data
class CsvParseQAHandle(BaseParseQAHandle):
def support(self, file, get_buffer):
file_name: str = file.name.lower()
if file_name.endswith(".csv"):
return True
return False
def handle(self, file, get_buffer):
buffer = get_buffer(file)
reader = csv.reader(io.TextIOWrapper(io.BytesIO(buffer), encoding=detect(buffer)['encoding']))
try:
title_row_list = reader.__next__()
except Exception as e:
return []
title_row_index_dict = {'title': 0, 'content': 1, 'problem_list': 2}
for index in range(len(title_row_list)):
title_row = title_row_list[index]
if title_row.startswith('分段标题'):
title_row_index_dict['title'] = index
if title_row.startswith('分段内容'):
title_row_index_dict['content'] = index
if title_row.startswith('问题'):
title_row_index_dict['problem_list'] = index
paragraph_list = []
for row in reader:
problem = row[title_row_index_dict.get('problem_list')]
problem_list = [{'content': p[0:255]} for p in problem.split('\n') if len(p.strip()) > 0]
paragraph_list.append({'title': row[title_row_index_dict.get('title')][0:255],
'content': row[title_row_index_dict.get('content')][0:4096],
'problem_list': problem_list})
return [{'name': file.name, 'paragraphs': paragraph_list}]

View File

@ -0,0 +1,55 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file xls_parse_qa_handle.py
@date2024/5/21 14:59
@desc:
"""
import xlrd
from common.handle.base_parse_qa_handle import BaseParseQAHandle
def handle_sheet(file_name, sheet):
rows = iter([sheet.row_values(i) for i in range(sheet.nrows)])
try:
title_row_list = next(rows)
except Exception as e:
return None
title_row_index_dict = {'title': 0, 'content': 1, 'problem_list': 2}
for index in range(len(title_row_list)):
title_row = str(title_row_list[index])
if title_row.startswith('分段标题'):
title_row_index_dict['title'] = index
if title_row.startswith('分段内容'):
title_row_index_dict['content'] = index
if title_row.startswith('问题'):
title_row_index_dict['problem_list'] = index
paragraph_list = []
for row in rows:
problem = str(row[title_row_index_dict.get('problem_list')])
problem_list = [{'content': p[0:255]} for p in problem.split('\n') if len(p.strip()) > 0]
paragraph_list.append({'title': str(row[title_row_index_dict.get('title')])[0:255],
'content': str(row[title_row_index_dict.get('content')])[0:4096],
'problem_list': problem_list})
return {'name': file_name, 'paragraphs': paragraph_list}
class XlsParseQAHandle(BaseParseQAHandle):
def support(self, file, get_buffer):
file_name: str = file.name.lower()
if file_name.endswith(".xls"):
return True
return False
def handle(self, file, get_buffer):
buffer = get_buffer(file)
workbook = xlrd.open_workbook(file_contents=buffer)
worksheets = workbook.sheets()
worksheets_size = len(worksheets)
return [row for row in
[handle_sheet(file.name, sheet) if worksheets_size == 1 and sheet.name == 'Sheet1' else handle_sheet(
sheet.name, sheet) for sheet
in worksheets] if row is not None]

View File

@ -0,0 +1,56 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file xlsx_parse_qa_handle.py
@date2024/5/21 14:59
@desc:
"""
import io
import openpyxl
from common.handle.base_parse_qa_handle import BaseParseQAHandle
def handle_sheet(file_name, sheet):
rows = sheet.rows
try:
title_row_list = next(rows)
except Exception as e:
return None
title_row_index_dict = {}
for index in range(len(title_row_list)):
title_row = str(title_row_list[index].value)
if title_row.startswith('分段标题'):
title_row_index_dict['title'] = index
if title_row.startswith('分段内容'):
title_row_index_dict['content'] = index
if title_row.startswith('问题'):
title_row_index_dict['problem_list'] = index
paragraph_list = []
for row in rows:
problem = str(row[title_row_index_dict.get('problem_list')].value)
problem_list = [{'content': p[0:255]} for p in problem.split('\n') if len(p.strip()) > 0]
paragraph_list.append({'title': str(row[title_row_index_dict.get('title')].value)[0:255],
'content': str(row[title_row_index_dict.get('content')].value)[0:4096],
'problem_list': problem_list})
return {'name': file_name, 'paragraphs': paragraph_list}
class XlsxParseQAHandle(BaseParseQAHandle):
def support(self, file, get_buffer):
file_name: str = file.name.lower()
if file_name.endswith(".xlsx"):
return True
return False
def handle(self, file, get_buffer):
buffer = get_buffer(file)
workbook = openpyxl.load_workbook(io.BytesIO(buffer))
worksheets = workbook.worksheets
worksheets_size = len(worksheets)
return [row for row in
[handle_sheet(file.name, sheet) if worksheets_size == 1 and sheet.title == 'Sheet1' else handle_sheet(
sheet.title, sheet) for sheet
in worksheets] if row is not None]

View File

@ -45,6 +45,18 @@ def get_exec_method(clazz_: str, method_: str):
return getattr(getattr(package_model, clazz_name), method_)
def flat_map(array: List[List]):
"""
将二位数组转为一维数组
:param array: 二维数组
:return: 一维数组
"""
result = []
for e in array:
result += e
return result
def post(post_function):
def inner(func):
def run(*args, **kwargs):

View File

@ -104,3 +104,13 @@ class ErrMessage:
'invalid_image': gettext_lazy('%s】上载有效的图像。您上载的文件不是图像或图像已损坏。' % field),
'max_length': gettext_lazy('请确保此文件名最多包含 {max_length} 个字符(长度为 {length})。')
}
@staticmethod
def file(field: str):
return {
'required': gettext_lazy('%s】此字段必填。' % field),
'empty': gettext_lazy('%s】提交的文件为空。' % field),
'invalid': gettext_lazy('%s】提交的数据不是文件。请检查表单上的编码类型。' % field),
'no_name': gettext_lazy('%s】无法确定任何文件名。' % field),
'max_length': gettext_lazy('请确保此文件名最多包含 {max_length} 个字符(长度为 {length})。')
}

View File

@ -18,7 +18,7 @@ from urllib.parse import urlparse
from django.contrib.postgres.fields import ArrayField
from django.core import validators
from django.db import transaction, models
from django.db.models import QuerySet, Q
from django.db.models import QuerySet
from drf_yasg import openapi
from rest_framework import serializers
@ -29,7 +29,7 @@ from common.db.sql_execute import select_list
from common.event import ListenerManagement, SyncWebDatasetArgs
from common.exception.app_exception import AppApiException
from common.mixins.api_mixin import ApiMixin
from common.util.common import post
from common.util.common import post, flat_map
from common.util.field_message import ErrMessage
from common.util.file_util import get_file_content
from common.util.fork import ChildLink, Fork
@ -210,6 +210,75 @@ class DataSetSerializers(serializers.ModelSerializer):
super().is_valid(raise_exception=True)
return True
class CreateQASerializers(serializers.Serializer):
"""
创建web站点序列化对象
"""
name = serializers.CharField(required=True,
error_messages=ErrMessage.char("知识库名称"),
max_length=64,
min_length=1)
desc = serializers.CharField(required=True,
error_messages=ErrMessage.char("知识库描述"),
max_length=256,
min_length=1)
file_list = serializers.ListSerializer(required=True,
error_messages=ErrMessage.list("文件列表"),
child=serializers.FileField(required=True,
error_messages=ErrMessage.file("文件")))
@staticmethod
def get_request_params_api():
return [openapi.Parameter(name='file',
in_=openapi.IN_FORM,
type=openapi.TYPE_ARRAY,
items=openapi.Items(type=openapi.TYPE_FILE),
required=True,
description='上传文件'),
openapi.Parameter(name='name',
in_=openapi.IN_FORM,
required=True,
type=openapi.TYPE_STRING, title="知识库名称", description="知识库名称"),
openapi.Parameter(name='desc',
in_=openapi.IN_FORM,
required=True,
type=openapi.TYPE_STRING, title="知识库描述", description="知识库描述"),
]
@staticmethod
def get_response_body_api():
return openapi.Schema(
type=openapi.TYPE_OBJECT,
required=['id', 'name', 'desc', 'user_id', 'char_length', 'document_count',
'update_time', 'create_time', 'document_list'],
properties={
'id': openapi.Schema(type=openapi.TYPE_STRING, title="id",
description="id", default="xx"),
'name': openapi.Schema(type=openapi.TYPE_STRING, title="名称",
description="名称", default="测试知识库"),
'desc': openapi.Schema(type=openapi.TYPE_STRING, title="描述",
description="描述", default="测试知识库描述"),
'user_id': openapi.Schema(type=openapi.TYPE_STRING, title="所属用户id",
description="所属用户id", default="user_xxxx"),
'char_length': openapi.Schema(type=openapi.TYPE_STRING, title="字符数",
description="字符数", default=10),
'document_count': openapi.Schema(type=openapi.TYPE_STRING, title="文档数量",
description="文档数量", default=1),
'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间",
description="修改时间",
default="1970-01-01 00:00:00"),
'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间",
description="创建时间",
default="1970-01-01 00:00:00"
),
'document_list': openapi.Schema(type=openapi.TYPE_ARRAY, title="文档列表",
description="文档列表",
items=DocumentSerializers.Operate.get_response_body_api())
}
)
class CreateWebSerializers(serializers.Serializer):
"""
创建web站点序列化对象
@ -288,6 +357,15 @@ class DataSetSerializers(serializers.ModelSerializer):
ListenerManagement.embedding_by_dataset_signal.send(dataset_id)
return document_list
def save_qa(self, instance: Dict, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
self.CreateQASerializers(data=instance).is_valid()
file_list = instance.get('file_list')
document_list = flat_map([DocumentSerializers.Create.parse_qa_file(file) for file in file_list])
dataset_instance = {'name': instance.get('name'), 'desc': instance.get('desc'), 'documents': document_list}
return self.save(dataset_instance, with_valid=True)
@post(post_function=post_embedding_dataset)
@transaction.atomic
def save(self, instance: Dict, with_valid=True):

View File

@ -17,6 +17,7 @@ from typing import List, Dict
from django.core import validators
from django.db import transaction
from django.db.models import QuerySet
from django.http import HttpResponse
from drf_yasg import openapi
from rest_framework import serializers
@ -27,9 +28,12 @@ from common.exception.app_exception import AppApiException
from common.handle.impl.doc_split_handle import DocSplitHandle
from common.handle.impl.html_split_handle import HTMLSplitHandle
from common.handle.impl.pdf_split_handle import PdfSplitHandle
from common.handle.impl.qa.csv_parse_qa_handle import CsvParseQAHandle
from common.handle.impl.qa.xls_parse_qa_handle import XlsParseQAHandle
from common.handle.impl.qa.xlsx_parse_qa_handle import XlsxParseQAHandle
from common.handle.impl.text_split_handle import TextSplitHandle
from common.mixins.api_mixin import ApiMixin
from common.util.common import post
from common.util.common import post, flat_map
from common.util.field_message import ErrMessage
from common.util.file_util import get_file_content
from common.util.fork import Fork
@ -39,6 +43,17 @@ from dataset.serializers.common_serializers import BatchSerializer, MetaSerializ
from dataset.serializers.paragraph_serializers import ParagraphSerializers, ParagraphInstanceSerializer
from smartdoc.conf import PROJECT_DIR
parse_qa_handle_list = [XlsParseQAHandle(), CsvParseQAHandle(), XlsxParseQAHandle()]
class FileBufferHandle:
buffer = None
def get_buffer(self, file):
if self.buffer is None:
self.buffer = file.read()
return self.buffer
class DocumentEditInstanceSerializer(ApiMixin, serializers.Serializer):
meta = serializers.DictField(required=False)
@ -87,16 +102,19 @@ class DocumentWebInstanceSerializer(ApiMixin, serializers.Serializer):
"选择器"))
@staticmethod
def get_request_body_api():
return openapi.Schema(
type=openapi.TYPE_OBJECT,
required=['source_url_list'],
properties={
'source_url_list': openapi.Schema(type=openapi.TYPE_ARRAY, title="段落列表", description="段落列表",
items=openapi.Schema(type=openapi.TYPE_STRING)),
'selector': openapi.Schema(type=openapi.TYPE_STRING, title="文档名称", description="文档名称")
}
)
def get_request_params_api():
return [openapi.Parameter(name='file',
in_=openapi.IN_FORM,
type=openapi.TYPE_ARRAY,
items=openapi.Items(type=openapi.TYPE_FILE),
required=True,
description='上传文件'),
openapi.Parameter(name='dataset_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='知识库id'),
]
class DocumentInstanceSerializer(ApiMixin, serializers.Serializer):
@ -120,7 +138,48 @@ class DocumentInstanceSerializer(ApiMixin, serializers.Serializer):
)
class DocumentInstanceQASerializer(ApiMixin, serializers.Serializer):
file_list = serializers.ListSerializer(required=True,
error_messages=ErrMessage.list("文件列表"),
child=serializers.FileField(required=True,
error_messages=ErrMessage.file("文件")))
class DocumentSerializers(ApiMixin, serializers.Serializer):
class Export(ApiMixin, serializers.Serializer):
type = serializers.CharField(required=True, validators=[
validators.RegexValidator(regex=re.compile("^csv|excel$"),
message="模版类型只支持excel|csv",
code=500)
], error_messages=ErrMessage.char("模版类型"))
@staticmethod
def get_request_params_api():
return [openapi.Parameter(name='type',
in_=openapi.IN_QUERY,
type=openapi.TYPE_STRING,
required=True,
description='导出模板类型csv|excel'),
]
def export(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
if self.data.get('type') == 'csv':
file = open(os.path.join(PROJECT_DIR, "apps", "dataset", 'template', 'csv_template.csv'), "rb")
content = file.read()
file.close()
return HttpResponse(content, status=200, headers={'Content-Type': 'text/cxv',
'Content-Disposition': 'attachment; filename="csv_template.csv"'})
elif self.data.get('type') == 'excel':
file = open(os.path.join(PROJECT_DIR, "apps", "dataset", 'template', 'excel_template.xlsx'), "rb")
content = file.read()
file.close()
return HttpResponse(content, status=200, headers={'Content-Type': 'application/vnd.ms-excel',
'Content-Disposition': 'attachment; filename="excel_template.xlsx"'})
class Migrate(ApiMixin, serializers.Serializer):
dataset_id = serializers.UUIDField(required=True,
error_messages=ErrMessage.char(
@ -473,6 +532,22 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
ListenerManagement.embedding_by_document_signal.send(document_id)
return result
@staticmethod
def parse_qa_file(file):
for parse_qa_handle in parse_qa_handle_list:
get_buffer = FileBufferHandle().get_buffer
if parse_qa_handle.support(file, get_buffer):
return parse_qa_handle.handle(file, get_buffer)
raise AppApiException(500, '不支持的文件格式')
def save_qa(self, instance: Dict, with_valid=True):
if with_valid:
DocumentInstanceQASerializer(data=instance).is_valid(raise_exception=True)
self.is_valid(raise_exception=True)
file_list = instance.get('file_list')
document_list = flat_map([self.parse_qa_file(file) for file in file_list])
return DocumentSerializers.Batch(data={'dataset_id': self.data.get('dataset_id')}).batch_save(document_list)
@post(post_function=post_embedding)
@transaction.atomic
def save(self, instance: Dict, with_valid=False, **kwargs):
@ -714,6 +789,8 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
problem_paragraph_mapping_list) > 0 else None
# 查询文档
query_set = QuerySet(model=Document)
if len(document_model_list) == 0:
return [],
query_set = query_set.filter(**{'id__in': [d.id for d in document_model_list]})
return native_search(query_set, select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_document.sql')), with_search_one=False),

View File

@ -0,0 +1,8 @@
分段标题(选填),分段内容必填问题答案最长不超过4096个字符,问题(选填,单元格内一行一个)
MaxKB产品介绍,"MaxKB 是一款基于 LLM 大语言模型的知识库问答系统。MaxKB = Max Knowledge Base旨在成为企业的最强大脑。
开箱即用:支持直接上传文档、自动爬取在线文档,支持文本自动拆分、向量化,智能问答交互体验好;
无缝嵌入:支持零编码快速嵌入到第三方业务系统;
多模型支持:支持对接主流的大模型,包括 Ollama 本地私有大模型(如 Llama 2、Llama 3、qwen、通义千问、OpenAI、Azure OpenAI、Kimi、智谱 AI、讯飞星火和百度千帆大模型等。","MaxKB是什么
MaxKB产品介绍
MaxKB支持的大语言模型
MaxKB优势"
1 分段标题(选填) 分段内容(必填,问题答案,最长不超过4096个字符)) 问题(选填,单元格内一行一个)
2 MaxKB产品介绍 MaxKB 是一款基于 LLM 大语言模型的知识库问答系统。MaxKB = Max Knowledge Base,旨在成为企业的最强大脑。 开箱即用:支持直接上传文档、自动爬取在线文档,支持文本自动拆分、向量化,智能问答交互体验好; 无缝嵌入:支持零编码快速嵌入到第三方业务系统; 多模型支持:支持对接主流的大模型,包括 Ollama 本地私有大模型(如 Llama 2、Llama 3、qwen)、通义千问、OpenAI、Azure OpenAI、Kimi、智谱 AI、讯飞星火和百度千帆大模型等。 MaxKB是什么? MaxKB产品介绍 MaxKB支持的大语言模型 MaxKB优势

Binary file not shown.

View File

@ -6,13 +6,16 @@ app_name = "dataset"
urlpatterns = [
path('dataset', views.Dataset.as_view(), name="dataset"),
path('dataset/web', views.Dataset.CreateWebDataset.as_view()),
path('dataset/qa', views.Dataset.CreateQADataset.as_view()),
path('dataset/<str:dataset_id>', views.Dataset.Operate.as_view(), name="dataset_key"),
path('dataset/<str:dataset_id>/application', views.Dataset.Application.as_view()),
path('dataset/<int:current_page>/<int:page_size>', views.Dataset.Page.as_view(), name="dataset"),
path('dataset/<str:dataset_id>/sync_web', views.Dataset.SyncWeb.as_view()),
path('dataset/<str:dataset_id>/hit_test', views.Dataset.HitTest.as_view()),
path('dataset/<str:dataset_id>/document', views.Document.as_view(), name='document'),
path('dataset/document/template/export', views.Template.as_view()),
path('dataset/<str:dataset_id>/document/web', views.WebDocument.as_view()),
path('dataset/<str:dataset_id>/document/qa', views.QaDocument.as_view()),
path('dataset/<str:dataset_id>/document/_bach', views.Document.Batch.as_view()),
path('dataset/<str:dataset_id>/document/batch_hit_handling', views.Document.BatchEditHitHandling.as_view()),
path('dataset/<str:dataset_id>/document/<int:current_page>/<int:page_size>', views.Document.Page.as_view()),

View File

@ -9,6 +9,7 @@
from drf_yasg.utils import swagger_auto_schema
from rest_framework.decorators import action
from rest_framework.parsers import MultiPartParser
from rest_framework.views import APIView
from rest_framework.views import Request
@ -44,6 +45,27 @@ class Dataset(APIView):
data={'sync_type': request.query_params.get('sync_type'), 'id': dataset_id,
'user_id': str(request.user.id)}).sync())
class CreateQADataset(APIView):
authentication_classes = [TokenAuth]
parser_classes = [MultiPartParser]
@action(methods=['POST'], detail=False)
@swagger_auto_schema(operation_summary="创建QA知识库",
operation_id="创建QA知识库",
manual_parameters=DataSetSerializers.Create.CreateQASerializers.get_request_params_api(),
responses=get_api_response(
DataSetSerializers.Create.CreateQASerializers.get_response_body_api()),
tags=["知识库"]
)
@has_permissions(PermissionConstants.DATASET_CREATE, compare=CompareConstants.AND)
def post(self, request: Request):
return result.success(DataSetSerializers.Create(data={'user_id': request.user.id}).save_qa({
'file_list': request.FILES.getlist('file'),
'name': request.data.get('name'),
'desc': request.data.get('desc')
}))
class CreateWebDataset(APIView):
authentication_classes = [TokenAuth]

View File

@ -22,6 +22,18 @@ from dataset.serializers.document_serializers import DocumentSerializers, Docume
from dataset.swagger_api.document_api import DocumentApi
class Template(APIView):
authentication_classes = [TokenAuth]
@action(methods=['GET'], detail=False)
@swagger_auto_schema(operation_summary="获取QA模版",
operation_id="获取QA模版",
manual_parameters=DocumentSerializers.Export.get_request_params_api(),
tags=["知识库/文档"])
def get(self, request: Request):
return DocumentSerializers.Export(data={'type': request.query_params.get('type')}).export(with_valid=True)
class WebDocument(APIView):
authentication_classes = [TokenAuth]
@ -40,6 +52,26 @@ class WebDocument(APIView):
DocumentSerializers.Create(data={'dataset_id': dataset_id}).save_web(request.data, with_valid=True))
class QaDocument(APIView):
authentication_classes = [TokenAuth]
parser_classes = [MultiPartParser]
@action(methods=['POST'], detail=False)
@swagger_auto_schema(operation_summary="导入QA并创建文档",
operation_id="导入QA并创建文档",
manual_parameters=DocumentWebInstanceSerializer.get_request_params_api(),
responses=result.get_api_response(DocumentSerializers.Create.get_response_body_api()),
tags=["知识库/文档"])
@has_permissions(
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
dynamic_tag=k.get('dataset_id')))
def post(self, request: Request, dataset_id: str):
return result.success(
DocumentSerializers.Create(data={'dataset_id': dataset_id}).save_qa(
{'file_list': request.FILES.getlist('file')},
with_valid=True))
class Document(APIView):
authentication_classes = [TokenAuth]

View File

@ -38,6 +38,8 @@ httpx = "^0.27.0"
httpx-sse = "^0.4.0"
websocket-client = "^1.7.0"
langchain-google-genai = "^1.0.3"
openpyxl = "^3.1.2"
xlrd = "^2.0.1"
[build-system]
requires = ["poetry-core"]

View File

@ -93,6 +93,22 @@ const postWebDataset: (data: any, loading?: Ref<boolean>) => Promise<Result<any>
return post(`${prefix}/web`, data, undefined, loading)
}
/**
* QA知识库
* @param formData
* {
"file": "file",
"name": "string",
"desc": "string",
}
*/
const postQADataset: (data: any, loading?: Ref<boolean>) => Promise<Result<any>> = (
data,
loading
) => {
return post(`${prefix}/qa`, data, undefined, loading)
}
/**
*
* @param dataset_id
@ -170,5 +186,6 @@ export default {
listUsableApplication,
getDatasetHitTest,
postWebDataset,
postQADataset,
putSyncWebDataset
}

View File

@ -1,5 +1,5 @@
import { Result } from '@/request/Result'
import { get, post, del, put } from '@/request/index'
import { get, post, del, put, exportExcel } from '@/request/index'
import type { Ref } from 'vue'
import type { KeyValue } from '@/api/type/common'
import type { pageRequest } from '@/api/type/common'
@ -188,6 +188,20 @@ const postWebDocument: (
return post(`${prefix}/${dataset_id}/document/web`, data, undefined, loading)
}
/**
* QA文档
* @param
* file
}
*/
const postQADocument: (
dataset_id: string,
data: any,
loading?: Ref<boolean>
) => Promise<Result<any>> = (dataset_id, data, loading) => {
return post(`${prefix}/${dataset_id}/document/qa`, data, undefined, loading)
}
/**
*
* @param dataset_id,target_dataset_id,
@ -220,6 +234,19 @@ const batchEditHitHandling: (
) => Promise<Result<boolean>> = (dataset_id, data, loading) => {
return put(`${prefix}/${dataset_id}/document/batch_hit_handling`, data, undefined, loading)
}
/**
* QA模版
* @param fileName,type,
*/
const exportQATemplate: (fileName: string, type: string, loading?: Ref<boolean>) => void = (
fileName,
type,
loading
) => {
return exportExcel(fileName, `${prefix}/document/template/export`, { type }, loading)
}
export default {
postSplitDocument,
getDocument,
@ -234,5 +261,7 @@ export default {
delMulSyncDocument,
postWebDocument,
putMigrateMulDocument,
batchEditHitHandling
batchEditHitHandling,
exportQATemplate,
postQADocument
}

View File

@ -0,0 +1,5 @@
<svg width="24" height="24" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M4 2.5C4 1.94772 4.44772 1.5 5 1.5H14.7929C14.9255 1.5 15.0527 1.55268 15.1464 1.64645L19.8536 6.35355C19.9473 6.44732 20 6.5745 20 6.70711V21.5C20 22.0523 19.5523 22.5 19 22.5H5C4.44772 22.5 4 22.0523 4 21.5V2.5Z" fill="#34C724"/>
<path d="M15 1.54492C15.054 1.56949 15.1037 1.6037 15.1464 1.64646L19.8536 6.35357C19.8963 6.39632 19.9305 6.44602 19.9551 6.50001H16C15.4477 6.50001 15 6.0523 15 5.50001V1.54492Z" fill="#2CA91F"/>
<path d="M11.308 13.5956L8.33203 17.9996H9.60403L11.98 14.4596L14.284 17.9996H15.676L12.664 13.5956L15.496 9.43164H14.224L11.992 12.7796L9.85603 9.43164H8.48803L11.308 13.5956Z" fill="white"/>
</svg>

After

Width:  |  Height:  |  Size: 735 B

View File

@ -0,0 +1,5 @@
<svg width="24" height="24" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M4 2.5C4 1.94772 4.44772 1.5 5 1.5H14.7929C14.9255 1.5 15.0527 1.55268 15.1464 1.64645L19.8536 6.35355C19.9473 6.44732 20 6.5745 20 6.70711V21.5C20 22.0523 19.5523 22.5 19 22.5H5C4.44772 22.5 4 22.0523 4 21.5V2.5Z" fill="#34C724"/>
<path d="M15 1.54492C15.054 1.56949 15.1037 1.6037 15.1464 1.64646L19.8536 6.35357C19.8963 6.39632 19.9305 6.44602 19.9551 6.50001H16C15.4477 6.50001 15 6.0523 15 5.50001V1.54492Z" fill="#2CA91F"/>
<path d="M11.308 13.5956L8.33203 17.9996H9.60403L11.98 14.4596L14.284 17.9996H15.676L12.664 13.5956L15.496 9.43164H14.224L11.992 12.7796L9.85603 9.43164H8.48803L11.308 13.5956Z" fill="white"/>
</svg>

After

Width:  |  Height:  |  Size: 735 B

View File

@ -7,6 +7,7 @@ import { type Ref } from 'vue'
export interface datasetStateTypes {
baseInfo: datasetData | null
webInfo: any
documentsType: string
documentsFiles: UploadUserFile[]
}
@ -15,6 +16,7 @@ const useDatasetStore = defineStore({
state: (): datasetStateTypes => ({
baseInfo: null,
webInfo: null,
documentsType: '',
documentsFiles: []
}),
actions: {
@ -24,6 +26,9 @@ const useDatasetStore = defineStore({
saveWebInfo(info: any) {
this.webInfo = info
},
saveDocumentsType(val: string) {
this.documentsType = val
},
saveDocumentsFile(file: UploadUserFile[]) {
this.documentsFiles = file
},

View File

@ -37,14 +37,20 @@ export function fileType(name: string) {
/*
*/
const typeList: any = {
txt: ['txt', 'pdf', 'docx', 'csv', 'md'],
QA: ['xlsx', 'csv', 'xls']
}
export function getImgUrl(name: string) {
const type = isRightType(name) ? fileType(name) : 'unknow'
const list = Object.values(typeList).flat()
const type = list.includes(fileType(name)) ? fileType(name) : 'unknow'
return new URL(`../assets/${type}-icon.svg`, import.meta.url).href
}
// 是否是白名单后缀
export function isRightType(name: string) {
const typeList = ['txt', 'pdf', 'docx', 'csv', 'md', 'html']
return typeList.includes(fileType(name))
export function isRightType(name: string, type: string) {
return typeList[type].includes(fileType(name))
}
/*

View File

@ -40,6 +40,7 @@ import StepFirst from './step/StepFirst.vue'
import StepSecond from './step/StepSecond.vue'
import ResultSuccess from './step/ResultSuccess.vue'
import datasetApi from '@/api/dataset'
import documentApi from '@/api/document'
import type { datasetData } from '@/api/type/dataset'
import { MsgConfirm, MsgSuccess } from '@/utils/message'
@ -48,6 +49,7 @@ const { dataset, document } = useStore()
const baseInfo = computed(() => dataset.baseInfo)
const webInfo = computed(() => dataset.webInfo)
const documentsFiles = computed(() => dataset.documentsFiles)
const documentsType = computed(() => dataset.documentsType)
const router = useRouter()
const route = useRoute()
@ -80,7 +82,33 @@ const successInfo = ref<any>(null)
async function next() {
disabled.value = true
if (await StepFirstRef.value?.onSubmit()) {
if (active.value++ > 2) active.value = 0
if (documentsType.value === 'QA') {
let fd = new FormData()
documentsFiles.value.forEach((item: any) => {
if (item?.raw) {
fd.append('file', item?.raw)
}
})
if (id) {
// QA
documentApi.postQADocument(id as string, fd, loading).then((res) => {
MsgSuccess('提交成功')
clearStore()
router.push({ path: `/dataset/${id}/document` })
})
} else {
// QA
fd.append('name', baseInfo.value?.name as string)
fd.append('desc', baseInfo.value?.desc as string)
datasetApi.postQADataset(fd, loading).then((res) => {
successInfo.value = res.data
active.value = 2
})
}
} else {
if (active.value++ > 2) active.value = 0
}
} else {
disabled.value = false
}
@ -93,6 +121,7 @@ function clearStore() {
dataset.saveBaseInfo(null)
dataset.saveWebInfo(null)
dataset.saveDocumentsFile([])
dataset.saveDocumentsType('')
}
function submit() {
loading.value = true

View File

@ -7,7 +7,48 @@
label-position="top"
require-asterisk-position="right"
>
<el-form-item prop="fileList">
<el-form-item>
<el-radio-group v-model="form.fileType" @change="radioChange">
<el-radio value="txt">文本文件</el-radio>
<el-radio value="QA">QA 问答对</el-radio>
</el-radio-group>
</el-form-item>
<el-form-item prop="fileList" v-if="form.fileType === 'QA'">
<el-upload
:webkitdirectory="false"
class="w-full mb-4"
drag
multiple
v-model:file-list="form.fileList"
action="#"
:auto-upload="false"
:show-file-list="false"
accept=".xlsx, .xls, .csv"
:limit="50"
:on-exceed="onExceed"
:on-change="fileHandleChange"
@click.prevent="handlePreview(false)"
>
<img src="@/assets/upload-icon.svg" alt="" />
<div class="el-upload__text">
<p>
拖拽文件至此上传或
<em class="hover" @click.prevent="handlePreview(false)"> 选择文件 </em>
<em class="hover" @click.prevent="handlePreview(true)"> 选择文件夹 </em>
</p>
<div class="upload__decoration">
<p>当前支持 XLSX / XLS / CSV 格式的文档</p>
<p>每次最多上传50个文件每个文件不超过 100MB</p>
</div>
</div>
</el-upload>
<el-button type="primary" link @click="downloadTemplate('excel')">
下载 Excel 模板
</el-button>
<el-divider direction="vertical" />
<el-button type="primary" link @click="downloadTemplate('csv')"> 下载 CSV 模板 </el-button>
</el-form-item>
<el-form-item prop="fileList" v-else>
<el-upload
:webkitdirectory="false"
class="w-full"
@ -63,13 +104,16 @@
</template>
<script setup lang="ts">
import { ref, reactive, onUnmounted, onMounted, computed, watch, nextTick } from 'vue'
import type { UploadFile, UploadFiles } from 'element-plus'
import type { UploadFiles } from 'element-plus'
import { filesize, getImgUrl, isRightType } from '@/utils/utils'
import { MsgError } from '@/utils/message'
import documentApi from '@/api/document'
import useStore from '@/stores'
const { dataset } = useStore()
const documentsFiles = computed(() => dataset.documentsFiles)
const documentsType = computed(() => dataset.documentsType)
const form = ref({
fileType: 'txt',
fileList: [] as any
})
@ -79,22 +123,32 @@ const rules = reactive({
const FormRef = ref()
watch(form.value, (value) => {
dataset.saveDocumentsType(value.fileType)
dataset.saveDocumentsFile(value.fileList)
})
function downloadTemplate(type: string) {
documentApi.exportQATemplate(`${type}模版.${type == 'csv' ? type : 'xlsx'}`, type)
}
function radioChange() {
form.value.fileList = []
}
function deleteFile(index: number) {
form.value.fileList.splice(index, 1)
}
// on-change
const fileHandleChange = (file: any, fileList: UploadFiles) => {
//110M
//1100M
const isLimit = file?.size / 1024 / 1024 < 100
if (!isLimit) {
MsgError('文件大小超过 100MB')
fileList.splice(-1, 1) //
return false
}
if (!isRightType(file?.name)) {
if (!isRightType(file?.name, form.value.fileType)) {
MsgError('文件格式不支持')
fileList.splice(-1, 1)
return false
@ -126,12 +180,16 @@ function validate() {
}
onMounted(() => {
if (documentsType.value) {
form.value.fileType = documentsType.value
}
if (documentsFiles.value) {
form.value.fileList = documentsFiles.value
}
})
onUnmounted(() => {
form.value = {
fileType: 'txt',
fileList: []
}
})

View File

@ -110,6 +110,7 @@ watch(form.value, (value) => {
function radioChange() {
dataset.saveDocumentsFile([])
dataset.saveDocumentsType('')
form.value.source_url = ''
form.value.selector = ''
}
@ -126,6 +127,7 @@ const onSubmit = async () => {
stores保存数据
*/
dataset.saveBaseInfo(BaseFormRef.value.form)
dataset.saveDocumentsType(UploadComponentRef.value.form.fileType)
dataset.saveDocumentsFile(UploadComponentRef.value.form.fileList)
return true
}
@ -156,6 +158,7 @@ const onSubmit = async () => {
/*
stores保存数据
*/
dataset.saveDocumentsType(UploadComponentRef.value.form.fileType)
dataset.saveDocumentsFile(UploadComponentRef.value.form.fileList)
return true
} else {