diff --git a/apps/knowledge/api/problem.py b/apps/knowledge/api/problem.py index e69de29bb..428e278c8 100644 --- a/apps/knowledge/api/problem.py +++ b/apps/knowledge/api/problem.py @@ -0,0 +1,50 @@ +from drf_spectacular.types import OpenApiTypes +from drf_spectacular.utils import OpenApiParameter + +from common.mixins.api_mixin import APIMixin +from common.result import DefaultResultSerializer +from knowledge.serializers.problem import ProblemBatchSerializer, \ + ProblemBatchDeleteSerializer, BatchAssociation + + +class ProblemReadAPI(APIMixin): + @staticmethod + def get_parameters(): + return [ + OpenApiParameter( + name="workspace_id", + description="工作空间id", + type=OpenApiTypes.STR, + location='path', + required=True, + ), + OpenApiParameter( + name="knowledge_id", + description="知识库id", + type=OpenApiTypes.STR, + location='path', + required=True, + ), + ] + + @staticmethod + def get_response(): + return DefaultResultSerializer + + +class ProblemBatchCreateAPI(ProblemReadAPI): + @staticmethod + def get_request(): + return ProblemBatchSerializer + + +class BatchAssociationAPI(ProblemReadAPI): + @staticmethod + def get_request(): + return BatchAssociation + + +class BatchDeleteAPI(ProblemReadAPI): + @staticmethod + def get_request(): + return ProblemBatchDeleteSerializer diff --git a/apps/knowledge/serializers/problem.py b/apps/knowledge/serializers/problem.py index 91855d4dc..96dc64152 100644 --- a/apps/knowledge/serializers/problem.py +++ b/apps/knowledge/serializers/problem.py @@ -1,16 +1,18 @@ import os -from typing import Dict +from functools import reduce +from typing import Dict, List +import uuid_utils.compat as uuid from django.db import transaction from django.db.models import QuerySet from django.utils.translation import gettext_lazy as _ from rest_framework import serializers -from common.db.search import native_search +from common.db.search import native_search, native_page_search from common.utils.common import get_file_content -from knowledge.models import Problem, ProblemParagraphMapping, Paragraph, Knowledge +from knowledge.models import Problem, ProblemParagraphMapping, Paragraph, Knowledge, SourceType from knowledge.serializers.common import get_embedding_model_id_by_knowledge_id -from knowledge.task.embedding import delete_embedding_by_source_ids, update_problem_embedding +from knowledge.task.embedding import delete_embedding_by_source_ids, update_problem_embedding, embedding_by_data_list from maxkb.const import PROJECT_DIR @@ -25,7 +27,114 @@ class ProblemInstanceSerializer(serializers.Serializer): content = serializers.CharField(required=True, max_length=256, label=_('content')) +class ProblemMappingSerializer(serializers.Serializer): + paragraph_id = serializers.UUIDField(required=True, label=_('paragraph id')) + document_id = serializers.UUIDField(required=True, label=_('document id')) + + +class ProblemBatchSerializer(serializers.Serializer): + problem_list = serializers.ListField(required=True, label=_('problem list'), + child=serializers.CharField(required=True, max_length=256, label=_('problem'))) + + +class ProblemBatchDeleteSerializer(serializers.Serializer): + problem_id_list = serializers.ListField(required=True, label=_('problem id list'), + child=serializers.UUIDField(required=True, label=_('problem id'))) + + +class AssociationParagraph(serializers.Serializer): + paragraph_id = serializers.UUIDField(required=True, label=_('paragraph id')) + document_id = serializers.UUIDField(required=True, label=_('document id')) + + +class BatchAssociation(serializers.Serializer): + problem_id_list = serializers.ListField(required=True, label=_('problem id list'), + child=serializers.UUIDField(required=True, label=_('problem id'))) + paragraph_list = AssociationParagraph(many=True) + + +def is_exits(exits_problem_paragraph_mapping_list, new_paragraph_mapping): + filter_list = [exits_problem_paragraph_mapping for exits_problem_paragraph_mapping in + exits_problem_paragraph_mapping_list if + str(exits_problem_paragraph_mapping.paragraph_id) == new_paragraph_mapping.paragraph_id + and str(exits_problem_paragraph_mapping.problem_id) == new_paragraph_mapping.problem_id + and str(exits_problem_paragraph_mapping.dataset_id) == new_paragraph_mapping.dataset_id] + return len(filter_list) > 0 + + +def to_problem_paragraph_mapping(problem, document_id: str, paragraph_id: str, dataset_id: str): + return ProblemParagraphMapping(id=uuid.uuid1(), + document_id=document_id, + paragraph_id=paragraph_id, + dataset_id=dataset_id, + problem_id=str(problem.id)), problem + + class ProblemSerializers(serializers.Serializer): + class BatchOperate(serializers.Serializer): + knowledge_id = serializers.UUIDField(required=True, label=_('knowledge id')) + + def delete(self, problem_id_list: List, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + knowledge_id = self.data.get('knowledge_id') + problem_paragraph_mapping_list = QuerySet(ProblemParagraphMapping).filter( + knowledge_id=knowledge_id, + problem_id__in=problem_id_list) + source_ids = [row.id for row in problem_paragraph_mapping_list] + problem_paragraph_mapping_list.delete() + QuerySet(Problem).filter(id__in=problem_id_list).delete() + delete_embedding_by_source_ids(source_ids) + return True + + def association(self, instance: Dict, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + BatchAssociation(data=instance).is_valid(raise_exception=True) + knowledge_id = self.data.get('knowledge_id') + paragraph_list = instance.get('paragraph_list') + problem_id_list = instance.get('problem_id_list') + problem_list = QuerySet(Problem).filter(id__in=problem_id_list) + + exits_problem_paragraph_mapping = QuerySet( + ProblemParagraphMapping + ).filter(problem_id__in=problem_id_list, paragraph_id__in=[p.get('paragraph_id') for p in paragraph_list]) + + problem_paragraph_mapping_list = [ + (problem_paragraph_mapping, problem) for problem_paragraph_mapping, problem in + reduce( + lambda x, y: [*x, *y], + [ + [ + to_problem_paragraph_mapping( + problem, paragraph.get('document_id'), + paragraph.get('paragraph_id'), + knowledge_id + ) for paragraph in paragraph_list + ] for problem in problem_list + ], + [] + ) if not is_exits(exits_problem_paragraph_mapping, problem_paragraph_mapping) + ] + + QuerySet(ProblemParagraphMapping).bulk_create([ + problem_paragraph_mapping for problem_paragraph_mapping, problem in problem_paragraph_mapping_list + ]) + + data_list = [ + { + 'text': problem.content, + 'is_active': True, + 'source_type': SourceType.PROBLEM, + 'source_id': str(problem_paragraph_mapping.id), + 'document_id': str(problem_paragraph_mapping.document_id), + 'paragraph_id': str(problem_paragraph_mapping.paragraph_id), + 'knowledge_id': knowledge_id, + } for problem_paragraph_mapping, problem in problem_paragraph_mapping_list + ] + model_id = get_embedding_model_id_by_knowledge_id(self.data.get('knowledge_id')) + embedding_by_data_list(data_list, model_id=model_id) + class Operate(serializers.Serializer): knowledge_id = serializers.UUIDField(required=True, label=_('knowledge id')) problem_id = serializers.UUIDField(required=True, label=_('problem id')) @@ -75,3 +184,55 @@ class ProblemSerializers(serializers.Serializer): problem.save() model_id = get_embedding_model_id_by_knowledge_id(knowledge_id) update_problem_embedding(problem_id, content, model_id) + + class Create(serializers.Serializer): + workspace_id = serializers.CharField(required=True, label=_('workspace id')) + knowledge_id = serializers.UUIDField(required=True, label=_('knowledge id')) + + def batch(self, problem_list, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + ProblemBatchSerializer(data={'problem_list': problem_list}).is_valid(raise_exception=True) + problem_list = list(set(problem_list)) + knowledge_id = self.data.get('knowledge_id') + exists_problem_content_list = [ + problem.content for problem in QuerySet( + Problem + ).filter(knowledge_id=knowledge_id, content__in=problem_list) + ] + problem_instance_list = [ + Problem( + id=uuid.uuid7(), knowledge_id=knowledge_id, content=problem_content + ) for problem_content in problem_list if ( + not exists_problem_content_list.__contains__( + problem_content + ) if len(exists_problem_content_list) > 0 else True + ) + ] + + QuerySet(Problem).bulk_create(problem_instance_list) if len(problem_instance_list) > 0 else None + return [ProblemSerializer(problem_instance).data for problem_instance in problem_instance_list] + + class Query(serializers.Serializer): + workspace_id = serializers.CharField(required=True, label=_('workspace id')) + knowledge_id = serializers.UUIDField(required=True, label=_('knowledge id')) + content = serializers.CharField(required=False, label=_('content')) + + def get_query_set(self): + query_set = QuerySet(model=Problem) + query_set = query_set.filter( + **{'knowledge_id': self.data.get('knowledge_id')}) + if 'content' in self.data: + query_set = query_set.filter(**{'content__icontains': self.data.get('content')}) + query_set = query_set.order_by("-create_time") + return query_set + + def list(self): + query_set = self.get_query_set() + return native_search(query_set, select_string=get_file_content( + os.path.join(PROJECT_DIR, "apps", "knowledge", 'sql', 'list_problem.sql'))) + + def page(self, current_page, page_size): + query_set = self.get_query_set() + return native_page_search(current_page, page_size, query_set, select_string=get_file_content( + os.path.join(PROJECT_DIR, "apps", "knowledge", 'sql', 'list_problem.sql'))) diff --git a/apps/knowledge/urls.py b/apps/knowledge/urls.py index 660aeb68e..0f1a5985a 100644 --- a/apps/knowledge/urls.py +++ b/apps/knowledge/urls.py @@ -30,6 +30,9 @@ urlpatterns = [ path( 'workspace//knowledge//document//paragraph//', views.ParagraphView.Page.as_view()), path('workspace//knowledge//document//paragraph//problem//association', views.ParagraphView.Association.as_view()), path('workspace//knowledge//document//paragraph//problem//unassociation', views.ParagraphView.UnAssociation.as_view()), + path('workspace//knowledge//problem', views.ProblemView.as_view()), + path('workspace//knowledge//problem/batch_delete', views.ProblemView.BatchDelete.as_view()), + path('workspace//knowledge//problem/batch_association', views.ProblemView.BatchAssociation.as_view()), path('workspace//knowledge//document//', views.DocumentView.Page.as_view()), path('workspace//knowledge//', views.KnowledgeView.Page.as_view()), ] diff --git a/apps/knowledge/views/__init__.py b/apps/knowledge/views/__init__.py index c1501a122..586b2d335 100644 --- a/apps/knowledge/views/__init__.py +++ b/apps/knowledge/views/__init__.py @@ -1,3 +1,4 @@ from .document import * from .knowledge import * from .paragraph import * +from .problem import * diff --git a/apps/knowledge/views/problem.py b/apps/knowledge/views/problem.py new file mode 100644 index 000000000..a219fbebe --- /dev/null +++ b/apps/knowledge/views/problem.py @@ -0,0 +1,90 @@ +from django.utils.translation import gettext_lazy as _ +from drf_spectacular.utils import extend_schema +from rest_framework.views import APIView +from rest_framework.views import Request + +from common.auth import TokenAuth +from common.auth.authentication import has_permissions +from common.constants.permission_constants import PermissionConstants +from common.result import result +from common.utils.common import query_params_to_single_dict +from knowledge.api.problem import ProblemReadAPI, ProblemBatchCreateAPI, BatchAssociationAPI, BatchDeleteAPI +from knowledge.serializers.problem import ProblemSerializers + + +class ProblemView(APIView): + authentication_classes = [TokenAuth] + + @extend_schema( + methods=['GET'], + summary=_('Question list'), + description=_('Question list'), + operation_id=_('Question list'), + parameters=ProblemReadAPI.get_parameters(), + responses=ProblemReadAPI.get_response(), + tags=[_('Knowledge Base/Documentation/Paragraph/Question')] + ) + @has_permissions(PermissionConstants.DOCUMENT_EDIT.get_workspace_permission()) + def get(self, request: Request, workspace_id: str, knowledge_id: str): + q = ProblemSerializers.Query( + data={ + **query_params_to_single_dict(request.query_params), + 'workspace_id': workspace_id, + 'knowledge_id': knowledge_id + } + ) + q.is_valid(raise_exception=True) + return result.success(q.list()) + + @extend_schema( + methods=['POST'], + summary=_('Create question'), + description=_('Create question'), + operation_id=_('Create question'), + parameters=ProblemBatchCreateAPI.get_parameters(), + responses=ProblemBatchCreateAPI.get_response(), + request=ProblemBatchCreateAPI.get_request(), + tags=[_('Knowledge Base/Documentation/Paragraph/Question')] + ) + @has_permissions(PermissionConstants.DOCUMENT_EDIT.get_workspace_permission()) + def post(self, request: Request, workspace_id: str, knowledge_id: str): + return result.success(ProblemSerializers.Create( + data={'workspace_id': workspace_id, 'knowledge_id': knowledge_id, 'problem_list': request.data} + ).batch()) + + class BatchAssociation(APIView): + authentication_classes = [TokenAuth] + + @extend_schema( + summary=_('Batch associated paragraphs'), + description=_('Batch associated paragraphs'), + operation_id=_('Batch associated paragraphs'), + request=BatchAssociationAPI.get_request(), + parameters=BatchAssociationAPI.get_parameters(), + responses=BatchAssociationAPI.get_response(), + tags=[_('Knowledge Base/Documentation/Paragraph/Question')] + ) + @has_permissions(PermissionConstants.DOCUMENT_EDIT.get_workspace_permission()) + def put(self, request: Request, workspace_id: str, knowledge_id: str): + return result.success(ProblemSerializers.BatchOperate( + data={'knowledge_id': knowledge_id, 'workspace_id': workspace_id} + ).association(request.data)) + + class BatchDelete(APIView): + authentication_classes = [TokenAuth] + + @extend_schema( + methods=['PUT'], + summary=_('Batch deletion issues'), + description=_('Batch deletion issues'), + operation_id=_('Batch deletion issues'), + request=BatchDeleteAPI.get_request(), + parameters=BatchDeleteAPI.get_parameters(), + responses=BatchDeleteAPI.get_response(), + tags=[_('Knowledge Base/Documentation/Paragraph/Question')] + ) + @has_permissions(PermissionConstants.DOCUMENT_EDIT.get_workspace_permission()) + def put(self, request: Request, workspace_id: str, knowledge_id: str): + return result.success(ProblemSerializers.BatchOperate( + data={'knowledge_id': knowledge_id, 'workspace_id': workspace_id} + ).delete(request.data))