Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace push implementation with map_overlap for Dask #9712

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

phofl
Copy link
Contributor

@phofl phofl commented Nov 4, 2024

  • Tests added
  • User visible changes (including notable bug fixes) are documented in whats-new.rst
  • New functions/methods are listed in api.rst

Our benchmarks here showed us that ffill alone adds 4.5 million tasks to the graph which isn't great (the dataset has 550k chunks, so a multiplication of 9).

Rewriting this with map_overlap gets this down to 1.5 million tasks, which is basically the number of chunks times 3, which is the minimum that we can get to at the moment.

We merged a few map_overlap improvements today on the dask side to make this possible, but it's now a nice improvement (also makes code on the xarray side easier).

cc @dcherian

@dcherian
Copy link
Contributor

dcherian commented Nov 4, 2024

This can't work in general.

We switched from map_overlap to cumreduction intentionally: #6118 . I fully support improving cumreduction. It's a fundamental parallel primitive.

@dcherian
Copy link
Contributor

dcherian commented Nov 4, 2024

See this test:

@requires_dask
@requires_bottleneck
def test_push_dask():
import bottleneck
import dask.array
array = np.array([np.nan, 1, 2, 3, np.nan, np.nan, np.nan, np.nan, 4, 5, np.nan, 6])
for n in [None, 1, 2, 3, 4, 5, 11]:
expected = bottleneck.push(array, axis=0, n=n)
for c in range(1, 11):
with raise_if_dask_computes():
actual = push(dask.array.from_array(array, chunks=c), axis=0, n=n)
np.testing.assert_equal(actual, expected)
# some chunks of size-1 with NaN
with raise_if_dask_computes():
actual = push(
dask.array.from_array(array, chunks=(1, 2, 3, 2, 2, 1, 1)), axis=0, n=n
)
np.testing.assert_equal(actual, expected)

@phofl
Copy link
Contributor Author

phofl commented Nov 4, 2024

Oh I didn't consider None as limit, sorry about that

@dcherian
Copy link
Contributor

dcherian commented Nov 4, 2024

You could experiment with method="blelloch" vs method="sequential" in cumreduction.

@phofl
Copy link
Contributor Author

phofl commented Nov 4, 2024

I thought more about this, I didn't consider large limit values properly. cumreduction itself works ok-ish, the issue that makes the task-graph that large is the section here:


     if n is not None and 0 < n < array.shape[axis] - 1:
         arange = da.broadcast_to(
             da.arange(
                 array.shape[axis], chunks=array.chunks[axis], dtype=array.dtype
             ).reshape(
                 tuple(size if i == axis else 1 for i, size in enumerate(array.shape))
             ),
             array.shape,
             array.chunks,
         )
         valid_arange = da.where(da.notnull(array), arange, np.nan)
         valid_limits = (arange - push(valid_arange, None, axis)) <= n
         # omit the forward fill that violate the limit
         return da.where(valid_limits, push(array, None, axis), np.nan)

Bunch of operations that can't be fused properly because of the interactions between the 2 different arrays. I'll think a bit more if we can reduce this down somehow, but there isn't anything obvious right away (at least not to me).

cumreduction itself is equivalent to map_overlap from a topology perspective if the overlapping part only reaches a single neighbouring chunk, would you be open to calling overlap in these cases? Makes it a bit uglier, but that seems to be a reasonably common use-case from what I have seen so far (I might be totally wrong here)?

@dcherian
Copy link
Contributor

dcherian commented Nov 4, 2024

Can you add push to dask.array? Then you can add whatever optimization you want :). We'd be happy to dispatch instead of vendoring all this if we can.

Also, you should be able to write this as a single cumreduction that takes a 1D array of axis-indices and the input array as inputs. I wrote it for grouped ffill in flox: https://github.com/xarray-contrib/flox/blob/672be8ceeebfa588a15ebdc9861999efa12fa44e/flox/aggregations.py#L651

@phofl
Copy link
Contributor Author

phofl commented Nov 4, 2024

Oh, I'll check that one out.

Sure, there shouldn't be any reason not to add this. I'll check out the flox implementation, getting the number of tasks down would be nice, but adding it in Dask should be a good option anyway

@dcherian
Copy link
Contributor

dcherian commented Nov 4, 2024

Ah I guess the core issue is a dependency on numbagg and/or bottleneck.

@phofl
Copy link
Contributor Author

phofl commented Nov 4, 2024

That should be fine, we can just raise if neither is installed, similar to what you are doing here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants