mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 01:33:05 +00:00
fix: 【知识库/应用】关闭文档后,命中测试还是可以命中文档
This commit is contained in:
parent
699cc0b084
commit
739fc9808c
|
|
@ -27,7 +27,7 @@ from common.db.search import get_dynamics_model, native_search, native_page_sear
|
|||
from common.db.sql_execute import select_list
|
||||
from common.exception.app_exception import AppApiException, NotFound404
|
||||
from common.util.file_util import get_file_content
|
||||
from dataset.models import DataSet
|
||||
from dataset.models import DataSet, Document
|
||||
from dataset.serializers.common_serializers import list_paragraph
|
||||
from setting.models import AuthOperate
|
||||
from setting.models.model_management import Model
|
||||
|
|
@ -215,10 +215,16 @@ class ApplicationSerializer(serializers.Serializer):
|
|||
def hit_test(self):
|
||||
self.is_valid()
|
||||
vector = VectorStore.get_embedding_vector()
|
||||
dataset_id_list = [ad.dataset_id for ad in
|
||||
QuerySet(ApplicationDatasetMapping).filter(
|
||||
application_id=self.data.get('id'))]
|
||||
|
||||
exclude_document_id_list = [str(document.id) for document in
|
||||
QuerySet(Document).filter(
|
||||
dataset_id__in=dataset_id_list,
|
||||
is_active=False)]
|
||||
# 向量库检索
|
||||
hit_list = vector.hit_test(self.data.get('query_text'), [ad.dataset_id for ad in
|
||||
QuerySet(ApplicationDatasetMapping).filter(
|
||||
application_id=self.data.get('id'))],
|
||||
hit_list = vector.hit_test(self.data.get('query_text'), dataset_id_list, exclude_document_id_list,
|
||||
self.data.get('top_number'),
|
||||
self.data.get('similarity'),
|
||||
EmbeddingModel.get_embedding_model())
|
||||
|
|
@ -377,7 +383,8 @@ class ApplicationSerializer(serializers.Serializer):
|
|||
application = QuerySet(Application).get(id=self.data.get("application_id"))
|
||||
return select_list(get_file_content(
|
||||
os.path.join(PROJECT_DIR, "apps", "application", 'sql', 'list_application_dataset.sql')),
|
||||
[self.data.get('user_id') if self.data.get('user_id')==str(application.user_id) else None, application.user_id, self.data.get('user_id')])
|
||||
[self.data.get('user_id') if self.data.get('user_id') == str(application.user_id) else None,
|
||||
application.user_id, self.data.get('user_id')])
|
||||
|
||||
class ApplicationKeySerializerModel(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
|
|
|
|||
|
|
@ -461,8 +461,13 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||
def hit_test(self):
|
||||
self.is_valid()
|
||||
vector = VectorStore.get_embedding_vector()
|
||||
exclude_document_id_list = [str(document.id) for document in
|
||||
QuerySet(Document).filter(
|
||||
dataset_id=self.data.get('id'),
|
||||
is_active=False)]
|
||||
# 向量库检索
|
||||
hit_list = vector.hit_test(self.data.get('query_text'), [self.data.get('id')], self.data.get('top_number'),
|
||||
hit_list = vector.hit_test(self.data.get('query_text'), [self.data.get('id')], exclude_document_id_list,
|
||||
self.data.get('top_number'),
|
||||
self.data.get('similarity'),
|
||||
EmbeddingModel.get_embedding_model())
|
||||
hit_dict = reduce(lambda x, y: {**x, **y}, [{hit.get('paragraph_id'): hit} for hit in hit_list], {})
|
||||
|
|
|
|||
|
|
@ -118,7 +118,8 @@ class BaseVectorStore(ABC):
|
|||
pass
|
||||
|
||||
@abstractmethod
|
||||
def hit_test(self, query_text, dataset_id: list[str], top_number: int, similarity: float,
|
||||
def hit_test(self, query_text, dataset_id: list[str], exclude_document_id_list: list[str], top_number: int,
|
||||
similarity: float,
|
||||
embedding: HuggingFaceEmbeddings):
|
||||
pass
|
||||
|
||||
|
|
|
|||
|
|
@ -61,10 +61,15 @@ class PGVector(BaseVectorStore):
|
|||
QuerySet(Embedding).bulk_create(embedding_list) if len(embedding_list) > 0 else None
|
||||
return True
|
||||
|
||||
def hit_test(self, query_text, dataset_id_list: list[str], top_number: int, similarity: float,
|
||||
def hit_test(self, query_text, dataset_id_list: list[str], exclude_document_id_list: list[str], top_number: int,
|
||||
similarity: float,
|
||||
embedding: HuggingFaceEmbeddings):
|
||||
exclude_dict = {}
|
||||
embedding_query = embedding.embed_query(query_text)
|
||||
query_set = QuerySet(Embedding).filter(dataset_id__in=dataset_id_list, is_active=True)
|
||||
if exclude_document_id_list is not None and len(exclude_document_id_list) > 0:
|
||||
exclude_dict.__setitem__('document_id__in', exclude_document_id_list)
|
||||
query_set = query_set.exclude(**exclude_dict)
|
||||
exec_sql, exec_params = generate_sql_by_query_dict({'embedding_query': query_set},
|
||||
select_string=get_file_content(
|
||||
os.path.join(PROJECT_DIR, "apps", "embedding", 'sql',
|
||||
|
|
|
|||
Loading…
Reference in New Issue