diff --git a/apps/application/serializers/application_serializers.py b/apps/application/serializers/application_serializers.py index 73fa0f173..58ff6d5be 100644 --- a/apps/application/serializers/application_serializers.py +++ b/apps/application/serializers/application_serializers.py @@ -27,7 +27,7 @@ from common.db.search import get_dynamics_model, native_search, native_page_sear from common.db.sql_execute import select_list from common.exception.app_exception import AppApiException, NotFound404 from common.util.file_util import get_file_content -from dataset.models import DataSet +from dataset.models import DataSet, Document from dataset.serializers.common_serializers import list_paragraph from setting.models import AuthOperate from setting.models.model_management import Model @@ -215,10 +215,16 @@ class ApplicationSerializer(serializers.Serializer): def hit_test(self): self.is_valid() vector = VectorStore.get_embedding_vector() + dataset_id_list = [ad.dataset_id for ad in + QuerySet(ApplicationDatasetMapping).filter( + application_id=self.data.get('id'))] + + exclude_document_id_list = [str(document.id) for document in + QuerySet(Document).filter( + dataset_id__in=dataset_id_list, + is_active=False)] # 向量库检索 - hit_list = vector.hit_test(self.data.get('query_text'), [ad.dataset_id for ad in - QuerySet(ApplicationDatasetMapping).filter( - application_id=self.data.get('id'))], + hit_list = vector.hit_test(self.data.get('query_text'), dataset_id_list, exclude_document_id_list, self.data.get('top_number'), self.data.get('similarity'), EmbeddingModel.get_embedding_model()) @@ -377,7 +383,8 @@ class ApplicationSerializer(serializers.Serializer): application = QuerySet(Application).get(id=self.data.get("application_id")) return select_list(get_file_content( os.path.join(PROJECT_DIR, "apps", "application", 'sql', 'list_application_dataset.sql')), - [self.data.get('user_id') if self.data.get('user_id')==str(application.user_id) else None, application.user_id, self.data.get('user_id')]) + [self.data.get('user_id') if self.data.get('user_id') == str(application.user_id) else None, + application.user_id, self.data.get('user_id')]) class ApplicationKeySerializerModel(serializers.ModelSerializer): class Meta: diff --git a/apps/dataset/serializers/dataset_serializers.py b/apps/dataset/serializers/dataset_serializers.py index bbe1f67c4..a4ba19749 100644 --- a/apps/dataset/serializers/dataset_serializers.py +++ b/apps/dataset/serializers/dataset_serializers.py @@ -461,8 +461,13 @@ class DataSetSerializers(serializers.ModelSerializer): def hit_test(self): self.is_valid() vector = VectorStore.get_embedding_vector() + exclude_document_id_list = [str(document.id) for document in + QuerySet(Document).filter( + dataset_id=self.data.get('id'), + is_active=False)] # 向量库检索 - hit_list = vector.hit_test(self.data.get('query_text'), [self.data.get('id')], self.data.get('top_number'), + hit_list = vector.hit_test(self.data.get('query_text'), [self.data.get('id')], exclude_document_id_list, + self.data.get('top_number'), self.data.get('similarity'), EmbeddingModel.get_embedding_model()) hit_dict = reduce(lambda x, y: {**x, **y}, [{hit.get('paragraph_id'): hit} for hit in hit_list], {}) diff --git a/apps/embedding/vector/base_vector.py b/apps/embedding/vector/base_vector.py index a8580e3c6..9b3518707 100644 --- a/apps/embedding/vector/base_vector.py +++ b/apps/embedding/vector/base_vector.py @@ -118,7 +118,8 @@ class BaseVectorStore(ABC): pass @abstractmethod - def hit_test(self, query_text, dataset_id: list[str], top_number: int, similarity: float, + def hit_test(self, query_text, dataset_id: list[str], exclude_document_id_list: list[str], top_number: int, + similarity: float, embedding: HuggingFaceEmbeddings): pass diff --git a/apps/embedding/vector/pg_vector.py b/apps/embedding/vector/pg_vector.py index bb49923c2..fd5378aea 100644 --- a/apps/embedding/vector/pg_vector.py +++ b/apps/embedding/vector/pg_vector.py @@ -61,10 +61,15 @@ class PGVector(BaseVectorStore): QuerySet(Embedding).bulk_create(embedding_list) if len(embedding_list) > 0 else None return True - def hit_test(self, query_text, dataset_id_list: list[str], top_number: int, similarity: float, + def hit_test(self, query_text, dataset_id_list: list[str], exclude_document_id_list: list[str], top_number: int, + similarity: float, embedding: HuggingFaceEmbeddings): + exclude_dict = {} embedding_query = embedding.embed_query(query_text) query_set = QuerySet(Embedding).filter(dataset_id__in=dataset_id_list, is_active=True) + if exclude_document_id_list is not None and len(exclude_document_id_list) > 0: + exclude_dict.__setitem__('document_id__in', exclude_document_id_list) + query_set = query_set.exclude(**exclude_dict) exec_sql, exec_params = generate_sql_by_query_dict({'embedding_query': query_set}, select_string=get_file_content( os.path.join(PROJECT_DIR, "apps", "embedding", 'sql',