Skip to content

Commit

Permalink
Rename DataTree's "ds" and "data" to "dataset" (pydata#9476)
Browse files Browse the repository at this point in the history
* Rename DataTree's "ds" and "data" to "dataset"

.ds is kept around as a soft-deprecated alias to facilitate the
transition from xarray-contrib/datatree, though I verified that all
tests pass without it.

* fix data= usage in test_formatting_html.py

* fix formatting test
  • Loading branch information
shoyer authored Sep 11, 2024
1 parent a6bacfe commit fac2c89
Show file tree
Hide file tree
Showing 8 changed files with 95 additions and 92 deletions.
2 changes: 1 addition & 1 deletion xarray/backends/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1262,7 +1262,7 @@ def open_datatree(
ds = open_dataset(
filename_or_obj, store=store, group=path_group, engine="zarr", **kwargs
)
new_node: DataTree = DataTree(name=NodePath(path_group).name, data=ds)
new_node = DataTree(name=NodePath(path_group).name, dataset=ds)
tree_root._set_item(
path_group,
new_node,
Expand Down
66 changes: 35 additions & 31 deletions xarray/core/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ class DataTree(

def __init__(
self,
data: Dataset | None = None,
dataset: Dataset | None = None,
children: Mapping[str, DataTree] | None = None,
name: str | None = None,
):
Expand All @@ -430,12 +430,12 @@ def __init__(
Parameters
----------
data : Dataset, optional
Data to store under the .ds attribute of this node.
dataset : Dataset, optional
Data to store directly at this node.
children : Mapping[str, DataTree], optional
Any child nodes of this node. Default is None.
Any child nodes of this node.
name : str, optional
Name for this node of the tree. Default is None.
Name for this node of the tree.
Returns
-------
Expand All @@ -449,24 +449,24 @@ def __init__(
children = {}

super().__init__(name=name)
self._set_node_data(_to_new_dataset(data))
self._set_node_data(_to_new_dataset(dataset))

# shallow copy to avoid modifying arguments in-place (see GH issue #9196)
self.children = {name: child.copy() for name, child in children.items()}

def _set_node_data(self, ds: Dataset):
data_vars, coord_vars = _collect_data_and_coord_variables(ds)
def _set_node_data(self, dataset: Dataset):
data_vars, coord_vars = _collect_data_and_coord_variables(dataset)
self._data_variables = data_vars
self._node_coord_variables = coord_vars
self._node_dims = ds._dims
self._node_indexes = ds._indexes
self._encoding = ds._encoding
self._attrs = ds._attrs
self._close = ds._close
self._node_dims = dataset._dims
self._node_indexes = dataset._indexes
self._encoding = dataset._encoding
self._attrs = dataset._attrs
self._close = dataset._close

def _pre_attach(self: DataTree, parent: DataTree, name: str) -> None:
super()._pre_attach(parent, name)
if name in parent.ds.variables:
if name in parent.dataset.variables:
raise KeyError(
f"parent {parent.name} already contains a variable named {name}"
)
Expand Down Expand Up @@ -534,7 +534,7 @@ def _to_dataset_view(self, rebuild_dims: bool, inherited: bool) -> DatasetView:
)

@property
def ds(self) -> DatasetView:
def dataset(self) -> DatasetView:
"""
An immutable Dataset-like view onto the data in this node.
Expand All @@ -549,11 +549,15 @@ def ds(self) -> DatasetView:
"""
return self._to_dataset_view(rebuild_dims=True, inherited=True)

@ds.setter
def ds(self, data: Dataset | None = None) -> None:
@dataset.setter
def dataset(self, data: Dataset | None = None) -> None:
ds = _to_new_dataset(data)
self._replace_node(ds)

# soft-deprecated alias, to facilitate the transition from
# xarray-contrib/datatree
ds = dataset

def to_dataset(self, inherited: bool = True) -> Dataset:
"""
Return the data in this node as a new xarray.Dataset object.
Expand All @@ -566,7 +570,7 @@ def to_dataset(self, inherited: bool = True) -> Dataset:
See Also
--------
DataTree.ds
DataTree.dataset
"""
coord_vars = self._coord_variables if inherited else self._node_coord_variables
variables = dict(self._data_variables)
Expand Down Expand Up @@ -845,8 +849,8 @@ def get( # type: ignore[override]
"""
if key in self.children:
return self.children[key]
elif key in self.ds:
return self.ds[key]
elif key in self.dataset:
return self.dataset[key]
else:
return default

Expand Down Expand Up @@ -1114,7 +1118,7 @@ def from_dict(
if isinstance(root_data, DataTree):
obj = root_data.copy()
elif root_data is None or isinstance(root_data, Dataset):
obj = cls(name=name, data=root_data, children=None)
obj = cls(name=name, dataset=root_data, children=None)
else:
raise TypeError(
f'root node data (at "/") must be a Dataset or DataTree, got {type(root_data)}'
Expand All @@ -1133,7 +1137,7 @@ def depth(item) -> int:
if isinstance(data, DataTree):
new_node = data.copy()
elif isinstance(data, Dataset) or data is None:
new_node = cls(name=node_name, data=data)
new_node = cls(name=node_name, dataset=data)
else:
raise TypeError(f"invalid values: {data}")
obj._set_item(
Expand Down Expand Up @@ -1264,7 +1268,7 @@ def equals(self, other: DataTree, from_root: bool = True) -> bool:

return all(
[
node.ds.equals(other_node.ds)
node.dataset.equals(other_node.dataset)
for node, other_node in zip(self.subtree, other.subtree, strict=True)
]
)
Expand Down Expand Up @@ -1294,7 +1298,7 @@ def identical(self, other: DataTree, from_root=True) -> bool:
return False

return all(
node.ds.identical(other_node.ds)
node.dataset.identical(other_node.dataset)
for node, other_node in zip(self.subtree, other.subtree, strict=True)
)

Expand All @@ -1321,7 +1325,7 @@ def filter(self: DataTree, filterfunc: Callable[[DataTree], bool]) -> DataTree:
map_over_subtree
"""
filtered_nodes = {
node.path: node.ds for node in self.subtree if filterfunc(node)
node.path: node.dataset for node in self.subtree if filterfunc(node)
}
return DataTree.from_dict(filtered_nodes, name=self.root.name)

Expand Down Expand Up @@ -1365,7 +1369,7 @@ def match(self, pattern: str) -> DataTree:
└── Group: /b/B
"""
matching_nodes = {
node.path: node.ds
node.path: node.dataset
for node in self.subtree
if NodePath(node.path).match(pattern)
}
Expand All @@ -1389,7 +1393,7 @@ def map_over_subtree(
----------
func : callable
Function to apply to datasets with signature:
`func(node.ds, *args, **kwargs) -> Dataset`.
`func(node.dataset, *args, **kwargs) -> Dataset`.
Function will not be applied to any nodes without datasets.
*args : tuple, optional
Expand Down Expand Up @@ -1420,7 +1424,7 @@ def map_over_subtree_inplace(
----------
func : callable
Function to apply to datasets with signature:
`func(node.ds, *args, **kwargs) -> Dataset`.
`func(node.dataset, *args, **kwargs) -> Dataset`.
Function will not be applied to any nodes without datasets,
*args : tuple, optional
Expand All @@ -1433,7 +1437,7 @@ def map_over_subtree_inplace(

for node in self.subtree:
if node.has_data:
node.ds = func(node.ds, *args, **kwargs)
node.dataset = func(node.dataset, *args, **kwargs)

def pipe(
self, func: Callable | tuple[Callable, str], *args: Any, **kwargs: Any
Expand Down Expand Up @@ -1499,7 +1503,7 @@ def render(self):
"""Print tree structure, including any data stored at each node."""
for pre, fill, node in RenderDataTree(self):
print(f"{pre}DataTree('{self.name}')")
for ds_line in repr(node.ds)[1:]:
for ds_line in repr(node.dataset)[1:]:
print(f"{fill}{ds_line}")

def merge(self, datatree: DataTree) -> DataTree:
Expand All @@ -1513,7 +1517,7 @@ def merge_child_nodes(self, *paths, new_path: T_Path) -> DataTree:
# TODO some kind of .collapse() or .flatten() method to merge a subtree

def to_dataarray(self) -> DataArray:
return self.ds.to_dataarray()
return self.dataset.to_dataarray()

@property
def groups(self):
Expand Down
11 changes: 6 additions & 5 deletions xarray/core/datatree_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,10 @@ def map_over_subtree(func: Callable) -> Callable:
Function will not be applied to any nodes without datasets.
*args : tuple, optional
Positional arguments passed on to `func`. If DataTrees any data-containing nodes will be converted to Datasets
via `.ds`.
via `.dataset`.
**kwargs : Any
Keyword arguments passed on to `func`. If DataTrees any data-containing nodes will be converted to Datasets
via `.ds`.
via `.dataset`.
Returns
-------
Expand Down Expand Up @@ -160,13 +160,14 @@ def _map_over_subtree(*args, **kwargs) -> DataTree | tuple[DataTree, ...]:
strict=False,
):
node_args_as_datasetviews = [
a.ds if isinstance(a, DataTree) else a for a in all_node_args[:n_args]
a.dataset if isinstance(a, DataTree) else a
for a in all_node_args[:n_args]
]
node_kwargs_as_datasetviews = dict(
zip(
[k for k in kwargs_as_tree_length_iterables.keys()],
[
v.ds if isinstance(v, DataTree) else v
v.dataset if isinstance(v, DataTree) else v
for v in all_node_args[n_args:]
],
strict=True,
Expand All @@ -183,7 +184,7 @@ def _map_over_subtree(*args, **kwargs) -> DataTree | tuple[DataTree, ...]:
)
elif node_of_first_tree.has_attrs:
# propagate attrs
results = node_of_first_tree.ds
results = node_of_first_tree.dataset
else:
# nothing to propagate so use fastpath to create empty node in new tree
results = None
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -1038,7 +1038,7 @@ def diff_nodewise_summary(a: DataTree, b: DataTree, compat):

summary = []
for node_a, node_b in zip(a.subtree, b.subtree, strict=True):
a_ds, b_ds = node_a.ds, node_b.ds
a_ds, b_ds = node_a.dataset, node_b.dataset

if not a_ds._all_compat(b_ds, compat):
dataset_diff = diff_dataset_repr(a_ds, b_ds, compat_str)
Expand Down
4 changes: 2 additions & 2 deletions xarray/tests/test_backends_datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def test_netcdf_encoding(self, tmpdir, simple_datatree):

# add compression
comp = dict(zlib=True, complevel=9)
enc = {"/set2": {var: comp for var in original_dt["/set2"].ds.data_vars}}
enc = {"/set2": {var: comp for var in original_dt["/set2"].dataset.data_vars}}

original_dt.to_netcdf(filepath, encoding=enc, engine=self.engine)
roundtrip_dt = open_datatree(filepath, engine=self.engine)
Expand Down Expand Up @@ -246,7 +246,7 @@ def test_zarr_encoding(self, tmpdir, simple_datatree):
original_dt = simple_datatree

comp = {"compressor": zarr.Blosc(cname="zstd", clevel=3, shuffle=2)}
enc = {"/set2": {var: comp for var in original_dt["/set2"].ds.data_vars}}
enc = {"/set2": {var: comp for var in original_dt["/set2"].dataset.data_vars}}
original_dt.to_zarr(filepath, encoding=enc)
roundtrip_dt = open_datatree(filepath, engine="zarr")

Expand Down
Loading

0 comments on commit fac2c89

Please sign in to comment.