fix: 修复索引中的文档,知识库删除后依然再执行 (#934)

This commit is contained in:
shaohuzhang1 2024-08-06 16:22:53 +08:00 committed by GitHub
parent b3c7120372
commit 864bca6450
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 26 additions and 9 deletions

View File

@ -110,11 +110,16 @@ class ListenerManagement:
@embedding_poxy
def embedding_by_paragraph_data_list(data_list, paragraph_id_list, embedding_model: Embeddings):
max_kb.info(f'开始--->向量化段落:{paragraph_id_list}')
status = Status.success
try:
# 删除段落
VectorStore.get_embedding_vector().delete_by_paragraph_ids(paragraph_id_list)
def is_save_function():
return QuerySet(Paragraph).filter(id__in=paragraph_id_list).exists()
# 批量向量化
VectorStore.get_embedding_vector().batch_save(data_list, embedding_model)
VectorStore.get_embedding_vector().batch_save(data_list, embedding_model, is_save_function)
except Exception as e:
max_kb_error.error(f'向量化段落:{paragraph_id_list}出现错误{str(e)}{traceback.format_exc()}')
status = Status.error
@ -141,8 +146,12 @@ class ListenerManagement:
os.path.join(PROJECT_DIR, "apps", "common", 'sql', 'list_embedding_text.sql')))
# 删除段落
VectorStore.get_embedding_vector().delete_by_paragraph_id(paragraph_id)
def is_save_function():
return QuerySet(Paragraph).filter(id=paragraph_id).exists()
# 批量向量化
VectorStore.get_embedding_vector().batch_save(data_list, embedding_model)
VectorStore.get_embedding_vector().batch_save(data_list, embedding_model, is_save_function)
except Exception as e:
max_kb_error.error(f'向量化段落:{paragraph_id}出现错误{str(e)}{traceback.format_exc()}')
status = Status.error
@ -175,8 +184,12 @@ class ListenerManagement:
os.path.join(PROJECT_DIR, "apps", "common", 'sql', 'list_embedding_text.sql')))
# 删除文档向量数据
VectorStore.get_embedding_vector().delete_by_document_id(document_id)
def is_save_function():
return QuerySet(Document).filter(id=document_id).exists()
# 批量向量化
VectorStore.get_embedding_vector().batch_save(data_list, embedding_model)
VectorStore.get_embedding_vector().batch_save(data_list, embedding_model, is_save_function)
except Exception as e:
max_kb_error.error(f'向量化文档:{document_id}出现错误{str(e)}{traceback.format_exc()}')
status = Status.error

View File

@ -84,9 +84,9 @@ class BaseVectorStore(ABC):
chunk_list = chunk_data(data)
result = sub_array(chunk_list)
for child_array in result:
self._batch_save(child_array, embedding)
self._batch_save(child_array, embedding, lambda: True)
def batch_save(self, data_list: List[Dict], embedding: Embeddings):
def batch_save(self, data_list: List[Dict], embedding: Embeddings, is_save_function):
# 获取锁
lock.acquire()
try:
@ -100,7 +100,10 @@ class BaseVectorStore(ABC):
chunk_list = chunk_data_list(data_list)
result = sub_array(chunk_list)
for child_array in result:
self._batch_save(child_array, embedding)
if is_save_function():
self._batch_save(child_array, embedding, is_save_function)
else:
break
finally:
# 释放锁
lock.release()
@ -113,7 +116,7 @@ class BaseVectorStore(ABC):
pass
@abstractmethod
def _batch_save(self, text_list: List[Dict], embedding: Embeddings):
def _batch_save(self, text_list: List[Dict], embedding: Embeddings, is_save_function):
pass
def search(self, query_text, dataset_id_list: list[str], exclude_document_id_list: list[str],

View File

@ -55,7 +55,7 @@ class PGVector(BaseVectorStore):
embedding.save()
return True
def _batch_save(self, text_list: List[Dict], embedding: Embeddings):
def _batch_save(self, text_list: List[Dict], embedding: Embeddings, is_save_function):
texts = [row.get('text') for row in text_list]
embeddings = embedding.embed_documents(texts)
embedding_list = [Embedding(id=uuid.uuid1(),
@ -68,7 +68,8 @@ class PGVector(BaseVectorStore):
embedding=embeddings[index],
search_vector=to_ts_vector(text_list[index]['text'])) for index in
range(0, len(text_list))]
QuerySet(Embedding).bulk_create(embedding_list) if len(embedding_list) > 0 else None
if is_save_function():
QuerySet(Embedding).bulk_create(embedding_list) if len(embedding_list) > 0 else None
return True
def hit_test(self, query_text, dataset_id_list: list[str], exclude_document_id_list: list[str], top_number: int,