diff --git a/apps/application/flow/step_node/document_split_node/impl/base_document_split_node.py b/apps/application/flow/step_node/document_split_node/impl/base_document_split_node.py index 713314787..3e7e640d0 100644 --- a/apps/application/flow/step_node/document_split_node/impl/base_document_split_node.py +++ b/apps/application/flow/step_node/document_split_node/impl/base_document_split_node.py @@ -1,6 +1,7 @@ # coding=utf-8 import io import mimetypes +from typing import List from django.core.files.uploadedfile import InMemoryUploadedFile from django.db.models import QuerySet @@ -38,6 +39,9 @@ class BaseDocumentSplitNode(IDocumentSplitNode): def save_context(self, details, workflow_manage): self.context['content'] = details.get('content') + def get_reference_content(self, fields: List[str]): + return self.workflow_manage.get_reference_field(fields[0], fields[1:]) + def execute(self, files, knowledge_id, split_strategy, paragraph_title_relate_problem_type, paragraph_title_relate_problem, paragraph_title_relate_problem_reference, document_name_relate_problem_type, document_name_relate_problem, @@ -53,21 +57,27 @@ class BaseDocumentSplitNode(IDocumentSplitNode): for split_handle in split_handles: if split_handle.support(file_mem, get_buffer): - result = split_handle.handle(file_mem, patterns, with_filter, limit, get_buffer, self.save_image) - if isinstance(result, list): - for item in result: - item['source_file_id'] = file.id - paragraph_list = result - else: - result['source_file_id'] = file.id - paragraph_list = [result] + result = split_handle.handle(file_mem, patterns, with_filter, limit, get_buffer, self._save_image) + # 统一处理结果为列表 + results = result if isinstance(result, list) else [result] + + for item in results: + self._process_split_result( + item, knowledge_id, file.id, file.file_name, + split_strategy, paragraph_title_relate_problem_type, + paragraph_title_relate_problem, paragraph_title_relate_problem_reference, + document_name_relate_problem_type, document_name_relate_problem, + document_name_relate_problem_reference + ) + + paragraph_list = results + break self.context['paragraph_list'] = paragraph_list - return NodeResult({'paragraph_list': paragraph_list}, {}) - def save_image(self, image_list): + def _save_image(self, image_list): if image_list is not None and len(image_list) > 0: exist_image_list = [str(i.get('id')) for i in QuerySet(File).filter(id__in=[i.id for i in image_list]).values('id')] @@ -81,6 +91,58 @@ class BaseDocumentSplitNode(IDocumentSplitNode): file.source_id = self.context.get('knowledge_id') file.save(file_bytes) + def _process_split_result( + self, item, knowledge_id, source_file_id, file_name, + split_strategy, paragraph_title_relate_problem_type, + paragraph_title_relate_problem, paragraph_title_relate_problem_reference, + document_name_relate_problem_type, document_name_relate_problem, + document_name_relate_problem_reference + ): + """处理文档分割结果""" + item['meta'] = { + 'knowledge_id': knowledge_id, + 'source_file_id': source_file_id + } + item['paragraphs'] = item.pop('content', []) + + for paragraph in item['paragraphs']: + paragraph['problem_list'] = self._generate_problem_list( + paragraph, file_name, + split_strategy, paragraph_title_relate_problem_type, + paragraph_title_relate_problem, paragraph_title_relate_problem_reference, + document_name_relate_problem_type, document_name_relate_problem, + document_name_relate_problem_reference + ) + paragraph['is_active'] = True + + def _generate_problem_list( + self, paragraph, document_name, split_strategy, paragraph_title_relate_problem_type, + paragraph_title_relate_problem, paragraph_title_relate_problem_reference, + document_name_relate_problem_type, document_name_relate_problem, + document_name_relate_problem_reference + ): + if paragraph_title_relate_problem_type == 'referencing': + paragraph_title_relate_problem = self.get_reference_content(paragraph_title_relate_problem_reference) + if document_name_relate_problem_type == 'referencing': + document_name_relate_problem = self.get_reference_content(document_name_relate_problem_reference) + + problem_list = [] + if split_strategy == 'auto': + if paragraph_title_relate_problem and paragraph.get('title'): + problem_list.append(paragraph.get('title')) + if document_name_relate_problem and document_name: + problem_list.append(document_name) + elif split_strategy == 'custom': + if paragraph_title_relate_problem: + problem_list.extend(paragraph_title_relate_problem) + if document_name_relate_problem: + problem_list.extend(document_name_relate_problem) + elif split_strategy == 'qa': + if document_name_relate_problem and document_name: + problem_list.append(document_name) + + return problem_list + def get_details(self, index: int, **kwargs): return { 'name': self.node.properties.get('stepName'),