diff --git a/apps/common/db/sql_execute.py b/apps/common/db/sql_execute.py index 79e7de46a..b12297e1f 100644 --- a/apps/common/db/sql_execute.py +++ b/apps/common/db/sql_execute.py @@ -36,8 +36,9 @@ def update_execute(sql: str, params): """ with connection.cursor() as cursor: cursor.execute(sql, params) + affected_rows = cursor.rowcount cursor.close() - return None + return affected_rows def select_list(sql: str, params: List): diff --git a/apps/common/event/listener_manage.py b/apps/common/event/listener_manage.py index a98b29bf1..9c16ad5c2 100644 --- a/apps/common/event/listener_manage.py +++ b/apps/common/event/listener_manage.py @@ -10,11 +10,12 @@ import datetime import logging import os import threading +import time import traceback from typing import List import django.db.models -from django.db import models +from django.db import models, transaction from django.db.models import QuerySet from django.db.models.functions import Substr, Reverse from langchain_core.embeddings import Embeddings @@ -168,6 +169,7 @@ class ListenerManagement: @staticmethod def get_aggregation_document_status(document_id): def aggregation_document_status(): + pass sql = get_file_content( os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'update_document_status_meta.sql')) native_update({'document_custom_sql': QuerySet(Document).filter(id=document_id)}, sql, with_table_name=True) diff --git a/apps/common/util/page_utils.py b/apps/common/util/page_utils.py index 7fc176b68..92f21849b 100644 --- a/apps/common/util/page_utils.py +++ b/apps/common/util/page_utils.py @@ -18,10 +18,11 @@ def page(query_set, page_size, handler, is_the_task_interrupted=lambda: False): @param is_the_task_interrupted: 任务是否被中断 @return: """ + query = query_set.order_by("id") count = query_set.count() for i in range(0, ceil(count / page_size)): if is_the_task_interrupted(): return offset = i * page_size - paragraph_list = query_set[offset: offset + page_size] + paragraph_list = query.all()[offset: offset + page_size] handler(paragraph_list) diff --git a/apps/dataset/serializers/document_serializers.py b/apps/dataset/serializers/document_serializers.py index 1ab74ead2..70facd8db 100644 --- a/apps/dataset/serializers/document_serializers.py +++ b/apps/dataset/serializers/document_serializers.py @@ -613,7 +613,8 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): document_id = self.data.get("document_id") ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), TaskType.EMBEDDING, State.PENDING) - ListenerManagement.update_status(QuerySet(Paragraph).filter(document_id=document_id), TaskType.EMBEDDING, + ListenerManagement.update_status(QuerySet(Paragraph).filter(document_id=document_id), + TaskType.EMBEDDING, State.PENDING) ListenerManagement.get_aggregation_document_status(document_id)() embedding_model_id = get_embedding_model_id_by_dataset_id(dataset_id=self.data.get('dataset_id')) @@ -708,8 +709,8 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): @staticmethod def post_embedding(result, document_id, dataset_id): - model_id = get_embedding_model_id_by_dataset_id(dataset_id) - embedding_by_document.delay(document_id, model_id) + DocumentSerializers.Operate( + data={'dataset_id': dataset_id, 'document_id': document_id}).refresh() return result @staticmethod @@ -907,8 +908,8 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): @staticmethod def post_embedding(document_list, dataset_id): for document_dict in document_list: - model_id = get_embedding_model_id_by_dataset_id(dataset_id) - embedding_by_document.delay(document_dict.get('id'), model_id) + DocumentSerializers.Operate( + data={'dataset_id': dataset_id, 'document_id': document_dict.get('id')}).refresh() return document_list @post(post_function=post_embedding) diff --git a/ui/src/views/document/component/Status.vue b/ui/src/views/document/component/Status.vue index 8bbab6784..139365371 100644 --- a/ui/src/views/document/component/Status.vue +++ b/ui/src/views/document/component/Status.vue @@ -1,51 +1,13 @@