diff --git a/apps/common/constants/permission_constants.py b/apps/common/constants/permission_constants.py index ce167e83d..2d9eb9966 100644 --- a/apps/common/constants/permission_constants.py +++ b/apps/common/constants/permission_constants.py @@ -218,6 +218,14 @@ class PermissionConstants(Enum): RoleConstants.USER]) KNOWLEDGE_DELETE = Permission(group=Group.KNOWLEDGE, operate=Operate.DELETE, role_list=[RoleConstants.ADMIN, RoleConstants.USER]) + DOCUMENT_READ = Permission(group=Group.KNOWLEDGE, operate=Operate.DELETE, role_list=[RoleConstants.ADMIN, + RoleConstants.USER]) + DOCUMENT_CREATE = Permission(group=Group.KNOWLEDGE, operate=Operate.DELETE, role_list=[RoleConstants.ADMIN, + RoleConstants.USER]) + DOCUMENT_EDIT = Permission(group=Group.KNOWLEDGE, operate=Operate.DELETE, role_list=[RoleConstants.ADMIN, + RoleConstants.USER]) + DOCUMENT_DELETE = Permission(group=Group.KNOWLEDGE, operate=Operate.DELETE, role_list=[RoleConstants.ADMIN, + RoleConstants.USER]) def get_workspace_application_permission(self): return lambda r, kwargs: Permission(group=self.value.group, operate=self.value.operate, diff --git a/apps/common/handle/__init__.py b/apps/common/handle/__init__.py new file mode 100644 index 000000000..ad09602db --- /dev/null +++ b/apps/common/handle/__init__.py @@ -0,0 +1,8 @@ +# coding=utf-8 +""" + @project: qabot + @Author:虎 + @file: __init__.py.py + @date:2023/9/6 10:09 + @desc: +""" diff --git a/apps/common/handle/base_parse_qa_handle.py b/apps/common/handle/base_parse_qa_handle.py new file mode 100644 index 000000000..8cd1cd1cd --- /dev/null +++ b/apps/common/handle/base_parse_qa_handle.py @@ -0,0 +1,52 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: base_parse_qa_handle.py + @date:2024/5/21 14:56 + @desc: +""" +from abc import ABC, abstractmethod + + +def get_row_value(row, title_row_index_dict, field): + index = title_row_index_dict.get(field) + if index is None: + return None + if (len(row) - 1) >= index: + return row[index] + return None + + +def get_title_row_index_dict(title_row_list): + title_row_index_dict = {} + if len(title_row_list) == 1: + title_row_index_dict['content'] = 0 + elif len(title_row_list) == 1: + title_row_index_dict['title'] = 0 + title_row_index_dict['content'] = 1 + else: + title_row_index_dict['title'] = 0 + title_row_index_dict['content'] = 1 + title_row_index_dict['problem_list'] = 2 + for index in range(len(title_row_list)): + title_row = title_row_list[index] + if title_row is None: + title_row = '' + if title_row.startswith('分段标题'): + title_row_index_dict['title'] = index + if title_row.startswith('分段内容'): + title_row_index_dict['content'] = index + if title_row.startswith('问题'): + title_row_index_dict['problem_list'] = index + return title_row_index_dict + + +class BaseParseQAHandle(ABC): + @abstractmethod + def support(self, file, get_buffer): + pass + + @abstractmethod + def handle(self, file, get_buffer, save_image): + pass diff --git a/apps/common/handle/base_parse_table_handle.py b/apps/common/handle/base_parse_table_handle.py new file mode 100644 index 000000000..65eaf897f --- /dev/null +++ b/apps/common/handle/base_parse_table_handle.py @@ -0,0 +1,23 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: base_parse_qa_handle.py + @date:2024/5/21 14:56 + @desc: +""" +from abc import ABC, abstractmethod + + +class BaseParseTableHandle(ABC): + @abstractmethod + def support(self, file, get_buffer): + pass + + @abstractmethod + def handle(self, file, get_buffer,save_image): + pass + + @abstractmethod + def get_content(self, file, save_image): + pass \ No newline at end of file diff --git a/apps/common/handle/base_split_handle.py b/apps/common/handle/base_split_handle.py new file mode 100644 index 000000000..bedaad5e1 --- /dev/null +++ b/apps/common/handle/base_split_handle.py @@ -0,0 +1,24 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: base_split_handle.py + @date:2024/3/27 18:13 + @desc: +""" +from abc import ABC, abstractmethod +from typing import List + + +class BaseSplitHandle(ABC): + @abstractmethod + def support(self, file, get_buffer): + pass + + @abstractmethod + def handle(self, file, pattern_list: List, with_filter: bool, limit: int, get_buffer, save_image): + pass + + @abstractmethod + def get_content(self, file, save_image): + pass diff --git a/apps/common/handle/base_to_response.py b/apps/common/handle/base_to_response.py new file mode 100644 index 000000000..376d1a9dd --- /dev/null +++ b/apps/common/handle/base_to_response.py @@ -0,0 +1,30 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: base_to_response.py + @date:2024/9/6 16:04 + @desc: +""" +from abc import ABC, abstractmethod + +from rest_framework import status + + +class BaseToResponse(ABC): + + @abstractmethod + def to_block_response(self, chat_id, chat_record_id, content, is_end, completion_tokens, + prompt_tokens, other_params: dict = None, + _status=status.HTTP_200_OK): + pass + + @abstractmethod + def to_stream_chunk_response(self, chat_id, chat_record_id, node_id, up_node_id_list, content, is_end, + completion_tokens, + prompt_tokens, other_params: dict = None): + pass + + @staticmethod + def format_stream_chunk(response_str): + return 'data: ' + response_str + '\n\n' diff --git a/apps/common/handle/handle_exception.py b/apps/common/handle/handle_exception.py new file mode 100644 index 000000000..c25d4c01f --- /dev/null +++ b/apps/common/handle/handle_exception.py @@ -0,0 +1,94 @@ +# coding=utf-8 +""" + @project: qabot + @Author:虎 + @file: handle_exception.py + @date:2023/9/5 19:29 + @desc: +""" +import logging +import traceback + +from rest_framework.exceptions import ValidationError, ErrorDetail, APIException +from rest_framework.views import exception_handler + +from common.exception.app_exception import AppApiException + +from django.utils.translation import gettext_lazy as _ + +from common.result import result + + +def to_result(key, args, parent_key=None): + """ + 将校验异常 args转换为统一数据 + :param key: 校验key + :param args: 校验异常参数 + :param parent_key 父key + :return: 接口响应对象 + """ + error_detail = list(filter( + lambda d: True if isinstance(d, ErrorDetail) else True if isinstance(d, dict) and len( + d.keys()) > 0 else False, + (args[0] if len(args) > 0 else {key: [ErrorDetail(_('Unknown exception'), code='unknown')]}).get(key)))[0] + + if isinstance(error_detail, dict): + return list(map(lambda k: to_result(k, args=[error_detail], + parent_key=key if parent_key is None else parent_key + '.' + key), + error_detail.keys() if len(error_detail) > 0 else []))[0] + + return result.Result(500 if isinstance(error_detail.code, str) else error_detail.code, + message=f"【{key if parent_key is None else parent_key + '.' + key}】为必填参数" if str( + error_detail) == "This field is required." else error_detail) + + +def validation_error_to_result(exc: ValidationError): + """ + 校验异常转响应对象 + :param exc: 校验异常 + :return: 接口响应对象 + """ + try: + v = find_err_detail(exc.detail) + if v is None: + return result.error(str(exc.detail)) + return result.error(str(v)) + except Exception as e: + return result.error(str(exc.detail)) + + +def find_err_detail(exc_detail): + if isinstance(exc_detail, ErrorDetail): + return exc_detail + if isinstance(exc_detail, dict): + keys = exc_detail.keys() + for key in keys: + _value = exc_detail[key] + if isinstance(_value, list): + return find_err_detail(_value) + if isinstance(_value, ErrorDetail): + return _value + if isinstance(_value, dict) and len(_value.keys()) > 0: + return find_err_detail(_value) + if isinstance(exc_detail, list): + for v in exc_detail: + r = find_err_detail(v) + if r is not None: + return r + + +def handle_exception(exc, context): + exception_class = exc.__class__ + # 先调用REST framework默认的异常处理方法获得标准错误响应对象 + response = exception_handler(exc, context) + # 在此处补充自定义的异常处理 + if issubclass(exception_class, ValidationError): + return validation_error_to_result(exc) + if issubclass(exception_class, AppApiException): + return result.Result(exc.code, exc.message, response_status=exc.status_code) + if issubclass(exception_class, APIException): + return result.error(exc.detail) + if response is None: + logging.getLogger("max_kb_error").error(f'{str(exc)}:{traceback.format_exc()}') + return result.error(str(exc)) + return response diff --git a/apps/common/handle/impl/__init__.py b/apps/common/handle/impl/__init__.py new file mode 100644 index 000000000..ad09602db --- /dev/null +++ b/apps/common/handle/impl/__init__.py @@ -0,0 +1,8 @@ +# coding=utf-8 +""" + @project: qabot + @Author:虎 + @file: __init__.py.py + @date:2023/9/6 10:09 + @desc: +""" diff --git a/apps/common/handle/impl/common_handle.py b/apps/common/handle/impl/common_handle.py new file mode 100644 index 000000000..48692f244 --- /dev/null +++ b/apps/common/handle/impl/common_handle.py @@ -0,0 +1,116 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: tools.py + @date:2024/9/11 16:41 + @desc: +""" +import io +import uuid_utils.compat as uuid +from functools import reduce +from io import BytesIO +from xml.etree.ElementTree import fromstring +from zipfile import ZipFile + +from PIL import Image as PILImage +from openpyxl.drawing.image import Image as openpyxl_Image +from openpyxl.packaging.relationship import get_rels_path, get_dependents +from openpyxl.xml.constants import SHEET_DRAWING_NS, REL_NS, SHEET_MAIN_NS + +from common.handle.base_parse_qa_handle import get_title_row_index_dict, get_row_value +from knowledge.models import File + + +def parse_element(element) -> {}: + data = {} + xdr_namespace = "{%s}" % SHEET_DRAWING_NS + targets = level_order_traversal(element, xdr_namespace + "nvPicPr") + for target in targets: + cNvPr = embed = "" + for child in target: + if child.tag == xdr_namespace + "nvPicPr": + cNvPr = child[0].attrib["name"] + elif child.tag == xdr_namespace + "blipFill": + _rel_embed = "{%s}embed" % REL_NS + embed = child[0].attrib[_rel_embed] + if cNvPr: + data[cNvPr] = embed + return data + + +def parse_element_sheet_xml(element) -> []: + data = [] + xdr_namespace = "{%s}" % SHEET_MAIN_NS + targets = level_order_traversal(element, xdr_namespace + "f") + for target in targets: + for child in target: + if child.tag == xdr_namespace + "f": + data.append(child.text) + return data + + +def level_order_traversal(root, flag: str) -> []: + queue = [root] + targets = [] + while queue: + node = queue.pop(0) + children = [child.tag for child in node] + if flag in children: + targets.append(node) + continue + for child in node: + queue.append(child) + return targets + + +def handle_images(deps, archive: ZipFile) -> []: + images = [] + if not PILImage: # Pillow not installed, drop images + return images + for dep in deps: + try: + image_io = archive.read(dep.target) + image = openpyxl_Image(BytesIO(image_io)) + except Exception as e: + print(e) + continue + image.embed = dep.id # 文件rId + image.target = dep.target # 文件地址 + images.append(image) + return images + + +def xlsx_embed_cells_images(buffer) -> {}: + archive = ZipFile(buffer) + # 解析cellImage.xml文件 + deps = get_dependents(archive, get_rels_path("xl/cellimages.xml")) + image_rel = handle_images(deps=deps, archive=archive) + # 工作表及其中图片ID + sheet_list = {} + for item in archive.namelist(): + if not item.startswith('xl/worksheets/sheet'): + continue + key = item.split('/')[-1].split('.')[0].split('sheet')[-1] + sheet_list[key] = parse_element_sheet_xml(fromstring(archive.read(item))) + cell_images_xml = parse_element(fromstring(archive.read("xl/cellimages.xml"))) + cell_images_rel = {} + for image in image_rel: + cell_images_rel[image.embed] = image + for cnv, embed in cell_images_xml.items(): + cell_images_xml[cnv] = cell_images_rel.get(embed) + result = {} + for key, img in cell_images_xml.items(): + image_excel_id_list = [_xl for _xl in + reduce(lambda x, y: [*x, *y], [sheet for sheet_id, sheet in sheet_list.items()], []) if + key in _xl] + if len(image_excel_id_list) > 0: + image_excel_id = image_excel_id_list[-1] + f = archive.open(img.target) + img_byte = io.BytesIO() + im = PILImage.open(f).convert('RGB') + im.save(img_byte, format='JPEG') + image = File(id=uuid.uuid7(), file_name=img.path, meta={'debug': False, 'content': img_byte.getvalue()}) + result['=' + image_excel_id] = image + archive.close() + return result diff --git a/apps/common/handle/impl/csv_split_handle.py b/apps/common/handle/impl/csv_split_handle.py new file mode 100644 index 000000000..3ea690e0e --- /dev/null +++ b/apps/common/handle/impl/csv_split_handle.py @@ -0,0 +1,72 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: csv_parse_qa_handle.py + @date:2024/5/21 14:59 + @desc: +""" +import csv +import io +import os +from typing import List + +from charset_normalizer import detect + +from common.handle.base_split_handle import BaseSplitHandle + + +def post_cell(cell_value): + return cell_value.replace('\n', '
').replace('|', '|') + + +def row_to_md(row): + return '| ' + ' | '.join( + [post_cell(cell) if cell is not None else '' for cell in row]) + ' |\n' + + +class CsvSplitHandle(BaseSplitHandle): + def handle(self, file, pattern_list: List, with_filter: bool, limit: int, get_buffer, save_image): + buffer = get_buffer(file) + paragraphs = [] + file_name = os.path.basename(file.name) + result = {'name': file_name, 'content': paragraphs} + try: + reader = csv.reader(io.TextIOWrapper(io.BytesIO(buffer), encoding=detect(buffer)['encoding'])) + try: + title_row_list = reader.__next__() + title_md_content = row_to_md(title_row_list) + title_md_content += '| ' + ' | '.join( + ['---' if cell is not None else '' for cell in title_row_list]) + ' |\n' + except Exception as e: + return result + if len(title_row_list) == 0: + return result + result_item_content = '' + for row in reader: + next_md_content = row_to_md(row) + next_md_content_len = len(next_md_content) + result_item_content_len = len(result_item_content) + if len(result_item_content) == 0: + result_item_content += title_md_content + result_item_content += next_md_content + else: + if result_item_content_len + next_md_content_len < limit: + result_item_content += next_md_content + else: + paragraphs.append({'content': result_item_content, 'title': ''}) + result_item_content = title_md_content + next_md_content + if len(result_item_content) > 0: + paragraphs.append({'content': result_item_content, 'title': ''}) + return result + except Exception as e: + return result + + def get_content(self, file, save_image): + pass + + def support(self, file, get_buffer): + file_name: str = file.name.lower() + if file_name.endswith(".csv"): + return True + return False diff --git a/apps/common/handle/impl/doc_split_handle.py b/apps/common/handle/impl/doc_split_handle.py new file mode 100644 index 000000000..752f726a2 --- /dev/null +++ b/apps/common/handle/impl/doc_split_handle.py @@ -0,0 +1,235 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: text_split_handle.py + @date:2024/3/27 18:19 + @desc: +""" +import io +import os +import re +import traceback +import uuid_utils.compat as uuid +from functools import reduce +from typing import List + +from docx import Document, ImagePart +from docx.oxml import ns +from docx.table import Table +from docx.text.paragraph import Paragraph + +from common.handle.base_split_handle import BaseSplitHandle +from common.utils.split_model import SplitModel +from knowledge.models import File +from django.utils.translation import gettext_lazy as _ + +default_pattern_list = [re.compile('(?<=^)# .*|(?<=\\n)# .*'), + re.compile('(?<=\\n)(? 0: + for image in _images: + images.append({'image': image, 'get_image_id_handle': get_image_id_handle}) + except Exception as e: + pass + return images + + +def images_to_string(images, doc: Document, images_list, get_image_id): + return "".join( + [item for item in [image_to_mode(image, doc, images_list, get_image_id) for image in images] if + item is not None]) + + +def get_paragraph_element_txt(paragraph_element, doc: Document, images_list, get_image_id): + try: + images = get_paragraph_element_images(paragraph_element, doc, images_list, get_image_id) + if len(images) > 0: + return images_to_string(images, doc, images_list, get_image_id) + elif paragraph_element.text is not None: + return paragraph_element.text + return "" + except Exception as e: + print(e) + return "" + + +def get_paragraph_txt(paragraph: Paragraph, doc: Document, images_list, get_image_id): + try: + return "".join([get_paragraph_element_txt(e, doc, images_list, get_image_id) for e in paragraph._element]) + except Exception as e: + return "" + + +def get_cell_text(cell, doc: Document, images_list, get_image_id): + try: + return "".join( + [get_paragraph_txt(paragraph, doc, images_list, get_image_id) for paragraph in cell.paragraphs]).replace( + "\n", '
') + except Exception as e: + return "" + + +def get_image_id_func(): + image_map = {} + + def get_image_id(image_id): + _v = image_map.get(image_id) + if _v is None: + image_map[image_id] = uuid.uuid7() + return image_map.get(image_id) + return _v + + return get_image_id + + +title_font_list = [ + [36, 100], + [30, 36] +] + + +def get_title_level(paragraph: Paragraph): + try: + if paragraph.style is not None: + psn = paragraph.style.name + if psn.startswith('Heading') or psn.startswith('TOC 标题') or psn.startswith('标题'): + return int(psn.replace("Heading ", '').replace('TOC 标题', '').replace('标题', + '')) + if len(paragraph.runs) == 1: + font_size = paragraph.runs[0].font.size + pt = font_size.pt + if pt >= 30: + for _value, index in zip(title_font_list, range(len(title_font_list))): + if pt >= _value[0] and pt < _value[1]: + return index + 1 + except Exception as e: + pass + return None + + +class DocSplitHandle(BaseSplitHandle): + @staticmethod + def paragraph_to_md(paragraph: Paragraph, doc: Document, images_list, get_image_id): + try: + title_level = get_title_level(paragraph) + if title_level is not None: + title = "".join(["#" for i in range(title_level)]) + " " + paragraph.text + images = reduce(lambda x, y: [*x, *y], + [get_paragraph_element_images(e, doc, images_list, get_image_id) for e in + paragraph._element], + []) + if len(images) > 0: + return title + '\n' + images_to_string(images, doc, images_list, get_image_id) if len( + paragraph.text) > 0 else images_to_string(images, doc, images_list, get_image_id) + return title + + except Exception as e: + traceback.print_exc() + return paragraph.text + return get_paragraph_txt(paragraph, doc, images_list, get_image_id) + + @staticmethod + def table_to_md(table, doc: Document, images_list, get_image_id): + rows = table.rows + + # 创建 Markdown 格式的表格 + md_table = '| ' + ' | '.join( + [get_cell_text(cell, doc, images_list, get_image_id) for cell in rows[0].cells]) + ' |\n' + md_table += '| ' + ' | '.join(['---' for i in range(len(rows[0].cells))]) + ' |\n' + for row in rows[1:]: + md_table += '| ' + ' | '.join( + [get_cell_text(cell, doc, images_list, get_image_id) for cell in row.cells]) + ' |\n' + return md_table + + def to_md(self, doc, images_list, get_image_id): + elements = [] + for element in doc.element.body: + tag = str(element.tag) + if tag.endswith('tbl'): + # 处理表格 + table = Table(element, doc) + elements.append(table) + elif tag.endswith('p'): + # 处理段落 + paragraph = Paragraph(element, doc) + elements.append(paragraph) + return "\n".join( + [self.paragraph_to_md(element, doc, images_list, get_image_id) if isinstance(element, + Paragraph) else self.table_to_md( + element, + doc, + images_list, get_image_id) + for element + in elements]) + + def handle(self, file, pattern_list: List, with_filter: bool, limit: int, get_buffer, save_image): + file_name = os.path.basename(file.name) + try: + image_list = [] + buffer = get_buffer(file) + doc = Document(io.BytesIO(buffer)) + content = self.to_md(doc, image_list, get_image_id_func()) + if len(image_list) > 0: + save_image(image_list) + if pattern_list is not None and len(pattern_list) > 0: + split_model = SplitModel(pattern_list, with_filter, limit) + else: + split_model = SplitModel(default_pattern_list, with_filter=with_filter, limit=limit) + except BaseException as e: + traceback.print_exception(e) + return {'name': file_name, + 'content': []} + return {'name': file_name, + 'content': split_model.parse(content) + } + + def support(self, file, get_buffer): + file_name: str = file.name.lower() + if file_name.endswith(".docx") or file_name.endswith(".doc") or file_name.endswith( + ".DOC") or file_name.endswith(".DOCX"): + return True + return False + + def get_content(self, file, save_image): + try: + image_list = [] + buffer = file.read() + doc = Document(io.BytesIO(buffer)) + content = self.to_md(doc, image_list, get_image_id_func()) + if len(image_list) > 0: + content = content.replace('/api/image/', '/api/file/') + save_image(image_list) + return content + except BaseException as e: + traceback.print_exception(e) + return f'{e}' diff --git a/apps/common/handle/impl/html_split_handle.py b/apps/common/handle/impl/html_split_handle.py new file mode 100644 index 000000000..fd4b84fc4 --- /dev/null +++ b/apps/common/handle/impl/html_split_handle.py @@ -0,0 +1,73 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: html_split_handle.py + @date:2024/5/23 10:58 + @desc: +""" +import re +import traceback +from typing import List + +from bs4 import BeautifulSoup +from charset_normalizer import detect +from html2text import html2text + +from common.handle.base_split_handle import BaseSplitHandle +from common.utils.split_model import SplitModel + +default_pattern_list = [re.compile('(?<=^)# .*|(?<=\\n)# .*'), + re.compile('(?<=\\n)(? 0: + charset = charset_list[0] + return charset + return detect(buffer)['encoding'] + + +class HTMLSplitHandle(BaseSplitHandle): + def support(self, file, get_buffer): + file_name: str = file.name.lower() + if file_name.endswith(".html") or file_name.endswith(".HTML"): + return True + return False + + def handle(self, file, pattern_list: List, with_filter: bool, limit: int, get_buffer, save_image): + buffer = get_buffer(file) + + if pattern_list is not None and len(pattern_list) > 0: + split_model = SplitModel(pattern_list, with_filter, limit) + else: + split_model = SplitModel(default_pattern_list, with_filter=with_filter, limit=limit) + try: + encoding = get_encoding(buffer) + content = buffer.decode(encoding) + content = html2text(content) + except BaseException as e: + return {'name': file.name, + 'content': []} + return {'name': file.name, + 'content': split_model.parse(content) + } + + def get_content(self, file, save_image): + buffer = file.read() + + try: + encoding = get_encoding(buffer) + content = buffer.decode(encoding) + return html2text(content) + except BaseException as e: + traceback.print_exception(e) + return f'{e}' \ No newline at end of file diff --git a/apps/common/handle/impl/pdf_split_handle.py b/apps/common/handle/impl/pdf_split_handle.py new file mode 100644 index 000000000..3e011d1d5 --- /dev/null +++ b/apps/common/handle/impl/pdf_split_handle.py @@ -0,0 +1,339 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: text_split_handle.py + @date:2024/3/27 18:19 + @desc: +""" +import logging +import os +import re +import tempfile +import time +import traceback +from typing import List + +import fitz +from langchain_community.document_loaders import PyPDFLoader + +from common.handle.base_split_handle import BaseSplitHandle +from common.utils.split_model import SplitModel +from django.utils.translation import gettext_lazy as _ + +default_pattern_list = [re.compile('(?<=^)# .*|(?<=\\n)# .*'), + re.compile('(?<=\\n)(? 0: + return {'name': file.name, 'content': result} + + # 没有目录的pdf + content = self.handle_pdf_content(file, pdf_document) + + if pattern_list is not None and len(pattern_list) > 0: + split_model = SplitModel(pattern_list, with_filter, limit) + else: + split_model = SplitModel(default_pattern_list, with_filter=with_filter, limit=limit) + except BaseException as e: + max_kb.error(f"File: {file.name}, error: {e}") + return {'name': file.name, + 'content': []} + finally: + pdf_document.close() + # 处理完后可以删除临时文件 + os.remove(temp_file_path) + + return {'name': file.name, + 'content': split_model.parse(content) + } + + @staticmethod + def handle_pdf_content(file, pdf_document): + content = "" + for page_num in range(len(pdf_document)): + start_time = time.time() + page = pdf_document.load_page(page_num) + text = page.get_text() + + if text and text.strip(): # 如果页面中有文本内容 + page_content = text + else: + try: + new_doc = fitz.open() + new_doc.insert_pdf(pdf_document, from_page=page_num, to_page=page_num) + page_num_pdf = tempfile.gettempdir() + f"/{file.name}_{page_num}.pdf" + new_doc.save(page_num_pdf) + new_doc.close() + + loader = PyPDFLoader(page_num_pdf, extract_images=True) + page_content = "\n" + loader.load()[0].page_content + except NotImplementedError as e: + # 文件格式不支持,直接退出 + raise e + except BaseException as e: + # 当页出错继续进行下一页,防止一个页面出错导致整个文件解析失败 + max_kb.error(f"File: {file.name}, Page: {page_num + 1}, error: {e}") + continue + finally: + os.remove(page_num_pdf) + + content += page_content + + # Null characters are not allowed. + content = content.replace('\0', '') + + elapsed_time = time.time() - start_time + max_kb.debug( + f"File: {file.name}, Page: {page_num + 1}, Time : {elapsed_time: .3f}s, content-length: {len(page_content)}") + + return content + + @staticmethod + def handle_toc(doc, limit): + # 找到目录 + toc = doc.get_toc() + if toc is None or len(toc) == 0: + return None + + # 创建存储章节内容的数组 + chapters = [] + + # 遍历目录并按章节提取文本 + for i, entry in enumerate(toc): + level, title, start_page = entry + start_page -= 1 # PyMuPDF 页码从 0 开始,书签页码从 1 开始 + chapter_title = title + # 确定结束页码,如果是最后一个章节则到文档末尾 + if i + 1 < len(toc): + end_page = toc[i + 1][2] - 1 + else: + end_page = doc.page_count - 1 + + # 去掉标题中的符号 + title = PdfSplitHandle.handle_chapter_title(title) + + # 提取该章节的文本内容 + chapter_text = "" + for page_num in range(start_page, end_page + 1): + page = doc.load_page(page_num) # 加载页面 + text = page.get_text("text") + text = re.sub(r'(? -1: + text = text[idx + len(title):] + + if i + 1 < len(toc): + l, next_title, next_start_page = toc[i + 1] + next_title = PdfSplitHandle.handle_chapter_title(next_title) + # print(f'next_title: {next_title}') + idx = text.find(next_title) + if idx > -1: + text = text[:idx] + + chapter_text += text # 提取文本 + + # Null characters are not allowed. + chapter_text = chapter_text.replace('\0', '') + # 限制标题长度 + real_chapter_title = chapter_title[:256] + # 限制章节内容长度 + if 0 < limit < len(chapter_text): + split_text = PdfSplitHandle.split_text(chapter_text, limit) + for text in split_text: + chapters.append({"title": real_chapter_title, "content": text}) + else: + chapters.append({"title": real_chapter_title, "content": chapter_text if chapter_text else real_chapter_title}) + # 保存章节内容和章节标题 + return chapters + + @staticmethod + def handle_links(doc, pattern_list, with_filter, limit): + # 检查文档是否包含内部链接 + if not check_links_in_pdf(doc): + return + # 创建存储章节内容的数组 + chapters = [] + toc_start_page = -1 + page_content = "" + handle_pre_toc = True + # 遍历 PDF 的每一页,查找带有目录链接的页 + for page_num in range(doc.page_count): + page = doc.load_page(page_num) + links = page.get_links() + # 如果目录开始页码未设置,则设置为当前页码 + if len(links) > 0: + toc_start_page = page_num + if toc_start_page < 0: + page_content += page.get_text('text') + # 检查该页是否包含内部链接(即指向文档内部的页面) + for num in range(len(links)): + link = links[num] + if link['kind'] == 1: # 'kind' 为 1 表示内部链接 + # 获取链接目标的页面 + dest_page = link['page'] + rect = link['from'] # 获取链接的矩形区域 + # 如果目录开始页码包括前言部分,则不处理前言部分 + if dest_page < toc_start_page: + handle_pre_toc = False + + # 提取链接区域的文本作为标题 + link_title = page.get_text("text", clip=rect).strip().split("\n")[0].replace('.', '').strip() + # print(f'link_title: {link_title}') + # 提取目标页面内容作为章节开始 + start_page = dest_page + end_page = dest_page + # 下一个link + next_link = links[num + 1] if num + 1 < len(links) else None + next_link_title = None + if next_link is not None and next_link['kind'] == 1: + rect = next_link['from'] + next_link_title = page.get_text("text", clip=rect).strip() \ + .split("\n")[0].replace('.', '').strip() + # print(f'next_link_title: {next_link_title}') + end_page = next_link['page'] + + # 提取章节内容 + chapter_text = "" + for p_num in range(start_page, end_page + 1): + p = doc.load_page(p_num) + text = p.get_text("text") + text = re.sub(r'(? -1: + text = text[idx + len(link_title):] + + if next_link_title is not None: + idx = text.find(next_link_title) + if idx > -1: + text = text[:idx] + chapter_text += text + + # Null characters are not allowed. + chapter_text = chapter_text.replace('\0', '') + + # 限制章节内容长度 + if 0 < limit < len(chapter_text): + split_text = PdfSplitHandle.split_text(chapter_text, limit) + for text in split_text: + chapters.append({"title": link_title, "content": text}) + else: + # 保存章节信息 + chapters.append({"title": link_title, "content": chapter_text}) + + # 目录中没有前言部分,手动处理 + if handle_pre_toc: + pre_toc = [] + lines = page_content.strip().split('\n') + try: + for line in lines: + if re.match(r'^前\s*言', line): + pre_toc.append({'title': line, 'content': ''}) + else: + pre_toc[-1]['content'] += line + for i in range(len(pre_toc)): + pre_toc[i]['content'] = re.sub(r'(? 0: + split_model = SplitModel(pattern_list, with_filter, limit) + else: + split_model = SplitModel(default_pattern_list, with_filter=with_filter, limit=limit) + # 插入目录前的部分 + page_content = re.sub(r'(?= length: + # 查找最近的句号 + last_period_index = current_segment.rfind('.') + if last_period_index != -1: + segments.append(current_segment[:last_period_index + 1]) + current_segment = current_segment[last_period_index + 1:] # 更新当前段落 + else: + segments.append(current_segment) + current_segment = "" + + # 处理剩余的部分 + if current_segment: + segments.append(current_segment) + + return segments + + @staticmethod + def handle_chapter_title(title): + title = re.sub(r'[一二三四五六七八九十\s*]、\s*', '', title) + title = re.sub(r'第[一二三四五六七八九十]章\s*', '', title) + return title + + def support(self, file, get_buffer): + file_name: str = file.name.lower() + if file_name.endswith(".pdf") or file_name.endswith(".PDF"): + return True + return False + + def get_content(self, file, save_image): + with tempfile.NamedTemporaryFile(delete=False) as temp_file: + # 将上传的文件保存到临时文件中 + temp_file.write(file.read()) + # 获取临时文件的路径 + temp_file_path = temp_file.name + + pdf_document = fitz.open(temp_file_path) + try: + return self.handle_pdf_content(file, pdf_document) + except BaseException as e: + traceback.print_exception(e) + return f'{e}' \ No newline at end of file diff --git a/apps/common/handle/impl/qa/__init__.py b/apps/common/handle/impl/qa/__init__.py new file mode 100644 index 000000000..ad09602db --- /dev/null +++ b/apps/common/handle/impl/qa/__init__.py @@ -0,0 +1,8 @@ +# coding=utf-8 +""" + @project: qabot + @Author:虎 + @file: __init__.py.py + @date:2023/9/6 10:09 + @desc: +""" diff --git a/apps/common/handle/impl/qa/csv_parse_qa_handle.py b/apps/common/handle/impl/qa/csv_parse_qa_handle.py new file mode 100644 index 000000000..75c22cbda --- /dev/null +++ b/apps/common/handle/impl/qa/csv_parse_qa_handle.py @@ -0,0 +1,59 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: csv_parse_qa_handle.py + @date:2024/5/21 14:59 + @desc: +""" +import csv +import io + +from charset_normalizer import detect + +from common.handle.base_parse_qa_handle import BaseParseQAHandle, get_title_row_index_dict, get_row_value + + +def read_csv_standard(file_path): + data = [] + with open(file_path, 'r') as file: + reader = csv.reader(file) + for row in reader: + data.append(row) + return data + + +class CsvParseQAHandle(BaseParseQAHandle): + def support(self, file, get_buffer): + file_name: str = file.name.lower() + if file_name.endswith(".csv"): + return True + return False + + def handle(self, file, get_buffer, save_image): + buffer = get_buffer(file) + try: + reader = csv.reader(io.TextIOWrapper(io.BytesIO(buffer), encoding=detect(buffer)['encoding'])) + try: + title_row_list = reader.__next__() + except Exception as e: + return [{'name': file.name, 'paragraphs': []}] + if len(title_row_list) == 0: + return [{'name': file.name, 'paragraphs': []}] + title_row_index_dict = get_title_row_index_dict(title_row_list) + paragraph_list = [] + for row in reader: + content = get_row_value(row, title_row_index_dict, 'content') + if content is None: + continue + problem = get_row_value(row, title_row_index_dict, 'problem_list') + problem = str(problem) if problem is not None else '' + problem_list = [{'content': p[0:255]} for p in problem.split('\n') if len(p.strip()) > 0] + title = get_row_value(row, title_row_index_dict, 'title') + title = str(title) if title is not None else '' + paragraph_list.append({'title': title[0:255], + 'content': content[0:102400], + 'problem_list': problem_list}) + return [{'name': file.name, 'paragraphs': paragraph_list}] + except Exception as e: + return [{'name': file.name, 'paragraphs': []}] diff --git a/apps/common/handle/impl/qa/xls_parse_qa_handle.py b/apps/common/handle/impl/qa/xls_parse_qa_handle.py new file mode 100644 index 000000000..06edb1fb3 --- /dev/null +++ b/apps/common/handle/impl/qa/xls_parse_qa_handle.py @@ -0,0 +1,61 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: xls_parse_qa_handle.py + @date:2024/5/21 14:59 + @desc: +""" + +import xlrd + +from common.handle.base_parse_qa_handle import BaseParseQAHandle, get_title_row_index_dict, get_row_value + + +def handle_sheet(file_name, sheet): + rows = iter([sheet.row_values(i) for i in range(sheet.nrows)]) + try: + title_row_list = next(rows) + except Exception as e: + return {'name': file_name, 'paragraphs': []} + if len(title_row_list) == 0: + return {'name': file_name, 'paragraphs': []} + title_row_index_dict = get_title_row_index_dict(title_row_list) + paragraph_list = [] + for row in rows: + content = get_row_value(row, title_row_index_dict, 'content') + if content is None: + continue + problem = get_row_value(row, title_row_index_dict, 'problem_list') + problem = str(problem) if problem is not None else '' + problem_list = [{'content': p[0:255]} for p in problem.split('\n') if len(p.strip()) > 0] + title = get_row_value(row, title_row_index_dict, 'title') + title = str(title) if title is not None else '' + content = str(content) + paragraph_list.append({'title': title[0:255], + 'content': content[0:102400], + 'problem_list': problem_list}) + return {'name': file_name, 'paragraphs': paragraph_list} + + +class XlsParseQAHandle(BaseParseQAHandle): + def support(self, file, get_buffer): + file_name: str = file.name.lower() + buffer = get_buffer(file) + if file_name.endswith(".xls") and xlrd.inspect_format(content=buffer): + return True + return False + + def handle(self, file, get_buffer, save_image): + buffer = get_buffer(file) + try: + workbook = xlrd.open_workbook(file_contents=buffer) + worksheets = workbook.sheets() + worksheets_size = len(worksheets) + return [row for row in + [handle_sheet(file.name, + sheet) if worksheets_size == 1 and sheet.name == 'Sheet1' else handle_sheet( + sheet.name, sheet) for sheet + in worksheets] if row is not None] + except Exception as e: + return [{'name': file.name, 'paragraphs': []}] diff --git a/apps/common/handle/impl/qa/xlsx_parse_qa_handle.py b/apps/common/handle/impl/qa/xlsx_parse_qa_handle.py new file mode 100644 index 000000000..ecadb5139 --- /dev/null +++ b/apps/common/handle/impl/qa/xlsx_parse_qa_handle.py @@ -0,0 +1,72 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: xlsx_parse_qa_handle.py + @date:2024/5/21 14:59 + @desc: +""" +import io + +import openpyxl + +from common.handle.base_parse_qa_handle import BaseParseQAHandle, get_title_row_index_dict, get_row_value +from common.handle.impl.common_handle import xlsx_embed_cells_images + + +def handle_sheet(file_name, sheet, image_dict): + rows = sheet.rows + try: + title_row_list = next(rows) + title_row_list = [row.value for row in title_row_list] + except Exception as e: + return {'name': file_name, 'paragraphs': []} + if len(title_row_list) == 0: + return {'name': file_name, 'paragraphs': []} + title_row_index_dict = get_title_row_index_dict(title_row_list) + paragraph_list = [] + for row in rows: + content = get_row_value(row, title_row_index_dict, 'content') + if content is None or content.value is None: + continue + problem = get_row_value(row, title_row_index_dict, 'problem_list') + problem = str(problem.value) if problem is not None and problem.value is not None else '' + problem_list = [{'content': p[0:255]} for p in problem.split('\n') if len(p.strip()) > 0] + title = get_row_value(row, title_row_index_dict, 'title') + title = str(title.value) if title is not None and title.value is not None else '' + content = str(content.value) + image = image_dict.get(content, None) + if image is not None: + content = f'![](/api/image/{image.id})' + paragraph_list.append({'title': title[0:255], + 'content': content[0:102400], + 'problem_list': problem_list}) + return {'name': file_name, 'paragraphs': paragraph_list} + + +class XlsxParseQAHandle(BaseParseQAHandle): + def support(self, file, get_buffer): + file_name: str = file.name.lower() + if file_name.endswith(".xlsx"): + return True + return False + + def handle(self, file, get_buffer, save_image): + buffer = get_buffer(file) + try: + workbook = openpyxl.load_workbook(io.BytesIO(buffer)) + try: + image_dict: dict = xlsx_embed_cells_images(io.BytesIO(buffer)) + save_image([item for item in image_dict.values()]) + except Exception as e: + image_dict = {} + worksheets = workbook.worksheets + worksheets_size = len(worksheets) + return [row for row in + [handle_sheet(file.name, + sheet, + image_dict) if worksheets_size == 1 and sheet.title == 'Sheet1' else handle_sheet( + sheet.title, sheet, image_dict) for sheet + in worksheets] if row is not None] + except Exception as e: + return [{'name': file.name, 'paragraphs': []}] diff --git a/apps/common/handle/impl/qa/zip_parse_qa_handle.py b/apps/common/handle/impl/qa/zip_parse_qa_handle.py new file mode 100644 index 000000000..d00bc14dd --- /dev/null +++ b/apps/common/handle/impl/qa/zip_parse_qa_handle.py @@ -0,0 +1,161 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: text_split_handle.py + @date:2024/3/27 18:19 + @desc: +""" +import io +import os +import re +import uuid_utils.compat as uuid +import zipfile +from typing import List +from urllib.parse import urljoin + +from django.db.models import QuerySet + +from common.handle.base_parse_qa_handle import BaseParseQAHandle +from common.handle.impl.qa.csv_parse_qa_handle import CsvParseQAHandle +from common.handle.impl.qa.xls_parse_qa_handle import XlsParseQAHandle +from common.handle.impl.qa.xlsx_parse_qa_handle import XlsxParseQAHandle +from common.utils.common import parse_md_image +from knowledge.models import File +from django.utils.translation import gettext_lazy as _ + + +class FileBufferHandle: + buffer = None + + def get_buffer(self, file): + if self.buffer is None: + self.buffer = file.read() + return self.buffer + + +split_handles = [ + XlsParseQAHandle(), + XlsxParseQAHandle(), + CsvParseQAHandle() +] + + +def file_to_paragraph(file, save_inner_image): + """ + 文件转换为段落列表 + @param file: 文件 + @return: { + name:文件名 + paragraphs:段落列表 + } + """ + get_buffer = FileBufferHandle().get_buffer + for split_handle in split_handles: + if split_handle.support(file, get_buffer): + return split_handle.handle(file, get_buffer, save_inner_image) + raise Exception(_("Unsupported file format")) + + +def is_valid_uuid(uuid_str: str): + """ + 校验字符串是否是uuid + @param uuid_str: 需要校验的字符串 + @return: bool + """ + try: + uuid.UUID(uuid_str) + except ValueError: + return False + return True + + +def get_image_list(result_list: list, zip_files: List[str]): + """ + 获取图片文件列表 + @param result_list: + @param zip_files: + @return: + """ + image_file_list = [] + for result in result_list: + for p in result.get('paragraphs', []): + content: str = p.get('content', '') + image_list = parse_md_image(content) + for image in image_list: + search = re.search("\(.*\)", image) + if search: + new_image_id = str(uuid.uuid7()) + source_image_path = search.group().replace('(', '').replace(')', '') + image_path = urljoin(result.get('name'), '.' + source_image_path if source_image_path.startswith( + '/') else source_image_path) + if not zip_files.__contains__(image_path): + continue + if image_path.startswith('api/file/') or image_path.startswith('api/image/'): + image_id = image_path.replace('api/file/', '').replace('api/image/', '') + if is_valid_uuid(image_id): + image_file_list.append({'source_file': image_path, + 'image_id': image_id}) + else: + image_file_list.append({'source_file': image_path, + 'image_id': new_image_id}) + content = content.replace(source_image_path, f'/api/image/{new_image_id}') + p['content'] = content + else: + image_file_list.append({'source_file': image_path, + 'image_id': new_image_id}) + content = content.replace(source_image_path, f'/api/image/{new_image_id}') + p['content'] = content + + return image_file_list + + +def filter_image_file(result_list: list, image_list): + image_source_file_list = [image.get('source_file') for image in image_list] + return [r for r in result_list if not image_source_file_list.__contains__(r.get('name', ''))] + + +class ZipParseQAHandle(BaseParseQAHandle): + + def handle(self, file, get_buffer, save_image): + buffer = get_buffer(file) + bytes_io = io.BytesIO(buffer) + result = [] + # 打开zip文件 + with zipfile.ZipFile(bytes_io, 'r') as zip_ref: + # 获取压缩包中的文件名列表 + files = zip_ref.namelist() + # 读取压缩包中的文件内容 + for file in files: + # 跳过 macOS 特有的元数据目录和文件 + if file.endswith('/') or file.startswith('__MACOSX'): + continue + with zip_ref.open(file) as f: + # 对文件内容进行处理 + try: + value = file_to_paragraph(f, save_image) + if isinstance(value, list): + result = [*result, *value] + else: + result.append(value) + except Exception: + pass + image_list = get_image_list(result, files) + result = filter_image_file(result, image_list) + image_mode_list = [] + for image in image_list: + with zip_ref.open(image.get('source_file')) as f: + i = File( + id=image.get('image_id'), + file_name=os.path.basename(image.get('source_file')), + meta={'debug': False, 'content': f.read()} + ) + image_mode_list.append(i) + save_image(image_mode_list) + return result + + def support(self, file, get_buffer): + file_name: str = file.name.lower() + if file_name.endswith(".zip") or file_name.endswith(".ZIP"): + return True + return False diff --git a/apps/common/handle/impl/response/__init__.py b/apps/common/handle/impl/response/__init__.py new file mode 100644 index 000000000..ad09602db --- /dev/null +++ b/apps/common/handle/impl/response/__init__.py @@ -0,0 +1,8 @@ +# coding=utf-8 +""" + @project: qabot + @Author:虎 + @file: __init__.py.py + @date:2023/9/6 10:09 + @desc: +""" diff --git a/apps/common/handle/impl/response/openai_to_response.py b/apps/common/handle/impl/response/openai_to_response.py new file mode 100644 index 000000000..f2b69384e --- /dev/null +++ b/apps/common/handle/impl/response/openai_to_response.py @@ -0,0 +1,52 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: openai_to_response.py + @date:2024/9/6 16:08 + @desc: +""" +import datetime + +from django.http import JsonResponse +from openai.types import CompletionUsage +from openai.types.chat import ChatCompletionChunk, ChatCompletionMessage, ChatCompletion +from openai.types.chat.chat_completion import Choice as BlockChoice +from openai.types.chat.chat_completion_chunk import Choice, ChoiceDelta +from rest_framework import status + +from common.handle.base_to_response import BaseToResponse + + +class OpenaiToResponse(BaseToResponse): + def to_block_response(self, chat_id, chat_record_id, content, is_end, completion_tokens, prompt_tokens, + other_params: dict = None, + _status=status.HTTP_200_OK): + if other_params is None: + other_params = {} + data = ChatCompletion(id=chat_record_id, choices=[ + BlockChoice(finish_reason='stop', index=0, chat_id=chat_id, + answer_list=other_params.get('answer_list', ""), + message=ChatCompletionMessage(role='assistant', content=content))], + created=datetime.datetime.now().second, model='', object='chat.completion', + usage=CompletionUsage(completion_tokens=completion_tokens, + prompt_tokens=prompt_tokens, + total_tokens=completion_tokens + prompt_tokens) + ).dict() + return JsonResponse(data=data, status=_status) + + def to_stream_chunk_response(self, chat_id, chat_record_id, node_id, up_node_id_list, content, is_end, + completion_tokens, + prompt_tokens, other_params: dict = None): + if other_params is None: + other_params = {} + chunk = ChatCompletionChunk(id=chat_record_id, model='', object='chat.completion.chunk', + created=datetime.datetime.now().second, choices=[ + Choice(delta=ChoiceDelta(content=content, reasoning_content=other_params.get('reasoning_content', ""), + chat_id=chat_id), + finish_reason='stop' if is_end else None, + index=0)], + usage=CompletionUsage(completion_tokens=completion_tokens, + prompt_tokens=prompt_tokens, + total_tokens=completion_tokens + prompt_tokens)).json() + return super().format_stream_chunk(chunk) diff --git a/apps/common/handle/impl/response/system_to_response.py b/apps/common/handle/impl/response/system_to_response.py new file mode 100644 index 000000000..a1a530dba --- /dev/null +++ b/apps/common/handle/impl/response/system_to_response.py @@ -0,0 +1,41 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: system_to_response.py + @date:2024/9/6 18:03 + @desc: +""" +import json + +from rest_framework import status + +from common.handle.base_to_response import BaseToResponse +from common.result import result + + +class SystemToResponse(BaseToResponse): + def to_block_response(self, chat_id, chat_record_id, content, is_end, completion_tokens, + prompt_tokens, other_params: dict = None, + _status=status.HTTP_200_OK): + if other_params is None: + other_params = {} + return result.success({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True, + 'content': content, 'is_end': is_end, **other_params, + 'completion_tokens': completion_tokens, 'prompt_tokens': prompt_tokens}, + response_status=_status, + code=_status) + + def to_stream_chunk_response(self, chat_id, chat_record_id, node_id, up_node_id_list, content, is_end, + completion_tokens, + prompt_tokens, other_params: dict = None): + if other_params is None: + other_params = {} + chunk = json.dumps({'chat_id': str(chat_id), 'chat_record_id': str(chat_record_id), 'operate': True, + 'content': content, 'node_id': node_id, 'up_node_id_list': up_node_id_list, + 'is_end': is_end, + 'usage': {'completion_tokens': completion_tokens, + 'prompt_tokens': prompt_tokens, + 'total_tokens': completion_tokens + prompt_tokens}, + **other_params}) + return super().format_stream_chunk(chunk) diff --git a/apps/common/handle/impl/table/__init__.py b/apps/common/handle/impl/table/__init__.py new file mode 100644 index 000000000..ad09602db --- /dev/null +++ b/apps/common/handle/impl/table/__init__.py @@ -0,0 +1,8 @@ +# coding=utf-8 +""" + @project: qabot + @Author:虎 + @file: __init__.py.py + @date:2023/9/6 10:09 + @desc: +""" diff --git a/apps/common/handle/impl/table/csv_parse_table_handle.py b/apps/common/handle/impl/table/csv_parse_table_handle.py new file mode 100644 index 000000000..e2fc7ce86 --- /dev/null +++ b/apps/common/handle/impl/table/csv_parse_table_handle.py @@ -0,0 +1,44 @@ +# coding=utf-8 +import logging + +from charset_normalizer import detect + +from common.handle.base_parse_table_handle import BaseParseTableHandle + +max_kb = logging.getLogger("max_kb") + + +class CsvSplitHandle(BaseParseTableHandle): + def support(self, file, get_buffer): + file_name: str = file.name.lower() + if file_name.endswith(".csv"): + return True + return False + + def handle(self, file, get_buffer,save_image): + buffer = get_buffer(file) + try: + content = buffer.decode(detect(buffer)['encoding']) + except BaseException as e: + max_kb.error(f'csv split handle error: {e}') + return [{'name': file.name, 'paragraphs': []}] + + csv_model = content.split('\n') + paragraphs = [] + # 第一行为标题 + title = csv_model[0].split(',') + for row in csv_model[1:]: + if not row: + continue + line = '; '.join([f'{key}:{value}' for key, value in zip(title, row.split(','))]) + paragraphs.append({'title': '', 'content': line}) + + return [{'name': file.name, 'paragraphs': paragraphs}] + + def get_content(self, file, save_image): + buffer = file.read() + try: + return buffer.decode(detect(buffer)['encoding']) + except BaseException as e: + max_kb.error(f'csv split handle error: {e}') + return f'error: {e}' \ No newline at end of file diff --git a/apps/common/handle/impl/table/xls_parse_table_handle.py b/apps/common/handle/impl/table/xls_parse_table_handle.py new file mode 100644 index 000000000..897e347e8 --- /dev/null +++ b/apps/common/handle/impl/table/xls_parse_table_handle.py @@ -0,0 +1,94 @@ +# coding=utf-8 +import logging + +import xlrd + +from common.handle.base_parse_table_handle import BaseParseTableHandle + +max_kb = logging.getLogger("max_kb") + + +class XlsSplitHandle(BaseParseTableHandle): + def support(self, file, get_buffer): + file_name: str = file.name.lower() + buffer = get_buffer(file) + if file_name.endswith(".xls") and xlrd.inspect_format(content=buffer): + return True + return False + + def handle(self, file, get_buffer, save_image): + buffer = get_buffer(file) + try: + wb = xlrd.open_workbook(file_contents=buffer, formatting_info=True) + result = [] + sheets = wb.sheets() + for sheet in sheets: + # 获取合并单元格的范围信息 + merged_cells = sheet.merged_cells + print(merged_cells) + data = [] + paragraphs = [] + # 获取第一行作为标题行 + headers = [sheet.cell_value(0, col_idx) for col_idx in range(sheet.ncols)] + # 从第二行开始遍历每一行(跳过标题行) + for row_idx in range(1, sheet.nrows): + row_data = {} + for col_idx in range(sheet.ncols): + cell_value = sheet.cell_value(row_idx, col_idx) + + # 检查是否为空单元格,如果为空检查是否在合并区域中 + if cell_value == "": + # 检查当前单元格是否在合并区域 + for (rlo, rhi, clo, chi) in merged_cells: + if rlo <= row_idx < rhi and clo <= col_idx < chi: + # 使用合并区域的左上角单元格的值 + cell_value = sheet.cell_value(rlo, clo) + break + + # 将标题作为键,单元格的值作为值存入字典 + row_data[headers[col_idx]] = cell_value + data.append(row_data) + + for row in data: + row_output = "; ".join([f"{key}: {value}" for key, value in row.items()]) + # print(row_output) + paragraphs.append({'title': '', 'content': row_output}) + + result.append({'name': sheet.name, 'paragraphs': paragraphs}) + + except BaseException as e: + max_kb.error(f'excel split handle error: {e}') + return [{'name': file.name, 'paragraphs': []}] + return result + + def get_content(self, file, save_image): + # 打开 .xls 文件 + try: + workbook = xlrd.open_workbook(file_contents=file.read(), formatting_info=True) + sheets = workbook.sheets() + md_tables = '' + for sheet in sheets: + # 过滤空白的sheet + if sheet.nrows == 0 or sheet.ncols == 0: + continue + + # 获取表头和内容 + headers = sheet.row_values(0) + data = [sheet.row_values(row_idx) for row_idx in range(1, sheet.nrows)] + + # 构建 Markdown 表格 + md_table = '| ' + ' | '.join(headers) + ' |\n' + md_table += '| ' + ' | '.join(['---'] * len(headers)) + ' |\n' + for row in data: + # 将每个单元格中的内容替换换行符为
以保留原始格式 + md_table += '| ' + ' | '.join( + [str(cell) + .replace('\r\n', '
') + .replace('\n', '
') + if cell else '' for cell in row]) + ' |\n' + md_tables += md_table + '\n\n' + + return md_tables + except Exception as e: + max_kb.error(f'excel split handle error: {e}') + return f'error: {e}' diff --git a/apps/common/handle/impl/table/xlsx_parse_table_handle.py b/apps/common/handle/impl/table/xlsx_parse_table_handle.py new file mode 100644 index 000000000..7b50683fa --- /dev/null +++ b/apps/common/handle/impl/table/xlsx_parse_table_handle.py @@ -0,0 +1,118 @@ +# coding=utf-8 +import io +import logging + +from openpyxl import load_workbook + +from common.handle.base_parse_table_handle import BaseParseTableHandle +from common.handle.impl.common_handle import xlsx_embed_cells_images + +max_kb = logging.getLogger("max_kb") + + +class XlsxSplitHandle(BaseParseTableHandle): + def support(self, file, get_buffer): + file_name: str = file.name.lower() + if file_name.endswith('.xlsx'): + return True + return False + + def fill_merged_cells(self, sheet, image_dict): + data = [] + + # 获取第一行作为标题行 + headers = [] + for idx, cell in enumerate(sheet[1]): + if cell.value is None: + headers.append(' ' * (idx + 1)) + else: + headers.append(cell.value) + + # 从第二行开始遍历每一行 + for row in sheet.iter_rows(min_row=2, values_only=False): + row_data = {} + for col_idx, cell in enumerate(row): + cell_value = cell.value + + # 如果单元格为空,并且该单元格在合并单元格内,获取合并单元格的值 + if cell_value is None: + for merged_range in sheet.merged_cells.ranges: + if cell.coordinate in merged_range: + cell_value = sheet[merged_range.min_row][merged_range.min_col - 1].value + break + + image = image_dict.get(cell_value, None) + if image is not None: + cell_value = f'![](/api/image/{image.id})' + + # 使用标题作为键,单元格的值作为值存入字典 + row_data[headers[col_idx]] = cell_value + data.append(row_data) + + return data + + def handle(self, file, get_buffer, save_image): + buffer = get_buffer(file) + try: + wb = load_workbook(io.BytesIO(buffer)) + try: + image_dict: dict = xlsx_embed_cells_images(io.BytesIO(buffer)) + save_image([item for item in image_dict.values()]) + except Exception as e: + image_dict = {} + result = [] + for sheetname in wb.sheetnames: + paragraphs = [] + ws = wb[sheetname] + data = self.fill_merged_cells(ws, image_dict) + + for row in data: + row_output = "; ".join([f"{key}: {value}" for key, value in row.items()]) + # print(row_output) + paragraphs.append({'title': '', 'content': row_output}) + + result.append({'name': sheetname, 'paragraphs': paragraphs}) + + except BaseException as e: + max_kb.error(f'excel split handle error: {e}') + return [{'name': file.name, 'paragraphs': []}] + return result + + + def get_content(self, file, save_image): + try: + # 加载 Excel 文件 + workbook = load_workbook(file) + try: + image_dict: dict = xlsx_embed_cells_images(file) + if len(image_dict) > 0: + save_image(image_dict.values()) + except Exception as e: + print(f'{e}') + image_dict = {} + md_tables = '' + # 如果未指定 sheet_name,则使用第一个工作表 + for sheetname in workbook.sheetnames: + sheet = workbook[sheetname] if sheetname else workbook.active + rows = self.fill_merged_cells(sheet, image_dict) + if len(rows) == 0: + continue + # 提取表头和内容 + + headers = [f"{key}" for key, value in rows[0].items()] + + # 构建 Markdown 表格 + md_table = '| ' + ' | '.join(headers) + ' |\n' + md_table += '| ' + ' | '.join(['---'] * len(headers)) + ' |\n' + for row in rows: + r = [f'{value}' for key, value in row.items()] + md_table += '| ' + ' | '.join( + [str(cell).replace('\n', '
') if cell is not None else '' for cell in r]) + ' |\n' + + md_tables += md_table + '\n\n' + + md_tables = md_tables.replace('/api/image/', '/api/file/') + return md_tables + except Exception as e: + max_kb.error(f'excel split handle error: {e}') + return f'error: {e}' diff --git a/apps/common/handle/impl/text_split_handle.py b/apps/common/handle/impl/text_split_handle.py new file mode 100644 index 000000000..1a18ab030 --- /dev/null +++ b/apps/common/handle/impl/text_split_handle.py @@ -0,0 +1,60 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: text_split_handle.py + @date:2024/3/27 18:19 + @desc: +""" +import re +import traceback +from typing import List + +from charset_normalizer import detect + +from common.handle.base_split_handle import BaseSplitHandle +from common.utils.split_model import SplitModel + +default_pattern_list = [re.compile('(?<=^)# .*|(?<=\\n)# .*'), + re.compile('(?<=\\n)(? 0.5: + return True + return False + + def handle(self, file, pattern_list: List, with_filter: bool, limit: int, get_buffer, save_image): + buffer = get_buffer(file) + if pattern_list is not None and len(pattern_list) > 0: + split_model = SplitModel(pattern_list, with_filter, limit) + else: + split_model = SplitModel(default_pattern_list, with_filter=with_filter, limit=limit) + try: + content = buffer.decode(detect(buffer)['encoding']) + except BaseException as e: + return {'name': file.name, + 'content': []} + return {'name': file.name, + 'content': split_model.parse(content) + } + + def get_content(self, file, save_image): + buffer = file.read() + try: + return buffer.decode(detect(buffer)['encoding']) + except BaseException as e: + traceback.print_exception(e) + return f'{e}' \ No newline at end of file diff --git a/apps/common/handle/impl/xls_split_handle.py b/apps/common/handle/impl/xls_split_handle.py new file mode 100644 index 000000000..dbdcc9550 --- /dev/null +++ b/apps/common/handle/impl/xls_split_handle.py @@ -0,0 +1,80 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: xls_parse_qa_handle.py + @date:2024/5/21 14:59 + @desc: +""" +from typing import List + +import xlrd + +from common.handle.base_split_handle import BaseSplitHandle + + +def post_cell(cell_value): + return cell_value.replace('\r\n', '
').replace('\n', '
').replace('|', '|') + + +def row_to_md(row): + return '| ' + ' | '.join( + [post_cell(str(cell)) if cell is not None else '' for cell in row]) + ' |\n' + + +def handle_sheet(file_name, sheet, limit: int): + rows = iter([sheet.row_values(i) for i in range(sheet.nrows)]) + paragraphs = [] + result = {'name': file_name, 'content': paragraphs} + try: + title_row_list = next(rows) + title_md_content = row_to_md(title_row_list) + title_md_content += '| ' + ' | '.join( + ['---' if cell is not None else '' for cell in title_row_list]) + ' |\n' + except Exception as e: + return result + if len(title_row_list) == 0: + return result + result_item_content = '' + for row in rows: + next_md_content = row_to_md(row) + next_md_content_len = len(next_md_content) + result_item_content_len = len(result_item_content) + if len(result_item_content) == 0: + result_item_content += title_md_content + result_item_content += next_md_content + else: + if result_item_content_len + next_md_content_len < limit: + result_item_content += next_md_content + else: + paragraphs.append({'content': result_item_content, 'title': ''}) + result_item_content = title_md_content + next_md_content + if len(result_item_content) > 0: + paragraphs.append({'content': result_item_content, 'title': ''}) + return result + + +class XlsSplitHandle(BaseSplitHandle): + def handle(self, file, pattern_list: List, with_filter: bool, limit: int, get_buffer, save_image): + buffer = get_buffer(file) + try: + workbook = xlrd.open_workbook(file_contents=buffer) + worksheets = workbook.sheets() + worksheets_size = len(worksheets) + return [row for row in + [handle_sheet(file.name, + sheet, limit) if worksheets_size == 1 and sheet.name == 'Sheet1' else handle_sheet( + sheet.name, sheet, limit) for sheet + in worksheets] if row is not None] + except Exception as e: + return [{'name': file.name, 'content': []}] + + def get_content(self, file, save_image): + pass + + def support(self, file, get_buffer): + file_name: str = file.name.lower() + buffer = get_buffer(file) + if file_name.endswith(".xls") and xlrd.inspect_format(content=buffer): + return True + return False diff --git a/apps/common/handle/impl/xlsx_split_handle.py b/apps/common/handle/impl/xlsx_split_handle.py new file mode 100644 index 000000000..abda06d2e --- /dev/null +++ b/apps/common/handle/impl/xlsx_split_handle.py @@ -0,0 +1,92 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: xlsx_parse_qa_handle.py + @date:2024/5/21 14:59 + @desc: +""" +import io +from typing import List + +import openpyxl + +from common.handle.base_split_handle import BaseSplitHandle +from common.handle.impl.common_handle import xlsx_embed_cells_images + + +def post_cell(image_dict, cell_value): + image = image_dict.get(cell_value, None) + if image is not None: + return f'![](/api/image/{image.id})' + return cell_value.replace('\n', '
').replace('|', '|') + + +def row_to_md(row, image_dict): + return '| ' + ' | '.join( + [post_cell(image_dict, str(cell.value if cell.value is not None else '')) if cell is not None else '' for cell + in row]) + ' |\n' + + +def handle_sheet(file_name, sheet, image_dict, limit: int): + rows = sheet.rows + paragraphs = [] + result = {'name': file_name, 'content': paragraphs} + try: + title_row_list = next(rows) + title_md_content = row_to_md(title_row_list, image_dict) + title_md_content += '| ' + ' | '.join( + ['---' if cell is not None else '' for cell in title_row_list]) + ' |\n' + except Exception as e: + return result + if len(title_row_list) == 0: + return result + result_item_content = '' + for row in rows: + next_md_content = row_to_md(row, image_dict) + next_md_content_len = len(next_md_content) + result_item_content_len = len(result_item_content) + if len(result_item_content) == 0: + result_item_content += title_md_content + result_item_content += next_md_content + else: + if result_item_content_len + next_md_content_len < limit: + result_item_content += next_md_content + else: + paragraphs.append({'content': result_item_content, 'title': ''}) + result_item_content = title_md_content + next_md_content + if len(result_item_content) > 0: + paragraphs.append({'content': result_item_content, 'title': ''}) + return result + + +class XlsxSplitHandle(BaseSplitHandle): + def handle(self, file, pattern_list: List, with_filter: bool, limit: int, get_buffer, save_image): + buffer = get_buffer(file) + try: + workbook = openpyxl.load_workbook(io.BytesIO(buffer)) + try: + image_dict: dict = xlsx_embed_cells_images(io.BytesIO(buffer)) + save_image([item for item in image_dict.values()]) + except Exception as e: + image_dict = {} + worksheets = workbook.worksheets + worksheets_size = len(worksheets) + return [row for row in + [handle_sheet(file.name, + sheet, + image_dict, + limit) if worksheets_size == 1 and sheet.title == 'Sheet1' else handle_sheet( + sheet.title, sheet, image_dict, limit) for sheet + in worksheets] if row is not None] + except Exception as e: + return [{'name': file.name, 'content': []}] + + def get_content(self, file, save_image): + pass + + def support(self, file, get_buffer): + file_name: str = file.name.lower() + if file_name.endswith(".xlsx"): + return True + return False diff --git a/apps/common/handle/impl/zip_split_handle.py b/apps/common/handle/impl/zip_split_handle.py new file mode 100644 index 000000000..f5beec141 --- /dev/null +++ b/apps/common/handle/impl/zip_split_handle.py @@ -0,0 +1,164 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: text_split_handle.py + @date:2024/3/27 18:19 + @desc: +""" +import io +import os +import re +import zipfile +from typing import List +from urllib.parse import urljoin + +import uuid_utils.compat as uuid +from charset_normalizer import detect +from django.utils.translation import gettext_lazy as _ + +from common.handle.base_split_handle import BaseSplitHandle +from common.handle.impl.csv_split_handle import CsvSplitHandle +from common.handle.impl.doc_split_handle import DocSplitHandle +from common.handle.impl.html_split_handle import HTMLSplitHandle +from common.handle.impl.pdf_split_handle import PdfSplitHandle +from common.handle.impl.text_split_handle import TextSplitHandle +from common.handle.impl.xls_split_handle import XlsSplitHandle +from common.handle.impl.xlsx_split_handle import XlsxSplitHandle +from common.utils.common import parse_md_image +from knowledge.models import File + + +class FileBufferHandle: + buffer = None + + def get_buffer(self, file): + if self.buffer is None: + self.buffer = file.read() + return self.buffer + + +default_split_handle = TextSplitHandle() +split_handles = [ + HTMLSplitHandle(), + DocSplitHandle(), + PdfSplitHandle(), + XlsxSplitHandle(), + XlsSplitHandle(), + CsvSplitHandle(), + default_split_handle +] + + +def file_to_paragraph(file, pattern_list: List, with_filter: bool, limit: int, save_inner_image): + get_buffer = FileBufferHandle().get_buffer + for split_handle in split_handles: + if split_handle.support(file, get_buffer): + return split_handle.handle(file, pattern_list, with_filter, limit, get_buffer, save_inner_image) + raise Exception(_('Unsupported file format')) + + +def is_valid_uuid(uuid_str: str): + try: + uuid.UUID(uuid_str) + except ValueError: + return False + return True + + +def get_image_list(result_list: list, zip_files: List[str]): + image_file_list = [] + for result in result_list: + for p in result.get('content', []): + content: str = p.get('content', '') + image_list = parse_md_image(content) + for image in image_list: + search = re.search("\(.*\)", image) + if search: + new_image_id = str(uuid.uuid7()) + source_image_path = search.group().replace('(', '').replace(')', '') + source_image_path = source_image_path.strip().split(" ")[0] + image_path = urljoin(result.get('name'), '.' + source_image_path if source_image_path.startswith( + '/') else source_image_path) + if not zip_files.__contains__(image_path): + continue + if image_path.startswith('api/file/') or image_path.startswith('api/image/'): + image_id = image_path.replace('api/file/', '').replace('api/image/', '') + if is_valid_uuid(image_id): + image_file_list.append({'source_file': image_path, + 'image_id': image_id}) + else: + image_file_list.append({'source_file': image_path, + 'image_id': new_image_id}) + content = content.replace(source_image_path, f'/api/image/{new_image_id}') + p['content'] = content + else: + image_file_list.append({'source_file': image_path, + 'image_id': new_image_id}) + content = content.replace(source_image_path, f'/api/image/{new_image_id}') + p['content'] = content + + return image_file_list + + +def get_file_name(file_name): + try: + file_name_code = file_name.encode('cp437') + charset = detect(file_name_code)['encoding'] + return file_name_code.decode(charset) + except Exception as e: + return file_name + + +def filter_image_file(result_list: list, image_list): + image_source_file_list = [image.get('source_file') for image in image_list] + return [r for r in result_list if not image_source_file_list.__contains__(r.get('name', ''))] + + +class ZipSplitHandle(BaseSplitHandle): + def handle(self, file, pattern_list: List, with_filter: bool, limit: int, get_buffer, save_image): + buffer = get_buffer(file) + bytes_io = io.BytesIO(buffer) + result = [] + # 打开zip文件 + with zipfile.ZipFile(bytes_io, 'r') as zip_ref: + # 获取压缩包中的文件名列表 + files = zip_ref.namelist() + # 读取压缩包中的文件内容 + for file in files: + if file.endswith('/') or file.startswith('__MACOSX'): + continue + with zip_ref.open(file) as f: + # 对文件内容进行处理 + try: + # 处理一下文件名 + f.name = get_file_name(f.name) + value = file_to_paragraph(f, pattern_list, with_filter, limit, save_image) + if isinstance(value, list): + result = [*result, *value] + else: + result.append(value) + except Exception: + pass + image_list = get_image_list(result, files) + result = filter_image_file(result, image_list) + image_mode_list = [] + for image in image_list: + with zip_ref.open(image.get('source_file')) as f: + i = File( + id=image.get('image_id'), + image_name=os.path.basename(image.get('source_file')), + meta={'debug': False, 'content': f.read()} # 这里的content是二进制数据 + ) + image_mode_list.append(i) + save_image(image_mode_list) + return result + + def support(self, file, get_buffer): + file_name: str = file.name.lower() + if file_name.endswith(".zip") or file_name.endswith(".ZIP"): + return True + return False + + def get_content(self, file, save_image): + return "" diff --git a/apps/common/utils/common.py b/apps/common/utils/common.py index e4c6c3443..d5d3dac10 100644 --- a/apps/common/utils/common.py +++ b/apps/common/utils/common.py @@ -257,3 +257,9 @@ def post(post_function): return inner + +def parse_md_image(content: str): + matches = re.finditer("!\[.*?\]\(.*?\)", content) + image_list = [match.group() for match in matches] + return image_list + diff --git a/apps/knowledge/api/document.py b/apps/knowledge/api/document.py index 7cc248f71..47a7ddb03 100644 --- a/apps/knowledge/api/document.py +++ b/apps/knowledge/api/document.py @@ -32,3 +32,46 @@ class DocumentCreateAPI(APIMixin): @staticmethod def get_response(): return DocumentCreateResponse + + +class DocumentSplitAPI(APIMixin): + @staticmethod + def get_parameters(): + return [ + OpenApiParameter( + name="workspace_id", + description="工作空间id", + type=OpenApiTypes.STR, + location='path', + required=True, + ), + OpenApiParameter( + name="file", + description="文件", + type=OpenApiTypes.BINARY, + location='query', + required=False, + ), + OpenApiParameter( + name="limit", + description="分段长度", + type=OpenApiTypes.INT, + location='query', + required=False, + ), + OpenApiParameter( + name="patterns", + description="分段正则列表", + type=OpenApiTypes.STR, + location='query', + required=False, + ), + OpenApiParameter( + name="with_filter", + description="是否清除特殊字符", + type=OpenApiTypes.BOOL, + location='query', + required=False, + ), + ] + diff --git a/apps/knowledge/serializers/document.py b/apps/knowledge/serializers/document.py index 789f682d5..5870f3fd0 100644 --- a/apps/knowledge/serializers/document.py +++ b/apps/knowledge/serializers/document.py @@ -13,14 +13,34 @@ from rest_framework import serializers from common.db.search import native_search from common.event import ListenerManagement from common.exception.app_exception import AppApiException +from common.handle.impl.csv_split_handle import CsvSplitHandle +from common.handle.impl.doc_split_handle import DocSplitHandle +from common.handle.impl.html_split_handle import HTMLSplitHandle +from common.handle.impl.pdf_split_handle import PdfSplitHandle +from common.handle.impl.text_split_handle import TextSplitHandle +from common.handle.impl.xls_split_handle import XlsSplitHandle +from common.handle.impl.xlsx_split_handle import XlsxSplitHandle +from common.handle.impl.zip_split_handle import ZipSplitHandle from common.utils.common import post, get_file_content from knowledge.models import Knowledge, Paragraph, Problem, Document, KnowledgeType, ProblemParagraphMapping, State, \ - TaskType + TaskType, File from knowledge.serializers.common import ProblemParagraphManage from knowledge.serializers.paragraph import ParagraphSerializers, ParagraphInstanceSerializer from knowledge.task import embedding_by_document from maxkb.const import PROJECT_DIR +default_split_handle = TextSplitHandle() +split_handles = [ + HTMLSplitHandle(), + DocSplitHandle(), + PdfSplitHandle(), + XlsxSplitHandle(), + XlsSplitHandle(), + CsvSplitHandle(), + ZipSplitHandle(), + default_split_handle +] + class DocumentInstanceSerializer(serializers.Serializer): name = serializers.CharField(required=True, label=_('document name'), max_length=128, min_length=1) @@ -34,6 +54,17 @@ class DocumentCreateRequest(serializers.Serializer): documents = DocumentInstanceSerializer(required=False, many=True) +class DocumentSplitRequest(serializers.Serializer): + file = serializers.ListField(required=True, label=_('file list')) + limit = serializers.IntegerField(required=False, label=_('limit')) + patterns = serializers.ListField( + required=False, + child=serializers.CharField(required=True, label=_('patterns')), + label=_('patterns') + ) + with_filter = serializers.BooleanField(required=False, label=_('Auto Clean')) + + class DocumentSerializers(serializers.Serializer): class Operate(serializers.Serializer): document_id = serializers.UUIDField(required=True, label=_('document id')) @@ -177,3 +208,67 @@ class DocumentSerializers(serializers.Serializer): document_model, instance.get('paragraphs') if 'paragraphs' in instance else [] ) + + class Split(serializers.Serializer): + workspace_id = serializers.CharField(required=True, label=_('workspace id')) + knowledge_id = serializers.UUIDField(required=True, label=_('knowledge id')) + + def is_valid(self, *, raise_exception=True): + super().is_valid(raise_exception=True) + files = self.data.get('file') + for f in files: + if f.size > 1024 * 1024 * 100: + raise AppApiException(500, _( + 'The maximum size of the uploaded file cannot exceed {}MB' + ).format(100)) + + def parse(self, instance): + self.is_valid(raise_exception=True) + DocumentSplitRequest(instance).is_valid(raise_exception=True) + + file_list = instance.get("file") + return reduce( + lambda x, y: [*x, *y], + [self.file_to_paragraph( + f, + instance.get("patterns", None), + instance.get("with_filter", None), + instance.get("limit", 4096) + ) for f in file_list], + [] + ) + + def save_image(self, image_list): + if image_list is not None and len(image_list) > 0: + exist_image_list = [str(i.get('id')) for i in + QuerySet(File).filter(id__in=[i.id for i in image_list]).values('id')] + save_image_list = [image for image in image_list if not exist_image_list.__contains__(str(image.id))] + save_image_list = list({img.id: img for img in save_image_list}.values()) + # save image + for file in save_image_list: + file_bytes = file.meta.pop('content') + file.workspace_id = self.data.get('workspace_id') + file.meta['knowledge_id'] = self.data.get('knowledge_id') + file.save(file_bytes) + + def file_to_paragraph(self, file, pattern_list: List, with_filter: bool, limit: int): + get_buffer = FileBufferHandle().get_buffer + for split_handle in split_handles: + if split_handle.support(file, get_buffer): + result = split_handle.handle(file, pattern_list, with_filter, limit, get_buffer, self.save_image) + if isinstance(result, list): + return result + return [result] + result = default_split_handle.handle(file, pattern_list, with_filter, limit, get_buffer, self.save_image) + if isinstance(result, list): + return result + return [result] + + +class FileBufferHandle: + buffer = None + + def get_buffer(self, file): + if self.buffer is None: + self.buffer = file.read() + return self.buffer diff --git a/apps/knowledge/urls.py b/apps/knowledge/urls.py index 3715e11f7..eaa95c729 100644 --- a/apps/knowledge/urls.py +++ b/apps/knowledge/urls.py @@ -8,5 +8,6 @@ urlpatterns = [ path('workspace//knowledge/base', views.KnowledgeBaseView.as_view()), path('workspace//knowledge/web', views.KnowledgeWebView.as_view()), path('workspace//knowledge/', views.KnowledgeView.Operate.as_view()), + path('workspace//knowledge//document/split', views.DocumentView.Split.as_view()), path('workspace//knowledge//', views.KnowledgeView.Page.as_view()), ] diff --git a/apps/knowledge/views/__init__.py b/apps/knowledge/views/__init__.py index 4a43fce73..ce309a4e5 100644 --- a/apps/knowledge/views/__init__.py +++ b/apps/knowledge/views/__init__.py @@ -1 +1,2 @@ from .knowledge import * +from .document import * diff --git a/apps/knowledge/views/document.py b/apps/knowledge/views/document.py new file mode 100644 index 000000000..fbe276a63 --- /dev/null +++ b/apps/knowledge/views/document.py @@ -0,0 +1,70 @@ +from django.utils.translation import gettext_lazy as _ +from drf_spectacular.utils import extend_schema +from rest_framework.parsers import MultiPartParser +from rest_framework.request import Request +from rest_framework.views import APIView + +from common.auth import TokenAuth +from common.auth.authentication import has_permissions +from common.constants.permission_constants import PermissionConstants, CompareConstants +from common.result import result +from knowledge.api.document import DocumentSplitAPI +from knowledge.api.knowledge import KnowledgeTreeReadAPI +from knowledge.serializers.document import DocumentSerializers +from knowledge.serializers.knowledge import KnowledgeSerializer + + +class DocumentView(APIView): + authentication_classes = [TokenAuth] + + @extend_schema( + methods=['GET'], + description=_('Get document'), + operation_id=_('Get document'), + parameters=KnowledgeTreeReadAPI.get_parameters(), + responses=KnowledgeTreeReadAPI.get_response(), + tags=[_('Knowledge Base')] + ) + @has_permissions(PermissionConstants.DOCUMENT_READ.get_workspace_permission()) + def get(self, request: Request, workspace_id: str): + return result.success(KnowledgeSerializer.Query( + data={ + 'workspace_id': workspace_id, + 'folder_id': request.query_params.get('folder_id'), + 'name': request.query_params.get('name'), + 'desc': request.query_params.get("desc"), + 'user_id': request.query_params.get('user_id') + } + ).list()) + + class Split(APIView): + authentication_classes = [TokenAuth] + parser_classes = [MultiPartParser] + + @extend_schema( + methods=['POST'], + description=_('Segmented document'), + operation_id=_('Segmented document'), + parameters=DocumentSplitAPI.get_parameters(), + request=DocumentSplitAPI.get_request(), + responses=DocumentSplitAPI.get_response(), + tags=[_('Knowledge Base/Documentation')] + ) + @has_permissions([ + PermissionConstants.DOCUMENT_CREATE.get_workspace_permission(), + PermissionConstants.DOCUMENT_EDIT.get_workspace_permission(), + ]) + def post(self, request: Request, workspace_id: str, knowledge_id: str): + split_data = {'file': request.FILES.getlist('file')} + request_data = request.data + if 'patterns' in request.data and request.data.get('patterns') is not None and len( + request.data.get('patterns')) > 0: + split_data.__setitem__('patterns', request_data.getlist('patterns')) + if 'limit' in request.data: + split_data.__setitem__('limit', request_data.get('limit')) + if 'with_filter' in request.data: + split_data.__setitem__('with_filter', request_data.get('with_filter')) + return result.success(DocumentSerializers.Split(data={ + 'workspace_id': workspace_id, + 'knowledge_id': knowledge_id, + }).parse(split_data)) diff --git a/pyproject.toml b/pyproject.toml index 3e8dd05c6..543c02f0a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,12 @@ celery-once = "3.0.1" beautifulsoup4 = "4.13.4" html2text = "2025.4.15" jieba = "0.42.1" +openpyxl = "3.1.5" +python-docx = "1.1.2" +xlrd = "2.0.1" +xlwt = "1.3.0" +pymupdf = "1.24.9" +pypdf = "4.3.1" [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api"