Skip to content

Commit

Permalink
Merge pull request #190 from lanougue/not_numerical_coords
Browse files Browse the repository at this point in the history
Explicit error when coordinates are not numercial
  • Loading branch information
roxyboy committed Feb 8, 2023
2 parents 8baebcb + 74730fc commit f9ab8d2
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 8 deletions.
18 changes: 18 additions & 0 deletions xrft/tests/test_xrft.py
Original file line number Diff line number Diff line change
Expand Up @@ -1344,3 +1344,21 @@ def test_nondim_coords():
xrft.power_spectrum(da)

xrft.power_spectrum(da, dim=["time", "y"])


def test_non_numerical_or_datetime_coords():
"""Error should be raised if there are non-numerical or non-datetime coordinate"""
da = xr.DataArray(
np.random.rand(2, 5, 3),
dims=["time", "x", "y"],
coords={
"time": np.array(["2019-04-18", "2019-04-19"], dtype="datetime64"),
"x": range(5),
"y": ["a", "b", "c"],
},
)

with pytest.raises(ValueError):
xrft.power_spectrum(da)

xrft.power_spectrum(da, dim=["time", "x"])
45 changes: 37 additions & 8 deletions xrft/xrft.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import scipy.linalg as spl

from .detrend import detrend as _detrend

from pandas.api.types import is_numeric_dtype, is_datetime64_any_dtype

__all__ = [
"fft",
Expand Down Expand Up @@ -230,9 +230,9 @@ def _lag_coord(coord):
decoded_time = cftime.date2num(lag, ref_units, calendar)
return decoded_time
elif pd.api.types.is_datetime64_dtype(v0):
return lag.astype("timedelta64[s]").astype("f8").data
return lag.astype("timedelta64[s]").astype("f8")
else:
return lag.data
return lag


def dft(
Expand Down Expand Up @@ -330,7 +330,6 @@ def fft(
daft : `xarray.DataArray`
The output of the Fourier transformation, with appropriate dimensions.
"""

if dim is None:
dim = list(da.dims)
else:
Expand All @@ -352,6 +351,20 @@ def fft(
real_dim
] # real dim has to be moved or added at the end !

if not np.all(
[
(
is_numeric_dtype(da.coords[d])
or is_datetime64_any_dtype(da.coords[d])
or bool(getattr(da.coords[d][0].item(), "calendar", False))
)
for d in dim
]
): # checking if coodinates are numerical or datetime
raise ValueError(
"All transformed dimensions coordinates must be numerical or datetime."
)

if chunks_to_segments:
da = _stack_chunks(da, dim)

Expand Down Expand Up @@ -452,7 +465,7 @@ def fft(
dims=up_dim,
coords={up_dim: newcoords[up_dim]},
) # taking advantage of xarray broadcasting and ordered coordinates
daft[up_dim].attrs.update({"direct_lag": lag.obj})
daft[up_dim].attrs.update({"direct_lag": lag})

if true_amplitude:
daft = daft * np.prod(delta_x)
Expand Down Expand Up @@ -520,7 +533,6 @@ def ifft(
da : `xarray.DataArray`
The output of the Inverse Fourier transformation, with appropriate dimensions.
"""

if dim is None:
dim = list(daft.dims)
else:
Expand All @@ -540,6 +552,21 @@ def ifft(
dim = [d for d in dim if d != real_dim] + [
real_dim
] # real dim has to be moved or added at the end !

if not np.all(
[
(
is_numeric_dtype(daft.coords[d])
or is_datetime64_any_dtype(daft.coords[d])
or bool(getattr(daft.coords[d][0].item(), "calendar", False))
)
for d in dim
]
): # checking if coodinates are numerical or datetime
raise ValueError(
"All transformed dimensions coordinates must be numerical or datetime."
)

if lag is None:
lag = [daft[d].attrs.get("direct_lag", 0.0) for d in dim]
msg = "Default ifft's behaviour (lag=None) changed! Default value of lag was zero (centered output coordinates) and is now set to transformed coordinate's attribute: 'direct_lag'."
Expand Down Expand Up @@ -898,8 +925,10 @@ def cross_phase(da1, da2, dim=None, true_phase=True, **kwargs):
kwargs : dict : see xrft.fft for argument list
"""

cp = xr.ufuncs.angle(
cross_spectrum(da1, da2, dim=dim, true_phase=true_phase, **kwargs)
cp = xr.apply_ufunc(
np.angle,
cross_spectrum(da1, da2, dim=dim, true_phase=true_phase, **kwargs),
dask="allowed",
)

if da1.name and da2.name:
Expand Down

0 comments on commit f9ab8d2

Please sign in to comment.