From 1721344cc6386a4fcdbb4b055f3f498cb9339988 Mon Sep 17 00:00:00 2001 From: shaohuzhang1 <80892890+shaohuzhang1@users.noreply.github.com> Date: Tue, 6 Aug 2024 16:22:53 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E7=B4=A2=E5=BC=95?= =?UTF-8?q?=E4=B8=AD=E7=9A=84=E6=96=87=E6=A1=A3,=E7=9F=A5=E8=AF=86?= =?UTF-8?q?=E5=BA=93=E5=88=A0=E9=99=A4=E5=90=8E=E4=BE=9D=E7=84=B6=E5=86=8D?= =?UTF-8?q?=E6=89=A7=E8=A1=8C=20(#934)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit (cherry picked from commit 864bca64502e2cedc2fe9e2582fa426d39cedec2) --- apps/common/event/listener_manage.py | 19 ++++++++++++++++--- apps/embedding/vector/base_vector.py | 11 +++++++---- apps/embedding/vector/pg_vector.py | 5 +++-- 3 files changed, 26 insertions(+), 9 deletions(-) diff --git a/apps/common/event/listener_manage.py b/apps/common/event/listener_manage.py index 96a759134..9c101cd40 100644 --- a/apps/common/event/listener_manage.py +++ b/apps/common/event/listener_manage.py @@ -110,11 +110,16 @@ class ListenerManagement: @embedding_poxy def embedding_by_paragraph_data_list(data_list, paragraph_id_list, embedding_model: Embeddings): max_kb.info(f'开始--->向量化段落:{paragraph_id_list}') + status = Status.success try: # 删除段落 VectorStore.get_embedding_vector().delete_by_paragraph_ids(paragraph_id_list) + + def is_save_function(): + return QuerySet(Paragraph).filter(id__in=paragraph_id_list).exists() + # 批量向量化 - VectorStore.get_embedding_vector().batch_save(data_list, embedding_model) + VectorStore.get_embedding_vector().batch_save(data_list, embedding_model, is_save_function) except Exception as e: max_kb_error.error(f'向量化段落:{paragraph_id_list}出现错误{str(e)}{traceback.format_exc()}') status = Status.error @@ -141,8 +146,12 @@ class ListenerManagement: os.path.join(PROJECT_DIR, "apps", "common", 'sql', 'list_embedding_text.sql'))) # 删除段落 VectorStore.get_embedding_vector().delete_by_paragraph_id(paragraph_id) + + def is_save_function(): + return QuerySet(Paragraph).filter(id=paragraph_id).exists() + # 批量向量化 - VectorStore.get_embedding_vector().batch_save(data_list, embedding_model) + VectorStore.get_embedding_vector().batch_save(data_list, embedding_model, is_save_function) except Exception as e: max_kb_error.error(f'向量化段落:{paragraph_id}出现错误{str(e)}{traceback.format_exc()}') status = Status.error @@ -175,8 +184,12 @@ class ListenerManagement: os.path.join(PROJECT_DIR, "apps", "common", 'sql', 'list_embedding_text.sql'))) # 删除文档向量数据 VectorStore.get_embedding_vector().delete_by_document_id(document_id) + + def is_save_function(): + return QuerySet(Document).filter(id=document_id).exists() + # 批量向量化 - VectorStore.get_embedding_vector().batch_save(data_list, embedding_model) + VectorStore.get_embedding_vector().batch_save(data_list, embedding_model, is_save_function) except Exception as e: max_kb_error.error(f'向量化文档:{document_id}出现错误{str(e)}{traceback.format_exc()}') status = Status.error diff --git a/apps/embedding/vector/base_vector.py b/apps/embedding/vector/base_vector.py index 15a4449b5..fd9ca3391 100644 --- a/apps/embedding/vector/base_vector.py +++ b/apps/embedding/vector/base_vector.py @@ -84,9 +84,9 @@ class BaseVectorStore(ABC): chunk_list = chunk_data(data) result = sub_array(chunk_list) for child_array in result: - self._batch_save(child_array, embedding) + self._batch_save(child_array, embedding, lambda: True) - def batch_save(self, data_list: List[Dict], embedding: Embeddings): + def batch_save(self, data_list: List[Dict], embedding: Embeddings, is_save_function): # 获取锁 lock.acquire() try: @@ -100,7 +100,10 @@ class BaseVectorStore(ABC): chunk_list = chunk_data_list(data_list) result = sub_array(chunk_list) for child_array in result: - self._batch_save(child_array, embedding) + if is_save_function(): + self._batch_save(child_array, embedding, is_save_function) + else: + break finally: # 释放锁 lock.release() @@ -113,7 +116,7 @@ class BaseVectorStore(ABC): pass @abstractmethod - def _batch_save(self, text_list: List[Dict], embedding: Embeddings): + def _batch_save(self, text_list: List[Dict], embedding: Embeddings, is_save_function): pass def search(self, query_text, dataset_id_list: list[str], exclude_document_id_list: list[str], diff --git a/apps/embedding/vector/pg_vector.py b/apps/embedding/vector/pg_vector.py index 935461b00..1866971d5 100644 --- a/apps/embedding/vector/pg_vector.py +++ b/apps/embedding/vector/pg_vector.py @@ -55,7 +55,7 @@ class PGVector(BaseVectorStore): embedding.save() return True - def _batch_save(self, text_list: List[Dict], embedding: Embeddings): + def _batch_save(self, text_list: List[Dict], embedding: Embeddings, is_save_function): texts = [row.get('text') for row in text_list] embeddings = embedding.embed_documents(texts) embedding_list = [Embedding(id=uuid.uuid1(), @@ -68,7 +68,8 @@ class PGVector(BaseVectorStore): embedding=embeddings[index], search_vector=to_ts_vector(text_list[index]['text'])) for index in range(0, len(text_list))] - QuerySet(Embedding).bulk_create(embedding_list) if len(embedding_list) > 0 else None + if is_save_function(): + QuerySet(Embedding).bulk_create(embedding_list) if len(embedding_list) > 0 else None return True def hit_test(self, query_text, dataset_id_list: list[str], exclude_document_id_list: list[str], top_number: int,