From 739fc9808ceff8ce034973c1192399a15e1754a9 Mon Sep 17 00:00:00 2001 From: shaohuzhang1 Date: Thu, 29 Feb 2024 15:51:35 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E3=80=90=E7=9F=A5=E8=AF=86=E5=BA=93/?= =?UTF-8?q?=E5=BA=94=E7=94=A8=E3=80=91=E5=85=B3=E9=97=AD=E6=96=87=E6=A1=A3?= =?UTF-8?q?=E5=90=8E=EF=BC=8C=E5=91=BD=E4=B8=AD=E6=B5=8B=E8=AF=95=E8=BF=98?= =?UTF-8?q?=E6=98=AF=E5=8F=AF=E4=BB=A5=E5=91=BD=E4=B8=AD=E6=96=87=E6=A1=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../serializers/application_serializers.py | 17 ++++++++++++----- apps/dataset/serializers/dataset_serializers.py | 7 ++++++- apps/embedding/vector/base_vector.py | 3 ++- apps/embedding/vector/pg_vector.py | 7 ++++++- 4 files changed, 26 insertions(+), 8 deletions(-) diff --git a/apps/application/serializers/application_serializers.py b/apps/application/serializers/application_serializers.py index 73fa0f173..58ff6d5be 100644 --- a/apps/application/serializers/application_serializers.py +++ b/apps/application/serializers/application_serializers.py @@ -27,7 +27,7 @@ from common.db.search import get_dynamics_model, native_search, native_page_sear from common.db.sql_execute import select_list from common.exception.app_exception import AppApiException, NotFound404 from common.util.file_util import get_file_content -from dataset.models import DataSet +from dataset.models import DataSet, Document from dataset.serializers.common_serializers import list_paragraph from setting.models import AuthOperate from setting.models.model_management import Model @@ -215,10 +215,16 @@ class ApplicationSerializer(serializers.Serializer): def hit_test(self): self.is_valid() vector = VectorStore.get_embedding_vector() + dataset_id_list = [ad.dataset_id for ad in + QuerySet(ApplicationDatasetMapping).filter( + application_id=self.data.get('id'))] + + exclude_document_id_list = [str(document.id) for document in + QuerySet(Document).filter( + dataset_id__in=dataset_id_list, + is_active=False)] # 向量库检索 - hit_list = vector.hit_test(self.data.get('query_text'), [ad.dataset_id for ad in - QuerySet(ApplicationDatasetMapping).filter( - application_id=self.data.get('id'))], + hit_list = vector.hit_test(self.data.get('query_text'), dataset_id_list, exclude_document_id_list, self.data.get('top_number'), self.data.get('similarity'), EmbeddingModel.get_embedding_model()) @@ -377,7 +383,8 @@ class ApplicationSerializer(serializers.Serializer): application = QuerySet(Application).get(id=self.data.get("application_id")) return select_list(get_file_content( os.path.join(PROJECT_DIR, "apps", "application", 'sql', 'list_application_dataset.sql')), - [self.data.get('user_id') if self.data.get('user_id')==str(application.user_id) else None, application.user_id, self.data.get('user_id')]) + [self.data.get('user_id') if self.data.get('user_id') == str(application.user_id) else None, + application.user_id, self.data.get('user_id')]) class ApplicationKeySerializerModel(serializers.ModelSerializer): class Meta: diff --git a/apps/dataset/serializers/dataset_serializers.py b/apps/dataset/serializers/dataset_serializers.py index bbe1f67c4..a4ba19749 100644 --- a/apps/dataset/serializers/dataset_serializers.py +++ b/apps/dataset/serializers/dataset_serializers.py @@ -461,8 +461,13 @@ class DataSetSerializers(serializers.ModelSerializer): def hit_test(self): self.is_valid() vector = VectorStore.get_embedding_vector() + exclude_document_id_list = [str(document.id) for document in + QuerySet(Document).filter( + dataset_id=self.data.get('id'), + is_active=False)] # 向量库检索 - hit_list = vector.hit_test(self.data.get('query_text'), [self.data.get('id')], self.data.get('top_number'), + hit_list = vector.hit_test(self.data.get('query_text'), [self.data.get('id')], exclude_document_id_list, + self.data.get('top_number'), self.data.get('similarity'), EmbeddingModel.get_embedding_model()) hit_dict = reduce(lambda x, y: {**x, **y}, [{hit.get('paragraph_id'): hit} for hit in hit_list], {}) diff --git a/apps/embedding/vector/base_vector.py b/apps/embedding/vector/base_vector.py index a8580e3c6..9b3518707 100644 --- a/apps/embedding/vector/base_vector.py +++ b/apps/embedding/vector/base_vector.py @@ -118,7 +118,8 @@ class BaseVectorStore(ABC): pass @abstractmethod - def hit_test(self, query_text, dataset_id: list[str], top_number: int, similarity: float, + def hit_test(self, query_text, dataset_id: list[str], exclude_document_id_list: list[str], top_number: int, + similarity: float, embedding: HuggingFaceEmbeddings): pass diff --git a/apps/embedding/vector/pg_vector.py b/apps/embedding/vector/pg_vector.py index bb49923c2..fd5378aea 100644 --- a/apps/embedding/vector/pg_vector.py +++ b/apps/embedding/vector/pg_vector.py @@ -61,10 +61,15 @@ class PGVector(BaseVectorStore): QuerySet(Embedding).bulk_create(embedding_list) if len(embedding_list) > 0 else None return True - def hit_test(self, query_text, dataset_id_list: list[str], top_number: int, similarity: float, + def hit_test(self, query_text, dataset_id_list: list[str], exclude_document_id_list: list[str], top_number: int, + similarity: float, embedding: HuggingFaceEmbeddings): + exclude_dict = {} embedding_query = embedding.embed_query(query_text) query_set = QuerySet(Embedding).filter(dataset_id__in=dataset_id_list, is_active=True) + if exclude_document_id_list is not None and len(exclude_document_id_list) > 0: + exclude_dict.__setitem__('document_id__in', exclude_document_id_list) + query_set = query_set.exclude(**exclude_dict) exec_sql, exec_params = generate_sql_by_query_dict({'embedding_query': query_set}, select_string=get_file_content( os.path.join(PROJECT_DIR, "apps", "embedding", 'sql',