diff --git a/apps/models_provider/serializers/model_serializer.py b/apps/models_provider/serializers/model_serializer.py index 2a5ee3088..6270420e7 100644 --- a/apps/models_provider/serializers/model_serializer.py +++ b/apps/models_provider/serializers/model_serializer.py @@ -8,7 +8,7 @@ import uuid_utils.compat as uuid from django.db.models import QuerySet from django.utils.translation import gettext_lazy as _ from rest_framework import serializers - +from django.db.models.query_utils import Q from common.config.embedding_config import ModelManage from common.database_model_manage.database_model_manage import DatabaseModelManage from common.exception.app_exception import AppApiException @@ -413,16 +413,26 @@ class ModelSerializer(serializers.Serializer): def get_authorized_tool(tool_query_set, workspace_id, model_workspace_authorization): + # 对所有工作空间拉黑的工具 + non_auths = QuerySet(model_workspace_authorization).filter( + Q(workspace_id='None') & Q(authentication_type='WHITE_LIST') + ).values_list('model_id', flat=True) + # 授权给所有工作空间的工具 + all_auths = QuerySet(model_workspace_authorization).filter( + Q(workspace_id='None') & Q(authentication_type='BLACK_LIST') + ).values_list('model_id', flat=True) + # 查询白名单授权的工具 white_authorized_tool_ids = QuerySet(model_workspace_authorization).filter( workspace_id=workspace_id, authentication_type='WHITE_LIST' ).values_list('model_id', flat=True) + # 查询黑名单授权的工具 black_authorized_tool_ids = QuerySet(model_workspace_authorization).filter( workspace_id=workspace_id, authentication_type='BLACK_LIST' ).values_list('model_id', flat=True) tool_query_set = tool_query_set.filter( - id__in=white_authorized_tool_ids + id__in=list(white_authorized_tool_ids) + list(all_auths) ).exclude( - id__in=black_authorized_tool_ids + id__in=list(black_authorized_tool_ids) + list(non_auths) ) return tool_query_set