diff --git a/apps/embedding/task/embedding.py b/apps/embedding/task/embedding.py index 67590e2c7..b6d5dfb75 100644 --- a/apps/embedding/task/embedding.py +++ b/apps/embedding/task/embedding.py @@ -6,7 +6,7 @@ @date:2024/8/19 14:13 @desc: """ - +import datetime import logging import traceback from typing import List @@ -17,7 +17,7 @@ from django.db.models import QuerySet from common.config.embedding_config import ModelManage from common.event import ListenerManagement, UpdateProblemArgs, UpdateEmbeddingDatasetIdArgs, \ UpdateEmbeddingDocumentIdArgs -from dataset.models import Document +from dataset.models import Document, Status from ops import celery_app from setting.models import Model from setting.models_provider import get_model @@ -26,10 +26,15 @@ max_kb_error = logging.getLogger("max_kb_error") max_kb = logging.getLogger("max_kb") -def get_embedding_model(model_id): - model = QuerySet(Model).filter(id=model_id).first() - embedding_model = ModelManage.get_model(model_id, - lambda _id: get_model(model)) +def get_embedding_model(model_id, exception_handler=lambda e: max_kb_error.error( + f'获取向量模型失败:{str(e)}{traceback.format_exc()}')): + try: + model = QuerySet(Model).filter(id=model_id).first() + embedding_model = ModelManage.get_model(model_id, + lambda _id: get_model(model)) + except Exception as e: + exception_handler(e) + raise e return embedding_model @@ -59,7 +64,14 @@ def embedding_by_document(document_id, model_id): @param model_id 向量模型 :return: None """ - embedding_model = get_embedding_model(model_id) + + def exception_handler(e): + QuerySet(Document).filter(id=document_id).update( + **{'status': Status.error, 'update_time': datetime.datetime.now()}) + max_kb_error.error( + f'获取向量模型失败:{str(e)}{traceback.format_exc()}') + + embedding_model = get_embedding_model(model_id, exception_handler) ListenerManagement.embedding_by_document(document_id, embedding_model) @@ -71,7 +83,6 @@ def embedding_by_document_list(document_id_list, model_id): @param model_id 向量模型 :return: None """ - print(document_id_list) for document_id in document_id_list: embedding_by_document.delay(document_id, model_id)