diff --git a/apps/common/event/listener_manage.py b/apps/common/event/listener_manage.py index 96a759134..9c101cd40 100644 --- a/apps/common/event/listener_manage.py +++ b/apps/common/event/listener_manage.py @@ -110,11 +110,16 @@ class ListenerManagement: @embedding_poxy def embedding_by_paragraph_data_list(data_list, paragraph_id_list, embedding_model: Embeddings): max_kb.info(f'开始--->向量化段落:{paragraph_id_list}') + status = Status.success try: # 删除段落 VectorStore.get_embedding_vector().delete_by_paragraph_ids(paragraph_id_list) + + def is_save_function(): + return QuerySet(Paragraph).filter(id__in=paragraph_id_list).exists() + # 批量向量化 - VectorStore.get_embedding_vector().batch_save(data_list, embedding_model) + VectorStore.get_embedding_vector().batch_save(data_list, embedding_model, is_save_function) except Exception as e: max_kb_error.error(f'向量化段落:{paragraph_id_list}出现错误{str(e)}{traceback.format_exc()}') status = Status.error @@ -141,8 +146,12 @@ class ListenerManagement: os.path.join(PROJECT_DIR, "apps", "common", 'sql', 'list_embedding_text.sql'))) # 删除段落 VectorStore.get_embedding_vector().delete_by_paragraph_id(paragraph_id) + + def is_save_function(): + return QuerySet(Paragraph).filter(id=paragraph_id).exists() + # 批量向量化 - VectorStore.get_embedding_vector().batch_save(data_list, embedding_model) + VectorStore.get_embedding_vector().batch_save(data_list, embedding_model, is_save_function) except Exception as e: max_kb_error.error(f'向量化段落:{paragraph_id}出现错误{str(e)}{traceback.format_exc()}') status = Status.error @@ -175,8 +184,12 @@ class ListenerManagement: os.path.join(PROJECT_DIR, "apps", "common", 'sql', 'list_embedding_text.sql'))) # 删除文档向量数据 VectorStore.get_embedding_vector().delete_by_document_id(document_id) + + def is_save_function(): + return QuerySet(Document).filter(id=document_id).exists() + # 批量向量化 - VectorStore.get_embedding_vector().batch_save(data_list, embedding_model) + VectorStore.get_embedding_vector().batch_save(data_list, embedding_model, is_save_function) except Exception as e: max_kb_error.error(f'向量化文档:{document_id}出现错误{str(e)}{traceback.format_exc()}') status = Status.error diff --git a/apps/embedding/vector/base_vector.py b/apps/embedding/vector/base_vector.py index 15a4449b5..fd9ca3391 100644 --- a/apps/embedding/vector/base_vector.py +++ b/apps/embedding/vector/base_vector.py @@ -84,9 +84,9 @@ class BaseVectorStore(ABC): chunk_list = chunk_data(data) result = sub_array(chunk_list) for child_array in result: - self._batch_save(child_array, embedding) + self._batch_save(child_array, embedding, lambda: True) - def batch_save(self, data_list: List[Dict], embedding: Embeddings): + def batch_save(self, data_list: List[Dict], embedding: Embeddings, is_save_function): # 获取锁 lock.acquire() try: @@ -100,7 +100,10 @@ class BaseVectorStore(ABC): chunk_list = chunk_data_list(data_list) result = sub_array(chunk_list) for child_array in result: - self._batch_save(child_array, embedding) + if is_save_function(): + self._batch_save(child_array, embedding, is_save_function) + else: + break finally: # 释放锁 lock.release() @@ -113,7 +116,7 @@ class BaseVectorStore(ABC): pass @abstractmethod - def _batch_save(self, text_list: List[Dict], embedding: Embeddings): + def _batch_save(self, text_list: List[Dict], embedding: Embeddings, is_save_function): pass def search(self, query_text, dataset_id_list: list[str], exclude_document_id_list: list[str], diff --git a/apps/embedding/vector/pg_vector.py b/apps/embedding/vector/pg_vector.py index 935461b00..1866971d5 100644 --- a/apps/embedding/vector/pg_vector.py +++ b/apps/embedding/vector/pg_vector.py @@ -55,7 +55,7 @@ class PGVector(BaseVectorStore): embedding.save() return True - def _batch_save(self, text_list: List[Dict], embedding: Embeddings): + def _batch_save(self, text_list: List[Dict], embedding: Embeddings, is_save_function): texts = [row.get('text') for row in text_list] embeddings = embedding.embed_documents(texts) embedding_list = [Embedding(id=uuid.uuid1(), @@ -68,7 +68,8 @@ class PGVector(BaseVectorStore): embedding=embeddings[index], search_vector=to_ts_vector(text_list[index]['text'])) for index in range(0, len(text_list))] - QuerySet(Embedding).bulk_create(embedding_list) if len(embedding_list) > 0 else None + if is_save_function(): + QuerySet(Embedding).bulk_create(embedding_list) if len(embedding_list) > 0 else None return True def hit_test(self, query_text, dataset_id_list: list[str], exclude_document_id_list: list[str], top_number: int,