feat: implement web knowledge synchronization with ForkManage and related handlers

This commit is contained in:
CaptainB 2025-04-28 10:39:23 +08:00
parent 99fd32897c
commit ee37d7c320
39 changed files with 1564 additions and 192 deletions

View File

@ -166,13 +166,13 @@ class PermissionConstants(Enum):
role_list=[RoleConstants.ADMIN, RoleConstants.USER])
MODEL_DELETE = Permission(group=Group.MODEL, operate=Operate.DELETE,
role_list=[RoleConstants.ADMIN, RoleConstants.USER])
TOOL_MODULE_CREATE = Permission(group=Group.TOOL, operate=Operate.CREATE, role_list=[RoleConstants.ADMIN,
TOOL_FOLDER_CREATE = Permission(group=Group.TOOL, operate=Operate.CREATE, role_list=[RoleConstants.ADMIN,
RoleConstants.USER])
TOOL_MODULE_READ = Permission(group=Group.TOOL, operate=Operate.READ, role_list=[RoleConstants.ADMIN,
TOOL_FOLDER_READ = Permission(group=Group.TOOL, operate=Operate.READ, role_list=[RoleConstants.ADMIN,
RoleConstants.USER])
TOOL_MODULE_EDIT = Permission(group=Group.TOOL, operate=Operate.EDIT, role_list=[RoleConstants.ADMIN,
TOOL_FOLDER_EDIT = Permission(group=Group.TOOL, operate=Operate.EDIT, role_list=[RoleConstants.ADMIN,
RoleConstants.USER])
TOOL_MODULE_DELETE = Permission(group=Group.TOOL, operate=Operate.DELETE, role_list=[RoleConstants.ADMIN,
TOOL_FOLDER_DELETE = Permission(group=Group.TOOL, operate=Operate.DELETE, role_list=[RoleConstants.ADMIN,
RoleConstants.USER])
TOOL_CREATE = Permission(group=Group.TOOL, operate=Operate.CREATE, role_list=[RoleConstants.ADMIN,
@ -190,20 +190,20 @@ class PermissionConstants(Enum):
TOOL_EXPORT = Permission(group=Group.TOOL, operate=Operate.USE, role_list=[RoleConstants.ADMIN,
RoleConstants.USER])
KNOWLEDGE_MODULE_CREATE = Permission(group=Group.KNOWLEDGE, operate=Operate.CREATE, role_list=[RoleConstants.ADMIN,
KNOWLEDGE_FOLDER_CREATE = Permission(group=Group.KNOWLEDGE, operate=Operate.CREATE, role_list=[RoleConstants.ADMIN,
RoleConstants.USER])
KNOWLEDGE_MODULE_READ = Permission(group=Group.KNOWLEDGE, operate=Operate.READ, role_list=[RoleConstants.ADMIN,
KNOWLEDGE_FOLDER_READ = Permission(group=Group.KNOWLEDGE, operate=Operate.READ, role_list=[RoleConstants.ADMIN,
RoleConstants.USER],
resource_permission_group_list=[
ResourcePermissionGroup.VIEW
])
KNOWLEDGE_MODULE_EDIT = Permission(group=Group.KNOWLEDGE, operate=Operate.EDIT, role_list=[RoleConstants.ADMIN,
KNOWLEDGE_FOLDER_EDIT = Permission(group=Group.KNOWLEDGE, operate=Operate.EDIT, role_list=[RoleConstants.ADMIN,
RoleConstants.USER],
resource_permission_group_list=[
ResourcePermissionGroup.MANAGE
]
)
KNOWLEDGE_MODULE_DELETE = Permission(group=Group.KNOWLEDGE, operate=Operate.DELETE, role_list=[RoleConstants.ADMIN,
KNOWLEDGE_FOLDER_DELETE = Permission(group=Group.KNOWLEDGE, operate=Operate.DELETE, role_list=[RoleConstants.ADMIN,
RoleConstants.USER],
resource_permission_group_list=[
ResourcePermissionGroup.MANAGE

178
apps/common/utils/fork.py Normal file
View File

@ -0,0 +1,178 @@
import copy
import logging
import re
import traceback
from functools import reduce
from typing import List, Set
from urllib.parse import urljoin, urlparse, ParseResult, urlsplit, urlunparse
import html2text as ht
import requests
from bs4 import BeautifulSoup
requests.packages.urllib3.disable_warnings()
class ChildLink:
def __init__(self, url, tag):
self.url = url
self.tag = copy.deepcopy(tag)
class ForkManage:
def __init__(self, base_url: str, selector_list: List[str]):
self.base_url = base_url
self.selector_list = selector_list
def fork(self, level: int, exclude_link_url: Set[str], fork_handler):
self.fork_child(ChildLink(self.base_url, None), self.selector_list, level, exclude_link_url, fork_handler)
@staticmethod
def fork_child(child_link: ChildLink, selector_list: List[str], level: int, exclude_link_url: Set[str],
fork_handler):
if level < 0:
return
else:
child_link.url = remove_fragment(child_link.url)
child_url = child_link.url[:-1] if child_link.url.endswith('/') else child_link.url
if not exclude_link_url.__contains__(child_url):
exclude_link_url.add(child_url)
response = Fork(child_link.url, selector_list).fork()
fork_handler(child_link, response)
for child_link in response.child_link_list:
child_url = child_link.url[:-1] if child_link.url.endswith('/') else child_link.url
if not exclude_link_url.__contains__(child_url):
ForkManage.fork_child(child_link, selector_list, level - 1, exclude_link_url, fork_handler)
def remove_fragment(url: str) -> str:
parsed_url = urlparse(url)
modified_url = ParseResult(scheme=parsed_url.scheme, netloc=parsed_url.netloc, path=parsed_url.path,
params=parsed_url.params, query=parsed_url.query, fragment=None)
return urlunparse(modified_url)
class Fork:
class Response:
def __init__(self, content: str, child_link_list: List[ChildLink], status, message: str):
self.content = content
self.child_link_list = child_link_list
self.status = status
self.message = message
@staticmethod
def success(html_content: str, child_link_list: List[ChildLink]):
return Fork.Response(html_content, child_link_list, 200, '')
@staticmethod
def error(message: str):
return Fork.Response('', [], 500, message)
def __init__(self, base_fork_url: str, selector_list: List[str]):
base_fork_url = remove_fragment(base_fork_url)
self.base_fork_url = urljoin(base_fork_url if base_fork_url.endswith("/") else base_fork_url + '/', '.')
parsed = urlsplit(base_fork_url)
query = parsed.query
self.base_fork_url = self.base_fork_url[:-1]
if query is not None and len(query) > 0:
self.base_fork_url = self.base_fork_url + '?' + query
self.selector_list = [selector for selector in selector_list if selector is not None and len(selector) > 0]
self.urlparse = urlparse(self.base_fork_url)
self.base_url = ParseResult(scheme=self.urlparse.scheme, netloc=self.urlparse.netloc, path='', params='',
query='',
fragment='').geturl()
def get_child_link_list(self, bf: BeautifulSoup):
pattern = "^((?!(http:|https:|tel:/|#|mailto:|javascript:))|" + self.base_fork_url + "|/).*"
link_list = bf.find_all(name='a', href=re.compile(pattern))
result = [ChildLink(link.get('href'), link) if link.get('href').startswith(self.base_url) else ChildLink(
self.base_url + link.get('href'), link) for link in link_list]
result = [row for row in result if row.url.startswith(self.base_fork_url)]
return result
def get_content_html(self, bf: BeautifulSoup):
if self.selector_list is None or len(self.selector_list) == 0:
return str(bf)
params = reduce(lambda x, y: {**x, **y},
[{'class_': selector.replace('.', '')} if selector.startswith('.') else
{'id': selector.replace("#", "")} if selector.startswith("#") else {'name': selector} for
selector in
self.selector_list], {})
f = bf.find_all(**params)
return "\n".join([str(row) for row in f])
@staticmethod
def reset_url(tag, field, base_fork_url):
field_value: str = tag[field]
if field_value.startswith("/"):
result = urlparse(base_fork_url)
result_url = ParseResult(scheme=result.scheme, netloc=result.netloc, path=field_value, params='', query='',
fragment='').geturl()
else:
result_url = urljoin(
base_fork_url + '/' + (field_value if field_value.endswith('/') else field_value + '/'),
".")
result_url = result_url[:-1] if result_url.endswith('/') else result_url
tag[field] = result_url
def reset_beautiful_soup(self, bf: BeautifulSoup):
reset_config_list = [
{
'field': 'href',
},
{
'field': 'src',
}
]
for reset_config in reset_config_list:
field = reset_config.get('field')
tag_list = bf.find_all(**{field: re.compile('^(?!(http:|https:|tel:/|#|mailto:|javascript:)).*')})
for tag in tag_list:
self.reset_url(tag, field, self.base_fork_url)
return bf
@staticmethod
def get_beautiful_soup(response):
encoding = response.encoding if response.encoding is not None and response.encoding != 'ISO-8859-1' else response.apparent_encoding
html_content = response.content.decode(encoding)
beautiful_soup = BeautifulSoup(html_content, "html.parser")
meta_list = beautiful_soup.find_all('meta')
charset_list = [meta.attrs.get('charset') for meta in meta_list if
meta.attrs is not None and 'charset' in meta.attrs]
if len(charset_list) > 0:
charset = charset_list[0]
if charset != encoding:
try:
html_content = response.content.decode(charset)
except Exception as e:
logging.getLogger("max_kb").error(f'{e}')
return BeautifulSoup(html_content, "html.parser")
return beautiful_soup
def fork(self):
try:
headers = {
'user-agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/99.0.4844.51 Safari/537.36'
}
logging.getLogger("max_kb").info(f'fork:{self.base_fork_url}')
response = requests.get(self.base_fork_url, verify=False, headers=headers)
if response.status_code != 200:
logging.getLogger("max_kb").error(f"url: {self.base_fork_url} code:{response.status_code}")
return Fork.Response.error(f"url: {self.base_fork_url} code:{response.status_code}")
bf = self.get_beautiful_soup(response)
except Exception as e:
logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}')
return Fork.Response.error(str(e))
bf = self.reset_beautiful_soup(bf)
link_list = self.get_child_link_list(bf)
content = self.get_content_html(bf)
r = ht.html2text(content)
return Fork.Response.success(r, link_list)
def handler(base_url, response: Fork.Response):
print(base_url.url, base_url.tag.text if base_url.tag else None, response.content)
# ForkManage('https://bbs.fit2cloud.com/c/de/6', ['.md-content']).fork(3, set(), handler)

View File

@ -0,0 +1,417 @@
# coding=utf-8
"""
@project: qabot
@Author
@file split_model.py
@date2023/9/1 15:12
@desc:
"""
import re
from functools import reduce
from typing import List, Dict
import jieba
def get_level_block(text, level_content_list, level_content_index, cursor):
"""
从文本中获取块数据
:param text: 文本
:param level_content_list: 拆分的title数组
:param level_content_index: 指定的下标
:param cursor: 开始的下标位置
:return: 拆分后的文本数据
"""
start_content: str = level_content_list[level_content_index].get('content')
next_content = level_content_list[level_content_index + 1].get("content") if level_content_index + 1 < len(
level_content_list) else None
start_index = text.index(start_content, cursor)
end_index = text.index(next_content, start_index + 1) if next_content is not None else len(text)
return text[start_index + len(start_content):end_index], end_index
def to_tree_obj(content, state='title'):
"""
转换为树形对象
:param content: 文本数据
:param state: 状态: title block
:return: 转换后的数据
"""
return {'content': content, 'state': state}
def remove_special_symbol(str_source: str):
"""
删除特殊字符
:param str_source: 需要删除的文本数据
:return: 删除后的数据
"""
return str_source
def filter_special_symbol(content: dict):
"""
过滤文本中的特殊字符
:param content: 需要过滤的对象
:return: 过滤后返回
"""
content['content'] = remove_special_symbol(content['content'])
return content
def flat(tree_data_list: List[dict], parent_chain: List[dict], result: List[dict]):
"""
扁平化树形结构数据
:param tree_data_list: 树形接口数据
:param parent_chain: 父级数据 [] 用于递归存储数据
:param result: 响应数据 [] 用于递归存放数据
:return: result 扁平化后的数据
"""
if parent_chain is None:
parent_chain = []
if result is None:
result = []
for tree_data in tree_data_list:
p = parent_chain.copy()
p.append(tree_data)
result.append(to_flat_obj(parent_chain, content=tree_data["content"], state=tree_data["state"]))
children = tree_data.get('children')
if children is not None and len(children) > 0:
flat(children, p, result)
return result
def to_paragraph(obj: dict):
"""
转换为段落
:param obj: 需要转换的对象
:return: 段落对象
"""
content = obj['content']
return {"keywords": get_keyword(content),
'parent_chain': list(map(lambda p: p['content'], obj['parent_chain'])),
'content': ",".join(list(map(lambda p: p['content'], obj['parent_chain']))) + content}
def get_keyword(content: str):
"""
获取content中的关键词
:param content: 文本
:return: 关键词数组
"""
stopwords = ['', '', '', '', '\n', '\\s']
cutworms = jieba.lcut(content)
return list(set(list(filter(lambda k: (k not in stopwords) | len(k) > 1, cutworms))))
def titles_to_paragraph(list_title: List[dict]):
"""
将同一父级的title转换为块段落
:param list_title: 同父级title
:return: 块段落
"""
if len(list_title) > 0:
content = "\n,".join(
list(map(lambda d: d['content'].strip("\r\n").strip("\n").strip("\\s"), list_title)))
return {'keywords': '',
'parent_chain': list(
map(lambda p: p['content'].strip("\r\n").strip("\n").strip("\\s"), list_title[0]['parent_chain'])),
'content': ",".join(list(
map(lambda p: p['content'].strip("\r\n").strip("\n").strip("\\s"),
list_title[0]['parent_chain']))) + content}
return None
def parse_group_key(level_list: List[dict]):
"""
将同级别同父级的title生成段落,加上本身的段落数据形成新的数据
:param level_list: title n 级数据
:return: 根据title生成的数据 + 段落数据
"""
result = []
group_data = group_by(list(filter(lambda f: f['state'] == 'title' and len(f['parent_chain']) > 0, level_list)),
key=lambda d: ",".join(list(map(lambda p: p['content'], d['parent_chain']))))
result += list(map(lambda group_data_key: titles_to_paragraph(group_data[group_data_key]), group_data))
result += list(map(to_paragraph, list(filter(lambda f: f['state'] == 'block', level_list))))
return result
def to_block_paragraph(tree_data_list: List[dict]):
"""
转换为块段落对象
:param tree_data_list: 树数据
:return: 块段落
"""
flat_list = flat(tree_data_list, [], [])
level_group_dict: dict = group_by(flat_list, key=lambda f: f['level'])
return list(map(lambda level: parse_group_key(level_group_dict[level]), level_group_dict))
def parse_title_level(text, content_level_pattern: List, index):
if index >= len(content_level_pattern):
return []
result = parse_level(text, content_level_pattern[index])
if len(result) == 0 and len(content_level_pattern) > index:
return parse_title_level(text, content_level_pattern, index + 1)
return result
def parse_level(text, pattern: str):
"""
获取正则匹配到的文本
:param text: 需要匹配的文本
:param pattern: 正则
:return: 符合正则的文本
"""
level_content_list = list(map(to_tree_obj, [r[0:255] for r in re_findall(pattern, text) if r is not None]))
return list(map(filter_special_symbol, level_content_list))
def re_findall(pattern, text):
result = re.findall(pattern, text, flags=0)
return list(filter(lambda r: r is not None and len(r) > 0, reduce(lambda x, y: [*x, *y], list(
map(lambda row: [*(row if isinstance(row, tuple) else [row])], result)),
[])))
def to_flat_obj(parent_chain: List[dict], content: str, state: str):
"""
将树形属性转换为扁平对象
:param parent_chain:
:param content:
:param state:
:return:
"""
return {'parent_chain': parent_chain, 'level': len(parent_chain), "content": content, 'state': state}
def flat_map(array: List[List]):
"""
将二位数组转为一维数组
:param array: 二维数组
:return: 一维数组
"""
result = []
for e in array:
result += e
return result
def group_by(list_source: List, key):
"""
將數組分組
:param list_source: 需要分組的數組
:param key: 分組函數
:return: key->[]
"""
result = {}
for e in list_source:
k = key(e)
array = result.get(k) if k in result else []
array.append(e)
result[k] = array
return result
def result_tree_to_paragraph(result_tree: List[dict], result, parent_chain, with_filter: bool):
"""
转换为分段对象
:param result_tree: 解析文本的树
:param result: [] 用于递归
:param parent_chain: [] 用户递归存储数据
:param with_filter: 是否过滤block
:return: List[{'problem':'xx','content':'xx'}]
"""
for item in result_tree:
if item.get('state') == 'block':
result.append({'title': " ".join(parent_chain),
'content': filter_special_char(item.get("content")) if with_filter else item.get("content")})
children = item.get("children")
if children is not None and len(children) > 0:
result_tree_to_paragraph(children, result,
[*parent_chain, remove_special_symbol(item.get('content'))], with_filter)
return result
def post_handler_paragraph(content: str, limit: int):
"""
根据文本的最大字符分段
:param content: 需要分段的文本字段
:param limit: 最大分段字符
:return: 分段后数据
"""
result = []
temp_char, start = '', 0
while (pos := content.find("\n", start)) != -1:
split, start = content[start:pos + 1], pos + 1
if len(temp_char + split) > limit:
if len(temp_char) > 4096:
pass
result.append(temp_char)
temp_char = ''
temp_char = temp_char + split
temp_char = temp_char + content[start:]
if len(temp_char) > 0:
if len(temp_char) > 4096:
pass
result.append(temp_char)
pattern = "[\\S\\s]{1," + str(limit) + '}'
# 如果\n 单段超过限制,则继续拆分
return reduce(lambda x, y: [*x, *y], map(lambda row: re.findall(pattern, row), result), [])
replace_map = {
re.compile('\n+'): '\n',
re.compile(' +'): ' ',
re.compile('#+'): "",
re.compile("\t+"): ''
}
def filter_special_char(content: str):
"""
过滤特殊字段
:param content: 文本
:return: 过滤后字段
"""
items = replace_map.items()
for key, value in items:
content = re.sub(key, value, content)
return content
class SplitModel:
def __init__(self, content_level_pattern, with_filter=True, limit=100000):
self.content_level_pattern = content_level_pattern
self.with_filter = with_filter
if limit is None or limit > 100000:
limit = 100000
if limit < 50:
limit = 50
self.limit = limit
def parse_to_tree(self, text: str, index=0):
"""
解析文本
:param text: 需要解析的文本
:param index: 从那个正则开始解析
:return: 解析后的树形结果数据
"""
level_content_list = parse_title_level(text, self.content_level_pattern, index)
if len(level_content_list) == 0:
return [to_tree_obj(row, 'block') for row in post_handler_paragraph(text, limit=self.limit)]
if index == 0 and text.lstrip().index(level_content_list[0]["content"].lstrip()) != 0:
level_content_list.insert(0, to_tree_obj(""))
cursor = 0
level_title_content_list = [item for item in level_content_list if item.get('state') == 'title']
for i in range(len(level_title_content_list)):
start_content: str = level_title_content_list[i].get('content')
if cursor < text.index(start_content, cursor):
for row in post_handler_paragraph(text[cursor: text.index(start_content, cursor)], limit=self.limit):
level_content_list.insert(0, to_tree_obj(row, 'block'))
block, cursor = get_level_block(text, level_title_content_list, i, cursor)
if len(block) == 0:
continue
children = self.parse_to_tree(text=block, index=index + 1)
level_title_content_list[i]['children'] = children
first_child_idx_in_block = block.lstrip().index(children[0]["content"].lstrip())
if first_child_idx_in_block != 0:
inner_children = self.parse_to_tree(block[:first_child_idx_in_block], index + 1)
level_title_content_list[i]['children'].extend(inner_children)
return level_content_list
def parse(self, text: str):
"""
解析文本
:param text: 文本数据
:return: 解析后数据 {content:段落数据,keywords:[段落关键词],parent_chain:['段落父级链路']}
"""
text = text.replace('\r\n', '\n')
text = text.replace('\r', '\n')
text = text.replace("\0", '')
result_tree = self.parse_to_tree(text, 0)
result = result_tree_to_paragraph(result_tree, [], [], self.with_filter)
for e in result:
if len(e['content']) > 4096:
pass
title_list = list(set([row.get('title') for row in result]))
return [item for item in [self.post_reset_paragraph(row, title_list) for row in result] if
'content' in item and len(item.get('content').strip()) > 0]
def post_reset_paragraph(self, paragraph: Dict, title_list: List[str]):
result = self.content_is_null(paragraph, title_list)
result = self.filter_title_special_characters(result)
result = self.sub_title(result)
return result
@staticmethod
def sub_title(paragraph: Dict):
if 'title' in paragraph:
title = paragraph.get('title')
if len(title) > 255:
return {**paragraph, 'title': title[0:255], 'content': title[255:len(title)] + paragraph.get('content')}
return paragraph
@staticmethod
def content_is_null(paragraph: Dict, title_list: List[str]):
if 'title' in paragraph:
title = paragraph.get('title')
content = paragraph.get('content')
if (content is None or len(content.strip()) == 0) and (title is not None and len(title) > 0):
find = [t for t in title_list if t.__contains__(title) and t != title]
if find:
return {'title': '', 'content': ''}
return {'title': '', 'content': title}
return paragraph
@staticmethod
def filter_title_special_characters(paragraph: Dict):
title = paragraph.get('title') if 'title' in paragraph else ''
for title_special_characters in title_special_characters_list:
title = title.replace(title_special_characters, '')
return {**paragraph,
'title': title}
title_special_characters_list = ['#', '\n', '\r', '\\s']
default_split_pattern = {
'md': [re.compile('(?<=^)# .*|(?<=\\n)# .*'),
re.compile('(?<=\\n)(?<!#)## (?!#).*|(?<=^)(?<!#)## (?!#).*'),
re.compile("(?<=\\n)(?<!#)### (?!#).*|(?<=^)(?<!#)### (?!#).*"),
re.compile("(?<=\\n)(?<!#)#### (?!#).*|(?<=^)(?<!#)#### (?!#).*"),
re.compile("(?<=\\n)(?<!#)##### (?!#).*|(?<=^)(?<!#)##### (?!#).*"),
re.compile("(?<=\\n)(?<!#)###### (?!#).*|(?<=^)(?<!#)###### (?!#).*")],
'default': [re.compile("(?<!\n)\n\n+")]
}
def get_split_model(filename: str, with_filter: bool = False, limit: int = 100000):
"""
根据文件名称获取分段模型
:param limit: 每段大小
:param with_filter: 是否过滤特殊字符
:param filename: 文件名称
:return: 分段模型
"""
if filename.endswith(".md"):
pattern_list = default_split_pattern.get('md')
return SplitModel(pattern_list, with_filter=with_filter, limit=limit)
pattern_list = default_split_pattern.get('md')
return SplitModel(pattern_list, with_filter=with_filter, limit=limit)
def to_title_tree_string(result_tree: List):
f = flat(result_tree, [], [])
return "\n".join(list(map(lambda r: title_tostring(r), list(filter(lambda row: row.get('state') == 'title', f)))))
def title_tostring(title_obj):
f = "".join(list(map(lambda index: " ", range(0, len(title_obj.get("parent_chain"))))))
return f + "├───" + title_obj.get('content')

View File

@ -4,16 +4,16 @@ from drf_spectacular.utils import OpenApiParameter
from common.mixins.api_mixin import APIMixin
from common.result import ResultSerializer, DefaultResultSerializer
from modules.models.module import ModuleCreateRequest, ModuleEditRequest
from modules.serializers.module import ModuleSerializer
from folders.models.folder import FolderCreateRequest, FolderEditRequest
from folders.serializers.folder import FolderSerializer
class ModuleCreateResponse(ResultSerializer):
class FolderCreateResponse(ResultSerializer):
def get_data(self):
return ModuleSerializer()
return FolderSerializer()
class ModuleCreateAPI(APIMixin):
class FolderCreateAPI(APIMixin):
@staticmethod
def get_parameters():
return [
@ -36,14 +36,14 @@ class ModuleCreateAPI(APIMixin):
@staticmethod
def get_request():
return ModuleCreateRequest
return FolderCreateRequest
@staticmethod
def get_response():
return ModuleCreateResponse
return FolderCreateResponse
class ModuleReadAPI(APIMixin):
class FolderReadAPI(APIMixin):
@staticmethod
def get_parameters():
return [
@ -63,8 +63,8 @@ class ModuleReadAPI(APIMixin):
required=True,
),
OpenApiParameter(
name="module_id",
description="模块id",
name="folder_id",
description="文件夹id",
type=OpenApiTypes.STR,
location='path',
required=True,
@ -73,23 +73,23 @@ class ModuleReadAPI(APIMixin):
@staticmethod
def get_response():
return ModuleCreateResponse
return FolderCreateResponse
class ModuleEditAPI(ModuleReadAPI):
class FolderEditAPI(FolderReadAPI):
@staticmethod
def get_request():
return ModuleEditRequest
return FolderEditRequest
class ModuleDeleteAPI(ModuleReadAPI):
class FolderDeleteAPI(FolderReadAPI):
@staticmethod
def get_response():
return DefaultResultSerializer
class ModuleTreeReadAPI(APIMixin):
class FolderTreeReadAPI(APIMixin):
@staticmethod
def get_parameters():
return [

View File

@ -2,14 +2,14 @@ from django.utils.translation import gettext_lazy as _
from rest_framework import serializers
class ModuleCreateRequest(serializers.Serializer):
name = serializers.CharField(required=True, label=_('module name'))
class FolderCreateRequest(serializers.Serializer):
name = serializers.CharField(required=True, label=_('folder name'))
parent_id = serializers.CharField(required=False, allow_null=True, allow_blank=True, default='root',
label=_('parent id'))
class ModuleEditRequest(serializers.Serializer):
name = serializers.CharField(required=False, allow_null=True, allow_blank=True, label=_('module name'))
class FolderEditRequest(serializers.Serializer):
name = serializers.CharField(required=False, allow_null=True, allow_blank=True, label=_('folder name'))
parent_id = serializers.CharField(required=False, allow_null=True, allow_blank=True, default='root',
label=_('parent id'))

View File

@ -7,30 +7,30 @@ from django.utils.translation import gettext_lazy as _
from rest_framework import serializers
from common.constants.permission_constants import Group
from knowledge.models import KnowledgeModule
from modules.api.module import ModuleCreateRequest
from tools.models import ToolModule
from tools.serializers.tool_module import ToolModuleTreeSerializer
from knowledge.models import KnowledgeFolder
from folders.api.folder import FolderCreateRequest
from tools.models import ToolFolder
from tools.serializers.tool_folder import ToolFolderTreeSerializer
def get_module_type(source):
def get_folder_type(source):
if source == Group.TOOL.name:
return ToolModule
return ToolFolder
elif source == Group.APPLICATION.name:
# todo app module
# todo app folder
return None
elif source == Group.KNOWLEDGE.name:
return KnowledgeModule
return KnowledgeFolder
else:
return None
MODULE_DEPTH = 2 # Module 不能超过3层
FOLDER_DEPTH = 2 # Folder 不能超过3层
def check_depth(source, parent_id, current_depth=0):
# Module 不能超过3层
Module = get_module_type(source)
# Folder 不能超过3层
Folder = get_folder_type(source)
if parent_id != 'root':
# 计算当前层级
@ -40,14 +40,14 @@ def check_depth(source, parent_id, current_depth=0):
# 向上追溯父节点
while current_parent_id != 'root':
depth += 1
parent_node = QuerySet(Module).filter(id=current_parent_id).first()
parent_node = QuerySet(Folder).filter(id=current_parent_id).first()
if parent_node is None:
break
current_parent_id = parent_node.parent_id
# 验证层级深度
if depth + current_depth > MODULE_DEPTH:
raise serializers.ValidationError(_('Module depth cannot exceed 3 levels'))
if depth + current_depth > FOLDER_DEPTH:
raise serializers.ValidationError(_('Folder depth cannot exceed 3 levels'))
def get_max_depth(current_node):
@ -68,10 +68,10 @@ def get_max_depth(current_node):
return max_depth
class ModuleSerializer(serializers.Serializer):
id = serializers.CharField(required=True, label=_('module id'))
name = serializers.CharField(required=True, label=_('module name'))
user_id = serializers.CharField(required=True, label=_('module user id'))
class FolderSerializer(serializers.Serializer):
id = serializers.CharField(required=True, label=_('folder id'))
name = serializers.CharField(required=True, label=_('folder name'))
user_id = serializers.CharField(required=True, label=_('folder user id'))
workspace_id = serializers.CharField(required=False, allow_null=True, allow_blank=True, label=_('workspace id'))
parent_id = serializers.CharField(required=False, allow_null=True, allow_blank=True, label=_('parent id'))
@ -82,84 +82,84 @@ class ModuleSerializer(serializers.Serializer):
def insert(self, instance, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
ModuleCreateRequest(data=instance).is_valid(raise_exception=True)
FolderCreateRequest(data=instance).is_valid(raise_exception=True)
workspace_id = self.data.get('workspace_id', 'default')
parent_id = instance.get('parent_id', 'root')
name = instance.get('name')
Module = get_module_type(self.data.get('source'))
if QuerySet(Module).filter(name=name, workspace_id=workspace_id, parent_id=parent_id).exists():
raise serializers.ValidationError(_('Module name already exists'))
# Module 不能超过3层
Folder = get_folder_type(self.data.get('source'))
if QuerySet(Folder).filter(name=name, workspace_id=workspace_id, parent_id=parent_id).exists():
raise serializers.ValidationError(_('Folder name already exists'))
# Folder 不能超过3层
check_depth(self.data.get('source'), parent_id)
module = Module(
folder = Folder(
id=uuid.uuid7(),
name=instance.get('name'),
user_id=self.data.get('user_id'),
workspace_id=workspace_id,
parent_id=parent_id
)
module.save()
return ModuleSerializer(module).data
folder.save()
return FolderSerializer(folder).data
class Operate(serializers.Serializer):
id = serializers.CharField(required=True, label=_('module id'))
id = serializers.CharField(required=True, label=_('folder id'))
workspace_id = serializers.CharField(required=True, allow_null=True, allow_blank=True, label=_('workspace id'))
source = serializers.CharField(required=True, label=_('source'))
@transaction.atomic
def edit(self, instance):
self.is_valid(raise_exception=True)
Module = get_module_type(self.data.get('source'))
Folder = get_folder_type(self.data.get('source'))
current_id = self.data.get('id')
current_node = Module.objects.get(id=current_id)
current_node = Folder.objects.get(id=current_id)
if current_node is None:
raise serializers.ValidationError(_('Module does not exist'))
raise serializers.ValidationError(_('Folder does not exist'))
edit_field_list = ['name']
edit_dict = {field: instance.get(field) for field in edit_field_list if (
field in instance and instance.get(field) is not None)}
QuerySet(Module).filter(id=current_id).update(**edit_dict)
QuerySet(Folder).filter(id=current_id).update(**edit_dict)
# 模块间的移动
parent_id = instance.get('parent_id')
if parent_id is not None and current_id != 'root':
# Module 不能超过3层
# Folder 不能超过3层
current_depth = get_max_depth(current_node)
check_depth(self.data.get('source'), parent_id, current_depth)
parent = Module.objects.get(id=parent_id)
parent = Folder.objects.get(id=parent_id)
current_node.move_to(parent)
return self.one()
def one(self):
self.is_valid(raise_exception=True)
Module = get_module_type(self.data.get('source'))
module = QuerySet(Module).filter(id=self.data.get('id')).first()
return ModuleSerializer(module).data
Folder = get_folder_type(self.data.get('source'))
folder = QuerySet(Folder).filter(id=self.data.get('id')).first()
return FolderSerializer(folder).data
def delete(self):
self.is_valid(raise_exception=True)
if self.data.get('id') == 'root':
raise serializers.ValidationError(_('Cannot delete root module'))
Module = get_module_type(self.data.get('source'))
QuerySet(Module).filter(id=self.data.get('id')).delete()
raise serializers.ValidationError(_('Cannot delete root folder'))
Folder = get_folder_type(self.data.get('source'))
QuerySet(Folder).filter(id=self.data.get('id')).delete()
class ModuleTreeSerializer(serializers.Serializer):
class FolderTreeSerializer(serializers.Serializer):
workspace_id = serializers.CharField(required=True, allow_null=True, allow_blank=True, label=_('workspace id'))
source = serializers.CharField(required=True, label=_('source'))
def get_module_tree(self, name=None):
def get_folder_tree(self, name=None):
self.is_valid(raise_exception=True)
Module = get_module_type(self.data.get('source'))
Folder = get_folder_type(self.data.get('source'))
if name is not None:
nodes = Module.objects.filter(Q(workspace_id=self.data.get('workspace_id')) &
nodes = Folder.objects.filter(Q(workspace_id=self.data.get('workspace_id')) &
Q(name__contains=name)).get_cached_trees()
else:
nodes = Module.objects.filter(Q(workspace_id=self.data.get('workspace_id'))).get_cached_trees()
serializer = ToolModuleTreeSerializer(nodes, many=True)
nodes = Folder.objects.filter(Q(workspace_id=self.data.get('workspace_id'))).get_cached_trees()
serializer = ToolFolderTreeSerializer(nodes, many=True)
return serializer.data # 这是可序列化的字典

9
apps/folders/urls.py Normal file
View File

@ -0,0 +1,9 @@
from django.urls import path
from . import views
app_name = "folder"
urlpatterns = [
path('workspace/<str:workspace_id>/<str:source>/folder', views.FolderView.as_view()),
path('workspace/<str:workspace_id>/<str:source>/folder/<str:folder_id>', views.FolderView.Operate.as_view()),
]

View File

@ -0,0 +1 @@
from .folder import *

View File

@ -7,26 +7,26 @@ from common.auth import TokenAuth
from common.auth.authentication import has_permissions
from common.constants.permission_constants import Permission, Group, Operate
from common.result import result
from modules.api.module import ModuleCreateAPI, ModuleEditAPI, ModuleReadAPI, ModuleTreeReadAPI, ModuleDeleteAPI
from modules.serializers.module import ModuleSerializer, ModuleTreeSerializer
from folders.api.folder import FolderCreateAPI, FolderEditAPI, FolderReadAPI, FolderTreeReadAPI, FolderDeleteAPI
from folders.serializers.folder import FolderSerializer, FolderTreeSerializer
class ModuleView(APIView):
class FolderView(APIView):
authentication_classes = [TokenAuth]
@extend_schema(
methods=['POST'],
description=_('Create module'),
operation_id=_('Create module'),
parameters=ModuleCreateAPI.get_parameters(),
request=ModuleCreateAPI.get_request(),
responses=ModuleCreateAPI.get_response(),
tags=[_('Module')]
description=_('Create folder'),
operation_id=_('Create folder'),
parameters=FolderCreateAPI.get_parameters(),
request=FolderCreateAPI.get_request(),
responses=FolderCreateAPI.get_response(),
tags=[_('Folder')]
)
@has_permissions(lambda r, kwargs: Permission(group=Group(kwargs.get('source')), operate=Operate.CREATE,
resource_path=f"/WORKSPACE/{kwargs.get('workspace_id')}"))
def post(self, request: Request, workspace_id: str, source: str):
return result.success(ModuleSerializer.Create(
return result.success(FolderSerializer.Create(
data={'user_id': request.user.id,
'source': source,
'workspace_id': workspace_id}
@ -34,64 +34,64 @@ class ModuleView(APIView):
@extend_schema(
methods=['GET'],
description=_('Get module tree'),
operation_id=_('Get module tree'),
parameters=ModuleTreeReadAPI.get_parameters(),
responses=ModuleTreeReadAPI.get_response(),
tags=[_('Module')]
description=_('Get folder tree'),
operation_id=_('Get folder tree'),
parameters=FolderTreeReadAPI.get_parameters(),
responses=FolderTreeReadAPI.get_response(),
tags=[_('Folder')]
)
@has_permissions(lambda r, kwargs: Permission(group=Group(kwargs.get('source')), operate=Operate.READ,
resource_path=f"/WORKSPACE/{kwargs.get('workspace_id')}"))
def get(self, request: Request, workspace_id: str, source: str):
return result.success(ModuleTreeSerializer(
return result.success(FolderTreeSerializer(
data={'workspace_id': workspace_id, 'source': source}
).get_module_tree(request.query_params.get('name')))
).get_folder_tree(request.query_params.get('name')))
class Operate(APIView):
authentication_classes = [TokenAuth]
@extend_schema(
methods=['PUT'],
description=_('Update module'),
operation_id=_('Update module'),
parameters=ModuleEditAPI.get_parameters(),
request=ModuleEditAPI.get_request(),
responses=ModuleEditAPI.get_response(),
tags=[_('Module')]
description=_('Update folder'),
operation_id=_('Update folder'),
parameters=FolderEditAPI.get_parameters(),
request=FolderEditAPI.get_request(),
responses=FolderEditAPI.get_response(),
tags=[_('Folder')]
)
@has_permissions(lambda r, kwargs: Permission(group=Group(kwargs.get('source')), operate=Operate.EDIT,
resource_path=f"/WORKSPACE/{kwargs.get('workspace_id')}"))
def put(self, request: Request, workspace_id: str, source: str, module_id: str):
return result.success(ModuleSerializer.Operate(
data={'id': module_id, 'workspace_id': workspace_id, 'source': source}
def put(self, request: Request, workspace_id: str, source: str, folder_id: str):
return result.success(FolderSerializer.Operate(
data={'id': folder_id, 'workspace_id': workspace_id, 'source': source}
).edit(request.data))
@extend_schema(
methods=['GET'],
description=_('Get module'),
operation_id=_('Get module'),
parameters=ModuleReadAPI.get_parameters(),
responses=ModuleReadAPI.get_response(),
tags=[_('Module')]
description=_('Get folder'),
operation_id=_('Get folder'),
parameters=FolderReadAPI.get_parameters(),
responses=FolderReadAPI.get_response(),
tags=[_('Folder')]
)
@has_permissions(lambda r, kwargs: Permission(group=Group(kwargs.get('source')), operate=Operate.READ,
resource_path=f"/WORKSPACE/{kwargs.get('workspace_id')}"))
def get(self, request: Request, workspace_id: str, source: str, module_id: str):
return result.success(ModuleSerializer.Operate(
data={'id': module_id, 'workspace_id': workspace_id, 'source': source}
def get(self, request: Request, workspace_id: str, source: str, folder_id: str):
return result.success(FolderSerializer.Operate(
data={'id': folder_id, 'workspace_id': workspace_id, 'source': source}
).one())
@extend_schema(
methods=['DELETE'],
description=_('Delete module'),
operation_id=_('Delete module'),
parameters=ModuleDeleteAPI.get_parameters(),
responses=ModuleDeleteAPI.get_response(),
tags=[_('Module')]
description=_('Delete folder'),
operation_id=_('Delete folder'),
parameters=FolderDeleteAPI.get_parameters(),
responses=FolderDeleteAPI.get_response(),
tags=[_('Folder')]
)
@has_permissions(lambda r, kwargs: Permission(group=Group(kwargs.get('source')), operate=Operate.DELETE,
resource_path=f"/WORKSPACE/{kwargs.get('workspace_id')}"))
def delete(self, request: Request, workspace_id: str, source: str, module_id: str):
return result.success(ModuleSerializer.Operate(
data={'id': module_id, 'workspace_id': workspace_id, 'source': source}
def delete(self, request: Request, workspace_id: str, source: str, folder_id: str):
return result.success(FolderSerializer.Operate(
data={'id': folder_id, 'workspace_id': workspace_id, 'source': source}
).delete())

View File

@ -67,8 +67,8 @@ class KnowledgeTreeReadAPI(APIMixin):
required=True,
),
OpenApiParameter(
name="module_id",
description="模块id",
name="folder_id",
description="文件夹id",
type=OpenApiTypes.STR,
location='query',
required=False,

View File

@ -6,12 +6,12 @@ import mptt.fields
import uuid_utils.compat
from django.db import migrations, models
from knowledge.models import KnowledgeModule
from knowledge.models import KnowledgeFolder
def insert_default_data(apps, schema_editor):
# 创建一个根模块(没有父节点)
KnowledgeModule.objects.create(id='root', name='根目录', user_id='f0dd8f71-e4ee-11ee-8c84-a8a1595801ab')
KnowledgeFolder.objects.create(id='root', name='根目录', user_id='f0dd8f71-e4ee-11ee-8c84-a8a1595801ab')
class Migration(migrations.Migration):
@ -42,7 +42,7 @@ class Migration(migrations.Migration):
},
),
migrations.CreateModel(
name='KnowledgeModule',
name='KnowledgeFolder',
fields=[
('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')),
('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')),
@ -57,12 +57,12 @@ class Migration(migrations.Migration):
('level', models.PositiveIntegerField(editable=False)),
('parent',
mptt.fields.TreeForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE,
related_name='children', to='knowledge.knowledgemodule')),
related_name='children', to='knowledge.knowledgefolder')),
('user', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='users.user',
verbose_name='用户id')),
],
options={
'db_table': 'knowledge_module',
'db_table': 'knowledge_folder',
},
),
migrations.CreateModel(
@ -84,10 +84,10 @@ class Migration(migrations.Migration):
('scope',
models.CharField(choices=[('SHARED', '共享'), ('WORKSPACE', '工作空间可用')], default='WORKSPACE',
max_length=20, verbose_name='可用范围')),
('module',
('folder',
models.ForeignKey(default='root', on_delete=django.db.models.deletion.CASCADE,
to='knowledge.knowledgemodule',
verbose_name='模块id')),
to='knowledge.knowledgefolder',
verbose_name='文件夹id')),
('embedding_model', models.ForeignKey(default=knowledge.models.knowledge.default_model,
on_delete=django.db.models.deletion.DO_NOTHING,
to='models_provider.model', verbose_name='向量模型')),

View File

@ -28,7 +28,7 @@ def default_model():
return uuid.UUID('42f63a3d-427e-11ef-b3ec-a8a1595801ab')
class KnowledgeModule(MPTTModel, AppModelMixin):
class KnowledgeFolder(MPTTModel, AppModelMixin):
id = models.CharField(primary_key=True, max_length=64, editable=False, verbose_name="主键id")
name = models.CharField(max_length=64, verbose_name="文件夹名称")
user = models.ForeignKey(User, on_delete=models.DO_NOTHING, verbose_name="用户id")
@ -36,7 +36,7 @@ class KnowledgeModule(MPTTModel, AppModelMixin):
parent = TreeForeignKey('self', on_delete=models.CASCADE, null=True, blank=True, related_name='children')
class Meta:
db_table = "knowledge_module"
db_table = "knowledge_folder"
class MPTTMeta:
order_insertion_by = ['name']
@ -51,7 +51,7 @@ class Knowledge(AppModelMixin):
user = models.ForeignKey(User, on_delete=models.DO_NOTHING, verbose_name="所属用户")
type = models.IntegerField(verbose_name='类型', choices=KnowledgeType.choices, default=KnowledgeType.BASE)
scope = models.CharField(max_length=20, verbose_name='可用范围', choices=KnowledgeScope.choices, default=KnowledgeScope.WORKSPACE)
module = models.ForeignKey(KnowledgeModule, on_delete=models.CASCADE, verbose_name="模块id", default='root')
folder = models.ForeignKey(KnowledgeFolder, on_delete=models.CASCADE, verbose_name="文件夹id", default='root')
embedding_model = models.ForeignKey(Model, on_delete=models.DO_NOTHING, verbose_name="向量模型",
default=default_model)
meta = models.JSONField(verbose_name="元数据", default=dict)

View File

@ -1,25 +1,36 @@
from typing import Dict
import uuid_utils as uuid
from django.db import transaction
from django.db.models import QuerySet
from django.utils.translation import gettext_lazy as _
from rest_framework import serializers
from common.exception.app_exception import AppApiException
from common.utils.common import valid_license
from knowledge.models import Knowledge, KnowledgeScope, KnowledgeType
class KnowledgeModelSerializer(serializers.ModelSerializer):
class Meta:
model = Knowledge
fields = ['id', 'name', 'desc', 'meta', 'module_id', 'type', 'workspace_id', 'create_time', 'update_time']
fields = ['id', 'name', 'desc', 'meta', 'folder_id', 'type', 'workspace_id', 'create_time', 'update_time']
class KnowledgeBaseCreateRequest(serializers.Serializer):
name = serializers.CharField(required=True, label=_('knowledge name'))
folder_id = serializers.CharField(required=True, label=_('folder id'))
desc = serializers.CharField(required=False, allow_null=True, allow_blank=True, label=_('knowledge description'))
embedding = serializers.CharField(required=True, label=_('knowledge embedding'))
class KnowledgeWebCreateRequest(serializers.Serializer):
name = serializers.CharField(required=True, label=_('knowledge name'))
folder_id = serializers.CharField(required=True, label=_('folder id'))
desc = serializers.CharField(required=False, allow_null=True, allow_blank=True, label=_('knowledge description'))
embedding = serializers.CharField(required=True, label=_('knowledge embedding'))
source_url = serializers.CharField(required=True, label=_('source url'))
selector = serializers.CharField(required=True, label=_('knowledge selector'))
class KnowledgeSerializer(serializers.Serializer):
@ -27,10 +38,19 @@ class KnowledgeSerializer(serializers.Serializer):
user_id = serializers.UUIDField(required=True, label=_('user id'))
workspace_id = serializers.CharField(required=True, label=_('workspace id'))
def insert(self, instance, with_valid=True):
@valid_license(model=Knowledge, count=50,
message=_(
'The community version supports up to 50 knowledge bases. If you need more knowledge bases, please contact us (https://fit2cloud.com/).'))
# @post(post_function=post_embedding_dataset)
@transaction.atomic
def save_base(self, instance, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
KnowledgeBaseCreateRequest(data=instance).is_valid(raise_exception=True)
if QuerySet(Knowledge).filter(workspace_id=self.data.get('workspace_id'),
name=instance.get('name')).exists():
raise AppApiException(500, _('Knowledge base name duplicate!'))
knowledge = Knowledge(
id=uuid.uuid7(),
name=instance.get('name'),
@ -39,13 +59,43 @@ class KnowledgeSerializer(serializers.Serializer):
type=instance.get('type', KnowledgeType.BASE),
user_id=self.data.get('user_id'),
scope=KnowledgeScope.WORKSPACE,
module_id=instance.get('module_id', 'root'),
folder_id=instance.get('folder_id', 'root'),
embedding_model_id=instance.get('embedding'),
meta=instance.get('meta', {}),
)
knowledge.save()
return KnowledgeModelSerializer(knowledge).data
def save_web(self, instance: Dict, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
KnowledgeWebCreateRequest(data=instance).is_valid(raise_exception=True)
if QuerySet(Knowledge).filter(workspace_id=self.data.get('workspace_id'),
name=instance.get('name')).exists():
raise AppApiException(500, _('Knowledge base name duplicate!'))
knowledge_id = uuid.uuid7()
knowledge = Knowledge(
id=knowledge_id,
name=instance.get('name'),
desc=instance.get('desc'),
user_id=self.data.get('user_id'),
type=instance.get('type', KnowledgeType.WEB),
scope=KnowledgeScope.WORKSPACE,
folder_id=instance.get('folder_id', 'root'),
embedding_model_id=instance.get('embedding'),
meta={
'source_url': instance.get('source_url'),
'selector': instance.get('selector'),
'embedding_model_id': instance.get('embedding')
},
)
knowledge.save()
# sync_web_knowledge.delay(str(knowledge_id), instance.get('source_url'), instance.get('selector'))
return {**KnowledgeModelSerializer(knowledge).data,
'document_list': []}
class KnowledgeTreeSerializer(serializers.Serializer):
def get_knowledge_list(self, param):

View File

@ -0,0 +1 @@
from .sync import *

View File

@ -0,0 +1,63 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file sync.py
@date2024/8/20 21:37
@desc:
"""
import logging
import traceback
from typing import List
from celery_once import QueueOnce
from common.utils.fork import ForkManage, Fork
from .tools import get_save_handler, get_sync_web_document_handler, get_sync_handler
from ops import celery_app
from django.utils.translation import gettext_lazy as _
max_kb_error = logging.getLogger("max_kb_error")
max_kb = logging.getLogger("max_kb")
@celery_app.task(base=QueueOnce, once={'keys': ['knowledge_id']}, name='celery:sync_web_knowledge')
def sync_web_knowledge(knowledge_id: str, url: str, selector: str):
try:
max_kb.info(
_('Start--->Start synchronization web knowledge base:{knowledge_id}').format(knowledge_id=knowledge_id))
ForkManage(url, selector.split(" ") if selector is not None else []).fork(2, set(),
get_save_handler(knowledge_id,
selector))
max_kb.info(_('End--->End synchronization web knowledge base:{knowledge_id}').format(knowledge_id=knowledge_id))
except Exception as e:
max_kb_error.error(_('Synchronize web knowledge base:{knowledge_id} error{error}{traceback}').format(
knowledge_id=knowledge_id, error=str(e), traceback=traceback.format_exc()))
@celery_app.task(base=QueueOnce, once={'keys': ['knowledge_id']}, name='celery:sync_replace_web_knowledge')
def sync_replace_web_knowledge(knowledge_id: str, url: str, selector: str):
try:
max_kb.info(
_('Start--->Start synchronization web knowledge base:{knowledge_id}').format(knowledge_id=knowledge_id))
ForkManage(url, selector.split(" ") if selector is not None else []).fork(2, set(),
get_sync_handler(knowledge_id
))
max_kb.info(_('End--->End synchronization web knowledge base:{knowledge_id}').format(knowledge_id=knowledge_id))
except Exception as e:
max_kb_error.error(_('Synchronize web knowledge base:{knowledge_id} error{error}{traceback}').format(
knowledge_id=knowledge_id, error=str(e), traceback=traceback.format_exc()))
@celery_app.task(name='celery:sync_web_document')
def sync_web_document(knowledge_id, source_url_list: List[str], selector: str):
handler = get_sync_web_document_handler(knowledge_id)
for source_url in source_url_list:
try:
result = Fork(base_fork_url=source_url, selector_list=selector.split(' ')).fork()
handler(source_url, selector, result)
except Exception as e:
pass

View File

@ -0,0 +1,114 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file tools.py
@date2024/8/20 21:48
@desc:
"""
import logging
import re
import traceback
from django.db.models import QuerySet
from common.utils.fork import ChildLink, Fork
from common.utils.split_model import get_split_model
from knowledge.models.knowledge import KnowledgeType, Document, DataSet, Status
from django.utils.translation import gettext_lazy as _
max_kb_error = logging.getLogger("max_kb_error")
max_kb = logging.getLogger("max_kb")
def get_save_handler(dataset_id, selector):
from knowledge.serializers.document_serializers import DocumentSerializers
def handler(child_link: ChildLink, response: Fork.Response):
if response.status == 200:
try:
document_name = child_link.tag.text if child_link.tag is not None and len(
child_link.tag.text.strip()) > 0 else child_link.url
paragraphs = get_split_model('web.md').parse(response.content)
DocumentSerializers.Create(data={'dataset_id': dataset_id}).save(
{'name': document_name, 'paragraphs': paragraphs,
'meta': {'source_url': child_link.url, 'selector': selector},
'type': KnowledgeType.WEB}, with_valid=True)
except Exception as e:
logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}')
return handler
def get_sync_handler(dataset_id):
from knowledge.serializers.document_serializers import DocumentSerializers
dataset = QuerySet(DataSet).filter(id=dataset_id).first()
def handler(child_link: ChildLink, response: Fork.Response):
if response.status == 200:
try:
document_name = child_link.tag.text if child_link.tag is not None and len(
child_link.tag.text.strip()) > 0 else child_link.url
paragraphs = get_split_model('web.md').parse(response.content)
first = QuerySet(Document).filter(meta__source_url=child_link.url.strip(),
dataset=dataset).first()
if first is not None:
# 如果存在,使用文档同步
DocumentSerializers.Sync(data={'document_id': first.id}).sync()
else:
# 插入
DocumentSerializers.Create(data={'dataset_id': dataset.id}).save(
{'name': document_name, 'paragraphs': paragraphs,
'meta': {'source_url': child_link.url.strip(), 'selector': dataset.meta.get('selector')},
'type': Type.web}, with_valid=True)
except Exception as e:
logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}')
return handler
def get_sync_web_document_handler(dataset_id):
from knowledge.serializers.document_serializers import DocumentSerializers
def handler(source_url: str, selector, response: Fork.Response):
if response.status == 200:
try:
paragraphs = get_split_model('web.md').parse(response.content)
# 插入
DocumentSerializers.Create(data={'dataset_id': dataset_id}).save(
{'name': source_url[0:128], 'paragraphs': paragraphs,
'meta': {'source_url': source_url, 'selector': selector},
'type': KnowledgeType.WEB}, with_valid=True)
except Exception as e:
logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}')
else:
Document(name=source_url[0:128],
dataset_id=dataset_id,
meta={'source_url': source_url, 'selector': selector},
type=KnowledgeType.WEB,
char_length=0,
status=Status.error).save()
return handler
def save_problem(dataset_id, document_id, paragraph_id, problem):
from knowledge.serializers.paragraph_serializers import ParagraphSerializers
# print(f"dataset_id: {dataset_id}")
# print(f"document_id: {document_id}")
# print(f"paragraph_id: {paragraph_id}")
# print(f"problem: {problem}")
problem = re.sub(r"^\d+\.\s*", "", problem)
pattern = r"<question>(.*?)</question>"
match = re.search(pattern, problem)
problem = match.group(1) if match else None
if problem is None or len(problem) == 0:
return
try:
ParagraphSerializers.Problem(
data={"dataset_id": dataset_id, 'document_id': document_id,
'paragraph_id': paragraph_id}).save(instance={"content": problem}, with_valid=True)
except Exception as e:
max_kb_error.error(_('Association problem failed {error}').format(error=str(e)))

View File

@ -16,8 +16,8 @@ class KnowledgeView(APIView):
@extend_schema(
methods=['GET'],
description=_('Get knowledge by module'),
operation_id=_('Get knowledge by module'),
description=_('Get knowledge by folder'),
operation_id=_('Get knowledge by folder'),
parameters=KnowledgeTreeReadAPI.get_parameters(),
responses=KnowledgeTreeReadAPI.get_response(),
tags=[_('Knowledge Base')]
@ -26,7 +26,7 @@ class KnowledgeView(APIView):
def get(self, request: Request, workspace_id: str):
return result.success(KnowledgeTreeSerializer(
data={'workspace_id': workspace_id}
).get_knowledge_list(request.query_params.get('module_id')))
).get_knowledge_list(request.query_params.get('folder_id')))
class KnowledgeBaseView(APIView):
@ -45,7 +45,7 @@ class KnowledgeBaseView(APIView):
def post(self, request: Request, workspace_id: str):
return result.success(KnowledgeSerializer.Create(
data={'user_id': request.user.id, 'workspace_id': workspace_id}
).insert(request.data))
).save_base(request.data))
class KnowledgeWebView(APIView):
@ -64,4 +64,4 @@ class KnowledgeWebView(APIView):
def post(self, request: Request, workspace_id: str):
return result.success(KnowledgeSerializer.Create(
data={'user_id': request.user.id, 'workspace_id': workspace_id}
).insert(request.data))
).save_web(request.data))

View File

@ -24,7 +24,7 @@ urlpatterns = [
path("api/", include("users.urls")),
path("api/", include("tools.urls")),
path("api/", include("models_provider.urls")),
path("api/", include("modules.urls")),
path("api/", include("folders.urls")),
path("api/", include("knowledge.urls")),
]
urlpatterns += [

View File

@ -1,9 +0,0 @@
from django.urls import path
from . import views
app_name = "module"
urlpatterns = [
path('workspace/<str:workspace_id>/<str:source>/module', views.ModuleView.as_view()),
path('workspace/<str:workspace_id>/<str:source>/module/<str:module_id>', views.ModuleView.Operate.as_view()),
]

View File

@ -1 +0,0 @@
from .module import *

9
apps/ops/__init__.py Normal file
View File

@ -0,0 +1,9 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file __init__.py.py
@date2024/8/16 14:47
@desc:
"""
from .celery import app as celery_app

View File

@ -0,0 +1,33 @@
# -*- coding: utf-8 -*-
import os
from celery import Celery
from celery.schedules import crontab
from kombu import Exchange, Queue
from maxkb import settings
from .heartbeat import *
# set the default Django settings module for the 'celery' program.
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'smartdoc.settings')
app = Celery('MaxKB')
configs = {k: v for k, v in settings.__dict__.items() if k.startswith('CELERY')}
configs['worker_concurrency'] = 5
# Using a string here means the worker will not have to
# pickle the object when using Windows.
# app.config_from_object('django.conf:settings', namespace='CELERY')
configs["task_queues"] = [
Queue("celery", Exchange("celery"), routing_key="celery"),
Queue("model", Exchange("model"), routing_key="model")
]
app.namespace = 'CELERY'
app.conf.update(
{key.replace('CELERY_', '') if key.replace('CELERY_', '').lower() == key.replace('CELERY_',
'') else key: configs.get(
key) for
key
in configs.keys()})
app.autodiscover_tasks(lambda: [app_config.split('.')[0] for app_config in settings.INSTALLED_APPS])

4
apps/ops/celery/const.py Normal file
View File

@ -0,0 +1,4 @@
# -*- coding: utf-8 -*-
#
CELERY_LOG_MAGIC_MARK = b'\x00\x00\x00\x00\x00'

View File

@ -0,0 +1,104 @@
# -*- coding: utf-8 -*-
#
from functools import wraps
_need_registered_period_tasks = []
_after_app_ready_start_tasks = []
_after_app_shutdown_clean_periodic_tasks = []
def add_register_period_task(task):
_need_registered_period_tasks.append(task)
def get_register_period_tasks():
return _need_registered_period_tasks
def add_after_app_shutdown_clean_task(name):
_after_app_shutdown_clean_periodic_tasks.append(name)
def get_after_app_shutdown_clean_tasks():
return _after_app_shutdown_clean_periodic_tasks
def add_after_app_ready_task(name):
_after_app_ready_start_tasks.append(name)
def get_after_app_ready_tasks():
return _after_app_ready_start_tasks
def register_as_period_task(
crontab=None, interval=None, name=None,
args=(), kwargs=None,
description=''):
"""
Warning: Task must have not any args and kwargs
:param crontab: "* * * * *"
:param interval: 60*60*60
:param args: ()
:param kwargs: {}
:param description: "
:param name: ""
:return:
"""
if crontab is None and interval is None:
raise SyntaxError("Must set crontab or interval one")
def decorate(func):
if crontab is None and interval is None:
raise SyntaxError("Interval and crontab must set one")
# Because when this decorator run, the task was not created,
# So we can't use func.name
task = '{func.__module__}.{func.__name__}'.format(func=func)
_name = name if name else task
add_register_period_task({
_name: {
'task': task,
'interval': interval,
'crontab': crontab,
'args': args,
'kwargs': kwargs if kwargs else {},
'description': description
}
})
@wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
return wrapper
return decorate
def after_app_ready_start(func):
# Because when this decorator run, the task was not created,
# So we can't use func.name
name = '{func.__module__}.{func.__name__}'.format(func=func)
if name not in _after_app_ready_start_tasks:
add_after_app_ready_task(name)
@wraps(func)
def decorate(*args, **kwargs):
return func(*args, **kwargs)
return decorate
def after_app_shutdown_clean_periodic(func):
# Because when this decorator run, the task was not created,
# So we can't use func.name
name = '{func.__module__}.{func.__name__}'.format(func=func)
if name not in _after_app_shutdown_clean_periodic_tasks:
add_after_app_shutdown_clean_task(name)
@wraps(func)
def decorate(*args, **kwargs):
return func(*args, **kwargs)
return decorate

View File

@ -0,0 +1,25 @@
from pathlib import Path
from celery.signals import heartbeat_sent, worker_ready, worker_shutdown
@heartbeat_sent.connect
def heartbeat(sender, **kwargs):
worker_name = sender.eventer.hostname.split('@')[0]
heartbeat_path = Path('/tmp/worker_heartbeat_{}'.format(worker_name))
heartbeat_path.touch()
@worker_ready.connect
def worker_ready(sender, **kwargs):
worker_name = sender.hostname.split('@')[0]
ready_path = Path('/tmp/worker_ready_{}'.format(worker_name))
ready_path.touch()
@worker_shutdown.connect
def worker_shutdown(sender, **kwargs):
worker_name = sender.hostname.split('@')[0]
for signal in ['ready', 'heartbeat']:
path = Path('/tmp/worker_{}_{}'.format(signal, worker_name))
path.unlink(missing_ok=True)

225
apps/ops/celery/logger.py Normal file
View File

@ -0,0 +1,225 @@
from logging import StreamHandler
from threading import get_ident
from celery import current_task
from celery.signals import task_prerun, task_postrun
from django.conf import settings
from kombu import Connection, Exchange, Queue, Producer
from kombu.mixins import ConsumerMixin
from .utils import get_celery_task_log_path
from .const import CELERY_LOG_MAGIC_MARK
routing_key = 'celery_log'
celery_log_exchange = Exchange('celery_log_exchange', type='direct')
celery_log_queue = [Queue('celery_log', celery_log_exchange, routing_key=routing_key)]
class CeleryLoggerConsumer(ConsumerMixin):
def __init__(self):
self.connection = Connection(settings.CELERY_LOG_BROKER_URL)
def get_consumers(self, Consumer, channel):
return [Consumer(queues=celery_log_queue,
accept=['pickle', 'json'],
callbacks=[self.process_task])
]
def handle_task_start(self, task_id, message):
pass
def handle_task_end(self, task_id, message):
pass
def handle_task_log(self, task_id, msg, message):
pass
def process_task(self, body, message):
action = body.get('action')
task_id = body.get('task_id')
msg = body.get('msg')
if action == CeleryLoggerProducer.ACTION_TASK_LOG:
self.handle_task_log(task_id, msg, message)
elif action == CeleryLoggerProducer.ACTION_TASK_START:
self.handle_task_start(task_id, message)
elif action == CeleryLoggerProducer.ACTION_TASK_END:
self.handle_task_end(task_id, message)
class CeleryLoggerProducer:
ACTION_TASK_START, ACTION_TASK_LOG, ACTION_TASK_END = range(3)
def __init__(self):
self.connection = Connection(settings.CELERY_LOG_BROKER_URL)
@property
def producer(self):
return Producer(self.connection)
def publish(self, payload):
self.producer.publish(
payload, serializer='json', exchange=celery_log_exchange,
declare=[celery_log_exchange], routing_key=routing_key
)
def log(self, task_id, msg):
payload = {'task_id': task_id, 'msg': msg, 'action': self.ACTION_TASK_LOG}
return self.publish(payload)
def read(self):
pass
def flush(self):
pass
def task_end(self, task_id):
payload = {'task_id': task_id, 'action': self.ACTION_TASK_END}
return self.publish(payload)
def task_start(self, task_id):
payload = {'task_id': task_id, 'action': self.ACTION_TASK_START}
return self.publish(payload)
class CeleryTaskLoggerHandler(StreamHandler):
terminator = '\r\n'
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
task_prerun.connect(self.on_task_start)
task_postrun.connect(self.on_start_end)
@staticmethod
def get_current_task_id():
if not current_task:
return
task_id = current_task.request.root_id
return task_id
def on_task_start(self, sender, task_id, **kwargs):
return self.handle_task_start(task_id)
def on_start_end(self, sender, task_id, **kwargs):
return self.handle_task_end(task_id)
def after_task_publish(self, sender, body, **kwargs):
pass
def emit(self, record):
task_id = self.get_current_task_id()
if not task_id:
return
try:
self.write_task_log(task_id, record)
self.flush()
except Exception:
self.handleError(record)
def write_task_log(self, task_id, msg):
pass
def handle_task_start(self, task_id):
pass
def handle_task_end(self, task_id):
pass
class CeleryThreadingLoggerHandler(CeleryTaskLoggerHandler):
@staticmethod
def get_current_thread_id():
return str(get_ident())
def emit(self, record):
thread_id = self.get_current_thread_id()
try:
self.write_thread_task_log(thread_id, record)
self.flush()
except ValueError:
self.handleError(record)
def write_thread_task_log(self, thread_id, msg):
pass
def handle_task_start(self, task_id):
pass
def handle_task_end(self, task_id):
pass
def handleError(self, record) -> None:
pass
class CeleryTaskMQLoggerHandler(CeleryTaskLoggerHandler):
def __init__(self):
self.producer = CeleryLoggerProducer()
super().__init__(stream=None)
def write_task_log(self, task_id, record):
msg = self.format(record)
self.producer.log(task_id, msg)
def flush(self):
self.producer.flush()
class CeleryTaskFileHandler(CeleryTaskLoggerHandler):
def __init__(self, *args, **kwargs):
self.f = None
super().__init__(*args, **kwargs)
def emit(self, record):
msg = self.format(record)
if not self.f or self.f.closed:
return
self.f.write(msg)
self.f.write(self.terminator)
self.flush()
def flush(self):
self.f and self.f.flush()
def handle_task_start(self, task_id):
log_path = get_celery_task_log_path(task_id)
self.f = open(log_path, 'a')
def handle_task_end(self, task_id):
self.f and self.f.close()
class CeleryThreadTaskFileHandler(CeleryThreadingLoggerHandler):
def __init__(self, *args, **kwargs):
self.thread_id_fd_mapper = {}
self.task_id_thread_id_mapper = {}
super().__init__(*args, **kwargs)
def write_thread_task_log(self, thread_id, record):
f = self.thread_id_fd_mapper.get(thread_id, None)
if not f:
raise ValueError('Not found thread task file')
msg = self.format(record)
f.write(msg.encode())
f.write(self.terminator.encode())
f.flush()
def flush(self):
for f in self.thread_id_fd_mapper.values():
f.flush()
def handle_task_start(self, task_id):
print('handle_task_start')
log_path = get_celery_task_log_path(task_id)
thread_id = self.get_current_thread_id()
self.task_id_thread_id_mapper[task_id] = thread_id
f = open(log_path, 'ab')
self.thread_id_fd_mapper[thread_id] = f
def handle_task_end(self, task_id):
print('handle_task_end')
ident_id = self.task_id_thread_id_mapper.get(task_id, '')
f = self.thread_id_fd_mapper.pop(ident_id, None)
if f and not f.closed:
f.write(CELERY_LOG_MAGIC_MARK)
f.close()
self.task_id_thread_id_mapper.pop(task_id, None)

View File

@ -0,0 +1,75 @@
# -*- coding: utf-8 -*-
#
import logging
import os
from celery import subtask
from celery.signals import (
worker_ready, worker_shutdown, after_setup_logger, task_revoked, task_prerun
)
from django.core.cache import cache
from django_celery_beat.models import PeriodicTask
from .decorator import get_after_app_ready_tasks, get_after_app_shutdown_clean_tasks
from .logger import CeleryThreadTaskFileHandler
logger = logging.getLogger(__file__)
safe_str = lambda x: x
@worker_ready.connect
def on_app_ready(sender=None, headers=None, **kwargs):
if cache.get("CELERY_APP_READY", 0) == 1:
return
cache.set("CELERY_APP_READY", 1, 10)
tasks = get_after_app_ready_tasks()
logger.debug("Work ready signal recv")
logger.debug("Start need start task: [{}]".format(", ".join(tasks)))
for task in tasks:
periodic_task = PeriodicTask.objects.filter(task=task).first()
if periodic_task and not periodic_task.enabled:
logger.debug("Periodic task [{}] is disabled!".format(task))
continue
subtask(task).delay()
def delete_files(directory):
if os.path.isdir(directory):
for filename in os.listdir(directory):
file_path = os.path.join(directory, filename)
if os.path.isfile(file_path):
os.remove(file_path)
@worker_shutdown.connect
def after_app_shutdown_periodic_tasks(sender=None, **kwargs):
if cache.get("CELERY_APP_SHUTDOWN", 0) == 1:
return
cache.set("CELERY_APP_SHUTDOWN", 1, 10)
tasks = get_after_app_shutdown_clean_tasks()
logger.debug("Worker shutdown signal recv")
logger.debug("Clean period tasks: [{}]".format(', '.join(tasks)))
PeriodicTask.objects.filter(name__in=tasks).delete()
@after_setup_logger.connect
def add_celery_logger_handler(sender=None, logger=None, loglevel=None, format=None, **kwargs):
if not logger:
return
task_handler = CeleryThreadTaskFileHandler()
task_handler.setLevel(loglevel)
formatter = logging.Formatter(format)
task_handler.setFormatter(formatter)
logger.addHandler(task_handler)
@task_revoked.connect
def on_task_revoked(request, terminated, signum, expired, **kwargs):
print('task_revoked', terminated)
@task_prerun.connect
def on_taskaa_start(sender, task_id, **kwargs):
pass
# sender.update_state(state='REVOKED',
# meta={'exc_type': 'Exception', 'exc': 'Exception', 'message': '暂停任务', 'exc_message': ''})

68
apps/ops/celery/utils.py Normal file
View File

@ -0,0 +1,68 @@
# -*- coding: utf-8 -*-
#
import logging
import os
import uuid
from django.conf import settings
from django_celery_beat.models import (
PeriodicTasks
)
from maxkb.const import PROJECT_DIR
logger = logging.getLogger(__file__)
def disable_celery_periodic_task(task_name):
from django_celery_beat.models import PeriodicTask
PeriodicTask.objects.filter(name=task_name).update(enabled=False)
PeriodicTasks.update_changed()
def delete_celery_periodic_task(task_name):
from django_celery_beat.models import PeriodicTask
PeriodicTask.objects.filter(name=task_name).delete()
PeriodicTasks.update_changed()
def get_celery_periodic_task(task_name):
from django_celery_beat.models import PeriodicTask
task = PeriodicTask.objects.filter(name=task_name).first()
return task
def make_dirs(name, mode=0o755, exist_ok=False):
""" 默认权限设置为 0o755 """
return os.makedirs(name, mode=mode, exist_ok=exist_ok)
def get_task_log_path(base_path, task_id, level=2):
task_id = str(task_id)
try:
uuid.UUID(task_id)
except:
return os.path.join(PROJECT_DIR, 'data', 'caution.txt')
rel_path = os.path.join(*task_id[:level], task_id + '.log')
path = os.path.join(base_path, rel_path)
make_dirs(os.path.dirname(path), exist_ok=True)
return path
def get_celery_task_log_path(task_id):
return get_task_log_path(settings.CELERY_LOG_DIR, task_id)
def get_celery_status():
from . import app
i = app.control.inspect()
ping_data = i.ping() or {}
active_nodes = [k for k, v in ping_data.items() if v.get('ok') == 'pong']
active_queue_worker = set([n.split('@')[0] for n in active_nodes if n])
# Celery Worker 数量: 2
if len(active_queue_worker) < 2:
print("Not all celery worker worked")
return False
else:
return True

View File

@ -84,8 +84,8 @@ class ToolTreeReadAPI(APIMixin):
required=True,
),
OpenApiParameter(
name="module_id",
description="模块id",
name="folder_id",
description="文件夹id",
type=OpenApiTypes.STR,
location='query',
required=False,
@ -178,8 +178,8 @@ class ToolPageAPI(ToolReadAPI):
required=True,
),
OpenApiParameter(
name="module_id",
description="模块id",
name="folder_id",
description="文件夹id",
type=OpenApiTypes.STR,
location='query',
required=True,

View File

@ -5,12 +5,12 @@ import mptt.fields
import uuid_utils.compat
from django.db import migrations, models
from tools.models import ToolModule
from tools.models import ToolFolder
def insert_default_data(apps, schema_editor):
# 创建一个根模块(没有父节点)
ToolModule.objects.create(id='root', name='根目录', user_id='f0dd8f71-e4ee-11ee-8c84-a8a1595801ab')
ToolFolder.objects.create(id='root', name='根目录', user_id='f0dd8f71-e4ee-11ee-8c84-a8a1595801ab')
class Migration(migrations.Migration):
@ -22,7 +22,7 @@ class Migration(migrations.Migration):
operations = [
migrations.CreateModel(
name='ToolModule',
name='ToolFolder',
fields=[
('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')),
('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')),
@ -37,12 +37,12 @@ class Migration(migrations.Migration):
('level', models.PositiveIntegerField(editable=False)),
('parent',
mptt.fields.TreeForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE,
related_name='children', to='tools.toolmodule')),
related_name='children', to='tools.toolfolder')),
('user', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='users.user',
verbose_name='用户id')),
],
options={
'db_table': 'tool_module',
'db_table': 'tool_folder',
},
),
migrations.RunPython(insert_default_data),
@ -72,9 +72,9 @@ class Migration(migrations.Migration):
('init_params', models.CharField(max_length=102400, null=True, verbose_name='初始化参数')),
('user', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='users.user',
verbose_name='用户id')),
('module',
models.ForeignKey(default='root', on_delete=django.db.models.deletion.CASCADE, to='tools.toolmodule',
verbose_name='模块id')),
('folder',
models.ForeignKey(default='root', on_delete=django.db.models.deletion.CASCADE, to='tools.toolfolder',
verbose_name='文件夹id')),
],
options={
'db_table': 'tool',

View File

@ -7,7 +7,7 @@ from common.mixins.app_model_mixin import AppModelMixin
from users.models import User
class ToolModule(MPTTModel, AppModelMixin):
class ToolFolder(MPTTModel, AppModelMixin):
id = models.CharField(primary_key=True, max_length=64, editable=False, verbose_name="主键id")
name = models.CharField(max_length=64, verbose_name="文件夹名称")
user = models.ForeignKey(User, on_delete=models.DO_NOTHING, verbose_name="用户id")
@ -15,7 +15,7 @@ class ToolModule(MPTTModel, AppModelMixin):
parent = TreeForeignKey('self', on_delete=models.CASCADE, null=True, blank=True, related_name='children')
class Meta:
db_table = "tool_module"
db_table = "tool_folder"
class MPTTMeta:
order_insertion_by = ['name']
@ -46,7 +46,7 @@ class Tool(AppModelMixin):
tool_type = models.CharField(max_length=20, verbose_name='工具类型', choices=ToolType.choices,
default=ToolType.CUSTOM, db_index=True)
template_id = models.UUIDField(max_length=128, verbose_name="模版id", null=True, default=None)
module = models.ForeignKey(ToolModule, on_delete=models.CASCADE, verbose_name="模块id", default='root')
folder = models.ForeignKey(ToolFolder, on_delete=models.CASCADE, verbose_name="文件夹id", default='root')
workspace_id = models.CharField(max_length=64, verbose_name="工作空间id", default="default", db_index=True)
init_params = models.CharField(max_length=102400, verbose_name="初始化参数", null=True)

View File

@ -17,7 +17,7 @@ from common.exception.app_exception import AppApiException
from common.result import result
from common.utils.tool_code import ToolExecutor
from maxkb.const import CONFIG
from tools.models import Tool, ToolScope, ToolModule
from tools.models import Tool, ToolScope, ToolFolder
tool_executor = ToolExecutor(CONFIG.get('SANDBOX'))
@ -37,11 +37,11 @@ ALLOWED_CLASSES = {
class RestrictedUnpickler(pickle.Unpickler):
def find_class(self, module, name):
if (module, name) in ALLOWED_CLASSES:
return super().find_class(module, name)
def find_class(self, folder, name):
if (folder, name) in ALLOWED_CLASSES:
return super().find_class(folder, name)
raise pickle.UnpicklingError("global '%s.%s' is forbidden" %
(module, name))
(folder, name))
def encryption(message: str):
@ -75,7 +75,7 @@ class ToolModelSerializer(serializers.ModelSerializer):
class Meta:
model = Tool
fields = ['id', 'name', 'icon', 'desc', 'code', 'input_field_list', 'init_field_list', 'init_params',
'scope', 'is_active', 'user_id', 'template_id', 'workspace_id', 'module_id', 'tool_type',
'scope', 'is_active', 'user_id', 'template_id', 'workspace_id', 'folder_id', 'tool_type',
'create_time', 'update_time']
@ -126,7 +126,7 @@ class ToolCreateRequest(serializers.Serializer):
is_active = serializers.BooleanField(required=False, label=_('Is active'))
module_id = serializers.CharField(required=False, allow_null=True, allow_blank=True, default='root')
folder_id = serializers.CharField(required=False, allow_null=True, allow_blank=True, default='root')
class ToolEditRequest(serializers.Serializer):
@ -146,7 +146,7 @@ class ToolEditRequest(serializers.Serializer):
is_active = serializers.BooleanField(required=False, label=_('Is active'))
module_id = serializers.CharField(required=False, allow_null=True, allow_blank=True, default='root')
folder_id = serializers.CharField(required=False, allow_null=True, allow_blank=True, default='root')
class DebugField(serializers.Serializer):
@ -180,7 +180,7 @@ class ToolSerializer(serializers.Serializer):
input_field_list=instance.get('input_field_list', []),
init_field_list=instance.get('init_field_list', []),
scope=ToolScope.WORKSPACE,
module_id=instance.get('module_id', 'root'),
folder_id=instance.get('folder_id', 'root'),
is_active=False)
tool.save()
return ToolModelSerializer(tool).data
@ -321,42 +321,42 @@ class ToolSerializer(serializers.Serializer):
class ToolTreeSerializer(serializers.Serializer):
workspace_id = serializers.CharField(required=True, label=_('workspace id'))
def get_tools(self, module_id):
def get_tools(self, folder_id):
self.is_valid(raise_exception=True)
if not module_id:
module_id = 'root'
root = ToolModule.objects.filter(id=module_id).first()
if not folder_id:
folder_id = 'root'
root = ToolFolder.objects.filter(id=folder_id).first()
if not root:
raise serializers.ValidationError(_('Module not found'))
raise serializers.ValidationError(_('Folder not found'))
# 使用MPTT的get_descendants()方法获取所有相关节点
all_modules = root.get_descendants(include_self=True)
all_folders = root.get_descendants(include_self=True)
tools = QuerySet(Tool).filter(workspace_id=self.data.get('workspace_id'), module_id__in=all_modules)
tools = QuerySet(Tool).filter(workspace_id=self.data.get('workspace_id'), folder_id__in=all_folders)
return ToolModelSerializer(tools, many=True).data
class Query(serializers.Serializer):
workspace_id = serializers.CharField(required=True, label=_('workspace id'))
module_id = serializers.CharField(required=True, label=_('module id'))
folder_id = serializers.CharField(required=True, label=_('folder id'))
name = serializers.CharField(required=False, allow_null=True, allow_blank=True, label=_('tool name'))
tool_type = serializers.CharField(required=True, label=_('tool type'))
def page(self, current_page: int, page_size: int):
self.is_valid(raise_exception=True)
module_id = self.data.get('module_id', 'root')
root = ToolModule.objects.filter(id=module_id).first()
folder_id = self.data.get('folder_id', 'root')
root = ToolFolder.objects.filter(id=folder_id).first()
if not root:
raise serializers.ValidationError(_('Module not found'))
raise serializers.ValidationError(_('Folder not found'))
# 使用MPTT的get_descendants()方法获取所有相关节点
all_modules = root.get_descendants(include_self=True)
all_folders = root.get_descendants(include_self=True)
if self.data.get('name'):
tools = QuerySet(Tool).filter(Q(workspace_id=self.data.get('workspace_id')) &
Q(module_id__in=all_modules) &
Q(folder_id__in=all_folders) &
Q(tool_type=self.data.get('tool_type')) &
Q(name__contains=self.data.get('name')))
else:
tools = QuerySet(Tool).filter(Q(workspace_id=self.data.get('workspace_id')) &
Q(module_id__in=all_modules) &
Q(folder_id__in=all_folders) &
Q(tool_type=self.data.get('tool_type')))
return page_search(current_page, page_size, tools, lambda record: ToolModelSerializer(record).data)

View File

@ -2,15 +2,15 @@
from rest_framework import serializers
from tools.models import ToolModule
from tools.models import ToolFolder
class ToolModuleTreeSerializer(serializers.ModelSerializer):
class ToolFolderTreeSerializer(serializers.ModelSerializer):
children = serializers.SerializerMethodField()
class Meta:
model = ToolModule
model = ToolFolder
fields = ['id', 'name', 'user_id', 'workspace_id', 'parent_id', 'children']
def get_children(self, obj):
return ToolModuleTreeSerializer(obj.get_children(), many=True).data
return ToolFolderTreeSerializer(obj.get_children(), many=True).data

View File

@ -33,8 +33,8 @@ class ToolView(APIView):
@extend_schema(
methods=['GET'],
description=_('Get tool by module'),
operation_id=_('Get tool by module'),
description=_('Get tool by folder'),
operation_id=_('Get tool by folder'),
parameters=ToolTreeReadAPI.get_parameters(),
responses=ToolTreeReadAPI.get_response(),
tags=[_('Tool')]
@ -43,7 +43,7 @@ class ToolView(APIView):
def get(self, request: Request, workspace_id: str):
return result.success(ToolTreeSerializer(
data={'workspace_id': workspace_id}
).get_tools(request.query_params.get('module_id')))
).get_tools(request.query_params.get('folder_id')))
class Debug(APIView):
authentication_classes = [TokenAuth]
@ -124,7 +124,7 @@ class ToolView(APIView):
return result.success(ToolTreeSerializer.Query(
data={
'workspace_id': workspace_id,
'module_id': request.query_params.get('module_id'),
'folder_id': request.query_params.get('folder_id'),
'name': request.query_params.get('name'),
'tool_type': request.query_params.get('tool_type'),
}

View File

@ -40,6 +40,12 @@ cffi = "1.17.1"
pysilk = "0.0.1"
sentence-transformers = "4.1.0"
websockets = "15.0.1"
celery = { extras = ["sqlalchemy"], version = "5.5.2" }
django-celery-beat = "2.8.0"
celery-once = "3.0.1"
beautifulsoup4 = "4.13.4"
html2text = "2025.4.15"
jieba = "0.42.1"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"