diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 96005e17f78..976f7851821 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -17,7 +17,7 @@ What's New .. _whats-new.2024.05.1: -v2024.05.1 (unreleased) +v2024.06 (unreleased) ----------------------- New Features @@ -54,6 +54,10 @@ Documentation Internal Changes ~~~~~~~~~~~~~~~~ +- Migrates remainder of ``io.py`` to ``xarray/core/datatree_io.py`` and + ``TreeAttrAccessMixin`` into ``xarray/core/common.py`` (:pull: `9011`) + By `Owen Littlejohns `_ and + `Tom Nicholas `_. .. _whats-new.2024.05.0: @@ -136,10 +140,9 @@ Internal Changes By `Owen Littlejohns `_, `Matt Savoie `_ and `Tom Nicholas `_. - ``transpose``, ``set_dims``, ``stack`` & ``unstack`` now use a ``dim`` kwarg - rather than ``dims`` or ``dimensions``. This is the final change to unify - xarray functions to use ``dim``. Using the existing kwarg will raise a - warning. - By `Maximilian Roos `_ + rather than ``dims`` or ``dimensions``. This is the final change to make xarray methods + consistent with their use of ``dim``. Using the existing kwarg will raise a + warning. By `Maximilian Roos `_ .. _whats-new.2024.03.0: diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 76fcac62cd3..4b7f1052655 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -36,7 +36,7 @@ from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset, _get_chunk, _maybe_chunk from xarray.core.indexes import Index -from xarray.core.types import ZarrWriteModes +from xarray.core.types import NetcdfWriteModes, ZarrWriteModes from xarray.core.utils import is_remote_uri from xarray.namedarray.daskmanager import DaskManager from xarray.namedarray.parallelcompat import guess_chunkmanager @@ -1120,7 +1120,7 @@ def open_mfdataset( def to_netcdf( dataset: Dataset, path_or_file: str | os.PathLike | None = None, - mode: Literal["w", "a"] = "w", + mode: NetcdfWriteModes = "w", format: T_NetcdfTypes | None = None, group: str | None = None, engine: T_NetcdfEngine | None = None, @@ -1138,7 +1138,7 @@ def to_netcdf( def to_netcdf( dataset: Dataset, path_or_file: None = None, - mode: Literal["w", "a"] = "w", + mode: NetcdfWriteModes = "w", format: T_NetcdfTypes | None = None, group: str | None = None, engine: T_NetcdfEngine | None = None, @@ -1155,7 +1155,7 @@ def to_netcdf( def to_netcdf( dataset: Dataset, path_or_file: str | os.PathLike, - mode: Literal["w", "a"] = "w", + mode: NetcdfWriteModes = "w", format: T_NetcdfTypes | None = None, group: str | None = None, engine: T_NetcdfEngine | None = None, @@ -1173,7 +1173,7 @@ def to_netcdf( def to_netcdf( dataset: Dataset, path_or_file: str | os.PathLike, - mode: Literal["w", "a"] = "w", + mode: NetcdfWriteModes = "w", format: T_NetcdfTypes | None = None, group: str | None = None, engine: T_NetcdfEngine | None = None, @@ -1191,7 +1191,7 @@ def to_netcdf( def to_netcdf( dataset: Dataset, path_or_file: str | os.PathLike, - mode: Literal["w", "a"] = "w", + mode: NetcdfWriteModes = "w", format: T_NetcdfTypes | None = None, group: str | None = None, engine: T_NetcdfEngine | None = None, @@ -1209,7 +1209,7 @@ def to_netcdf( def to_netcdf( dataset: Dataset, path_or_file: str | os.PathLike, - mode: Literal["w", "a"] = "w", + mode: NetcdfWriteModes = "w", format: T_NetcdfTypes | None = None, group: str | None = None, engine: T_NetcdfEngine | None = None, @@ -1226,7 +1226,7 @@ def to_netcdf( def to_netcdf( dataset: Dataset, path_or_file: str | os.PathLike | None, - mode: Literal["w", "a"] = "w", + mode: NetcdfWriteModes = "w", format: T_NetcdfTypes | None = None, group: str | None = None, engine: T_NetcdfEngine | None = None, @@ -1241,7 +1241,7 @@ def to_netcdf( def to_netcdf( dataset: Dataset, path_or_file: str | os.PathLike | None = None, - mode: Literal["w", "a"] = "w", + mode: NetcdfWriteModes = "w", format: T_NetcdfTypes | None = None, group: str | None = None, engine: T_NetcdfEngine | None = None, diff --git a/xarray/core/common.py b/xarray/core/common.py index 7b9a049c662..936a6bb2489 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -347,6 +347,24 @@ def _ipython_key_completions_(self) -> list[str]: return list(items) +class TreeAttrAccessMixin(AttrAccessMixin): + """Mixin class that allows getting keys with attribute access""" + + # TODO: Ensure ipython tab completion can include both child datatrees and + # variables from Dataset objects on relevant nodes. + + __slots__ = () + + def __init_subclass__(cls, **kwargs): + """This method overrides the check from ``AttrAccessMixin`` that ensures + ``__dict__`` is absent in a class, with ``__slots__`` used instead. + ``DataTree`` has some dynamically defined attributes in addition to those + defined in ``__slots__``. (GH9068) + """ + if not hasattr(object.__new__(cls), "__dict__"): + pass + + def get_squeeze_dims( xarray_obj, dim: Hashable | Iterable[Hashable] | None = None, diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 4dc897c1878..16b9330345b 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -54,6 +54,7 @@ from xarray.core.options import OPTIONS, _get_keep_attrs from xarray.core.types import ( DaCompatible, + NetcdfWriteModes, T_DataArray, T_DataArrayOrSet, ZarrWriteModes, @@ -3945,7 +3946,7 @@ def to_masked_array(self, copy: bool = True) -> np.ma.MaskedArray: def to_netcdf( self, path: None = None, - mode: Literal["w", "a"] = "w", + mode: NetcdfWriteModes = "w", format: T_NetcdfTypes | None = None, group: str | None = None, engine: T_NetcdfEngine | None = None, @@ -3960,7 +3961,7 @@ def to_netcdf( def to_netcdf( self, path: str | PathLike, - mode: Literal["w", "a"] = "w", + mode: NetcdfWriteModes = "w", format: T_NetcdfTypes | None = None, group: str | None = None, engine: T_NetcdfEngine | None = None, @@ -3976,7 +3977,7 @@ def to_netcdf( def to_netcdf( self, path: str | PathLike, - mode: Literal["w", "a"] = "w", + mode: NetcdfWriteModes = "w", format: T_NetcdfTypes | None = None, group: str | None = None, engine: T_NetcdfEngine | None = None, @@ -3992,7 +3993,7 @@ def to_netcdf( def to_netcdf( self, path: str | PathLike, - mode: Literal["w", "a"] = "w", + mode: NetcdfWriteModes = "w", format: T_NetcdfTypes | None = None, group: str | None = None, engine: T_NetcdfEngine | None = None, @@ -4005,7 +4006,7 @@ def to_netcdf( def to_netcdf( self, path: str | PathLike | None = None, - mode: Literal["w", "a"] = "w", + mode: NetcdfWriteModes = "w", format: T_NetcdfTypes | None = None, group: str | None = None, engine: T_NetcdfEngine | None = None, diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 09597670573..ca152b499a7 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -88,6 +88,7 @@ from xarray.core.missing import get_clean_interp_index from xarray.core.options import OPTIONS, _get_keep_attrs from xarray.core.types import ( + NetcdfWriteModes, QuantileMethods, Self, T_ChunkDim, @@ -2171,7 +2172,7 @@ def dump_to_store(self, store: AbstractDataStore, **kwargs) -> None: def to_netcdf( self, path: None = None, - mode: Literal["w", "a"] = "w", + mode: NetcdfWriteModes = "w", format: T_NetcdfTypes | None = None, group: str | None = None, engine: T_NetcdfEngine | None = None, @@ -2186,7 +2187,7 @@ def to_netcdf( def to_netcdf( self, path: str | PathLike, - mode: Literal["w", "a"] = "w", + mode: NetcdfWriteModes = "w", format: T_NetcdfTypes | None = None, group: str | None = None, engine: T_NetcdfEngine | None = None, @@ -2202,7 +2203,7 @@ def to_netcdf( def to_netcdf( self, path: str | PathLike, - mode: Literal["w", "a"] = "w", + mode: NetcdfWriteModes = "w", format: T_NetcdfTypes | None = None, group: str | None = None, engine: T_NetcdfEngine | None = None, @@ -2218,7 +2219,7 @@ def to_netcdf( def to_netcdf( self, path: str | PathLike, - mode: Literal["w", "a"] = "w", + mode: NetcdfWriteModes = "w", format: T_NetcdfTypes | None = None, group: str | None = None, engine: T_NetcdfEngine | None = None, @@ -2231,7 +2232,7 @@ def to_netcdf( def to_netcdf( self, path: str | PathLike | None = None, - mode: Literal["w", "a"] = "w", + mode: NetcdfWriteModes = "w", format: T_NetcdfTypes | None = None, group: str | None = None, engine: T_NetcdfEngine | None = None, diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index 5737cdcb686..4e4d30885a3 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -9,12 +9,14 @@ Any, Callable, Generic, + Literal, NoReturn, Union, overload, ) from xarray.core import utils +from xarray.core.common import TreeAttrAccessMixin from xarray.core.coordinates import DatasetCoordinates from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset, DataVariables @@ -46,7 +48,6 @@ maybe_wrap_array, ) from xarray.core.variable import Variable -from xarray.datatree_.datatree.common import TreeAttrAccessMixin try: from xarray.core.variable import calculate_dimensions @@ -57,8 +58,9 @@ if TYPE_CHECKING: import pandas as pd + from xarray.core.datatree_io import T_DataTreeNetcdfEngine, T_DataTreeNetcdfTypes from xarray.core.merge import CoercibleValue - from xarray.core.types import ErrorOptions + from xarray.core.types import ErrorOptions, NetcdfWriteModes, ZarrWriteModes # """ # DEVELOPERS' NOTE @@ -1475,7 +1477,16 @@ def groups(self): return tuple(node.path for node in self.subtree) def to_netcdf( - self, filepath, mode: str = "w", encoding=None, unlimited_dims=None, **kwargs + self, + filepath, + mode: NetcdfWriteModes = "w", + encoding=None, + unlimited_dims=None, + format: T_DataTreeNetcdfTypes | None = None, + engine: T_DataTreeNetcdfEngine | None = None, + group: str | None = None, + compute: bool = True, + **kwargs, ): """ Write datatree contents to a netCDF file. @@ -1499,10 +1510,25 @@ def to_netcdf( By default, no dimensions are treated as unlimited dimensions. Note that unlimited_dims may also be set via ``dataset.encoding["unlimited_dims"]``. + format : {"NETCDF4", }, optional + File format for the resulting netCDF file: + + * NETCDF4: Data is stored in an HDF5 file, using netCDF4 API features. + engine : {"netcdf4", "h5netcdf"}, optional + Engine to use when writing netCDF files. If not provided, the + default engine is chosen based on available dependencies, with a + preference for "netcdf4" if writing to a file on disk. + group : str, optional + Path to the netCDF4 group in the given file to open as the root group + of the ``DataTree``. Currently, specifying a group is not supported. + compute : bool, default: True + If true compute immediately, otherwise return a + ``dask.delayed.Delayed`` object that can be computed later. + Currently, ``compute=False`` is not supported. kwargs : Addional keyword arguments to be passed to ``xarray.Dataset.to_netcdf`` """ - from xarray.datatree_.datatree.io import _datatree_to_netcdf + from xarray.core.datatree_io import _datatree_to_netcdf _datatree_to_netcdf( self, @@ -1510,15 +1536,21 @@ def to_netcdf( mode=mode, encoding=encoding, unlimited_dims=unlimited_dims, + format=format, + engine=engine, + group=group, + compute=compute, **kwargs, ) def to_zarr( self, store, - mode: str = "w-", + mode: ZarrWriteModes = "w-", encoding=None, consolidated: bool = True, + group: str | None = None, + compute: Literal[True] = True, **kwargs, ): """ @@ -1541,10 +1573,17 @@ def to_zarr( consolidated : bool If True, apply zarr's `consolidate_metadata` function to the store after writing metadata for all groups. + group : str, optional + Group path. (a.k.a. `path` in zarr terminology.) + compute : bool, default: True + If true compute immediately, otherwise return a + ``dask.delayed.Delayed`` object that can be computed later. Metadata + is always updated eagerly. Currently, ``compute=False`` is not + supported. kwargs : Additional keyword arguments to be passed to ``xarray.Dataset.to_zarr`` """ - from xarray.datatree_.datatree.io import _datatree_to_zarr + from xarray.core.datatree_io import _datatree_to_zarr _datatree_to_zarr( self, @@ -1552,6 +1591,8 @@ def to_zarr( mode=mode, encoding=encoding, consolidated=consolidated, + group=group, + compute=compute, **kwargs, ) diff --git a/xarray/datatree_/datatree/io.py b/xarray/core/datatree_io.py similarity index 64% rename from xarray/datatree_/datatree/io.py rename to xarray/core/datatree_io.py index 6c8e9617da3..1473e624d9e 100644 --- a/xarray/datatree_/datatree/io.py +++ b/xarray/core/datatree_io.py @@ -1,7 +1,17 @@ +from __future__ import annotations + +from collections.abc import Mapping, MutableMapping +from os import PathLike +from typing import Any, Literal, get_args + from xarray.core.datatree import DataTree +from xarray.core.types import NetcdfWriteModes, ZarrWriteModes + +T_DataTreeNetcdfEngine = Literal["netcdf4", "h5netcdf"] +T_DataTreeNetcdfTypes = Literal["NETCDF4"] -def _get_nc_dataset_class(engine): +def _get_nc_dataset_class(engine: T_DataTreeNetcdfEngine | None): if engine == "netcdf4": from netCDF4 import Dataset elif engine == "h5netcdf": @@ -16,7 +26,12 @@ def _get_nc_dataset_class(engine): return Dataset -def _create_empty_netcdf_group(filename, group, mode, engine): +def _create_empty_netcdf_group( + filename: str | PathLike, + group: str, + mode: NetcdfWriteModes, + engine: T_DataTreeNetcdfEngine | None, +): ncDataset = _get_nc_dataset_class(engine) with ncDataset(filename, mode=mode) as rootgrp: @@ -25,25 +40,34 @@ def _create_empty_netcdf_group(filename, group, mode, engine): def _datatree_to_netcdf( dt: DataTree, - filepath, - mode: str = "w", - encoding=None, - unlimited_dims=None, + filepath: str | PathLike, + mode: NetcdfWriteModes = "w", + encoding: Mapping[str, Any] | None = None, + unlimited_dims: Mapping | None = None, + format: T_DataTreeNetcdfTypes | None = None, + engine: T_DataTreeNetcdfEngine | None = None, + group: str | None = None, + compute: bool = True, **kwargs, ): - if kwargs.get("format", None) not in [None, "NETCDF4"]: + """This function creates an appropriate datastore for writing a datatree to + disk as a netCDF file. + + See `DataTree.to_netcdf` for full API docs. + """ + + if format not in [None, *get_args(T_DataTreeNetcdfTypes)]: raise ValueError("to_netcdf only supports the NETCDF4 format") - engine = kwargs.get("engine", None) - if engine not in [None, "netcdf4", "h5netcdf"]: + if engine not in [None, *get_args(T_DataTreeNetcdfEngine)]: raise ValueError("to_netcdf only supports the netcdf4 and h5netcdf engines") - if kwargs.get("group", None) is not None: + if group is not None: raise NotImplementedError( "specifying a root group for the tree has not been implemented" ) - if not kwargs.get("compute", True): + if not compute: raise NotImplementedError("compute=False has not been implemented yet") if encoding is None: @@ -72,12 +96,17 @@ def _datatree_to_netcdf( mode=mode, encoding=encoding.get(node.path), unlimited_dims=unlimited_dims.get(node.path), + engine=engine, + format=format, + compute=compute, **kwargs, ) - mode = "r+" + mode = "a" -def _create_empty_zarr_group(store, group, mode): +def _create_empty_zarr_group( + store: MutableMapping | str | PathLike[str], group: str, mode: ZarrWriteModes +): import zarr root = zarr.open_group(store, mode=mode) @@ -86,20 +115,28 @@ def _create_empty_zarr_group(store, group, mode): def _datatree_to_zarr( dt: DataTree, - store, - mode: str = "w-", - encoding=None, + store: MutableMapping | str | PathLike[str], + mode: ZarrWriteModes = "w-", + encoding: Mapping[str, Any] | None = None, consolidated: bool = True, + group: str | None = None, + compute: Literal[True] = True, **kwargs, ): + """This function creates an appropriate datastore for writing a datatree + to a zarr store. + + See `DataTree.to_zarr` for full API docs. + """ + from zarr.convenience import consolidate_metadata - if kwargs.get("group", None) is not None: + if group is not None: raise NotImplementedError( "specifying a root group for the tree has not been implemented" ) - if not kwargs.get("compute", True): + if not compute: raise NotImplementedError("compute=False has not been implemented yet") if encoding is None: diff --git a/xarray/core/types.py b/xarray/core/types.py index 8f58e54d8cf..41078d29c0e 100644 --- a/xarray/core/types.py +++ b/xarray/core/types.py @@ -281,4 +281,5 @@ def copy( ] +NetcdfWriteModes = Literal["w", "a"] ZarrWriteModes = Literal["w", "w-", "a", "a-", "r+", "r"] diff --git a/xarray/datatree_/datatree/common.py b/xarray/datatree_/datatree/common.py deleted file mode 100644 index f4f74337c50..00000000000 --- a/xarray/datatree_/datatree/common.py +++ /dev/null @@ -1,104 +0,0 @@ -""" -This file and class only exists because it was easier to copy the code for AttrAccessMixin from xarray.core.common -with some slight modifications than it was to change the behaviour of an inherited xarray internal here. - -The modifications are marked with # TODO comments. -""" - -import warnings -from collections.abc import Hashable, Iterable, Mapping -from contextlib import suppress -from typing import Any - - -class TreeAttrAccessMixin: - """Mixin class that allows getting keys with attribute access""" - - __slots__ = () - - def __init_subclass__(cls, **kwargs): - """Verify that all subclasses explicitly define ``__slots__``. If they don't, - raise error in the core xarray module and a FutureWarning in third-party - extensions. - """ - if not hasattr(object.__new__(cls), "__dict__"): - pass - # TODO reinstate this once integrated upstream - # elif cls.__module__.startswith("datatree."): - # raise AttributeError(f"{cls.__name__} must explicitly define __slots__") - # else: - # cls.__setattr__ = cls._setattr_dict - # warnings.warn( - # f"xarray subclass {cls.__name__} should explicitly define __slots__", - # FutureWarning, - # stacklevel=2, - # ) - super().__init_subclass__(**kwargs) - - @property - def _attr_sources(self) -> Iterable[Mapping[Hashable, Any]]: - """Places to look-up items for attribute-style access""" - yield from () - - @property - def _item_sources(self) -> Iterable[Mapping[Hashable, Any]]: - """Places to look-up items for key-autocompletion""" - yield from () - - def __getattr__(self, name: str) -> Any: - if name not in {"__dict__", "__setstate__"}: - # this avoids an infinite loop when pickle looks for the - # __setstate__ attribute before the xarray object is initialized - for source in self._attr_sources: - with suppress(KeyError): - return source[name] - raise AttributeError( - f"{type(self).__name__!r} object has no attribute {name!r}" - ) - - # This complicated two-method design boosts overall performance of simple operations - # - particularly DataArray methods that perform a _to_temp_dataset() round-trip - by - # a whopping 8% compared to a single method that checks hasattr(self, "__dict__") at - # runtime before every single assignment. All of this is just temporary until the - # FutureWarning can be changed into a hard crash. - def _setattr_dict(self, name: str, value: Any) -> None: - """Deprecated third party subclass (see ``__init_subclass__`` above)""" - object.__setattr__(self, name, value) - if name in self.__dict__: - # Custom, non-slotted attr, or improperly assigned variable? - warnings.warn( - f"Setting attribute {name!r} on a {type(self).__name__!r} object. Explicitly define __slots__ " - "to suppress this warning for legitimate custom attributes and " - "raise an error when attempting variables assignments.", - FutureWarning, - stacklevel=2, - ) - - def __setattr__(self, name: str, value: Any) -> None: - """Objects with ``__slots__`` raise AttributeError if you try setting an - undeclared attribute. This is desirable, but the error message could use some - improvement. - """ - try: - object.__setattr__(self, name, value) - except AttributeError as e: - # Don't accidentally shadow custom AttributeErrors, e.g. - # DataArray.dims.setter - if str(e) != f"{type(self).__name__!r} object has no attribute {name!r}": - raise - raise AttributeError( - f"cannot set attribute {name!r} on a {type(self).__name__!r} object. Use __setitem__ style" - "assignment (e.g., `ds['name'] = ...`) instead of assigning variables." - ) from e - - def __dir__(self) -> list[str]: - """Provide method name lookup and completion. Only provide 'public' - methods. - """ - extra_attrs = { - item - for source in self._attr_sources - for item in source - if isinstance(item, str) - } - return sorted(set(dir(type(self))) | extra_attrs)