fix: 修复新增段落迁移文档未索引
Some checks are pending
sync2gitee / repo-sync (push) Waiting to run
Typos Check / Spell Check with Typos (push) Waiting to run

This commit is contained in:
zhangshaohu 2024-08-28 00:28:47 +08:00 committed by shaohuzhang1
parent a53c6829c7
commit cfb6307293
3 changed files with 28 additions and 21 deletions

View File

@ -52,10 +52,9 @@ class UpdateProblemArgs:
class UpdateEmbeddingDatasetIdArgs:
def __init__(self, paragraph_id_list: List[str], target_dataset_id: str, target_embedding_model: Embeddings):
def __init__(self, paragraph_id_list: List[str], target_dataset_id: str):
self.paragraph_id_list = paragraph_id_list
self.target_dataset_id = target_dataset_id
self.target_embedding_model = target_embedding_model
class UpdateEmbeddingDocumentIdArgs:
@ -88,7 +87,6 @@ class ListenerManagement:
max_kb_error.error(f'查询向量数据:{paragraph_id_list}出现错误{str(e)}{traceback.format_exc()}')
@staticmethod
@embedding_poxy
def embedding_by_paragraph_data_list(data_list, paragraph_id_list, embedding_model: Embeddings):
max_kb.info(f'开始--->向量化段落:{paragraph_id_list}')
status = Status.success
@ -109,7 +107,6 @@ class ListenerManagement:
max_kb.info(f'结束--->向量化段落:{paragraph_id_list}')
@staticmethod
@embedding_poxy
def embedding_by_paragraph(paragraph_id, embedding_model: Embeddings):
"""
向量化段落 根据段落id
@ -238,15 +235,8 @@ class ListenerManagement:
@staticmethod
def update_embedding_dataset_id(args: UpdateEmbeddingDatasetIdArgs):
if args.target_embedding_model is None:
VectorStore.get_embedding_vector().update_by_paragraph_ids(args.paragraph_id_list,
{'dataset_id': args.target_dataset_id})
else:
# 删除向量数据
ListenerManagement.delete_embedding_by_paragraph_ids(args.paragraph_id_list)
# 向量数据
ListenerManagement.embedding_by_paragraph_list(args.paragraph_id_list,
embedding_model=args.target_embedding_model)
VectorStore.get_embedding_vector().update_by_paragraph_ids(args.paragraph_id_list,
{'dataset_id': args.target_dataset_id})
@staticmethod
def update_embedding_document_id(args: UpdateEmbeddingDocumentIdArgs):

View File

@ -42,11 +42,12 @@ from common.util.fork import Fork
from common.util.split_model import get_split_model
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, Status, ProblemParagraphMapping, Image
from dataset.serializers.common_serializers import BatchSerializer, MetaSerializer, ProblemParagraphManage, \
get_embedding_model_by_dataset_id, get_embedding_model_id_by_dataset_id
get_embedding_model_id_by_dataset_id
from dataset.serializers.paragraph_serializers import ParagraphSerializers, ParagraphInstanceSerializer
from dataset.task import sync_web_document
from embedding.task.embedding import embedding_by_document, delete_embedding_by_document_list, \
delete_embedding_by_document, update_embedding_dataset_id
delete_embedding_by_document, update_embedding_dataset_id, delete_embedding_by_paragraph_ids, \
embedding_by_document_list
from smartdoc.conf import PROJECT_DIR
parse_qa_handle_list = [XlsParseQAHandle(), CsvParseQAHandle(), XlsxParseQAHandle()]
@ -246,7 +247,12 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
# 修改段落信息
paragraph_list.update(dataset_id=target_dataset_id)
# 修改向量信息
update_embedding_dataset_id(pid_list, target_dataset_id, model_id)
if model_id:
delete_embedding_by_paragraph_ids(pid_list)
QuerySet(Document).filter(id__in=document_id_list).update(status=Status.queue_up)
embedding_by_document_list.delay(document_id_list, model_id)
else:
update_embedding_dataset_id(pid_list, target_dataset_id)
@staticmethod
def get_target_dataset_problem(target_dataset_id: str,

View File

@ -63,6 +63,19 @@ def embedding_by_document(document_id, model_id):
ListenerManagement.embedding_by_document(document_id, embedding_model)
@celery_app.task(name='celery:embedding_by_document_list')
def embedding_by_document_list(document_id_list, model_id):
"""
向量化文档
@param document_id_list: 文档id列表
@param model_id 向量模型
:return: None
"""
print(document_id_list)
for document_id in document_id_list:
embedding_by_document.delay(document_id, model_id)
@celery_app.task(base=QueueOnce, once={'keys': ['dataset_id']}, name='celery:embedding_by_dataset')
def embedding_by_dataset(dataset_id, model_id):
"""
@ -183,18 +196,16 @@ def update_problem_embedding(problem_id: str, problem_content: str, model_id):
ListenerManagement.update_problem(UpdateProblemArgs(problem_id, problem_content, model))
def update_embedding_dataset_id(paragraph_id_list, target_dataset_id, target_embedding_model_id=None):
def update_embedding_dataset_id(paragraph_id_list, target_dataset_id):
"""
修改向量数据到指定知识库
@param paragraph_id_list: 指定段落的向量数据
@param target_dataset_id: 知识库id
@param target_embedding_model_id: 目标知识库
@return:
"""
target_embedding_model = get_embedding_model(
target_embedding_model_id) if target_embedding_model_id is not None else None
ListenerManagement.update_embedding_dataset_id(
UpdateEmbeddingDatasetIdArgs(paragraph_id_list, target_dataset_id, target_embedding_model))
UpdateEmbeddingDatasetIdArgs(paragraph_id_list, target_dataset_id))
def delete_embedding_by_paragraph_ids(paragraph_ids: List[str]):