Skip to content

Commit

Permalink
Refactor detrend (#118)
Browse files Browse the repository at this point in the history
* first pass at refactoring detrend

* remove comments

* fix chunking

* add 3D case

* add error checking; implement 2D detrend

* add more 2D test cases

* black

* refactor xrft and update tests

* fix typo in detrend test

* doc updates
  • Loading branch information
rabernat committed Nov 16, 2020
1 parent db176c1 commit 3859d63
Show file tree
Hide file tree
Showing 6 changed files with 241 additions and 412 deletions.
15 changes: 15 additions & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,23 @@ This page provides an auto-generated summary of xrft's API. For more details
and examples, refer to the relevant chapters in the main part of the
documentation.


.. note::

None of xrft's functions will work correctly in the presence of NaNs or
missing data. It's the user's responsibility to ensure data are free of NaN
or that NaNs have been filled somehow.

xrft
====

.. automodule:: xrft.xrft
:members:

detrend
=======

You also may wish to use xrft's detrend function on its own.

.. automodule:: xrft.detrend
:members:
1 change: 1 addition & 0 deletions xrft/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
del get_versions

from .xrft import * # noqa
from .detrend import detrend
96 changes: 96 additions & 0 deletions xrft/detrend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
"""
Functions for detrending xarray data.
"""

import numpy as np
import xarray as xr
import scipy.signal as sps
import scipy.linalg as spl


def detrend(da, dim, detrend_type="constant"):
"""Detrend a DataArray
Parameters
----------
da : xarray.DataArray
The data to detrend
dim : str or list
Dimensions along which to apply detrend.
Can be either one dimension or a list with two dimensions.
Higher-dimensional detrending is not supported.
If dask data are passed, the data must be chunked along dim.
detrend_type : {'constant', 'linear'}
If ``constant``, a constant offset will be removed from each dim.
If ``linear``, a linear least-squares fit will be estimated and removed
from the data.
Returns
-------
da : xarray.DataArray
The detrended data.
Notes
-----
This function will act lazily in the presence of dask arrays on the
input.
"""

if detrend_type not in ["constant", "linear", None]:
raise NotImplementedError(
"%s is not a valid detrending option. Valid "
"options are: 'constant','linear', or None." % detrend_type
)

if detrend_type is None:
return da
elif detrend_type == "constant":
return da - da.mean(dim=dim)
elif detrend_type == "linear":
data = da.data
axis_num = [da.get_axis_num(d) for d in dim]
chunks = getattr(data, "chunks", None)
if chunks:
axis_chunks = [data.chunks[a] for a in axis_num]
if not all([len(ac) == 1 for ac in axis_chunks]):
raise ValueError("Contiguous chunks required for detrending.")
if len(dim) == 1:
dt = xr.apply_ufunc(
sps.detrend,
da,
axis_num[0],
output_dtypes=[da.dtype],
dask="parallelized",
)
elif len(dim) == 2:
dt = xr.apply_ufunc(
_detrend_2d_ufunc,
da,
input_core_dims=[dim],
output_core_dims=[dim],
output_dtypes=[da.dtype],
vectorize=True,
dask="parallelized",
)
else: # pragma: no cover
raise NotImplementedError(
"Only 1D and 2D detrending are implemented so far."
)

return dt


def _detrend_2d_ufunc(arr):
assert arr.ndim == 2
N = arr.shape

col0 = np.ones(N[0] * N[1])
col1 = np.repeat(np.arange(N[0]), N[1]) + 1
col2 = np.tile(np.arange(N[1]), N[0]) + 1
G = np.stack([col0, col1, col2]).transpose()

d_obs = np.reshape(arr, (N[0] * N[1], 1))
m_est = np.dot(np.dot(spl.inv(np.dot(G.T, G)), G.T), d_obs)
d_est = np.dot(G, m_est)
linear_fit = np.reshape(d_est, N)
return arr - linear_fit
106 changes: 106 additions & 0 deletions xrft/tests/test_detrend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import numpy as np
import xarray as xr
import scipy.signal as sps

import pytest
import numpy.testing as npt
import xarray.testing as xrt

import xrft
from xrft.detrend import detrend


def detrended_noise(N, amplitude=1.0):
return sps.detrend(amplitude * np.random.rand(N))


def noise(dims, shape):
assert len(dims) == len(shape)
coords = {d: (d, np.arange(n)) for d, n in zip(dims, shape)}
data = np.random.rand(*shape)
for n in range(len(shape)):
data = sps.detrend(data, n)
da = xr.DataArray(data, dims=dims, coords=coords)
return da


@pytest.mark.parametrize(
"array_dims, array_shape, detrend_dim, chunks, linear_error",
(
(["x"], [16], "x", None, None),
(["y", "x"], [32, 16], "x", None, None),
(["y", "x"], [32, 16], "x", {"y": 4}, None),
(["y", "x"], [32, 16], "y", None, None),
(["y", "x"], [32, 16], "y", {"x": 4}, None),
(["time", "y", "x"], [4, 32, 16], "x", None, None),
(["time", "y", "x"], [4, 32, 16], "x", {"y": 4}, None),
(["time", "y", "x"], [4, 32, 16], "x", {"time": 1, "y": 4}, None),
# error cases for linear detrending
(["x"], [16], "x", {"x": 1}, ValueError),
(["y", "x"], [32, 16], "x", {"x": 4}, ValueError),
),
)
@pytest.mark.parametrize("detrend_type", [None, "constant", "linear"])
@pytest.mark.parametrize("trend_amplitude", [0.01, 100])
def test_detrend_1D(
array_dims,
array_shape,
detrend_dim,
chunks,
detrend_type,
trend_amplitude,
linear_error,
):
da_original = noise(array_dims, array_shape)
da_trend = da_original + trend_amplitude * da_original[detrend_dim]
if chunks:
da_trend = da_trend.chunk(chunks)

# bail out if we are expecting an error
if detrend_type == "linear" and linear_error:
with pytest.raises(linear_error):
detrend(da_trend, detrend_dim, detrend_type=detrend_type)
return

detrended = detrend(da_trend, detrend_dim, detrend_type=detrend_type)
assert detrended.chunks == da_trend.chunks
if detrend_type is None:
xrt.assert_equal(detrended, da_trend)
elif detrend_type == "constant":
xrt.assert_allclose(detrended, da_trend - da_trend.mean(dim=detrend_dim))
elif detrend_type == "linear":
xrt.assert_allclose(detrended, da_original)


# always detrend on x y dims
@pytest.mark.parametrize(
"array_dims, array_shape, chunks",
(
(["y", "x"], [32, 16], None),
(["z", "y", "x"], [2, 32, 16], None),
(["z", "y", "x"], [2, 32, 16], {"z": 1}),
),
)
@pytest.mark.parametrize("detrend_type", [None, "constant", "linear"])
@pytest.mark.parametrize(
"trend_amplitude", [{"x": 0.1, "y": 0.1}, {"x": 10.0, "y": 0.01}]
)
def test_detrend_2D(array_dims, array_shape, chunks, detrend_type, trend_amplitude):
da_original = noise(array_dims, array_shape)
da_trend = (
da_original
+ trend_amplitude["x"] * da_original["x"]
+ trend_amplitude["y"] * da_original["y"]
)
if chunks:
da_trend = da_trend.chunk(chunks)

detrend_dim = ["y", "x"]
detrended = detrend(da_trend, detrend_dim, detrend_type=detrend_type)
assert detrended.chunks == da_trend.chunks
if detrend_type is None:
xrt.assert_equal(detrended, da_trend)
elif detrend_type == "constant":
xrt.assert_allclose(detrended, da_trend - da_trend.mean(dim=detrend_dim))
elif detrend_type == "linear":
xrt.assert_allclose(detrended, da_original)
Loading

0 comments on commit 3859d63

Please sign in to comment.