diff --git a/apps/knowledge/api/document.py b/apps/knowledge/api/document.py index 14d53142b..fbed798b3 100644 --- a/apps/knowledge/api/document.py +++ b/apps/knowledge/api/document.py @@ -5,7 +5,8 @@ from common.mixins.api_mixin import APIMixin from common.result import DefaultResultSerializer from knowledge.serializers.common import BatchSerializer from knowledge.serializers.document import DocumentInstanceSerializer, DocumentWebInstanceSerializer, \ - CancelInstanceSerializer, BatchCancelInstanceSerializer, DocumentRefreshSerializer, BatchEditHitHandlingSerializer + CancelInstanceSerializer, BatchCancelInstanceSerializer, DocumentRefreshSerializer, BatchEditHitHandlingSerializer, \ + DocumentBatchRefreshSerializer, DocumentBatchGenerateRelatedSerializer class DocumentSplitAPI(APIMixin): @@ -356,3 +357,52 @@ class DocumentSplitPatternAPI(APIMixin): @staticmethod def get_response(): return DefaultResultSerializer + + +class BatchRefreshAPI(APIMixin): + @staticmethod + def get_parameters(): + return [ + OpenApiParameter( + name="workspace_id", + description="工作空间id", + type=OpenApiTypes.STR, + location='path', + required=True, + ), + OpenApiParameter( + name="knowledge_id", + description="知识库id", + type=OpenApiTypes.STR, + location='path', + required=True, + ), + ] + + @staticmethod + def get_request(): + return DocumentBatchRefreshSerializer + +class BatchGenerateRelatedAPI(APIMixin): + @staticmethod + def get_parameters(): + return [ + OpenApiParameter( + name="workspace_id", + description="工作空间id", + type=OpenApiTypes.STR, + location='path', + required=True, + ), + OpenApiParameter( + name="knowledge_id", + description="知识库id", + type=OpenApiTypes.STR, + location='path', + required=True, + ), + ] + + @staticmethod + def get_request(): + return DocumentBatchGenerateRelatedSerializer \ No newline at end of file diff --git a/apps/knowledge/serializers/document.py b/apps/knowledge/serializers/document.py index 814fbb269..8d6416fd5 100644 --- a/apps/knowledge/serializers/document.py +++ b/apps/knowledge/serializers/document.py @@ -44,6 +44,7 @@ from knowledge.serializers.paragraph import ParagraphSerializers, ParagraphInsta delete_problems_and_mappings from knowledge.task.embedding import embedding_by_document, delete_embedding_by_document_list, \ delete_embedding_by_document +from knowledge.task.generate import generate_related_by_document_id from knowledge.task.sync import sync_web_document from maxkb.const import PROJECT_DIR @@ -109,17 +110,17 @@ class DocumentEditInstanceSerializer(serializers.Serializer): @staticmethod def get_meta_valid_map(): - dataset_meta_valid_map = { + knowledge_meta_valid_map = { KnowledgeType.BASE: MetaSerializer.BaseMeta, KnowledgeType.WEB: MetaSerializer.WebMeta } - return dataset_meta_valid_map + return knowledge_meta_valid_map def is_valid(self, *, document: Document = None): super().is_valid(raise_exception=True) if 'meta' in self.data and self.data.get('meta') is not None: - dataset_meta_valid_map = self.get_meta_valid_map() - valid_class = dataset_meta_valid_map.get(document.type) + knowledge_meta_valid_map = self.get_meta_valid_map() + valid_class = knowledge_meta_valid_map.get(document.type) valid_class(data=self.data.get('meta')).is_valid(raise_exception=True) @@ -154,6 +155,18 @@ class DocumentRefreshSerializer(serializers.Serializer): state_list = serializers.ListField(required=True, label=_('state list')) +class DocumentBatchRefreshSerializer(serializers.Serializer): + id_list = serializers.ListField(required=True, label=_('id list')) + state_list = serializers.ListField(required=True, label=_('state list')) + + +class DocumentBatchGenerateRelatedSerializer(serializers.Serializer): + document_id_list = serializers.ListField(required=True, label=_('document id list')) + model_id = serializers.UUIDField(required=True, label=_('model id')) + prompt = serializers.CharField(required=True, label=_('prompt')) + state_list = serializers.ListField(required=True, label=_('state list')) + + class BatchEditHitHandlingSerializer(serializers.Serializer): id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True), label=_('id list')) hit_handling_method = serializers.CharField(required=True, label=_('hit handling method')) @@ -521,10 +534,10 @@ class DocumentSerializers(serializers.Serializer): if with_valid: DocumentWebInstanceSerializer(data=instance).is_valid(raise_exception=True) self.is_valid(raise_exception=True) - dataset_id = self.data.get('dataset_id') + knowledge_id = self.data.get('knowledge_id') source_url_list = instance.get('source_url_list') selector = instance.get('selector') - sync_web_document.delay(dataset_id, source_url_list, selector) + sync_web_document.delay(knowledge_id, source_url_list, selector) def save_qa(self, instance: Dict, with_valid=True): if with_valid: @@ -532,7 +545,8 @@ class DocumentSerializers(serializers.Serializer): self.is_valid(raise_exception=True) file_list = instance.get('file_list') document_list = flat_map([self.parse_qa_file(file) for file in file_list]) - return DocumentSerializers.Batch(data={'dataset_id': self.data.get('dataset_id')}).batch_save(document_list) + return DocumentSerializers.Batch(data={'knowledge_id': self.data.get('knowledge_id')}).batch_save( + document_list) def save_table(self, instance: Dict, with_valid=True): if with_valid: @@ -540,7 +554,8 @@ class DocumentSerializers(serializers.Serializer): self.is_valid(raise_exception=True) file_list = instance.get('file_list') document_list = flat_map([self.parse_table_file(file) for file in file_list]) - return DocumentSerializers.Batch(data={'dataset_id': self.data.get('dataset_id')}).batch_save(document_list) + return DocumentSerializers.Batch(data={'knowledge_id': self.data.get('knowledge_id')}).batch_save( + document_list) def parse_qa_file(self, file): get_buffer = FileBufferHandle().get_buffer @@ -788,6 +803,42 @@ class DocumentSerializers(serializers.Serializer): except AlreadyQueued as e: pass + class BatchGenerateRelated(serializers.Serializer): + workspace_id = serializers.CharField(required=True, label=_('workspace id')) + knowledge_id = serializers.UUIDField(required=True, label=_('knowledge id')) + + def batch_generate_related(self, instance: Dict, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + document_id_list = instance.get("document_id_list") + model_id = instance.get("model_id") + prompt = instance.get("prompt") + state_list = instance.get('state_list') + ListenerManagement.update_status( + QuerySet(Document).filter(id__in=document_id_list), + TaskType.GENERATE_PROBLEM, + State.PENDING + ) + ListenerManagement.update_status( + QuerySet(Paragraph).annotate( + reversed_status=Reverse('status'), + task_type_status=Substr('reversed_status', TaskType.GENERATE_PROBLEM.value, + 1), + ).filter( + task_type_status__in=state_list, document_id__in=document_id_list + ) + .values('id'), + TaskType.GENERATE_PROBLEM, + State.PENDING + ) + ListenerManagement.get_aggregation_document_status_by_query_set( + QuerySet(Document).filter(id__in=document_id_list))() + try: + for document_id in document_id_list: + generate_related_by_document_id.delay(document_id, model_id, prompt, state_list) + except AlreadyQueued as e: + pass + class FileBufferHandle: buffer = None diff --git a/apps/knowledge/urls.py b/apps/knowledge/urls.py index 12187e38a..1ec613a33 100644 --- a/apps/knowledge/urls.py +++ b/apps/knowledge/urls.py @@ -18,6 +18,8 @@ urlpatterns = [ path('workspace//knowledge//document/batch_create', views.DocumentView.BatchCreate.as_view()), path('workspace//knowledge//document/batch_sync', views.DocumentView.BatchSync.as_view()), path('workspace//knowledge//document/batch_delete', views.DocumentView.BatchDelete.as_view()), + path('workspace//knowledge//document/batch_refresh', views.DocumentView.BatchRefresh.as_view()), + path('workspace//knowledge//document/batch_generate_related', views.DocumentView.BatchGenerateRelated.as_view()), path('workspace//knowledge//document/web', views.WebDocumentView.as_view()), path('workspace//knowledge//document/qa', views.QaDocumentView.as_view()), path('workspace//knowledge//document/table', views.TableDocumentView.as_view()), diff --git a/apps/knowledge/views/document.py b/apps/knowledge/views/document.py index 2ab81aedc..51b617e78 100644 --- a/apps/knowledge/views/document.py +++ b/apps/knowledge/views/document.py @@ -11,7 +11,7 @@ from common.result import result from knowledge.api.document import DocumentSplitAPI, DocumentBatchAPI, DocumentBatchCreateAPI, DocumentCreateAPI, \ DocumentReadAPI, DocumentEditAPI, DocumentDeleteAPI, TableDocumentCreateAPI, QaDocumentCreateAPI, \ WebDocumentCreateAPI, CancelTaskAPI, BatchCancelTaskAPI, SyncWebAPI, RefreshAPI, BatchEditHitHandlingAPI, \ - DocumentTreeReadAPI, DocumentSplitPatternAPI + DocumentTreeReadAPI, DocumentSplitPatternAPI, BatchRefreshAPI, BatchGenerateRelatedAPI from knowledge.serializers.document import DocumentSerializers @@ -314,6 +314,49 @@ class DocumentView(APIView): data={'workspace_id': workspace_id, 'knowledge_id': knowledge_id} ).batch_delete(request.data)) + class BatchRefresh(APIView): + authentication_classes = [TokenAuth] + + @extend_schema( + methods=['PUT'], + summary=_('Batch refresh document vector library'), + operation_id=_('Batch refresh document vector library'), + request=BatchRefreshAPI.get_request(), + parameters=BatchRefreshAPI.get_parameters(), + responses=BatchRefreshAPI.get_response(), + tags=[_('Knowledge Base/Documentation')] + ) + @has_permissions([ + PermissionConstants.DOCUMENT_CREATE.get_workspace_permission(), + PermissionConstants.DOCUMENT_EDIT.get_workspace_permission(), + ]) + def put(self, request: Request, workspace_id: str, knowledge_id: str): + return result.success( + DocumentSerializers.Batch( + data={'workspace_id': workspace_id, 'knowledge_id': knowledge_id} + ).batch_refresh(request.data)) + + class BatchGenerateRelated(APIView): + authentication_classes = [TokenAuth] + + @extend_schema( + methods=['PUT'], + summary=_('Batch refresh document vector library'), + operation_id=_('Batch refresh document vector library'), + request=BatchGenerateRelatedAPI.get_request(), + parameters=BatchGenerateRelatedAPI.get_parameters(), + responses=BatchGenerateRelatedAPI.get_response(), + tags=[_('Knowledge Base/Documentation')] + ) + @has_permissions([ + PermissionConstants.DOCUMENT_CREATE.get_workspace_permission(), + PermissionConstants.DOCUMENT_EDIT.get_workspace_permission(), + ]) + def put(self, request: Request, workspace_id: str, knowledge_id: str): + return result.success(DocumentSerializers.BatchGenerateRelated( + data={'workspace_id': workspace_id, 'knowledge_id': knowledge_id} + ).batch_generate_related(request.data)) + class Page(APIView): authentication_classes = [TokenAuth]