From cfb63072936231531ceb0efa36a253e8c2d90ba2 Mon Sep 17 00:00:00 2001 From: zhangshaohu Date: Wed, 28 Aug 2024 00:28:47 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E6=96=B0=E5=A2=9E?= =?UTF-8?q?=E6=AE=B5=E8=90=BD=E8=BF=81=E7=A7=BB=E6=96=87=E6=A1=A3=E6=9C=AA?= =?UTF-8?q?=E7=B4=A2=E5=BC=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/common/event/listener_manage.py | 16 +++----------- .../serializers/document_serializers.py | 12 ++++++++--- apps/embedding/task/embedding.py | 21 ++++++++++++++----- 3 files changed, 28 insertions(+), 21 deletions(-) diff --git a/apps/common/event/listener_manage.py b/apps/common/event/listener_manage.py index 77daabd80..537677ba2 100644 --- a/apps/common/event/listener_manage.py +++ b/apps/common/event/listener_manage.py @@ -52,10 +52,9 @@ class UpdateProblemArgs: class UpdateEmbeddingDatasetIdArgs: - def __init__(self, paragraph_id_list: List[str], target_dataset_id: str, target_embedding_model: Embeddings): + def __init__(self, paragraph_id_list: List[str], target_dataset_id: str): self.paragraph_id_list = paragraph_id_list self.target_dataset_id = target_dataset_id - self.target_embedding_model = target_embedding_model class UpdateEmbeddingDocumentIdArgs: @@ -88,7 +87,6 @@ class ListenerManagement: max_kb_error.error(f'查询向量数据:{paragraph_id_list}出现错误{str(e)}{traceback.format_exc()}') @staticmethod - @embedding_poxy def embedding_by_paragraph_data_list(data_list, paragraph_id_list, embedding_model: Embeddings): max_kb.info(f'开始--->向量化段落:{paragraph_id_list}') status = Status.success @@ -109,7 +107,6 @@ class ListenerManagement: max_kb.info(f'结束--->向量化段落:{paragraph_id_list}') @staticmethod - @embedding_poxy def embedding_by_paragraph(paragraph_id, embedding_model: Embeddings): """ 向量化段落 根据段落id @@ -238,15 +235,8 @@ class ListenerManagement: @staticmethod def update_embedding_dataset_id(args: UpdateEmbeddingDatasetIdArgs): - if args.target_embedding_model is None: - VectorStore.get_embedding_vector().update_by_paragraph_ids(args.paragraph_id_list, - {'dataset_id': args.target_dataset_id}) - else: - # 删除向量数据 - ListenerManagement.delete_embedding_by_paragraph_ids(args.paragraph_id_list) - # 向量数据 - ListenerManagement.embedding_by_paragraph_list(args.paragraph_id_list, - embedding_model=args.target_embedding_model) + VectorStore.get_embedding_vector().update_by_paragraph_ids(args.paragraph_id_list, + {'dataset_id': args.target_dataset_id}) @staticmethod def update_embedding_document_id(args: UpdateEmbeddingDocumentIdArgs): diff --git a/apps/dataset/serializers/document_serializers.py b/apps/dataset/serializers/document_serializers.py index d3f0e5553..43d401428 100644 --- a/apps/dataset/serializers/document_serializers.py +++ b/apps/dataset/serializers/document_serializers.py @@ -42,11 +42,12 @@ from common.util.fork import Fork from common.util.split_model import get_split_model from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, Status, ProblemParagraphMapping, Image from dataset.serializers.common_serializers import BatchSerializer, MetaSerializer, ProblemParagraphManage, \ - get_embedding_model_by_dataset_id, get_embedding_model_id_by_dataset_id + get_embedding_model_id_by_dataset_id from dataset.serializers.paragraph_serializers import ParagraphSerializers, ParagraphInstanceSerializer from dataset.task import sync_web_document from embedding.task.embedding import embedding_by_document, delete_embedding_by_document_list, \ - delete_embedding_by_document, update_embedding_dataset_id + delete_embedding_by_document, update_embedding_dataset_id, delete_embedding_by_paragraph_ids, \ + embedding_by_document_list from smartdoc.conf import PROJECT_DIR parse_qa_handle_list = [XlsParseQAHandle(), CsvParseQAHandle(), XlsxParseQAHandle()] @@ -246,7 +247,12 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): # 修改段落信息 paragraph_list.update(dataset_id=target_dataset_id) # 修改向量信息 - update_embedding_dataset_id(pid_list, target_dataset_id, model_id) + if model_id: + delete_embedding_by_paragraph_ids(pid_list) + QuerySet(Document).filter(id__in=document_id_list).update(status=Status.queue_up) + embedding_by_document_list.delay(document_id_list, model_id) + else: + update_embedding_dataset_id(pid_list, target_dataset_id) @staticmethod def get_target_dataset_problem(target_dataset_id: str, diff --git a/apps/embedding/task/embedding.py b/apps/embedding/task/embedding.py index 46bf81a4e..779969c64 100644 --- a/apps/embedding/task/embedding.py +++ b/apps/embedding/task/embedding.py @@ -63,6 +63,19 @@ def embedding_by_document(document_id, model_id): ListenerManagement.embedding_by_document(document_id, embedding_model) +@celery_app.task(name='celery:embedding_by_document_list') +def embedding_by_document_list(document_id_list, model_id): + """ + 向量化文档 + @param document_id_list: 文档id列表 + @param model_id 向量模型 + :return: None + """ + print(document_id_list) + for document_id in document_id_list: + embedding_by_document.delay(document_id, model_id) + + @celery_app.task(base=QueueOnce, once={'keys': ['dataset_id']}, name='celery:embedding_by_dataset') def embedding_by_dataset(dataset_id, model_id): """ @@ -183,18 +196,16 @@ def update_problem_embedding(problem_id: str, problem_content: str, model_id): ListenerManagement.update_problem(UpdateProblemArgs(problem_id, problem_content, model)) -def update_embedding_dataset_id(paragraph_id_list, target_dataset_id, target_embedding_model_id=None): +def update_embedding_dataset_id(paragraph_id_list, target_dataset_id): """ 修改向量数据到指定知识库 @param paragraph_id_list: 指定段落的向量数据 @param target_dataset_id: 知识库id - @param target_embedding_model_id: 目标知识库 @return: """ - target_embedding_model = get_embedding_model( - target_embedding_model_id) if target_embedding_model_id is not None else None + ListenerManagement.update_embedding_dataset_id( - UpdateEmbeddingDatasetIdArgs(paragraph_id_list, target_dataset_id, target_embedding_model)) + UpdateEmbeddingDatasetIdArgs(paragraph_id_list, target_dataset_id)) def delete_embedding_by_paragraph_ids(paragraph_ids: List[str]):