diff --git a/xarray/namedarray/daskmanager.py b/xarray/namedarray/daskmanager.py index 6485ba375f5..6c9128c93b3 100644 --- a/xarray/namedarray/daskmanager.py +++ b/xarray/namedarray/daskmanager.py @@ -5,6 +5,7 @@ import numpy as np +from xarray.core.common import _contains_cftime_datetimes from xarray.core.indexing import ImplicitToExplicitIndexingAdapter from xarray.namedarray.parallelcompat import ChunkManagerEntrypoint, T_ChunkedArray from xarray.namedarray.utils import is_duck_dask_array, module_available @@ -16,6 +17,7 @@ _NormalizedChunks, duckarray, ) + from xarray.namedarray.parallelcompat import _Chunks try: from dask.array import Array as DaskArray @@ -264,3 +266,56 @@ def shuffle( if chunks != "auto": raise NotImplementedError("Only chunks='auto' is supported at present.") return dask.array.shuffle(x, indexer, axis, chunks="auto") + + def rechunk( # type: ignore[override] + self, + data: T_ChunkedArray, + chunks: _NormalizedChunks | tuple[int, ...] | _Chunks, + **kwargs: Any, + ) -> Any: + """ + Changes the chunking pattern of the given array. + + Called when the .chunk method is called on an xarray object that is already chunked. + + Parameters + ---------- + data : dask array + Array to be rechunked. + chunks : int, tuple, dict or str, optional + The new block dimensions to create. -1 indicates the full size of the + corresponding dimension. Default is "auto" which automatically + determines chunk sizes. + + Returns + ------- + chunked array + + See Also + -------- + dask.array.Array.rechunk + cubed.Array.rechunk + """ + + if _contains_cftime_datetimes(data): + # Preprocess chunks if they're cftime + cftime_nbytes_approx = 64 + from dask import config as dask_config + from dask.utils import parse_bytes + + target_chunksize = parse_bytes(dask_config.get("array.chunk-size")) + + # Calculate total elements per chunk + elements_per_chunk = target_chunksize // cftime_nbytes_approx + + # Distribute elements across dimensions + # Simple approach: try to make chunks roughly cubic + ndim = data.ndim # type:ignore + shape = data.shape # type:ignore + if ndim > 0: + chunk_size_per_dim = int(elements_per_chunk ** (1.0 / ndim)) + chunks = tuple(min(chunk_size_per_dim, dim_size) for dim_size in shape) + else: + chunks = () + + return data.rechunk(chunks, **kwargs) diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 9024f2ae677..68a93dfc9e2 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -1060,6 +1060,29 @@ def make_da(): return da +def make_da_cftime(): + yrs = np.arange(2000, 2120) + cftime_dates = xr.date_range( + start=f"{yrs[0]}-01-01", + end=f"{yrs[-1]}-12-31", + freq="1YE", + use_cftime=True, + ) + yr_array = np.tile(cftime_dates.values, (10, 1)) + da = xr.DataArray( + yr_array, + dims=["x", "t"], + coords={"x": np.arange(10), "t": cftime_dates}, + name="a", + ).chunk({"x": 4, "t": 5}) + da.x.attrs["long_name"] = "x" + da.attrs["test"] = "test" + da.coords["c2"] = 0.5 + da.coords["ndcoord"] = da.x * 2 + + return da + + def make_ds(): map_ds = xr.Dataset() map_ds["a"] = make_da() @@ -1141,6 +1164,14 @@ def test_auto_chunk_da(obj): assert actual.chunks == expected.chunks +@pytest.mark.parametrize("obj", [make_da_cftime()]) +def test_auto_chunk_da_cftime(obj): + actual = obj.chunk("auto").data + expected = obj.data.rechunk({0: 10, 1: 120}) + np.testing.assert_array_equal(actual, expected) + assert actual.chunks == expected.chunks + + def test_map_blocks_error(map_da, map_ds): def bad_func(darray): return (darray * darray.x + 5 * darray.y)[:1, :1]