Skip to content

Commit

Permalink
Re-implement map_over_datasets using group_subtrees (#9636)
Browse files Browse the repository at this point in the history
* Add zip_subtrees for paired iteration over DataTrees

This should be used for implementing DataTree arithmetic inside
map_over_datasets, so the result does not depend on the order in which
child nodes are defined.

I have also added a minimal implementation of breadth-first-search with
an explicit queue the current recursion based solution in
xarray.core.iterators (which has been removed). The new implementation
is also slightly faster in my microbenchmark:

    In [1]: import xarray as xr

    In [2]: tree = xr.DataTree.from_dict({f"/x{i}": None for i in range(100)})

    In [3]: %timeit _ = list(tree.subtree)
    # on main
    87.2 μs ± 394 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

    # with this branch
    55.1 μs ± 294 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

* fix pytype error

* Re-implement map_over_datasets

The main changes:

- It is implemented using zip_subtrees, which means it should properly
  handle DataTrees where the nodes are defined in a different order.
- For simplicity, I removed handling of `**kwargs`, in order to preserve
  some flexibility for adding keyword arugments.
- I removed automatic skipping of empty nodes, because there are almost
  assuredly cases where that would make sense. This could be restored
  with a option keyword arugment.

* fix typing of map_over_datasets

* add group_subtrees

* wip fixes

* update isomorphic

* documentation and API change for map_over_datasets

* mypy fixes

* fix test

* diff formatting

* more mypy

* doc fix

* more doc fix

* add api docs

* add utility for joining path on windows

* docstring

* add an overload for two return values from map_over_datasets

* partial fixes per review

* fixes per review

* remove a couple of xfails
  • Loading branch information
shoyer authored Oct 21, 2024
1 parent 8f6e45b commit e58edcc
Show file tree
Hide file tree
Showing 19 changed files with 627 additions and 1,069 deletions.
4 changes: 3 additions & 1 deletion DATATREE_MIGRATION_GUIDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ This guide is for previous users of the prototype `datatree.DataTree` class in t
> [!IMPORTANT]
> There are breaking changes! You should not expect that code written with `xarray-contrib/datatree` will work without any modifications. At the absolute minimum you will need to change the top-level import statement, but there are other changes too.
We have made various changes compared to the prototype version. These can be split into three categories: data model changes, which affect the hierarchal structure itself, integration with xarray's IO backends; and minor API changes, which mostly consist of renaming methods to be more self-consistent.
We have made various changes compared to the prototype version. These can be split into three categories: data model changes, which affect the hierarchal structure itself; integration with xarray's IO backends; and minor API changes, which mostly consist of renaming methods to be more self-consistent.

### Data model changes

Expand All @@ -17,6 +17,8 @@ These alignment checks happen at tree construction time, meaning there are some

The alignment checks allowed us to add "Coordinate Inheritance", a much-requested feature where indexed coordinate variables are now "inherited" down to child nodes. This allows you to define common coordinates in a parent group that are then automatically available on every child node. The distinction between a locally-defined coordinate variables and an inherited coordinate that was defined on a parent node is reflected in the `DataTree.__repr__`. Generally if you prefer not to have these variables be inherited you can get more similar behaviour to the old `datatree` package by removing indexes from coordinates, as this prevents inheritance.

Tree structure checks between multiple trees (i.e., `DataTree.isomorophic`) and pairing of nodes in arithmetic has also changed. Nodes are now matched (with `xarray.group_subtrees`) based on their relative paths, without regard to the order in which child nodes are defined.

For further documentation see the page in the user guide on Hierarchical Data.

### Integrated backends
Expand Down
12 changes: 11 additions & 1 deletion doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -749,6 +749,17 @@ Manipulate the contents of a single ``DataTree`` node.
DataTree.assign
DataTree.drop_nodes

DataTree Operations
-------------------

Apply operations over multiple ``DataTree`` objects.

.. autosummary::
:toctree: generated/

map_over_datasets
group_subtrees

Comparisons
-----------

Expand Down Expand Up @@ -954,7 +965,6 @@ DataTree methods

open_datatree
open_groups
map_over_datasets
DataTree.to_dict
DataTree.to_netcdf
DataTree.to_zarr
Expand Down
81 changes: 63 additions & 18 deletions doc/user-guide/hierarchical-data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -362,21 +362,26 @@ This returns an iterable of nodes, which yields them in depth-first order.
for node in vertebrates.subtree:
print(node.path)
A very useful pattern is to use :py:class:`~xarray.DataTree.subtree` conjunction with the :py:class:`~xarray.DataTree.path` property to manipulate the nodes however you wish,
then rebuild a new tree using :py:meth:`xarray.DataTree.from_dict()`.
Similarly, :py:class:`~xarray.DataTree.subtree_with_keys` returns an iterable of
relative paths and corresponding nodes.

A very useful pattern is to iterate over :py:class:`~xarray.DataTree.subtree_with_keys`
to manipulate nodes however you wish, then rebuild a new tree using
:py:meth:`xarray.DataTree.from_dict()`.
For example, we could keep only the nodes containing data by looping over all nodes,
checking if they contain any data using :py:class:`~xarray.DataTree.has_data`,
then rebuilding a new tree using only the paths of those nodes:

.. ipython:: python
non_empty_nodes = {node.path: node.dataset for node in dt.subtree if node.has_data}
non_empty_nodes = {
path: node.dataset for path, node in dt.subtree_with_keys if node.has_data
}
xr.DataTree.from_dict(non_empty_nodes)
You can see this tree is similar to the ``dt`` object above, except that it is missing the empty nodes ``a/c`` and ``a/c/d``.

(If you want to keep the name of the root node, you will need to add the ``name`` kwarg to :py:class:`~xarray.DataTree.from_dict`, i.e. ``DataTree.from_dict(non_empty_nodes, name=dt.root.name)``.)
(If you want to keep the name of the root node, you will need to add the ``name`` kwarg to :py:class:`~xarray.DataTree.from_dict`, i.e. ``DataTree.from_dict(non_empty_nodes, name=dt.name)``.)

.. _manipulating trees:

Expand Down Expand Up @@ -573,38 +578,78 @@ Then calculate the RMS value of these signals:
.. _multiple trees:

We can also use the :py:meth:`~xarray.map_over_datasets` decorator to promote a function which accepts datasets into one which
accepts datatrees.
We can also use :py:func:`~xarray.map_over_datasets` to apply a function over
the data in multiple trees, by passing the trees as positional arguments.

Operating on Multiple Trees
---------------------------

The examples so far have involved mapping functions or methods over the nodes of a single tree,
but we can generalize this to mapping functions over multiple trees at once.

Iterating Over Multiple Trees
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

To iterate over the corresponding nodes in multiple trees, use
:py:func:`~xarray.group_subtrees` instead of
:py:class:`~xarray.DataTree.subtree_with_keys`. This combines well with
:py:meth:`xarray.DataTree.from_dict()` to build a new tree:

.. ipython:: python
dt1 = xr.DataTree.from_dict({"a": xr.Dataset({"x": 1}), "b": xr.Dataset({"x": 2})})
dt2 = xr.DataTree.from_dict(
{"a": xr.Dataset({"x": 10}), "b": xr.Dataset({"x": 20})}
)
result = {}
for path, (node1, node2) in xr.group_subtrees(dt1, dt2):
result[path] = node1.dataset + node2.dataset
xr.DataTree.from_dict(result)
Alternatively, you apply a function directly to paired datasets at every node
using :py:func:`xarray.map_over_datasets`:

.. ipython:: python
xr.map_over_datasets(lambda x, y: x + y, dt1, dt2)
Comparing Trees for Isomorphism
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

For it to make sense to map a single non-unary function over the nodes of multiple trees at once,
each tree needs to have the same structure. Specifically two trees can only be considered similar, or "isomorphic",
if they have the same number of nodes, and each corresponding node has the same number of children.
We can check if any two trees are isomorphic using the :py:meth:`~xarray.DataTree.isomorphic` method.
each tree needs to have the same structure. Specifically two trees can only be considered similar,
or "isomorphic", if the full paths to all of their descendent nodes are the same.

Applying :py:func:`~xarray.group_subtrees` to trees with different structures
raises :py:class:`~xarray.TreeIsomorphismError`:

.. ipython:: python
:okexcept:
dt1 = xr.DataTree.from_dict({"a": None, "a/b": None})
dt2 = xr.DataTree.from_dict({"a": None})
dt1.isomorphic(dt2)
tree = xr.DataTree.from_dict({"a": None, "a/b": None, "a/c": None})
simple_tree = xr.DataTree.from_dict({"a": None})
for _ in xr.group_subtrees(tree, simple_tree):
...
We can explicitly also check if any two trees are isomorphic using the :py:meth:`~xarray.DataTree.isomorphic` method:

.. ipython:: python
tree.isomorphic(simple_tree)
dt3 = xr.DataTree.from_dict({"a": None, "b": None})
dt1.isomorphic(dt3)
Corresponding tree nodes do not need to have the same data in order to be considered isomorphic:

dt4 = xr.DataTree.from_dict({"A": None, "A/B": xr.Dataset({"foo": 1})})
dt1.isomorphic(dt4)
.. ipython:: python
tree_with_data = xr.DataTree.from_dict({"a": xr.Dataset({"foo": 1})})
simple_tree.isomorphic(tree_with_data)
They also do not need to define child nodes in the same order:

.. ipython:: python
If the trees are not isomorphic a :py:class:`~xarray.TreeIsomorphismError` will be raised.
Notice that corresponding tree nodes do not need to have the same name or contain the same data in order to be considered isomorphic.
reordered_tree = xr.DataTree.from_dict({"a": None, "a/c": None, "a/b": None})
tree.isomorphic(reordered_tree)
Arithmetic Between Multiple Trees
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
4 changes: 2 additions & 2 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ New Features
~~~~~~~~~~~~
- ``DataTree`` related functionality is now exposed in the main ``xarray`` public
API. This includes: ``xarray.DataTree``, ``xarray.open_datatree``, ``xarray.open_groups``,
``xarray.map_over_datasets``, ``xarray.register_datatree_accessor`` and
``xarray.testing.assert_isomorphic``.
``xarray.map_over_datasets``, ``xarray.group_subtrees``,
``xarray.register_datatree_accessor`` and ``xarray.testing.assert_isomorphic``.
By `Owen Littlejohns <https://github.com/owenlittlejohns>`_,
`Eni Awowale <https://github.com/eni-awowale>`_,
`Matt Savoie <https://github.com/flamingbear>`_,
Expand Down
10 changes: 8 additions & 2 deletions xarray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset
from xarray.core.datatree import DataTree
from xarray.core.datatree_mapping import TreeIsomorphismError, map_over_datasets
from xarray.core.datatree_mapping import map_over_datasets
from xarray.core.extensions import (
register_dataarray_accessor,
register_dataset_accessor,
Expand All @@ -45,7 +45,12 @@
from xarray.core.merge import Context, MergeError, merge
from xarray.core.options import get_options, set_options
from xarray.core.parallel import map_blocks
from xarray.core.treenode import InvalidTreeError, NotFoundInTreeError
from xarray.core.treenode import (
InvalidTreeError,
NotFoundInTreeError,
TreeIsomorphismError,
group_subtrees,
)
from xarray.core.variable import IndexVariable, Variable, as_variable
from xarray.namedarray.core import NamedArray
from xarray.util.print_versions import show_versions
Expand Down Expand Up @@ -82,6 +87,7 @@
"cross",
"full_like",
"get_options",
"group_subtrees",
"infer_freq",
"load_dataarray",
"load_dataset",
Expand Down
15 changes: 1 addition & 14 deletions xarray/core/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from xarray.core.merge import merge_attrs, merge_coordinates_without_align
from xarray.core.options import OPTIONS, _get_keep_attrs
from xarray.core.types import Dims, T_DataArray
from xarray.core.utils import is_dict_like, is_scalar, parse_dims_as_set
from xarray.core.utils import is_dict_like, is_scalar, parse_dims_as_set, result_name
from xarray.core.variable import Variable
from xarray.namedarray.parallelcompat import get_chunked_array_type
from xarray.namedarray.pycompat import is_chunked_array
Expand All @@ -46,7 +46,6 @@
MissingCoreDimOptions = Literal["raise", "copy", "drop"]

_NO_FILL_VALUE = utils.ReprObject("<no-fill-value>")
_DEFAULT_NAME = utils.ReprObject("<default-name>")
_JOINS_WITHOUT_FILL_VALUES = frozenset({"inner", "exact"})


Expand Down Expand Up @@ -186,18 +185,6 @@ def _enumerate(dim):
return str(alt_signature)


def result_name(objects: Iterable[Any]) -> Any:
# use the same naming heuristics as pandas:
# https://github.com/blaze/blaze/issues/458#issuecomment-51936356
names = {getattr(obj, "name", _DEFAULT_NAME) for obj in objects}
names.discard(_DEFAULT_NAME)
if len(names) == 1:
(name,) = names
else:
name = None
return name


def _get_coords_list(args: Iterable[Any]) -> list[Coordinates]:
coords_list = []
for arg in args:
Expand Down
12 changes: 2 additions & 10 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
either_dict_or_kwargs,
hashable,
infix_dims,
result_name,
)
from xarray.core.variable import (
IndexVariable,
Expand Down Expand Up @@ -4726,15 +4727,6 @@ def identical(self, other: Self) -> bool:
except (TypeError, AttributeError):
return False

def _result_name(self, other: Any = None) -> Hashable | None:
# use the same naming heuristics as pandas:
# https://github.com/ContinuumIO/blaze/issues/458#issuecomment-51936356
other_name = getattr(other, "name", _default)
if other_name is _default or other_name == self.name:
return self.name
else:
return None

def __array_wrap__(self, obj, context=None) -> Self:
new_var = self.variable.__array_wrap__(obj, context)
return self._replace(new_var)
Expand Down Expand Up @@ -4782,7 +4774,7 @@ def _binary_op(
else f(other_variable_or_arraylike, self.variable)
)
coords, indexes = self.coords._merge_raw(other_coords, reflexive)
name = self._result_name(other)
name = result_name([self, other])

return self._replace(variable, coords, name, indexes=indexes)

Expand Down
Loading

0 comments on commit e58edcc

Please sign in to comment.