From 67e6138066cf36e7fd8ca3958f87de811261e7fa Mon Sep 17 00:00:00 2001 From: shaohuzhang1 Date: Tue, 23 Jan 2024 15:52:15 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E6=89=B9=E9=87=8F=E5=88=A0=E9=99=A4?= =?UTF-8?q?=E6=96=87=E6=A1=A3,=E6=9C=AA=E5=88=A0=E9=99=A4=E5=85=B3?= =?UTF-8?q?=E8=81=94=E6=AE=B5=E8=90=BD=E4=BF=A1=E6=81=AF,=20=E6=B7=BB?= =?UTF-8?q?=E5=8A=A0=E5=85=B3=E8=81=94=E9=97=AE=E9=A2=98=E6=8A=A5=E9=94=99?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/common/event/listener_manage.py | 7 +++++++ apps/dataset/serializers/document_serializers.py | 7 ++++++- apps/dataset/serializers/problem_serializers.py | 2 ++ apps/embedding/vector/base_vector.py | 11 +++++------ apps/embedding/vector/pg_vector.py | 3 +++ 5 files changed, 23 insertions(+), 7 deletions(-) diff --git a/apps/common/event/listener_manage.py b/apps/common/event/listener_manage.py index 66fe14320..14b371bfd 100644 --- a/apps/common/event/listener_manage.py +++ b/apps/common/event/listener_manage.py @@ -50,6 +50,7 @@ class ListenerManagement: embedding_by_dataset_signal = signal("embedding_by_dataset") embedding_by_document_signal = signal("embedding_by_document") delete_embedding_by_document_signal = signal("delete_embedding_by_document") + delete_embedding_by_document_list_signal = signal("delete_embedding_by_document_list") delete_embedding_by_dataset_signal = signal("delete_embedding_by_dataset") delete_embedding_by_paragraph_signal = signal("delete_embedding_by_paragraph") delete_embedding_by_source_signal = signal("delete_embedding_by_source") @@ -144,6 +145,10 @@ class ListenerManagement: def delete_embedding_by_document(document_id): VectorStore.get_embedding_vector().delete_by_document_id(document_id) + @staticmethod + def delete_embedding_by_document_list(document_id_list: List[str]): + VectorStore.get_embedding_vector().delete_bu_document_id_list(document_id_list) + @staticmethod def delete_embedding_by_dataset(dataset_id): VectorStore.get_embedding_vector().delete_by_dataset_id(dataset_id) @@ -201,6 +206,8 @@ class ListenerManagement: self.embedding_by_document) # 删除 向量 根据文档 ListenerManagement.delete_embedding_by_document_signal.connect(self.delete_embedding_by_document) + # 删除 向量 根据文档id列表 + ListenerManagement.delete_embedding_by_document_list_signal.connect(self.delete_embedding_by_document_list) # 删除 向量 根据知识库id ListenerManagement.delete_embedding_by_dataset_signal.connect(self.delete_embedding_by_dataset) # 删除向量 根据段落id diff --git a/apps/dataset/serializers/document_serializers.py b/apps/dataset/serializers/document_serializers.py index dd32a7f4a..d92b26825 100644 --- a/apps/dataset/serializers/document_serializers.py +++ b/apps/dataset/serializers/document_serializers.py @@ -547,7 +547,12 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): if with_valid: BatchSerializer(data=instance).is_valid(model=Document, raise_exception=True) self.is_valid(raise_exception=True) - QuerySet(Document).filter(id__in=instance.get('id_list')).delete() + document_id_list = instance.get("id_list") + QuerySet(Document).filter(id__in=document_id_list).delete() + QuerySet(Paragraph).filter(document_id__in=document_id_list).delete() + QuerySet(Problem).filter(document_id__in=document_id_list).delete() + # 删除向量库 + ListenerManagement.delete_embedding_by_document_list_signal.send(document_id_list) return True diff --git a/apps/dataset/serializers/problem_serializers.py b/apps/dataset/serializers/problem_serializers.py index 2f5fdc0a4..25e35dfbd 100644 --- a/apps/dataset/serializers/problem_serializers.py +++ b/apps/dataset/serializers/problem_serializers.py @@ -9,6 +9,7 @@ import uuid from typing import Dict +from django.db import transaction from django.db.models import QuerySet from drf_yasg import openapi from rest_framework import serializers @@ -61,6 +62,7 @@ class ProblemSerializers(ApiMixin, serializers.Serializer): dataset_id=self.data.get('dataset_id')).exists(): raise AppApiException(500, "段落id不正确") + @transaction.atomic def save(self, instance: Dict, with_valid=True, with_embedding=True): if with_valid: self.is_valid() diff --git a/apps/embedding/vector/base_vector.py b/apps/embedding/vector/base_vector.py index ce96aa3df..a8580e3c6 100644 --- a/apps/embedding/vector/base_vector.py +++ b/apps/embedding/vector/base_vector.py @@ -51,8 +51,6 @@ class BaseVectorStore(ABC): def save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str, is_active: bool, - star_num: int, - trample_num: int, embedding=None): """ 插入向量数据 @@ -64,16 +62,13 @@ class BaseVectorStore(ABC): :param is_active: 是否禁用 :param embedding: 向量化处理器 :param paragraph_id 段落id - :param star_num 点赞数量 - :param trample_num 点踩数量 :return: bool """ if embedding is None: embedding = EmbeddingModel.get_embedding_model() self.save_pre_handler() - self._save(text, source_type, dataset_id, document_id, paragraph_id, source_id, is_active, star_num, - trample_num, embedding) + self._save(text, source_type, dataset_id, document_id, paragraph_id, source_id, is_active, embedding) def batch_save(self, data_list: List[Dict], embedding=None): # 获取锁 @@ -143,6 +138,10 @@ class BaseVectorStore(ABC): def delete_by_document_id(self, document_id: str): pass + @abstractmethod + def delete_bu_document_id_list(self, document_id_list: List[str]): + pass + @abstractmethod def delete_by_source_id(self, source_id: str, source_type: str): pass diff --git a/apps/embedding/vector/pg_vector.py b/apps/embedding/vector/pg_vector.py index 0091e31da..bb49923c2 100644 --- a/apps/embedding/vector/pg_vector.py +++ b/apps/embedding/vector/pg_vector.py @@ -107,6 +107,9 @@ class PGVector(BaseVectorStore): QuerySet(Embedding).filter(document_id=document_id).delete() return True + def delete_bu_document_id_list(self, document_id_list: List[str]): + return QuerySet(Embedding).filter(document_id__in=document_id_list).delete() + def delete_by_source_id(self, source_id: str, source_type: str): QuerySet(Embedding).filter(source_id=source_id, source_type=source_type).delete() return True