Skip to content

Commit

Permalink
Generalize cumulative reduction (scan) to non-dask types (#8019)
Browse files Browse the repository at this point in the history
* add scan to ChunkManager ABC

* implement scan for dask using cumreduction

* generalize push to work for non-dask chunked arrays

* whatsnew

* fix importerror

* Allow arbitrary kwargs

Co-authored-by: Deepak Cherian <[email protected]>

* Type hint return value of T_ChunkedArray

Co-authored-by: Illviljan <[email protected]>

* Type hint return value of Dask array

* ffill -> bfill in doc/whats-new.rst

Co-authored-by: Deepak Cherian <[email protected]>

* hopefully fix docs warning

---------

Co-authored-by: Deepak Cherian <[email protected]>
Co-authored-by: Illviljan <[email protected]>
  • Loading branch information
3 people authored Dec 18, 2023
1 parent 2971994 commit 766da34
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 0 deletions.
4 changes: 4 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,10 @@ Internal Changes

- :py:func:`as_variable` now consistently includes the variable name in any exceptions
raised. (:pull:`7995`). By `Peter Hill <https://github.com/ZedThree>`_
- Redirect cumulative reduction functions internally through the :py:class:`ChunkManagerEntryPoint`,
potentially allowing :py:meth:`~xarray.DataArray.ffill` and :py:meth:`~xarray.DataArray.bfill` to
use non-dask chunked array types.
(:pull:`8019`) By `Tom Nicholas <https://github.com/TomNicholas>`_.
- :py:func:`encode_dataset_coordinates` now sorts coordinates automatically assigned to
`coordinates` attributes during serialization (:issue:`8026`, :pull:`8034`).
`By Ian Carroll <https://github.com/itcarroll>`_.
Expand Down
22 changes: 22 additions & 0 deletions xarray/core/daskmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,28 @@ def reduction(
keepdims=keepdims,
)

def scan(
self,
func: Callable,
binop: Callable,
ident: float,
arr: T_ChunkedArray,
axis: int | None = None,
dtype: np.dtype | None = None,
**kwargs,
) -> DaskArray:
from dask.array.reductions import cumreduction

return cumreduction(
func,
binop,
ident,
arr,
axis=axis,
dtype=dtype,
**kwargs,
)

def apply_gufunc(
self,
func: Callable,
Expand Down
37 changes: 37 additions & 0 deletions xarray/core/parallelcompat.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,43 @@ def reduction(
"""
raise NotImplementedError()

def scan(
self,
func: Callable,
binop: Callable,
ident: float,
arr: T_ChunkedArray,
axis: int | None = None,
dtype: np.dtype | None = None,
**kwargs,
) -> T_ChunkedArray:
"""
General version of a 1D scan, also known as a cumulative array reduction.
Used in ``ffill`` and ``bfill`` in xarray.
Parameters
----------
func: callable
Cumulative function like np.cumsum or np.cumprod
binop: callable
Associated binary operator like ``np.cumsum->add`` or ``np.cumprod->mul``
ident: Number
Associated identity like ``np.cumsum->0`` or ``np.cumprod->1``
arr: dask Array
axis: int, optional
dtype: dtype
Returns
-------
Chunked array
See also
--------
dask.array.cumreduction
"""
raise NotImplementedError()

@abstractmethod
def apply_gufunc(
self,
Expand Down

0 comments on commit 766da34

Please sign in to comment.