From 11331a60e0ed945b5be9b889df4ce62d4816d1f8 Mon Sep 17 00:00:00 2001 From: CaptainB Date: Tue, 22 Apr 2025 17:34:19 +0800 Subject: [PATCH] feat: enhance Tool API by adding module_id and tool_type parameters, and refactor query handling --- apps/tools/api/tool.py | 22 ++++++++++++------ apps/tools/migrations/0001_initial.py | 3 +++ apps/tools/models/tool.py | 7 ++++++ apps/tools/serializers/tool.py | 33 ++++++++++++++++++++++++--- apps/tools/views/tool.py | 20 ++++++++-------- 5 files changed, 64 insertions(+), 21 deletions(-) diff --git a/apps/tools/api/tool.py b/apps/tools/api/tool.py index 827476fe6..a070d3467 100644 --- a/apps/tools/api/tool.py +++ b/apps/tools/api/tool.py @@ -163,13 +163,6 @@ class ToolPageAPI(ToolReadAPI): location='path', required=True, ), - OpenApiParameter( - name="tool_id", - description="工具id", - type=OpenApiTypes.STR, - location='path', - required=True, - ), OpenApiParameter( name="current_page", description="当前页码", @@ -184,6 +177,21 @@ class ToolPageAPI(ToolReadAPI): location='path', required=True, ), + OpenApiParameter( + name="module_id", + description="模块id", + type=OpenApiTypes.STR, + location='query', + required=True, + ), + OpenApiParameter( + name="tool_type", + description="工具类型", + type=OpenApiTypes.STR, + enum=["CUSTOM", "INTERNAL"], + location='query', + required=True, + ), OpenApiParameter( name="name", description="工具名称", diff --git a/apps/tools/migrations/0001_initial.py b/apps/tools/migrations/0001_initial.py index c8e7ac471..12098a864 100644 --- a/apps/tools/migrations/0001_initial.py +++ b/apps/tools/migrations/0001_initial.py @@ -62,6 +62,9 @@ class Migration(migrations.Migration): ('scope', models.CharField(choices=[('SHARED', '共享'), ('WORKSPACE', '工作空间可用')], default='WORKSPACE', max_length=20, verbose_name='可用范围')), + ('tool_type', + models.CharField(choices=[('INTERNAL', '内置'), ('CUSTOM', '自定义')], default='CUSTOM', + max_length=20, verbose_name='函数类型', db_index=True)), ('template_id', models.UUIDField(default=None, null=True, verbose_name='模版id')), ('workspace_id', models.CharField(default='default', max_length=64, verbose_name='工作空间id', db_index=True)), ('init_params', models.CharField(max_length=102400, null=True, verbose_name='初始化参数')), diff --git a/apps/tools/models/tool.py b/apps/tools/models/tool.py index 55bcc18a8..cb4873dd9 100644 --- a/apps/tools/models/tool.py +++ b/apps/tools/models/tool.py @@ -10,6 +10,11 @@ class ToolScope(models.TextChoices): WORKSPACE = "WORKSPACE", "工作空间可用" +class ToolType(models.TextChoices): + INTERNAL = "INTERNAL", '内置' + CUSTOM = "CUSTOM", "自定义" + + class Tool(models.Model): id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid7, editable=False, verbose_name="主键id") user = models.ForeignKey(User, on_delete=models.CASCADE, verbose_name="用户id") @@ -22,6 +27,8 @@ class Tool(models.Model): is_active = models.BooleanField(default=True) scope = models.CharField(max_length=20, verbose_name='可用范围', choices=ToolScope.choices, default=ToolScope.WORKSPACE) + tool_type = models.CharField(max_length=20, verbose_name='函数类型', choices=ToolType.choices, + default=ToolType.CUSTOM, db_index=True) template_id = models.UUIDField(max_length=128, verbose_name="模版id", null=True, default=None) module = models.ForeignKey(ToolModule, on_delete=models.CASCADE, verbose_name="模块id", default='root') workspace_id = models.CharField(max_length=64, verbose_name="工作空间id", default="default", db_index=True) diff --git a/apps/tools/serializers/tool.py b/apps/tools/serializers/tool.py index 634af5c4a..3781b73ea 100644 --- a/apps/tools/serializers/tool.py +++ b/apps/tools/serializers/tool.py @@ -6,7 +6,7 @@ import re import uuid_utils.compat as uuid from django.core import validators from django.db import transaction -from django.db.models import QuerySet +from django.db.models import QuerySet, Q from django.http import HttpResponse from django.utils.translation import gettext_lazy as _ from rest_framework import serializers, status @@ -57,7 +57,7 @@ class ToolModelSerializer(serializers.ModelSerializer): class Meta: model = Tool fields = ['id', 'name', 'icon', 'desc', 'code', 'input_field_list', 'init_field_list', 'init_params', - 'scope', 'is_active', 'user_id', 'template_id', 'workspace_id', 'module_id', + 'scope', 'is_active', 'user_id', 'template_id', 'workspace_id', 'module_id', 'tool_type', 'create_time', 'update_time'] @@ -291,8 +291,35 @@ class ToolTreeSerializer(serializers.Serializer): root = ToolModule.objects.filter(id=module_id).first() if not root: raise serializers.ValidationError(_('Module not found')) - # 使用MPTT的get_family()方法获取所有相关节点 + # 使用MPTT的get_descendants()方法获取所有相关节点 all_modules = root.get_descendants(include_self=True) tools = QuerySet(Tool).filter(workspace_id=self.data.get('workspace_id'), module_id__in=all_modules) return ToolModelSerializer(tools, many=True).data + + class Query(serializers.Serializer): + workspace_id = serializers.CharField(required=True, label=_('workspace id')) + module_id = serializers.CharField(required=True, label=_('module id')) + name = serializers.CharField(required=False, allow_null=True, allow_blank=True, label=_('tool name')) + tool_type = serializers.CharField(required=True, label=_('tool type')) + + def page(self, current_page: int, page_size: int): + self.is_valid(raise_exception=True) + + module_id = self.data.get('module_id', 'root') + root = ToolModule.objects.filter(id=module_id).first() + if not root: + raise serializers.ValidationError(_('Module not found')) + # 使用MPTT的get_descendants()方法获取所有相关节点 + all_modules = root.get_descendants(include_self=True) + + if self.data.get('name'): + tools = QuerySet(Tool).filter(Q(workspace_id=self.data.get('workspace_id')) & + Q(module_id__in=all_modules) & + Q(tool_type=self.data.get('tool_type')) & + Q(name__contains=self.data.get('name'))) + else: + tools = QuerySet(Tool).filter(Q(workspace_id=self.data.get('workspace_id')) & + Q(module_id__in=all_modules) & + Q(tool_type=self.data.get('tool_type'))) + return ToolModelSerializer(tools, many=True).data diff --git a/apps/tools/views/tool.py b/apps/tools/views/tool.py index 454fc3315..e717fc136 100644 --- a/apps/tools/views/tool.py +++ b/apps/tools/views/tool.py @@ -120,17 +120,15 @@ class ToolView(APIView): tags=[_('Tool')] ) @has_permissions(PermissionConstants.TOOL_READ.get_workspace_permission()) - def get(self, request: Request, current_page: int, page_size: int): - return result.success( - ToolSerializer.Query( - data={ - 'name': request.query_params.get('name'), - 'desc': request.query_params.get('desc'), - 'function_type': request.query_params.get('function_type'), - 'user_id': request.user.id, - 'select_user_id': request.query_params.get('select_user_id') - } - ).page(current_page, page_size)) + def get(self, request: Request, workspace_id: str, current_page: int, page_size: int): + return result.success(ToolTreeSerializer.Query( + data={ + 'workspace_id': workspace_id, + 'module_id': request.query_params.get('module_id'), + 'name': request.query_params.get('name'), + 'tool_type': request.query_params.get('tool_type'), + } + ).page(current_page, page_size)) class Import(APIView): authentication_classes = [TokenAuth]