diff --git a/apps/knowledge/api/document.py b/apps/knowledge/api/document.py index cfafaa583..71fd7794c 100644 --- a/apps/knowledge/api/document.py +++ b/apps/knowledge/api/document.py @@ -4,7 +4,8 @@ from drf_spectacular.utils import OpenApiParameter from common.mixins.api_mixin import APIMixin from common.result import DefaultResultSerializer from knowledge.serializers.common import BatchSerializer -from knowledge.serializers.document import DocumentInstanceSerializer, DocumentWebInstanceSerializer +from knowledge.serializers.document import DocumentInstanceSerializer, DocumentWebInstanceSerializer, \ + CancelInstanceSerializer, BatchCancelInstanceSerializer, DocumentRefreshSerializer, BatchEditHitHandlingSerializer class DocumentSplitAPI(APIMixin): @@ -218,3 +219,50 @@ class WebDocumentCreateAPI(APIMixin): @staticmethod def get_request(): return DocumentWebInstanceSerializer + + +class CancelTaskAPI(DocumentReadAPI): + @staticmethod + def get_request(): + return CancelInstanceSerializer + + +class BatchCancelTaskAPI(DocumentReadAPI): + @staticmethod + def get_request(): + return BatchCancelInstanceSerializer + + +class SyncWebAPI(DocumentReadAPI): + pass + + +class RefreshAPI(DocumentReadAPI): + @staticmethod + def get_request(): + return DocumentRefreshSerializer + + +class BatchEditHitHandlingAPI(APIMixin): + @staticmethod + def get_parameters(): + return [ + OpenApiParameter( + name="workspace_id", + description="工作空间id", + type=OpenApiTypes.STR, + location='path', + required=True, + ), + OpenApiParameter( + name="knowledge_id", + description="知识库id", + type=OpenApiTypes.STR, + location='path', + required=True, + ), + ] + + @staticmethod + def get_request(): + return BatchEditHitHandlingSerializer diff --git a/apps/knowledge/serializers/document.py b/apps/knowledge/serializers/document.py index 001d17f4d..84e080eb2 100644 --- a/apps/knowledge/serializers/document.py +++ b/apps/knowledge/serializers/document.py @@ -64,7 +64,7 @@ parse_table_handle_list = [CsvParseTableHandle(), XlsParseTableHandle(), XlsxPar class BatchCancelInstanceSerializer(serializers.Serializer): - id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True), label=('id list')) + id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True), label=_('id list')) type = serializers.IntegerField(required=True, label=_('task type')) def is_valid(self, *, raise_exception=False): @@ -81,6 +81,18 @@ class DocumentInstanceSerializer(serializers.Serializer): paragraphs = ParagraphInstanceSerializer(required=False, many=True, allow_null=True) +class CancelInstanceSerializer(serializers.Serializer): + type = serializers.IntegerField(required=True, label=_('task type')) + + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + _type = self.data.get('type') + try: + TaskType(_type) + except Exception as e: + raise AppApiException(500, _('task type not support')) + + class DocumentEditInstanceSerializer(serializers.Serializer): meta = serializers.DictField(required=False) name = serializers.CharField(required=False, max_length=128, min_length=1, label=_('document name')) @@ -138,6 +150,22 @@ class DocumentInstanceTableSerializer(serializers.Serializer): child=serializers.FileField(required=True, label=_('file'))) +class DocumentRefreshSerializer(serializers.Serializer): + state_list = serializers.ListField(required=True, label=_('state list')) + + +class BatchEditHitHandlingSerializer(serializers.Serializer): + id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True), label=_('id list')) + hit_handling_method = serializers.CharField(required=True, label=_('hit handling method')) + directly_return_similarity = serializers.FloatField(required=False, max_value=2, min_value=0, + label=_('directly return similarity')) + + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + if self.data.get('hit_handling_method') not in ['optimization', 'directly_return']: + raise AppApiException(500, _('The type only supports optimization|directly_return')) + + class DocumentSerializers(serializers.Serializer): class Query(serializers.Serializer): # 知识库id @@ -201,6 +229,8 @@ class DocumentSerializers(serializers.Serializer): os.path.join(PROJECT_DIR, "apps", "knowledge", 'sql', 'list_document.sql'))) class Sync(serializers.Serializer): + workspace_id = serializers.CharField(required=False, label=_('workspace id')) + knowledge_id = serializers.UUIDField(required=False, label=_('knowledge id')) document_id = serializers.UUIDField(required=True, label=_('document id')) def is_valid(self, *, raise_exception=False): @@ -320,6 +350,38 @@ class DocumentSerializers(serializers.Serializer): _document.save() return self.one() + def cancel(self, instance, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + CancelInstanceSerializer(data=instance).is_valid() + document_id = self.data.get("document_id") + ListenerManagement.update_status( + QuerySet(Paragraph).annotate( + reversed_status=Reverse('status'), + task_type_status=Substr('reversed_status', TaskType(instance.get('type')).value, 1), + ).filter( + task_type_status__in=[State.PENDING.value, State.STARTED.value] + ).filter( + document_id=document_id + ).values('id'), + TaskType(instance.get('type')), + State.REVOKE + ) + ListenerManagement.update_status( + QuerySet(Document).annotate( + reversed_status=Reverse('status'), + task_type_status=Substr('reversed_status', TaskType(instance.get('type')).value, + 1), + ).filter( + task_type_status__in=[State.PENDING.value, State.STARTED.value] + ).filter( + id=document_id + ).values('id'), + TaskType(instance.get('type')), + State.REVOKE + ) + return True + @transaction.atomic def delete(self): document_id = self.data.get("document_id") diff --git a/apps/knowledge/task/handler.py b/apps/knowledge/task/handler.py index 3d8eae735..6069c9aa0 100644 --- a/apps/knowledge/task/handler.py +++ b/apps/knowledge/task/handler.py @@ -11,14 +11,14 @@ from django.utils.translation import gettext_lazy as _ from common.utils.fork import ChildLink, Fork from common.utils.split_model import get_split_model from knowledge.models.knowledge import KnowledgeType, Document, Knowledge, Status -from knowledge.serializers.document import DocumentSerializers -from knowledge.serializers.paragraph import ParagraphSerializers max_kb_error = logging.getLogger("max_kb_error") max_kb = logging.getLogger("max_kb") def get_save_handler(knowledge_id, selector): + from knowledge.serializers.document import DocumentSerializers + def handler(child_link: ChildLink, response: Fork.Response): if response.status == 200: try: @@ -40,6 +40,8 @@ def get_save_handler(knowledge_id, selector): def get_sync_handler(knowledge_id): + from knowledge.serializers.document import DocumentSerializers + knowledge = QuerySet(Knowledge).filter(id=knowledge_id).first() def handler(child_link: ChildLink, response: Fork.Response): @@ -70,6 +72,8 @@ def get_sync_handler(knowledge_id): def get_sync_web_document_handler(knowledge_id): + from knowledge.serializers.document import DocumentSerializers + def handler(source_url: str, selector, response: Fork.Response): if response.status == 200: try: @@ -93,6 +97,8 @@ def get_sync_web_document_handler(knowledge_id): def save_problem(knowledge_id, document_id, paragraph_id, problem): + from knowledge.serializers.paragraph import ParagraphSerializers + # print(f"knowledge_id: {knowledge_id}") # print(f"document_id: {document_id}") # print(f"paragraph_id: {paragraph_id}") diff --git a/apps/knowledge/task/sync.py b/apps/knowledge/task/sync.py index b3bc8bb1f..63f5e26c6 100644 --- a/apps/knowledge/task/sync.py +++ b/apps/knowledge/task/sync.py @@ -16,7 +16,6 @@ from django.utils.translation import gettext_lazy as _ from common.utils.fork import ForkManage, Fork from ops import celery_app -from .handler import get_save_handler, get_sync_web_document_handler, get_sync_handler max_kb_error = logging.getLogger("max_kb_error") max_kb = logging.getLogger("max_kb") @@ -24,6 +23,8 @@ max_kb = logging.getLogger("max_kb") @celery_app.task(base=QueueOnce, once={'keys': ['knowledge_id']}, name='celery:sync_web_knowledge') def sync_web_knowledge(knowledge_id: str, url: str, selector: str): + from knowledge.task.handler import get_save_handler + try: max_kb.info( _('Start--->Start synchronization web knowledge base:{knowledge_id}').format(knowledge_id=knowledge_id)) @@ -39,6 +40,8 @@ def sync_web_knowledge(knowledge_id: str, url: str, selector: str): @celery_app.task(base=QueueOnce, once={'keys': ['knowledge_id']}, name='celery:sync_replace_web_knowledge') def sync_replace_web_knowledge(knowledge_id: str, url: str, selector: str): + from knowledge.task.handler import get_sync_handler + try: max_kb.info( _('Start--->Start synchronization web knowledge base:{knowledge_id}').format(knowledge_id=knowledge_id)) @@ -53,6 +56,8 @@ def sync_replace_web_knowledge(knowledge_id: str, url: str, selector: str): @celery_app.task(name='celery:sync_web_document') def sync_web_document(knowledge_id, source_url_list: List[str], selector: str): + from knowledge.task.handler import get_sync_web_document_handler + handler = get_sync_web_document_handler(knowledge_id) for source_url in source_url_list: try: diff --git a/apps/knowledge/urls.py b/apps/knowledge/urls.py index d85b7c32d..a69da97da 100644 --- a/apps/knowledge/urls.py +++ b/apps/knowledge/urls.py @@ -14,6 +14,11 @@ urlpatterns = [ path('workspace//knowledge//document/web', views.WebDocumentView.as_view()), path('workspace//knowledge//document/qa', views.QaDocumentView.as_view()), path('workspace//knowledge//document/table', views.TableDocumentView.as_view()), + path('workspace//knowledge//document/batch_hit_handling', views.DocumentView.BatchEditHitHandling.as_view()), path('workspace//knowledge//document/', views.DocumentView.Operate.as_view()), + path('workspace//knowledge//document//sync', views.DocumentView.SyncWeb.as_view()), + path('workspace//knowledge//document//refresh', views.DocumentView.Refresh.as_view()), + path('workspace//knowledge//document//cancel_task', views.DocumentView.CancelTask.as_view()), + path('workspace//knowledge//document//cancel_task/batch', views.DocumentView.BatchCancelTask.as_view()), path('workspace//knowledge//', views.KnowledgeView.Page.as_view()), ] diff --git a/apps/knowledge/views/document.py b/apps/knowledge/views/document.py index b12f1437d..34b242b22 100644 --- a/apps/knowledge/views/document.py +++ b/apps/knowledge/views/document.py @@ -10,7 +10,7 @@ from common.constants.permission_constants import PermissionConstants from common.result import result from knowledge.api.document import DocumentSplitAPI, DocumentBatchAPI, DocumentBatchCreateAPI, DocumentCreateAPI, \ DocumentReadAPI, DocumentEditAPI, DocumentDeleteAPI, TableDocumentCreateAPI, QaDocumentCreateAPI, \ - WebDocumentCreateAPI + WebDocumentCreateAPI, CancelTaskAPI, BatchCancelTaskAPI, SyncWebAPI, RefreshAPI, BatchEditHitHandlingAPI from knowledge.api.knowledge import KnowledgeTreeReadAPI from knowledge.serializers.document import DocumentSerializers @@ -140,6 +140,97 @@ class DocumentView(APIView): 'knowledge_id': knowledge_id, }).parse(split_data)) + class BatchEditHitHandling(APIView): + authentication_classes = [TokenAuth] + + @extend_schema( + methods=['PUT'], + summary=_('Modify document hit processing methods in batches'), + description=_('Modify document hit processing methods in batches'), + operation_id=_('Modify document hit processing methods in batches'), + request=BatchEditHitHandlingAPI.get_request(), + parameters=BatchEditHitHandlingAPI.get_parameters(), + responses=BatchEditHitHandlingAPI.get_response(), + tags=[_('Knowledge Base/Documentation')] + ) + @has_permissions(PermissionConstants.DOCUMENT_EDIT.get_workspace_permission()) + def put(self, request: Request, workspace_id: str, knowledge_id: str): + return result.success(DocumentSerializers.Batch( + data={'knowledge_id': knowledge_id, 'workspace_id': workspace_id} + ).batch_edit_hit_handling(request.data)) + + class SyncWeb(APIView): + authentication_classes = [TokenAuth] + + @extend_schema( + methods=['PUT'], + description=_('Synchronize web site types'), + summary=_('Synchronize web site types'), + operation_id=_('Synchronize web site types'), + parameters=SyncWebAPI.get_parameters(), + responses=SyncWebAPI.get_response(), + tags=[_('Knowledge Base/Documentation')] + ) + @has_permissions(PermissionConstants.DOCUMENT_EDIT.get_workspace_permission()) + def put(self, request: Request, workspace_id: str, knowledge_id: str, document_id: str): + return result.success(DocumentSerializers.Sync( + data={'document_id': document_id, 'knowledge_id': knowledge_id, 'workspace_id': workspace_id} + ).sync()) + + class Refresh(APIView): + authentication_classes = [TokenAuth] + + @extend_schema( + methods=['PUT'], + summary=_('Refresh document vector library'), + description=_('Refresh document vector library'), + operation_id=_('Refresh document vector library'), + parameters=RefreshAPI.get_parameters(), + request=RefreshAPI.get_request(), + responses=RefreshAPI.get_response(), + tags=[_('Knowledge Base/Documentation')] + ) + @has_permissions(PermissionConstants.DOCUMENT_EDIT.get_workspace_permission()) + def put(self, request: Request, workspace_id: str, knowledge_id: str, document_id: str): + return result.success(DocumentSerializers.Operate( + data={'document_id': document_id, 'knowledge_id': knowledge_id, 'workspace_id': workspace_id} + ).refresh(request.data.get('state_list'))) + + class CancelTask(APIView): + authentication_classes = [TokenAuth] + + @extend_schema( + summary=_('Cancel task'), + description=_('Cancel task'), + operation_id=_('Cancel task'), + parameters=CancelTaskAPI.get_parameters(), + request=CancelTaskAPI.get_request(), + responses=CancelTaskAPI.get_response(), + tags=[_('Knowledge Base/Documentation')] + ) + @has_permissions(PermissionConstants.DOCUMENT_EDIT.get_workspace_permission()) + def put(self, request: Request, workspace_id: str, knowledge_id: str, document_id: str): + return result.success(DocumentSerializers.Operate( + data={'document_id': document_id, 'knowledge_id': knowledge_id, 'workspace_id': workspace_id} + ).cancel(request.data)) + + class BatchCancelTask(APIView): + authentication_classes = [TokenAuth] + + @extend_schema( + summary=_('Cancel tasks in batches'), + description=_('Cancel tasks in batches'), + operation_id=_('Cancel tasks in batches'), + parameters=BatchCancelTaskAPI.get_parameters(), + request=BatchCancelTaskAPI.get_request(), + responses=BatchCancelTaskAPI.get_response(), + tags=[_('Knowledge Base/Documentation')] + ) + def put(self, request: Request, workspace_id: str, knowledge_id: str): + return result.success(DocumentSerializers.Batch(data={ + 'knowledge_id': knowledge_id, 'workspace_id': workspace_id} + ).batch_cancel(request.data)) + class Batch(APIView): authentication_classes = [TokenAuth]