feat: Document vectorization supports processing based on status (#1984)

This commit is contained in:
shaohuzhang1 2025-01-07 11:15:10 +08:00 committed by GitHub
parent 9a310bfb98
commit 54381ffaf3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 140 additions and 45 deletions

View File

@ -6,26 +6,22 @@
@date2023/10/20 14:01
@desc:
"""
import datetime
import logging
import os
import threading
import time
import traceback
from typing import List
import django.db.models
from django.db import models, transaction
from django.db.models import QuerySet
from django.db.models.functions import Substr, Reverse
from langchain_core.embeddings import Embeddings
from common.config.embedding_config import VectorStore
from common.db.search import native_search, get_dynamics_model, native_update
from common.db.sql_execute import sql_execute, update_execute
from common.util.file_util import get_file_content
from common.util.lock import try_lock, un_lock
from common.util.page_utils import page
from common.util.page_utils import page_desc
from dataset.models import Paragraph, Status, Document, ProblemParagraphMapping, TaskType, State
from embedding.models import SourceType, SearchMode
from smartdoc.conf import PROJECT_DIR
@ -162,7 +158,7 @@ class ListenerManagement:
if is_the_task_interrupted():
break
ListenerManagement.embedding_by_paragraph(str(paragraph.get('id')), embedding_model)
post_apply()
post_apply()
return embedding_paragraph_apply
@ -241,13 +237,16 @@ class ListenerManagement:
lock.release()
@staticmethod
def embedding_by_document(document_id, embedding_model: Embeddings):
def embedding_by_document(document_id, embedding_model: Embeddings, state_list=None):
"""
向量化文档
@param state_list:
@param document_id: 文档id
@param embedding_model 向量模型
:return: None
"""
if state_list is None:
state_list = [State.PENDING, State.SUCCESS, State.FAILURE, State.REVOKE, State.REVOKED]
if not try_lock('embedding' + str(document_id)):
return
try:
@ -268,11 +267,17 @@ class ListenerManagement:
VectorStore.get_embedding_vector().delete_by_document_id(document_id)
# 根据段落进行向量化处理
page(QuerySet(Paragraph).filter(document_id=document_id).values('id'), 5,
ListenerManagement.get_embedding_paragraph_apply(embedding_model, is_the_task_interrupted,
ListenerManagement.get_aggregation_document_status(
document_id)),
is_the_task_interrupted)
page_desc(QuerySet(Paragraph)
.annotate(
reversed_status=Reverse('status'),
task_type_status=Substr('reversed_status', TaskType.EMBEDDING.value,
1),
).filter(task_type_status__in=state_list, document_id=document_id)
.values('id'), 5,
ListenerManagement.get_embedding_paragraph_apply(embedding_model, is_the_task_interrupted,
ListenerManagement.get_aggregation_document_status(
document_id)),
is_the_task_interrupted)
except Exception as e:
max_kb_error.error(f'向量化文档:{document_id}出现错误{str(e)}{traceback.format_exc()}')
finally:

View File

@ -26,3 +26,22 @@ def page(query_set, page_size, handler, is_the_task_interrupted=lambda: False):
offset = i * page_size
paragraph_list = query.all()[offset: offset + page_size]
handler(paragraph_list)
def page_desc(query_set, page_size, handler, is_the_task_interrupted=lambda: False):
"""
@param query_set: 查询query_set
@param page_size: 每次查询大小
@param handler: 数据处理器
@param is_the_task_interrupted: 任务是否被中断
@return:
"""
query = query_set.order_by("id")
count = query_set.count()
for i in sorted(range(0, ceil(count / page_size)), reverse=True):
if is_the_task_interrupted():
return
offset = i * page_size
paragraph_list = query.all()[offset: offset + page_size]
handler(paragraph_list)

View File

@ -700,20 +700,24 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
_document.save()
return self.one()
@transaction.atomic
def refresh(self, with_valid=True):
def refresh(self, state_list, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
document_id = self.data.get("document_id")
ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), TaskType.EMBEDDING,
State.PENDING)
ListenerManagement.update_status(QuerySet(Paragraph).filter(document_id=document_id),
ListenerManagement.update_status(QuerySet(Paragraph).annotate(
reversed_status=Reverse('status'),
task_type_status=Substr('reversed_status', TaskType.EMBEDDING.value,
1),
).filter(task_type_status__in=state_list, document_id=document_id)
.values('id'),
TaskType.EMBEDDING,
State.PENDING)
ListenerManagement.get_aggregation_document_status(document_id)()
embedding_model_id = get_embedding_model_id_by_dataset_id(dataset_id=self.data.get('dataset_id'))
try:
embedding_by_document.delay(document_id, embedding_model_id)
embedding_by_document.delay(document_id, embedding_model_id, state_list)
except AlreadyQueued as e:
raise AppApiException(500, "任务正在执行中,请勿重复下发")
@ -1122,14 +1126,14 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
if with_valid:
self.is_valid(raise_exception=True)
document_id_list = instance.get("id_list")
with transaction.atomic():
dataset_id = self.data.get('dataset_id')
for document_id in document_id_list:
try:
DocumentSerializers.Operate(
data={'dataset_id': dataset_id, 'document_id': document_id}).refresh()
except AlreadyQueued as e:
pass
state_list = instance.get("state_list")
dataset_id = self.data.get('dataset_id')
for document_id in document_id_list:
try:
DocumentSerializers.Operate(
data={'dataset_id': dataset_id, 'document_id': document_id}).refresh(state_list)
except AlreadyQueued as e:
pass
class GenerateRelated(ApiMixin, serializers.Serializer):
document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("文档id"))

View File

@ -51,3 +51,16 @@ class DocumentApi(ApiMixin):
description="1|2|3 1:向量化|2:生成问题|3:同步文档", default=1)
}
)
class EmbeddingState(ApiMixin):
@staticmethod
def get_request_body_api():
return openapi.Schema(
type=openapi.TYPE_OBJECT,
properties={
'state_list': openapi.Schema(type=openapi.TYPE_ARRAY,
items=openapi.Schema(type=openapi.TYPE_STRING),
title="状态列表",
description="状态列表")
}
)

View File

@ -262,6 +262,7 @@ class Document(APIView):
@action(methods=['PUT'], detail=False)
@swagger_auto_schema(operation_summary="刷新文档向量库",
operation_id="刷新文档向量库",
request_body=DocumentApi.EmbeddingState.get_request_body_api(),
manual_parameters=DocumentSerializers.Operate.get_request_params_api(),
responses=result.get_default_response(),
tags=["知识库/文档"]
@ -272,6 +273,7 @@ class Document(APIView):
def put(self, request: Request, dataset_id: str, document_id: str):
return result.success(
DocumentSerializers.Operate(data={'document_id': document_id, 'dataset_id': dataset_id}).refresh(
request.data.get('state_list')
))
class BatchRefresh(APIView):

View File

@ -56,14 +56,20 @@ def embedding_by_paragraph_list(paragraph_id_list, model_id):
@celery_app.task(base=QueueOnce, once={'keys': ['document_id']}, name='celery:embedding_by_document')
def embedding_by_document(document_id, model_id):
def embedding_by_document(document_id, model_id, state_list=None):
"""
向量化文档
@param state_list:
@param document_id: 文档id
@param model_id 向量模型
:return: None
"""
if state_list is None:
state_list = [State.PENDING.value, State.STARTED.value, State.SUCCESS.value, State.FAILURE.value,
State.REVOKE.value,
State.REVOKED.value, State.IGNORED.value]
def exception_handler(e):
ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), TaskType.EMBEDDING,
State.FAILURE)
@ -71,7 +77,7 @@ def embedding_by_document(document_id, model_id):
f'获取向量模型失败:{str(e)}{traceback.format_exc()}')
embedding_model = get_embedding_model(model_id, exception_handler)
ListenerManagement.embedding_by_document(document_id, embedding_model)
ListenerManagement.embedding_by_document(document_id, embedding_model, state_list)
@celery_app.task(name='celery:embedding_by_document_list')

View File

@ -129,11 +129,12 @@ const delMulDocument: (
const batchRefresh: (
dataset_id: string,
data: any,
stateList: Array<string>,
loading?: Ref<boolean>
) => Promise<Result<boolean>> = (dataset_id, data, loading) => {
) => Promise<Result<boolean>> = (dataset_id, data, stateList, loading) => {
return put(
`${prefix}/${dataset_id}/document/batch_refresh`,
{ id_list: data },
{ id_list: data, state_list: stateList },
undefined,
loading
)
@ -157,11 +158,12 @@ const getDocumentDetail: (dataset_id: string, document_id: string) => Promise<Re
const putDocumentRefresh: (
dataset_id: string,
document_id: string,
state_list: Array<string>,
loading?: Ref<boolean>
) => Promise<Result<any>> = (dataset_id, document_id, loading) => {
) => Promise<Result<any>> = (dataset_id, document_id, state_list, loading) => {
return put(
`${prefix}/${dataset_id}/document/${document_id}/refresh`,
undefined,
{ state_list },
undefined,
loading
)

View File

@ -0,0 +1,41 @@
<template>
<el-dialog v-model="dialogVisible" title="选择向量化内容" width="500" :before-close="close">
<el-radio-group v-model="state">
<el-radio value="error" size="large">向量化未成功的分段</el-radio>
<el-radio value="all" size="large">全部分段</el-radio>
</el-radio-group>
<template #footer>
<div class="dialog-footer">
<el-button @click="close">取消</el-button>
<el-button type="primary" @click="submit"> 提交 </el-button>
</div>
</template>
</el-dialog>
</template>
<script setup lang="ts">
import { ref } from 'vue'
const dialogVisible = ref<boolean>(false)
const state = ref<'all' | 'error'>('error')
const stateMap = {
all: ['0', '1', '2', '3', '4', '5', 'n'],
error: ['0', '1', '3', '4', '5', 'n']
}
const submit_handle = ref<(stateList: Array<string>) => void>()
const submit = () => {
if (submit_handle.value) {
submit_handle.value(stateMap[state.value])
}
close()
}
const open = (handle: (stateList: Array<string>) => void) => {
submit_handle.value = handle
dialogVisible.value = true
}
const close = () => {
submit_handle.value = undefined
dialogVisible.value = false
}
defineExpose({ open, close })
</script>
<style lang="scss" scoped></style>

View File

@ -422,6 +422,7 @@
</el-text>
<el-button class="ml-16" type="primary" link @click="clearSelection"> 清空 </el-button>
</div>
<EmbeddingContentDialog ref="embeddingContentDialogRef"></EmbeddingContentDialog>
</LayoutContainer>
</template>
<script setup lang="ts">
@ -439,6 +440,7 @@ import { MsgSuccess, MsgConfirm, MsgError } from '@/utils/message'
import useStore from '@/stores'
import StatusVlue from '@/views/document/component/Status.vue'
import GenerateRelatedDialog from '@/components/generate-related-dialog/index.vue'
import EmbeddingContentDialog from '@/views/document/component/EmbeddingContentDialog.vue'
import { TaskType, State } from '@/utils/status'
const router = useRouter()
const route = useRoute()
@ -469,7 +471,7 @@ onBeforeRouteLeave((to: any) => {
})
const beforePagination = computed(() => common.paginationConfig[storeKey])
const beforeSearch = computed(() => common.search[storeKey])
const embeddingContentDialogRef = ref<InstanceType<typeof EmbeddingContentDialog>>()
const SyncWebDialogRef = ref()
const loading = ref(false)
let interval: any
@ -621,10 +623,14 @@ function syncDocument(row: any) {
.catch(() => {})
}
}
function refreshDocument(row: any) {
documentApi.putDocumentRefresh(row.dataset_id, row.id).then(() => {
getList()
})
const embeddingDocument = (stateList: Array<string>) => {
return documentApi.putDocumentRefresh(row.dataset_id, row.id, stateList).then(() => {
getList()
})
}
embeddingContentDialogRef.value?.open(embeddingDocument)
}
function rowClickHandle(row: any, column: any) {
@ -691,19 +697,16 @@ function deleteMulDocument() {
}
function batchRefresh() {
const arr: string[] = []
multipleSelection.value.map((v) => {
if (v) {
arr.push(v.id)
}
})
documentApi.batchRefresh(id, arr, loading).then(() => {
MsgSuccess('批量向量化成功')
multipleTableRef.value?.clearSelection()
})
const arr: string[] = multipleSelection.value.map((v) => v.id)
const embeddingBatchDocument = (stateList: Array<string>) => {
documentApi.batchRefresh(id, arr, stateList, loading).then(() => {
MsgSuccess('批量向量化成功')
multipleTableRef.value?.clearSelection()
})
}
embeddingContentDialogRef.value?.open(embeddingBatchDocument)
}
function deleteDocument(row: any) {
MsgConfirm(
`是否删除文档:${row.name} ?`,