fix: 【知识库】知识库上传 有关联问题的会阻塞 (#676)

This commit is contained in:
shaohuzhang1 2024-07-01 19:39:07 +08:00 committed by GitHub
parent 60181d6f83
commit fc6da6a484
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 89 additions and 71 deletions

View File

@ -7,6 +7,7 @@
@desc:
"""
import os
import uuid
from typing import List
from django.db.models import QuerySet
@ -20,7 +21,7 @@ from common.mixins.api_mixin import ApiMixin
from common.util.field_message import ErrMessage
from common.util.file_util import get_file_content
from common.util.fork import Fork
from dataset.models import Paragraph
from dataset.models import Paragraph, Problem, ProblemParagraphMapping
from smartdoc.conf import PROJECT_DIR
@ -79,3 +80,53 @@ class BatchSerializer(ApiMixin, serializers.Serializer):
description="主键id列表")
}
)
class ProblemParagraphObject:
def __init__(self, dataset_id: str, document_id: str, paragraph_id: str, problem_content: str):
self.dataset_id = dataset_id
self.document_id = document_id
self.paragraph_id = paragraph_id
self.problem_content = problem_content
def or_get(exists_problem_list, content, dataset_id, document_id, paragraph_id, problem_content_dict):
if content in problem_content_dict:
return problem_content_dict.get(content)[0], document_id, paragraph_id
exists = [row for row in exists_problem_list if row.content == content]
if len(exists) > 0:
problem_content_dict[content] = exists[0], False
return exists[0], document_id, paragraph_id
else:
problem = Problem(id=uuid.uuid1(), content=content, dataset_id=dataset_id)
problem_content_dict[content] = problem, True
return problem, document_id, paragraph_id
class ProblemParagraphManage:
def __init__(self, problemParagraphObjectList: [ProblemParagraphObject], dataset_id):
self.dataset_id = dataset_id
self.problemParagraphObjectList = problemParagraphObjectList
def to_problem_model_list(self):
problem_list = [item.problem_content for item in self.problemParagraphObjectList]
exists_problem_list = []
if len(self.problemParagraphObjectList) > 0:
# 查询到已存在的问题列表
exists_problem_list = QuerySet(Problem).filter(dataset_id=self.dataset_id,
content__in=problem_list).all()
problem_content_dict = {}
problem_model_list = [
or_get(exists_problem_list, problemParagraphObject.problem_content, problemParagraphObject.dataset_id,
problemParagraphObject.document_id, problemParagraphObject.paragraph_id, problem_content_dict) for
problemParagraphObject in self.problemParagraphObjectList]
problem_paragraph_mapping_list = [
ProblemParagraphMapping(id=uuid.uuid1(), document_id=document_id, problem_id=problem_model.id,
paragraph_id=paragraph_id,
dataset_id=self.dataset_id) for
problem_model, document_id, paragraph_id in problem_model_list]
result = [problem_model for problem_model, is_create in problem_content_dict.values() if
is_create], problem_paragraph_mapping_list
return result

View File

@ -37,7 +37,7 @@ from common.util.file_util import get_file_content
from common.util.fork import ChildLink, Fork
from common.util.split_model import get_split_model
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, ProblemParagraphMapping
from dataset.serializers.common_serializers import list_paragraph, MetaSerializer
from dataset.serializers.common_serializers import list_paragraph, MetaSerializer, ProblemParagraphManage
from dataset.serializers.document_serializers import DocumentSerializers, DocumentInstanceSerializer
from embedding.models import SearchMode
from setting.models import AuthOperate
@ -383,8 +383,7 @@ class DataSetSerializers(serializers.ModelSerializer):
document_model_list = []
paragraph_model_list = []
problem_model_list = []
problem_paragraph_mapping_list = []
problem_paragraph_object_list = []
# 插入文档
for document in instance.get('documents') if 'documents' in instance else []:
document_paragraph_dict_model = DocumentSerializers.Create.get_document_paragraph_model(dataset_id,
@ -392,12 +391,12 @@ class DataSetSerializers(serializers.ModelSerializer):
document_model_list.append(document_paragraph_dict_model.get('document'))
for paragraph in document_paragraph_dict_model.get('paragraph_model_list'):
paragraph_model_list.append(paragraph)
for problem in document_paragraph_dict_model.get('problem_model_list'):
problem_model_list.append(problem)
for problem_paragraph_mapping in document_paragraph_dict_model.get('problem_paragraph_mapping_list'):
problem_paragraph_mapping_list.append(problem_paragraph_mapping)
problem_model_list, problem_paragraph_mapping_list = DocumentSerializers.Create.reset_problem_model(
problem_model_list, problem_paragraph_mapping_list)
for problem_paragraph_object in document_paragraph_dict_model.get('problem_paragraph_object_list'):
problem_paragraph_object_list.append(problem_paragraph_object)
problem_model_list, problem_paragraph_mapping_list = (ProblemParagraphManage(problem_paragraph_object_list,
dataset_id)
.to_problem_model_list())
# 插入知识库
dataset.save()
# 插入文档

View File

@ -41,7 +41,7 @@ from common.util.file_util import get_file_content
from common.util.fork import Fork
from common.util.split_model import get_split_model
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, Status, ProblemParagraphMapping, Image
from dataset.serializers.common_serializers import BatchSerializer, MetaSerializer
from dataset.serializers.common_serializers import BatchSerializer, MetaSerializer, ProblemParagraphManage
from dataset.serializers.paragraph_serializers import ParagraphSerializers, ParagraphInstanceSerializer
from smartdoc.conf import PROJECT_DIR
@ -380,8 +380,9 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
document_paragraph_model = DocumentSerializers.Create.get_paragraph_model(document, paragraphs)
paragraph_model_list = document_paragraph_model.get('paragraph_model_list')
problem_model_list = document_paragraph_model.get('problem_model_list')
problem_paragraph_mapping_list = document_paragraph_model.get('problem_paragraph_mapping_list')
problem_paragraph_object_list = document_paragraph_model.get('problem_paragraph_object_list')
problem_model_list, problem_paragraph_mapping_list = ProblemParagraphManage(
problem_paragraph_object_list, document.dataset_id).to_problem_model_list()
# 批量插入段落
QuerySet(Paragraph).bulk_create(paragraph_model_list) if len(paragraph_model_list) > 0 else None
# 批量插入问题
@ -626,11 +627,13 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
self.is_valid(raise_exception=True)
dataset_id = self.data.get('dataset_id')
document_paragraph_model = self.get_document_paragraph_model(dataset_id, instance)
document_model = document_paragraph_model.get('document')
paragraph_model_list = document_paragraph_model.get('paragraph_model_list')
problem_model_list = document_paragraph_model.get('problem_model_list')
problem_paragraph_mapping_list = document_paragraph_model.get('problem_paragraph_mapping_list')
problem_paragraph_object_list = document_paragraph_model.get('problem_paragraph_object_list')
problem_model_list, problem_paragraph_mapping_list = (ProblemParagraphManage(problem_paragraph_object_list,
dataset_id)
.to_problem_model_list())
# 插入文档
document_model.save()
# 批量插入段落
@ -685,35 +688,15 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
dataset_id, document_model.id, paragraph) for paragraph in paragraph_list]
paragraph_model_list = []
problem_model_list = []
problem_paragraph_mapping_list = []
problem_paragraph_object_list = []
for paragraphs in paragraph_model_dict_list:
paragraph = paragraphs.get('paragraph')
for problem_model in paragraphs.get('problem_model_list'):
problem_model_list.append(problem_model)
for problem_paragraph_mapping in paragraphs.get('problem_paragraph_mapping_list'):
problem_paragraph_mapping_list.append(problem_paragraph_mapping)
for problem_model in paragraphs.get('problem_paragraph_object_list'):
problem_paragraph_object_list.append(problem_model)
paragraph_model_list.append(paragraph)
problem_model_list, problem_paragraph_mapping_list = DocumentSerializers.Create.reset_problem_model(
problem_model_list, problem_paragraph_mapping_list)
return {'document': document_model, 'paragraph_model_list': paragraph_model_list,
'problem_model_list': problem_model_list,
'problem_paragraph_mapping_list': problem_paragraph_mapping_list}
@staticmethod
def reset_problem_model(problem_model_list, problem_paragraph_mapping_list):
new_problem_model_list = [x for i, x in enumerate(problem_model_list) if
len([item for item in problem_model_list[:i] if item.content == x.content]) <= 0]
for new_problem_model in new_problem_model_list:
old_model_list = [problem.id for problem in problem_model_list if
problem.content == new_problem_model.content]
for problem_paragraph_mapping in problem_paragraph_mapping_list:
if old_model_list.__contains__(problem_paragraph_mapping.problem_id):
problem_paragraph_mapping.problem_id = new_problem_model.id
return new_problem_model_list, problem_paragraph_mapping_list
'problem_paragraph_object_list': problem_paragraph_object_list}
@staticmethod
def get_document_paragraph_model(dataset_id, instance: Dict):
@ -834,8 +817,7 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
dataset_id = self.data.get("dataset_id")
document_model_list = []
paragraph_model_list = []
problem_model_list = []
problem_paragraph_mapping_list = []
problem_paragraph_object_list = []
# 插入文档
for document in instance_list:
document_paragraph_dict_model = DocumentSerializers.Create.get_document_paragraph_model(dataset_id,
@ -843,11 +825,12 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
document_model_list.append(document_paragraph_dict_model.get('document'))
for paragraph in document_paragraph_dict_model.get('paragraph_model_list'):
paragraph_model_list.append(paragraph)
for problem in document_paragraph_dict_model.get('problem_model_list'):
problem_model_list.append(problem)
for problem_paragraph_mapping in document_paragraph_dict_model.get('problem_paragraph_mapping_list'):
problem_paragraph_mapping_list.append(problem_paragraph_mapping)
for problem_paragraph_object in document_paragraph_dict_model.get('problem_paragraph_object_list'):
problem_paragraph_object_list.append(problem_paragraph_object)
problem_model_list, problem_paragraph_mapping_list = (ProblemParagraphManage(problem_paragraph_object_list,
dataset_id)
.to_problem_model_list())
# 插入文档
QuerySet(Document).bulk_create(document_model_list) if len(document_model_list) > 0 else None
# 批量插入段落

View File

@ -21,7 +21,8 @@ from common.mixins.api_mixin import ApiMixin
from common.util.common import post
from common.util.field_message import ErrMessage
from dataset.models import Paragraph, Problem, Document, ProblemParagraphMapping
from dataset.serializers.common_serializers import update_document_char_length, BatchSerializer
from dataset.serializers.common_serializers import update_document_char_length, BatchSerializer, ProblemParagraphObject, \
ProblemParagraphManage
from dataset.serializers.problem_serializers import ProblemInstanceSerializer, ProblemSerializer, ProblemSerializers
from embedding.models import SourceType
@ -567,8 +568,10 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
document_id = self.data.get('document_id')
paragraph_problem_model = self.get_paragraph_problem_model(dataset_id, document_id, instance)
paragraph = paragraph_problem_model.get('paragraph')
problem_model_list = paragraph_problem_model.get('problem_model_list')
problem_paragraph_mapping_list = paragraph_problem_model.get('problem_paragraph_mapping_list')
problem_paragraph_object_list = paragraph_problem_model.get('problem_paragraph_object_list')
problem_model_list, problem_paragraph_mapping_list = (ProblemParagraphManage(problem_paragraph_object_list,
dataset_id).
to_problem_model_list())
# 插入段落
paragraph_problem_model.get('paragraph').save()
# 插入問題
@ -591,30 +594,12 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
content=instance.get("content"),
dataset_id=dataset_id,
title=instance.get("title") if 'title' in instance else '')
problem_list = instance.get('problem_list')
exists_problem_list = []
if 'problem_list' in instance and len(problem_list) > 0:
exists_problem_list = QuerySet(Problem).filter(dataset_id=dataset_id,
content__in=[p.get('content') for p in
problem_list]).all()
problem_paragraph_object_list = [
ProblemParagraphObject(dataset_id, document_id, paragraph.id, problem.get('content')) for problem in
(instance.get('problem_list') if 'problem_list' in instance else [])]
problem_model_list = [
ParagraphSerializers.Create.or_get(exists_problem_list, problem.get('content'), dataset_id) for
problem in (
instance.get('problem_list') if 'problem_list' in instance else [])]
# 问题去重
problem_model_list = [x for i, x in enumerate(problem_model_list) if
len([item for item in problem_model_list[:i] if item.content == x.content]) <= 0]
problem_paragraph_mapping_list = [
ProblemParagraphMapping(id=uuid.uuid1(), document_id=document_id, problem_id=problem_model.id,
paragraph_id=paragraph.id,
dataset_id=dataset_id) for
problem_model in problem_model_list]
return {'paragraph': paragraph,
'problem_model_list': [problem_model for problem_model in problem_model_list if
not list(exists_problem_list).__contains__(problem_model)],
'problem_paragraph_mapping_list': problem_paragraph_mapping_list}
'problem_paragraph_object_list': problem_paragraph_object_list}
@staticmethod
def or_get(exists_problem_list, content, dataset_id):