From 614f17b70c80b452d18c9e344b09cf2d681e8a02 Mon Sep 17 00:00:00 2001 From: wxg0103 <727495428@qq.com> Date: Fri, 30 Aug 2024 18:45:25 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E7=9F=A5=E8=AF=86=E5=BA=93=E5=88=97?= =?UTF-8?q?=E8=A1=A8=E5=A2=9E=E5=8A=A0=E7=AD=9B=E9=80=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --story=1016081 --user=王孝刚 【知识库】- 文档列表支持按文档状态、启用状态、命中处理方式进行筛选 https://www.tapd.cn/57709429/s/1571832 --- .../serializers/document_serializers.py | 21 +++++++++++++++++++ apps/dataset/urls.py | 1 + apps/dataset/views/document.py | 21 +++++++++++++++++++ ui/src/enums/document.ts | 12 +++++++++++ 4 files changed, 55 insertions(+) diff --git a/apps/dataset/serializers/document_serializers.py b/apps/dataset/serializers/document_serializers.py index 0c29d348a..f88d45516 100644 --- a/apps/dataset/serializers/document_serializers.py +++ b/apps/dataset/serializers/document_serializers.py @@ -337,6 +337,8 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): error_messages=ErrMessage.char( "文档名称")) hit_handling_method = serializers.CharField(required=False, error_messages=ErrMessage.char("命中处理方式")) + is_active = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean("文档是否可用")) + status = serializers.CharField(required=False, error_messages=ErrMessage.char("文档状态")) def get_query_set(self): query_set = QuerySet(model=Document) @@ -345,6 +347,10 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): query_set = query_set.filter(**{'name__icontains': self.data.get('name')}) if 'hit_handling_method' in self.data and self.data.get('hit_handling_method') is not None: query_set = query_set.filter(**{'hit_handling_method': self.data.get('hit_handling_method')}) + if 'is_active' in self.data and self.data.get('is_active') is not None: + query_set = query_set.filter(**{'is_active': self.data.get('is_active')}) + if 'status' in self.data and self.data.get('status') is not None: + query_set = query_set.filter(**{'status': self.data.get('status')}) query_set = query_set.order_by('-create_time') return query_set @@ -936,6 +942,21 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): update_dict['directly_return_similarity'] = directly_return_similarity QuerySet(Document).filter(id__in=document_id_list).update(**update_dict) + def batch_refresh(self, instance: Dict, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + document_id_list = instance.get("id_list") + with transaction.atomic(): + Document.objects.filter(id__in=document_id_list).update(status=Status.queue_up) + Paragraph.objects.filter(document_id__in=document_id_list).update(status=Status.queue_up) + dataset_id = self.data.get('dataset_id') + embedding_model_id = get_embedding_model_id_by_dataset_id(dataset_id=dataset_id) + for document_id in document_id_list: + try: + embedding_by_document.delay(document_id, embedding_model_id) + except AlreadyQueued as e: + raise AppApiException(500, "任务正在执行中,请勿重复下发") + class FileBufferHandle: buffer = None diff --git a/apps/dataset/urls.py b/apps/dataset/urls.py index 026492d18..2068922ee 100644 --- a/apps/dataset/urls.py +++ b/apps/dataset/urls.py @@ -34,6 +34,7 @@ urlpatterns = [ name="document_export"), path('dataset//document//sync', views.Document.SyncWeb.as_view()), path('dataset//document//refresh', views.Document.Refresh.as_view()), + path('dataset//document/batch_refresh', views.Document.BatchRefresh.as_view()), path('dataset//document//paragraph', views.Paragraph.as_view()), path( 'dataset//document//paragraph/migrate/dataset//document/', diff --git a/apps/dataset/views/document.py b/apps/dataset/views/document.py index 1988ca75a..c2ef152a0 100644 --- a/apps/dataset/views/document.py +++ b/apps/dataset/views/document.py @@ -33,6 +33,7 @@ class Template(APIView): def get(self, request: Request): return DocumentSerializers.Export(data={'type': request.query_params.get('type')}).export(with_valid=True) + class TableTemplate(APIView): authentication_classes = [TokenAuth] @@ -82,6 +83,7 @@ class QaDocument(APIView): {'file_list': request.FILES.getlist('file')}, with_valid=True)) + class TableDocument(APIView): authentication_classes = [TokenAuth] parser_classes = [MultiPartParser] @@ -101,6 +103,7 @@ class TableDocument(APIView): {'file_list': request.FILES.getlist('file')}, with_valid=True)) + class Document(APIView): authentication_classes = [TokenAuth] @@ -233,6 +236,24 @@ class Document(APIView): DocumentSerializers.Operate(data={'document_id': document_id, 'dataset_id': dataset_id}).refresh( )) + class BatchRefresh(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['POST'], detail=False) + @swagger_auto_schema(operation_summary="批量刷新文档向量库", + operation_id="批量刷新文档向量库", + request_body= + DocumentApi.BatchEditHitHandlingApi.get_request_body_api(), + manual_parameters=DocumentSerializers.Create.get_request_params_api(), + responses=result.get_default_response(), + tags=["知识库/文档"]) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE, + dynamic_tag=k.get('dataset_id'))) + def put(self, request: Request, dataset_id: str): + return result.success( + DocumentSerializers.Batch(data={'dataset_id': dataset_id}).batch_refresh(request.data)) + class Migrate(APIView): authentication_classes = [TokenAuth] diff --git a/ui/src/enums/document.ts b/ui/src/enums/document.ts index 0a03c9c12..f3a7d24d7 100644 --- a/ui/src/enums/document.ts +++ b/ui/src/enums/document.ts @@ -2,3 +2,15 @@ export enum hitHandlingMethod { optimization = '模型优化', directly_return = '直接回答' } + +export enum hitStatus { + waiting = '等待中', + processing = '处理中', + completed = '已完成', + failed = '失败' +} + +export enum isActivated { + true = '启用', + false = '禁用' +}