diff --git a/apps/common/utils/common.py b/apps/common/utils/common.py index d5d3dac10..a3db845fe 100644 --- a/apps/common/utils/common.py +++ b/apps/common/utils/common.py @@ -263,3 +263,10 @@ def parse_md_image(content: str): image_list = [match.group() for match in matches] return image_list +def bulk_create_in_batches(model, data, batch_size=1000): + if len(data) == 0: + return + for i in range(0, len(data), batch_size): + batch = data[i:i + batch_size] + model.objects.bulk_create(batch) + diff --git a/apps/knowledge/api/document.py b/apps/knowledge/api/document.py index 47a7ddb03..b8dca50c4 100644 --- a/apps/knowledge/api/document.py +++ b/apps/knowledge/api/document.py @@ -2,16 +2,11 @@ from drf_spectacular.types import OpenApiTypes from drf_spectacular.utils import OpenApiParameter from common.mixins.api_mixin import APIMixin -from common.result import DefaultResultSerializer, ResultSerializer +from common.result import DefaultResultSerializer +from knowledge.serializers.common import BatchSerializer from knowledge.serializers.document import DocumentCreateRequest -class DocumentCreateResponse(ResultSerializer): - @staticmethod - def get_data(): - return DefaultResultSerializer() - - class DocumentCreateAPI(APIMixin): @staticmethod def get_parameters(): @@ -31,7 +26,7 @@ class DocumentCreateAPI(APIMixin): @staticmethod def get_response(): - return DocumentCreateResponse + return DefaultResultSerializer class DocumentSplitAPI(APIMixin): @@ -75,3 +70,31 @@ class DocumentSplitAPI(APIMixin): ), ] + +class DocumentBatchAPI(APIMixin): + @staticmethod + def get_parameters(): + return [ + OpenApiParameter( + name="workspace_id", + description="工作空间id", + type=OpenApiTypes.STR, + location='path', + required=True, + ), + OpenApiParameter( + name="knowledge_id", + description="知识库id", + type=OpenApiTypes.STR, + location='path', + required=True, + ), + ] + + @staticmethod + def get_request(): + return BatchSerializer + + @staticmethod + def get_response(): + return DefaultResultSerializer diff --git a/apps/knowledge/serializers/document.py b/apps/knowledge/serializers/document.py index e3df94fb6..8c83fc9be 100644 --- a/apps/knowledge/serializers/document.py +++ b/apps/knowledge/serializers/document.py @@ -12,6 +12,7 @@ from rest_framework import serializers from common.db.search import native_search from common.event import ListenerManagement +from common.event.common import work_thread_pool from common.exception.app_exception import AppApiException from common.handle.impl.text.csv_split_handle import CsvSplitHandle from common.handle.impl.text.doc_split_handle import DocSplitHandle @@ -21,12 +22,13 @@ from common.handle.impl.text.text_split_handle import TextSplitHandle from common.handle.impl.text.xls_split_handle import XlsSplitHandle from common.handle.impl.text.xlsx_split_handle import XlsxSplitHandle from common.handle.impl.text.zip_split_handle import ZipSplitHandle -from common.utils.common import post, get_file_content +from common.utils.common import post, get_file_content, bulk_create_in_batches from knowledge.models import Knowledge, Paragraph, Problem, Document, KnowledgeType, ProblemParagraphMapping, State, \ TaskType, File -from knowledge.serializers.common import ProblemParagraphManage -from knowledge.serializers.paragraph import ParagraphSerializers, ParagraphInstanceSerializer -from knowledge.task import embedding_by_document +from knowledge.serializers.common import ProblemParagraphManage, BatchSerializer +from knowledge.serializers.paragraph import ParagraphSerializers, ParagraphInstanceSerializer, \ + delete_problems_and_mappings +from knowledge.task import embedding_by_document, delete_embedding_by_document_list from maxkb.const import PROJECT_DIR default_split_handle = TextSplitHandle() @@ -42,6 +44,19 @@ split_handles = [ ] +class BatchCancelInstanceSerializer(serializers.Serializer): + id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True), label=('id list')) + type = serializers.IntegerField(required=True, label=_('task type')) + + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + _type = self.data.get('type') + try: + TaskType(_type) + except Exception as e: + raise AppApiException(500, _('task type not support')) + + class DocumentInstanceSerializer(serializers.Serializer): name = serializers.CharField(required=True, label=_('document name'), max_length=128, min_length=1) paragraphs = ParagraphInstanceSerializer(required=False, many=True, allow_null=True) @@ -65,6 +80,17 @@ class DocumentSplitRequest(serializers.Serializer): with_filter = serializers.BooleanField(required=False, label=_('Auto Clean')) +class DocumentBatchRequest(serializers.Serializer): + file = serializers.ListField(required=True, label=_('file list')) + limit = serializers.IntegerField(required=False, label=_('limit')) + patterns = serializers.ListField( + required=False, + child=serializers.CharField(required=True, label=_('patterns')), + label=_('patterns') + ) + with_filter = serializers.BooleanField(required=False, label=_('Auto Clean')) + + class DocumentSerializers(serializers.Serializer): class Operate(serializers.Serializer): document_id = serializers.UUIDField(required=True, label=_('document id')) @@ -264,6 +290,150 @@ class DocumentSerializers(serializers.Serializer): return result return [result] + class Batch(serializers.Serializer): + workspace_id = serializers.UUIDField(required=True, label=_('workspace id')) + knowledge_id = serializers.UUIDField(required=True, label=_('knowledge id')) + + @staticmethod + def post_embedding(document_list, knowledge_id): + for document_dict in document_list: + DocumentSerializers.Operate( + data={'knowledge_id': knowledge_id, 'document_id': document_dict.get('id')}).refresh() + return document_list + + @post(post_function=post_embedding) + @transaction.atomic + def batch_save(self, instance_list: List[Dict], with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + DocumentInstanceSerializer(many=True, data=instance_list).is_valid(raise_exception=True) + knowledge_id = self.data.get("knowledge_id") + document_model_list = [] + paragraph_model_list = [] + problem_paragraph_object_list = [] + # 插入文档 + for document in instance_list: + document_paragraph_dict_model = DocumentSerializers.Create.get_document_paragraph_model(knowledge_id, + document) + document_model_list.append(document_paragraph_dict_model.get('document')) + for paragraph in document_paragraph_dict_model.get('paragraph_model_list'): + paragraph_model_list.append(paragraph) + for problem_paragraph_object in document_paragraph_dict_model.get('problem_paragraph_object_list'): + problem_paragraph_object_list.append(problem_paragraph_object) + + problem_model_list, problem_paragraph_mapping_list = ( + ProblemParagraphManage(problem_paragraph_object_list, knowledge_id).to_problem_model_list() + ) + # 插入文档 + QuerySet(Document).bulk_create(document_model_list) if len(document_model_list) > 0 else None + # 批量插入段落 + bulk_create_in_batches(Paragraph, paragraph_model_list, batch_size=1000) + # 批量插入问题 + bulk_create_in_batches(Problem, problem_model_list, batch_size=1000) + # 批量插入关联问题 + bulk_create_in_batches(ProblemParagraphMapping, problem_paragraph_mapping_list, batch_size=1000) + # 查询文档 + query_set = QuerySet(model=Document) + if len(document_model_list) == 0: + return [], knowledge_id + query_set = query_set.filter(**{'id__in': [d.id for d in document_model_list]}) + return native_search( + { + 'document_custom_sql': query_set, + 'order_by_query': QuerySet(Document).order_by('-create_time', 'id') + }, + select_string=get_file_content( + os.path.join(PROJECT_DIR, "apps", "knowledge", 'sql', 'list_document.sql') + ), + with_search_one=False + ), knowledge_id + + @staticmethod + def _batch_sync(document_id_list: List[str]): + for document_id in document_id_list: + DocumentSerializers.Sync(data={'document_id': document_id}).sync() + + def batch_sync(self, instance: Dict, with_valid=True): + if with_valid: + BatchSerializer(data=instance).is_valid(model=Document, raise_exception=True) + self.is_valid(raise_exception=True) + # 异步同步 + work_thread_pool.submit(self._batch_sync, instance.get('id_list')) + return True + + @transaction.atomic + def batch_delete(self, instance: Dict, with_valid=True): + if with_valid: + BatchSerializer(data=instance).is_valid(model=Document, raise_exception=True) + self.is_valid(raise_exception=True) + document_id_list = instance.get("id_list") + QuerySet(Document).filter(id__in=document_id_list).delete() + QuerySet(Paragraph).filter(document_id__in=document_id_list).delete() + delete_problems_and_mappings(document_id_list) + # 删除向量库 + delete_embedding_by_document_list(document_id_list) + return True + + def batch_cancel(self, instance: Dict, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + BatchCancelInstanceSerializer(data=instance).is_valid(raise_exception=True) + document_id_list = instance.get("id_list") + ListenerManagement.update_status( + QuerySet(Paragraph).annotate( + reversed_status=Reverse('status'), + task_type_status=Substr('reversed_status', TaskType(instance.get('type')).value, 1), + ).filter( + task_type_status__in=[State.PENDING.value, State.STARTED.value] + ).filter( + document_id__in=document_id_list + ).values('id'), + TaskType(instance.get('type')), + State.REVOKE + ) + ListenerManagement.update_status( + QuerySet(Document).annotate( + reversed_status=Reverse('status'), + task_type_status=Substr('reversed_status', TaskType(instance.get('type')).value, 1), + ).filter( + task_type_status__in=[State.PENDING.value, State.STARTED.value] + ).filter( + id__in=document_id_list + ).values('id'), + TaskType(instance.get('type')), + State.REVOKE + ) + + def batch_edit_hit_handling(self, instance: Dict, with_valid=True): + if with_valid: + BatchSerializer(data=instance).is_valid(model=Document, raise_exception=True) + hit_handling_method = instance.get('hit_handling_method') + if hit_handling_method is None: + raise AppApiException(500, _('Hit handling method is required')) + if hit_handling_method != 'optimization' and hit_handling_method != 'directly_return': + raise AppApiException(500, _('The hit processing method must be directly_return|optimization')) + self.is_valid(raise_exception=True) + document_id_list = instance.get("id_list") + hit_handling_method = instance.get('hit_handling_method') + directly_return_similarity = instance.get('directly_return_similarity') + update_dict = {'hit_handling_method': hit_handling_method} + if directly_return_similarity is not None: + update_dict['directly_return_similarity'] = directly_return_similarity + QuerySet(Document).filter(id__in=document_id_list).update(**update_dict) + + def batch_refresh(self, instance: Dict, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + document_id_list = instance.get("id_list") + state_list = instance.get("state_list") + knowledge_id = self.data.get('knowledge_id') + for document_id in document_id_list: + try: + DocumentSerializers.Operate( + data={'knowledge_id': knowledge_id, 'document_id': document_id}).refresh(state_list) + except AlreadyQueued as e: + pass + class FileBufferHandle: buffer = None diff --git a/apps/knowledge/urls.py b/apps/knowledge/urls.py index eaa95c729..5caa357d8 100644 --- a/apps/knowledge/urls.py +++ b/apps/knowledge/urls.py @@ -9,5 +9,6 @@ urlpatterns = [ path('workspace//knowledge/web', views.KnowledgeWebView.as_view()), path('workspace//knowledge/', views.KnowledgeView.Operate.as_view()), path('workspace//knowledge//document/split', views.DocumentView.Split.as_view()), + path('workspace//knowledge//document/batch', views.DocumentView.Batch.as_view()), path('workspace//knowledge//', views.KnowledgeView.Page.as_view()), ] diff --git a/apps/knowledge/views/document.py b/apps/knowledge/views/document.py index fbe276a63..320d09bde 100644 --- a/apps/knowledge/views/document.py +++ b/apps/knowledge/views/document.py @@ -6,9 +6,9 @@ from rest_framework.views import APIView from common.auth import TokenAuth from common.auth.authentication import has_permissions -from common.constants.permission_constants import PermissionConstants, CompareConstants +from common.constants.permission_constants import PermissionConstants from common.result import result -from knowledge.api.document import DocumentSplitAPI +from knowledge.api.document import DocumentSplitAPI, DocumentBatchAPI from knowledge.api.knowledge import KnowledgeTreeReadAPI from knowledge.serializers.document import DocumentSerializers from knowledge.serializers.knowledge import KnowledgeSerializer @@ -68,3 +68,60 @@ class DocumentView(APIView): 'workspace_id': workspace_id, 'knowledge_id': knowledge_id, }).parse(split_data)) + + class Batch(APIView): + authentication_classes = [TokenAuth] + + @extend_schema( + methods=['POST'], + description=_('Create documents in batches'), + operation_id=_('Create documents in batches'), + request=DocumentBatchAPI.get_request(), + parameters=DocumentBatchAPI.get_parameters(), + responses=DocumentBatchAPI.get_response(), + tags=[_('Knowledge Base/Documentation')] + ) + @has_permissions([ + PermissionConstants.DOCUMENT_CREATE.get_workspace_permission(), + PermissionConstants.DOCUMENT_EDIT.get_workspace_permission(), + ]) + def post(self, request: Request, workspace_id: str, knowledge_id: str): + return result.success(DocumentSerializers.Batch( + data={'knowledge_id': knowledge_id, 'workspace_id': workspace_id} + ).batch_save(request.data)) + + @extend_schema( + methods=['PUT'], + description=_('Batch sync documents'), + operation_id=_('Batch sync documents'), + request=DocumentBatchAPI.get_request(), + parameters=DocumentBatchAPI.get_parameters(), + responses=DocumentBatchAPI.get_response(), + tags=[_('Knowledge Base/Documentation')] + ) + @has_permissions([ + PermissionConstants.DOCUMENT_CREATE.get_workspace_permission(), + PermissionConstants.DOCUMENT_EDIT.get_workspace_permission(), + ]) + def put(self, request: Request, workspace_id: str, knowledge_id: str): + return result.success(DocumentSerializers.Batch( + data={'knowledge_id': knowledge_id, 'workspace_id': workspace_id} + ).batch_sync(request.data)) + + @extend_schema( + methods=['DELETE'], + description=_('Delete documents in batches'), + operation_id=_('Delete documents in batches'), + request=DocumentBatchAPI.get_request(), + parameters=DocumentBatchAPI.get_parameters(), + responses=DocumentBatchAPI.get_response(), + tags=[_('Knowledge Base/Documentation')] + ) + @has_permissions([ + PermissionConstants.DOCUMENT_CREATE.get_workspace_permission(), + PermissionConstants.DOCUMENT_EDIT.get_workspace_permission(), + ]) + def delete(self, request: Request, workspace_id: str, knowledge_id: str): + return result.success(DocumentSerializers.Batch( + data={'workspace_id': workspace_id, 'knowledge_id': knowledge_id} + ).batch_delete(request.data))