Skip to content

Commit

Permalink
Updates to DataTree.equals and DataTree.identical (#9627)
Browse files Browse the repository at this point in the history
* Updates to DataTree.equals and DataTree.identical

In contrast to `equals`, `identical` now also checks that any
inherited variables are inherited on both objects. However, they do
not need to be inherited from the same source. This aligns the
behavior of `identical` with the DataTree `__repr__`.

I've also removed the `from_root` argument from `equals` and `identical`.
If a user wants to compare trees from their roots, a better (simpler)
inference is to simply call these methods on the `.root` properties.
I would also like to remove the `strict_names` argument, but that will
require switching to use the new `zip_subtrees` (#9623) first.

* More efficient check for inherited coordinates
  • Loading branch information
shoyer authored Oct 16, 2024
1 parent 3c01ced commit de3fce8
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 71 deletions.
44 changes: 22 additions & 22 deletions xarray/core/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -1252,61 +1252,61 @@ def isomorphic(
except (TypeError, TreeIsomorphismError):
return False

def equals(self, other: DataTree, from_root: bool = True) -> bool:
def equals(self, other: DataTree) -> bool:
"""
Two DataTrees are equal if they have isomorphic node structures, with matching node names,
and if they have matching variables and coordinates, all of which are equal.
By default this method will check the whole tree above the given node.
Two DataTrees are equal if they have isomorphic node structures, with
matching node names, and if they have matching variables and
coordinates, all of which are equal.
Parameters
----------
other : DataTree
The other tree object to compare to.
from_root : bool, optional, default is True
Whether or not to first traverse to the root of the two trees before checking for isomorphism.
If neither tree has a parent then this has no effect.
See Also
--------
Dataset.equals
DataTree.isomorphic
DataTree.identical
"""
if not self.isomorphic(other, from_root=from_root, strict_names=True):
if not self.isomorphic(other, strict_names=True):
return False

return all(
[
node.dataset.equals(other_node.dataset)
for node, other_node in zip(self.subtree, other.subtree, strict=True)
]
node.dataset.equals(other_node.dataset)
for node, other_node in zip(self.subtree, other.subtree, strict=True)
)

def identical(self, other: DataTree, from_root=True) -> bool:
"""
Like equals, but will also check all dataset attributes and the attributes on
all variables and coordinates.
def _inherited_coords_set(self) -> set[str]:
return set(self.parent.coords if self.parent else [])

By default this method will check the whole tree above the given node.
def identical(self, other: DataTree) -> bool:
"""
Like equals, but also checks attributes on all datasets, variables and
coordinates, and requires that any inherited coordinates at the tree
root are also inherited on the other tree.
Parameters
----------
other : DataTree
The other tree object to compare to.
from_root : bool, optional, default is True
Whether or not to first traverse to the root of the two trees before checking for isomorphism.
If neither tree has a parent then this has no effect.
See Also
--------
Dataset.identical
DataTree.isomorphic
DataTree.equals
"""
if not self.isomorphic(other, from_root=from_root, strict_names=True):
if not self.isomorphic(other, strict_names=True):
return False

if self.name != other.name:
return False

if self._inherited_coords_set() != other._inherited_coords_set():
return False

# TODO: switch to zip_subtrees, when available
return all(
node.dataset.identical(other_node.dataset)
for node, other_node in zip(self.subtree, other.subtree, strict=True)
Expand Down
45 changes: 4 additions & 41 deletions xarray/testing/assertions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import functools
import warnings
from collections.abc import Hashable
from typing import overload

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -107,16 +106,8 @@ def maybe_transpose_dims(a, b, check_dim_order: bool):
return b


@overload
def assert_equal(a, b): ...


@overload
def assert_equal(a: DataTree, b: DataTree, from_root: bool = True): ...


@ensure_warnings
def assert_equal(a, b, from_root=True, check_dim_order: bool = True):
def assert_equal(a, b, check_dim_order: bool = True):
"""Like :py:func:`numpy.testing.assert_array_equal`, but for xarray
objects.
Expand All @@ -135,10 +126,6 @@ def assert_equal(a, b, from_root=True, check_dim_order: bool = True):
or xarray.core.datatree.DataTree. The first object to compare.
b : xarray.Dataset, xarray.DataArray, xarray.Variable, xarray.Coordinates
or xarray.core.datatree.DataTree. The second object to compare.
from_root : bool, optional, default is True
Only used when comparing DataTree objects. Indicates whether or not to
first traverse to the root of the trees before checking for isomorphism.
If a & b have no parents then this has no effect.
check_dim_order : bool, optional, default is True
Whether dimensions must be in the same order.
Expand All @@ -159,25 +146,13 @@ def assert_equal(a, b, from_root=True, check_dim_order: bool = True):
elif isinstance(a, Coordinates):
assert a.equals(b), formatting.diff_coords_repr(a, b, "equals")
elif isinstance(a, DataTree):
if from_root:
a = a.root
b = b.root

assert a.equals(b, from_root=from_root), diff_datatree_repr(a, b, "equals")
assert a.equals(b), diff_datatree_repr(a, b, "equals")
else:
raise TypeError(f"{type(a)} not supported by assertion comparison")


@overload
def assert_identical(a, b): ...


@overload
def assert_identical(a: DataTree, b: DataTree, from_root: bool = True): ...


@ensure_warnings
def assert_identical(a, b, from_root=True):
def assert_identical(a, b):
"""Like :py:func:`xarray.testing.assert_equal`, but also matches the
objects' names and attributes.
Expand All @@ -193,12 +168,6 @@ def assert_identical(a, b, from_root=True):
The first object to compare.
b : xarray.Dataset, xarray.DataArray, xarray.Variable or xarray.Coordinates
The second object to compare.
from_root : bool, optional, default is True
Only used when comparing DataTree objects. Indicates whether or not to
first traverse to the root of the trees before checking for isomorphism.
If a & b have no parents then this has no effect.
check_dim_order : bool, optional, default is True
Whether dimensions must be in the same order.
See Also
--------
Expand All @@ -220,13 +189,7 @@ def assert_identical(a, b, from_root=True):
elif isinstance(a, Coordinates):
assert a.identical(b), formatting.diff_coords_repr(a, b, "identical")
elif isinstance(a, DataTree):
if from_root:
a = a.root
b = b.root

assert a.identical(b, from_root=from_root), diff_datatree_repr(
a, b, "identical"
)
assert a.identical(b), diff_datatree_repr(a, b, "identical")
else:
raise TypeError(f"{type(a)} not supported by assertion comparison")

Expand Down
118 changes: 111 additions & 7 deletions xarray/tests/test_datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -1538,6 +1538,110 @@ def f(x, tree, y):
assert actual is dt and actual.attrs == attrs


class TestEqualsAndIdentical:

def test_minimal_variations(self):
tree = DataTree.from_dict(
{
"/": Dataset({"x": 1}),
"/child": Dataset({"x": 2}),
}
)
assert tree.equals(tree)
assert tree.identical(tree)

child = tree.children["child"]
assert child.equals(child)
assert child.identical(child)

new_child = DataTree(dataset=Dataset({"x": 2}), name="child")
assert child.equals(new_child)
assert child.identical(new_child)

anonymous_child = DataTree(dataset=Dataset({"x": 2}))
# TODO: re-enable this after fixing .equals() not to require matching
# names on the root node (i.e., after switching to use zip_subtrees)
# assert child.equals(anonymous_child)
assert not child.identical(anonymous_child)

different_variables = DataTree.from_dict(
{
"/": Dataset(),
"/other": Dataset({"x": 2}),
}
)
assert not tree.equals(different_variables)
assert not tree.identical(different_variables)

different_root_data = DataTree.from_dict(
{
"/": Dataset({"x": 4}),
"/child": Dataset({"x": 2}),
}
)
assert not tree.equals(different_root_data)
assert not tree.identical(different_root_data)

different_child_data = DataTree.from_dict(
{
"/": Dataset({"x": 1}),
"/child": Dataset({"x": 3}),
}
)
assert not tree.equals(different_child_data)
assert not tree.identical(different_child_data)

different_child_node_attrs = DataTree.from_dict(
{
"/": Dataset({"x": 1}),
"/child": Dataset({"x": 2}, attrs={"foo": "bar"}),
}
)
assert tree.equals(different_child_node_attrs)
assert not tree.identical(different_child_node_attrs)

different_child_variable_attrs = DataTree.from_dict(
{
"/": Dataset({"x": 1}),
"/child": Dataset({"x": ((), 2, {"foo": "bar"})}),
}
)
assert tree.equals(different_child_variable_attrs)
assert not tree.identical(different_child_variable_attrs)

different_name = DataTree.from_dict(
{
"/": Dataset({"x": 1}),
"/child": Dataset({"x": 2}),
},
name="different",
)
# TODO: re-enable this after fixing .equals() not to require matching
# names on the root node (i.e., after switching to use zip_subtrees)
# assert tree.equals(different_name)
assert not tree.identical(different_name)

def test_differently_inherited_coordinates(self):
root = DataTree.from_dict(
{
"/": Dataset(coords={"x": [1, 2]}),
"/child": Dataset(),
}
)
child = root.children["child"]
assert child.equals(child)
assert child.identical(child)

new_child = DataTree(dataset=Dataset(coords={"x": [1, 2]}), name="child")
assert child.equals(new_child)
assert not child.identical(new_child)

deeper_root = DataTree(children={"root": root})
grandchild = deeper_root["/root/child"]
assert child.equals(grandchild)
assert child.identical(grandchild)


class TestSubset:
def test_match(self) -> None:
# TODO is this example going to cause problems with case sensitivity?
Expand Down Expand Up @@ -1599,7 +1703,7 @@ def test_isel_siblings(self) -> None:
}
)
actual = tree.isel(x=-1)
assert_equal(actual, expected)
assert_identical(actual, expected)

expected = DataTree.from_dict(
{
Expand All @@ -1608,13 +1712,13 @@ def test_isel_siblings(self) -> None:
}
)
actual = tree.isel(x=slice(1))
assert_equal(actual, expected)
assert_identical(actual, expected)

actual = tree.isel(x=[0])
assert_equal(actual, expected)
assert_identical(actual, expected)

actual = tree.isel(x=slice(None))
assert_equal(actual, tree)
assert_identical(actual, tree)

def test_isel_inherited(self) -> None:
tree = DataTree.from_dict(
Expand All @@ -1631,15 +1735,15 @@ def test_isel_inherited(self) -> None:
}
)
actual = tree.isel(x=-1)
assert_equal(actual, expected)
assert_identical(actual, expected)

expected = DataTree.from_dict(
{
"/child": xr.Dataset({"foo": 4}),
}
)
actual = tree.isel(x=-1, drop=True)
assert_equal(actual, expected)
assert_identical(actual, expected)

expected = DataTree.from_dict(
{
Expand All @@ -1648,7 +1752,7 @@ def test_isel_inherited(self) -> None:
}
)
actual = tree.isel(x=[0])
assert_equal(actual, expected)
assert_identical(actual, expected)

actual = tree.isel(x=slice(None))

Expand Down
2 changes: 1 addition & 1 deletion xarray/tests/test_datatree_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def times_ten(ds):

expected = create_test_datatree(modify=lambda ds: 10.0 * ds)["set1"]
result_tree = times_ten(subtree)
assert_equal(result_tree, expected, from_root=False)
assert_equal(result_tree, expected)

def test_skip_empty_nodes_with_attrs(self, create_test_datatree):
# inspired by xarray-datatree GH262
Expand Down

0 comments on commit de3fce8

Please sign in to comment.