diff --git a/apps/application/flow/step_node/document_extract_node/impl/base_document_extract_node.py b/apps/application/flow/step_node/document_extract_node/impl/base_document_extract_node.py index 3cb1d601c..4e7b2f660 100644 --- a/apps/application/flow/step_node/document_extract_node/impl/base_document_extract_node.py +++ b/apps/application/flow/step_node/document_extract_node/impl/base_document_extract_node.py @@ -29,18 +29,20 @@ class BaseDocumentExtractNode(IDocumentExtractNode): # 回到文件头 buffer.seek(0) file_content = split_handle.get_content(buffer) - content.append( '## ' + doc['name'] + '\n' + file_content) + content.append('## ' + doc['name'] + '\n' + file_content) break return NodeResult({'content': splitter.join(content)}, {}) def get_details(self, index: int, **kwargs): + # 不保存content全部内容,因为content内容可能会很大 + content = (self.context.get('content')[:500] + '...') if len(self.context.get('content')) > 0 else '' return { 'name': self.node.properties.get('stepName'), "index": index, 'run_time': self.context.get('run_time'), 'type': self.node.type, - # 'content': self.context.get('content'), # 不保存content内容,因为content内容可能会很大 + 'content': content, 'status': self.status, 'err_message': self.err_message, 'document_list': self.context.get('document_list') diff --git a/apps/application/serializers/chat_message_serializers.py b/apps/application/serializers/chat_message_serializers.py index 919cb71cf..22fc18bae 100644 --- a/apps/application/serializers/chat_message_serializers.py +++ b/apps/application/serializers/chat_message_serializers.py @@ -154,6 +154,7 @@ def get_post_handler(chat_info: ChatInfo): details=manage.get_details(), message_tokens=manage.context['message_tokens'], answer_tokens=manage.context['answer_tokens'], + answer_text_list=[answer_text], run_time=manage.context['run_time'], index=len(chat_info.chat_record_list) + 1) chat_info.append_chat_record(chat_record, client_id) diff --git a/apps/application/serializers/chat_serializers.py b/apps/application/serializers/chat_serializers.py index f7bb11f24..d07b5676e 100644 --- a/apps/application/serializers/chat_serializers.py +++ b/apps/application/serializers/chat_serializers.py @@ -395,7 +395,8 @@ class ChatRecordSerializerModel(serializers.ModelSerializer): class Meta: model = ChatRecord fields = ['id', 'chat_id', 'vote_status', 'problem_text', 'answer_text', - 'message_tokens', 'answer_tokens', 'const', 'improve_paragraph_id_list', 'run_time', 'index','answer_text_list', + 'message_tokens', 'answer_tokens', 'const', 'improve_paragraph_id_list', 'run_time', 'index', + 'answer_text_list', 'create_time', 'update_time'] @@ -457,6 +458,7 @@ class ChatRecordSerializer(serializers.Serializer): def reset_chat_record(chat_record): dataset_list = [] paragraph_list = [] + if 'search_step' in chat_record.details and chat_record.details.get('search_step').get( 'paragraph_list') is not None: paragraph_list = chat_record.details.get('search_step').get( @@ -468,6 +470,14 @@ class ChatRecordSerializer(serializers.Serializer): row in paragraph_list], {}).items()] + if len(chat_record.improve_paragraph_id_list) > 0: + paragraph_model_list = QuerySet(Paragraph).filter(id__in=chat_record.improve_paragraph_id_list) + if len(paragraph_model_list) < len(chat_record.improve_paragraph_id_list): + paragraph_model_id_list = [str(p.id) for p in paragraph_model_list] + chat_record.improve_paragraph_id_list = list( + filter(lambda p_id: paragraph_model_id_list.__contains__(p_id), + chat_record.improve_paragraph_id_list)) + chat_record.save() return { **ChatRecordSerializerModel(chat_record).data, @@ -608,13 +618,11 @@ class ChatRecordSerializer(serializers.Serializer): title=instance.get("title") if 'title' in instance else '') problem_text = instance.get('problem_text') if instance.get( 'problem_text') is not None else chat_record.problem_text - problem = Problem(id=uuid.uuid1(), content=problem_text, dataset_id=dataset_id) + problem, _ = Problem.objects.get_or_create(content=problem_text, dataset_id=dataset_id) problem_paragraph_mapping = ProblemParagraphMapping(id=uuid.uuid1(), dataset_id=dataset_id, document_id=document_id, problem_id=problem.id, paragraph_id=paragraph.id) - # 插入问题 - problem.save() # 插入段落 paragraph.save() # 插入关联问题 diff --git a/apps/common/db/sql_execute.py b/apps/common/db/sql_execute.py index 79e7de46a..b12297e1f 100644 --- a/apps/common/db/sql_execute.py +++ b/apps/common/db/sql_execute.py @@ -36,8 +36,9 @@ def update_execute(sql: str, params): """ with connection.cursor() as cursor: cursor.execute(sql, params) + affected_rows = cursor.rowcount cursor.close() - return None + return affected_rows def select_list(sql: str, params: List): diff --git a/apps/common/event/listener_manage.py b/apps/common/event/listener_manage.py index a98b29bf1..dc57bfaaf 100644 --- a/apps/common/event/listener_manage.py +++ b/apps/common/event/listener_manage.py @@ -10,11 +10,12 @@ import datetime import logging import os import threading +import time import traceback from typing import List import django.db.models -from django.db import models +from django.db import models, transaction from django.db.models import QuerySet from django.db.models.functions import Substr, Reverse from langchain_core.embeddings import Embeddings @@ -168,6 +169,7 @@ class ListenerManagement: @staticmethod def get_aggregation_document_status(document_id): def aggregation_document_status(): + pass sql = get_file_content( os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'update_document_status_meta.sql')) native_update({'document_custom_sql': QuerySet(Document).filter(id=document_id)}, sql, with_table_name=True) @@ -179,7 +181,8 @@ class ListenerManagement: def aggregation_document_status(): sql = get_file_content( os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'update_document_status_meta.sql')) - native_update({'document_custom_sql': QuerySet(Document).filter(dataset_id=dataset_id)}, sql) + native_update({'document_custom_sql': QuerySet(Document).filter(dataset_id=dataset_id)}, sql, + with_table_name=True) return aggregation_document_status @@ -188,7 +191,7 @@ class ListenerManagement: def aggregation_document_status(): sql = get_file_content( os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'update_document_status_meta.sql')) - native_update({'document_custom_sql': queryset}, sql) + native_update({'document_custom_sql': queryset}, sql, with_table_name=True) return aggregation_document_status @@ -247,19 +250,23 @@ class ListenerManagement: """ if not try_lock('embedding' + str(document_id)): return - max_kb.info(f"开始--->向量化文档:{document_id}") - # 批量修改状态为PADDING - ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), TaskType.EMBEDDING, State.STARTED) try: - # 删除文档向量数据 - VectorStore.get_embedding_vector().delete_by_document_id(document_id) - def is_the_task_interrupted(): document = QuerySet(Document).filter(id=document_id).first() if document is None or Status(document.status)[TaskType.EMBEDDING] == State.REVOKE: return True return False + if is_the_task_interrupted(): + return + max_kb.info(f"开始--->向量化文档:{document_id}") + # 批量修改状态为PADDING + ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), TaskType.EMBEDDING, + State.STARTED) + + # 删除文档向量数据 + VectorStore.get_embedding_vector().delete_by_document_id(document_id) + # 根据段落进行向量化处理 page(QuerySet(Paragraph).filter(document_id=document_id).values('id'), 5, ListenerManagement.get_embedding_paragraph_apply(embedding_model, is_the_task_interrupted, diff --git a/apps/common/handle/impl/doc_split_handle.py b/apps/common/handle/impl/doc_split_handle.py index 350a3921a..6ac6f43f9 100644 --- a/apps/common/handle/impl/doc_split_handle.py +++ b/apps/common/handle/impl/doc_split_handle.py @@ -198,4 +198,4 @@ class DocSplitHandle(BaseSplitHandle): return self.to_md(doc, image_list, get_image_id_func()) except BaseException as e: traceback.print_exception(e) - return '' \ No newline at end of file + return f'{e}' \ No newline at end of file diff --git a/apps/common/handle/impl/html_split_handle.py b/apps/common/handle/impl/html_split_handle.py index 688904567..bb69e0af0 100644 --- a/apps/common/handle/impl/html_split_handle.py +++ b/apps/common/handle/impl/html_split_handle.py @@ -70,4 +70,4 @@ class HTMLSplitHandle(BaseSplitHandle): return html2text(content) except BaseException as e: traceback.print_exception(e) - return '' \ No newline at end of file + return f'{e}' \ No newline at end of file diff --git a/apps/common/handle/impl/pdf_split_handle.py b/apps/common/handle/impl/pdf_split_handle.py index b759c6d6a..21d243058 100644 --- a/apps/common/handle/impl/pdf_split_handle.py +++ b/apps/common/handle/impl/pdf_split_handle.py @@ -321,4 +321,4 @@ class PdfSplitHandle(BaseSplitHandle): return self.handle_pdf_content(file, pdf_document) except BaseException as e: traceback.print_exception(e) - return '' \ No newline at end of file + return f'{e}' \ No newline at end of file diff --git a/apps/common/handle/impl/text_split_handle.py b/apps/common/handle/impl/text_split_handle.py index 984c4e1e9..1ae22f95f 100644 --- a/apps/common/handle/impl/text_split_handle.py +++ b/apps/common/handle/impl/text_split_handle.py @@ -57,4 +57,4 @@ class TextSplitHandle(BaseSplitHandle): return buffer.decode(detect(buffer)['encoding']) except BaseException as e: traceback.print_exception(e) - return '' \ No newline at end of file + return f'{e}' \ No newline at end of file diff --git a/apps/common/util/page_utils.py b/apps/common/util/page_utils.py index 7fc176b68..92f21849b 100644 --- a/apps/common/util/page_utils.py +++ b/apps/common/util/page_utils.py @@ -18,10 +18,11 @@ def page(query_set, page_size, handler, is_the_task_interrupted=lambda: False): @param is_the_task_interrupted: 任务是否被中断 @return: """ + query = query_set.order_by("id") count = query_set.count() for i in range(0, ceil(count / page_size)): if is_the_task_interrupted(): return offset = i * page_size - paragraph_list = query_set[offset: offset + page_size] + paragraph_list = query.all()[offset: offset + page_size] handler(paragraph_list) diff --git a/apps/dataset/migrations/0011_document_status_meta_paragraph_status_meta_and_more.py b/apps/dataset/migrations/0011_document_status_meta_paragraph_status_meta_and_more.py index c64a4db20..e47bfd60c 100644 --- a/apps/dataset/migrations/0011_document_status_meta_paragraph_status_meta_and_more.py +++ b/apps/dataset/migrations/0011_document_status_meta_paragraph_status_meta_and_more.py @@ -7,6 +7,11 @@ import dataset from common.event import ListenerManagement from dataset.models import State, TaskType +sql = """ +UPDATE "document" +SET status ="replace"(status, '1', '3') +""" + def updateDocumentStatus(apps, schema_editor): ParagraphModel = apps.get_model('dataset', 'Paragraph') @@ -43,5 +48,6 @@ class Migration(migrations.Migration): name='status', field=models.CharField(default=dataset.models.data_set.Status.__str__, max_length=20, verbose_name='状态'), ), + migrations.RunSQL(sql), migrations.RunPython(updateDocumentStatus) ] diff --git a/apps/dataset/serializers/document_serializers.py b/apps/dataset/serializers/document_serializers.py index 1ab74ead2..45057d9bc 100644 --- a/apps/dataset/serializers/document_serializers.py +++ b/apps/dataset/serializers/document_serializers.py @@ -297,6 +297,9 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): ListenerManagement.update_status(QuerySet(Document).filter(id__in=document_id_list), TaskType.EMBEDDING, State.PENDING) + ListenerManagement.update_status(QuerySet(Paragraph).filter(document_id__in=document_id_list), + TaskType.EMBEDDING, + State.PENDING) embedding_by_document_list.delay(document_id_list, model_id) else: update_embedding_dataset_id(pid_list, target_dataset_id) @@ -613,7 +616,8 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): document_id = self.data.get("document_id") ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), TaskType.EMBEDDING, State.PENDING) - ListenerManagement.update_status(QuerySet(Paragraph).filter(document_id=document_id), TaskType.EMBEDDING, + ListenerManagement.update_status(QuerySet(Paragraph).filter(document_id=document_id), + TaskType.EMBEDDING, State.PENDING) ListenerManagement.get_aggregation_document_status(document_id)() embedding_model_id = get_embedding_model_id_by_dataset_id(dataset_id=self.data.get('dataset_id')) @@ -708,8 +712,8 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): @staticmethod def post_embedding(result, document_id, dataset_id): - model_id = get_embedding_model_id_by_dataset_id(dataset_id) - embedding_by_document.delay(document_id, model_id) + DocumentSerializers.Operate( + data={'dataset_id': dataset_id, 'document_id': document_id}).refresh() return result @staticmethod @@ -907,8 +911,8 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): @staticmethod def post_embedding(document_list, dataset_id): for document_dict in document_list: - model_id = get_embedding_model_id_by_dataset_id(dataset_id) - embedding_by_document.delay(document_dict.get('id'), model_id) + DocumentSerializers.Operate( + data={'dataset_id': dataset_id, 'document_id': document_dict.get('id')}).refresh() return document_list @post(post_function=post_embedding) diff --git a/apps/dataset/serializers/paragraph_serializers.py b/apps/dataset/serializers/paragraph_serializers.py index 82aacc79d..a115e544b 100644 --- a/apps/dataset/serializers/paragraph_serializers.py +++ b/apps/dataset/serializers/paragraph_serializers.py @@ -540,8 +540,16 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer): if with_valid: self.is_valid(raise_exception=True) paragraph_id = self.data.get('paragraph_id') - QuerySet(Paragraph).filter(id=paragraph_id).delete() - QuerySet(ProblemParagraphMapping).filter(paragraph_id=paragraph_id).delete() + Paragraph.objects.filter(id=paragraph_id).delete() + + problem_id = ProblemParagraphMapping.objects.filter(paragraph_id=paragraph_id).values_list('problem_id', + flat=True).first() + + if problem_id is not None: + if ProblemParagraphMapping.objects.filter(problem_id=problem_id).count() == 1: + Problem.objects.filter(id=problem_id).delete() + ProblemParagraphMapping.objects.filter(paragraph_id=paragraph_id).delete() + update_document_char_length(self.data.get('document_id')) delete_embedding_by_paragraph(paragraph_id) diff --git a/apps/dataset/task/generate.py b/apps/dataset/task/generate.py index e81039744..6a085c448 100644 --- a/apps/dataset/task/generate.py +++ b/apps/dataset/task/generate.py @@ -51,21 +51,28 @@ def get_generate_problem(llm_model, prompt, post_apply=lambda: None, is_the_task return generate_problem +def get_is_the_task_interrupted(document_id): + def is_the_task_interrupted(): + document = QuerySet(Document).filter(id=document_id).first() + if document is None or Status(document.status)[TaskType.GENERATE_PROBLEM] == State.REVOKE: + return True + return False + + return is_the_task_interrupted + + @celery_app.task(base=QueueOnce, once={'keys': ['document_id']}, name='celery:generate_related_by_document') def generate_related_by_document_id(document_id, model_id, prompt): try: + is_the_task_interrupted = get_is_the_task_interrupted(document_id) + if is_the_task_interrupted(): + return ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), TaskType.GENERATE_PROBLEM, State.STARTED) llm_model = get_llm_model(model_id) - def is_the_task_interrupted(): - document = QuerySet(Document).filter(id=document_id).first() - if document is None or Status(document.status)[TaskType.GENERATE_PROBLEM] == State.REVOKE: - return True - return False - # 生成问题函数 generate_problem = get_generate_problem(llm_model, prompt, ListenerManagement.get_aggregation_document_status( @@ -82,6 +89,12 @@ def generate_related_by_document_id(document_id, model_id, prompt): name='celery:generate_related_by_paragraph_list') def generate_related_by_paragraph_id_list(document_id, paragraph_id_list, model_id, prompt): try: + is_the_task_interrupted = get_is_the_task_interrupted(document_id) + if is_the_task_interrupted(): + ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), + TaskType.GENERATE_PROBLEM, + State.REVOKED) + return ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), TaskType.GENERATE_PROBLEM, State.STARTED) diff --git a/apps/embedding/task/embedding.py b/apps/embedding/task/embedding.py index b6d5dfb75..3e63c26b2 100644 --- a/apps/embedding/task/embedding.py +++ b/apps/embedding/task/embedding.py @@ -102,6 +102,7 @@ def embedding_by_dataset(dataset_id, model_id): max_kb.info(f"数据集文档:{[d.name for d in document_list]}") for document in document_list: try: + print(document.id, model_id) embedding_by_document.delay(document.id, model_id) except Exception as e: pass diff --git a/apps/smartdoc/settings/lib.py b/apps/smartdoc/settings/lib.py index e7b6d39dd..a4c1aaabb 100644 --- a/apps/smartdoc/settings/lib.py +++ b/apps/smartdoc/settings/lib.py @@ -32,9 +32,11 @@ CELERY_WORKER_REDIRECT_STDOUTS = True CELERY_WORKER_REDIRECT_STDOUTS_LEVEL = "INFO" CELERY_TASK_SOFT_TIME_LIMIT = 3600 CELERY_WORKER_CANCEL_LONG_RUNNING_TASKS_ON_CONNECTION_LOSS = True +CELERY_ACKS_LATE = True +celery_once_path = os.path.join(celery_data_dir, "celery_once") CELERY_ONCE = { 'backend': 'celery_once.backends.File', - 'settings': {'location': os.path.join(celery_data_dir, "celery_once")} + 'settings': {'location': celery_once_path} } CELERY_BROKER_CONNECTION_RETRY_ON_STARTUP = True CELERY_LOG_DIR = os.path.join(PROJECT_DIR, 'logs', 'celery') diff --git a/ui/src/components/ai-chat/ExecutionDetailDialog.vue b/ui/src/components/ai-chat/ExecutionDetailDialog.vue index 46a8f3946..673de74d7 100644 --- a/ui/src/components/ai-chat/ExecutionDetailDialog.vue +++ b/ui/src/components/ai-chat/ExecutionDetailDialog.vue @@ -63,11 +63,10 @@ {{ f.label }}: {{ f.value }}
上传的文档:
+文档:
上传的图片:
+图片: