mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-25 17:22:55 +00:00
feat: 优化创建数据集文档段落
This commit is contained in:
parent
5488804ea8
commit
019e133f0f
|
|
@ -212,12 +212,31 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||
dataset_id = uuid.uuid1()
|
||||
dataset = DataSet(
|
||||
**{'id': dataset_id, 'name': self.data.get("name"), 'desc': self.data.get('desc'), 'user': user})
|
||||
|
||||
document_model_list = []
|
||||
paragraph_model_list = []
|
||||
problem_model_list = []
|
||||
# 插入文档
|
||||
for document in self.data.get('documents') if 'documents' in self.data else []:
|
||||
document_paragraph_dict_model = DocumentSerializers.Create.get_document_paragraph_model(dataset_id,
|
||||
document)
|
||||
document_model_list.append(document_paragraph_dict_model.get('document'))
|
||||
for paragraph in document_paragraph_dict_model.get('paragraph_model_list'):
|
||||
paragraph_model_list.append(paragraph)
|
||||
for problem in document_paragraph_dict_model.get('problem_model_list'):
|
||||
problem_model_list.append(problem)
|
||||
|
||||
# 插入数据集
|
||||
dataset.save()
|
||||
for document in self.data.get('documents') if 'documents' in self.data else []:
|
||||
DocumentSerializers.Create(data={'dataset_id': dataset_id}).save(document, with_valid=True,
|
||||
with_embedding=False)
|
||||
# 插入文档
|
||||
QuerySet(Document).bulk_create(document_model_list) if len(document_model_list) > 0 else None
|
||||
# 批量插入段落
|
||||
QuerySet(Paragraph).bulk_create(paragraph_model_list) if len(paragraph_model_list) > 0 else None
|
||||
# 批量插入问题
|
||||
QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None
|
||||
# 发送向量化事件
|
||||
ListenerManagement.embedding_by_dataset_signal.send(str(dataset.id))
|
||||
# 响应数据
|
||||
return {**DataSetSerializers(dataset).data,
|
||||
'document_list': DocumentSerializers.Query(data={'dataset_id': dataset_id}).list(with_valid=True)}
|
||||
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ class DocumentInstanceSerializer(ApiMixin, serializers.Serializer):
|
|||
message="数据集名称在1-128个字符之间")
|
||||
])
|
||||
|
||||
paragraphs = ParagraphInstanceSerializer(required=False, many=True)
|
||||
paragraphs = ParagraphInstanceSerializer(required=False, many=True, allow_null=True)
|
||||
|
||||
@staticmethod
|
||||
def get_request_body_api():
|
||||
|
|
@ -204,7 +204,24 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||
DocumentInstanceSerializer(data=instance).is_valid(raise_exception=True)
|
||||
self.is_valid(raise_exception=True)
|
||||
dataset_id = self.data.get('dataset_id')
|
||||
document_paragraph_model = self.get_document_paragraph_model(dataset_id, instance)
|
||||
document_model = document_paragraph_model.get('document')
|
||||
paragraph_model_list = document_paragraph_model.get('paragraph_model_list')
|
||||
problem_model_list = document_paragraph_model.get('problem_model_list')
|
||||
# 插入文档
|
||||
document_model.save()
|
||||
# 批量插入段落
|
||||
QuerySet(Paragraph).bulk_create(paragraph_model_list) if len(paragraph_model_list) > 0 else None
|
||||
# 批量插入问题
|
||||
QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None
|
||||
if with_embedding:
|
||||
ListenerManagement.embedding_by_document_signal.send(str(document_model.id))
|
||||
return DocumentSerializers.Operate(
|
||||
data={'dataset_id': dataset_id, 'document_id': str(document_model.id)}).one(
|
||||
with_valid=True)
|
||||
|
||||
@staticmethod
|
||||
def get_document_paragraph_model(dataset_id, instance: Dict):
|
||||
document_model = Document(
|
||||
**{'dataset_id': dataset_id,
|
||||
'id': uuid.uuid1(),
|
||||
|
|
@ -212,19 +229,22 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||
'char_length': reduce(lambda x, y: x + y,
|
||||
[len(p.get('content')) for p in instance.get('paragraphs', [])],
|
||||
0)})
|
||||
# 插入文档
|
||||
document_model.save()
|
||||
|
||||
for paragraph in instance.get('paragraphs') if 'paragraphs' in instance else []:
|
||||
ParagraphSerializers.Create(
|
||||
data={'dataset_id': dataset_id, 'document_id': str(document_model.id)}).save(paragraph,
|
||||
with_valid=True,
|
||||
with_embedding=False)
|
||||
if with_embedding:
|
||||
ListenerManagement.embedding_by_document_signal.send(str(document_model.id))
|
||||
return DocumentSerializers.Operate(
|
||||
data={'dataset_id': dataset_id, 'document_id': str(document_model.id)}).one(
|
||||
with_valid=True)
|
||||
paragraph_model_dict_list = [ParagraphSerializers.Create(
|
||||
data={'dataset_id': dataset_id, 'document_id': str(document_model.id)}).get_paragraph_problem_model(
|
||||
dataset_id, document_model.id, paragraph) for paragraph in (instance.get('paragraphs') if
|
||||
'paragraphs' in instance else [])]
|
||||
|
||||
paragraph_model_list = []
|
||||
problem_model_list = []
|
||||
for paragraphs in paragraph_model_dict_list:
|
||||
paragraph = paragraphs.get('paragraph')
|
||||
for problem_model in paragraphs.get('problem_model_list'):
|
||||
problem_model_list.append(problem_model)
|
||||
paragraph_model_list.append(paragraph)
|
||||
|
||||
return {'document': document_model, 'paragraph_model_list': paragraph_model_list,
|
||||
'problem_model_list': problem_model_list}
|
||||
|
||||
@staticmethod
|
||||
def get_request_body_api():
|
||||
|
|
|
|||
|
|
@ -39,8 +39,8 @@ class ParagraphInstanceSerializer(ApiMixin, serializers.Serializer):
|
|||
validators.MaxLengthValidator(limit_value=1024,
|
||||
message="段落在1-1024个字符之间"),
|
||||
validators.MinLengthValidator(limit_value=1,
|
||||
message="段落在1-1024个字符之间")
|
||||
])
|
||||
message="段落在1-1024个字符之间"),
|
||||
], allow_null=True, allow_blank=True)
|
||||
|
||||
title = serializers.CharField(required=False, allow_null=True, allow_blank=True)
|
||||
|
||||
|
|
@ -179,17 +179,11 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
|
|||
self.is_valid()
|
||||
dataset_id = self.data.get("dataset_id")
|
||||
document_id = self.data.get('document_id')
|
||||
|
||||
paragraph = Paragraph(id=uuid.uuid1(),
|
||||
document_id=document_id,
|
||||
content=instance.get("content"),
|
||||
dataset_id=dataset_id,
|
||||
title=instance.get("title") if 'title' in instance else '')
|
||||
paragraph_problem_model = self.get_paragraph_problem_model(dataset_id, document_id, instance)
|
||||
paragraph = paragraph_problem_model.get('paragraph')
|
||||
problem_model_list = paragraph_problem_model.get('problem_model_list')
|
||||
# 插入段落
|
||||
paragraph.save()
|
||||
problem_model_list = [Problem(id=uuid.uuid1(), content=problem.get('content'), paragraph_id=paragraph.id,
|
||||
document_id=document_id, dataset_id=dataset_id) for problem in (
|
||||
instance.get('problem_list') if 'problem_list' in instance else [])]
|
||||
paragraph_problem_model.get('paragraph').save()
|
||||
# 插入問題
|
||||
QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None
|
||||
# 修改长度
|
||||
|
|
@ -200,6 +194,20 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
|
|||
data={'paragraph_id': str(paragraph.id), 'dataset_id': dataset_id, 'document_id': document_id}).one(
|
||||
with_valid=True)
|
||||
|
||||
@staticmethod
|
||||
def get_paragraph_problem_model(dataset_id: str, document_id: str, instance: Dict):
|
||||
paragraph = Paragraph(id=uuid.uuid1(),
|
||||
document_id=document_id,
|
||||
content=instance.get("content"),
|
||||
dataset_id=dataset_id,
|
||||
title=instance.get("title") if 'title' in instance else '')
|
||||
|
||||
problem_model_list = [Problem(id=uuid.uuid1(), content=problem.get('content'), paragraph_id=paragraph.id,
|
||||
document_id=document_id, dataset_id=dataset_id) for problem in (
|
||||
instance.get('problem_list') if 'problem_list' in instance else [])]
|
||||
|
||||
return {'paragraph': paragraph, 'problem_model_list': problem_model_list}
|
||||
|
||||
@staticmethod
|
||||
def get_request_body_api():
|
||||
return ParagraphInstanceSerializer.get_request_body_api()
|
||||
|
|
|
|||
Loading…
Reference in New Issue