feat: add migration endpoints for documents and paragraphs with parameter handling

This commit is contained in:
CaptainB 2025-06-03 18:37:23 +08:00
parent 070a5d4ed3
commit 852313217d
8 changed files with 367 additions and 15 deletions

View File

@ -6,7 +6,7 @@ from common.result import DefaultResultSerializer
from knowledge.serializers.common import BatchSerializer
from knowledge.serializers.document import DocumentInstanceSerializer, DocumentWebInstanceSerializer, \
CancelInstanceSerializer, BatchCancelInstanceSerializer, DocumentRefreshSerializer, BatchEditHitHandlingSerializer, \
DocumentBatchRefreshSerializer, DocumentBatchGenerateRelatedSerializer
DocumentBatchRefreshSerializer, DocumentBatchGenerateRelatedSerializer, DocumentMigrateSerializer
class DocumentSplitAPI(APIMixin):
@ -471,3 +471,35 @@ class DocumentExportAPI(APIMixin):
@staticmethod
def get_response():
return DefaultResultSerializer
class DocumentMigrateAPI(APIMixin):
@staticmethod
def get_parameters():
return [
OpenApiParameter(
name="workspace_id",
description="工作空间id",
type=OpenApiTypes.STR,
location='path',
required=True,
),
OpenApiParameter(
name="knowledge_id",
description="知识库id",
type=OpenApiTypes.STR,
location='path',
required=True,
),
OpenApiParameter(
name="target_knowledge_id",
description="目标知识库id",
type=OpenApiTypes.STR,
location='path',
required=True,
),
]
@staticmethod
def get_request():
return DocumentMigrateSerializer

View File

@ -265,3 +265,49 @@ class ParagraphPageAPI(APIMixin):
required=True,
),
]
class ParagraphMigrateAPI(APIMixin):
@staticmethod
def get_parameters():
return [
OpenApiParameter(
name="workspace_id",
description="工作空间id",
type=OpenApiTypes.STR,
location='path',
required=True,
),
OpenApiParameter(
name="knowledge_id",
description="知识库id",
type=OpenApiTypes.STR,
location='path',
required=True,
),
OpenApiParameter(
name="document_id",
description="文档id",
type=OpenApiTypes.STR,
location='path',
required=True,
),
OpenApiParameter(
name="target_knowledge_id",
description="目标知识库id",
type=OpenApiTypes.STR,
location='path',
required=True,
),
OpenApiParameter(
name="target_document_id",
description="目标文档id",
type=OpenApiTypes.STR,
location='path',
required=True,
),
]
@staticmethod
def get_request():
return BatchSerializer

View File

@ -49,7 +49,8 @@ from knowledge.serializers.common import ProblemParagraphManage, BatchSerializer
from knowledge.serializers.paragraph import ParagraphSerializers, ParagraphInstanceSerializer, \
delete_problems_and_mappings
from knowledge.task.embedding import embedding_by_document, delete_embedding_by_document_list, \
delete_embedding_by_document
delete_embedding_by_document, delete_embedding_by_paragraph_ids, embedding_by_document_list, \
update_embedding_knowledge_id
from knowledge.task.generate import generate_related_by_document_id
from knowledge.task.sync import sync_web_document
from maxkb.const import PROJECT_DIR
@ -173,6 +174,10 @@ class DocumentBatchGenerateRelatedSerializer(serializers.Serializer):
state_list = serializers.ListField(required=True, label=_('state list'))
class DocumentMigrateSerializer(serializers.Serializer):
document_id_list = serializers.ListField(required=True, label=_('document id list'))
class BatchEditHitHandlingSerializer(serializers.Serializer):
id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True), label=_('id list'))
hit_handling_method = serializers.CharField(required=True, label=_('hit handling method'))
@ -199,7 +204,8 @@ class DocumentSerializers(serializers.Serializer):
language = get_language()
if self.data.get('type') == 'csv':
file = open(
os.path.join(PROJECT_DIR, "apps", "knowledge", 'template', f'csv_template_{to_locale(language)}.csv'),
os.path.join(PROJECT_DIR, "apps", "knowledge", 'template',
f'csv_template_{to_locale(language)}.csv'),
"rb")
content = file.read()
file.close()
@ -239,6 +245,98 @@ class DocumentSerializers(serializers.Serializer):
else:
return None
class Migrate(serializers.Serializer):
workspace_id = serializers.CharField(required=True, label=_('workspace id'))
knowledge_id = serializers.UUIDField(required=True, label=_('knowledge id'))
target_knowledge_id = serializers.UUIDField(required=True, label=_('target knowledge id'))
document_id_list = serializers.ListField(required=True, label=_('document list'),
child=serializers.UUIDField(required=True, label=_('document id')))
@transaction.atomic
def migrate(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
knowledge_id = self.data.get('knowledge_id')
target_knowledge_id = self.data.get('target_knowledge_id')
knowledge = QuerySet(Knowledge).filter(id=knowledge_id).first()
target_knowledge = QuerySet(Knowledge).filter(id=target_knowledge_id).first()
document_id_list = self.data.get('document_id_list')
document_list = QuerySet(Document).filter(knowledge_id=knowledge_id, id__in=document_id_list)
paragraph_list = QuerySet(Paragraph).filter(knowledge_id=knowledge_id, document_id__in=document_id_list)
problem_paragraph_mapping_list = QuerySet(ProblemParagraphMapping).filter(paragraph__in=paragraph_list)
problem_list = QuerySet(Problem).filter(
id__in=[problem_paragraph_mapping.problem_id for problem_paragraph_mapping in
problem_paragraph_mapping_list])
target_problem_list = list(
QuerySet(Problem).filter(content__in=[problem.content for problem in problem_list],
knowledge_id=target_knowledge_id))
target_handle_problem_list = [
self.get_target_knowledge_problem(target_knowledge_id, problem_paragraph_mapping,
problem_list, target_problem_list) for
problem_paragraph_mapping
in
problem_paragraph_mapping_list]
create_problem_list = [problem for problem, is_create in target_handle_problem_list if
is_create is not None and is_create]
# 插入问题
QuerySet(Problem).bulk_create(create_problem_list)
# 修改mapping
QuerySet(ProblemParagraphMapping).bulk_update(problem_paragraph_mapping_list,
['problem_id', 'knowledge_id'])
# 修改文档
if knowledge.type == KnowledgeType.BASE.value and target_knowledge.type == KnowledgeType.WEB.value:
document_list.update(knowledge_id=target_knowledge_id, type=KnowledgeType.WEB,
meta={'source_url': '', 'selector': ''})
elif target_knowledge.type == KnowledgeType.BASE.value and knowledge.type == KnowledgeType.WEB.value:
document_list.update(knowledge_id=target_knowledge_id, type=KnowledgeType.BASE,
meta={})
else:
document_list.update(knowledge_id=target_knowledge_id)
model_id = None
if knowledge.embedding_mode_id != target_knowledge.embedding_mode_id:
model_id = get_embedding_model_id_by_knowledge_id(target_knowledge_id)
pid_list = [paragraph.id for paragraph in paragraph_list]
# 修改段落信息
paragraph_list.update(knowledge_id=target_knowledge_id)
# 修改向量信息
if model_id:
delete_embedding_by_paragraph_ids(pid_list)
ListenerManagement.update_status(QuerySet(Document).filter(id__in=document_id_list),
TaskType.EMBEDDING,
State.PENDING)
ListenerManagement.update_status(QuerySet(Paragraph).filter(document_id__in=document_id_list),
TaskType.EMBEDDING,
State.PENDING)
ListenerManagement.get_aggregation_document_status_by_query_set(
QuerySet(Document).filter(id__in=document_id_list))()
embedding_by_document_list.delay(document_id_list, model_id)
else:
update_embedding_knowledge_id(pid_list, target_knowledge_id)
@staticmethod
def get_target_knowledge_problem(target_knowledge_id: str,
problem_paragraph_mapping,
source_problem_list,
target_problem_list):
source_problem_list = [source_problem for source_problem in source_problem_list if
source_problem.id == problem_paragraph_mapping.problem_id]
problem_paragraph_mapping.knowledge_id = target_knowledge_id
if len(source_problem_list) > 0:
problem_content = source_problem_list[-1].content
problem_list = [problem for problem in target_problem_list if problem.content == problem_content]
if len(problem_list) > 0:
problem = problem_list[-1]
problem_paragraph_mapping.problem_id = problem.id
return problem, False
else:
problem = Problem(id=uuid.uuid1(), knowledge_id=target_knowledge_id, content=problem_content)
target_problem_list.append(problem)
problem_paragraph_mapping.problem_id = problem.id
return problem, True
return None
class Query(serializers.Serializer):
# 知识库id

View File

@ -13,14 +13,15 @@ from common.db.search import page_search
from common.event import ListenerManagement
from common.exception.app_exception import AppApiException
from common.utils.common import post
from knowledge.models import Paragraph, Problem, Document, ProblemParagraphMapping, SourceType, TaskType, State
from knowledge.models import Paragraph, Problem, Document, ProblemParagraphMapping, SourceType, TaskType, State, \
Knowledge
from knowledge.serializers.common import ProblemParagraphObject, ProblemParagraphManage, \
get_embedding_model_id_by_knowledge_id, update_document_char_length, BatchSerializer
from knowledge.serializers.problem import ProblemInstanceSerializer, ProblemSerializer, ProblemSerializers
from knowledge.task.embedding import embedding_by_paragraph, enable_embedding_by_paragraph, \
disable_embedding_by_paragraph, \
delete_embedding_by_paragraph, embedding_by_problem as embedding_by_problem_task, delete_embedding_by_paragraph_ids, \
embedding_by_problem, delete_embedding_by_source
embedding_by_problem, delete_embedding_by_source, update_embedding_document_id
from knowledge.task.generate import generate_related_by_paragraph_id_list
@ -414,6 +415,126 @@ class ParagraphSerializers(serializers.Serializer):
except AlreadyQueued as e:
raise AppApiException(500, _('The task is being executed, please do not send it again.'))
class Migrate(serializers.Serializer):
workspace_id = serializers.UUIDField(required=True, label=_('workspace id'))
knowledge_id = serializers.UUIDField(required=True, label=_('knowledge id'))
document_id = serializers.UUIDField(required=True, label=_('document id'))
target_knowledge_id = serializers.UUIDField(required=True, label=_('target knowledge id'))
target_document_id = serializers.UUIDField(required=True, label=_('target document id'))
paragraph_id_list = serializers.ListField(required=True, label=_('paragraph id list'),
child=serializers.UUIDField(required=True, label=_('paragraph id')))
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
document_list = QuerySet(Document).filter(
id__in=[self.data.get('document_id'), self.data.get('target_document_id')])
document_id = self.data.get('document_id')
target_document_id = self.data.get('target_document_id')
if document_id == target_document_id:
raise AppApiException(5000, _('The document to be migrated is consistent with the target document'))
if len([document for document in document_list if str(document.id) == self.data.get('document_id')]) < 1:
raise AppApiException(5000, _('The document id does not exist [{document_id}]').format(
document_id=self.data.get('document_id')))
if len([document for document in document_list if
str(document.id) == self.data.get('target_document_id')]) < 1:
raise AppApiException(5000, _('The target document id does not exist [{document_id}]').format(
document_id=self.data.get('target_document_id')))
@transaction.atomic
def migrate(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
knowledge_id = self.data.get('knowledge_id')
target_knowledge_id = self.data.get('target_knowledge_id')
document_id = self.data.get('document_id')
target_document_id = self.data.get('target_document_id')
paragraph_id_list = self.data.get('paragraph_id_list')
paragraph_list = QuerySet(Paragraph).filter(knowledge_id=knowledge_id, document_id=document_id,
id__in=paragraph_id_list)
problem_paragraph_mapping_list = QuerySet(ProblemParagraphMapping).filter(paragraph__in=paragraph_list)
# 同数据集迁移
if target_knowledge_id == knowledge_id:
if len(problem_paragraph_mapping_list):
problem_paragraph_mapping_list = [
self.update_problem_paragraph_mapping(target_document_id,
problem_paragraph_mapping) for problem_paragraph_mapping
in
problem_paragraph_mapping_list]
# 修改mapping
QuerySet(ProblemParagraphMapping).bulk_update(problem_paragraph_mapping_list,
['document_id'])
update_embedding_document_id([paragraph.id for paragraph in paragraph_list],
target_document_id, target_knowledge_id, None)
# 修改段落信息
paragraph_list.update(document_id=target_document_id)
# 不同数据集迁移
else:
problem_list = QuerySet(Problem).filter(
id__in=[problem_paragraph_mapping.problem_id for problem_paragraph_mapping in
problem_paragraph_mapping_list])
# 目标数据集问题
target_problem_list = list(
QuerySet(Problem).filter(content__in=[problem.content for problem in problem_list],
knowledge_id=target_knowledge_id))
target_handle_problem_list = [
self.get_target_knowledge_problem(target_knowledge_id, target_document_id,
problem_paragraph_mapping,
problem_list, target_problem_list) for
problem_paragraph_mapping
in
problem_paragraph_mapping_list]
create_problem_list = [problem for problem, is_create in target_handle_problem_list if
is_create is not None and is_create]
# 插入问题
QuerySet(Problem).bulk_create(create_problem_list)
# 修改mapping
QuerySet(ProblemParagraphMapping).bulk_update(problem_paragraph_mapping_list,
['problem_id', 'knowledge_id', 'document_id'])
target_knowledge = QuerySet(Knowledge).filter(id=target_knowledge_id).first()
knowledge = QuerySet(Knowledge).filter(id=knowledge_id).first()
embedding_model_id = None
if target_knowledge.embedding_mode_id != knowledge.embedding_mode_id:
embedding_model_id = str(target_knowledge.embedding_mode_id)
pid_list = [paragraph.id for paragraph in paragraph_list]
# 修改段落信息
paragraph_list.update(knowledge_id=target_knowledge_id, document_id=target_document_id)
# 修改向量段落信息
update_embedding_document_id(pid_list, target_document_id, target_knowledge_id, embedding_model_id)
update_document_char_length(document_id)
update_document_char_length(target_document_id)
@staticmethod
def update_problem_paragraph_mapping(target_document_id: str, problem_paragraph_mapping):
problem_paragraph_mapping.document_id = target_document_id
return problem_paragraph_mapping
@staticmethod
def get_target_knowledge_problem(target_knowledge_id: str,
target_document_id: str,
problem_paragraph_mapping,
source_problem_list,
target_problem_list):
source_problem_list = [source_problem for source_problem in source_problem_list if
source_problem.id == problem_paragraph_mapping.problem_id]
problem_paragraph_mapping.knowledge_id = target_knowledge_id
problem_paragraph_mapping.document_id = target_document_id
if len(source_problem_list) > 0:
problem_content = source_problem_list[-1].content
problem_list = [problem for problem in target_problem_list if problem.content == problem_content]
if len(problem_list) > 0:
problem = problem_list[-1]
problem_paragraph_mapping.problem_id = problem.id
return problem, False
else:
problem = Problem(id=uuid.uuid1(), knowledge_id=target_knowledge_id, content=problem_content)
target_problem_list.append(problem)
problem_paragraph_mapping.problem_id = problem.id
return problem, True
return None
def delete_problems_and_mappings(paragraph_ids):
problem_paragraph_mappings = ProblemParagraphMapping.objects.filter(paragraph_id__in=paragraph_ids)

View File

@ -30,6 +30,7 @@ urlpatterns = [
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/batch_hit_handling', views.DocumentView.BatchEditHitHandling.as_view()),
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/template/export', views.Template.as_view()),
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/table_template/export', views.TableTemplate.as_view()),
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/migrate/target_knowledge_id', views.DocumentView.Migrate.as_view()),
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/<str:document_id>', views.DocumentView.Operate.as_view()),
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/<str:document_id>/sync', views.DocumentView.SyncWeb.as_view()),
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/<str:document_id>/refresh', views.DocumentView.Refresh.as_view()),
@ -40,6 +41,7 @@ urlpatterns = [
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/<str:document_id>/paragraph', views.ParagraphView.as_view()),
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/<str:document_id>/paragraph/batch_delete', views.ParagraphView.BatchDelete.as_view()),
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/<str:document_id>/paragraph/batch_generate_related', views.ParagraphView.BatchGenerateRelated.as_view()),
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/<str:document_id>/paragraph/migrate/knowledge/<str:target_knowledge_id>/document/<str:target_document_id>', views.ParagraphView.BatchMigrate.as_view()),
path( 'workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/<str:document_id>/paragraph/association', views.ParagraphView.Association.as_view()),
path( 'workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/<str:document_id>/paragraph/unassociation', views.ParagraphView.UnAssociation.as_view()),
path('workspace/<str:workspace_id>/knowledge/<str:knowledge_id>/document/<str:document_id>/paragraph/<str:paragraph_id>', views.ParagraphView.Operate.as_view()),

View File

@ -12,7 +12,7 @@ from knowledge.api.document import DocumentSplitAPI, DocumentBatchAPI, DocumentB
DocumentReadAPI, DocumentEditAPI, DocumentDeleteAPI, TableDocumentCreateAPI, QaDocumentCreateAPI, \
WebDocumentCreateAPI, CancelTaskAPI, BatchCancelTaskAPI, SyncWebAPI, RefreshAPI, BatchEditHitHandlingAPI, \
DocumentTreeReadAPI, DocumentSplitPatternAPI, BatchRefreshAPI, BatchGenerateRelatedAPI, TemplateExportAPI, \
DocumentExportAPI
DocumentExportAPI, DocumentMigrateAPI
from knowledge.serializers.document import DocumentSerializers
@ -396,8 +396,10 @@ class DocumentView(APIView):
tags=[_('Knowledge Base/Documentation')] # type: ignore
)
@has_permissions(PermissionConstants.KNOWLEDGE_DOCUMENT_EXPORT.get_workspace_permission())
def get(self, request: Request, dataset_id: str, document_id: str):
return DocumentSerializers.Operate(data={'document_id': document_id, 'dataset_id': dataset_id}).export()
def get(self, request: Request, workspace_id: str, knowledge_id: str, document_id: str):
return DocumentSerializers.Operate(data={
'workspace_id': workspace_id, 'document_id': document_id, 'knowledge_id': knowledge_id
}).export()
class ExportZip(APIView):
authentication_classes = [TokenAuth]
@ -410,8 +412,31 @@ class DocumentView(APIView):
tags=[_('Knowledge Base/Documentation')] # type: ignore
)
@has_permissions(PermissionConstants.KNOWLEDGE_DOCUMENT_EXPORT.get_workspace_permission())
def get(self, request: Request, dataset_id: str, document_id: str):
return DocumentSerializers.Operate(data={'document_id': document_id, 'dataset_id': dataset_id}).export_zip()
def get(self, request: Request, workspace_id: str, knowledge_id: str, document_id: str):
return DocumentSerializers.Operate(data={
'workspace_id': workspace_id, 'document_id': document_id, 'knowledge_id': knowledge_id
}).export_zip()
class Migrate(APIView):
authentication_classes = [TokenAuth]
@extend_schema(
summary=_('Migrate documents in batches'),
operation_id=_('Migrate documents in batches'), # type: ignore
parameters=DocumentMigrateAPI.get_parameters(),
request=DocumentMigrateAPI.get_request(),
responses=DocumentMigrateAPI.get_response(),
tags=[_('Knowledge Base/Documentation')] # type: ignore
)
@has_permissions(PermissionConstants.KNOWLEDGE_DOCUMENT_MIGRATE.get_workspace_permission())
def put(self, request: Request, workspace_id, knowledge_id: str, target_knowledge_id: str):
return result.success(DocumentSerializers.Migrate(
data={
'workspace_id': workspace_id,
'knowledge_id': knowledge_id,
'target_knowledge_id': target_knowledge_id,
'document_id_list': request.data}
).migrate())
class WebDocumentView(APIView):

View File

@ -193,8 +193,10 @@ class KnowledgeView(APIView):
tags=[_('Knowledge Base')] # type: ignore
)
@has_permissions(PermissionConstants.KNOWLEDGE_EXPORT.get_workspace_permission())
def get(self, request: Request, knowledge_id: str):
return KnowledgeSerializer.Operate(data={'id': knowledge_id, 'user_id': request.user.id}).export_excel()
def get(self, request: Request, workspace_id: str, knowledge_id: str):
return KnowledgeSerializer.Operate(data={
'workspace_id': workspace_id, 'knowledge_id': knowledge_id, 'user_id': request.user.id
}).export_excel()
class ExportZip(APIView):
authentication_classes = [TokenAuth]
@ -207,8 +209,10 @@ class KnowledgeView(APIView):
tags=[_('Knowledge Base')] # type: ignore
)
@has_permissions(PermissionConstants.KNOWLEDGE_EXPORT.get_workspace_permission())
def get(self, request: Request, knowledge_id: str):
return KnowledgeSerializer.Operate(data={'id': knowledge_id, 'user_id': request.user.id}).export_zip()
def get(self, request: Request, workspace_id: str, knowledge_id: str):
return KnowledgeSerializer.Operate(data={
'workspace_id': workspace_id, 'knowledge_id': knowledge_id, 'user_id': request.user.id
}).export_zip()
class GenerateRelated(APIView):
authentication_classes = [TokenAuth]

View File

@ -10,7 +10,7 @@ from common.result import result
from common.utils.common import query_params_to_single_dict
from knowledge.api.paragraph import ParagraphReadAPI, ParagraphCreateAPI, ParagraphBatchDeleteAPI, ParagraphEditAPI, \
ParagraphGetAPI, ProblemCreateAPI, UnAssociationAPI, AssociationAPI, ParagraphPageAPI, \
ParagraphBatchGenerateRelatedAPI
ParagraphBatchGenerateRelatedAPI, ParagraphMigrateAPI
from knowledge.serializers.paragraph import ParagraphSerializers
@ -71,6 +71,30 @@ class ParagraphView(APIView):
data={'workspace_id': workspace_id, 'knowledge_id': knowledge_id, 'document_id': document_id}
).batch_delete(request.data))
class BatchMigrate(APIView):
authentication_classes = [TokenAuth]
@extend_schema(
summary=_('Migrate paragraphs in batches'),
operation_id=_('Migrate paragraphs in batches'), # type: ignore
parameters=ParagraphMigrateAPI.get_parameters(),
request=ParagraphMigrateAPI.get_request(),
responses=ParagraphMigrateAPI.get_response(),
tags=[_('Knowledge Base/Documentation/Paragraph')] # type: ignore
)
@has_permissions(PermissionConstants.KNOWLEDGE_DOCUMENT_MIGRATE.get_workspace_permission())
def put(self, request: Request, workspace_id: str, knowledge_id: str, document_id: str,
target_knowledge_id: str, target_document_id):
return result.success(
ParagraphSerializers.Migrate(data={
'workspace_id': workspace_id,
'knowledge_id': knowledge_id,
'target_knowledge_id': target_knowledge_id,
'document_id': document_id,
'target_document_id': target_document_id,
'paragraph_id_list': request.data.get('id_list')
}).migrate())
class BatchGenerateRelated(APIView):
authentication_classes = [TokenAuth]