From f70f189afff4c4864139d9aa2da6608e66106731 Mon Sep 17 00:00:00 2001 From: CaptainB Date: Fri, 25 Jul 2025 10:52:53 +0800 Subject: [PATCH] feat: add dataset index creation and deletion functions --- apps/common/event/listener_manage.py | 3 ++ .../dataset/serializers/common_serializers.py | 44 ++++++++++++++++++- .../serializers/dataset_serializers.py | 3 +- apps/embedding/sql/blend_search.sql | 8 ++-- apps/embedding/sql/embedding_search.sql | 6 +-- apps/embedding/task/embedding.py | 2 + apps/embedding/vector/pg_vector.py | 29 ++++++++---- 7 files changed, 78 insertions(+), 17 deletions(-) diff --git a/apps/common/event/listener_manage.py b/apps/common/event/listener_manage.py index dd2a54a0c..6899c31f3 100644 --- a/apps/common/event/listener_manage.py +++ b/apps/common/event/listener_manage.py @@ -24,6 +24,7 @@ from common.util.file_util import get_file_content from common.util.lock import try_lock, un_lock from common.util.page_utils import page_desc from dataset.models import Paragraph, Status, Document, ProblemParagraphMapping, TaskType, State +from dataset.serializers.common_serializers import create_dataset_index from embedding.models import SourceType, SearchMode from smartdoc.conf import PROJECT_DIR from django.utils.translation import gettext_lazy as _ @@ -281,6 +282,8 @@ class ListenerManagement: ListenerManagement.get_aggregation_document_status( document_id)), is_the_task_interrupted) + # 检查是否存在索引 + create_dataset_index(document_id=document_id) except Exception as e: max_kb_error.error(_('Vectorized document: {document_id} error {error} {traceback}').format( document_id=document_id, error=str(e), traceback=traceback.format_exc())) diff --git a/apps/dataset/serializers/common_serializers.py b/apps/dataset/serializers/common_serializers.py index 856f3da15..edf064236 100644 --- a/apps/dataset/serializers/common_serializers.py +++ b/apps/dataset/serializers/common_serializers.py @@ -18,13 +18,13 @@ from rest_framework import serializers from common.config.embedding_config import ModelManage from common.db.search import native_search -from common.db.sql_execute import update_execute +from common.db.sql_execute import update_execute, sql_execute from common.exception.app_exception import AppApiException from common.mixins.api_mixin import ApiMixin from common.util.field_message import ErrMessage from common.util.file_util import get_file_content from common.util.fork import Fork -from dataset.models import Paragraph, Problem, ProblemParagraphMapping, DataSet, File, Image +from dataset.models import Paragraph, Problem, ProblemParagraphMapping, DataSet, File, Image, Document from setting.models_provider import get_model from smartdoc.conf import PROJECT_DIR from django.utils.translation import gettext_lazy as _ @@ -224,6 +224,46 @@ def get_embedding_model_id_by_dataset_id_list(dataset_id_list: List): return str(dataset_list[0].embedding_mode_id) + +def create_dataset_index(dataset_id=None, document_id=None): + if dataset_id is None and document_id is None: + raise AppApiException(500, _('Dataset ID or Document ID must be provided')) + + if dataset_id is not None: + k_id = dataset_id + else: + document = QuerySet(Document).filter(id=document_id).first() + k_id = document.dataset_id + + sql = f"SELECT indexname, indexdef FROM pg_indexes WHERE tablename = 'embedding' AND indexname = 'embedding_hnsw_idx_{k_id}'" + index = sql_execute(sql, []) + if not index: + sql = f"SELECT vector_dims(embedding) AS dims FROM embedding WHERE dataset_id = '{k_id}' LIMIT 1" + result = sql_execute(sql, []) + if len(result) == 0: + return + dims = result[0]['dims'] + sql = f"""CREATE INDEX "embedding_hnsw_idx_{k_id}" ON embedding USING hnsw ((embedding::vector({dims})) vector_cosine_ops) WHERE dataset_id = '{k_id}'""" + update_execute(sql, []) + + +def drop_dataset_index(dataset_id=None, document_id=None): + if dataset_id is None and document_id is None: + raise AppApiException(500, _('Dataset ID or Document ID must be provided')) + + if dataset_id is not None: + k_id = dataset_id + else: + document = QuerySet(Document).filter(id=document_id).first() + k_id = document.dataset_id + + sql = f"SELECT indexname, indexdef FROM pg_indexes WHERE tablename = 'embedding' AND indexname = 'embedding_hnsw_idx_{k_id}'" + index = sql_execute(sql, []) + if index: + sql = f'DROP INDEX "embedding_hnsw_idx_{k_id}"' + update_execute(sql, []) + + class GenerateRelatedSerializer(ApiMixin, serializers.Serializer): model_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid(_('Model id'))) prompt = serializers.CharField(required=True, error_messages=ErrMessage.uuid(_('Prompt word'))) diff --git a/apps/dataset/serializers/dataset_serializers.py b/apps/dataset/serializers/dataset_serializers.py index 718916619..e5955dc92 100644 --- a/apps/dataset/serializers/dataset_serializers.py +++ b/apps/dataset/serializers/dataset_serializers.py @@ -44,7 +44,7 @@ from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, State, File, Image from dataset.serializers.common_serializers import list_paragraph, MetaSerializer, ProblemParagraphManage, \ get_embedding_model_by_dataset_id, get_embedding_model_id_by_dataset_id, write_image, zip_dir, \ - GenerateRelatedSerializer + GenerateRelatedSerializer, drop_dataset_index from dataset.serializers.document_serializers import DocumentSerializers, DocumentInstanceSerializer from dataset.task import sync_web_dataset, sync_replace_web_dataset, generate_related_by_dataset_id from embedding.models import SearchMode @@ -788,6 +788,7 @@ class DataSetSerializers(serializers.ModelSerializer): QuerySet(ProblemParagraphMapping).filter(dataset=dataset).delete() QuerySet(Paragraph).filter(dataset=dataset).delete() QuerySet(Problem).filter(dataset=dataset).delete() + drop_dataset_index(knowledge_id=dataset.id) dataset.delete() delete_embedding_by_dataset(self.data.get('id')) return True diff --git a/apps/embedding/sql/blend_search.sql b/apps/embedding/sql/blend_search.sql index afb1f0040..c70e66464 100644 --- a/apps/embedding/sql/blend_search.sql +++ b/apps/embedding/sql/blend_search.sql @@ -5,15 +5,17 @@ SELECT FROM ( SELECT DISTINCT ON - ( "paragraph_id" ) ( similarity ),* , - similarity AS comprehensive_score + ( "paragraph_id" ) ( 1 - distince + ts_similarity ) as similarity, *, + (1 - distince + ts_similarity) AS comprehensive_score FROM ( SELECT *, - (( 1 - ( embedding.embedding <=> %s ) )+ts_rank_cd( embedding.search_vector, websearch_to_tsquery('simple', %s ), 32 )) AS similarity + (embedding.embedding::vector(%s) <=> %s) as distince, + (ts_rank_cd( embedding.search_vector, websearch_to_tsquery('simple', %s ), 32 )) AS ts_similarity FROM embedding ${embedding_query} + ORDER BY distince ) TEMP ORDER BY paragraph_id, diff --git a/apps/embedding/sql/embedding_search.sql b/apps/embedding/sql/embedding_search.sql index ce3d4a580..1b5689959 100644 --- a/apps/embedding/sql/embedding_search.sql +++ b/apps/embedding/sql/embedding_search.sql @@ -5,12 +5,12 @@ SELECT FROM ( SELECT DISTINCT ON - ("paragraph_id") ( similarity ),* ,similarity AS comprehensive_score + ("paragraph_id") ( 1 - distince ),* ,(1 - distince) AS comprehensive_score FROM - ( SELECT *, ( 1 - ( embedding.embedding <=> %s ) ) AS similarity FROM embedding ${embedding_query}) TEMP + ( SELECT *, ( embedding.embedding::vector(%s) <=> %s ) AS distince FROM embedding ${embedding_query} ORDER BY distince) TEMP ORDER BY paragraph_id, - similarity DESC + distince ) DISTINCT_TEMP WHERE comprehensive_score>%s ORDER BY comprehensive_score DESC diff --git a/apps/embedding/task/embedding.py b/apps/embedding/task/embedding.py index 3b26bd7a1..488467500 100644 --- a/apps/embedding/task/embedding.py +++ b/apps/embedding/task/embedding.py @@ -17,6 +17,7 @@ from common.config.embedding_config import ModelManage from common.event import ListenerManagement, UpdateProblemArgs, UpdateEmbeddingDatasetIdArgs, \ UpdateEmbeddingDocumentIdArgs from dataset.models import Document, TaskType, State +from dataset.serializers.common_serializers import drop_dataset_index from ops import celery_app from setting.models import Model from setting.models_provider import get_model @@ -110,6 +111,7 @@ def embedding_by_dataset(dataset_id, model_id): max_kb.info(_('Start--->Vectorized dataset: {dataset_id}').format(dataset_id=dataset_id)) try: ListenerManagement.delete_embedding_by_dataset(dataset_id) + drop_dataset_index(dataset_id=dataset_id) document_list = QuerySet(Document).filter(dataset_id=dataset_id) max_kb.info(_('Dataset documentation: {document_names}').format( document_names=", ".join([d.name for d in document_list]))) diff --git a/apps/embedding/vector/pg_vector.py b/apps/embedding/vector/pg_vector.py index 7929685a3..af9ff7e4c 100644 --- a/apps/embedding/vector/pg_vector.py +++ b/apps/embedding/vector/pg_vector.py @@ -12,7 +12,6 @@ import uuid from abc import ABC, abstractmethod from typing import Dict, List -import jieba from django.contrib.postgres.search import SearchVector from django.db.models import QuerySet, Value from langchain_core.embeddings import Embeddings @@ -169,8 +168,13 @@ class EmbeddingSearch(ISearch): os.path.join(PROJECT_DIR, "apps", "embedding", 'sql', 'embedding_search.sql')), with_table_name=True) - embedding_model = select_list(exec_sql, - [json.dumps(query_embedding), *exec_params, similarity, top_number]) + embedding_model = select_list(exec_sql, [ + len(query_embedding), + json.dumps(query_embedding), + *exec_params, + similarity, + top_number + ]) return embedding_model def support(self, search_mode: SearchMode): @@ -190,8 +194,12 @@ class KeywordsSearch(ISearch): os.path.join(PROJECT_DIR, "apps", "embedding", 'sql', 'keywords_search.sql')), with_table_name=True) - embedding_model = select_list(exec_sql, - [to_query(query_text), *exec_params, similarity, top_number]) + embedding_model = select_list(exec_sql, [ + to_query(query_text), + *exec_params, + similarity, + top_number + ]) return embedding_model def support(self, search_mode: SearchMode): @@ -211,9 +219,14 @@ class BlendSearch(ISearch): os.path.join(PROJECT_DIR, "apps", "embedding", 'sql', 'blend_search.sql')), with_table_name=True) - embedding_model = select_list(exec_sql, - [json.dumps(query_embedding), to_query(query_text), *exec_params, similarity, - top_number]) + embedding_model = select_list(exec_sql, [ + len(query_embedding), + json.dumps(query_embedding), + to_query(query_text), + *exec_params, + similarity, + top_number + ]) return embedding_model def support(self, search_mode: SearchMode):