From 64e93679f833449391d3d55356ce3d8c255debbe Mon Sep 17 00:00:00 2001 From: shaohuzhang1 Date: Wed, 11 Oct 2023 15:07:10 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81=E5=A4=8D=E6=9D=82sql?= =?UTF-8?q?=E6=9F=A5=E8=AF=A2,=E4=BF=AE=E6=94=B9=E6=95=B0=E6=8D=AE?= =?UTF-8?q?=E9=9B=86=E6=9F=A5=E8=AF=A2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/common/db/search.py | 106 +++++++++++++----- apps/common/response/result.py | 6 +- .../serializers/dataset_serializers.py | 61 ++++++++-- apps/dataset/sql/list_dataset.sql | 33 +++++- apps/dataset/views/dataset.py | 9 +- 5 files changed, 161 insertions(+), 54 deletions(-) diff --git a/apps/common/db/search.py b/apps/common/db/search.py index 6b772a632..007fb9f27 100644 --- a/apps/common/db/search.py +++ b/apps/common/db/search.py @@ -6,8 +6,9 @@ @date:2023/10/7 18:20 @desc: """ +from typing import Dict, Any -from django.db import DEFAULT_DB_ALIAS, models +from django.db import DEFAULT_DB_ALIAS, models, connections from django.db.models import QuerySet from common.db.compiler import AppSQLCompiler @@ -30,8 +31,61 @@ def get_dynamics_model(attr: dict, table_name='dynamics'): return type('Dynamics', (models.Model,), attributes) -def native_search(queryset: QuerySet, select_string: str, - field_replace_dict=None, +def generate_sql_by_query_dict(queryset_dict: Dict[str, QuerySet], select_string: str, + field_replace_dict: None | Dict[str, Dict[str, str]] = None): + """ + 生成 查询sql + :param queryset_dict: 多条件 查询条件 + :param select_string: 查询sql + :param field_replace_dict: 需要替换的查询字段,一般不需要传入如果有特殊的需要传入 + :return: sql:需要查询的sql params: sql 参数 + """ + + params_dict: Dict[int, Any] = {} + result_params = [] + for key in queryset_dict.keys(): + value = queryset_dict.get(key) + sql, params = compiler_queryset(value, None if field_replace_dict is None else field_replace_dict.get(key)) + params_dict = {**params_dict, select_string.index("${" + key + "}"): params} + select_string = select_string.replace("${" + key + "}", sql) + + for key in sorted(list(params_dict.keys())): + result_params = [*result_params, *params_dict.get(key)] + return select_string, result_params + + +def generate_sql_by_query(queryset: QuerySet, select_string: str, + field_replace_dict: None | Dict[str, str] = None): + """ + 生成 查询sql + :param queryset: 查询条件 + :param select_string: 原始sql + :param field_replace_dict: 需要替换的查询字段,一般不需要传入如果有特殊的需要传入 + :return: sql:需要查询的sql params: sql 参数 + """ + sql, params = compiler_queryset(queryset, field_replace_dict) + return select_string + " " + sql, params + + +def compiler_queryset(queryset: QuerySet, field_replace_dict: None | Dict[str, str] = None): + """ + 解析 queryset查询对象 + :param queryset: 查询对象 + :param field_replace_dict: 需要替换的查询字段,一般不需要传入如果有特殊的需要传入 + :return: sql:需要查询的sql params: sql 参数 + """ + q = queryset.query + compiler = q.get_compiler(DEFAULT_DB_ALIAS) + if field_replace_dict is None: + field_replace_dict = get_field_replace_dict(queryset) + app_sql_compiler = AppSQLCompiler(q, using=DEFAULT_DB_ALIAS, connection=compiler.connection, + field_replace_dict=field_replace_dict) + sql, params = app_sql_compiler.get_query_str(with_table_name=False) + return sql, params + + +def native_search(queryset: QuerySet | Dict[str, QuerySet], select_string: str, + field_replace_dict: None | Dict[str, Dict[str, str]] | Dict[str, str] = None, with_search_one=False): """ 复杂查询 @@ -41,19 +95,14 @@ def native_search(queryset: QuerySet, select_string: str, :param with_search_one: 查询 :return: 查询结果 """ - if field_replace_dict is None: - field_replace_dict = get_field_replace_dict(queryset) - q = queryset.query - compiler = q.get_compiler(DEFAULT_DB_ALIAS) - app_sql_compiler = AppSQLCompiler(q, using=DEFAULT_DB_ALIAS, connection=compiler.connection, - field_replace_dict=field_replace_dict) - sql, params = app_sql_compiler.get_query_str(with_table_name=False) - if with_search_one: - return select_one(select_string + " " + - sql, params) + if isinstance(queryset, Dict): + exec_sql, exec_params = generate_sql_by_query_dict(queryset, select_string, field_replace_dict) else: - return select_list(select_string + " " + - sql, params) + exec_sql, exec_params = generate_sql_by_query(queryset, select_string, field_replace_dict) + if with_search_one: + return select_one(exec_sql, exec_params) + else: + return select_list(exec_sql, exec_params) def page_search(current_page: int, page_size: int, queryset: QuerySet, post_records_handler): @@ -70,7 +119,7 @@ def page_search(current_page: int, page_size: int, queryset: QuerySet, post_reco return Page(total, list(map(post_records_handler, result)), current_page, page_size) -def native_page_search(current_page: int, page_size: int, queryset: QuerySet, select_string: str, +def native_page_search(current_page: int, page_size: int, queryset: QuerySet | Dict[str, QuerySet], select_string: str, field_replace_dict=None, post_records_handler=lambda r: r): """ @@ -83,20 +132,17 @@ def native_page_search(current_page: int, page_size: int, queryset: QuerySet, se :param post_records_handler: 数据row处理器 :return: 分页结果 """ - if field_replace_dict is None: - field_replace_dict = get_field_replace_dict(queryset) - q = queryset.query - compiler = q.get_compiler(DEFAULT_DB_ALIAS) - app_sql_compiler = AppSQLCompiler(q, using=DEFAULT_DB_ALIAS, connection=compiler.connection, - field_replace_dict=field_replace_dict) - page_sql, params = app_sql_compiler.get_query_str(with_table_name=False) - total_sql = "SELECT \"count\"(*) FROM (%s) temp" % (select_string + " " + page_sql) - total = select_one(total_sql, params) - q.set_limits(((current_page - 1) * page_size), (current_page * page_size)) - app_sql_compiler = AppSQLCompiler(q, using=DEFAULT_DB_ALIAS, connection=compiler.connection, - field_replace_dict=field_replace_dict) - page_sql, params = app_sql_compiler.get_query_str(with_table_name=False) - result = select_list(select_string + " " + page_sql, params) + if isinstance(queryset, Dict): + exec_sql, exec_params = generate_sql_by_query_dict(queryset, select_string, field_replace_dict) + else: + exec_sql, exec_params = generate_sql_by_query(queryset, select_string, field_replace_dict) + total_sql = "SELECT \"count\"(*) FROM (%s) temp" % exec_sql + total = select_one(total_sql, exec_params) + limit_sql = connections[DEFAULT_DB_ALIAS].ops.limit_offset_sql( + ((current_page - 1) * page_size), (current_page * page_size) + ) + page_sql = exec_sql + " " + limit_sql + result = select_list(page_sql, exec_params) return Page(total.get("count"), list(map(post_records_handler, result)), current_page, page_size) diff --git a/apps/common/response/result.py b/apps/common/response/result.py index c3992e18a..1b205069b 100644 --- a/apps/common/response/result.py +++ b/apps/common/response/result.py @@ -11,7 +11,7 @@ class Page(dict): """ def __init__(self, total: int, records: List, current_page: int, page_size: int, **kwargs): - super().__init__(**{'total': total, 'records': records, 'current_page': current_page, 'page_size': page_size}) + super().__init__(**{'total': total, 'records': records, 'current': current_page, 'size': page_size}) class Result(JsonResponse): @@ -71,12 +71,12 @@ def get_page_api_response(response_data_schema: openapi.Schema): default=1, description="数据总条数"), "records": response_data_schema, - "current_page": openapi.Schema( + "current": openapi.Schema( type=openapi.TYPE_INTEGER, title="当前页", default=1, description="当前页"), - "page_size": openapi.Schema( + "size": openapi.Schema( type=openapi.TYPE_INTEGER, title="每页大小", default=10, diff --git a/apps/dataset/serializers/dataset_serializers.py b/apps/dataset/serializers/dataset_serializers.py index 5dfcf4c53..9998422d7 100644 --- a/apps/dataset/serializers/dataset_serializers.py +++ b/apps/dataset/serializers/dataset_serializers.py @@ -11,6 +11,7 @@ import uuid from functools import reduce from typing import Dict +from django.contrib.postgres.fields import ArrayField from django.core import validators from django.db import transaction, models from django.db.models import QuerySet @@ -23,6 +24,7 @@ from common.mixins.api_mixin import ApiMixin from common.util.file_util import get_file_content from dataset.models.data_set import DataSet, Document, Paragraph from dataset.serializers.document_serializers import CreateDocumentSerializers +from setting.models import AuthOperate from smartdoc.conf import PROJECT_DIR from users.models import User @@ -73,17 +75,38 @@ class DataSetSerializers(serializers.ModelSerializer): message="数据集名称在1-256个字符之间") ]) + user_id = serializers.CharField(required=True) + def get_query_set(self): + user_id = self.data.get("user_id") + query_set_dict = {} query_set = QuerySet(model=get_dynamics_model( {'dataset.name': models.CharField(), 'dataset.desc': models.CharField(), "document_temp.char_length": models.IntegerField()})) if "desc" in self.data: - query_string = {'dataset.desc__contains', self.data.get("desc")} - query_set = query_set.filter(query_string) + query_set = query_set.filter(**{'dataset.desc__contains': self.data.get("desc")}) if "name" in self.data: - query_string = {'dataset.name__contains', self.data.get("name")} - query_set = query_set.filter(query_string) - return query_set + query_set = query_set.filter(**{'dataset.name__contains': self.data.get("name")}) + + query_set_dict['default_sql'] = query_set + + query_set_dict['dataset_custom_sql'] = QuerySet(model=get_dynamics_model( + {'dataset.user_id': models.CharField(), + })).filter( + **{'dataset.user_id': user_id} + ) + + query_set_dict['team_member_permission_custom_sql'] = QuerySet(model=get_dynamics_model( + {'user_id': models.CharField(), + 'team_member_permission.operate': ArrayField(verbose_name="权限操作列表", + base_field=models.CharField(max_length=256, + blank=True, + choices=AuthOperate.choices, + default=AuthOperate.USE) + )})).filter( + **{'user_id': user_id, 'team_member_permission.operate__contains': ['USE']}) + + return query_set_dict def page(self, current_page: int, page_size: int): return native_page_search(current_page, page_size, self.get_query_set(), select_string=get_file_content( @@ -200,18 +223,32 @@ class DataSetSerializers(serializers.ModelSerializer): dataset.delete() return True - def one(self, with_valid=True): + def one(self, user_id, with_valid=True): if with_valid: self.is_valid() - query_string = {'dataset.id', self.data.get("id")} - query_set = QuerySet(model=get_dynamics_model( - {'dataset.id': models.UUIDField()})).filter(query_string) - return native_search(query_set, select_string=get_file_content( + query_set_dict = {'default_sql': QuerySet(model=get_dynamics_model( + {'temp.id': models.UUIDField()})).filter(**{'temp.id': self.data.get("id")}), + 'dataset_custom_sql': QuerySet(model=get_dynamics_model( + {'dataset.user_id': models.CharField()})).filter( + **{'dataset.user_id': user_id} + ), 'team_member_permission_custom_sql': QuerySet( + model=get_dynamics_model({'user_id': models.CharField(), + 'team_member_permission.operate': ArrayField( + verbose_name="权限操作列表", + base_field=models.CharField(max_length=256, + blank=True, + choices=AuthOperate.choices, + default=AuthOperate.USE) + )})).filter( + **{'user_id': user_id, 'team_member_permission.operate__contains': ['USE']})} + + return native_search(query_set_dict, select_string=get_file_content( os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_dataset.sql')), with_search_one=True) - def edit(self, dataset: Dict): + def edit(self, dataset: Dict, user_id: str): """ 修改数据集 + :param user_id: 用户id :param dataset: Dict name desc :return: """ @@ -222,7 +259,7 @@ class DataSetSerializers(serializers.ModelSerializer): if 'desc' in dataset: _dataset.desc = dataset.get("desc") _dataset.save() - return self.one(with_valid=False) + return self.one(with_valid=False, user_id=user_id) @staticmethod def get_request_body_api(): diff --git a/apps/dataset/sql/list_dataset.sql b/apps/dataset/sql/list_dataset.sql index c7827e6a8..e3e85e229 100644 --- a/apps/dataset/sql/list_dataset.sql +++ b/apps/dataset/sql/list_dataset.sql @@ -1,7 +1,30 @@ SELECT - dataset.*, - document_temp."char_length", - "document_temp".document_count + * FROM - dataset dataset - LEFT JOIN ( SELECT "count" ( "id" ) AS document_count, "sum" ( "char_length" ) "char_length", dataset_id FROM "document" GROUP BY dataset_id ) "document_temp" ON dataset."id" = "document_temp".dataset_id \ No newline at end of file + ( + SELECT + "temp_dataset".*, + "document_temp"."char_length", + "document_temp".document_count FROM ( + SELECT dataset.* + FROM + dataset dataset + ${dataset_custom_sql} + UNION + SELECT + * + FROM + dataset + WHERE + dataset."id" IN ( + SELECT + team_member_permission.target + FROM + team_member team_member + LEFT JOIN team_member_permission team_member_permission ON team_member_permission.member_id = team_member."id" + ${team_member_permission_custom_sql} + ) + ) temp_dataset + LEFT JOIN ( SELECT "count" ( "id" ) AS document_count, "sum" ( "char_length" ) "char_length", dataset_id FROM "document" GROUP BY dataset_id ) "document_temp" ON temp_dataset."id" = "document_temp".dataset_id + ) temp + ${default_sql} \ No newline at end of file diff --git a/apps/dataset/views/dataset.py b/apps/dataset/views/dataset.py index 48439d693..3079c7bc2 100644 --- a/apps/dataset/views/dataset.py +++ b/apps/dataset/views/dataset.py @@ -29,7 +29,7 @@ class Dataset(APIView): responses=get_api_response(DataSetSerializers.Query.get_response_body_api())) @has_permissions(PermissionConstants.DATASET_READ, compare=CompareConstants.AND) def get(self, request: Request): - d = DataSetSerializers.Query(data=request.query_params) + d = DataSetSerializers.Query(data={**request.query_params, 'user_id': str(request.user.id)}) d.is_valid() return result.success(d.list()) @@ -63,7 +63,7 @@ class Dataset(APIView): @has_permissions(lambda r, keywords: Permission(group=Group.DATASET, operate=Operate.USE, dynamic_tag=keywords.get('dataset_id'))) def get(self, request: Request, dataset_id: str): - return result.success(DataSetSerializers.Operate(data={'id': dataset_id}).one()) + return result.success(DataSetSerializers.Operate(data={'id': dataset_id}).one(user_id=request.user.id)) @action(methods="PUT", detail=False) @swagger_auto_schema(operation_summary="修改数据集信息", operation_id="修改数据集信息", @@ -72,7 +72,8 @@ class Dataset(APIView): @has_permissions(lambda r, keywords: Permission(group=Group.DATASET, operate=Operate.MANAGE, dynamic_tag=keywords.get('dataset_id'))) def put(self, request: Request, dataset_id: str): - return result.success(DataSetSerializers.Operate(data={'id': dataset_id}).edit(request.data)) + return result.success( + DataSetSerializers.Operate(data={'id': dataset_id}).edit(request.data, user_id=request.user.id)) class Page(APIView): authentication_classes = [TokenAuth] @@ -85,6 +86,6 @@ class Dataset(APIView): responses=get_page_api_response(DataSetSerializers.Query.get_response_body_api())) @has_permissions(PermissionConstants.USER_READ, compare=CompareConstants.AND) def get(self, request: Request, current_page, page_size): - d = DataSetSerializers.Query(data=request.query_params) + d = DataSetSerializers.Query(data={**request.query_params, 'user_id': str(request.user.id)}) d.is_valid() return result.success(d.page(current_page, page_size))