diff --git a/xrft/tests/test_xrft.py b/xrft/tests/test_xrft.py index 482e3432..dc661879 100644 --- a/xrft/tests/test_xrft.py +++ b/xrft/tests/test_xrft.py @@ -1306,3 +1306,23 @@ def test_reversed_coordinates(): xrt.assert_allclose( xrft.dft(s, dim="x", true_phase=True), xrft.dft(s2, dim="x", true_phase=True) ) + + +def test_nondim_coords(): + """Error should be raised if there are non-dimensional coordinates attached to the dimension(s) over which the FFT is being taken""" + N = 16 + da = xr.DataArray( + np.random.rand(2, N, N), + dims=["time", "x", "y"], + coords={ + "time": np.array(["2019-04-18", "2019-04-19"], dtype="datetime64"), + "x": range(N), + "y": range(N), + "x_nondim": ("x", np.arange(N)), + }, + ) + + with pytest.raises(ValueError): + xrft.power_spectrum(da) + + xrft.power_spectrum(da, dim=["time", "y"]) diff --git a/xrft/xrft.py b/xrft/xrft.py index ea4a9e86..1c38b6e0 100644 --- a/xrft/xrft.py +++ b/xrft/xrft.py @@ -383,6 +383,17 @@ def fft( N = [da.shape[n] for n in axis_num] + # raise error if there are multiple coordinates attached to the dimension(s) over which the FFT is taken + for d in dim: + bad_coords = [ + cname for cname in da.coords if cname != d and d in da[cname].dims + ] + if bad_coords: + raise ValueError( + f"The input array contains coordinate variable(s) ({bad_coords}) whose dims include the transform dimension(s) `{d}`. " + f"Please drop these coordinates (`.drop({bad_coords}`) before invoking xrft." + ) + # verify even spacing of input coordinates delta_x = [] lag_x = []