MaxKB/apps/dataset/serializers/problem_serializers.py
2024-01-16 16:46:54 +08:00

232 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# coding=utf-8
"""
@project: maxkb
@Author
@file problem_serializers.py
@date2023/10/23 13:55
@desc:
"""
import uuid
from typing import Dict
from django.db.models import QuerySet
from drf_yasg import openapi
from rest_framework import serializers
from common.event.listener_manage import ListenerManagement
from common.exception.app_exception import AppApiException
from common.mixins.api_mixin import ApiMixin
from dataset.models import Problem, Paragraph
from embedding.models import SourceType
from embedding.vector.pg_vector import PGVector
class ProblemSerializer(serializers.ModelSerializer):
class Meta:
model = Problem
fields = ['id', 'content', 'dataset_id', 'document_id',
'create_time', 'update_time']
class ProblemInstanceSerializer(ApiMixin, serializers.Serializer):
id = serializers.CharField(required=False)
content = serializers.CharField(required=True)
@staticmethod
def get_request_body_api():
return openapi.Schema(type=openapi.TYPE_OBJECT,
required=["content"],
properties={
'id': openapi.Schema(
type=openapi.TYPE_STRING,
title="问题id,修改的时候传递,创建的时候不传"),
'content': openapi.Schema(
type=openapi.TYPE_STRING, title="内容")
})
class ProblemSerializers(ApiMixin, serializers.Serializer):
class Create(ApiMixin, serializers.Serializer):
dataset_id = serializers.UUIDField(required=True)
document_id = serializers.UUIDField(required=True)
paragraph_id = serializers.UUIDField(required=True)
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
if not QuerySet(Paragraph).filter(id=self.data.get('paragraph_id'),
document_id=self.data.get('document_id'),
dataset_id=self.data.get('dataset_id')).exists():
raise AppApiException(500, "段落id不正确")
def save(self, instance: Dict, with_valid=True, with_embedding=True):
if with_valid:
self.is_valid()
ProblemInstanceSerializer(data=instance).is_valid(raise_exception=True)
problem = Problem(id=uuid.uuid1(), paragraph_id=self.data.get('paragraph_id'),
document_id=self.data.get('document_id'), dataset_id=self.data.get('dataset_id'),
content=instance.get('content'))
problem.save()
if with_embedding:
ListenerManagement.embedding_by_problem_signal.send({'text': problem.content,
'is_active': True,
'source_type': SourceType.PROBLEM,
'source_id': problem.id,
'document_id': self.data.get('document_id'),
'paragraph_id': self.data.get('paragraph_id'),
'dataset_id': self.data.get('dataset_id'),
})
return ProblemSerializers.Operate(
data={'dataset_id': self.data.get('dataset_id'), 'document_id': self.data.get('document_id'),
'paragraph_id': self.data.get('paragraph_id'), 'problem_id': problem.id}).one(with_valid=True)
@staticmethod
def get_request_body_api():
return ProblemInstanceSerializer.get_request_body_api()
@staticmethod
def get_request_params_api():
return [openapi.Parameter(name='dataset_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='知识库id'),
openapi.Parameter(name='document_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='文档id'),
openapi.Parameter(name='paragraph_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='段落id')]
class Query(ApiMixin, serializers.Serializer):
dataset_id = serializers.UUIDField(required=True)
document_id = serializers.UUIDField(required=True)
paragraph_id = serializers.UUIDField(required=True)
def is_valid(self, *, raise_exception=True):
super().is_valid(raise_exception=True)
if not QuerySet(Paragraph).filter(id=self.data.get('paragraph_id')).exists():
raise AppApiException(500, "段落id不存在")
def get_query_set(self):
dataset_id = self.data.get('dataset_id')
document_id = self.data.get('document_id')
paragraph_id = self.data.get("paragraph_id")
return QuerySet(Problem).filter(
**{'paragraph_id': paragraph_id, 'dataset_id': dataset_id, 'document_id': document_id})
def list(self, with_valid=False):
"""
获取问题列表
:param with_valid: 是否校验
:return: 问题列表
"""
if with_valid:
self.is_valid(raise_exception=True)
query_set = self.get_query_set()
return [ProblemSerializer(p).data for p in query_set]
@staticmethod
def get_request_params_api():
return [openapi.Parameter(name='dataset_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='知识库id'),
openapi.Parameter(name='document_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='文档id')
, openapi.Parameter(name='paragraph_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='段落id')]
class Operate(ApiMixin, serializers.Serializer):
dataset_id = serializers.UUIDField(required=True)
document_id = serializers.UUIDField(required=True)
paragraph_id = serializers.UUIDField(required=True)
problem_id = serializers.UUIDField(required=True)
def delete(self, with_valid=False):
if with_valid:
self.is_valid(raise_exception=True)
QuerySet(Problem).filter(**{'id': self.data.get('problem_id')}).delete()
PGVector().delete_by_source_id(self.data.get('problem_id'), SourceType.PROBLEM)
ListenerManagement.delete_embedding_by_source_signal.send(self.data.get('problem_id'))
return True
def one(self, with_valid=False):
if with_valid:
self.is_valid(raise_exception=True)
return ProblemInstanceSerializer(QuerySet(Problem).get(**{'id': self.data.get('problem_id')})).data
@staticmethod
def get_request_params_api():
return [openapi.Parameter(name='dataset_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='知识库id'),
openapi.Parameter(name='document_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='文档id')
, openapi.Parameter(name='paragraph_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='段落id'),
openapi.Parameter(name='problem_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='问题id')
]
@staticmethod
def get_response_body_api():
return openapi.Schema(
type=openapi.TYPE_OBJECT,
required=['id', 'content', 'hit_num', 'star_num', 'trample_num', 'dataset_id',
'document_id',
'create_time', 'update_time'],
properties={
'id': openapi.Schema(type=openapi.TYPE_STRING, title="id",
description="id", default="xx"),
'content': openapi.Schema(type=openapi.TYPE_STRING, title="问题内容",
description="问题内容", default='问题内容'),
'hit_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="命中数量", description="命中数量",
default=1),
'star_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="点赞数量",
description="点赞数量", default=1),
'trample_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="点踩数量",
description="点踩数", default=1),
'document_id': openapi.Schema(type=openapi.TYPE_STRING, title="文档id",
description="文档id", default='xxx'),
'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间",
description="修改时间",
default="1970-01-01 00:00:00"),
'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间",
description="创建时间",
default="1970-01-01 00:00:00"
)
}
)