diff --git a/xarray/core/dask_array_ops.py b/xarray/core/dask_array_ops.py index 98ff9002856..0feb4d41345 100644 --- a/xarray/core/dask_array_ops.py +++ b/xarray/core/dask_array_ops.py @@ -59,37 +59,8 @@ def push(array, n, axis): """ Dask-aware bottleneck.push """ - import dask.array as da - import numpy as np - from xarray.core.duck_array_ops import _push - def _fill_with_last_one(a, b): - # cumreduction apply the push func over all the blocks first so, the only missing part is filling - # the missing values using the last data of the previous chunk - return np.where(~np.isnan(b), b, a) + n = min(n, array.shape[axis]) - if n is not None and 0 < n < array.shape[axis] - 1: - arange = da.broadcast_to( - da.arange( - array.shape[axis], chunks=array.chunks[axis], dtype=array.dtype - ).reshape( - tuple(size if i == axis else 1 for i, size in enumerate(array.shape)) - ), - array.shape, - array.chunks, - ) - valid_arange = da.where(da.notnull(array), arange, np.nan) - valid_limits = (arange - push(valid_arange, None, axis)) <= n - # omit the forward fill that violate the limit - return da.where(valid_limits, push(array, None, axis), np.nan) - - # The method parameter makes that the tests for python 3.7 fails. - return da.reductions.cumreduction( - func=_push, - binop=_fill_with_last_one, - ident=np.nan, - x=array, - axis=axis, - dtype=array.dtype, - ) + return array.map_overlap(_push, depth={axis: (n, 0)}, n=n, axis=axis) diff --git a/xarray/tests/test_missing.py b/xarray/tests/test_missing.py index 58d8a9dcf5d..720329a9af2 100644 --- a/xarray/tests/test_missing.py +++ b/xarray/tests/test_missing.py @@ -498,6 +498,24 @@ def test_ffill_bfill_dask(method): assert_equal(actual, expected) +@requires_bottleneck +@requires_dask +@pytest.mark.parametrize("method", ["ffill", "bfill"]) +def test_ffill_bfill_nans_on_chunk_edge(method): + arr = np.array( + [[1, 2], [np.nan, 4], [5, 6], [7, 8], [np.nan, 10], [11, 12]], dtype=np.float64 + ) + da = xr.DataArray(arr, dims=("time", "x")).chunk({"time": 2, "x": -1}) + + dask_method = getattr(da, method) + numpy_method = getattr(da.compute(), method) + + with raise_if_dask_computes(): + actual = dask_method("x", limit=41) + expected = numpy_method("x", limit=41) + assert_equal(actual, expected) + + @requires_bottleneck def test_ffill_bfill_nonans(): vals = np.array([1, 2, 3, 4, 5, 6], dtype=np.float64)