mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 01:33:05 +00:00
feat: Document vectorization supports processing based on status (#1984)
(cherry picked from commit 54381ffaf3)
This commit is contained in:
parent
ae56590f3f
commit
1bfaf2f024
|
|
@ -6,26 +6,22 @@
|
|||
@date:2023/10/20 14:01
|
||||
@desc:
|
||||
"""
|
||||
import datetime
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
from typing import List
|
||||
|
||||
import django.db.models
|
||||
from django.db import models, transaction
|
||||
from django.db.models import QuerySet
|
||||
from django.db.models.functions import Substr, Reverse
|
||||
from langchain_core.embeddings import Embeddings
|
||||
|
||||
from common.config.embedding_config import VectorStore
|
||||
from common.db.search import native_search, get_dynamics_model, native_update
|
||||
from common.db.sql_execute import sql_execute, update_execute
|
||||
from common.util.file_util import get_file_content
|
||||
from common.util.lock import try_lock, un_lock
|
||||
from common.util.page_utils import page
|
||||
from common.util.page_utils import page_desc
|
||||
from dataset.models import Paragraph, Status, Document, ProblemParagraphMapping, TaskType, State
|
||||
from embedding.models import SourceType, SearchMode
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
|
@ -162,7 +158,7 @@ class ListenerManagement:
|
|||
if is_the_task_interrupted():
|
||||
break
|
||||
ListenerManagement.embedding_by_paragraph(str(paragraph.get('id')), embedding_model)
|
||||
post_apply()
|
||||
post_apply()
|
||||
|
||||
return embedding_paragraph_apply
|
||||
|
||||
|
|
@ -241,13 +237,16 @@ class ListenerManagement:
|
|||
lock.release()
|
||||
|
||||
@staticmethod
|
||||
def embedding_by_document(document_id, embedding_model: Embeddings):
|
||||
def embedding_by_document(document_id, embedding_model: Embeddings, state_list=None):
|
||||
"""
|
||||
向量化文档
|
||||
@param state_list:
|
||||
@param document_id: 文档id
|
||||
@param embedding_model 向量模型
|
||||
:return: None
|
||||
"""
|
||||
if state_list is None:
|
||||
state_list = [State.PENDING, State.SUCCESS, State.FAILURE, State.REVOKE, State.REVOKED]
|
||||
if not try_lock('embedding' + str(document_id)):
|
||||
return
|
||||
try:
|
||||
|
|
@ -268,11 +267,17 @@ class ListenerManagement:
|
|||
VectorStore.get_embedding_vector().delete_by_document_id(document_id)
|
||||
|
||||
# 根据段落进行向量化处理
|
||||
page(QuerySet(Paragraph).filter(document_id=document_id).values('id'), 5,
|
||||
ListenerManagement.get_embedding_paragraph_apply(embedding_model, is_the_task_interrupted,
|
||||
ListenerManagement.get_aggregation_document_status(
|
||||
document_id)),
|
||||
is_the_task_interrupted)
|
||||
page_desc(QuerySet(Paragraph)
|
||||
.annotate(
|
||||
reversed_status=Reverse('status'),
|
||||
task_type_status=Substr('reversed_status', TaskType.EMBEDDING.value,
|
||||
1),
|
||||
).filter(task_type_status__in=state_list, document_id=document_id)
|
||||
.values('id'), 5,
|
||||
ListenerManagement.get_embedding_paragraph_apply(embedding_model, is_the_task_interrupted,
|
||||
ListenerManagement.get_aggregation_document_status(
|
||||
document_id)),
|
||||
is_the_task_interrupted)
|
||||
except Exception as e:
|
||||
max_kb_error.error(f'向量化文档:{document_id}出现错误{str(e)}{traceback.format_exc()}')
|
||||
finally:
|
||||
|
|
|
|||
|
|
@ -26,3 +26,22 @@ def page(query_set, page_size, handler, is_the_task_interrupted=lambda: False):
|
|||
offset = i * page_size
|
||||
paragraph_list = query.all()[offset: offset + page_size]
|
||||
handler(paragraph_list)
|
||||
|
||||
|
||||
def page_desc(query_set, page_size, handler, is_the_task_interrupted=lambda: False):
|
||||
"""
|
||||
|
||||
@param query_set: 查询query_set
|
||||
@param page_size: 每次查询大小
|
||||
@param handler: 数据处理器
|
||||
@param is_the_task_interrupted: 任务是否被中断
|
||||
@return:
|
||||
"""
|
||||
query = query_set.order_by("id")
|
||||
count = query_set.count()
|
||||
for i in sorted(range(0, ceil(count / page_size)), reverse=True):
|
||||
if is_the_task_interrupted():
|
||||
return
|
||||
offset = i * page_size
|
||||
paragraph_list = query.all()[offset: offset + page_size]
|
||||
handler(paragraph_list)
|
||||
|
|
|
|||
|
|
@ -700,20 +700,24 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||
_document.save()
|
||||
return self.one()
|
||||
|
||||
@transaction.atomic
|
||||
def refresh(self, with_valid=True):
|
||||
def refresh(self, state_list, with_valid=True):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
document_id = self.data.get("document_id")
|
||||
ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), TaskType.EMBEDDING,
|
||||
State.PENDING)
|
||||
ListenerManagement.update_status(QuerySet(Paragraph).filter(document_id=document_id),
|
||||
ListenerManagement.update_status(QuerySet(Paragraph).annotate(
|
||||
reversed_status=Reverse('status'),
|
||||
task_type_status=Substr('reversed_status', TaskType.EMBEDDING.value,
|
||||
1),
|
||||
).filter(task_type_status__in=state_list, document_id=document_id)
|
||||
.values('id'),
|
||||
TaskType.EMBEDDING,
|
||||
State.PENDING)
|
||||
ListenerManagement.get_aggregation_document_status(document_id)()
|
||||
embedding_model_id = get_embedding_model_id_by_dataset_id(dataset_id=self.data.get('dataset_id'))
|
||||
try:
|
||||
embedding_by_document.delay(document_id, embedding_model_id)
|
||||
embedding_by_document.delay(document_id, embedding_model_id, state_list)
|
||||
except AlreadyQueued as e:
|
||||
raise AppApiException(500, "任务正在执行中,请勿重复下发")
|
||||
|
||||
|
|
@ -1122,14 +1126,14 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
document_id_list = instance.get("id_list")
|
||||
with transaction.atomic():
|
||||
dataset_id = self.data.get('dataset_id')
|
||||
for document_id in document_id_list:
|
||||
try:
|
||||
DocumentSerializers.Operate(
|
||||
data={'dataset_id': dataset_id, 'document_id': document_id}).refresh()
|
||||
except AlreadyQueued as e:
|
||||
pass
|
||||
state_list = instance.get("state_list")
|
||||
dataset_id = self.data.get('dataset_id')
|
||||
for document_id in document_id_list:
|
||||
try:
|
||||
DocumentSerializers.Operate(
|
||||
data={'dataset_id': dataset_id, 'document_id': document_id}).refresh(state_list)
|
||||
except AlreadyQueued as e:
|
||||
pass
|
||||
|
||||
class GenerateRelated(ApiMixin, serializers.Serializer):
|
||||
document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("文档id"))
|
||||
|
|
|
|||
|
|
@ -51,3 +51,16 @@ class DocumentApi(ApiMixin):
|
|||
description="1|2|3 1:向量化|2:生成问题|3:同步文档", default=1)
|
||||
}
|
||||
)
|
||||
|
||||
class EmbeddingState(ApiMixin):
|
||||
@staticmethod
|
||||
def get_request_body_api():
|
||||
return openapi.Schema(
|
||||
type=openapi.TYPE_OBJECT,
|
||||
properties={
|
||||
'state_list': openapi.Schema(type=openapi.TYPE_ARRAY,
|
||||
items=openapi.Schema(type=openapi.TYPE_STRING),
|
||||
title="状态列表",
|
||||
description="状态列表")
|
||||
}
|
||||
)
|
||||
|
|
|
|||
|
|
@ -262,6 +262,7 @@ class Document(APIView):
|
|||
@action(methods=['PUT'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="刷新文档向量库",
|
||||
operation_id="刷新文档向量库",
|
||||
request_body=DocumentApi.EmbeddingState.get_request_body_api(),
|
||||
manual_parameters=DocumentSerializers.Operate.get_request_params_api(),
|
||||
responses=result.get_default_response(),
|
||||
tags=["知识库/文档"]
|
||||
|
|
@ -272,6 +273,7 @@ class Document(APIView):
|
|||
def put(self, request: Request, dataset_id: str, document_id: str):
|
||||
return result.success(
|
||||
DocumentSerializers.Operate(data={'document_id': document_id, 'dataset_id': dataset_id}).refresh(
|
||||
request.data.get('state_list')
|
||||
))
|
||||
|
||||
class BatchRefresh(APIView):
|
||||
|
|
|
|||
|
|
@ -56,14 +56,20 @@ def embedding_by_paragraph_list(paragraph_id_list, model_id):
|
|||
|
||||
|
||||
@celery_app.task(base=QueueOnce, once={'keys': ['document_id']}, name='celery:embedding_by_document')
|
||||
def embedding_by_document(document_id, model_id):
|
||||
def embedding_by_document(document_id, model_id, state_list=None):
|
||||
"""
|
||||
向量化文档
|
||||
@param state_list:
|
||||
@param document_id: 文档id
|
||||
@param model_id 向量模型
|
||||
:return: None
|
||||
"""
|
||||
|
||||
if state_list is None:
|
||||
state_list = [State.PENDING.value, State.STARTED.value, State.SUCCESS.value, State.FAILURE.value,
|
||||
State.REVOKE.value,
|
||||
State.REVOKED.value, State.IGNORED.value]
|
||||
|
||||
def exception_handler(e):
|
||||
ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), TaskType.EMBEDDING,
|
||||
State.FAILURE)
|
||||
|
|
@ -71,7 +77,7 @@ def embedding_by_document(document_id, model_id):
|
|||
f'获取向量模型失败:{str(e)}{traceback.format_exc()}')
|
||||
|
||||
embedding_model = get_embedding_model(model_id, exception_handler)
|
||||
ListenerManagement.embedding_by_document(document_id, embedding_model)
|
||||
ListenerManagement.embedding_by_document(document_id, embedding_model, state_list)
|
||||
|
||||
|
||||
@celery_app.task(name='celery:embedding_by_document_list')
|
||||
|
|
|
|||
|
|
@ -129,11 +129,12 @@ const delMulDocument: (
|
|||
const batchRefresh: (
|
||||
dataset_id: string,
|
||||
data: any,
|
||||
stateList: Array<string>,
|
||||
loading?: Ref<boolean>
|
||||
) => Promise<Result<boolean>> = (dataset_id, data, loading) => {
|
||||
) => Promise<Result<boolean>> = (dataset_id, data, stateList, loading) => {
|
||||
return put(
|
||||
`${prefix}/${dataset_id}/document/batch_refresh`,
|
||||
{ id_list: data },
|
||||
{ id_list: data, state_list: stateList },
|
||||
undefined,
|
||||
loading
|
||||
)
|
||||
|
|
@ -157,11 +158,12 @@ const getDocumentDetail: (dataset_id: string, document_id: string) => Promise<Re
|
|||
const putDocumentRefresh: (
|
||||
dataset_id: string,
|
||||
document_id: string,
|
||||
state_list: Array<string>,
|
||||
loading?: Ref<boolean>
|
||||
) => Promise<Result<any>> = (dataset_id, document_id, loading) => {
|
||||
) => Promise<Result<any>> = (dataset_id, document_id, state_list, loading) => {
|
||||
return put(
|
||||
`${prefix}/${dataset_id}/document/${document_id}/refresh`,
|
||||
undefined,
|
||||
{ state_list },
|
||||
undefined,
|
||||
loading
|
||||
)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,41 @@
|
|||
<template>
|
||||
<el-dialog v-model="dialogVisible" title="选择向量化内容" width="500" :before-close="close">
|
||||
<el-radio-group v-model="state">
|
||||
<el-radio value="error" size="large">向量化未成功的分段</el-radio>
|
||||
<el-radio value="all" size="large">全部分段</el-radio>
|
||||
</el-radio-group>
|
||||
<template #footer>
|
||||
<div class="dialog-footer">
|
||||
<el-button @click="close">取消</el-button>
|
||||
<el-button type="primary" @click="submit"> 提交 </el-button>
|
||||
</div>
|
||||
</template>
|
||||
</el-dialog>
|
||||
</template>
|
||||
<script setup lang="ts">
|
||||
import { ref } from 'vue'
|
||||
const dialogVisible = ref<boolean>(false)
|
||||
const state = ref<'all' | 'error'>('error')
|
||||
const stateMap = {
|
||||
all: ['0', '1', '2', '3', '4', '5', 'n'],
|
||||
error: ['0', '1', '3', '4', '5', 'n']
|
||||
}
|
||||
const submit_handle = ref<(stateList: Array<string>) => void>()
|
||||
const submit = () => {
|
||||
if (submit_handle.value) {
|
||||
submit_handle.value(stateMap[state.value])
|
||||
}
|
||||
close()
|
||||
}
|
||||
|
||||
const open = (handle: (stateList: Array<string>) => void) => {
|
||||
submit_handle.value = handle
|
||||
dialogVisible.value = true
|
||||
}
|
||||
const close = () => {
|
||||
submit_handle.value = undefined
|
||||
dialogVisible.value = false
|
||||
}
|
||||
defineExpose({ open, close })
|
||||
</script>
|
||||
<style lang="scss" scoped></style>
|
||||
|
|
@ -422,6 +422,7 @@
|
|||
</el-text>
|
||||
<el-button class="ml-16" type="primary" link @click="clearSelection"> 清空 </el-button>
|
||||
</div>
|
||||
<EmbeddingContentDialog ref="embeddingContentDialogRef"></EmbeddingContentDialog>
|
||||
</LayoutContainer>
|
||||
</template>
|
||||
<script setup lang="ts">
|
||||
|
|
@ -439,6 +440,7 @@ import { MsgSuccess, MsgConfirm, MsgError } from '@/utils/message'
|
|||
import useStore from '@/stores'
|
||||
import StatusVlue from '@/views/document/component/Status.vue'
|
||||
import GenerateRelatedDialog from '@/components/generate-related-dialog/index.vue'
|
||||
import EmbeddingContentDialog from '@/views/document/component/EmbeddingContentDialog.vue'
|
||||
import { TaskType, State } from '@/utils/status'
|
||||
const router = useRouter()
|
||||
const route = useRoute()
|
||||
|
|
@ -469,7 +471,7 @@ onBeforeRouteLeave((to: any) => {
|
|||
})
|
||||
const beforePagination = computed(() => common.paginationConfig[storeKey])
|
||||
const beforeSearch = computed(() => common.search[storeKey])
|
||||
|
||||
const embeddingContentDialogRef = ref<InstanceType<typeof EmbeddingContentDialog>>()
|
||||
const SyncWebDialogRef = ref()
|
||||
const loading = ref(false)
|
||||
let interval: any
|
||||
|
|
@ -621,10 +623,14 @@ function syncDocument(row: any) {
|
|||
.catch(() => {})
|
||||
}
|
||||
}
|
||||
|
||||
function refreshDocument(row: any) {
|
||||
documentApi.putDocumentRefresh(row.dataset_id, row.id).then(() => {
|
||||
getList()
|
||||
})
|
||||
const embeddingDocument = (stateList: Array<string>) => {
|
||||
return documentApi.putDocumentRefresh(row.dataset_id, row.id, stateList).then(() => {
|
||||
getList()
|
||||
})
|
||||
}
|
||||
embeddingContentDialogRef.value?.open(embeddingDocument)
|
||||
}
|
||||
|
||||
function rowClickHandle(row: any, column: any) {
|
||||
|
|
@ -691,19 +697,16 @@ function deleteMulDocument() {
|
|||
}
|
||||
|
||||
function batchRefresh() {
|
||||
const arr: string[] = []
|
||||
multipleSelection.value.map((v) => {
|
||||
if (v) {
|
||||
arr.push(v.id)
|
||||
}
|
||||
})
|
||||
documentApi.batchRefresh(id, arr, loading).then(() => {
|
||||
MsgSuccess('批量向量化成功')
|
||||
multipleTableRef.value?.clearSelection()
|
||||
})
|
||||
const arr: string[] = multipleSelection.value.map((v) => v.id)
|
||||
const embeddingBatchDocument = (stateList: Array<string>) => {
|
||||
documentApi.batchRefresh(id, arr, stateList, loading).then(() => {
|
||||
MsgSuccess('批量向量化成功')
|
||||
multipleTableRef.value?.clearSelection()
|
||||
})
|
||||
}
|
||||
embeddingContentDialogRef.value?.open(embeddingBatchDocument)
|
||||
}
|
||||
|
||||
|
||||
function deleteDocument(row: any) {
|
||||
MsgConfirm(
|
||||
`是否删除文档:${row.name} ?`,
|
||||
|
|
|
|||
Loading…
Reference in New Issue