mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 10:12:51 +00:00
80 lines
3.3 KiB
Python
80 lines
3.3 KiB
Python
# coding=utf-8
|
||
"""
|
||
@project: maxkb
|
||
@Author:虎
|
||
@file: pg_vector.py
|
||
@date:2023/10/19 15:28
|
||
@desc:
|
||
"""
|
||
import uuid
|
||
from typing import Dict, List
|
||
|
||
from django.db.models import QuerySet
|
||
from langchain.embeddings import HuggingFaceEmbeddings
|
||
|
||
from embedding.models import Embedding, SourceType
|
||
from embedding.vector.base_vector import BaseVectorStore
|
||
|
||
|
||
class PGVector(BaseVectorStore):
|
||
|
||
def vector_is_create(self) -> bool:
|
||
# 项目启动默认是创建好的 不需要再创建
|
||
return True
|
||
|
||
def vector_create(self):
|
||
return True
|
||
|
||
def _save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str,
|
||
is_active: bool,
|
||
embedding: HuggingFaceEmbeddings):
|
||
text_embedding = embedding.embed_query(text)
|
||
embedding = Embedding(id=uuid.uuid1(),
|
||
dataset_id=dataset_id,
|
||
document_id=document_id,
|
||
is_active=is_active,
|
||
paragraph_id=paragraph_id,
|
||
source_id=source_id,
|
||
embedding=text_embedding,
|
||
source_type=source_type,
|
||
)
|
||
embedding.save()
|
||
return True
|
||
|
||
def _batch_save(self, text_list: List[Dict], embedding: HuggingFaceEmbeddings):
|
||
texts = [row.get('text') for row in text_list]
|
||
embeddings = embedding.embed_documents(texts)
|
||
QuerySet(Embedding).bulk_create([Embedding(id=uuid.uuid1(),
|
||
document_id=text_list[index].get('document_id'),
|
||
paragraph_id=text_list[index].get('paragraph_id'),
|
||
dataset_id=text_list[index].get('dataset_id'),
|
||
is_active=text_list[index].get('is_active', True),
|
||
source_id=text_list[index].get('source_id'),
|
||
source_type=text_list[index].get('source_type'),
|
||
embedding=embeddings[index]) for index in
|
||
range(0, len(text_list))]) if len(text_list) > 0 else None
|
||
return True
|
||
|
||
def search(self, query_text, dataset_id_list: list[str], is_active: bool, embedding: HuggingFaceEmbeddings):
|
||
pass
|
||
|
||
def update_by_source_id(self, source_id: str, instance: Dict):
|
||
QuerySet(Embedding).filter(source_id=source_id).update(**instance)
|
||
|
||
def update_by_paragraph_id(self, paragraph_id: str, instance: Dict):
|
||
QuerySet(Embedding).filter(paragraph_id=paragraph_id).update(**instance)
|
||
|
||
def delete_by_dataset_id(self, dataset_id: str):
|
||
QuerySet(Embedding).filter(dataset_id=dataset_id).delete()
|
||
|
||
def delete_by_document_id(self, document_id: str):
|
||
QuerySet(Embedding).filter(document_id=document_id).delete()
|
||
return True
|
||
|
||
def delete_by_source_id(self, source_id: str, source_type: str):
|
||
QuerySet(Embedding).filter(source_id=source_id, source_type=source_type).delete()
|
||
return True
|
||
|
||
def delete_by_paragraph_id(self, paragraph_id: str):
|
||
QuerySet(Embedding).filter(paragraph_id=paragraph_id).delete()
|