diff --git a/apps/application/flow/step_node/document_split_node/i_document_split_node.py b/apps/application/flow/step_node/document_split_node/i_document_split_node.py index 540b1da9a..25f4956d1 100644 --- a/apps/application/flow/step_node/document_split_node/i_document_split_node.py +++ b/apps/application/flow/step_node/document_split_node/i_document_split_node.py @@ -22,7 +22,7 @@ class DocumentSplitNodeSerializer(serializers.Serializer): required=False, label=_("paragraph title relate problem"), default=False ) paragraph_title_relate_problem_reference = serializers.ListField( - required=False, label=_("paragraph title relate problem reference"), child=serializers.CharField() + required=False, label=_("paragraph title relate problem reference"), child=serializers.CharField(), default=[] ) document_name_relate_problem_type = serializers.ChoiceField( choices=['custom', 'referencing'], required=False, label=_("document name relate problem type"), @@ -32,7 +32,7 @@ class DocumentSplitNodeSerializer(serializers.Serializer): required=False, label=_("document name relate problem"), default=False ) document_name_relate_problem_reference = serializers.ListField( - required=False, label=_("document name relate problem reference"), child=serializers.CharField() + required=False, label=_("document name relate problem reference"), child=serializers.CharField(), default=[] ) limit = serializers.IntegerField(required=False, label=_("limit"), default=4096) patterns = serializers.ListField( @@ -55,9 +55,9 @@ class IDocumentSplitNode(INode): def _run(self): res = self.workflow_manage.get_reference_field(self.node_params_serializer.data.get('file_list')[0], self.node_params_serializer.data.get('file_list')[1:]) - return self.execute(file_list=res, **self.flow_params_serializer.data) + return self.execute(files=res, **self.node_params_serializer.data, **self.flow_params_serializer.data) - def execute(self, file_list, knowledge_id, split_strategy, paragraph_title_relate_problem_type, + def execute(self, files, knowledge_id, split_strategy, paragraph_title_relate_problem_type, paragraph_title_relate_problem, paragraph_title_relate_problem_reference, document_name_relate_problem_type, document_name_relate_problem, document_name_relate_problem_reference, limit, patterns, with_filter, **kwargs) -> NodeResult: diff --git a/apps/application/flow/step_node/document_split_node/impl/base_document_split_node.py b/apps/application/flow/step_node/document_split_node/impl/base_document_split_node.py index 89fcec523..713314787 100644 --- a/apps/application/flow/step_node/document_split_node/impl/base_document_split_node.py +++ b/apps/application/flow/step_node/document_split_node/impl/base_document_split_node.py @@ -1,59 +1,85 @@ # coding=utf-8 +import io +import mimetypes + +from django.core.files.uploadedfile import InMemoryUploadedFile from django.db.models import QuerySet from application.flow.i_step_node import NodeResult from application.flow.step_node.document_split_node.i_document_split_node import IDocumentSplitNode -from knowledge.models import File +from knowledge.models import File, FileSourceType from knowledge.serializers.document import split_handles, FileBufferHandle +def bytes_to_uploaded_file(file_bytes, file_name="file.txt"): + content_type, _ = mimetypes.guess_type(file_name) + if content_type is None: + # 如果未能识别,设置为默认的二进制文件类型 + content_type = "application/octet-stream" + # 创建一个内存中的字节流对象 + file_stream = io.BytesIO(file_bytes) + + # 获取文件大小 + file_size = len(file_bytes) + + # 创建 InMemoryUploadedFile 对象 + uploaded_file = InMemoryUploadedFile( + file=file_stream, + field_name=None, + name=file_name, + content_type=content_type, + size=file_size, + charset=None, + ) + return uploaded_file + + class BaseDocumentSplitNode(IDocumentSplitNode): def save_context(self, details, workflow_manage): self.context['content'] = details.get('content') - print(details) - def execute(self, file_list, knowledge_id, split_strategy, paragraph_title_relate_problem_type, + def execute(self, files, knowledge_id, split_strategy, paragraph_title_relate_problem_type, paragraph_title_relate_problem, paragraph_title_relate_problem_reference, document_name_relate_problem_type, document_name_relate_problem, document_name_relate_problem_reference, limit, patterns, with_filter, **kwargs) -> NodeResult: get_buffer = FileBufferHandle().get_buffer + self.context['file_list'] = files + self.context['knowledge_id'] = knowledge_id paragraph_list = [] - for doc in file_list: + for doc in files: file = QuerySet(File).filter(id=doc['file_id']).first() - file_id = file.id + file_mem = bytes_to_uploaded_file(file.get_bytes(), file_name=file.file_name) + for split_handle in split_handles: - if split_handle.support(file, get_buffer): - result = split_handle.handle(file, patterns, with_filter, limit, get_buffer, self.save_image) + if split_handle.support(file_mem, get_buffer): + result = split_handle.handle(file_mem, patterns, with_filter, limit, get_buffer, self.save_image) if isinstance(result, list): for item in result: - item['source_file_id'] = file_id + item['source_file_id'] = file.id paragraph_list = result else: - result['source_file_id'] = file_id + result['source_file_id'] = file.id paragraph_list = [result] - self.context['file_list'] = file_list self.context['paragraph_list'] = paragraph_list - print(paragraph_list) return NodeResult({'paragraph_list': paragraph_list}, {}) def save_image(self, image_list): - # if image_list is not None and len(image_list) > 0: - # exist_image_list = [str(i.get('id')) for i in - # QuerySet(File).filter(id__in=[i.id for i in image_list]).values('id')] - # save_image_list = [image for image in image_list if not exist_image_list.__contains__(str(image.id))] - # save_image_list = list({img.id: img for img in save_image_list}.values()) - # # save image - # for file in save_image_list: - # file_bytes = file.meta.pop('content') - # file.meta['knowledge_id'] = self.data.get('knowledge_id') - # file.source_type = FileSourceType.KNOWLEDGE - # file.source_id = self.data.get('knowledge_id') - # file.save(file_bytes) - pass + if image_list is not None and len(image_list) > 0: + exist_image_list = [str(i.get('id')) for i in + QuerySet(File).filter(id__in=[i.id for i in image_list]).values('id')] + save_image_list = [image for image in image_list if not exist_image_list.__contains__(str(image.id))] + save_image_list = list({img.id: img for img in save_image_list}.values()) + # save image + for file in save_image_list: + file_bytes = file.meta.pop('content') + file.meta['knowledge_id'] = self.context.get('knowledge_id') + file.source_type = FileSourceType.KNOWLEDGE + file.source_id = self.context.get('knowledge_id') + file.save(file_bytes) def get_details(self, index: int, **kwargs): return {