diff --git a/apps/common/constants/permission_constants.py b/apps/common/constants/permission_constants.py
index ce167e83d..2d9eb9966 100644
--- a/apps/common/constants/permission_constants.py
+++ b/apps/common/constants/permission_constants.py
@@ -218,6 +218,14 @@ class PermissionConstants(Enum):
RoleConstants.USER])
KNOWLEDGE_DELETE = Permission(group=Group.KNOWLEDGE, operate=Operate.DELETE, role_list=[RoleConstants.ADMIN,
RoleConstants.USER])
+ DOCUMENT_READ = Permission(group=Group.KNOWLEDGE, operate=Operate.DELETE, role_list=[RoleConstants.ADMIN,
+ RoleConstants.USER])
+ DOCUMENT_CREATE = Permission(group=Group.KNOWLEDGE, operate=Operate.DELETE, role_list=[RoleConstants.ADMIN,
+ RoleConstants.USER])
+ DOCUMENT_EDIT = Permission(group=Group.KNOWLEDGE, operate=Operate.DELETE, role_list=[RoleConstants.ADMIN,
+ RoleConstants.USER])
+ DOCUMENT_DELETE = Permission(group=Group.KNOWLEDGE, operate=Operate.DELETE, role_list=[RoleConstants.ADMIN,
+ RoleConstants.USER])
def get_workspace_application_permission(self):
return lambda r, kwargs: Permission(group=self.value.group, operate=self.value.operate,
diff --git a/apps/common/handle/__init__.py b/apps/common/handle/__init__.py
new file mode 100644
index 000000000..ad09602db
--- /dev/null
+++ b/apps/common/handle/__init__.py
@@ -0,0 +1,8 @@
+# coding=utf-8
+"""
+ @project: qabot
+ @Author:虎
+ @file: __init__.py.py
+ @date:2023/9/6 10:09
+ @desc:
+"""
diff --git a/apps/common/handle/base_parse_qa_handle.py b/apps/common/handle/base_parse_qa_handle.py
new file mode 100644
index 000000000..8cd1cd1cd
--- /dev/null
+++ b/apps/common/handle/base_parse_qa_handle.py
@@ -0,0 +1,52 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: base_parse_qa_handle.py
+ @date:2024/5/21 14:56
+ @desc:
+"""
+from abc import ABC, abstractmethod
+
+
+def get_row_value(row, title_row_index_dict, field):
+ index = title_row_index_dict.get(field)
+ if index is None:
+ return None
+ if (len(row) - 1) >= index:
+ return row[index]
+ return None
+
+
+def get_title_row_index_dict(title_row_list):
+ title_row_index_dict = {}
+ if len(title_row_list) == 1:
+ title_row_index_dict['content'] = 0
+ elif len(title_row_list) == 1:
+ title_row_index_dict['title'] = 0
+ title_row_index_dict['content'] = 1
+ else:
+ title_row_index_dict['title'] = 0
+ title_row_index_dict['content'] = 1
+ title_row_index_dict['problem_list'] = 2
+ for index in range(len(title_row_list)):
+ title_row = title_row_list[index]
+ if title_row is None:
+ title_row = ''
+ if title_row.startswith('分段标题'):
+ title_row_index_dict['title'] = index
+ if title_row.startswith('分段内容'):
+ title_row_index_dict['content'] = index
+ if title_row.startswith('问题'):
+ title_row_index_dict['problem_list'] = index
+ return title_row_index_dict
+
+
+class BaseParseQAHandle(ABC):
+ @abstractmethod
+ def support(self, file, get_buffer):
+ pass
+
+ @abstractmethod
+ def handle(self, file, get_buffer, save_image):
+ pass
diff --git a/apps/common/handle/base_parse_table_handle.py b/apps/common/handle/base_parse_table_handle.py
new file mode 100644
index 000000000..65eaf897f
--- /dev/null
+++ b/apps/common/handle/base_parse_table_handle.py
@@ -0,0 +1,23 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: base_parse_qa_handle.py
+ @date:2024/5/21 14:56
+ @desc:
+"""
+from abc import ABC, abstractmethod
+
+
+class BaseParseTableHandle(ABC):
+ @abstractmethod
+ def support(self, file, get_buffer):
+ pass
+
+ @abstractmethod
+ def handle(self, file, get_buffer,save_image):
+ pass
+
+ @abstractmethod
+ def get_content(self, file, save_image):
+ pass
\ No newline at end of file
diff --git a/apps/common/handle/base_split_handle.py b/apps/common/handle/base_split_handle.py
new file mode 100644
index 000000000..bedaad5e1
--- /dev/null
+++ b/apps/common/handle/base_split_handle.py
@@ -0,0 +1,24 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: base_split_handle.py
+ @date:2024/3/27 18:13
+ @desc:
+"""
+from abc import ABC, abstractmethod
+from typing import List
+
+
+class BaseSplitHandle(ABC):
+ @abstractmethod
+ def support(self, file, get_buffer):
+ pass
+
+ @abstractmethod
+ def handle(self, file, pattern_list: List, with_filter: bool, limit: int, get_buffer, save_image):
+ pass
+
+ @abstractmethod
+ def get_content(self, file, save_image):
+ pass
diff --git a/apps/common/handle/base_to_response.py b/apps/common/handle/base_to_response.py
new file mode 100644
index 000000000..376d1a9dd
--- /dev/null
+++ b/apps/common/handle/base_to_response.py
@@ -0,0 +1,30 @@
+# coding=utf-8
+"""
+ @project: MaxKB
+ @Author:虎
+ @file: base_to_response.py
+ @date:2024/9/6 16:04
+ @desc:
+"""
+from abc import ABC, abstractmethod
+
+from rest_framework import status
+
+
+class BaseToResponse(ABC):
+
+ @abstractmethod
+ def to_block_response(self, chat_id, chat_record_id, content, is_end, completion_tokens,
+ prompt_tokens, other_params: dict = None,
+ _status=status.HTTP_200_OK):
+ pass
+
+ @abstractmethod
+ def to_stream_chunk_response(self, chat_id, chat_record_id, node_id, up_node_id_list, content, is_end,
+ completion_tokens,
+ prompt_tokens, other_params: dict = None):
+ pass
+
+ @staticmethod
+ def format_stream_chunk(response_str):
+ return 'data: ' + response_str + '\n\n'
diff --git a/apps/common/handle/handle_exception.py b/apps/common/handle/handle_exception.py
new file mode 100644
index 000000000..c25d4c01f
--- /dev/null
+++ b/apps/common/handle/handle_exception.py
@@ -0,0 +1,94 @@
+# coding=utf-8
+"""
+ @project: qabot
+ @Author:虎
+ @file: handle_exception.py
+ @date:2023/9/5 19:29
+ @desc:
+"""
+import logging
+import traceback
+
+from rest_framework.exceptions import ValidationError, ErrorDetail, APIException
+from rest_framework.views import exception_handler
+
+from common.exception.app_exception import AppApiException
+
+from django.utils.translation import gettext_lazy as _
+
+from common.result import result
+
+
+def to_result(key, args, parent_key=None):
+ """
+ 将校验异常 args转换为统一数据
+ :param key: 校验key
+ :param args: 校验异常参数
+ :param parent_key 父key
+ :return: 接口响应对象
+ """
+ error_detail = list(filter(
+ lambda d: True if isinstance(d, ErrorDetail) else True if isinstance(d, dict) and len(
+ d.keys()) > 0 else False,
+ (args[0] if len(args) > 0 else {key: [ErrorDetail(_('Unknown exception'), code='unknown')]}).get(key)))[0]
+
+ if isinstance(error_detail, dict):
+ return list(map(lambda k: to_result(k, args=[error_detail],
+ parent_key=key if parent_key is None else parent_key + '.' + key),
+ error_detail.keys() if len(error_detail) > 0 else []))[0]
+
+ return result.Result(500 if isinstance(error_detail.code, str) else error_detail.code,
+ message=f"【{key if parent_key is None else parent_key + '.' + key}】为必填参数" if str(
+ error_detail) == "This field is required." else error_detail)
+
+
+def validation_error_to_result(exc: ValidationError):
+ """
+ 校验异常转响应对象
+ :param exc: 校验异常
+ :return: 接口响应对象
+ """
+ try:
+ v = find_err_detail(exc.detail)
+ if v is None:
+ return result.error(str(exc.detail))
+ return result.error(str(v))
+ except Exception as e:
+ return result.error(str(exc.detail))
+
+
+def find_err_detail(exc_detail):
+ if isinstance(exc_detail, ErrorDetail):
+ return exc_detail
+ if isinstance(exc_detail, dict):
+ keys = exc_detail.keys()
+ for key in keys:
+ _value = exc_detail[key]
+ if isinstance(_value, list):
+ return find_err_detail(_value)
+ if isinstance(_value, ErrorDetail):
+ return _value
+ if isinstance(_value, dict) and len(_value.keys()) > 0:
+ return find_err_detail(_value)
+ if isinstance(exc_detail, list):
+ for v in exc_detail:
+ r = find_err_detail(v)
+ if r is not None:
+ return r
+
+
+def handle_exception(exc, context):
+ exception_class = exc.__class__
+ # 先调用REST framework默认的异常处理方法获得标准错误响应对象
+ response = exception_handler(exc, context)
+ # 在此处补充自定义的异常处理
+ if issubclass(exception_class, ValidationError):
+ return validation_error_to_result(exc)
+ if issubclass(exception_class, AppApiException):
+ return result.Result(exc.code, exc.message, response_status=exc.status_code)
+ if issubclass(exception_class, APIException):
+ return result.error(exc.detail)
+ if response is None:
+ logging.getLogger("max_kb_error").error(f'{str(exc)}:{traceback.format_exc()}')
+ return result.error(str(exc))
+ return response
diff --git a/apps/common/handle/impl/__init__.py b/apps/common/handle/impl/__init__.py
new file mode 100644
index 000000000..ad09602db
--- /dev/null
+++ b/apps/common/handle/impl/__init__.py
@@ -0,0 +1,8 @@
+# coding=utf-8
+"""
+ @project: qabot
+ @Author:虎
+ @file: __init__.py.py
+ @date:2023/9/6 10:09
+ @desc:
+"""
diff --git a/apps/common/handle/impl/common_handle.py b/apps/common/handle/impl/common_handle.py
new file mode 100644
index 000000000..48692f244
--- /dev/null
+++ b/apps/common/handle/impl/common_handle.py
@@ -0,0 +1,116 @@
+# coding=utf-8
+"""
+ @project: MaxKB
+ @Author:虎
+ @file: tools.py
+ @date:2024/9/11 16:41
+ @desc:
+"""
+import io
+import uuid_utils.compat as uuid
+from functools import reduce
+from io import BytesIO
+from xml.etree.ElementTree import fromstring
+from zipfile import ZipFile
+
+from PIL import Image as PILImage
+from openpyxl.drawing.image import Image as openpyxl_Image
+from openpyxl.packaging.relationship import get_rels_path, get_dependents
+from openpyxl.xml.constants import SHEET_DRAWING_NS, REL_NS, SHEET_MAIN_NS
+
+from common.handle.base_parse_qa_handle import get_title_row_index_dict, get_row_value
+from knowledge.models import File
+
+
+def parse_element(element) -> {}:
+ data = {}
+ xdr_namespace = "{%s}" % SHEET_DRAWING_NS
+ targets = level_order_traversal(element, xdr_namespace + "nvPicPr")
+ for target in targets:
+ cNvPr = embed = ""
+ for child in target:
+ if child.tag == xdr_namespace + "nvPicPr":
+ cNvPr = child[0].attrib["name"]
+ elif child.tag == xdr_namespace + "blipFill":
+ _rel_embed = "{%s}embed" % REL_NS
+ embed = child[0].attrib[_rel_embed]
+ if cNvPr:
+ data[cNvPr] = embed
+ return data
+
+
+def parse_element_sheet_xml(element) -> []:
+ data = []
+ xdr_namespace = "{%s}" % SHEET_MAIN_NS
+ targets = level_order_traversal(element, xdr_namespace + "f")
+ for target in targets:
+ for child in target:
+ if child.tag == xdr_namespace + "f":
+ data.append(child.text)
+ return data
+
+
+def level_order_traversal(root, flag: str) -> []:
+ queue = [root]
+ targets = []
+ while queue:
+ node = queue.pop(0)
+ children = [child.tag for child in node]
+ if flag in children:
+ targets.append(node)
+ continue
+ for child in node:
+ queue.append(child)
+ return targets
+
+
+def handle_images(deps, archive: ZipFile) -> []:
+ images = []
+ if not PILImage: # Pillow not installed, drop images
+ return images
+ for dep in deps:
+ try:
+ image_io = archive.read(dep.target)
+ image = openpyxl_Image(BytesIO(image_io))
+ except Exception as e:
+ print(e)
+ continue
+ image.embed = dep.id # 文件rId
+ image.target = dep.target # 文件地址
+ images.append(image)
+ return images
+
+
+def xlsx_embed_cells_images(buffer) -> {}:
+ archive = ZipFile(buffer)
+ # 解析cellImage.xml文件
+ deps = get_dependents(archive, get_rels_path("xl/cellimages.xml"))
+ image_rel = handle_images(deps=deps, archive=archive)
+ # 工作表及其中图片ID
+ sheet_list = {}
+ for item in archive.namelist():
+ if not item.startswith('xl/worksheets/sheet'):
+ continue
+ key = item.split('/')[-1].split('.')[0].split('sheet')[-1]
+ sheet_list[key] = parse_element_sheet_xml(fromstring(archive.read(item)))
+ cell_images_xml = parse_element(fromstring(archive.read("xl/cellimages.xml")))
+ cell_images_rel = {}
+ for image in image_rel:
+ cell_images_rel[image.embed] = image
+ for cnv, embed in cell_images_xml.items():
+ cell_images_xml[cnv] = cell_images_rel.get(embed)
+ result = {}
+ for key, img in cell_images_xml.items():
+ image_excel_id_list = [_xl for _xl in
+ reduce(lambda x, y: [*x, *y], [sheet for sheet_id, sheet in sheet_list.items()], []) if
+ key in _xl]
+ if len(image_excel_id_list) > 0:
+ image_excel_id = image_excel_id_list[-1]
+ f = archive.open(img.target)
+ img_byte = io.BytesIO()
+ im = PILImage.open(f).convert('RGB')
+ im.save(img_byte, format='JPEG')
+ image = File(id=uuid.uuid7(), file_name=img.path, meta={'debug': False, 'content': img_byte.getvalue()})
+ result['=' + image_excel_id] = image
+ archive.close()
+ return result
diff --git a/apps/common/handle/impl/csv_split_handle.py b/apps/common/handle/impl/csv_split_handle.py
new file mode 100644
index 000000000..3ea690e0e
--- /dev/null
+++ b/apps/common/handle/impl/csv_split_handle.py
@@ -0,0 +1,72 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: csv_parse_qa_handle.py
+ @date:2024/5/21 14:59
+ @desc:
+"""
+import csv
+import io
+import os
+from typing import List
+
+from charset_normalizer import detect
+
+from common.handle.base_split_handle import BaseSplitHandle
+
+
+def post_cell(cell_value):
+ return cell_value.replace('\n', '
').replace('|', '|')
+
+
+def row_to_md(row):
+ return '| ' + ' | '.join(
+ [post_cell(cell) if cell is not None else '' for cell in row]) + ' |\n'
+
+
+class CsvSplitHandle(BaseSplitHandle):
+ def handle(self, file, pattern_list: List, with_filter: bool, limit: int, get_buffer, save_image):
+ buffer = get_buffer(file)
+ paragraphs = []
+ file_name = os.path.basename(file.name)
+ result = {'name': file_name, 'content': paragraphs}
+ try:
+ reader = csv.reader(io.TextIOWrapper(io.BytesIO(buffer), encoding=detect(buffer)['encoding']))
+ try:
+ title_row_list = reader.__next__()
+ title_md_content = row_to_md(title_row_list)
+ title_md_content += '| ' + ' | '.join(
+ ['---' if cell is not None else '' for cell in title_row_list]) + ' |\n'
+ except Exception as e:
+ return result
+ if len(title_row_list) == 0:
+ return result
+ result_item_content = ''
+ for row in reader:
+ next_md_content = row_to_md(row)
+ next_md_content_len = len(next_md_content)
+ result_item_content_len = len(result_item_content)
+ if len(result_item_content) == 0:
+ result_item_content += title_md_content
+ result_item_content += next_md_content
+ else:
+ if result_item_content_len + next_md_content_len < limit:
+ result_item_content += next_md_content
+ else:
+ paragraphs.append({'content': result_item_content, 'title': ''})
+ result_item_content = title_md_content + next_md_content
+ if len(result_item_content) > 0:
+ paragraphs.append({'content': result_item_content, 'title': ''})
+ return result
+ except Exception as e:
+ return result
+
+ def get_content(self, file, save_image):
+ pass
+
+ def support(self, file, get_buffer):
+ file_name: str = file.name.lower()
+ if file_name.endswith(".csv"):
+ return True
+ return False
diff --git a/apps/common/handle/impl/doc_split_handle.py b/apps/common/handle/impl/doc_split_handle.py
new file mode 100644
index 000000000..752f726a2
--- /dev/null
+++ b/apps/common/handle/impl/doc_split_handle.py
@@ -0,0 +1,235 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: text_split_handle.py
+ @date:2024/3/27 18:19
+ @desc:
+"""
+import io
+import os
+import re
+import traceback
+import uuid_utils.compat as uuid
+from functools import reduce
+from typing import List
+
+from docx import Document, ImagePart
+from docx.oxml import ns
+from docx.table import Table
+from docx.text.paragraph import Paragraph
+
+from common.handle.base_split_handle import BaseSplitHandle
+from common.utils.split_model import SplitModel
+from knowledge.models import File
+from django.utils.translation import gettext_lazy as _
+
+default_pattern_list = [re.compile('(?<=^)# .*|(?<=\\n)# .*'),
+ re.compile('(?<=\\n)(? 0:
+ for image in _images:
+ images.append({'image': image, 'get_image_id_handle': get_image_id_handle})
+ except Exception as e:
+ pass
+ return images
+
+
+def images_to_string(images, doc: Document, images_list, get_image_id):
+ return "".join(
+ [item for item in [image_to_mode(image, doc, images_list, get_image_id) for image in images] if
+ item is not None])
+
+
+def get_paragraph_element_txt(paragraph_element, doc: Document, images_list, get_image_id):
+ try:
+ images = get_paragraph_element_images(paragraph_element, doc, images_list, get_image_id)
+ if len(images) > 0:
+ return images_to_string(images, doc, images_list, get_image_id)
+ elif paragraph_element.text is not None:
+ return paragraph_element.text
+ return ""
+ except Exception as e:
+ print(e)
+ return ""
+
+
+def get_paragraph_txt(paragraph: Paragraph, doc: Document, images_list, get_image_id):
+ try:
+ return "".join([get_paragraph_element_txt(e, doc, images_list, get_image_id) for e in paragraph._element])
+ except Exception as e:
+ return ""
+
+
+def get_cell_text(cell, doc: Document, images_list, get_image_id):
+ try:
+ return "".join(
+ [get_paragraph_txt(paragraph, doc, images_list, get_image_id) for paragraph in cell.paragraphs]).replace(
+ "\n", '')
+ except Exception as e:
+ return ""
+
+
+def get_image_id_func():
+ image_map = {}
+
+ def get_image_id(image_id):
+ _v = image_map.get(image_id)
+ if _v is None:
+ image_map[image_id] = uuid.uuid7()
+ return image_map.get(image_id)
+ return _v
+
+ return get_image_id
+
+
+title_font_list = [
+ [36, 100],
+ [30, 36]
+]
+
+
+def get_title_level(paragraph: Paragraph):
+ try:
+ if paragraph.style is not None:
+ psn = paragraph.style.name
+ if psn.startswith('Heading') or psn.startswith('TOC 标题') or psn.startswith('标题'):
+ return int(psn.replace("Heading ", '').replace('TOC 标题', '').replace('标题',
+ ''))
+ if len(paragraph.runs) == 1:
+ font_size = paragraph.runs[0].font.size
+ pt = font_size.pt
+ if pt >= 30:
+ for _value, index in zip(title_font_list, range(len(title_font_list))):
+ if pt >= _value[0] and pt < _value[1]:
+ return index + 1
+ except Exception as e:
+ pass
+ return None
+
+
+class DocSplitHandle(BaseSplitHandle):
+ @staticmethod
+ def paragraph_to_md(paragraph: Paragraph, doc: Document, images_list, get_image_id):
+ try:
+ title_level = get_title_level(paragraph)
+ if title_level is not None:
+ title = "".join(["#" for i in range(title_level)]) + " " + paragraph.text
+ images = reduce(lambda x, y: [*x, *y],
+ [get_paragraph_element_images(e, doc, images_list, get_image_id) for e in
+ paragraph._element],
+ [])
+ if len(images) > 0:
+ return title + '\n' + images_to_string(images, doc, images_list, get_image_id) if len(
+ paragraph.text) > 0 else images_to_string(images, doc, images_list, get_image_id)
+ return title
+
+ except Exception as e:
+ traceback.print_exc()
+ return paragraph.text
+ return get_paragraph_txt(paragraph, doc, images_list, get_image_id)
+
+ @staticmethod
+ def table_to_md(table, doc: Document, images_list, get_image_id):
+ rows = table.rows
+
+ # 创建 Markdown 格式的表格
+ md_table = '| ' + ' | '.join(
+ [get_cell_text(cell, doc, images_list, get_image_id) for cell in rows[0].cells]) + ' |\n'
+ md_table += '| ' + ' | '.join(['---' for i in range(len(rows[0].cells))]) + ' |\n'
+ for row in rows[1:]:
+ md_table += '| ' + ' | '.join(
+ [get_cell_text(cell, doc, images_list, get_image_id) for cell in row.cells]) + ' |\n'
+ return md_table
+
+ def to_md(self, doc, images_list, get_image_id):
+ elements = []
+ for element in doc.element.body:
+ tag = str(element.tag)
+ if tag.endswith('tbl'):
+ # 处理表格
+ table = Table(element, doc)
+ elements.append(table)
+ elif tag.endswith('p'):
+ # 处理段落
+ paragraph = Paragraph(element, doc)
+ elements.append(paragraph)
+ return "\n".join(
+ [self.paragraph_to_md(element, doc, images_list, get_image_id) if isinstance(element,
+ Paragraph) else self.table_to_md(
+ element,
+ doc,
+ images_list, get_image_id)
+ for element
+ in elements])
+
+ def handle(self, file, pattern_list: List, with_filter: bool, limit: int, get_buffer, save_image):
+ file_name = os.path.basename(file.name)
+ try:
+ image_list = []
+ buffer = get_buffer(file)
+ doc = Document(io.BytesIO(buffer))
+ content = self.to_md(doc, image_list, get_image_id_func())
+ if len(image_list) > 0:
+ save_image(image_list)
+ if pattern_list is not None and len(pattern_list) > 0:
+ split_model = SplitModel(pattern_list, with_filter, limit)
+ else:
+ split_model = SplitModel(default_pattern_list, with_filter=with_filter, limit=limit)
+ except BaseException as e:
+ traceback.print_exception(e)
+ return {'name': file_name,
+ 'content': []}
+ return {'name': file_name,
+ 'content': split_model.parse(content)
+ }
+
+ def support(self, file, get_buffer):
+ file_name: str = file.name.lower()
+ if file_name.endswith(".docx") or file_name.endswith(".doc") or file_name.endswith(
+ ".DOC") or file_name.endswith(".DOCX"):
+ return True
+ return False
+
+ def get_content(self, file, save_image):
+ try:
+ image_list = []
+ buffer = file.read()
+ doc = Document(io.BytesIO(buffer))
+ content = self.to_md(doc, image_list, get_image_id_func())
+ if len(image_list) > 0:
+ content = content.replace('/api/image/', '/api/file/')
+ save_image(image_list)
+ return content
+ except BaseException as e:
+ traceback.print_exception(e)
+ return f'{e}'
diff --git a/apps/common/handle/impl/html_split_handle.py b/apps/common/handle/impl/html_split_handle.py
new file mode 100644
index 000000000..fd4b84fc4
--- /dev/null
+++ b/apps/common/handle/impl/html_split_handle.py
@@ -0,0 +1,73 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: html_split_handle.py
+ @date:2024/5/23 10:58
+ @desc:
+"""
+import re
+import traceback
+from typing import List
+
+from bs4 import BeautifulSoup
+from charset_normalizer import detect
+from html2text import html2text
+
+from common.handle.base_split_handle import BaseSplitHandle
+from common.utils.split_model import SplitModel
+
+default_pattern_list = [re.compile('(?<=^)# .*|(?<=\\n)# .*'),
+ re.compile('(?<=\\n)(? 0:
+ charset = charset_list[0]
+ return charset
+ return detect(buffer)['encoding']
+
+
+class HTMLSplitHandle(BaseSplitHandle):
+ def support(self, file, get_buffer):
+ file_name: str = file.name.lower()
+ if file_name.endswith(".html") or file_name.endswith(".HTML"):
+ return True
+ return False
+
+ def handle(self, file, pattern_list: List, with_filter: bool, limit: int, get_buffer, save_image):
+ buffer = get_buffer(file)
+
+ if pattern_list is not None and len(pattern_list) > 0:
+ split_model = SplitModel(pattern_list, with_filter, limit)
+ else:
+ split_model = SplitModel(default_pattern_list, with_filter=with_filter, limit=limit)
+ try:
+ encoding = get_encoding(buffer)
+ content = buffer.decode(encoding)
+ content = html2text(content)
+ except BaseException as e:
+ return {'name': file.name,
+ 'content': []}
+ return {'name': file.name,
+ 'content': split_model.parse(content)
+ }
+
+ def get_content(self, file, save_image):
+ buffer = file.read()
+
+ try:
+ encoding = get_encoding(buffer)
+ content = buffer.decode(encoding)
+ return html2text(content)
+ except BaseException as e:
+ traceback.print_exception(e)
+ return f'{e}'
\ No newline at end of file
diff --git a/apps/common/handle/impl/pdf_split_handle.py b/apps/common/handle/impl/pdf_split_handle.py
new file mode 100644
index 000000000..3e011d1d5
--- /dev/null
+++ b/apps/common/handle/impl/pdf_split_handle.py
@@ -0,0 +1,339 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: text_split_handle.py
+ @date:2024/3/27 18:19
+ @desc:
+"""
+import logging
+import os
+import re
+import tempfile
+import time
+import traceback
+from typing import List
+
+import fitz
+from langchain_community.document_loaders import PyPDFLoader
+
+from common.handle.base_split_handle import BaseSplitHandle
+from common.utils.split_model import SplitModel
+from django.utils.translation import gettext_lazy as _
+
+default_pattern_list = [re.compile('(?<=^)# .*|(?<=\\n)# .*'),
+ re.compile('(?<=\\n)(? 0:
+ return {'name': file.name, 'content': result}
+
+ # 没有目录的pdf
+ content = self.handle_pdf_content(file, pdf_document)
+
+ if pattern_list is not None and len(pattern_list) > 0:
+ split_model = SplitModel(pattern_list, with_filter, limit)
+ else:
+ split_model = SplitModel(default_pattern_list, with_filter=with_filter, limit=limit)
+ except BaseException as e:
+ max_kb.error(f"File: {file.name}, error: {e}")
+ return {'name': file.name,
+ 'content': []}
+ finally:
+ pdf_document.close()
+ # 处理完后可以删除临时文件
+ os.remove(temp_file_path)
+
+ return {'name': file.name,
+ 'content': split_model.parse(content)
+ }
+
+ @staticmethod
+ def handle_pdf_content(file, pdf_document):
+ content = ""
+ for page_num in range(len(pdf_document)):
+ start_time = time.time()
+ page = pdf_document.load_page(page_num)
+ text = page.get_text()
+
+ if text and text.strip(): # 如果页面中有文本内容
+ page_content = text
+ else:
+ try:
+ new_doc = fitz.open()
+ new_doc.insert_pdf(pdf_document, from_page=page_num, to_page=page_num)
+ page_num_pdf = tempfile.gettempdir() + f"/{file.name}_{page_num}.pdf"
+ new_doc.save(page_num_pdf)
+ new_doc.close()
+
+ loader = PyPDFLoader(page_num_pdf, extract_images=True)
+ page_content = "\n" + loader.load()[0].page_content
+ except NotImplementedError as e:
+ # 文件格式不支持,直接退出
+ raise e
+ except BaseException as e:
+ # 当页出错继续进行下一页,防止一个页面出错导致整个文件解析失败
+ max_kb.error(f"File: {file.name}, Page: {page_num + 1}, error: {e}")
+ continue
+ finally:
+ os.remove(page_num_pdf)
+
+ content += page_content
+
+ # Null characters are not allowed.
+ content = content.replace('\0', '')
+
+ elapsed_time = time.time() - start_time
+ max_kb.debug(
+ f"File: {file.name}, Page: {page_num + 1}, Time : {elapsed_time: .3f}s, content-length: {len(page_content)}")
+
+ return content
+
+ @staticmethod
+ def handle_toc(doc, limit):
+ # 找到目录
+ toc = doc.get_toc()
+ if toc is None or len(toc) == 0:
+ return None
+
+ # 创建存储章节内容的数组
+ chapters = []
+
+ # 遍历目录并按章节提取文本
+ for i, entry in enumerate(toc):
+ level, title, start_page = entry
+ start_page -= 1 # PyMuPDF 页码从 0 开始,书签页码从 1 开始
+ chapter_title = title
+ # 确定结束页码,如果是最后一个章节则到文档末尾
+ if i + 1 < len(toc):
+ end_page = toc[i + 1][2] - 1
+ else:
+ end_page = doc.page_count - 1
+
+ # 去掉标题中的符号
+ title = PdfSplitHandle.handle_chapter_title(title)
+
+ # 提取该章节的文本内容
+ chapter_text = ""
+ for page_num in range(start_page, end_page + 1):
+ page = doc.load_page(page_num) # 加载页面
+ text = page.get_text("text")
+ text = re.sub(r'(? -1:
+ text = text[idx + len(title):]
+
+ if i + 1 < len(toc):
+ l, next_title, next_start_page = toc[i + 1]
+ next_title = PdfSplitHandle.handle_chapter_title(next_title)
+ # print(f'next_title: {next_title}')
+ idx = text.find(next_title)
+ if idx > -1:
+ text = text[:idx]
+
+ chapter_text += text # 提取文本
+
+ # Null characters are not allowed.
+ chapter_text = chapter_text.replace('\0', '')
+ # 限制标题长度
+ real_chapter_title = chapter_title[:256]
+ # 限制章节内容长度
+ if 0 < limit < len(chapter_text):
+ split_text = PdfSplitHandle.split_text(chapter_text, limit)
+ for text in split_text:
+ chapters.append({"title": real_chapter_title, "content": text})
+ else:
+ chapters.append({"title": real_chapter_title, "content": chapter_text if chapter_text else real_chapter_title})
+ # 保存章节内容和章节标题
+ return chapters
+
+ @staticmethod
+ def handle_links(doc, pattern_list, with_filter, limit):
+ # 检查文档是否包含内部链接
+ if not check_links_in_pdf(doc):
+ return
+ # 创建存储章节内容的数组
+ chapters = []
+ toc_start_page = -1
+ page_content = ""
+ handle_pre_toc = True
+ # 遍历 PDF 的每一页,查找带有目录链接的页
+ for page_num in range(doc.page_count):
+ page = doc.load_page(page_num)
+ links = page.get_links()
+ # 如果目录开始页码未设置,则设置为当前页码
+ if len(links) > 0:
+ toc_start_page = page_num
+ if toc_start_page < 0:
+ page_content += page.get_text('text')
+ # 检查该页是否包含内部链接(即指向文档内部的页面)
+ for num in range(len(links)):
+ link = links[num]
+ if link['kind'] == 1: # 'kind' 为 1 表示内部链接
+ # 获取链接目标的页面
+ dest_page = link['page']
+ rect = link['from'] # 获取链接的矩形区域
+ # 如果目录开始页码包括前言部分,则不处理前言部分
+ if dest_page < toc_start_page:
+ handle_pre_toc = False
+
+ # 提取链接区域的文本作为标题
+ link_title = page.get_text("text", clip=rect).strip().split("\n")[0].replace('.', '').strip()
+ # print(f'link_title: {link_title}')
+ # 提取目标页面内容作为章节开始
+ start_page = dest_page
+ end_page = dest_page
+ # 下一个link
+ next_link = links[num + 1] if num + 1 < len(links) else None
+ next_link_title = None
+ if next_link is not None and next_link['kind'] == 1:
+ rect = next_link['from']
+ next_link_title = page.get_text("text", clip=rect).strip() \
+ .split("\n")[0].replace('.', '').strip()
+ # print(f'next_link_title: {next_link_title}')
+ end_page = next_link['page']
+
+ # 提取章节内容
+ chapter_text = ""
+ for p_num in range(start_page, end_page + 1):
+ p = doc.load_page(p_num)
+ text = p.get_text("text")
+ text = re.sub(r'(? -1:
+ text = text[idx + len(link_title):]
+
+ if next_link_title is not None:
+ idx = text.find(next_link_title)
+ if idx > -1:
+ text = text[:idx]
+ chapter_text += text
+
+ # Null characters are not allowed.
+ chapter_text = chapter_text.replace('\0', '')
+
+ # 限制章节内容长度
+ if 0 < limit < len(chapter_text):
+ split_text = PdfSplitHandle.split_text(chapter_text, limit)
+ for text in split_text:
+ chapters.append({"title": link_title, "content": text})
+ else:
+ # 保存章节信息
+ chapters.append({"title": link_title, "content": chapter_text})
+
+ # 目录中没有前言部分,手动处理
+ if handle_pre_toc:
+ pre_toc = []
+ lines = page_content.strip().split('\n')
+ try:
+ for line in lines:
+ if re.match(r'^前\s*言', line):
+ pre_toc.append({'title': line, 'content': ''})
+ else:
+ pre_toc[-1]['content'] += line
+ for i in range(len(pre_toc)):
+ pre_toc[i]['content'] = re.sub(r'(? 0:
+ split_model = SplitModel(pattern_list, with_filter, limit)
+ else:
+ split_model = SplitModel(default_pattern_list, with_filter=with_filter, limit=limit)
+ # 插入目录前的部分
+ page_content = re.sub(r'(?= length:
+ # 查找最近的句号
+ last_period_index = current_segment.rfind('.')
+ if last_period_index != -1:
+ segments.append(current_segment[:last_period_index + 1])
+ current_segment = current_segment[last_period_index + 1:] # 更新当前段落
+ else:
+ segments.append(current_segment)
+ current_segment = ""
+
+ # 处理剩余的部分
+ if current_segment:
+ segments.append(current_segment)
+
+ return segments
+
+ @staticmethod
+ def handle_chapter_title(title):
+ title = re.sub(r'[一二三四五六七八九十\s*]、\s*', '', title)
+ title = re.sub(r'第[一二三四五六七八九十]章\s*', '', title)
+ return title
+
+ def support(self, file, get_buffer):
+ file_name: str = file.name.lower()
+ if file_name.endswith(".pdf") or file_name.endswith(".PDF"):
+ return True
+ return False
+
+ def get_content(self, file, save_image):
+ with tempfile.NamedTemporaryFile(delete=False) as temp_file:
+ # 将上传的文件保存到临时文件中
+ temp_file.write(file.read())
+ # 获取临时文件的路径
+ temp_file_path = temp_file.name
+
+ pdf_document = fitz.open(temp_file_path)
+ try:
+ return self.handle_pdf_content(file, pdf_document)
+ except BaseException as e:
+ traceback.print_exception(e)
+ return f'{e}'
\ No newline at end of file
diff --git a/apps/common/handle/impl/qa/__init__.py b/apps/common/handle/impl/qa/__init__.py
new file mode 100644
index 000000000..ad09602db
--- /dev/null
+++ b/apps/common/handle/impl/qa/__init__.py
@@ -0,0 +1,8 @@
+# coding=utf-8
+"""
+ @project: qabot
+ @Author:虎
+ @file: __init__.py.py
+ @date:2023/9/6 10:09
+ @desc:
+"""
diff --git a/apps/common/handle/impl/qa/csv_parse_qa_handle.py b/apps/common/handle/impl/qa/csv_parse_qa_handle.py
new file mode 100644
index 000000000..75c22cbda
--- /dev/null
+++ b/apps/common/handle/impl/qa/csv_parse_qa_handle.py
@@ -0,0 +1,59 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: csv_parse_qa_handle.py
+ @date:2024/5/21 14:59
+ @desc:
+"""
+import csv
+import io
+
+from charset_normalizer import detect
+
+from common.handle.base_parse_qa_handle import BaseParseQAHandle, get_title_row_index_dict, get_row_value
+
+
+def read_csv_standard(file_path):
+ data = []
+ with open(file_path, 'r') as file:
+ reader = csv.reader(file)
+ for row in reader:
+ data.append(row)
+ return data
+
+
+class CsvParseQAHandle(BaseParseQAHandle):
+ def support(self, file, get_buffer):
+ file_name: str = file.name.lower()
+ if file_name.endswith(".csv"):
+ return True
+ return False
+
+ def handle(self, file, get_buffer, save_image):
+ buffer = get_buffer(file)
+ try:
+ reader = csv.reader(io.TextIOWrapper(io.BytesIO(buffer), encoding=detect(buffer)['encoding']))
+ try:
+ title_row_list = reader.__next__()
+ except Exception as e:
+ return [{'name': file.name, 'paragraphs': []}]
+ if len(title_row_list) == 0:
+ return [{'name': file.name, 'paragraphs': []}]
+ title_row_index_dict = get_title_row_index_dict(title_row_list)
+ paragraph_list = []
+ for row in reader:
+ content = get_row_value(row, title_row_index_dict, 'content')
+ if content is None:
+ continue
+ problem = get_row_value(row, title_row_index_dict, 'problem_list')
+ problem = str(problem) if problem is not None else ''
+ problem_list = [{'content': p[0:255]} for p in problem.split('\n') if len(p.strip()) > 0]
+ title = get_row_value(row, title_row_index_dict, 'title')
+ title = str(title) if title is not None else ''
+ paragraph_list.append({'title': title[0:255],
+ 'content': content[0:102400],
+ 'problem_list': problem_list})
+ return [{'name': file.name, 'paragraphs': paragraph_list}]
+ except Exception as e:
+ return [{'name': file.name, 'paragraphs': []}]
diff --git a/apps/common/handle/impl/qa/xls_parse_qa_handle.py b/apps/common/handle/impl/qa/xls_parse_qa_handle.py
new file mode 100644
index 000000000..06edb1fb3
--- /dev/null
+++ b/apps/common/handle/impl/qa/xls_parse_qa_handle.py
@@ -0,0 +1,61 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: xls_parse_qa_handle.py
+ @date:2024/5/21 14:59
+ @desc:
+"""
+
+import xlrd
+
+from common.handle.base_parse_qa_handle import BaseParseQAHandle, get_title_row_index_dict, get_row_value
+
+
+def handle_sheet(file_name, sheet):
+ rows = iter([sheet.row_values(i) for i in range(sheet.nrows)])
+ try:
+ title_row_list = next(rows)
+ except Exception as e:
+ return {'name': file_name, 'paragraphs': []}
+ if len(title_row_list) == 0:
+ return {'name': file_name, 'paragraphs': []}
+ title_row_index_dict = get_title_row_index_dict(title_row_list)
+ paragraph_list = []
+ for row in rows:
+ content = get_row_value(row, title_row_index_dict, 'content')
+ if content is None:
+ continue
+ problem = get_row_value(row, title_row_index_dict, 'problem_list')
+ problem = str(problem) if problem is not None else ''
+ problem_list = [{'content': p[0:255]} for p in problem.split('\n') if len(p.strip()) > 0]
+ title = get_row_value(row, title_row_index_dict, 'title')
+ title = str(title) if title is not None else ''
+ content = str(content)
+ paragraph_list.append({'title': title[0:255],
+ 'content': content[0:102400],
+ 'problem_list': problem_list})
+ return {'name': file_name, 'paragraphs': paragraph_list}
+
+
+class XlsParseQAHandle(BaseParseQAHandle):
+ def support(self, file, get_buffer):
+ file_name: str = file.name.lower()
+ buffer = get_buffer(file)
+ if file_name.endswith(".xls") and xlrd.inspect_format(content=buffer):
+ return True
+ return False
+
+ def handle(self, file, get_buffer, save_image):
+ buffer = get_buffer(file)
+ try:
+ workbook = xlrd.open_workbook(file_contents=buffer)
+ worksheets = workbook.sheets()
+ worksheets_size = len(worksheets)
+ return [row for row in
+ [handle_sheet(file.name,
+ sheet) if worksheets_size == 1 and sheet.name == 'Sheet1' else handle_sheet(
+ sheet.name, sheet) for sheet
+ in worksheets] if row is not None]
+ except Exception as e:
+ return [{'name': file.name, 'paragraphs': []}]
diff --git a/apps/common/handle/impl/qa/xlsx_parse_qa_handle.py b/apps/common/handle/impl/qa/xlsx_parse_qa_handle.py
new file mode 100644
index 000000000..ecadb5139
--- /dev/null
+++ b/apps/common/handle/impl/qa/xlsx_parse_qa_handle.py
@@ -0,0 +1,72 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: xlsx_parse_qa_handle.py
+ @date:2024/5/21 14:59
+ @desc:
+"""
+import io
+
+import openpyxl
+
+from common.handle.base_parse_qa_handle import BaseParseQAHandle, get_title_row_index_dict, get_row_value
+from common.handle.impl.common_handle import xlsx_embed_cells_images
+
+
+def handle_sheet(file_name, sheet, image_dict):
+ rows = sheet.rows
+ try:
+ title_row_list = next(rows)
+ title_row_list = [row.value for row in title_row_list]
+ except Exception as e:
+ return {'name': file_name, 'paragraphs': []}
+ if len(title_row_list) == 0:
+ return {'name': file_name, 'paragraphs': []}
+ title_row_index_dict = get_title_row_index_dict(title_row_list)
+ paragraph_list = []
+ for row in rows:
+ content = get_row_value(row, title_row_index_dict, 'content')
+ if content is None or content.value is None:
+ continue
+ problem = get_row_value(row, title_row_index_dict, 'problem_list')
+ problem = str(problem.value) if problem is not None and problem.value is not None else ''
+ problem_list = [{'content': p[0:255]} for p in problem.split('\n') if len(p.strip()) > 0]
+ title = get_row_value(row, title_row_index_dict, 'title')
+ title = str(title.value) if title is not None and title.value is not None else ''
+ content = str(content.value)
+ image = image_dict.get(content, None)
+ if image is not None:
+ content = f''
+ paragraph_list.append({'title': title[0:255],
+ 'content': content[0:102400],
+ 'problem_list': problem_list})
+ return {'name': file_name, 'paragraphs': paragraph_list}
+
+
+class XlsxParseQAHandle(BaseParseQAHandle):
+ def support(self, file, get_buffer):
+ file_name: str = file.name.lower()
+ if file_name.endswith(".xlsx"):
+ return True
+ return False
+
+ def handle(self, file, get_buffer, save_image):
+ buffer = get_buffer(file)
+ try:
+ workbook = openpyxl.load_workbook(io.BytesIO(buffer))
+ try:
+ image_dict: dict = xlsx_embed_cells_images(io.BytesIO(buffer))
+ save_image([item for item in image_dict.values()])
+ except Exception as e:
+ image_dict = {}
+ worksheets = workbook.worksheets
+ worksheets_size = len(worksheets)
+ return [row for row in
+ [handle_sheet(file.name,
+ sheet,
+ image_dict) if worksheets_size == 1 and sheet.title == 'Sheet1' else handle_sheet(
+ sheet.title, sheet, image_dict) for sheet
+ in worksheets] if row is not None]
+ except Exception as e:
+ return [{'name': file.name, 'paragraphs': []}]
diff --git a/apps/common/handle/impl/qa/zip_parse_qa_handle.py b/apps/common/handle/impl/qa/zip_parse_qa_handle.py
new file mode 100644
index 000000000..d00bc14dd
--- /dev/null
+++ b/apps/common/handle/impl/qa/zip_parse_qa_handle.py
@@ -0,0 +1,161 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: text_split_handle.py
+ @date:2024/3/27 18:19
+ @desc:
+"""
+import io
+import os
+import re
+import uuid_utils.compat as uuid
+import zipfile
+from typing import List
+from urllib.parse import urljoin
+
+from django.db.models import QuerySet
+
+from common.handle.base_parse_qa_handle import BaseParseQAHandle
+from common.handle.impl.qa.csv_parse_qa_handle import CsvParseQAHandle
+from common.handle.impl.qa.xls_parse_qa_handle import XlsParseQAHandle
+from common.handle.impl.qa.xlsx_parse_qa_handle import XlsxParseQAHandle
+from common.utils.common import parse_md_image
+from knowledge.models import File
+from django.utils.translation import gettext_lazy as _
+
+
+class FileBufferHandle:
+ buffer = None
+
+ def get_buffer(self, file):
+ if self.buffer is None:
+ self.buffer = file.read()
+ return self.buffer
+
+
+split_handles = [
+ XlsParseQAHandle(),
+ XlsxParseQAHandle(),
+ CsvParseQAHandle()
+]
+
+
+def file_to_paragraph(file, save_inner_image):
+ """
+ 文件转换为段落列表
+ @param file: 文件
+ @return: {
+ name:文件名
+ paragraphs:段落列表
+ }
+ """
+ get_buffer = FileBufferHandle().get_buffer
+ for split_handle in split_handles:
+ if split_handle.support(file, get_buffer):
+ return split_handle.handle(file, get_buffer, save_inner_image)
+ raise Exception(_("Unsupported file format"))
+
+
+def is_valid_uuid(uuid_str: str):
+ """
+ 校验字符串是否是uuid
+ @param uuid_str: 需要校验的字符串
+ @return: bool
+ """
+ try:
+ uuid.UUID(uuid_str)
+ except ValueError:
+ return False
+ return True
+
+
+def get_image_list(result_list: list, zip_files: List[str]):
+ """
+ 获取图片文件列表
+ @param result_list:
+ @param zip_files:
+ @return:
+ """
+ image_file_list = []
+ for result in result_list:
+ for p in result.get('paragraphs', []):
+ content: str = p.get('content', '')
+ image_list = parse_md_image(content)
+ for image in image_list:
+ search = re.search("\(.*\)", image)
+ if search:
+ new_image_id = str(uuid.uuid7())
+ source_image_path = search.group().replace('(', '').replace(')', '')
+ image_path = urljoin(result.get('name'), '.' + source_image_path if source_image_path.startswith(
+ '/') else source_image_path)
+ if not zip_files.__contains__(image_path):
+ continue
+ if image_path.startswith('api/file/') or image_path.startswith('api/image/'):
+ image_id = image_path.replace('api/file/', '').replace('api/image/', '')
+ if is_valid_uuid(image_id):
+ image_file_list.append({'source_file': image_path,
+ 'image_id': image_id})
+ else:
+ image_file_list.append({'source_file': image_path,
+ 'image_id': new_image_id})
+ content = content.replace(source_image_path, f'/api/image/{new_image_id}')
+ p['content'] = content
+ else:
+ image_file_list.append({'source_file': image_path,
+ 'image_id': new_image_id})
+ content = content.replace(source_image_path, f'/api/image/{new_image_id}')
+ p['content'] = content
+
+ return image_file_list
+
+
+def filter_image_file(result_list: list, image_list):
+ image_source_file_list = [image.get('source_file') for image in image_list]
+ return [r for r in result_list if not image_source_file_list.__contains__(r.get('name', ''))]
+
+
+class ZipParseQAHandle(BaseParseQAHandle):
+
+ def handle(self, file, get_buffer, save_image):
+ buffer = get_buffer(file)
+ bytes_io = io.BytesIO(buffer)
+ result = []
+ # 打开zip文件
+ with zipfile.ZipFile(bytes_io, 'r') as zip_ref:
+ # 获取压缩包中的文件名列表
+ files = zip_ref.namelist()
+ # 读取压缩包中的文件内容
+ for file in files:
+ # 跳过 macOS 特有的元数据目录和文件
+ if file.endswith('/') or file.startswith('__MACOSX'):
+ continue
+ with zip_ref.open(file) as f:
+ # 对文件内容进行处理
+ try:
+ value = file_to_paragraph(f, save_image)
+ if isinstance(value, list):
+ result = [*result, *value]
+ else:
+ result.append(value)
+ except Exception:
+ pass
+ image_list = get_image_list(result, files)
+ result = filter_image_file(result, image_list)
+ image_mode_list = []
+ for image in image_list:
+ with zip_ref.open(image.get('source_file')) as f:
+ i = File(
+ id=image.get('image_id'),
+ file_name=os.path.basename(image.get('source_file')),
+ meta={'debug': False, 'content': f.read()}
+ )
+ image_mode_list.append(i)
+ save_image(image_mode_list)
+ return result
+
+ def support(self, file, get_buffer):
+ file_name: str = file.name.lower()
+ if file_name.endswith(".zip") or file_name.endswith(".ZIP"):
+ return True
+ return False
diff --git a/apps/common/handle/impl/response/__init__.py b/apps/common/handle/impl/response/__init__.py
new file mode 100644
index 000000000..ad09602db
--- /dev/null
+++ b/apps/common/handle/impl/response/__init__.py
@@ -0,0 +1,8 @@
+# coding=utf-8
+"""
+ @project: qabot
+ @Author:虎
+ @file: __init__.py.py
+ @date:2023/9/6 10:09
+ @desc:
+"""
diff --git a/apps/common/handle/impl/response/openai_to_response.py b/apps/common/handle/impl/response/openai_to_response.py
new file mode 100644
index 000000000..f2b69384e
--- /dev/null
+++ b/apps/common/handle/impl/response/openai_to_response.py
@@ -0,0 +1,52 @@
+# coding=utf-8
+"""
+ @project: MaxKB
+ @Author:虎
+ @file: openai_to_response.py
+ @date:2024/9/6 16:08
+ @desc:
+"""
+import datetime
+
+from django.http import JsonResponse
+from openai.types import CompletionUsage
+from openai.types.chat import ChatCompletionChunk, ChatCompletionMessage, ChatCompletion
+from openai.types.chat.chat_completion import Choice as BlockChoice
+from openai.types.chat.chat_completion_chunk import Choice, ChoiceDelta
+from rest_framework import status
+
+from common.handle.base_to_response import BaseToResponse
+
+
+class OpenaiToResponse(BaseToResponse):
+ def to_block_response(self, chat_id, chat_record_id, content, is_end, completion_tokens, prompt_tokens,
+ other_params: dict = None,
+ _status=status.HTTP_200_OK):
+ if other_params is None:
+ other_params = {}
+ data = ChatCompletion(id=chat_record_id, choices=[
+ BlockChoice(finish_reason='stop', index=0, chat_id=chat_id,
+ answer_list=other_params.get('answer_list', ""),
+ message=ChatCompletionMessage(role='assistant', content=content))],
+ created=datetime.datetime.now().second, model='', object='chat.completion',
+ usage=CompletionUsage(completion_tokens=completion_tokens,
+ prompt_tokens=prompt_tokens,
+ total_tokens=completion_tokens + prompt_tokens)
+ ).dict()
+ return JsonResponse(data=data, status=_status)
+
+ def to_stream_chunk_response(self, chat_id, chat_record_id, node_id, up_node_id_list, content, is_end,
+ completion_tokens,
+ prompt_tokens, other_params: dict = None):
+ if other_params is None:
+ other_params = {}
+ chunk = ChatCompletionChunk(id=chat_record_id, model='', object='chat.completion.chunk',
+ created=datetime.datetime.now().second, choices=[
+ Choice(delta=ChoiceDelta(content=content, reasoning_content=other_params.get('reasoning_content', ""),
+ chat_id=chat_id),
+ finish_reason='stop' if is_end else None,
+ index=0)],
+ usage=CompletionUsage(completion_tokens=completion_tokens,
+ prompt_tokens=prompt_tokens,
+ total_tokens=completion_tokens + prompt_tokens)).json()
+ return super().format_stream_chunk(chunk)
diff --git a/apps/common/handle/impl/response/system_to_response.py b/apps/common/handle/impl/response/system_to_response.py
new file mode 100644
index 000000000..a1a530dba
--- /dev/null
+++ b/apps/common/handle/impl/response/system_to_response.py
@@ -0,0 +1,41 @@
+# coding=utf-8
+"""
+ @project: MaxKB
+ @Author:虎
+ @file: system_to_response.py
+ @date:2024/9/6 18:03
+ @desc:
+"""
+import json
+
+from rest_framework import status
+
+from common.handle.base_to_response import BaseToResponse
+from common.result import result
+
+
+class SystemToResponse(BaseToResponse):
+ def to_block_response(self, chat_id, chat_record_id, content, is_end, completion_tokens,
+ prompt_tokens, other_params: dict = None,
+ _status=status.HTTP_200_OK):
+ if other_params is None:
+ other_params = {}
+ return result.success({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
+ 'content': content, 'is_end': is_end, **other_params,
+ 'completion_tokens': completion_tokens, 'prompt_tokens': prompt_tokens},
+ response_status=_status,
+ code=_status)
+
+ def to_stream_chunk_response(self, chat_id, chat_record_id, node_id, up_node_id_list, content, is_end,
+ completion_tokens,
+ prompt_tokens, other_params: dict = None):
+ if other_params is None:
+ other_params = {}
+ chunk = json.dumps({'chat_id': str(chat_id), 'chat_record_id': str(chat_record_id), 'operate': True,
+ 'content': content, 'node_id': node_id, 'up_node_id_list': up_node_id_list,
+ 'is_end': is_end,
+ 'usage': {'completion_tokens': completion_tokens,
+ 'prompt_tokens': prompt_tokens,
+ 'total_tokens': completion_tokens + prompt_tokens},
+ **other_params})
+ return super().format_stream_chunk(chunk)
diff --git a/apps/common/handle/impl/table/__init__.py b/apps/common/handle/impl/table/__init__.py
new file mode 100644
index 000000000..ad09602db
--- /dev/null
+++ b/apps/common/handle/impl/table/__init__.py
@@ -0,0 +1,8 @@
+# coding=utf-8
+"""
+ @project: qabot
+ @Author:虎
+ @file: __init__.py.py
+ @date:2023/9/6 10:09
+ @desc:
+"""
diff --git a/apps/common/handle/impl/table/csv_parse_table_handle.py b/apps/common/handle/impl/table/csv_parse_table_handle.py
new file mode 100644
index 000000000..e2fc7ce86
--- /dev/null
+++ b/apps/common/handle/impl/table/csv_parse_table_handle.py
@@ -0,0 +1,44 @@
+# coding=utf-8
+import logging
+
+from charset_normalizer import detect
+
+from common.handle.base_parse_table_handle import BaseParseTableHandle
+
+max_kb = logging.getLogger("max_kb")
+
+
+class CsvSplitHandle(BaseParseTableHandle):
+ def support(self, file, get_buffer):
+ file_name: str = file.name.lower()
+ if file_name.endswith(".csv"):
+ return True
+ return False
+
+ def handle(self, file, get_buffer,save_image):
+ buffer = get_buffer(file)
+ try:
+ content = buffer.decode(detect(buffer)['encoding'])
+ except BaseException as e:
+ max_kb.error(f'csv split handle error: {e}')
+ return [{'name': file.name, 'paragraphs': []}]
+
+ csv_model = content.split('\n')
+ paragraphs = []
+ # 第一行为标题
+ title = csv_model[0].split(',')
+ for row in csv_model[1:]:
+ if not row:
+ continue
+ line = '; '.join([f'{key}:{value}' for key, value in zip(title, row.split(','))])
+ paragraphs.append({'title': '', 'content': line})
+
+ return [{'name': file.name, 'paragraphs': paragraphs}]
+
+ def get_content(self, file, save_image):
+ buffer = file.read()
+ try:
+ return buffer.decode(detect(buffer)['encoding'])
+ except BaseException as e:
+ max_kb.error(f'csv split handle error: {e}')
+ return f'error: {e}'
\ No newline at end of file
diff --git a/apps/common/handle/impl/table/xls_parse_table_handle.py b/apps/common/handle/impl/table/xls_parse_table_handle.py
new file mode 100644
index 000000000..897e347e8
--- /dev/null
+++ b/apps/common/handle/impl/table/xls_parse_table_handle.py
@@ -0,0 +1,94 @@
+# coding=utf-8
+import logging
+
+import xlrd
+
+from common.handle.base_parse_table_handle import BaseParseTableHandle
+
+max_kb = logging.getLogger("max_kb")
+
+
+class XlsSplitHandle(BaseParseTableHandle):
+ def support(self, file, get_buffer):
+ file_name: str = file.name.lower()
+ buffer = get_buffer(file)
+ if file_name.endswith(".xls") and xlrd.inspect_format(content=buffer):
+ return True
+ return False
+
+ def handle(self, file, get_buffer, save_image):
+ buffer = get_buffer(file)
+ try:
+ wb = xlrd.open_workbook(file_contents=buffer, formatting_info=True)
+ result = []
+ sheets = wb.sheets()
+ for sheet in sheets:
+ # 获取合并单元格的范围信息
+ merged_cells = sheet.merged_cells
+ print(merged_cells)
+ data = []
+ paragraphs = []
+ # 获取第一行作为标题行
+ headers = [sheet.cell_value(0, col_idx) for col_idx in range(sheet.ncols)]
+ # 从第二行开始遍历每一行(跳过标题行)
+ for row_idx in range(1, sheet.nrows):
+ row_data = {}
+ for col_idx in range(sheet.ncols):
+ cell_value = sheet.cell_value(row_idx, col_idx)
+
+ # 检查是否为空单元格,如果为空检查是否在合并区域中
+ if cell_value == "":
+ # 检查当前单元格是否在合并区域
+ for (rlo, rhi, clo, chi) in merged_cells:
+ if rlo <= row_idx < rhi and clo <= col_idx < chi:
+ # 使用合并区域的左上角单元格的值
+ cell_value = sheet.cell_value(rlo, clo)
+ break
+
+ # 将标题作为键,单元格的值作为值存入字典
+ row_data[headers[col_idx]] = cell_value
+ data.append(row_data)
+
+ for row in data:
+ row_output = "; ".join([f"{key}: {value}" for key, value in row.items()])
+ # print(row_output)
+ paragraphs.append({'title': '', 'content': row_output})
+
+ result.append({'name': sheet.name, 'paragraphs': paragraphs})
+
+ except BaseException as e:
+ max_kb.error(f'excel split handle error: {e}')
+ return [{'name': file.name, 'paragraphs': []}]
+ return result
+
+ def get_content(self, file, save_image):
+ # 打开 .xls 文件
+ try:
+ workbook = xlrd.open_workbook(file_contents=file.read(), formatting_info=True)
+ sheets = workbook.sheets()
+ md_tables = ''
+ for sheet in sheets:
+ # 过滤空白的sheet
+ if sheet.nrows == 0 or sheet.ncols == 0:
+ continue
+
+ # 获取表头和内容
+ headers = sheet.row_values(0)
+ data = [sheet.row_values(row_idx) for row_idx in range(1, sheet.nrows)]
+
+ # 构建 Markdown 表格
+ md_table = '| ' + ' | '.join(headers) + ' |\n'
+ md_table += '| ' + ' | '.join(['---'] * len(headers)) + ' |\n'
+ for row in data:
+ # 将每个单元格中的内容替换换行符为
以保留原始格式
+ md_table += '| ' + ' | '.join(
+ [str(cell)
+ .replace('\r\n', '
')
+ .replace('\n', '
')
+ if cell else '' for cell in row]) + ' |\n'
+ md_tables += md_table + '\n\n'
+
+ return md_tables
+ except Exception as e:
+ max_kb.error(f'excel split handle error: {e}')
+ return f'error: {e}'
diff --git a/apps/common/handle/impl/table/xlsx_parse_table_handle.py b/apps/common/handle/impl/table/xlsx_parse_table_handle.py
new file mode 100644
index 000000000..7b50683fa
--- /dev/null
+++ b/apps/common/handle/impl/table/xlsx_parse_table_handle.py
@@ -0,0 +1,118 @@
+# coding=utf-8
+import io
+import logging
+
+from openpyxl import load_workbook
+
+from common.handle.base_parse_table_handle import BaseParseTableHandle
+from common.handle.impl.common_handle import xlsx_embed_cells_images
+
+max_kb = logging.getLogger("max_kb")
+
+
+class XlsxSplitHandle(BaseParseTableHandle):
+ def support(self, file, get_buffer):
+ file_name: str = file.name.lower()
+ if file_name.endswith('.xlsx'):
+ return True
+ return False
+
+ def fill_merged_cells(self, sheet, image_dict):
+ data = []
+
+ # 获取第一行作为标题行
+ headers = []
+ for idx, cell in enumerate(sheet[1]):
+ if cell.value is None:
+ headers.append(' ' * (idx + 1))
+ else:
+ headers.append(cell.value)
+
+ # 从第二行开始遍历每一行
+ for row in sheet.iter_rows(min_row=2, values_only=False):
+ row_data = {}
+ for col_idx, cell in enumerate(row):
+ cell_value = cell.value
+
+ # 如果单元格为空,并且该单元格在合并单元格内,获取合并单元格的值
+ if cell_value is None:
+ for merged_range in sheet.merged_cells.ranges:
+ if cell.coordinate in merged_range:
+ cell_value = sheet[merged_range.min_row][merged_range.min_col - 1].value
+ break
+
+ image = image_dict.get(cell_value, None)
+ if image is not None:
+ cell_value = f''
+
+ # 使用标题作为键,单元格的值作为值存入字典
+ row_data[headers[col_idx]] = cell_value
+ data.append(row_data)
+
+ return data
+
+ def handle(self, file, get_buffer, save_image):
+ buffer = get_buffer(file)
+ try:
+ wb = load_workbook(io.BytesIO(buffer))
+ try:
+ image_dict: dict = xlsx_embed_cells_images(io.BytesIO(buffer))
+ save_image([item for item in image_dict.values()])
+ except Exception as e:
+ image_dict = {}
+ result = []
+ for sheetname in wb.sheetnames:
+ paragraphs = []
+ ws = wb[sheetname]
+ data = self.fill_merged_cells(ws, image_dict)
+
+ for row in data:
+ row_output = "; ".join([f"{key}: {value}" for key, value in row.items()])
+ # print(row_output)
+ paragraphs.append({'title': '', 'content': row_output})
+
+ result.append({'name': sheetname, 'paragraphs': paragraphs})
+
+ except BaseException as e:
+ max_kb.error(f'excel split handle error: {e}')
+ return [{'name': file.name, 'paragraphs': []}]
+ return result
+
+
+ def get_content(self, file, save_image):
+ try:
+ # 加载 Excel 文件
+ workbook = load_workbook(file)
+ try:
+ image_dict: dict = xlsx_embed_cells_images(file)
+ if len(image_dict) > 0:
+ save_image(image_dict.values())
+ except Exception as e:
+ print(f'{e}')
+ image_dict = {}
+ md_tables = ''
+ # 如果未指定 sheet_name,则使用第一个工作表
+ for sheetname in workbook.sheetnames:
+ sheet = workbook[sheetname] if sheetname else workbook.active
+ rows = self.fill_merged_cells(sheet, image_dict)
+ if len(rows) == 0:
+ continue
+ # 提取表头和内容
+
+ headers = [f"{key}" for key, value in rows[0].items()]
+
+ # 构建 Markdown 表格
+ md_table = '| ' + ' | '.join(headers) + ' |\n'
+ md_table += '| ' + ' | '.join(['---'] * len(headers)) + ' |\n'
+ for row in rows:
+ r = [f'{value}' for key, value in row.items()]
+ md_table += '| ' + ' | '.join(
+ [str(cell).replace('\n', '
') if cell is not None else '' for cell in r]) + ' |\n'
+
+ md_tables += md_table + '\n\n'
+
+ md_tables = md_tables.replace('/api/image/', '/api/file/')
+ return md_tables
+ except Exception as e:
+ max_kb.error(f'excel split handle error: {e}')
+ return f'error: {e}'
diff --git a/apps/common/handle/impl/text_split_handle.py b/apps/common/handle/impl/text_split_handle.py
new file mode 100644
index 000000000..1a18ab030
--- /dev/null
+++ b/apps/common/handle/impl/text_split_handle.py
@@ -0,0 +1,60 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: text_split_handle.py
+ @date:2024/3/27 18:19
+ @desc:
+"""
+import re
+import traceback
+from typing import List
+
+from charset_normalizer import detect
+
+from common.handle.base_split_handle import BaseSplitHandle
+from common.utils.split_model import SplitModel
+
+default_pattern_list = [re.compile('(?<=^)# .*|(?<=\\n)# .*'),
+ re.compile('(?<=\\n)(? 0.5:
+ return True
+ return False
+
+ def handle(self, file, pattern_list: List, with_filter: bool, limit: int, get_buffer, save_image):
+ buffer = get_buffer(file)
+ if pattern_list is not None and len(pattern_list) > 0:
+ split_model = SplitModel(pattern_list, with_filter, limit)
+ else:
+ split_model = SplitModel(default_pattern_list, with_filter=with_filter, limit=limit)
+ try:
+ content = buffer.decode(detect(buffer)['encoding'])
+ except BaseException as e:
+ return {'name': file.name,
+ 'content': []}
+ return {'name': file.name,
+ 'content': split_model.parse(content)
+ }
+
+ def get_content(self, file, save_image):
+ buffer = file.read()
+ try:
+ return buffer.decode(detect(buffer)['encoding'])
+ except BaseException as e:
+ traceback.print_exception(e)
+ return f'{e}'
\ No newline at end of file
diff --git a/apps/common/handle/impl/xls_split_handle.py b/apps/common/handle/impl/xls_split_handle.py
new file mode 100644
index 000000000..dbdcc9550
--- /dev/null
+++ b/apps/common/handle/impl/xls_split_handle.py
@@ -0,0 +1,80 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: xls_parse_qa_handle.py
+ @date:2024/5/21 14:59
+ @desc:
+"""
+from typing import List
+
+import xlrd
+
+from common.handle.base_split_handle import BaseSplitHandle
+
+
+def post_cell(cell_value):
+ return cell_value.replace('\r\n', '
').replace('\n', '
').replace('|', '|')
+
+
+def row_to_md(row):
+ return '| ' + ' | '.join(
+ [post_cell(str(cell)) if cell is not None else '' for cell in row]) + ' |\n'
+
+
+def handle_sheet(file_name, sheet, limit: int):
+ rows = iter([sheet.row_values(i) for i in range(sheet.nrows)])
+ paragraphs = []
+ result = {'name': file_name, 'content': paragraphs}
+ try:
+ title_row_list = next(rows)
+ title_md_content = row_to_md(title_row_list)
+ title_md_content += '| ' + ' | '.join(
+ ['---' if cell is not None else '' for cell in title_row_list]) + ' |\n'
+ except Exception as e:
+ return result
+ if len(title_row_list) == 0:
+ return result
+ result_item_content = ''
+ for row in rows:
+ next_md_content = row_to_md(row)
+ next_md_content_len = len(next_md_content)
+ result_item_content_len = len(result_item_content)
+ if len(result_item_content) == 0:
+ result_item_content += title_md_content
+ result_item_content += next_md_content
+ else:
+ if result_item_content_len + next_md_content_len < limit:
+ result_item_content += next_md_content
+ else:
+ paragraphs.append({'content': result_item_content, 'title': ''})
+ result_item_content = title_md_content + next_md_content
+ if len(result_item_content) > 0:
+ paragraphs.append({'content': result_item_content, 'title': ''})
+ return result
+
+
+class XlsSplitHandle(BaseSplitHandle):
+ def handle(self, file, pattern_list: List, with_filter: bool, limit: int, get_buffer, save_image):
+ buffer = get_buffer(file)
+ try:
+ workbook = xlrd.open_workbook(file_contents=buffer)
+ worksheets = workbook.sheets()
+ worksheets_size = len(worksheets)
+ return [row for row in
+ [handle_sheet(file.name,
+ sheet, limit) if worksheets_size == 1 and sheet.name == 'Sheet1' else handle_sheet(
+ sheet.name, sheet, limit) for sheet
+ in worksheets] if row is not None]
+ except Exception as e:
+ return [{'name': file.name, 'content': []}]
+
+ def get_content(self, file, save_image):
+ pass
+
+ def support(self, file, get_buffer):
+ file_name: str = file.name.lower()
+ buffer = get_buffer(file)
+ if file_name.endswith(".xls") and xlrd.inspect_format(content=buffer):
+ return True
+ return False
diff --git a/apps/common/handle/impl/xlsx_split_handle.py b/apps/common/handle/impl/xlsx_split_handle.py
new file mode 100644
index 000000000..abda06d2e
--- /dev/null
+++ b/apps/common/handle/impl/xlsx_split_handle.py
@@ -0,0 +1,92 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: xlsx_parse_qa_handle.py
+ @date:2024/5/21 14:59
+ @desc:
+"""
+import io
+from typing import List
+
+import openpyxl
+
+from common.handle.base_split_handle import BaseSplitHandle
+from common.handle.impl.common_handle import xlsx_embed_cells_images
+
+
+def post_cell(image_dict, cell_value):
+ image = image_dict.get(cell_value, None)
+ if image is not None:
+ return f''
+ return cell_value.replace('\n', '
').replace('|', '|')
+
+
+def row_to_md(row, image_dict):
+ return '| ' + ' | '.join(
+ [post_cell(image_dict, str(cell.value if cell.value is not None else '')) if cell is not None else '' for cell
+ in row]) + ' |\n'
+
+
+def handle_sheet(file_name, sheet, image_dict, limit: int):
+ rows = sheet.rows
+ paragraphs = []
+ result = {'name': file_name, 'content': paragraphs}
+ try:
+ title_row_list = next(rows)
+ title_md_content = row_to_md(title_row_list, image_dict)
+ title_md_content += '| ' + ' | '.join(
+ ['---' if cell is not None else '' for cell in title_row_list]) + ' |\n'
+ except Exception as e:
+ return result
+ if len(title_row_list) == 0:
+ return result
+ result_item_content = ''
+ for row in rows:
+ next_md_content = row_to_md(row, image_dict)
+ next_md_content_len = len(next_md_content)
+ result_item_content_len = len(result_item_content)
+ if len(result_item_content) == 0:
+ result_item_content += title_md_content
+ result_item_content += next_md_content
+ else:
+ if result_item_content_len + next_md_content_len < limit:
+ result_item_content += next_md_content
+ else:
+ paragraphs.append({'content': result_item_content, 'title': ''})
+ result_item_content = title_md_content + next_md_content
+ if len(result_item_content) > 0:
+ paragraphs.append({'content': result_item_content, 'title': ''})
+ return result
+
+
+class XlsxSplitHandle(BaseSplitHandle):
+ def handle(self, file, pattern_list: List, with_filter: bool, limit: int, get_buffer, save_image):
+ buffer = get_buffer(file)
+ try:
+ workbook = openpyxl.load_workbook(io.BytesIO(buffer))
+ try:
+ image_dict: dict = xlsx_embed_cells_images(io.BytesIO(buffer))
+ save_image([item for item in image_dict.values()])
+ except Exception as e:
+ image_dict = {}
+ worksheets = workbook.worksheets
+ worksheets_size = len(worksheets)
+ return [row for row in
+ [handle_sheet(file.name,
+ sheet,
+ image_dict,
+ limit) if worksheets_size == 1 and sheet.title == 'Sheet1' else handle_sheet(
+ sheet.title, sheet, image_dict, limit) for sheet
+ in worksheets] if row is not None]
+ except Exception as e:
+ return [{'name': file.name, 'content': []}]
+
+ def get_content(self, file, save_image):
+ pass
+
+ def support(self, file, get_buffer):
+ file_name: str = file.name.lower()
+ if file_name.endswith(".xlsx"):
+ return True
+ return False
diff --git a/apps/common/handle/impl/zip_split_handle.py b/apps/common/handle/impl/zip_split_handle.py
new file mode 100644
index 000000000..f5beec141
--- /dev/null
+++ b/apps/common/handle/impl/zip_split_handle.py
@@ -0,0 +1,164 @@
+# coding=utf-8
+"""
+ @project: maxkb
+ @Author:虎
+ @file: text_split_handle.py
+ @date:2024/3/27 18:19
+ @desc:
+"""
+import io
+import os
+import re
+import zipfile
+from typing import List
+from urllib.parse import urljoin
+
+import uuid_utils.compat as uuid
+from charset_normalizer import detect
+from django.utils.translation import gettext_lazy as _
+
+from common.handle.base_split_handle import BaseSplitHandle
+from common.handle.impl.csv_split_handle import CsvSplitHandle
+from common.handle.impl.doc_split_handle import DocSplitHandle
+from common.handle.impl.html_split_handle import HTMLSplitHandle
+from common.handle.impl.pdf_split_handle import PdfSplitHandle
+from common.handle.impl.text_split_handle import TextSplitHandle
+from common.handle.impl.xls_split_handle import XlsSplitHandle
+from common.handle.impl.xlsx_split_handle import XlsxSplitHandle
+from common.utils.common import parse_md_image
+from knowledge.models import File
+
+
+class FileBufferHandle:
+ buffer = None
+
+ def get_buffer(self, file):
+ if self.buffer is None:
+ self.buffer = file.read()
+ return self.buffer
+
+
+default_split_handle = TextSplitHandle()
+split_handles = [
+ HTMLSplitHandle(),
+ DocSplitHandle(),
+ PdfSplitHandle(),
+ XlsxSplitHandle(),
+ XlsSplitHandle(),
+ CsvSplitHandle(),
+ default_split_handle
+]
+
+
+def file_to_paragraph(file, pattern_list: List, with_filter: bool, limit: int, save_inner_image):
+ get_buffer = FileBufferHandle().get_buffer
+ for split_handle in split_handles:
+ if split_handle.support(file, get_buffer):
+ return split_handle.handle(file, pattern_list, with_filter, limit, get_buffer, save_inner_image)
+ raise Exception(_('Unsupported file format'))
+
+
+def is_valid_uuid(uuid_str: str):
+ try:
+ uuid.UUID(uuid_str)
+ except ValueError:
+ return False
+ return True
+
+
+def get_image_list(result_list: list, zip_files: List[str]):
+ image_file_list = []
+ for result in result_list:
+ for p in result.get('content', []):
+ content: str = p.get('content', '')
+ image_list = parse_md_image(content)
+ for image in image_list:
+ search = re.search("\(.*\)", image)
+ if search:
+ new_image_id = str(uuid.uuid7())
+ source_image_path = search.group().replace('(', '').replace(')', '')
+ source_image_path = source_image_path.strip().split(" ")[0]
+ image_path = urljoin(result.get('name'), '.' + source_image_path if source_image_path.startswith(
+ '/') else source_image_path)
+ if not zip_files.__contains__(image_path):
+ continue
+ if image_path.startswith('api/file/') or image_path.startswith('api/image/'):
+ image_id = image_path.replace('api/file/', '').replace('api/image/', '')
+ if is_valid_uuid(image_id):
+ image_file_list.append({'source_file': image_path,
+ 'image_id': image_id})
+ else:
+ image_file_list.append({'source_file': image_path,
+ 'image_id': new_image_id})
+ content = content.replace(source_image_path, f'/api/image/{new_image_id}')
+ p['content'] = content
+ else:
+ image_file_list.append({'source_file': image_path,
+ 'image_id': new_image_id})
+ content = content.replace(source_image_path, f'/api/image/{new_image_id}')
+ p['content'] = content
+
+ return image_file_list
+
+
+def get_file_name(file_name):
+ try:
+ file_name_code = file_name.encode('cp437')
+ charset = detect(file_name_code)['encoding']
+ return file_name_code.decode(charset)
+ except Exception as e:
+ return file_name
+
+
+def filter_image_file(result_list: list, image_list):
+ image_source_file_list = [image.get('source_file') for image in image_list]
+ return [r for r in result_list if not image_source_file_list.__contains__(r.get('name', ''))]
+
+
+class ZipSplitHandle(BaseSplitHandle):
+ def handle(self, file, pattern_list: List, with_filter: bool, limit: int, get_buffer, save_image):
+ buffer = get_buffer(file)
+ bytes_io = io.BytesIO(buffer)
+ result = []
+ # 打开zip文件
+ with zipfile.ZipFile(bytes_io, 'r') as zip_ref:
+ # 获取压缩包中的文件名列表
+ files = zip_ref.namelist()
+ # 读取压缩包中的文件内容
+ for file in files:
+ if file.endswith('/') or file.startswith('__MACOSX'):
+ continue
+ with zip_ref.open(file) as f:
+ # 对文件内容进行处理
+ try:
+ # 处理一下文件名
+ f.name = get_file_name(f.name)
+ value = file_to_paragraph(f, pattern_list, with_filter, limit, save_image)
+ if isinstance(value, list):
+ result = [*result, *value]
+ else:
+ result.append(value)
+ except Exception:
+ pass
+ image_list = get_image_list(result, files)
+ result = filter_image_file(result, image_list)
+ image_mode_list = []
+ for image in image_list:
+ with zip_ref.open(image.get('source_file')) as f:
+ i = File(
+ id=image.get('image_id'),
+ image_name=os.path.basename(image.get('source_file')),
+ meta={'debug': False, 'content': f.read()} # 这里的content是二进制数据
+ )
+ image_mode_list.append(i)
+ save_image(image_mode_list)
+ return result
+
+ def support(self, file, get_buffer):
+ file_name: str = file.name.lower()
+ if file_name.endswith(".zip") or file_name.endswith(".ZIP"):
+ return True
+ return False
+
+ def get_content(self, file, save_image):
+ return ""
diff --git a/apps/common/utils/common.py b/apps/common/utils/common.py
index e4c6c3443..d5d3dac10 100644
--- a/apps/common/utils/common.py
+++ b/apps/common/utils/common.py
@@ -257,3 +257,9 @@ def post(post_function):
return inner
+
+def parse_md_image(content: str):
+ matches = re.finditer("!\[.*?\]\(.*?\)", content)
+ image_list = [match.group() for match in matches]
+ return image_list
+
diff --git a/apps/knowledge/api/document.py b/apps/knowledge/api/document.py
index 7cc248f71..47a7ddb03 100644
--- a/apps/knowledge/api/document.py
+++ b/apps/knowledge/api/document.py
@@ -32,3 +32,46 @@ class DocumentCreateAPI(APIMixin):
@staticmethod
def get_response():
return DocumentCreateResponse
+
+
+class DocumentSplitAPI(APIMixin):
+ @staticmethod
+ def get_parameters():
+ return [
+ OpenApiParameter(
+ name="workspace_id",
+ description="工作空间id",
+ type=OpenApiTypes.STR,
+ location='path',
+ required=True,
+ ),
+ OpenApiParameter(
+ name="file",
+ description="文件",
+ type=OpenApiTypes.BINARY,
+ location='query',
+ required=False,
+ ),
+ OpenApiParameter(
+ name="limit",
+ description="分段长度",
+ type=OpenApiTypes.INT,
+ location='query',
+ required=False,
+ ),
+ OpenApiParameter(
+ name="patterns",
+ description="分段正则列表",
+ type=OpenApiTypes.STR,
+ location='query',
+ required=False,
+ ),
+ OpenApiParameter(
+ name="with_filter",
+ description="是否清除特殊字符",
+ type=OpenApiTypes.BOOL,
+ location='query',
+ required=False,
+ ),
+ ]
+
diff --git a/apps/knowledge/serializers/document.py b/apps/knowledge/serializers/document.py
index 789f682d5..5870f3fd0 100644
--- a/apps/knowledge/serializers/document.py
+++ b/apps/knowledge/serializers/document.py
@@ -13,14 +13,34 @@ from rest_framework import serializers
from common.db.search import native_search
from common.event import ListenerManagement
from common.exception.app_exception import AppApiException
+from common.handle.impl.csv_split_handle import CsvSplitHandle
+from common.handle.impl.doc_split_handle import DocSplitHandle
+from common.handle.impl.html_split_handle import HTMLSplitHandle
+from common.handle.impl.pdf_split_handle import PdfSplitHandle
+from common.handle.impl.text_split_handle import TextSplitHandle
+from common.handle.impl.xls_split_handle import XlsSplitHandle
+from common.handle.impl.xlsx_split_handle import XlsxSplitHandle
+from common.handle.impl.zip_split_handle import ZipSplitHandle
from common.utils.common import post, get_file_content
from knowledge.models import Knowledge, Paragraph, Problem, Document, KnowledgeType, ProblemParagraphMapping, State, \
- TaskType
+ TaskType, File
from knowledge.serializers.common import ProblemParagraphManage
from knowledge.serializers.paragraph import ParagraphSerializers, ParagraphInstanceSerializer
from knowledge.task import embedding_by_document
from maxkb.const import PROJECT_DIR
+default_split_handle = TextSplitHandle()
+split_handles = [
+ HTMLSplitHandle(),
+ DocSplitHandle(),
+ PdfSplitHandle(),
+ XlsxSplitHandle(),
+ XlsSplitHandle(),
+ CsvSplitHandle(),
+ ZipSplitHandle(),
+ default_split_handle
+]
+
class DocumentInstanceSerializer(serializers.Serializer):
name = serializers.CharField(required=True, label=_('document name'), max_length=128, min_length=1)
@@ -34,6 +54,17 @@ class DocumentCreateRequest(serializers.Serializer):
documents = DocumentInstanceSerializer(required=False, many=True)
+class DocumentSplitRequest(serializers.Serializer):
+ file = serializers.ListField(required=True, label=_('file list'))
+ limit = serializers.IntegerField(required=False, label=_('limit'))
+ patterns = serializers.ListField(
+ required=False,
+ child=serializers.CharField(required=True, label=_('patterns')),
+ label=_('patterns')
+ )
+ with_filter = serializers.BooleanField(required=False, label=_('Auto Clean'))
+
+
class DocumentSerializers(serializers.Serializer):
class Operate(serializers.Serializer):
document_id = serializers.UUIDField(required=True, label=_('document id'))
@@ -177,3 +208,67 @@ class DocumentSerializers(serializers.Serializer):
document_model,
instance.get('paragraphs') if 'paragraphs' in instance else []
)
+
+ class Split(serializers.Serializer):
+ workspace_id = serializers.CharField(required=True, label=_('workspace id'))
+ knowledge_id = serializers.UUIDField(required=True, label=_('knowledge id'))
+
+ def is_valid(self, *, raise_exception=True):
+ super().is_valid(raise_exception=True)
+ files = self.data.get('file')
+ for f in files:
+ if f.size > 1024 * 1024 * 100:
+ raise AppApiException(500, _(
+ 'The maximum size of the uploaded file cannot exceed {}MB'
+ ).format(100))
+
+ def parse(self, instance):
+ self.is_valid(raise_exception=True)
+ DocumentSplitRequest(instance).is_valid(raise_exception=True)
+
+ file_list = instance.get("file")
+ return reduce(
+ lambda x, y: [*x, *y],
+ [self.file_to_paragraph(
+ f,
+ instance.get("patterns", None),
+ instance.get("with_filter", None),
+ instance.get("limit", 4096)
+ ) for f in file_list],
+ []
+ )
+
+ def save_image(self, image_list):
+ if image_list is not None and len(image_list) > 0:
+ exist_image_list = [str(i.get('id')) for i in
+ QuerySet(File).filter(id__in=[i.id for i in image_list]).values('id')]
+ save_image_list = [image for image in image_list if not exist_image_list.__contains__(str(image.id))]
+ save_image_list = list({img.id: img for img in save_image_list}.values())
+ # save image
+ for file in save_image_list:
+ file_bytes = file.meta.pop('content')
+ file.workspace_id = self.data.get('workspace_id')
+ file.meta['knowledge_id'] = self.data.get('knowledge_id')
+ file.save(file_bytes)
+
+ def file_to_paragraph(self, file, pattern_list: List, with_filter: bool, limit: int):
+ get_buffer = FileBufferHandle().get_buffer
+ for split_handle in split_handles:
+ if split_handle.support(file, get_buffer):
+ result = split_handle.handle(file, pattern_list, with_filter, limit, get_buffer, self.save_image)
+ if isinstance(result, list):
+ return result
+ return [result]
+ result = default_split_handle.handle(file, pattern_list, with_filter, limit, get_buffer, self.save_image)
+ if isinstance(result, list):
+ return result
+ return [result]
+
+
+class FileBufferHandle:
+ buffer = None
+
+ def get_buffer(self, file):
+ if self.buffer is None:
+ self.buffer = file.read()
+ return self.buffer
diff --git a/apps/knowledge/urls.py b/apps/knowledge/urls.py
index 3715e11f7..eaa95c729 100644
--- a/apps/knowledge/urls.py
+++ b/apps/knowledge/urls.py
@@ -8,5 +8,6 @@ urlpatterns = [
path('workspace//knowledge/base', views.KnowledgeBaseView.as_view()),
path('workspace//knowledge/web', views.KnowledgeWebView.as_view()),
path('workspace//knowledge/', views.KnowledgeView.Operate.as_view()),
+ path('workspace//knowledge//document/split', views.DocumentView.Split.as_view()),
path('workspace//knowledge//', views.KnowledgeView.Page.as_view()),
]
diff --git a/apps/knowledge/views/__init__.py b/apps/knowledge/views/__init__.py
index 4a43fce73..ce309a4e5 100644
--- a/apps/knowledge/views/__init__.py
+++ b/apps/knowledge/views/__init__.py
@@ -1 +1,2 @@
from .knowledge import *
+from .document import *
diff --git a/apps/knowledge/views/document.py b/apps/knowledge/views/document.py
new file mode 100644
index 000000000..fbe276a63
--- /dev/null
+++ b/apps/knowledge/views/document.py
@@ -0,0 +1,70 @@
+from django.utils.translation import gettext_lazy as _
+from drf_spectacular.utils import extend_schema
+from rest_framework.parsers import MultiPartParser
+from rest_framework.request import Request
+from rest_framework.views import APIView
+
+from common.auth import TokenAuth
+from common.auth.authentication import has_permissions
+from common.constants.permission_constants import PermissionConstants, CompareConstants
+from common.result import result
+from knowledge.api.document import DocumentSplitAPI
+from knowledge.api.knowledge import KnowledgeTreeReadAPI
+from knowledge.serializers.document import DocumentSerializers
+from knowledge.serializers.knowledge import KnowledgeSerializer
+
+
+class DocumentView(APIView):
+ authentication_classes = [TokenAuth]
+
+ @extend_schema(
+ methods=['GET'],
+ description=_('Get document'),
+ operation_id=_('Get document'),
+ parameters=KnowledgeTreeReadAPI.get_parameters(),
+ responses=KnowledgeTreeReadAPI.get_response(),
+ tags=[_('Knowledge Base')]
+ )
+ @has_permissions(PermissionConstants.DOCUMENT_READ.get_workspace_permission())
+ def get(self, request: Request, workspace_id: str):
+ return result.success(KnowledgeSerializer.Query(
+ data={
+ 'workspace_id': workspace_id,
+ 'folder_id': request.query_params.get('folder_id'),
+ 'name': request.query_params.get('name'),
+ 'desc': request.query_params.get("desc"),
+ 'user_id': request.query_params.get('user_id')
+ }
+ ).list())
+
+ class Split(APIView):
+ authentication_classes = [TokenAuth]
+ parser_classes = [MultiPartParser]
+
+ @extend_schema(
+ methods=['POST'],
+ description=_('Segmented document'),
+ operation_id=_('Segmented document'),
+ parameters=DocumentSplitAPI.get_parameters(),
+ request=DocumentSplitAPI.get_request(),
+ responses=DocumentSplitAPI.get_response(),
+ tags=[_('Knowledge Base/Documentation')]
+ )
+ @has_permissions([
+ PermissionConstants.DOCUMENT_CREATE.get_workspace_permission(),
+ PermissionConstants.DOCUMENT_EDIT.get_workspace_permission(),
+ ])
+ def post(self, request: Request, workspace_id: str, knowledge_id: str):
+ split_data = {'file': request.FILES.getlist('file')}
+ request_data = request.data
+ if 'patterns' in request.data and request.data.get('patterns') is not None and len(
+ request.data.get('patterns')) > 0:
+ split_data.__setitem__('patterns', request_data.getlist('patterns'))
+ if 'limit' in request.data:
+ split_data.__setitem__('limit', request_data.get('limit'))
+ if 'with_filter' in request.data:
+ split_data.__setitem__('with_filter', request_data.get('with_filter'))
+ return result.success(DocumentSerializers.Split(data={
+ 'workspace_id': workspace_id,
+ 'knowledge_id': knowledge_id,
+ }).parse(split_data))
diff --git a/pyproject.toml b/pyproject.toml
index 3e8dd05c6..543c02f0a 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -46,6 +46,12 @@ celery-once = "3.0.1"
beautifulsoup4 = "4.13.4"
html2text = "2025.4.15"
jieba = "0.42.1"
+openpyxl = "3.1.5"
+python-docx = "1.1.2"
+xlrd = "2.0.1"
+xlwt = "1.3.0"
+pymupdf = "1.24.9"
+pypdf = "4.3.1"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"