feat: 批量关联问题 (#1235)

This commit is contained in:
shaohuzhang1 2024-09-20 17:07:52 +08:00 committed by GitHub
parent 808941ba70
commit 1f26744cef
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 245 additions and 53 deletions

View File

@ -137,6 +137,11 @@ class ListenerManagement:
QuerySet(Paragraph).filter(id=paragraph_id).update(**{'status': status})
max_kb.info(f'结束--->向量化段落:{paragraph_id}')
@staticmethod
def embedding_by_data_list(data_list: List, embedding_model: Embeddings):
# 批量向量化
VectorStore.get_embedding_vector().batch_save(data_list, embedding_model, lambda: True)
@staticmethod
def embedding_by_document(document_id, embedding_model: Embeddings):
"""

View File

@ -8,6 +8,7 @@
"""
import os
import uuid
from functools import reduce
from typing import Dict, List
from django.db import transaction
@ -21,7 +22,8 @@ from common.util.field_message import ErrMessage
from common.util.file_util import get_file_content
from dataset.models import Problem, Paragraph, ProblemParagraphMapping, DataSet
from dataset.serializers.common_serializers import get_embedding_model_id_by_dataset_id
from embedding.task import delete_embedding_by_source_ids, update_problem_embedding
from embedding.models import SourceType
from embedding.task import delete_embedding_by_source_ids, update_problem_embedding, embedding_by_data_list
from smartdoc.conf import PROJECT_DIR
@ -50,6 +52,35 @@ class ProblemInstanceSerializer(ApiMixin, serializers.Serializer):
})
class AssociationParagraph(serializers.Serializer):
paragraph_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("段落id"))
document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("文档id"))
class BatchAssociation(serializers.Serializer):
problem_id_list = serializers.ListField(required=True, error_messages=ErrMessage.list("问题id列表"),
child=serializers.UUIDField(required=True,
error_messages=ErrMessage.uuid("问题id")))
paragraph_list = AssociationParagraph(many=True)
def is_exits(exits_problem_paragraph_mapping_list, new_paragraph_mapping):
filter_list = [exits_problem_paragraph_mapping for exits_problem_paragraph_mapping in
exits_problem_paragraph_mapping_list if
str(exits_problem_paragraph_mapping.paragraph_id) == new_paragraph_mapping.paragraph_id
and str(exits_problem_paragraph_mapping.problem_id) == new_paragraph_mapping.problem_id
and str(exits_problem_paragraph_mapping.dataset_id) == new_paragraph_mapping.dataset_id]
return len(filter_list) > 0
def to_problem_paragraph_mapping(problem, document_id: str, paragraph_id: str, dataset_id: str):
return ProblemParagraphMapping(id=uuid.uuid1(),
document_id=document_id,
paragraph_id=paragraph_id,
dataset_id=dataset_id,
problem_id=str(problem.id)), problem
class ProblemSerializers(ApiMixin, serializers.Serializer):
class Create(serializers.Serializer):
dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id"))
@ -115,6 +146,47 @@ class ProblemSerializers(ApiMixin, serializers.Serializer):
delete_embedding_by_source_ids(source_ids)
return True
def association(self, instance: Dict, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
BatchAssociation(data=instance).is_valid(raise_exception=True)
dataset_id = self.data.get('dataset_id')
paragraph_list = instance.get('paragraph_list')
problem_id_list = instance.get('problem_id_list')
problem_list = QuerySet(Problem).filter(id__in=problem_id_list)
exits_problem_paragraph_mapping = QuerySet(ProblemParagraphMapping).filter(problem_id__in=problem_id_list,
paragraph_id__in=[
p.get('paragraph_id')
for p in
paragraph_list])
problem_paragraph_mapping_list = [(problem_paragraph_mapping, problem) for
problem_paragraph_mapping, problem in reduce(lambda x, y: [*x, *y],
[[
to_problem_paragraph_mapping(
problem,
paragraph.get(
'document_id'),
paragraph.get(
'paragraph_id'),
dataset_id) for
paragraph in
paragraph_list]
for problem in
problem_list], []) if
not is_exits(exits_problem_paragraph_mapping, problem_paragraph_mapping)]
QuerySet(ProblemParagraphMapping).bulk_create(
[problem_paragraph_mapping for problem_paragraph_mapping, problem in problem_paragraph_mapping_list])
data_list = [{'text': problem.content,
'is_active': True,
'source_type': SourceType.PROBLEM,
'source_id': str(problem_paragraph_mapping.id),
'document_id': str(problem_paragraph_mapping.document_id),
'paragraph_id': str(problem_paragraph_mapping.paragraph_id),
'dataset_id': dataset_id,
} for problem_paragraph_mapping, problem in problem_paragraph_mapping_list]
model_id = get_embedding_model_id_by_dataset_id(self.data.get('dataset_id'))
embedding_by_data_list(data_list, model_id=model_id)
class Operate(serializers.Serializer):
dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id"))

View File

@ -36,6 +36,36 @@ class ProblemApi(ApiMixin):
}
)
class BatchAssociation(ApiMixin):
@staticmethod
def get_request_params_api():
return ProblemApi.BatchOperate.get_request_params_api()
@staticmethod
def get_request_body_api():
return openapi.Schema(
type=openapi.TYPE_OBJECT,
required=['problem_id_list'],
properties={
'problem_id_list': openapi.Schema(type=openapi.TYPE_ARRAY, title="问题id列表",
description="问题id列表",
items=openapi.Schema(type=openapi.TYPE_STRING)),
'paragraph_list': openapi.Schema(type=openapi.TYPE_ARRAY, title="关联段落信息列表",
description="关联段落信息列表",
items=openapi.Schema(type=openapi.TYPE_OBJECT,
required=['paragraph_id', 'document_id'],
properties={
'paragraph_id': openapi.Schema(
type=openapi.TYPE_STRING,
title="段落id"),
'document_id': openapi.Schema(
type=openapi.TYPE_STRING,
title="文档id")
}))
}
)
class BatchOperate(ApiMixin):
@staticmethod
def get_request_params_api():

View File

@ -88,6 +88,20 @@ class Problem(APIView):
return result.success(
ProblemSerializers.BatchOperate(data={'dataset_id': dataset_id}).delete(request.data))
@action(methods=['POST'], detail=False)
@swagger_auto_schema(operation_summary="批量关联段落",
operation_id="批量关联段落",
request_body=ProblemApi.BatchAssociation.get_request_body_api(),
manual_parameters=ProblemApi.BatchOperate.get_request_params_api(),
responses=result.get_default_response(),
tags=["知识库/文档/段落/问题"])
@has_permissions(
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
dynamic_tag=k.get('dataset_id')))
def post(self, request: Request, dataset_id: str):
return result.success(
ProblemSerializers.BatchOperate(data={'dataset_id': dataset_id}).association(request.data))
class Operate(APIView):
authentication_classes = [TokenAuth]

View File

@ -111,6 +111,11 @@ def embedding_by_problem(args, model_id):
ListenerManagement.embedding_by_problem(args, embedding_model)
def embedding_by_data_list(args: List, model_id):
embedding_model = get_embedding_model(model_id)
ListenerManagement.embedding_by_data_list(args, embedding_model)
def delete_embedding_by_document(document_id):
"""
删除指定文档id的向量

View File

@ -87,27 +87,21 @@ class BaseVectorStore(ABC):
self._batch_save(child_array, embedding, lambda: True)
def batch_save(self, data_list: List[Dict], embedding: Embeddings, is_save_function):
# 获取锁
lock.acquire()
try:
"""
批量插入
:param data_list: 数据列表
:param embedding: 向量化处理器
:return: bool
"""
self.save_pre_handler()
chunk_list = chunk_data_list(data_list)
result = sub_array(chunk_list)
for child_array in result:
if is_save_function():
self._batch_save(child_array, embedding, is_save_function)
else:
break
finally:
# 释放锁
lock.release()
return True
"""
批量插入
@param data_list: 数据列表
@param embedding: 向量化处理器
@param is_save_function:
:return: bool
"""
self.save_pre_handler()
chunk_list = chunk_data_list(data_list)
result = sub_array(chunk_list)
for child_array in result:
if is_save_function():
self._batch_save(child_array, embedding, is_save_function)
else:
break
@abstractmethod
def _save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str,

View File

@ -97,11 +97,28 @@ const getDetailProblems: (
return get(`${prefix}/${dataset_id}/problem/${problem_id}/paragraph`, undefined, loading)
}
/**
*
* @param dataset_id,
* {
"problem_id_list": "Array",
"paragraph_list": "Array",
}
*/
const postMulAssociationProblem: (
dataset_id: string,
data: any,
loading?: Ref<boolean>
) => Promise<Result<boolean>> = (dataset_id, data, loading) => {
return post(`${prefix}/${dataset_id}/problem/_batch`, data, undefined, loading)
}
export default {
getProblems,
postProblems,
delProblems,
putProblems,
getDetailProblems,
delMulProblem
delMulProblem,
postMulAssociationProblem
}

View File

@ -91,6 +91,12 @@
</el-scrollbar>
</el-col>
</el-row>
<template #footer v-if="isMul">
<div class="dialog-footer">
<el-button @click="dialogVisible = false"> 取消</el-button>
<el-button type="primary" @click="mulAssociation"> 确认 </el-button>
</div>
</template>
</el-dialog>
</template>
<script setup lang="ts">
@ -99,6 +105,7 @@ import { useRoute } from 'vue-router'
import problemApi from '@/api/problem'
import paragraphApi from '@/api/paragraph'
import useStore from '@/stores'
import { MsgSuccess } from '@/utils/message'
const { problem, document } = useStore()
@ -116,6 +123,7 @@ const documentList = ref<any[]>([])
const cloneDocumentList = ref<any[]>([])
const paragraphList = ref<any[]>([])
const currentProblemId = ref<String>('')
const currentMulProblemId = ref<string[]>([])
//
const associationParagraph = ref<any[]>([])
@ -124,6 +132,8 @@ const currentDocument = ref<String>('')
const search = ref('')
const searchType = ref('title')
const filterDoc = ref('')
//
const isMul = ref(false)
const paginationConfig = reactive({
current_page: 1,
@ -131,31 +141,53 @@ const paginationConfig = reactive({
total: 0
})
function mulAssociation() {
const data = {
problem_id_list: currentMulProblemId.value,
paragraph_list: associationParagraph.value.map((item) => ({
paragraph_id: item.id,
document_id: item.document_id
}))
}
problemApi.postMulAssociationProblem(id, data, loading).then(() => {
MsgSuccess('批量关联分段成功')
dialogVisible.value = false
})
}
function associationClick(item: any) {
if (isAssociation(item.id)) {
problem
.asyncDisassociationProblem(
id,
item.document_id,
item.id,
currentProblemId.value as string,
loading
)
.then(() => {
getRecord(currentProblemId.value)
})
if (isMul.value) {
if (isAssociation(item.id)) {
associationParagraph.value.splice(associationParagraph.value.indexOf(item.id), 1)
} else {
associationParagraph.value.push(item)
}
} else {
problem
.asyncAssociationProblem(
id,
item.document_id,
item.id,
currentProblemId.value as string,
loading
)
.then(() => {
getRecord(currentProblemId.value)
})
if (isAssociation(item.id)) {
problem
.asyncDisassociationProblem(
id,
item.document_id,
item.id,
currentProblemId.value as string,
loading
)
.then(() => {
getRecord(currentProblemId.value)
})
} else {
problem
.asyncAssociationProblem(
id,
item.document_id,
item.id,
currentProblemId.value as string,
loading
)
.then(() => {
getRecord(currentProblemId.value)
})
}
}
}
@ -216,6 +248,7 @@ watch(dialogVisible, (bool) => {
cloneDocumentList.value = []
paragraphList.value = []
associationParagraph.value = []
isMul.value = false
currentDocument.value = ''
search.value = ''
@ -232,10 +265,15 @@ watch(filterDoc, (val) => {
currentDocument.value = documentList.value?.length > 0 ? documentList.value[0].id : ''
})
const open = (problemId: string) => {
currentProblemId.value = problemId
const open = (problemId: any) => {
getDocument()
getRecord(problemId)
if (problemId.length == 1) {
currentProblemId.value = problemId[0]
getRecord(problemId)
} else if (problemId.length > 1) {
currentMulProblemId.value = problemId
isMul.value = true
}
dialogVisible.value = true
}

View File

@ -5,6 +5,9 @@
<div class="flex-between">
<div>
<el-button type="primary" @click="createProblem">创建问题</el-button>
<el-button @click="relateProblem()" :disabled="multipleSelection.length === 0"
>关联分段</el-button
>
<el-button @click="deleteMulDocument" :disabled="multipleSelection.length === 0"
>批量删除</el-button
>
@ -104,7 +107,7 @@
:next_disable="next_disable"
@refresh="refresh"
/>
<RelateProblemDialog ref="RelateProblemDialogRef" @refresh="refresh" />
<RelateProblemDialog ref="RelateProblemDialogRef" @refresh="refreshRelate" />
</LayoutContainer>
</template>
<script setup lang="ts">
@ -157,8 +160,19 @@ const problemIndexMap = computed<Dict<number>>(() => {
const multipleTableRef = ref<InstanceType<typeof ElTable>>()
const multipleSelection = ref<any[]>([])
function relateProblem(row: any) {
RelateProblemDialogRef.value.open(row.id)
function relateProblem(row?: any) {
const arr: string[] = []
if (row) {
arr.push(row.id)
} else {
multipleSelection.value.map((v) => {
if (v) {
arr.push(v.id)
}
})
}
RelateProblemDialogRef.value.open(arr)
}
function createProblem() {
@ -330,7 +344,10 @@ function getList() {
paginationConfig.total = res.data.total
})
}
function refreshRelate() {
getList()
multipleTableRef.value?.clearSelection()
}
function refresh() {
paginationConfig.current_page = 1
getList()