Skip to content

Commit

Permalink
Migrate datatreee assertions/extensions/formatting (#8967)
Browse files Browse the repository at this point in the history
* DAS-2067 - Migrate formatting.py.

* DAS-2067 - Migrate datatree/extensions.py.

* DAS-2067 - Migrate datatree/tests/test_dataset_api.py.

* DAS-2067 - Migrate datatree_render.py.

* DAS-2067 - Migrate DataTree assertions into xarray/testing/assertions.py.

* DAS-2067 - Update doc/whats-new.rst.

* DAS-2067 - Fix doctests for DataTreeRender.by_attr.

* DAS-2067 - Fix comments in doctests examples for datatree_render.

* DAS-2067 - Implement PR feedback, fix RenderDataTree.__str__.

* DAS-2067 - Add overload for xarray.testing.assert_equal and xarray.testing.assert_identical.

* DAS-2067 - Remove out-of-date comments.

* Remove test of printing datatree

---------

Co-authored-by: Tom Nicholas <[email protected]>
  • Loading branch information
owenlittlejohns and TomNicholas authored Apr 26, 2024
1 parent 8a23e24 commit 6d62719
Show file tree
Hide file tree
Showing 23 changed files with 763 additions and 945 deletions.
9 changes: 7 additions & 2 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,17 @@ Bug fixes

Internal Changes
~~~~~~~~~~~~~~~~
- Migrates ``formatting_html`` functionality for `DataTree` into ``xarray/core`` (:pull: `8930`)
- Migrates ``formatting_html`` functionality for ``DataTree`` into ``xarray/core`` (:pull: `8930`)
By `Eni Awowale <https://github.com/eni-awowale>`_, `Julia Signell <https://github.com/jsignell>`_
and `Tom Nicholas <https://github.com/TomNicholas>`_.
- Migrates ``datatree_mapping`` functionality into ``xarray/core`` (:pull:`8948`)
By `Matt Savoie <https://github.com/flamingbear>`_ `Owen Littlejohns
<https://github.com/owenlittlejohns>` and `Tom Nicholas <https://github.com/TomNicholas>`_.
<https://github.com/owenlittlejohns>`_ and `Tom Nicholas <https://github.com/TomNicholas>`_.
- Migrates ``extensions``, ``formatting`` and ``datatree_render`` functionality for
``DataTree`` into ``xarray/core``. Also migrates ``testing`` functionality into
``xarray/testing/assertions`` for ``DataTree``. (:pull:`8967`)
By `Owen Littlejohns <https://github.com/owenlittlejohns>`_ and
`Tom Nicholas <https://github.com/TomNicholas>`_.


.. _whats-new.2024.03.0:
Expand Down
6 changes: 3 additions & 3 deletions xarray/core/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
check_isomorphic,
map_over_subtree,
)
from xarray.core.datatree_render import RenderDataTree
from xarray.core.formatting import datatree_repr
from xarray.core.formatting_html import (
datatree_repr as datatree_repr_html,
)
Expand All @@ -40,13 +42,11 @@
)
from xarray.core.variable import Variable
from xarray.datatree_.datatree.common import TreeAttrAccessMixin
from xarray.datatree_.datatree.formatting import datatree_repr
from xarray.datatree_.datatree.ops import (
DataTreeArithmeticMixin,
MappedDatasetMethodsMixin,
MappedDataWithCoords,
)
from xarray.datatree_.datatree.render import RenderTree

try:
from xarray.core.variable import calculate_dimensions
Expand Down Expand Up @@ -1451,7 +1451,7 @@ def pipe(

def render(self):
"""Print tree structure, including any data stored at each node."""
for pre, fill, node in RenderTree(self):
for pre, fill, node in RenderDataTree(self):
print(f"{pre}DataTree('{self.name}')")
for ds_line in repr(node.ds)[1:]:
print(f"{fill}{ds_line}")
Expand Down
37 changes: 3 additions & 34 deletions xarray/core/datatree_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
import functools
import sys
from itertools import repeat
from textwrap import dedent
from typing import TYPE_CHECKING, Callable

from xarray import DataArray, Dataset
from xarray.core.iterators import LevelOrderIter
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset
from xarray.core.formatting import diff_treestructure
from xarray.core.treenode import NodePath, TreeNode

if TYPE_CHECKING:
Expand Down Expand Up @@ -71,37 +71,6 @@ def check_isomorphic(
raise TreeIsomorphismError("DataTree objects are not isomorphic:\n" + diff)


def diff_treestructure(a: DataTree, b: DataTree, require_names_equal: bool) -> str:
"""
Return a summary of why two trees are not isomorphic.
If they are isomorphic return an empty string.
"""

# Walking nodes in "level-order" fashion means walking down from the root breadth-first.
# Checking for isomorphism by walking in this way implicitly assumes that the tree is an ordered tree
# (which it is so long as children are stored in a tuple or list rather than in a set).
for node_a, node_b in zip(LevelOrderIter(a), LevelOrderIter(b)):
path_a, path_b = node_a.path, node_b.path

if require_names_equal and node_a.name != node_b.name:
diff = dedent(
f"""\
Node '{path_a}' in the left object has name '{node_a.name}'
Node '{path_b}' in the right object has name '{node_b.name}'"""
)
return diff

if len(node_a.children) != len(node_b.children):
diff = dedent(
f"""\
Number of children on node '{path_a}' of the left object: {len(node_a.children)}
Number of children on node '{path_b}' of the right object: {len(node_b.children)}"""
)
return diff

return ""


def map_over_subtree(func: Callable) -> Callable:
"""
Decorator which turns a function which acts on (and returns) Datasets into one which acts on and returns DataTrees.
Expand Down
266 changes: 266 additions & 0 deletions xarray/core/datatree_render.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,266 @@
"""
String Tree Rendering. Copied from anytree.
Minor changes to `RenderDataTree` include accessing `children.values()`, and
type hints.
"""

from __future__ import annotations

from collections import namedtuple
from collections.abc import Iterable, Iterator
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from xarray.core.datatree import DataTree

Row = namedtuple("Row", ("pre", "fill", "node"))


class AbstractStyle:
def __init__(self, vertical: str, cont: str, end: str):
"""
Tree Render Style.
Args:
vertical: Sign for vertical line.
cont: Chars for a continued branch.
end: Chars for the last branch.
"""
super().__init__()
self.vertical = vertical
self.cont = cont
self.end = end
assert (
len(cont) == len(vertical) == len(end)
), f"'{vertical}', '{cont}' and '{end}' need to have equal length"

@property
def empty(self) -> str:
"""Empty string as placeholder."""
return " " * len(self.end)

def __repr__(self) -> str:
return f"{self.__class__.__name__}()"


class ContStyle(AbstractStyle):
def __init__(self):
"""
Continued style, without gaps.
>>> from xarray.core.datatree import DataTree
>>> from xarray.core.datatree_render import RenderDataTree
>>> root = DataTree(name="root")
>>> s0 = DataTree(name="sub0", parent=root)
>>> s0b = DataTree(name="sub0B", parent=s0)
>>> s0a = DataTree(name="sub0A", parent=s0)
>>> s1 = DataTree(name="sub1", parent=root)
>>> print(RenderDataTree(root))
DataTree('root', parent=None)
├── DataTree('sub0')
│ ├── DataTree('sub0B')
│ └── DataTree('sub0A')
└── DataTree('sub1')
"""
super().__init__("\u2502 ", "\u251c\u2500\u2500 ", "\u2514\u2500\u2500 ")


class RenderDataTree:
def __init__(
self,
node: DataTree,
style=ContStyle(),
childiter: type = list,
maxlevel: int | None = None,
):
"""
Render tree starting at `node`.
Keyword Args:
style (AbstractStyle): Render Style.
childiter: Child iterator. Note, due to the use of node.children.values(),
Iterables that change the order of children cannot be used
(e.g., `reversed`).
maxlevel: Limit rendering to this depth.
:any:`RenderDataTree` is an iterator, returning a tuple with 3 items:
`pre`
tree prefix.
`fill`
filling for multiline entries.
`node`
:any:`NodeMixin` object.
It is up to the user to assemble these parts to a whole.
Examples
--------
>>> from xarray import Dataset
>>> from xarray.core.datatree import DataTree
>>> from xarray.core.datatree_render import RenderDataTree
>>> root = DataTree(name="root", data=Dataset({"a": 0, "b": 1}))
>>> s0 = DataTree(name="sub0", parent=root, data=Dataset({"c": 2, "d": 3}))
>>> s0b = DataTree(name="sub0B", parent=s0, data=Dataset({"e": 4}))
>>> s0a = DataTree(name="sub0A", parent=s0, data=Dataset({"f": 5, "g": 6}))
>>> s1 = DataTree(name="sub1", parent=root, data=Dataset({"h": 7}))
# Simple one line:
>>> for pre, _, node in RenderDataTree(root):
... print(f"{pre}{node.name}")
...
root
├── sub0
│ ├── sub0B
│ └── sub0A
└── sub1
# Multiline:
>>> for pre, fill, node in RenderDataTree(root):
... print(f"{pre}{node.name}")
... for variable in node.variables:
... print(f"{fill}{variable}")
...
root
a
b
├── sub0
│ c
│ d
│ ├── sub0B
│ │ e
│ └── sub0A
│ f
│ g
└── sub1
h
:any:`by_attr` simplifies attribute rendering and supports multiline:
>>> print(RenderDataTree(root).by_attr())
root
├── sub0
│ ├── sub0B
│ └── sub0A
└── sub1
# `maxlevel` limits the depth of the tree:
>>> print(RenderDataTree(root, maxlevel=2).by_attr("name"))
root
├── sub0
└── sub1
"""
if not isinstance(style, AbstractStyle):
style = style()
self.node = node
self.style = style
self.childiter = childiter
self.maxlevel = maxlevel

def __iter__(self) -> Iterator[Row]:
return self.__next(self.node, tuple())

def __next(
self, node: DataTree, continues: tuple[bool, ...], level: int = 0
) -> Iterator[Row]:
yield RenderDataTree.__item(node, continues, self.style)
children = node.children.values()
level += 1
if children and (self.maxlevel is None or level < self.maxlevel):
children = self.childiter(children)
for child, is_last in _is_last(children):
yield from self.__next(child, continues + (not is_last,), level=level)

@staticmethod
def __item(
node: DataTree, continues: tuple[bool, ...], style: AbstractStyle
) -> Row:
if not continues:
return Row("", "", node)
else:
items = [style.vertical if cont else style.empty for cont in continues]
indent = "".join(items[:-1])
branch = style.cont if continues[-1] else style.end
pre = indent + branch
fill = "".join(items)
return Row(pre, fill, node)

def __str__(self) -> str:
return str(self.node)

def __repr__(self) -> str:
classname = self.__class__.__name__
args = [
repr(self.node),
f"style={repr(self.style)}",
f"childiter={repr(self.childiter)}",
]
return f"{classname}({', '.join(args)})"

def by_attr(self, attrname: str = "name") -> str:
"""
Return rendered tree with node attribute `attrname`.
Examples
--------
>>> from xarray import Dataset
>>> from xarray.core.datatree import DataTree
>>> from xarray.core.datatree_render import RenderDataTree
>>> root = DataTree(name="root")
>>> s0 = DataTree(name="sub0", parent=root)
>>> s0b = DataTree(
... name="sub0B", parent=s0, data=Dataset({"foo": 4, "bar": 109})
... )
>>> s0a = DataTree(name="sub0A", parent=s0)
>>> s1 = DataTree(name="sub1", parent=root)
>>> s1a = DataTree(name="sub1A", parent=s1)
>>> s1b = DataTree(name="sub1B", parent=s1, data=Dataset({"bar": 8}))
>>> s1c = DataTree(name="sub1C", parent=s1)
>>> s1ca = DataTree(name="sub1Ca", parent=s1c)
>>> print(RenderDataTree(root).by_attr("name"))
root
├── sub0
│ ├── sub0B
│ └── sub0A
└── sub1
├── sub1A
├── sub1B
└── sub1C
└── sub1Ca
"""

def get() -> Iterator[str]:
for pre, fill, node in self:
attr = (
attrname(node)
if callable(attrname)
else getattr(node, attrname, "")
)
if isinstance(attr, (list, tuple)):
lines = attr
else:
lines = str(attr).split("\n")
yield f"{pre}{lines[0]}"
for line in lines[1:]:
yield f"{fill}{line}"

return "\n".join(get())


def _is_last(iterable: Iterable) -> Iterator[tuple[DataTree, bool]]:
iter_ = iter(iterable)
try:
nextitem = next(iter_)
except StopIteration:
pass
else:
item = nextitem
while True:
try:
nextitem = next(iter_)
yield item, False
except StopIteration:
yield nextitem, True
break
item = nextitem
Loading

0 comments on commit 6d62719

Please sign in to comment.