diff --git a/apps/common/util/common.py b/apps/common/util/common.py index cbf6b0011..8571c91e3 100644 --- a/apps/common/util/common.py +++ b/apps/common/util/common.py @@ -102,3 +102,12 @@ def valid_license(model=None, count=None, message=None): return run return inner + + +def bulk_create_in_batches(model, data, batch_size=1000): + if len(data) == 0: + return + for i in range(0, len(data), batch_size): + batch = data[i:i + batch_size] + model.objects.bulk_create(batch) + diff --git a/apps/dataset/serializers/document_serializers.py b/apps/dataset/serializers/document_serializers.py index 6f16ae210..d90d21e15 100644 --- a/apps/dataset/serializers/document_serializers.py +++ b/apps/dataset/serializers/document_serializers.py @@ -41,7 +41,7 @@ from common.handle.impl.table.xls_parse_table_handle import XlsSplitHandle from common.handle.impl.table.xlsx_parse_table_handle import XlsxSplitHandle from common.handle.impl.text_split_handle import TextSplitHandle from common.mixins.api_mixin import ApiMixin -from common.util.common import post, flat_map +from common.util.common import post, flat_map, bulk_create_in_batches from common.util.field_message import ErrMessage from common.util.file_util import get_file_content from common.util.fork import Fork @@ -955,12 +955,11 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): # 插入文档 QuerySet(Document).bulk_create(document_model_list) if len(document_model_list) > 0 else None # 批量插入段落 - QuerySet(Paragraph).bulk_create(paragraph_model_list) if len(paragraph_model_list) > 0 else None + bulk_create_in_batches(Paragraph, paragraph_model_list, batch_size=1000) # 批量插入问题 - QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None + bulk_create_in_batches(Problem, problem_model_list, batch_size=1000) # 批量插入关联问题 - QuerySet(ProblemParagraphMapping).bulk_create(problem_paragraph_mapping_list) if len( - problem_paragraph_mapping_list) > 0 else None + bulk_create_in_batches(ProblemParagraphMapping, problem_paragraph_mapping_list, batch_size=1000) # 查询文档 query_set = QuerySet(model=Document) if len(document_model_list) == 0: