diff --git a/apps/dataset/serializers/dataset_serializers.py b/apps/dataset/serializers/dataset_serializers.py index 680cb4ea7..bf761b23a 100644 --- a/apps/dataset/serializers/dataset_serializers.py +++ b/apps/dataset/serializers/dataset_serializers.py @@ -212,12 +212,31 @@ class DataSetSerializers(serializers.ModelSerializer): dataset_id = uuid.uuid1() dataset = DataSet( **{'id': dataset_id, 'name': self.data.get("name"), 'desc': self.data.get('desc'), 'user': user}) + + document_model_list = [] + paragraph_model_list = [] + problem_model_list = [] + # 插入文档 + for document in self.data.get('documents') if 'documents' in self.data else []: + document_paragraph_dict_model = DocumentSerializers.Create.get_document_paragraph_model(dataset_id, + document) + document_model_list.append(document_paragraph_dict_model.get('document')) + for paragraph in document_paragraph_dict_model.get('paragraph_model_list'): + paragraph_model_list.append(paragraph) + for problem in document_paragraph_dict_model.get('problem_model_list'): + problem_model_list.append(problem) + # 插入数据集 dataset.save() - for document in self.data.get('documents') if 'documents' in self.data else []: - DocumentSerializers.Create(data={'dataset_id': dataset_id}).save(document, with_valid=True, - with_embedding=False) + # 插入文档 + QuerySet(Document).bulk_create(document_model_list) if len(document_model_list) > 0 else None + # 批量插入段落 + QuerySet(Paragraph).bulk_create(paragraph_model_list) if len(paragraph_model_list) > 0 else None + # 批量插入问题 + QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None + # 发送向量化事件 ListenerManagement.embedding_by_dataset_signal.send(str(dataset.id)) + # 响应数据 return {**DataSetSerializers(dataset).data, 'document_list': DocumentSerializers.Query(data={'dataset_id': dataset_id}).list(with_valid=True)} diff --git a/apps/dataset/serializers/document_serializers.py b/apps/dataset/serializers/document_serializers.py index e25fc5495..b80301975 100644 --- a/apps/dataset/serializers/document_serializers.py +++ b/apps/dataset/serializers/document_serializers.py @@ -37,7 +37,7 @@ class DocumentInstanceSerializer(ApiMixin, serializers.Serializer): message="数据集名称在1-128个字符之间") ]) - paragraphs = ParagraphInstanceSerializer(required=False, many=True) + paragraphs = ParagraphInstanceSerializer(required=False, many=True, allow_null=True) @staticmethod def get_request_body_api(): @@ -204,7 +204,24 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): DocumentInstanceSerializer(data=instance).is_valid(raise_exception=True) self.is_valid(raise_exception=True) dataset_id = self.data.get('dataset_id') + document_paragraph_model = self.get_document_paragraph_model(dataset_id, instance) + document_model = document_paragraph_model.get('document') + paragraph_model_list = document_paragraph_model.get('paragraph_model_list') + problem_model_list = document_paragraph_model.get('problem_model_list') + # 插入文档 + document_model.save() + # 批量插入段落 + QuerySet(Paragraph).bulk_create(paragraph_model_list) if len(paragraph_model_list) > 0 else None + # 批量插入问题 + QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None + if with_embedding: + ListenerManagement.embedding_by_document_signal.send(str(document_model.id)) + return DocumentSerializers.Operate( + data={'dataset_id': dataset_id, 'document_id': str(document_model.id)}).one( + with_valid=True) + @staticmethod + def get_document_paragraph_model(dataset_id, instance: Dict): document_model = Document( **{'dataset_id': dataset_id, 'id': uuid.uuid1(), @@ -212,19 +229,22 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): 'char_length': reduce(lambda x, y: x + y, [len(p.get('content')) for p in instance.get('paragraphs', [])], 0)}) - # 插入文档 - document_model.save() - for paragraph in instance.get('paragraphs') if 'paragraphs' in instance else []: - ParagraphSerializers.Create( - data={'dataset_id': dataset_id, 'document_id': str(document_model.id)}).save(paragraph, - with_valid=True, - with_embedding=False) - if with_embedding: - ListenerManagement.embedding_by_document_signal.send(str(document_model.id)) - return DocumentSerializers.Operate( - data={'dataset_id': dataset_id, 'document_id': str(document_model.id)}).one( - with_valid=True) + paragraph_model_dict_list = [ParagraphSerializers.Create( + data={'dataset_id': dataset_id, 'document_id': str(document_model.id)}).get_paragraph_problem_model( + dataset_id, document_model.id, paragraph) for paragraph in (instance.get('paragraphs') if + 'paragraphs' in instance else [])] + + paragraph_model_list = [] + problem_model_list = [] + for paragraphs in paragraph_model_dict_list: + paragraph = paragraphs.get('paragraph') + for problem_model in paragraphs.get('problem_model_list'): + problem_model_list.append(problem_model) + paragraph_model_list.append(paragraph) + + return {'document': document_model, 'paragraph_model_list': paragraph_model_list, + 'problem_model_list': problem_model_list} @staticmethod def get_request_body_api(): diff --git a/apps/dataset/serializers/paragraph_serializers.py b/apps/dataset/serializers/paragraph_serializers.py index d0436d691..97db251e5 100644 --- a/apps/dataset/serializers/paragraph_serializers.py +++ b/apps/dataset/serializers/paragraph_serializers.py @@ -39,8 +39,8 @@ class ParagraphInstanceSerializer(ApiMixin, serializers.Serializer): validators.MaxLengthValidator(limit_value=1024, message="段落在1-1024个字符之间"), validators.MinLengthValidator(limit_value=1, - message="段落在1-1024个字符之间") - ]) + message="段落在1-1024个字符之间"), + ], allow_null=True, allow_blank=True) title = serializers.CharField(required=False, allow_null=True, allow_blank=True) @@ -179,17 +179,11 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer): self.is_valid() dataset_id = self.data.get("dataset_id") document_id = self.data.get('document_id') - - paragraph = Paragraph(id=uuid.uuid1(), - document_id=document_id, - content=instance.get("content"), - dataset_id=dataset_id, - title=instance.get("title") if 'title' in instance else '') + paragraph_problem_model = self.get_paragraph_problem_model(dataset_id, document_id, instance) + paragraph = paragraph_problem_model.get('paragraph') + problem_model_list = paragraph_problem_model.get('problem_model_list') # 插入段落 - paragraph.save() - problem_model_list = [Problem(id=uuid.uuid1(), content=problem.get('content'), paragraph_id=paragraph.id, - document_id=document_id, dataset_id=dataset_id) for problem in ( - instance.get('problem_list') if 'problem_list' in instance else [])] + paragraph_problem_model.get('paragraph').save() # 插入問題 QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None # 修改长度 @@ -200,6 +194,20 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer): data={'paragraph_id': str(paragraph.id), 'dataset_id': dataset_id, 'document_id': document_id}).one( with_valid=True) + @staticmethod + def get_paragraph_problem_model(dataset_id: str, document_id: str, instance: Dict): + paragraph = Paragraph(id=uuid.uuid1(), + document_id=document_id, + content=instance.get("content"), + dataset_id=dataset_id, + title=instance.get("title") if 'title' in instance else '') + + problem_model_list = [Problem(id=uuid.uuid1(), content=problem.get('content'), paragraph_id=paragraph.id, + document_id=document_id, dataset_id=dataset_id) for problem in ( + instance.get('problem_list') if 'problem_list' in instance else [])] + + return {'paragraph': paragraph, 'problem_model_list': problem_model_list} + @staticmethod def get_request_body_api(): return ParagraphInstanceSerializer.get_request_body_api() diff --git a/main.py b/main.py index cb55ebaff..5fbb24c29 100644 --- a/main.py +++ b/main.py @@ -43,7 +43,7 @@ def perform_db_migrate(): def start_services(): - management.call_command('runserver') + management.call_command('runserver',"0.0.0.0:8000") if __name__ == '__main__':