fix: 【知识库/应用】关闭文档后,命中测试还是可以命中文档

This commit is contained in:
shaohuzhang1 2024-02-29 15:51:35 +08:00
parent 699cc0b084
commit 739fc9808c
4 changed files with 26 additions and 8 deletions

View File

@ -27,7 +27,7 @@ from common.db.search import get_dynamics_model, native_search, native_page_sear
from common.db.sql_execute import select_list
from common.exception.app_exception import AppApiException, NotFound404
from common.util.file_util import get_file_content
from dataset.models import DataSet
from dataset.models import DataSet, Document
from dataset.serializers.common_serializers import list_paragraph
from setting.models import AuthOperate
from setting.models.model_management import Model
@ -215,10 +215,16 @@ class ApplicationSerializer(serializers.Serializer):
def hit_test(self):
self.is_valid()
vector = VectorStore.get_embedding_vector()
dataset_id_list = [ad.dataset_id for ad in
QuerySet(ApplicationDatasetMapping).filter(
application_id=self.data.get('id'))]
exclude_document_id_list = [str(document.id) for document in
QuerySet(Document).filter(
dataset_id__in=dataset_id_list,
is_active=False)]
# 向量库检索
hit_list = vector.hit_test(self.data.get('query_text'), [ad.dataset_id for ad in
QuerySet(ApplicationDatasetMapping).filter(
application_id=self.data.get('id'))],
hit_list = vector.hit_test(self.data.get('query_text'), dataset_id_list, exclude_document_id_list,
self.data.get('top_number'),
self.data.get('similarity'),
EmbeddingModel.get_embedding_model())
@ -377,7 +383,8 @@ class ApplicationSerializer(serializers.Serializer):
application = QuerySet(Application).get(id=self.data.get("application_id"))
return select_list(get_file_content(
os.path.join(PROJECT_DIR, "apps", "application", 'sql', 'list_application_dataset.sql')),
[self.data.get('user_id') if self.data.get('user_id')==str(application.user_id) else None, application.user_id, self.data.get('user_id')])
[self.data.get('user_id') if self.data.get('user_id') == str(application.user_id) else None,
application.user_id, self.data.get('user_id')])
class ApplicationKeySerializerModel(serializers.ModelSerializer):
class Meta:

View File

@ -461,8 +461,13 @@ class DataSetSerializers(serializers.ModelSerializer):
def hit_test(self):
self.is_valid()
vector = VectorStore.get_embedding_vector()
exclude_document_id_list = [str(document.id) for document in
QuerySet(Document).filter(
dataset_id=self.data.get('id'),
is_active=False)]
# 向量库检索
hit_list = vector.hit_test(self.data.get('query_text'), [self.data.get('id')], self.data.get('top_number'),
hit_list = vector.hit_test(self.data.get('query_text'), [self.data.get('id')], exclude_document_id_list,
self.data.get('top_number'),
self.data.get('similarity'),
EmbeddingModel.get_embedding_model())
hit_dict = reduce(lambda x, y: {**x, **y}, [{hit.get('paragraph_id'): hit} for hit in hit_list], {})

View File

@ -118,7 +118,8 @@ class BaseVectorStore(ABC):
pass
@abstractmethod
def hit_test(self, query_text, dataset_id: list[str], top_number: int, similarity: float,
def hit_test(self, query_text, dataset_id: list[str], exclude_document_id_list: list[str], top_number: int,
similarity: float,
embedding: HuggingFaceEmbeddings):
pass

View File

@ -61,10 +61,15 @@ class PGVector(BaseVectorStore):
QuerySet(Embedding).bulk_create(embedding_list) if len(embedding_list) > 0 else None
return True
def hit_test(self, query_text, dataset_id_list: list[str], top_number: int, similarity: float,
def hit_test(self, query_text, dataset_id_list: list[str], exclude_document_id_list: list[str], top_number: int,
similarity: float,
embedding: HuggingFaceEmbeddings):
exclude_dict = {}
embedding_query = embedding.embed_query(query_text)
query_set = QuerySet(Embedding).filter(dataset_id__in=dataset_id_list, is_active=True)
if exclude_document_id_list is not None and len(exclude_document_id_list) > 0:
exclude_dict.__setitem__('document_id__in', exclude_document_id_list)
query_set = query_set.exclude(**exclude_dict)
exec_sql, exec_params = generate_sql_by_query_dict({'embedding_query': query_set},
select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "embedding", 'sql',