diff --git a/apps/application/flow/step_node/document_split_node/impl/base_document_split_node.py b/apps/application/flow/step_node/document_split_node/impl/base_document_split_node.py index ad757aabf..dfecbc0fb 100644 --- a/apps/application/flow/step_node/document_split_node/impl/base_document_split_node.py +++ b/apps/application/flow/step_node/document_split_node/impl/base_document_split_node.py @@ -8,7 +8,7 @@ from django.core.files.uploadedfile import InMemoryUploadedFile from application.flow.i_step_node import NodeResult from application.flow.step_node.document_split_node.i_document_split_node import IDocumentSplitNode from common.chunk import text_to_chunk -from knowledge.serializers.document import default_split_handle, FileBufferHandle +from knowledge.serializers.document import default_split_handle, FileBufferHandle, md_qa_split_handle def bytes_to_uploaded_file(file_bytes, file_name="file.txt"): @@ -65,7 +65,11 @@ class BaseDocumentSplitNode(IDocumentSplitNode): get_buffer = FileBufferHandle().get_buffer file_mem = bytes_to_uploaded_file(doc['content'].encode('utf-8')) - result = default_split_handle.handle(file_mem, patterns, with_filter, limit, get_buffer, self._save_image) + if split_strategy == 'qa': + result = md_qa_split_handle.handle(file_mem, get_buffer, self._save_image) + else: + result = default_split_handle.handle(file_mem, patterns, with_filter, limit, get_buffer, + self._save_image) # 统一处理结果为列表 results = result if isinstance(result, list) else [result] @@ -102,7 +106,7 @@ class BaseDocumentSplitNode(IDocumentSplitNode): } item['name'] = file_name item['source_file_id'] = source_file_id - item['paragraphs'] = item.pop('content', []) + item['paragraphs'] = item.pop('content', item.get('paragraphs', [])) for paragraph in item['paragraphs']: paragraph['problem_list'] = self._generate_problem_list( @@ -126,7 +130,11 @@ class BaseDocumentSplitNode(IDocumentSplitNode): if document_name_relate_problem_type == 'referencing': document_name_relate_problem = self.get_reference_content(document_name_relate_problem_reference) - problem_list = [] + problem_list = [ + item for p in paragraph.get('problem_list', []) for item in p.get('content', '').split('
') + if item.strip() + ] + if split_strategy == 'auto': if paragraph_title_relate_problem and paragraph.get('title'): problem_list.append(paragraph.get('title')) @@ -141,7 +149,7 @@ class BaseDocumentSplitNode(IDocumentSplitNode): if document_name_relate_problem and document_name: problem_list.append(document_name) - return problem_list + return list(set(problem_list)) def get_details(self, index: int, **kwargs): return { diff --git a/apps/common/handle/impl/qa/md_parse_qa_handle.py b/apps/common/handle/impl/qa/md_parse_qa_handle.py new file mode 100644 index 000000000..2ea8ce24f --- /dev/null +++ b/apps/common/handle/impl/qa/md_parse_qa_handle.py @@ -0,0 +1,108 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: md_parse_qa_handle.py + @date:2024/5/21 14:59 + @desc: +""" +import re +import traceback + +from charset_normalizer import detect + +from common.handle.base_parse_qa_handle import BaseParseQAHandle, get_title_row_index_dict, get_row_value +from common.utils.logger import maxkb_logger + + +class MarkdownParseQAHandle(BaseParseQAHandle): + def support(self, file, get_buffer): + file_name: str = file.name.lower() + if file_name.endswith(".md") or file_name.endswith(".markdown"): + return True + return False + + def parse_markdown_table(self, content): + """解析 Markdown 表格,返回表格数据列表""" + tables = [] + lines = content.split('\n') + i = 0 + + while i < len(lines): + line = lines[i].strip() + # 检测表格开始(包含 | 符号) + if '|' in line and line.startswith('|'): + table_data = [] + # 读取表头 + header = [cell.strip() for cell in line.split('|')[1:-1]] + table_data.append(header) + i += 1 + + # 跳过分隔行 (例如: | --- | --- |) + if i < len(lines) and re.match(r'\s*\|[\s\-:]+\|\s*', lines[i]): + i += 1 + + # 读取数据行 + while i < len(lines): + line = lines[i].strip() + if not line.startswith('|'): + break + row = [cell.strip() for cell in line.split('|')[1:-1]] + if len(row) > 0: + table_data.append(row) + i += 1 + + if len(table_data) > 1: # 至少有表头和一行数据 + tables.append(table_data) + else: + i += 1 + + return tables + + def handle(self, file, get_buffer, save_image): + buffer = get_buffer(file) + try: + # 检测编码并读取文件内容 + encoding = detect(buffer)['encoding'] + content = buffer.decode(encoding if encoding else 'utf-8') + + # 解析 Markdown 表格 + tables = self.parse_markdown_table(content) + + if not tables: + return [{'name': file.name, 'paragraphs': []}] + + paragraph_list = [] + + # 处理每个表格 + for table in tables: + if len(table) < 2: + continue + + title_row_list = table[0] + title_row_index_dict = get_title_row_index_dict(title_row_list) + + # 处理表格的每一行数据 + for row in table[1:]: + content = get_row_value(row, title_row_index_dict, 'content') + if content is None: + continue + + problem = get_row_value(row, title_row_index_dict, 'problem_list') + problem = str(problem) if problem is not None else '' + problem_list = [{'content': p[0:255]} for p in problem.split('\n') if len(p.strip()) > 0] + + title = get_row_value(row, title_row_index_dict, 'title') + title = str(title) if title is not None else '' + + paragraph_list.append({ + 'title': title[0:255], + 'content': content[0:102400], + 'problem_list': problem_list + }) + + return [{'name': file.name, 'paragraphs': paragraph_list}] + + except Exception as e: + maxkb_logger.error(f"Error processing Markdown file {file.name}: {e}, {traceback.format_exc()}") + return [{'name': file.name, 'paragraphs': []}] diff --git a/apps/knowledge/serializers/document.py b/apps/knowledge/serializers/document.py index 21c47e5b6..f48904f04 100644 --- a/apps/knowledge/serializers/document.py +++ b/apps/knowledge/serializers/document.py @@ -30,6 +30,7 @@ from common.event.common import work_thread_pool from common.exception.app_exception import AppApiException from common.field.common import UploadedFileField from common.handle.impl.qa.csv_parse_qa_handle import CsvParseQAHandle +from common.handle.impl.qa.md_parse_qa_handle import MarkdownParseQAHandle from common.handle.impl.qa.xls_parse_qa_handle import XlsParseQAHandle from common.handle.impl.qa.xlsx_parse_qa_handle import XlsxParseQAHandle from common.handle.impl.qa.zip_parse_qa_handle import ZipParseQAHandle @@ -75,6 +76,7 @@ split_handles = [ default_split_handle ] +md_qa_split_handle = MarkdownParseQAHandle() parse_qa_handle_list = [XlsParseQAHandle(), CsvParseQAHandle(), XlsxParseQAHandle(), ZipParseQAHandle()] parse_table_handle_list = [CsvParseTableHandle(), XlsParseTableHandle(), XlsxParseTableHandle()]