feat: 优化创建数据集文档段落

This commit is contained in:
shaohuzhang1 2023-12-12 15:44:21 +08:00
parent 5488804ea8
commit 019e133f0f
4 changed files with 76 additions and 29 deletions

View File

@ -212,12 +212,31 @@ class DataSetSerializers(serializers.ModelSerializer):
dataset_id = uuid.uuid1()
dataset = DataSet(
**{'id': dataset_id, 'name': self.data.get("name"), 'desc': self.data.get('desc'), 'user': user})
document_model_list = []
paragraph_model_list = []
problem_model_list = []
# 插入文档
for document in self.data.get('documents') if 'documents' in self.data else []:
document_paragraph_dict_model = DocumentSerializers.Create.get_document_paragraph_model(dataset_id,
document)
document_model_list.append(document_paragraph_dict_model.get('document'))
for paragraph in document_paragraph_dict_model.get('paragraph_model_list'):
paragraph_model_list.append(paragraph)
for problem in document_paragraph_dict_model.get('problem_model_list'):
problem_model_list.append(problem)
# 插入数据集
dataset.save()
for document in self.data.get('documents') if 'documents' in self.data else []:
DocumentSerializers.Create(data={'dataset_id': dataset_id}).save(document, with_valid=True,
with_embedding=False)
# 插入文档
QuerySet(Document).bulk_create(document_model_list) if len(document_model_list) > 0 else None
# 批量插入段落
QuerySet(Paragraph).bulk_create(paragraph_model_list) if len(paragraph_model_list) > 0 else None
# 批量插入问题
QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None
# 发送向量化事件
ListenerManagement.embedding_by_dataset_signal.send(str(dataset.id))
# 响应数据
return {**DataSetSerializers(dataset).data,
'document_list': DocumentSerializers.Query(data={'dataset_id': dataset_id}).list(with_valid=True)}

View File

@ -37,7 +37,7 @@ class DocumentInstanceSerializer(ApiMixin, serializers.Serializer):
message="数据集名称在1-128个字符之间")
])
paragraphs = ParagraphInstanceSerializer(required=False, many=True)
paragraphs = ParagraphInstanceSerializer(required=False, many=True, allow_null=True)
@staticmethod
def get_request_body_api():
@ -204,7 +204,24 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
DocumentInstanceSerializer(data=instance).is_valid(raise_exception=True)
self.is_valid(raise_exception=True)
dataset_id = self.data.get('dataset_id')
document_paragraph_model = self.get_document_paragraph_model(dataset_id, instance)
document_model = document_paragraph_model.get('document')
paragraph_model_list = document_paragraph_model.get('paragraph_model_list')
problem_model_list = document_paragraph_model.get('problem_model_list')
# 插入文档
document_model.save()
# 批量插入段落
QuerySet(Paragraph).bulk_create(paragraph_model_list) if len(paragraph_model_list) > 0 else None
# 批量插入问题
QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None
if with_embedding:
ListenerManagement.embedding_by_document_signal.send(str(document_model.id))
return DocumentSerializers.Operate(
data={'dataset_id': dataset_id, 'document_id': str(document_model.id)}).one(
with_valid=True)
@staticmethod
def get_document_paragraph_model(dataset_id, instance: Dict):
document_model = Document(
**{'dataset_id': dataset_id,
'id': uuid.uuid1(),
@ -212,19 +229,22 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
'char_length': reduce(lambda x, y: x + y,
[len(p.get('content')) for p in instance.get('paragraphs', [])],
0)})
# 插入文档
document_model.save()
for paragraph in instance.get('paragraphs') if 'paragraphs' in instance else []:
ParagraphSerializers.Create(
data={'dataset_id': dataset_id, 'document_id': str(document_model.id)}).save(paragraph,
with_valid=True,
with_embedding=False)
if with_embedding:
ListenerManagement.embedding_by_document_signal.send(str(document_model.id))
return DocumentSerializers.Operate(
data={'dataset_id': dataset_id, 'document_id': str(document_model.id)}).one(
with_valid=True)
paragraph_model_dict_list = [ParagraphSerializers.Create(
data={'dataset_id': dataset_id, 'document_id': str(document_model.id)}).get_paragraph_problem_model(
dataset_id, document_model.id, paragraph) for paragraph in (instance.get('paragraphs') if
'paragraphs' in instance else [])]
paragraph_model_list = []
problem_model_list = []
for paragraphs in paragraph_model_dict_list:
paragraph = paragraphs.get('paragraph')
for problem_model in paragraphs.get('problem_model_list'):
problem_model_list.append(problem_model)
paragraph_model_list.append(paragraph)
return {'document': document_model, 'paragraph_model_list': paragraph_model_list,
'problem_model_list': problem_model_list}
@staticmethod
def get_request_body_api():

View File

@ -39,8 +39,8 @@ class ParagraphInstanceSerializer(ApiMixin, serializers.Serializer):
validators.MaxLengthValidator(limit_value=1024,
message="段落在1-1024个字符之间"),
validators.MinLengthValidator(limit_value=1,
message="段落在1-1024个字符之间")
])
message="段落在1-1024个字符之间"),
], allow_null=True, allow_blank=True)
title = serializers.CharField(required=False, allow_null=True, allow_blank=True)
@ -179,17 +179,11 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
self.is_valid()
dataset_id = self.data.get("dataset_id")
document_id = self.data.get('document_id')
paragraph = Paragraph(id=uuid.uuid1(),
document_id=document_id,
content=instance.get("content"),
dataset_id=dataset_id,
title=instance.get("title") if 'title' in instance else '')
paragraph_problem_model = self.get_paragraph_problem_model(dataset_id, document_id, instance)
paragraph = paragraph_problem_model.get('paragraph')
problem_model_list = paragraph_problem_model.get('problem_model_list')
# 插入段落
paragraph.save()
problem_model_list = [Problem(id=uuid.uuid1(), content=problem.get('content'), paragraph_id=paragraph.id,
document_id=document_id, dataset_id=dataset_id) for problem in (
instance.get('problem_list') if 'problem_list' in instance else [])]
paragraph_problem_model.get('paragraph').save()
# 插入問題
QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None
# 修改长度
@ -200,6 +194,20 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
data={'paragraph_id': str(paragraph.id), 'dataset_id': dataset_id, 'document_id': document_id}).one(
with_valid=True)
@staticmethod
def get_paragraph_problem_model(dataset_id: str, document_id: str, instance: Dict):
paragraph = Paragraph(id=uuid.uuid1(),
document_id=document_id,
content=instance.get("content"),
dataset_id=dataset_id,
title=instance.get("title") if 'title' in instance else '')
problem_model_list = [Problem(id=uuid.uuid1(), content=problem.get('content'), paragraph_id=paragraph.id,
document_id=document_id, dataset_id=dataset_id) for problem in (
instance.get('problem_list') if 'problem_list' in instance else [])]
return {'paragraph': paragraph, 'problem_model_list': problem_model_list}
@staticmethod
def get_request_body_api():
return ParagraphInstanceSerializer.get_request_body_api()

View File

@ -43,7 +43,7 @@ def perform_db_migrate():
def start_services():
management.call_command('runserver')
management.call_command('runserver',"0.0.0.0:8000")
if __name__ == '__main__':