From d5c3f04ce4a158d18fdaad08897e37a55da3cf9b Mon Sep 17 00:00:00 2001 From: shaohuzhang1 Date: Wed, 13 Dec 2023 18:03:57 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81=20=E6=96=87=E6=A1=A3?= =?UTF-8?q?=E5=88=B7=E6=96=B0=20->=20=E9=87=8D=E5=86=99=E5=90=91=E9=87=8F?= =?UTF-8?q?=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/common/event/listener_manage.py | 76 ++++++++++++------- .../serializers/document_serializers.py | 7 ++ apps/dataset/urls.py | 1 + apps/dataset/views/document.py | 18 +++++ 4 files changed, 73 insertions(+), 29 deletions(-) diff --git a/apps/common/event/listener_manage.py b/apps/common/event/listener_manage.py index 5f84ad335..068e42444 100644 --- a/apps/common/event/listener_manage.py +++ b/apps/common/event/listener_manage.py @@ -46,15 +46,21 @@ class ListenerManagement: :param paragraph_id: 段落id :return: None """ - data_list = native_search( - {'problem': QuerySet(get_dynamics_model({'problem.paragraph_id': django.db.models.CharField()})).filter( - **{'problem.paragraph_id': paragraph_id}), - 'paragraph': QuerySet(Paragraph).filter(id=paragraph_id)}, - select_string=get_file_content( - os.path.join(PROJECT_DIR, "apps", "common", 'sql', 'list_embedding_text.sql'))) - # 批量向量化 - VectorStore.get_embedding_vector().batch_save(data_list) - QuerySet(Paragraph).filter(id=paragraph_id).update(**{'status': Status.success.value}) + status = Status.success + try: + data_list = native_search( + {'problem': QuerySet(get_dynamics_model({'problem.paragraph_id': django.db.models.CharField()})).filter( + **{'problem.paragraph_id': paragraph_id}), + 'paragraph': QuerySet(Paragraph).filter(id=paragraph_id)}, + select_string=get_file_content( + os.path.join(PROJECT_DIR, "apps", "common", 'sql', 'list_embedding_text.sql'))) + # 删除段落 + VectorStore.get_embedding_vector().delete_by_paragraph_id(paragraph_id) + # 批量向量化 + VectorStore.get_embedding_vector().batch_save(data_list) + except Exception as e: + status = Status.error + QuerySet(Paragraph).filter(id=paragraph_id).update(**{'status': status}) @staticmethod @poxy @@ -64,17 +70,23 @@ class ListenerManagement: :param document_id: 文档id :return: None """ - data_list = native_search( - {'problem': QuerySet(get_dynamics_model({'problem.document_id': django.db.models.CharField()})).filter( - **{'problem.document_id': document_id}), - 'paragraph': QuerySet(Paragraph).filter(document_id=document_id)}, - select_string=get_file_content( - os.path.join(PROJECT_DIR, "apps", "common", 'sql', 'list_embedding_text.sql'))) - # 批量向量化 - VectorStore.get_embedding_vector().batch_save(data_list) + status = Status.success + try: + data_list = native_search( + {'problem': QuerySet(get_dynamics_model({'problem.document_id': django.db.models.CharField()})).filter( + **{'problem.document_id': document_id}), + 'paragraph': QuerySet(Paragraph).filter(document_id=document_id)}, + select_string=get_file_content( + os.path.join(PROJECT_DIR, "apps", "common", 'sql', 'list_embedding_text.sql'))) + # 删除文档向量数据 + VectorStore.get_embedding_vector().delete_by_document_id(document_id) + # 批量向量化 + VectorStore.get_embedding_vector().batch_save(data_list) + except Exception as e: + status = Status.error # 修改状态 - QuerySet(Document).filter(id=document_id).update(**{'status': Status.success.value}) - QuerySet(Paragraph).filter(document_id=document_id).update(**{'status': Status.success.value}) + QuerySet(Document).filter(id=document_id).update(**{'status': status}) + QuerySet(Paragraph).filter(document_id=document_id).update(**{'status': status}) @staticmethod @poxy @@ -84,17 +96,23 @@ class ListenerManagement: :param dataset_id: 数据集id :return: None """ - data_list = native_search( - {'problem': QuerySet(get_dynamics_model({'problem.dataset_id': django.db.models.CharField()})).filter( - **{'problem.dataset_id': dataset_id}), - 'paragraph': QuerySet(Paragraph).filter(dataset_id=dataset_id)}, - select_string=get_file_content( - os.path.join(PROJECT_DIR, "apps", "common", 'sql', 'list_embedding_text.sql'))) - # 批量向量化 - VectorStore.get_embedding_vector().batch_save(data_list) + status = Status.success + try: + data_list = native_search( + {'problem': QuerySet(get_dynamics_model({'problem.dataset_id': django.db.models.CharField()})).filter( + **{'problem.dataset_id': dataset_id}), + 'paragraph': QuerySet(Paragraph).filter(dataset_id=dataset_id)}, + select_string=get_file_content( + os.path.join(PROJECT_DIR, "apps", "common", 'sql', 'list_embedding_text.sql'))) + # 删除数据集相关向量数据 + VectorStore.get_embedding_vector().delete_by_dataset_id(dataset_id) + # 批量向量化 + VectorStore.get_embedding_vector().batch_save(data_list) + except Exception as e: + status = Status.error # 修改文档 以及段落的状态 - QuerySet(Document).filter(dataset_id=dataset_id).update(**{'status': Status.success.value}) - QuerySet(Paragraph).filter(dataset_id=dataset_id).update(**{'status': Status.success.value}) + QuerySet(Document).filter(dataset_id=dataset_id).update(**{'status': status}) + QuerySet(Paragraph).filter(dataset_id=dataset_id).update(**{'status': status}) @staticmethod def delete_embedding_by_document(document_id): diff --git a/apps/dataset/serializers/document_serializers.py b/apps/dataset/serializers/document_serializers.py index b80301975..97647c5f5 100644 --- a/apps/dataset/serializers/document_serializers.py +++ b/apps/dataset/serializers/document_serializers.py @@ -140,6 +140,13 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): _document.save() return self.one() + def refresh(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + document_id = self.data.get("document_id") + ListenerManagement.embedding_by_document_signal.send(document_id) + return True + @transaction.atomic def delete(self): document_id = self.data.get("document_id") diff --git a/apps/dataset/urls.py b/apps/dataset/urls.py index e1ef1b154..bceccf274 100644 --- a/apps/dataset/urls.py +++ b/apps/dataset/urls.py @@ -17,6 +17,7 @@ urlpatterns = [ name="document_operate"), path('dataset/document/split_pattern', views.Document.SplitPattern.as_view(), name="document_operate"), + path('dataset//document//refresh', views.Document.Refresh.as_view()), path('dataset//document//paragraph', views.Paragraph.as_view()), path('dataset//document//paragraph//', views.Paragraph.Page.as_view(), name='paragraph_page'), diff --git a/apps/dataset/views/document.py b/apps/dataset/views/document.py index 382d5cc43..566ee1124 100644 --- a/apps/dataset/views/document.py +++ b/apps/dataset/views/document.py @@ -70,6 +70,24 @@ class Document(APIView): def post(self, request: Request, dataset_id: str): return result.success(DocumentSerializers.Batch(data={'dataset_id': dataset_id}).batch_save(request.data)) + class Refresh(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['PUT'], detail=False) + @swagger_auto_schema(operation_summary="刷新文档向量库", + operation_id="刷新文档向量库", + manual_parameters=DocumentSerializers.Operate.get_request_params_api(), + responses=result.get_default_response(), + tags=["数据集/文档"] + ) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE, + dynamic_tag=k.get('dataset_id'))) + def put(self, request: Request, dataset_id: str, document_id: str): + return result.success( + DocumentSerializers.Operate(data={'document_id': document_id, 'dataset_id': dataset_id}).refresh( + )) + class Operate(APIView): authentication_classes = [TokenAuth]