diff --git a/apps/common/db/__init__.py b/apps/common/db/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/apps/common/db/compiler.py b/apps/common/db/compiler.py new file mode 100644 index 000000000..69640c8a0 --- /dev/null +++ b/apps/common/db/compiler.py @@ -0,0 +1,217 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: compiler.py + @date:2023/10/7 10:53 + @desc: +""" + +from django.core.exceptions import EmptyResultSet, FullResultSet +from django.db import NotSupportedError +from django.db.models.sql.compiler import SQLCompiler +from django.db.transaction import TransactionManagementError + + +class AppSQLCompiler(SQLCompiler): + def __init__(self, query, connection, using, elide_empty=True, field_replace_dict=None): + super().__init__(query, connection, using, elide_empty) + if field_replace_dict is None: + field_replace_dict = {} + self.field_replace_dict = field_replace_dict + + def get_query_str(self, with_limits=True, with_table_name=False, with_col_aliases=False): + refcounts_before = self.query.alias_refcount.copy() + try: + combinator = self.query.combinator + extra_select, order_by, group_by = self.pre_sql_setup( + with_col_aliases=with_col_aliases or bool(combinator), + ) + for_update_part = None + # Is a LIMIT/OFFSET clause needed? + with_limit_offset = with_limits and self.query.is_sliced + combinator = self.query.combinator + features = self.connection.features + if combinator: + if not getattr(features, "supports_select_{}".format(combinator)): + raise NotSupportedError( + "{} is not supported on this database backend.".format( + combinator + ) + ) + result, params = self.get_combinator_sql( + combinator, self.query.combinator_all + ) + elif self.qualify: + result, params = self.get_qualify_sql() + order_by = None + else: + distinct_fields, distinct_params = self.get_distinct() + # This must come after 'select', 'ordering', and 'distinct' + # (see docstring of get_from_clause() for details). + from_, f_params = self.get_from_clause() + try: + where, w_params = ( + self.compile(self.where) if self.where is not None else ("", []) + ) + except EmptyResultSet: + if self.elide_empty: + raise + # Use a predicate that's always False. + where, w_params = "0 = 1", [] + except FullResultSet: + where, w_params = "", [] + try: + having, h_params = ( + self.compile(self.having) + if self.having is not None + else ("", []) + ) + except FullResultSet: + having, h_params = "", [] + result = [] + params = [] + + if self.query.distinct: + distinct_result, distinct_params = self.connection.ops.distinct_sql( + distinct_fields, + distinct_params, + ) + result += distinct_result + params += distinct_params + + out_cols = [] + for _, (s_sql, s_params), alias in self.select + extra_select: + if alias: + s_sql = "%s AS %s" % ( + s_sql, + self.connection.ops.quote_name(alias), + ) + params.extend(s_params) + out_cols.append(s_sql) + + params.extend(f_params) + + if self.query.select_for_update and features.has_select_for_update: + if ( + self.connection.get_autocommit() + # Don't raise an exception when database doesn't + # support transactions, as it's a noop. + and features.supports_transactions + ): + raise TransactionManagementError( + "select_for_update cannot be used outside of a transaction." + ) + + if ( + with_limit_offset + and not features.supports_select_for_update_with_limit + ): + raise NotSupportedError( + "LIMIT/OFFSET is not supported with " + "select_for_update on this database backend." + ) + nowait = self.query.select_for_update_nowait + skip_locked = self.query.select_for_update_skip_locked + of = self.query.select_for_update_of + no_key = self.query.select_for_no_key_update + # If it's a NOWAIT/SKIP LOCKED/OF/NO KEY query but the + # backend doesn't support it, raise NotSupportedError to + # prevent a possible deadlock. + if nowait and not features.has_select_for_update_nowait: + raise NotSupportedError( + "NOWAIT is not supported on this database backend." + ) + elif skip_locked and not features.has_select_for_update_skip_locked: + raise NotSupportedError( + "SKIP LOCKED is not supported on this database backend." + ) + elif of and not features.has_select_for_update_of: + raise NotSupportedError( + "FOR UPDATE OF is not supported on this database backend." + ) + elif no_key and not features.has_select_for_no_key_update: + raise NotSupportedError( + "FOR NO KEY UPDATE is not supported on this " + "database backend." + ) + for_update_part = self.connection.ops.for_update_sql( + nowait=nowait, + skip_locked=skip_locked, + of=self.get_select_for_update_of_arguments(), + no_key=no_key, + ) + + if for_update_part and features.for_update_after_from: + result.append(for_update_part) + + if where: + result.append("WHERE %s" % where) + params.extend(w_params) + + grouping = [] + for g_sql, g_params in group_by: + grouping.append(g_sql) + params.extend(g_params) + if grouping: + if distinct_fields: + raise NotImplementedError( + "annotate() + distinct(fields) is not implemented." + ) + order_by = order_by or self.connection.ops.force_no_ordering() + result.append("GROUP BY %s" % ", ".join(grouping)) + if self._meta_ordering: + order_by = None + if having: + result.append("HAVING %s" % having) + params.extend(h_params) + + if self.query.explain_info: + result.insert( + 0, + self.connection.ops.explain_query_prefix( + self.query.explain_info.format, + **self.query.explain_info.options, + ), + ) + + if order_by: + ordering = [] + for _, (o_sql, o_params, _) in order_by: + ordering.append(o_sql) + params.extend(o_params) + order_by_sql = "ORDER BY %s" % ", ".join(ordering) + if combinator and features.requires_compound_order_by_subquery: + result = ["SELECT * FROM (", *result, ")", order_by_sql] + else: + result.append(order_by_sql) + + if with_limit_offset: + result.append( + self.connection.ops.limit_offset_sql( + self.query.low_mark, self.query.high_mark + ) + ) + + if for_update_part and not features.for_update_after_from: + result.append(for_update_part) + + from_, f_params = self.get_from_clause() + sql = " ".join(result) + if not with_table_name: + for table_name in from_: + sql = sql.replace(table_name + ".", "") + for key in self.field_replace_dict.keys(): + value = self.field_replace_dict.get(key) + sql = sql.replace(key, value) + return sql, tuple(params) + finally: + # Finally do cleanup - get rid of the joins we created above. + self.query.reset_refcounts(refcounts_before) + + def as_sql(self, with_limits=True, with_col_aliases=False, select_string=None): + if select_string is None: + return super().as_sql(with_limits, with_col_aliases) + else: + sql, params = self.get_query_str(with_table_name=False) + return (select_string + " " + sql), params diff --git a/apps/common/db/search.py b/apps/common/db/search.py new file mode 100644 index 000000000..2955ba691 --- /dev/null +++ b/apps/common/db/search.py @@ -0,0 +1,194 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: search.py + @date:2023/10/7 18:20 + @desc: +""" +from typing import Dict, Any + +from django.db import DEFAULT_DB_ALIAS, models, connections +from django.db.models import QuerySet + +from common.db.compiler import AppSQLCompiler +from common.db.sql_execute import select_one, select_list, update_execute +from common.result import Page + + +def get_dynamics_model(attr: dict, table_name='dynamics'): + """ + 获取一个动态的django模型 + :param attr: 模型字段 + :param table_name: 表名 + :return: django 模型 + """ + attributes = { + "__module__": "dataset.models", + "Meta": type("Meta", (), {'db_table': table_name}), + **attr + } + return type('Dynamics', (models.Model,), attributes) + + +def generate_sql_by_query_dict(queryset_dict: Dict[str, QuerySet], select_string: str, + field_replace_dict: None | Dict[str, Dict[str, str]] = None, with_table_name=False): + """ + 生成 查询sql + :param with_table_name: + :param queryset_dict: 多条件 查询条件 + :param select_string: 查询sql + :param field_replace_dict: 需要替换的查询字段,一般不需要传入如果有特殊的需要传入 + :return: sql:需要查询的sql params: sql 参数 + """ + + params_dict: Dict[int, Any] = {} + result_params = [] + for key in queryset_dict.keys(): + value = queryset_dict.get(key) + sql, params = compiler_queryset(value, None if field_replace_dict is None else field_replace_dict.get(key), + with_table_name) + params_dict = {**params_dict, select_string.index("${" + key + "}"): params} + select_string = select_string.replace("${" + key + "}", sql) + + for key in sorted(list(params_dict.keys())): + result_params = [*result_params, *params_dict.get(key)] + return select_string, result_params + + +def generate_sql_by_query(queryset: QuerySet, select_string: str, + field_replace_dict: None | Dict[str, str] = None, with_table_name=False): + """ + 生成 查询sql + :param queryset: 查询条件 + :param select_string: 原始sql + :param field_replace_dict: 需要替换的查询字段,一般不需要传入如果有特殊的需要传入 + :return: sql:需要查询的sql params: sql 参数 + """ + sql, params = compiler_queryset(queryset, field_replace_dict, with_table_name) + return select_string + " " + sql, params + + +def compiler_queryset(queryset: QuerySet, field_replace_dict: None | Dict[str, str] = None, with_table_name=False): + """ + 解析 queryset查询对象 + :param with_table_name: + :param queryset: 查询对象 + :param field_replace_dict: 需要替换的查询字段,一般不需要传入如果有特殊的需要传入 + :return: sql:需要查询的sql params: sql 参数 + """ + q = queryset.query + compiler = q.get_compiler(DEFAULT_DB_ALIAS) + if field_replace_dict is None: + field_replace_dict = get_field_replace_dict(queryset) + app_sql_compiler = AppSQLCompiler(q, using=DEFAULT_DB_ALIAS, connection=compiler.connection, + field_replace_dict=field_replace_dict) + sql, params = app_sql_compiler.get_query_str(with_table_name=with_table_name) + return sql, params + + +def native_search(queryset: QuerySet | Dict[str, QuerySet], select_string: str, + field_replace_dict: None | Dict[str, Dict[str, str]] | Dict[str, str] = None, + with_search_one=False, with_table_name=False): + """ + 复杂查询 + :param with_table_name: 生成sql是否包含表名 + :param queryset: 查询条件构造器 + :param select_string: 查询前缀 不包括 where limit 等信息 + :param field_replace_dict: 需要替换的字段 + :param with_search_one: 查询 + :return: 查询结果 + """ + if isinstance(queryset, Dict): + exec_sql, exec_params = generate_sql_by_query_dict(queryset, select_string, field_replace_dict, with_table_name) + else: + exec_sql, exec_params = generate_sql_by_query(queryset, select_string, field_replace_dict, with_table_name) + if with_search_one: + return select_one(exec_sql, exec_params) + else: + return select_list(exec_sql, exec_params) + + +def native_update(queryset: QuerySet | Dict[str, QuerySet], select_string: str, + field_replace_dict: None | Dict[str, Dict[str, str]] | Dict[str, str] = None, + with_table_name=False): + """ + 复杂查询 + :param with_table_name: 生成sql是否包含表名 + :param queryset: 查询条件构造器 + :param select_string: 查询前缀 不包括 where limit 等信息 + :param field_replace_dict: 需要替换的字段 + :return: 查询结果 + """ + if isinstance(queryset, Dict): + exec_sql, exec_params = generate_sql_by_query_dict(queryset, select_string, field_replace_dict, with_table_name) + else: + exec_sql, exec_params = generate_sql_by_query(queryset, select_string, field_replace_dict, with_table_name) + return update_execute(exec_sql, exec_params) + + +def page_search(current_page: int, page_size: int, queryset: QuerySet, post_records_handler): + """ + 分页查询 + :param current_page: 当前页 + :param page_size: 每页大小 + :param queryset: 查询条件 + :param post_records_handler: 数据处理器 + :return: 分页结果 + """ + total = QuerySet(query=queryset.query.clone(), model=queryset.model).count() + result = queryset.all()[((current_page - 1) * page_size):(current_page * page_size)] + return Page(total, list(map(post_records_handler, result)), current_page, page_size) + + +def native_page_search(current_page: int, page_size: int, queryset: QuerySet | Dict[str, QuerySet], select_string: str, + field_replace_dict=None, + post_records_handler=lambda r: r, + with_table_name=False): + """ + 复杂分页查询 + :param with_table_name: + :param current_page: 当前页 + :param page_size: 每页大小 + :param queryset: 查询条件 + :param select_string: 查询 + :param field_replace_dict: 特殊字段替换 + :param post_records_handler: 数据row处理器 + :return: 分页结果 + """ + if isinstance(queryset, Dict): + exec_sql, exec_params = generate_sql_by_query_dict(queryset, select_string, field_replace_dict, with_table_name) + else: + exec_sql, exec_params = generate_sql_by_query(queryset, select_string, field_replace_dict, with_table_name) + total_sql = "SELECT \"count\"(*) FROM (%s) temp" % exec_sql + total = select_one(total_sql, exec_params) + limit_sql = connections[DEFAULT_DB_ALIAS].ops.limit_offset_sql( + ((current_page - 1) * page_size), (current_page * page_size) + ) + page_sql = exec_sql + " " + limit_sql + result = select_list(page_sql, exec_params) + return Page(total.get("count"), list(map(post_records_handler, result)), current_page, page_size) + + +def get_field_replace_dict(queryset: QuerySet): + """ + 获取需要替换的字段 默认 “xxx.xxx”需要被替换成 “xxx”."xxx" + :param queryset: 查询对象 + :return: 需要替换的字典 + """ + result = {} + for field in queryset.model._meta.local_fields: + if field.attname.__contains__("."): + replace_field = to_replace_field(field.attname) + result.__setitem__('"' + field.attname + '"', replace_field) + return result + + +def to_replace_field(field: str): + """ + 将field 转换为 需要替换的field “xxx.xxx”需要被替换成 “xxx”."xxx" 只替换 field包含.的字段 + :param field: django field字段 + :return: 替换字段 + """ + split_field = field.split(".") + return ".".join(list(map(lambda sf: '"' + sf + '"', split_field))) diff --git a/apps/common/db/sql_execute.py b/apps/common/db/sql_execute.py new file mode 100644 index 000000000..b12297e1f --- /dev/null +++ b/apps/common/db/sql_execute.py @@ -0,0 +1,67 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: sql_execute.py + @date:2023/9/25 20:05 + @desc: +""" +from typing import List + +from django.db import connection + + +def sql_execute(sql: str, params): + """ + 执行一条sql + :param sql: 需要执行的sql + :param params: sql参数 + :return: 执行结果 + """ + with connection.cursor() as cursor: + cursor.execute(sql, params) + columns = list(map(lambda d: d.name, cursor.description)) + res = cursor.fetchall() + result = list(map(lambda row: dict(list(zip(columns, row))), res)) + cursor.close() + return result + + +def update_execute(sql: str, params): + """ + 执行一条sql + :param sql: 需要执行的sql + :param params: sql参数 + :return: 执行结果 + """ + with connection.cursor() as cursor: + cursor.execute(sql, params) + affected_rows = cursor.rowcount + cursor.close() + return affected_rows + + +def select_list(sql: str, params: List): + """ + 执行sql 查询列表数据 + :param sql: 需要执行的sql + :param params: sql的参数 + :return: 查询结果 + """ + result_list = sql_execute(sql, params) + if result_list is None: + return [] + return result_list + + +def select_one(sql: str, params: List): + """ + 执行sql 查询一条数据 + :param sql: 需要执行的sql + :param params: 参数 + :return: 查询结果 + """ + result_list = sql_execute(sql, params) + if result_list is None or len(result_list) == 0: + return None + return result_list[0] diff --git a/apps/tools/serializers/tool.py b/apps/tools/serializers/tool.py index 872e7657a..a7c1bcde7 100644 --- a/apps/tools/serializers/tool.py +++ b/apps/tools/serializers/tool.py @@ -12,6 +12,7 @@ from django.http import HttpResponse from django.utils.translation import gettext_lazy as _ from rest_framework import serializers, status +from common.db.search import page_search from common.exception.app_exception import AppApiException from common.result import result from common.utils.tool_code import ToolExecutor @@ -127,6 +128,7 @@ class ToolCreateRequest(serializers.Serializer): module_id = serializers.CharField(required=False, allow_null=True, allow_blank=True, default='root') + class ToolEditRequest(serializers.Serializer): name = serializers.CharField(required=False, label=_('tool name')) @@ -357,4 +359,4 @@ class ToolTreeSerializer(serializers.Serializer): tools = QuerySet(Tool).filter(Q(workspace_id=self.data.get('workspace_id')) & Q(module_id__in=all_modules) & Q(tool_type=self.data.get('tool_type'))) - return ToolModelSerializer(tools, many=True).data + return page_search(current_page, page_size, tools, lambda record: ToolModelSerializer(record).data)