From 521b087e96f637606daec7250f7953440e5f9707 Mon Sep 17 00:00:00 2001 From: Justus Magin Date: Thu, 24 Oct 2024 19:15:38 +0200 Subject: [PATCH] support `chunks` in `open_groups` and `open_datatree` (#9660) * support chunking and default values in `open_groups` * same for `open_datatree` * use `group_subtrees` instead of `map_over_datasets` * check that `chunks` on `open_datatree` works * specify the chunksizes when opening from disk * check that `open_groups` with chunks works, too * require dask for `test_open_groups_chunks` * protect variables from write operations * copy over `_close` from the backend tree * copy a lot of the docstring from `open_dataset` * same for `open_groups` * reuse `_protect_dataset_variables_inplace` * final missing `requires_dask` * typing for the test utils Co-authored-by: Tom Nicholas * type hints for `_protect_datatree_variables_inplace` Co-authored-by: Tom Nicholas * type hints for `_protect_dataset_variables_inplace` * copy over the name of the backend tree Co-authored-by: Tom Nicholas * typo * swap the order of arguments to `assert_identical` * try explicitly typing `data` * typo * use `Hashable` for variable names --------- Co-authored-by: Tom Nicholas Co-authored-by: Tom Nicholas --- xarray/backends/api.py | 468 ++++++++++++++++++++++++- xarray/tests/test_backends_datatree.py | 150 ++++++++ 2 files changed, 602 insertions(+), 16 deletions(-) diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 16e9a34f240..a52e73701ab 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -42,7 +42,9 @@ ) from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset, _get_chunk, _maybe_chunk +from xarray.core.datatree import DataTree from xarray.core.indexes import Index +from xarray.core.treenode import group_subtrees from xarray.core.types import NetcdfWriteModes, ZarrWriteModes from xarray.core.utils import is_remote_uri from xarray.namedarray.daskmanager import DaskManager @@ -75,7 +77,6 @@ T_NetcdfTypes = Literal[ "NETCDF4", "NETCDF4_CLASSIC", "NETCDF3_64BIT", "NETCDF3_CLASSIC" ] - from xarray.core.datatree import DataTree DATAARRAY_NAME = "__xarray_dataarray_name__" DATAARRAY_VARIABLE = "__xarray_dataarray_variable__" @@ -255,16 +256,22 @@ def _get_mtime(filename_or_obj): return mtime -def _protect_dataset_variables_inplace(dataset, cache): +def _protect_dataset_variables_inplace(dataset: Dataset, cache: bool) -> None: for name, variable in dataset.variables.items(): if name not in dataset._indexes: # no need to protect IndexVariable objects + data: indexing.ExplicitlyIndexedNDArrayMixin data = indexing.CopyOnWriteArray(variable._data) if cache: data = indexing.MemoryCachedArray(data) variable.data = data +def _protect_datatree_variables_inplace(tree: DataTree, cache: bool) -> None: + for node in tree.subtree: + _protect_dataset_variables_inplace(node, cache) + + def _finalize_store(write, store): """Finalize this store by explicitly syncing and closing""" del write # ensure writing is done first @@ -415,6 +422,58 @@ def _dataset_from_backend_dataset( return ds +def _datatree_from_backend_datatree( + backend_tree, + filename_or_obj, + engine, + chunks, + cache, + overwrite_encoded_chunks, + inline_array, + chunked_array_type, + from_array_kwargs, + **extra_tokens, +): + if not isinstance(chunks, int | dict) and chunks not in {None, "auto"}: + raise ValueError( + f"chunks must be an int, dict, 'auto', or None. Instead found {chunks}." + ) + + _protect_datatree_variables_inplace(backend_tree, cache) + if chunks is None: + tree = backend_tree + else: + tree = DataTree.from_dict( + { + path: _chunk_ds( + node.dataset, + filename_or_obj, + engine, + chunks, + overwrite_encoded_chunks, + inline_array, + chunked_array_type, + from_array_kwargs, + **extra_tokens, + ) + for path, [node] in group_subtrees(backend_tree) + }, + name=backend_tree.name, + ) + + for path, [node] in group_subtrees(backend_tree): + tree[path].set_close(node._close) + + # Ensure source filename always stored in dataset object + if "source" not in tree.encoding: + path = getattr(filename_or_obj, "path", filename_or_obj) + + if isinstance(path, str | os.PathLike): + tree.encoding["source"] = _normalize_path(path) + + return tree + + def open_dataset( filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, *, @@ -839,7 +898,22 @@ def open_dataarray( def open_datatree( filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, + *, engine: T_Engine = None, + chunks: T_Chunks = None, + cache: bool | None = None, + decode_cf: bool | None = None, + mask_and_scale: bool | Mapping[str, bool] | None = None, + decode_times: bool | Mapping[str, bool] | None = None, + decode_timedelta: bool | Mapping[str, bool] | None = None, + use_cftime: bool | Mapping[str, bool] | None = None, + concat_characters: bool | Mapping[str, bool] | None = None, + decode_coords: Literal["coordinates", "all"] | bool | None = None, + drop_variables: str | Iterable[str] | None = None, + inline_array: bool = False, + chunked_array_type: str | None = None, + from_array_kwargs: dict[str, Any] | None = None, + backend_kwargs: dict[str, Any] | None = None, **kwargs, ) -> DataTree: """ @@ -849,29 +923,217 @@ def open_datatree( ---------- filename_or_obj : str, Path, file-like, or DataStore Strings and Path objects are interpreted as a path to a netCDF file or Zarr store. - engine : str, optional - Xarray backend engine to use. Valid options include `{"netcdf4", "h5netcdf", "zarr"}`. - **kwargs : dict - Additional keyword arguments passed to :py:func:`~xarray.open_dataset` for each group. + engine : {"netcdf4", "h5netcdf", "zarr", None}, \ + installed backend or xarray.backends.BackendEntrypoint, optional + Engine to use when reading files. If not provided, the default engine + is chosen based on available dependencies, with a preference for + "netcdf4". A custom backend class (a subclass of ``BackendEntrypoint``) + can also be used. + chunks : int, dict, 'auto' or None, default: None + If provided, used to load the data into dask arrays. + + - ``chunks="auto"`` will use dask ``auto`` chunking taking into account the + engine preferred chunks. + - ``chunks=None`` skips using dask, which is generally faster for + small arrays. + - ``chunks=-1`` loads the data with dask using a single chunk for all arrays. + - ``chunks={}`` loads the data with dask using the engine's preferred chunk + size, generally identical to the format's chunk size. If not available, a + single chunk for all arrays. + + See dask chunking for more details. + cache : bool, optional + If True, cache data loaded from the underlying datastore in memory as + NumPy arrays when accessed to avoid reading from the underlying data- + store multiple times. Defaults to True unless you specify the `chunks` + argument to use dask, in which case it defaults to False. Does not + change the behavior of coordinates corresponding to dimensions, which + always load their data from disk into a ``pandas.Index``. + decode_cf : bool, optional + Whether to decode these variables, assuming they were saved according + to CF conventions. + mask_and_scale : bool or dict-like, optional + If True, replace array values equal to `_FillValue` with NA and scale + values according to the formula `original_values * scale_factor + + add_offset`, where `_FillValue`, `scale_factor` and `add_offset` are + taken from variable attributes (if they exist). If the `_FillValue` or + `missing_value` attribute contains multiple values a warning will be + issued and all array values matching one of the multiple values will + be replaced by NA. Pass a mapping, e.g. ``{"my_variable": False}``, + to toggle this feature per-variable individually. + This keyword may not be supported by all the backends. + decode_times : bool or dict-like, optional + If True, decode times encoded in the standard NetCDF datetime format + into datetime objects. Otherwise, leave them encoded as numbers. + Pass a mapping, e.g. ``{"my_variable": False}``, + to toggle this feature per-variable individually. + This keyword may not be supported by all the backends. + decode_timedelta : bool or dict-like, optional + If True, decode variables and coordinates with time units in + {"days", "hours", "minutes", "seconds", "milliseconds", "microseconds"} + into timedelta objects. If False, leave them encoded as numbers. + If None (default), assume the same value of decode_time. + Pass a mapping, e.g. ``{"my_variable": False}``, + to toggle this feature per-variable individually. + This keyword may not be supported by all the backends. + use_cftime: bool or dict-like, optional + Only relevant if encoded dates come from a standard calendar + (e.g. "gregorian", "proleptic_gregorian", "standard", or not + specified). If None (default), attempt to decode times to + ``np.datetime64[ns]`` objects; if this is not possible, decode times to + ``cftime.datetime`` objects. If True, always decode times to + ``cftime.datetime`` objects, regardless of whether or not they can be + represented using ``np.datetime64[ns]`` objects. If False, always + decode times to ``np.datetime64[ns]`` objects; if this is not possible + raise an error. Pass a mapping, e.g. ``{"my_variable": False}``, + to toggle this feature per-variable individually. + This keyword may not be supported by all the backends. + concat_characters : bool or dict-like, optional + If True, concatenate along the last dimension of character arrays to + form string arrays. Dimensions will only be concatenated over (and + removed) if they have no corresponding variable and if they are only + used as the last dimension of character arrays. + Pass a mapping, e.g. ``{"my_variable": False}``, + to toggle this feature per-variable individually. + This keyword may not be supported by all the backends. + decode_coords : bool or {"coordinates", "all"}, optional + Controls which variables are set as coordinate variables: + + - "coordinates" or True: Set variables referred to in the + ``'coordinates'`` attribute of the datasets or individual variables + as coordinate variables. + - "all": Set variables referred to in ``'grid_mapping'``, ``'bounds'`` and + other attributes as coordinate variables. + + Only existing variables can be set as coordinates. Missing variables + will be silently ignored. + drop_variables: str or iterable of str, optional + A variable or list of variables to exclude from being parsed from the + dataset. This may be useful to drop variables with problems or + inconsistent values. + inline_array: bool, default: False + How to include the array in the dask task graph. + By default(``inline_array=False``) the array is included in a task by + itself, and each chunk refers to that task by its key. With + ``inline_array=True``, Dask will instead inline the array directly + in the values of the task graph. See :py:func:`dask.array.from_array`. + chunked_array_type: str, optional + Which chunked array type to coerce this datasets' arrays to. + Defaults to 'dask' if installed, else whatever is registered via the `ChunkManagerEnetryPoint` system. + Experimental API that should not be relied upon. + from_array_kwargs: dict + Additional keyword arguments passed on to the `ChunkManagerEntrypoint.from_array` method used to create + chunked arrays, via whichever chunk manager is specified through the `chunked_array_type` kwarg. + For example if :py:func:`dask.array.Array` objects are used for chunking, additional kwargs will be passed + to :py:func:`dask.array.from_array`. Experimental API that should not be relied upon. + backend_kwargs: dict + Additional keyword arguments passed on to the engine open function, + equivalent to `**kwargs`. + **kwargs: dict + Additional keyword arguments passed on to the engine open function. + For example: + + - 'group': path to the group in the given file to open as the root group as + a str. + - 'lock': resource lock to use when reading data from disk. Only + relevant when using dask or another form of parallelism. By default, + appropriate locks are chosen to safely read and write files with the + currently active dask scheduler. Supported by "netcdf4", "h5netcdf", + "scipy". + + See engine open function for kwargs accepted by each specific engine. + Returns ------- - xarray.DataTree + tree : DataTree + The newly created datatree. + + Notes + ----- + ``open_datatree`` opens the file with read-only access. When you modify + values of a DataTree, even one linked to files on disk, only the in-memory + copy you are manipulating in xarray is modified: the original file on disk + is never touched. + + See Also + -------- + xarray.open_groups + xarray.open_dataset """ + if cache is None: + cache = chunks is None + + if backend_kwargs is not None: + kwargs.update(backend_kwargs) + if engine is None: engine = plugins.guess_engine(filename_or_obj) + if from_array_kwargs is None: + from_array_kwargs = {} + backend = plugins.get_backend(engine) - return backend.open_datatree(filename_or_obj, **kwargs) + decoders = _resolve_decoders_kwargs( + decode_cf, + open_backend_dataset_parameters=(), + mask_and_scale=mask_and_scale, + decode_times=decode_times, + decode_timedelta=decode_timedelta, + concat_characters=concat_characters, + use_cftime=use_cftime, + decode_coords=decode_coords, + ) + overwrite_encoded_chunks = kwargs.pop("overwrite_encoded_chunks", None) + + backend_tree = backend.open_datatree( + filename_or_obj, + drop_variables=drop_variables, + **decoders, + **kwargs, + ) + + tree = _datatree_from_backend_datatree( + backend_tree, + filename_or_obj, + engine, + chunks, + cache, + overwrite_encoded_chunks, + inline_array, + chunked_array_type, + from_array_kwargs, + drop_variables=drop_variables, + **decoders, + **kwargs, + ) + + return tree def open_groups( filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, + *, engine: T_Engine = None, + chunks: T_Chunks = None, + cache: bool | None = None, + decode_cf: bool | None = None, + mask_and_scale: bool | Mapping[str, bool] | None = None, + decode_times: bool | Mapping[str, bool] | None = None, + decode_timedelta: bool | Mapping[str, bool] | None = None, + use_cftime: bool | Mapping[str, bool] | None = None, + concat_characters: bool | Mapping[str, bool] | None = None, + decode_coords: Literal["coordinates", "all"] | bool | None = None, + drop_variables: str | Iterable[str] | None = None, + inline_array: bool = False, + chunked_array_type: str | None = None, + from_array_kwargs: dict[str, Any] | None = None, + backend_kwargs: dict[str, Any] | None = None, **kwargs, ) -> dict[str, Dataset]: """ Open and decode a file or file-like object, creating a dictionary containing one xarray Dataset for each group in the file. + Useful for an HDF file ("netcdf4" or "h5netcdf") containing many groups that are not alignable with their parents and cannot be opened directly with ``open_datatree``. It is encouraged to use this function to inspect your data, then make the necessary changes to make the structure coercible to a `DataTree` object before calling `DataTree.from_dict()` and proceeding with your analysis. @@ -880,26 +1142,200 @@ def open_groups( ---------- filename_or_obj : str, Path, file-like, or DataStore Strings and Path objects are interpreted as a path to a netCDF file. - engine : str, optional - Xarray backend engine to use. Valid options include `{"netcdf4", "h5netcdf"}`. - **kwargs : dict - Additional keyword arguments passed to :py:func:`~xarray.open_dataset` for each group. + Parameters + ---------- + filename_or_obj : str, Path, file-like, or DataStore + Strings and Path objects are interpreted as a path to a netCDF file or Zarr store. + engine : {"netcdf4", "h5netcdf", "zarr", None}, \ + installed backend or xarray.backends.BackendEntrypoint, optional + Engine to use when reading files. If not provided, the default engine + is chosen based on available dependencies, with a preference for + "netcdf4". A custom backend class (a subclass of ``BackendEntrypoint``) + can also be used. + chunks : int, dict, 'auto' or None, default: None + If provided, used to load the data into dask arrays. + + - ``chunks="auto"`` will use dask ``auto`` chunking taking into account the + engine preferred chunks. + - ``chunks=None`` skips using dask, which is generally faster for + small arrays. + - ``chunks=-1`` loads the data with dask using a single chunk for all arrays. + - ``chunks={}`` loads the data with dask using the engine's preferred chunk + size, generally identical to the format's chunk size. If not available, a + single chunk for all arrays. + + See dask chunking for more details. + cache : bool, optional + If True, cache data loaded from the underlying datastore in memory as + NumPy arrays when accessed to avoid reading from the underlying data- + store multiple times. Defaults to True unless you specify the `chunks` + argument to use dask, in which case it defaults to False. Does not + change the behavior of coordinates corresponding to dimensions, which + always load their data from disk into a ``pandas.Index``. + decode_cf : bool, optional + Whether to decode these variables, assuming they were saved according + to CF conventions. + mask_and_scale : bool or dict-like, optional + If True, replace array values equal to `_FillValue` with NA and scale + values according to the formula `original_values * scale_factor + + add_offset`, where `_FillValue`, `scale_factor` and `add_offset` are + taken from variable attributes (if they exist). If the `_FillValue` or + `missing_value` attribute contains multiple values a warning will be + issued and all array values matching one of the multiple values will + be replaced by NA. Pass a mapping, e.g. ``{"my_variable": False}``, + to toggle this feature per-variable individually. + This keyword may not be supported by all the backends. + decode_times : bool or dict-like, optional + If True, decode times encoded in the standard NetCDF datetime format + into datetime objects. Otherwise, leave them encoded as numbers. + Pass a mapping, e.g. ``{"my_variable": False}``, + to toggle this feature per-variable individually. + This keyword may not be supported by all the backends. + decode_timedelta : bool or dict-like, optional + If True, decode variables and coordinates with time units in + {"days", "hours", "minutes", "seconds", "milliseconds", "microseconds"} + into timedelta objects. If False, leave them encoded as numbers. + If None (default), assume the same value of decode_time. + Pass a mapping, e.g. ``{"my_variable": False}``, + to toggle this feature per-variable individually. + This keyword may not be supported by all the backends. + use_cftime: bool or dict-like, optional + Only relevant if encoded dates come from a standard calendar + (e.g. "gregorian", "proleptic_gregorian", "standard", or not + specified). If None (default), attempt to decode times to + ``np.datetime64[ns]`` objects; if this is not possible, decode times to + ``cftime.datetime`` objects. If True, always decode times to + ``cftime.datetime`` objects, regardless of whether or not they can be + represented using ``np.datetime64[ns]`` objects. If False, always + decode times to ``np.datetime64[ns]`` objects; if this is not possible + raise an error. Pass a mapping, e.g. ``{"my_variable": False}``, + to toggle this feature per-variable individually. + This keyword may not be supported by all the backends. + concat_characters : bool or dict-like, optional + If True, concatenate along the last dimension of character arrays to + form string arrays. Dimensions will only be concatenated over (and + removed) if they have no corresponding variable and if they are only + used as the last dimension of character arrays. + Pass a mapping, e.g. ``{"my_variable": False}``, + to toggle this feature per-variable individually. + This keyword may not be supported by all the backends. + decode_coords : bool or {"coordinates", "all"}, optional + Controls which variables are set as coordinate variables: + + - "coordinates" or True: Set variables referred to in the + ``'coordinates'`` attribute of the datasets or individual variables + as coordinate variables. + - "all": Set variables referred to in ``'grid_mapping'``, ``'bounds'`` and + other attributes as coordinate variables. + + Only existing variables can be set as coordinates. Missing variables + will be silently ignored. + drop_variables: str or iterable of str, optional + A variable or list of variables to exclude from being parsed from the + dataset. This may be useful to drop variables with problems or + inconsistent values. + inline_array: bool, default: False + How to include the array in the dask task graph. + By default(``inline_array=False``) the array is included in a task by + itself, and each chunk refers to that task by its key. With + ``inline_array=True``, Dask will instead inline the array directly + in the values of the task graph. See :py:func:`dask.array.from_array`. + chunked_array_type: str, optional + Which chunked array type to coerce this datasets' arrays to. + Defaults to 'dask' if installed, else whatever is registered via the `ChunkManagerEnetryPoint` system. + Experimental API that should not be relied upon. + from_array_kwargs: dict + Additional keyword arguments passed on to the `ChunkManagerEntrypoint.from_array` method used to create + chunked arrays, via whichever chunk manager is specified through the `chunked_array_type` kwarg. + For example if :py:func:`dask.array.Array` objects are used for chunking, additional kwargs will be passed + to :py:func:`dask.array.from_array`. Experimental API that should not be relied upon. + backend_kwargs: dict + Additional keyword arguments passed on to the engine open function, + equivalent to `**kwargs`. + **kwargs: dict + Additional keyword arguments passed on to the engine open function. + For example: + + - 'group': path to the group in the given file to open as the root group as + a str. + - 'lock': resource lock to use when reading data from disk. Only + relevant when using dask or another form of parallelism. By default, + appropriate locks are chosen to safely read and write files with the + currently active dask scheduler. Supported by "netcdf4", "h5netcdf", + "scipy". + + See engine open function for kwargs accepted by each specific engine. Returns ------- - dict[str, xarray.Dataset] + groups : dict of str to xarray.Dataset + The groups as Dataset objects + + Notes + ----- + ``open_groups`` opens the file with read-only access. When you modify + values of a Dataset, even one linked to files on disk, only the in-memory + copy you are manipulating in xarray is modified: the original file on disk + is never touched. See Also -------- - open_datatree() - DataTree.from_dict() + xarray.open_datatree + xarray.open_dataset + xarray.DataTree.from_dict """ + if cache is None: + cache = chunks is None + + if backend_kwargs is not None: + kwargs.update(backend_kwargs) + if engine is None: engine = plugins.guess_engine(filename_or_obj) + if from_array_kwargs is None: + from_array_kwargs = {} + backend = plugins.get_backend(engine) - return backend.open_groups_as_dict(filename_or_obj, **kwargs) + decoders = _resolve_decoders_kwargs( + decode_cf, + open_backend_dataset_parameters=(), + mask_and_scale=mask_and_scale, + decode_times=decode_times, + decode_timedelta=decode_timedelta, + concat_characters=concat_characters, + use_cftime=use_cftime, + decode_coords=decode_coords, + ) + overwrite_encoded_chunks = kwargs.pop("overwrite_encoded_chunks", None) + + backend_groups = backend.open_groups_as_dict( + filename_or_obj, + drop_variables=drop_variables, + **decoders, + **kwargs, + ) + + groups = { + name: _dataset_from_backend_dataset( + backend_ds, + filename_or_obj, + engine, + chunks, + cache, + overwrite_encoded_chunks, + inline_array, + chunked_array_type, + from_array_kwargs, + drop_variables=drop_variables, + **decoders, + **kwargs, + ) + for name, backend_ds in backend_groups.items() + } + + return groups def open_mfdataset( diff --git a/xarray/tests/test_backends_datatree.py b/xarray/tests/test_backends_datatree.py index 16598194e1d..01b8b0ae81b 100644 --- a/xarray/tests/test_backends_datatree.py +++ b/xarray/tests/test_backends_datatree.py @@ -1,6 +1,7 @@ from __future__ import annotations import re +from collections.abc import Hashable from typing import TYPE_CHECKING, cast import numpy as np @@ -11,6 +12,7 @@ from xarray.core.datatree import DataTree from xarray.testing import assert_equal, assert_identical from xarray.tests import ( + requires_dask, requires_h5netcdf, requires_netCDF4, requires_zarr, @@ -27,6 +29,47 @@ have_zarr_v3 = xr.backends.zarr._zarr_v3() +def diff_chunks( + comparison: dict[tuple[str, Hashable], bool], tree1: DataTree, tree2: DataTree +) -> str: + mismatching_variables = [loc for loc, equals in comparison.items() if not equals] + + variable_messages = [ + "\n".join( + [ + f"L {path}:{name}: {tree1[path].variables[name].chunksizes}", + f"R {path}:{name}: {tree2[path].variables[name].chunksizes}", + ] + ) + for path, name in mismatching_variables + ] + return "\n".join(["Differing chunk sizes:"] + variable_messages) + + +def assert_chunks_equal( + actual: DataTree, expected: DataTree, enforce_dask: bool = False +) -> None: + __tracebackhide__ = True + + from xarray.namedarray.pycompat import array_type + + dask_array_type = array_type("dask") + + comparison = { + (path, name): ( + ( + not enforce_dask + or isinstance(node1.variables[name].data, dask_array_type) + ) + and node1.variables[name].chunksizes == node2.variables[name].chunksizes + ) + for path, (node1, node2) in xr.group_subtrees(actual, expected) + for name in node1.variables.keys() + } + + assert all(comparison.values()), diff_chunks(comparison, actual, expected) + + @pytest.fixture(scope="module") def unaligned_datatree_nc(tmp_path_factory): """Creates a test netCDF4 file with the following unaligned structure, writes it to a /tmp directory @@ -172,6 +215,29 @@ def test_open_datatree(self, unaligned_datatree_nc) -> None: ): open_datatree(unaligned_datatree_nc) + @requires_dask + def test_open_datatree_chunks(self, tmpdir, simple_datatree) -> None: + filepath = tmpdir / "test.nc" + + chunks = {"x": 2, "y": 1} + + root_data = xr.Dataset({"a": ("y", [6, 7, 8]), "set0": ("x", [9, 10])}) + set1_data = xr.Dataset({"a": ("y", [-1, 0, 1]), "b": ("x", [-10, 6])}) + set2_data = xr.Dataset({"a": ("y", [1, 2, 3]), "b": ("x", [0.1, 0.2])}) + original_tree = DataTree.from_dict( + { + "/": root_data.chunk(chunks), + "/group1": set1_data.chunk(chunks), + "/group2": set2_data.chunk(chunks), + } + ) + original_tree.to_netcdf(filepath, engine="netcdf4") + + with open_datatree(filepath, engine="netcdf4", chunks=chunks) as tree: + xr.testing.assert_identical(tree, original_tree) + + assert_chunks_equal(tree, original_tree, enforce_dask=True) + def test_open_groups(self, unaligned_datatree_nc) -> None: """Test `open_groups` with a netCDF4 file with an unaligned group hierarchy.""" unaligned_dict_of_datasets = open_groups(unaligned_datatree_nc) @@ -193,6 +259,37 @@ def test_open_groups(self, unaligned_datatree_nc) -> None: for ds in unaligned_dict_of_datasets.values(): ds.close() + @requires_dask + def test_open_groups_chunks(self, tmpdir) -> None: + """Test `open_groups` with chunks on a netcdf4 file.""" + + chunks = {"x": 2, "y": 1} + filepath = tmpdir / "test.nc" + + chunks = {"x": 2, "y": 1} + + root_data = xr.Dataset({"a": ("y", [6, 7, 8]), "set0": ("x", [9, 10])}) + set1_data = xr.Dataset({"a": ("y", [-1, 0, 1]), "b": ("x", [-10, 6])}) + set2_data = xr.Dataset({"a": ("y", [1, 2, 3]), "b": ("x", [0.1, 0.2])}) + original_tree = DataTree.from_dict( + { + "/": root_data.chunk(chunks), + "/group1": set1_data.chunk(chunks), + "/group2": set2_data.chunk(chunks), + } + ) + original_tree.to_netcdf(filepath, mode="w") + + dict_of_datasets = open_groups(filepath, engine="netcdf4", chunks=chunks) + + for path, ds in dict_of_datasets.items(): + assert { + k: max(vs) for k, vs in ds.chunksizes.items() + } == chunks, f"unexpected chunking for {path}" + + for ds in dict_of_datasets.values(): + ds.close() + def test_open_groups_to_dict(self, tmpdir) -> None: """Create an aligned netCDF4 with the following structure to test `open_groups` and `DataTree.from_dict`. @@ -353,6 +450,28 @@ def test_open_datatree(self, unaligned_datatree_zarr) -> None: ): open_datatree(unaligned_datatree_zarr, engine="zarr") + @requires_dask + def test_open_datatree_chunks(self, tmpdir, simple_datatree) -> None: + filepath = tmpdir / "test.zarr" + + chunks = {"x": 2, "y": 1} + + root_data = xr.Dataset({"a": ("y", [6, 7, 8]), "set0": ("x", [9, 10])}) + set1_data = xr.Dataset({"a": ("y", [-1, 0, 1]), "b": ("x", [-10, 6])}) + set2_data = xr.Dataset({"a": ("y", [1, 2, 3]), "b": ("x", [0.1, 0.2])}) + original_tree = DataTree.from_dict( + { + "/": root_data.chunk(chunks), + "/group1": set1_data.chunk(chunks), + "/group2": set2_data.chunk(chunks), + } + ) + original_tree.to_zarr(filepath) + + with open_datatree(filepath, engine="zarr", chunks=chunks) as tree: + xr.testing.assert_identical(tree, original_tree) + assert_chunks_equal(tree, original_tree, enforce_dask=True) + def test_open_groups(self, unaligned_datatree_zarr) -> None: """Test `open_groups` with a zarr store of an unaligned group hierarchy.""" @@ -382,3 +501,34 @@ def test_open_groups(self, unaligned_datatree_zarr) -> None: for ds in unaligned_dict_of_datasets.values(): ds.close() + + @requires_dask + def test_open_groups_chunks(self, tmpdir) -> None: + """Test `open_groups` with chunks on a zarr store.""" + + chunks = {"x": 2, "y": 1} + filepath = tmpdir / "test.zarr" + + chunks = {"x": 2, "y": 1} + + root_data = xr.Dataset({"a": ("y", [6, 7, 8]), "set0": ("x", [9, 10])}) + set1_data = xr.Dataset({"a": ("y", [-1, 0, 1]), "b": ("x", [-10, 6])}) + set2_data = xr.Dataset({"a": ("y", [1, 2, 3]), "b": ("x", [0.1, 0.2])}) + original_tree = DataTree.from_dict( + { + "/": root_data.chunk(chunks), + "/group1": set1_data.chunk(chunks), + "/group2": set2_data.chunk(chunks), + } + ) + original_tree.to_zarr(filepath, mode="w") + + dict_of_datasets = open_groups(filepath, engine="zarr", chunks=chunks) + + for path, ds in dict_of_datasets.items(): + assert { + k: max(vs) for k, vs in ds.chunksizes.items() + } == chunks, f"unexpected chunking for {path}" + + for ds in dict_of_datasets.values(): + ds.close()