diff --git a/apps/application/__init__.py b/apps/application/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/apps/application/admin.py b/apps/application/admin.py new file mode 100644 index 000000000..8c38f3f3d --- /dev/null +++ b/apps/application/admin.py @@ -0,0 +1,3 @@ +from django.contrib import admin + +# Register your models here. diff --git a/apps/application/apps.py b/apps/application/apps.py new file mode 100644 index 000000000..30c0916b0 --- /dev/null +++ b/apps/application/apps.py @@ -0,0 +1,6 @@ +from django.apps import AppConfig + + +class ApplicationConfig(AppConfig): + default_auto_field = 'django.db.models.BigAutoField' + name = 'application' diff --git a/apps/application/migrations/0001_initial.py b/apps/application/migrations/0001_initial.py new file mode 100644 index 000000000..91af9d5ea --- /dev/null +++ b/apps/application/migrations/0001_initial.py @@ -0,0 +1,50 @@ +# Generated by Django 4.1.10 on 2023-10-09 06:33 + +import django.contrib.postgres.fields +from django.db import migrations, models +import django.db.models.deletion +import uuid + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [ + ('users', '0001_initial'), + ('dataset', '0001_initial'), + ] + + operations = [ + migrations.CreateModel( + name='Application', + fields=[ + ('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')), + ('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')), + ('id', models.UUIDField(default=uuid.uuid1, editable=False, primary_key=True, serialize=False, verbose_name='主键id')), + ('name', models.CharField(max_length=128, verbose_name='应用名称')), + ('desc', models.CharField(max_length=128, verbose_name='引用描述')), + ('prologue', models.CharField(max_length=1024, verbose_name='开场白')), + ('example', django.contrib.postgres.fields.ArrayField(base_field=models.CharField(blank=True, max_length=256), size=None, verbose_name='示例列表')), + ('status', models.BooleanField(default=True, verbose_name='是否发布')), + ('is_active', models.BooleanField(default=True)), + ('user', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='users.user')), + ], + options={ + 'db_table': 'application', + }, + ), + migrations.CreateModel( + name='ApplicationDatasetMapping', + fields=[ + ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')), + ('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')), + ('application', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='application.application')), + ('dataset', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.dataset')), + ], + options={ + 'db_table': 'application_dataset_mapping', + }, + ), + ] diff --git a/apps/application/migrations/__init__.py b/apps/application/migrations/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/apps/application/models/__init__.py b/apps/application/models/__init__.py new file mode 100644 index 000000000..0d579760e --- /dev/null +++ b/apps/application/models/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py + @date:2023/9/25 14:25 + @desc: +""" +from .application import * diff --git a/apps/application/models/application.py b/apps/application/models/application.py new file mode 100644 index 000000000..79cef354d --- /dev/null +++ b/apps/application/models/application.py @@ -0,0 +1,37 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: application.py + @date:2023/9/25 14:24 + @desc: +""" +import uuid + +from django.contrib.postgres.fields import ArrayField +from django.db import models + +from common.mixins.app_model_mixin import AppModelMixin +from dataset.models.data_set import DataSet +from users.models import User + + +class Application(AppModelMixin): + id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id") + name = models.CharField(max_length=128, verbose_name="应用名称") + desc = models.CharField(max_length=128, verbose_name="引用描述") + prologue = models.CharField(max_length=1024, verbose_name="开场白") + example = ArrayField(verbose_name="示例列表", base_field=models.CharField(max_length=256, blank=True)) + status = models.BooleanField(default=True, verbose_name="是否发布") + user = models.ForeignKey(User, on_delete=models.DO_NOTHING) + is_active = models.BooleanField(default=True) + class Meta: + db_table = "application" + + +class ApplicationDatasetMapping(AppModelMixin): + application = models.ForeignKey(Application, on_delete=models.DO_NOTHING) + dataset = models.ForeignKey(DataSet, on_delete=models.DO_NOTHING) + + class Meta: + db_table = "application_dataset_mapping" diff --git a/apps/application/tests.py b/apps/application/tests.py new file mode 100644 index 000000000..7ce503c2d --- /dev/null +++ b/apps/application/tests.py @@ -0,0 +1,3 @@ +from django.test import TestCase + +# Create your tests here. diff --git a/apps/application/views.py b/apps/application/views.py new file mode 100644 index 000000000..91ea44a21 --- /dev/null +++ b/apps/application/views.py @@ -0,0 +1,3 @@ +from django.shortcuts import render + +# Create your views here. diff --git a/apps/common/auth/authenticate.py b/apps/common/auth/authenticate.py index f7cf39a94..bac2dd9b2 100644 --- a/apps/common/auth/authenticate.py +++ b/apps/common/auth/authenticate.py @@ -7,13 +7,15 @@ @desc: 认证类 """ -from common.constants.permission_constants import Auth, get_permission_list_by_role, RoleConstants -from common.exception.app_exception import AppAuthenticationFailed from django.core import cache from django.core import signing +from django.db.models import QuerySet from rest_framework.authentication import TokenAuthentication + +from common.constants.permission_constants import Auth, get_permission_list_by_role, RoleConstants +from common.exception.app_exception import AppAuthenticationFailed from smartdoc.settings import JWT_AUTH -from users.models.user import User +from users.models.user import User, get_user_dynamics_permission token_cache = cache.caches['token_cache'] @@ -38,12 +40,15 @@ class TokenAuth(TokenAuthentication): cache_token = token_cache.get(auth) if cache_token is None: raise AppAuthenticationFailed(1002, "登录过期") - user = User.objects.get(id=user['id']) + user = QuerySet(User).get(id=user['id']) # 续期 token_cache.touch(auth, timeout=JWT_AUTH['JWT_EXPIRATION_DELTA'].total_seconds()) rule = RoleConstants[user.role] + permission_list = get_permission_list_by_role(RoleConstants[user.role]) + # 获取用户的应用和数据集的权限 + permission_list += get_user_dynamics_permission(str(user.id)) return user, Auth(role_list=[rule], - permission_list=get_permission_list_by_role(RoleConstants[user.role])) + permission_list=permission_list) else: raise AppAuthenticationFailed(1002, "身份验证信息不正确!非法用户") diff --git a/apps/common/auth/authentication.py b/apps/common/auth/authentication.py index a82404766..5a155bc66 100644 --- a/apps/common/auth/authentication.py +++ b/apps/common/auth/authentication.py @@ -8,7 +8,8 @@ """ from typing import List -from common.constants.permission_constants import ViewPermission, CompareConstants, RoleConstants, PermissionConstants +from common.constants.permission_constants import ViewPermission, CompareConstants, RoleConstants, PermissionConstants, \ + Permission from common.exception.app_exception import AppUnauthorizedFailed @@ -55,9 +56,19 @@ def exist_permissions(user_role: List[RoleConstants], user_permission: List[Perm return exist_role_by_role_constants(user_role, [permission]) elif isinstance(permission, PermissionConstants): return exist_permissions_by_permission_constants(user_permission, [permission]) + elif isinstance(permission, Permission): + return user_permission.__contains__(permission) return False +def exist(user_role: List[RoleConstants], user_permission: List[PermissionConstants], permission, request, **kwargs): + if callable(permission): + p = permission(request, kwargs) + return exist_permissions(user_role, user_permission, p) + else: + return exist_permissions(user_role, user_permission, permission) + + def has_permissions(*permission, compare=CompareConstants.OR): """ 权限 role or permission @@ -69,7 +80,8 @@ def has_permissions(*permission, compare=CompareConstants.OR): def inner(func): def run(view, request, **kwargs): exit_list = list( - map(lambda p: exist_permissions(request.auth.role_list, request.auth.permission_list, p), permission)) + map(lambda p: exist(request.auth.role_list, request.auth.permission_list, p, request, **kwargs), + permission)) # 判断是否有权限 if any(exit_list) if compare == CompareConstants.OR else all(exit_list): return func(view, request, **kwargs) diff --git a/apps/common/constants/permission_constants.py b/apps/common/constants/permission_constants.py index 51c7df5c7..8ee8d9230 100644 --- a/apps/common/constants/permission_constants.py +++ b/apps/common/constants/permission_constants.py @@ -16,6 +16,12 @@ class Group(Enum): """ USER = "USER" + DATASET = "DATASET" + + APPLICATION = "APPLICATION" + + SETTING = "SETTING" + class Operate(Enum): """ @@ -25,6 +31,14 @@ class Operate(Enum): EDIT = "EDIT" CREATE = "CREATE" DELETE = "DELETE" + """ + 管理权限 + """ + MANAGE = "MANAGE" + """ + 使用权限 + """ + USE = "USE" class Role: @@ -44,13 +58,20 @@ class Permission: 权限信息 """ - def __init__(self, group: Group, operate: Operate, roles: List[RoleConstants]): + def __init__(self, group: Group, operate: Operate, roles=None, dynamic_tag=None): + if roles is None: + roles = [] self.group = group self.operate = operate self.roleList = roles + self.dynamic_tag = dynamic_tag def __str__(self): - return self.group.value + ":" + self.operate.value + return self.group.value + ":" + self.operate.value + ( + (":" + self.dynamic_tag) if self.dynamic_tag is not None else '') + + def __eq__(self, other): + return str(self) == str(other) class PermissionConstants(Enum): @@ -59,7 +80,19 @@ class PermissionConstants(Enum): """ USER_READ = Permission(group=Group.USER, operate=Operate.READ, roles=[RoleConstants.ADMIN, RoleConstants.USER]) USER_EDIT = Permission(group=Group.USER, operate=Operate.EDIT, roles=[RoleConstants.ADMIN, RoleConstants.USER]) - USER_DELETE = Permission(group=Group.USER, operate=Operate.EDIT, roles=[RoleConstants.USER]) + USER_DELETE = Permission(group=Group.USER, operate=Operate.DELETE, roles=[RoleConstants.USER]) + + DATASET_CREATE = Permission(group=Group.DATASET, operate=Operate.CREATE, + roles=[RoleConstants.ADMIN, RoleConstants.USER]) + + DATASET_READ = Permission(group=Group.DATASET, operate=Operate.READ, + roles=[RoleConstants.ADMIN, RoleConstants.USER]) + + APPLICATION_READ = Permission(group=Group.APPLICATION, operate=Operate.READ, + roles=[RoleConstants.ADMIN, RoleConstants.USER]) + + SETTING_READ = Permission(group=Group.SETTING, operate=Operate.READ, + roles=[RoleConstants.ADMIN, RoleConstants.USER]) def get_permission_list_by_role(role: RoleConstants): diff --git a/apps/common/db/compiler.py b/apps/common/db/compiler.py new file mode 100644 index 000000000..81f21664f --- /dev/null +++ b/apps/common/db/compiler.py @@ -0,0 +1,123 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: compiler.py + @date:2023/10/7 10:53 + @desc: +""" + +from django.core.exceptions import EmptyResultSet +from django.db import NotSupportedError +from django.db.models.sql.compiler import SQLCompiler + + +class AppSQLCompiler(SQLCompiler): + def __init__(self, query, connection, using, elide_empty=True, field_replace_dict=None): + super().__init__(query, connection, using, elide_empty) + if field_replace_dict is None: + field_replace_dict = {} + self.field_replace_dict = field_replace_dict + + def get_query_str(self, with_limits=True, with_table_name=True): + refcounts_before = self.query.alias_refcount.copy() + try: + extra_select, order_by, group_by = self.pre_sql_setup() + for_update_part = None + # Is a LIMIT/OFFSET clause needed? + with_limit_offset = with_limits and ( + self.query.high_mark is not None or self.query.low_mark + ) + combinator = self.query.combinator + features = self.connection.features + if combinator: + if not getattr(features, "supports_select_{}".format(combinator)): + raise NotSupportedError( + "{} is not supported on this database backend.".format( + combinator + ) + ) + result, params = self.get_combinator_sql( + combinator, self.query.combinator_all + ) + else: + distinct_fields, distinct_params = self.get_distinct() + try: + where, w_params = ( + self.compile(self.where) if self.where is not None else ("", []) + ) + except EmptyResultSet: + if self.elide_empty: + raise + # Use a predicate that's always False. + where, w_params = "0 = 1", [] + having, h_params = ( + self.compile(self.having) if self.having is not None else ("", []) + ) + result = [] + params = [] + if where: + result.append("WHERE %s" % where) + params.extend(w_params) + + grouping = [] + for g_sql, g_params in group_by: + grouping.append(g_sql) + params.extend(g_params) + if grouping: + if distinct_fields: + raise NotImplementedError( + "annotate() + distinct(fields) is not implemented." + ) + order_by = order_by or self.connection.ops.force_no_ordering() + result.append("GROUP BY %s" % ", ".join(grouping)) + if self._meta_ordering: + order_by = None + if having: + result.append("HAVING %s" % having) + params.extend(h_params) + + if self.query.explain_info: + result.insert( + 0, + self.connection.ops.explain_query_prefix( + self.query.explain_info.format, + **self.query.explain_info.options, + ), + ) + + if order_by: + ordering = [] + for _, (o_sql, o_params, _) in order_by: + ordering.append(o_sql) + params.extend(o_params) + result.append("ORDER BY %s" % ", ".join(ordering)) + + if with_limit_offset: + result.append( + self.connection.ops.limit_offset_sql( + self.query.low_mark, self.query.high_mark + ) + ) + + if for_update_part and not features.for_update_after_from: + result.append(for_update_part) + from_, f_params = self.get_from_clause() + sql = " ".join(result) + if not with_table_name: + for table_name in from_: + sql = sql.replace(table_name + ".", "") + for key in self.field_replace_dict.keys(): + value = self.field_replace_dict.get(key) + sql = sql.replace(key, value) + return sql, tuple(params) + finally: + # Finally do cleanup - get rid of the joins we created above. + self.query.reset_refcounts(refcounts_before) + + def as_sql(self, with_limits=True, with_col_aliases=False, select_string=None): + if select_string is None: + return super().as_sql(with_limits, with_col_aliases) + else: + sql, params = self.get_query_str(with_table_name=False) + return (select_string + " " + sql), params diff --git a/apps/common/db/search.py b/apps/common/db/search.py new file mode 100644 index 000000000..6b772a632 --- /dev/null +++ b/apps/common/db/search.py @@ -0,0 +1,124 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: search.py + @date:2023/10/7 18:20 + @desc: +""" + +from django.db import DEFAULT_DB_ALIAS, models +from django.db.models import QuerySet + +from common.db.compiler import AppSQLCompiler +from common.db.sql_execute import select_one, select_list +from common.response.result import Page + + +def get_dynamics_model(attr: dict, table_name='dynamics'): + """ + 获取一个动态的django模型 + :param attr: 模型字段 + :param table_name: 表名 + :return: django 模型 + """ + attributes = { + "__module__": "dataset.models", + "Meta": type("Meta", (), {'db_table': table_name}), + **attr + } + return type('Dynamics', (models.Model,), attributes) + + +def native_search(queryset: QuerySet, select_string: str, + field_replace_dict=None, + with_search_one=False): + """ + 复杂查询 + :param queryset: 查询条件构造器 + :param select_string: 查询前缀 不包括 where limit 等信息 + :param field_replace_dict: 需要替换的字段 + :param with_search_one: 查询 + :return: 查询结果 + """ + if field_replace_dict is None: + field_replace_dict = get_field_replace_dict(queryset) + q = queryset.query + compiler = q.get_compiler(DEFAULT_DB_ALIAS) + app_sql_compiler = AppSQLCompiler(q, using=DEFAULT_DB_ALIAS, connection=compiler.connection, + field_replace_dict=field_replace_dict) + sql, params = app_sql_compiler.get_query_str(with_table_name=False) + if with_search_one: + return select_one(select_string + " " + + sql, params) + else: + return select_list(select_string + " " + + sql, params) + + +def page_search(current_page: int, page_size: int, queryset: QuerySet, post_records_handler): + """ + 分页查询 + :param current_page: 当前页 + :param page_size: 每页大小 + :param queryset: 查询条件 + :param post_records_handler: 数据处理器 + :return: 分页结果 + """ + total = QuerySet(query=queryset.query.clone(), model=queryset.model).count() + result = queryset.all()[((current_page - 1) * page_size):(current_page * page_size)] + return Page(total, list(map(post_records_handler, result)), current_page, page_size) + + +def native_page_search(current_page: int, page_size: int, queryset: QuerySet, select_string: str, + field_replace_dict=None, + post_records_handler=lambda r: r): + """ + 复杂分页查询 + :param current_page: 当前页 + :param page_size: 每页大小 + :param queryset: 查询条件 + :param select_string: 查询 + :param field_replace_dict: 特殊字段替换 + :param post_records_handler: 数据row处理器 + :return: 分页结果 + """ + if field_replace_dict is None: + field_replace_dict = get_field_replace_dict(queryset) + q = queryset.query + compiler = q.get_compiler(DEFAULT_DB_ALIAS) + app_sql_compiler = AppSQLCompiler(q, using=DEFAULT_DB_ALIAS, connection=compiler.connection, + field_replace_dict=field_replace_dict) + page_sql, params = app_sql_compiler.get_query_str(with_table_name=False) + total_sql = "SELECT \"count\"(*) FROM (%s) temp" % (select_string + " " + page_sql) + total = select_one(total_sql, params) + q.set_limits(((current_page - 1) * page_size), (current_page * page_size)) + app_sql_compiler = AppSQLCompiler(q, using=DEFAULT_DB_ALIAS, connection=compiler.connection, + field_replace_dict=field_replace_dict) + page_sql, params = app_sql_compiler.get_query_str(with_table_name=False) + result = select_list(select_string + " " + page_sql, params) + return Page(total.get("count"), list(map(post_records_handler, result)), current_page, page_size) + + +def get_field_replace_dict(queryset: QuerySet): + """ + 获取需要替换的字段 默认 “xxx.xxx”需要被替换成 “xxx”."xxx" + :param queryset: 查询对象 + :return: 需要替换的字典 + """ + result = {} + for field in queryset.model._meta.local_fields: + if field.attname.__contains__("."): + replace_field = to_replace_field(field.attname) + result.__setitem__('"' + field.attname + '"', replace_field) + return result + + +def to_replace_field(field: str): + """ + 将field 转换为 需要替换的field “xxx.xxx”需要被替换成 “xxx”."xxx" 只替换 field包含.的字段 + :param field: django field字段 + :return: 替换字段 + """ + split_field = field.split(".") + return ".".join(list(map(lambda sf: '"' + sf + '"', split_field))) diff --git a/apps/common/db/sql_execute.py b/apps/common/db/sql_execute.py new file mode 100644 index 000000000..3f459759e --- /dev/null +++ b/apps/common/db/sql_execute.py @@ -0,0 +1,53 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: sql_execute.py + @date:2023/9/25 20:05 + @desc: +""" +from typing import List + +from django.db import connection + + +def sql_execute(sql: str, params): + """ + 执行一条sql + :param sql: 需要执行的sql + :param params: sql参数 + :return: 执行结果 + """ + with connection.cursor() as cursor: + cursor.execute(sql, params) + columns = list(map(lambda d: d.name, cursor.description)) + res = cursor.fetchall() + result = list(map(lambda row: dict(list(zip(columns, row))), res)) + cursor.close() + return result + + +def select_list(sql: str, params: List): + """ + 执行sql 查询列表数据 + :param sql: 需要执行的sql + :param params: sql的参数 + :return: 查询结果 + """ + result_list = sql_execute(sql, params) + if result_list is None: + return [] + return result_list + + +def select_one(sql: str, params: List): + """ + 执行sql 查询一条数据 + :param sql: 需要执行的sql + :param params: 参数 + :return: 查询结果 + """ + result_list = sql_execute(sql, params) + if result_list is None or len(result_list) == 0: + return None + return result_list[0] diff --git a/apps/common/handle/handle_exception.py b/apps/common/handle/handle_exception.py index c3fa367e5..5bab9e3f8 100644 --- a/apps/common/handle/handle_exception.py +++ b/apps/common/handle/handle_exception.py @@ -6,6 +6,8 @@ @date:2023/9/5 19:29 @desc: """ +import django.core.exceptions +from psycopg2 import IntegrityError from rest_framework.exceptions import ValidationError, ErrorDetail, APIException from rest_framework.views import exception_handler @@ -13,19 +15,27 @@ from common.exception.app_exception import AppApiException from common.response import result -def to_result(key, args): +def to_result(key, args, parent_key=None): """ 将校验异常 args转换为统一数据 - :param key: 校验key - :param args: 校验异常参数 + :param key: 校验key + :param args: 校验异常参数 + :param parent_key 父key :return: 接口响应对象 """ - error_detail = (args[0] if len(args) > 0 else {key: [ErrorDetail('未知异常', code='unknown')]}).get(key)[ - 0] + error_detail = list(filter( + lambda d: True if isinstance(d, ErrorDetail) else True if isinstance(d, dict) and len( + d.keys()) > 0 else False, + (args[0] if len(args) > 0 else {key: [ErrorDetail('未知异常', code='unknown')]}).get(key)))[0] + + if isinstance(error_detail, dict): + return list(map(lambda k: to_result(k, args=[error_detail], + parent_key=key if parent_key is None else parent_key + '.' + key), + error_detail.keys() if len(error_detail) > 0 else []))[0] return result.Result(500 if isinstance(error_detail.code, str) else error_detail.code, - message=f"【{key}】为必填参数" if str( - error_detail) == "This field is required." else error_detail) + message=f"【{key if parent_key is None else parent_key + '.' + key}】为必填参数" if str( + error_detail) == "This field is required." else f"【{key if parent_key is None else parent_key + '.' + key}】" + error_detail) def validation_error_to_result(exc: ValidationError): @@ -34,8 +44,11 @@ def validation_error_to_result(exc: ValidationError): :param exc: 校验异常 :return: 接口响应对象 """ - res = list(map(lambda key: to_result(key, args=exc.args), - exc.args[0].keys() if len(exc.args) > 0 else [])) + try: + res = list(map(lambda key: to_result(key, args=exc.args), + exc.args[0].keys() if len(exc.args) > 0 else [])) + except Exception as e: + return result.error(str(exc.detail)) if len(res) > 0: return res[0] else: @@ -53,4 +66,6 @@ def handle_exception(exc, context): return result.Result(exc.code, exc.message, response_status=exc.status_code) if issubclass(exception_class, APIException): return result.error(exc.detail) + if response is None: + return result.error(str(exc)) return response diff --git a/apps/common/mixins/api_mixin.py b/apps/common/mixins/api_mixin.py index d90a0dc4a..d2625a0d1 100644 --- a/apps/common/mixins/api_mixin.py +++ b/apps/common/mixins/api_mixin.py @@ -11,11 +11,14 @@ from rest_framework import serializers class ApiMixin(serializers.Serializer): - def get_request_params_api(self): + @staticmethod + def get_request_params_api(): pass - def get_request_body_api(self): + @staticmethod + def get_request_body_api(): pass - def get_response_body_api(self): + @staticmethod + def get_response_body_api(): pass diff --git a/apps/common/mixins/app_model_mixin.py b/apps/common/mixins/app_model_mixin.py new file mode 100644 index 000000000..412dbae00 --- /dev/null +++ b/apps/common/mixins/app_model_mixin.py @@ -0,0 +1,18 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: app_model_mixin.py + @date:2023/9/21 9:41 + @desc: +""" +from django.db import models + + +class AppModelMixin(models.Model): + create_time = models.DateTimeField(verbose_name="创建时间", auto_now_add=True) + update_time = models.DateTimeField(verbose_name="修改时间", auto_now=True) + + class Meta: + abstract = True + ordering = ['create_time'] diff --git a/apps/common/response/result.py b/apps/common/response/result.py index d42dc5a47..c3992e18a 100644 --- a/apps/common/response/result.py +++ b/apps/common/response/result.py @@ -1,8 +1,19 @@ +from typing import List + from django.http import JsonResponse from drf_yasg import openapi from rest_framework import status +class Page(dict): + """ + 分页对象 + """ + + def __init__(self, total: int, records: List, current_page: int, page_size: int, **kwargs): + super().__init__(**{'total': total, 'records': records, 'current_page': current_page, 'page_size': page_size}) + + class Result(JsonResponse): """ 接口统一返回对象 @@ -13,7 +24,73 @@ class Result(JsonResponse): super().__init__(data=back_info_dict, status=response_status) -def get_api_response(response_data_schema: openapi.Schema, data_examples): +def get_page_request_params(other_request_params=None): + if other_request_params is None: + other_request_params = [] + current_page = openapi.Parameter(name='current_page', + in_=openapi.IN_PATH, + type=openapi.TYPE_INTEGER, + required=True, + description='当前页') + + page_size = openapi.Parameter(name='page_size', + in_=openapi.IN_PATH, + type=openapi.TYPE_INTEGER, + required=True, + description='每页大小') + result = [current_page, page_size] + for other_request_param in other_request_params: + result.append(other_request_param) + return result + + +def get_page_api_response(response_data_schema: openapi.Schema): + """ + 获取统一返回 响应Api + """ + return openapi.Responses(responses={200: openapi.Response(description="响应参数", + schema=openapi.Schema( + type=openapi.TYPE_OBJECT, + properties={ + 'code': openapi.Schema( + type=openapi.TYPE_INTEGER, + title="响应码", + default=200, + description="成功:200 失败:其他"), + "message": openapi.Schema( + type=openapi.TYPE_STRING, + title="提示", + default='成功', + description="错误提示"), + "data": openapi.Schema( + type=openapi.TYPE_OBJECT, + properties={ + 'total': openapi.Schema( + type=openapi.TYPE_INTEGER, + title="总条数", + default=1, + description="数据总条数"), + "records": response_data_schema, + "current_page": openapi.Schema( + type=openapi.TYPE_INTEGER, + title="当前页", + default=1, + description="当前页"), + "page_size": openapi.Schema( + type=openapi.TYPE_INTEGER, + title="每页大小", + default=10, + description="每页大小") + + } + ) + + } + ), + )}) + + +def get_api_response(response_data_schema: openapi.Schema): """ 获取统一返回 响应Api """ @@ -35,9 +112,7 @@ def get_api_response(response_data_schema: openapi.Schema, data_examples): } ), - examples={'code': 200, - 'data': data_examples, - 'message': "成功"})}) + )}) def success(data): diff --git a/apps/common/util/file_util.py b/apps/common/util/file_util.py new file mode 100644 index 000000000..f46c460a7 --- /dev/null +++ b/apps/common/util/file_util.py @@ -0,0 +1,16 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: file_util.py + @date:2023/9/25 21:06 + @desc: +""" + + +def get_file_content(path): + file = open(path, "r", + encoding='utf-8') + content = file.read() + file.close() + return content diff --git a/apps/dataset/__init__.py b/apps/dataset/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/apps/dataset/apps.py b/apps/dataset/apps.py new file mode 100644 index 000000000..166bedbe7 --- /dev/null +++ b/apps/dataset/apps.py @@ -0,0 +1,6 @@ +from django.apps import AppConfig + + +class DatasetConfig(AppConfig): + default_auto_field = 'django.db.models.BigAutoField' + name = 'dataset' diff --git a/apps/dataset/migrations/0001_initial.py b/apps/dataset/migrations/0001_initial.py new file mode 100644 index 000000000..a5d4dd3f2 --- /dev/null +++ b/apps/dataset/migrations/0001_initial.py @@ -0,0 +1,91 @@ +# Generated by Django 4.1.10 on 2023-10-09 06:33 + +from django.db import migrations, models +import django.db.models.deletion +import uuid + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [ + ('users', '0001_initial'), + ] + + operations = [ + migrations.CreateModel( + name='DataSet', + fields=[ + ('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')), + ('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')), + ('id', models.UUIDField(default=uuid.uuid1, editable=False, primary_key=True, serialize=False, verbose_name='主键id')), + ('name', models.CharField(max_length=150, verbose_name='数据集名称')), + ('desc', models.CharField(max_length=256, verbose_name='数据库描述')), + ('user', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='users.user', verbose_name='所属用户')), + ], + options={ + 'db_table': 'dataset', + }, + ), + migrations.CreateModel( + name='Document', + fields=[ + ('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')), + ('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')), + ('id', models.UUIDField(default=uuid.uuid1, editable=False, primary_key=True, serialize=False, verbose_name='主键id')), + ('name', models.CharField(max_length=150, verbose_name='文档名称')), + ('char_length', models.IntegerField(verbose_name='文档字符数 冗余字段')), + ('is_active', models.BooleanField(default=True)), + ('dataset', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.dataset')), + ], + options={ + 'db_table': 'document', + }, + ), + migrations.CreateModel( + name='Paragraph', + fields=[ + ('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')), + ('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')), + ('id', models.UUIDField(default=uuid.uuid1, editable=False, primary_key=True, serialize=False, verbose_name='主键id')), + ('content', models.CharField(max_length=1024, verbose_name='段落内容')), + ('hit_num', models.IntegerField(default=0, verbose_name='命中数量')), + ('star_num', models.IntegerField(default=0, verbose_name='点赞数')), + ('trample_num', models.IntegerField(default=0, verbose_name='点踩数')), + ('is_active', models.BooleanField(default=True)), + ('document', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.document')), + ], + options={ + 'db_table': 'paragraph', + }, + ), + migrations.CreateModel( + name='Problem', + fields=[ + ('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')), + ('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')), + ('id', models.UUIDField(default=uuid.uuid1, editable=False, primary_key=True, serialize=False, verbose_name='主键id')), + ('content', models.CharField(max_length=256, verbose_name='问题内容')), + ], + options={ + 'db_table': 'problem', + }, + ), + migrations.CreateModel( + name='ProblemAnswerMapping', + fields=[ + ('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')), + ('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')), + ('id', models.UUIDField(default=uuid.uuid1, editable=False, primary_key=True, serialize=False, verbose_name='主键id')), + ('hit_num', models.IntegerField(default=0, verbose_name='命中数量')), + ('star_num', models.IntegerField(default=0, verbose_name='点赞数')), + ('trample_num', models.IntegerField(default=0, verbose_name='点踩数')), + ('paragraph', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.paragraph')), + ('problem', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.problem')), + ], + options={ + 'db_table': 'problem_paragraph_mapping', + }, + ), + ] diff --git a/apps/dataset/migrations/__init__.py b/apps/dataset/migrations/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/apps/dataset/models/__init__.py b/apps/dataset/models/__init__.py new file mode 100644 index 000000000..fdee77b4d --- /dev/null +++ b/apps/dataset/models/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py + @date:2023/9/21 9:32 + @desc: +""" +from .data_set import * diff --git a/apps/dataset/models/data_set.py b/apps/dataset/models/data_set.py new file mode 100644 index 000000000..cc39a07f2 --- /dev/null +++ b/apps/dataset/models/data_set.py @@ -0,0 +1,84 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: data_set.py + @date:2023/9/21 9:35 + @desc: 数据集 +""" +import uuid + +from django.db import models + +from common.mixins.app_model_mixin import AppModelMixin +from users.models import User + + +class DataSet(AppModelMixin): + """ + 数据集表 + """ + id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id") + name = models.CharField(max_length=150, verbose_name="数据集名称") + desc = models.CharField(max_length=256, verbose_name="数据库描述") + user = models.ForeignKey(User, on_delete=models.DO_NOTHING, verbose_name="所属用户") + + class Meta: + db_table = "dataset" + + +class Document(AppModelMixin): + """ + 文档表 + """ + id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id") + dataset = models.ForeignKey(DataSet, on_delete=models.DO_NOTHING) + name = models.CharField(max_length=150, verbose_name="文档名称") + char_length = models.IntegerField(verbose_name="文档字符数 冗余字段") + is_active = models.BooleanField(default=True) + + class Meta: + db_table = "document" + + +class Paragraph(AppModelMixin): + """ + 段落表 + """ + id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id") + document = models.ForeignKey(Document, on_delete=models.DO_NOTHING) + content = models.CharField(max_length=1024, verbose_name="段落内容") + hit_num = models.IntegerField(verbose_name="命中数量", default=0) + star_num = models.IntegerField(verbose_name="点赞数", default=0) + trample_num = models.IntegerField(verbose_name="点踩数", default=0) + is_active = models.BooleanField(default=True) + + class Meta: + db_table = "paragraph" + + +class Problem(AppModelMixin): + """ + 问题表 + """ + id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id") + content = models.CharField(max_length=256, verbose_name="问题内容") + + class Meta: + db_table = "problem" + + +class ProblemAnswerMapping(AppModelMixin): + """ + 问题 段落 映射表 + """ + id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id") + paragraph = models.ForeignKey(Paragraph, on_delete=models.DO_NOTHING) + problem = models.ForeignKey(Problem, on_delete=models.DO_NOTHING) + hit_num = models.IntegerField(verbose_name="命中数量", default=0) + star_num = models.IntegerField(verbose_name="点赞数", default=0) + trample_num = models.IntegerField(verbose_name="点踩数", default=0) + + class Meta: + db_table = "problem_paragraph_mapping" + diff --git a/apps/dataset/serializers/dataset_serializers.py b/apps/dataset/serializers/dataset_serializers.py new file mode 100644 index 000000000..5dfcf4c53 --- /dev/null +++ b/apps/dataset/serializers/dataset_serializers.py @@ -0,0 +1,274 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: dataset_serializers.py + @date:2023/9/21 16:14 + @desc: +""" +import os.path +import uuid +from functools import reduce +from typing import Dict + +from django.core import validators +from django.db import transaction, models +from django.db.models import QuerySet +from drf_yasg import openapi +from rest_framework import serializers + +from common.db.search import get_dynamics_model, native_page_search, native_search +from common.exception.app_exception import AppApiException +from common.mixins.api_mixin import ApiMixin +from common.util.file_util import get_file_content +from dataset.models.data_set import DataSet, Document, Paragraph +from dataset.serializers.document_serializers import CreateDocumentSerializers +from smartdoc.conf import PROJECT_DIR +from users.models import User + +""" +# __exact 精确等于 like ‘aaa’ +# __iexact 精确等于 忽略大小写 ilike 'aaa' +# __contains 包含like '%aaa%' +# __icontains 包含 忽略大小写 ilike ‘%aaa%’,但是对于sqlite来说,contains的作用效果等同于icontains。 +# __gt 大于 +# __gte 大于等于 +# __lt 小于 +# __lte 小于等于 +# __in 存在于一个list范围内 +# __startswith 以…开头 +# __istartswith 以…开头 忽略大小写 +# __endswith 以…结尾 +# __iendswith 以…结尾,忽略大小写 +# __range 在…范围内 +# __year 日期字段的年份 +# __month 日期字段的月份 +# __day 日期字段的日 +# __isnull=True/False +""" + + +class DataSetSerializers(serializers.ModelSerializer): + class Meta: + model = DataSet + fields = ['id', 'name', 'desc', 'create_time', 'update_time'] + + class Query(ApiMixin, serializers.Serializer): + """ + 查询对象 + """ + name = serializers.CharField(required=False, + validators=[ + validators.MaxLengthValidator(limit_value=20, + message="数据集名称在1-20个字符之间"), + validators.MinLengthValidator(limit_value=1, + message="数据集名称在1-20个字符之间") + ]) + + desc = serializers.CharField(required=False, + validators=[ + validators.MaxLengthValidator(limit_value=256, + message="数据集名称在1-256个字符之间"), + validators.MinLengthValidator(limit_value=1, + message="数据集名称在1-256个字符之间") + ]) + + def get_query_set(self): + query_set = QuerySet(model=get_dynamics_model( + {'dataset.name': models.CharField(), 'dataset.desc': models.CharField(), + "document_temp.char_length": models.IntegerField()})) + if "desc" in self.data: + query_string = {'dataset.desc__contains', self.data.get("desc")} + query_set = query_set.filter(query_string) + if "name" in self.data: + query_string = {'dataset.name__contains', self.data.get("name")} + query_set = query_set.filter(query_string) + return query_set + + def page(self, current_page: int, page_size: int): + return native_page_search(current_page, page_size, self.get_query_set(), select_string=get_file_content( + os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_dataset.sql')), + post_records_handler=lambda r: r) + + def list(self): + return native_search(self.get_query_set(), select_string=get_file_content( + os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_dataset.sql'))) + + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='name', + in_=openapi.IN_QUERY, + type=openapi.TYPE_STRING, + required=False, + description='数据集名称'), + openapi.Parameter(name='desc', + in_=openapi.IN_QUERY, + type=openapi.TYPE_STRING, + required=False, + description='数据集描述') + ] + + @staticmethod + def get_response_body_api(): + return openapi.Schema(type=openapi.TYPE_ARRAY, + title="数据集列表", description="数据集列表", + items=DataSetSerializers.Operate.get_response_body_api()) + + class Create(ApiMixin, serializers.Serializer): + """ + 创建序列化对象 + """ + name = serializers.CharField(required=True, + validators=[ + validators.MaxLengthValidator(limit_value=20, + message="数据集名称在1-20个字符之间"), + validators.MinLengthValidator(limit_value=1, + message="数据集名称在1-20个字符之间") + ]) + + desc = serializers.CharField(required=True, + validators=[ + validators.MaxLengthValidator(limit_value=256, + message="数据集名称在1-256个字符之间"), + validators.MinLengthValidator(limit_value=1, + message="数据集名称在1-256个字符之间") + ]) + + documents = CreateDocumentSerializers(required=False, many=True) + + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + return True + + @transaction.atomic + def save(self, user: User): + dataset_id = uuid.uuid1() + dataset = DataSet( + **{'id': dataset_id, 'name': self.data.get("name"), 'desc': self.data.get('desc'), 'user': user}) + document_model_list = [] + paragraph_model_list = [] + if 'documents' in self.data: + documents = self.data.get('documents') + for document in documents: + document_model = Document(**{'dataset': dataset, 'id': uuid.uuid1(), 'name': document.get('name'), + 'char_length': reduce(lambda x, y: x + y, + list( + map(lambda p: len(p), + document.get("paragraphs"))), 0)}) + document_model_list.append(document_model) + if 'paragraphs' in document: + paragraph_model_list += list(map(lambda p: Paragraph( + **{'document': document_model, 'id': uuid.uuid1(), 'content': p}), + document.get('paragraphs'))) + # 插入数据集 + dataset.save() + # 插入文档 + QuerySet(Document).bulk_create(document_model_list) if len(document_model_list) > 0 else None + # 插入段落 + QuerySet(Paragraph).bulk_create(paragraph_model_list) if len(paragraph_model_list) > 0 else None + return True + + @staticmethod + def get_request_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['name', 'desc'], + properties={ + 'name': openapi.Schema(type=openapi.TYPE_STRING, title="数据集名称", description="数据集名称"), + 'desc': openapi.Schema(type=openapi.TYPE_STRING, title="数据集描述", description="数据集描述"), + 'documents': openapi.Schema(type=openapi.TYPE_ARRAY, title="文档数据", description="文档数据", + items=CreateDocumentSerializers().get_request_body_api() + ) + } + ) + + class Operate(ApiMixin, serializers.Serializer): + id = serializers.CharField(required=True) + + def is_valid(self, *, raise_exception=True): + super().is_valid(raise_exception=True) + if not QuerySet(DataSet).filter(id=self.data.get("id")).exists(): + raise AppApiException(300, "id不存在") + + @transaction.atomic + def delete(self): + self.is_valid() + dataset = QuerySet(DataSet).get(id=self.data.get("id")) + document_list = QuerySet(Document).filter(dataset=dataset) + QuerySet(Paragraph).filter(document__in=document_list).delete() + document_list.delete() + dataset.delete() + return True + + def one(self, with_valid=True): + if with_valid: + self.is_valid() + query_string = {'dataset.id', self.data.get("id")} + query_set = QuerySet(model=get_dynamics_model( + {'dataset.id': models.UUIDField()})).filter(query_string) + return native_search(query_set, select_string=get_file_content( + os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_dataset.sql')), with_search_one=True) + + def edit(self, dataset: Dict): + """ + 修改数据集 + :param dataset: Dict name desc + :return: + """ + self.is_valid() + _dataset = QuerySet(DataSet).get(id=self.data.get("id")) + if "name" in dataset: + _dataset.name = dataset.get("name") + if 'desc' in dataset: + _dataset.desc = dataset.get("desc") + _dataset.save() + return self.one(with_valid=False) + + @staticmethod + def get_request_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['name', 'desc'], + properties={ + 'name': openapi.Schema(type=openapi.TYPE_STRING, title="数据集名称", description="数据集名称"), + 'desc': openapi.Schema(type=openapi.TYPE_STRING, title="数据集描述", description="数据集描述") + } + ) + + @staticmethod + def get_response_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['id', 'name', 'desc', 'user_id', 'char_length', 'document_count', + 'update_time', 'create_time'], + properties={ + 'id': openapi.Schema(type=openapi.TYPE_STRING, title="id", + description="id", default="xx"), + 'name': openapi.Schema(type=openapi.TYPE_STRING, title="名称", + description="名称", default="测试数据集"), + 'desc': openapi.Schema(type=openapi.TYPE_STRING, title="描述", + description="描述", default="测试数据集描述"), + 'user_id': openapi.Schema(type=openapi.TYPE_STRING, title="所属用户id", + description="所属用户id", default="user_xxxx"), + 'char_length': openapi.Schema(type=openapi.TYPE_STRING, title="字符数", + description="字符数", default=10), + 'document_count': openapi.Schema(type=openapi.TYPE_STRING, title="文档数量", + description="文档数量", default=1), + 'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间", + description="修改时间", + default="1970-01-01 00:00:00"), + 'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间", + description="创建时间", + default="1970-01-01 00:00:00" + ) + } + ) + + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=False, + description='数据集id') + ] diff --git a/apps/dataset/serializers/document_serializers.py b/apps/dataset/serializers/document_serializers.py new file mode 100644 index 000000000..1f1e92b04 --- /dev/null +++ b/apps/dataset/serializers/document_serializers.py @@ -0,0 +1,79 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: document_serializers.py + @date:2023/9/22 13:43 + @desc: +""" +import uuid +from functools import reduce + +from django.core import validators +from django.db.models import QuerySet +from drf_yasg import openapi +from rest_framework import serializers + +from common.exception.app_exception import AppApiException +from common.mixins.api_mixin import ApiMixin +from dataset.models.data_set import DataSet, Document, Paragraph + + +class CreateDocumentSerializers(ApiMixin, serializers.Serializer): + name = serializers.CharField(required=True, + validators=[ + validators.MaxLengthValidator(limit_value=128, + message="文档名称在1-128个字符之间"), + validators.MinLengthValidator(limit_value=1, + message="数据集名称在1-128个字符之间") + ]) + + paragraphs = serializers.ListField(required=False, + child=serializers.CharField(required=True, + validators=[ + validators.MaxLengthValidator(limit_value=256, + message="段落在1-256个字符之间"), + validators.MinLengthValidator(limit_value=1, + message="段落在1-256个字符之间") + ])) + + def is_valid(self, *, dataset_id=None, raise_exception=False): + if not QuerySet(DataSet).filter(id=dataset_id).exists(): + raise AppApiException(10000, "数据集id不存在") + return super().is_valid(raise_exception=True) + + def save(self, dataset_id: str, **kwargs): + document_model = Document( + **{'dataset': DataSet(id=dataset_id), + 'id': uuid.uuid1(), + 'name': self.data.get('name'), + 'char_length': reduce(lambda x, y: x + y, list(map(lambda p: len(p), self.data.get("paragraphs"))), 0)}) + + paragraph_model_list = list(map(lambda p: Paragraph( + **{'document': document_model, 'id': uuid.uuid1(), 'content': p}), + self.data.get('paragraphs'))) + + # 插入文档 + document_model.save() + # 插入段落 + QuerySet(Paragraph).bulk_create(paragraph_model_list) if len(paragraph_model_list) > 0 else None + return True + + def get_request_body_api(self): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['name', 'paragraph'], + properties={ + 'name': openapi.Schema(type=openapi.TYPE_STRING, title="文档名称", description="文档名称"), + 'paragraphs': openapi.Schema(type=openapi.TYPE_ARRAY, title="段落列表", description="段落列表", + items=openapi.Schema(type=openapi.TYPE_STRING, title="段落数据", + description="段落数据")) + } + ) + + def get_request_params_api(self): + return [openapi.Parameter(name='dataset_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='数据集id')] diff --git a/apps/dataset/sql/list_dataset.sql b/apps/dataset/sql/list_dataset.sql new file mode 100644 index 000000000..c7827e6a8 --- /dev/null +++ b/apps/dataset/sql/list_dataset.sql @@ -0,0 +1,7 @@ +SELECT + dataset.*, + document_temp."char_length", + "document_temp".document_count +FROM + dataset dataset + LEFT JOIN ( SELECT "count" ( "id" ) AS document_count, "sum" ( "char_length" ) "char_length", dataset_id FROM "document" GROUP BY dataset_id ) "document_temp" ON dataset."id" = "document_temp".dataset_id \ No newline at end of file diff --git a/apps/dataset/tests.py b/apps/dataset/tests.py new file mode 100644 index 000000000..7ce503c2d --- /dev/null +++ b/apps/dataset/tests.py @@ -0,0 +1,3 @@ +from django.test import TestCase + +# Create your tests here. diff --git a/apps/dataset/urls.py b/apps/dataset/urls.py new file mode 100644 index 000000000..8e196cd26 --- /dev/null +++ b/apps/dataset/urls.py @@ -0,0 +1,11 @@ +from django.urls import path + +from . import views + +app_name = "dataset" +urlpatterns = [ + path('dataset', views.Dataset.as_view(), name="dataset"), + path('dataset/', views.Dataset.Operate.as_view(), name="dataset_key"), + path('dataset//', views.Dataset.Page.as_view(), name="dataset"), + path('dataset//document', views.Document.as_view(), name='document') +] diff --git a/apps/dataset/views/__init__.py b/apps/dataset/views/__init__.py new file mode 100644 index 000000000..413cf8363 --- /dev/null +++ b/apps/dataset/views/__init__.py @@ -0,0 +1,10 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py + @date:2023/9/21 9:32 + @desc: +""" +from .dataset import * +from .document import * diff --git a/apps/dataset/views/dataset.py b/apps/dataset/views/dataset.py new file mode 100644 index 000000000..48439d693 --- /dev/null +++ b/apps/dataset/views/dataset.py @@ -0,0 +1,90 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: dataset.py + @date:2023/9/21 15:52 + @desc: +""" + +from drf_yasg.utils import swagger_auto_schema +from rest_framework.decorators import action +from rest_framework.views import APIView +from rest_framework.views import Request + +from common.auth import TokenAuth, has_permissions +from common.constants.permission_constants import PermissionConstants, CompareConstants, Permission, Group, Operate +from common.response import result +from common.response.result import get_page_request_params, get_page_api_response, get_api_response +from dataset.serializers.dataset_serializers import DataSetSerializers + + +class Dataset(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="获取数据集列表", + operation_id="获取数据集列表", + manual_parameters=DataSetSerializers.Query.get_request_params_api(), + responses=get_api_response(DataSetSerializers.Query.get_response_body_api())) + @has_permissions(PermissionConstants.DATASET_READ, compare=CompareConstants.AND) + def get(self, request: Request): + d = DataSetSerializers.Query(data=request.query_params) + d.is_valid() + return result.success(d.list()) + + @action(methods=['POST'], detail=False) + @swagger_auto_schema(operation_summary="创建数据集", + operation_id="创建数据集", + request_body=DataSetSerializers.Create.get_request_body_api()) + @has_permissions(PermissionConstants.DATASET_CREATE, compare=CompareConstants.AND) + def post(self, request: Request): + s = DataSetSerializers.Create(data=request.data) + if s.is_valid(): + s.save(request.user) + return result.success("ok") + + class Operate(APIView): + authentication_classes = [TokenAuth] + + @action(methods="DELETE", detail=False) + @swagger_auto_schema(operation_summary="删除数据集", operation_id="删除数据集") + @has_permissions(lambda r, keywords: Permission(group=Group.DATASET, operate=Operate.MANAGE, + dynamic_tag=keywords.get('dataset_id')), + lambda r, k: Permission(group=Group.DATASET, operate=Operate.DELETE, + dynamic_tag=k.get('dataset_id')), compare=CompareConstants.AND) + def delete(self, request: Request, dataset_id: str): + operate = DataSetSerializers.Operate(data={'id': dataset_id}) + return result.success(operate.delete()) + + @action(methods="GET", detail=False) + @swagger_auto_schema(operation_summary="查询数据集详情根据数据集id", operation_id="查询数据集详情根据数据集id", + responses=get_api_response(DataSetSerializers.Operate.get_response_body_api())) + @has_permissions(lambda r, keywords: Permission(group=Group.DATASET, operate=Operate.USE, + dynamic_tag=keywords.get('dataset_id'))) + def get(self, request: Request, dataset_id: str): + return result.success(DataSetSerializers.Operate(data={'id': dataset_id}).one()) + + @action(methods="PUT", detail=False) + @swagger_auto_schema(operation_summary="修改数据集信息", operation_id="修改数据集信息", + request_body=DataSetSerializers.Operate.get_request_body_api(), + responses=get_api_response(DataSetSerializers.Operate.get_response_body_api())) + @has_permissions(lambda r, keywords: Permission(group=Group.DATASET, operate=Operate.MANAGE, + dynamic_tag=keywords.get('dataset_id'))) + def put(self, request: Request, dataset_id: str): + return result.success(DataSetSerializers.Operate(data={'id': dataset_id}).edit(request.data)) + + class Page(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="获取数据集分页列表", + operation_id="获取数据集分页列表", + manual_parameters=get_page_request_params( + DataSetSerializers.Query.get_request_params_api()), + responses=get_page_api_response(DataSetSerializers.Query.get_response_body_api())) + @has_permissions(PermissionConstants.USER_READ, compare=CompareConstants.AND) + def get(self, request: Request, current_page, page_size): + d = DataSetSerializers.Query(data=request.query_params) + d.is_valid() + return result.success(d.page(current_page, page_size)) diff --git a/apps/dataset/views/document.py b/apps/dataset/views/document.py new file mode 100644 index 000000000..7f759e0ac --- /dev/null +++ b/apps/dataset/views/document.py @@ -0,0 +1,51 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: document.py + @date:2023/9/22 11:32 + @desc: +""" + +from drf_yasg.utils import swagger_auto_schema +from rest_framework.decorators import action +from rest_framework.views import APIView +from rest_framework.views import Request + +from common.auth import TokenAuth, has_permissions +from common.constants.permission_constants import Permission, Group, Operate, PermissionConstants +from common.response import result +from dataset.serializers.dataset_serializers import CreateDocumentSerializers + + +class Document(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['POST'], detail=False) + @swagger_auto_schema(operation_summary="创建文档", + operation_id="创建文档", + request_body=CreateDocumentSerializers().get_request_body_api(), + manual_parameters=CreateDocumentSerializers().get_request_params_api()) + @has_permissions(PermissionConstants.DATASET_CREATE) + def post(self, request: Request, dataset_id: str): + d = CreateDocumentSerializers(data=request.data) + if d.is_valid(dataset_id=dataset_id): + d.save(dataset_id) + return result.success("ok") + + +class DocumentDetails(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="获取文档详情", + operation_id="获取文档详情", + request_body=CreateDocumentSerializers().get_request_body_api(), + manual_parameters=CreateDocumentSerializers().get_request_params_api()) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE, dynamic_tag=k.get('dataset_id'))) + def get(self, request: Request, dataset_id: str): + d = CreateDocumentSerializers(data=request.data) + if d.is_valid(dataset_id=dataset_id): + d.save(dataset_id) + return result.success("ok") diff --git a/apps/embedding/__init__.py b/apps/embedding/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/apps/embedding/admin.py b/apps/embedding/admin.py new file mode 100644 index 000000000..8c38f3f3d --- /dev/null +++ b/apps/embedding/admin.py @@ -0,0 +1,3 @@ +from django.contrib import admin + +# Register your models here. diff --git a/apps/embedding/apps.py b/apps/embedding/apps.py new file mode 100644 index 000000000..45a5d88ea --- /dev/null +++ b/apps/embedding/apps.py @@ -0,0 +1,6 @@ +from django.apps import AppConfig + + +class EmbeddingConfig(AppConfig): + default_auto_field = 'django.db.models.BigAutoField' + name = 'embedding' diff --git a/apps/embedding/migrations/0001_initial.py b/apps/embedding/migrations/0001_initial.py new file mode 100644 index 000000000..968025f20 --- /dev/null +++ b/apps/embedding/migrations/0001_initial.py @@ -0,0 +1,30 @@ +# Generated by Django 4.1.10 on 2023-10-09 06:33 + +import common.field.vector_field +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [ + ('dataset', '0001_initial'), + ] + + operations = [ + migrations.CreateModel( + name='Embedding', + fields=[ + ('id', models.CharField(max_length=128, primary_key=True, serialize=False, verbose_name='主键id')), + ('source_id', models.CharField(max_length=128, verbose_name='资源id')), + ('source_type', models.CharField(choices=[('0', '问题'), ('1', '段落')], default='0', max_length=1, verbose_name='资源类型')), + ('embedding', common.field.vector_field.VectorField(verbose_name='向量')), + ('dataset', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.dataset', verbose_name='数据集关联')), + ], + options={ + 'db_table': 'embedding', + }, + ), + ] diff --git a/apps/embedding/migrations/__init__.py b/apps/embedding/migrations/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/apps/embedding/models/__init__.py b/apps/embedding/models/__init__.py new file mode 100644 index 000000000..b5dcf44f5 --- /dev/null +++ b/apps/embedding/models/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py + @date:2023/9/21 14:53 + @desc: +""" +from .embedding import * diff --git a/apps/embedding/models/embedding.py b/apps/embedding/models/embedding.py new file mode 100644 index 000000000..816dcd608 --- /dev/null +++ b/apps/embedding/models/embedding.py @@ -0,0 +1,34 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: embedding.py + @date:2023/9/21 15:46 + @desc: +""" +from django.db import models + +from common.field.vector_field import VectorField +from dataset.models.data_set import DataSet + + +class SourceType(models.TextChoices): + """订单类型""" + PROBLEM = 0, '问题' + PARAGRAPH = 1, '段落' + + +class Embedding(models.Model): + id = models.CharField(max_length=128, primary_key=True, verbose_name="主键id") + + source_id = models.CharField(max_length=128, verbose_name="资源id") + + source_type = models.CharField(verbose_name='资源类型', max_length=1, choices=SourceType.choices, + default=SourceType.PROBLEM) + + dataset = models.ForeignKey(DataSet, on_delete=models.DO_NOTHING, verbose_name="数据集关联") + + embedding = VectorField(verbose_name="向量") + + class Meta: + db_table = "embedding" diff --git a/apps/embedding/tests.py b/apps/embedding/tests.py new file mode 100644 index 000000000..7ce503c2d --- /dev/null +++ b/apps/embedding/tests.py @@ -0,0 +1,3 @@ +from django.test import TestCase + +# Create your tests here. diff --git a/apps/embedding/views.py b/apps/embedding/views.py new file mode 100644 index 000000000..91ea44a21 --- /dev/null +++ b/apps/embedding/views.py @@ -0,0 +1,3 @@ +from django.shortcuts import render + +# Create your views here. diff --git a/apps/setting/__init__.py b/apps/setting/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/apps/setting/admin.py b/apps/setting/admin.py new file mode 100644 index 000000000..8c38f3f3d --- /dev/null +++ b/apps/setting/admin.py @@ -0,0 +1,3 @@ +from django.contrib import admin + +# Register your models here. diff --git a/apps/setting/apps.py b/apps/setting/apps.py new file mode 100644 index 000000000..57d346a1f --- /dev/null +++ b/apps/setting/apps.py @@ -0,0 +1,6 @@ +from django.apps import AppConfig + + +class SettingConfig(AppConfig): + default_auto_field = 'django.db.models.BigAutoField' + name = 'setting' diff --git a/apps/setting/migrations/0001_initial.py b/apps/setting/migrations/0001_initial.py new file mode 100644 index 000000000..330047abb --- /dev/null +++ b/apps/setting/migrations/0001_initial.py @@ -0,0 +1,58 @@ +# Generated by Django 4.1.10 on 2023-10-09 06:33 + +import django.contrib.postgres.fields +from django.db import migrations, models +import django.db.models.deletion +import uuid + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [ + ('users', '0001_initial'), + ] + + operations = [ + migrations.CreateModel( + name='Team', + fields=[ + ('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')), + ('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')), + ('user', models.OneToOneField(on_delete=django.db.models.deletion.DO_NOTHING, primary_key=True, serialize=False, to='users.user', verbose_name='团队所有者')), + ('name', models.CharField(max_length=128, verbose_name='团队名称')), + ], + options={ + 'db_table': 'team', + }, + ), + migrations.CreateModel( + name='TeamMember', + fields=[ + ('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')), + ('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')), + ('id', models.UUIDField(default=uuid.uuid1, editable=False, primary_key=True, serialize=False, verbose_name='主键id')), + ('team', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='setting.team', verbose_name='团队id')), + ('user', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='users.user', verbose_name='成员用户id')), + ], + options={ + 'db_table': 'team_member', + }, + ), + migrations.CreateModel( + name='TeamMemberPermission', + fields=[ + ('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')), + ('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')), + ('id', models.UUIDField(default=uuid.uuid1, editable=False, primary_key=True, serialize=False, verbose_name='主键id')), + ('auth_target_type', models.CharField(choices=[('DATASET', '数据集'), ('APPLICATION', '应用')], default='DATASET', max_length=128, verbose_name='授权目标')), + ('target', models.UUIDField(verbose_name='数据集/应用id')), + ('operate', django.contrib.postgres.fields.ArrayField(base_field=models.CharField(blank=True, choices=[('MANAGE', '管理'), ('USE', '使用')], default='USE', max_length=256), size=None, verbose_name='权限操作列表')), + ('member', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='setting.teammember', verbose_name='团队成员')), + ], + options={ + 'db_table': 'team_member_permission', + }, + ), + ] diff --git a/apps/setting/migrations/__init__.py b/apps/setting/migrations/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/apps/setting/models/__init__.py b/apps/setting/models/__init__.py new file mode 100644 index 000000000..42f719eea --- /dev/null +++ b/apps/setting/models/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py + @date:2023/9/25 15:04 + @desc: +""" +from .team_management import * diff --git a/apps/setting/models/team_management.py b/apps/setting/models/team_management.py new file mode 100644 index 000000000..3e480d8d7 --- /dev/null +++ b/apps/setting/models/team_management.py @@ -0,0 +1,73 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: team_management.py + @date:2023/9/25 15:04 + @desc: +""" +import uuid + +from django.contrib.postgres.fields import ArrayField +from django.db import models + +from common.constants.permission_constants import Group, Operate +from common.mixins.app_model_mixin import AppModelMixin +from users.models import User + + +class AuthTargetType(models.TextChoices): + """授权目标""" + DATASET = Group.DATASET.value, '数据集' + APPLICATION = Group.APPLICATION.value, '应用' + + +class AuthOperate(models.TextChoices): + """授权权限""" + MANAGE = Operate.MANAGE.value, '管理' + + USE = Operate.USE.value, "使用" + + +class Team(AppModelMixin): + """ + 团队表 + """ + user = models.OneToOneField(User, primary_key=True, on_delete=models.DO_NOTHING, verbose_name="团队所有者") + + name = models.CharField(max_length=128, verbose_name="团队名称") + + class Meta: + db_table = "team" + + +class TeamMember(AppModelMixin): + id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id") + team = models.ForeignKey(Team, on_delete=models.DO_NOTHING, verbose_name="团队id") + user = models.ForeignKey(User, on_delete=models.DO_NOTHING, verbose_name="成员用户id") + + class Meta: + db_table = "team_member" + + +class TeamMemberPermission(AppModelMixin): + id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id") + """ + 团队成员权限 + """ + member = models.ForeignKey(TeamMember, on_delete=models.DO_NOTHING, verbose_name="团队成员") + + auth_target_type = models.CharField(verbose_name='授权目标', max_length=128, choices=AuthTargetType.choices, + default=AuthTargetType.DATASET) + + target = models.UUIDField(max_length=128, verbose_name="数据集/应用id") + + operate = ArrayField(verbose_name="权限操作列表", + base_field=models.CharField(max_length=256, + blank=True, + choices=AuthOperate.choices, + default=AuthOperate.USE), + ) + + class Meta: + db_table = "team_member_permission" diff --git a/apps/setting/serializers/team_serializers.py b/apps/setting/serializers/team_serializers.py new file mode 100644 index 000000000..b4cb1364f --- /dev/null +++ b/apps/setting/serializers/team_serializers.py @@ -0,0 +1,276 @@ +# coding=utf-8 +""" + @project: qabot + @Author:虎 + @file: team_serializers.py + @date:2023/9/5 16:32 + @desc: +""" +import itertools +import json +import os +import uuid +from typing import Dict + +from django.core import cache +from django.db.models import QuerySet, Q +from drf_yasg import openapi +from rest_framework import serializers + +from common.constants.permission_constants import Operate +from common.db.sql_execute import select_list +from common.exception.app_exception import AppApiException +from common.mixins.api_mixin import ApiMixin +from common.response.result import get_api_response +from common.util.file_util import get_file_content +from setting.models import TeamMember, TeamMemberPermission +from smartdoc.conf import PROJECT_DIR +from users.models.user import User +from users.serializers.user_serializers import UserSerializer + +user_cache = cache.caches['user_cache'] + + +def get_response_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['id', 'username', 'email', 'role', 'is_active', 'team_id', 'member_id'], + properties={ + 'id': openapi.Schema(type=openapi.TYPE_STRING, title="用户id", description="用户id"), + 'username': openapi.Schema(type=openapi.TYPE_STRING, title="用户名", description="用户名"), + 'email': openapi.Schema(type=openapi.TYPE_STRING, title="邮箱", description="邮箱地址"), + 'role': openapi.Schema(type=openapi.TYPE_STRING, title="角色", description="角色"), + 'is_active': openapi.Schema(type=openapi.TYPE_STRING, title="是否可用", description="是否可用"), + 'team_id': openapi.Schema(type=openapi.TYPE_STRING, title="团队id", description="团队id"), + 'member_id': openapi.Schema(type=openapi.TYPE_STRING, title="成员id", description="成员id"), + } + ) + + +class TeamMemberPermissionOperate(ApiMixin, serializers.Serializer): + USE = serializers.BooleanField(required=True) + MANAGE = serializers.BooleanField(required=True) + + def get_request_body_api(self): + return openapi.Schema(type=openapi.TYPE_OBJECT, + title="类型", + description="操作权限USE,MANAGE权限", + properties={ + 'USE': openapi.Schema(type=openapi.TYPE_BOOLEAN, + title="使用权限", + description="使用权限 True|False"), + 'MANAGE': openapi.Schema(type=openapi.TYPE_BOOLEAN, + title="管理权限", + description="管理权限 True|False") + } + ) + + +class UpdateTeamMemberItemPermissionSerializer(ApiMixin, serializers.Serializer): + target_id = serializers.CharField(required=True) + type = serializers.CharField(required=True) + operate = TeamMemberPermissionOperate(required=True, many=False) + + def get_request_body_api(self): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['id', 'type', 'operate'], + properties={ + 'target_id': openapi.Schema(type=openapi.TYPE_STRING, title="数据集/应用id", + description="数据集或者应用的id"), + 'type': openapi.Schema(type=openapi.TYPE_STRING, + title="类型", + description="DATASET|APPLICATION", + ), + 'operate': TeamMemberPermissionOperate().get_request_body_api() + } + ) + + +class UpdateTeamMemberPermissionSerializer(ApiMixin, serializers.Serializer): + team_member_permission_list = UpdateTeamMemberItemPermissionSerializer(required=True, many=True) + + def is_valid(self, *, user_id=None): + super().is_valid(raise_exception=True) + permission_list = self.data.get("team_member_permission_list") + illegal_target_id_list = select_list( + get_file_content( + os.path.join(PROJECT_DIR, "apps", "setting", 'sql', 'check_member_permission_target_exists.sql')), + [json.dumps(permission_list), user_id, user_id]) + if illegal_target_id_list is not None and len(illegal_target_id_list) > 0: + raise AppApiException(500, '不存在的 应用|数据集id[' + str(illegal_target_id_list) + ']') + + def update_or_save(self, member_id: str): + team_member_permission_list = self.data.get("team_member_permission_list") + # 获取数据库已有权限 从而判断是否是插入还是更新 + team_member_permission_exist_list = QuerySet(TeamMemberPermission).filter( + member_id=member_id) + update_list = [] + save_list = [] + for item in team_member_permission_list: + exist_list = list( + filter(lambda use: str(use.target) == item.get('target_id'), team_member_permission_exist_list)) + if len(exist_list) > 0: + exist_list[0].operate = list( + filter(lambda key: item.get('operate').get(key), + item.get('operate').keys())) + update_list.append(exist_list[0]) + else: + save_list.append(TeamMemberPermission(target=item.get('target_id'), auth_target_type=item.get('type'), + operate=list( + filter(lambda key: item.get('operate').get(key), + item.get('operate').keys())), + member_id=member_id)) + # 批量更新 + QuerySet(TeamMemberPermission).bulk_update(update_list, ['operate']) + # 批量插入 + QuerySet(TeamMemberPermission).bulk_create(save_list) + + def get_request_body_api(self): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['id'], + properties={ + 'team_member_permission_list': + openapi.Schema(type=openapi.TYPE_ARRAY, title="权限数据", + description="权限数据", + items=UpdateTeamMemberItemPermissionSerializer().get_request_body_api() + ), + } + ) + + +class TeamMemberSerializer(ApiMixin, serializers.Serializer): + team_id = serializers.CharField(required=True) + + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + + def get_request_body_api(self): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['username_or_email'], + properties={ + 'username_or_email': openapi.Schema(type=openapi.TYPE_STRING, title="用户名或者邮箱", + description="用户名或者邮箱"), + + } + ) + + def add_member(self, username_or_email: str, with_valid=True): + """ + 添加一个成员 + :param with_valid: 是否校驗參數 + :param username_or_email: 添加成员的邮箱或者用户名 + :return: 成员列表 + """ + if with_valid: + self.is_valid(raise_exception=True) + if username_or_email is None: + raise AppApiException(500, "用户名或者邮箱必填") + user = QuerySet(User).filter( + Q(username=username_or_email) | Q(email=username_or_email)).first() + if user is None: + raise AppApiException(500, "不存在的用户") + if QuerySet(TeamMember).filter(Q(team_id=self.data.get('team_id')) & Q(user=user)).exists(): + raise AppApiException(500, "团队中已存在当前成员,不要重复添加") + TeamMember(team_id=self.data.get("team_id"), user=user).save() + return TeamMemberSerializer(data={'team_id': self.data.get("team_id")}).list_member() + + def list_member(self, with_valid=True): + """ + 获取 团队中的成员列表 + :return: 成员列表 + """ + if with_valid: + self.is_valid(raise_exception=True) + # 普通成員列表 + member_list = list(map(lambda t: {"id": t.id, 'email': t.user.email, 'username': t.user.username, + 'team_id': self.data.get("team_id"), 'user_id': t.user.id, + 'type': 'member'}, + QuerySet(TeamMember).filter(team_id=self.data.get("team_id")))) + # 管理員成員 + manage_member = QuerySet(User).get(id=self.data.get('team_id')) + return [{'id': 'root', 'email': manage_member.email, 'username': manage_member.username, + 'team_id': self.data.get("team_id"), 'user_id': manage_member.id, 'type': 'manage' + }, *member_list] + + def get_response_body_api(self): + return get_api_response(openapi.Schema( + type=openapi.TYPE_ARRAY, title="成员列表", description="成员列表", + items=UserSerializer().get_response_body_api() + )) + + class Operate(ApiMixin, serializers.Serializer): + # 团队 成员id + member_id = serializers.CharField(required=True) + # 团队id + team_id = serializers.CharField(required=True) + + def is_valid(self, *, raise_exception=True): + super().is_valid(raise_exception=True) + if self.data.get('member_id') != 'root' and not QuerySet(TeamMember).filter( + team_id=self.data.get('team_id'), + id=self.data.get('member_id')).exists(): + raise AppApiException(500, "不存在的成员,请先添加成员") + + return True + + def list_member_permission(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + team_id = self.data.get('team_id') + member_id = self.data.get("member_id") + # 查询当前团队成员所有的数据集和应用的权限 注意 operate为null是为设置权限 默认值都是false + member_permission_list = select_list( + get_file_content(os.path.join(PROJECT_DIR, "apps", "setting", 'sql', 'get_member_permission.sql')), + [team_id, team_id, (member_id if member_id != 'root' else uuid.uuid1())]) + + # 如果是管理员 则拥有所有权限 默认赋值 + if member_id == 'root': + member_permission_list = list( + map(lambda row: {**row, 'operate': {Operate.USE.value: True, Operate.MANAGE.value: True}}, + member_permission_list)) + # 分为 APPLICATION DATASET俩组 + groups = itertools.groupby( + list(map(lambda m: {**m, 'member_id': member_id, + 'operate': dict( + map(lambda key: (key, True if m.get('operate') is not None and m.get( + 'operate').__contains__(key) else False), + [Operate.USE.value, Operate.MANAGE.value]))}, + member_permission_list)), + key=lambda x: x.get('type')) + return dict([(key, list(group)) for key, group in groups]) + + def edit(self, member_permission: Dict): + self.is_valid(raise_exception=True) + member_id = self.data.get("member_id") + if member_id == 'root': + raise AppApiException(500, "管理员权限不允许修改") + s = UpdateTeamMemberPermissionSerializer(data=member_permission) + s.is_valid(user_id=self.data.get("team_id")) + s.update_or_save(member_id) + return self.list_member_permission(with_valid=False) + + def delete(self): + """ + 移除成员 + :return: + """ + self.is_valid(raise_exception=True) + member_id = self.data.get("member_id") + if member_id == 'root': + raise AppApiException(500, "无法移除团队管理员") + # 删除成员权限 + QuerySet(TeamMemberPermission).filter(member_id=member_id).delete() + # 删除成员 + QuerySet(TeamMember).filter(id=member_id).delete() + return True + + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='member_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='团队成员id')] diff --git a/apps/setting/sql/check_member_permission_target_exists.sql b/apps/setting/sql/check_member_permission_target_exists.sql new file mode 100644 index 000000000..13c1aaaaa --- /dev/null +++ b/apps/setting/sql/check_member_permission_target_exists.sql @@ -0,0 +1,32 @@ +SELECT + static_temp."target_id"::text +FROM + (SELECT * FROM json_to_recordset( + %s + ) AS x(target_id uuid,type text)) static_temp + LEFT JOIN ( + SELECT + "id", + 'DATASET' AS "type", + user_id, + ARRAY [ 'MANAGE', + 'USE', + 'DELETE' ] AS "operate" + FROM + dataset + WHERE + "user_id" = %s UNION + SELECT + "id", + 'APPLICATION' AS "type", + user_id, + ARRAY [ 'MANAGE', + 'USE', + 'DELETE' ] AS "operate" + FROM + application + WHERE + "user_id" = %s + ) "app_and_dataset_temp" + ON "app_and_dataset_temp"."id" = static_temp."target_id" and app_and_dataset_temp."type"=static_temp."type" + WHERE app_and_dataset_temp.id is NULL ; \ No newline at end of file diff --git a/apps/setting/sql/get_member_permission.sql b/apps/setting/sql/get_member_permission.sql new file mode 100644 index 000000000..f6b2d953f --- /dev/null +++ b/apps/setting/sql/get_member_permission.sql @@ -0,0 +1,26 @@ +SELECT + app_or_dataset.*, + team_member_permission.member_id, + team_member_permission.operate +FROM + ( + SELECT + "id", + "name", + 'DATASET' AS "type", + user_id + FROM + dataset + WHERE + "user_id" = %s UNION + SELECT + "id", + "name", + 'APPLICATION' AS "type", + user_id + FROM + application + WHERE + "user_id" = %s + ) app_or_dataset + LEFT JOIN ( SELECT * FROM team_member_permission WHERE member_id = %s ) team_member_permission ON team_member_permission.target = app_or_dataset."id" \ No newline at end of file diff --git a/apps/setting/sql/get_user_permission.sql b/apps/setting/sql/get_user_permission.sql new file mode 100644 index 000000000..c50e5ea6e --- /dev/null +++ b/apps/setting/sql/get_user_permission.sql @@ -0,0 +1,30 @@ +SELECT + "id", + 'DATASET' AS "type", + user_id, + ARRAY [ 'MANAGE', + 'USE','DELETE' ] AS "operate" +FROM + dataset +WHERE + "user_id" = %s UNION +SELECT + "id", + 'APPLICATION' AS "type", + user_id, + ARRAY [ 'MANAGE', + 'USE','DELETE' ] AS "operate" +FROM + application +WHERE + "user_id" = %s UNION +SELECT + team_member_permission.target AS "id", + team_member_permission.auth_target_type AS "type", + team_member.user_id AS user_id, + team_member_permission.operate AS "operate" +FROM + team_member team_member + LEFT JOIN team_member_permission team_member_permission ON team_member.ID = team_member_permission.member_id +WHERE + team_member.user_id = %s AND team_member_permission.target IS NOT NULL \ No newline at end of file diff --git a/apps/setting/tests.py b/apps/setting/tests.py new file mode 100644 index 000000000..7ce503c2d --- /dev/null +++ b/apps/setting/tests.py @@ -0,0 +1,3 @@ +from django.test import TestCase + +# Create your tests here. diff --git a/apps/setting/urls.py b/apps/setting/urls.py new file mode 100644 index 000000000..778d1080f --- /dev/null +++ b/apps/setting/urls.py @@ -0,0 +1,9 @@ +from django.urls import path + +from . import views + +app_name = "team" +urlpatterns = [ + path('team/member', views.TeamMember.as_view(), name="team"), + path('team/member/', views.TeamMember.Operate.as_view(), name='member') +] diff --git a/apps/setting/views/Team.py b/apps/setting/views/Team.py new file mode 100644 index 000000000..468a37c2b --- /dev/null +++ b/apps/setting/views/Team.py @@ -0,0 +1,66 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: Team.py + @date:2023/9/25 17:13 + @desc: +""" +from drf_yasg.utils import swagger_auto_schema +from rest_framework.decorators import action +from rest_framework.views import APIView +from rest_framework.views import Request + +from common.auth import TokenAuth +from common.response import result +from setting.serializers.team_serializers import TeamMemberSerializer, get_response_body_api, \ + UpdateTeamMemberPermissionSerializer + + +class TeamMember(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="获取团队成员列表", + operation_id="获取团员成员列表", + responses=result.get_api_response(get_response_body_api())) + def get(self, request: Request): + return result.success(TeamMemberSerializer(data={'team_id': str(request.user.id)}).list_member()) + + @action(methods=['POST'], detail=False) + @swagger_auto_schema(operation_summary="添加成员", + operation_id="添加成员", + request_body=TeamMemberSerializer().get_request_body_api()) + def post(self, request: Request): + team = TeamMemberSerializer(data={'team_id': str(request.user.id)}) + return result.success((team.add_member(**request.data))) + + class Operate(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="获取团队成员权限", + operation_id="获取团队成员权限", + manual_parameters=TeamMemberSerializer.Operate.get_request_params_api()) + def get(self, request: Request, member_id: str): + return result.success(TeamMemberSerializer.Operate( + data={'member_id': member_id, 'team_id': str(request.user.id)}).list_member_permission()) + + @action(methods=['PUT'], detail=False) + @swagger_auto_schema(operation_summary="修改团队成员权限", + operation_id="修改团队成员权限", + request_body=UpdateTeamMemberPermissionSerializer().get_request_body_api(), + manual_parameters=TeamMemberSerializer.Operate.get_request_params_api() + ) + def put(self, request: Request, member_id: str): + return result.success(TeamMemberSerializer.Operate( + data={'member_id': member_id, 'team_id': str(request.user.id)}).edit(request.data)) + + @action(methods=['DELETE'], detail=False) + @swagger_auto_schema(operation_summary="移除成员", + operation_id="移除成员", + manual_parameters=TeamMemberSerializer.Operate.get_request_params_api() + ) + def delete(self, request: Request, member_id: str): + return result.success(TeamMemberSerializer.Operate( + data={'member_id': member_id, 'team_id': str(request.user.id)}).delete()) diff --git a/apps/setting/views/__init__.py b/apps/setting/views/__init__.py new file mode 100644 index 000000000..3eaf6ef6e --- /dev/null +++ b/apps/setting/views/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py.py + @date:2023/9/25 17:12 + @desc: +""" +from .Team import * diff --git a/apps/smartdoc/settings/base.py b/apps/smartdoc/settings/base.py index 72113a74c..8dc1ebf0b 100644 --- a/apps/smartdoc/settings/base.py +++ b/apps/smartdoc/settings/base.py @@ -17,7 +17,7 @@ BASE_DIR = Path(__file__).resolve().parent.parent SECRET_KEY = 'django-insecure-g1u*$)1ddn20_3orw^f+g4(i(2dacj^awe*2vh-$icgqwfnbq(' # SECURITY WARNING: don't run with debug turned on in production! -DEBUG = False +DEBUG = True ALLOWED_HOSTS = ['*'] @@ -29,12 +29,17 @@ DATABASES = { INSTALLED_APPS = [ 'users.apps.UsersConfig', + 'setting', + 'dataset', + 'application', + 'embedding', 'django.contrib.contenttypes', 'django.contrib.messages', 'django.contrib.staticfiles', 'rest_framework', "drf_yasg", # swagger 接口 'django_filters', # 条件过滤 + ] MIDDLEWARE = [ diff --git a/apps/smartdoc/settings/logging.py b/apps/smartdoc/settings/logging.py index ffdffbc49..3944cf84e 100644 --- a/apps/smartdoc/settings/logging.py +++ b/apps/smartdoc/settings/logging.py @@ -10,7 +10,7 @@ DRF_EXCEPTION_LOG_FILE = os.path.join(LOG_DIR, 'drf_exception.log') UNEXPECTED_EXCEPTION_LOG_FILE = os.path.join(LOG_DIR, 'unexpected_exception.log') ANSIBLE_LOG_FILE = os.path.join(LOG_DIR, 'ansible.log') GUNICORN_LOG_FILE = os.path.join(LOG_DIR, 'gunicorn.log') -LOG_LEVEL = CONFIG.LOG_LEVEL +LOG_LEVEL = "DEBUG" LOGGING = { 'version': 1, @@ -100,39 +100,20 @@ LOGGING = { 'level': LOG_LEVEL, 'propagate': False, }, + 'django.db.backends': { + 'handlers': ['console', 'file', 'syslog'], + 'propagate': False, + 'level': LOG_LEVEL, + }, 'django.server': { 'handlers': ['console', 'file', 'syslog'], 'level': LOG_LEVEL, 'propagate': False, }, - 'jumpserver': { + 'smartdoc': { 'handlers': ['console', 'file'], 'level': LOG_LEVEL, }, - 'drf_exception': { - 'handlers': ['console', 'drf_exception'], - 'level': LOG_LEVEL, - }, - 'unexpected_exception': { - 'handlers': ['unexpected_exception'], - 'level': LOG_LEVEL, - }, - 'ops.ansible_api': { - 'handlers': ['console', 'ansible_logs'], - 'level': LOG_LEVEL, - }, - 'django_auth_ldap': { - 'handlers': ['console', 'file'], - 'level': "INFO", - }, - 'syslog': { - 'handlers': ['syslog'], - 'level': 'INFO' - }, - 'azure': { - 'handlers': ['null'], - 'level': 'ERROR' - } } } diff --git a/apps/smartdoc/urls.py b/apps/smartdoc/urls.py index 12aac16db..e74af4c48 100644 --- a/apps/smartdoc/urls.py +++ b/apps/smartdoc/urls.py @@ -42,6 +42,8 @@ schema_view = get_schema_view( urlpatterns = [ path("api/", include("users.urls")), + path("api/", include("dataset.urls")), + path("api/", include("setting.urls")) ] diff --git a/apps/users/migrations/0001_initial.py b/apps/users/migrations/0001_initial.py index 851b71474..7bd5bb2fa 100644 --- a/apps/users/migrations/0001_initial.py +++ b/apps/users/migrations/0001_initial.py @@ -1,6 +1,7 @@ -# Generated by Django 4.1.10 on 2023-09-20 02:58 +# Generated by Django 4.1.10 on 2023-10-09 06:33 from django.db import migrations, models +import uuid class Migration(migrations.Migration): @@ -14,7 +15,7 @@ class Migration(migrations.Migration): migrations.CreateModel( name='User', fields=[ - ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('id', models.UUIDField(default=uuid.uuid1, editable=False, primary_key=True, serialize=False, verbose_name='主键id')), ('email', models.EmailField(max_length=254, unique=True, verbose_name='邮箱')), ('username', models.CharField(max_length=150, unique=True, verbose_name='用户名')), ('password', models.CharField(max_length=150, verbose_name='密码')), diff --git a/apps/users/models/user.py b/apps/users/models/user.py index 226acee70..d7cfc1f4c 100644 --- a/apps/users/models/user.py +++ b/apps/users/models/user.py @@ -7,9 +7,16 @@ @desc: """ import hashlib +import os +import uuid from django.db import models +from common.constants.permission_constants import Permission, Group, Operate +from common.db.sql_execute import select_list +from common.util.file_util import get_file_content +from smartdoc.conf import PROJECT_DIR + __all__ = ["User", "password_encrypt"] @@ -25,7 +32,36 @@ def password_encrypt(raw_password): return result +def to_dynamics_permission(group_type: str, operate: list[str], dynamic_tag: str): + """ + 转换为权限对象 + :param group_type: 分组类型 + :param operate: 操作 + :param dynamic_tag: 标记 + :return: 权限列表 + """ + return [Permission(group=Group[group_type], operate=Operate[o], dynamic_tag=dynamic_tag) + for o in operate] + + +def get_user_dynamics_permission(user_id: str): + """ + 获取 应用和数据集权限 + :param user_id: 用户id + :return: 用户 应用和数据集权限 + """ + member_permission_list = select_list( + get_file_content(os.path.join(PROJECT_DIR, "apps", "setting", 'sql', 'get_user_permission.sql')), + [user_id, user_id, user_id]) + result = [] + for member_permission in member_permission_list: + result += to_dynamics_permission(member_permission.get('type'), member_permission.get('operate'), + str(member_permission.get('id'))) + return result + + class User(models.Model): + id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id") email = models.EmailField(unique=True, verbose_name="邮箱") username = models.CharField(max_length=150, unique=True, verbose_name="用户名") password = models.CharField(max_length=150, verbose_name="密码") diff --git a/apps/users/serializers/user_serializers.py b/apps/users/serializers/user_serializers.py index 95c0afbf4..17f1d507d 100644 --- a/apps/users/serializers/user_serializers.py +++ b/apps/users/serializers/user_serializers.py @@ -2,7 +2,7 @@ """ @project: qabot @Author:虎 - @file: user_serializers.py + @file: team_serializers.py @date:2023/9/5 16:32 @desc: """ @@ -10,22 +10,25 @@ import datetime import os import random import re +import uuid from django.core import validators, signing, cache from django.core.mail import send_mail +from django.db import transaction from django.db.models import Q from drf_yasg import openapi from rest_framework import serializers from common.constants.exception_code_constants import ExceptionCodeConstants -from common.constants.permission_constants import RoleConstants +from common.constants.permission_constants import RoleConstants, get_permission_list_by_role from common.exception.app_exception import AppApiException from common.mixins.api_mixin import ApiMixin from common.response.result import get_api_response from common.util.lock import lock +from setting.models import Team from smartdoc.conf import PROJECT_DIR from smartdoc.settings import EMAIL_ADDRESS -from users.models.user import User, password_encrypt +from users.models.user import User, password_encrypt, get_user_dynamics_permission user_cache = cache.caches['user_cache'] @@ -63,7 +66,7 @@ class LoginSerializer(ApiMixin, serializers.Serializer): :return: 用户Token(认证信息) """ user = self.is_valid() - token = signing.dumps({'username': user.username, 'id': user.id, 'email': user.email}) + token = signing.dumps({'username': user.username, 'id': str(user.id), 'email': user.email}) return token class Meta: @@ -86,7 +89,7 @@ class LoginSerializer(ApiMixin, serializers.Serializer): title="token", default="xxxx", description="认证token" - ), "token value") + )) class RegisterSerializer(ApiMixin, serializers.Serializer): @@ -138,19 +141,23 @@ class RegisterSerializer(ApiMixin, serializers.Serializer): return True + @transaction.atomic def save(self, **kwargs): m = User( - **{'email': self.data.get("email"), 'username': self.data.get("username"), + **{'id': uuid.uuid1(), 'email': self.data.get("email"), 'username': self.data.get("username"), 'role': RoleConstants.USER.name}) m.set_password(self.data.get("password")) # 插入用户 m.save() + # 初始化用户团队 + Team(**{'user': m, 'name': m.username + '的团队'}).save() email = self.data.get("email") code_cache_key = email + ":register" # 删除验证码缓存 user_cache.delete(code_cache_key) - def get_request_body_api(self): + @staticmethod + def get_request_body_api(): return openapi.Schema( type=openapi.TYPE_OBJECT, required=['username', 'email', 'password', 're_password', 'code'], @@ -205,7 +212,7 @@ class CheckCodeSerializer(ApiMixin, serializers.Serializer): type=openapi.TYPE_BOOLEAN, title="是否成功", default=True, - description="错误提示"), True) + description="错误提示")) class RePasswordSerializer(ApiMixin, serializers.Serializer): @@ -334,11 +341,55 @@ class SendEmailSerializer(ApiMixin, serializers.Serializer): ) def get_response_body_api(self): - return get_api_response(openapi.Schema(type=openapi.TYPE_STRING, default=True), True) + return get_api_response(openapi.Schema(type=openapi.TYPE_STRING, default=True)) -class UserSerializer(serializers.ModelSerializer): +class UserProfile(ApiMixin): + + @staticmethod + def get_user_profile(user: User): + """ + 获取用户详情 + :param user: 用户对象 + :return: + """ + permission_list = get_user_dynamics_permission(str(user.id)) + permission_list += [p.value for p in get_permission_list_by_role(RoleConstants[user.role])] + return {'id': user.id, 'username': user.username, 'email': user.email, 'role': user.role, + 'permissions': [str(p) for p in permission_list]} + + @staticmethod + def get_response_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['id', 'username', 'email', 'role', 'is_active'], + properties={ + 'id': openapi.Schema(type=openapi.TYPE_STRING, title="用户id", description="用户id"), + 'username': openapi.Schema(type=openapi.TYPE_STRING, title="用户名", description="用户名"), + 'email': openapi.Schema(type=openapi.TYPE_STRING, title="邮箱", description="邮箱地址"), + 'role': openapi.Schema(type=openapi.TYPE_STRING, title="角色", description="角色"), + 'is_active': openapi.Schema(type=openapi.TYPE_STRING, title="是否可用", description="是否可用"), + "permissions": openapi.Schema(type=openapi.TYPE_ARRAY, title="权限列表", description="权限列表", + items=openapi.Schema(type=openapi.TYPE_STRING)) + } + ) + + +class UserSerializer(ApiMixin, serializers.ModelSerializer): class Meta: model = User fields = ["email", "id", "username", ] + + def get_response_body_api(self): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['id', 'username', 'email', 'role', 'is_active'], + properties={ + 'id': openapi.Schema(type=openapi.TYPE_STRING, title="用户id", description="用户id"), + 'username': openapi.Schema(type=openapi.TYPE_STRING, title="用户名", description="用户名"), + 'email': openapi.Schema(type=openapi.TYPE_STRING, title="邮箱", description="邮箱地址"), + 'role': openapi.Schema(type=openapi.TYPE_STRING, title="角色", description="角色"), + 'is_active': openapi.Schema(type=openapi.TYPE_STRING, title="是否可用", description="是否可用") + } + ) diff --git a/apps/users/views/user.py b/apps/users/views/user.py index 407e79624..8c16d11b3 100644 --- a/apps/users/views/user.py +++ b/apps/users/views/user.py @@ -7,6 +7,7 @@ @desc: """ from django.core import cache +from django.db.models import QuerySet from drf_yasg import openapi from drf_yasg.utils import swagger_auto_schema from rest_framework.decorators import action @@ -23,7 +24,7 @@ from smartdoc.settings import JWT_AUTH from users.models.user import User as UserModel from users.serializers.user_serializers import RegisterSerializer, LoginSerializer, UserSerializer, CheckCodeSerializer, \ RePasswordSerializer, \ - SendEmailSerializer + SendEmailSerializer, UserProfile user_cache = cache.caches['user_cache'] token_cache = cache.caches['token_cache'] @@ -34,10 +35,11 @@ class User(APIView): @action(methods=['GET'], detail=False) @swagger_auto_schema(operation_summary="获取当前用户信息", - operation_id="获取当前用户信息") + operation_id="获取当前用户信息", + responses=result.get_api_response(UserProfile.get_response_body_api())) @has_permissions(PermissionConstants.USER_READ, compare=CompareConstants.AND) def get(self, request: Request): - return result.success(UserSerializer(instance=UserModel.objects.get(id=request.user.id)).data) + return result.success(UserProfile.get_user_profile(request.user)) class ResetCurrentUserPasswordView(APIView): diff --git a/ui/src/api/user/type.ts b/ui/src/api/user/type.ts index a793e6267..6462354f8 100644 --- a/ui/src/api/user/type.ts +++ b/ui/src/api/user/type.ts @@ -1,4 +1,8 @@ interface User { + /** + * 用户id + */ + id: string /** * 用户名 */ @@ -7,6 +11,14 @@ interface User { * 邮箱 */ email: string + /** + * 用户角色 + */ + role: string + /** + * 用户权限 + */ + permissions: Array } interface LoginRequest { diff --git a/ui/src/common/permission/index.ts b/ui/src/common/permission/index.ts new file mode 100644 index 000000000..1bf099209 --- /dev/null +++ b/ui/src/common/permission/index.ts @@ -0,0 +1,52 @@ +import { store } from '@/stores' +import { useUserStore } from '@/stores/user' +import { Role, Permission, ComplexPermission } from '@/common/permission/type' +/** + * 是否包含当前权限 + * @param permission 当前权限 + * @returns True 包含 false 不包含 + */ +const hasPermissionChild = (permission: Role | string | Permission | ComplexPermission) => { + const userStore = useUserStore(store) + const permissions = userStore.getPermissions() + const role = userStore.getRole() + if (permission instanceof Role) { + return role === permission.role + } + if (permission instanceof Permission) { + return permissions.includes(permission.permission) + } + if (permission instanceof ComplexPermission) { + const permissionOk = permission.permissionList.some((p) => permissions.includes(p)) + const roleOk = permission.roleList.includes(role) + return permission.compare === 'AND' ? permissionOk && roleOk : permissionOk || roleOk + } + if (typeof permission === 'string') { + return permissions.includes(permission) + } + return false +} +/** + * 判断是否有角色和权限 + * @param role 角色 + * @param permissions 权限 + * @param requiredPermissions 权限 + * @returns + */ +export const hasPermission = ( + permission: + | Array + | Role + | string + | Permission + | ComplexPermission, + compare: 'OR' | 'AND' +): boolean => { + if (permission instanceof Array) { + return compare === 'OR' + ? permission.some((p) => hasPermissionChild(p)) + : permission.every((p) => hasPermissionChild(p)) + } else { + return hasPermissionChild(permission) + } +} diff --git a/ui/src/common/permission/type.ts b/ui/src/common/permission/type.ts new file mode 100644 index 000000000..874fc830b --- /dev/null +++ b/ui/src/common/permission/type.ts @@ -0,0 +1,36 @@ +/** + * 角色对象 + */ +export class Role { + role: string + + constructor(role: string) { + this.role = role + } +} +/** + * 权限对象 + */ +export class Permission { + permission: string + + constructor(permission: string) { + this.permission = permission + } +} +/** + * 复杂权限对象 + */ +export class ComplexPermission { + roleList: Array + + permissionList: Array + + compare: 'OR' | 'AND' + + constructor(roleList: Array, permissionList: Array, compare: 'OR' | 'AND') { + this.roleList = roleList + this.permissionList = permissionList + this.compare = compare + } +} diff --git a/ui/src/components/layout/top-bar/components/top-menu/index.vue b/ui/src/components/layout/top-bar/components/top-menu/index.vue index 03f2d1548..79ab11a8b 100644 --- a/ui/src/components/layout/top-bar/components/top-menu/index.vue +++ b/ui/src/components/layout/top-bar/components/top-menu/index.vue @@ -1,6 +1,6 @@ diff --git a/ui/src/directives/hasPermission.ts b/ui/src/directives/hasPermission.ts new file mode 100644 index 000000000..bdd2344a9 --- /dev/null +++ b/ui/src/directives/hasPermission.ts @@ -0,0 +1,27 @@ +import type { App } from 'vue' +import { hasPermission } from '@/common/permission' + +const display = async (el: any, binding: any) => { + const has = hasPermission( + binding.value.permission ? binding.value.permission : binding.value, + binding.value.compare ? binding.value.compare : 'OR' + ) + if (!has) { + el.style.display = 'none' + } else { + delete el.style.display + } +} + +export default { + install: (app: App) => { + app.directive('hasPermission', { + async created(el: any, binding: any) { + display(el, binding) + }, + async beforeUpdate(el: any, binding: any) { + display(el, binding) + } + }) + } +} diff --git a/ui/src/directives/index.ts b/ui/src/directives/index.ts new file mode 100644 index 000000000..de8ee8018 --- /dev/null +++ b/ui/src/directives/index.ts @@ -0,0 +1,14 @@ +import type { App } from 'vue' + +const directives = import.meta.glob('./*.ts', { eager: true }) +const install = (app: App) => { + Object.keys(directives) + .filter((key: string) => { + return !key.endsWith('index.ts') + }) + .forEach((key: string) => { + const directive: any = directives[key] + app.use(directive.default) + }) +} +export default { install } diff --git a/ui/src/main.ts b/ui/src/main.ts index 83cc8d61b..80bd1614d 100644 --- a/ui/src/main.ts +++ b/ui/src/main.ts @@ -6,12 +6,13 @@ import 'element-plus/dist/index.css' import { createApp } from 'vue' import { store } from '@/stores' import theme from '@/theme' - +import directives from '@/directives' import App from './App.vue' -import router from './router' +import router from '@/router' const app = createApp(App) app.use(store) +app.use(directives) const ElementPlusIconsVue: object = ElementPlusIcons // 将elementIcon放到全局 app.config.globalProperties.$antIcons = ElementPlusIconsVue diff --git a/ui/src/router/data.ts b/ui/src/router/data.ts index f419e76f2..8c7733970 100644 --- a/ui/src/router/data.ts +++ b/ui/src/router/data.ts @@ -1,5 +1,5 @@ import type { RouteRecordRaw } from 'vue-router' - +import { Role } from '@/common/permission/type' export const routes: Array = [ { path: '/', @@ -15,19 +15,19 @@ export const routes: Array = [ { path: '/app', name: 'app', - meta: { icon: 'app', title: '应用' }, + meta: { icon: 'app', title: '应用', permission: 'APPLICATION:READ' }, component: () => import('@/views/app/index.vue') }, { path: '/dataset', name: 'dataset', - meta: { icon: 'dataset', title: '数据集' }, + meta: { icon: 'dataset', title: '数据集', permission: 'DATASET:READ' }, component: () => import('@/views/dataset/index.vue') }, { path: '/setting', name: 'setting', - meta: { icon: 'setting', title: '数据设置' }, + meta: { icon: 'setting', title: '数据设置', permission: 'SETTING:READ' }, component: () => import('@/views/setting/index.vue') } ] diff --git a/ui/src/router/index.ts b/ui/src/router/index.ts index a9b9520f5..920752e7a 100644 --- a/ui/src/router/index.ts +++ b/ui/src/router/index.ts @@ -1,3 +1,4 @@ +import { hasPermission } from '@/common/permission/index' import { createRouter, createWebHistory, @@ -37,11 +38,16 @@ router.beforeEach( return } if (!userStore.userInfo) { - userStore.profile() + await userStore.profile() } } - - next() + // 判断是否有菜单权限 + if (to.meta.permission ? hasPermission(to.meta.permission as any, 'OR') : true) { + next() + } else { + // 如果没有权限则直接取404页面 + next('404') + } } ) diff --git a/ui/src/stores/user.ts b/ui/src/stores/user.ts index 64c4f3b33..d0725b4c7 100644 --- a/ui/src/stores/user.ts +++ b/ui/src/stores/user.ts @@ -2,6 +2,7 @@ import { defineStore } from 'pinia' import type { User } from '@/api/user/type' import UserApi from '@/api/user' import { ref } from 'vue' + export const useUserStore = defineStore('user', () => { const userInfo = ref() // 用户认证token @@ -14,10 +15,25 @@ export const useUserStore = defineStore('user', () => { return localStorage.getItem('token') } + const getPermissions = () => { + if (userInfo.value) { + return userInfo.value.permissions + } else { + return [] + } + } + + const getRole = () => { + if (userInfo.value) { + return userInfo.value.role + } else { + return '' + } + } + const profile = () => { return UserApi.profile().then((ok) => { userInfo.value = ok.data - return ok.data }) } @@ -35,6 +51,5 @@ export const useUserStore = defineStore('user', () => { return true }) } - - return { token, getToken, userInfo, profile, login, logout } + return { token, getToken, userInfo, profile, login, logout, getPermissions, getRole } }) diff --git a/ui/src/views/first/index.vue b/ui/src/views/first/index.vue index 575c81178..11aa9c087 100644 --- a/ui/src/views/first/index.vue +++ b/ui/src/views/first/index.vue @@ -1,8 +1,16 @@ \ No newline at end of file