MaxKB/apps/dataset/serializers/document_serializers.py
2023-12-29 18:02:23 +08:00

398 lines
20 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# coding=utf-8
"""
@project: maxkb
@Author
@file document_serializers.py
@date2023/9/22 13:43
@desc:
"""
import os
import uuid
from functools import reduce
from typing import List, Dict
from django.core import validators
from django.db import transaction
from django.db.models import QuerySet
from drf_yasg import openapi
from rest_framework import serializers
from common.db.search import native_search, native_page_search
from common.event.listener_manage import ListenerManagement
from common.exception.app_exception import AppApiException
from common.mixins.api_mixin import ApiMixin
from common.util.common import post
from common.util.file_util import get_file_content
from common.util.split_model import SplitModel, get_split_model
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type
from dataset.serializers.paragraph_serializers import ParagraphSerializers, ParagraphInstanceSerializer
from smartdoc.conf import PROJECT_DIR
class DocumentInstanceSerializer(ApiMixin, serializers.Serializer):
name = serializers.CharField(required=True,
validators=[
validators.MaxLengthValidator(limit_value=128,
message="文档名称在1-128个字符之间"),
validators.MinLengthValidator(limit_value=1,
message="知识库名称在1-128个字符之间")
])
paragraphs = ParagraphInstanceSerializer(required=False, many=True, allow_null=True)
@staticmethod
def get_request_body_api():
return openapi.Schema(
type=openapi.TYPE_OBJECT,
required=['name', 'paragraphs'],
properties={
'name': openapi.Schema(type=openapi.TYPE_STRING, title="文档名称", description="文档名称"),
'paragraphs': openapi.Schema(type=openapi.TYPE_ARRAY, title="段落列表", description="段落列表",
items=ParagraphSerializers.Create.get_request_body_api())
}
)
class DocumentSerializers(ApiMixin, serializers.Serializer):
class Query(ApiMixin, serializers.Serializer):
# 知识库id
dataset_id = serializers.UUIDField(required=True)
name = serializers.CharField(required=False,
validators=[
validators.MaxLengthValidator(limit_value=128,
message="文档名称在1-128个字符之间"),
validators.MinLengthValidator(limit_value=1,
message="知识库名称在1-128个字符之间")
])
def get_query_set(self):
query_set = QuerySet(model=Document)
query_set = query_set.filter(**{'dataset_id': self.data.get("dataset_id")})
if 'name' in self.data and self.data.get('name') is not None:
query_set = query_set.filter(**{'name__contains': self.data.get('name')})
query_set = query_set.order_by('-create_time')
return query_set
def list(self, with_valid=False):
if with_valid:
self.is_valid(raise_exception=True)
query_set = self.get_query_set()
return native_search(query_set, select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_document.sql')))
def page(self, current_page, page_size):
query_set = self.get_query_set()
return native_page_search(current_page, page_size, query_set, select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_document.sql')))
@staticmethod
def get_request_params_api():
return [openapi.Parameter(name='name',
in_=openapi.IN_QUERY,
type=openapi.TYPE_STRING,
required=False,
description='文档名称')]
@staticmethod
def get_response_body_api():
return openapi.Schema(type=openapi.TYPE_ARRAY,
title="文档列表", description="文档列表",
items=DocumentSerializers.Operate.get_response_body_api())
class Operate(ApiMixin, serializers.Serializer):
document_id = serializers.UUIDField(required=True)
@staticmethod
def get_request_params_api():
return [openapi.Parameter(name='dataset_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='知识库id'),
openapi.Parameter(name='document_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='文档id')
]
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
document_id = self.data.get('document_id')
if not QuerySet(Document).filter(id=document_id).exists():
raise AppApiException(500, "文档id不存在")
def one(self, with_valid=False):
if with_valid:
self.is_valid(raise_exception=True)
query_set = QuerySet(model=Document)
query_set = query_set.filter(**{'id': self.data.get("document_id")})
return native_search(query_set, select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_document.sql')), with_search_one=True)
def edit(self, instance: Dict, with_valid=False):
if with_valid:
self.is_valid()
_document = QuerySet(Document).get(id=self.data.get("document_id"))
update_keys = ['name', 'is_active']
for update_key in update_keys:
if update_key in instance and instance.get(update_key) is not None:
_document.__setattr__(update_key, instance.get(update_key))
_document.save()
return self.one()
def refresh(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
document_id = self.data.get("document_id")
ListenerManagement.embedding_by_document_signal.send(document_id)
return True
@transaction.atomic
def delete(self):
document_id = self.data.get("document_id")
QuerySet(model=Document).filter(id=document_id).delete()
# 删除段落
QuerySet(model=Paragraph).filter(document_id=document_id).delete()
# 删除问题
QuerySet(model=Problem).filter(document_id=document_id).delete()
# 删除向量库
ListenerManagement.delete_embedding_by_document_signal.send(document_id)
return True
@staticmethod
def get_response_body_api():
return openapi.Schema(
type=openapi.TYPE_OBJECT,
required=['id', 'name', 'char_length', 'user_id', 'paragraph_count', 'is_active'
'update_time', 'create_time'],
properties={
'id': openapi.Schema(type=openapi.TYPE_STRING, title="id",
description="id", default="xx"),
'name': openapi.Schema(type=openapi.TYPE_STRING, title="名称",
description="名称", default="测试知识库"),
'char_length': openapi.Schema(type=openapi.TYPE_INTEGER, title="字符数",
description="字符数", default=10),
'user_id': openapi.Schema(type=openapi.TYPE_STRING, title="用户id", description="用户id"),
'paragraph_count': openapi.Schema(type=openapi.TYPE_INTEGER, title="文档数量",
description="文档数量", default=1),
'is_active': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否可用",
description="是否可用", default=True),
'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间",
description="修改时间",
default="1970-01-01 00:00:00"),
'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间",
description="创建时间",
default="1970-01-01 00:00:00"
)
}
)
@staticmethod
def get_request_body_api():
return openapi.Schema(
type=openapi.TYPE_OBJECT,
properties={
'name': openapi.Schema(type=openapi.TYPE_STRING, title="文档名称", description="文档名称"),
'is_active': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否可用", description="是否可用"),
}
)
class Create(ApiMixin, serializers.Serializer):
dataset_id = serializers.UUIDField(required=True)
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
if not QuerySet(DataSet).filter(id=self.data.get('dataset_id')).exists():
raise AppApiException(10000, "知识库id不存在")
return True
@staticmethod
def post_embedding(result, document_id):
ListenerManagement.embedding_by_document_signal.send(document_id)
return result
@post(post_function=post_embedding)
@transaction.atomic
def save(self, instance: Dict, with_valid=False, **kwargs):
if with_valid:
DocumentInstanceSerializer(data=instance).is_valid(raise_exception=True)
self.is_valid(raise_exception=True)
dataset_id = self.data.get('dataset_id')
document_paragraph_model = self.get_document_paragraph_model(dataset_id, instance)
document_model = document_paragraph_model.get('document')
paragraph_model_list = document_paragraph_model.get('paragraph_model_list')
problem_model_list = document_paragraph_model.get('problem_model_list')
# 插入文档
document_model.save()
# 批量插入段落
QuerySet(Paragraph).bulk_create(paragraph_model_list) if len(paragraph_model_list) > 0 else None
# 批量插入问题
QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None
document_id = str(document_model.id)
return DocumentSerializers.Operate(
data={'dataset_id': dataset_id, 'document_id': document_id}).one(
with_valid=True), document_id
@staticmethod
def get_document_paragraph_model(dataset_id, instance: Dict):
document_model = Document(
**{'dataset_id': dataset_id,
'id': uuid.uuid1(),
'name': instance.get('name'),
'char_length': reduce(lambda x, y: x + y,
[len(p.get('content')) for p in instance.get('paragraphs', [])],
0),
'meta': instance.get('meta') if instance.get('meta') is not None else {},
'type': instance.get('type') if instance.get('type') is not None else Type.base})
paragraph_model_dict_list = [ParagraphSerializers.Create(
data={'dataset_id': dataset_id, 'document_id': str(document_model.id)}).get_paragraph_problem_model(
dataset_id, document_model.id, paragraph) for paragraph in (instance.get('paragraphs') if
'paragraphs' in instance else [])]
paragraph_model_list = []
problem_model_list = []
for paragraphs in paragraph_model_dict_list:
paragraph = paragraphs.get('paragraph')
for problem_model in paragraphs.get('problem_model_list'):
problem_model_list.append(problem_model)
paragraph_model_list.append(paragraph)
return {'document': document_model, 'paragraph_model_list': paragraph_model_list,
'problem_model_list': problem_model_list}
@staticmethod
def get_request_body_api():
return DocumentInstanceSerializer.get_request_body_api()
@staticmethod
def get_request_params_api():
return [openapi.Parameter(name='dataset_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='知识库id')
]
class Split(ApiMixin, serializers.Serializer):
file = serializers.ListField(required=True)
limit = serializers.IntegerField(required=False)
patterns = serializers.ListField(required=False,
child=serializers.CharField(required=True))
with_filter = serializers.BooleanField(required=False)
def is_valid(self, *, raise_exception=True):
super().is_valid(raise_exception=True)
files = self.data.get('file')
for f in files:
if f.size > 1024 * 1024 * 10:
raise AppApiException(500, "上传文件最大不能超过10m")
@staticmethod
def get_request_params_api():
return [
openapi.Parameter(name='file',
in_=openapi.IN_FORM,
type=openapi.TYPE_ARRAY,
items=openapi.Items(type=openapi.TYPE_FILE),
required=True,
description='上传文件'),
openapi.Parameter(name='limit',
in_=openapi.IN_FORM,
required=False,
type=openapi.TYPE_INTEGER, title="分段长度", description="分段长度"),
openapi.Parameter(name='patterns',
in_=openapi.IN_FORM,
required=False,
type=openapi.TYPE_ARRAY, items=openapi.Items(type=openapi.TYPE_STRING),
title="分段正则列表", description="分段正则列表"),
openapi.Parameter(name='with_filter',
in_=openapi.IN_FORM,
required=False,
type=openapi.TYPE_BOOLEAN, title="是否清除特殊字符", description="是否清除特殊字符"),
]
def parse(self):
file_list = self.data.get("file")
return list(
map(lambda f: file_to_paragraph(f, self.data.get("patterns", None), self.data.get("with_filter", None),
self.data.get("limit", None)), file_list))
class SplitPattern(ApiMixin, serializers.Serializer):
@staticmethod
def list():
return [{'key': "#", 'value': '^# .*'}, {'key': '##', 'value': '(?<!#)## (?!#).*'},
{'key': '###', 'value': "(?<!#)### (?!#).*"}, {'key': '####', 'value': "(?<!#)####(?!#).*"},
{'key': '#####', 'value': "(?<!#)#####(?!#).*"}, {'key': '######', 'value': "(?<!#)######(?!#).*"},
{'key': '-', 'value': '(?<! )- .*'},
{'key': '空格', 'value': '(?<!\\s)\\s(?!\\s)'},
{'key': '分号', 'value': '(?<!)(?!)'}, {'key': '逗号', 'value': '(?<!)(?!)'},
{'key': '句号', 'value': '(?<!。)。(?!。)'}, {'key': '回车', 'value': '(?<!\\n)\\n(?!\\n)'},
{'key': '空行', 'value': '(?<!\\n)\\n\\n(?!\\n)'}]
class Batch(ApiMixin, serializers.Serializer):
dataset_id = serializers.UUIDField(required=True)
@staticmethod
def get_request_body_api():
return openapi.Schema(type=openapi.TYPE_ARRAY, items=DocumentSerializers.Create.get_request_body_api())
@staticmethod
def post_embedding(document_list):
for document_dict in document_list:
ListenerManagement.embedding_by_document_signal.send(document_dict.get('id'))
return document_list
@post(post_function=post_embedding)
@transaction.atomic
def batch_save(self, instance_list: List[Dict], with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
DocumentInstanceSerializer(many=True, data=instance_list).is_valid(raise_exception=True)
dataset_id = self.data.get("dataset_id")
document_model_list = []
paragraph_model_list = []
problem_model_list = []
# 插入文档
for document in instance_list:
document_paragraph_dict_model = DocumentSerializers.Create.get_document_paragraph_model(dataset_id,
document)
document_model_list.append(document_paragraph_dict_model.get('document'))
for paragraph in document_paragraph_dict_model.get('paragraph_model_list'):
paragraph_model_list.append(paragraph)
for problem in document_paragraph_dict_model.get('problem_model_list'):
problem_model_list.append(problem)
# 插入文档
QuerySet(Document).bulk_create(document_model_list) if len(document_model_list) > 0 else None
# 批量插入段落
QuerySet(Paragraph).bulk_create(paragraph_model_list) if len(paragraph_model_list) > 0 else None
# 批量插入问题
QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None
# 查询文档
query_set = QuerySet(model=Document)
query_set = query_set.filter(**{'id__in': [d.id for d in document_model_list]})
return native_search(query_set, select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_document.sql')), with_search_one=False),
def file_to_paragraph(file, pattern_list: List, with_filter: bool, limit: int):
data = file.read()
if pattern_list is not None and len(pattern_list) > 0:
split_model = SplitModel(pattern_list, with_filter, limit)
else:
split_model = get_split_model(file.name, with_filter=with_filter, limit=limit)
try:
content = data.decode('utf-8')
except BaseException as e:
return {'name': file.name,
'content': []}
return {'name': file.name,
'content': split_model.parse(content)
}