mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-25 17:22:55 +00:00
feat: 支持复杂sql查询,修改数据集查询
This commit is contained in:
parent
1e3097fa3f
commit
64e93679f8
|
|
@ -6,8 +6,9 @@
|
|||
@date:2023/10/7 18:20
|
||||
@desc:
|
||||
"""
|
||||
from typing import Dict, Any
|
||||
|
||||
from django.db import DEFAULT_DB_ALIAS, models
|
||||
from django.db import DEFAULT_DB_ALIAS, models, connections
|
||||
from django.db.models import QuerySet
|
||||
|
||||
from common.db.compiler import AppSQLCompiler
|
||||
|
|
@ -30,8 +31,61 @@ def get_dynamics_model(attr: dict, table_name='dynamics'):
|
|||
return type('Dynamics', (models.Model,), attributes)
|
||||
|
||||
|
||||
def native_search(queryset: QuerySet, select_string: str,
|
||||
field_replace_dict=None,
|
||||
def generate_sql_by_query_dict(queryset_dict: Dict[str, QuerySet], select_string: str,
|
||||
field_replace_dict: None | Dict[str, Dict[str, str]] = None):
|
||||
"""
|
||||
生成 查询sql
|
||||
:param queryset_dict: 多条件 查询条件
|
||||
:param select_string: 查询sql
|
||||
:param field_replace_dict: 需要替换的查询字段,一般不需要传入如果有特殊的需要传入
|
||||
:return: sql:需要查询的sql params: sql 参数
|
||||
"""
|
||||
|
||||
params_dict: Dict[int, Any] = {}
|
||||
result_params = []
|
||||
for key in queryset_dict.keys():
|
||||
value = queryset_dict.get(key)
|
||||
sql, params = compiler_queryset(value, None if field_replace_dict is None else field_replace_dict.get(key))
|
||||
params_dict = {**params_dict, select_string.index("${" + key + "}"): params}
|
||||
select_string = select_string.replace("${" + key + "}", sql)
|
||||
|
||||
for key in sorted(list(params_dict.keys())):
|
||||
result_params = [*result_params, *params_dict.get(key)]
|
||||
return select_string, result_params
|
||||
|
||||
|
||||
def generate_sql_by_query(queryset: QuerySet, select_string: str,
|
||||
field_replace_dict: None | Dict[str, str] = None):
|
||||
"""
|
||||
生成 查询sql
|
||||
:param queryset: 查询条件
|
||||
:param select_string: 原始sql
|
||||
:param field_replace_dict: 需要替换的查询字段,一般不需要传入如果有特殊的需要传入
|
||||
:return: sql:需要查询的sql params: sql 参数
|
||||
"""
|
||||
sql, params = compiler_queryset(queryset, field_replace_dict)
|
||||
return select_string + " " + sql, params
|
||||
|
||||
|
||||
def compiler_queryset(queryset: QuerySet, field_replace_dict: None | Dict[str, str] = None):
|
||||
"""
|
||||
解析 queryset查询对象
|
||||
:param queryset: 查询对象
|
||||
:param field_replace_dict: 需要替换的查询字段,一般不需要传入如果有特殊的需要传入
|
||||
:return: sql:需要查询的sql params: sql 参数
|
||||
"""
|
||||
q = queryset.query
|
||||
compiler = q.get_compiler(DEFAULT_DB_ALIAS)
|
||||
if field_replace_dict is None:
|
||||
field_replace_dict = get_field_replace_dict(queryset)
|
||||
app_sql_compiler = AppSQLCompiler(q, using=DEFAULT_DB_ALIAS, connection=compiler.connection,
|
||||
field_replace_dict=field_replace_dict)
|
||||
sql, params = app_sql_compiler.get_query_str(with_table_name=False)
|
||||
return sql, params
|
||||
|
||||
|
||||
def native_search(queryset: QuerySet | Dict[str, QuerySet], select_string: str,
|
||||
field_replace_dict: None | Dict[str, Dict[str, str]] | Dict[str, str] = None,
|
||||
with_search_one=False):
|
||||
"""
|
||||
复杂查询
|
||||
|
|
@ -41,19 +95,14 @@ def native_search(queryset: QuerySet, select_string: str,
|
|||
:param with_search_one: 查询
|
||||
:return: 查询结果
|
||||
"""
|
||||
if field_replace_dict is None:
|
||||
field_replace_dict = get_field_replace_dict(queryset)
|
||||
q = queryset.query
|
||||
compiler = q.get_compiler(DEFAULT_DB_ALIAS)
|
||||
app_sql_compiler = AppSQLCompiler(q, using=DEFAULT_DB_ALIAS, connection=compiler.connection,
|
||||
field_replace_dict=field_replace_dict)
|
||||
sql, params = app_sql_compiler.get_query_str(with_table_name=False)
|
||||
if with_search_one:
|
||||
return select_one(select_string + " " +
|
||||
sql, params)
|
||||
if isinstance(queryset, Dict):
|
||||
exec_sql, exec_params = generate_sql_by_query_dict(queryset, select_string, field_replace_dict)
|
||||
else:
|
||||
return select_list(select_string + " " +
|
||||
sql, params)
|
||||
exec_sql, exec_params = generate_sql_by_query(queryset, select_string, field_replace_dict)
|
||||
if with_search_one:
|
||||
return select_one(exec_sql, exec_params)
|
||||
else:
|
||||
return select_list(exec_sql, exec_params)
|
||||
|
||||
|
||||
def page_search(current_page: int, page_size: int, queryset: QuerySet, post_records_handler):
|
||||
|
|
@ -70,7 +119,7 @@ def page_search(current_page: int, page_size: int, queryset: QuerySet, post_reco
|
|||
return Page(total, list(map(post_records_handler, result)), current_page, page_size)
|
||||
|
||||
|
||||
def native_page_search(current_page: int, page_size: int, queryset: QuerySet, select_string: str,
|
||||
def native_page_search(current_page: int, page_size: int, queryset: QuerySet | Dict[str, QuerySet], select_string: str,
|
||||
field_replace_dict=None,
|
||||
post_records_handler=lambda r: r):
|
||||
"""
|
||||
|
|
@ -83,20 +132,17 @@ def native_page_search(current_page: int, page_size: int, queryset: QuerySet, se
|
|||
:param post_records_handler: 数据row处理器
|
||||
:return: 分页结果
|
||||
"""
|
||||
if field_replace_dict is None:
|
||||
field_replace_dict = get_field_replace_dict(queryset)
|
||||
q = queryset.query
|
||||
compiler = q.get_compiler(DEFAULT_DB_ALIAS)
|
||||
app_sql_compiler = AppSQLCompiler(q, using=DEFAULT_DB_ALIAS, connection=compiler.connection,
|
||||
field_replace_dict=field_replace_dict)
|
||||
page_sql, params = app_sql_compiler.get_query_str(with_table_name=False)
|
||||
total_sql = "SELECT \"count\"(*) FROM (%s) temp" % (select_string + " " + page_sql)
|
||||
total = select_one(total_sql, params)
|
||||
q.set_limits(((current_page - 1) * page_size), (current_page * page_size))
|
||||
app_sql_compiler = AppSQLCompiler(q, using=DEFAULT_DB_ALIAS, connection=compiler.connection,
|
||||
field_replace_dict=field_replace_dict)
|
||||
page_sql, params = app_sql_compiler.get_query_str(with_table_name=False)
|
||||
result = select_list(select_string + " " + page_sql, params)
|
||||
if isinstance(queryset, Dict):
|
||||
exec_sql, exec_params = generate_sql_by_query_dict(queryset, select_string, field_replace_dict)
|
||||
else:
|
||||
exec_sql, exec_params = generate_sql_by_query(queryset, select_string, field_replace_dict)
|
||||
total_sql = "SELECT \"count\"(*) FROM (%s) temp" % exec_sql
|
||||
total = select_one(total_sql, exec_params)
|
||||
limit_sql = connections[DEFAULT_DB_ALIAS].ops.limit_offset_sql(
|
||||
((current_page - 1) * page_size), (current_page * page_size)
|
||||
)
|
||||
page_sql = exec_sql + " " + limit_sql
|
||||
result = select_list(page_sql, exec_params)
|
||||
return Page(total.get("count"), list(map(post_records_handler, result)), current_page, page_size)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ class Page(dict):
|
|||
"""
|
||||
|
||||
def __init__(self, total: int, records: List, current_page: int, page_size: int, **kwargs):
|
||||
super().__init__(**{'total': total, 'records': records, 'current_page': current_page, 'page_size': page_size})
|
||||
super().__init__(**{'total': total, 'records': records, 'current': current_page, 'size': page_size})
|
||||
|
||||
|
||||
class Result(JsonResponse):
|
||||
|
|
@ -71,12 +71,12 @@ def get_page_api_response(response_data_schema: openapi.Schema):
|
|||
default=1,
|
||||
description="数据总条数"),
|
||||
"records": response_data_schema,
|
||||
"current_page": openapi.Schema(
|
||||
"current": openapi.Schema(
|
||||
type=openapi.TYPE_INTEGER,
|
||||
title="当前页",
|
||||
default=1,
|
||||
description="当前页"),
|
||||
"page_size": openapi.Schema(
|
||||
"size": openapi.Schema(
|
||||
type=openapi.TYPE_INTEGER,
|
||||
title="每页大小",
|
||||
default=10,
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ import uuid
|
|||
from functools import reduce
|
||||
from typing import Dict
|
||||
|
||||
from django.contrib.postgres.fields import ArrayField
|
||||
from django.core import validators
|
||||
from django.db import transaction, models
|
||||
from django.db.models import QuerySet
|
||||
|
|
@ -23,6 +24,7 @@ from common.mixins.api_mixin import ApiMixin
|
|||
from common.util.file_util import get_file_content
|
||||
from dataset.models.data_set import DataSet, Document, Paragraph
|
||||
from dataset.serializers.document_serializers import CreateDocumentSerializers
|
||||
from setting.models import AuthOperate
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
from users.models import User
|
||||
|
||||
|
|
@ -73,17 +75,38 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||
message="数据集名称在1-256个字符之间")
|
||||
])
|
||||
|
||||
user_id = serializers.CharField(required=True)
|
||||
|
||||
def get_query_set(self):
|
||||
user_id = self.data.get("user_id")
|
||||
query_set_dict = {}
|
||||
query_set = QuerySet(model=get_dynamics_model(
|
||||
{'dataset.name': models.CharField(), 'dataset.desc': models.CharField(),
|
||||
"document_temp.char_length": models.IntegerField()}))
|
||||
if "desc" in self.data:
|
||||
query_string = {'dataset.desc__contains', self.data.get("desc")}
|
||||
query_set = query_set.filter(query_string)
|
||||
query_set = query_set.filter(**{'dataset.desc__contains': self.data.get("desc")})
|
||||
if "name" in self.data:
|
||||
query_string = {'dataset.name__contains', self.data.get("name")}
|
||||
query_set = query_set.filter(query_string)
|
||||
return query_set
|
||||
query_set = query_set.filter(**{'dataset.name__contains': self.data.get("name")})
|
||||
|
||||
query_set_dict['default_sql'] = query_set
|
||||
|
||||
query_set_dict['dataset_custom_sql'] = QuerySet(model=get_dynamics_model(
|
||||
{'dataset.user_id': models.CharField(),
|
||||
})).filter(
|
||||
**{'dataset.user_id': user_id}
|
||||
)
|
||||
|
||||
query_set_dict['team_member_permission_custom_sql'] = QuerySet(model=get_dynamics_model(
|
||||
{'user_id': models.CharField(),
|
||||
'team_member_permission.operate': ArrayField(verbose_name="权限操作列表",
|
||||
base_field=models.CharField(max_length=256,
|
||||
blank=True,
|
||||
choices=AuthOperate.choices,
|
||||
default=AuthOperate.USE)
|
||||
)})).filter(
|
||||
**{'user_id': user_id, 'team_member_permission.operate__contains': ['USE']})
|
||||
|
||||
return query_set_dict
|
||||
|
||||
def page(self, current_page: int, page_size: int):
|
||||
return native_page_search(current_page, page_size, self.get_query_set(), select_string=get_file_content(
|
||||
|
|
@ -200,18 +223,32 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||
dataset.delete()
|
||||
return True
|
||||
|
||||
def one(self, with_valid=True):
|
||||
def one(self, user_id, with_valid=True):
|
||||
if with_valid:
|
||||
self.is_valid()
|
||||
query_string = {'dataset.id', self.data.get("id")}
|
||||
query_set = QuerySet(model=get_dynamics_model(
|
||||
{'dataset.id': models.UUIDField()})).filter(query_string)
|
||||
return native_search(query_set, select_string=get_file_content(
|
||||
query_set_dict = {'default_sql': QuerySet(model=get_dynamics_model(
|
||||
{'temp.id': models.UUIDField()})).filter(**{'temp.id': self.data.get("id")}),
|
||||
'dataset_custom_sql': QuerySet(model=get_dynamics_model(
|
||||
{'dataset.user_id': models.CharField()})).filter(
|
||||
**{'dataset.user_id': user_id}
|
||||
), 'team_member_permission_custom_sql': QuerySet(
|
||||
model=get_dynamics_model({'user_id': models.CharField(),
|
||||
'team_member_permission.operate': ArrayField(
|
||||
verbose_name="权限操作列表",
|
||||
base_field=models.CharField(max_length=256,
|
||||
blank=True,
|
||||
choices=AuthOperate.choices,
|
||||
default=AuthOperate.USE)
|
||||
)})).filter(
|
||||
**{'user_id': user_id, 'team_member_permission.operate__contains': ['USE']})}
|
||||
|
||||
return native_search(query_set_dict, select_string=get_file_content(
|
||||
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_dataset.sql')), with_search_one=True)
|
||||
|
||||
def edit(self, dataset: Dict):
|
||||
def edit(self, dataset: Dict, user_id: str):
|
||||
"""
|
||||
修改数据集
|
||||
:param user_id: 用户id
|
||||
:param dataset: Dict name desc
|
||||
:return:
|
||||
"""
|
||||
|
|
@ -222,7 +259,7 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||
if 'desc' in dataset:
|
||||
_dataset.desc = dataset.get("desc")
|
||||
_dataset.save()
|
||||
return self.one(with_valid=False)
|
||||
return self.one(with_valid=False, user_id=user_id)
|
||||
|
||||
@staticmethod
|
||||
def get_request_body_api():
|
||||
|
|
|
|||
|
|
@ -1,7 +1,30 @@
|
|||
SELECT
|
||||
dataset.*,
|
||||
document_temp."char_length",
|
||||
"document_temp".document_count
|
||||
*
|
||||
FROM
|
||||
dataset dataset
|
||||
LEFT JOIN ( SELECT "count" ( "id" ) AS document_count, "sum" ( "char_length" ) "char_length", dataset_id FROM "document" GROUP BY dataset_id ) "document_temp" ON dataset."id" = "document_temp".dataset_id
|
||||
(
|
||||
SELECT
|
||||
"temp_dataset".*,
|
||||
"document_temp"."char_length",
|
||||
"document_temp".document_count FROM (
|
||||
SELECT dataset.*
|
||||
FROM
|
||||
dataset dataset
|
||||
${dataset_custom_sql}
|
||||
UNION
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
dataset
|
||||
WHERE
|
||||
dataset."id" IN (
|
||||
SELECT
|
||||
team_member_permission.target
|
||||
FROM
|
||||
team_member team_member
|
||||
LEFT JOIN team_member_permission team_member_permission ON team_member_permission.member_id = team_member."id"
|
||||
${team_member_permission_custom_sql}
|
||||
)
|
||||
) temp_dataset
|
||||
LEFT JOIN ( SELECT "count" ( "id" ) AS document_count, "sum" ( "char_length" ) "char_length", dataset_id FROM "document" GROUP BY dataset_id ) "document_temp" ON temp_dataset."id" = "document_temp".dataset_id
|
||||
) temp
|
||||
${default_sql}
|
||||
|
|
@ -29,7 +29,7 @@ class Dataset(APIView):
|
|||
responses=get_api_response(DataSetSerializers.Query.get_response_body_api()))
|
||||
@has_permissions(PermissionConstants.DATASET_READ, compare=CompareConstants.AND)
|
||||
def get(self, request: Request):
|
||||
d = DataSetSerializers.Query(data=request.query_params)
|
||||
d = DataSetSerializers.Query(data={**request.query_params, 'user_id': str(request.user.id)})
|
||||
d.is_valid()
|
||||
return result.success(d.list())
|
||||
|
||||
|
|
@ -63,7 +63,7 @@ class Dataset(APIView):
|
|||
@has_permissions(lambda r, keywords: Permission(group=Group.DATASET, operate=Operate.USE,
|
||||
dynamic_tag=keywords.get('dataset_id')))
|
||||
def get(self, request: Request, dataset_id: str):
|
||||
return result.success(DataSetSerializers.Operate(data={'id': dataset_id}).one())
|
||||
return result.success(DataSetSerializers.Operate(data={'id': dataset_id}).one(user_id=request.user.id))
|
||||
|
||||
@action(methods="PUT", detail=False)
|
||||
@swagger_auto_schema(operation_summary="修改数据集信息", operation_id="修改数据集信息",
|
||||
|
|
@ -72,7 +72,8 @@ class Dataset(APIView):
|
|||
@has_permissions(lambda r, keywords: Permission(group=Group.DATASET, operate=Operate.MANAGE,
|
||||
dynamic_tag=keywords.get('dataset_id')))
|
||||
def put(self, request: Request, dataset_id: str):
|
||||
return result.success(DataSetSerializers.Operate(data={'id': dataset_id}).edit(request.data))
|
||||
return result.success(
|
||||
DataSetSerializers.Operate(data={'id': dataset_id}).edit(request.data, user_id=request.user.id))
|
||||
|
||||
class Page(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
|
@ -85,6 +86,6 @@ class Dataset(APIView):
|
|||
responses=get_page_api_response(DataSetSerializers.Query.get_response_body_api()))
|
||||
@has_permissions(PermissionConstants.USER_READ, compare=CompareConstants.AND)
|
||||
def get(self, request: Request, current_page, page_size):
|
||||
d = DataSetSerializers.Query(data=request.query_params)
|
||||
d = DataSetSerializers.Query(data={**request.query_params, 'user_id': str(request.user.id)})
|
||||
d.is_valid()
|
||||
return result.success(d.page(current_page, page_size))
|
||||
|
|
|
|||
Loading…
Reference in New Issue