mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 01:33:05 +00:00
feat: add dynamic SQL execution and pagination functionality with custom query compiler
This commit is contained in:
parent
fbb4e7d449
commit
4c38e8a82b
|
|
@ -0,0 +1,217 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: compiler.py
|
||||
@date:2023/10/7 10:53
|
||||
@desc:
|
||||
"""
|
||||
|
||||
from django.core.exceptions import EmptyResultSet, FullResultSet
|
||||
from django.db import NotSupportedError
|
||||
from django.db.models.sql.compiler import SQLCompiler
|
||||
from django.db.transaction import TransactionManagementError
|
||||
|
||||
|
||||
class AppSQLCompiler(SQLCompiler):
|
||||
def __init__(self, query, connection, using, elide_empty=True, field_replace_dict=None):
|
||||
super().__init__(query, connection, using, elide_empty)
|
||||
if field_replace_dict is None:
|
||||
field_replace_dict = {}
|
||||
self.field_replace_dict = field_replace_dict
|
||||
|
||||
def get_query_str(self, with_limits=True, with_table_name=False, with_col_aliases=False):
|
||||
refcounts_before = self.query.alias_refcount.copy()
|
||||
try:
|
||||
combinator = self.query.combinator
|
||||
extra_select, order_by, group_by = self.pre_sql_setup(
|
||||
with_col_aliases=with_col_aliases or bool(combinator),
|
||||
)
|
||||
for_update_part = None
|
||||
# Is a LIMIT/OFFSET clause needed?
|
||||
with_limit_offset = with_limits and self.query.is_sliced
|
||||
combinator = self.query.combinator
|
||||
features = self.connection.features
|
||||
if combinator:
|
||||
if not getattr(features, "supports_select_{}".format(combinator)):
|
||||
raise NotSupportedError(
|
||||
"{} is not supported on this database backend.".format(
|
||||
combinator
|
||||
)
|
||||
)
|
||||
result, params = self.get_combinator_sql(
|
||||
combinator, self.query.combinator_all
|
||||
)
|
||||
elif self.qualify:
|
||||
result, params = self.get_qualify_sql()
|
||||
order_by = None
|
||||
else:
|
||||
distinct_fields, distinct_params = self.get_distinct()
|
||||
# This must come after 'select', 'ordering', and 'distinct'
|
||||
# (see docstring of get_from_clause() for details).
|
||||
from_, f_params = self.get_from_clause()
|
||||
try:
|
||||
where, w_params = (
|
||||
self.compile(self.where) if self.where is not None else ("", [])
|
||||
)
|
||||
except EmptyResultSet:
|
||||
if self.elide_empty:
|
||||
raise
|
||||
# Use a predicate that's always False.
|
||||
where, w_params = "0 = 1", []
|
||||
except FullResultSet:
|
||||
where, w_params = "", []
|
||||
try:
|
||||
having, h_params = (
|
||||
self.compile(self.having)
|
||||
if self.having is not None
|
||||
else ("", [])
|
||||
)
|
||||
except FullResultSet:
|
||||
having, h_params = "", []
|
||||
result = []
|
||||
params = []
|
||||
|
||||
if self.query.distinct:
|
||||
distinct_result, distinct_params = self.connection.ops.distinct_sql(
|
||||
distinct_fields,
|
||||
distinct_params,
|
||||
)
|
||||
result += distinct_result
|
||||
params += distinct_params
|
||||
|
||||
out_cols = []
|
||||
for _, (s_sql, s_params), alias in self.select + extra_select:
|
||||
if alias:
|
||||
s_sql = "%s AS %s" % (
|
||||
s_sql,
|
||||
self.connection.ops.quote_name(alias),
|
||||
)
|
||||
params.extend(s_params)
|
||||
out_cols.append(s_sql)
|
||||
|
||||
params.extend(f_params)
|
||||
|
||||
if self.query.select_for_update and features.has_select_for_update:
|
||||
if (
|
||||
self.connection.get_autocommit()
|
||||
# Don't raise an exception when database doesn't
|
||||
# support transactions, as it's a noop.
|
||||
and features.supports_transactions
|
||||
):
|
||||
raise TransactionManagementError(
|
||||
"select_for_update cannot be used outside of a transaction."
|
||||
)
|
||||
|
||||
if (
|
||||
with_limit_offset
|
||||
and not features.supports_select_for_update_with_limit
|
||||
):
|
||||
raise NotSupportedError(
|
||||
"LIMIT/OFFSET is not supported with "
|
||||
"select_for_update on this database backend."
|
||||
)
|
||||
nowait = self.query.select_for_update_nowait
|
||||
skip_locked = self.query.select_for_update_skip_locked
|
||||
of = self.query.select_for_update_of
|
||||
no_key = self.query.select_for_no_key_update
|
||||
# If it's a NOWAIT/SKIP LOCKED/OF/NO KEY query but the
|
||||
# backend doesn't support it, raise NotSupportedError to
|
||||
# prevent a possible deadlock.
|
||||
if nowait and not features.has_select_for_update_nowait:
|
||||
raise NotSupportedError(
|
||||
"NOWAIT is not supported on this database backend."
|
||||
)
|
||||
elif skip_locked and not features.has_select_for_update_skip_locked:
|
||||
raise NotSupportedError(
|
||||
"SKIP LOCKED is not supported on this database backend."
|
||||
)
|
||||
elif of and not features.has_select_for_update_of:
|
||||
raise NotSupportedError(
|
||||
"FOR UPDATE OF is not supported on this database backend."
|
||||
)
|
||||
elif no_key and not features.has_select_for_no_key_update:
|
||||
raise NotSupportedError(
|
||||
"FOR NO KEY UPDATE is not supported on this "
|
||||
"database backend."
|
||||
)
|
||||
for_update_part = self.connection.ops.for_update_sql(
|
||||
nowait=nowait,
|
||||
skip_locked=skip_locked,
|
||||
of=self.get_select_for_update_of_arguments(),
|
||||
no_key=no_key,
|
||||
)
|
||||
|
||||
if for_update_part and features.for_update_after_from:
|
||||
result.append(for_update_part)
|
||||
|
||||
if where:
|
||||
result.append("WHERE %s" % where)
|
||||
params.extend(w_params)
|
||||
|
||||
grouping = []
|
||||
for g_sql, g_params in group_by:
|
||||
grouping.append(g_sql)
|
||||
params.extend(g_params)
|
||||
if grouping:
|
||||
if distinct_fields:
|
||||
raise NotImplementedError(
|
||||
"annotate() + distinct(fields) is not implemented."
|
||||
)
|
||||
order_by = order_by or self.connection.ops.force_no_ordering()
|
||||
result.append("GROUP BY %s" % ", ".join(grouping))
|
||||
if self._meta_ordering:
|
||||
order_by = None
|
||||
if having:
|
||||
result.append("HAVING %s" % having)
|
||||
params.extend(h_params)
|
||||
|
||||
if self.query.explain_info:
|
||||
result.insert(
|
||||
0,
|
||||
self.connection.ops.explain_query_prefix(
|
||||
self.query.explain_info.format,
|
||||
**self.query.explain_info.options,
|
||||
),
|
||||
)
|
||||
|
||||
if order_by:
|
||||
ordering = []
|
||||
for _, (o_sql, o_params, _) in order_by:
|
||||
ordering.append(o_sql)
|
||||
params.extend(o_params)
|
||||
order_by_sql = "ORDER BY %s" % ", ".join(ordering)
|
||||
if combinator and features.requires_compound_order_by_subquery:
|
||||
result = ["SELECT * FROM (", *result, ")", order_by_sql]
|
||||
else:
|
||||
result.append(order_by_sql)
|
||||
|
||||
if with_limit_offset:
|
||||
result.append(
|
||||
self.connection.ops.limit_offset_sql(
|
||||
self.query.low_mark, self.query.high_mark
|
||||
)
|
||||
)
|
||||
|
||||
if for_update_part and not features.for_update_after_from:
|
||||
result.append(for_update_part)
|
||||
|
||||
from_, f_params = self.get_from_clause()
|
||||
sql = " ".join(result)
|
||||
if not with_table_name:
|
||||
for table_name in from_:
|
||||
sql = sql.replace(table_name + ".", "")
|
||||
for key in self.field_replace_dict.keys():
|
||||
value = self.field_replace_dict.get(key)
|
||||
sql = sql.replace(key, value)
|
||||
return sql, tuple(params)
|
||||
finally:
|
||||
# Finally do cleanup - get rid of the joins we created above.
|
||||
self.query.reset_refcounts(refcounts_before)
|
||||
|
||||
def as_sql(self, with_limits=True, with_col_aliases=False, select_string=None):
|
||||
if select_string is None:
|
||||
return super().as_sql(with_limits, with_col_aliases)
|
||||
else:
|
||||
sql, params = self.get_query_str(with_table_name=False)
|
||||
return (select_string + " " + sql), params
|
||||
|
|
@ -0,0 +1,194 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: search.py
|
||||
@date:2023/10/7 18:20
|
||||
@desc:
|
||||
"""
|
||||
from typing import Dict, Any
|
||||
|
||||
from django.db import DEFAULT_DB_ALIAS, models, connections
|
||||
from django.db.models import QuerySet
|
||||
|
||||
from common.db.compiler import AppSQLCompiler
|
||||
from common.db.sql_execute import select_one, select_list, update_execute
|
||||
from common.result import Page
|
||||
|
||||
|
||||
def get_dynamics_model(attr: dict, table_name='dynamics'):
|
||||
"""
|
||||
获取一个动态的django模型
|
||||
:param attr: 模型字段
|
||||
:param table_name: 表名
|
||||
:return: django 模型
|
||||
"""
|
||||
attributes = {
|
||||
"__module__": "dataset.models",
|
||||
"Meta": type("Meta", (), {'db_table': table_name}),
|
||||
**attr
|
||||
}
|
||||
return type('Dynamics', (models.Model,), attributes)
|
||||
|
||||
|
||||
def generate_sql_by_query_dict(queryset_dict: Dict[str, QuerySet], select_string: str,
|
||||
field_replace_dict: None | Dict[str, Dict[str, str]] = None, with_table_name=False):
|
||||
"""
|
||||
生成 查询sql
|
||||
:param with_table_name:
|
||||
:param queryset_dict: 多条件 查询条件
|
||||
:param select_string: 查询sql
|
||||
:param field_replace_dict: 需要替换的查询字段,一般不需要传入如果有特殊的需要传入
|
||||
:return: sql:需要查询的sql params: sql 参数
|
||||
"""
|
||||
|
||||
params_dict: Dict[int, Any] = {}
|
||||
result_params = []
|
||||
for key in queryset_dict.keys():
|
||||
value = queryset_dict.get(key)
|
||||
sql, params = compiler_queryset(value, None if field_replace_dict is None else field_replace_dict.get(key),
|
||||
with_table_name)
|
||||
params_dict = {**params_dict, select_string.index("${" + key + "}"): params}
|
||||
select_string = select_string.replace("${" + key + "}", sql)
|
||||
|
||||
for key in sorted(list(params_dict.keys())):
|
||||
result_params = [*result_params, *params_dict.get(key)]
|
||||
return select_string, result_params
|
||||
|
||||
|
||||
def generate_sql_by_query(queryset: QuerySet, select_string: str,
|
||||
field_replace_dict: None | Dict[str, str] = None, with_table_name=False):
|
||||
"""
|
||||
生成 查询sql
|
||||
:param queryset: 查询条件
|
||||
:param select_string: 原始sql
|
||||
:param field_replace_dict: 需要替换的查询字段,一般不需要传入如果有特殊的需要传入
|
||||
:return: sql:需要查询的sql params: sql 参数
|
||||
"""
|
||||
sql, params = compiler_queryset(queryset, field_replace_dict, with_table_name)
|
||||
return select_string + " " + sql, params
|
||||
|
||||
|
||||
def compiler_queryset(queryset: QuerySet, field_replace_dict: None | Dict[str, str] = None, with_table_name=False):
|
||||
"""
|
||||
解析 queryset查询对象
|
||||
:param with_table_name:
|
||||
:param queryset: 查询对象
|
||||
:param field_replace_dict: 需要替换的查询字段,一般不需要传入如果有特殊的需要传入
|
||||
:return: sql:需要查询的sql params: sql 参数
|
||||
"""
|
||||
q = queryset.query
|
||||
compiler = q.get_compiler(DEFAULT_DB_ALIAS)
|
||||
if field_replace_dict is None:
|
||||
field_replace_dict = get_field_replace_dict(queryset)
|
||||
app_sql_compiler = AppSQLCompiler(q, using=DEFAULT_DB_ALIAS, connection=compiler.connection,
|
||||
field_replace_dict=field_replace_dict)
|
||||
sql, params = app_sql_compiler.get_query_str(with_table_name=with_table_name)
|
||||
return sql, params
|
||||
|
||||
|
||||
def native_search(queryset: QuerySet | Dict[str, QuerySet], select_string: str,
|
||||
field_replace_dict: None | Dict[str, Dict[str, str]] | Dict[str, str] = None,
|
||||
with_search_one=False, with_table_name=False):
|
||||
"""
|
||||
复杂查询
|
||||
:param with_table_name: 生成sql是否包含表名
|
||||
:param queryset: 查询条件构造器
|
||||
:param select_string: 查询前缀 不包括 where limit 等信息
|
||||
:param field_replace_dict: 需要替换的字段
|
||||
:param with_search_one: 查询
|
||||
:return: 查询结果
|
||||
"""
|
||||
if isinstance(queryset, Dict):
|
||||
exec_sql, exec_params = generate_sql_by_query_dict(queryset, select_string, field_replace_dict, with_table_name)
|
||||
else:
|
||||
exec_sql, exec_params = generate_sql_by_query(queryset, select_string, field_replace_dict, with_table_name)
|
||||
if with_search_one:
|
||||
return select_one(exec_sql, exec_params)
|
||||
else:
|
||||
return select_list(exec_sql, exec_params)
|
||||
|
||||
|
||||
def native_update(queryset: QuerySet | Dict[str, QuerySet], select_string: str,
|
||||
field_replace_dict: None | Dict[str, Dict[str, str]] | Dict[str, str] = None,
|
||||
with_table_name=False):
|
||||
"""
|
||||
复杂查询
|
||||
:param with_table_name: 生成sql是否包含表名
|
||||
:param queryset: 查询条件构造器
|
||||
:param select_string: 查询前缀 不包括 where limit 等信息
|
||||
:param field_replace_dict: 需要替换的字段
|
||||
:return: 查询结果
|
||||
"""
|
||||
if isinstance(queryset, Dict):
|
||||
exec_sql, exec_params = generate_sql_by_query_dict(queryset, select_string, field_replace_dict, with_table_name)
|
||||
else:
|
||||
exec_sql, exec_params = generate_sql_by_query(queryset, select_string, field_replace_dict, with_table_name)
|
||||
return update_execute(exec_sql, exec_params)
|
||||
|
||||
|
||||
def page_search(current_page: int, page_size: int, queryset: QuerySet, post_records_handler):
|
||||
"""
|
||||
分页查询
|
||||
:param current_page: 当前页
|
||||
:param page_size: 每页大小
|
||||
:param queryset: 查询条件
|
||||
:param post_records_handler: 数据处理器
|
||||
:return: 分页结果
|
||||
"""
|
||||
total = QuerySet(query=queryset.query.clone(), model=queryset.model).count()
|
||||
result = queryset.all()[((current_page - 1) * page_size):(current_page * page_size)]
|
||||
return Page(total, list(map(post_records_handler, result)), current_page, page_size)
|
||||
|
||||
|
||||
def native_page_search(current_page: int, page_size: int, queryset: QuerySet | Dict[str, QuerySet], select_string: str,
|
||||
field_replace_dict=None,
|
||||
post_records_handler=lambda r: r,
|
||||
with_table_name=False):
|
||||
"""
|
||||
复杂分页查询
|
||||
:param with_table_name:
|
||||
:param current_page: 当前页
|
||||
:param page_size: 每页大小
|
||||
:param queryset: 查询条件
|
||||
:param select_string: 查询
|
||||
:param field_replace_dict: 特殊字段替换
|
||||
:param post_records_handler: 数据row处理器
|
||||
:return: 分页结果
|
||||
"""
|
||||
if isinstance(queryset, Dict):
|
||||
exec_sql, exec_params = generate_sql_by_query_dict(queryset, select_string, field_replace_dict, with_table_name)
|
||||
else:
|
||||
exec_sql, exec_params = generate_sql_by_query(queryset, select_string, field_replace_dict, with_table_name)
|
||||
total_sql = "SELECT \"count\"(*) FROM (%s) temp" % exec_sql
|
||||
total = select_one(total_sql, exec_params)
|
||||
limit_sql = connections[DEFAULT_DB_ALIAS].ops.limit_offset_sql(
|
||||
((current_page - 1) * page_size), (current_page * page_size)
|
||||
)
|
||||
page_sql = exec_sql + " " + limit_sql
|
||||
result = select_list(page_sql, exec_params)
|
||||
return Page(total.get("count"), list(map(post_records_handler, result)), current_page, page_size)
|
||||
|
||||
|
||||
def get_field_replace_dict(queryset: QuerySet):
|
||||
"""
|
||||
获取需要替换的字段 默认 “xxx.xxx”需要被替换成 “xxx”."xxx"
|
||||
:param queryset: 查询对象
|
||||
:return: 需要替换的字典
|
||||
"""
|
||||
result = {}
|
||||
for field in queryset.model._meta.local_fields:
|
||||
if field.attname.__contains__("."):
|
||||
replace_field = to_replace_field(field.attname)
|
||||
result.__setitem__('"' + field.attname + '"', replace_field)
|
||||
return result
|
||||
|
||||
|
||||
def to_replace_field(field: str):
|
||||
"""
|
||||
将field 转换为 需要替换的field “xxx.xxx”需要被替换成 “xxx”."xxx" 只替换 field包含.的字段
|
||||
:param field: django field字段
|
||||
:return: 替换字段
|
||||
"""
|
||||
split_field = field.split(".")
|
||||
return ".".join(list(map(lambda sf: '"' + sf + '"', split_field)))
|
||||
|
|
@ -0,0 +1,67 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: sql_execute.py
|
||||
@date:2023/9/25 20:05
|
||||
@desc:
|
||||
"""
|
||||
from typing import List
|
||||
|
||||
from django.db import connection
|
||||
|
||||
|
||||
def sql_execute(sql: str, params):
|
||||
"""
|
||||
执行一条sql
|
||||
:param sql: 需要执行的sql
|
||||
:param params: sql参数
|
||||
:return: 执行结果
|
||||
"""
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(sql, params)
|
||||
columns = list(map(lambda d: d.name, cursor.description))
|
||||
res = cursor.fetchall()
|
||||
result = list(map(lambda row: dict(list(zip(columns, row))), res))
|
||||
cursor.close()
|
||||
return result
|
||||
|
||||
|
||||
def update_execute(sql: str, params):
|
||||
"""
|
||||
执行一条sql
|
||||
:param sql: 需要执行的sql
|
||||
:param params: sql参数
|
||||
:return: 执行结果
|
||||
"""
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(sql, params)
|
||||
affected_rows = cursor.rowcount
|
||||
cursor.close()
|
||||
return affected_rows
|
||||
|
||||
|
||||
def select_list(sql: str, params: List):
|
||||
"""
|
||||
执行sql 查询列表数据
|
||||
:param sql: 需要执行的sql
|
||||
:param params: sql的参数
|
||||
:return: 查询结果
|
||||
"""
|
||||
result_list = sql_execute(sql, params)
|
||||
if result_list is None:
|
||||
return []
|
||||
return result_list
|
||||
|
||||
|
||||
def select_one(sql: str, params: List):
|
||||
"""
|
||||
执行sql 查询一条数据
|
||||
:param sql: 需要执行的sql
|
||||
:param params: 参数
|
||||
:return: 查询结果
|
||||
"""
|
||||
result_list = sql_execute(sql, params)
|
||||
if result_list is None or len(result_list) == 0:
|
||||
return None
|
||||
return result_list[0]
|
||||
|
|
@ -12,6 +12,7 @@ from django.http import HttpResponse
|
|||
from django.utils.translation import gettext_lazy as _
|
||||
from rest_framework import serializers, status
|
||||
|
||||
from common.db.search import page_search
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.result import result
|
||||
from common.utils.tool_code import ToolExecutor
|
||||
|
|
@ -127,6 +128,7 @@ class ToolCreateRequest(serializers.Serializer):
|
|||
|
||||
module_id = serializers.CharField(required=False, allow_null=True, allow_blank=True, default='root')
|
||||
|
||||
|
||||
class ToolEditRequest(serializers.Serializer):
|
||||
name = serializers.CharField(required=False, label=_('tool name'))
|
||||
|
||||
|
|
@ -357,4 +359,4 @@ class ToolTreeSerializer(serializers.Serializer):
|
|||
tools = QuerySet(Tool).filter(Q(workspace_id=self.data.get('workspace_id')) &
|
||||
Q(module_id__in=all_modules) &
|
||||
Q(tool_type=self.data.get('tool_type')))
|
||||
return ToolModelSerializer(tools, many=True).data
|
||||
return page_search(current_page, page_size, tools, lambda record: ToolModelSerializer(record).data)
|
||||
|
|
|
|||
Loading…
Reference in New Issue