Skip to content

Add chunks='auto' support for cftime datasets #10527

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions xarray/namedarray/daskmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import numpy as np

from xarray.core.common import _contains_cftime_datetimes
from xarray.core.indexing import ImplicitToExplicitIndexingAdapter
from xarray.namedarray.parallelcompat import ChunkManagerEntrypoint, T_ChunkedArray
from xarray.namedarray.utils import is_duck_dask_array, module_available
Expand All @@ -16,6 +17,7 @@
_NormalizedChunks,
duckarray,
)
from xarray.namedarray.parallelcompat import _Chunks

try:
from dask.array import Array as DaskArray
Expand Down Expand Up @@ -264,3 +266,56 @@ def shuffle(
if chunks != "auto":
raise NotImplementedError("Only chunks='auto' is supported at present.")
return dask.array.shuffle(x, indexer, axis, chunks="auto")

def rechunk( # type: ignore[override]
self,
data: T_ChunkedArray,
chunks: _NormalizedChunks | tuple[int, ...] | _Chunks,
**kwargs: Any,
) -> Any:
"""
Changes the chunking pattern of the given array.

Called when the .chunk method is called on an xarray object that is already chunked.

Parameters
----------
data : dask array
Array to be rechunked.
chunks : int, tuple, dict or str, optional
The new block dimensions to create. -1 indicates the full size of the
corresponding dimension. Default is "auto" which automatically
determines chunk sizes.

Returns
-------
chunked array

See Also
--------
dask.array.Array.rechunk
cubed.Array.rechunk
"""

if _contains_cftime_datetimes(data):
# Preprocess chunks if they're cftime
cftime_nbytes_approx = 64
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
cftime_nbytes_approx = 64
cftime_nbytes_approx = sys.getsizeof(first_n_items(data, 1))

first_n_items is in xarray.core.formatting

from dask import config as dask_config
from dask.utils import parse_bytes

target_chunksize = parse_bytes(dask_config.get("array.chunk-size"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very nice; everything but this line is dask-specific. Can we abstract all this logic to a helper function (say in namedarray.utils that takes target_chunksize as input? Then we can reuse it for cubed too


# Calculate total elements per chunk
elements_per_chunk = target_chunksize // cftime_nbytes_approx

# Distribute elements across dimensions
# Simple approach: try to make chunks roughly cubic
ndim = data.ndim # type:ignore
shape = data.shape # type:ignore
if ndim > 0:
chunk_size_per_dim = int(elements_per_chunk ** (1.0 / ndim))
chunks = tuple(min(chunk_size_per_dim, dim_size) for dim_size in shape)
else:
chunks = ()

return data.rechunk(chunks, **kwargs)
31 changes: 31 additions & 0 deletions xarray/tests/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -1060,6 +1060,29 @@ def make_da():
return da


def make_da_cftime():
yrs = np.arange(2000, 2120)
cftime_dates = xr.date_range(
start=f"{yrs[0]}-01-01",
end=f"{yrs[-1]}-12-31",
freq="1YE",
use_cftime=True,
)
yr_array = np.tile(cftime_dates.values, (10, 1))
da = xr.DataArray(
yr_array,
dims=["x", "t"],
coords={"x": np.arange(10), "t": cftime_dates},
name="a",
).chunk({"x": 4, "t": 5})
da.x.attrs["long_name"] = "x"
da.attrs["test"] = "test"
da.coords["c2"] = 0.5
da.coords["ndcoord"] = da.x * 2

return da


def make_ds():
map_ds = xr.Dataset()
map_ds["a"] = make_da()
Expand Down Expand Up @@ -1141,6 +1164,14 @@ def test_auto_chunk_da(obj):
assert actual.chunks == expected.chunks


@pytest.mark.parametrize("obj", [make_da_cftime()])
def test_auto_chunk_da_cftime(obj):
actual = obj.chunk("auto").data
expected = obj.data.rechunk({0: 10, 1: 120})
np.testing.assert_array_equal(actual, expected)
assert actual.chunks == expected.chunks


def test_map_blocks_error(map_da, map_ds):
def bad_func(darray):
return (darray * darray.x + 5 * darray.y)[:1, :1]
Expand Down
Loading