feat: 支持复杂sql查询,修改数据集查询

This commit is contained in:
shaohuzhang1 2023-10-11 15:07:10 +08:00
parent 1e3097fa3f
commit 64e93679f8
5 changed files with 161 additions and 54 deletions

View File

@ -6,8 +6,9 @@
@date2023/10/7 18:20
@desc:
"""
from typing import Dict, Any
from django.db import DEFAULT_DB_ALIAS, models
from django.db import DEFAULT_DB_ALIAS, models, connections
from django.db.models import QuerySet
from common.db.compiler import AppSQLCompiler
@ -30,8 +31,61 @@ def get_dynamics_model(attr: dict, table_name='dynamics'):
return type('Dynamics', (models.Model,), attributes)
def native_search(queryset: QuerySet, select_string: str,
field_replace_dict=None,
def generate_sql_by_query_dict(queryset_dict: Dict[str, QuerySet], select_string: str,
field_replace_dict: None | Dict[str, Dict[str, str]] = None):
"""
生成 查询sql
:param queryset_dict: 多条件 查询条件
:param select_string: 查询sql
:param field_replace_dict: 需要替换的查询字段,一般不需要传入如果有特殊的需要传入
:return: sql:需要查询的sql params: sql 参数
"""
params_dict: Dict[int, Any] = {}
result_params = []
for key in queryset_dict.keys():
value = queryset_dict.get(key)
sql, params = compiler_queryset(value, None if field_replace_dict is None else field_replace_dict.get(key))
params_dict = {**params_dict, select_string.index("${" + key + "}"): params}
select_string = select_string.replace("${" + key + "}", sql)
for key in sorted(list(params_dict.keys())):
result_params = [*result_params, *params_dict.get(key)]
return select_string, result_params
def generate_sql_by_query(queryset: QuerySet, select_string: str,
field_replace_dict: None | Dict[str, str] = None):
"""
生成 查询sql
:param queryset: 查询条件
:param select_string: 原始sql
:param field_replace_dict: 需要替换的查询字段,一般不需要传入如果有特殊的需要传入
:return: sql:需要查询的sql params: sql 参数
"""
sql, params = compiler_queryset(queryset, field_replace_dict)
return select_string + " " + sql, params
def compiler_queryset(queryset: QuerySet, field_replace_dict: None | Dict[str, str] = None):
"""
解析 queryset查询对象
:param queryset: 查询对象
:param field_replace_dict: 需要替换的查询字段,一般不需要传入如果有特殊的需要传入
:return: sql:需要查询的sql params: sql 参数
"""
q = queryset.query
compiler = q.get_compiler(DEFAULT_DB_ALIAS)
if field_replace_dict is None:
field_replace_dict = get_field_replace_dict(queryset)
app_sql_compiler = AppSQLCompiler(q, using=DEFAULT_DB_ALIAS, connection=compiler.connection,
field_replace_dict=field_replace_dict)
sql, params = app_sql_compiler.get_query_str(with_table_name=False)
return sql, params
def native_search(queryset: QuerySet | Dict[str, QuerySet], select_string: str,
field_replace_dict: None | Dict[str, Dict[str, str]] | Dict[str, str] = None,
with_search_one=False):
"""
复杂查询
@ -41,19 +95,14 @@ def native_search(queryset: QuerySet, select_string: str,
:param with_search_one: 查询
:return: 查询结果
"""
if field_replace_dict is None:
field_replace_dict = get_field_replace_dict(queryset)
q = queryset.query
compiler = q.get_compiler(DEFAULT_DB_ALIAS)
app_sql_compiler = AppSQLCompiler(q, using=DEFAULT_DB_ALIAS, connection=compiler.connection,
field_replace_dict=field_replace_dict)
sql, params = app_sql_compiler.get_query_str(with_table_name=False)
if with_search_one:
return select_one(select_string + " " +
sql, params)
if isinstance(queryset, Dict):
exec_sql, exec_params = generate_sql_by_query_dict(queryset, select_string, field_replace_dict)
else:
return select_list(select_string + " " +
sql, params)
exec_sql, exec_params = generate_sql_by_query(queryset, select_string, field_replace_dict)
if with_search_one:
return select_one(exec_sql, exec_params)
else:
return select_list(exec_sql, exec_params)
def page_search(current_page: int, page_size: int, queryset: QuerySet, post_records_handler):
@ -70,7 +119,7 @@ def page_search(current_page: int, page_size: int, queryset: QuerySet, post_reco
return Page(total, list(map(post_records_handler, result)), current_page, page_size)
def native_page_search(current_page: int, page_size: int, queryset: QuerySet, select_string: str,
def native_page_search(current_page: int, page_size: int, queryset: QuerySet | Dict[str, QuerySet], select_string: str,
field_replace_dict=None,
post_records_handler=lambda r: r):
"""
@ -83,20 +132,17 @@ def native_page_search(current_page: int, page_size: int, queryset: QuerySet, se
:param post_records_handler: 数据row处理器
:return: 分页结果
"""
if field_replace_dict is None:
field_replace_dict = get_field_replace_dict(queryset)
q = queryset.query
compiler = q.get_compiler(DEFAULT_DB_ALIAS)
app_sql_compiler = AppSQLCompiler(q, using=DEFAULT_DB_ALIAS, connection=compiler.connection,
field_replace_dict=field_replace_dict)
page_sql, params = app_sql_compiler.get_query_str(with_table_name=False)
total_sql = "SELECT \"count\"(*) FROM (%s) temp" % (select_string + " " + page_sql)
total = select_one(total_sql, params)
q.set_limits(((current_page - 1) * page_size), (current_page * page_size))
app_sql_compiler = AppSQLCompiler(q, using=DEFAULT_DB_ALIAS, connection=compiler.connection,
field_replace_dict=field_replace_dict)
page_sql, params = app_sql_compiler.get_query_str(with_table_name=False)
result = select_list(select_string + " " + page_sql, params)
if isinstance(queryset, Dict):
exec_sql, exec_params = generate_sql_by_query_dict(queryset, select_string, field_replace_dict)
else:
exec_sql, exec_params = generate_sql_by_query(queryset, select_string, field_replace_dict)
total_sql = "SELECT \"count\"(*) FROM (%s) temp" % exec_sql
total = select_one(total_sql, exec_params)
limit_sql = connections[DEFAULT_DB_ALIAS].ops.limit_offset_sql(
((current_page - 1) * page_size), (current_page * page_size)
)
page_sql = exec_sql + " " + limit_sql
result = select_list(page_sql, exec_params)
return Page(total.get("count"), list(map(post_records_handler, result)), current_page, page_size)

View File

@ -11,7 +11,7 @@ class Page(dict):
"""
def __init__(self, total: int, records: List, current_page: int, page_size: int, **kwargs):
super().__init__(**{'total': total, 'records': records, 'current_page': current_page, 'page_size': page_size})
super().__init__(**{'total': total, 'records': records, 'current': current_page, 'size': page_size})
class Result(JsonResponse):
@ -71,12 +71,12 @@ def get_page_api_response(response_data_schema: openapi.Schema):
default=1,
description="数据总条数"),
"records": response_data_schema,
"current_page": openapi.Schema(
"current": openapi.Schema(
type=openapi.TYPE_INTEGER,
title="当前页",
default=1,
description="当前页"),
"page_size": openapi.Schema(
"size": openapi.Schema(
type=openapi.TYPE_INTEGER,
title="每页大小",
default=10,

View File

@ -11,6 +11,7 @@ import uuid
from functools import reduce
from typing import Dict
from django.contrib.postgres.fields import ArrayField
from django.core import validators
from django.db import transaction, models
from django.db.models import QuerySet
@ -23,6 +24,7 @@ from common.mixins.api_mixin import ApiMixin
from common.util.file_util import get_file_content
from dataset.models.data_set import DataSet, Document, Paragraph
from dataset.serializers.document_serializers import CreateDocumentSerializers
from setting.models import AuthOperate
from smartdoc.conf import PROJECT_DIR
from users.models import User
@ -73,17 +75,38 @@ class DataSetSerializers(serializers.ModelSerializer):
message="数据集名称在1-256个字符之间")
])
user_id = serializers.CharField(required=True)
def get_query_set(self):
user_id = self.data.get("user_id")
query_set_dict = {}
query_set = QuerySet(model=get_dynamics_model(
{'dataset.name': models.CharField(), 'dataset.desc': models.CharField(),
"document_temp.char_length": models.IntegerField()}))
if "desc" in self.data:
query_string = {'dataset.desc__contains', self.data.get("desc")}
query_set = query_set.filter(query_string)
query_set = query_set.filter(**{'dataset.desc__contains': self.data.get("desc")})
if "name" in self.data:
query_string = {'dataset.name__contains', self.data.get("name")}
query_set = query_set.filter(query_string)
return query_set
query_set = query_set.filter(**{'dataset.name__contains': self.data.get("name")})
query_set_dict['default_sql'] = query_set
query_set_dict['dataset_custom_sql'] = QuerySet(model=get_dynamics_model(
{'dataset.user_id': models.CharField(),
})).filter(
**{'dataset.user_id': user_id}
)
query_set_dict['team_member_permission_custom_sql'] = QuerySet(model=get_dynamics_model(
{'user_id': models.CharField(),
'team_member_permission.operate': ArrayField(verbose_name="权限操作列表",
base_field=models.CharField(max_length=256,
blank=True,
choices=AuthOperate.choices,
default=AuthOperate.USE)
)})).filter(
**{'user_id': user_id, 'team_member_permission.operate__contains': ['USE']})
return query_set_dict
def page(self, current_page: int, page_size: int):
return native_page_search(current_page, page_size, self.get_query_set(), select_string=get_file_content(
@ -200,18 +223,32 @@ class DataSetSerializers(serializers.ModelSerializer):
dataset.delete()
return True
def one(self, with_valid=True):
def one(self, user_id, with_valid=True):
if with_valid:
self.is_valid()
query_string = {'dataset.id', self.data.get("id")}
query_set = QuerySet(model=get_dynamics_model(
{'dataset.id': models.UUIDField()})).filter(query_string)
return native_search(query_set, select_string=get_file_content(
query_set_dict = {'default_sql': QuerySet(model=get_dynamics_model(
{'temp.id': models.UUIDField()})).filter(**{'temp.id': self.data.get("id")}),
'dataset_custom_sql': QuerySet(model=get_dynamics_model(
{'dataset.user_id': models.CharField()})).filter(
**{'dataset.user_id': user_id}
), 'team_member_permission_custom_sql': QuerySet(
model=get_dynamics_model({'user_id': models.CharField(),
'team_member_permission.operate': ArrayField(
verbose_name="权限操作列表",
base_field=models.CharField(max_length=256,
blank=True,
choices=AuthOperate.choices,
default=AuthOperate.USE)
)})).filter(
**{'user_id': user_id, 'team_member_permission.operate__contains': ['USE']})}
return native_search(query_set_dict, select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_dataset.sql')), with_search_one=True)
def edit(self, dataset: Dict):
def edit(self, dataset: Dict, user_id: str):
"""
修改数据集
:param user_id: 用户id
:param dataset: Dict name desc
:return:
"""
@ -222,7 +259,7 @@ class DataSetSerializers(serializers.ModelSerializer):
if 'desc' in dataset:
_dataset.desc = dataset.get("desc")
_dataset.save()
return self.one(with_valid=False)
return self.one(with_valid=False, user_id=user_id)
@staticmethod
def get_request_body_api():

View File

@ -1,7 +1,30 @@
SELECT
dataset.*,
document_temp."char_length",
"document_temp".document_count
*
FROM
dataset dataset
LEFT JOIN ( SELECT "count" ( "id" ) AS document_count, "sum" ( "char_length" ) "char_length", dataset_id FROM "document" GROUP BY dataset_id ) "document_temp" ON dataset."id" = "document_temp".dataset_id
(
SELECT
"temp_dataset".*,
"document_temp"."char_length",
"document_temp".document_count FROM (
SELECT dataset.*
FROM
dataset dataset
${dataset_custom_sql}
UNION
SELECT
*
FROM
dataset
WHERE
dataset."id" IN (
SELECT
team_member_permission.target
FROM
team_member team_member
LEFT JOIN team_member_permission team_member_permission ON team_member_permission.member_id = team_member."id"
${team_member_permission_custom_sql}
)
) temp_dataset
LEFT JOIN ( SELECT "count" ( "id" ) AS document_count, "sum" ( "char_length" ) "char_length", dataset_id FROM "document" GROUP BY dataset_id ) "document_temp" ON temp_dataset."id" = "document_temp".dataset_id
) temp
${default_sql}

View File

@ -29,7 +29,7 @@ class Dataset(APIView):
responses=get_api_response(DataSetSerializers.Query.get_response_body_api()))
@has_permissions(PermissionConstants.DATASET_READ, compare=CompareConstants.AND)
def get(self, request: Request):
d = DataSetSerializers.Query(data=request.query_params)
d = DataSetSerializers.Query(data={**request.query_params, 'user_id': str(request.user.id)})
d.is_valid()
return result.success(d.list())
@ -63,7 +63,7 @@ class Dataset(APIView):
@has_permissions(lambda r, keywords: Permission(group=Group.DATASET, operate=Operate.USE,
dynamic_tag=keywords.get('dataset_id')))
def get(self, request: Request, dataset_id: str):
return result.success(DataSetSerializers.Operate(data={'id': dataset_id}).one())
return result.success(DataSetSerializers.Operate(data={'id': dataset_id}).one(user_id=request.user.id))
@action(methods="PUT", detail=False)
@swagger_auto_schema(operation_summary="修改数据集信息", operation_id="修改数据集信息",
@ -72,7 +72,8 @@ class Dataset(APIView):
@has_permissions(lambda r, keywords: Permission(group=Group.DATASET, operate=Operate.MANAGE,
dynamic_tag=keywords.get('dataset_id')))
def put(self, request: Request, dataset_id: str):
return result.success(DataSetSerializers.Operate(data={'id': dataset_id}).edit(request.data))
return result.success(
DataSetSerializers.Operate(data={'id': dataset_id}).edit(request.data, user_id=request.user.id))
class Page(APIView):
authentication_classes = [TokenAuth]
@ -85,6 +86,6 @@ class Dataset(APIView):
responses=get_page_api_response(DataSetSerializers.Query.get_response_body_api()))
@has_permissions(PermissionConstants.USER_READ, compare=CompareConstants.AND)
def get(self, request: Request, current_page, page_size):
d = DataSetSerializers.Query(data=request.query_params)
d = DataSetSerializers.Query(data={**request.query_params, 'user_id': str(request.user.id)})
d.is_valid()
return result.success(d.page(current_page, page_size))