diff --git a/apps/common/event/listener_manage.py b/apps/common/event/listener_manage.py index e1205f3af..fc94fd838 100644 --- a/apps/common/event/listener_manage.py +++ b/apps/common/event/listener_manage.py @@ -47,22 +47,26 @@ class SyncWebDocumentArgs: class UpdateProblemArgs: - def __init__(self, problem_id: str, problem_content: str): + def __init__(self, problem_id: str, problem_content: str, embedding_model: Embeddings): self.problem_id = problem_id self.problem_content = problem_content + self.embedding_model = embedding_model class UpdateEmbeddingDatasetIdArgs: - def __init__(self, paragraph_id_list: List[str], target_dataset_id: str): + def __init__(self, paragraph_id_list: List[str], target_dataset_id: str, target_embedding_model: Embeddings): self.paragraph_id_list = paragraph_id_list self.target_dataset_id = target_dataset_id + self.target_embedding_model = target_embedding_model class UpdateEmbeddingDocumentIdArgs: - def __init__(self, paragraph_id_list: List[str], target_document_id: str, target_dataset_id: str): + def __init__(self, paragraph_id_list: List[str], target_document_id: str, target_dataset_id: str, + target_embedding_model: Embeddings = None): self.paragraph_id_list = paragraph_id_list self.target_document_id = target_document_id self.target_dataset_id = target_dataset_id + self.target_embedding_model = target_embedding_model class ListenerManagement: @@ -88,6 +92,36 @@ class ListenerManagement: def embedding_by_problem(args, embedding_model: Embeddings): VectorStore.get_embedding_vector().save(**args, embedding=embedding_model) + @staticmethod + def embedding_by_paragraph_list(paragraph_id_list, embedding_model: Embeddings): + try: + data_list = native_search( + {'problem': QuerySet(get_dynamics_model({'paragraph.id': django.db.models.CharField()})).filter( + **{'paragraph.id__in': paragraph_id_list}), + 'paragraph': QuerySet(Paragraph).filter(id__in=paragraph_id_list)}, + select_string=get_file_content( + os.path.join(PROJECT_DIR, "apps", "common", 'sql', 'list_embedding_text.sql'))) + ListenerManagement.embedding_by_paragraph_data_list(data_list, paragraph_id_list=paragraph_id_list, + embedding_model=embedding_model) + except Exception as e: + max_kb_error.error(f'查询向量数据:{paragraph_id_list}出现错误{str(e)}{traceback.format_exc()}') + + @staticmethod + @embedding_poxy + def embedding_by_paragraph_data_list(data_list, paragraph_id_list, embedding_model: Embeddings): + max_kb.info(f'开始--->向量化段落:{paragraph_id_list}') + try: + # 删除段落 + VectorStore.get_embedding_vector().delete_by_paragraph_ids(paragraph_id_list) + # 批量向量化 + VectorStore.get_embedding_vector().batch_save(data_list, embedding_model) + except Exception as e: + max_kb_error.error(f'向量化段落:{paragraph_id_list}出现错误{str(e)}{traceback.format_exc()}') + status = Status.error + finally: + QuerySet(Paragraph).filter(id__in=paragraph_id_list).update(**{'status': status}) + max_kb.info(f'结束--->向量化段落:{paragraph_id_list}') + @staticmethod @embedding_poxy def embedding_by_paragraph(paragraph_id, embedding_model: Embeddings): @@ -227,14 +261,22 @@ class ListenerManagement: @staticmethod def update_embedding_dataset_id(args: UpdateEmbeddingDatasetIdArgs): - VectorStore.get_embedding_vector().update_by_paragraph_ids(args.paragraph_id_list, - {'dataset_id': args.target_dataset_id}) + if args.target_embedding_model is None: + VectorStore.get_embedding_vector().update_by_paragraph_ids(args.paragraph_id_list, + {'dataset_id': args.target_dataset_id}) + else: + ListenerManagement.embedding_by_paragraph_list(args.paragraph_id_list, + embedding_model=args.target_embedding_model) @staticmethod def update_embedding_document_id(args: UpdateEmbeddingDocumentIdArgs): - VectorStore.get_embedding_vector().update_by_paragraph_ids(args.paragraph_id_list, - {'document_id': args.target_document_id, - 'dataset_id': args.target_dataset_id}) + if args.target_embedding_model is None: + VectorStore.get_embedding_vector().update_by_paragraph_ids(args.paragraph_id_list, + {'document_id': args.target_document_id, + 'dataset_id': args.target_dataset_id}) + else: + ListenerManagement.embedding_by_paragraph_list(args.paragraph_id_list, + embedding_model=args.target_embedding_model) @staticmethod def delete_embedding_by_source_ids(source_ids: List[str]): diff --git a/apps/dataset/serializers/common_serializers.py b/apps/dataset/serializers/common_serializers.py index 887772f11..5c2eb853b 100644 --- a/apps/dataset/serializers/common_serializers.py +++ b/apps/dataset/serializers/common_serializers.py @@ -147,3 +147,7 @@ def get_embedding_model_by_dataset_id_list(dataset_id_list: List): def get_embedding_model_by_dataset_id(dataset_id: str): dataset = QuerySet(DataSet).select_related('embedding_mode').filter(id=dataset_id).first() return EmbeddingModelManage.get_model(dataset_id, lambda _id: get_model(dataset.embedding_mode)) + + +def get_embedding_model_by_dataset(dataset): + return EmbeddingModelManage.get_model(str(dataset.id), lambda _id: get_model(dataset.embedding_mode)) diff --git a/apps/dataset/serializers/document_serializers.py b/apps/dataset/serializers/document_serializers.py index 3bd5a1c9c..aac5feff9 100644 --- a/apps/dataset/serializers/document_serializers.py +++ b/apps/dataset/serializers/document_serializers.py @@ -235,12 +235,17 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): meta={}) else: document_list.update(dataset_id=target_dataset_id) - # 修改向量信息 - ListenerManagement.update_embedding_dataset_id(UpdateEmbeddingDatasetIdArgs( - [paragraph.id for paragraph in paragraph_list], - target_dataset_id)) + model = None + if dataset.embedding_mode_id != target_dataset.embedding_mode_id: + model = get_embedding_model_by_dataset_id(target_dataset_id) + + pid_list = [paragraph.id for paragraph in paragraph_list] # 修改段落信息 paragraph_list.update(dataset_id=target_dataset_id) + # 修改向量信息 + ListenerManagement.update_embedding_dataset_id(UpdateEmbeddingDatasetIdArgs( + pid_list, + target_dataset_id, model)) @staticmethod def get_target_dataset_problem(target_dataset_id: str, diff --git a/apps/dataset/serializers/paragraph_serializers.py b/apps/dataset/serializers/paragraph_serializers.py index 7f84f61c0..36f625ad7 100644 --- a/apps/dataset/serializers/paragraph_serializers.py +++ b/apps/dataset/serializers/paragraph_serializers.py @@ -20,9 +20,9 @@ from common.exception.app_exception import AppApiException from common.mixins.api_mixin import ApiMixin from common.util.common import post from common.util.field_message import ErrMessage -from dataset.models import Paragraph, Problem, Document, ProblemParagraphMapping +from dataset.models import Paragraph, Problem, Document, ProblemParagraphMapping, DataSet from dataset.serializers.common_serializers import update_document_char_length, BatchSerializer, ProblemParagraphObject, \ - ProblemParagraphManage, get_embedding_model_by_dataset_id + ProblemParagraphManage, get_embedding_model_by_dataset_id, get_embedding_model_by_dataset from dataset.serializers.problem_serializers import ProblemInstanceSerializer, ProblemSerializer, ProblemSerializers from embedding.models import SourceType @@ -338,10 +338,11 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer): # 修改mapping QuerySet(ProblemParagraphMapping).bulk_update(problem_paragraph_mapping_list, ['document_id']) + # 修改向量段落信息 ListenerManagement.update_embedding_document_id(UpdateEmbeddingDocumentIdArgs( [paragraph.id for paragraph in paragraph_list], - target_document_id, target_dataset_id)) + target_document_id, target_dataset_id, target_embedding_model=None)) # 修改段落信息 paragraph_list.update(document_id=target_document_id) # 不同数据集迁移 @@ -368,12 +369,19 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer): # 修改mapping QuerySet(ProblemParagraphMapping).bulk_update(problem_paragraph_mapping_list, ['problem_id', 'dataset_id', 'document_id']) - # 修改向量段落信息 - ListenerManagement.update_embedding_document_id(UpdateEmbeddingDocumentIdArgs( - [paragraph.id for paragraph in paragraph_list], - target_document_id, target_dataset_id)) + target_dataset = QuerySet(DataSet).filter(id=target_dataset_id).first() + dataset = QuerySet(DataSet).filter(id=dataset_id).first() + embedding_model = None + if target_dataset.embedding_mode_id != dataset.embedding_mode_id: + embedding_model = get_embedding_model_by_dataset(target_dataset) + pid_list = [paragraph.id for paragraph in paragraph_list] # 修改段落信息 paragraph_list.update(dataset_id=target_dataset_id, document_id=target_document_id) + # 修改向量段落信息 + ListenerManagement.update_embedding_document_id(UpdateEmbeddingDocumentIdArgs( + pid_list, + target_document_id, target_dataset_id, target_embedding_model=embedding_model)) + update_document_char_length(document_id) update_document_char_length(target_document_id) diff --git a/apps/dataset/serializers/problem_serializers.py b/apps/dataset/serializers/problem_serializers.py index 5d00d5be4..34064f9a9 100644 --- a/apps/dataset/serializers/problem_serializers.py +++ b/apps/dataset/serializers/problem_serializers.py @@ -20,7 +20,8 @@ from common.event import ListenerManagement, UpdateProblemArgs from common.mixins.api_mixin import ApiMixin from common.util.field_message import ErrMessage from common.util.file_util import get_file_content -from dataset.models import Problem, Paragraph, ProblemParagraphMapping +from dataset.models import Problem, Paragraph, ProblemParagraphMapping, DataSet +from dataset.serializers.common_serializers import get_embedding_model_by_dataset_id from smartdoc.conf import PROJECT_DIR @@ -157,6 +158,8 @@ class ProblemSerializers(ApiMixin, serializers.Serializer): content = instance.get('content') problem = QuerySet(Problem).filter(id=problem_id, dataset_id=dataset_id).first() + QuerySet(DataSet).filter(id=dataset_id) problem.content = content problem.save() - ListenerManagement.update_problem_signal.send(UpdateProblemArgs(problem_id, content)) + model = get_embedding_model_by_dataset_id(dataset_id) + ListenerManagement.update_problem_signal.send(UpdateProblemArgs(problem_id, content, model))