diff --git a/apps/models_provider/serializers/model_serializer.py b/apps/models_provider/serializers/model_serializer.py index 13b1f20e0..d970e1772 100644 --- a/apps/models_provider/serializers/model_serializer.py +++ b/apps/models_provider/serializers/model_serializer.py @@ -325,16 +325,18 @@ class ModelSerializer(serializers.Serializer): create_user = serializers.CharField(required=False, label=_('create user')) workspace_id = serializers.CharField(required=False, label=_('workspace id')) - def list(self, with_valid): + def list(self, workspace_id, with_valid): if with_valid: self.is_valid(raise_exception=True) - query_params = self._build_query_params() + query_params = self._build_query_params(workspace_id) return self._fetch_models(query_params) - def _build_query_params(self): + def _build_query_params(self, workspace_id): query_params = {} - for field in ['name', 'model_type', 'model_name', 'provider', 'create_user', 'workspace_id']: + if workspace_id: + query_params['workspace_id'] = workspace_id + for field in ['name', 'model_type', 'model_name', 'provider', 'create_user']: value = self.data.get(field) if value is not None: if field == 'name': diff --git a/apps/models_provider/views/model.py b/apps/models_provider/views/model.py index b3ea3fd9d..f3df85cac 100644 --- a/apps/models_provider/views/model.py +++ b/apps/models_provider/views/model.py @@ -56,11 +56,11 @@ class Model(APIView): responses=ModelListResponse.get_response(), tags=[_('Model')]) @has_permissions(PermissionConstants.MODEL_READ.get_workspace_permission()) - def get(self, request: Request): + def get(self, request: Request, workspace_id: str): return result.success( ModelSerializer.Query( - data={**query_params_to_single_dict(request.query_params)}).list( - with_valid=True)) + data={**query_params_to_single_dict(request.query_params)}).list(workspace_id=workspace_id, + with_valid=True)) class Operate(APIView): authentication_classes = [TokenAuth]