MaxKB/apps/knowledge/vector/base_vector.py

189 lines
5.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# coding=utf-8
"""
@project: maxkb
@Author
@file base_vector.py
@date2023/10/18 19:16
@desc:
"""
import threading
from abc import ABC, abstractmethod
from functools import reduce
from typing import List, Dict
from langchain_core.embeddings import Embeddings
from common.chunk import text_to_chunk
from common.utils.common import sub_array
from knowledge.models import SourceType, SearchMode
lock = threading.Lock()
def chunk_data(data: Dict):
if str(data.get('source_type')) == SourceType.PARAGRAPH.value:
text = data.get('text')
chunk_list = text_to_chunk(text)
return [{**data, 'text': chunk} for chunk in chunk_list]
return [data]
def chunk_data_list(data_list: List[Dict]):
result = [chunk_data(data) for data in data_list]
return reduce(lambda x, y: [*x, *y], result, [])
class BaseVectorStore(ABC):
vector_exists = False
@abstractmethod
def vector_is_create(self) -> bool:
"""
判断向量库是否创建
:return: 是否创建向量库
"""
pass
@abstractmethod
def vector_create(self):
"""
创建 向量库
:return:
"""
pass
def save_pre_handler(self):
"""
插入前置处理器 主要是判断向量库是否创建
:return: True
"""
if not BaseVectorStore.vector_exists:
if not self.vector_is_create():
self.vector_create()
BaseVectorStore.vector_exists = True
return True
def save(self, text, source_type: SourceType, knowledge_id: str, document_id: str, paragraph_id: str, source_id: str,
is_active: bool,
embedding: Embeddings):
"""
插入向量数据
:param source_id: 资源id
:param knowledge_id: 知识库id
:param text: 文本
:param source_type: 资源类型
:param document_id: 文档id
:param is_active: 是否禁用
:param embedding: 向量化处理器
:param paragraph_id 段落id
:return: bool
"""
self.save_pre_handler()
data = {'document_id': document_id, 'paragraph_id': paragraph_id, 'knowledge_id': knowledge_id,
'is_active': is_active, 'source_id': source_id, 'source_type': source_type, 'text': text}
chunk_list = chunk_data(data)
result = sub_array(chunk_list)
for child_array in result:
self._batch_save(child_array, embedding, lambda: False)
def batch_save(self, data_list: List[Dict], embedding: Embeddings, is_the_task_interrupted):
"""
批量插入
@param data_list: 数据列表
@param embedding: 向量化处理器
@param is_the_task_interrupted: 判断是否中断任务
:return: bool
"""
self.save_pre_handler()
chunk_list = chunk_data_list(data_list)
result = sub_array(chunk_list)
for child_array in result:
if not is_the_task_interrupted():
self._batch_save(child_array, embedding, is_the_task_interrupted)
else:
break
@abstractmethod
def _save(self, text, source_type: SourceType, knowledge_id: str, document_id: str, paragraph_id: str, source_id: str,
is_active: bool,
embedding: Embeddings):
pass
@abstractmethod
def _batch_save(self, text_list: List[Dict], embedding: Embeddings, is_the_task_interrupted):
pass
def search(self, query_text, knowledge_id_list: list[str], exclude_document_id_list: list[str],
exclude_paragraph_list: list[str],
is_active: bool,
embedding: Embeddings):
if knowledge_id_list is None or len(knowledge_id_list) == 0:
return []
embedding_query = embedding.embed_query(query_text)
result = self.query(embedding_query, knowledge_id_list, exclude_document_id_list, exclude_paragraph_list,
is_active, 1, 3, 0.65)
return result[0]
@abstractmethod
def query(self, query_text: str, query_embedding: List[float], knowledge_id_list: list[str],
document_id_list: list[str] | None,
exclude_document_id_list: list[str],
exclude_paragraph_list: list[str], is_active: bool, top_n: int, similarity: float,
search_mode: SearchMode):
pass
@abstractmethod
def hit_test(self, query_text, knowledge_id: list[str], exclude_document_id_list: list[str], top_number: int,
similarity: float,
search_mode: SearchMode,
embedding: Embeddings):
pass
@abstractmethod
def update_by_paragraph_id(self, paragraph_id: str, instance: Dict):
pass
@abstractmethod
def update_by_paragraph_ids(self, paragraph_ids: str, instance: Dict):
pass
@abstractmethod
def update_by_source_id(self, source_id: str, instance: Dict):
pass
@abstractmethod
def update_by_source_ids(self, source_ids: List[str], instance: Dict):
pass
@abstractmethod
def delete_by_knowledge_id(self, knowledge_id: str):
pass
@abstractmethod
def delete_by_document_id(self, document_id: str):
pass
@abstractmethod
def delete_by_document_id_list(self, document_id_list: List[str]):
pass
@abstractmethod
def delete_by_knowledge_id_list(self, knowledge_id_list: List[str]):
pass
@abstractmethod
def delete_by_source_id(self, source_id: str, source_type: str):
pass
@abstractmethod
def delete_by_source_ids(self, source_ids: List[str], source_type: str):
pass
@abstractmethod
def delete_by_paragraph_id(self, paragraph_id: str):
pass
@abstractmethod
def delete_by_paragraph_ids(self, paragraph_ids: List[str]):
pass