mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 01:33:05 +00:00
feat: 批量关联问题 (#1235)
This commit is contained in:
parent
808941ba70
commit
1f26744cef
|
|
@ -137,6 +137,11 @@ class ListenerManagement:
|
|||
QuerySet(Paragraph).filter(id=paragraph_id).update(**{'status': status})
|
||||
max_kb.info(f'结束--->向量化段落:{paragraph_id}')
|
||||
|
||||
@staticmethod
|
||||
def embedding_by_data_list(data_list: List, embedding_model: Embeddings):
|
||||
# 批量向量化
|
||||
VectorStore.get_embedding_vector().batch_save(data_list, embedding_model, lambda: True)
|
||||
|
||||
@staticmethod
|
||||
def embedding_by_document(document_id, embedding_model: Embeddings):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@
|
|||
"""
|
||||
import os
|
||||
import uuid
|
||||
from functools import reduce
|
||||
from typing import Dict, List
|
||||
|
||||
from django.db import transaction
|
||||
|
|
@ -21,7 +22,8 @@ from common.util.field_message import ErrMessage
|
|||
from common.util.file_util import get_file_content
|
||||
from dataset.models import Problem, Paragraph, ProblemParagraphMapping, DataSet
|
||||
from dataset.serializers.common_serializers import get_embedding_model_id_by_dataset_id
|
||||
from embedding.task import delete_embedding_by_source_ids, update_problem_embedding
|
||||
from embedding.models import SourceType
|
||||
from embedding.task import delete_embedding_by_source_ids, update_problem_embedding, embedding_by_data_list
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
|
||||
|
|
@ -50,6 +52,35 @@ class ProblemInstanceSerializer(ApiMixin, serializers.Serializer):
|
|||
})
|
||||
|
||||
|
||||
class AssociationParagraph(serializers.Serializer):
|
||||
paragraph_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("段落id"))
|
||||
document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("文档id"))
|
||||
|
||||
|
||||
class BatchAssociation(serializers.Serializer):
|
||||
problem_id_list = serializers.ListField(required=True, error_messages=ErrMessage.list("问题id列表"),
|
||||
child=serializers.UUIDField(required=True,
|
||||
error_messages=ErrMessage.uuid("问题id")))
|
||||
paragraph_list = AssociationParagraph(many=True)
|
||||
|
||||
|
||||
def is_exits(exits_problem_paragraph_mapping_list, new_paragraph_mapping):
|
||||
filter_list = [exits_problem_paragraph_mapping for exits_problem_paragraph_mapping in
|
||||
exits_problem_paragraph_mapping_list if
|
||||
str(exits_problem_paragraph_mapping.paragraph_id) == new_paragraph_mapping.paragraph_id
|
||||
and str(exits_problem_paragraph_mapping.problem_id) == new_paragraph_mapping.problem_id
|
||||
and str(exits_problem_paragraph_mapping.dataset_id) == new_paragraph_mapping.dataset_id]
|
||||
return len(filter_list) > 0
|
||||
|
||||
|
||||
def to_problem_paragraph_mapping(problem, document_id: str, paragraph_id: str, dataset_id: str):
|
||||
return ProblemParagraphMapping(id=uuid.uuid1(),
|
||||
document_id=document_id,
|
||||
paragraph_id=paragraph_id,
|
||||
dataset_id=dataset_id,
|
||||
problem_id=str(problem.id)), problem
|
||||
|
||||
|
||||
class ProblemSerializers(ApiMixin, serializers.Serializer):
|
||||
class Create(serializers.Serializer):
|
||||
dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id"))
|
||||
|
|
@ -115,6 +146,47 @@ class ProblemSerializers(ApiMixin, serializers.Serializer):
|
|||
delete_embedding_by_source_ids(source_ids)
|
||||
return True
|
||||
|
||||
def association(self, instance: Dict, with_valid=True):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
BatchAssociation(data=instance).is_valid(raise_exception=True)
|
||||
dataset_id = self.data.get('dataset_id')
|
||||
paragraph_list = instance.get('paragraph_list')
|
||||
problem_id_list = instance.get('problem_id_list')
|
||||
problem_list = QuerySet(Problem).filter(id__in=problem_id_list)
|
||||
exits_problem_paragraph_mapping = QuerySet(ProblemParagraphMapping).filter(problem_id__in=problem_id_list,
|
||||
paragraph_id__in=[
|
||||
p.get('paragraph_id')
|
||||
for p in
|
||||
paragraph_list])
|
||||
problem_paragraph_mapping_list = [(problem_paragraph_mapping, problem) for
|
||||
problem_paragraph_mapping, problem in reduce(lambda x, y: [*x, *y],
|
||||
[[
|
||||
to_problem_paragraph_mapping(
|
||||
problem,
|
||||
paragraph.get(
|
||||
'document_id'),
|
||||
paragraph.get(
|
||||
'paragraph_id'),
|
||||
dataset_id) for
|
||||
paragraph in
|
||||
paragraph_list]
|
||||
for problem in
|
||||
problem_list], []) if
|
||||
not is_exits(exits_problem_paragraph_mapping, problem_paragraph_mapping)]
|
||||
QuerySet(ProblemParagraphMapping).bulk_create(
|
||||
[problem_paragraph_mapping for problem_paragraph_mapping, problem in problem_paragraph_mapping_list])
|
||||
data_list = [{'text': problem.content,
|
||||
'is_active': True,
|
||||
'source_type': SourceType.PROBLEM,
|
||||
'source_id': str(problem_paragraph_mapping.id),
|
||||
'document_id': str(problem_paragraph_mapping.document_id),
|
||||
'paragraph_id': str(problem_paragraph_mapping.paragraph_id),
|
||||
'dataset_id': dataset_id,
|
||||
} for problem_paragraph_mapping, problem in problem_paragraph_mapping_list]
|
||||
model_id = get_embedding_model_id_by_dataset_id(self.data.get('dataset_id'))
|
||||
embedding_by_data_list(data_list, model_id=model_id)
|
||||
|
||||
class Operate(serializers.Serializer):
|
||||
dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id"))
|
||||
|
||||
|
|
|
|||
|
|
@ -36,6 +36,36 @@ class ProblemApi(ApiMixin):
|
|||
}
|
||||
)
|
||||
|
||||
class BatchAssociation(ApiMixin):
|
||||
@staticmethod
|
||||
def get_request_params_api():
|
||||
return ProblemApi.BatchOperate.get_request_params_api()
|
||||
|
||||
@staticmethod
|
||||
def get_request_body_api():
|
||||
return openapi.Schema(
|
||||
type=openapi.TYPE_OBJECT,
|
||||
required=['problem_id_list'],
|
||||
properties={
|
||||
'problem_id_list': openapi.Schema(type=openapi.TYPE_ARRAY, title="问题id列表",
|
||||
description="问题id列表",
|
||||
items=openapi.Schema(type=openapi.TYPE_STRING)),
|
||||
'paragraph_list': openapi.Schema(type=openapi.TYPE_ARRAY, title="关联段落信息列表",
|
||||
description="关联段落信息列表",
|
||||
items=openapi.Schema(type=openapi.TYPE_OBJECT,
|
||||
required=['paragraph_id', 'document_id'],
|
||||
properties={
|
||||
'paragraph_id': openapi.Schema(
|
||||
type=openapi.TYPE_STRING,
|
||||
title="段落id"),
|
||||
'document_id': openapi.Schema(
|
||||
type=openapi.TYPE_STRING,
|
||||
title="文档id")
|
||||
}))
|
||||
|
||||
}
|
||||
)
|
||||
|
||||
class BatchOperate(ApiMixin):
|
||||
@staticmethod
|
||||
def get_request_params_api():
|
||||
|
|
|
|||
|
|
@ -88,6 +88,20 @@ class Problem(APIView):
|
|||
return result.success(
|
||||
ProblemSerializers.BatchOperate(data={'dataset_id': dataset_id}).delete(request.data))
|
||||
|
||||
@action(methods=['POST'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="批量关联段落",
|
||||
operation_id="批量关联段落",
|
||||
request_body=ProblemApi.BatchAssociation.get_request_body_api(),
|
||||
manual_parameters=ProblemApi.BatchOperate.get_request_params_api(),
|
||||
responses=result.get_default_response(),
|
||||
tags=["知识库/文档/段落/问题"])
|
||||
@has_permissions(
|
||||
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
|
||||
dynamic_tag=k.get('dataset_id')))
|
||||
def post(self, request: Request, dataset_id: str):
|
||||
return result.success(
|
||||
ProblemSerializers.BatchOperate(data={'dataset_id': dataset_id}).association(request.data))
|
||||
|
||||
class Operate(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
|
|
|
|||
|
|
@ -111,6 +111,11 @@ def embedding_by_problem(args, model_id):
|
|||
ListenerManagement.embedding_by_problem(args, embedding_model)
|
||||
|
||||
|
||||
def embedding_by_data_list(args: List, model_id):
|
||||
embedding_model = get_embedding_model(model_id)
|
||||
ListenerManagement.embedding_by_data_list(args, embedding_model)
|
||||
|
||||
|
||||
def delete_embedding_by_document(document_id):
|
||||
"""
|
||||
删除指定文档id的向量
|
||||
|
|
|
|||
|
|
@ -87,27 +87,21 @@ class BaseVectorStore(ABC):
|
|||
self._batch_save(child_array, embedding, lambda: True)
|
||||
|
||||
def batch_save(self, data_list: List[Dict], embedding: Embeddings, is_save_function):
|
||||
# 获取锁
|
||||
lock.acquire()
|
||||
try:
|
||||
"""
|
||||
批量插入
|
||||
:param data_list: 数据列表
|
||||
:param embedding: 向量化处理器
|
||||
:return: bool
|
||||
"""
|
||||
self.save_pre_handler()
|
||||
chunk_list = chunk_data_list(data_list)
|
||||
result = sub_array(chunk_list)
|
||||
for child_array in result:
|
||||
if is_save_function():
|
||||
self._batch_save(child_array, embedding, is_save_function)
|
||||
else:
|
||||
break
|
||||
finally:
|
||||
# 释放锁
|
||||
lock.release()
|
||||
return True
|
||||
"""
|
||||
批量插入
|
||||
@param data_list: 数据列表
|
||||
@param embedding: 向量化处理器
|
||||
@param is_save_function:
|
||||
:return: bool
|
||||
"""
|
||||
self.save_pre_handler()
|
||||
chunk_list = chunk_data_list(data_list)
|
||||
result = sub_array(chunk_list)
|
||||
for child_array in result:
|
||||
if is_save_function():
|
||||
self._batch_save(child_array, embedding, is_save_function)
|
||||
else:
|
||||
break
|
||||
|
||||
@abstractmethod
|
||||
def _save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str,
|
||||
|
|
|
|||
|
|
@ -97,11 +97,28 @@ const getDetailProblems: (
|
|||
return get(`${prefix}/${dataset_id}/problem/${problem_id}/paragraph`, undefined, loading)
|
||||
}
|
||||
|
||||
/**
|
||||
* 批量关联段落
|
||||
* @param 参数 dataset_id,
|
||||
* {
|
||||
"problem_id_list": "Array",
|
||||
"paragraph_list": "Array",
|
||||
}
|
||||
*/
|
||||
const postMulAssociationProblem: (
|
||||
dataset_id: string,
|
||||
data: any,
|
||||
loading?: Ref<boolean>
|
||||
) => Promise<Result<boolean>> = (dataset_id, data, loading) => {
|
||||
return post(`${prefix}/${dataset_id}/problem/_batch`, data, undefined, loading)
|
||||
}
|
||||
|
||||
export default {
|
||||
getProblems,
|
||||
postProblems,
|
||||
delProblems,
|
||||
putProblems,
|
||||
getDetailProblems,
|
||||
delMulProblem
|
||||
delMulProblem,
|
||||
postMulAssociationProblem
|
||||
}
|
||||
|
|
|
|||
|
|
@ -91,6 +91,12 @@
|
|||
</el-scrollbar>
|
||||
</el-col>
|
||||
</el-row>
|
||||
<template #footer v-if="isMul">
|
||||
<div class="dialog-footer">
|
||||
<el-button @click="dialogVisible = false"> 取消</el-button>
|
||||
<el-button type="primary" @click="mulAssociation"> 确认 </el-button>
|
||||
</div>
|
||||
</template>
|
||||
</el-dialog>
|
||||
</template>
|
||||
<script setup lang="ts">
|
||||
|
|
@ -99,6 +105,7 @@ import { useRoute } from 'vue-router'
|
|||
import problemApi from '@/api/problem'
|
||||
import paragraphApi from '@/api/paragraph'
|
||||
import useStore from '@/stores'
|
||||
import { MsgSuccess } from '@/utils/message'
|
||||
|
||||
const { problem, document } = useStore()
|
||||
|
||||
|
|
@ -116,6 +123,7 @@ const documentList = ref<any[]>([])
|
|||
const cloneDocumentList = ref<any[]>([])
|
||||
const paragraphList = ref<any[]>([])
|
||||
const currentProblemId = ref<String>('')
|
||||
const currentMulProblemId = ref<string[]>([])
|
||||
|
||||
// 回显
|
||||
const associationParagraph = ref<any[]>([])
|
||||
|
|
@ -124,6 +132,8 @@ const currentDocument = ref<String>('')
|
|||
const search = ref('')
|
||||
const searchType = ref('title')
|
||||
const filterDoc = ref('')
|
||||
// 批量
|
||||
const isMul = ref(false)
|
||||
|
||||
const paginationConfig = reactive({
|
||||
current_page: 1,
|
||||
|
|
@ -131,31 +141,53 @@ const paginationConfig = reactive({
|
|||
total: 0
|
||||
})
|
||||
|
||||
function mulAssociation() {
|
||||
const data = {
|
||||
problem_id_list: currentMulProblemId.value,
|
||||
paragraph_list: associationParagraph.value.map((item) => ({
|
||||
paragraph_id: item.id,
|
||||
document_id: item.document_id
|
||||
}))
|
||||
}
|
||||
problemApi.postMulAssociationProblem(id, data, loading).then(() => {
|
||||
MsgSuccess('批量关联分段成功')
|
||||
dialogVisible.value = false
|
||||
})
|
||||
}
|
||||
|
||||
function associationClick(item: any) {
|
||||
if (isAssociation(item.id)) {
|
||||
problem
|
||||
.asyncDisassociationProblem(
|
||||
id,
|
||||
item.document_id,
|
||||
item.id,
|
||||
currentProblemId.value as string,
|
||||
loading
|
||||
)
|
||||
.then(() => {
|
||||
getRecord(currentProblemId.value)
|
||||
})
|
||||
if (isMul.value) {
|
||||
if (isAssociation(item.id)) {
|
||||
associationParagraph.value.splice(associationParagraph.value.indexOf(item.id), 1)
|
||||
} else {
|
||||
associationParagraph.value.push(item)
|
||||
}
|
||||
} else {
|
||||
problem
|
||||
.asyncAssociationProblem(
|
||||
id,
|
||||
item.document_id,
|
||||
item.id,
|
||||
currentProblemId.value as string,
|
||||
loading
|
||||
)
|
||||
.then(() => {
|
||||
getRecord(currentProblemId.value)
|
||||
})
|
||||
if (isAssociation(item.id)) {
|
||||
problem
|
||||
.asyncDisassociationProblem(
|
||||
id,
|
||||
item.document_id,
|
||||
item.id,
|
||||
currentProblemId.value as string,
|
||||
loading
|
||||
)
|
||||
.then(() => {
|
||||
getRecord(currentProblemId.value)
|
||||
})
|
||||
} else {
|
||||
problem
|
||||
.asyncAssociationProblem(
|
||||
id,
|
||||
item.document_id,
|
||||
item.id,
|
||||
currentProblemId.value as string,
|
||||
loading
|
||||
)
|
||||
.then(() => {
|
||||
getRecord(currentProblemId.value)
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -216,6 +248,7 @@ watch(dialogVisible, (bool) => {
|
|||
cloneDocumentList.value = []
|
||||
paragraphList.value = []
|
||||
associationParagraph.value = []
|
||||
isMul.value = false
|
||||
|
||||
currentDocument.value = ''
|
||||
search.value = ''
|
||||
|
|
@ -232,10 +265,15 @@ watch(filterDoc, (val) => {
|
|||
currentDocument.value = documentList.value?.length > 0 ? documentList.value[0].id : ''
|
||||
})
|
||||
|
||||
const open = (problemId: string) => {
|
||||
currentProblemId.value = problemId
|
||||
const open = (problemId: any) => {
|
||||
getDocument()
|
||||
getRecord(problemId)
|
||||
if (problemId.length == 1) {
|
||||
currentProblemId.value = problemId[0]
|
||||
getRecord(problemId)
|
||||
} else if (problemId.length > 1) {
|
||||
currentMulProblemId.value = problemId
|
||||
isMul.value = true
|
||||
}
|
||||
dialogVisible.value = true
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -5,6 +5,9 @@
|
|||
<div class="flex-between">
|
||||
<div>
|
||||
<el-button type="primary" @click="createProblem">创建问题</el-button>
|
||||
<el-button @click="relateProblem()" :disabled="multipleSelection.length === 0"
|
||||
>关联分段</el-button
|
||||
>
|
||||
<el-button @click="deleteMulDocument" :disabled="multipleSelection.length === 0"
|
||||
>批量删除</el-button
|
||||
>
|
||||
|
|
@ -104,7 +107,7 @@
|
|||
:next_disable="next_disable"
|
||||
@refresh="refresh"
|
||||
/>
|
||||
<RelateProblemDialog ref="RelateProblemDialogRef" @refresh="refresh" />
|
||||
<RelateProblemDialog ref="RelateProblemDialogRef" @refresh="refreshRelate" />
|
||||
</LayoutContainer>
|
||||
</template>
|
||||
<script setup lang="ts">
|
||||
|
|
@ -157,8 +160,19 @@ const problemIndexMap = computed<Dict<number>>(() => {
|
|||
const multipleTableRef = ref<InstanceType<typeof ElTable>>()
|
||||
const multipleSelection = ref<any[]>([])
|
||||
|
||||
function relateProblem(row: any) {
|
||||
RelateProblemDialogRef.value.open(row.id)
|
||||
function relateProblem(row?: any) {
|
||||
const arr: string[] = []
|
||||
if (row) {
|
||||
arr.push(row.id)
|
||||
} else {
|
||||
multipleSelection.value.map((v) => {
|
||||
if (v) {
|
||||
arr.push(v.id)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
RelateProblemDialogRef.value.open(arr)
|
||||
}
|
||||
|
||||
function createProblem() {
|
||||
|
|
@ -330,7 +344,10 @@ function getList() {
|
|||
paginationConfig.total = res.data.total
|
||||
})
|
||||
}
|
||||
|
||||
function refreshRelate() {
|
||||
getList()
|
||||
multipleTableRef.value?.clearSelection()
|
||||
}
|
||||
function refresh() {
|
||||
paginationConfig.current_page = 1
|
||||
getList()
|
||||
|
|
|
|||
Loading…
Reference in New Issue