mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 01:33:05 +00:00
feat: implement web knowledge synchronization with ForkManage and related handlers
This commit is contained in:
parent
99fd32897c
commit
ee37d7c320
|
|
@ -166,13 +166,13 @@ class PermissionConstants(Enum):
|
|||
role_list=[RoleConstants.ADMIN, RoleConstants.USER])
|
||||
MODEL_DELETE = Permission(group=Group.MODEL, operate=Operate.DELETE,
|
||||
role_list=[RoleConstants.ADMIN, RoleConstants.USER])
|
||||
TOOL_MODULE_CREATE = Permission(group=Group.TOOL, operate=Operate.CREATE, role_list=[RoleConstants.ADMIN,
|
||||
TOOL_FOLDER_CREATE = Permission(group=Group.TOOL, operate=Operate.CREATE, role_list=[RoleConstants.ADMIN,
|
||||
RoleConstants.USER])
|
||||
TOOL_MODULE_READ = Permission(group=Group.TOOL, operate=Operate.READ, role_list=[RoleConstants.ADMIN,
|
||||
TOOL_FOLDER_READ = Permission(group=Group.TOOL, operate=Operate.READ, role_list=[RoleConstants.ADMIN,
|
||||
RoleConstants.USER])
|
||||
TOOL_MODULE_EDIT = Permission(group=Group.TOOL, operate=Operate.EDIT, role_list=[RoleConstants.ADMIN,
|
||||
TOOL_FOLDER_EDIT = Permission(group=Group.TOOL, operate=Operate.EDIT, role_list=[RoleConstants.ADMIN,
|
||||
RoleConstants.USER])
|
||||
TOOL_MODULE_DELETE = Permission(group=Group.TOOL, operate=Operate.DELETE, role_list=[RoleConstants.ADMIN,
|
||||
TOOL_FOLDER_DELETE = Permission(group=Group.TOOL, operate=Operate.DELETE, role_list=[RoleConstants.ADMIN,
|
||||
RoleConstants.USER])
|
||||
|
||||
TOOL_CREATE = Permission(group=Group.TOOL, operate=Operate.CREATE, role_list=[RoleConstants.ADMIN,
|
||||
|
|
@ -190,20 +190,20 @@ class PermissionConstants(Enum):
|
|||
TOOL_EXPORT = Permission(group=Group.TOOL, operate=Operate.USE, role_list=[RoleConstants.ADMIN,
|
||||
RoleConstants.USER])
|
||||
|
||||
KNOWLEDGE_MODULE_CREATE = Permission(group=Group.KNOWLEDGE, operate=Operate.CREATE, role_list=[RoleConstants.ADMIN,
|
||||
KNOWLEDGE_FOLDER_CREATE = Permission(group=Group.KNOWLEDGE, operate=Operate.CREATE, role_list=[RoleConstants.ADMIN,
|
||||
RoleConstants.USER])
|
||||
KNOWLEDGE_MODULE_READ = Permission(group=Group.KNOWLEDGE, operate=Operate.READ, role_list=[RoleConstants.ADMIN,
|
||||
KNOWLEDGE_FOLDER_READ = Permission(group=Group.KNOWLEDGE, operate=Operate.READ, role_list=[RoleConstants.ADMIN,
|
||||
RoleConstants.USER],
|
||||
resource_permission_group_list=[
|
||||
ResourcePermissionGroup.VIEW
|
||||
])
|
||||
KNOWLEDGE_MODULE_EDIT = Permission(group=Group.KNOWLEDGE, operate=Operate.EDIT, role_list=[RoleConstants.ADMIN,
|
||||
KNOWLEDGE_FOLDER_EDIT = Permission(group=Group.KNOWLEDGE, operate=Operate.EDIT, role_list=[RoleConstants.ADMIN,
|
||||
RoleConstants.USER],
|
||||
resource_permission_group_list=[
|
||||
ResourcePermissionGroup.MANAGE
|
||||
]
|
||||
)
|
||||
KNOWLEDGE_MODULE_DELETE = Permission(group=Group.KNOWLEDGE, operate=Operate.DELETE, role_list=[RoleConstants.ADMIN,
|
||||
KNOWLEDGE_FOLDER_DELETE = Permission(group=Group.KNOWLEDGE, operate=Operate.DELETE, role_list=[RoleConstants.ADMIN,
|
||||
RoleConstants.USER],
|
||||
resource_permission_group_list=[
|
||||
ResourcePermissionGroup.MANAGE
|
||||
|
|
|
|||
|
|
@ -0,0 +1,178 @@
|
|||
import copy
|
||||
import logging
|
||||
import re
|
||||
import traceback
|
||||
from functools import reduce
|
||||
from typing import List, Set
|
||||
from urllib.parse import urljoin, urlparse, ParseResult, urlsplit, urlunparse
|
||||
|
||||
import html2text as ht
|
||||
import requests
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
requests.packages.urllib3.disable_warnings()
|
||||
|
||||
|
||||
class ChildLink:
|
||||
def __init__(self, url, tag):
|
||||
self.url = url
|
||||
self.tag = copy.deepcopy(tag)
|
||||
|
||||
|
||||
class ForkManage:
|
||||
def __init__(self, base_url: str, selector_list: List[str]):
|
||||
self.base_url = base_url
|
||||
self.selector_list = selector_list
|
||||
|
||||
def fork(self, level: int, exclude_link_url: Set[str], fork_handler):
|
||||
self.fork_child(ChildLink(self.base_url, None), self.selector_list, level, exclude_link_url, fork_handler)
|
||||
|
||||
@staticmethod
|
||||
def fork_child(child_link: ChildLink, selector_list: List[str], level: int, exclude_link_url: Set[str],
|
||||
fork_handler):
|
||||
if level < 0:
|
||||
return
|
||||
else:
|
||||
child_link.url = remove_fragment(child_link.url)
|
||||
child_url = child_link.url[:-1] if child_link.url.endswith('/') else child_link.url
|
||||
if not exclude_link_url.__contains__(child_url):
|
||||
exclude_link_url.add(child_url)
|
||||
response = Fork(child_link.url, selector_list).fork()
|
||||
fork_handler(child_link, response)
|
||||
for child_link in response.child_link_list:
|
||||
child_url = child_link.url[:-1] if child_link.url.endswith('/') else child_link.url
|
||||
if not exclude_link_url.__contains__(child_url):
|
||||
ForkManage.fork_child(child_link, selector_list, level - 1, exclude_link_url, fork_handler)
|
||||
|
||||
|
||||
def remove_fragment(url: str) -> str:
|
||||
parsed_url = urlparse(url)
|
||||
modified_url = ParseResult(scheme=parsed_url.scheme, netloc=parsed_url.netloc, path=parsed_url.path,
|
||||
params=parsed_url.params, query=parsed_url.query, fragment=None)
|
||||
return urlunparse(modified_url)
|
||||
|
||||
|
||||
class Fork:
|
||||
class Response:
|
||||
def __init__(self, content: str, child_link_list: List[ChildLink], status, message: str):
|
||||
self.content = content
|
||||
self.child_link_list = child_link_list
|
||||
self.status = status
|
||||
self.message = message
|
||||
|
||||
@staticmethod
|
||||
def success(html_content: str, child_link_list: List[ChildLink]):
|
||||
return Fork.Response(html_content, child_link_list, 200, '')
|
||||
|
||||
@staticmethod
|
||||
def error(message: str):
|
||||
return Fork.Response('', [], 500, message)
|
||||
|
||||
def __init__(self, base_fork_url: str, selector_list: List[str]):
|
||||
base_fork_url = remove_fragment(base_fork_url)
|
||||
self.base_fork_url = urljoin(base_fork_url if base_fork_url.endswith("/") else base_fork_url + '/', '.')
|
||||
parsed = urlsplit(base_fork_url)
|
||||
query = parsed.query
|
||||
self.base_fork_url = self.base_fork_url[:-1]
|
||||
if query is not None and len(query) > 0:
|
||||
self.base_fork_url = self.base_fork_url + '?' + query
|
||||
self.selector_list = [selector for selector in selector_list if selector is not None and len(selector) > 0]
|
||||
self.urlparse = urlparse(self.base_fork_url)
|
||||
self.base_url = ParseResult(scheme=self.urlparse.scheme, netloc=self.urlparse.netloc, path='', params='',
|
||||
query='',
|
||||
fragment='').geturl()
|
||||
|
||||
def get_child_link_list(self, bf: BeautifulSoup):
|
||||
pattern = "^((?!(http:|https:|tel:/|#|mailto:|javascript:))|" + self.base_fork_url + "|/).*"
|
||||
link_list = bf.find_all(name='a', href=re.compile(pattern))
|
||||
result = [ChildLink(link.get('href'), link) if link.get('href').startswith(self.base_url) else ChildLink(
|
||||
self.base_url + link.get('href'), link) for link in link_list]
|
||||
result = [row for row in result if row.url.startswith(self.base_fork_url)]
|
||||
return result
|
||||
|
||||
def get_content_html(self, bf: BeautifulSoup):
|
||||
if self.selector_list is None or len(self.selector_list) == 0:
|
||||
return str(bf)
|
||||
params = reduce(lambda x, y: {**x, **y},
|
||||
[{'class_': selector.replace('.', '')} if selector.startswith('.') else
|
||||
{'id': selector.replace("#", "")} if selector.startswith("#") else {'name': selector} for
|
||||
selector in
|
||||
self.selector_list], {})
|
||||
f = bf.find_all(**params)
|
||||
return "\n".join([str(row) for row in f])
|
||||
|
||||
@staticmethod
|
||||
def reset_url(tag, field, base_fork_url):
|
||||
field_value: str = tag[field]
|
||||
if field_value.startswith("/"):
|
||||
result = urlparse(base_fork_url)
|
||||
result_url = ParseResult(scheme=result.scheme, netloc=result.netloc, path=field_value, params='', query='',
|
||||
fragment='').geturl()
|
||||
else:
|
||||
result_url = urljoin(
|
||||
base_fork_url + '/' + (field_value if field_value.endswith('/') else field_value + '/'),
|
||||
".")
|
||||
result_url = result_url[:-1] if result_url.endswith('/') else result_url
|
||||
tag[field] = result_url
|
||||
|
||||
def reset_beautiful_soup(self, bf: BeautifulSoup):
|
||||
reset_config_list = [
|
||||
{
|
||||
'field': 'href',
|
||||
},
|
||||
{
|
||||
'field': 'src',
|
||||
}
|
||||
]
|
||||
for reset_config in reset_config_list:
|
||||
field = reset_config.get('field')
|
||||
tag_list = bf.find_all(**{field: re.compile('^(?!(http:|https:|tel:/|#|mailto:|javascript:)).*')})
|
||||
for tag in tag_list:
|
||||
self.reset_url(tag, field, self.base_fork_url)
|
||||
return bf
|
||||
|
||||
@staticmethod
|
||||
def get_beautiful_soup(response):
|
||||
encoding = response.encoding if response.encoding is not None and response.encoding != 'ISO-8859-1' else response.apparent_encoding
|
||||
html_content = response.content.decode(encoding)
|
||||
beautiful_soup = BeautifulSoup(html_content, "html.parser")
|
||||
meta_list = beautiful_soup.find_all('meta')
|
||||
charset_list = [meta.attrs.get('charset') for meta in meta_list if
|
||||
meta.attrs is not None and 'charset' in meta.attrs]
|
||||
if len(charset_list) > 0:
|
||||
charset = charset_list[0]
|
||||
if charset != encoding:
|
||||
try:
|
||||
html_content = response.content.decode(charset)
|
||||
except Exception as e:
|
||||
logging.getLogger("max_kb").error(f'{e}')
|
||||
return BeautifulSoup(html_content, "html.parser")
|
||||
return beautiful_soup
|
||||
|
||||
def fork(self):
|
||||
try:
|
||||
|
||||
headers = {
|
||||
'user-agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/99.0.4844.51 Safari/537.36'
|
||||
}
|
||||
|
||||
logging.getLogger("max_kb").info(f'fork:{self.base_fork_url}')
|
||||
response = requests.get(self.base_fork_url, verify=False, headers=headers)
|
||||
if response.status_code != 200:
|
||||
logging.getLogger("max_kb").error(f"url: {self.base_fork_url} code:{response.status_code}")
|
||||
return Fork.Response.error(f"url: {self.base_fork_url} code:{response.status_code}")
|
||||
bf = self.get_beautiful_soup(response)
|
||||
except Exception as e:
|
||||
logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}')
|
||||
return Fork.Response.error(str(e))
|
||||
bf = self.reset_beautiful_soup(bf)
|
||||
link_list = self.get_child_link_list(bf)
|
||||
content = self.get_content_html(bf)
|
||||
r = ht.html2text(content)
|
||||
return Fork.Response.success(r, link_list)
|
||||
|
||||
|
||||
def handler(base_url, response: Fork.Response):
|
||||
print(base_url.url, base_url.tag.text if base_url.tag else None, response.content)
|
||||
|
||||
# ForkManage('https://bbs.fit2cloud.com/c/de/6', ['.md-content']).fork(3, set(), handler)
|
||||
|
|
@ -0,0 +1,417 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: qabot
|
||||
@Author:虎
|
||||
@file: split_model.py
|
||||
@date:2023/9/1 15:12
|
||||
@desc:
|
||||
"""
|
||||
import re
|
||||
from functools import reduce
|
||||
from typing import List, Dict
|
||||
|
||||
import jieba
|
||||
|
||||
|
||||
def get_level_block(text, level_content_list, level_content_index, cursor):
|
||||
"""
|
||||
从文本中获取块数据
|
||||
:param text: 文本
|
||||
:param level_content_list: 拆分的title数组
|
||||
:param level_content_index: 指定的下标
|
||||
:param cursor: 开始的下标位置
|
||||
:return: 拆分后的文本数据
|
||||
"""
|
||||
start_content: str = level_content_list[level_content_index].get('content')
|
||||
next_content = level_content_list[level_content_index + 1].get("content") if level_content_index + 1 < len(
|
||||
level_content_list) else None
|
||||
start_index = text.index(start_content, cursor)
|
||||
end_index = text.index(next_content, start_index + 1) if next_content is not None else len(text)
|
||||
return text[start_index + len(start_content):end_index], end_index
|
||||
|
||||
|
||||
def to_tree_obj(content, state='title'):
|
||||
"""
|
||||
转换为树形对象
|
||||
:param content: 文本数据
|
||||
:param state: 状态: title block
|
||||
:return: 转换后的数据
|
||||
"""
|
||||
return {'content': content, 'state': state}
|
||||
|
||||
|
||||
def remove_special_symbol(str_source: str):
|
||||
"""
|
||||
删除特殊字符
|
||||
:param str_source: 需要删除的文本数据
|
||||
:return: 删除后的数据
|
||||
"""
|
||||
return str_source
|
||||
|
||||
|
||||
def filter_special_symbol(content: dict):
|
||||
"""
|
||||
过滤文本中的特殊字符
|
||||
:param content: 需要过滤的对象
|
||||
:return: 过滤后返回
|
||||
"""
|
||||
content['content'] = remove_special_symbol(content['content'])
|
||||
return content
|
||||
|
||||
|
||||
def flat(tree_data_list: List[dict], parent_chain: List[dict], result: List[dict]):
|
||||
"""
|
||||
扁平化树形结构数据
|
||||
:param tree_data_list: 树形接口数据
|
||||
:param parent_chain: 父级数据 传[] 用于递归存储数据
|
||||
:param result: 响应数据 传[] 用于递归存放数据
|
||||
:return: result 扁平化后的数据
|
||||
"""
|
||||
if parent_chain is None:
|
||||
parent_chain = []
|
||||
if result is None:
|
||||
result = []
|
||||
for tree_data in tree_data_list:
|
||||
p = parent_chain.copy()
|
||||
p.append(tree_data)
|
||||
result.append(to_flat_obj(parent_chain, content=tree_data["content"], state=tree_data["state"]))
|
||||
children = tree_data.get('children')
|
||||
if children is not None and len(children) > 0:
|
||||
flat(children, p, result)
|
||||
return result
|
||||
|
||||
|
||||
def to_paragraph(obj: dict):
|
||||
"""
|
||||
转换为段落
|
||||
:param obj: 需要转换的对象
|
||||
:return: 段落对象
|
||||
"""
|
||||
content = obj['content']
|
||||
return {"keywords": get_keyword(content),
|
||||
'parent_chain': list(map(lambda p: p['content'], obj['parent_chain'])),
|
||||
'content': ",".join(list(map(lambda p: p['content'], obj['parent_chain']))) + content}
|
||||
|
||||
|
||||
def get_keyword(content: str):
|
||||
"""
|
||||
获取content中的关键词
|
||||
:param content: 文本
|
||||
:return: 关键词数组
|
||||
"""
|
||||
stopwords = [':', '“', '!', '”', '\n', '\\s']
|
||||
cutworms = jieba.lcut(content)
|
||||
return list(set(list(filter(lambda k: (k not in stopwords) | len(k) > 1, cutworms))))
|
||||
|
||||
|
||||
def titles_to_paragraph(list_title: List[dict]):
|
||||
"""
|
||||
将同一父级的title转换为块段落
|
||||
:param list_title: 同父级title
|
||||
:return: 块段落
|
||||
"""
|
||||
if len(list_title) > 0:
|
||||
content = "\n,".join(
|
||||
list(map(lambda d: d['content'].strip("\r\n").strip("\n").strip("\\s"), list_title)))
|
||||
|
||||
return {'keywords': '',
|
||||
'parent_chain': list(
|
||||
map(lambda p: p['content'].strip("\r\n").strip("\n").strip("\\s"), list_title[0]['parent_chain'])),
|
||||
'content': ",".join(list(
|
||||
map(lambda p: p['content'].strip("\r\n").strip("\n").strip("\\s"),
|
||||
list_title[0]['parent_chain']))) + content}
|
||||
return None
|
||||
|
||||
|
||||
def parse_group_key(level_list: List[dict]):
|
||||
"""
|
||||
将同级别同父级的title生成段落,加上本身的段落数据形成新的数据
|
||||
:param level_list: title n 级数据
|
||||
:return: 根据title生成的数据 + 段落数据
|
||||
"""
|
||||
result = []
|
||||
group_data = group_by(list(filter(lambda f: f['state'] == 'title' and len(f['parent_chain']) > 0, level_list)),
|
||||
key=lambda d: ",".join(list(map(lambda p: p['content'], d['parent_chain']))))
|
||||
result += list(map(lambda group_data_key: titles_to_paragraph(group_data[group_data_key]), group_data))
|
||||
result += list(map(to_paragraph, list(filter(lambda f: f['state'] == 'block', level_list))))
|
||||
return result
|
||||
|
||||
|
||||
def to_block_paragraph(tree_data_list: List[dict]):
|
||||
"""
|
||||
转换为块段落对象
|
||||
:param tree_data_list: 树数据
|
||||
:return: 块段落
|
||||
"""
|
||||
flat_list = flat(tree_data_list, [], [])
|
||||
level_group_dict: dict = group_by(flat_list, key=lambda f: f['level'])
|
||||
return list(map(lambda level: parse_group_key(level_group_dict[level]), level_group_dict))
|
||||
|
||||
|
||||
def parse_title_level(text, content_level_pattern: List, index):
|
||||
if index >= len(content_level_pattern):
|
||||
return []
|
||||
result = parse_level(text, content_level_pattern[index])
|
||||
if len(result) == 0 and len(content_level_pattern) > index:
|
||||
return parse_title_level(text, content_level_pattern, index + 1)
|
||||
return result
|
||||
|
||||
|
||||
def parse_level(text, pattern: str):
|
||||
"""
|
||||
获取正则匹配到的文本
|
||||
:param text: 需要匹配的文本
|
||||
:param pattern: 正则
|
||||
:return: 符合正则的文本
|
||||
"""
|
||||
level_content_list = list(map(to_tree_obj, [r[0:255] for r in re_findall(pattern, text) if r is not None]))
|
||||
return list(map(filter_special_symbol, level_content_list))
|
||||
|
||||
|
||||
def re_findall(pattern, text):
|
||||
result = re.findall(pattern, text, flags=0)
|
||||
return list(filter(lambda r: r is not None and len(r) > 0, reduce(lambda x, y: [*x, *y], list(
|
||||
map(lambda row: [*(row if isinstance(row, tuple) else [row])], result)),
|
||||
[])))
|
||||
|
||||
|
||||
def to_flat_obj(parent_chain: List[dict], content: str, state: str):
|
||||
"""
|
||||
将树形属性转换为扁平对象
|
||||
:param parent_chain:
|
||||
:param content:
|
||||
:param state:
|
||||
:return:
|
||||
"""
|
||||
return {'parent_chain': parent_chain, 'level': len(parent_chain), "content": content, 'state': state}
|
||||
|
||||
|
||||
def flat_map(array: List[List]):
|
||||
"""
|
||||
将二位数组转为一维数组
|
||||
:param array: 二维数组
|
||||
:return: 一维数组
|
||||
"""
|
||||
result = []
|
||||
for e in array:
|
||||
result += e
|
||||
return result
|
||||
|
||||
|
||||
def group_by(list_source: List, key):
|
||||
"""
|
||||
將數組分組
|
||||
:param list_source: 需要分組的數組
|
||||
:param key: 分組函數
|
||||
:return: key->[]
|
||||
"""
|
||||
result = {}
|
||||
for e in list_source:
|
||||
k = key(e)
|
||||
array = result.get(k) if k in result else []
|
||||
array.append(e)
|
||||
result[k] = array
|
||||
return result
|
||||
|
||||
|
||||
def result_tree_to_paragraph(result_tree: List[dict], result, parent_chain, with_filter: bool):
|
||||
"""
|
||||
转换为分段对象
|
||||
:param result_tree: 解析文本的树
|
||||
:param result: 传[] 用于递归
|
||||
:param parent_chain: 传[] 用户递归存储数据
|
||||
:param with_filter: 是否过滤block
|
||||
:return: List[{'problem':'xx','content':'xx'}]
|
||||
"""
|
||||
for item in result_tree:
|
||||
if item.get('state') == 'block':
|
||||
result.append({'title': " ".join(parent_chain),
|
||||
'content': filter_special_char(item.get("content")) if with_filter else item.get("content")})
|
||||
children = item.get("children")
|
||||
if children is not None and len(children) > 0:
|
||||
result_tree_to_paragraph(children, result,
|
||||
[*parent_chain, remove_special_symbol(item.get('content'))], with_filter)
|
||||
return result
|
||||
|
||||
|
||||
def post_handler_paragraph(content: str, limit: int):
|
||||
"""
|
||||
根据文本的最大字符分段
|
||||
:param content: 需要分段的文本字段
|
||||
:param limit: 最大分段字符
|
||||
:return: 分段后数据
|
||||
"""
|
||||
result = []
|
||||
temp_char, start = '', 0
|
||||
while (pos := content.find("\n", start)) != -1:
|
||||
split, start = content[start:pos + 1], pos + 1
|
||||
if len(temp_char + split) > limit:
|
||||
if len(temp_char) > 4096:
|
||||
pass
|
||||
result.append(temp_char)
|
||||
temp_char = ''
|
||||
temp_char = temp_char + split
|
||||
temp_char = temp_char + content[start:]
|
||||
if len(temp_char) > 0:
|
||||
if len(temp_char) > 4096:
|
||||
pass
|
||||
result.append(temp_char)
|
||||
|
||||
pattern = "[\\S\\s]{1," + str(limit) + '}'
|
||||
# 如果\n 单段超过限制,则继续拆分
|
||||
return reduce(lambda x, y: [*x, *y], map(lambda row: re.findall(pattern, row), result), [])
|
||||
|
||||
|
||||
replace_map = {
|
||||
re.compile('\n+'): '\n',
|
||||
re.compile(' +'): ' ',
|
||||
re.compile('#+'): "",
|
||||
re.compile("\t+"): ''
|
||||
}
|
||||
|
||||
|
||||
def filter_special_char(content: str):
|
||||
"""
|
||||
过滤特殊字段
|
||||
:param content: 文本
|
||||
:return: 过滤后字段
|
||||
"""
|
||||
items = replace_map.items()
|
||||
for key, value in items:
|
||||
content = re.sub(key, value, content)
|
||||
return content
|
||||
|
||||
|
||||
class SplitModel:
|
||||
|
||||
def __init__(self, content_level_pattern, with_filter=True, limit=100000):
|
||||
self.content_level_pattern = content_level_pattern
|
||||
self.with_filter = with_filter
|
||||
if limit is None or limit > 100000:
|
||||
limit = 100000
|
||||
if limit < 50:
|
||||
limit = 50
|
||||
self.limit = limit
|
||||
|
||||
def parse_to_tree(self, text: str, index=0):
|
||||
"""
|
||||
解析文本
|
||||
:param text: 需要解析的文本
|
||||
:param index: 从那个正则开始解析
|
||||
:return: 解析后的树形结果数据
|
||||
"""
|
||||
level_content_list = parse_title_level(text, self.content_level_pattern, index)
|
||||
if len(level_content_list) == 0:
|
||||
return [to_tree_obj(row, 'block') for row in post_handler_paragraph(text, limit=self.limit)]
|
||||
if index == 0 and text.lstrip().index(level_content_list[0]["content"].lstrip()) != 0:
|
||||
level_content_list.insert(0, to_tree_obj(""))
|
||||
|
||||
cursor = 0
|
||||
level_title_content_list = [item for item in level_content_list if item.get('state') == 'title']
|
||||
for i in range(len(level_title_content_list)):
|
||||
start_content: str = level_title_content_list[i].get('content')
|
||||
if cursor < text.index(start_content, cursor):
|
||||
for row in post_handler_paragraph(text[cursor: text.index(start_content, cursor)], limit=self.limit):
|
||||
level_content_list.insert(0, to_tree_obj(row, 'block'))
|
||||
|
||||
block, cursor = get_level_block(text, level_title_content_list, i, cursor)
|
||||
if len(block) == 0:
|
||||
continue
|
||||
children = self.parse_to_tree(text=block, index=index + 1)
|
||||
level_title_content_list[i]['children'] = children
|
||||
first_child_idx_in_block = block.lstrip().index(children[0]["content"].lstrip())
|
||||
if first_child_idx_in_block != 0:
|
||||
inner_children = self.parse_to_tree(block[:first_child_idx_in_block], index + 1)
|
||||
level_title_content_list[i]['children'].extend(inner_children)
|
||||
return level_content_list
|
||||
|
||||
def parse(self, text: str):
|
||||
"""
|
||||
解析文本
|
||||
:param text: 文本数据
|
||||
:return: 解析后数据 {content:段落数据,keywords:[‘段落关键词’],parent_chain:['段落父级链路']}
|
||||
"""
|
||||
text = text.replace('\r\n', '\n')
|
||||
text = text.replace('\r', '\n')
|
||||
text = text.replace("\0", '')
|
||||
result_tree = self.parse_to_tree(text, 0)
|
||||
result = result_tree_to_paragraph(result_tree, [], [], self.with_filter)
|
||||
for e in result:
|
||||
if len(e['content']) > 4096:
|
||||
pass
|
||||
title_list = list(set([row.get('title') for row in result]))
|
||||
return [item for item in [self.post_reset_paragraph(row, title_list) for row in result] if
|
||||
'content' in item and len(item.get('content').strip()) > 0]
|
||||
|
||||
def post_reset_paragraph(self, paragraph: Dict, title_list: List[str]):
|
||||
result = self.content_is_null(paragraph, title_list)
|
||||
result = self.filter_title_special_characters(result)
|
||||
result = self.sub_title(result)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def sub_title(paragraph: Dict):
|
||||
if 'title' in paragraph:
|
||||
title = paragraph.get('title')
|
||||
if len(title) > 255:
|
||||
return {**paragraph, 'title': title[0:255], 'content': title[255:len(title)] + paragraph.get('content')}
|
||||
return paragraph
|
||||
|
||||
@staticmethod
|
||||
def content_is_null(paragraph: Dict, title_list: List[str]):
|
||||
if 'title' in paragraph:
|
||||
title = paragraph.get('title')
|
||||
content = paragraph.get('content')
|
||||
if (content is None or len(content.strip()) == 0) and (title is not None and len(title) > 0):
|
||||
find = [t for t in title_list if t.__contains__(title) and t != title]
|
||||
if find:
|
||||
return {'title': '', 'content': ''}
|
||||
return {'title': '', 'content': title}
|
||||
return paragraph
|
||||
|
||||
@staticmethod
|
||||
def filter_title_special_characters(paragraph: Dict):
|
||||
title = paragraph.get('title') if 'title' in paragraph else ''
|
||||
for title_special_characters in title_special_characters_list:
|
||||
title = title.replace(title_special_characters, '')
|
||||
return {**paragraph,
|
||||
'title': title}
|
||||
|
||||
|
||||
title_special_characters_list = ['#', '\n', '\r', '\\s']
|
||||
|
||||
default_split_pattern = {
|
||||
'md': [re.compile('(?<=^)# .*|(?<=\\n)# .*'),
|
||||
re.compile('(?<=\\n)(?<!#)## (?!#).*|(?<=^)(?<!#)## (?!#).*'),
|
||||
re.compile("(?<=\\n)(?<!#)### (?!#).*|(?<=^)(?<!#)### (?!#).*"),
|
||||
re.compile("(?<=\\n)(?<!#)#### (?!#).*|(?<=^)(?<!#)#### (?!#).*"),
|
||||
re.compile("(?<=\\n)(?<!#)##### (?!#).*|(?<=^)(?<!#)##### (?!#).*"),
|
||||
re.compile("(?<=\\n)(?<!#)###### (?!#).*|(?<=^)(?<!#)###### (?!#).*")],
|
||||
'default': [re.compile("(?<!\n)\n\n+")]
|
||||
}
|
||||
|
||||
|
||||
def get_split_model(filename: str, with_filter: bool = False, limit: int = 100000):
|
||||
"""
|
||||
根据文件名称获取分段模型
|
||||
:param limit: 每段大小
|
||||
:param with_filter: 是否过滤特殊字符
|
||||
:param filename: 文件名称
|
||||
:return: 分段模型
|
||||
"""
|
||||
if filename.endswith(".md"):
|
||||
pattern_list = default_split_pattern.get('md')
|
||||
return SplitModel(pattern_list, with_filter=with_filter, limit=limit)
|
||||
|
||||
pattern_list = default_split_pattern.get('md')
|
||||
return SplitModel(pattern_list, with_filter=with_filter, limit=limit)
|
||||
|
||||
|
||||
def to_title_tree_string(result_tree: List):
|
||||
f = flat(result_tree, [], [])
|
||||
return "\n│".join(list(map(lambda r: title_tostring(r), list(filter(lambda row: row.get('state') == 'title', f)))))
|
||||
|
||||
|
||||
def title_tostring(title_obj):
|
||||
f = "│ ".join(list(map(lambda index: " ", range(0, len(title_obj.get("parent_chain"))))))
|
||||
return f + "├───" + title_obj.get('content')
|
||||
|
|
@ -4,16 +4,16 @@ from drf_spectacular.utils import OpenApiParameter
|
|||
|
||||
from common.mixins.api_mixin import APIMixin
|
||||
from common.result import ResultSerializer, DefaultResultSerializer
|
||||
from modules.models.module import ModuleCreateRequest, ModuleEditRequest
|
||||
from modules.serializers.module import ModuleSerializer
|
||||
from folders.models.folder import FolderCreateRequest, FolderEditRequest
|
||||
from folders.serializers.folder import FolderSerializer
|
||||
|
||||
|
||||
class ModuleCreateResponse(ResultSerializer):
|
||||
class FolderCreateResponse(ResultSerializer):
|
||||
def get_data(self):
|
||||
return ModuleSerializer()
|
||||
return FolderSerializer()
|
||||
|
||||
|
||||
class ModuleCreateAPI(APIMixin):
|
||||
class FolderCreateAPI(APIMixin):
|
||||
@staticmethod
|
||||
def get_parameters():
|
||||
return [
|
||||
|
|
@ -36,14 +36,14 @@ class ModuleCreateAPI(APIMixin):
|
|||
|
||||
@staticmethod
|
||||
def get_request():
|
||||
return ModuleCreateRequest
|
||||
return FolderCreateRequest
|
||||
|
||||
@staticmethod
|
||||
def get_response():
|
||||
return ModuleCreateResponse
|
||||
return FolderCreateResponse
|
||||
|
||||
|
||||
class ModuleReadAPI(APIMixin):
|
||||
class FolderReadAPI(APIMixin):
|
||||
@staticmethod
|
||||
def get_parameters():
|
||||
return [
|
||||
|
|
@ -63,8 +63,8 @@ class ModuleReadAPI(APIMixin):
|
|||
required=True,
|
||||
),
|
||||
OpenApiParameter(
|
||||
name="module_id",
|
||||
description="模块id",
|
||||
name="folder_id",
|
||||
description="文件夹id",
|
||||
type=OpenApiTypes.STR,
|
||||
location='path',
|
||||
required=True,
|
||||
|
|
@ -73,23 +73,23 @@ class ModuleReadAPI(APIMixin):
|
|||
|
||||
@staticmethod
|
||||
def get_response():
|
||||
return ModuleCreateResponse
|
||||
return FolderCreateResponse
|
||||
|
||||
|
||||
class ModuleEditAPI(ModuleReadAPI):
|
||||
class FolderEditAPI(FolderReadAPI):
|
||||
|
||||
@staticmethod
|
||||
def get_request():
|
||||
return ModuleEditRequest
|
||||
return FolderEditRequest
|
||||
|
||||
|
||||
class ModuleDeleteAPI(ModuleReadAPI):
|
||||
class FolderDeleteAPI(FolderReadAPI):
|
||||
@staticmethod
|
||||
def get_response():
|
||||
return DefaultResultSerializer
|
||||
|
||||
|
||||
class ModuleTreeReadAPI(APIMixin):
|
||||
class FolderTreeReadAPI(APIMixin):
|
||||
@staticmethod
|
||||
def get_parameters():
|
||||
return [
|
||||
|
|
@ -2,14 +2,14 @@ from django.utils.translation import gettext_lazy as _
|
|||
from rest_framework import serializers
|
||||
|
||||
|
||||
class ModuleCreateRequest(serializers.Serializer):
|
||||
name = serializers.CharField(required=True, label=_('module name'))
|
||||
class FolderCreateRequest(serializers.Serializer):
|
||||
name = serializers.CharField(required=True, label=_('folder name'))
|
||||
|
||||
parent_id = serializers.CharField(required=False, allow_null=True, allow_blank=True, default='root',
|
||||
label=_('parent id'))
|
||||
|
||||
|
||||
class ModuleEditRequest(serializers.Serializer):
|
||||
name = serializers.CharField(required=False, allow_null=True, allow_blank=True, label=_('module name'))
|
||||
class FolderEditRequest(serializers.Serializer):
|
||||
name = serializers.CharField(required=False, allow_null=True, allow_blank=True, label=_('folder name'))
|
||||
parent_id = serializers.CharField(required=False, allow_null=True, allow_blank=True, default='root',
|
||||
label=_('parent id'))
|
||||
|
|
@ -7,30 +7,30 @@ from django.utils.translation import gettext_lazy as _
|
|||
from rest_framework import serializers
|
||||
|
||||
from common.constants.permission_constants import Group
|
||||
from knowledge.models import KnowledgeModule
|
||||
from modules.api.module import ModuleCreateRequest
|
||||
from tools.models import ToolModule
|
||||
from tools.serializers.tool_module import ToolModuleTreeSerializer
|
||||
from knowledge.models import KnowledgeFolder
|
||||
from folders.api.folder import FolderCreateRequest
|
||||
from tools.models import ToolFolder
|
||||
from tools.serializers.tool_folder import ToolFolderTreeSerializer
|
||||
|
||||
|
||||
def get_module_type(source):
|
||||
def get_folder_type(source):
|
||||
if source == Group.TOOL.name:
|
||||
return ToolModule
|
||||
return ToolFolder
|
||||
elif source == Group.APPLICATION.name:
|
||||
# todo app module
|
||||
# todo app folder
|
||||
return None
|
||||
elif source == Group.KNOWLEDGE.name:
|
||||
return KnowledgeModule
|
||||
return KnowledgeFolder
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
MODULE_DEPTH = 2 # Module 不能超过3层
|
||||
FOLDER_DEPTH = 2 # Folder 不能超过3层
|
||||
|
||||
|
||||
def check_depth(source, parent_id, current_depth=0):
|
||||
# Module 不能超过3层
|
||||
Module = get_module_type(source)
|
||||
# Folder 不能超过3层
|
||||
Folder = get_folder_type(source)
|
||||
|
||||
if parent_id != 'root':
|
||||
# 计算当前层级
|
||||
|
|
@ -40,14 +40,14 @@ def check_depth(source, parent_id, current_depth=0):
|
|||
# 向上追溯父节点
|
||||
while current_parent_id != 'root':
|
||||
depth += 1
|
||||
parent_node = QuerySet(Module).filter(id=current_parent_id).first()
|
||||
parent_node = QuerySet(Folder).filter(id=current_parent_id).first()
|
||||
if parent_node is None:
|
||||
break
|
||||
current_parent_id = parent_node.parent_id
|
||||
|
||||
# 验证层级深度
|
||||
if depth + current_depth > MODULE_DEPTH:
|
||||
raise serializers.ValidationError(_('Module depth cannot exceed 3 levels'))
|
||||
if depth + current_depth > FOLDER_DEPTH:
|
||||
raise serializers.ValidationError(_('Folder depth cannot exceed 3 levels'))
|
||||
|
||||
|
||||
def get_max_depth(current_node):
|
||||
|
|
@ -68,10 +68,10 @@ def get_max_depth(current_node):
|
|||
return max_depth
|
||||
|
||||
|
||||
class ModuleSerializer(serializers.Serializer):
|
||||
id = serializers.CharField(required=True, label=_('module id'))
|
||||
name = serializers.CharField(required=True, label=_('module name'))
|
||||
user_id = serializers.CharField(required=True, label=_('module user id'))
|
||||
class FolderSerializer(serializers.Serializer):
|
||||
id = serializers.CharField(required=True, label=_('folder id'))
|
||||
name = serializers.CharField(required=True, label=_('folder name'))
|
||||
user_id = serializers.CharField(required=True, label=_('folder user id'))
|
||||
workspace_id = serializers.CharField(required=False, allow_null=True, allow_blank=True, label=_('workspace id'))
|
||||
parent_id = serializers.CharField(required=False, allow_null=True, allow_blank=True, label=_('parent id'))
|
||||
|
||||
|
|
@ -82,84 +82,84 @@ class ModuleSerializer(serializers.Serializer):
|
|||
def insert(self, instance, with_valid=True):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
ModuleCreateRequest(data=instance).is_valid(raise_exception=True)
|
||||
FolderCreateRequest(data=instance).is_valid(raise_exception=True)
|
||||
|
||||
workspace_id = self.data.get('workspace_id', 'default')
|
||||
parent_id = instance.get('parent_id', 'root')
|
||||
name = instance.get('name')
|
||||
|
||||
Module = get_module_type(self.data.get('source'))
|
||||
if QuerySet(Module).filter(name=name, workspace_id=workspace_id, parent_id=parent_id).exists():
|
||||
raise serializers.ValidationError(_('Module name already exists'))
|
||||
# Module 不能超过3层
|
||||
Folder = get_folder_type(self.data.get('source'))
|
||||
if QuerySet(Folder).filter(name=name, workspace_id=workspace_id, parent_id=parent_id).exists():
|
||||
raise serializers.ValidationError(_('Folder name already exists'))
|
||||
# Folder 不能超过3层
|
||||
check_depth(self.data.get('source'), parent_id)
|
||||
|
||||
module = Module(
|
||||
folder = Folder(
|
||||
id=uuid.uuid7(),
|
||||
name=instance.get('name'),
|
||||
user_id=self.data.get('user_id'),
|
||||
workspace_id=workspace_id,
|
||||
parent_id=parent_id
|
||||
)
|
||||
module.save()
|
||||
return ModuleSerializer(module).data
|
||||
folder.save()
|
||||
return FolderSerializer(folder).data
|
||||
|
||||
class Operate(serializers.Serializer):
|
||||
id = serializers.CharField(required=True, label=_('module id'))
|
||||
id = serializers.CharField(required=True, label=_('folder id'))
|
||||
workspace_id = serializers.CharField(required=True, allow_null=True, allow_blank=True, label=_('workspace id'))
|
||||
source = serializers.CharField(required=True, label=_('source'))
|
||||
|
||||
@transaction.atomic
|
||||
def edit(self, instance):
|
||||
self.is_valid(raise_exception=True)
|
||||
Module = get_module_type(self.data.get('source'))
|
||||
Folder = get_folder_type(self.data.get('source'))
|
||||
current_id = self.data.get('id')
|
||||
current_node = Module.objects.get(id=current_id)
|
||||
current_node = Folder.objects.get(id=current_id)
|
||||
if current_node is None:
|
||||
raise serializers.ValidationError(_('Module does not exist'))
|
||||
raise serializers.ValidationError(_('Folder does not exist'))
|
||||
|
||||
edit_field_list = ['name']
|
||||
edit_dict = {field: instance.get(field) for field in edit_field_list if (
|
||||
field in instance and instance.get(field) is not None)}
|
||||
|
||||
QuerySet(Module).filter(id=current_id).update(**edit_dict)
|
||||
QuerySet(Folder).filter(id=current_id).update(**edit_dict)
|
||||
|
||||
# 模块间的移动
|
||||
parent_id = instance.get('parent_id')
|
||||
if parent_id is not None and current_id != 'root':
|
||||
# Module 不能超过3层
|
||||
# Folder 不能超过3层
|
||||
current_depth = get_max_depth(current_node)
|
||||
check_depth(self.data.get('source'), parent_id, current_depth)
|
||||
parent = Module.objects.get(id=parent_id)
|
||||
parent = Folder.objects.get(id=parent_id)
|
||||
current_node.move_to(parent)
|
||||
|
||||
return self.one()
|
||||
|
||||
def one(self):
|
||||
self.is_valid(raise_exception=True)
|
||||
Module = get_module_type(self.data.get('source'))
|
||||
module = QuerySet(Module).filter(id=self.data.get('id')).first()
|
||||
return ModuleSerializer(module).data
|
||||
Folder = get_folder_type(self.data.get('source'))
|
||||
folder = QuerySet(Folder).filter(id=self.data.get('id')).first()
|
||||
return FolderSerializer(folder).data
|
||||
|
||||
def delete(self):
|
||||
self.is_valid(raise_exception=True)
|
||||
if self.data.get('id') == 'root':
|
||||
raise serializers.ValidationError(_('Cannot delete root module'))
|
||||
Module = get_module_type(self.data.get('source'))
|
||||
QuerySet(Module).filter(id=self.data.get('id')).delete()
|
||||
raise serializers.ValidationError(_('Cannot delete root folder'))
|
||||
Folder = get_folder_type(self.data.get('source'))
|
||||
QuerySet(Folder).filter(id=self.data.get('id')).delete()
|
||||
|
||||
|
||||
class ModuleTreeSerializer(serializers.Serializer):
|
||||
class FolderTreeSerializer(serializers.Serializer):
|
||||
workspace_id = serializers.CharField(required=True, allow_null=True, allow_blank=True, label=_('workspace id'))
|
||||
source = serializers.CharField(required=True, label=_('source'))
|
||||
|
||||
def get_module_tree(self, name=None):
|
||||
def get_folder_tree(self, name=None):
|
||||
self.is_valid(raise_exception=True)
|
||||
Module = get_module_type(self.data.get('source'))
|
||||
Folder = get_folder_type(self.data.get('source'))
|
||||
if name is not None:
|
||||
nodes = Module.objects.filter(Q(workspace_id=self.data.get('workspace_id')) &
|
||||
nodes = Folder.objects.filter(Q(workspace_id=self.data.get('workspace_id')) &
|
||||
Q(name__contains=name)).get_cached_trees()
|
||||
else:
|
||||
nodes = Module.objects.filter(Q(workspace_id=self.data.get('workspace_id'))).get_cached_trees()
|
||||
serializer = ToolModuleTreeSerializer(nodes, many=True)
|
||||
nodes = Folder.objects.filter(Q(workspace_id=self.data.get('workspace_id'))).get_cached_trees()
|
||||
serializer = ToolFolderTreeSerializer(nodes, many=True)
|
||||
return serializer.data # 这是可序列化的字典
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
from django.urls import path
|
||||
|
||||
from . import views
|
||||
|
||||
app_name = "folder"
|
||||
urlpatterns = [
|
||||
path('workspace/<str:workspace_id>/<str:source>/folder', views.FolderView.as_view()),
|
||||
path('workspace/<str:workspace_id>/<str:source>/folder/<str:folder_id>', views.FolderView.Operate.as_view()),
|
||||
]
|
||||
|
|
@ -0,0 +1 @@
|
|||
from .folder import *
|
||||
|
|
@ -7,26 +7,26 @@ from common.auth import TokenAuth
|
|||
from common.auth.authentication import has_permissions
|
||||
from common.constants.permission_constants import Permission, Group, Operate
|
||||
from common.result import result
|
||||
from modules.api.module import ModuleCreateAPI, ModuleEditAPI, ModuleReadAPI, ModuleTreeReadAPI, ModuleDeleteAPI
|
||||
from modules.serializers.module import ModuleSerializer, ModuleTreeSerializer
|
||||
from folders.api.folder import FolderCreateAPI, FolderEditAPI, FolderReadAPI, FolderTreeReadAPI, FolderDeleteAPI
|
||||
from folders.serializers.folder import FolderSerializer, FolderTreeSerializer
|
||||
|
||||
|
||||
class ModuleView(APIView):
|
||||
class FolderView(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
@extend_schema(
|
||||
methods=['POST'],
|
||||
description=_('Create module'),
|
||||
operation_id=_('Create module'),
|
||||
parameters=ModuleCreateAPI.get_parameters(),
|
||||
request=ModuleCreateAPI.get_request(),
|
||||
responses=ModuleCreateAPI.get_response(),
|
||||
tags=[_('Module')]
|
||||
description=_('Create folder'),
|
||||
operation_id=_('Create folder'),
|
||||
parameters=FolderCreateAPI.get_parameters(),
|
||||
request=FolderCreateAPI.get_request(),
|
||||
responses=FolderCreateAPI.get_response(),
|
||||
tags=[_('Folder')]
|
||||
)
|
||||
@has_permissions(lambda r, kwargs: Permission(group=Group(kwargs.get('source')), operate=Operate.CREATE,
|
||||
resource_path=f"/WORKSPACE/{kwargs.get('workspace_id')}"))
|
||||
def post(self, request: Request, workspace_id: str, source: str):
|
||||
return result.success(ModuleSerializer.Create(
|
||||
return result.success(FolderSerializer.Create(
|
||||
data={'user_id': request.user.id,
|
||||
'source': source,
|
||||
'workspace_id': workspace_id}
|
||||
|
|
@ -34,64 +34,64 @@ class ModuleView(APIView):
|
|||
|
||||
@extend_schema(
|
||||
methods=['GET'],
|
||||
description=_('Get module tree'),
|
||||
operation_id=_('Get module tree'),
|
||||
parameters=ModuleTreeReadAPI.get_parameters(),
|
||||
responses=ModuleTreeReadAPI.get_response(),
|
||||
tags=[_('Module')]
|
||||
description=_('Get folder tree'),
|
||||
operation_id=_('Get folder tree'),
|
||||
parameters=FolderTreeReadAPI.get_parameters(),
|
||||
responses=FolderTreeReadAPI.get_response(),
|
||||
tags=[_('Folder')]
|
||||
)
|
||||
@has_permissions(lambda r, kwargs: Permission(group=Group(kwargs.get('source')), operate=Operate.READ,
|
||||
resource_path=f"/WORKSPACE/{kwargs.get('workspace_id')}"))
|
||||
def get(self, request: Request, workspace_id: str, source: str):
|
||||
return result.success(ModuleTreeSerializer(
|
||||
return result.success(FolderTreeSerializer(
|
||||
data={'workspace_id': workspace_id, 'source': source}
|
||||
).get_module_tree(request.query_params.get('name')))
|
||||
).get_folder_tree(request.query_params.get('name')))
|
||||
|
||||
class Operate(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
@extend_schema(
|
||||
methods=['PUT'],
|
||||
description=_('Update module'),
|
||||
operation_id=_('Update module'),
|
||||
parameters=ModuleEditAPI.get_parameters(),
|
||||
request=ModuleEditAPI.get_request(),
|
||||
responses=ModuleEditAPI.get_response(),
|
||||
tags=[_('Module')]
|
||||
description=_('Update folder'),
|
||||
operation_id=_('Update folder'),
|
||||
parameters=FolderEditAPI.get_parameters(),
|
||||
request=FolderEditAPI.get_request(),
|
||||
responses=FolderEditAPI.get_response(),
|
||||
tags=[_('Folder')]
|
||||
)
|
||||
@has_permissions(lambda r, kwargs: Permission(group=Group(kwargs.get('source')), operate=Operate.EDIT,
|
||||
resource_path=f"/WORKSPACE/{kwargs.get('workspace_id')}"))
|
||||
def put(self, request: Request, workspace_id: str, source: str, module_id: str):
|
||||
return result.success(ModuleSerializer.Operate(
|
||||
data={'id': module_id, 'workspace_id': workspace_id, 'source': source}
|
||||
def put(self, request: Request, workspace_id: str, source: str, folder_id: str):
|
||||
return result.success(FolderSerializer.Operate(
|
||||
data={'id': folder_id, 'workspace_id': workspace_id, 'source': source}
|
||||
).edit(request.data))
|
||||
|
||||
@extend_schema(
|
||||
methods=['GET'],
|
||||
description=_('Get module'),
|
||||
operation_id=_('Get module'),
|
||||
parameters=ModuleReadAPI.get_parameters(),
|
||||
responses=ModuleReadAPI.get_response(),
|
||||
tags=[_('Module')]
|
||||
description=_('Get folder'),
|
||||
operation_id=_('Get folder'),
|
||||
parameters=FolderReadAPI.get_parameters(),
|
||||
responses=FolderReadAPI.get_response(),
|
||||
tags=[_('Folder')]
|
||||
)
|
||||
@has_permissions(lambda r, kwargs: Permission(group=Group(kwargs.get('source')), operate=Operate.READ,
|
||||
resource_path=f"/WORKSPACE/{kwargs.get('workspace_id')}"))
|
||||
def get(self, request: Request, workspace_id: str, source: str, module_id: str):
|
||||
return result.success(ModuleSerializer.Operate(
|
||||
data={'id': module_id, 'workspace_id': workspace_id, 'source': source}
|
||||
def get(self, request: Request, workspace_id: str, source: str, folder_id: str):
|
||||
return result.success(FolderSerializer.Operate(
|
||||
data={'id': folder_id, 'workspace_id': workspace_id, 'source': source}
|
||||
).one())
|
||||
|
||||
@extend_schema(
|
||||
methods=['DELETE'],
|
||||
description=_('Delete module'),
|
||||
operation_id=_('Delete module'),
|
||||
parameters=ModuleDeleteAPI.get_parameters(),
|
||||
responses=ModuleDeleteAPI.get_response(),
|
||||
tags=[_('Module')]
|
||||
description=_('Delete folder'),
|
||||
operation_id=_('Delete folder'),
|
||||
parameters=FolderDeleteAPI.get_parameters(),
|
||||
responses=FolderDeleteAPI.get_response(),
|
||||
tags=[_('Folder')]
|
||||
)
|
||||
@has_permissions(lambda r, kwargs: Permission(group=Group(kwargs.get('source')), operate=Operate.DELETE,
|
||||
resource_path=f"/WORKSPACE/{kwargs.get('workspace_id')}"))
|
||||
def delete(self, request: Request, workspace_id: str, source: str, module_id: str):
|
||||
return result.success(ModuleSerializer.Operate(
|
||||
data={'id': module_id, 'workspace_id': workspace_id, 'source': source}
|
||||
def delete(self, request: Request, workspace_id: str, source: str, folder_id: str):
|
||||
return result.success(FolderSerializer.Operate(
|
||||
data={'id': folder_id, 'workspace_id': workspace_id, 'source': source}
|
||||
).delete())
|
||||
|
|
@ -67,8 +67,8 @@ class KnowledgeTreeReadAPI(APIMixin):
|
|||
required=True,
|
||||
),
|
||||
OpenApiParameter(
|
||||
name="module_id",
|
||||
description="模块id",
|
||||
name="folder_id",
|
||||
description="文件夹id",
|
||||
type=OpenApiTypes.STR,
|
||||
location='query',
|
||||
required=False,
|
||||
|
|
|
|||
|
|
@ -6,12 +6,12 @@ import mptt.fields
|
|||
import uuid_utils.compat
|
||||
from django.db import migrations, models
|
||||
|
||||
from knowledge.models import KnowledgeModule
|
||||
from knowledge.models import KnowledgeFolder
|
||||
|
||||
|
||||
def insert_default_data(apps, schema_editor):
|
||||
# 创建一个根模块(没有父节点)
|
||||
KnowledgeModule.objects.create(id='root', name='根目录', user_id='f0dd8f71-e4ee-11ee-8c84-a8a1595801ab')
|
||||
KnowledgeFolder.objects.create(id='root', name='根目录', user_id='f0dd8f71-e4ee-11ee-8c84-a8a1595801ab')
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
|
@ -42,7 +42,7 @@ class Migration(migrations.Migration):
|
|||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name='KnowledgeModule',
|
||||
name='KnowledgeFolder',
|
||||
fields=[
|
||||
('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')),
|
||||
('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')),
|
||||
|
|
@ -57,12 +57,12 @@ class Migration(migrations.Migration):
|
|||
('level', models.PositiveIntegerField(editable=False)),
|
||||
('parent',
|
||||
mptt.fields.TreeForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE,
|
||||
related_name='children', to='knowledge.knowledgemodule')),
|
||||
related_name='children', to='knowledge.knowledgefolder')),
|
||||
('user', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='users.user',
|
||||
verbose_name='用户id')),
|
||||
],
|
||||
options={
|
||||
'db_table': 'knowledge_module',
|
||||
'db_table': 'knowledge_folder',
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
|
|
@ -84,10 +84,10 @@ class Migration(migrations.Migration):
|
|||
('scope',
|
||||
models.CharField(choices=[('SHARED', '共享'), ('WORKSPACE', '工作空间可用')], default='WORKSPACE',
|
||||
max_length=20, verbose_name='可用范围')),
|
||||
('module',
|
||||
('folder',
|
||||
models.ForeignKey(default='root', on_delete=django.db.models.deletion.CASCADE,
|
||||
to='knowledge.knowledgemodule',
|
||||
verbose_name='模块id')),
|
||||
to='knowledge.knowledgefolder',
|
||||
verbose_name='文件夹id')),
|
||||
('embedding_model', models.ForeignKey(default=knowledge.models.knowledge.default_model,
|
||||
on_delete=django.db.models.deletion.DO_NOTHING,
|
||||
to='models_provider.model', verbose_name='向量模型')),
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ def default_model():
|
|||
return uuid.UUID('42f63a3d-427e-11ef-b3ec-a8a1595801ab')
|
||||
|
||||
|
||||
class KnowledgeModule(MPTTModel, AppModelMixin):
|
||||
class KnowledgeFolder(MPTTModel, AppModelMixin):
|
||||
id = models.CharField(primary_key=True, max_length=64, editable=False, verbose_name="主键id")
|
||||
name = models.CharField(max_length=64, verbose_name="文件夹名称")
|
||||
user = models.ForeignKey(User, on_delete=models.DO_NOTHING, verbose_name="用户id")
|
||||
|
|
@ -36,7 +36,7 @@ class KnowledgeModule(MPTTModel, AppModelMixin):
|
|||
parent = TreeForeignKey('self', on_delete=models.CASCADE, null=True, blank=True, related_name='children')
|
||||
|
||||
class Meta:
|
||||
db_table = "knowledge_module"
|
||||
db_table = "knowledge_folder"
|
||||
|
||||
class MPTTMeta:
|
||||
order_insertion_by = ['name']
|
||||
|
|
@ -51,7 +51,7 @@ class Knowledge(AppModelMixin):
|
|||
user = models.ForeignKey(User, on_delete=models.DO_NOTHING, verbose_name="所属用户")
|
||||
type = models.IntegerField(verbose_name='类型', choices=KnowledgeType.choices, default=KnowledgeType.BASE)
|
||||
scope = models.CharField(max_length=20, verbose_name='可用范围', choices=KnowledgeScope.choices, default=KnowledgeScope.WORKSPACE)
|
||||
module = models.ForeignKey(KnowledgeModule, on_delete=models.CASCADE, verbose_name="模块id", default='root')
|
||||
folder = models.ForeignKey(KnowledgeFolder, on_delete=models.CASCADE, verbose_name="文件夹id", default='root')
|
||||
embedding_model = models.ForeignKey(Model, on_delete=models.DO_NOTHING, verbose_name="向量模型",
|
||||
default=default_model)
|
||||
meta = models.JSONField(verbose_name="元数据", default=dict)
|
||||
|
|
|
|||
|
|
@ -1,25 +1,36 @@
|
|||
from typing import Dict
|
||||
|
||||
import uuid_utils as uuid
|
||||
from django.db import transaction
|
||||
from django.db.models import QuerySet
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from rest_framework import serializers
|
||||
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.utils.common import valid_license
|
||||
from knowledge.models import Knowledge, KnowledgeScope, KnowledgeType
|
||||
|
||||
|
||||
class KnowledgeModelSerializer(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = Knowledge
|
||||
fields = ['id', 'name', 'desc', 'meta', 'module_id', 'type', 'workspace_id', 'create_time', 'update_time']
|
||||
fields = ['id', 'name', 'desc', 'meta', 'folder_id', 'type', 'workspace_id', 'create_time', 'update_time']
|
||||
|
||||
|
||||
class KnowledgeBaseCreateRequest(serializers.Serializer):
|
||||
name = serializers.CharField(required=True, label=_('knowledge name'))
|
||||
folder_id = serializers.CharField(required=True, label=_('folder id'))
|
||||
desc = serializers.CharField(required=False, allow_null=True, allow_blank=True, label=_('knowledge description'))
|
||||
embedding = serializers.CharField(required=True, label=_('knowledge embedding'))
|
||||
|
||||
|
||||
class KnowledgeWebCreateRequest(serializers.Serializer):
|
||||
name = serializers.CharField(required=True, label=_('knowledge name'))
|
||||
folder_id = serializers.CharField(required=True, label=_('folder id'))
|
||||
desc = serializers.CharField(required=False, allow_null=True, allow_blank=True, label=_('knowledge description'))
|
||||
embedding = serializers.CharField(required=True, label=_('knowledge embedding'))
|
||||
source_url = serializers.CharField(required=True, label=_('source url'))
|
||||
selector = serializers.CharField(required=True, label=_('knowledge selector'))
|
||||
|
||||
|
||||
class KnowledgeSerializer(serializers.Serializer):
|
||||
|
|
@ -27,10 +38,19 @@ class KnowledgeSerializer(serializers.Serializer):
|
|||
user_id = serializers.UUIDField(required=True, label=_('user id'))
|
||||
workspace_id = serializers.CharField(required=True, label=_('workspace id'))
|
||||
|
||||
def insert(self, instance, with_valid=True):
|
||||
@valid_license(model=Knowledge, count=50,
|
||||
message=_(
|
||||
'The community version supports up to 50 knowledge bases. If you need more knowledge bases, please contact us (https://fit2cloud.com/).'))
|
||||
# @post(post_function=post_embedding_dataset)
|
||||
@transaction.atomic
|
||||
def save_base(self, instance, with_valid=True):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
KnowledgeBaseCreateRequest(data=instance).is_valid(raise_exception=True)
|
||||
if QuerySet(Knowledge).filter(workspace_id=self.data.get('workspace_id'),
|
||||
name=instance.get('name')).exists():
|
||||
raise AppApiException(500, _('Knowledge base name duplicate!'))
|
||||
|
||||
knowledge = Knowledge(
|
||||
id=uuid.uuid7(),
|
||||
name=instance.get('name'),
|
||||
|
|
@ -39,13 +59,43 @@ class KnowledgeSerializer(serializers.Serializer):
|
|||
type=instance.get('type', KnowledgeType.BASE),
|
||||
user_id=self.data.get('user_id'),
|
||||
scope=KnowledgeScope.WORKSPACE,
|
||||
module_id=instance.get('module_id', 'root'),
|
||||
folder_id=instance.get('folder_id', 'root'),
|
||||
embedding_model_id=instance.get('embedding'),
|
||||
meta=instance.get('meta', {}),
|
||||
)
|
||||
knowledge.save()
|
||||
return KnowledgeModelSerializer(knowledge).data
|
||||
|
||||
def save_web(self, instance: Dict, with_valid=True):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
KnowledgeWebCreateRequest(data=instance).is_valid(raise_exception=True)
|
||||
|
||||
if QuerySet(Knowledge).filter(workspace_id=self.data.get('workspace_id'),
|
||||
name=instance.get('name')).exists():
|
||||
raise AppApiException(500, _('Knowledge base name duplicate!'))
|
||||
|
||||
knowledge_id = uuid.uuid7()
|
||||
knowledge = Knowledge(
|
||||
id=knowledge_id,
|
||||
name=instance.get('name'),
|
||||
desc=instance.get('desc'),
|
||||
user_id=self.data.get('user_id'),
|
||||
type=instance.get('type', KnowledgeType.WEB),
|
||||
scope=KnowledgeScope.WORKSPACE,
|
||||
folder_id=instance.get('folder_id', 'root'),
|
||||
embedding_model_id=instance.get('embedding'),
|
||||
meta={
|
||||
'source_url': instance.get('source_url'),
|
||||
'selector': instance.get('selector'),
|
||||
'embedding_model_id': instance.get('embedding')
|
||||
},
|
||||
)
|
||||
knowledge.save()
|
||||
# sync_web_knowledge.delay(str(knowledge_id), instance.get('source_url'), instance.get('selector'))
|
||||
return {**KnowledgeModelSerializer(knowledge).data,
|
||||
'document_list': []}
|
||||
|
||||
|
||||
class KnowledgeTreeSerializer(serializers.Serializer):
|
||||
def get_knowledge_list(self, param):
|
||||
|
|
|
|||
|
|
@ -0,0 +1 @@
|
|||
from .sync import *
|
||||
|
|
@ -0,0 +1,63 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: sync.py
|
||||
@date:2024/8/20 21:37
|
||||
@desc:
|
||||
"""
|
||||
|
||||
import logging
|
||||
import traceback
|
||||
from typing import List
|
||||
|
||||
from celery_once import QueueOnce
|
||||
|
||||
from common.utils.fork import ForkManage, Fork
|
||||
from .tools import get_save_handler, get_sync_web_document_handler, get_sync_handler
|
||||
|
||||
from ops import celery_app
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
max_kb_error = logging.getLogger("max_kb_error")
|
||||
max_kb = logging.getLogger("max_kb")
|
||||
|
||||
|
||||
@celery_app.task(base=QueueOnce, once={'keys': ['knowledge_id']}, name='celery:sync_web_knowledge')
|
||||
def sync_web_knowledge(knowledge_id: str, url: str, selector: str):
|
||||
try:
|
||||
max_kb.info(
|
||||
_('Start--->Start synchronization web knowledge base:{knowledge_id}').format(knowledge_id=knowledge_id))
|
||||
ForkManage(url, selector.split(" ") if selector is not None else []).fork(2, set(),
|
||||
get_save_handler(knowledge_id,
|
||||
selector))
|
||||
|
||||
max_kb.info(_('End--->End synchronization web knowledge base:{knowledge_id}').format(knowledge_id=knowledge_id))
|
||||
except Exception as e:
|
||||
max_kb_error.error(_('Synchronize web knowledge base:{knowledge_id} error{error}{traceback}').format(
|
||||
knowledge_id=knowledge_id, error=str(e), traceback=traceback.format_exc()))
|
||||
|
||||
|
||||
@celery_app.task(base=QueueOnce, once={'keys': ['knowledge_id']}, name='celery:sync_replace_web_knowledge')
|
||||
def sync_replace_web_knowledge(knowledge_id: str, url: str, selector: str):
|
||||
try:
|
||||
max_kb.info(
|
||||
_('Start--->Start synchronization web knowledge base:{knowledge_id}').format(knowledge_id=knowledge_id))
|
||||
ForkManage(url, selector.split(" ") if selector is not None else []).fork(2, set(),
|
||||
get_sync_handler(knowledge_id
|
||||
))
|
||||
max_kb.info(_('End--->End synchronization web knowledge base:{knowledge_id}').format(knowledge_id=knowledge_id))
|
||||
except Exception as e:
|
||||
max_kb_error.error(_('Synchronize web knowledge base:{knowledge_id} error{error}{traceback}').format(
|
||||
knowledge_id=knowledge_id, error=str(e), traceback=traceback.format_exc()))
|
||||
|
||||
|
||||
@celery_app.task(name='celery:sync_web_document')
|
||||
def sync_web_document(knowledge_id, source_url_list: List[str], selector: str):
|
||||
handler = get_sync_web_document_handler(knowledge_id)
|
||||
for source_url in source_url_list:
|
||||
try:
|
||||
result = Fork(base_fork_url=source_url, selector_list=selector.split(' ')).fork()
|
||||
handler(source_url, selector, result)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
|
@ -0,0 +1,114 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: tools.py
|
||||
@date:2024/8/20 21:48
|
||||
@desc:
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
import traceback
|
||||
|
||||
from django.db.models import QuerySet
|
||||
|
||||
from common.utils.fork import ChildLink, Fork
|
||||
from common.utils.split_model import get_split_model
|
||||
from knowledge.models.knowledge import KnowledgeType, Document, DataSet, Status
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
max_kb_error = logging.getLogger("max_kb_error")
|
||||
max_kb = logging.getLogger("max_kb")
|
||||
|
||||
|
||||
def get_save_handler(dataset_id, selector):
|
||||
from knowledge.serializers.document_serializers import DocumentSerializers
|
||||
|
||||
def handler(child_link: ChildLink, response: Fork.Response):
|
||||
if response.status == 200:
|
||||
try:
|
||||
document_name = child_link.tag.text if child_link.tag is not None and len(
|
||||
child_link.tag.text.strip()) > 0 else child_link.url
|
||||
paragraphs = get_split_model('web.md').parse(response.content)
|
||||
DocumentSerializers.Create(data={'dataset_id': dataset_id}).save(
|
||||
{'name': document_name, 'paragraphs': paragraphs,
|
||||
'meta': {'source_url': child_link.url, 'selector': selector},
|
||||
'type': KnowledgeType.WEB}, with_valid=True)
|
||||
except Exception as e:
|
||||
logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}')
|
||||
|
||||
return handler
|
||||
|
||||
|
||||
def get_sync_handler(dataset_id):
|
||||
from knowledge.serializers.document_serializers import DocumentSerializers
|
||||
dataset = QuerySet(DataSet).filter(id=dataset_id).first()
|
||||
|
||||
def handler(child_link: ChildLink, response: Fork.Response):
|
||||
if response.status == 200:
|
||||
try:
|
||||
|
||||
document_name = child_link.tag.text if child_link.tag is not None and len(
|
||||
child_link.tag.text.strip()) > 0 else child_link.url
|
||||
paragraphs = get_split_model('web.md').parse(response.content)
|
||||
first = QuerySet(Document).filter(meta__source_url=child_link.url.strip(),
|
||||
dataset=dataset).first()
|
||||
if first is not None:
|
||||
# 如果存在,使用文档同步
|
||||
DocumentSerializers.Sync(data={'document_id': first.id}).sync()
|
||||
else:
|
||||
# 插入
|
||||
DocumentSerializers.Create(data={'dataset_id': dataset.id}).save(
|
||||
{'name': document_name, 'paragraphs': paragraphs,
|
||||
'meta': {'source_url': child_link.url.strip(), 'selector': dataset.meta.get('selector')},
|
||||
'type': Type.web}, with_valid=True)
|
||||
except Exception as e:
|
||||
logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}')
|
||||
|
||||
return handler
|
||||
|
||||
|
||||
def get_sync_web_document_handler(dataset_id):
|
||||
from knowledge.serializers.document_serializers import DocumentSerializers
|
||||
|
||||
def handler(source_url: str, selector, response: Fork.Response):
|
||||
if response.status == 200:
|
||||
try:
|
||||
paragraphs = get_split_model('web.md').parse(response.content)
|
||||
# 插入
|
||||
DocumentSerializers.Create(data={'dataset_id': dataset_id}).save(
|
||||
{'name': source_url[0:128], 'paragraphs': paragraphs,
|
||||
'meta': {'source_url': source_url, 'selector': selector},
|
||||
'type': KnowledgeType.WEB}, with_valid=True)
|
||||
except Exception as e:
|
||||
logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}')
|
||||
else:
|
||||
Document(name=source_url[0:128],
|
||||
dataset_id=dataset_id,
|
||||
meta={'source_url': source_url, 'selector': selector},
|
||||
type=KnowledgeType.WEB,
|
||||
char_length=0,
|
||||
status=Status.error).save()
|
||||
|
||||
return handler
|
||||
|
||||
|
||||
def save_problem(dataset_id, document_id, paragraph_id, problem):
|
||||
from knowledge.serializers.paragraph_serializers import ParagraphSerializers
|
||||
# print(f"dataset_id: {dataset_id}")
|
||||
# print(f"document_id: {document_id}")
|
||||
# print(f"paragraph_id: {paragraph_id}")
|
||||
# print(f"problem: {problem}")
|
||||
problem = re.sub(r"^\d+\.\s*", "", problem)
|
||||
pattern = r"<question>(.*?)</question>"
|
||||
match = re.search(pattern, problem)
|
||||
problem = match.group(1) if match else None
|
||||
if problem is None or len(problem) == 0:
|
||||
return
|
||||
try:
|
||||
ParagraphSerializers.Problem(
|
||||
data={"dataset_id": dataset_id, 'document_id': document_id,
|
||||
'paragraph_id': paragraph_id}).save(instance={"content": problem}, with_valid=True)
|
||||
except Exception as e:
|
||||
max_kb_error.error(_('Association problem failed {error}').format(error=str(e)))
|
||||
|
|
@ -16,8 +16,8 @@ class KnowledgeView(APIView):
|
|||
|
||||
@extend_schema(
|
||||
methods=['GET'],
|
||||
description=_('Get knowledge by module'),
|
||||
operation_id=_('Get knowledge by module'),
|
||||
description=_('Get knowledge by folder'),
|
||||
operation_id=_('Get knowledge by folder'),
|
||||
parameters=KnowledgeTreeReadAPI.get_parameters(),
|
||||
responses=KnowledgeTreeReadAPI.get_response(),
|
||||
tags=[_('Knowledge Base')]
|
||||
|
|
@ -26,7 +26,7 @@ class KnowledgeView(APIView):
|
|||
def get(self, request: Request, workspace_id: str):
|
||||
return result.success(KnowledgeTreeSerializer(
|
||||
data={'workspace_id': workspace_id}
|
||||
).get_knowledge_list(request.query_params.get('module_id')))
|
||||
).get_knowledge_list(request.query_params.get('folder_id')))
|
||||
|
||||
|
||||
class KnowledgeBaseView(APIView):
|
||||
|
|
@ -45,7 +45,7 @@ class KnowledgeBaseView(APIView):
|
|||
def post(self, request: Request, workspace_id: str):
|
||||
return result.success(KnowledgeSerializer.Create(
|
||||
data={'user_id': request.user.id, 'workspace_id': workspace_id}
|
||||
).insert(request.data))
|
||||
).save_base(request.data))
|
||||
|
||||
|
||||
class KnowledgeWebView(APIView):
|
||||
|
|
@ -64,4 +64,4 @@ class KnowledgeWebView(APIView):
|
|||
def post(self, request: Request, workspace_id: str):
|
||||
return result.success(KnowledgeSerializer.Create(
|
||||
data={'user_id': request.user.id, 'workspace_id': workspace_id}
|
||||
).insert(request.data))
|
||||
).save_web(request.data))
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ urlpatterns = [
|
|||
path("api/", include("users.urls")),
|
||||
path("api/", include("tools.urls")),
|
||||
path("api/", include("models_provider.urls")),
|
||||
path("api/", include("modules.urls")),
|
||||
path("api/", include("folders.urls")),
|
||||
path("api/", include("knowledge.urls")),
|
||||
]
|
||||
urlpatterns += [
|
||||
|
|
|
|||
|
|
@ -1,9 +0,0 @@
|
|||
from django.urls import path
|
||||
|
||||
from . import views
|
||||
|
||||
app_name = "module"
|
||||
urlpatterns = [
|
||||
path('workspace/<str:workspace_id>/<str:source>/module', views.ModuleView.as_view()),
|
||||
path('workspace/<str:workspace_id>/<str:source>/module/<str:module_id>', views.ModuleView.Operate.as_view()),
|
||||
]
|
||||
|
|
@ -1 +0,0 @@
|
|||
from .module import *
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: __init__.py.py
|
||||
@date:2024/8/16 14:47
|
||||
@desc:
|
||||
"""
|
||||
from .celery import app as celery_app
|
||||
|
|
@ -0,0 +1,33 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import os
|
||||
|
||||
from celery import Celery
|
||||
from celery.schedules import crontab
|
||||
from kombu import Exchange, Queue
|
||||
from maxkb import settings
|
||||
from .heartbeat import *
|
||||
|
||||
# set the default Django settings module for the 'celery' program.
|
||||
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'smartdoc.settings')
|
||||
|
||||
app = Celery('MaxKB')
|
||||
|
||||
configs = {k: v for k, v in settings.__dict__.items() if k.startswith('CELERY')}
|
||||
configs['worker_concurrency'] = 5
|
||||
# Using a string here means the worker will not have to
|
||||
# pickle the object when using Windows.
|
||||
# app.config_from_object('django.conf:settings', namespace='CELERY')
|
||||
|
||||
configs["task_queues"] = [
|
||||
Queue("celery", Exchange("celery"), routing_key="celery"),
|
||||
Queue("model", Exchange("model"), routing_key="model")
|
||||
]
|
||||
app.namespace = 'CELERY'
|
||||
app.conf.update(
|
||||
{key.replace('CELERY_', '') if key.replace('CELERY_', '').lower() == key.replace('CELERY_',
|
||||
'') else key: configs.get(
|
||||
key) for
|
||||
key
|
||||
in configs.keys()})
|
||||
app.autodiscover_tasks(lambda: [app_config.split('.')[0] for app_config in settings.INSTALLED_APPS])
|
||||
|
|
@ -0,0 +1,4 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
|
||||
CELERY_LOG_MAGIC_MARK = b'\x00\x00\x00\x00\x00'
|
||||
|
|
@ -0,0 +1,104 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
from functools import wraps
|
||||
|
||||
_need_registered_period_tasks = []
|
||||
_after_app_ready_start_tasks = []
|
||||
_after_app_shutdown_clean_periodic_tasks = []
|
||||
|
||||
|
||||
def add_register_period_task(task):
|
||||
_need_registered_period_tasks.append(task)
|
||||
|
||||
|
||||
def get_register_period_tasks():
|
||||
return _need_registered_period_tasks
|
||||
|
||||
|
||||
def add_after_app_shutdown_clean_task(name):
|
||||
_after_app_shutdown_clean_periodic_tasks.append(name)
|
||||
|
||||
|
||||
def get_after_app_shutdown_clean_tasks():
|
||||
return _after_app_shutdown_clean_periodic_tasks
|
||||
|
||||
|
||||
def add_after_app_ready_task(name):
|
||||
_after_app_ready_start_tasks.append(name)
|
||||
|
||||
|
||||
def get_after_app_ready_tasks():
|
||||
return _after_app_ready_start_tasks
|
||||
|
||||
|
||||
def register_as_period_task(
|
||||
crontab=None, interval=None, name=None,
|
||||
args=(), kwargs=None,
|
||||
description=''):
|
||||
"""
|
||||
Warning: Task must have not any args and kwargs
|
||||
:param crontab: "* * * * *"
|
||||
:param interval: 60*60*60
|
||||
:param args: ()
|
||||
:param kwargs: {}
|
||||
:param description: "
|
||||
:param name: ""
|
||||
:return:
|
||||
"""
|
||||
if crontab is None and interval is None:
|
||||
raise SyntaxError("Must set crontab or interval one")
|
||||
|
||||
def decorate(func):
|
||||
if crontab is None and interval is None:
|
||||
raise SyntaxError("Interval and crontab must set one")
|
||||
|
||||
# Because when this decorator run, the task was not created,
|
||||
# So we can't use func.name
|
||||
task = '{func.__module__}.{func.__name__}'.format(func=func)
|
||||
_name = name if name else task
|
||||
add_register_period_task({
|
||||
_name: {
|
||||
'task': task,
|
||||
'interval': interval,
|
||||
'crontab': crontab,
|
||||
'args': args,
|
||||
'kwargs': kwargs if kwargs else {},
|
||||
'description': description
|
||||
}
|
||||
})
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorate
|
||||
|
||||
|
||||
def after_app_ready_start(func):
|
||||
# Because when this decorator run, the task was not created,
|
||||
# So we can't use func.name
|
||||
name = '{func.__module__}.{func.__name__}'.format(func=func)
|
||||
if name not in _after_app_ready_start_tasks:
|
||||
add_after_app_ready_task(name)
|
||||
|
||||
@wraps(func)
|
||||
def decorate(*args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return decorate
|
||||
|
||||
|
||||
def after_app_shutdown_clean_periodic(func):
|
||||
# Because when this decorator run, the task was not created,
|
||||
# So we can't use func.name
|
||||
name = '{func.__module__}.{func.__name__}'.format(func=func)
|
||||
if name not in _after_app_shutdown_clean_periodic_tasks:
|
||||
add_after_app_shutdown_clean_task(name)
|
||||
|
||||
@wraps(func)
|
||||
def decorate(*args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return decorate
|
||||
|
|
@ -0,0 +1,25 @@
|
|||
from pathlib import Path
|
||||
|
||||
from celery.signals import heartbeat_sent, worker_ready, worker_shutdown
|
||||
|
||||
|
||||
@heartbeat_sent.connect
|
||||
def heartbeat(sender, **kwargs):
|
||||
worker_name = sender.eventer.hostname.split('@')[0]
|
||||
heartbeat_path = Path('/tmp/worker_heartbeat_{}'.format(worker_name))
|
||||
heartbeat_path.touch()
|
||||
|
||||
|
||||
@worker_ready.connect
|
||||
def worker_ready(sender, **kwargs):
|
||||
worker_name = sender.hostname.split('@')[0]
|
||||
ready_path = Path('/tmp/worker_ready_{}'.format(worker_name))
|
||||
ready_path.touch()
|
||||
|
||||
|
||||
@worker_shutdown.connect
|
||||
def worker_shutdown(sender, **kwargs):
|
||||
worker_name = sender.hostname.split('@')[0]
|
||||
for signal in ['ready', 'heartbeat']:
|
||||
path = Path('/tmp/worker_{}_{}'.format(signal, worker_name))
|
||||
path.unlink(missing_ok=True)
|
||||
|
|
@ -0,0 +1,225 @@
|
|||
from logging import StreamHandler
|
||||
from threading import get_ident
|
||||
|
||||
from celery import current_task
|
||||
from celery.signals import task_prerun, task_postrun
|
||||
from django.conf import settings
|
||||
from kombu import Connection, Exchange, Queue, Producer
|
||||
from kombu.mixins import ConsumerMixin
|
||||
|
||||
from .utils import get_celery_task_log_path
|
||||
from .const import CELERY_LOG_MAGIC_MARK
|
||||
|
||||
routing_key = 'celery_log'
|
||||
celery_log_exchange = Exchange('celery_log_exchange', type='direct')
|
||||
celery_log_queue = [Queue('celery_log', celery_log_exchange, routing_key=routing_key)]
|
||||
|
||||
|
||||
class CeleryLoggerConsumer(ConsumerMixin):
|
||||
def __init__(self):
|
||||
self.connection = Connection(settings.CELERY_LOG_BROKER_URL)
|
||||
|
||||
def get_consumers(self, Consumer, channel):
|
||||
return [Consumer(queues=celery_log_queue,
|
||||
accept=['pickle', 'json'],
|
||||
callbacks=[self.process_task])
|
||||
]
|
||||
|
||||
def handle_task_start(self, task_id, message):
|
||||
pass
|
||||
|
||||
def handle_task_end(self, task_id, message):
|
||||
pass
|
||||
|
||||
def handle_task_log(self, task_id, msg, message):
|
||||
pass
|
||||
|
||||
def process_task(self, body, message):
|
||||
action = body.get('action')
|
||||
task_id = body.get('task_id')
|
||||
msg = body.get('msg')
|
||||
if action == CeleryLoggerProducer.ACTION_TASK_LOG:
|
||||
self.handle_task_log(task_id, msg, message)
|
||||
elif action == CeleryLoggerProducer.ACTION_TASK_START:
|
||||
self.handle_task_start(task_id, message)
|
||||
elif action == CeleryLoggerProducer.ACTION_TASK_END:
|
||||
self.handle_task_end(task_id, message)
|
||||
|
||||
|
||||
class CeleryLoggerProducer:
|
||||
ACTION_TASK_START, ACTION_TASK_LOG, ACTION_TASK_END = range(3)
|
||||
|
||||
def __init__(self):
|
||||
self.connection = Connection(settings.CELERY_LOG_BROKER_URL)
|
||||
|
||||
@property
|
||||
def producer(self):
|
||||
return Producer(self.connection)
|
||||
|
||||
def publish(self, payload):
|
||||
self.producer.publish(
|
||||
payload, serializer='json', exchange=celery_log_exchange,
|
||||
declare=[celery_log_exchange], routing_key=routing_key
|
||||
)
|
||||
|
||||
def log(self, task_id, msg):
|
||||
payload = {'task_id': task_id, 'msg': msg, 'action': self.ACTION_TASK_LOG}
|
||||
return self.publish(payload)
|
||||
|
||||
def read(self):
|
||||
pass
|
||||
|
||||
def flush(self):
|
||||
pass
|
||||
|
||||
def task_end(self, task_id):
|
||||
payload = {'task_id': task_id, 'action': self.ACTION_TASK_END}
|
||||
return self.publish(payload)
|
||||
|
||||
def task_start(self, task_id):
|
||||
payload = {'task_id': task_id, 'action': self.ACTION_TASK_START}
|
||||
return self.publish(payload)
|
||||
|
||||
|
||||
class CeleryTaskLoggerHandler(StreamHandler):
|
||||
terminator = '\r\n'
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
task_prerun.connect(self.on_task_start)
|
||||
task_postrun.connect(self.on_start_end)
|
||||
|
||||
@staticmethod
|
||||
def get_current_task_id():
|
||||
if not current_task:
|
||||
return
|
||||
task_id = current_task.request.root_id
|
||||
return task_id
|
||||
|
||||
def on_task_start(self, sender, task_id, **kwargs):
|
||||
return self.handle_task_start(task_id)
|
||||
|
||||
def on_start_end(self, sender, task_id, **kwargs):
|
||||
return self.handle_task_end(task_id)
|
||||
|
||||
def after_task_publish(self, sender, body, **kwargs):
|
||||
pass
|
||||
|
||||
def emit(self, record):
|
||||
task_id = self.get_current_task_id()
|
||||
if not task_id:
|
||||
return
|
||||
try:
|
||||
self.write_task_log(task_id, record)
|
||||
self.flush()
|
||||
except Exception:
|
||||
self.handleError(record)
|
||||
|
||||
def write_task_log(self, task_id, msg):
|
||||
pass
|
||||
|
||||
def handle_task_start(self, task_id):
|
||||
pass
|
||||
|
||||
def handle_task_end(self, task_id):
|
||||
pass
|
||||
|
||||
|
||||
class CeleryThreadingLoggerHandler(CeleryTaskLoggerHandler):
|
||||
@staticmethod
|
||||
def get_current_thread_id():
|
||||
return str(get_ident())
|
||||
|
||||
def emit(self, record):
|
||||
thread_id = self.get_current_thread_id()
|
||||
try:
|
||||
self.write_thread_task_log(thread_id, record)
|
||||
self.flush()
|
||||
except ValueError:
|
||||
self.handleError(record)
|
||||
|
||||
def write_thread_task_log(self, thread_id, msg):
|
||||
pass
|
||||
|
||||
def handle_task_start(self, task_id):
|
||||
pass
|
||||
|
||||
def handle_task_end(self, task_id):
|
||||
pass
|
||||
|
||||
def handleError(self, record) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class CeleryTaskMQLoggerHandler(CeleryTaskLoggerHandler):
|
||||
def __init__(self):
|
||||
self.producer = CeleryLoggerProducer()
|
||||
super().__init__(stream=None)
|
||||
|
||||
def write_task_log(self, task_id, record):
|
||||
msg = self.format(record)
|
||||
self.producer.log(task_id, msg)
|
||||
|
||||
def flush(self):
|
||||
self.producer.flush()
|
||||
|
||||
|
||||
class CeleryTaskFileHandler(CeleryTaskLoggerHandler):
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.f = None
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def emit(self, record):
|
||||
msg = self.format(record)
|
||||
if not self.f or self.f.closed:
|
||||
return
|
||||
self.f.write(msg)
|
||||
self.f.write(self.terminator)
|
||||
self.flush()
|
||||
|
||||
def flush(self):
|
||||
self.f and self.f.flush()
|
||||
|
||||
def handle_task_start(self, task_id):
|
||||
log_path = get_celery_task_log_path(task_id)
|
||||
self.f = open(log_path, 'a')
|
||||
|
||||
def handle_task_end(self, task_id):
|
||||
self.f and self.f.close()
|
||||
|
||||
|
||||
class CeleryThreadTaskFileHandler(CeleryThreadingLoggerHandler):
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.thread_id_fd_mapper = {}
|
||||
self.task_id_thread_id_mapper = {}
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def write_thread_task_log(self, thread_id, record):
|
||||
f = self.thread_id_fd_mapper.get(thread_id, None)
|
||||
if not f:
|
||||
raise ValueError('Not found thread task file')
|
||||
msg = self.format(record)
|
||||
f.write(msg.encode())
|
||||
f.write(self.terminator.encode())
|
||||
f.flush()
|
||||
|
||||
def flush(self):
|
||||
for f in self.thread_id_fd_mapper.values():
|
||||
f.flush()
|
||||
|
||||
def handle_task_start(self, task_id):
|
||||
print('handle_task_start')
|
||||
log_path = get_celery_task_log_path(task_id)
|
||||
thread_id = self.get_current_thread_id()
|
||||
self.task_id_thread_id_mapper[task_id] = thread_id
|
||||
f = open(log_path, 'ab')
|
||||
self.thread_id_fd_mapper[thread_id] = f
|
||||
|
||||
def handle_task_end(self, task_id):
|
||||
print('handle_task_end')
|
||||
ident_id = self.task_id_thread_id_mapper.get(task_id, '')
|
||||
f = self.thread_id_fd_mapper.pop(ident_id, None)
|
||||
if f and not f.closed:
|
||||
f.write(CELERY_LOG_MAGIC_MARK)
|
||||
f.close()
|
||||
self.task_id_thread_id_mapper.pop(task_id, None)
|
||||
|
|
@ -0,0 +1,75 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
import logging
|
||||
import os
|
||||
|
||||
from celery import subtask
|
||||
from celery.signals import (
|
||||
worker_ready, worker_shutdown, after_setup_logger, task_revoked, task_prerun
|
||||
)
|
||||
from django.core.cache import cache
|
||||
from django_celery_beat.models import PeriodicTask
|
||||
|
||||
from .decorator import get_after_app_ready_tasks, get_after_app_shutdown_clean_tasks
|
||||
from .logger import CeleryThreadTaskFileHandler
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
safe_str = lambda x: x
|
||||
|
||||
|
||||
@worker_ready.connect
|
||||
def on_app_ready(sender=None, headers=None, **kwargs):
|
||||
if cache.get("CELERY_APP_READY", 0) == 1:
|
||||
return
|
||||
cache.set("CELERY_APP_READY", 1, 10)
|
||||
tasks = get_after_app_ready_tasks()
|
||||
logger.debug("Work ready signal recv")
|
||||
logger.debug("Start need start task: [{}]".format(", ".join(tasks)))
|
||||
for task in tasks:
|
||||
periodic_task = PeriodicTask.objects.filter(task=task).first()
|
||||
if periodic_task and not periodic_task.enabled:
|
||||
logger.debug("Periodic task [{}] is disabled!".format(task))
|
||||
continue
|
||||
subtask(task).delay()
|
||||
|
||||
|
||||
def delete_files(directory):
|
||||
if os.path.isdir(directory):
|
||||
for filename in os.listdir(directory):
|
||||
file_path = os.path.join(directory, filename)
|
||||
if os.path.isfile(file_path):
|
||||
os.remove(file_path)
|
||||
|
||||
|
||||
@worker_shutdown.connect
|
||||
def after_app_shutdown_periodic_tasks(sender=None, **kwargs):
|
||||
if cache.get("CELERY_APP_SHUTDOWN", 0) == 1:
|
||||
return
|
||||
cache.set("CELERY_APP_SHUTDOWN", 1, 10)
|
||||
tasks = get_after_app_shutdown_clean_tasks()
|
||||
logger.debug("Worker shutdown signal recv")
|
||||
logger.debug("Clean period tasks: [{}]".format(', '.join(tasks)))
|
||||
PeriodicTask.objects.filter(name__in=tasks).delete()
|
||||
|
||||
|
||||
@after_setup_logger.connect
|
||||
def add_celery_logger_handler(sender=None, logger=None, loglevel=None, format=None, **kwargs):
|
||||
if not logger:
|
||||
return
|
||||
task_handler = CeleryThreadTaskFileHandler()
|
||||
task_handler.setLevel(loglevel)
|
||||
formatter = logging.Formatter(format)
|
||||
task_handler.setFormatter(formatter)
|
||||
logger.addHandler(task_handler)
|
||||
|
||||
|
||||
@task_revoked.connect
|
||||
def on_task_revoked(request, terminated, signum, expired, **kwargs):
|
||||
print('task_revoked', terminated)
|
||||
|
||||
|
||||
@task_prerun.connect
|
||||
def on_taskaa_start(sender, task_id, **kwargs):
|
||||
pass
|
||||
# sender.update_state(state='REVOKED',
|
||||
# meta={'exc_type': 'Exception', 'exc': 'Exception', 'message': '暂停任务', 'exc_message': ''})
|
||||
|
|
@ -0,0 +1,68 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
|
||||
from django.conf import settings
|
||||
from django_celery_beat.models import (
|
||||
PeriodicTasks
|
||||
)
|
||||
|
||||
from maxkb.const import PROJECT_DIR
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
|
||||
|
||||
def disable_celery_periodic_task(task_name):
|
||||
from django_celery_beat.models import PeriodicTask
|
||||
PeriodicTask.objects.filter(name=task_name).update(enabled=False)
|
||||
PeriodicTasks.update_changed()
|
||||
|
||||
|
||||
def delete_celery_periodic_task(task_name):
|
||||
from django_celery_beat.models import PeriodicTask
|
||||
PeriodicTask.objects.filter(name=task_name).delete()
|
||||
PeriodicTasks.update_changed()
|
||||
|
||||
|
||||
def get_celery_periodic_task(task_name):
|
||||
from django_celery_beat.models import PeriodicTask
|
||||
task = PeriodicTask.objects.filter(name=task_name).first()
|
||||
return task
|
||||
|
||||
|
||||
def make_dirs(name, mode=0o755, exist_ok=False):
|
||||
""" 默认权限设置为 0o755 """
|
||||
return os.makedirs(name, mode=mode, exist_ok=exist_ok)
|
||||
|
||||
|
||||
def get_task_log_path(base_path, task_id, level=2):
|
||||
task_id = str(task_id)
|
||||
try:
|
||||
uuid.UUID(task_id)
|
||||
except:
|
||||
return os.path.join(PROJECT_DIR, 'data', 'caution.txt')
|
||||
|
||||
rel_path = os.path.join(*task_id[:level], task_id + '.log')
|
||||
path = os.path.join(base_path, rel_path)
|
||||
make_dirs(os.path.dirname(path), exist_ok=True)
|
||||
return path
|
||||
|
||||
|
||||
def get_celery_task_log_path(task_id):
|
||||
return get_task_log_path(settings.CELERY_LOG_DIR, task_id)
|
||||
|
||||
|
||||
def get_celery_status():
|
||||
from . import app
|
||||
i = app.control.inspect()
|
||||
ping_data = i.ping() or {}
|
||||
active_nodes = [k for k, v in ping_data.items() if v.get('ok') == 'pong']
|
||||
active_queue_worker = set([n.split('@')[0] for n in active_nodes if n])
|
||||
# Celery Worker 数量: 2
|
||||
if len(active_queue_worker) < 2:
|
||||
print("Not all celery worker worked")
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
|
@ -84,8 +84,8 @@ class ToolTreeReadAPI(APIMixin):
|
|||
required=True,
|
||||
),
|
||||
OpenApiParameter(
|
||||
name="module_id",
|
||||
description="模块id",
|
||||
name="folder_id",
|
||||
description="文件夹id",
|
||||
type=OpenApiTypes.STR,
|
||||
location='query',
|
||||
required=False,
|
||||
|
|
@ -178,8 +178,8 @@ class ToolPageAPI(ToolReadAPI):
|
|||
required=True,
|
||||
),
|
||||
OpenApiParameter(
|
||||
name="module_id",
|
||||
description="模块id",
|
||||
name="folder_id",
|
||||
description="文件夹id",
|
||||
type=OpenApiTypes.STR,
|
||||
location='query',
|
||||
required=True,
|
||||
|
|
|
|||
|
|
@ -5,12 +5,12 @@ import mptt.fields
|
|||
import uuid_utils.compat
|
||||
from django.db import migrations, models
|
||||
|
||||
from tools.models import ToolModule
|
||||
from tools.models import ToolFolder
|
||||
|
||||
|
||||
def insert_default_data(apps, schema_editor):
|
||||
# 创建一个根模块(没有父节点)
|
||||
ToolModule.objects.create(id='root', name='根目录', user_id='f0dd8f71-e4ee-11ee-8c84-a8a1595801ab')
|
||||
ToolFolder.objects.create(id='root', name='根目录', user_id='f0dd8f71-e4ee-11ee-8c84-a8a1595801ab')
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
|
@ -22,7 +22,7 @@ class Migration(migrations.Migration):
|
|||
|
||||
operations = [
|
||||
migrations.CreateModel(
|
||||
name='ToolModule',
|
||||
name='ToolFolder',
|
||||
fields=[
|
||||
('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')),
|
||||
('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')),
|
||||
|
|
@ -37,12 +37,12 @@ class Migration(migrations.Migration):
|
|||
('level', models.PositiveIntegerField(editable=False)),
|
||||
('parent',
|
||||
mptt.fields.TreeForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE,
|
||||
related_name='children', to='tools.toolmodule')),
|
||||
related_name='children', to='tools.toolfolder')),
|
||||
('user', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='users.user',
|
||||
verbose_name='用户id')),
|
||||
],
|
||||
options={
|
||||
'db_table': 'tool_module',
|
||||
'db_table': 'tool_folder',
|
||||
},
|
||||
),
|
||||
migrations.RunPython(insert_default_data),
|
||||
|
|
@ -72,9 +72,9 @@ class Migration(migrations.Migration):
|
|||
('init_params', models.CharField(max_length=102400, null=True, verbose_name='初始化参数')),
|
||||
('user', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='users.user',
|
||||
verbose_name='用户id')),
|
||||
('module',
|
||||
models.ForeignKey(default='root', on_delete=django.db.models.deletion.CASCADE, to='tools.toolmodule',
|
||||
verbose_name='模块id')),
|
||||
('folder',
|
||||
models.ForeignKey(default='root', on_delete=django.db.models.deletion.CASCADE, to='tools.toolfolder',
|
||||
verbose_name='文件夹id')),
|
||||
],
|
||||
options={
|
||||
'db_table': 'tool',
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from common.mixins.app_model_mixin import AppModelMixin
|
|||
from users.models import User
|
||||
|
||||
|
||||
class ToolModule(MPTTModel, AppModelMixin):
|
||||
class ToolFolder(MPTTModel, AppModelMixin):
|
||||
id = models.CharField(primary_key=True, max_length=64, editable=False, verbose_name="主键id")
|
||||
name = models.CharField(max_length=64, verbose_name="文件夹名称")
|
||||
user = models.ForeignKey(User, on_delete=models.DO_NOTHING, verbose_name="用户id")
|
||||
|
|
@ -15,7 +15,7 @@ class ToolModule(MPTTModel, AppModelMixin):
|
|||
parent = TreeForeignKey('self', on_delete=models.CASCADE, null=True, blank=True, related_name='children')
|
||||
|
||||
class Meta:
|
||||
db_table = "tool_module"
|
||||
db_table = "tool_folder"
|
||||
|
||||
class MPTTMeta:
|
||||
order_insertion_by = ['name']
|
||||
|
|
@ -46,7 +46,7 @@ class Tool(AppModelMixin):
|
|||
tool_type = models.CharField(max_length=20, verbose_name='工具类型', choices=ToolType.choices,
|
||||
default=ToolType.CUSTOM, db_index=True)
|
||||
template_id = models.UUIDField(max_length=128, verbose_name="模版id", null=True, default=None)
|
||||
module = models.ForeignKey(ToolModule, on_delete=models.CASCADE, verbose_name="模块id", default='root')
|
||||
folder = models.ForeignKey(ToolFolder, on_delete=models.CASCADE, verbose_name="文件夹id", default='root')
|
||||
workspace_id = models.CharField(max_length=64, verbose_name="工作空间id", default="default", db_index=True)
|
||||
init_params = models.CharField(max_length=102400, verbose_name="初始化参数", null=True)
|
||||
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ from common.exception.app_exception import AppApiException
|
|||
from common.result import result
|
||||
from common.utils.tool_code import ToolExecutor
|
||||
from maxkb.const import CONFIG
|
||||
from tools.models import Tool, ToolScope, ToolModule
|
||||
from tools.models import Tool, ToolScope, ToolFolder
|
||||
|
||||
tool_executor = ToolExecutor(CONFIG.get('SANDBOX'))
|
||||
|
||||
|
|
@ -37,11 +37,11 @@ ALLOWED_CLASSES = {
|
|||
|
||||
class RestrictedUnpickler(pickle.Unpickler):
|
||||
|
||||
def find_class(self, module, name):
|
||||
if (module, name) in ALLOWED_CLASSES:
|
||||
return super().find_class(module, name)
|
||||
def find_class(self, folder, name):
|
||||
if (folder, name) in ALLOWED_CLASSES:
|
||||
return super().find_class(folder, name)
|
||||
raise pickle.UnpicklingError("global '%s.%s' is forbidden" %
|
||||
(module, name))
|
||||
(folder, name))
|
||||
|
||||
|
||||
def encryption(message: str):
|
||||
|
|
@ -75,7 +75,7 @@ class ToolModelSerializer(serializers.ModelSerializer):
|
|||
class Meta:
|
||||
model = Tool
|
||||
fields = ['id', 'name', 'icon', 'desc', 'code', 'input_field_list', 'init_field_list', 'init_params',
|
||||
'scope', 'is_active', 'user_id', 'template_id', 'workspace_id', 'module_id', 'tool_type',
|
||||
'scope', 'is_active', 'user_id', 'template_id', 'workspace_id', 'folder_id', 'tool_type',
|
||||
'create_time', 'update_time']
|
||||
|
||||
|
||||
|
|
@ -126,7 +126,7 @@ class ToolCreateRequest(serializers.Serializer):
|
|||
|
||||
is_active = serializers.BooleanField(required=False, label=_('Is active'))
|
||||
|
||||
module_id = serializers.CharField(required=False, allow_null=True, allow_blank=True, default='root')
|
||||
folder_id = serializers.CharField(required=False, allow_null=True, allow_blank=True, default='root')
|
||||
|
||||
|
||||
class ToolEditRequest(serializers.Serializer):
|
||||
|
|
@ -146,7 +146,7 @@ class ToolEditRequest(serializers.Serializer):
|
|||
|
||||
is_active = serializers.BooleanField(required=False, label=_('Is active'))
|
||||
|
||||
module_id = serializers.CharField(required=False, allow_null=True, allow_blank=True, default='root')
|
||||
folder_id = serializers.CharField(required=False, allow_null=True, allow_blank=True, default='root')
|
||||
|
||||
|
||||
class DebugField(serializers.Serializer):
|
||||
|
|
@ -180,7 +180,7 @@ class ToolSerializer(serializers.Serializer):
|
|||
input_field_list=instance.get('input_field_list', []),
|
||||
init_field_list=instance.get('init_field_list', []),
|
||||
scope=ToolScope.WORKSPACE,
|
||||
module_id=instance.get('module_id', 'root'),
|
||||
folder_id=instance.get('folder_id', 'root'),
|
||||
is_active=False)
|
||||
tool.save()
|
||||
return ToolModelSerializer(tool).data
|
||||
|
|
@ -321,42 +321,42 @@ class ToolSerializer(serializers.Serializer):
|
|||
class ToolTreeSerializer(serializers.Serializer):
|
||||
workspace_id = serializers.CharField(required=True, label=_('workspace id'))
|
||||
|
||||
def get_tools(self, module_id):
|
||||
def get_tools(self, folder_id):
|
||||
self.is_valid(raise_exception=True)
|
||||
if not module_id:
|
||||
module_id = 'root'
|
||||
root = ToolModule.objects.filter(id=module_id).first()
|
||||
if not folder_id:
|
||||
folder_id = 'root'
|
||||
root = ToolFolder.objects.filter(id=folder_id).first()
|
||||
if not root:
|
||||
raise serializers.ValidationError(_('Module not found'))
|
||||
raise serializers.ValidationError(_('Folder not found'))
|
||||
# 使用MPTT的get_descendants()方法获取所有相关节点
|
||||
all_modules = root.get_descendants(include_self=True)
|
||||
all_folders = root.get_descendants(include_self=True)
|
||||
|
||||
tools = QuerySet(Tool).filter(workspace_id=self.data.get('workspace_id'), module_id__in=all_modules)
|
||||
tools = QuerySet(Tool).filter(workspace_id=self.data.get('workspace_id'), folder_id__in=all_folders)
|
||||
return ToolModelSerializer(tools, many=True).data
|
||||
|
||||
class Query(serializers.Serializer):
|
||||
workspace_id = serializers.CharField(required=True, label=_('workspace id'))
|
||||
module_id = serializers.CharField(required=True, label=_('module id'))
|
||||
folder_id = serializers.CharField(required=True, label=_('folder id'))
|
||||
name = serializers.CharField(required=False, allow_null=True, allow_blank=True, label=_('tool name'))
|
||||
tool_type = serializers.CharField(required=True, label=_('tool type'))
|
||||
|
||||
def page(self, current_page: int, page_size: int):
|
||||
self.is_valid(raise_exception=True)
|
||||
|
||||
module_id = self.data.get('module_id', 'root')
|
||||
root = ToolModule.objects.filter(id=module_id).first()
|
||||
folder_id = self.data.get('folder_id', 'root')
|
||||
root = ToolFolder.objects.filter(id=folder_id).first()
|
||||
if not root:
|
||||
raise serializers.ValidationError(_('Module not found'))
|
||||
raise serializers.ValidationError(_('Folder not found'))
|
||||
# 使用MPTT的get_descendants()方法获取所有相关节点
|
||||
all_modules = root.get_descendants(include_self=True)
|
||||
all_folders = root.get_descendants(include_self=True)
|
||||
|
||||
if self.data.get('name'):
|
||||
tools = QuerySet(Tool).filter(Q(workspace_id=self.data.get('workspace_id')) &
|
||||
Q(module_id__in=all_modules) &
|
||||
Q(folder_id__in=all_folders) &
|
||||
Q(tool_type=self.data.get('tool_type')) &
|
||||
Q(name__contains=self.data.get('name')))
|
||||
else:
|
||||
tools = QuerySet(Tool).filter(Q(workspace_id=self.data.get('workspace_id')) &
|
||||
Q(module_id__in=all_modules) &
|
||||
Q(folder_id__in=all_folders) &
|
||||
Q(tool_type=self.data.get('tool_type')))
|
||||
return page_search(current_page, page_size, tools, lambda record: ToolModelSerializer(record).data)
|
||||
|
|
|
|||
|
|
@ -2,15 +2,15 @@
|
|||
|
||||
from rest_framework import serializers
|
||||
|
||||
from tools.models import ToolModule
|
||||
from tools.models import ToolFolder
|
||||
|
||||
|
||||
class ToolModuleTreeSerializer(serializers.ModelSerializer):
|
||||
class ToolFolderTreeSerializer(serializers.ModelSerializer):
|
||||
children = serializers.SerializerMethodField()
|
||||
|
||||
class Meta:
|
||||
model = ToolModule
|
||||
model = ToolFolder
|
||||
fields = ['id', 'name', 'user_id', 'workspace_id', 'parent_id', 'children']
|
||||
|
||||
def get_children(self, obj):
|
||||
return ToolModuleTreeSerializer(obj.get_children(), many=True).data
|
||||
return ToolFolderTreeSerializer(obj.get_children(), many=True).data
|
||||
|
|
@ -33,8 +33,8 @@ class ToolView(APIView):
|
|||
|
||||
@extend_schema(
|
||||
methods=['GET'],
|
||||
description=_('Get tool by module'),
|
||||
operation_id=_('Get tool by module'),
|
||||
description=_('Get tool by folder'),
|
||||
operation_id=_('Get tool by folder'),
|
||||
parameters=ToolTreeReadAPI.get_parameters(),
|
||||
responses=ToolTreeReadAPI.get_response(),
|
||||
tags=[_('Tool')]
|
||||
|
|
@ -43,7 +43,7 @@ class ToolView(APIView):
|
|||
def get(self, request: Request, workspace_id: str):
|
||||
return result.success(ToolTreeSerializer(
|
||||
data={'workspace_id': workspace_id}
|
||||
).get_tools(request.query_params.get('module_id')))
|
||||
).get_tools(request.query_params.get('folder_id')))
|
||||
|
||||
class Debug(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
|
@ -124,7 +124,7 @@ class ToolView(APIView):
|
|||
return result.success(ToolTreeSerializer.Query(
|
||||
data={
|
||||
'workspace_id': workspace_id,
|
||||
'module_id': request.query_params.get('module_id'),
|
||||
'folder_id': request.query_params.get('folder_id'),
|
||||
'name': request.query_params.get('name'),
|
||||
'tool_type': request.query_params.get('tool_type'),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -40,6 +40,12 @@ cffi = "1.17.1"
|
|||
pysilk = "0.0.1"
|
||||
sentence-transformers = "4.1.0"
|
||||
websockets = "15.0.1"
|
||||
celery = { extras = ["sqlalchemy"], version = "5.5.2" }
|
||||
django-celery-beat = "2.8.0"
|
||||
celery-once = "3.0.1"
|
||||
beautifulsoup4 = "4.13.4"
|
||||
html2text = "2025.4.15"
|
||||
jieba = "0.42.1"
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
|
|
|||
Loading…
Reference in New Issue