Skip to content

Commit

Permalink
Hierarchical coordinates in DataTree (#9063)
Browse files Browse the repository at this point in the history
* Inheritance of data coordinates

* Simplify __init__

* Include path name in alignment errors

* Fix some mypy errors

* mypy fix

* simplify DataTree data model

* Add to_dataset(local=True)

* Fix mypy failure in tests

* Fix to_zarr for inherited coords

* Fix to_netcdf for heirarchical coords

* Add ChainSet

* Revise internal data model; remove ChainSet

* add another way to construct inherited indexes

* Finish refactoring error message

* include inherited dimensions in HTML repr, too

* Construct ChainMap objects on demand.

* slightly better error message with mis-aligned data trees

* mypy fix

* use float64 instead of float32 for windows

* clean-up per review

* Add note about inheritance to .ds docs
  • Loading branch information
shoyer authored Jul 3, 2024
1 parent 6c2d8c3 commit a86c3ff
Show file tree
Hide file tree
Showing 8 changed files with 527 additions and 240 deletions.
443 changes: 228 additions & 215 deletions xarray/core/datatree.py

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions xarray/core/datatree_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def _datatree_to_netcdf(
unlimited_dims = {}

for node in dt.subtree:
ds = node.ds
ds = node.to_dataset(inherited=False)
group_path = node.path
if ds is None:
_create_empty_netcdf_group(filepath, group_path, mode, engine)
Expand Down Expand Up @@ -151,7 +151,7 @@ def _datatree_to_zarr(
)

for node in dt.subtree:
ds = node.ds
ds = node.to_dataset(inherited=False)
group_path = node.path
if ds is None:
_create_empty_zarr_group(store, group_path, mode)
Expand Down
23 changes: 22 additions & 1 deletion xarray/core/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,6 +748,27 @@ def dataset_repr(ds):
return "\n".join(summary)


def dims_and_coords_repr(ds) -> str:
"""Partial Dataset repr for use inside DataTree inheritance errors."""
summary = []

col_width = _calculate_col_width(ds.coords)
max_rows = OPTIONS["display_max_rows"]

dims_start = pretty_print("Dimensions:", col_width)
dims_values = dim_summary_limited(ds, col_width=col_width + 1, max_rows=max_rows)
summary.append(f"{dims_start}({dims_values})")

if ds.coords:
summary.append(coords_repr(ds.coords, col_width=col_width, max_rows=max_rows))

unindexed_dims_str = unindexed_dims_repr(ds.dims, ds.coords, max_rows=max_rows)
if unindexed_dims_str:
summary.append(unindexed_dims_str)

return "\n".join(summary)


def diff_dim_summary(a, b):
if a.sizes != b.sizes:
return f"Differing dimensions:\n ({dim_summary(a)}) != ({dim_summary(b)})"
Expand Down Expand Up @@ -1030,7 +1051,7 @@ def diff_datatree_repr(a: DataTree, b: DataTree, compat):
def _single_node_repr(node: DataTree) -> str:
"""Information about this node, not including its relationships to other nodes."""
if node.has_data or node.has_attrs:
ds_info = "\n" + repr(node.ds)
ds_info = "\n" + repr(node._to_dataset_view(rebuild_dims=False))
else:
ds_info = ""
return f"Group: {node.path}{ds_info}"
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/formatting_html.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ def summarize_datatree_children(children: Mapping[str, DataTree]) -> str:
def datatree_node_repr(group_title: str, dt: DataTree) -> str:
header_components = [f"<div class='xr-obj-type'>{escape(group_title)}</div>"]

ds = dt.ds
ds = dt._to_dataset_view(rebuild_dims=False)

sections = [
children_section(dt.children),
Expand Down
16 changes: 9 additions & 7 deletions xarray/core/treenode.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,14 +138,14 @@ def _attach(self, parent: Tree | None, child_name: str | None = None) -> None:
"To directly set parent, child needs a name, but child is unnamed"
)

self._pre_attach(parent)
self._pre_attach(parent, child_name)
parentchildren = parent._children
assert not any(
child is self for child in parentchildren
), "Tree is corrupt."
parentchildren[child_name] = self
self._parent = parent
self._post_attach(parent)
self._post_attach(parent, child_name)
else:
self._parent = None

Expand Down Expand Up @@ -415,11 +415,11 @@ def _post_detach(self: Tree, parent: Tree) -> None:
"""Method call after detaching from `parent`."""
pass

def _pre_attach(self: Tree, parent: Tree) -> None:
def _pre_attach(self: Tree, parent: Tree, name: str) -> None:
"""Method call before attaching to `parent`."""
pass

def _post_attach(self: Tree, parent: Tree) -> None:
def _post_attach(self: Tree, parent: Tree, name: str) -> None:
"""Method call after attaching to `parent`."""
pass

Expand Down Expand Up @@ -567,6 +567,9 @@ def same_tree(self, other: Tree) -> bool:
return self.root is other.root


AnyNamedNode = TypeVar("AnyNamedNode", bound="NamedNode")


class NamedNode(TreeNode, Generic[Tree]):
"""
A TreeNode which knows its own name.
Expand Down Expand Up @@ -606,10 +609,9 @@ def __repr__(self, level=0):
def __str__(self) -> str:
return f"NamedNode('{self.name}')" if self.name else "NamedNode()"

def _post_attach(self: NamedNode, parent: NamedNode) -> None:
def _post_attach(self: AnyNamedNode, parent: AnyNamedNode, name: str) -> None:
"""Ensures child has name attribute corresponding to key under which it has been stored."""
key = next(k for k, v in parent.children.items() if v is self)
self.name = key
self.name = name

@property
def path(self) -> str:
Expand Down
42 changes: 37 additions & 5 deletions xarray/tests/test_backends_datatree.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, cast

import pytest

import xarray as xr
from xarray.backends.api import open_datatree
from xarray.core.datatree import DataTree
from xarray.testing import assert_equal
from xarray.tests import (
requires_h5netcdf,
Expand All @@ -13,11 +15,11 @@
)

if TYPE_CHECKING:
from xarray.backends.api import T_NetcdfEngine
from xarray.core.datatree_io import T_DataTreeNetcdfEngine


class DatatreeIOBase:
engine: T_NetcdfEngine | None = None
engine: T_DataTreeNetcdfEngine | None = None

def test_to_netcdf(self, tmpdir, simple_datatree):
filepath = tmpdir / "test.nc"
Expand All @@ -27,6 +29,21 @@ def test_to_netcdf(self, tmpdir, simple_datatree):
roundtrip_dt = open_datatree(filepath, engine=self.engine)
assert_equal(original_dt, roundtrip_dt)

def test_to_netcdf_inherited_coords(self, tmpdir):
filepath = tmpdir / "test.nc"
original_dt = DataTree.from_dict(
{
"/": xr.Dataset({"a": (("x",), [1, 2])}, coords={"x": [3, 4]}),
"/sub": xr.Dataset({"b": (("x",), [5, 6])}),
}
)
original_dt.to_netcdf(filepath, engine=self.engine)

roundtrip_dt = open_datatree(filepath, engine=self.engine)
assert_equal(original_dt, roundtrip_dt)
subtree = cast(DataTree, roundtrip_dt["/sub"])
assert "x" not in subtree.to_dataset(inherited=False).coords

def test_netcdf_encoding(self, tmpdir, simple_datatree):
filepath = tmpdir / "test.nc"
original_dt = simple_datatree
Expand All @@ -48,12 +65,12 @@ def test_netcdf_encoding(self, tmpdir, simple_datatree):

@requires_netCDF4
class TestNetCDF4DatatreeIO(DatatreeIOBase):
engine: T_NetcdfEngine | None = "netcdf4"
engine: T_DataTreeNetcdfEngine | None = "netcdf4"


@requires_h5netcdf
class TestH5NetCDFDatatreeIO(DatatreeIOBase):
engine: T_NetcdfEngine | None = "h5netcdf"
engine: T_DataTreeNetcdfEngine | None = "h5netcdf"


@requires_zarr
Expand Down Expand Up @@ -119,3 +136,18 @@ def test_to_zarr_default_write_mode(self, tmpdir, simple_datatree):
# with default settings, to_zarr should not overwrite an existing dir
with pytest.raises(zarr.errors.ContainsGroupError):
simple_datatree.to_zarr(tmpdir)

def test_to_zarr_inherited_coords(self, tmpdir):
original_dt = DataTree.from_dict(
{
"/": xr.Dataset({"a": (("x",), [1, 2])}, coords={"x": [3, 4]}),
"/sub": xr.Dataset({"b": (("x",), [5, 6])}),
}
)
filepath = tmpdir / "test.zarr"
original_dt.to_zarr(filepath)

roundtrip_dt = open_datatree(filepath, engine="zarr")
assert_equal(original_dt, roundtrip_dt)
subtree = cast(DataTree, roundtrip_dt["/sub"])
assert "x" not in subtree.to_dataset(inherited=False).coords
Loading

0 comments on commit a86c3ff

Please sign in to comment.