MaxKB/apps/knowledge/serializers/problem.py
2025-05-08 10:00:58 +08:00

245 lines
12 KiB
Python

import os
from functools import reduce
from typing import Dict, List
import uuid_utils.compat as uuid
from django.db import transaction
from django.db.models import QuerySet
from django.utils.translation import gettext_lazy as _
from rest_framework import serializers
from common.db.search import native_search, native_page_search
from common.utils.common import get_file_content
from knowledge.models import Problem, ProblemParagraphMapping, Paragraph, Knowledge, SourceType
from knowledge.serializers.common import get_embedding_model_id_by_knowledge_id
from knowledge.task.embedding import delete_embedding_by_source_ids, update_problem_embedding, embedding_by_data_list
from maxkb.const import PROJECT_DIR
class ProblemSerializer(serializers.ModelSerializer):
class Meta:
model = Problem
fields = ['id', 'content', 'knowledge_id', 'create_time', 'update_time']
class ProblemInstanceSerializer(serializers.Serializer):
id = serializers.CharField(required=False, label=_('problem id'))
content = serializers.CharField(required=True, max_length=256, label=_('content'))
class ProblemEditSerializer(serializers.Serializer):
content = serializers.CharField(required=True, max_length=256, label=_('content'))
class ProblemMappingSerializer(serializers.Serializer):
paragraph_id = serializers.UUIDField(required=True, label=_('paragraph id'))
document_id = serializers.UUIDField(required=True, label=_('document id'))
class ProblemBatchSerializer(serializers.Serializer):
problem_list = serializers.ListField(required=True, label=_('problem list'),
child=serializers.CharField(required=True, max_length=256, label=_('problem')))
class ProblemBatchDeleteSerializer(serializers.Serializer):
problem_id_list = serializers.ListField(required=True, label=_('problem id list'),
child=serializers.UUIDField(required=True, label=_('problem id')))
class AssociationParagraph(serializers.Serializer):
paragraph_id = serializers.UUIDField(required=True, label=_('paragraph id'))
document_id = serializers.UUIDField(required=True, label=_('document id'))
class BatchAssociation(serializers.Serializer):
problem_id_list = serializers.ListField(required=True, label=_('problem id list'),
child=serializers.UUIDField(required=True, label=_('problem id')))
paragraph_list = AssociationParagraph(many=True)
class ProblemSerializers(serializers.Serializer):
class BatchOperate(serializers.Serializer):
knowledge_id = serializers.UUIDField(required=True, label=_('knowledge id'))
def delete(self, problem_id_list: List, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
knowledge_id = self.data.get('knowledge_id')
problem_paragraph_mapping_list = QuerySet(ProblemParagraphMapping).filter(
knowledge_id=knowledge_id,
problem_id__in=problem_id_list)
source_ids = [row.id for row in problem_paragraph_mapping_list]
problem_paragraph_mapping_list.delete()
QuerySet(Problem).filter(id__in=problem_id_list).delete()
delete_embedding_by_source_ids(source_ids)
return True
def association(self, instance: Dict, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
BatchAssociation(data=instance).is_valid(raise_exception=True)
knowledge_id = self.data.get('knowledge_id')
paragraph_list = instance.get('paragraph_list')
problem_id_list = instance.get('problem_id_list')
problem_list = QuerySet(Problem).filter(id__in=problem_id_list)
exits_problem_paragraph_mapping = QuerySet(
ProblemParagraphMapping
).filter(problem_id__in=problem_id_list, paragraph_id__in=[p.get('paragraph_id') for p in paragraph_list])
problem_paragraph_mapping_list = [
(problem_paragraph_mapping, problem) for problem_paragraph_mapping, problem in
reduce(
lambda x, y: [*x, *y],
[
[
to_problem_paragraph_mapping(
problem, paragraph.get('document_id'),
paragraph.get('paragraph_id'),
knowledge_id
) for paragraph in paragraph_list
] for problem in problem_list
],
[]
) if not is_exits(exits_problem_paragraph_mapping, problem_paragraph_mapping)
]
QuerySet(ProblemParagraphMapping).bulk_create(
[problem_paragraph_mapping for problem_paragraph_mapping, problem in problem_paragraph_mapping_list]
)
data_list = [
{
'text': problem.content,
'is_active': True,
'source_type': SourceType.PROBLEM,
'source_id': str(problem_paragraph_mapping.id),
'document_id': str(problem_paragraph_mapping.document_id),
'paragraph_id': str(problem_paragraph_mapping.paragraph_id),
'knowledge_id': knowledge_id,
} for problem_paragraph_mapping, problem in problem_paragraph_mapping_list
]
model_id = get_embedding_model_id_by_knowledge_id(self.data.get('knowledge_id'))
embedding_by_data_list(data_list, model_id=model_id)
class Operate(serializers.Serializer):
knowledge_id = serializers.UUIDField(required=True, label=_('knowledge id'))
problem_id = serializers.UUIDField(required=True, label=_('problem id'))
def list_paragraph(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
problem_paragraph_mapping = QuerySet(ProblemParagraphMapping).filter(
knowledge_id=self.data.get("knowledge_id"),
problem_id=self.data.get("problem_id")
)
if problem_paragraph_mapping is None or len(problem_paragraph_mapping) == 0:
return []
return native_search(
QuerySet(Paragraph).filter(id__in=[row.paragraph_id for row in problem_paragraph_mapping]),
select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "knowledge", 'sql', 'list_paragraph.sql')))
def one(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
return ProblemInstanceSerializer(QuerySet(Problem).get(**{'id': self.data.get('problem_id')})).data
@transaction.atomic
def delete(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
problem_paragraph_mapping_list = QuerySet(ProblemParagraphMapping).filter(
knowledge_id=self.data.get('knowledge_id'),
problem_id=self.data.get('problem_id'))
source_ids = [row.id for row in problem_paragraph_mapping_list]
problem_paragraph_mapping_list.delete()
QuerySet(Problem).filter(id=self.data.get('problem_id')).delete()
delete_embedding_by_source_ids(source_ids)
return True
@transaction.atomic
def edit(self, instance: Dict, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
problem_id = self.data.get('problem_id')
knowledge_id = self.data.get('knowledge_id')
content = instance.get('content')
problem = QuerySet(Problem).filter(id=problem_id, knowledge_id=knowledge_id).first()
QuerySet(Knowledge).filter(id=knowledge_id)
problem.content = content
problem.save()
model_id = get_embedding_model_id_by_knowledge_id(knowledge_id)
update_problem_embedding(problem_id, content, model_id)
class Create(serializers.Serializer):
workspace_id = serializers.CharField(required=True, label=_('workspace id'))
knowledge_id = serializers.UUIDField(required=True, label=_('knowledge id'))
def batch(self, problem_list, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
ProblemBatchSerializer(data={'problem_list': problem_list}).is_valid(raise_exception=True)
problem_list = list(set(problem_list))
knowledge_id = self.data.get('knowledge_id')
exists_problem_content_list = [
problem.content for problem in QuerySet(
Problem
).filter(knowledge_id=knowledge_id, content__in=problem_list)
]
problem_instance_list = [
Problem(
id=uuid.uuid7(), knowledge_id=knowledge_id, content=problem_content
) for problem_content in problem_list if (
not exists_problem_content_list.__contains__(
problem_content
) if len(exists_problem_content_list) > 0 else True
)
]
QuerySet(Problem).bulk_create(problem_instance_list) if len(problem_instance_list) > 0 else None
return [ProblemSerializer(problem_instance).data for problem_instance in problem_instance_list]
class Query(serializers.Serializer):
workspace_id = serializers.CharField(required=True, label=_('workspace id'))
knowledge_id = serializers.UUIDField(required=True, label=_('knowledge id'))
content = serializers.CharField(required=False, label=_('content'))
def get_query_set(self):
query_set = QuerySet(model=Problem)
query_set = query_set.filter(
**{'knowledge_id': self.data.get('knowledge_id')})
if 'content' in self.data:
query_set = query_set.filter(**{'content__icontains': self.data.get('content')})
query_set = query_set.order_by("-create_time")
return query_set
def list(self):
query_set = self.get_query_set()
return native_search(query_set, select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "knowledge", 'sql', 'list_problem.sql')))
def page(self, current_page, page_size):
query_set = self.get_query_set()
return native_page_search(current_page, page_size, query_set, select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "knowledge", 'sql', 'list_problem.sql')))
def is_exits(exits_problem_paragraph_mapping_list, new_paragraph_mapping):
filter_list = [exits_problem_paragraph_mapping for exits_problem_paragraph_mapping in
exits_problem_paragraph_mapping_list if
str(exits_problem_paragraph_mapping.paragraph_id) == new_paragraph_mapping.paragraph_id
and str(exits_problem_paragraph_mapping.problem_id) == new_paragraph_mapping.problem_id
and str(exits_problem_paragraph_mapping.knowledge_id) == new_paragraph_mapping.knowledge_id]
return len(filter_list) > 0
def to_problem_paragraph_mapping(problem, document_id: str, paragraph_id: str, knowledge_id: str):
return ProblemParagraphMapping(
id=uuid.uuid7(),
document_id=document_id,
paragraph_id=paragraph_id,
knowledge_id=knowledge_id,
problem_id=str(problem.id)
), problem