feat: 文档状态

This commit is contained in:
shaohuzhang1 2024-11-26 12:08:13 +08:00 committed by GitHub
parent 663ffcdc4d
commit e978c83f02
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
24 changed files with 828 additions and 135 deletions

View File

@ -12,7 +12,7 @@ from django.db import DEFAULT_DB_ALIAS, models, connections
from django.db.models import QuerySet
from common.db.compiler import AppSQLCompiler
from common.db.sql_execute import select_one, select_list
from common.db.sql_execute import select_one, select_list, update_execute
from common.response.result import Page
@ -109,6 +109,24 @@ def native_search(queryset: QuerySet | Dict[str, QuerySet], select_string: str,
return select_list(exec_sql, exec_params)
def native_update(queryset: QuerySet | Dict[str, QuerySet], select_string: str,
field_replace_dict: None | Dict[str, Dict[str, str]] | Dict[str, str] = None,
with_table_name=False):
"""
复杂查询
:param with_table_name: 生成sql是否包含表名
:param queryset: 查询条件构造器
:param select_string: 查询前缀 不包括 where limit 等信息
:param field_replace_dict: 需要替换的字段
:return: 查询结果
"""
if isinstance(queryset, Dict):
exec_sql, exec_params = generate_sql_by_query_dict(queryset, select_string, field_replace_dict, with_table_name)
else:
exec_sql, exec_params = generate_sql_by_query(queryset, select_string, field_replace_dict, with_table_name)
return update_execute(exec_sql, exec_params)
def page_search(current_page: int, page_size: int, queryset: QuerySet, post_records_handler):
"""
分页查询

View File

@ -9,24 +9,29 @@
import datetime
import logging
import os
import threading
import traceback
from typing import List
import django.db.models
from django.db import models
from django.db.models import QuerySet
from django.db.models.functions import Substr, Reverse
from langchain_core.embeddings import Embeddings
from common.config.embedding_config import VectorStore
from common.db.search import native_search, get_dynamics_model
from common.event.common import embedding_poxy
from common.db.search import native_search, get_dynamics_model, native_update
from common.db.sql_execute import sql_execute, update_execute
from common.util.file_util import get_file_content
from common.util.lock import try_lock, un_lock
from dataset.models import Paragraph, Status, Document, ProblemParagraphMapping
from common.util.page_utils import page
from dataset.models import Paragraph, Status, Document, ProblemParagraphMapping, TaskType, State
from embedding.models import SourceType, SearchMode
from smartdoc.conf import PROJECT_DIR
max_kb_error = logging.getLogger(__file__)
max_kb = logging.getLogger(__file__)
lock = threading.Lock()
class SyncWebDatasetArgs:
@ -114,7 +119,8 @@ class ListenerManagement:
@param embedding_model: 向量模型
"""
max_kb.info(f"开始--->向量化段落:{paragraph_id}")
status = Status.success
# 更新到开始状态
ListenerManagement.update_status(QuerySet(Paragraph).filter(id=paragraph_id), TaskType.EMBEDDING, State.STARTED)
try:
data_list = native_search(
{'problem': QuerySet(get_dynamics_model({'paragraph.id': django.db.models.CharField()})).filter(
@ -125,16 +131,22 @@ class ListenerManagement:
# 删除段落
VectorStore.get_embedding_vector().delete_by_paragraph_id(paragraph_id)
def is_save_function():
return QuerySet(Paragraph).filter(id=paragraph_id).exists()
def is_the_task_interrupted():
_paragraph = QuerySet(Paragraph).filter(id=paragraph_id).first()
if _paragraph is None or Status(_paragraph.status)[TaskType.EMBEDDING] == State.REVOKE:
return True
return False
# 批量向量化
VectorStore.get_embedding_vector().batch_save(data_list, embedding_model, is_save_function)
VectorStore.get_embedding_vector().batch_save(data_list, embedding_model, is_the_task_interrupted)
# 更新到开始状态
ListenerManagement.update_status(QuerySet(Paragraph).filter(id=paragraph_id), TaskType.EMBEDDING,
State.SUCCESS)
except Exception as e:
max_kb_error.error(f'向量化段落:{paragraph_id}出现错误{str(e)}{traceback.format_exc()}')
status = Status.error
ListenerManagement.update_status(QuerySet(Paragraph).filter(id=paragraph_id), TaskType.EMBEDDING,
State.FAILURE)
finally:
QuerySet(Paragraph).filter(id=paragraph_id).update(**{'status': status})
max_kb.info(f'结束--->向量化段落:{paragraph_id}')
@staticmethod
@ -142,6 +154,91 @@ class ListenerManagement:
# 批量向量化
VectorStore.get_embedding_vector().batch_save(data_list, embedding_model, lambda: True)
@staticmethod
def get_embedding_paragraph_apply(embedding_model, is_the_task_interrupted, post_apply=lambda: None):
def embedding_paragraph_apply(paragraph_list):
for paragraph in paragraph_list:
if is_the_task_interrupted():
break
ListenerManagement.embedding_by_paragraph(str(paragraph.get('id')), embedding_model)
post_apply()
return embedding_paragraph_apply
@staticmethod
def get_aggregation_document_status(document_id):
def aggregation_document_status():
sql = get_file_content(
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'update_document_status_meta.sql'))
native_update({'document_custom_sql': QuerySet(Document).filter(id=document_id),
'default_sql': QuerySet(Document).filter(id=document_id)}, sql, with_table_name=True)
return aggregation_document_status
@staticmethod
def get_aggregation_document_status_by_dataset_id(dataset_id):
def aggregation_document_status():
sql = get_file_content(
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'update_document_status_meta.sql'))
native_update({'document_custom_sql': QuerySet(Document).filter(dataset_id=dataset_id),
'default_sql': QuerySet(Document).filter(dataset_id=dataset_id)}, sql)
return aggregation_document_status
@staticmethod
def get_aggregation_document_status_by_query_set(queryset):
def aggregation_document_status():
sql = get_file_content(
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'update_document_status_meta.sql'))
native_update({'document_custom_sql': queryset, 'default_sql': queryset}, sql)
return aggregation_document_status
@staticmethod
def post_update_document_status(document_id, task_type: TaskType):
_document = QuerySet(Document).filter(id=document_id).first()
status = Status(_document.status)
if status[task_type] == State.REVOKE:
status[task_type] = State.REVOKED
else:
status[task_type] = State.SUCCESS
for item in _document.status_meta.get('aggs', []):
agg_status = item.get('status')
agg_count = item.get('count')
if Status(agg_status)[task_type] == State.FAILURE and agg_count > 0:
status[task_type] = State.FAILURE
ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), task_type, status[task_type])
ListenerManagement.update_status(QuerySet(Paragraph).annotate(
reversed_status=Reverse('status'),
task_type_status=Substr('reversed_status', task_type.value,
task_type.value),
).filter(task_type_status=State.REVOKE.value).filter(document_id=document_id).values('id'),
task_type,
State.REVOKED)
@staticmethod
def update_status(query_set: QuerySet, taskType: TaskType, state: State):
exec_sql = get_file_content(
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'update_paragraph_status.sql'))
bit_number = len(TaskType)
up_index = taskType.value - 1
next_index = taskType.value + 1
current_index = taskType.value
status_number = state.value
params_dict = {'${bit_number}': bit_number, '${up_index}': up_index,
'${status_number}': status_number, '${next_index}': next_index,
'${table_name}': query_set.model._meta.db_table, '${current_index}': current_index}
for key in params_dict:
_value_ = params_dict[key]
exec_sql = exec_sql.replace(key, str(_value_))
lock.acquire()
try:
native_update(query_set, exec_sql)
finally:
lock.release()
@staticmethod
def embedding_by_document(document_id, embedding_model: Embeddings):
"""
@ -153,33 +250,29 @@ class ListenerManagement:
if not try_lock('embedding' + str(document_id)):
return
max_kb.info(f"开始--->向量化文档:{document_id}")
QuerySet(Document).filter(id=document_id).update(**{'status': Status.embedding})
QuerySet(Paragraph).filter(document_id=document_id).update(**{'status': Status.embedding})
status = Status.success
# 批量修改状态为PADDING
ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), TaskType.EMBEDDING, State.STARTED)
try:
data_list = native_search(
{'problem': QuerySet(
get_dynamics_model({'paragraph.document_id': django.db.models.CharField()})).filter(
**{'paragraph.document_id': document_id}),
'paragraph': QuerySet(Paragraph).filter(document_id=document_id)},
select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "common", 'sql', 'list_embedding_text.sql')))
# 删除文档向量数据
VectorStore.get_embedding_vector().delete_by_document_id(document_id)
def is_save_function():
return QuerySet(Document).filter(id=document_id).exists()
def is_the_task_interrupted():
document = QuerySet(Document).filter(id=document_id).first()
if document is None or Status(document.status)[TaskType.EMBEDDING] == State.REVOKE:
return True
return False
# 批量向量化
VectorStore.get_embedding_vector().batch_save(data_list, embedding_model, is_save_function)
# 根据段落进行向量化处理
page(QuerySet(Paragraph).filter(document_id=document_id).values('id'), 5,
ListenerManagement.get_embedding_paragraph_apply(embedding_model, is_the_task_interrupted,
ListenerManagement.get_aggregation_document_status(
document_id)),
is_the_task_interrupted)
except Exception as e:
max_kb_error.error(f'向量化文档:{document_id}出现错误{str(e)}{traceback.format_exc()}')
status = Status.error
finally:
# 修改状态
QuerySet(Document).filter(id=document_id).update(
**{'status': status, 'update_time': datetime.datetime.now()})
QuerySet(Paragraph).filter(document_id=document_id).update(**{'status': status})
ListenerManagement.post_update_document_status(document_id, TaskType.EMBEDDING)
ListenerManagement.get_aggregation_document_status(document_id)()
max_kb.info(f"结束--->向量化文档:{document_id}")
un_lock('embedding' + str(document_id))

View File

@ -0,0 +1,27 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file page_utils.py
@date2024/11/21 10:32
@desc:
"""
from math import ceil
def page(query_set, page_size, handler, is_the_task_interrupted=lambda: False):
"""
@param query_set: 查询query_set
@param page_size: 每次查询大小
@param handler: 数据处理器
@param is_the_task_interrupted: 任务是否被中断
@return:
"""
count = query_set.count()
for i in range(0, ceil(count / page_size)):
if is_the_task_interrupted():
return
offset = i * page_size
paragraph_list = query_set[offset: offset + page_size]
handler(paragraph_list)

View File

@ -0,0 +1,34 @@
# Generated by Django 4.2.15 on 2024-11-22 14:44
import dataset.models.data_set
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('dataset', '0010_file_meta'),
]
operations = [
migrations.AddField(
model_name='document',
name='status_meta',
field=models.JSONField(default=dataset.models.data_set.default_status_meta, verbose_name='状态统计数据'),
),
migrations.AddField(
model_name='paragraph',
name='status_meta',
field=models.JSONField(default=dataset.models.data_set.default_status_meta, verbose_name='状态数据'),
),
migrations.AlterField(
model_name='document',
name='status',
field=models.CharField(default=dataset.models.data_set.Status.__str__, max_length=20, verbose_name='状态'),
),
migrations.AlterField(
model_name='paragraph',
name='status',
field=models.CharField(default=dataset.models.data_set.Status.__str__, max_length=20, verbose_name='状态'),
),
]

View File

@ -7,6 +7,7 @@
@desc: 数据集
"""
import uuid
from enum import Enum
from django.db import models
from django.db.models.signals import pre_delete
@ -18,13 +19,62 @@ from setting.models import Model
from users.models import User
class Status(models.TextChoices):
"""订单类型"""
embedding = 0, '导入中'
success = 1, '已完成'
error = 2, '导入失败'
queue_up = 3, '排队中'
generating = 4, '生成问题中'
class TaskType(Enum):
# 向量
EMBEDDING = 1
# 生成问题
GENERATE_PROBLEM = 2
# 同步
SYNC = 3
class State(Enum):
# 等待
PENDING = '0'
# 执行中
STARTED = '1'
# 成功
SUCCESS = '2'
# 失败
FAILURE = '3'
# 取消任务
REVOKE = '4'
# 取消成功
REVOKED = '5'
# 忽略
IGNORED = 'n'
class Status:
type_cls = TaskType
state_cls = State
def __init__(self, status: str = None):
self.task_status = {}
status_list = list(status[::-1] if status is not None else '')
for _type in self.type_cls:
index = _type.value - 1
_state = self.state_cls(status_list[index] if len(status_list) > index else 'n')
self.task_status[_type] = _state
@staticmethod
def of(status: str):
return Status(status)
def __str__(self):
result = []
for _type in sorted(self.type_cls, key=lambda item: item.value, reverse=True):
result.insert(len(self.type_cls) - _type.value, self.task_status[_type].value)
return ''.join(result)
def __setitem__(self, key, value):
self.task_status[key] = value
def __getitem__(self, item):
return self.task_status[item]
def update_status(self, task_type: TaskType, state: State):
self.task_status[task_type] = state
class Type(models.TextChoices):
@ -42,6 +92,10 @@ def default_model():
return uuid.UUID('42f63a3d-427e-11ef-b3ec-a8a1595801ab')
def default_status_meta():
return {"state_time": {}}
class DataSet(AppModelMixin):
"""
数据集表
@ -68,8 +122,8 @@ class Document(AppModelMixin):
dataset = models.ForeignKey(DataSet, on_delete=models.DO_NOTHING)
name = models.CharField(max_length=150, verbose_name="文档名称")
char_length = models.IntegerField(verbose_name="文档字符数 冗余字段")
status = models.CharField(verbose_name='状态', max_length=1, choices=Status.choices,
default=Status.queue_up)
status = models.CharField(verbose_name='状态', max_length=20, default=Status('').__str__)
status_meta = models.JSONField(verbose_name="状态统计数据", default=default_status_meta)
is_active = models.BooleanField(default=True)
type = models.CharField(verbose_name='类型', max_length=1, choices=Type.choices,
@ -94,8 +148,8 @@ class Paragraph(AppModelMixin):
dataset = models.ForeignKey(DataSet, on_delete=models.DO_NOTHING)
content = models.CharField(max_length=102400, verbose_name="段落内容")
title = models.CharField(max_length=256, verbose_name="标题", default="")
status = models.CharField(verbose_name='状态', max_length=1, choices=Status.choices,
default=Status.embedding)
status = models.CharField(verbose_name='状态', max_length=20, default=Status('').__str__)
status_meta = models.JSONField(verbose_name="状态数据", default=default_status_meta)
hit_num = models.IntegerField(verbose_name="命中次数", default=0)
is_active = models.BooleanField(default=True)
@ -145,7 +199,6 @@ class File(AppModelMixin):
meta = models.JSONField(verbose_name="文件关联数据", default=dict)
class Meta:
db_table = "file"
@ -161,7 +214,6 @@ class File(AppModelMixin):
return result['data']
@receiver(pre_delete, sender=File)
def on_delete_file(sender, instance, **kwargs):
select_one(f'SELECT lo_unlink({instance.loid})', [])

View File

@ -27,6 +27,7 @@ from application.models import ApplicationDatasetMapping
from common.config.embedding_config import VectorStore
from common.db.search import get_dynamics_model, native_page_search, native_search
from common.db.sql_execute import select_list
from common.event import ListenerManagement
from common.exception.app_exception import AppApiException
from common.mixins.api_mixin import ApiMixin
from common.util.common import post, flat_map, valid_license
@ -34,7 +35,8 @@ from common.util.field_message import ErrMessage
from common.util.file_util import get_file_content
from common.util.fork import ChildLink, Fork
from common.util.split_model import get_split_model
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, ProblemParagraphMapping, Status
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, ProblemParagraphMapping, Status, \
TaskType, State
from dataset.serializers.common_serializers import list_paragraph, MetaSerializer, ProblemParagraphManage, \
get_embedding_model_by_dataset_id, get_embedding_model_id_by_dataset_id
from dataset.serializers.document_serializers import DocumentSerializers, DocumentInstanceSerializer
@ -733,9 +735,13 @@ class DataSetSerializers(serializers.ModelSerializer):
def re_embedding(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
QuerySet(Document).filter(dataset_id=self.data.get('id')).update(**{'status': Status.queue_up})
QuerySet(Paragraph).filter(dataset_id=self.data.get('id')).update(**{'status': Status.queue_up})
ListenerManagement.update_status(QuerySet(Document).filter(dataset_id=self.data.get('id')),
TaskType.EMBEDDING,
State.PENDING)
ListenerManagement.update_status(QuerySet(Paragraph).filter(dataset_id=self.data.get('id')),
TaskType.EMBEDDING,
State.PENDING)
ListenerManagement.get_aggregation_document_status_by_dataset_id(self.data.get('id'))()
embedding_model_id = get_embedding_model_id_by_dataset_id(self.data.get('id'))
embedding_by_dataset.delay(self.data.get('id'), embedding_model_id)

View File

@ -19,6 +19,7 @@ from celery_once import AlreadyQueued
from django.core import validators
from django.db import transaction
from django.db.models import QuerySet
from django.db.models.functions import Substr, Reverse
from django.http import HttpResponse
from drf_yasg import openapi
from openpyxl.cell.cell import ILLEGAL_CHARACTERS_RE
@ -26,6 +27,7 @@ from rest_framework import serializers
from xlwt import Utils
from common.db.search import native_search, native_page_search
from common.event import ListenerManagement
from common.event.common import work_thread_pool
from common.exception.app_exception import AppApiException
from common.handle.impl.doc_split_handle import DocSplitHandle
@ -44,7 +46,8 @@ from common.util.field_message import ErrMessage
from common.util.file_util import get_file_content
from common.util.fork import Fork
from common.util.split_model import get_split_model
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, Status, ProblemParagraphMapping, Image
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, ProblemParagraphMapping, Image, \
TaskType, State
from dataset.serializers.common_serializers import BatchSerializer, MetaSerializer, ProblemParagraphManage, \
get_embedding_model_id_by_dataset_id
from dataset.serializers.paragraph_serializers import ParagraphSerializers, ParagraphInstanceSerializer
@ -67,6 +70,19 @@ class FileBufferHandle:
return self.buffer
class CancelInstanceSerializer(serializers.Serializer):
type = serializers.IntegerField(required=True, error_messages=ErrMessage.boolean(
"任务类型"))
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
_type = self.data.get('type')
try:
TaskType(_type)
except Exception as e:
raise AppApiException(500, '任务类型不支持')
class DocumentEditInstanceSerializer(ApiMixin, serializers.Serializer):
meta = serializers.DictField(required=False)
name = serializers.CharField(required=False, max_length=128, min_length=1,
@ -278,7 +294,9 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
# 修改向量信息
if model_id:
delete_embedding_by_paragraph_ids(pid_list)
QuerySet(Document).filter(id__in=document_id_list).update(status=Status.queue_up)
ListenerManagement.update_status(QuerySet(Document).filter(id__in=document_id_list),
TaskType.EMBEDDING,
State.PENDING)
embedding_by_document_list.delay(document_id_list, model_id)
else:
update_embedding_dataset_id(pid_list, target_dataset_id)
@ -404,11 +422,13 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
self.is_valid(raise_exception=True)
document_id = self.data.get('document_id')
document = QuerySet(Document).filter(id=document_id).first()
state = State.SUCCESS
if document.type != Type.web:
return True
try:
document.status = Status.queue_up
document.save()
ListenerManagement.update_status(QuerySet(Document).filter(id=document_id),
TaskType.SYNC,
State.PENDING)
source_url = document.meta.get('source_url')
selector_list = document.meta.get('selector').split(
" ") if 'selector' in document.meta and document.meta.get('selector') is not None else []
@ -442,13 +462,18 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
if with_embedding:
embedding_model_id = get_embedding_model_id_by_dataset_id(document.dataset_id)
embedding_by_document.delay(document_id, embedding_model_id)
else:
document.status = Status.error
document.save()
state = State.FAILURE
except Exception as e:
logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}')
document.status = Status.error
document.save()
state = State.FAILURE
ListenerManagement.update_status(QuerySet(Document).filter(id=document_id),
TaskType.SYNC,
state)
ListenerManagement.update_status(QuerySet(Paragraph).filter(document_id=document_id),
TaskType.SYNC,
state)
return True
class Operate(ApiMixin, serializers.Serializer):
@ -586,14 +611,35 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
if with_valid:
self.is_valid(raise_exception=True)
document_id = self.data.get("document_id")
QuerySet(Document).filter(id=document_id).update(**{'status': Status.queue_up})
QuerySet(Paragraph).filter(document_id=document_id).update(**{'status': Status.queue_up})
ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), TaskType.EMBEDDING,
State.PENDING)
ListenerManagement.update_status(QuerySet(Paragraph).filter(document_id=document_id), TaskType.EMBEDDING,
State.PENDING)
ListenerManagement.get_aggregation_document_status(document_id)()
embedding_model_id = get_embedding_model_id_by_dataset_id(dataset_id=self.data.get('dataset_id'))
try:
embedding_by_document.delay(document_id, embedding_model_id)
except AlreadyQueued as e:
raise AppApiException(500, "任务正在执行中,请勿重复下发")
def cancel(self, instance, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
CancelInstanceSerializer(data=instance).is_valid()
document_id = self.data.get("document_id")
ListenerManagement.update_status(QuerySet(Paragraph).annotate(
reversed_status=Reverse('status'),
task_type_status=Substr('reversed_status', TaskType(instance.get('type')).value,
TaskType(instance.get('type')).value),
).filter(task_type_status__in=[State.PENDING.value, State.STARTED.value]).filter(
document_id=document_id).values('id'),
TaskType(instance.get('type')),
State.REVOKE)
ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), TaskType(instance.get('type')),
State.REVOKE)
return True
@transaction.atomic
def delete(self):
document_id = self.data.get("document_id")
@ -955,15 +1001,13 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
self.is_valid(raise_exception=True)
document_id_list = instance.get("id_list")
with transaction.atomic():
Document.objects.filter(id__in=document_id_list).update(status=Status.queue_up)
Paragraph.objects.filter(document_id__in=document_id_list).update(status=Status.queue_up)
dataset_id = self.data.get('dataset_id')
embedding_model_id = get_embedding_model_id_by_dataset_id(dataset_id=dataset_id)
for document_id in document_id_list:
try:
embedding_by_document.delay(document_id, embedding_model_id)
DocumentSerializers.Operate(
data={'dataset_id': dataset_id, 'document_id': document_id}).refresh()
except AlreadyQueued as e:
raise AppApiException(500, "任务正在执行中,请勿重复下发")
pass
class GenerateRelated(ApiMixin, serializers.Serializer):
document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("文档id"))
@ -978,7 +1022,13 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
if with_valid:
self.is_valid(raise_exception=True)
document_id = self.data.get('document_id')
QuerySet(Document).filter(id=document_id).update(status=Status.queue_up)
ListenerManagement.update_status(QuerySet(Document).filter(id=document_id),
TaskType.GENERATE_PROBLEM,
State.PENDING)
ListenerManagement.update_status(QuerySet(Paragraph).filter(document_id=document_id),
TaskType.GENERATE_PROBLEM,
State.PENDING)
ListenerManagement.get_aggregation_document_status(document_id)()
try:
generate_related_by_document_id.delay(document_id, model_id, prompt)
except AlreadyQueued as e:

View File

@ -16,11 +16,12 @@ from drf_yasg import openapi
from rest_framework import serializers
from common.db.search import page_search
from common.event import ListenerManagement
from common.exception.app_exception import AppApiException
from common.mixins.api_mixin import ApiMixin
from common.util.common import post
from common.util.field_message import ErrMessage
from dataset.models import Paragraph, Problem, Document, ProblemParagraphMapping, DataSet
from dataset.models import Paragraph, Problem, Document, ProblemParagraphMapping, DataSet, TaskType, State
from dataset.serializers.common_serializers import update_document_char_length, BatchSerializer, ProblemParagraphObject, \
ProblemParagraphManage, get_embedding_model_id_by_dataset_id
from dataset.serializers.problem_serializers import ProblemInstanceSerializer, ProblemSerializer, ProblemSerializers
@ -722,7 +723,6 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
}
)
class BatchGenerateRelated(ApiMixin, serializers.Serializer):
dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id"))
document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("文档id"))
@ -734,10 +734,16 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
paragraph_id_list = instance.get("paragraph_id_list")
model_id = instance.get("model_id")
prompt = instance.get("prompt")
document_id = self.data.get('document_id')
ListenerManagement.update_status(QuerySet(Document).filter(id=document_id),
TaskType.GENERATE_PROBLEM,
State.PENDING)
ListenerManagement.update_status(QuerySet(Paragraph).filter(id__in=paragraph_id_list),
TaskType.GENERATE_PROBLEM,
State.PENDING)
ListenerManagement.get_aggregation_document_status(document_id)()
try:
generate_related_by_paragraph_id_list.delay(paragraph_id_list, model_id, prompt)
generate_related_by_paragraph_id_list.delay(document_id, paragraph_id_list, model_id,
prompt)
except AlreadyQueued as e:
raise AppApiException(500, "任务正在执行中,请勿重复下发")

View File

@ -1,6 +1,7 @@
SELECT
"document".* ,
to_json("document"."meta") as meta,
to_json("document"."status_meta") as status_meta,
(SELECT "count"("id") FROM "paragraph" WHERE document_id="document"."id") as "paragraph_count"
FROM
"document" "document"

View File

@ -0,0 +1,25 @@
UPDATE "document" "document"
SET status_meta = jsonb_set ( "document".status_meta, '{aggs}', tmp.status_meta )
FROM
(
SELECT COALESCE
( jsonb_agg ( jsonb_delete ( ( row_to_json ( record ) :: JSONB ), 'document_id' ) ), '[]' :: JSONB ) AS status_meta,
document_id AS document_id
FROM
(
SELECT
"paragraph".status,
"count" ( "paragraph"."id" ),
"document"."id" AS document_id
FROM
"document" "document"
LEFT JOIN "paragraph" "paragraph" ON "document"."id" = paragraph.document_id
${document_custom_sql}
GROUP BY
"paragraph".status,
"document"."id"
) record
GROUP BY
document_id
) tmp
${default_sql}

View File

@ -0,0 +1,13 @@
UPDATE "${table_name}"
SET status = reverse (
SUBSTRING ( reverse ( LPAD( status, ${bit_number}, 'n' ) ) :: TEXT FROM 1 FOR ${up_index} ) || ${status_number} || SUBSTRING ( reverse ( LPAD( status, ${bit_number}, 'n' ) ) :: TEXT FROM ${next_index} )
),
status_meta = jsonb_set (
"${table_name}".status_meta,
'{state_time,${current_index}}',
jsonb_set (
COALESCE ( "${table_name}".status_meta #> '{state_time,${current_index}}', jsonb_build_object ( '${status_number}', now( ) ) ),
'{${status_number}}',
CONCAT ( '"', now( ), '"' ) :: JSONB
)
)

View File

@ -26,3 +26,14 @@ class DocumentApi(ApiMixin):
'directly_return_similarity': openapi.Schema(type=openapi.TYPE_NUMBER, title="直接返回相似度")
}
)
class Cancel(ApiMixin):
@staticmethod
def get_request_body_api():
return openapi.Schema(
type=openapi.TYPE_OBJECT,
properties={
'type': openapi.Schema(type=openapi.TYPE_INTEGER, title="任务类型",
description="1|2|3 1:向量化|2:生成问题|3:同步文档")
}
)

View File

@ -1,12 +1,14 @@
import logging
from math import ceil
import traceback
from celery_once import QueueOnce
from django.db.models import QuerySet
from langchain_core.messages import HumanMessage
from common.config.embedding_config import ModelManage
from dataset.models import Paragraph, Document, Status
from common.event import ListenerManagement
from common.util.page_utils import page
from dataset.models import Paragraph, Document, Status, TaskType, State
from dataset.task.tools import save_problem
from ops import celery_app
from setting.models import Model
@ -21,44 +23,79 @@ def get_llm_model(model_id):
return ModelManage.get_model(model_id, lambda _id: get_model(model))
def generate_problem_by_paragraph(paragraph, llm_model, prompt):
try:
ListenerManagement.update_status(QuerySet(Paragraph).filter(id=paragraph.id), TaskType.GENERATE_PROBLEM,
State.STARTED)
res = llm_model.invoke([HumanMessage(content=prompt.replace('{data}', paragraph.content))])
if (res.content is None) or (len(res.content) == 0):
return
problems = res.content.split('\n')
for problem in problems:
save_problem(paragraph.dataset_id, paragraph.document_id, paragraph.id, problem)
ListenerManagement.update_status(QuerySet(Paragraph).filter(id=paragraph.id), TaskType.GENERATE_PROBLEM,
State.SUCCESS)
except Exception as e:
ListenerManagement.update_status(QuerySet(Paragraph).filter(id=paragraph.id), TaskType.GENERATE_PROBLEM,
State.FAILURE)
def get_generate_problem(llm_model, prompt, post_apply=lambda: None, is_the_task_interrupted=lambda: False):
def generate_problem(paragraph_list):
for paragraph in paragraph_list:
if is_the_task_interrupted():
return
generate_problem_by_paragraph(paragraph, llm_model, prompt)
post_apply()
return generate_problem
@celery_app.task(base=QueueOnce, once={'keys': ['document_id']},
name='celery:generate_related_by_document')
def generate_related_by_document_id(document_id, model_id, prompt):
llm_model = get_llm_model(model_id)
offset = 0
page_size = 10
QuerySet(Document).filter(id=document_id).update(status=Status.generating)
try:
ListenerManagement.update_status(QuerySet(Document).filter(id=document_id),
TaskType.GENERATE_PROBLEM,
State.STARTED)
llm_model = get_llm_model(model_id)
count = QuerySet(Paragraph).filter(document_id=document_id).count()
for i in range(0, ceil(count / page_size)):
paragraph_list = QuerySet(Paragraph).filter(document_id=document_id).all()[offset:offset + page_size]
offset += page_size
for paragraph in paragraph_list:
res = llm_model.invoke([HumanMessage(content=prompt.replace('{data}', paragraph.content))])
if (res.content is None) or (len(res.content) == 0):
continue
problems = res.content.split('\n')
for problem in problems:
save_problem(paragraph.dataset_id, paragraph.document_id, paragraph.id, problem)
QuerySet(Document).filter(id=document_id).update(status=Status.success)
def is_the_task_interrupted():
document = QuerySet(Document).filter(id=document_id).first()
if document is None or Status(document.status)[TaskType.GENERATE_PROBLEM] == State.REVOKE:
return True
return False
# 生成问题函数
generate_problem = get_generate_problem(llm_model, prompt,
ListenerManagement.get_aggregation_document_status(
document_id), is_the_task_interrupted)
page(QuerySet(Paragraph).filter(document_id=document_id), 10, generate_problem, is_the_task_interrupted)
except Exception as e:
max_kb_error.error(f'根据文档生成问题:{document_id}出现错误{str(e)}{traceback.format_exc()}')
finally:
ListenerManagement.post_update_document_status(document_id, TaskType.GENERATE_PROBLEM)
max_kb.info(f"结束--->生成问题:{document_id}")
@celery_app.task(base=QueueOnce, once={'keys': ['paragraph_id_list']},
name='celery:generate_related_by_paragraph_list')
def generate_related_by_paragraph_id_list(paragraph_id_list, model_id, prompt):
llm_model = get_llm_model(model_id)
offset = 0
page_size = 10
count = QuerySet(Paragraph).filter(id__in=paragraph_id_list).count()
for i in range(0, ceil(count / page_size)):
paragraph_list = QuerySet(Paragraph).filter(id__in=paragraph_id_list).all()[offset:offset + page_size]
offset += page_size
for paragraph in paragraph_list:
res = llm_model.invoke([HumanMessage(content=prompt.replace('{data}', paragraph.content))])
if (res.content is None) or (len(res.content) == 0):
continue
problems = res.content.split('\n')
for problem in problems:
save_problem(paragraph.dataset_id, paragraph.document_id, paragraph.id, problem)
def generate_related_by_paragraph_id_list(document_id, paragraph_id_list, model_id, prompt):
try:
ListenerManagement.update_status(QuerySet(Document).filter(id=document_id),
TaskType.GENERATE_PROBLEM,
State.STARTED)
llm_model = get_llm_model(model_id)
# 生成问题函数
generate_problem = get_generate_problem(llm_model, prompt, ListenerManagement.get_aggregation_document_status(
document_id))
def is_the_task_interrupted():
document = QuerySet(Document).filter(id=document_id).first()
if document is None or Status(document.status)[TaskType.GENERATE_PROBLEM] == State.REVOKE:
return True
return False
page(QuerySet(Paragraph).filter(id__in=paragraph_id_list), 10, generate_problem, is_the_task_interrupted)
finally:
ListenerManagement.post_update_document_status(document_id, TaskType.GENERATE_PROBLEM)

View File

@ -37,6 +37,7 @@ urlpatterns = [
name="document_export"),
path('dataset/<str:dataset_id>/document/<str:document_id>/sync', views.Document.SyncWeb.as_view()),
path('dataset/<str:dataset_id>/document/<str:document_id>/refresh', views.Document.Refresh.as_view()),
path('dataset/<str:dataset_id>/document/<str:document_id>/cancel_task', views.Document.CancelTask.as_view()),
path('dataset/<str:dataset_id>/document/<str:document_id>/paragraph', views.Paragraph.as_view()),
path('dataset/<str:dataset_id>/document/batch_generate_related', views.Document.BatchGenerateRelated.as_view()),
path(
@ -45,7 +46,8 @@ urlpatterns = [
path('dataset/<str:dataset_id>/document/<str:document_id>/paragraph/_batch', views.Paragraph.Batch.as_view()),
path('dataset/<str:dataset_id>/document/<str:document_id>/paragraph/<int:current_page>/<int:page_size>',
views.Paragraph.Page.as_view(), name='paragraph_page'),
path('dataset/<str:dataset_id>/document/<str:document_id>/paragraph/batch_generate_related', views.Paragraph.BatchGenerateRelated.as_view()),
path('dataset/<str:dataset_id>/document/<str:document_id>/paragraph/batch_generate_related',
views.Paragraph.BatchGenerateRelated.as_view()),
path('dataset/<str:dataset_id>/document/<str:document_id>/paragraph/<paragraph_id>',
views.Paragraph.Operate.as_view()),
path('dataset/<str:dataset_id>/document/<str:document_id>/paragraph/<str:paragraph_id>/problem',

View File

@ -218,6 +218,26 @@ class Document(APIView):
DocumentSerializers.Sync(data={'document_id': document_id, 'dataset_id': dataset_id}).sync(
))
class CancelTask(APIView):
authentication_classes = [TokenAuth]
@action(methods=['PUT'], detail=False)
@swagger_auto_schema(operation_summary="取消任务",
operation_id="取消任务",
manual_parameters=DocumentSerializers.Operate.get_request_params_api(),
request_body=DocumentApi.Cancel.get_request_body_api(),
responses=result.get_default_response(),
tags=["知识库/文档"]
)
@has_permissions(
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
dynamic_tag=k.get('dataset_id')))
def put(self, request: Request, dataset_id: str, document_id: str):
return result.success(
DocumentSerializers.Operate(data={'document_id': document_id, 'dataset_id': dataset_id}).cancel(
request.data
))
class Refresh(APIView):
authentication_classes = [TokenAuth]

View File

@ -86,20 +86,20 @@ class BaseVectorStore(ABC):
for child_array in result:
self._batch_save(child_array, embedding, lambda: True)
def batch_save(self, data_list: List[Dict], embedding: Embeddings, is_save_function):
def batch_save(self, data_list: List[Dict], embedding: Embeddings, is_the_task_interrupted):
"""
批量插入
@param data_list: 数据列表
@param embedding: 向量化处理器
@param is_save_function:
@param is_the_task_interrupted: 判断是否中断任务
:return: bool
"""
self.save_pre_handler()
chunk_list = chunk_data_list(data_list)
result = sub_array(chunk_list)
for child_array in result:
if is_save_function():
self._batch_save(child_array, embedding, is_save_function)
if not is_the_task_interrupted():
self._batch_save(child_array, embedding, is_the_task_interrupted)
else:
break
@ -110,7 +110,7 @@ class BaseVectorStore(ABC):
pass
@abstractmethod
def _batch_save(self, text_list: List[Dict], embedding: Embeddings, is_save_function):
def _batch_save(self, text_list: List[Dict], embedding: Embeddings, is_the_task_interrupted):
pass
def search(self, query_text, dataset_id_list: list[str], exclude_document_id_list: list[str],

View File

@ -57,7 +57,7 @@ class PGVector(BaseVectorStore):
embedding.save()
return True
def _batch_save(self, text_list: List[Dict], embedding: Embeddings, is_save_function):
def _batch_save(self, text_list: List[Dict], embedding: Embeddings, is_the_task_interrupted):
texts = [row.get('text') for row in text_list]
embeddings = embedding.embed_documents(texts)
embedding_list = [Embedding(id=uuid.uuid1(),
@ -70,7 +70,7 @@ class PGVector(BaseVectorStore):
embedding=embeddings[index],
search_vector=to_ts_vector(text_list[index]['text'])) for index in
range(0, len(texts))]
if is_save_function():
if not is_the_task_interrupted():
QuerySet(Embedding).bulk_create(embedding_list) if len(embedding_list) > 0 else None
return True

View File

@ -208,6 +208,7 @@ class CeleryThreadTaskFileHandler(CeleryThreadingLoggerHandler):
f.flush()
def handle_task_start(self, task_id):
print('handle_task_start')
log_path = get_celery_task_log_path(task_id)
thread_id = self.get_current_thread_id()
self.task_id_thread_id_mapper[task_id] = thread_id
@ -215,6 +216,7 @@ class CeleryThreadTaskFileHandler(CeleryThreadingLoggerHandler):
self.thread_id_fd_mapper[thread_id] = f
def handle_task_end(self, task_id):
print('handle_task_end')
ident_id = self.task_id_thread_id_mapper.get(task_id, '')
f = self.thread_id_fd_mapper.pop(ident_id, None)
if f and not f.closed:

View File

@ -5,7 +5,7 @@ import os
from celery import subtask
from celery.signals import (
worker_ready, worker_shutdown, after_setup_logger
worker_ready, worker_shutdown, after_setup_logger, task_revoked, task_prerun
)
from django.core.cache import cache
from django_celery_beat.models import PeriodicTask
@ -61,3 +61,15 @@ def add_celery_logger_handler(sender=None, logger=None, loglevel=None, format=No
formatter = logging.Formatter(format)
task_handler.setFormatter(formatter)
logger.addHandler(task_handler)
@task_revoked.connect
def on_task_revoked(request, terminated, signum, expired, **kwargs):
print('task_revoked', terminated)
@task_prerun.connect
def on_taskaa_start(sender, task_id, **kwargs):
pass
# sender.update_state(state='REVOKED',
# meta={'exc_type': 'Exception', 'exc': 'Exception', 'message': '暂停任务', 'exc_message': ''})

View File

@ -322,8 +322,17 @@ const batchGenerateRelated: (
data: any,
loading?: Ref<boolean>
) => Promise<Result<boolean>> = (dataset_id, data, loading) => {
return put(`${prefix}/${dataset_id}/document/batch_generate_related`, data, undefined, loading)
}
const cancelTask: (
dataset_id: string,
document_id: string,
data: any,
loading?: Ref<boolean>
) => Promise<Result<boolean>> = (dataset_id, document_id, data, loading) => {
return put(
`${prefix}/${dataset_id}/document/batch_generate_related`,
`${prefix}/${dataset_id}/document/${document_id}/cancel_task`,
data,
undefined,
loading
@ -352,5 +361,6 @@ export default {
postTableDocument,
exportDocument,
batchRefresh,
batchGenerateRelated
batchGenerateRelated,
cancelTask
}

View File

@ -1307,5 +1307,26 @@ export const iconMap: any = {
)
])
}
},
'app-close': {
iconReader: () => {
return h('i', [
h(
'svg',
{
style: { height: '100%', width: '100%' },
viewBox: '0 0 16 16',
version: '1.1',
xmlns: 'http://www.w3.org/2000/svg'
},
[
h('path', {
d: 'M7.96141 6.98572L12.4398 2.50738C12.5699 2.3772 12.781 2.3772 12.9112 2.50738L13.3826 2.97878C13.5127 3.10895 13.5127 3.32001 13.3826 3.45018L8.90422 7.92853L13.3826 12.4069C13.5127 12.537 13.5127 12.7481 13.3826 12.8783L12.9112 13.3497C12.781 13.4799 12.5699 13.4799 12.4398 13.3497L7.96141 8.87134L3.48307 13.3497C3.35289 13.4799 3.14184 13.4799 3.01166 13.3497L2.54026 12.8783C2.41008 12.7481 2.41008 12.537 2.54026 12.4069L7.0186 7.92853L2.54026 3.45018C2.41008 3.32001 2.41008 3.10895 2.54026 2.97878L3.01166 2.50738C3.14184 2.3772 3.35289 2.3772 3.48307 2.50738L7.96141 6.98572Z',
fill: 'currentColor'
})
]
)
])
}
}
}

68
ui/src/utils/status.ts Normal file
View File

@ -0,0 +1,68 @@
import { type Dict } from '@/api/type/common'
interface TaskTypeInterface {
// 向量化
EMBEDDING: number
// 生成问题
GENERATE_PROBLEM: number
// 同步
SYNC: number
}
interface StateInterface {
// 等待
PENDING: '0'
// 执行中
STARTED: '1'
// 成功
SUCCESS: '2'
// 失败
FAILURE: '3'
// 取消任务
REVOKE: '4'
// 取消成功
REVOKED: '5'
IGNORED: 'n'
}
const TaskType: TaskTypeInterface = {
EMBEDDING: 1,
GENERATE_PROBLEM: 2,
SYNC: 3
}
const State: StateInterface = {
// 等待
PENDING: '0',
// 执行中
STARTED: '1',
// 成功
SUCCESS: '2',
// 失败
FAILURE: '3',
// 取消任务
REVOKE: '4',
// 取消成功
REVOKED: '5',
IGNORED: 'n'
}
class Status {
task_status: Dict<any>
constructor(status?: string) {
if (!status) {
status = ''
}
status = status.split('').reverse().join('')
this.task_status = {}
for (let key in TaskType) {
const value = TaskType[key as keyof TaskTypeInterface]
const index = value - 1
this.task_status[value] = status[index] ? status[index] : 'n'
}
}
toString() {
const r = []
for (let key in TaskType) {
const value = TaskType[key as keyof TaskTypeInterface]
r.push(this.task_status[value])
}
return r.reverse().join('')
}
}
export { Status, State, TaskType, TaskTypeInterface, StateInterface }

View File

@ -0,0 +1,167 @@
<template>
<el-popover placement="top" :width="450" trigger="hover">
<template #default>
<el-row :gutter="3" v-for="status in statusTable" :key="status.type">
<el-col :span="4">{{ taskTypeMap[status.type] }} </el-col>
<el-col :span="4">
<el-text v-if="status.state === State.SUCCESS || status.state === State.REVOKED">
<el-icon class="success"><SuccessFilled /></el-icon>
{{ stateMap[status.state](status.type) }}
</el-text>
<el-text v-else-if="status.state === State.FAILURE">
<el-icon class="danger"><CircleCloseFilled /></el-icon>
{{ stateMap[status.state](status.type) }}
</el-text>
<el-text v-else-if="status.state === State.STARTED">
<el-icon class="is-loading primary"><Loading /></el-icon>
{{ stateMap[status.state](status.type) }}
</el-text>
<el-text v-else-if="status.state === State.PENDING">
<el-icon class="is-loading primary"><Loading /></el-icon>
{{ stateMap[status.state](status.type) }}
</el-text>
<el-text v-else-if="aggStatus?.value === State.REVOKE">
<el-icon class="is-loading primary"><Loading /></el-icon>
{{ stateMap[aggStatus.value](aggStatus.key) }}
</el-text>
</el-col>
<el-col :span="5">
完成
{{
Object.keys(status.aggs ? status.aggs : {})
.filter((k) => k == State.SUCCESS)
.map((k) => status.aggs[k])
.reduce((x: any, y: any) => x + y, 0)
}}/{{
Object.values(status.aggs ? status.aggs : {}).reduce((x: any, y: any) => x + y, 0)
}}
</el-col>
<el-col :span="9">
{{
status.time
? status.time[
status.state == State.REVOKED ? State.REVOKED : State.PENDING
]?.substring(0, 19)
: undefined
}}
</el-col>
</el-row>
</template>
<template #reference>
<el-text v-if="aggStatus?.value === State.SUCCESS || aggStatus?.value === State.REVOKED">
<el-icon class="success"><SuccessFilled /></el-icon>
{{ stateMap[aggStatus.value](aggStatus.key) }}
</el-text>
<el-text v-else-if="aggStatus?.value === State.FAILURE">
<el-icon class="danger"><CircleCloseFilled /></el-icon>
{{ stateMap[aggStatus.value](aggStatus.key) }}
</el-text>
<el-text v-else-if="aggStatus?.value === State.STARTED">
<el-icon class="is-loading primary"><Loading /></el-icon>
{{ stateMap[aggStatus.value](aggStatus.key) }}
</el-text>
<el-text v-else-if="aggStatus?.value === State.PENDING">
<el-icon class="is-loading primary"><Loading /></el-icon>
{{ stateMap[aggStatus.value](aggStatus.key) }}
</el-text>
<el-text v-else-if="aggStatus?.value === State.REVOKE">
<el-icon class="is-loading primary"><Loading /></el-icon>
{{ stateMap[aggStatus.value](aggStatus.key) }}
</el-text>
</template>
</el-popover>
</template>
<script setup lang="ts">
import { computed } from 'vue'
import { Status, TaskType, State, type TaskTypeInterface } from '@/utils/status'
import { mergeWith } from 'lodash'
const props = defineProps<{ status: string; statusMeta: any }>()
const checkList: Array<string> = [
State.REVOKE,
State.STARTED,
State.PENDING,
State.REVOKED,
State.FAILURE,
State.SUCCESS
]
const aggStatus = computed(() => {
for (const i in checkList) {
const state = checkList[i]
const index = props.status.indexOf(state)
if (index > -1) {
return { key: props.status.length - index, value: state }
}
}
})
const startedMap = {
[TaskType.EMBEDDING]: '索引中',
[TaskType.GENERATE_PROBLEM]: '生成中',
[TaskType.SYNC]: '同步中'
}
const taskTypeMap = {
[TaskType.EMBEDDING]: '向量化',
[TaskType.GENERATE_PROBLEM]: '生成问题',
[TaskType.SYNC]: '同步'
}
const stateMap: any = {
[State.PENDING]: (type: number) => '排队中',
[State.STARTED]: (type: number) => startedMap[type],
[State.REVOKE]: (type: number) => '取消中',
[State.REVOKED]: (type: number) => '成功',
[State.FAILURE]: (type: number) => '失败',
[State.SUCCESS]: (type: number) => '成功'
}
const parseAgg = (agg: { count: number; status: string }) => {
const status = new Status(agg.status)
return Object.keys(TaskType)
.map((key) => {
const value = TaskType[key as keyof TaskTypeInterface]
return { [value]: { [status.task_status[value]]: agg.count } }
})
.reduce((x, y) => ({ ...x, ...y }), {})
}
const customizer: (x: any, y: any) => any = (objValue: any, srcValue: any) => {
if (objValue == undefined && srcValue) {
return srcValue
}
if (srcValue == undefined && objValue) {
return objValue
}
//
if (typeof objValue === 'object' && typeof srcValue === 'object') {
// object
return mergeWith(objValue, srcValue, customizer)
} else {
//
return objValue + srcValue
}
}
const aggs = computed(() => {
return (props.statusMeta.aggs ? props.statusMeta.aggs : [])
.map((agg: any) => {
return parseAgg(agg)
})
.reduce((x: any, y: any) => {
return mergeWith(x, y, customizer)
}, {})
})
const statusTable = computed(() => {
return Object.keys(TaskType)
.map((key) => {
const value = TaskType[key as keyof TaskTypeInterface]
const parseStatus = new Status(props.status)
return {
type: value,
state: parseStatus.task_status[value],
aggs: aggs.value[value],
time: props.statusMeta.state_time[value]
}
})
.filter((item) => item.state !== State.IGNORED)
})
</script>
<style lang="scss" scoped></style>

View File

@ -134,21 +134,7 @@
</div>
</template>
<template #default="{ row }">
<el-text v-if="row.status === '1'">
<el-icon class="success"><SuccessFilled /></el-icon>
</el-text>
<el-text v-else-if="row.status === '2'">
<el-icon class="danger"><CircleCloseFilled /></el-icon>
</el-text>
<el-text v-else-if="row.status === '0'">
<el-icon class="is-loading primary"><Loading /></el-icon>
</el-text>
<el-text v-else-if="row.status === '3'">
<el-icon class="is-loading primary"><Loading /></el-icon>
</el-text>
<el-text v-else-if="row.status === '4'">
<el-icon class="is-loading primary"><Loading /></el-icon>
</el-text>
<StatusVlue :status="row.status" :status-meta="row.status_meta"></StatusVlue>
</template>
</el-table-column>
<el-table-column width="130">
@ -249,7 +235,7 @@
<template #default="{ row }">
<div v-if="datasetDetail.type === '0'">
<span class="mr-4">
<el-tooltip effect="dark" content="重新向量化" placement="top">
<el-tooltip effect="dark" content="向量化" placement="top">
<el-button type="primary" text @click.stop="refreshDocument(row)">
<AppIcon iconName="app-document-refresh" style="font-size: 16px"></AppIcon>
</el-button>
@ -298,7 +284,22 @@
</el-tooltip>
</span>
<span class="mr-4">
<el-tooltip effect="dark" content="重新向量化" placement="top">
<el-tooltip
effect="dark"
v-if="getTaskState(row.status, TaskType.EMBEDDING) == State.STARTED"
content="取消向量化"
placement="top"
>
<el-button
type="primary"
text
@click.stop="cancelTask(row, TaskType.EMBEDDING)"
>
<AppIcon iconName="app-close" style="font-size: 16px"></AppIcon>
</el-button>
</el-tooltip>
<el-tooltip effect="dark" v-else content="向量化" placement="top">
<el-button type="primary" text @click.stop="refreshDocument(row)">
<AppIcon iconName="app-document-refresh" style="font-size: 16px"></AppIcon>
</el-button>
@ -315,9 +316,18 @@
<el-dropdown-item icon="Setting" @click="settingDoc(row)"
>设置</el-dropdown-item
>
<el-dropdown-item @click="openGenerateDialog(row)">
<el-dropdown-item
v-if="
getTaskState(row.status, TaskType.GENERATE_PROBLEM) == State.STARTED
"
@click="cancelTask(row, TaskType.GENERATE_PROBLEM)"
>
<el-icon><Connection /></el-icon>
生成关联问题
取消生成问题
</el-dropdown-item>
<el-dropdown-item v-else @click="openGenerateDialog(row)">
<el-icon><Connection /></el-icon>
生成问题
</el-dropdown-item>
<el-dropdown-item @click="openDatasetDialog(row)">
<AppIcon iconName="app-migrate"></AppIcon>
@ -360,7 +370,9 @@ import { datetimeFormat } from '@/utils/time'
import { hitHandlingMethod } from '@/enums/document'
import { MsgSuccess, MsgConfirm, MsgError } from '@/utils/message'
import useStore from '@/stores'
import StatusVlue from '@/views/document/component/Status.vue'
import GenerateRelatedDialog from '@/views/document/component/GenerateRelatedDialog.vue'
import { TaskType, State } from '@/utils/status'
const router = useRouter()
const route = useRoute()
const {
@ -368,9 +380,11 @@ const {
} = route as any
const { common, dataset, document } = useStore()
const storeKey = 'documents'
const getTaskState = (status, taskType) => {
const statusList = status.split('').reverse()
return taskType - 1 > statusList.length + 1 ? 'n' : statusList[taskType - 1]
}
onBeforeRouteUpdate(() => {
common.savePage(storeKey, null)
common.saveCondition(storeKey, null)
@ -441,7 +455,11 @@ function beforeCommand(attr: string, val: any) {
command: val
}
}
const cancelTask = (row: any, task_type: number) => {
documentApi.cancelTask(row.dataset_id, row.id, { type: task_type }).then(() => {
MsgSuccess('发送成功')
})
}
function syncDataset() {
SyncWebDialogRef.value.open(id)
}