diff --git a/apps/application/flow/step_node/document_split_node/i_document_split_node.py b/apps/application/flow/step_node/document_split_node/i_document_split_node.py index 9a6327365..39684d6bd 100644 --- a/apps/application/flow/step_node/document_split_node/i_document_split_node.py +++ b/apps/application/flow/step_node/document_split_node/i_document_split_node.py @@ -35,6 +35,7 @@ class DocumentSplitNodeSerializer(serializers.Serializer): required=False, label=_("document name relate problem reference"), child=serializers.CharField(), default=[] ) limit = serializers.IntegerField(required=False, label=_("limit"), default=4096) + chunk_size = serializers.IntegerField(required=False, label=_("chunk size"), default=256) patterns = serializers.ListField( required=False, label=_("patterns"), child=serializers.CharField(), default=[] ) @@ -53,12 +54,10 @@ class IDocumentSplitNode(INode): return DocumentSplitNodeSerializer def _run(self): - # res = self.workflow_manage.get_reference_field(self.node_params_serializer.data.get('file_list')[0], - # self.node_params_serializer.data.get('file_list')[1:]) return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data) def execute(self, document_list, knowledge_id, split_strategy, paragraph_title_relate_problem_type, paragraph_title_relate_problem, paragraph_title_relate_problem_reference, document_name_relate_problem_type, document_name_relate_problem, - document_name_relate_problem_reference, limit, patterns, with_filter, **kwargs) -> NodeResult: + document_name_relate_problem_reference, limit, chunk_size, patterns, with_filter, **kwargs) -> NodeResult: pass diff --git a/apps/application/flow/step_node/document_split_node/impl/base_document_split_node.py b/apps/application/flow/step_node/document_split_node/impl/base_document_split_node.py index ec4055ce0..7806ffc85 100644 --- a/apps/application/flow/step_node/document_split_node/impl/base_document_split_node.py +++ b/apps/application/flow/step_node/document_split_node/impl/base_document_split_node.py @@ -7,6 +7,7 @@ from django.core.files.uploadedfile import InMemoryUploadedFile from application.flow.i_step_node import NodeResult from application.flow.step_node.document_split_node.i_document_split_node import IDocumentSplitNode +from common.chunk import text_to_chunk from knowledge.serializers.document import default_split_handle, FileBufferHandle @@ -43,7 +44,7 @@ class BaseDocumentSplitNode(IDocumentSplitNode): def execute(self, document_list, knowledge_id, split_strategy, paragraph_title_relate_problem_type, paragraph_title_relate_problem, paragraph_title_relate_problem_reference, document_name_relate_problem_type, document_name_relate_problem, - document_name_relate_problem_reference, limit, patterns, with_filter, **kwargs) -> NodeResult: + document_name_relate_problem_reference, limit, chunk_size, patterns, with_filter, **kwargs) -> NodeResult: self.context['knowledge_id'] = knowledge_id file_list = self.workflow_manage.get_reference_field(document_list[0], document_list[1:]) paragraph_list = [] @@ -62,7 +63,7 @@ class BaseDocumentSplitNode(IDocumentSplitNode): split_strategy, paragraph_title_relate_problem_type, paragraph_title_relate_problem, paragraph_title_relate_problem_reference, document_name_relate_problem_type, document_name_relate_problem, - document_name_relate_problem_reference + document_name_relate_problem_reference, chunk_size ) paragraph_list += results @@ -79,7 +80,7 @@ class BaseDocumentSplitNode(IDocumentSplitNode): split_strategy, paragraph_title_relate_problem_type, paragraph_title_relate_problem, paragraph_title_relate_problem_reference, document_name_relate_problem_type, document_name_relate_problem, - document_name_relate_problem_reference + document_name_relate_problem_reference, chunk_size ): """处理文档分割结果""" item['meta'] = { @@ -99,6 +100,7 @@ class BaseDocumentSplitNode(IDocumentSplitNode): document_name_relate_problem_reference ) paragraph['is_active'] = True + paragraph['chunks'] = text_to_chunk(paragraph['content'], chunk_size) def _generate_problem_list( self, paragraph, document_name, split_strategy, paragraph_title_relate_problem_type, diff --git a/apps/common/chunk/__init__.py b/apps/common/chunk/__init__.py index a4babde76..9e5e28680 100644 --- a/apps/common/chunk/__init__.py +++ b/apps/common/chunk/__init__.py @@ -11,8 +11,8 @@ from common.chunk.impl.mark_chunk_handle import MarkChunkHandle handles = [MarkChunkHandle()] -def text_to_chunk(text: str): +def text_to_chunk(text: str, chunk_size: int = 256): chunk_list = [text] for handle in handles: - chunk_list = handle.handle(chunk_list) + chunk_list = handle.handle(chunk_list, chunk_size) return chunk_list diff --git a/apps/common/chunk/i_chunk_handle.py b/apps/common/chunk/i_chunk_handle.py index d53575d11..3cd022ab3 100644 --- a/apps/common/chunk/i_chunk_handle.py +++ b/apps/common/chunk/i_chunk_handle.py @@ -12,5 +12,5 @@ from typing import List class IChunkHandle(ABC): @abstractmethod - def handle(self, chunk_list: List[str]): + def handle(self, chunk_list: List[str], chunk_size: int = 256): pass diff --git a/apps/common/chunk/impl/mark_chunk_handle.py b/apps/common/chunk/impl/mark_chunk_handle.py index 5bca2f445..c57ad3d13 100644 --- a/apps/common/chunk/impl/mark_chunk_handle.py +++ b/apps/common/chunk/impl/mark_chunk_handle.py @@ -11,13 +11,11 @@ from typing import List from common.chunk.i_chunk_handle import IChunkHandle -max_chunk_len = 256 -split_chunk_pattern = r'.{1,%d}[。| |\\.|!|;|;|!|\n]' % max_chunk_len -max_chunk_pattern = r'.{1,%d}' % max_chunk_len - - class MarkChunkHandle(IChunkHandle): - def handle(self, chunk_list: List[str]): + def handle(self, chunk_list: List[str], chunk_size: int = 256): + split_chunk_pattern = r'.{1,%d}[。| |\\.|!|;|;|!|\n]' % chunk_size + max_chunk_pattern = r'.{1,%d}' % chunk_size + result = [] for chunk in chunk_list: chunk_result = re.findall(split_chunk_pattern, chunk, flags=re.DOTALL) @@ -28,7 +26,7 @@ class MarkChunkHandle(IChunkHandle): other_chunk_list = re.split(split_chunk_pattern, chunk, flags=re.DOTALL) for other_chunk in other_chunk_list: if len(other_chunk) > 0: - if len(other_chunk) < max_chunk_len: + if len(other_chunk) < chunk_size: if len(other_chunk.strip()) > 0: result.append(other_chunk.strip()) else: