diff --git a/apps/common/constants/permission_constants.py b/apps/common/constants/permission_constants.py index 8251ac39d..22f3c3a5f 100644 --- a/apps/common/constants/permission_constants.py +++ b/apps/common/constants/permission_constants.py @@ -166,13 +166,13 @@ class PermissionConstants(Enum): role_list=[RoleConstants.ADMIN, RoleConstants.USER]) MODEL_DELETE = Permission(group=Group.MODEL, operate=Operate.DELETE, role_list=[RoleConstants.ADMIN, RoleConstants.USER]) - TOOL_MODULE_CREATE = Permission(group=Group.TOOL, operate=Operate.CREATE, role_list=[RoleConstants.ADMIN, + TOOL_FOLDER_CREATE = Permission(group=Group.TOOL, operate=Operate.CREATE, role_list=[RoleConstants.ADMIN, RoleConstants.USER]) - TOOL_MODULE_READ = Permission(group=Group.TOOL, operate=Operate.READ, role_list=[RoleConstants.ADMIN, + TOOL_FOLDER_READ = Permission(group=Group.TOOL, operate=Operate.READ, role_list=[RoleConstants.ADMIN, RoleConstants.USER]) - TOOL_MODULE_EDIT = Permission(group=Group.TOOL, operate=Operate.EDIT, role_list=[RoleConstants.ADMIN, + TOOL_FOLDER_EDIT = Permission(group=Group.TOOL, operate=Operate.EDIT, role_list=[RoleConstants.ADMIN, RoleConstants.USER]) - TOOL_MODULE_DELETE = Permission(group=Group.TOOL, operate=Operate.DELETE, role_list=[RoleConstants.ADMIN, + TOOL_FOLDER_DELETE = Permission(group=Group.TOOL, operate=Operate.DELETE, role_list=[RoleConstants.ADMIN, RoleConstants.USER]) TOOL_CREATE = Permission(group=Group.TOOL, operate=Operate.CREATE, role_list=[RoleConstants.ADMIN, @@ -190,20 +190,20 @@ class PermissionConstants(Enum): TOOL_EXPORT = Permission(group=Group.TOOL, operate=Operate.USE, role_list=[RoleConstants.ADMIN, RoleConstants.USER]) - KNOWLEDGE_MODULE_CREATE = Permission(group=Group.KNOWLEDGE, operate=Operate.CREATE, role_list=[RoleConstants.ADMIN, + KNOWLEDGE_FOLDER_CREATE = Permission(group=Group.KNOWLEDGE, operate=Operate.CREATE, role_list=[RoleConstants.ADMIN, RoleConstants.USER]) - KNOWLEDGE_MODULE_READ = Permission(group=Group.KNOWLEDGE, operate=Operate.READ, role_list=[RoleConstants.ADMIN, + KNOWLEDGE_FOLDER_READ = Permission(group=Group.KNOWLEDGE, operate=Operate.READ, role_list=[RoleConstants.ADMIN, RoleConstants.USER], resource_permission_group_list=[ ResourcePermissionGroup.VIEW ]) - KNOWLEDGE_MODULE_EDIT = Permission(group=Group.KNOWLEDGE, operate=Operate.EDIT, role_list=[RoleConstants.ADMIN, + KNOWLEDGE_FOLDER_EDIT = Permission(group=Group.KNOWLEDGE, operate=Operate.EDIT, role_list=[RoleConstants.ADMIN, RoleConstants.USER], resource_permission_group_list=[ ResourcePermissionGroup.MANAGE ] ) - KNOWLEDGE_MODULE_DELETE = Permission(group=Group.KNOWLEDGE, operate=Operate.DELETE, role_list=[RoleConstants.ADMIN, + KNOWLEDGE_FOLDER_DELETE = Permission(group=Group.KNOWLEDGE, operate=Operate.DELETE, role_list=[RoleConstants.ADMIN, RoleConstants.USER], resource_permission_group_list=[ ResourcePermissionGroup.MANAGE diff --git a/apps/common/utils/fork.py b/apps/common/utils/fork.py new file mode 100644 index 000000000..4405b9b76 --- /dev/null +++ b/apps/common/utils/fork.py @@ -0,0 +1,178 @@ +import copy +import logging +import re +import traceback +from functools import reduce +from typing import List, Set +from urllib.parse import urljoin, urlparse, ParseResult, urlsplit, urlunparse + +import html2text as ht +import requests +from bs4 import BeautifulSoup + +requests.packages.urllib3.disable_warnings() + + +class ChildLink: + def __init__(self, url, tag): + self.url = url + self.tag = copy.deepcopy(tag) + + +class ForkManage: + def __init__(self, base_url: str, selector_list: List[str]): + self.base_url = base_url + self.selector_list = selector_list + + def fork(self, level: int, exclude_link_url: Set[str], fork_handler): + self.fork_child(ChildLink(self.base_url, None), self.selector_list, level, exclude_link_url, fork_handler) + + @staticmethod + def fork_child(child_link: ChildLink, selector_list: List[str], level: int, exclude_link_url: Set[str], + fork_handler): + if level < 0: + return + else: + child_link.url = remove_fragment(child_link.url) + child_url = child_link.url[:-1] if child_link.url.endswith('/') else child_link.url + if not exclude_link_url.__contains__(child_url): + exclude_link_url.add(child_url) + response = Fork(child_link.url, selector_list).fork() + fork_handler(child_link, response) + for child_link in response.child_link_list: + child_url = child_link.url[:-1] if child_link.url.endswith('/') else child_link.url + if not exclude_link_url.__contains__(child_url): + ForkManage.fork_child(child_link, selector_list, level - 1, exclude_link_url, fork_handler) + + +def remove_fragment(url: str) -> str: + parsed_url = urlparse(url) + modified_url = ParseResult(scheme=parsed_url.scheme, netloc=parsed_url.netloc, path=parsed_url.path, + params=parsed_url.params, query=parsed_url.query, fragment=None) + return urlunparse(modified_url) + + +class Fork: + class Response: + def __init__(self, content: str, child_link_list: List[ChildLink], status, message: str): + self.content = content + self.child_link_list = child_link_list + self.status = status + self.message = message + + @staticmethod + def success(html_content: str, child_link_list: List[ChildLink]): + return Fork.Response(html_content, child_link_list, 200, '') + + @staticmethod + def error(message: str): + return Fork.Response('', [], 500, message) + + def __init__(self, base_fork_url: str, selector_list: List[str]): + base_fork_url = remove_fragment(base_fork_url) + self.base_fork_url = urljoin(base_fork_url if base_fork_url.endswith("/") else base_fork_url + '/', '.') + parsed = urlsplit(base_fork_url) + query = parsed.query + self.base_fork_url = self.base_fork_url[:-1] + if query is not None and len(query) > 0: + self.base_fork_url = self.base_fork_url + '?' + query + self.selector_list = [selector for selector in selector_list if selector is not None and len(selector) > 0] + self.urlparse = urlparse(self.base_fork_url) + self.base_url = ParseResult(scheme=self.urlparse.scheme, netloc=self.urlparse.netloc, path='', params='', + query='', + fragment='').geturl() + + def get_child_link_list(self, bf: BeautifulSoup): + pattern = "^((?!(http:|https:|tel:/|#|mailto:|javascript:))|" + self.base_fork_url + "|/).*" + link_list = bf.find_all(name='a', href=re.compile(pattern)) + result = [ChildLink(link.get('href'), link) if link.get('href').startswith(self.base_url) else ChildLink( + self.base_url + link.get('href'), link) for link in link_list] + result = [row for row in result if row.url.startswith(self.base_fork_url)] + return result + + def get_content_html(self, bf: BeautifulSoup): + if self.selector_list is None or len(self.selector_list) == 0: + return str(bf) + params = reduce(lambda x, y: {**x, **y}, + [{'class_': selector.replace('.', '')} if selector.startswith('.') else + {'id': selector.replace("#", "")} if selector.startswith("#") else {'name': selector} for + selector in + self.selector_list], {}) + f = bf.find_all(**params) + return "\n".join([str(row) for row in f]) + + @staticmethod + def reset_url(tag, field, base_fork_url): + field_value: str = tag[field] + if field_value.startswith("/"): + result = urlparse(base_fork_url) + result_url = ParseResult(scheme=result.scheme, netloc=result.netloc, path=field_value, params='', query='', + fragment='').geturl() + else: + result_url = urljoin( + base_fork_url + '/' + (field_value if field_value.endswith('/') else field_value + '/'), + ".") + result_url = result_url[:-1] if result_url.endswith('/') else result_url + tag[field] = result_url + + def reset_beautiful_soup(self, bf: BeautifulSoup): + reset_config_list = [ + { + 'field': 'href', + }, + { + 'field': 'src', + } + ] + for reset_config in reset_config_list: + field = reset_config.get('field') + tag_list = bf.find_all(**{field: re.compile('^(?!(http:|https:|tel:/|#|mailto:|javascript:)).*')}) + for tag in tag_list: + self.reset_url(tag, field, self.base_fork_url) + return bf + + @staticmethod + def get_beautiful_soup(response): + encoding = response.encoding if response.encoding is not None and response.encoding != 'ISO-8859-1' else response.apparent_encoding + html_content = response.content.decode(encoding) + beautiful_soup = BeautifulSoup(html_content, "html.parser") + meta_list = beautiful_soup.find_all('meta') + charset_list = [meta.attrs.get('charset') for meta in meta_list if + meta.attrs is not None and 'charset' in meta.attrs] + if len(charset_list) > 0: + charset = charset_list[0] + if charset != encoding: + try: + html_content = response.content.decode(charset) + except Exception as e: + logging.getLogger("max_kb").error(f'{e}') + return BeautifulSoup(html_content, "html.parser") + return beautiful_soup + + def fork(self): + try: + + headers = { + 'user-agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/99.0.4844.51 Safari/537.36' + } + + logging.getLogger("max_kb").info(f'fork:{self.base_fork_url}') + response = requests.get(self.base_fork_url, verify=False, headers=headers) + if response.status_code != 200: + logging.getLogger("max_kb").error(f"url: {self.base_fork_url} code:{response.status_code}") + return Fork.Response.error(f"url: {self.base_fork_url} code:{response.status_code}") + bf = self.get_beautiful_soup(response) + except Exception as e: + logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}') + return Fork.Response.error(str(e)) + bf = self.reset_beautiful_soup(bf) + link_list = self.get_child_link_list(bf) + content = self.get_content_html(bf) + r = ht.html2text(content) + return Fork.Response.success(r, link_list) + + +def handler(base_url, response: Fork.Response): + print(base_url.url, base_url.tag.text if base_url.tag else None, response.content) + +# ForkManage('https://bbs.fit2cloud.com/c/de/6', ['.md-content']).fork(3, set(), handler) diff --git a/apps/common/utils/split_model.py b/apps/common/utils/split_model.py new file mode 100644 index 000000000..81b253531 --- /dev/null +++ b/apps/common/utils/split_model.py @@ -0,0 +1,417 @@ +# coding=utf-8 +""" + @project: qabot + @Author:虎 + @file: split_model.py + @date:2023/9/1 15:12 + @desc: +""" +import re +from functools import reduce +from typing import List, Dict + +import jieba + + +def get_level_block(text, level_content_list, level_content_index, cursor): + """ + 从文本中获取块数据 + :param text: 文本 + :param level_content_list: 拆分的title数组 + :param level_content_index: 指定的下标 + :param cursor: 开始的下标位置 + :return: 拆分后的文本数据 + """ + start_content: str = level_content_list[level_content_index].get('content') + next_content = level_content_list[level_content_index + 1].get("content") if level_content_index + 1 < len( + level_content_list) else None + start_index = text.index(start_content, cursor) + end_index = text.index(next_content, start_index + 1) if next_content is not None else len(text) + return text[start_index + len(start_content):end_index], end_index + + +def to_tree_obj(content, state='title'): + """ + 转换为树形对象 + :param content: 文本数据 + :param state: 状态: title block + :return: 转换后的数据 + """ + return {'content': content, 'state': state} + + +def remove_special_symbol(str_source: str): + """ + 删除特殊字符 + :param str_source: 需要删除的文本数据 + :return: 删除后的数据 + """ + return str_source + + +def filter_special_symbol(content: dict): + """ + 过滤文本中的特殊字符 + :param content: 需要过滤的对象 + :return: 过滤后返回 + """ + content['content'] = remove_special_symbol(content['content']) + return content + + +def flat(tree_data_list: List[dict], parent_chain: List[dict], result: List[dict]): + """ + 扁平化树形结构数据 + :param tree_data_list: 树形接口数据 + :param parent_chain: 父级数据 传[] 用于递归存储数据 + :param result: 响应数据 传[] 用于递归存放数据 + :return: result 扁平化后的数据 + """ + if parent_chain is None: + parent_chain = [] + if result is None: + result = [] + for tree_data in tree_data_list: + p = parent_chain.copy() + p.append(tree_data) + result.append(to_flat_obj(parent_chain, content=tree_data["content"], state=tree_data["state"])) + children = tree_data.get('children') + if children is not None and len(children) > 0: + flat(children, p, result) + return result + + +def to_paragraph(obj: dict): + """ + 转换为段落 + :param obj: 需要转换的对象 + :return: 段落对象 + """ + content = obj['content'] + return {"keywords": get_keyword(content), + 'parent_chain': list(map(lambda p: p['content'], obj['parent_chain'])), + 'content': ",".join(list(map(lambda p: p['content'], obj['parent_chain']))) + content} + + +def get_keyword(content: str): + """ + 获取content中的关键词 + :param content: 文本 + :return: 关键词数组 + """ + stopwords = [':', '“', '!', '”', '\n', '\\s'] + cutworms = jieba.lcut(content) + return list(set(list(filter(lambda k: (k not in stopwords) | len(k) > 1, cutworms)))) + + +def titles_to_paragraph(list_title: List[dict]): + """ + 将同一父级的title转换为块段落 + :param list_title: 同父级title + :return: 块段落 + """ + if len(list_title) > 0: + content = "\n,".join( + list(map(lambda d: d['content'].strip("\r\n").strip("\n").strip("\\s"), list_title))) + + return {'keywords': '', + 'parent_chain': list( + map(lambda p: p['content'].strip("\r\n").strip("\n").strip("\\s"), list_title[0]['parent_chain'])), + 'content': ",".join(list( + map(lambda p: p['content'].strip("\r\n").strip("\n").strip("\\s"), + list_title[0]['parent_chain']))) + content} + return None + + +def parse_group_key(level_list: List[dict]): + """ + 将同级别同父级的title生成段落,加上本身的段落数据形成新的数据 + :param level_list: title n 级数据 + :return: 根据title生成的数据 + 段落数据 + """ + result = [] + group_data = group_by(list(filter(lambda f: f['state'] == 'title' and len(f['parent_chain']) > 0, level_list)), + key=lambda d: ",".join(list(map(lambda p: p['content'], d['parent_chain'])))) + result += list(map(lambda group_data_key: titles_to_paragraph(group_data[group_data_key]), group_data)) + result += list(map(to_paragraph, list(filter(lambda f: f['state'] == 'block', level_list)))) + return result + + +def to_block_paragraph(tree_data_list: List[dict]): + """ + 转换为块段落对象 + :param tree_data_list: 树数据 + :return: 块段落 + """ + flat_list = flat(tree_data_list, [], []) + level_group_dict: dict = group_by(flat_list, key=lambda f: f['level']) + return list(map(lambda level: parse_group_key(level_group_dict[level]), level_group_dict)) + + +def parse_title_level(text, content_level_pattern: List, index): + if index >= len(content_level_pattern): + return [] + result = parse_level(text, content_level_pattern[index]) + if len(result) == 0 and len(content_level_pattern) > index: + return parse_title_level(text, content_level_pattern, index + 1) + return result + + +def parse_level(text, pattern: str): + """ + 获取正则匹配到的文本 + :param text: 需要匹配的文本 + :param pattern: 正则 + :return: 符合正则的文本 + """ + level_content_list = list(map(to_tree_obj, [r[0:255] for r in re_findall(pattern, text) if r is not None])) + return list(map(filter_special_symbol, level_content_list)) + + +def re_findall(pattern, text): + result = re.findall(pattern, text, flags=0) + return list(filter(lambda r: r is not None and len(r) > 0, reduce(lambda x, y: [*x, *y], list( + map(lambda row: [*(row if isinstance(row, tuple) else [row])], result)), + []))) + + +def to_flat_obj(parent_chain: List[dict], content: str, state: str): + """ + 将树形属性转换为扁平对象 + :param parent_chain: + :param content: + :param state: + :return: + """ + return {'parent_chain': parent_chain, 'level': len(parent_chain), "content": content, 'state': state} + + +def flat_map(array: List[List]): + """ + 将二位数组转为一维数组 + :param array: 二维数组 + :return: 一维数组 + """ + result = [] + for e in array: + result += e + return result + + +def group_by(list_source: List, key): + """ + 將數組分組 + :param list_source: 需要分組的數組 + :param key: 分組函數 + :return: key->[] + """ + result = {} + for e in list_source: + k = key(e) + array = result.get(k) if k in result else [] + array.append(e) + result[k] = array + return result + + +def result_tree_to_paragraph(result_tree: List[dict], result, parent_chain, with_filter: bool): + """ + 转换为分段对象 + :param result_tree: 解析文本的树 + :param result: 传[] 用于递归 + :param parent_chain: 传[] 用户递归存储数据 + :param with_filter: 是否过滤block + :return: List[{'problem':'xx','content':'xx'}] + """ + for item in result_tree: + if item.get('state') == 'block': + result.append({'title': " ".join(parent_chain), + 'content': filter_special_char(item.get("content")) if with_filter else item.get("content")}) + children = item.get("children") + if children is not None and len(children) > 0: + result_tree_to_paragraph(children, result, + [*parent_chain, remove_special_symbol(item.get('content'))], with_filter) + return result + + +def post_handler_paragraph(content: str, limit: int): + """ + 根据文本的最大字符分段 + :param content: 需要分段的文本字段 + :param limit: 最大分段字符 + :return: 分段后数据 + """ + result = [] + temp_char, start = '', 0 + while (pos := content.find("\n", start)) != -1: + split, start = content[start:pos + 1], pos + 1 + if len(temp_char + split) > limit: + if len(temp_char) > 4096: + pass + result.append(temp_char) + temp_char = '' + temp_char = temp_char + split + temp_char = temp_char + content[start:] + if len(temp_char) > 0: + if len(temp_char) > 4096: + pass + result.append(temp_char) + + pattern = "[\\S\\s]{1," + str(limit) + '}' + # 如果\n 单段超过限制,则继续拆分 + return reduce(lambda x, y: [*x, *y], map(lambda row: re.findall(pattern, row), result), []) + + +replace_map = { + re.compile('\n+'): '\n', + re.compile(' +'): ' ', + re.compile('#+'): "", + re.compile("\t+"): '' +} + + +def filter_special_char(content: str): + """ + 过滤特殊字段 + :param content: 文本 + :return: 过滤后字段 + """ + items = replace_map.items() + for key, value in items: + content = re.sub(key, value, content) + return content + + +class SplitModel: + + def __init__(self, content_level_pattern, with_filter=True, limit=100000): + self.content_level_pattern = content_level_pattern + self.with_filter = with_filter + if limit is None or limit > 100000: + limit = 100000 + if limit < 50: + limit = 50 + self.limit = limit + + def parse_to_tree(self, text: str, index=0): + """ + 解析文本 + :param text: 需要解析的文本 + :param index: 从那个正则开始解析 + :return: 解析后的树形结果数据 + """ + level_content_list = parse_title_level(text, self.content_level_pattern, index) + if len(level_content_list) == 0: + return [to_tree_obj(row, 'block') for row in post_handler_paragraph(text, limit=self.limit)] + if index == 0 and text.lstrip().index(level_content_list[0]["content"].lstrip()) != 0: + level_content_list.insert(0, to_tree_obj("")) + + cursor = 0 + level_title_content_list = [item for item in level_content_list if item.get('state') == 'title'] + for i in range(len(level_title_content_list)): + start_content: str = level_title_content_list[i].get('content') + if cursor < text.index(start_content, cursor): + for row in post_handler_paragraph(text[cursor: text.index(start_content, cursor)], limit=self.limit): + level_content_list.insert(0, to_tree_obj(row, 'block')) + + block, cursor = get_level_block(text, level_title_content_list, i, cursor) + if len(block) == 0: + continue + children = self.parse_to_tree(text=block, index=index + 1) + level_title_content_list[i]['children'] = children + first_child_idx_in_block = block.lstrip().index(children[0]["content"].lstrip()) + if first_child_idx_in_block != 0: + inner_children = self.parse_to_tree(block[:first_child_idx_in_block], index + 1) + level_title_content_list[i]['children'].extend(inner_children) + return level_content_list + + def parse(self, text: str): + """ + 解析文本 + :param text: 文本数据 + :return: 解析后数据 {content:段落数据,keywords:[‘段落关键词’],parent_chain:['段落父级链路']} + """ + text = text.replace('\r\n', '\n') + text = text.replace('\r', '\n') + text = text.replace("\0", '') + result_tree = self.parse_to_tree(text, 0) + result = result_tree_to_paragraph(result_tree, [], [], self.with_filter) + for e in result: + if len(e['content']) > 4096: + pass + title_list = list(set([row.get('title') for row in result])) + return [item for item in [self.post_reset_paragraph(row, title_list) for row in result] if + 'content' in item and len(item.get('content').strip()) > 0] + + def post_reset_paragraph(self, paragraph: Dict, title_list: List[str]): + result = self.content_is_null(paragraph, title_list) + result = self.filter_title_special_characters(result) + result = self.sub_title(result) + return result + + @staticmethod + def sub_title(paragraph: Dict): + if 'title' in paragraph: + title = paragraph.get('title') + if len(title) > 255: + return {**paragraph, 'title': title[0:255], 'content': title[255:len(title)] + paragraph.get('content')} + return paragraph + + @staticmethod + def content_is_null(paragraph: Dict, title_list: List[str]): + if 'title' in paragraph: + title = paragraph.get('title') + content = paragraph.get('content') + if (content is None or len(content.strip()) == 0) and (title is not None and len(title) > 0): + find = [t for t in title_list if t.__contains__(title) and t != title] + if find: + return {'title': '', 'content': ''} + return {'title': '', 'content': title} + return paragraph + + @staticmethod + def filter_title_special_characters(paragraph: Dict): + title = paragraph.get('title') if 'title' in paragraph else '' + for title_special_characters in title_special_characters_list: + title = title.replace(title_special_characters, '') + return {**paragraph, + 'title': title} + + +title_special_characters_list = ['#', '\n', '\r', '\\s'] + +default_split_pattern = { + 'md': [re.compile('(?<=^)# .*|(?<=\\n)# .*'), + re.compile('(?<=\\n)(? MODULE_DEPTH: - raise serializers.ValidationError(_('Module depth cannot exceed 3 levels')) + if depth + current_depth > FOLDER_DEPTH: + raise serializers.ValidationError(_('Folder depth cannot exceed 3 levels')) def get_max_depth(current_node): @@ -68,10 +68,10 @@ def get_max_depth(current_node): return max_depth -class ModuleSerializer(serializers.Serializer): - id = serializers.CharField(required=True, label=_('module id')) - name = serializers.CharField(required=True, label=_('module name')) - user_id = serializers.CharField(required=True, label=_('module user id')) +class FolderSerializer(serializers.Serializer): + id = serializers.CharField(required=True, label=_('folder id')) + name = serializers.CharField(required=True, label=_('folder name')) + user_id = serializers.CharField(required=True, label=_('folder user id')) workspace_id = serializers.CharField(required=False, allow_null=True, allow_blank=True, label=_('workspace id')) parent_id = serializers.CharField(required=False, allow_null=True, allow_blank=True, label=_('parent id')) @@ -82,84 +82,84 @@ class ModuleSerializer(serializers.Serializer): def insert(self, instance, with_valid=True): if with_valid: self.is_valid(raise_exception=True) - ModuleCreateRequest(data=instance).is_valid(raise_exception=True) + FolderCreateRequest(data=instance).is_valid(raise_exception=True) workspace_id = self.data.get('workspace_id', 'default') parent_id = instance.get('parent_id', 'root') name = instance.get('name') - Module = get_module_type(self.data.get('source')) - if QuerySet(Module).filter(name=name, workspace_id=workspace_id, parent_id=parent_id).exists(): - raise serializers.ValidationError(_('Module name already exists')) - # Module 不能超过3层 + Folder = get_folder_type(self.data.get('source')) + if QuerySet(Folder).filter(name=name, workspace_id=workspace_id, parent_id=parent_id).exists(): + raise serializers.ValidationError(_('Folder name already exists')) + # Folder 不能超过3层 check_depth(self.data.get('source'), parent_id) - module = Module( + folder = Folder( id=uuid.uuid7(), name=instance.get('name'), user_id=self.data.get('user_id'), workspace_id=workspace_id, parent_id=parent_id ) - module.save() - return ModuleSerializer(module).data + folder.save() + return FolderSerializer(folder).data class Operate(serializers.Serializer): - id = serializers.CharField(required=True, label=_('module id')) + id = serializers.CharField(required=True, label=_('folder id')) workspace_id = serializers.CharField(required=True, allow_null=True, allow_blank=True, label=_('workspace id')) source = serializers.CharField(required=True, label=_('source')) @transaction.atomic def edit(self, instance): self.is_valid(raise_exception=True) - Module = get_module_type(self.data.get('source')) + Folder = get_folder_type(self.data.get('source')) current_id = self.data.get('id') - current_node = Module.objects.get(id=current_id) + current_node = Folder.objects.get(id=current_id) if current_node is None: - raise serializers.ValidationError(_('Module does not exist')) + raise serializers.ValidationError(_('Folder does not exist')) edit_field_list = ['name'] edit_dict = {field: instance.get(field) for field in edit_field_list if ( field in instance and instance.get(field) is not None)} - QuerySet(Module).filter(id=current_id).update(**edit_dict) + QuerySet(Folder).filter(id=current_id).update(**edit_dict) # 模块间的移动 parent_id = instance.get('parent_id') if parent_id is not None and current_id != 'root': - # Module 不能超过3层 + # Folder 不能超过3层 current_depth = get_max_depth(current_node) check_depth(self.data.get('source'), parent_id, current_depth) - parent = Module.objects.get(id=parent_id) + parent = Folder.objects.get(id=parent_id) current_node.move_to(parent) return self.one() def one(self): self.is_valid(raise_exception=True) - Module = get_module_type(self.data.get('source')) - module = QuerySet(Module).filter(id=self.data.get('id')).first() - return ModuleSerializer(module).data + Folder = get_folder_type(self.data.get('source')) + folder = QuerySet(Folder).filter(id=self.data.get('id')).first() + return FolderSerializer(folder).data def delete(self): self.is_valid(raise_exception=True) if self.data.get('id') == 'root': - raise serializers.ValidationError(_('Cannot delete root module')) - Module = get_module_type(self.data.get('source')) - QuerySet(Module).filter(id=self.data.get('id')).delete() + raise serializers.ValidationError(_('Cannot delete root folder')) + Folder = get_folder_type(self.data.get('source')) + QuerySet(Folder).filter(id=self.data.get('id')).delete() -class ModuleTreeSerializer(serializers.Serializer): +class FolderTreeSerializer(serializers.Serializer): workspace_id = serializers.CharField(required=True, allow_null=True, allow_blank=True, label=_('workspace id')) source = serializers.CharField(required=True, label=_('source')) - def get_module_tree(self, name=None): + def get_folder_tree(self, name=None): self.is_valid(raise_exception=True) - Module = get_module_type(self.data.get('source')) + Folder = get_folder_type(self.data.get('source')) if name is not None: - nodes = Module.objects.filter(Q(workspace_id=self.data.get('workspace_id')) & + nodes = Folder.objects.filter(Q(workspace_id=self.data.get('workspace_id')) & Q(name__contains=name)).get_cached_trees() else: - nodes = Module.objects.filter(Q(workspace_id=self.data.get('workspace_id'))).get_cached_trees() - serializer = ToolModuleTreeSerializer(nodes, many=True) + nodes = Folder.objects.filter(Q(workspace_id=self.data.get('workspace_id'))).get_cached_trees() + serializer = ToolFolderTreeSerializer(nodes, many=True) return serializer.data # 这是可序列化的字典 diff --git a/apps/folders/urls.py b/apps/folders/urls.py new file mode 100644 index 000000000..5ced9cbbd --- /dev/null +++ b/apps/folders/urls.py @@ -0,0 +1,9 @@ +from django.urls import path + +from . import views + +app_name = "folder" +urlpatterns = [ + path('workspace///folder', views.FolderView.as_view()), + path('workspace///folder/', views.FolderView.Operate.as_view()), +] diff --git a/apps/folders/views/__init__.py b/apps/folders/views/__init__.py new file mode 100644 index 000000000..c5e517b86 --- /dev/null +++ b/apps/folders/views/__init__.py @@ -0,0 +1 @@ +from .folder import * diff --git a/apps/modules/views/module.py b/apps/folders/views/folder.py similarity index 56% rename from apps/modules/views/module.py rename to apps/folders/views/folder.py index da67e90c0..9786b1143 100644 --- a/apps/modules/views/module.py +++ b/apps/folders/views/folder.py @@ -7,26 +7,26 @@ from common.auth import TokenAuth from common.auth.authentication import has_permissions from common.constants.permission_constants import Permission, Group, Operate from common.result import result -from modules.api.module import ModuleCreateAPI, ModuleEditAPI, ModuleReadAPI, ModuleTreeReadAPI, ModuleDeleteAPI -from modules.serializers.module import ModuleSerializer, ModuleTreeSerializer +from folders.api.folder import FolderCreateAPI, FolderEditAPI, FolderReadAPI, FolderTreeReadAPI, FolderDeleteAPI +from folders.serializers.folder import FolderSerializer, FolderTreeSerializer -class ModuleView(APIView): +class FolderView(APIView): authentication_classes = [TokenAuth] @extend_schema( methods=['POST'], - description=_('Create module'), - operation_id=_('Create module'), - parameters=ModuleCreateAPI.get_parameters(), - request=ModuleCreateAPI.get_request(), - responses=ModuleCreateAPI.get_response(), - tags=[_('Module')] + description=_('Create folder'), + operation_id=_('Create folder'), + parameters=FolderCreateAPI.get_parameters(), + request=FolderCreateAPI.get_request(), + responses=FolderCreateAPI.get_response(), + tags=[_('Folder')] ) @has_permissions(lambda r, kwargs: Permission(group=Group(kwargs.get('source')), operate=Operate.CREATE, resource_path=f"/WORKSPACE/{kwargs.get('workspace_id')}")) def post(self, request: Request, workspace_id: str, source: str): - return result.success(ModuleSerializer.Create( + return result.success(FolderSerializer.Create( data={'user_id': request.user.id, 'source': source, 'workspace_id': workspace_id} @@ -34,64 +34,64 @@ class ModuleView(APIView): @extend_schema( methods=['GET'], - description=_('Get module tree'), - operation_id=_('Get module tree'), - parameters=ModuleTreeReadAPI.get_parameters(), - responses=ModuleTreeReadAPI.get_response(), - tags=[_('Module')] + description=_('Get folder tree'), + operation_id=_('Get folder tree'), + parameters=FolderTreeReadAPI.get_parameters(), + responses=FolderTreeReadAPI.get_response(), + tags=[_('Folder')] ) @has_permissions(lambda r, kwargs: Permission(group=Group(kwargs.get('source')), operate=Operate.READ, resource_path=f"/WORKSPACE/{kwargs.get('workspace_id')}")) def get(self, request: Request, workspace_id: str, source: str): - return result.success(ModuleTreeSerializer( + return result.success(FolderTreeSerializer( data={'workspace_id': workspace_id, 'source': source} - ).get_module_tree(request.query_params.get('name'))) + ).get_folder_tree(request.query_params.get('name'))) class Operate(APIView): authentication_classes = [TokenAuth] @extend_schema( methods=['PUT'], - description=_('Update module'), - operation_id=_('Update module'), - parameters=ModuleEditAPI.get_parameters(), - request=ModuleEditAPI.get_request(), - responses=ModuleEditAPI.get_response(), - tags=[_('Module')] + description=_('Update folder'), + operation_id=_('Update folder'), + parameters=FolderEditAPI.get_parameters(), + request=FolderEditAPI.get_request(), + responses=FolderEditAPI.get_response(), + tags=[_('Folder')] ) @has_permissions(lambda r, kwargs: Permission(group=Group(kwargs.get('source')), operate=Operate.EDIT, resource_path=f"/WORKSPACE/{kwargs.get('workspace_id')}")) - def put(self, request: Request, workspace_id: str, source: str, module_id: str): - return result.success(ModuleSerializer.Operate( - data={'id': module_id, 'workspace_id': workspace_id, 'source': source} + def put(self, request: Request, workspace_id: str, source: str, folder_id: str): + return result.success(FolderSerializer.Operate( + data={'id': folder_id, 'workspace_id': workspace_id, 'source': source} ).edit(request.data)) @extend_schema( methods=['GET'], - description=_('Get module'), - operation_id=_('Get module'), - parameters=ModuleReadAPI.get_parameters(), - responses=ModuleReadAPI.get_response(), - tags=[_('Module')] + description=_('Get folder'), + operation_id=_('Get folder'), + parameters=FolderReadAPI.get_parameters(), + responses=FolderReadAPI.get_response(), + tags=[_('Folder')] ) @has_permissions(lambda r, kwargs: Permission(group=Group(kwargs.get('source')), operate=Operate.READ, resource_path=f"/WORKSPACE/{kwargs.get('workspace_id')}")) - def get(self, request: Request, workspace_id: str, source: str, module_id: str): - return result.success(ModuleSerializer.Operate( - data={'id': module_id, 'workspace_id': workspace_id, 'source': source} + def get(self, request: Request, workspace_id: str, source: str, folder_id: str): + return result.success(FolderSerializer.Operate( + data={'id': folder_id, 'workspace_id': workspace_id, 'source': source} ).one()) @extend_schema( methods=['DELETE'], - description=_('Delete module'), - operation_id=_('Delete module'), - parameters=ModuleDeleteAPI.get_parameters(), - responses=ModuleDeleteAPI.get_response(), - tags=[_('Module')] + description=_('Delete folder'), + operation_id=_('Delete folder'), + parameters=FolderDeleteAPI.get_parameters(), + responses=FolderDeleteAPI.get_response(), + tags=[_('Folder')] ) @has_permissions(lambda r, kwargs: Permission(group=Group(kwargs.get('source')), operate=Operate.DELETE, resource_path=f"/WORKSPACE/{kwargs.get('workspace_id')}")) - def delete(self, request: Request, workspace_id: str, source: str, module_id: str): - return result.success(ModuleSerializer.Operate( - data={'id': module_id, 'workspace_id': workspace_id, 'source': source} + def delete(self, request: Request, workspace_id: str, source: str, folder_id: str): + return result.success(FolderSerializer.Operate( + data={'id': folder_id, 'workspace_id': workspace_id, 'source': source} ).delete()) diff --git a/apps/knowledge/api/knowledge.py b/apps/knowledge/api/knowledge.py index 3199cfc0a..1caecfac8 100644 --- a/apps/knowledge/api/knowledge.py +++ b/apps/knowledge/api/knowledge.py @@ -67,8 +67,8 @@ class KnowledgeTreeReadAPI(APIMixin): required=True, ), OpenApiParameter( - name="module_id", - description="模块id", + name="folder_id", + description="文件夹id", type=OpenApiTypes.STR, location='query', required=False, diff --git a/apps/knowledge/migrations/0001_initial.py b/apps/knowledge/migrations/0001_initial.py index bec7d6646..d206fc38b 100644 --- a/apps/knowledge/migrations/0001_initial.py +++ b/apps/knowledge/migrations/0001_initial.py @@ -6,12 +6,12 @@ import mptt.fields import uuid_utils.compat from django.db import migrations, models -from knowledge.models import KnowledgeModule +from knowledge.models import KnowledgeFolder def insert_default_data(apps, schema_editor): # 创建一个根模块(没有父节点) - KnowledgeModule.objects.create(id='root', name='根目录', user_id='f0dd8f71-e4ee-11ee-8c84-a8a1595801ab') + KnowledgeFolder.objects.create(id='root', name='根目录', user_id='f0dd8f71-e4ee-11ee-8c84-a8a1595801ab') class Migration(migrations.Migration): @@ -42,7 +42,7 @@ class Migration(migrations.Migration): }, ), migrations.CreateModel( - name='KnowledgeModule', + name='KnowledgeFolder', fields=[ ('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')), ('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')), @@ -57,12 +57,12 @@ class Migration(migrations.Migration): ('level', models.PositiveIntegerField(editable=False)), ('parent', mptt.fields.TreeForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, - related_name='children', to='knowledge.knowledgemodule')), + related_name='children', to='knowledge.knowledgefolder')), ('user', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='users.user', verbose_name='用户id')), ], options={ - 'db_table': 'knowledge_module', + 'db_table': 'knowledge_folder', }, ), migrations.CreateModel( @@ -84,10 +84,10 @@ class Migration(migrations.Migration): ('scope', models.CharField(choices=[('SHARED', '共享'), ('WORKSPACE', '工作空间可用')], default='WORKSPACE', max_length=20, verbose_name='可用范围')), - ('module', + ('folder', models.ForeignKey(default='root', on_delete=django.db.models.deletion.CASCADE, - to='knowledge.knowledgemodule', - verbose_name='模块id')), + to='knowledge.knowledgefolder', + verbose_name='文件夹id')), ('embedding_model', models.ForeignKey(default=knowledge.models.knowledge.default_model, on_delete=django.db.models.deletion.DO_NOTHING, to='models_provider.model', verbose_name='向量模型')), diff --git a/apps/knowledge/models/knowledge.py b/apps/knowledge/models/knowledge.py index 2facd2956..dfa13c931 100644 --- a/apps/knowledge/models/knowledge.py +++ b/apps/knowledge/models/knowledge.py @@ -28,7 +28,7 @@ def default_model(): return uuid.UUID('42f63a3d-427e-11ef-b3ec-a8a1595801ab') -class KnowledgeModule(MPTTModel, AppModelMixin): +class KnowledgeFolder(MPTTModel, AppModelMixin): id = models.CharField(primary_key=True, max_length=64, editable=False, verbose_name="主键id") name = models.CharField(max_length=64, verbose_name="文件夹名称") user = models.ForeignKey(User, on_delete=models.DO_NOTHING, verbose_name="用户id") @@ -36,7 +36,7 @@ class KnowledgeModule(MPTTModel, AppModelMixin): parent = TreeForeignKey('self', on_delete=models.CASCADE, null=True, blank=True, related_name='children') class Meta: - db_table = "knowledge_module" + db_table = "knowledge_folder" class MPTTMeta: order_insertion_by = ['name'] @@ -51,7 +51,7 @@ class Knowledge(AppModelMixin): user = models.ForeignKey(User, on_delete=models.DO_NOTHING, verbose_name="所属用户") type = models.IntegerField(verbose_name='类型', choices=KnowledgeType.choices, default=KnowledgeType.BASE) scope = models.CharField(max_length=20, verbose_name='可用范围', choices=KnowledgeScope.choices, default=KnowledgeScope.WORKSPACE) - module = models.ForeignKey(KnowledgeModule, on_delete=models.CASCADE, verbose_name="模块id", default='root') + folder = models.ForeignKey(KnowledgeFolder, on_delete=models.CASCADE, verbose_name="文件夹id", default='root') embedding_model = models.ForeignKey(Model, on_delete=models.DO_NOTHING, verbose_name="向量模型", default=default_model) meta = models.JSONField(verbose_name="元数据", default=dict) diff --git a/apps/knowledge/serializers/knowledge.py b/apps/knowledge/serializers/knowledge.py index cfacccef8..75e6e45f2 100644 --- a/apps/knowledge/serializers/knowledge.py +++ b/apps/knowledge/serializers/knowledge.py @@ -1,25 +1,36 @@ +from typing import Dict + import uuid_utils as uuid +from django.db import transaction +from django.db.models import QuerySet from django.utils.translation import gettext_lazy as _ from rest_framework import serializers +from common.exception.app_exception import AppApiException +from common.utils.common import valid_license from knowledge.models import Knowledge, KnowledgeScope, KnowledgeType class KnowledgeModelSerializer(serializers.ModelSerializer): class Meta: model = Knowledge - fields = ['id', 'name', 'desc', 'meta', 'module_id', 'type', 'workspace_id', 'create_time', 'update_time'] + fields = ['id', 'name', 'desc', 'meta', 'folder_id', 'type', 'workspace_id', 'create_time', 'update_time'] class KnowledgeBaseCreateRequest(serializers.Serializer): name = serializers.CharField(required=True, label=_('knowledge name')) + folder_id = serializers.CharField(required=True, label=_('folder id')) desc = serializers.CharField(required=False, allow_null=True, allow_blank=True, label=_('knowledge description')) embedding = serializers.CharField(required=True, label=_('knowledge embedding')) + class KnowledgeWebCreateRequest(serializers.Serializer): name = serializers.CharField(required=True, label=_('knowledge name')) + folder_id = serializers.CharField(required=True, label=_('folder id')) desc = serializers.CharField(required=False, allow_null=True, allow_blank=True, label=_('knowledge description')) embedding = serializers.CharField(required=True, label=_('knowledge embedding')) + source_url = serializers.CharField(required=True, label=_('source url')) + selector = serializers.CharField(required=True, label=_('knowledge selector')) class KnowledgeSerializer(serializers.Serializer): @@ -27,10 +38,19 @@ class KnowledgeSerializer(serializers.Serializer): user_id = serializers.UUIDField(required=True, label=_('user id')) workspace_id = serializers.CharField(required=True, label=_('workspace id')) - def insert(self, instance, with_valid=True): + @valid_license(model=Knowledge, count=50, + message=_( + 'The community version supports up to 50 knowledge bases. If you need more knowledge bases, please contact us (https://fit2cloud.com/).')) + # @post(post_function=post_embedding_dataset) + @transaction.atomic + def save_base(self, instance, with_valid=True): if with_valid: self.is_valid(raise_exception=True) KnowledgeBaseCreateRequest(data=instance).is_valid(raise_exception=True) + if QuerySet(Knowledge).filter(workspace_id=self.data.get('workspace_id'), + name=instance.get('name')).exists(): + raise AppApiException(500, _('Knowledge base name duplicate!')) + knowledge = Knowledge( id=uuid.uuid7(), name=instance.get('name'), @@ -39,13 +59,43 @@ class KnowledgeSerializer(serializers.Serializer): type=instance.get('type', KnowledgeType.BASE), user_id=self.data.get('user_id'), scope=KnowledgeScope.WORKSPACE, - module_id=instance.get('module_id', 'root'), + folder_id=instance.get('folder_id', 'root'), embedding_model_id=instance.get('embedding'), meta=instance.get('meta', {}), ) knowledge.save() return KnowledgeModelSerializer(knowledge).data + def save_web(self, instance: Dict, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + KnowledgeWebCreateRequest(data=instance).is_valid(raise_exception=True) + + if QuerySet(Knowledge).filter(workspace_id=self.data.get('workspace_id'), + name=instance.get('name')).exists(): + raise AppApiException(500, _('Knowledge base name duplicate!')) + + knowledge_id = uuid.uuid7() + knowledge = Knowledge( + id=knowledge_id, + name=instance.get('name'), + desc=instance.get('desc'), + user_id=self.data.get('user_id'), + type=instance.get('type', KnowledgeType.WEB), + scope=KnowledgeScope.WORKSPACE, + folder_id=instance.get('folder_id', 'root'), + embedding_model_id=instance.get('embedding'), + meta={ + 'source_url': instance.get('source_url'), + 'selector': instance.get('selector'), + 'embedding_model_id': instance.get('embedding') + }, + ) + knowledge.save() + # sync_web_knowledge.delay(str(knowledge_id), instance.get('source_url'), instance.get('selector')) + return {**KnowledgeModelSerializer(knowledge).data, + 'document_list': []} + class KnowledgeTreeSerializer(serializers.Serializer): def get_knowledge_list(self, param): diff --git a/apps/knowledge/task/__init__.py b/apps/knowledge/task/__init__.py new file mode 100644 index 000000000..4cc48c901 --- /dev/null +++ b/apps/knowledge/task/__init__.py @@ -0,0 +1 @@ +from .sync import * diff --git a/apps/knowledge/task/sync.py b/apps/knowledge/task/sync.py new file mode 100644 index 000000000..99ad15c92 --- /dev/null +++ b/apps/knowledge/task/sync.py @@ -0,0 +1,63 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: sync.py + @date:2024/8/20 21:37 + @desc: +""" + +import logging +import traceback +from typing import List + +from celery_once import QueueOnce + +from common.utils.fork import ForkManage, Fork +from .tools import get_save_handler, get_sync_web_document_handler, get_sync_handler + +from ops import celery_app +from django.utils.translation import gettext_lazy as _ + +max_kb_error = logging.getLogger("max_kb_error") +max_kb = logging.getLogger("max_kb") + + +@celery_app.task(base=QueueOnce, once={'keys': ['knowledge_id']}, name='celery:sync_web_knowledge') +def sync_web_knowledge(knowledge_id: str, url: str, selector: str): + try: + max_kb.info( + _('Start--->Start synchronization web knowledge base:{knowledge_id}').format(knowledge_id=knowledge_id)) + ForkManage(url, selector.split(" ") if selector is not None else []).fork(2, set(), + get_save_handler(knowledge_id, + selector)) + + max_kb.info(_('End--->End synchronization web knowledge base:{knowledge_id}').format(knowledge_id=knowledge_id)) + except Exception as e: + max_kb_error.error(_('Synchronize web knowledge base:{knowledge_id} error{error}{traceback}').format( + knowledge_id=knowledge_id, error=str(e), traceback=traceback.format_exc())) + + +@celery_app.task(base=QueueOnce, once={'keys': ['knowledge_id']}, name='celery:sync_replace_web_knowledge') +def sync_replace_web_knowledge(knowledge_id: str, url: str, selector: str): + try: + max_kb.info( + _('Start--->Start synchronization web knowledge base:{knowledge_id}').format(knowledge_id=knowledge_id)) + ForkManage(url, selector.split(" ") if selector is not None else []).fork(2, set(), + get_sync_handler(knowledge_id + )) + max_kb.info(_('End--->End synchronization web knowledge base:{knowledge_id}').format(knowledge_id=knowledge_id)) + except Exception as e: + max_kb_error.error(_('Synchronize web knowledge base:{knowledge_id} error{error}{traceback}').format( + knowledge_id=knowledge_id, error=str(e), traceback=traceback.format_exc())) + + +@celery_app.task(name='celery:sync_web_document') +def sync_web_document(knowledge_id, source_url_list: List[str], selector: str): + handler = get_sync_web_document_handler(knowledge_id) + for source_url in source_url_list: + try: + result = Fork(base_fork_url=source_url, selector_list=selector.split(' ')).fork() + handler(source_url, selector, result) + except Exception as e: + pass diff --git a/apps/knowledge/task/tools.py b/apps/knowledge/task/tools.py new file mode 100644 index 000000000..6d624c24b --- /dev/null +++ b/apps/knowledge/task/tools.py @@ -0,0 +1,114 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: tools.py + @date:2024/8/20 21:48 + @desc: +""" + +import logging +import re +import traceback + +from django.db.models import QuerySet + +from common.utils.fork import ChildLink, Fork +from common.utils.split_model import get_split_model +from knowledge.models.knowledge import KnowledgeType, Document, DataSet, Status +from django.utils.translation import gettext_lazy as _ + +max_kb_error = logging.getLogger("max_kb_error") +max_kb = logging.getLogger("max_kb") + + +def get_save_handler(dataset_id, selector): + from knowledge.serializers.document_serializers import DocumentSerializers + + def handler(child_link: ChildLink, response: Fork.Response): + if response.status == 200: + try: + document_name = child_link.tag.text if child_link.tag is not None and len( + child_link.tag.text.strip()) > 0 else child_link.url + paragraphs = get_split_model('web.md').parse(response.content) + DocumentSerializers.Create(data={'dataset_id': dataset_id}).save( + {'name': document_name, 'paragraphs': paragraphs, + 'meta': {'source_url': child_link.url, 'selector': selector}, + 'type': KnowledgeType.WEB}, with_valid=True) + except Exception as e: + logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}') + + return handler + + +def get_sync_handler(dataset_id): + from knowledge.serializers.document_serializers import DocumentSerializers + dataset = QuerySet(DataSet).filter(id=dataset_id).first() + + def handler(child_link: ChildLink, response: Fork.Response): + if response.status == 200: + try: + + document_name = child_link.tag.text if child_link.tag is not None and len( + child_link.tag.text.strip()) > 0 else child_link.url + paragraphs = get_split_model('web.md').parse(response.content) + first = QuerySet(Document).filter(meta__source_url=child_link.url.strip(), + dataset=dataset).first() + if first is not None: + # 如果存在,使用文档同步 + DocumentSerializers.Sync(data={'document_id': first.id}).sync() + else: + # 插入 + DocumentSerializers.Create(data={'dataset_id': dataset.id}).save( + {'name': document_name, 'paragraphs': paragraphs, + 'meta': {'source_url': child_link.url.strip(), 'selector': dataset.meta.get('selector')}, + 'type': Type.web}, with_valid=True) + except Exception as e: + logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}') + + return handler + + +def get_sync_web_document_handler(dataset_id): + from knowledge.serializers.document_serializers import DocumentSerializers + + def handler(source_url: str, selector, response: Fork.Response): + if response.status == 200: + try: + paragraphs = get_split_model('web.md').parse(response.content) + # 插入 + DocumentSerializers.Create(data={'dataset_id': dataset_id}).save( + {'name': source_url[0:128], 'paragraphs': paragraphs, + 'meta': {'source_url': source_url, 'selector': selector}, + 'type': KnowledgeType.WEB}, with_valid=True) + except Exception as e: + logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}') + else: + Document(name=source_url[0:128], + dataset_id=dataset_id, + meta={'source_url': source_url, 'selector': selector}, + type=KnowledgeType.WEB, + char_length=0, + status=Status.error).save() + + return handler + + +def save_problem(dataset_id, document_id, paragraph_id, problem): + from knowledge.serializers.paragraph_serializers import ParagraphSerializers + # print(f"dataset_id: {dataset_id}") + # print(f"document_id: {document_id}") + # print(f"paragraph_id: {paragraph_id}") + # print(f"problem: {problem}") + problem = re.sub(r"^\d+\.\s*", "", problem) + pattern = r"(.*?)" + match = re.search(pattern, problem) + problem = match.group(1) if match else None + if problem is None or len(problem) == 0: + return + try: + ParagraphSerializers.Problem( + data={"dataset_id": dataset_id, 'document_id': document_id, + 'paragraph_id': paragraph_id}).save(instance={"content": problem}, with_valid=True) + except Exception as e: + max_kb_error.error(_('Association problem failed {error}').format(error=str(e))) diff --git a/apps/knowledge/views/knowledge.py b/apps/knowledge/views/knowledge.py index 80cdb40ce..2ea2e312d 100644 --- a/apps/knowledge/views/knowledge.py +++ b/apps/knowledge/views/knowledge.py @@ -16,8 +16,8 @@ class KnowledgeView(APIView): @extend_schema( methods=['GET'], - description=_('Get knowledge by module'), - operation_id=_('Get knowledge by module'), + description=_('Get knowledge by folder'), + operation_id=_('Get knowledge by folder'), parameters=KnowledgeTreeReadAPI.get_parameters(), responses=KnowledgeTreeReadAPI.get_response(), tags=[_('Knowledge Base')] @@ -26,7 +26,7 @@ class KnowledgeView(APIView): def get(self, request: Request, workspace_id: str): return result.success(KnowledgeTreeSerializer( data={'workspace_id': workspace_id} - ).get_knowledge_list(request.query_params.get('module_id'))) + ).get_knowledge_list(request.query_params.get('folder_id'))) class KnowledgeBaseView(APIView): @@ -45,7 +45,7 @@ class KnowledgeBaseView(APIView): def post(self, request: Request, workspace_id: str): return result.success(KnowledgeSerializer.Create( data={'user_id': request.user.id, 'workspace_id': workspace_id} - ).insert(request.data)) + ).save_base(request.data)) class KnowledgeWebView(APIView): @@ -64,4 +64,4 @@ class KnowledgeWebView(APIView): def post(self, request: Request, workspace_id: str): return result.success(KnowledgeSerializer.Create( data={'user_id': request.user.id, 'workspace_id': workspace_id} - ).insert(request.data)) + ).save_web(request.data)) diff --git a/apps/maxkb/urls.py b/apps/maxkb/urls.py index 8ea7ece3f..d761fa423 100644 --- a/apps/maxkb/urls.py +++ b/apps/maxkb/urls.py @@ -24,7 +24,7 @@ urlpatterns = [ path("api/", include("users.urls")), path("api/", include("tools.urls")), path("api/", include("models_provider.urls")), - path("api/", include("modules.urls")), + path("api/", include("folders.urls")), path("api/", include("knowledge.urls")), ] urlpatterns += [ diff --git a/apps/modules/urls.py b/apps/modules/urls.py deleted file mode 100644 index 286a05a67..000000000 --- a/apps/modules/urls.py +++ /dev/null @@ -1,9 +0,0 @@ -from django.urls import path - -from . import views - -app_name = "module" -urlpatterns = [ - path('workspace///module', views.ModuleView.as_view()), - path('workspace///module/', views.ModuleView.Operate.as_view()), -] diff --git a/apps/modules/views/__init__.py b/apps/modules/views/__init__.py deleted file mode 100644 index 267692405..000000000 --- a/apps/modules/views/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .module import * diff --git a/apps/ops/__init__.py b/apps/ops/__init__.py new file mode 100644 index 000000000..a02f13af3 --- /dev/null +++ b/apps/ops/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: __init__.py.py + @date:2024/8/16 14:47 + @desc: +""" +from .celery import app as celery_app diff --git a/apps/ops/celery/__init__.py b/apps/ops/celery/__init__.py new file mode 100644 index 000000000..6f1c9af8c --- /dev/null +++ b/apps/ops/celery/__init__.py @@ -0,0 +1,33 @@ +# -*- coding: utf-8 -*- + +import os + +from celery import Celery +from celery.schedules import crontab +from kombu import Exchange, Queue +from maxkb import settings +from .heartbeat import * + +# set the default Django settings module for the 'celery' program. +os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'smartdoc.settings') + +app = Celery('MaxKB') + +configs = {k: v for k, v in settings.__dict__.items() if k.startswith('CELERY')} +configs['worker_concurrency'] = 5 +# Using a string here means the worker will not have to +# pickle the object when using Windows. +# app.config_from_object('django.conf:settings', namespace='CELERY') + +configs["task_queues"] = [ + Queue("celery", Exchange("celery"), routing_key="celery"), + Queue("model", Exchange("model"), routing_key="model") +] +app.namespace = 'CELERY' +app.conf.update( + {key.replace('CELERY_', '') if key.replace('CELERY_', '').lower() == key.replace('CELERY_', + '') else key: configs.get( + key) for + key + in configs.keys()}) +app.autodiscover_tasks(lambda: [app_config.split('.')[0] for app_config in settings.INSTALLED_APPS]) diff --git a/apps/ops/celery/const.py b/apps/ops/celery/const.py new file mode 100644 index 000000000..2f887023f --- /dev/null +++ b/apps/ops/celery/const.py @@ -0,0 +1,4 @@ +# -*- coding: utf-8 -*- +# + +CELERY_LOG_MAGIC_MARK = b'\x00\x00\x00\x00\x00' \ No newline at end of file diff --git a/apps/ops/celery/decorator.py b/apps/ops/celery/decorator.py new file mode 100644 index 000000000..317a7f7ae --- /dev/null +++ b/apps/ops/celery/decorator.py @@ -0,0 +1,104 @@ +# -*- coding: utf-8 -*- +# +from functools import wraps + +_need_registered_period_tasks = [] +_after_app_ready_start_tasks = [] +_after_app_shutdown_clean_periodic_tasks = [] + + +def add_register_period_task(task): + _need_registered_period_tasks.append(task) + + +def get_register_period_tasks(): + return _need_registered_period_tasks + + +def add_after_app_shutdown_clean_task(name): + _after_app_shutdown_clean_periodic_tasks.append(name) + + +def get_after_app_shutdown_clean_tasks(): + return _after_app_shutdown_clean_periodic_tasks + + +def add_after_app_ready_task(name): + _after_app_ready_start_tasks.append(name) + + +def get_after_app_ready_tasks(): + return _after_app_ready_start_tasks + + +def register_as_period_task( + crontab=None, interval=None, name=None, + args=(), kwargs=None, + description=''): + """ + Warning: Task must have not any args and kwargs + :param crontab: "* * * * *" + :param interval: 60*60*60 + :param args: () + :param kwargs: {} + :param description: " + :param name: "" + :return: + """ + if crontab is None and interval is None: + raise SyntaxError("Must set crontab or interval one") + + def decorate(func): + if crontab is None and interval is None: + raise SyntaxError("Interval and crontab must set one") + + # Because when this decorator run, the task was not created, + # So we can't use func.name + task = '{func.__module__}.{func.__name__}'.format(func=func) + _name = name if name else task + add_register_period_task({ + _name: { + 'task': task, + 'interval': interval, + 'crontab': crontab, + 'args': args, + 'kwargs': kwargs if kwargs else {}, + 'description': description + } + }) + + @wraps(func) + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + return wrapper + + return decorate + + +def after_app_ready_start(func): + # Because when this decorator run, the task was not created, + # So we can't use func.name + name = '{func.__module__}.{func.__name__}'.format(func=func) + if name not in _after_app_ready_start_tasks: + add_after_app_ready_task(name) + + @wraps(func) + def decorate(*args, **kwargs): + return func(*args, **kwargs) + + return decorate + + +def after_app_shutdown_clean_periodic(func): + # Because when this decorator run, the task was not created, + # So we can't use func.name + name = '{func.__module__}.{func.__name__}'.format(func=func) + if name not in _after_app_shutdown_clean_periodic_tasks: + add_after_app_shutdown_clean_task(name) + + @wraps(func) + def decorate(*args, **kwargs): + return func(*args, **kwargs) + + return decorate diff --git a/apps/ops/celery/heartbeat.py b/apps/ops/celery/heartbeat.py new file mode 100644 index 000000000..339a3c60a --- /dev/null +++ b/apps/ops/celery/heartbeat.py @@ -0,0 +1,25 @@ +from pathlib import Path + +from celery.signals import heartbeat_sent, worker_ready, worker_shutdown + + +@heartbeat_sent.connect +def heartbeat(sender, **kwargs): + worker_name = sender.eventer.hostname.split('@')[0] + heartbeat_path = Path('/tmp/worker_heartbeat_{}'.format(worker_name)) + heartbeat_path.touch() + + +@worker_ready.connect +def worker_ready(sender, **kwargs): + worker_name = sender.hostname.split('@')[0] + ready_path = Path('/tmp/worker_ready_{}'.format(worker_name)) + ready_path.touch() + + +@worker_shutdown.connect +def worker_shutdown(sender, **kwargs): + worker_name = sender.hostname.split('@')[0] + for signal in ['ready', 'heartbeat']: + path = Path('/tmp/worker_{}_{}'.format(signal, worker_name)) + path.unlink(missing_ok=True) diff --git a/apps/ops/celery/logger.py b/apps/ops/celery/logger.py new file mode 100644 index 000000000..1b2843c2b --- /dev/null +++ b/apps/ops/celery/logger.py @@ -0,0 +1,225 @@ +from logging import StreamHandler +from threading import get_ident + +from celery import current_task +from celery.signals import task_prerun, task_postrun +from django.conf import settings +from kombu import Connection, Exchange, Queue, Producer +from kombu.mixins import ConsumerMixin + +from .utils import get_celery_task_log_path +from .const import CELERY_LOG_MAGIC_MARK + +routing_key = 'celery_log' +celery_log_exchange = Exchange('celery_log_exchange', type='direct') +celery_log_queue = [Queue('celery_log', celery_log_exchange, routing_key=routing_key)] + + +class CeleryLoggerConsumer(ConsumerMixin): + def __init__(self): + self.connection = Connection(settings.CELERY_LOG_BROKER_URL) + + def get_consumers(self, Consumer, channel): + return [Consumer(queues=celery_log_queue, + accept=['pickle', 'json'], + callbacks=[self.process_task]) + ] + + def handle_task_start(self, task_id, message): + pass + + def handle_task_end(self, task_id, message): + pass + + def handle_task_log(self, task_id, msg, message): + pass + + def process_task(self, body, message): + action = body.get('action') + task_id = body.get('task_id') + msg = body.get('msg') + if action == CeleryLoggerProducer.ACTION_TASK_LOG: + self.handle_task_log(task_id, msg, message) + elif action == CeleryLoggerProducer.ACTION_TASK_START: + self.handle_task_start(task_id, message) + elif action == CeleryLoggerProducer.ACTION_TASK_END: + self.handle_task_end(task_id, message) + + +class CeleryLoggerProducer: + ACTION_TASK_START, ACTION_TASK_LOG, ACTION_TASK_END = range(3) + + def __init__(self): + self.connection = Connection(settings.CELERY_LOG_BROKER_URL) + + @property + def producer(self): + return Producer(self.connection) + + def publish(self, payload): + self.producer.publish( + payload, serializer='json', exchange=celery_log_exchange, + declare=[celery_log_exchange], routing_key=routing_key + ) + + def log(self, task_id, msg): + payload = {'task_id': task_id, 'msg': msg, 'action': self.ACTION_TASK_LOG} + return self.publish(payload) + + def read(self): + pass + + def flush(self): + pass + + def task_end(self, task_id): + payload = {'task_id': task_id, 'action': self.ACTION_TASK_END} + return self.publish(payload) + + def task_start(self, task_id): + payload = {'task_id': task_id, 'action': self.ACTION_TASK_START} + return self.publish(payload) + + +class CeleryTaskLoggerHandler(StreamHandler): + terminator = '\r\n' + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + task_prerun.connect(self.on_task_start) + task_postrun.connect(self.on_start_end) + + @staticmethod + def get_current_task_id(): + if not current_task: + return + task_id = current_task.request.root_id + return task_id + + def on_task_start(self, sender, task_id, **kwargs): + return self.handle_task_start(task_id) + + def on_start_end(self, sender, task_id, **kwargs): + return self.handle_task_end(task_id) + + def after_task_publish(self, sender, body, **kwargs): + pass + + def emit(self, record): + task_id = self.get_current_task_id() + if not task_id: + return + try: + self.write_task_log(task_id, record) + self.flush() + except Exception: + self.handleError(record) + + def write_task_log(self, task_id, msg): + pass + + def handle_task_start(self, task_id): + pass + + def handle_task_end(self, task_id): + pass + + +class CeleryThreadingLoggerHandler(CeleryTaskLoggerHandler): + @staticmethod + def get_current_thread_id(): + return str(get_ident()) + + def emit(self, record): + thread_id = self.get_current_thread_id() + try: + self.write_thread_task_log(thread_id, record) + self.flush() + except ValueError: + self.handleError(record) + + def write_thread_task_log(self, thread_id, msg): + pass + + def handle_task_start(self, task_id): + pass + + def handle_task_end(self, task_id): + pass + + def handleError(self, record) -> None: + pass + + +class CeleryTaskMQLoggerHandler(CeleryTaskLoggerHandler): + def __init__(self): + self.producer = CeleryLoggerProducer() + super().__init__(stream=None) + + def write_task_log(self, task_id, record): + msg = self.format(record) + self.producer.log(task_id, msg) + + def flush(self): + self.producer.flush() + + +class CeleryTaskFileHandler(CeleryTaskLoggerHandler): + def __init__(self, *args, **kwargs): + self.f = None + super().__init__(*args, **kwargs) + + def emit(self, record): + msg = self.format(record) + if not self.f or self.f.closed: + return + self.f.write(msg) + self.f.write(self.terminator) + self.flush() + + def flush(self): + self.f and self.f.flush() + + def handle_task_start(self, task_id): + log_path = get_celery_task_log_path(task_id) + self.f = open(log_path, 'a') + + def handle_task_end(self, task_id): + self.f and self.f.close() + + +class CeleryThreadTaskFileHandler(CeleryThreadingLoggerHandler): + def __init__(self, *args, **kwargs): + self.thread_id_fd_mapper = {} + self.task_id_thread_id_mapper = {} + super().__init__(*args, **kwargs) + + def write_thread_task_log(self, thread_id, record): + f = self.thread_id_fd_mapper.get(thread_id, None) + if not f: + raise ValueError('Not found thread task file') + msg = self.format(record) + f.write(msg.encode()) + f.write(self.terminator.encode()) + f.flush() + + def flush(self): + for f in self.thread_id_fd_mapper.values(): + f.flush() + + def handle_task_start(self, task_id): + print('handle_task_start') + log_path = get_celery_task_log_path(task_id) + thread_id = self.get_current_thread_id() + self.task_id_thread_id_mapper[task_id] = thread_id + f = open(log_path, 'ab') + self.thread_id_fd_mapper[thread_id] = f + + def handle_task_end(self, task_id): + print('handle_task_end') + ident_id = self.task_id_thread_id_mapper.get(task_id, '') + f = self.thread_id_fd_mapper.pop(ident_id, None) + if f and not f.closed: + f.write(CELERY_LOG_MAGIC_MARK) + f.close() + self.task_id_thread_id_mapper.pop(task_id, None) diff --git a/apps/ops/celery/signal_handler.py b/apps/ops/celery/signal_handler.py new file mode 100644 index 000000000..46671a0d8 --- /dev/null +++ b/apps/ops/celery/signal_handler.py @@ -0,0 +1,75 @@ +# -*- coding: utf-8 -*- +# +import logging +import os + +from celery import subtask +from celery.signals import ( + worker_ready, worker_shutdown, after_setup_logger, task_revoked, task_prerun +) +from django.core.cache import cache +from django_celery_beat.models import PeriodicTask + +from .decorator import get_after_app_ready_tasks, get_after_app_shutdown_clean_tasks +from .logger import CeleryThreadTaskFileHandler + +logger = logging.getLogger(__file__) +safe_str = lambda x: x + + +@worker_ready.connect +def on_app_ready(sender=None, headers=None, **kwargs): + if cache.get("CELERY_APP_READY", 0) == 1: + return + cache.set("CELERY_APP_READY", 1, 10) + tasks = get_after_app_ready_tasks() + logger.debug("Work ready signal recv") + logger.debug("Start need start task: [{}]".format(", ".join(tasks))) + for task in tasks: + periodic_task = PeriodicTask.objects.filter(task=task).first() + if periodic_task and not periodic_task.enabled: + logger.debug("Periodic task [{}] is disabled!".format(task)) + continue + subtask(task).delay() + + +def delete_files(directory): + if os.path.isdir(directory): + for filename in os.listdir(directory): + file_path = os.path.join(directory, filename) + if os.path.isfile(file_path): + os.remove(file_path) + + +@worker_shutdown.connect +def after_app_shutdown_periodic_tasks(sender=None, **kwargs): + if cache.get("CELERY_APP_SHUTDOWN", 0) == 1: + return + cache.set("CELERY_APP_SHUTDOWN", 1, 10) + tasks = get_after_app_shutdown_clean_tasks() + logger.debug("Worker shutdown signal recv") + logger.debug("Clean period tasks: [{}]".format(', '.join(tasks))) + PeriodicTask.objects.filter(name__in=tasks).delete() + + +@after_setup_logger.connect +def add_celery_logger_handler(sender=None, logger=None, loglevel=None, format=None, **kwargs): + if not logger: + return + task_handler = CeleryThreadTaskFileHandler() + task_handler.setLevel(loglevel) + formatter = logging.Formatter(format) + task_handler.setFormatter(formatter) + logger.addHandler(task_handler) + + +@task_revoked.connect +def on_task_revoked(request, terminated, signum, expired, **kwargs): + print('task_revoked', terminated) + + +@task_prerun.connect +def on_taskaa_start(sender, task_id, **kwargs): + pass + # sender.update_state(state='REVOKED', +# meta={'exc_type': 'Exception', 'exc': 'Exception', 'message': '暂停任务', 'exc_message': ''}) diff --git a/apps/ops/celery/utils.py b/apps/ops/celery/utils.py new file mode 100644 index 000000000..49b41300b --- /dev/null +++ b/apps/ops/celery/utils.py @@ -0,0 +1,68 @@ +# -*- coding: utf-8 -*- +# +import logging +import os +import uuid + +from django.conf import settings +from django_celery_beat.models import ( + PeriodicTasks +) + +from maxkb.const import PROJECT_DIR + +logger = logging.getLogger(__file__) + + +def disable_celery_periodic_task(task_name): + from django_celery_beat.models import PeriodicTask + PeriodicTask.objects.filter(name=task_name).update(enabled=False) + PeriodicTasks.update_changed() + + +def delete_celery_periodic_task(task_name): + from django_celery_beat.models import PeriodicTask + PeriodicTask.objects.filter(name=task_name).delete() + PeriodicTasks.update_changed() + + +def get_celery_periodic_task(task_name): + from django_celery_beat.models import PeriodicTask + task = PeriodicTask.objects.filter(name=task_name).first() + return task + + +def make_dirs(name, mode=0o755, exist_ok=False): + """ 默认权限设置为 0o755 """ + return os.makedirs(name, mode=mode, exist_ok=exist_ok) + + +def get_task_log_path(base_path, task_id, level=2): + task_id = str(task_id) + try: + uuid.UUID(task_id) + except: + return os.path.join(PROJECT_DIR, 'data', 'caution.txt') + + rel_path = os.path.join(*task_id[:level], task_id + '.log') + path = os.path.join(base_path, rel_path) + make_dirs(os.path.dirname(path), exist_ok=True) + return path + + +def get_celery_task_log_path(task_id): + return get_task_log_path(settings.CELERY_LOG_DIR, task_id) + + +def get_celery_status(): + from . import app + i = app.control.inspect() + ping_data = i.ping() or {} + active_nodes = [k for k, v in ping_data.items() if v.get('ok') == 'pong'] + active_queue_worker = set([n.split('@')[0] for n in active_nodes if n]) + # Celery Worker 数量: 2 + if len(active_queue_worker) < 2: + print("Not all celery worker worked") + return False + else: + return True diff --git a/apps/tools/api/tool.py b/apps/tools/api/tool.py index 382d94a54..50e21eb66 100644 --- a/apps/tools/api/tool.py +++ b/apps/tools/api/tool.py @@ -84,8 +84,8 @@ class ToolTreeReadAPI(APIMixin): required=True, ), OpenApiParameter( - name="module_id", - description="模块id", + name="folder_id", + description="文件夹id", type=OpenApiTypes.STR, location='query', required=False, @@ -178,8 +178,8 @@ class ToolPageAPI(ToolReadAPI): required=True, ), OpenApiParameter( - name="module_id", - description="模块id", + name="folder_id", + description="文件夹id", type=OpenApiTypes.STR, location='query', required=True, diff --git a/apps/tools/migrations/0001_initial.py b/apps/tools/migrations/0001_initial.py index 7f71eb782..eaf209498 100644 --- a/apps/tools/migrations/0001_initial.py +++ b/apps/tools/migrations/0001_initial.py @@ -5,12 +5,12 @@ import mptt.fields import uuid_utils.compat from django.db import migrations, models -from tools.models import ToolModule +from tools.models import ToolFolder def insert_default_data(apps, schema_editor): # 创建一个根模块(没有父节点) - ToolModule.objects.create(id='root', name='根目录', user_id='f0dd8f71-e4ee-11ee-8c84-a8a1595801ab') + ToolFolder.objects.create(id='root', name='根目录', user_id='f0dd8f71-e4ee-11ee-8c84-a8a1595801ab') class Migration(migrations.Migration): @@ -22,7 +22,7 @@ class Migration(migrations.Migration): operations = [ migrations.CreateModel( - name='ToolModule', + name='ToolFolder', fields=[ ('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')), ('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')), @@ -37,12 +37,12 @@ class Migration(migrations.Migration): ('level', models.PositiveIntegerField(editable=False)), ('parent', mptt.fields.TreeForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, - related_name='children', to='tools.toolmodule')), + related_name='children', to='tools.toolfolder')), ('user', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='users.user', verbose_name='用户id')), ], options={ - 'db_table': 'tool_module', + 'db_table': 'tool_folder', }, ), migrations.RunPython(insert_default_data), @@ -72,9 +72,9 @@ class Migration(migrations.Migration): ('init_params', models.CharField(max_length=102400, null=True, verbose_name='初始化参数')), ('user', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='users.user', verbose_name='用户id')), - ('module', - models.ForeignKey(default='root', on_delete=django.db.models.deletion.CASCADE, to='tools.toolmodule', - verbose_name='模块id')), + ('folder', + models.ForeignKey(default='root', on_delete=django.db.models.deletion.CASCADE, to='tools.toolfolder', + verbose_name='文件夹id')), ], options={ 'db_table': 'tool', diff --git a/apps/tools/models/tool.py b/apps/tools/models/tool.py index f8c98a5a0..5cd204040 100644 --- a/apps/tools/models/tool.py +++ b/apps/tools/models/tool.py @@ -7,7 +7,7 @@ from common.mixins.app_model_mixin import AppModelMixin from users.models import User -class ToolModule(MPTTModel, AppModelMixin): +class ToolFolder(MPTTModel, AppModelMixin): id = models.CharField(primary_key=True, max_length=64, editable=False, verbose_name="主键id") name = models.CharField(max_length=64, verbose_name="文件夹名称") user = models.ForeignKey(User, on_delete=models.DO_NOTHING, verbose_name="用户id") @@ -15,7 +15,7 @@ class ToolModule(MPTTModel, AppModelMixin): parent = TreeForeignKey('self', on_delete=models.CASCADE, null=True, blank=True, related_name='children') class Meta: - db_table = "tool_module" + db_table = "tool_folder" class MPTTMeta: order_insertion_by = ['name'] @@ -46,7 +46,7 @@ class Tool(AppModelMixin): tool_type = models.CharField(max_length=20, verbose_name='工具类型', choices=ToolType.choices, default=ToolType.CUSTOM, db_index=True) template_id = models.UUIDField(max_length=128, verbose_name="模版id", null=True, default=None) - module = models.ForeignKey(ToolModule, on_delete=models.CASCADE, verbose_name="模块id", default='root') + folder = models.ForeignKey(ToolFolder, on_delete=models.CASCADE, verbose_name="文件夹id", default='root') workspace_id = models.CharField(max_length=64, verbose_name="工作空间id", default="default", db_index=True) init_params = models.CharField(max_length=102400, verbose_name="初始化参数", null=True) diff --git a/apps/tools/serializers/tool.py b/apps/tools/serializers/tool.py index a7c1bcde7..e8ee72901 100644 --- a/apps/tools/serializers/tool.py +++ b/apps/tools/serializers/tool.py @@ -17,7 +17,7 @@ from common.exception.app_exception import AppApiException from common.result import result from common.utils.tool_code import ToolExecutor from maxkb.const import CONFIG -from tools.models import Tool, ToolScope, ToolModule +from tools.models import Tool, ToolScope, ToolFolder tool_executor = ToolExecutor(CONFIG.get('SANDBOX')) @@ -37,11 +37,11 @@ ALLOWED_CLASSES = { class RestrictedUnpickler(pickle.Unpickler): - def find_class(self, module, name): - if (module, name) in ALLOWED_CLASSES: - return super().find_class(module, name) + def find_class(self, folder, name): + if (folder, name) in ALLOWED_CLASSES: + return super().find_class(folder, name) raise pickle.UnpicklingError("global '%s.%s' is forbidden" % - (module, name)) + (folder, name)) def encryption(message: str): @@ -75,7 +75,7 @@ class ToolModelSerializer(serializers.ModelSerializer): class Meta: model = Tool fields = ['id', 'name', 'icon', 'desc', 'code', 'input_field_list', 'init_field_list', 'init_params', - 'scope', 'is_active', 'user_id', 'template_id', 'workspace_id', 'module_id', 'tool_type', + 'scope', 'is_active', 'user_id', 'template_id', 'workspace_id', 'folder_id', 'tool_type', 'create_time', 'update_time'] @@ -126,7 +126,7 @@ class ToolCreateRequest(serializers.Serializer): is_active = serializers.BooleanField(required=False, label=_('Is active')) - module_id = serializers.CharField(required=False, allow_null=True, allow_blank=True, default='root') + folder_id = serializers.CharField(required=False, allow_null=True, allow_blank=True, default='root') class ToolEditRequest(serializers.Serializer): @@ -146,7 +146,7 @@ class ToolEditRequest(serializers.Serializer): is_active = serializers.BooleanField(required=False, label=_('Is active')) - module_id = serializers.CharField(required=False, allow_null=True, allow_blank=True, default='root') + folder_id = serializers.CharField(required=False, allow_null=True, allow_blank=True, default='root') class DebugField(serializers.Serializer): @@ -180,7 +180,7 @@ class ToolSerializer(serializers.Serializer): input_field_list=instance.get('input_field_list', []), init_field_list=instance.get('init_field_list', []), scope=ToolScope.WORKSPACE, - module_id=instance.get('module_id', 'root'), + folder_id=instance.get('folder_id', 'root'), is_active=False) tool.save() return ToolModelSerializer(tool).data @@ -321,42 +321,42 @@ class ToolSerializer(serializers.Serializer): class ToolTreeSerializer(serializers.Serializer): workspace_id = serializers.CharField(required=True, label=_('workspace id')) - def get_tools(self, module_id): + def get_tools(self, folder_id): self.is_valid(raise_exception=True) - if not module_id: - module_id = 'root' - root = ToolModule.objects.filter(id=module_id).first() + if not folder_id: + folder_id = 'root' + root = ToolFolder.objects.filter(id=folder_id).first() if not root: - raise serializers.ValidationError(_('Module not found')) + raise serializers.ValidationError(_('Folder not found')) # 使用MPTT的get_descendants()方法获取所有相关节点 - all_modules = root.get_descendants(include_self=True) + all_folders = root.get_descendants(include_self=True) - tools = QuerySet(Tool).filter(workspace_id=self.data.get('workspace_id'), module_id__in=all_modules) + tools = QuerySet(Tool).filter(workspace_id=self.data.get('workspace_id'), folder_id__in=all_folders) return ToolModelSerializer(tools, many=True).data class Query(serializers.Serializer): workspace_id = serializers.CharField(required=True, label=_('workspace id')) - module_id = serializers.CharField(required=True, label=_('module id')) + folder_id = serializers.CharField(required=True, label=_('folder id')) name = serializers.CharField(required=False, allow_null=True, allow_blank=True, label=_('tool name')) tool_type = serializers.CharField(required=True, label=_('tool type')) def page(self, current_page: int, page_size: int): self.is_valid(raise_exception=True) - module_id = self.data.get('module_id', 'root') - root = ToolModule.objects.filter(id=module_id).first() + folder_id = self.data.get('folder_id', 'root') + root = ToolFolder.objects.filter(id=folder_id).first() if not root: - raise serializers.ValidationError(_('Module not found')) + raise serializers.ValidationError(_('Folder not found')) # 使用MPTT的get_descendants()方法获取所有相关节点 - all_modules = root.get_descendants(include_self=True) + all_folders = root.get_descendants(include_self=True) if self.data.get('name'): tools = QuerySet(Tool).filter(Q(workspace_id=self.data.get('workspace_id')) & - Q(module_id__in=all_modules) & + Q(folder_id__in=all_folders) & Q(tool_type=self.data.get('tool_type')) & Q(name__contains=self.data.get('name'))) else: tools = QuerySet(Tool).filter(Q(workspace_id=self.data.get('workspace_id')) & - Q(module_id__in=all_modules) & + Q(folder_id__in=all_folders) & Q(tool_type=self.data.get('tool_type'))) return page_search(current_page, page_size, tools, lambda record: ToolModelSerializer(record).data) diff --git a/apps/tools/serializers/tool_module.py b/apps/tools/serializers/tool_folder.py similarity index 58% rename from apps/tools/serializers/tool_module.py rename to apps/tools/serializers/tool_folder.py index 24151ecdf..e8dfb6dc5 100644 --- a/apps/tools/serializers/tool_module.py +++ b/apps/tools/serializers/tool_folder.py @@ -2,15 +2,15 @@ from rest_framework import serializers -from tools.models import ToolModule +from tools.models import ToolFolder -class ToolModuleTreeSerializer(serializers.ModelSerializer): +class ToolFolderTreeSerializer(serializers.ModelSerializer): children = serializers.SerializerMethodField() class Meta: - model = ToolModule + model = ToolFolder fields = ['id', 'name', 'user_id', 'workspace_id', 'parent_id', 'children'] def get_children(self, obj): - return ToolModuleTreeSerializer(obj.get_children(), many=True).data + return ToolFolderTreeSerializer(obj.get_children(), many=True).data diff --git a/apps/tools/views/tool.py b/apps/tools/views/tool.py index e717fc136..20aee12d9 100644 --- a/apps/tools/views/tool.py +++ b/apps/tools/views/tool.py @@ -33,8 +33,8 @@ class ToolView(APIView): @extend_schema( methods=['GET'], - description=_('Get tool by module'), - operation_id=_('Get tool by module'), + description=_('Get tool by folder'), + operation_id=_('Get tool by folder'), parameters=ToolTreeReadAPI.get_parameters(), responses=ToolTreeReadAPI.get_response(), tags=[_('Tool')] @@ -43,7 +43,7 @@ class ToolView(APIView): def get(self, request: Request, workspace_id: str): return result.success(ToolTreeSerializer( data={'workspace_id': workspace_id} - ).get_tools(request.query_params.get('module_id'))) + ).get_tools(request.query_params.get('folder_id'))) class Debug(APIView): authentication_classes = [TokenAuth] @@ -124,7 +124,7 @@ class ToolView(APIView): return result.success(ToolTreeSerializer.Query( data={ 'workspace_id': workspace_id, - 'module_id': request.query_params.get('module_id'), + 'folder_id': request.query_params.get('folder_id'), 'name': request.query_params.get('name'), 'tool_type': request.query_params.get('tool_type'), } diff --git a/pyproject.toml b/pyproject.toml index cf4e320f0..d9183af1e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,12 @@ cffi = "1.17.1" pysilk = "0.0.1" sentence-transformers = "4.1.0" websockets = "15.0.1" +celery = { extras = ["sqlalchemy"], version = "5.5.2" } +django-celery-beat = "2.8.0" +celery-once = "3.0.1" +beautifulsoup4 = "4.13.4" +html2text = "2025.4.15" +jieba = "0.42.1" [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api"