mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 01:33:05 +00:00
fix: 【知识库】知识库上传 有关联问题的会阻塞 (#676)
This commit is contained in:
parent
60181d6f83
commit
fc6da6a484
|
|
@ -7,6 +7,7 @@
|
|||
@desc:
|
||||
"""
|
||||
import os
|
||||
import uuid
|
||||
from typing import List
|
||||
|
||||
from django.db.models import QuerySet
|
||||
|
|
@ -20,7 +21,7 @@ from common.mixins.api_mixin import ApiMixin
|
|||
from common.util.field_message import ErrMessage
|
||||
from common.util.file_util import get_file_content
|
||||
from common.util.fork import Fork
|
||||
from dataset.models import Paragraph
|
||||
from dataset.models import Paragraph, Problem, ProblemParagraphMapping
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
|
||||
|
|
@ -79,3 +80,53 @@ class BatchSerializer(ApiMixin, serializers.Serializer):
|
|||
description="主键id列表")
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class ProblemParagraphObject:
|
||||
def __init__(self, dataset_id: str, document_id: str, paragraph_id: str, problem_content: str):
|
||||
self.dataset_id = dataset_id
|
||||
self.document_id = document_id
|
||||
self.paragraph_id = paragraph_id
|
||||
self.problem_content = problem_content
|
||||
|
||||
|
||||
def or_get(exists_problem_list, content, dataset_id, document_id, paragraph_id, problem_content_dict):
|
||||
if content in problem_content_dict:
|
||||
return problem_content_dict.get(content)[0], document_id, paragraph_id
|
||||
exists = [row for row in exists_problem_list if row.content == content]
|
||||
if len(exists) > 0:
|
||||
problem_content_dict[content] = exists[0], False
|
||||
return exists[0], document_id, paragraph_id
|
||||
else:
|
||||
problem = Problem(id=uuid.uuid1(), content=content, dataset_id=dataset_id)
|
||||
problem_content_dict[content] = problem, True
|
||||
return problem, document_id, paragraph_id
|
||||
|
||||
|
||||
class ProblemParagraphManage:
|
||||
def __init__(self, problemParagraphObjectList: [ProblemParagraphObject], dataset_id):
|
||||
self.dataset_id = dataset_id
|
||||
self.problemParagraphObjectList = problemParagraphObjectList
|
||||
|
||||
def to_problem_model_list(self):
|
||||
problem_list = [item.problem_content for item in self.problemParagraphObjectList]
|
||||
exists_problem_list = []
|
||||
if len(self.problemParagraphObjectList) > 0:
|
||||
# 查询到已存在的问题列表
|
||||
exists_problem_list = QuerySet(Problem).filter(dataset_id=self.dataset_id,
|
||||
content__in=problem_list).all()
|
||||
problem_content_dict = {}
|
||||
problem_model_list = [
|
||||
or_get(exists_problem_list, problemParagraphObject.problem_content, problemParagraphObject.dataset_id,
|
||||
problemParagraphObject.document_id, problemParagraphObject.paragraph_id, problem_content_dict) for
|
||||
problemParagraphObject in self.problemParagraphObjectList]
|
||||
|
||||
problem_paragraph_mapping_list = [
|
||||
ProblemParagraphMapping(id=uuid.uuid1(), document_id=document_id, problem_id=problem_model.id,
|
||||
paragraph_id=paragraph_id,
|
||||
dataset_id=self.dataset_id) for
|
||||
problem_model, document_id, paragraph_id in problem_model_list]
|
||||
|
||||
result = [problem_model for problem_model, is_create in problem_content_dict.values() if
|
||||
is_create], problem_paragraph_mapping_list
|
||||
return result
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ from common.util.file_util import get_file_content
|
|||
from common.util.fork import ChildLink, Fork
|
||||
from common.util.split_model import get_split_model
|
||||
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, ProblemParagraphMapping
|
||||
from dataset.serializers.common_serializers import list_paragraph, MetaSerializer
|
||||
from dataset.serializers.common_serializers import list_paragraph, MetaSerializer, ProblemParagraphManage
|
||||
from dataset.serializers.document_serializers import DocumentSerializers, DocumentInstanceSerializer
|
||||
from embedding.models import SearchMode
|
||||
from setting.models import AuthOperate
|
||||
|
|
@ -383,8 +383,7 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||
|
||||
document_model_list = []
|
||||
paragraph_model_list = []
|
||||
problem_model_list = []
|
||||
problem_paragraph_mapping_list = []
|
||||
problem_paragraph_object_list = []
|
||||
# 插入文档
|
||||
for document in instance.get('documents') if 'documents' in instance else []:
|
||||
document_paragraph_dict_model = DocumentSerializers.Create.get_document_paragraph_model(dataset_id,
|
||||
|
|
@ -392,12 +391,12 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||
document_model_list.append(document_paragraph_dict_model.get('document'))
|
||||
for paragraph in document_paragraph_dict_model.get('paragraph_model_list'):
|
||||
paragraph_model_list.append(paragraph)
|
||||
for problem in document_paragraph_dict_model.get('problem_model_list'):
|
||||
problem_model_list.append(problem)
|
||||
for problem_paragraph_mapping in document_paragraph_dict_model.get('problem_paragraph_mapping_list'):
|
||||
problem_paragraph_mapping_list.append(problem_paragraph_mapping)
|
||||
problem_model_list, problem_paragraph_mapping_list = DocumentSerializers.Create.reset_problem_model(
|
||||
problem_model_list, problem_paragraph_mapping_list)
|
||||
for problem_paragraph_object in document_paragraph_dict_model.get('problem_paragraph_object_list'):
|
||||
problem_paragraph_object_list.append(problem_paragraph_object)
|
||||
|
||||
problem_model_list, problem_paragraph_mapping_list = (ProblemParagraphManage(problem_paragraph_object_list,
|
||||
dataset_id)
|
||||
.to_problem_model_list())
|
||||
# 插入知识库
|
||||
dataset.save()
|
||||
# 插入文档
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ from common.util.file_util import get_file_content
|
|||
from common.util.fork import Fork
|
||||
from common.util.split_model import get_split_model
|
||||
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, Status, ProblemParagraphMapping, Image
|
||||
from dataset.serializers.common_serializers import BatchSerializer, MetaSerializer
|
||||
from dataset.serializers.common_serializers import BatchSerializer, MetaSerializer, ProblemParagraphManage
|
||||
from dataset.serializers.paragraph_serializers import ParagraphSerializers, ParagraphInstanceSerializer
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
|
|
@ -380,8 +380,9 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||
document_paragraph_model = DocumentSerializers.Create.get_paragraph_model(document, paragraphs)
|
||||
|
||||
paragraph_model_list = document_paragraph_model.get('paragraph_model_list')
|
||||
problem_model_list = document_paragraph_model.get('problem_model_list')
|
||||
problem_paragraph_mapping_list = document_paragraph_model.get('problem_paragraph_mapping_list')
|
||||
problem_paragraph_object_list = document_paragraph_model.get('problem_paragraph_object_list')
|
||||
problem_model_list, problem_paragraph_mapping_list = ProblemParagraphManage(
|
||||
problem_paragraph_object_list, document.dataset_id).to_problem_model_list()
|
||||
# 批量插入段落
|
||||
QuerySet(Paragraph).bulk_create(paragraph_model_list) if len(paragraph_model_list) > 0 else None
|
||||
# 批量插入问题
|
||||
|
|
@ -626,11 +627,13 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||
self.is_valid(raise_exception=True)
|
||||
dataset_id = self.data.get('dataset_id')
|
||||
document_paragraph_model = self.get_document_paragraph_model(dataset_id, instance)
|
||||
|
||||
document_model = document_paragraph_model.get('document')
|
||||
paragraph_model_list = document_paragraph_model.get('paragraph_model_list')
|
||||
problem_model_list = document_paragraph_model.get('problem_model_list')
|
||||
problem_paragraph_mapping_list = document_paragraph_model.get('problem_paragraph_mapping_list')
|
||||
|
||||
problem_paragraph_object_list = document_paragraph_model.get('problem_paragraph_object_list')
|
||||
problem_model_list, problem_paragraph_mapping_list = (ProblemParagraphManage(problem_paragraph_object_list,
|
||||
dataset_id)
|
||||
.to_problem_model_list())
|
||||
# 插入文档
|
||||
document_model.save()
|
||||
# 批量插入段落
|
||||
|
|
@ -685,35 +688,15 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||
dataset_id, document_model.id, paragraph) for paragraph in paragraph_list]
|
||||
|
||||
paragraph_model_list = []
|
||||
problem_model_list = []
|
||||
problem_paragraph_mapping_list = []
|
||||
problem_paragraph_object_list = []
|
||||
for paragraphs in paragraph_model_dict_list:
|
||||
paragraph = paragraphs.get('paragraph')
|
||||
for problem_model in paragraphs.get('problem_model_list'):
|
||||
problem_model_list.append(problem_model)
|
||||
for problem_paragraph_mapping in paragraphs.get('problem_paragraph_mapping_list'):
|
||||
problem_paragraph_mapping_list.append(problem_paragraph_mapping)
|
||||
for problem_model in paragraphs.get('problem_paragraph_object_list'):
|
||||
problem_paragraph_object_list.append(problem_model)
|
||||
paragraph_model_list.append(paragraph)
|
||||
|
||||
problem_model_list, problem_paragraph_mapping_list = DocumentSerializers.Create.reset_problem_model(
|
||||
problem_model_list, problem_paragraph_mapping_list)
|
||||
|
||||
return {'document': document_model, 'paragraph_model_list': paragraph_model_list,
|
||||
'problem_model_list': problem_model_list,
|
||||
'problem_paragraph_mapping_list': problem_paragraph_mapping_list}
|
||||
|
||||
@staticmethod
|
||||
def reset_problem_model(problem_model_list, problem_paragraph_mapping_list):
|
||||
new_problem_model_list = [x for i, x in enumerate(problem_model_list) if
|
||||
len([item for item in problem_model_list[:i] if item.content == x.content]) <= 0]
|
||||
|
||||
for new_problem_model in new_problem_model_list:
|
||||
old_model_list = [problem.id for problem in problem_model_list if
|
||||
problem.content == new_problem_model.content]
|
||||
for problem_paragraph_mapping in problem_paragraph_mapping_list:
|
||||
if old_model_list.__contains__(problem_paragraph_mapping.problem_id):
|
||||
problem_paragraph_mapping.problem_id = new_problem_model.id
|
||||
return new_problem_model_list, problem_paragraph_mapping_list
|
||||
'problem_paragraph_object_list': problem_paragraph_object_list}
|
||||
|
||||
@staticmethod
|
||||
def get_document_paragraph_model(dataset_id, instance: Dict):
|
||||
|
|
@ -834,8 +817,7 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||
dataset_id = self.data.get("dataset_id")
|
||||
document_model_list = []
|
||||
paragraph_model_list = []
|
||||
problem_model_list = []
|
||||
problem_paragraph_mapping_list = []
|
||||
problem_paragraph_object_list = []
|
||||
# 插入文档
|
||||
for document in instance_list:
|
||||
document_paragraph_dict_model = DocumentSerializers.Create.get_document_paragraph_model(dataset_id,
|
||||
|
|
@ -843,11 +825,12 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||
document_model_list.append(document_paragraph_dict_model.get('document'))
|
||||
for paragraph in document_paragraph_dict_model.get('paragraph_model_list'):
|
||||
paragraph_model_list.append(paragraph)
|
||||
for problem in document_paragraph_dict_model.get('problem_model_list'):
|
||||
problem_model_list.append(problem)
|
||||
for problem_paragraph_mapping in document_paragraph_dict_model.get('problem_paragraph_mapping_list'):
|
||||
problem_paragraph_mapping_list.append(problem_paragraph_mapping)
|
||||
for problem_paragraph_object in document_paragraph_dict_model.get('problem_paragraph_object_list'):
|
||||
problem_paragraph_object_list.append(problem_paragraph_object)
|
||||
|
||||
problem_model_list, problem_paragraph_mapping_list = (ProblemParagraphManage(problem_paragraph_object_list,
|
||||
dataset_id)
|
||||
.to_problem_model_list())
|
||||
# 插入文档
|
||||
QuerySet(Document).bulk_create(document_model_list) if len(document_model_list) > 0 else None
|
||||
# 批量插入段落
|
||||
|
|
|
|||
|
|
@ -21,7 +21,8 @@ from common.mixins.api_mixin import ApiMixin
|
|||
from common.util.common import post
|
||||
from common.util.field_message import ErrMessage
|
||||
from dataset.models import Paragraph, Problem, Document, ProblemParagraphMapping
|
||||
from dataset.serializers.common_serializers import update_document_char_length, BatchSerializer
|
||||
from dataset.serializers.common_serializers import update_document_char_length, BatchSerializer, ProblemParagraphObject, \
|
||||
ProblemParagraphManage
|
||||
from dataset.serializers.problem_serializers import ProblemInstanceSerializer, ProblemSerializer, ProblemSerializers
|
||||
from embedding.models import SourceType
|
||||
|
||||
|
|
@ -567,8 +568,10 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
|
|||
document_id = self.data.get('document_id')
|
||||
paragraph_problem_model = self.get_paragraph_problem_model(dataset_id, document_id, instance)
|
||||
paragraph = paragraph_problem_model.get('paragraph')
|
||||
problem_model_list = paragraph_problem_model.get('problem_model_list')
|
||||
problem_paragraph_mapping_list = paragraph_problem_model.get('problem_paragraph_mapping_list')
|
||||
problem_paragraph_object_list = paragraph_problem_model.get('problem_paragraph_object_list')
|
||||
problem_model_list, problem_paragraph_mapping_list = (ProblemParagraphManage(problem_paragraph_object_list,
|
||||
dataset_id).
|
||||
to_problem_model_list())
|
||||
# 插入段落
|
||||
paragraph_problem_model.get('paragraph').save()
|
||||
# 插入問題
|
||||
|
|
@ -591,30 +594,12 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
|
|||
content=instance.get("content"),
|
||||
dataset_id=dataset_id,
|
||||
title=instance.get("title") if 'title' in instance else '')
|
||||
problem_list = instance.get('problem_list')
|
||||
exists_problem_list = []
|
||||
if 'problem_list' in instance and len(problem_list) > 0:
|
||||
exists_problem_list = QuerySet(Problem).filter(dataset_id=dataset_id,
|
||||
content__in=[p.get('content') for p in
|
||||
problem_list]).all()
|
||||
problem_paragraph_object_list = [
|
||||
ProblemParagraphObject(dataset_id, document_id, paragraph.id, problem.get('content')) for problem in
|
||||
(instance.get('problem_list') if 'problem_list' in instance else [])]
|
||||
|
||||
problem_model_list = [
|
||||
ParagraphSerializers.Create.or_get(exists_problem_list, problem.get('content'), dataset_id) for
|
||||
problem in (
|
||||
instance.get('problem_list') if 'problem_list' in instance else [])]
|
||||
# 问题去重
|
||||
problem_model_list = [x for i, x in enumerate(problem_model_list) if
|
||||
len([item for item in problem_model_list[:i] if item.content == x.content]) <= 0]
|
||||
|
||||
problem_paragraph_mapping_list = [
|
||||
ProblemParagraphMapping(id=uuid.uuid1(), document_id=document_id, problem_id=problem_model.id,
|
||||
paragraph_id=paragraph.id,
|
||||
dataset_id=dataset_id) for
|
||||
problem_model in problem_model_list]
|
||||
return {'paragraph': paragraph,
|
||||
'problem_model_list': [problem_model for problem_model in problem_model_list if
|
||||
not list(exists_problem_list).__contains__(problem_model)],
|
||||
'problem_paragraph_mapping_list': problem_paragraph_mapping_list}
|
||||
'problem_paragraph_object_list': problem_paragraph_object_list}
|
||||
|
||||
@staticmethod
|
||||
def or_get(exists_problem_list, content, dataset_id):
|
||||
|
|
|
|||
Loading…
Reference in New Issue