diff --git a/jax_cfd/spectral/utils.py b/jax_cfd/spectral/utils.py index b40ee15..97a6b49 100644 --- a/jax_cfd/spectral/utils.py +++ b/jax_cfd/spectral/utils.py @@ -127,10 +127,10 @@ def circular_filter_2d(grid: grids.Grid) -> spectral_types.Array: def brick_wall_filter_2d(grid: grids.Grid): """Implements the 2/3 rule.""" - n, _ = grid.shape - filter_ = jnp.zeros((n, n // 2 + 1)) - filter_ = filter_.at[:int(2 / 3 * n) // 2, :int(2 / 3 * (n // 2 + 1))].set(1) - filter_ = filter_.at[-int(2 / 3 * n) // 2:, :int(2 / 3 * (n // 2 + 1))].set(1) + n, m = grid.shape + filter_ = jnp.zeros((n, m // 2 + 1)) + filter_ = filter_.at[:int(2 / 3 * n) // 2, :int(2 / 3 * (m // 2 + 1))].set(1) + filter_ = filter_.at[-int(2 / 3 * n) // 2:, :int(2 / 3 * (m // 2 + 1))].set(1) return filter_