mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-25 17:22:55 +00:00
* feat: 支持上传 Excel/CSV 类型的问答对 (#430)
This commit is contained in:
parent
6a226be048
commit
28938104c0
|
|
@ -0,0 +1,19 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: base_parse_qa_handle.py
|
||||
@date:2024/5/21 14:56
|
||||
@desc:
|
||||
"""
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class BaseParseQAHandle(ABC):
|
||||
@abstractmethod
|
||||
def support(self, file, get_buffer):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def handle(self, file, get_buffer):
|
||||
pass
|
||||
|
|
@ -0,0 +1,56 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: csv_parse_qa_handle.py
|
||||
@date:2024/5/21 14:59
|
||||
@desc:
|
||||
"""
|
||||
import csv
|
||||
import io
|
||||
|
||||
from charset_normalizer import detect
|
||||
|
||||
from common.handle.base_parse_qa_handle import BaseParseQAHandle
|
||||
|
||||
|
||||
def read_csv_standard(file_path):
|
||||
data = []
|
||||
with open(file_path, 'r') as file:
|
||||
reader = csv.reader(file)
|
||||
for row in reader:
|
||||
data.append(row)
|
||||
return data
|
||||
|
||||
|
||||
class CsvParseQAHandle(BaseParseQAHandle):
|
||||
def support(self, file, get_buffer):
|
||||
file_name: str = file.name.lower()
|
||||
if file_name.endswith(".csv"):
|
||||
return True
|
||||
return False
|
||||
|
||||
def handle(self, file, get_buffer):
|
||||
buffer = get_buffer(file)
|
||||
reader = csv.reader(io.TextIOWrapper(io.BytesIO(buffer), encoding=detect(buffer)['encoding']))
|
||||
try:
|
||||
title_row_list = reader.__next__()
|
||||
except Exception as e:
|
||||
return []
|
||||
title_row_index_dict = {'title': 0, 'content': 1, 'problem_list': 2}
|
||||
for index in range(len(title_row_list)):
|
||||
title_row = title_row_list[index]
|
||||
if title_row.startswith('分段标题'):
|
||||
title_row_index_dict['title'] = index
|
||||
if title_row.startswith('分段内容'):
|
||||
title_row_index_dict['content'] = index
|
||||
if title_row.startswith('问题'):
|
||||
title_row_index_dict['problem_list'] = index
|
||||
paragraph_list = []
|
||||
for row in reader:
|
||||
problem = row[title_row_index_dict.get('problem_list')]
|
||||
problem_list = [{'content': p[0:255]} for p in problem.split('\n') if len(p.strip()) > 0]
|
||||
paragraph_list.append({'title': row[title_row_index_dict.get('title')][0:255],
|
||||
'content': row[title_row_index_dict.get('content')][0:4096],
|
||||
'problem_list': problem_list})
|
||||
return [{'name': file.name, 'paragraphs': paragraph_list}]
|
||||
|
|
@ -0,0 +1,55 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: xls_parse_qa_handle.py
|
||||
@date:2024/5/21 14:59
|
||||
@desc:
|
||||
"""
|
||||
|
||||
import xlrd
|
||||
|
||||
from common.handle.base_parse_qa_handle import BaseParseQAHandle
|
||||
|
||||
|
||||
def handle_sheet(file_name, sheet):
|
||||
rows = iter([sheet.row_values(i) for i in range(sheet.nrows)])
|
||||
try:
|
||||
title_row_list = next(rows)
|
||||
except Exception as e:
|
||||
return None
|
||||
title_row_index_dict = {'title': 0, 'content': 1, 'problem_list': 2}
|
||||
for index in range(len(title_row_list)):
|
||||
title_row = str(title_row_list[index])
|
||||
if title_row.startswith('分段标题'):
|
||||
title_row_index_dict['title'] = index
|
||||
if title_row.startswith('分段内容'):
|
||||
title_row_index_dict['content'] = index
|
||||
if title_row.startswith('问题'):
|
||||
title_row_index_dict['problem_list'] = index
|
||||
paragraph_list = []
|
||||
for row in rows:
|
||||
problem = str(row[title_row_index_dict.get('problem_list')])
|
||||
problem_list = [{'content': p[0:255]} for p in problem.split('\n') if len(p.strip()) > 0]
|
||||
paragraph_list.append({'title': str(row[title_row_index_dict.get('title')])[0:255],
|
||||
'content': str(row[title_row_index_dict.get('content')])[0:4096],
|
||||
'problem_list': problem_list})
|
||||
return {'name': file_name, 'paragraphs': paragraph_list}
|
||||
|
||||
|
||||
class XlsParseQAHandle(BaseParseQAHandle):
|
||||
def support(self, file, get_buffer):
|
||||
file_name: str = file.name.lower()
|
||||
if file_name.endswith(".xls"):
|
||||
return True
|
||||
return False
|
||||
|
||||
def handle(self, file, get_buffer):
|
||||
buffer = get_buffer(file)
|
||||
workbook = xlrd.open_workbook(file_contents=buffer)
|
||||
worksheets = workbook.sheets()
|
||||
worksheets_size = len(worksheets)
|
||||
return [row for row in
|
||||
[handle_sheet(file.name, sheet) if worksheets_size == 1 and sheet.name == 'Sheet1' else handle_sheet(
|
||||
sheet.name, sheet) for sheet
|
||||
in worksheets] if row is not None]
|
||||
|
|
@ -0,0 +1,56 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: xlsx_parse_qa_handle.py
|
||||
@date:2024/5/21 14:59
|
||||
@desc:
|
||||
"""
|
||||
import io
|
||||
|
||||
import openpyxl
|
||||
|
||||
from common.handle.base_parse_qa_handle import BaseParseQAHandle
|
||||
|
||||
|
||||
def handle_sheet(file_name, sheet):
|
||||
rows = sheet.rows
|
||||
try:
|
||||
title_row_list = next(rows)
|
||||
except Exception as e:
|
||||
return None
|
||||
title_row_index_dict = {}
|
||||
for index in range(len(title_row_list)):
|
||||
title_row = str(title_row_list[index].value)
|
||||
if title_row.startswith('分段标题'):
|
||||
title_row_index_dict['title'] = index
|
||||
if title_row.startswith('分段内容'):
|
||||
title_row_index_dict['content'] = index
|
||||
if title_row.startswith('问题'):
|
||||
title_row_index_dict['problem_list'] = index
|
||||
paragraph_list = []
|
||||
for row in rows:
|
||||
problem = str(row[title_row_index_dict.get('problem_list')].value)
|
||||
problem_list = [{'content': p[0:255]} for p in problem.split('\n') if len(p.strip()) > 0]
|
||||
paragraph_list.append({'title': str(row[title_row_index_dict.get('title')].value)[0:255],
|
||||
'content': str(row[title_row_index_dict.get('content')].value)[0:4096],
|
||||
'problem_list': problem_list})
|
||||
return {'name': file_name, 'paragraphs': paragraph_list}
|
||||
|
||||
|
||||
class XlsxParseQAHandle(BaseParseQAHandle):
|
||||
def support(self, file, get_buffer):
|
||||
file_name: str = file.name.lower()
|
||||
if file_name.endswith(".xlsx"):
|
||||
return True
|
||||
return False
|
||||
|
||||
def handle(self, file, get_buffer):
|
||||
buffer = get_buffer(file)
|
||||
workbook = openpyxl.load_workbook(io.BytesIO(buffer))
|
||||
worksheets = workbook.worksheets
|
||||
worksheets_size = len(worksheets)
|
||||
return [row for row in
|
||||
[handle_sheet(file.name, sheet) if worksheets_size == 1 and sheet.title == 'Sheet1' else handle_sheet(
|
||||
sheet.title, sheet) for sheet
|
||||
in worksheets] if row is not None]
|
||||
|
|
@ -45,6 +45,18 @@ def get_exec_method(clazz_: str, method_: str):
|
|||
return getattr(getattr(package_model, clazz_name), method_)
|
||||
|
||||
|
||||
def flat_map(array: List[List]):
|
||||
"""
|
||||
将二位数组转为一维数组
|
||||
:param array: 二维数组
|
||||
:return: 一维数组
|
||||
"""
|
||||
result = []
|
||||
for e in array:
|
||||
result += e
|
||||
return result
|
||||
|
||||
|
||||
def post(post_function):
|
||||
def inner(func):
|
||||
def run(*args, **kwargs):
|
||||
|
|
|
|||
|
|
@ -104,3 +104,13 @@ class ErrMessage:
|
|||
'invalid_image': gettext_lazy('【%s】上载有效的图像。您上载的文件不是图像或图像已损坏。' % field),
|
||||
'max_length': gettext_lazy('请确保此文件名最多包含 {max_length} 个字符(长度为 {length})。')
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def file(field: str):
|
||||
return {
|
||||
'required': gettext_lazy('【%s】此字段必填。' % field),
|
||||
'empty': gettext_lazy('【%s】提交的文件为空。' % field),
|
||||
'invalid': gettext_lazy('【%s】提交的数据不是文件。请检查表单上的编码类型。' % field),
|
||||
'no_name': gettext_lazy('【%s】无法确定任何文件名。' % field),
|
||||
'max_length': gettext_lazy('请确保此文件名最多包含 {max_length} 个字符(长度为 {length})。')
|
||||
}
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ from urllib.parse import urlparse
|
|||
from django.contrib.postgres.fields import ArrayField
|
||||
from django.core import validators
|
||||
from django.db import transaction, models
|
||||
from django.db.models import QuerySet, Q
|
||||
from django.db.models import QuerySet
|
||||
from drf_yasg import openapi
|
||||
from rest_framework import serializers
|
||||
|
||||
|
|
@ -29,7 +29,7 @@ from common.db.sql_execute import select_list
|
|||
from common.event import ListenerManagement, SyncWebDatasetArgs
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.mixins.api_mixin import ApiMixin
|
||||
from common.util.common import post
|
||||
from common.util.common import post, flat_map
|
||||
from common.util.field_message import ErrMessage
|
||||
from common.util.file_util import get_file_content
|
||||
from common.util.fork import ChildLink, Fork
|
||||
|
|
@ -210,6 +210,75 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||
super().is_valid(raise_exception=True)
|
||||
return True
|
||||
|
||||
class CreateQASerializers(serializers.Serializer):
|
||||
"""
|
||||
创建web站点序列化对象
|
||||
"""
|
||||
name = serializers.CharField(required=True,
|
||||
error_messages=ErrMessage.char("知识库名称"),
|
||||
max_length=64,
|
||||
min_length=1)
|
||||
|
||||
desc = serializers.CharField(required=True,
|
||||
error_messages=ErrMessage.char("知识库描述"),
|
||||
max_length=256,
|
||||
min_length=1)
|
||||
|
||||
file_list = serializers.ListSerializer(required=True,
|
||||
error_messages=ErrMessage.list("文件列表"),
|
||||
child=serializers.FileField(required=True,
|
||||
error_messages=ErrMessage.file("文件")))
|
||||
|
||||
@staticmethod
|
||||
def get_request_params_api():
|
||||
return [openapi.Parameter(name='file',
|
||||
in_=openapi.IN_FORM,
|
||||
type=openapi.TYPE_ARRAY,
|
||||
items=openapi.Items(type=openapi.TYPE_FILE),
|
||||
required=True,
|
||||
description='上传文件'),
|
||||
openapi.Parameter(name='name',
|
||||
in_=openapi.IN_FORM,
|
||||
required=True,
|
||||
type=openapi.TYPE_STRING, title="知识库名称", description="知识库名称"),
|
||||
openapi.Parameter(name='desc',
|
||||
in_=openapi.IN_FORM,
|
||||
required=True,
|
||||
type=openapi.TYPE_STRING, title="知识库描述", description="知识库描述"),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_response_body_api():
|
||||
return openapi.Schema(
|
||||
type=openapi.TYPE_OBJECT,
|
||||
required=['id', 'name', 'desc', 'user_id', 'char_length', 'document_count',
|
||||
'update_time', 'create_time', 'document_list'],
|
||||
properties={
|
||||
'id': openapi.Schema(type=openapi.TYPE_STRING, title="id",
|
||||
description="id", default="xx"),
|
||||
'name': openapi.Schema(type=openapi.TYPE_STRING, title="名称",
|
||||
description="名称", default="测试知识库"),
|
||||
'desc': openapi.Schema(type=openapi.TYPE_STRING, title="描述",
|
||||
description="描述", default="测试知识库描述"),
|
||||
'user_id': openapi.Schema(type=openapi.TYPE_STRING, title="所属用户id",
|
||||
description="所属用户id", default="user_xxxx"),
|
||||
'char_length': openapi.Schema(type=openapi.TYPE_STRING, title="字符数",
|
||||
description="字符数", default=10),
|
||||
'document_count': openapi.Schema(type=openapi.TYPE_STRING, title="文档数量",
|
||||
description="文档数量", default=1),
|
||||
'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间",
|
||||
description="修改时间",
|
||||
default="1970-01-01 00:00:00"),
|
||||
'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间",
|
||||
description="创建时间",
|
||||
default="1970-01-01 00:00:00"
|
||||
),
|
||||
'document_list': openapi.Schema(type=openapi.TYPE_ARRAY, title="文档列表",
|
||||
description="文档列表",
|
||||
items=DocumentSerializers.Operate.get_response_body_api())
|
||||
}
|
||||
)
|
||||
|
||||
class CreateWebSerializers(serializers.Serializer):
|
||||
"""
|
||||
创建web站点序列化对象
|
||||
|
|
@ -288,6 +357,15 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||
ListenerManagement.embedding_by_dataset_signal.send(dataset_id)
|
||||
return document_list
|
||||
|
||||
def save_qa(self, instance: Dict, with_valid=True):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
self.CreateQASerializers(data=instance).is_valid()
|
||||
file_list = instance.get('file_list')
|
||||
document_list = flat_map([DocumentSerializers.Create.parse_qa_file(file) for file in file_list])
|
||||
dataset_instance = {'name': instance.get('name'), 'desc': instance.get('desc'), 'documents': document_list}
|
||||
return self.save(dataset_instance, with_valid=True)
|
||||
|
||||
@post(post_function=post_embedding_dataset)
|
||||
@transaction.atomic
|
||||
def save(self, instance: Dict, with_valid=True):
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ from typing import List, Dict
|
|||
from django.core import validators
|
||||
from django.db import transaction
|
||||
from django.db.models import QuerySet
|
||||
from django.http import HttpResponse
|
||||
from drf_yasg import openapi
|
||||
from rest_framework import serializers
|
||||
|
||||
|
|
@ -27,9 +28,12 @@ from common.exception.app_exception import AppApiException
|
|||
from common.handle.impl.doc_split_handle import DocSplitHandle
|
||||
from common.handle.impl.html_split_handle import HTMLSplitHandle
|
||||
from common.handle.impl.pdf_split_handle import PdfSplitHandle
|
||||
from common.handle.impl.qa.csv_parse_qa_handle import CsvParseQAHandle
|
||||
from common.handle.impl.qa.xls_parse_qa_handle import XlsParseQAHandle
|
||||
from common.handle.impl.qa.xlsx_parse_qa_handle import XlsxParseQAHandle
|
||||
from common.handle.impl.text_split_handle import TextSplitHandle
|
||||
from common.mixins.api_mixin import ApiMixin
|
||||
from common.util.common import post
|
||||
from common.util.common import post, flat_map
|
||||
from common.util.field_message import ErrMessage
|
||||
from common.util.file_util import get_file_content
|
||||
from common.util.fork import Fork
|
||||
|
|
@ -39,6 +43,17 @@ from dataset.serializers.common_serializers import BatchSerializer, MetaSerializ
|
|||
from dataset.serializers.paragraph_serializers import ParagraphSerializers, ParagraphInstanceSerializer
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
parse_qa_handle_list = [XlsParseQAHandle(), CsvParseQAHandle(), XlsxParseQAHandle()]
|
||||
|
||||
|
||||
class FileBufferHandle:
|
||||
buffer = None
|
||||
|
||||
def get_buffer(self, file):
|
||||
if self.buffer is None:
|
||||
self.buffer = file.read()
|
||||
return self.buffer
|
||||
|
||||
|
||||
class DocumentEditInstanceSerializer(ApiMixin, serializers.Serializer):
|
||||
meta = serializers.DictField(required=False)
|
||||
|
|
@ -87,16 +102,19 @@ class DocumentWebInstanceSerializer(ApiMixin, serializers.Serializer):
|
|||
"选择器"))
|
||||
|
||||
@staticmethod
|
||||
def get_request_body_api():
|
||||
return openapi.Schema(
|
||||
type=openapi.TYPE_OBJECT,
|
||||
required=['source_url_list'],
|
||||
properties={
|
||||
'source_url_list': openapi.Schema(type=openapi.TYPE_ARRAY, title="段落列表", description="段落列表",
|
||||
items=openapi.Schema(type=openapi.TYPE_STRING)),
|
||||
'selector': openapi.Schema(type=openapi.TYPE_STRING, title="文档名称", description="文档名称")
|
||||
}
|
||||
)
|
||||
def get_request_params_api():
|
||||
return [openapi.Parameter(name='file',
|
||||
in_=openapi.IN_FORM,
|
||||
type=openapi.TYPE_ARRAY,
|
||||
items=openapi.Items(type=openapi.TYPE_FILE),
|
||||
required=True,
|
||||
description='上传文件'),
|
||||
openapi.Parameter(name='dataset_id',
|
||||
in_=openapi.IN_PATH,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=True,
|
||||
description='知识库id'),
|
||||
]
|
||||
|
||||
|
||||
class DocumentInstanceSerializer(ApiMixin, serializers.Serializer):
|
||||
|
|
@ -120,7 +138,48 @@ class DocumentInstanceSerializer(ApiMixin, serializers.Serializer):
|
|||
)
|
||||
|
||||
|
||||
class DocumentInstanceQASerializer(ApiMixin, serializers.Serializer):
|
||||
file_list = serializers.ListSerializer(required=True,
|
||||
error_messages=ErrMessage.list("文件列表"),
|
||||
child=serializers.FileField(required=True,
|
||||
error_messages=ErrMessage.file("文件")))
|
||||
|
||||
|
||||
class DocumentSerializers(ApiMixin, serializers.Serializer):
|
||||
class Export(ApiMixin, serializers.Serializer):
|
||||
type = serializers.CharField(required=True, validators=[
|
||||
validators.RegexValidator(regex=re.compile("^csv|excel$"),
|
||||
message="模版类型只支持excel|csv",
|
||||
code=500)
|
||||
], error_messages=ErrMessage.char("模版类型"))
|
||||
|
||||
@staticmethod
|
||||
def get_request_params_api():
|
||||
return [openapi.Parameter(name='type',
|
||||
in_=openapi.IN_QUERY,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=True,
|
||||
description='导出模板类型csv|excel'),
|
||||
|
||||
]
|
||||
|
||||
def export(self, with_valid=True):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
|
||||
if self.data.get('type') == 'csv':
|
||||
file = open(os.path.join(PROJECT_DIR, "apps", "dataset", 'template', 'csv_template.csv'), "rb")
|
||||
content = file.read()
|
||||
file.close()
|
||||
return HttpResponse(content, status=200, headers={'Content-Type': 'text/cxv',
|
||||
'Content-Disposition': 'attachment; filename="csv_template.csv"'})
|
||||
elif self.data.get('type') == 'excel':
|
||||
file = open(os.path.join(PROJECT_DIR, "apps", "dataset", 'template', 'excel_template.xlsx'), "rb")
|
||||
content = file.read()
|
||||
file.close()
|
||||
return HttpResponse(content, status=200, headers={'Content-Type': 'application/vnd.ms-excel',
|
||||
'Content-Disposition': 'attachment; filename="excel_template.xlsx"'})
|
||||
|
||||
class Migrate(ApiMixin, serializers.Serializer):
|
||||
dataset_id = serializers.UUIDField(required=True,
|
||||
error_messages=ErrMessage.char(
|
||||
|
|
@ -473,6 +532,22 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||
ListenerManagement.embedding_by_document_signal.send(document_id)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def parse_qa_file(file):
|
||||
for parse_qa_handle in parse_qa_handle_list:
|
||||
get_buffer = FileBufferHandle().get_buffer
|
||||
if parse_qa_handle.support(file, get_buffer):
|
||||
return parse_qa_handle.handle(file, get_buffer)
|
||||
raise AppApiException(500, '不支持的文件格式')
|
||||
|
||||
def save_qa(self, instance: Dict, with_valid=True):
|
||||
if with_valid:
|
||||
DocumentInstanceQASerializer(data=instance).is_valid(raise_exception=True)
|
||||
self.is_valid(raise_exception=True)
|
||||
file_list = instance.get('file_list')
|
||||
document_list = flat_map([self.parse_qa_file(file) for file in file_list])
|
||||
return DocumentSerializers.Batch(data={'dataset_id': self.data.get('dataset_id')}).batch_save(document_list)
|
||||
|
||||
@post(post_function=post_embedding)
|
||||
@transaction.atomic
|
||||
def save(self, instance: Dict, with_valid=False, **kwargs):
|
||||
|
|
@ -714,6 +789,8 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||
problem_paragraph_mapping_list) > 0 else None
|
||||
# 查询文档
|
||||
query_set = QuerySet(model=Document)
|
||||
if len(document_model_list) == 0:
|
||||
return [],
|
||||
query_set = query_set.filter(**{'id__in': [d.id for d in document_model_list]})
|
||||
return native_search(query_set, select_string=get_file_content(
|
||||
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_document.sql')), with_search_one=False),
|
||||
|
|
|
|||
|
|
@ -0,0 +1,8 @@
|
|||
分段标题(选填),分段内容(必填,问题答案,最长不超过4096个字符)),问题(选填,单元格内一行一个)
|
||||
MaxKB产品介绍,"MaxKB 是一款基于 LLM 大语言模型的知识库问答系统。MaxKB = Max Knowledge Base,旨在成为企业的最强大脑。
|
||||
开箱即用:支持直接上传文档、自动爬取在线文档,支持文本自动拆分、向量化,智能问答交互体验好;
|
||||
无缝嵌入:支持零编码快速嵌入到第三方业务系统;
|
||||
多模型支持:支持对接主流的大模型,包括 Ollama 本地私有大模型(如 Llama 2、Llama 3、qwen)、通义千问、OpenAI、Azure OpenAI、Kimi、智谱 AI、讯飞星火和百度千帆大模型等。","MaxKB是什么?
|
||||
MaxKB产品介绍
|
||||
MaxKB支持的大语言模型
|
||||
MaxKB优势"
|
||||
|
Binary file not shown.
|
|
@ -6,13 +6,16 @@ app_name = "dataset"
|
|||
urlpatterns = [
|
||||
path('dataset', views.Dataset.as_view(), name="dataset"),
|
||||
path('dataset/web', views.Dataset.CreateWebDataset.as_view()),
|
||||
path('dataset/qa', views.Dataset.CreateQADataset.as_view()),
|
||||
path('dataset/<str:dataset_id>', views.Dataset.Operate.as_view(), name="dataset_key"),
|
||||
path('dataset/<str:dataset_id>/application', views.Dataset.Application.as_view()),
|
||||
path('dataset/<int:current_page>/<int:page_size>', views.Dataset.Page.as_view(), name="dataset"),
|
||||
path('dataset/<str:dataset_id>/sync_web', views.Dataset.SyncWeb.as_view()),
|
||||
path('dataset/<str:dataset_id>/hit_test', views.Dataset.HitTest.as_view()),
|
||||
path('dataset/<str:dataset_id>/document', views.Document.as_view(), name='document'),
|
||||
path('dataset/document/template/export', views.Template.as_view()),
|
||||
path('dataset/<str:dataset_id>/document/web', views.WebDocument.as_view()),
|
||||
path('dataset/<str:dataset_id>/document/qa', views.QaDocument.as_view()),
|
||||
path('dataset/<str:dataset_id>/document/_bach', views.Document.Batch.as_view()),
|
||||
path('dataset/<str:dataset_id>/document/batch_hit_handling', views.Document.BatchEditHitHandling.as_view()),
|
||||
path('dataset/<str:dataset_id>/document/<int:current_page>/<int:page_size>', views.Document.Page.as_view()),
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@
|
|||
|
||||
from drf_yasg.utils import swagger_auto_schema
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.parsers import MultiPartParser
|
||||
from rest_framework.views import APIView
|
||||
from rest_framework.views import Request
|
||||
|
||||
|
|
@ -44,6 +45,27 @@ class Dataset(APIView):
|
|||
data={'sync_type': request.query_params.get('sync_type'), 'id': dataset_id,
|
||||
'user_id': str(request.user.id)}).sync())
|
||||
|
||||
class CreateQADataset(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
parser_classes = [MultiPartParser]
|
||||
|
||||
@action(methods=['POST'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="创建QA知识库",
|
||||
operation_id="创建QA知识库",
|
||||
|
||||
manual_parameters=DataSetSerializers.Create.CreateQASerializers.get_request_params_api(),
|
||||
responses=get_api_response(
|
||||
DataSetSerializers.Create.CreateQASerializers.get_response_body_api()),
|
||||
tags=["知识库"]
|
||||
)
|
||||
@has_permissions(PermissionConstants.DATASET_CREATE, compare=CompareConstants.AND)
|
||||
def post(self, request: Request):
|
||||
return result.success(DataSetSerializers.Create(data={'user_id': request.user.id}).save_qa({
|
||||
'file_list': request.FILES.getlist('file'),
|
||||
'name': request.data.get('name'),
|
||||
'desc': request.data.get('desc')
|
||||
}))
|
||||
|
||||
class CreateWebDataset(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
|
|
|
|||
|
|
@ -22,6 +22,18 @@ from dataset.serializers.document_serializers import DocumentSerializers, Docume
|
|||
from dataset.swagger_api.document_api import DocumentApi
|
||||
|
||||
|
||||
class Template(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
@action(methods=['GET'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="获取QA模版",
|
||||
operation_id="获取QA模版",
|
||||
manual_parameters=DocumentSerializers.Export.get_request_params_api(),
|
||||
tags=["知识库/文档"])
|
||||
def get(self, request: Request):
|
||||
return DocumentSerializers.Export(data={'type': request.query_params.get('type')}).export(with_valid=True)
|
||||
|
||||
|
||||
class WebDocument(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
|
|
@ -40,6 +52,26 @@ class WebDocument(APIView):
|
|||
DocumentSerializers.Create(data={'dataset_id': dataset_id}).save_web(request.data, with_valid=True))
|
||||
|
||||
|
||||
class QaDocument(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
parser_classes = [MultiPartParser]
|
||||
|
||||
@action(methods=['POST'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="导入QA并创建文档",
|
||||
operation_id="导入QA并创建文档",
|
||||
manual_parameters=DocumentWebInstanceSerializer.get_request_params_api(),
|
||||
responses=result.get_api_response(DocumentSerializers.Create.get_response_body_api()),
|
||||
tags=["知识库/文档"])
|
||||
@has_permissions(
|
||||
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
|
||||
dynamic_tag=k.get('dataset_id')))
|
||||
def post(self, request: Request, dataset_id: str):
|
||||
return result.success(
|
||||
DocumentSerializers.Create(data={'dataset_id': dataset_id}).save_qa(
|
||||
{'file_list': request.FILES.getlist('file')},
|
||||
with_valid=True))
|
||||
|
||||
|
||||
class Document(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
|
|
|
|||
|
|
@ -38,6 +38,8 @@ httpx = "^0.27.0"
|
|||
httpx-sse = "^0.4.0"
|
||||
websocket-client = "^1.7.0"
|
||||
langchain-google-genai = "^1.0.3"
|
||||
openpyxl = "^3.1.2"
|
||||
xlrd = "^2.0.1"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
|
|
|
|||
|
|
@ -93,6 +93,22 @@ const postWebDataset: (data: any, loading?: Ref<boolean>) => Promise<Result<any>
|
|||
return post(`${prefix}/web`, data, undefined, loading)
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建QA知识库
|
||||
* @param 参数 formData
|
||||
* {
|
||||
"file": "file",
|
||||
"name": "string",
|
||||
"desc": "string",
|
||||
}
|
||||
*/
|
||||
const postQADataset: (data: any, loading?: Ref<boolean>) => Promise<Result<any>> = (
|
||||
data,
|
||||
loading
|
||||
) => {
|
||||
return post(`${prefix}/qa`, data, undefined, loading)
|
||||
}
|
||||
|
||||
/**
|
||||
* 知识库详情
|
||||
* @param 参数 dataset_id
|
||||
|
|
@ -170,5 +186,6 @@ export default {
|
|||
listUsableApplication,
|
||||
getDatasetHitTest,
|
||||
postWebDataset,
|
||||
postQADataset,
|
||||
putSyncWebDataset
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import { Result } from '@/request/Result'
|
||||
import { get, post, del, put } from '@/request/index'
|
||||
import { get, post, del, put, exportExcel } from '@/request/index'
|
||||
import type { Ref } from 'vue'
|
||||
import type { KeyValue } from '@/api/type/common'
|
||||
import type { pageRequest } from '@/api/type/common'
|
||||
|
|
@ -188,6 +188,20 @@ const postWebDocument: (
|
|||
return post(`${prefix}/${dataset_id}/document/web`, data, undefined, loading)
|
||||
}
|
||||
|
||||
/**
|
||||
* 导入QA文档
|
||||
* @param 参数
|
||||
* file
|
||||
}
|
||||
*/
|
||||
const postQADocument: (
|
||||
dataset_id: string,
|
||||
data: any,
|
||||
loading?: Ref<boolean>
|
||||
) => Promise<Result<any>> = (dataset_id, data, loading) => {
|
||||
return post(`${prefix}/${dataset_id}/document/qa`, data, undefined, loading)
|
||||
}
|
||||
|
||||
/**
|
||||
* 批量迁移文档
|
||||
* @param 参数 dataset_id,target_dataset_id,
|
||||
|
|
@ -220,6 +234,19 @@ const batchEditHitHandling: (
|
|||
) => Promise<Result<boolean>> = (dataset_id, data, loading) => {
|
||||
return put(`${prefix}/${dataset_id}/document/batch_hit_handling`, data, undefined, loading)
|
||||
}
|
||||
|
||||
/**
|
||||
* 获得QA模版
|
||||
* @param 参数 fileName,type,
|
||||
*/
|
||||
const exportQATemplate: (fileName: string, type: string, loading?: Ref<boolean>) => void = (
|
||||
fileName,
|
||||
type,
|
||||
loading
|
||||
) => {
|
||||
return exportExcel(fileName, `${prefix}/document/template/export`, { type }, loading)
|
||||
}
|
||||
|
||||
export default {
|
||||
postSplitDocument,
|
||||
getDocument,
|
||||
|
|
@ -234,5 +261,7 @@ export default {
|
|||
delMulSyncDocument,
|
||||
postWebDocument,
|
||||
putMigrateMulDocument,
|
||||
batchEditHitHandling
|
||||
batchEditHitHandling,
|
||||
exportQATemplate,
|
||||
postQADocument
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,5 @@
|
|||
<svg width="24" height="24" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M4 2.5C4 1.94772 4.44772 1.5 5 1.5H14.7929C14.9255 1.5 15.0527 1.55268 15.1464 1.64645L19.8536 6.35355C19.9473 6.44732 20 6.5745 20 6.70711V21.5C20 22.0523 19.5523 22.5 19 22.5H5C4.44772 22.5 4 22.0523 4 21.5V2.5Z" fill="#34C724"/>
|
||||
<path d="M15 1.54492C15.054 1.56949 15.1037 1.6037 15.1464 1.64646L19.8536 6.35357C19.8963 6.39632 19.9305 6.44602 19.9551 6.50001H16C15.4477 6.50001 15 6.0523 15 5.50001V1.54492Z" fill="#2CA91F"/>
|
||||
<path d="M11.308 13.5956L8.33203 17.9996H9.60403L11.98 14.4596L14.284 17.9996H15.676L12.664 13.5956L15.496 9.43164H14.224L11.992 12.7796L9.85603 9.43164H8.48803L11.308 13.5956Z" fill="white"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 735 B |
|
|
@ -0,0 +1,5 @@
|
|||
<svg width="24" height="24" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M4 2.5C4 1.94772 4.44772 1.5 5 1.5H14.7929C14.9255 1.5 15.0527 1.55268 15.1464 1.64645L19.8536 6.35355C19.9473 6.44732 20 6.5745 20 6.70711V21.5C20 22.0523 19.5523 22.5 19 22.5H5C4.44772 22.5 4 22.0523 4 21.5V2.5Z" fill="#34C724"/>
|
||||
<path d="M15 1.54492C15.054 1.56949 15.1037 1.6037 15.1464 1.64646L19.8536 6.35357C19.8963 6.39632 19.9305 6.44602 19.9551 6.50001H16C15.4477 6.50001 15 6.0523 15 5.50001V1.54492Z" fill="#2CA91F"/>
|
||||
<path d="M11.308 13.5956L8.33203 17.9996H9.60403L11.98 14.4596L14.284 17.9996H15.676L12.664 13.5956L15.496 9.43164H14.224L11.992 12.7796L9.85603 9.43164H8.48803L11.308 13.5956Z" fill="white"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 735 B |
|
|
@ -7,6 +7,7 @@ import { type Ref } from 'vue'
|
|||
export interface datasetStateTypes {
|
||||
baseInfo: datasetData | null
|
||||
webInfo: any
|
||||
documentsType: string
|
||||
documentsFiles: UploadUserFile[]
|
||||
}
|
||||
|
||||
|
|
@ -15,6 +16,7 @@ const useDatasetStore = defineStore({
|
|||
state: (): datasetStateTypes => ({
|
||||
baseInfo: null,
|
||||
webInfo: null,
|
||||
documentsType: '',
|
||||
documentsFiles: []
|
||||
}),
|
||||
actions: {
|
||||
|
|
@ -24,6 +26,9 @@ const useDatasetStore = defineStore({
|
|||
saveWebInfo(info: any) {
|
||||
this.webInfo = info
|
||||
},
|
||||
saveDocumentsType(val: string) {
|
||||
this.documentsType = val
|
||||
},
|
||||
saveDocumentsFile(file: UploadUserFile[]) {
|
||||
this.documentsFiles = file
|
||||
},
|
||||
|
|
|
|||
|
|
@ -37,14 +37,20 @@ export function fileType(name: string) {
|
|||
/*
|
||||
获得文件对应图片
|
||||
*/
|
||||
const typeList: any = {
|
||||
txt: ['txt', 'pdf', 'docx', 'csv', 'md'],
|
||||
QA: ['xlsx', 'csv', 'xls']
|
||||
}
|
||||
|
||||
export function getImgUrl(name: string) {
|
||||
const type = isRightType(name) ? fileType(name) : 'unknow'
|
||||
const list = Object.values(typeList).flat()
|
||||
|
||||
const type = list.includes(fileType(name)) ? fileType(name) : 'unknow'
|
||||
return new URL(`../assets/${type}-icon.svg`, import.meta.url).href
|
||||
}
|
||||
// 是否是白名单后缀
|
||||
export function isRightType(name: string) {
|
||||
const typeList = ['txt', 'pdf', 'docx', 'csv', 'md', 'html']
|
||||
return typeList.includes(fileType(name))
|
||||
export function isRightType(name: string, type: string) {
|
||||
return typeList[type].includes(fileType(name))
|
||||
}
|
||||
|
||||
/*
|
||||
|
|
|
|||
|
|
@ -40,6 +40,7 @@ import StepFirst from './step/StepFirst.vue'
|
|||
import StepSecond from './step/StepSecond.vue'
|
||||
import ResultSuccess from './step/ResultSuccess.vue'
|
||||
import datasetApi from '@/api/dataset'
|
||||
import documentApi from '@/api/document'
|
||||
import type { datasetData } from '@/api/type/dataset'
|
||||
import { MsgConfirm, MsgSuccess } from '@/utils/message'
|
||||
|
||||
|
|
@ -48,6 +49,7 @@ const { dataset, document } = useStore()
|
|||
const baseInfo = computed(() => dataset.baseInfo)
|
||||
const webInfo = computed(() => dataset.webInfo)
|
||||
const documentsFiles = computed(() => dataset.documentsFiles)
|
||||
const documentsType = computed(() => dataset.documentsType)
|
||||
|
||||
const router = useRouter()
|
||||
const route = useRoute()
|
||||
|
|
@ -80,7 +82,33 @@ const successInfo = ref<any>(null)
|
|||
async function next() {
|
||||
disabled.value = true
|
||||
if (await StepFirstRef.value?.onSubmit()) {
|
||||
if (active.value++ > 2) active.value = 0
|
||||
if (documentsType.value === 'QA') {
|
||||
let fd = new FormData()
|
||||
documentsFiles.value.forEach((item: any) => {
|
||||
if (item?.raw) {
|
||||
fd.append('file', item?.raw)
|
||||
}
|
||||
})
|
||||
if (id) {
|
||||
// QA文档上传
|
||||
documentApi.postQADocument(id as string, fd, loading).then((res) => {
|
||||
MsgSuccess('提交成功')
|
||||
clearStore()
|
||||
router.push({ path: `/dataset/${id}/document` })
|
||||
})
|
||||
} else {
|
||||
// QA知识库创建
|
||||
fd.append('name', baseInfo.value?.name as string)
|
||||
fd.append('desc', baseInfo.value?.desc as string)
|
||||
|
||||
datasetApi.postQADataset(fd, loading).then((res) => {
|
||||
successInfo.value = res.data
|
||||
active.value = 2
|
||||
})
|
||||
}
|
||||
} else {
|
||||
if (active.value++ > 2) active.value = 0
|
||||
}
|
||||
} else {
|
||||
disabled.value = false
|
||||
}
|
||||
|
|
@ -93,6 +121,7 @@ function clearStore() {
|
|||
dataset.saveBaseInfo(null)
|
||||
dataset.saveWebInfo(null)
|
||||
dataset.saveDocumentsFile([])
|
||||
dataset.saveDocumentsType('')
|
||||
}
|
||||
function submit() {
|
||||
loading.value = true
|
||||
|
|
|
|||
|
|
@ -7,7 +7,48 @@
|
|||
label-position="top"
|
||||
require-asterisk-position="right"
|
||||
>
|
||||
<el-form-item prop="fileList">
|
||||
<el-form-item>
|
||||
<el-radio-group v-model="form.fileType" @change="radioChange">
|
||||
<el-radio value="txt">文本文件</el-radio>
|
||||
<el-radio value="QA">QA 问答对</el-radio>
|
||||
</el-radio-group>
|
||||
</el-form-item>
|
||||
<el-form-item prop="fileList" v-if="form.fileType === 'QA'">
|
||||
<el-upload
|
||||
:webkitdirectory="false"
|
||||
class="w-full mb-4"
|
||||
drag
|
||||
multiple
|
||||
v-model:file-list="form.fileList"
|
||||
action="#"
|
||||
:auto-upload="false"
|
||||
:show-file-list="false"
|
||||
accept=".xlsx, .xls, .csv"
|
||||
:limit="50"
|
||||
:on-exceed="onExceed"
|
||||
:on-change="fileHandleChange"
|
||||
@click.prevent="handlePreview(false)"
|
||||
>
|
||||
<img src="@/assets/upload-icon.svg" alt="" />
|
||||
<div class="el-upload__text">
|
||||
<p>
|
||||
拖拽文件至此上传或
|
||||
<em class="hover" @click.prevent="handlePreview(false)"> 选择文件 </em>
|
||||
<em class="hover" @click.prevent="handlePreview(true)"> 选择文件夹 </em>
|
||||
</p>
|
||||
<div class="upload__decoration">
|
||||
<p>当前支持 XLSX / XLS / CSV 格式的文档</p>
|
||||
<p>每次最多上传50个文件,每个文件不超过 100MB</p>
|
||||
</div>
|
||||
</div>
|
||||
</el-upload>
|
||||
<el-button type="primary" link @click="downloadTemplate('excel')">
|
||||
下载 Excel 模板
|
||||
</el-button>
|
||||
<el-divider direction="vertical" />
|
||||
<el-button type="primary" link @click="downloadTemplate('csv')"> 下载 CSV 模板 </el-button>
|
||||
</el-form-item>
|
||||
<el-form-item prop="fileList" v-else>
|
||||
<el-upload
|
||||
:webkitdirectory="false"
|
||||
class="w-full"
|
||||
|
|
@ -63,13 +104,16 @@
|
|||
</template>
|
||||
<script setup lang="ts">
|
||||
import { ref, reactive, onUnmounted, onMounted, computed, watch, nextTick } from 'vue'
|
||||
import type { UploadFile, UploadFiles } from 'element-plus'
|
||||
import type { UploadFiles } from 'element-plus'
|
||||
import { filesize, getImgUrl, isRightType } from '@/utils/utils'
|
||||
import { MsgError } from '@/utils/message'
|
||||
import documentApi from '@/api/document'
|
||||
import useStore from '@/stores'
|
||||
const { dataset } = useStore()
|
||||
const documentsFiles = computed(() => dataset.documentsFiles)
|
||||
const documentsType = computed(() => dataset.documentsType)
|
||||
const form = ref({
|
||||
fileType: 'txt',
|
||||
fileList: [] as any
|
||||
})
|
||||
|
||||
|
|
@ -79,22 +123,32 @@ const rules = reactive({
|
|||
const FormRef = ref()
|
||||
|
||||
watch(form.value, (value) => {
|
||||
dataset.saveDocumentsType(value.fileType)
|
||||
dataset.saveDocumentsFile(value.fileList)
|
||||
})
|
||||
|
||||
function downloadTemplate(type: string) {
|
||||
documentApi.exportQATemplate(`${type}模版.${type == 'csv' ? type : 'xlsx'}`, type)
|
||||
}
|
||||
|
||||
function radioChange() {
|
||||
form.value.fileList = []
|
||||
}
|
||||
|
||||
function deleteFile(index: number) {
|
||||
form.value.fileList.splice(index, 1)
|
||||
}
|
||||
|
||||
// 上传on-change事件
|
||||
const fileHandleChange = (file: any, fileList: UploadFiles) => {
|
||||
//1、判断文件大小是否合法,文件限制不能大于10M
|
||||
//1、判断文件大小是否合法,文件限制不能大于100M
|
||||
const isLimit = file?.size / 1024 / 1024 < 100
|
||||
if (!isLimit) {
|
||||
MsgError('文件大小超过 100MB')
|
||||
fileList.splice(-1, 1) //移除当前超出大小的文件
|
||||
return false
|
||||
}
|
||||
if (!isRightType(file?.name)) {
|
||||
if (!isRightType(file?.name, form.value.fileType)) {
|
||||
MsgError('文件格式不支持')
|
||||
fileList.splice(-1, 1)
|
||||
return false
|
||||
|
|
@ -126,12 +180,16 @@ function validate() {
|
|||
}
|
||||
|
||||
onMounted(() => {
|
||||
if (documentsType.value) {
|
||||
form.value.fileType = documentsType.value
|
||||
}
|
||||
if (documentsFiles.value) {
|
||||
form.value.fileList = documentsFiles.value
|
||||
}
|
||||
})
|
||||
onUnmounted(() => {
|
||||
form.value = {
|
||||
fileType: 'txt',
|
||||
fileList: []
|
||||
}
|
||||
})
|
||||
|
|
|
|||
|
|
@ -110,6 +110,7 @@ watch(form.value, (value) => {
|
|||
|
||||
function radioChange() {
|
||||
dataset.saveDocumentsFile([])
|
||||
dataset.saveDocumentsType('')
|
||||
form.value.source_url = ''
|
||||
form.value.selector = ''
|
||||
}
|
||||
|
|
@ -126,6 +127,7 @@ const onSubmit = async () => {
|
|||
stores保存数据
|
||||
*/
|
||||
dataset.saveBaseInfo(BaseFormRef.value.form)
|
||||
dataset.saveDocumentsType(UploadComponentRef.value.form.fileType)
|
||||
dataset.saveDocumentsFile(UploadComponentRef.value.form.fileList)
|
||||
return true
|
||||
}
|
||||
|
|
@ -156,6 +158,7 @@ const onSubmit = async () => {
|
|||
/*
|
||||
stores保存数据
|
||||
*/
|
||||
dataset.saveDocumentsType(UploadComponentRef.value.form.fileType)
|
||||
dataset.saveDocumentsFile(UploadComponentRef.value.form.fileList)
|
||||
return true
|
||||
} else {
|
||||
|
|
|
|||
Loading…
Reference in New Issue