mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 01:33:05 +00:00
feat: enhance Document API with workspace ID support for get, put, and delete operations
This commit is contained in:
parent
3e9069aac1
commit
e702af8c2b
|
|
@ -8,7 +8,7 @@ from common.handle.base_parse_table_handle import BaseParseTableHandle
|
|||
max_kb = logging.getLogger("max_kb")
|
||||
|
||||
|
||||
class CsvSplitHandle(BaseParseTableHandle):
|
||||
class CsvParseTableHandle(BaseParseTableHandle):
|
||||
def support(self, file, get_buffer):
|
||||
file_name: str = file.name.lower()
|
||||
if file_name.endswith(".csv"):
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from common.handle.base_parse_table_handle import BaseParseTableHandle
|
|||
max_kb = logging.getLogger("max_kb")
|
||||
|
||||
|
||||
class XlsSplitHandle(BaseParseTableHandle):
|
||||
class XlsParseTableHandle(BaseParseTableHandle):
|
||||
def support(self, file, get_buffer):
|
||||
file_name: str = file.name.lower()
|
||||
buffer = get_buffer(file)
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from common.handle.impl.common_handle import xlsx_embed_cells_images
|
|||
max_kb = logging.getLogger("max_kb")
|
||||
|
||||
|
||||
class XlsxSplitHandle(BaseParseTableHandle):
|
||||
class XlsxParseTableHandle(BaseParseTableHandle):
|
||||
def support(self, file, get_buffer):
|
||||
file_name: str = file.name.lower()
|
||||
if file_name.endswith('.xlsx'):
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from drf_spectacular.utils import OpenApiParameter
|
|||
from common.mixins.api_mixin import APIMixin
|
||||
from common.result import DefaultResultSerializer
|
||||
from knowledge.serializers.common import BatchSerializer
|
||||
from knowledge.serializers.document import DocumentInstanceSerializer
|
||||
from knowledge.serializers.document import DocumentInstanceSerializer, DocumentWebInstanceSerializer
|
||||
|
||||
|
||||
class DocumentSplitAPI(APIMixin):
|
||||
|
|
@ -176,3 +176,45 @@ class DocumentEditAPI(DocumentReadAPI):
|
|||
|
||||
class DocumentDeleteAPI(DocumentReadAPI):
|
||||
pass
|
||||
|
||||
|
||||
class TableDocumentCreateAPI(APIMixin):
|
||||
@staticmethod
|
||||
def get_parameters():
|
||||
return [
|
||||
OpenApiParameter(
|
||||
name="workspace_id",
|
||||
description="工作空间id",
|
||||
type=OpenApiTypes.STR,
|
||||
location='path',
|
||||
required=True,
|
||||
),
|
||||
OpenApiParameter(
|
||||
name="knowledge_id",
|
||||
description="知识库id",
|
||||
type=OpenApiTypes.STR,
|
||||
location='path',
|
||||
required=True,
|
||||
),
|
||||
OpenApiParameter(
|
||||
name="file",
|
||||
description="文件",
|
||||
type=OpenApiTypes.BINARY,
|
||||
location='query',
|
||||
required=False,
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_response():
|
||||
return DefaultResultSerializer
|
||||
|
||||
|
||||
class QaDocumentCreateAPI(TableDocumentCreateAPI):
|
||||
pass
|
||||
|
||||
|
||||
class WebDocumentCreateAPI(APIMixin):
|
||||
@staticmethod
|
||||
def get_request():
|
||||
return DocumentWebInstanceSerializer
|
||||
|
|
|
|||
|
|
@ -1,11 +1,13 @@
|
|||
import logging
|
||||
import os
|
||||
import re
|
||||
import traceback
|
||||
from functools import reduce
|
||||
from typing import Dict, List
|
||||
|
||||
import uuid_utils.compat as uuid
|
||||
from celery_once import AlreadyQueued
|
||||
from django.core import validators
|
||||
from django.db import transaction, models
|
||||
from django.db.models import QuerySet, Model
|
||||
from django.db.models.functions import Substr, Reverse
|
||||
|
|
@ -16,6 +18,13 @@ from common.db.search import native_search, get_dynamics_model, native_page_sear
|
|||
from common.event import ListenerManagement
|
||||
from common.event.common import work_thread_pool
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.handle.impl.qa.csv_parse_qa_handle import CsvParseQAHandle
|
||||
from common.handle.impl.qa.xls_parse_qa_handle import XlsParseQAHandle
|
||||
from common.handle.impl.qa.xlsx_parse_qa_handle import XlsxParseQAHandle
|
||||
from common.handle.impl.qa.zip_parse_qa_handle import ZipParseQAHandle
|
||||
from common.handle.impl.table.csv_parse_table_handle import CsvParseTableHandle
|
||||
from common.handle.impl.table.xls_parse_table_handle import XlsParseTableHandle
|
||||
from common.handle.impl.table.xlsx_parse_table_handle import XlsxParseTableHandle
|
||||
from common.handle.impl.text.csv_split_handle import CsvSplitHandle
|
||||
from common.handle.impl.text.doc_split_handle import DocSplitHandle
|
||||
from common.handle.impl.text.html_split_handle import HTMLSplitHandle
|
||||
|
|
@ -26,14 +35,16 @@ from common.handle.impl.text.xlsx_split_handle import XlsxSplitHandle
|
|||
from common.handle.impl.text.zip_split_handle import ZipSplitHandle
|
||||
from common.utils.common import post, get_file_content, bulk_create_in_batches
|
||||
from common.utils.fork import Fork
|
||||
from common.utils.split_model import get_split_model
|
||||
from common.utils.split_model import get_split_model, flat_map
|
||||
from knowledge.models import Knowledge, Paragraph, Problem, Document, KnowledgeType, ProblemParagraphMapping, State, \
|
||||
TaskType, File
|
||||
from knowledge.serializers.common import ProblemParagraphManage, BatchSerializer, get_embedding_model_id_by_knowledge_id
|
||||
from knowledge.serializers.common import ProblemParagraphManage, BatchSerializer, \
|
||||
get_embedding_model_id_by_knowledge_id, MetaSerializer
|
||||
from knowledge.serializers.paragraph import ParagraphSerializers, ParagraphInstanceSerializer, \
|
||||
delete_problems_and_mappings
|
||||
from knowledge.task.embedding import embedding_by_document, delete_embedding_by_document_list, \
|
||||
delete_embedding_by_document
|
||||
from knowledge.task.sync import sync_web_document
|
||||
from maxkb.const import PROJECT_DIR
|
||||
|
||||
default_split_handle = TextSplitHandle()
|
||||
|
|
@ -48,6 +59,9 @@ split_handles = [
|
|||
default_split_handle
|
||||
]
|
||||
|
||||
parse_qa_handle_list = [XlsParseQAHandle(), CsvParseQAHandle(), XlsxParseQAHandle(), ZipParseQAHandle()]
|
||||
parse_table_handle_list = [CsvParseTableHandle(), XlsParseTableHandle(), XlsxParseTableHandle()]
|
||||
|
||||
|
||||
class BatchCancelInstanceSerializer(serializers.Serializer):
|
||||
id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True), label=('id list'))
|
||||
|
|
@ -67,6 +81,36 @@ class DocumentInstanceSerializer(serializers.Serializer):
|
|||
paragraphs = ParagraphInstanceSerializer(required=False, many=True, allow_null=True)
|
||||
|
||||
|
||||
class DocumentEditInstanceSerializer(serializers.Serializer):
|
||||
meta = serializers.DictField(required=False)
|
||||
name = serializers.CharField(required=False, max_length=128, min_length=1, label=_('document name'))
|
||||
hit_handling_method = serializers.CharField(required=False, validators=[
|
||||
validators.RegexValidator(regex=re.compile("^optimization|directly_return$"),
|
||||
message=_('The type only supports optimization|directly_return'),
|
||||
code=500)
|
||||
], label=_('hit handling method'))
|
||||
|
||||
directly_return_similarity = serializers.FloatField(required=False, max_value=2, min_value=0,
|
||||
label=_('directly return similarity'))
|
||||
|
||||
is_active = serializers.BooleanField(required=False, label=_('document is active'))
|
||||
|
||||
@staticmethod
|
||||
def get_meta_valid_map():
|
||||
dataset_meta_valid_map = {
|
||||
KnowledgeType.BASE: MetaSerializer.BaseMeta,
|
||||
KnowledgeType.WEB: MetaSerializer.WebMeta
|
||||
}
|
||||
return dataset_meta_valid_map
|
||||
|
||||
def is_valid(self, *, document: Document = None):
|
||||
super().is_valid(raise_exception=True)
|
||||
if 'meta' in self.data and self.data.get('meta') is not None:
|
||||
dataset_meta_valid_map = self.get_meta_valid_map()
|
||||
valid_class = dataset_meta_valid_map.get(document.type)
|
||||
valid_class(data=self.data.get('meta')).is_valid(raise_exception=True)
|
||||
|
||||
|
||||
class DocumentSplitRequest(serializers.Serializer):
|
||||
file = serializers.ListField(required=True, label=_('file list'))
|
||||
limit = serializers.IntegerField(required=False, label=_('limit'))
|
||||
|
|
@ -78,6 +122,22 @@ class DocumentSplitRequest(serializers.Serializer):
|
|||
with_filter = serializers.BooleanField(required=False, label=_('Auto Clean'))
|
||||
|
||||
|
||||
class DocumentWebInstanceSerializer(serializers.Serializer):
|
||||
source_url_list = serializers.ListField(required=True, label=_('document url list'),
|
||||
child=serializers.CharField(required=True, label=_('document url list')))
|
||||
selector = serializers.CharField(required=False, allow_null=True, allow_blank=True, label=_('selector'))
|
||||
|
||||
|
||||
class DocumentInstanceQASerializer(serializers.Serializer):
|
||||
file_list = serializers.ListSerializer(required=True, label=_('file list'),
|
||||
child=serializers.FileField(required=True, label=_('file')))
|
||||
|
||||
|
||||
class DocumentInstanceTableSerializer(serializers.Serializer):
|
||||
file_list = serializers.ListSerializer(required=True, label=_('file list'),
|
||||
child=serializers.FileField(required=True, label=_('file')))
|
||||
|
||||
|
||||
class DocumentSerializers(serializers.Serializer):
|
||||
class Query(serializers.Serializer):
|
||||
# 知识库id
|
||||
|
|
@ -226,6 +286,7 @@ class DocumentSerializers(serializers.Serializer):
|
|||
return True
|
||||
|
||||
class Operate(serializers.Serializer):
|
||||
workspace_id = serializers.CharField(required=True, label=_('workspace id'))
|
||||
document_id = serializers.UUIDField(required=True, label=_('document id'))
|
||||
knowledge_id = serializers.UUIDField(required=True, label=_('knowledge id'))
|
||||
|
||||
|
|
@ -246,6 +307,31 @@ class DocumentSerializers(serializers.Serializer):
|
|||
}, select_string=get_file_content(
|
||||
os.path.join(PROJECT_DIR, "apps", "knowledge", 'sql', 'list_document.sql')), with_search_one=True)
|
||||
|
||||
def edit(self, instance: Dict, with_valid=False):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
_document = QuerySet(Document).get(id=self.data.get("document_id"))
|
||||
if with_valid:
|
||||
DocumentEditInstanceSerializer(data=instance).is_valid(document=_document)
|
||||
update_keys = ['name', 'is_active', 'hit_handling_method', 'directly_return_similarity', 'meta']
|
||||
for update_key in update_keys:
|
||||
if update_key in instance and instance.get(update_key) is not None:
|
||||
_document.__setattr__(update_key, instance.get(update_key))
|
||||
_document.save()
|
||||
return self.one()
|
||||
|
||||
@transaction.atomic
|
||||
def delete(self):
|
||||
document_id = self.data.get("document_id")
|
||||
QuerySet(model=Document).filter(id=document_id).delete()
|
||||
# 删除段落
|
||||
QuerySet(model=Paragraph).filter(document_id=document_id).delete()
|
||||
# 删除问题
|
||||
delete_problems_and_mappings([document_id])
|
||||
# 删除向量库
|
||||
delete_embedding_by_document(document_id)
|
||||
return True
|
||||
|
||||
def refresh(self, state_list=None, with_valid=True):
|
||||
if state_list is None:
|
||||
state_list = [State.PENDING.value, State.STARTED.value, State.SUCCESS.value, State.FAILURE.value,
|
||||
|
|
@ -369,6 +455,58 @@ class DocumentSerializers(serializers.Serializer):
|
|||
instance.get('paragraphs') if 'paragraphs' in instance else []
|
||||
)
|
||||
|
||||
def save_web(self, instance: Dict, with_valid=True):
|
||||
if with_valid:
|
||||
DocumentWebInstanceSerializer(data=instance).is_valid(raise_exception=True)
|
||||
self.is_valid(raise_exception=True)
|
||||
dataset_id = self.data.get('dataset_id')
|
||||
source_url_list = instance.get('source_url_list')
|
||||
selector = instance.get('selector')
|
||||
sync_web_document.delay(dataset_id, source_url_list, selector)
|
||||
|
||||
def save_qa(self, instance: Dict, with_valid=True):
|
||||
if with_valid:
|
||||
DocumentInstanceQASerializer(data=instance).is_valid(raise_exception=True)
|
||||
self.is_valid(raise_exception=True)
|
||||
file_list = instance.get('file_list')
|
||||
document_list = flat_map([self.parse_qa_file(file) for file in file_list])
|
||||
return DocumentSerializers.Batch(data={'dataset_id': self.data.get('dataset_id')}).batch_save(document_list)
|
||||
|
||||
def save_table(self, instance: Dict, with_valid=True):
|
||||
if with_valid:
|
||||
DocumentInstanceTableSerializer(data=instance).is_valid(raise_exception=True)
|
||||
self.is_valid(raise_exception=True)
|
||||
file_list = instance.get('file_list')
|
||||
document_list = flat_map([self.parse_table_file(file) for file in file_list])
|
||||
return DocumentSerializers.Batch(data={'dataset_id': self.data.get('dataset_id')}).batch_save(document_list)
|
||||
|
||||
def parse_qa_file(self, file):
|
||||
get_buffer = FileBufferHandle().get_buffer
|
||||
for parse_qa_handle in parse_qa_handle_list:
|
||||
if parse_qa_handle.support(file, get_buffer):
|
||||
return parse_qa_handle.handle(file, get_buffer, self.save_image)
|
||||
raise AppApiException(500, _('Unsupported file format'))
|
||||
|
||||
def parse_table_file(self, file):
|
||||
get_buffer = FileBufferHandle().get_buffer
|
||||
for parse_table_handle in parse_table_handle_list:
|
||||
if parse_table_handle.support(file, get_buffer):
|
||||
return parse_table_handle.handle(file, get_buffer, self.save_image)
|
||||
raise AppApiException(500, _('Unsupported file format'))
|
||||
|
||||
def save_image(self, image_list):
|
||||
if image_list is not None and len(image_list) > 0:
|
||||
exist_image_list = [str(i.get('id')) for i in
|
||||
QuerySet(File).filter(id__in=[i.id for i in image_list]).values('id')]
|
||||
save_image_list = [image for image in image_list if not exist_image_list.__contains__(str(image.id))]
|
||||
save_image_list = list({img.id: img for img in save_image_list}.values())
|
||||
# save image
|
||||
for file in save_image_list:
|
||||
file_bytes = file.meta.pop('content')
|
||||
file.workspace_id = self.data.get('workspace_id')
|
||||
file.meta['knowledge_id'] = self.data.get('knowledge_id')
|
||||
file.save(file_bytes)
|
||||
|
||||
class Split(serializers.Serializer):
|
||||
workspace_id = serializers.CharField(required=True, label=_('workspace id'))
|
||||
knowledge_id = serializers.UUIDField(required=True, label=_('knowledge id'))
|
||||
|
|
|
|||
|
|
@ -11,6 +11,9 @@ urlpatterns = [
|
|||
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document', views.DocumentView.as_view()),
|
||||
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/split', views.DocumentView.Split.as_view()),
|
||||
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/batch', views.DocumentView.Batch.as_view()),
|
||||
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/web', views.WebDocumentView.as_view()),
|
||||
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/qa', views.QaDocumentView.as_view()),
|
||||
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/table', views.TableDocumentView.as_view()),
|
||||
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/<str:document_id>', views.DocumentView.Operate.as_view()),
|
||||
path('workspace/<str:workspace_id>/knowledge/<int:current_page>/<int:page_size>', views.KnowledgeView.Page.as_view()),
|
||||
]
|
||||
|
|
|
|||
|
|
@ -9,7 +9,8 @@ from common.auth.authentication import has_permissions
|
|||
from common.constants.permission_constants import PermissionConstants
|
||||
from common.result import result
|
||||
from knowledge.api.document import DocumentSplitAPI, DocumentBatchAPI, DocumentBatchCreateAPI, DocumentCreateAPI, \
|
||||
DocumentReadAPI, DocumentEditAPI, DocumentDeleteAPI
|
||||
DocumentReadAPI, DocumentEditAPI, DocumentDeleteAPI, TableDocumentCreateAPI, QaDocumentCreateAPI, \
|
||||
WebDocumentCreateAPI
|
||||
from knowledge.api.knowledge import KnowledgeTreeReadAPI
|
||||
from knowledge.serializers.document import DocumentSerializers
|
||||
|
||||
|
|
@ -68,8 +69,10 @@ class DocumentView(APIView):
|
|||
tags=[_('Knowledge Base/Documentation')]
|
||||
)
|
||||
@has_permissions(PermissionConstants.DOCUMENT_READ.get_workspace_permission())
|
||||
def get(self, request: Request, knowledge_id: str, document_id: str):
|
||||
operate = DocumentSerializers.Operate(data={'document_id': document_id, 'knowledge_id': knowledge_id})
|
||||
def get(self, request: Request, workspace_id: str, knowledge_id: str, document_id: str):
|
||||
operate = DocumentSerializers.Operate(data={
|
||||
'document_id': document_id, 'knowledge_id': knowledge_id, 'workspace_id': workspace_id
|
||||
})
|
||||
operate.is_valid(raise_exception=True)
|
||||
return result.success(operate.one())
|
||||
|
||||
|
|
@ -83,11 +86,10 @@ class DocumentView(APIView):
|
|||
tags=[_('Knowledge Base/Documentation')]
|
||||
)
|
||||
@has_permissions(PermissionConstants.DOCUMENT_EDIT.get_workspace_permission())
|
||||
def put(self, request: Request, knowledge_id: str, document_id: str):
|
||||
return result.success(
|
||||
DocumentSerializers.Operate(data={'document_id': document_id, 'knowledge_id': knowledge_id}).edit(
|
||||
request.data,
|
||||
with_valid=True))
|
||||
def put(self, request: Request, workspace_id: str, knowledge_id: str, document_id: str):
|
||||
return result.success(DocumentSerializers.Operate(data={
|
||||
'document_id': document_id, 'knowledge_id': knowledge_id, 'workspace_id': workspace_id
|
||||
}).edit(request.data, with_valid=True))
|
||||
|
||||
@extend_schema(
|
||||
description=_('Delete document'),
|
||||
|
|
@ -98,8 +100,10 @@ class DocumentView(APIView):
|
|||
tags=[_('Knowledge Base/Documentation')]
|
||||
)
|
||||
@has_permissions(PermissionConstants.DOCUMENT_DELETE.get_workspace_permission())
|
||||
def delete(self, request: Request, knowledge_id: str, document_id: str):
|
||||
operate = DocumentSerializers.Operate(data={'document_id': document_id, 'knowledge_id': knowledge_id})
|
||||
def delete(self, request: Request, workspace_id: str, knowledge_id: str, document_id: str):
|
||||
operate = DocumentSerializers.Operate(data={
|
||||
'document_id': document_id, 'knowledge_id': knowledge_id, 'workspace_id': workspace_id
|
||||
})
|
||||
operate.is_valid(raise_exception=True)
|
||||
return result.success(operate.delete())
|
||||
|
||||
|
|
@ -195,3 +199,63 @@ class DocumentView(APIView):
|
|||
return result.success(DocumentSerializers.Batch(
|
||||
data={'workspace_id': workspace_id, 'knowledge_id': knowledge_id}
|
||||
).batch_delete(request.data))
|
||||
|
||||
|
||||
class WebDocumentView(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
@extend_schema(
|
||||
methods=['POST'],
|
||||
description=_('Create Web site documents'),
|
||||
summary=_('Create Web site documents'),
|
||||
operation_id=_('Create Web site documents'),
|
||||
request=WebDocumentCreateAPI.get_request(),
|
||||
parameters=WebDocumentCreateAPI.get_parameters(),
|
||||
responses=WebDocumentCreateAPI.get_response(),
|
||||
tags=[_('Knowledge Base/Documentation')]
|
||||
)
|
||||
@has_permissions(PermissionConstants.DOCUMENT_CREATE.get_workspace_permission())
|
||||
def post(self, request: Request, workspace_id: str, knowledge_id: str):
|
||||
return result.success(DocumentSerializers.Create(data={
|
||||
'knowledge_id': knowledge_id, 'workspace_id': workspace_id
|
||||
}).save_web(request.data, with_valid=True))
|
||||
|
||||
|
||||
class QaDocumentView(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
parser_classes = [MultiPartParser]
|
||||
|
||||
@extend_schema(
|
||||
summary=_('Import QA and create documentation'),
|
||||
description=_('Import QA and create documentation'),
|
||||
operation_id=_('Import QA and create documentation'),
|
||||
request=QaDocumentCreateAPI.get_request(),
|
||||
parameters=QaDocumentCreateAPI.get_parameters(),
|
||||
responses=QaDocumentCreateAPI.get_response(),
|
||||
tags=[_('Knowledge Base/Documentation')]
|
||||
)
|
||||
@has_permissions(PermissionConstants.DOCUMENT_CREATE.get_workspace_permission())
|
||||
def post(self, request: Request, workspace_id: str, knowledge_id: str):
|
||||
return result.success(DocumentSerializers.Create(data={
|
||||
'knowledge_id': knowledge_id, 'workspace_id': workspace_id
|
||||
}).save_qa({'file_list': request.FILES.getlist('file')}, with_valid=True))
|
||||
|
||||
|
||||
class TableDocumentView(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
parser_classes = [MultiPartParser]
|
||||
|
||||
@extend_schema(
|
||||
summary=_('Import tables and create documents'),
|
||||
description=_('Import tables and create documents'),
|
||||
operation_id=_('Import tables and create documents'),
|
||||
request=TableDocumentCreateAPI.get_request(),
|
||||
parameters=TableDocumentCreateAPI.get_parameters(),
|
||||
responses=TableDocumentCreateAPI.get_response(),
|
||||
tags=[_('Knowledge Base/Documentation')]
|
||||
)
|
||||
@has_permissions(PermissionConstants.DOCUMENT_CREATE.get_workspace_permission())
|
||||
def post(self, request: Request, workspace_id: str, knowledge_id: str):
|
||||
return result.success(DocumentSerializers.Create(
|
||||
data={'knowledge_id': knowledge_id, 'workspace_id': workspace_id}
|
||||
).save_table({'file_list': request.FILES.getlist('file')}, with_valid=True))
|
||||
|
|
|
|||
Loading…
Reference in New Issue