MaxKB/apps/dataset/serializers/paragraph_serializers.py
shaohuzhang1 196a7c73ff
Pr@main@list search (#68)
* fix: type

* fix: 本地搜索对英文大小写敏感#57
2024-04-15 13:58:40 +08:00

563 lines
30 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# coding=utf-8
"""
@project: maxkb
@Author
@file paragraph_serializers.py
@date2023/10/16 15:51
@desc:
"""
import uuid
from typing import Dict
from django.db import transaction
from django.db.models import QuerySet
from drf_yasg import openapi
from rest_framework import serializers
from common.db.search import page_search
from common.event.listener_manage import ListenerManagement
from common.exception.app_exception import AppApiException
from common.mixins.api_mixin import ApiMixin
from common.util.common import post
from common.util.field_message import ErrMessage
from dataset.models import Paragraph, Problem, Document, ProblemParagraphMapping
from dataset.serializers.common_serializers import update_document_char_length
from dataset.serializers.problem_serializers import ProblemInstanceSerializer, ProblemSerializer, ProblemSerializers
from embedding.models import SourceType
class ParagraphSerializer(serializers.ModelSerializer):
class Meta:
model = Paragraph
fields = ['id', 'content', 'is_active', 'document_id', 'title',
'create_time', 'update_time']
class ParagraphInstanceSerializer(ApiMixin, serializers.Serializer):
"""
段落实例对象
"""
content = serializers.CharField(required=True, error_messages=ErrMessage.char("段落内容"),
max_length=4096,
min_length=1,
allow_null=True, allow_blank=True)
title = serializers.CharField(required=False, max_length=256, error_messages=ErrMessage.char("段落标题"),
allow_null=True, allow_blank=True)
problem_list = ProblemInstanceSerializer(required=False, many=True)
is_active = serializers.BooleanField(required=False, error_messages=ErrMessage.char("段落是否可用"))
@staticmethod
def get_request_body_api():
return openapi.Schema(
type=openapi.TYPE_OBJECT,
required=['content'],
properties={
'content': openapi.Schema(type=openapi.TYPE_STRING, max_length=4096, title="分段内容",
description="分段内容"),
'title': openapi.Schema(type=openapi.TYPE_STRING, max_length=256, title="分段标题",
description="分段标题"),
'is_active': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否可用", description="是否可用"),
'problem_list': openapi.Schema(type=openapi.TYPE_ARRAY, title='问题列表',
description="问题列表",
items=ProblemInstanceSerializer.get_request_body_api())
}
)
class EditParagraphSerializers(serializers.Serializer):
title = serializers.CharField(required=False, max_length=256, error_messages=ErrMessage.char(
"分段标题"), allow_null=True, allow_blank=True)
content = serializers.CharField(required=False, max_length=4096, allow_null=True, allow_blank=True,
error_messages=ErrMessage.char(
"分段内容"))
problem_list = ProblemInstanceSerializer(required=False, many=True)
class ParagraphSerializers(ApiMixin, serializers.Serializer):
title = serializers.CharField(required=False, max_length=256, error_messages=ErrMessage.char(
"分段标题"), allow_null=True, allow_blank=True)
content = serializers.CharField(required=True, max_length=4096, error_messages=ErrMessage.char(
"分段内容"))
class Problem(ApiMixin, serializers.Serializer):
dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id"))
document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("文档id"))
paragraph_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("段落id"))
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
if not QuerySet(Paragraph).filter(id=self.data.get('paragraph_id')).exists():
raise AppApiException(500, "段落id不存在")
def list(self, with_valid=False):
"""
获取问题列表
:param with_valid: 是否校验
:return: 问题列表
"""
if with_valid:
self.is_valid(raise_exception=True)
problem_paragraph_mapping = QuerySet(ProblemParagraphMapping).filter(dataset_id=self.data.get("dataset_id"),
paragraph_id=self.data.get(
'paragraph_id'))
return [ProblemSerializer(row).data for row in
QuerySet(Problem).filter(id__in=[row.problem_id for row in problem_paragraph_mapping])]
@transaction.atomic
def save(self, instance: Dict, with_valid=True, with_embedding=True):
if with_valid:
self.is_valid()
ProblemInstanceSerializer(data=instance).is_valid(raise_exception=True)
problem = QuerySet(Problem).filter(dataset_id=self.data.get('dataset_id'),
content=instance.get('content')).first()
if problem is None:
problem = Problem(id=uuid.uuid1(), dataset_id=self.data.get('dataset_id'),
content=instance.get('content'))
problem.save()
if QuerySet(ProblemParagraphMapping).filter(dataset_id=self.data.get('dataset_id'), problem_id=problem.id,
paragraph_id=self.data.get('paragraph_id')).exists():
raise AppApiException(500, "已经关联,请勿重复关联")
problem_paragraph_mapping = ProblemParagraphMapping(id=uuid.uuid1(),
problem_id=problem.id,
document_id=self.data.get('document_id'),
paragraph_id=self.data.get('paragraph_id'),
dataset_id=self.data.get('dataset_id'))
problem_paragraph_mapping.save()
if with_embedding:
ListenerManagement.embedding_by_problem_signal.send({'text': problem.content,
'is_active': True,
'source_type': SourceType.PROBLEM,
'source_id': problem_paragraph_mapping.id,
'document_id': self.data.get('document_id'),
'paragraph_id': self.data.get('paragraph_id'),
'dataset_id': self.data.get('dataset_id'),
})
return ProblemSerializers.Operate(
data={'dataset_id': self.data.get('dataset_id'),
'problem_id': problem.id}).one(with_valid=True)
@staticmethod
def get_request_params_api():
return [openapi.Parameter(name='dataset_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='知识库id'),
openapi.Parameter(name='document_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='文档id'),
openapi.Parameter(name='paragraph_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='段落id')]
@staticmethod
def get_request_body_api():
return openapi.Schema(type=openapi.TYPE_OBJECT,
required=["content"],
properties={
'content': openapi.Schema(
type=openapi.TYPE_STRING, title="内容")
})
@staticmethod
def get_response_body_api():
return openapi.Schema(
type=openapi.TYPE_OBJECT,
required=['id', 'content', 'hit_num', 'dataset_id', 'create_time', 'update_time'],
properties={
'id': openapi.Schema(type=openapi.TYPE_STRING, title="id",
description="id", default="xx"),
'content': openapi.Schema(type=openapi.TYPE_STRING, title="问题内容",
description="问题内容", default='问题内容'),
'hit_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="命中数量", description="命中数量",
default=1),
'dataset_id': openapi.Schema(type=openapi.TYPE_STRING, title="知识库id",
description="知识库id", default='xxx'),
'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间",
description="修改时间",
default="1970-01-01 00:00:00"),
'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间",
description="创建时间",
default="1970-01-01 00:00:00"
)
}
)
class Association(ApiMixin, serializers.Serializer):
dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id"))
problem_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("问题id"))
document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("文档id"))
paragraph_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("段落id"))
def is_valid(self, *, raise_exception=True):
super().is_valid(raise_exception=True)
dataset_id = self.data.get('dataset_id')
paragraph_id = self.data.get('paragraph_id')
problem_id = self.data.get("problem_id")
if not QuerySet(Paragraph).filter(dataset_id=dataset_id, id=paragraph_id).exists():
raise AppApiException(500, "段落不存在")
if not QuerySet(Problem).filter(dataset_id=dataset_id, id=problem_id).exists():
raise AppApiException(500, "问题不存在")
def association(self, with_valid=True, with_embedding=True):
if with_valid:
self.is_valid(raise_exception=True)
problem = QuerySet(Problem).filter(id=self.data.get("problem_id")).first()
problem_paragraph_mapping = ProblemParagraphMapping(id=uuid.uuid1(),
document_id=self.data.get('document_id'),
paragraph_id=self.data.get('paragraph_id'),
dataset_id=self.data.get('dataset_id'),
problem_id=problem.id)
problem_paragraph_mapping.save()
if with_embedding:
ListenerManagement.embedding_by_problem_signal.send({'text': problem.content,
'is_active': True,
'source_type': SourceType.PROBLEM,
'source_id': problem_paragraph_mapping.id,
'document_id': self.data.get('document_id'),
'paragraph_id': self.data.get('paragraph_id'),
'dataset_id': self.data.get('dataset_id'),
})
def un_association(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
problem_paragraph_mapping = QuerySet(ProblemParagraphMapping).filter(
paragraph_id=self.data.get('paragraph_id'),
dataset_id=self.data.get('dataset_id'),
problem_id=self.data.get(
'problem_id')).first()
problem_paragraph_mapping_id = problem_paragraph_mapping.id
problem_paragraph_mapping.delete()
ListenerManagement.delete_embedding_by_source_signal.send(problem_paragraph_mapping_id)
return True
@staticmethod
def get_request_params_api():
return [openapi.Parameter(name='dataset_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='知识库id'),
openapi.Parameter(name='document_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='文档id')
, openapi.Parameter(name='paragraph_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='段落id'),
openapi.Parameter(name='problem_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='问题id')
]
class Operate(ApiMixin, serializers.Serializer):
# 段落id
paragraph_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char(
"段落id"))
# 知识库id
dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char(
"知识库id"))
# 文档id
document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char(
"文档id"))
def is_valid(self, *, raise_exception=True):
super().is_valid(raise_exception=True)
if not QuerySet(Paragraph).filter(id=self.data.get('paragraph_id')).exists():
raise AppApiException(500, "段落id不存在")
@staticmethod
def post_embedding(paragraph, instance):
if 'is_active' in instance and instance.get('is_active') is not None:
s = (ListenerManagement.enable_embedding_by_paragraph_signal if instance.get(
'is_active') else ListenerManagement.disable_embedding_by_paragraph_signal)
s.send(paragraph.get('id'))
else:
ListenerManagement.embedding_by_paragraph_signal.send(paragraph.get('id'))
return paragraph
@post(post_embedding)
@transaction.atomic
def edit(self, instance: Dict):
self.is_valid()
EditParagraphSerializers(data=instance).is_valid(raise_exception=True)
_paragraph = QuerySet(Paragraph).get(id=self.data.get("paragraph_id"))
update_keys = ['title', 'content', 'is_active']
for update_key in update_keys:
if update_key in instance and instance.get(update_key) is not None:
_paragraph.__setattr__(update_key, instance.get(update_key))
if 'problem_list' in instance:
update_problem_list = list(
filter(lambda row: 'id' in row and row.get('id') is not None, instance.get('problem_list')))
create_problem_list = list(filter(lambda row: row.get('id') is None, instance.get('problem_list')))
# 问题集合
problem_list = QuerySet(Problem).filter(paragraph_id=self.data.get("paragraph_id"))
# 校验前端 携带过来的id
for update_problem in update_problem_list:
if not set([str(row.id) for row in problem_list]).__contains__(update_problem.get('id')):
raise AppApiException(500, update_problem.get('id') + '问题id不存在')
# 对比需要删除的问题
delete_problem_list = list(filter(
lambda row: not [str(update_row.get('id')) for update_row in update_problem_list].__contains__(
str(row.id)), problem_list)) if len(update_problem_list) > 0 else []
# 删除问题
QuerySet(Problem).filter(id__in=[row.id for row in delete_problem_list]).delete() if len(
delete_problem_list) > 0 else None
# 插入新的问题
QuerySet(Problem).bulk_create(
[Problem(id=uuid.uuid1(), content=p.get('content'), paragraph_id=self.data.get('paragraph_id'),
dataset_id=self.data.get('dataset_id'), document_id=self.data.get('document_id')) for
p in create_problem_list]) if len(create_problem_list) else None
# 修改问题集合
QuerySet(Problem).bulk_update(
[Problem(id=row.get('id'), content=row.get('content')) for row in update_problem_list],
['content']) if len(
update_problem_list) > 0 else None
_paragraph.save()
update_document_char_length(self.data.get('document_id'))
return self.one(), instance
def get_problem_list(self):
ProblemParagraphMapping(ProblemParagraphMapping)
problem_paragraph_mapping = QuerySet(ProblemParagraphMapping).filter(
paragraph_id=self.data.get("paragraph_id"))
if len(problem_paragraph_mapping) > 0:
return [ProblemSerializer(problem).data for problem in
QuerySet(Problem).filter(id__in=[ppm.problem_id for ppm in problem_paragraph_mapping])]
return []
def one(self, with_valid=False):
if with_valid:
self.is_valid(raise_exception=True)
return {**ParagraphSerializer(QuerySet(model=Paragraph).get(id=self.data.get('paragraph_id'))).data,
'problem_list': self.get_problem_list()}
def delete(self, with_valid=False):
if with_valid:
self.is_valid(raise_exception=True)
paragraph_id = self.data.get('paragraph_id')
QuerySet(Paragraph).filter(id=paragraph_id).delete()
QuerySet(ProblemParagraphMapping).filter(paragraph_id=paragraph_id).delete()
ListenerManagement.delete_embedding_by_paragraph_signal.send(paragraph_id)
@staticmethod
def get_request_body_api():
return ParagraphInstanceSerializer.get_request_body_api()
@staticmethod
def get_response_body_api():
return ParagraphInstanceSerializer.get_request_body_api()
@staticmethod
def get_request_params_api():
return [openapi.Parameter(type=openapi.TYPE_STRING, in_=openapi.IN_PATH, name='paragraph_id',
description="段落id")]
class Create(ApiMixin, serializers.Serializer):
dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char(
"知识库id"))
document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char(
"文档id"))
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
if not QuerySet(Document).filter(id=self.data.get('document_id'),
dataset_id=self.data.get('dataset_id')).exists():
raise AppApiException(500, "文档id不正确")
def save(self, instance: Dict, with_valid=True, with_embedding=True):
if with_valid:
ParagraphSerializers(data=instance).is_valid(raise_exception=True)
self.is_valid()
dataset_id = self.data.get("dataset_id")
document_id = self.data.get('document_id')
paragraph_problem_model = self.get_paragraph_problem_model(dataset_id, document_id, instance)
paragraph = paragraph_problem_model.get('paragraph')
problem_model_list = paragraph_problem_model.get('problem_model_list')
problem_paragraph_mapping_list = paragraph_problem_model.get('problem_paragraph_mapping_list')
# 插入段落
paragraph_problem_model.get('paragraph').save()
# 插入問題
QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None
# 插入问题关联关系
QuerySet(ProblemParagraphMapping).bulk_create(problem_paragraph_mapping_list) if len(
problem_paragraph_mapping_list) > 0 else None
# 修改长度
update_document_char_length(document_id)
if with_embedding:
ListenerManagement.embedding_by_paragraph_signal.send(str(paragraph.id))
return ParagraphSerializers.Operate(
data={'paragraph_id': str(paragraph.id), 'dataset_id': dataset_id, 'document_id': document_id}).one(
with_valid=True)
@staticmethod
def get_paragraph_problem_model(dataset_id: str, document_id: str, instance: Dict):
paragraph = Paragraph(id=uuid.uuid1(),
document_id=document_id,
content=instance.get("content"),
dataset_id=dataset_id,
title=instance.get("title") if 'title' in instance else '')
problem_list = instance.get('problem_list')
exists_problem_list = []
if 'problem_list' in instance and len(problem_list) > 0:
exists_problem_list = QuerySet(Problem).filter(dataset_id=dataset_id,
content__in=[p.get('content') for p in
problem_list]).all()
problem_model_list = [
ParagraphSerializers.Create.or_get(exists_problem_list, problem.get('content'), dataset_id) for
problem in (
instance.get('problem_list') if 'problem_list' in instance else [])]
# 问题去重
problem_model_list = [x for i, x in enumerate(problem_model_list) if
len([item for item in problem_model_list[:i] if item.content == x.content]) <= 0]
problem_paragraph_mapping_list = [
ProblemParagraphMapping(id=uuid.uuid1(), document_id=document_id, problem_id=problem_model.id,
paragraph_id=paragraph.id,
dataset_id=dataset_id) for
problem_model in problem_model_list]
return {'paragraph': paragraph,
'problem_model_list': [problem_model for problem_model in problem_model_list if
not list(exists_problem_list).__contains__(problem_model)],
'problem_paragraph_mapping_list': problem_paragraph_mapping_list}
@staticmethod
def or_get(exists_problem_list, content, dataset_id):
exists = [row for row in exists_problem_list if row.content == content]
if len(exists) > 0:
return exists[0]
else:
return Problem(id=uuid.uuid1(), content=content, dataset_id=dataset_id)
@staticmethod
def get_request_body_api():
return ParagraphInstanceSerializer.get_request_body_api()
@staticmethod
def get_request_params_api():
return [openapi.Parameter(name='dataset_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='知识库id'),
openapi.Parameter(name='document_id', in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description="文档id")
]
class Query(ApiMixin, serializers.Serializer):
dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char(
"知识库id"))
document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char(
"文档id"))
title = serializers.CharField(required=False, error_messages=ErrMessage.char(
"段落标题"))
content = serializers.CharField(required=False)
def get_query_set(self):
query_set = QuerySet(model=Paragraph)
query_set = query_set.filter(
**{'dataset_id': self.data.get('dataset_id'), 'document_id': self.data.get("document_id")})
if 'title' in self.data:
query_set = query_set.filter(
**{'title__icontains': self.data.get('title')})
if 'content' in self.data:
query_set = query_set.filter(**{'content__icontains': self.data.get('content')})
return query_set
def list(self):
return list(map(lambda row: ParagraphSerializer(row).data, self.get_query_set()))
def page(self, current_page, page_size):
query_set = self.get_query_set()
return page_search(current_page, page_size, query_set, lambda row: ParagraphSerializer(row).data)
@staticmethod
def get_request_params_api():
return [openapi.Parameter(name='document_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='文档id'),
openapi.Parameter(name='title',
in_=openapi.IN_QUERY,
type=openapi.TYPE_STRING,
required=False,
description='标题'),
openapi.Parameter(name='content',
in_=openapi.IN_QUERY,
type=openapi.TYPE_STRING,
required=False,
description='内容')
]
@staticmethod
def get_response_body_api():
return openapi.Schema(
type=openapi.TYPE_OBJECT,
required=['id', 'content', 'hit_num', 'star_num', 'trample_num', 'is_active', 'dataset_id',
'document_id', 'title',
'create_time', 'update_time'],
properties={
'id': openapi.Schema(type=openapi.TYPE_STRING, title="id",
description="id", default="xx"),
'content': openapi.Schema(type=openapi.TYPE_STRING, title="段落内容",
description="段落内容", default='段落内容'),
'title': openapi.Schema(type=openapi.TYPE_STRING, title="标题",
description="标题", default="xxx的描述"),
'hit_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="命中数量", description="命中数量",
default=1),
'star_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="点赞数量",
description="点赞数量", default=1),
'trample_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="点踩数量",
description="点踩数", default=1),
'dataset_id': openapi.Schema(type=openapi.TYPE_STRING, title="知识库id",
description="知识库id", default='xxx'),
'document_id': openapi.Schema(type=openapi.TYPE_STRING, title="文档id",
description="文档id", default='xxx'),
'is_active': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否可用",
description="是否可用", default=True),
'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间",
description="修改时间",
default="1970-01-01 00:00:00"),
'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间",
description="创建时间",
default="1970-01-01 00:00:00"
)
}
)