feat: enhance Document API with workspace ID support for get, put, and delete operations

This commit is contained in:
CaptainB 2025-05-06 14:33:59 +08:00
parent 3e9069aac1
commit e702af8c2b
7 changed files with 263 additions and 16 deletions

View File

@ -8,7 +8,7 @@ from common.handle.base_parse_table_handle import BaseParseTableHandle
max_kb = logging.getLogger("max_kb")
class CsvSplitHandle(BaseParseTableHandle):
class CsvParseTableHandle(BaseParseTableHandle):
def support(self, file, get_buffer):
file_name: str = file.name.lower()
if file_name.endswith(".csv"):

View File

@ -8,7 +8,7 @@ from common.handle.base_parse_table_handle import BaseParseTableHandle
max_kb = logging.getLogger("max_kb")
class XlsSplitHandle(BaseParseTableHandle):
class XlsParseTableHandle(BaseParseTableHandle):
def support(self, file, get_buffer):
file_name: str = file.name.lower()
buffer = get_buffer(file)

View File

@ -10,7 +10,7 @@ from common.handle.impl.common_handle import xlsx_embed_cells_images
max_kb = logging.getLogger("max_kb")
class XlsxSplitHandle(BaseParseTableHandle):
class XlsxParseTableHandle(BaseParseTableHandle):
def support(self, file, get_buffer):
file_name: str = file.name.lower()
if file_name.endswith('.xlsx'):

View File

@ -4,7 +4,7 @@ from drf_spectacular.utils import OpenApiParameter
from common.mixins.api_mixin import APIMixin
from common.result import DefaultResultSerializer
from knowledge.serializers.common import BatchSerializer
from knowledge.serializers.document import DocumentInstanceSerializer
from knowledge.serializers.document import DocumentInstanceSerializer, DocumentWebInstanceSerializer
class DocumentSplitAPI(APIMixin):
@ -176,3 +176,45 @@ class DocumentEditAPI(DocumentReadAPI):
class DocumentDeleteAPI(DocumentReadAPI):
pass
class TableDocumentCreateAPI(APIMixin):
@staticmethod
def get_parameters():
return [
OpenApiParameter(
name="workspace_id",
description="工作空间id",
type=OpenApiTypes.STR,
location='path',
required=True,
),
OpenApiParameter(
name="knowledge_id",
description="知识库id",
type=OpenApiTypes.STR,
location='path',
required=True,
),
OpenApiParameter(
name="file",
description="文件",
type=OpenApiTypes.BINARY,
location='query',
required=False,
),
]
@staticmethod
def get_response():
return DefaultResultSerializer
class QaDocumentCreateAPI(TableDocumentCreateAPI):
pass
class WebDocumentCreateAPI(APIMixin):
@staticmethod
def get_request():
return DocumentWebInstanceSerializer

View File

@ -1,11 +1,13 @@
import logging
import os
import re
import traceback
from functools import reduce
from typing import Dict, List
import uuid_utils.compat as uuid
from celery_once import AlreadyQueued
from django.core import validators
from django.db import transaction, models
from django.db.models import QuerySet, Model
from django.db.models.functions import Substr, Reverse
@ -16,6 +18,13 @@ from common.db.search import native_search, get_dynamics_model, native_page_sear
from common.event import ListenerManagement
from common.event.common import work_thread_pool
from common.exception.app_exception import AppApiException
from common.handle.impl.qa.csv_parse_qa_handle import CsvParseQAHandle
from common.handle.impl.qa.xls_parse_qa_handle import XlsParseQAHandle
from common.handle.impl.qa.xlsx_parse_qa_handle import XlsxParseQAHandle
from common.handle.impl.qa.zip_parse_qa_handle import ZipParseQAHandle
from common.handle.impl.table.csv_parse_table_handle import CsvParseTableHandle
from common.handle.impl.table.xls_parse_table_handle import XlsParseTableHandle
from common.handle.impl.table.xlsx_parse_table_handle import XlsxParseTableHandle
from common.handle.impl.text.csv_split_handle import CsvSplitHandle
from common.handle.impl.text.doc_split_handle import DocSplitHandle
from common.handle.impl.text.html_split_handle import HTMLSplitHandle
@ -26,14 +35,16 @@ from common.handle.impl.text.xlsx_split_handle import XlsxSplitHandle
from common.handle.impl.text.zip_split_handle import ZipSplitHandle
from common.utils.common import post, get_file_content, bulk_create_in_batches
from common.utils.fork import Fork
from common.utils.split_model import get_split_model
from common.utils.split_model import get_split_model, flat_map
from knowledge.models import Knowledge, Paragraph, Problem, Document, KnowledgeType, ProblemParagraphMapping, State, \
TaskType, File
from knowledge.serializers.common import ProblemParagraphManage, BatchSerializer, get_embedding_model_id_by_knowledge_id
from knowledge.serializers.common import ProblemParagraphManage, BatchSerializer, \
get_embedding_model_id_by_knowledge_id, MetaSerializer
from knowledge.serializers.paragraph import ParagraphSerializers, ParagraphInstanceSerializer, \
delete_problems_and_mappings
from knowledge.task.embedding import embedding_by_document, delete_embedding_by_document_list, \
delete_embedding_by_document
from knowledge.task.sync import sync_web_document
from maxkb.const import PROJECT_DIR
default_split_handle = TextSplitHandle()
@ -48,6 +59,9 @@ split_handles = [
default_split_handle
]
parse_qa_handle_list = [XlsParseQAHandle(), CsvParseQAHandle(), XlsxParseQAHandle(), ZipParseQAHandle()]
parse_table_handle_list = [CsvParseTableHandle(), XlsParseTableHandle(), XlsxParseTableHandle()]
class BatchCancelInstanceSerializer(serializers.Serializer):
id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True), label=('id list'))
@ -67,6 +81,36 @@ class DocumentInstanceSerializer(serializers.Serializer):
paragraphs = ParagraphInstanceSerializer(required=False, many=True, allow_null=True)
class DocumentEditInstanceSerializer(serializers.Serializer):
meta = serializers.DictField(required=False)
name = serializers.CharField(required=False, max_length=128, min_length=1, label=_('document name'))
hit_handling_method = serializers.CharField(required=False, validators=[
validators.RegexValidator(regex=re.compile("^optimization|directly_return$"),
message=_('The type only supports optimization|directly_return'),
code=500)
], label=_('hit handling method'))
directly_return_similarity = serializers.FloatField(required=False, max_value=2, min_value=0,
label=_('directly return similarity'))
is_active = serializers.BooleanField(required=False, label=_('document is active'))
@staticmethod
def get_meta_valid_map():
dataset_meta_valid_map = {
KnowledgeType.BASE: MetaSerializer.BaseMeta,
KnowledgeType.WEB: MetaSerializer.WebMeta
}
return dataset_meta_valid_map
def is_valid(self, *, document: Document = None):
super().is_valid(raise_exception=True)
if 'meta' in self.data and self.data.get('meta') is not None:
dataset_meta_valid_map = self.get_meta_valid_map()
valid_class = dataset_meta_valid_map.get(document.type)
valid_class(data=self.data.get('meta')).is_valid(raise_exception=True)
class DocumentSplitRequest(serializers.Serializer):
file = serializers.ListField(required=True, label=_('file list'))
limit = serializers.IntegerField(required=False, label=_('limit'))
@ -78,6 +122,22 @@ class DocumentSplitRequest(serializers.Serializer):
with_filter = serializers.BooleanField(required=False, label=_('Auto Clean'))
class DocumentWebInstanceSerializer(serializers.Serializer):
source_url_list = serializers.ListField(required=True, label=_('document url list'),
child=serializers.CharField(required=True, label=_('document url list')))
selector = serializers.CharField(required=False, allow_null=True, allow_blank=True, label=_('selector'))
class DocumentInstanceQASerializer(serializers.Serializer):
file_list = serializers.ListSerializer(required=True, label=_('file list'),
child=serializers.FileField(required=True, label=_('file')))
class DocumentInstanceTableSerializer(serializers.Serializer):
file_list = serializers.ListSerializer(required=True, label=_('file list'),
child=serializers.FileField(required=True, label=_('file')))
class DocumentSerializers(serializers.Serializer):
class Query(serializers.Serializer):
# 知识库id
@ -226,6 +286,7 @@ class DocumentSerializers(serializers.Serializer):
return True
class Operate(serializers.Serializer):
workspace_id = serializers.CharField(required=True, label=_('workspace id'))
document_id = serializers.UUIDField(required=True, label=_('document id'))
knowledge_id = serializers.UUIDField(required=True, label=_('knowledge id'))
@ -246,6 +307,31 @@ class DocumentSerializers(serializers.Serializer):
}, select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "knowledge", 'sql', 'list_document.sql')), with_search_one=True)
def edit(self, instance: Dict, with_valid=False):
if with_valid:
self.is_valid(raise_exception=True)
_document = QuerySet(Document).get(id=self.data.get("document_id"))
if with_valid:
DocumentEditInstanceSerializer(data=instance).is_valid(document=_document)
update_keys = ['name', 'is_active', 'hit_handling_method', 'directly_return_similarity', 'meta']
for update_key in update_keys:
if update_key in instance and instance.get(update_key) is not None:
_document.__setattr__(update_key, instance.get(update_key))
_document.save()
return self.one()
@transaction.atomic
def delete(self):
document_id = self.data.get("document_id")
QuerySet(model=Document).filter(id=document_id).delete()
# 删除段落
QuerySet(model=Paragraph).filter(document_id=document_id).delete()
# 删除问题
delete_problems_and_mappings([document_id])
# 删除向量库
delete_embedding_by_document(document_id)
return True
def refresh(self, state_list=None, with_valid=True):
if state_list is None:
state_list = [State.PENDING.value, State.STARTED.value, State.SUCCESS.value, State.FAILURE.value,
@ -369,6 +455,58 @@ class DocumentSerializers(serializers.Serializer):
instance.get('paragraphs') if 'paragraphs' in instance else []
)
def save_web(self, instance: Dict, with_valid=True):
if with_valid:
DocumentWebInstanceSerializer(data=instance).is_valid(raise_exception=True)
self.is_valid(raise_exception=True)
dataset_id = self.data.get('dataset_id')
source_url_list = instance.get('source_url_list')
selector = instance.get('selector')
sync_web_document.delay(dataset_id, source_url_list, selector)
def save_qa(self, instance: Dict, with_valid=True):
if with_valid:
DocumentInstanceQASerializer(data=instance).is_valid(raise_exception=True)
self.is_valid(raise_exception=True)
file_list = instance.get('file_list')
document_list = flat_map([self.parse_qa_file(file) for file in file_list])
return DocumentSerializers.Batch(data={'dataset_id': self.data.get('dataset_id')}).batch_save(document_list)
def save_table(self, instance: Dict, with_valid=True):
if with_valid:
DocumentInstanceTableSerializer(data=instance).is_valid(raise_exception=True)
self.is_valid(raise_exception=True)
file_list = instance.get('file_list')
document_list = flat_map([self.parse_table_file(file) for file in file_list])
return DocumentSerializers.Batch(data={'dataset_id': self.data.get('dataset_id')}).batch_save(document_list)
def parse_qa_file(self, file):
get_buffer = FileBufferHandle().get_buffer
for parse_qa_handle in parse_qa_handle_list:
if parse_qa_handle.support(file, get_buffer):
return parse_qa_handle.handle(file, get_buffer, self.save_image)
raise AppApiException(500, _('Unsupported file format'))
def parse_table_file(self, file):
get_buffer = FileBufferHandle().get_buffer
for parse_table_handle in parse_table_handle_list:
if parse_table_handle.support(file, get_buffer):
return parse_table_handle.handle(file, get_buffer, self.save_image)
raise AppApiException(500, _('Unsupported file format'))
def save_image(self, image_list):
if image_list is not None and len(image_list) > 0:
exist_image_list = [str(i.get('id')) for i in
QuerySet(File).filter(id__in=[i.id for i in image_list]).values('id')]
save_image_list = [image for image in image_list if not exist_image_list.__contains__(str(image.id))]
save_image_list = list({img.id: img for img in save_image_list}.values())
# save image
for file in save_image_list:
file_bytes = file.meta.pop('content')
file.workspace_id = self.data.get('workspace_id')
file.meta['knowledge_id'] = self.data.get('knowledge_id')
file.save(file_bytes)
class Split(serializers.Serializer):
workspace_id = serializers.CharField(required=True, label=_('workspace id'))
knowledge_id = serializers.UUIDField(required=True, label=_('knowledge id'))

View File

@ -11,6 +11,9 @@ urlpatterns = [
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document', views.DocumentView.as_view()),
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/split', views.DocumentView.Split.as_view()),
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/batch', views.DocumentView.Batch.as_view()),
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/web', views.WebDocumentView.as_view()),
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/qa', views.QaDocumentView.as_view()),
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/table', views.TableDocumentView.as_view()),
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/<str:document_id>', views.DocumentView.Operate.as_view()),
path('workspace/<str:workspace_id>/knowledge/<int:current_page>/<int:page_size>', views.KnowledgeView.Page.as_view()),
]

View File

@ -9,7 +9,8 @@ from common.auth.authentication import has_permissions
from common.constants.permission_constants import PermissionConstants
from common.result import result
from knowledge.api.document import DocumentSplitAPI, DocumentBatchAPI, DocumentBatchCreateAPI, DocumentCreateAPI, \
DocumentReadAPI, DocumentEditAPI, DocumentDeleteAPI
DocumentReadAPI, DocumentEditAPI, DocumentDeleteAPI, TableDocumentCreateAPI, QaDocumentCreateAPI, \
WebDocumentCreateAPI
from knowledge.api.knowledge import KnowledgeTreeReadAPI
from knowledge.serializers.document import DocumentSerializers
@ -68,8 +69,10 @@ class DocumentView(APIView):
tags=[_('Knowledge Base/Documentation')]
)
@has_permissions(PermissionConstants.DOCUMENT_READ.get_workspace_permission())
def get(self, request: Request, knowledge_id: str, document_id: str):
operate = DocumentSerializers.Operate(data={'document_id': document_id, 'knowledge_id': knowledge_id})
def get(self, request: Request, workspace_id: str, knowledge_id: str, document_id: str):
operate = DocumentSerializers.Operate(data={
'document_id': document_id, 'knowledge_id': knowledge_id, 'workspace_id': workspace_id
})
operate.is_valid(raise_exception=True)
return result.success(operate.one())
@ -83,11 +86,10 @@ class DocumentView(APIView):
tags=[_('Knowledge Base/Documentation')]
)
@has_permissions(PermissionConstants.DOCUMENT_EDIT.get_workspace_permission())
def put(self, request: Request, knowledge_id: str, document_id: str):
return result.success(
DocumentSerializers.Operate(data={'document_id': document_id, 'knowledge_id': knowledge_id}).edit(
request.data,
with_valid=True))
def put(self, request: Request, workspace_id: str, knowledge_id: str, document_id: str):
return result.success(DocumentSerializers.Operate(data={
'document_id': document_id, 'knowledge_id': knowledge_id, 'workspace_id': workspace_id
}).edit(request.data, with_valid=True))
@extend_schema(
description=_('Delete document'),
@ -98,8 +100,10 @@ class DocumentView(APIView):
tags=[_('Knowledge Base/Documentation')]
)
@has_permissions(PermissionConstants.DOCUMENT_DELETE.get_workspace_permission())
def delete(self, request: Request, knowledge_id: str, document_id: str):
operate = DocumentSerializers.Operate(data={'document_id': document_id, 'knowledge_id': knowledge_id})
def delete(self, request: Request, workspace_id: str, knowledge_id: str, document_id: str):
operate = DocumentSerializers.Operate(data={
'document_id': document_id, 'knowledge_id': knowledge_id, 'workspace_id': workspace_id
})
operate.is_valid(raise_exception=True)
return result.success(operate.delete())
@ -195,3 +199,63 @@ class DocumentView(APIView):
return result.success(DocumentSerializers.Batch(
data={'workspace_id': workspace_id, 'knowledge_id': knowledge_id}
).batch_delete(request.data))
class WebDocumentView(APIView):
authentication_classes = [TokenAuth]
@extend_schema(
methods=['POST'],
description=_('Create Web site documents'),
summary=_('Create Web site documents'),
operation_id=_('Create Web site documents'),
request=WebDocumentCreateAPI.get_request(),
parameters=WebDocumentCreateAPI.get_parameters(),
responses=WebDocumentCreateAPI.get_response(),
tags=[_('Knowledge Base/Documentation')]
)
@has_permissions(PermissionConstants.DOCUMENT_CREATE.get_workspace_permission())
def post(self, request: Request, workspace_id: str, knowledge_id: str):
return result.success(DocumentSerializers.Create(data={
'knowledge_id': knowledge_id, 'workspace_id': workspace_id
}).save_web(request.data, with_valid=True))
class QaDocumentView(APIView):
authentication_classes = [TokenAuth]
parser_classes = [MultiPartParser]
@extend_schema(
summary=_('Import QA and create documentation'),
description=_('Import QA and create documentation'),
operation_id=_('Import QA and create documentation'),
request=QaDocumentCreateAPI.get_request(),
parameters=QaDocumentCreateAPI.get_parameters(),
responses=QaDocumentCreateAPI.get_response(),
tags=[_('Knowledge Base/Documentation')]
)
@has_permissions(PermissionConstants.DOCUMENT_CREATE.get_workspace_permission())
def post(self, request: Request, workspace_id: str, knowledge_id: str):
return result.success(DocumentSerializers.Create(data={
'knowledge_id': knowledge_id, 'workspace_id': workspace_id
}).save_qa({'file_list': request.FILES.getlist('file')}, with_valid=True))
class TableDocumentView(APIView):
authentication_classes = [TokenAuth]
parser_classes = [MultiPartParser]
@extend_schema(
summary=_('Import tables and create documents'),
description=_('Import tables and create documents'),
operation_id=_('Import tables and create documents'),
request=TableDocumentCreateAPI.get_request(),
parameters=TableDocumentCreateAPI.get_parameters(),
responses=TableDocumentCreateAPI.get_response(),
tags=[_('Knowledge Base/Documentation')]
)
@has_permissions(PermissionConstants.DOCUMENT_CREATE.get_workspace_permission())
def post(self, request: Request, workspace_id: str, knowledge_id: str):
return result.success(DocumentSerializers.Create(
data={'knowledge_id': knowledge_id, 'workspace_id': workspace_id}
).save_table({'file_list': request.FILES.getlist('file')}, with_valid=True))