MaxKB/apps/dataset/serializers/problem_serializers.py
2025-01-13 18:11:20 +08:00

240 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# coding=utf-8
"""
@project: maxkb
@Author
@file problem_serializers.py
@date2023/10/23 13:55
@desc:
"""
import os
import uuid
from functools import reduce
from typing import Dict, List
from django.db import transaction
from django.db.models import QuerySet
from drf_yasg import openapi
from rest_framework import serializers
from common.db.search import native_search, native_page_search
from common.mixins.api_mixin import ApiMixin
from common.util.field_message import ErrMessage
from common.util.file_util import get_file_content
from dataset.models import Problem, Paragraph, ProblemParagraphMapping, DataSet
from dataset.serializers.common_serializers import get_embedding_model_id_by_dataset_id
from embedding.models import SourceType
from embedding.task import delete_embedding_by_source_ids, update_problem_embedding, embedding_by_data_list
from smartdoc.conf import PROJECT_DIR
from django.utils.translation import gettext_lazy as _
class ProblemSerializer(serializers.ModelSerializer):
class Meta:
model = Problem
fields = ['id', 'content', 'dataset_id',
'create_time', 'update_time']
class ProblemInstanceSerializer(ApiMixin, serializers.Serializer):
id = serializers.CharField(required=False, error_messages=ErrMessage.char(_('problem id')))
content = serializers.CharField(required=True, max_length=256, error_messages=ErrMessage.char(_('content')))
@staticmethod
def get_request_body_api():
return openapi.Schema(type=openapi.TYPE_OBJECT,
required=["content"],
properties={
'id': openapi.Schema(
type=openapi.TYPE_STRING,
title=_('Issue ID is passed when modifying, not when creating.')),
'content': openapi.Schema(
type=openapi.TYPE_STRING, title=_('content'),)
})
class AssociationParagraph(serializers.Serializer):
paragraph_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid(_('paragraph id')))
document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid(_('document id')))
class BatchAssociation(serializers.Serializer):
problem_id_list = serializers.ListField(required=True, error_messages=ErrMessage.list(_('problem id list')),
child=serializers.UUIDField(required=True,
error_messages=ErrMessage.uuid(_('problem id'))))
paragraph_list = AssociationParagraph(many=True)
def is_exits(exits_problem_paragraph_mapping_list, new_paragraph_mapping):
filter_list = [exits_problem_paragraph_mapping for exits_problem_paragraph_mapping in
exits_problem_paragraph_mapping_list if
str(exits_problem_paragraph_mapping.paragraph_id) == new_paragraph_mapping.paragraph_id
and str(exits_problem_paragraph_mapping.problem_id) == new_paragraph_mapping.problem_id
and str(exits_problem_paragraph_mapping.dataset_id) == new_paragraph_mapping.dataset_id]
return len(filter_list) > 0
def to_problem_paragraph_mapping(problem, document_id: str, paragraph_id: str, dataset_id: str):
return ProblemParagraphMapping(id=uuid.uuid1(),
document_id=document_id,
paragraph_id=paragraph_id,
dataset_id=dataset_id,
problem_id=str(problem.id)), problem
class ProblemSerializers(ApiMixin, serializers.Serializer):
class Create(serializers.Serializer):
dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid(_('dataset id')))
problem_list = serializers.ListField(required=True, error_messages=ErrMessage.list(_('problem list')),
child=serializers.CharField(required=True,
max_length=256,
error_messages=ErrMessage.char(_('problem'))))
def batch(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
problem_list = self.data.get('problem_list')
problem_list = list(set(problem_list))
dataset_id = self.data.get('dataset_id')
exists_problem_content_list = [problem.content for problem in
QuerySet(Problem).filter(dataset_id=dataset_id,
content__in=problem_list)]
problem_instance_list = [Problem(id=uuid.uuid1(), dataset_id=dataset_id, content=problem_content) for
problem_content in
problem_list if
(not exists_problem_content_list.__contains__(problem_content) if
len(exists_problem_content_list) > 0 else True)]
QuerySet(Problem).bulk_create(problem_instance_list) if len(problem_instance_list) > 0 else None
return [ProblemSerializer(problem_instance).data for problem_instance in problem_instance_list]
class Query(serializers.Serializer):
dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid(_('dataset id')))
content = serializers.CharField(required=False, error_messages=ErrMessage.char(_('content')))
def get_query_set(self):
query_set = QuerySet(model=Problem)
query_set = query_set.filter(
**{'dataset_id': self.data.get('dataset_id')})
if 'content' in self.data:
query_set = query_set.filter(**{'content__icontains': self.data.get('content')})
query_set = query_set.order_by("-create_time")
return query_set
def list(self):
query_set = self.get_query_set()
return native_search(query_set, select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_problem.sql')))
def page(self, current_page, page_size):
query_set = self.get_query_set()
return native_page_search(current_page, page_size, query_set, select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_problem.sql')))
class BatchOperate(serializers.Serializer):
dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid(_('dataset id')))
def delete(self, problem_id_list: List, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
dataset_id = self.data.get('dataset_id')
problem_paragraph_mapping_list = QuerySet(ProblemParagraphMapping).filter(
dataset_id=dataset_id,
problem_id__in=problem_id_list)
source_ids = [row.id for row in problem_paragraph_mapping_list]
problem_paragraph_mapping_list.delete()
QuerySet(Problem).filter(id__in=problem_id_list).delete()
delete_embedding_by_source_ids(source_ids)
return True
def association(self, instance: Dict, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
BatchAssociation(data=instance).is_valid(raise_exception=True)
dataset_id = self.data.get('dataset_id')
paragraph_list = instance.get('paragraph_list')
problem_id_list = instance.get('problem_id_list')
problem_list = QuerySet(Problem).filter(id__in=problem_id_list)
exits_problem_paragraph_mapping = QuerySet(ProblemParagraphMapping).filter(problem_id__in=problem_id_list,
paragraph_id__in=[
p.get('paragraph_id')
for p in
paragraph_list])
problem_paragraph_mapping_list = [(problem_paragraph_mapping, problem) for
problem_paragraph_mapping, problem in reduce(lambda x, y: [*x, *y],
[[
to_problem_paragraph_mapping(
problem,
paragraph.get(
'document_id'),
paragraph.get(
'paragraph_id'),
dataset_id) for
paragraph in
paragraph_list]
for problem in
problem_list], []) if
not is_exits(exits_problem_paragraph_mapping, problem_paragraph_mapping)]
QuerySet(ProblemParagraphMapping).bulk_create(
[problem_paragraph_mapping for problem_paragraph_mapping, problem in problem_paragraph_mapping_list])
data_list = [{'text': problem.content,
'is_active': True,
'source_type': SourceType.PROBLEM,
'source_id': str(problem_paragraph_mapping.id),
'document_id': str(problem_paragraph_mapping.document_id),
'paragraph_id': str(problem_paragraph_mapping.paragraph_id),
'dataset_id': dataset_id,
} for problem_paragraph_mapping, problem in problem_paragraph_mapping_list]
model_id = get_embedding_model_id_by_dataset_id(self.data.get('dataset_id'))
embedding_by_data_list(data_list, model_id=model_id)
class Operate(serializers.Serializer):
dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid(_('dataset id')))
problem_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid(_('problem id')))
def list_paragraph(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
problem_paragraph_mapping = QuerySet(ProblemParagraphMapping).filter(dataset_id=self.data.get("dataset_id"),
problem_id=self.data.get("problem_id"))
if problem_paragraph_mapping is None or len(problem_paragraph_mapping) == 0:
return []
return native_search(
QuerySet(Paragraph).filter(id__in=[row.paragraph_id for row in problem_paragraph_mapping]),
select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_paragraph.sql')))
def one(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
return ProblemInstanceSerializer(QuerySet(Problem).get(**{'id': self.data.get('problem_id')})).data
@transaction.atomic
def delete(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
problem_paragraph_mapping_list = QuerySet(ProblemParagraphMapping).filter(
dataset_id=self.data.get('dataset_id'),
problem_id=self.data.get('problem_id'))
source_ids = [row.id for row in problem_paragraph_mapping_list]
problem_paragraph_mapping_list.delete()
QuerySet(Problem).filter(id=self.data.get('problem_id')).delete()
delete_embedding_by_source_ids(source_ids)
return True
@transaction.atomic
def edit(self, instance: Dict, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
problem_id = self.data.get('problem_id')
dataset_id = self.data.get('dataset_id')
content = instance.get('content')
problem = QuerySet(Problem).filter(id=problem_id,
dataset_id=dataset_id).first()
QuerySet(DataSet).filter(id=dataset_id)
problem.content = content
problem.save()
model_id = get_embedding_model_id_by_dataset_id(dataset_id)
update_problem_embedding(problem_id, content, model_id)