From 9d808b4ccdcd487c4b176b41126a00b1e65adbd8 Mon Sep 17 00:00:00 2001 From: shaohuzhang1 <80892890+shaohuzhang1@users.noreply.github.com> Date: Fri, 26 Apr 2024 18:35:46 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81=E6=96=87=E6=A1=A3?= =?UTF-8?q?=E8=BF=81=E7=A7=BB(#52)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/common/event/listener_manage.py | 11 ++ .../serializers/document_serializers.py | 102 +++++++++++++- apps/dataset/urls.py | 1 + apps/dataset/views/document.py | 28 +++- ui/src/api/document.ts | 21 ++- ui/src/components/icons/index.ts | 25 ++++ .../component/SelectDatasetDialog.vue | 125 ++++++++++++++++++ ui/src/views/document/index.vue | 93 +++++++++---- 8 files changed, 377 insertions(+), 29 deletions(-) create mode 100644 ui/src/views/document/component/SelectDatasetDialog.vue diff --git a/apps/common/event/listener_manage.py b/apps/common/event/listener_manage.py index 98e22c821..b266bcef0 100644 --- a/apps/common/event/listener_manage.py +++ b/apps/common/event/listener_manage.py @@ -50,6 +50,12 @@ class UpdateProblemArgs: self.problem_content = problem_content +class UpdateEmbeddingDatasetIdArgs: + def __init__(self, source_id_list: List[str], target_dataset_id: str): + self.source_id_list = source_id_list + self.target_dataset_id = target_dataset_id + + class ListenerManagement: embedding_by_problem_signal = signal("embedding_by_problem") embedding_by_paragraph_signal = signal("embedding_by_paragraph") @@ -205,6 +211,11 @@ class ListenerManagement: VectorStore.get_embedding_vector().update_by_source_ids([v.id for v in problem_paragraph_mapping_list], {'embedding': embed_value}) + @staticmethod + def update_embedding_dataset_id(args: UpdateEmbeddingDatasetIdArgs): + VectorStore.get_embedding_vector().update_by_source_ids(args.source_id_list, + {'dataset_id': args.target_dataset_id}) + @staticmethod def delete_embedding_by_source_ids(source_ids: List[str]): VectorStore.get_embedding_vector().delete_by_source_ids(source_ids, SourceType.PROBLEM) diff --git a/apps/dataset/serializers/document_serializers.py b/apps/dataset/serializers/document_serializers.py index 3a7bdeba4..f2acfb9df 100644 --- a/apps/dataset/serializers/document_serializers.py +++ b/apps/dataset/serializers/document_serializers.py @@ -22,7 +22,7 @@ from rest_framework import serializers from common.db.search import native_search, native_page_search from common.event.common import work_thread_pool -from common.event.listener_manage import ListenerManagement, SyncWebDocumentArgs +from common.event.listener_manage import ListenerManagement, SyncWebDocumentArgs, UpdateEmbeddingDatasetIdArgs from common.exception.app_exception import AppApiException from common.handle.impl.doc_split_handle import DocSplitHandle from common.handle.impl.pdf_split_handle import PdfSplitHandle @@ -114,6 +114,106 @@ class DocumentInstanceSerializer(ApiMixin, serializers.Serializer): class DocumentSerializers(ApiMixin, serializers.Serializer): + class Migrate(ApiMixin, serializers.Serializer): + dataset_id = serializers.UUIDField(required=True, + error_messages=ErrMessage.char( + "知识库id")) + target_dataset_id = serializers.UUIDField(required=True, + error_messages=ErrMessage.char( + "目标知识库id")) + document_id_list = serializers.ListField(required=True, error_messages=ErrMessage.char("文档列表"), + child=serializers.UUIDField(required=True, + error_messages=ErrMessage.uuid("文档id"))) + + @transaction.atomic + def migrate(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + dataset_id = self.data.get('dataset_id') + target_dataset_id = self.data.get('target_dataset_id') + dataset = QuerySet(DataSet).filter(id=dataset_id).first() + target_dataset = QuerySet(DataSet).filter(id=target_dataset_id).first() + document_id_list = self.data.get('document_id_list') + document_list = QuerySet(Document).filter(dataset_id=dataset_id, id__in=document_id_list) + paragraph_list = QuerySet(Paragraph).filter(dataset_id=dataset_id, document_id__in=document_id_list) + + problem_paragraph_mapping_list = QuerySet(ProblemParagraphMapping).filter(paragraph__in=paragraph_list) + problem_list = QuerySet(Problem).filter( + id__in=[problem_paragraph_mapping.problem_id for problem_paragraph_mapping in + problem_paragraph_mapping_list]) + target_problem_list = list( + QuerySet(Problem).filter(content__in=[problem.content for problem in problem_list], + dataset_id=target_dataset_id)) + target_handle_problem_list = [ + self.get_target_dataset_problem(target_dataset_id, problem_paragraph_mapping, + problem_list, target_problem_list) for + problem_paragraph_mapping + in + problem_paragraph_mapping_list] + + create_problem_list = [problem for problem, is_create in target_handle_problem_list if + is_create is not None and is_create] + # 插入问题 + QuerySet(Problem).bulk_create(create_problem_list) + # 修改mapping + QuerySet(ProblemParagraphMapping).bulk_update(problem_paragraph_mapping_list, ['problem_id', 'dataset_id']) + # 修改文档 + if dataset.type == Type.base.value and target_dataset.type == Type.web.value: + document_list.update(dataset_id=target_dataset_id, type=Type.web, + meta={'source_url': '', 'selector': ''}) + elif target_dataset.type == Type.base.value and dataset.type == Type.web.value: + document_list.update(dataset_id=target_dataset_id, type=Type.base, + meta={}) + paragraph_list.update(dataset_id=target_dataset_id) + ListenerManagement.update_embedding_dataset_id(UpdateEmbeddingDatasetIdArgs( + [problem_paragraph_mapping.id for problem_paragraph_mapping in problem_paragraph_mapping_list], + target_dataset_id)) + + @staticmethod + def get_target_dataset_problem(target_dataset_id: str, + problem_paragraph_mapping, + source_problem_list, + target_problem_list): + source_problem_list = [source_problem for source_problem in source_problem_list if + source_problem.id == problem_paragraph_mapping.problem_id] + problem_paragraph_mapping.dataset_id = target_dataset_id + if len(source_problem_list) > 0: + problem_content = source_problem_list[-1].content + problem_list = [problem for problem in target_problem_list if problem.content == problem_content] + if len(problem_list) > 0: + problem = problem_list[-1] + problem_paragraph_mapping.problem_id = problem.id + return problem, False + else: + problem = Problem(id=uuid.uuid1(), dataset_id=target_dataset_id, content=problem_content) + target_problem_list.append(problem) + problem_paragraph_mapping.problem_id = problem.id + return problem, True + return None + + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='dataset_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='知识库id'), + openapi.Parameter(name='target_dataset_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='目标知识库id') + ] + + @staticmethod + def get_request_body_api(): + return openapi.Schema( + type=openapi.TYPE_ARRAY, + items=openapi.Schema(type=openapi.TYPE_STRING), + title='文档id列表', + description="文档id列表" + ) + class Query(ApiMixin, serializers.Serializer): # 知识库id dataset_id = serializers.UUIDField(required=True, diff --git a/apps/dataset/urls.py b/apps/dataset/urls.py index 2868bcbbd..5ed09a199 100644 --- a/apps/dataset/urls.py +++ b/apps/dataset/urls.py @@ -21,6 +21,7 @@ urlpatterns = [ name="document_operate"), path('dataset/document/split_pattern', views.Document.SplitPattern.as_view(), name="document_operate"), + path('dataset//document/migrate/', views.Document.Migrate.as_view()), path('dataset//document//refresh', views.Document.Refresh.as_view()), path('dataset//document//paragraph', views.Paragraph.as_view()), path('dataset//document//paragraph//', diff --git a/apps/dataset/views/document.py b/apps/dataset/views/document.py index ad42e5554..fd6797b01 100644 --- a/apps/dataset/views/document.py +++ b/apps/dataset/views/document.py @@ -14,7 +14,7 @@ from rest_framework.views import APIView from rest_framework.views import Request from common.auth import TokenAuth, has_permissions -from common.constants.permission_constants import Permission, Group, Operate +from common.constants.permission_constants import Permission, Group, Operate, CompareConstants from common.response import result from common.util.common import query_params_to_single_dict from dataset.serializers.common_serializers import BatchSerializer @@ -135,6 +135,32 @@ class Document(APIView): DocumentSerializers.Operate(data={'document_id': document_id, 'dataset_id': dataset_id}).refresh( )) + class Migrate(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['PUT'], detail=False) + @swagger_auto_schema(operation_summary="批量迁移文档", + operation_id="批量迁移文档", + manual_parameters=DocumentSerializers.Migrate.get_request_params_api(), + request_body=DocumentSerializers.Migrate.get_request_body_api(), + responses=result.get_api_response(DocumentSerializers.Operate.get_response_body_api()), + tags=["知识库/文档"] + ) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE, + dynamic_tag=k.get('dataset_id')), + lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE, + dynamic_tag=k.get('target_dataset_id')), + compare=CompareConstants.AND + ) + def put(self, request: Request, dataset_id: str, target_dataset_id: str): + return result.success( + DocumentSerializers.Migrate( + data={'dataset_id': dataset_id, 'target_dataset_id': target_dataset_id, + 'document_id_list': request.data}).migrate( + + )) + class Operate(APIView): authentication_classes = [TokenAuth] diff --git a/ui/src/api/document.ts b/ui/src/api/document.ts index 618b357f3..fdd070573 100644 --- a/ui/src/api/document.ts +++ b/ui/src/api/document.ts @@ -188,6 +188,24 @@ const postWebDocument: ( return post(`${prefix}/${dataset_id}/document/web`, data, undefined, loading) } +/** + * 批量迁移文档 + * @param 参数 dataset_id,target_dataset_id, + */ +const putMigrateMulDocument: ( + dataset_id: string, + target_dataset_id: string, + data: any, + loading?: Ref +) => Promise> = (dataset_id, target_dataset_id, data, loading) => { + return put( + `${prefix}/${dataset_id}/document/migrate/${target_dataset_id}`, + data, + undefined, + loading + ) +} + export default { postSplitDocument, getDocument, @@ -200,5 +218,6 @@ export default { listSplitPattern, putDocumentRefresh, delMulSyncDocument, - postWebDocument + postWebDocument, + putMigrateMulDocument } diff --git a/ui/src/components/icons/index.ts b/ui/src/components/icons/index.ts index 16e2fac23..983ac7de2 100644 --- a/ui/src/components/icons/index.ts +++ b/ui/src/components/icons/index.ts @@ -774,5 +774,30 @@ export const iconMap: any = { ) ]) } + }, + 'app-migrate': { + iconReader: () => { + return h('i', [ + h( + 'svg', + { + style: { height: '100%', width: '100%' }, + viewBox: '0 0 1024 1024', + version: '1.1', + xmlns: 'http://www.w3.org/2000/svg' + }, + [ + h('path', { + d: 'M537.6 665.6c-12.8 12.8-12.8 32 0 44.8 6.4 6.4 12.8 6.4 25.6 6.4 6.4 0 19.2 0 25.6-6.4l128-134.4s6.4-6.4 6.4-12.8v-19.2-6.4c0-6.4-6.4-12.8-6.4-12.8l-134.4-128c-12.8-12.8-32-12.8-44.8 0-12.8 12.8-12.8 38.4 0 51.2l76.8 76.8H96c-19.2 0-32 12.8-32 32s12.8 32 32 32h524.8l-83.2 76.8z', + fill: 'currentColor' + }), + h('path', { + d: 'M960 384c0-6.4-6.4-12.8-6.4-19.2L704 128c-6.4-6.4-6.4-6.4-12.8-6.4h-6.4-371.2c-76.8 0-140.8 64-140.8 140.8v172.8c0 19.2 12.8 32 32 32s25.6-19.2 25.6-38.4V262.4c0-44.8 38.4-76.8 76.8-76.8h339.2v211.2c0 19.2 12.8 32 32 32H896V768c0 44.8-38.4 76.8-76.8 76.8H313.6c-44.8 0-76.8-38.4-76.8-76.8v-89.6c0-19.2-12.8-32-32-32s-32 12.8-32 32V768c0 76.8 64 140.8 140.8 140.8h505.6c76.8 0 140.8-64 140.8-140.8V384c0 6.4 0 6.4 0 0z m-243.2-25.6V224l134.4 134.4h-134.4z', + fill: 'currentColor' + }) + ] + ) + ]) + } } } diff --git a/ui/src/views/document/component/SelectDatasetDialog.vue b/ui/src/views/document/component/SelectDatasetDialog.vue new file mode 100644 index 000000000..6c812f36b --- /dev/null +++ b/ui/src/views/document/component/SelectDatasetDialog.vue @@ -0,0 +1,125 @@ + + + diff --git a/ui/src/views/document/index.vue b/ui/src/views/document/index.vue index 55df89c81..02a0434ec 100644 --- a/ui/src/views/document/index.vue +++ b/ui/src/views/document/index.vue @@ -20,6 +20,9 @@ v-if="datasetDetail.type === '1'" >同步文档 + 批量迁移 批量删除 @@ -91,7 +94,7 @@ - + - + - + @@ -151,12 +154,23 @@ - - - - + + + + - + +
@@ -180,6 +194,10 @@ 设置 + + + 迁移 删除 @@ -194,6 +212,8 @@
+ + @@ -204,6 +224,7 @@ import { ElTable } from 'element-plus' import documentApi from '@/api/document' import ImportDocumentDialog from './component/ImportDocumentDialog.vue' import SyncWebDialog from '@/views/dataset/component/SyncWebDialog.vue' +import SelectDatasetDialog from './component/SelectDatasetDialog.vue' import { numberFormat } from '@/utils/utils' import { datetimeFormat } from '@/utils/time' import { hitHandlingMethod } from './utils' @@ -257,6 +278,23 @@ const multipleTableRef = ref>() const multipleSelection = ref([]) const title = ref('') +const SelectDatasetDialogRef = ref() + +function openDatasetDialog(row?: any) { + const arr: string[] = [] + if (row) { + arr.push(row.id) + } else { + multipleSelection.value.map((v) => { + if (v) { + arr.push(v.id) + } + }) + } + + SelectDatasetDialogRef.value.open(arr) +} + function dropdownHandle(val: string) { filterMethod.value = val getList() @@ -284,12 +322,6 @@ const handleSelectionChange = (val: any[]) => { */ const initInterval = () => { interval = setInterval(() => { - // if ( - // documentData.value.length === 0 || - // documentData.value.some((item) => item.status === '0' || item.status === '2') - // ) { - // getList(true) - // } getList(true) }, 6000) } @@ -304,20 +336,29 @@ const closeInterval = () => { } function refreshDocument(row: any) { if (row.type === '1') { - MsgConfirm(`确认同步文档?`, `同步将删除已有数据重新获取新数据,请谨慎操作。`, { - confirmButtonText: '同步', - confirmButtonClass: 'danger' - }) - .then(() => { - documentApi.putDocumentRefresh(row.dataset_id, row.id).then((res) => { - getList() - }) + if (row.meta?.source_url) { + MsgConfirm(`确认同步文档?`, `同步将删除已有数据重新获取新数据,请谨慎操作。`, { + confirmButtonText: '同步', + confirmButtonClass: 'danger' }) - .catch(() => {}) + .then(() => { + documentApi.putDocumentRefresh(row.dataset_id, row.id).then((res) => { + getList() + }) + }) + .catch(() => {}) + } else { + MsgConfirm(`提示`, `无法同步,请先去设置文档 URL地址`, { + confirmButtonText: '确认', + type: 'warning' + }) + .then(() => {}) + .catch(() => {}) + } } else { - documentApi.putDocumentRefresh(row.dataset_id, row.id).then((res) => { - getList() - }) + // documentApi.putDocumentRefresh(row.dataset_id, row.id).then((res) => { + // getList() + // }) } }