feat: enhance Tool API by adding module_id and tool_type parameters, and refactor query handling
Some checks are pending
sync2gitee / repo-sync (push) Waiting to run

This commit is contained in:
CaptainB 2025-04-22 17:34:19 +08:00
parent bbd7079166
commit 11331a60e0
5 changed files with 64 additions and 21 deletions

View File

@ -163,13 +163,6 @@ class ToolPageAPI(ToolReadAPI):
location='path',
required=True,
),
OpenApiParameter(
name="tool_id",
description="工具id",
type=OpenApiTypes.STR,
location='path',
required=True,
),
OpenApiParameter(
name="current_page",
description="当前页码",
@ -184,6 +177,21 @@ class ToolPageAPI(ToolReadAPI):
location='path',
required=True,
),
OpenApiParameter(
name="module_id",
description="模块id",
type=OpenApiTypes.STR,
location='query',
required=True,
),
OpenApiParameter(
name="tool_type",
description="工具类型",
type=OpenApiTypes.STR,
enum=["CUSTOM", "INTERNAL"],
location='query',
required=True,
),
OpenApiParameter(
name="name",
description="工具名称",

View File

@ -62,6 +62,9 @@ class Migration(migrations.Migration):
('scope',
models.CharField(choices=[('SHARED', '共享'), ('WORKSPACE', '工作空间可用')], default='WORKSPACE',
max_length=20, verbose_name='可用范围')),
('tool_type',
models.CharField(choices=[('INTERNAL', '内置'), ('CUSTOM', '自定义')], default='CUSTOM',
max_length=20, verbose_name='函数类型', db_index=True)),
('template_id', models.UUIDField(default=None, null=True, verbose_name='模版id')),
('workspace_id', models.CharField(default='default', max_length=64, verbose_name='工作空间id', db_index=True)),
('init_params', models.CharField(max_length=102400, null=True, verbose_name='初始化参数')),

View File

@ -10,6 +10,11 @@ class ToolScope(models.TextChoices):
WORKSPACE = "WORKSPACE", "工作空间可用"
class ToolType(models.TextChoices):
INTERNAL = "INTERNAL", '内置'
CUSTOM = "CUSTOM", "自定义"
class Tool(models.Model):
id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid7, editable=False, verbose_name="主键id")
user = models.ForeignKey(User, on_delete=models.CASCADE, verbose_name="用户id")
@ -22,6 +27,8 @@ class Tool(models.Model):
is_active = models.BooleanField(default=True)
scope = models.CharField(max_length=20, verbose_name='可用范围', choices=ToolScope.choices,
default=ToolScope.WORKSPACE)
tool_type = models.CharField(max_length=20, verbose_name='函数类型', choices=ToolType.choices,
default=ToolType.CUSTOM, db_index=True)
template_id = models.UUIDField(max_length=128, verbose_name="模版id", null=True, default=None)
module = models.ForeignKey(ToolModule, on_delete=models.CASCADE, verbose_name="模块id", default='root')
workspace_id = models.CharField(max_length=64, verbose_name="工作空间id", default="default", db_index=True)

View File

@ -6,7 +6,7 @@ import re
import uuid_utils.compat as uuid
from django.core import validators
from django.db import transaction
from django.db.models import QuerySet
from django.db.models import QuerySet, Q
from django.http import HttpResponse
from django.utils.translation import gettext_lazy as _
from rest_framework import serializers, status
@ -57,7 +57,7 @@ class ToolModelSerializer(serializers.ModelSerializer):
class Meta:
model = Tool
fields = ['id', 'name', 'icon', 'desc', 'code', 'input_field_list', 'init_field_list', 'init_params',
'scope', 'is_active', 'user_id', 'template_id', 'workspace_id', 'module_id',
'scope', 'is_active', 'user_id', 'template_id', 'workspace_id', 'module_id', 'tool_type',
'create_time', 'update_time']
@ -291,8 +291,35 @@ class ToolTreeSerializer(serializers.Serializer):
root = ToolModule.objects.filter(id=module_id).first()
if not root:
raise serializers.ValidationError(_('Module not found'))
# 使用MPTT的get_family()方法获取所有相关节点
# 使用MPTT的get_descendants()方法获取所有相关节点
all_modules = root.get_descendants(include_self=True)
tools = QuerySet(Tool).filter(workspace_id=self.data.get('workspace_id'), module_id__in=all_modules)
return ToolModelSerializer(tools, many=True).data
class Query(serializers.Serializer):
workspace_id = serializers.CharField(required=True, label=_('workspace id'))
module_id = serializers.CharField(required=True, label=_('module id'))
name = serializers.CharField(required=False, allow_null=True, allow_blank=True, label=_('tool name'))
tool_type = serializers.CharField(required=True, label=_('tool type'))
def page(self, current_page: int, page_size: int):
self.is_valid(raise_exception=True)
module_id = self.data.get('module_id', 'root')
root = ToolModule.objects.filter(id=module_id).first()
if not root:
raise serializers.ValidationError(_('Module not found'))
# 使用MPTT的get_descendants()方法获取所有相关节点
all_modules = root.get_descendants(include_self=True)
if self.data.get('name'):
tools = QuerySet(Tool).filter(Q(workspace_id=self.data.get('workspace_id')) &
Q(module_id__in=all_modules) &
Q(tool_type=self.data.get('tool_type')) &
Q(name__contains=self.data.get('name')))
else:
tools = QuerySet(Tool).filter(Q(workspace_id=self.data.get('workspace_id')) &
Q(module_id__in=all_modules) &
Q(tool_type=self.data.get('tool_type')))
return ToolModelSerializer(tools, many=True).data

View File

@ -120,17 +120,15 @@ class ToolView(APIView):
tags=[_('Tool')]
)
@has_permissions(PermissionConstants.TOOL_READ.get_workspace_permission())
def get(self, request: Request, current_page: int, page_size: int):
return result.success(
ToolSerializer.Query(
data={
'name': request.query_params.get('name'),
'desc': request.query_params.get('desc'),
'function_type': request.query_params.get('function_type'),
'user_id': request.user.id,
'select_user_id': request.query_params.get('select_user_id')
}
).page(current_page, page_size))
def get(self, request: Request, workspace_id: str, current_page: int, page_size: int):
return result.success(ToolTreeSerializer.Query(
data={
'workspace_id': workspace_id,
'module_id': request.query_params.get('module_id'),
'name': request.query_params.get('name'),
'tool_type': request.query_params.get('tool_type'),
}
).page(current_page, page_size))
class Import(APIView):
authentication_classes = [TokenAuth]