Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace push implementation with map_overlap for Dask #9712

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 2 additions & 31 deletions xarray/core/dask_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,37 +59,8 @@ def push(array, n, axis):
"""
Dask-aware bottleneck.push
"""
import dask.array as da
import numpy as np

from xarray.core.duck_array_ops import _push

def _fill_with_last_one(a, b):
# cumreduction apply the push func over all the blocks first so, the only missing part is filling
# the missing values using the last data of the previous chunk
return np.where(~np.isnan(b), b, a)
n = min(n, array.shape[axis])

if n is not None and 0 < n < array.shape[axis] - 1:
arange = da.broadcast_to(
da.arange(
array.shape[axis], chunks=array.chunks[axis], dtype=array.dtype
).reshape(
tuple(size if i == axis else 1 for i, size in enumerate(array.shape))
),
array.shape,
array.chunks,
)
valid_arange = da.where(da.notnull(array), arange, np.nan)
valid_limits = (arange - push(valid_arange, None, axis)) <= n
# omit the forward fill that violate the limit
return da.where(valid_limits, push(array, None, axis), np.nan)

# The method parameter makes that the tests for python 3.7 fails.
return da.reductions.cumreduction(
func=_push,
binop=_fill_with_last_one,
ident=np.nan,
x=array,
axis=axis,
dtype=array.dtype,
)
return array.map_overlap(_push, depth={axis: (n, 0)}, n=n, axis=axis)
18 changes: 18 additions & 0 deletions xarray/tests/test_missing.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,24 @@ def test_ffill_bfill_dask(method):
assert_equal(actual, expected)


@requires_bottleneck
@requires_dask
@pytest.mark.parametrize("method", ["ffill", "bfill"])
def test_ffill_bfill_nans_on_chunk_edge(method):
arr = np.array(
[[1, 2], [np.nan, 4], [5, 6], [7, 8], [np.nan, 10], [11, 12]], dtype=np.float64
)
da = xr.DataArray(arr, dims=("time", "x")).chunk({"time": 2, "x": -1})

dask_method = getattr(da, method)
numpy_method = getattr(da.compute(), method)

with raise_if_dask_computes():
actual = dask_method("x", limit=41)
expected = numpy_method("x", limit=41)
assert_equal(actual, expected)


@requires_bottleneck
def test_ffill_bfill_nonans():
vals = np.array([1, 2, 3, 4, 5, 6], dtype=np.float64)
Expand Down
Loading