diff --git a/apps/common/handle/impl/table/csv_parse_table_handle.py b/apps/common/handle/impl/table/csv_parse_table_handle.py index 4971c424f..ece767ff8 100644 --- a/apps/common/handle/impl/table/csv_parse_table_handle.py +++ b/apps/common/handle/impl/table/csv_parse_table_handle.py @@ -8,7 +8,7 @@ from common.handle.base_parse_table_handle import BaseParseTableHandle max_kb = logging.getLogger("max_kb") -class CsvSplitHandle(BaseParseTableHandle): +class CsvParseTableHandle(BaseParseTableHandle): def support(self, file, get_buffer): file_name: str = file.name.lower() if file_name.endswith(".csv"): diff --git a/apps/common/handle/impl/table/xls_parse_table_handle.py b/apps/common/handle/impl/table/xls_parse_table_handle.py index 897e347e8..657aaf0a6 100644 --- a/apps/common/handle/impl/table/xls_parse_table_handle.py +++ b/apps/common/handle/impl/table/xls_parse_table_handle.py @@ -8,7 +8,7 @@ from common.handle.base_parse_table_handle import BaseParseTableHandle max_kb = logging.getLogger("max_kb") -class XlsSplitHandle(BaseParseTableHandle): +class XlsParseTableHandle(BaseParseTableHandle): def support(self, file, get_buffer): file_name: str = file.name.lower() buffer = get_buffer(file) diff --git a/apps/common/handle/impl/table/xlsx_parse_table_handle.py b/apps/common/handle/impl/table/xlsx_parse_table_handle.py index c7364169f..8693a260c 100644 --- a/apps/common/handle/impl/table/xlsx_parse_table_handle.py +++ b/apps/common/handle/impl/table/xlsx_parse_table_handle.py @@ -10,7 +10,7 @@ from common.handle.impl.common_handle import xlsx_embed_cells_images max_kb = logging.getLogger("max_kb") -class XlsxSplitHandle(BaseParseTableHandle): +class XlsxParseTableHandle(BaseParseTableHandle): def support(self, file, get_buffer): file_name: str = file.name.lower() if file_name.endswith('.xlsx'): diff --git a/apps/knowledge/api/document.py b/apps/knowledge/api/document.py index 737864d4e..cfafaa583 100644 --- a/apps/knowledge/api/document.py +++ b/apps/knowledge/api/document.py @@ -4,7 +4,7 @@ from drf_spectacular.utils import OpenApiParameter from common.mixins.api_mixin import APIMixin from common.result import DefaultResultSerializer from knowledge.serializers.common import BatchSerializer -from knowledge.serializers.document import DocumentInstanceSerializer +from knowledge.serializers.document import DocumentInstanceSerializer, DocumentWebInstanceSerializer class DocumentSplitAPI(APIMixin): @@ -176,3 +176,45 @@ class DocumentEditAPI(DocumentReadAPI): class DocumentDeleteAPI(DocumentReadAPI): pass + + +class TableDocumentCreateAPI(APIMixin): + @staticmethod + def get_parameters(): + return [ + OpenApiParameter( + name="workspace_id", + description="工作空间id", + type=OpenApiTypes.STR, + location='path', + required=True, + ), + OpenApiParameter( + name="knowledge_id", + description="知识库id", + type=OpenApiTypes.STR, + location='path', + required=True, + ), + OpenApiParameter( + name="file", + description="文件", + type=OpenApiTypes.BINARY, + location='query', + required=False, + ), + ] + + @staticmethod + def get_response(): + return DefaultResultSerializer + + +class QaDocumentCreateAPI(TableDocumentCreateAPI): + pass + + +class WebDocumentCreateAPI(APIMixin): + @staticmethod + def get_request(): + return DocumentWebInstanceSerializer diff --git a/apps/knowledge/serializers/document.py b/apps/knowledge/serializers/document.py index 3d0d860aa..001d17f4d 100644 --- a/apps/knowledge/serializers/document.py +++ b/apps/knowledge/serializers/document.py @@ -1,11 +1,13 @@ import logging import os +import re import traceback from functools import reduce from typing import Dict, List import uuid_utils.compat as uuid from celery_once import AlreadyQueued +from django.core import validators from django.db import transaction, models from django.db.models import QuerySet, Model from django.db.models.functions import Substr, Reverse @@ -16,6 +18,13 @@ from common.db.search import native_search, get_dynamics_model, native_page_sear from common.event import ListenerManagement from common.event.common import work_thread_pool from common.exception.app_exception import AppApiException +from common.handle.impl.qa.csv_parse_qa_handle import CsvParseQAHandle +from common.handle.impl.qa.xls_parse_qa_handle import XlsParseQAHandle +from common.handle.impl.qa.xlsx_parse_qa_handle import XlsxParseQAHandle +from common.handle.impl.qa.zip_parse_qa_handle import ZipParseQAHandle +from common.handle.impl.table.csv_parse_table_handle import CsvParseTableHandle +from common.handle.impl.table.xls_parse_table_handle import XlsParseTableHandle +from common.handle.impl.table.xlsx_parse_table_handle import XlsxParseTableHandle from common.handle.impl.text.csv_split_handle import CsvSplitHandle from common.handle.impl.text.doc_split_handle import DocSplitHandle from common.handle.impl.text.html_split_handle import HTMLSplitHandle @@ -26,14 +35,16 @@ from common.handle.impl.text.xlsx_split_handle import XlsxSplitHandle from common.handle.impl.text.zip_split_handle import ZipSplitHandle from common.utils.common import post, get_file_content, bulk_create_in_batches from common.utils.fork import Fork -from common.utils.split_model import get_split_model +from common.utils.split_model import get_split_model, flat_map from knowledge.models import Knowledge, Paragraph, Problem, Document, KnowledgeType, ProblemParagraphMapping, State, \ TaskType, File -from knowledge.serializers.common import ProblemParagraphManage, BatchSerializer, get_embedding_model_id_by_knowledge_id +from knowledge.serializers.common import ProblemParagraphManage, BatchSerializer, \ + get_embedding_model_id_by_knowledge_id, MetaSerializer from knowledge.serializers.paragraph import ParagraphSerializers, ParagraphInstanceSerializer, \ delete_problems_and_mappings from knowledge.task.embedding import embedding_by_document, delete_embedding_by_document_list, \ delete_embedding_by_document +from knowledge.task.sync import sync_web_document from maxkb.const import PROJECT_DIR default_split_handle = TextSplitHandle() @@ -48,6 +59,9 @@ split_handles = [ default_split_handle ] +parse_qa_handle_list = [XlsParseQAHandle(), CsvParseQAHandle(), XlsxParseQAHandle(), ZipParseQAHandle()] +parse_table_handle_list = [CsvParseTableHandle(), XlsParseTableHandle(), XlsxParseTableHandle()] + class BatchCancelInstanceSerializer(serializers.Serializer): id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True), label=('id list')) @@ -67,6 +81,36 @@ class DocumentInstanceSerializer(serializers.Serializer): paragraphs = ParagraphInstanceSerializer(required=False, many=True, allow_null=True) +class DocumentEditInstanceSerializer(serializers.Serializer): + meta = serializers.DictField(required=False) + name = serializers.CharField(required=False, max_length=128, min_length=1, label=_('document name')) + hit_handling_method = serializers.CharField(required=False, validators=[ + validators.RegexValidator(regex=re.compile("^optimization|directly_return$"), + message=_('The type only supports optimization|directly_return'), + code=500) + ], label=_('hit handling method')) + + directly_return_similarity = serializers.FloatField(required=False, max_value=2, min_value=0, + label=_('directly return similarity')) + + is_active = serializers.BooleanField(required=False, label=_('document is active')) + + @staticmethod + def get_meta_valid_map(): + dataset_meta_valid_map = { + KnowledgeType.BASE: MetaSerializer.BaseMeta, + KnowledgeType.WEB: MetaSerializer.WebMeta + } + return dataset_meta_valid_map + + def is_valid(self, *, document: Document = None): + super().is_valid(raise_exception=True) + if 'meta' in self.data and self.data.get('meta') is not None: + dataset_meta_valid_map = self.get_meta_valid_map() + valid_class = dataset_meta_valid_map.get(document.type) + valid_class(data=self.data.get('meta')).is_valid(raise_exception=True) + + class DocumentSplitRequest(serializers.Serializer): file = serializers.ListField(required=True, label=_('file list')) limit = serializers.IntegerField(required=False, label=_('limit')) @@ -78,6 +122,22 @@ class DocumentSplitRequest(serializers.Serializer): with_filter = serializers.BooleanField(required=False, label=_('Auto Clean')) +class DocumentWebInstanceSerializer(serializers.Serializer): + source_url_list = serializers.ListField(required=True, label=_('document url list'), + child=serializers.CharField(required=True, label=_('document url list'))) + selector = serializers.CharField(required=False, allow_null=True, allow_blank=True, label=_('selector')) + + +class DocumentInstanceQASerializer(serializers.Serializer): + file_list = serializers.ListSerializer(required=True, label=_('file list'), + child=serializers.FileField(required=True, label=_('file'))) + + +class DocumentInstanceTableSerializer(serializers.Serializer): + file_list = serializers.ListSerializer(required=True, label=_('file list'), + child=serializers.FileField(required=True, label=_('file'))) + + class DocumentSerializers(serializers.Serializer): class Query(serializers.Serializer): # 知识库id @@ -226,6 +286,7 @@ class DocumentSerializers(serializers.Serializer): return True class Operate(serializers.Serializer): + workspace_id = serializers.CharField(required=True, label=_('workspace id')) document_id = serializers.UUIDField(required=True, label=_('document id')) knowledge_id = serializers.UUIDField(required=True, label=_('knowledge id')) @@ -246,6 +307,31 @@ class DocumentSerializers(serializers.Serializer): }, select_string=get_file_content( os.path.join(PROJECT_DIR, "apps", "knowledge", 'sql', 'list_document.sql')), with_search_one=True) + def edit(self, instance: Dict, with_valid=False): + if with_valid: + self.is_valid(raise_exception=True) + _document = QuerySet(Document).get(id=self.data.get("document_id")) + if with_valid: + DocumentEditInstanceSerializer(data=instance).is_valid(document=_document) + update_keys = ['name', 'is_active', 'hit_handling_method', 'directly_return_similarity', 'meta'] + for update_key in update_keys: + if update_key in instance and instance.get(update_key) is not None: + _document.__setattr__(update_key, instance.get(update_key)) + _document.save() + return self.one() + + @transaction.atomic + def delete(self): + document_id = self.data.get("document_id") + QuerySet(model=Document).filter(id=document_id).delete() + # 删除段落 + QuerySet(model=Paragraph).filter(document_id=document_id).delete() + # 删除问题 + delete_problems_and_mappings([document_id]) + # 删除向量库 + delete_embedding_by_document(document_id) + return True + def refresh(self, state_list=None, with_valid=True): if state_list is None: state_list = [State.PENDING.value, State.STARTED.value, State.SUCCESS.value, State.FAILURE.value, @@ -369,6 +455,58 @@ class DocumentSerializers(serializers.Serializer): instance.get('paragraphs') if 'paragraphs' in instance else [] ) + def save_web(self, instance: Dict, with_valid=True): + if with_valid: + DocumentWebInstanceSerializer(data=instance).is_valid(raise_exception=True) + self.is_valid(raise_exception=True) + dataset_id = self.data.get('dataset_id') + source_url_list = instance.get('source_url_list') + selector = instance.get('selector') + sync_web_document.delay(dataset_id, source_url_list, selector) + + def save_qa(self, instance: Dict, with_valid=True): + if with_valid: + DocumentInstanceQASerializer(data=instance).is_valid(raise_exception=True) + self.is_valid(raise_exception=True) + file_list = instance.get('file_list') + document_list = flat_map([self.parse_qa_file(file) for file in file_list]) + return DocumentSerializers.Batch(data={'dataset_id': self.data.get('dataset_id')}).batch_save(document_list) + + def save_table(self, instance: Dict, with_valid=True): + if with_valid: + DocumentInstanceTableSerializer(data=instance).is_valid(raise_exception=True) + self.is_valid(raise_exception=True) + file_list = instance.get('file_list') + document_list = flat_map([self.parse_table_file(file) for file in file_list]) + return DocumentSerializers.Batch(data={'dataset_id': self.data.get('dataset_id')}).batch_save(document_list) + + def parse_qa_file(self, file): + get_buffer = FileBufferHandle().get_buffer + for parse_qa_handle in parse_qa_handle_list: + if parse_qa_handle.support(file, get_buffer): + return parse_qa_handle.handle(file, get_buffer, self.save_image) + raise AppApiException(500, _('Unsupported file format')) + + def parse_table_file(self, file): + get_buffer = FileBufferHandle().get_buffer + for parse_table_handle in parse_table_handle_list: + if parse_table_handle.support(file, get_buffer): + return parse_table_handle.handle(file, get_buffer, self.save_image) + raise AppApiException(500, _('Unsupported file format')) + + def save_image(self, image_list): + if image_list is not None and len(image_list) > 0: + exist_image_list = [str(i.get('id')) for i in + QuerySet(File).filter(id__in=[i.id for i in image_list]).values('id')] + save_image_list = [image for image in image_list if not exist_image_list.__contains__(str(image.id))] + save_image_list = list({img.id: img for img in save_image_list}.values()) + # save image + for file in save_image_list: + file_bytes = file.meta.pop('content') + file.workspace_id = self.data.get('workspace_id') + file.meta['knowledge_id'] = self.data.get('knowledge_id') + file.save(file_bytes) + class Split(serializers.Serializer): workspace_id = serializers.CharField(required=True, label=_('workspace id')) knowledge_id = serializers.UUIDField(required=True, label=_('knowledge id')) diff --git a/apps/knowledge/urls.py b/apps/knowledge/urls.py index 277c0aa05..d85b7c32d 100644 --- a/apps/knowledge/urls.py +++ b/apps/knowledge/urls.py @@ -11,6 +11,9 @@ urlpatterns = [ path('workspace//knowledge//document', views.DocumentView.as_view()), path('workspace//knowledge//document/split', views.DocumentView.Split.as_view()), path('workspace//knowledge//document/batch', views.DocumentView.Batch.as_view()), + path('workspace//knowledge//document/web', views.WebDocumentView.as_view()), + path('workspace//knowledge//document/qa', views.QaDocumentView.as_view()), + path('workspace//knowledge//document/table', views.TableDocumentView.as_view()), path('workspace//knowledge//document/', views.DocumentView.Operate.as_view()), path('workspace//knowledge//', views.KnowledgeView.Page.as_view()), ] diff --git a/apps/knowledge/views/document.py b/apps/knowledge/views/document.py index ecabcff40..b12f1437d 100644 --- a/apps/knowledge/views/document.py +++ b/apps/knowledge/views/document.py @@ -9,7 +9,8 @@ from common.auth.authentication import has_permissions from common.constants.permission_constants import PermissionConstants from common.result import result from knowledge.api.document import DocumentSplitAPI, DocumentBatchAPI, DocumentBatchCreateAPI, DocumentCreateAPI, \ - DocumentReadAPI, DocumentEditAPI, DocumentDeleteAPI + DocumentReadAPI, DocumentEditAPI, DocumentDeleteAPI, TableDocumentCreateAPI, QaDocumentCreateAPI, \ + WebDocumentCreateAPI from knowledge.api.knowledge import KnowledgeTreeReadAPI from knowledge.serializers.document import DocumentSerializers @@ -68,8 +69,10 @@ class DocumentView(APIView): tags=[_('Knowledge Base/Documentation')] ) @has_permissions(PermissionConstants.DOCUMENT_READ.get_workspace_permission()) - def get(self, request: Request, knowledge_id: str, document_id: str): - operate = DocumentSerializers.Operate(data={'document_id': document_id, 'knowledge_id': knowledge_id}) + def get(self, request: Request, workspace_id: str, knowledge_id: str, document_id: str): + operate = DocumentSerializers.Operate(data={ + 'document_id': document_id, 'knowledge_id': knowledge_id, 'workspace_id': workspace_id + }) operate.is_valid(raise_exception=True) return result.success(operate.one()) @@ -83,11 +86,10 @@ class DocumentView(APIView): tags=[_('Knowledge Base/Documentation')] ) @has_permissions(PermissionConstants.DOCUMENT_EDIT.get_workspace_permission()) - def put(self, request: Request, knowledge_id: str, document_id: str): - return result.success( - DocumentSerializers.Operate(data={'document_id': document_id, 'knowledge_id': knowledge_id}).edit( - request.data, - with_valid=True)) + def put(self, request: Request, workspace_id: str, knowledge_id: str, document_id: str): + return result.success(DocumentSerializers.Operate(data={ + 'document_id': document_id, 'knowledge_id': knowledge_id, 'workspace_id': workspace_id + }).edit(request.data, with_valid=True)) @extend_schema( description=_('Delete document'), @@ -98,8 +100,10 @@ class DocumentView(APIView): tags=[_('Knowledge Base/Documentation')] ) @has_permissions(PermissionConstants.DOCUMENT_DELETE.get_workspace_permission()) - def delete(self, request: Request, knowledge_id: str, document_id: str): - operate = DocumentSerializers.Operate(data={'document_id': document_id, 'knowledge_id': knowledge_id}) + def delete(self, request: Request, workspace_id: str, knowledge_id: str, document_id: str): + operate = DocumentSerializers.Operate(data={ + 'document_id': document_id, 'knowledge_id': knowledge_id, 'workspace_id': workspace_id + }) operate.is_valid(raise_exception=True) return result.success(operate.delete()) @@ -195,3 +199,63 @@ class DocumentView(APIView): return result.success(DocumentSerializers.Batch( data={'workspace_id': workspace_id, 'knowledge_id': knowledge_id} ).batch_delete(request.data)) + + +class WebDocumentView(APIView): + authentication_classes = [TokenAuth] + + @extend_schema( + methods=['POST'], + description=_('Create Web site documents'), + summary=_('Create Web site documents'), + operation_id=_('Create Web site documents'), + request=WebDocumentCreateAPI.get_request(), + parameters=WebDocumentCreateAPI.get_parameters(), + responses=WebDocumentCreateAPI.get_response(), + tags=[_('Knowledge Base/Documentation')] + ) + @has_permissions(PermissionConstants.DOCUMENT_CREATE.get_workspace_permission()) + def post(self, request: Request, workspace_id: str, knowledge_id: str): + return result.success(DocumentSerializers.Create(data={ + 'knowledge_id': knowledge_id, 'workspace_id': workspace_id + }).save_web(request.data, with_valid=True)) + + +class QaDocumentView(APIView): + authentication_classes = [TokenAuth] + parser_classes = [MultiPartParser] + + @extend_schema( + summary=_('Import QA and create documentation'), + description=_('Import QA and create documentation'), + operation_id=_('Import QA and create documentation'), + request=QaDocumentCreateAPI.get_request(), + parameters=QaDocumentCreateAPI.get_parameters(), + responses=QaDocumentCreateAPI.get_response(), + tags=[_('Knowledge Base/Documentation')] + ) + @has_permissions(PermissionConstants.DOCUMENT_CREATE.get_workspace_permission()) + def post(self, request: Request, workspace_id: str, knowledge_id: str): + return result.success(DocumentSerializers.Create(data={ + 'knowledge_id': knowledge_id, 'workspace_id': workspace_id + }).save_qa({'file_list': request.FILES.getlist('file')}, with_valid=True)) + + +class TableDocumentView(APIView): + authentication_classes = [TokenAuth] + parser_classes = [MultiPartParser] + + @extend_schema( + summary=_('Import tables and create documents'), + description=_('Import tables and create documents'), + operation_id=_('Import tables and create documents'), + request=TableDocumentCreateAPI.get_request(), + parameters=TableDocumentCreateAPI.get_parameters(), + responses=TableDocumentCreateAPI.get_response(), + tags=[_('Knowledge Base/Documentation')] + ) + @has_permissions(PermissionConstants.DOCUMENT_CREATE.get_workspace_permission()) + def post(self, request: Request, workspace_id: str, knowledge_id: str): + return result.success(DocumentSerializers.Create( + data={'knowledge_id': knowledge_id, 'workspace_id': workspace_id} + ).save_table({'file_list': request.FILES.getlist('file')}, with_valid=True))