Skip to content

Commit

Permalink
Add inherit=False option to DataTree.copy() (#9628)
Browse files Browse the repository at this point in the history
* Add inherit=False option to DataTree.copy()

This PR adds a inherit=False option to DataTree.copy, so users can
decide if they want to inherit coordinates from parents or not when
creating a subtree.

The default behavior is `inherit=True`, which is a breaking change from
the current behavior where parent coordinates are dropped (which I
believe should be considered a bug).

* fix typing

* add migration guide note

* ignore typing error
  • Loading branch information
shoyer authored Oct 15, 2024
1 parent c3dabe1 commit 56f0e48
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 43 deletions.
3 changes: 2 additions & 1 deletion DATATREE_MIGRATION_GUIDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ A number of other API changes have been made, which should only require minor mo
- The top-level import has changed, from `from datatree import DataTree, open_datatree` to `from xarray import DataTree, open_datatree`. Alternatively you can now just use the `import xarray as xr` namespace convention for everything datatree-related.
- The `DataTree.ds` property has been changed to `DataTree.dataset`, though `DataTree.ds` remains as an alias for `DataTree.dataset`.
- Similarly the `ds` kwarg in the `DataTree.__init__` constructor has been replaced by `dataset`, i.e. use `DataTree(dataset=)` instead of `DataTree(ds=...)`.
- The method `DataTree.to_dataset()` still exists but now has different options for controlling which variables are present on the resulting `Dataset`, e.g. `inherited=True/False`.
- The method `DataTree.to_dataset()` still exists but now has different options for controlling which variables are present on the resulting `Dataset`, e.g. `inherit=True/False`.
- `DataTree.copy()` also has a new `inherit` keyword argument for controlling whether or not coordinates defined on parents are copied (only relevant when copying a non-root node).
- The `DataTree.parent` property is now read-only. To assign a ancestral relationships directly you must instead use the `.children` property on the parent node, which remains settable.
- Similarly the `parent` kwarg has been removed from the `DataTree.__init__` constuctor.
- DataTree objects passed to the `children` kwarg in `DataTree.__init__` are now shallow-copied.
Expand Down
18 changes: 7 additions & 11 deletions xarray/core/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -826,19 +826,13 @@ def _replace_node(

self.children = children

def _copy_node(
self: DataTree,
deep: bool = False,
) -> DataTree:
"""Copy just one node of a tree"""

new_node = super()._copy_node()

data = self._to_dataset_view(rebuild_dims=False, inherit=False)
def _copy_node(self, inherit: bool, deep: bool = False) -> Self:
"""Copy just one node of a tree."""
new_node = super()._copy_node(inherit=inherit, deep=deep)
data = self._to_dataset_view(rebuild_dims=False, inherit=inherit)
if deep:
data = data.copy(deep=True)
new_node._set_node_data(data)

return new_node

def get( # type: ignore[override]
Expand Down Expand Up @@ -1159,7 +1153,9 @@ def depth(item) -> int:
new_nodes_along_path=True,
)

return obj
# TODO: figure out why mypy is raising an error here, likely something
# to do with the return type of Dataset.copy()
return obj # type: ignore[return-value]

def to_dict(self) -> dict[str, Dataset]:
"""
Expand Down
50 changes: 21 additions & 29 deletions xarray/core/treenode.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
TypeVar,
)

from xarray.core.types import Self
from xarray.core.utils import Frozen, is_dict_like

if TYPE_CHECKING:
Expand Down Expand Up @@ -238,10 +239,7 @@ def _post_attach_children(self: Tree, children: Mapping[str, Tree]) -> None:
"""Method call after attaching `children`."""
pass

def copy(
self: Tree,
deep: bool = False,
) -> Tree:
def copy(self, *, inherit: bool = True, deep: bool = False) -> Self:
"""
Returns a copy of this subtree.
Expand All @@ -254,7 +252,12 @@ def copy(
Parameters
----------
deep : bool, default: False
inherit : bool
Whether inherited coordinates defined on parents of this node should
also be copied onto the new tree. Only relevant if the `parent` of
this node is not yet, and "Inherited coordinates" appear in its
repr.
deep : bool
Whether each component variable is loaded into memory and copied onto
the new object. Default is False.
Expand All @@ -269,35 +272,27 @@ def copy(
xarray.Dataset.copy
pandas.DataFrame.copy
"""
return self._copy_subtree(deep=deep)
return self._copy_subtree(inherit=inherit, deep=deep)

def _copy_subtree(
self: Tree,
deep: bool = False,
memo: dict[int, Any] | None = None,
) -> Tree:
def _copy_subtree(self, inherit: bool, deep: bool = False) -> Self:
"""Copy entire subtree recursively."""

new_tree = self._copy_node(deep=deep)
new_tree = self._copy_node(inherit=inherit, deep=deep)
for name, child in self.children.items():
# TODO use `.children[name] = ...` once #9477 is implemented
new_tree._set(name, child._copy_subtree(deep=deep))

new_tree._set(name, child._copy_subtree(inherit=False, deep=deep))
return new_tree

def _copy_node(
self: Tree,
deep: bool = False,
) -> Tree:
def _copy_node(self, inherit: bool, deep: bool = False) -> Self:
"""Copy just one node of a tree"""
new_empty_node = type(self)()
return new_empty_node

def __copy__(self: Tree) -> Tree:
return self._copy_subtree(deep=False)
def __copy__(self) -> Self:
return self._copy_subtree(inherit=True, deep=False)

def __deepcopy__(self: Tree, memo: dict[int, Any] | None = None) -> Tree:
return self._copy_subtree(deep=True, memo=memo)
def __deepcopy__(self, memo: dict[int, Any] | None = None) -> Self:
del memo # nodes cannot be reused in a DataTree
return self._copy_subtree(inherit=True, deep=True)

def _iter_parents(self: Tree) -> Iterator[Tree]:
"""Iterate up the tree, starting from the current node's parent."""
Expand Down Expand Up @@ -693,17 +688,14 @@ def __str__(self) -> str:
name_repr = repr(self.name) if self.name is not None else ""
return f"NamedNode({name_repr})"

def _post_attach(self: AnyNamedNode, parent: AnyNamedNode, name: str) -> None:
def _post_attach(self, parent: Self, name: str) -> None:
"""Ensures child has name attribute corresponding to key under which it has been stored."""
_validate_name(name) # is this check redundant?
self._name = name

def _copy_node(
self: AnyNamedNode,
deep: bool = False,
) -> AnyNamedNode:
def _copy_node(self, inherit: bool, deep: bool = False) -> Self:
"""Copy just one node of a tree"""
new_node = super()._copy_node()
new_node = super()._copy_node(inherit=inherit, deep=deep)
new_node._name = self.name
return new_node

Expand Down
12 changes: 10 additions & 2 deletions xarray/tests/test_datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,10 +414,18 @@ def test_copy_coord_inheritance(self) -> None:
tree = DataTree.from_dict(
{"/": xr.Dataset(coords={"x": [0, 1]}), "/c": DataTree()}
)
tree2 = tree.copy()
node_ds = tree2.children["c"].to_dataset(inherit=False)
actual = tree.copy()
node_ds = actual.children["c"].to_dataset(inherit=False)
assert_identical(node_ds, xr.Dataset())

actual = tree.children["c"].copy()
expected = DataTree(Dataset(coords={"x": [0, 1]}), name="c")
assert_identical(expected, actual)

actual = tree.children["c"].copy(inherit=False)
expected = DataTree(name="c")
assert_identical(expected, actual)

def test_deepcopy(self, create_test_datatree):
dt = create_test_datatree()

Expand Down

0 comments on commit 56f0e48

Please sign in to comment.