Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

open_groups for zarr backends #9469

Merged
merged 8 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 64 additions & 43 deletions xarray/backends/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1218,64 +1218,85 @@ def open_datatree(
zarr_version=None,
**kwargs,
) -> DataTree:
from xarray.backends.api import open_dataset
from xarray.core.datatree import DataTree

filename_or_obj = _normalize_path(filename_or_obj)
groups_dict = self.open_groups_as_dict(filename_or_obj, **kwargs)

return DataTree.from_dict(groups_dict)

def open_groups_as_dict(
eni-awowale marked this conversation as resolved.
Show resolved Hide resolved
self,
filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore,
*,
mask_and_scale=True,
decode_times=True,
concat_characters=True,
decode_coords=True,
drop_variables: str | Iterable[str] | None = None,
use_cftime=None,
decode_timedelta=None,
group: str | Iterable[str] | Callable | None = None,
mode="r",
synchronizer=None,
consolidated=None,
chunk_store=None,
storage_options=None,
stacklevel=3,
zarr_version=None,
**kwargs,
) -> dict[str, Dataset]:

from xarray.core.treenode import NodePath

filename_or_obj = _normalize_path(filename_or_obj)

# Check for a group and make it a parent if it exists
if group:
parent = NodePath("/") / NodePath(group)
stores = ZarrStore.open_store(
filename_or_obj,
group=parent,
mode=mode,
synchronizer=synchronizer,
consolidated=consolidated,
consolidate_on_close=False,
chunk_store=chunk_store,
storage_options=storage_options,
stacklevel=stacklevel + 1,
zarr_version=zarr_version,
)
if not stores:
ds = open_dataset(
filename_or_obj, group=parent, engine="zarr", **kwargs
)
return DataTree.from_dict({str(parent): ds})
else:
parent = NodePath("/")
stores = ZarrStore.open_store(
filename_or_obj,
group=parent,
mode=mode,
synchronizer=synchronizer,
consolidated=consolidated,
consolidate_on_close=False,
chunk_store=chunk_store,
storage_options=storage_options,
stacklevel=stacklevel + 1,
zarr_version=zarr_version,
)
ds = open_dataset(filename_or_obj, group=parent, engine="zarr", **kwargs)
tree_root = DataTree.from_dict({str(parent): ds})

stores = ZarrStore.open_store(
filename_or_obj,
group=parent,
TomNicholas marked this conversation as resolved.
Show resolved Hide resolved
mode=mode,
synchronizer=synchronizer,
consolidated=consolidated,
consolidate_on_close=False,
chunk_store=chunk_store,
storage_options=storage_options,
stacklevel=stacklevel + 1,
zarr_version=zarr_version,
)

groups_dict = {}

for path_group, store in stores.items():
ds = open_dataset(
filename_or_obj, store=store, group=path_group, engine="zarr", **kwargs
)
new_node: DataTree = DataTree(name=NodePath(path_group).name, data=ds)
tree_root._set_item(
path_group,
new_node,
allow_overwrite=False,
new_nodes_along_path=True,
)
return tree_root
store_entrypoint = StoreBackendEntrypoint()

with close_on_error(store):
group_ds = store_entrypoint.open_dataset(
store,
mask_and_scale=mask_and_scale,
decode_times=decode_times,
concat_characters=concat_characters,
decode_coords=decode_coords,
drop_variables=drop_variables,
use_cftime=use_cftime,
decode_timedelta=decode_timedelta,
)
group_name = str(NodePath(path_group))
groups_dict[group_name] = group_ds

return groups_dict


def _iter_zarr_groups(root, parent="/"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's try to add type hints (though I don't know what the type of root is)

Suggested change
def _iter_zarr_groups(root, parent="/"):
def _iter_zarr_groups(root: ZarrGroup?, parent: str = "/") -> Iterable[str]:

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure what the appropriate TypeHint would be here that wouldn't break mypy.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had a play with this. I traced root back and found it is indeed a zarr.Group, so I think the type hint above works (along with adding from zarr import Group as ZarrGroup in the if TYPE_CHECKING block at the top of the file).

The bit that was causing an issue the assignment of a NodePath object to parent, which is a str from the function signature. (Apologies if I'm just playing catch up here from what you guys have been saying all along)

I could make mypy happy if I did the following in the function:

def _iter_zarr_groups(root: ZarrGroup, parent: str = "/") -> Iterable[str]:
    from xarray.core.treenode import NodePath

    parent_nodepath = NodePath(parent)
    yield str(parent_nodepath)
    for path, group in root.groups():
        gpath = parent_nodepath / path
        yield str(gpath)
        yield from _iter_zarr_groups(group, parent=gpath)

Maybe that's an option? (Unless I'm missing a trick with this breaking the recursion?)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good!

(No need for from xarray.core.treenode import NodePath to live inside the function though)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Owen! That did the trick!

from xarray.core.treenode import NodePath

parent = NodePath(parent)
yield str(parent)
eni-awowale marked this conversation as resolved.
Show resolved Hide resolved
for path, group in root.groups():
gpath = parent / path
yield str(gpath)
Expand Down
225 changes: 138 additions & 87 deletions xarray/tests/test_backends_datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,88 @@
pass


@pytest.fixture(scope="module")
eni-awowale marked this conversation as resolved.
Show resolved Hide resolved
def unaligned_datatree_nc(tmp_path_factory):
"""Creates a test netCDF4 file with the following unaligned structure, writes it to a /tmp directory
and returns the file path of the netCDF4 file.

Group: /
│ Dimensions: (lat: 1, lon: 2)
│ Dimensions without coordinates: lat, lon
│ Data variables:
│ root_variable (lat, lon) float64 16B ...
└── Group: /Group1
│ Dimensions: (lat: 1, lon: 2)
│ Dimensions without coordinates: lat, lon
│ Data variables:
│ group_1_var (lat, lon) float64 16B ...
└── Group: /Group1/subgroup1
Dimensions: (lat: 2, lon: 2)
Dimensions without coordinates: lat, lon
Data variables:
subgroup1_var (lat, lon) float64 32B ...
"""
filepath = tmp_path_factory.mktemp("data") / "unaligned_subgroups.nc"
with nc4.Dataset(filepath, "w", format="NETCDF4") as root_group:
group_1 = root_group.createGroup("/Group1")
subgroup_1 = group_1.createGroup("/subgroup1")

root_group.createDimension("lat", 1)
root_group.createDimension("lon", 2)
root_group.createVariable("root_variable", np.float64, ("lat", "lon"))

group_1_var = group_1.createVariable("group_1_var", np.float64, ("lat", "lon"))
group_1_var[:] = np.array([[0.1, 0.2]])
group_1_var.units = "K"
group_1_var.long_name = "air_temperature"

subgroup_1.createDimension("lat", 2)

subgroup1_var = subgroup_1.createVariable(
"subgroup1_var", np.float64, ("lat", "lon")
)
subgroup1_var[:] = np.array([[0.1, 0.2]])

yield filepath


@pytest.fixture(scope="module")
def unaligned_datatree_zarr(tmp_path_factory):
"""Creates a zarr store with the following unaligned group hierarchy:
Group: /
│ Dimensions: (y: 3, x: 2)
│ Dimensions without coordinates: y, x
│ Data variables:
│ a (y) int64 24B ...
│ set0 (x) int64 16B ...
└── Group: /Group1
│ │ Dimensions: ()
│ │ Data variables:
│ │ a int64 8B ...
│ │ b int64 8B ...
│ └── /Group1/subgroup1
│ Dimensions: ()
│ Data variables:
│ a int64 8B ...
│ b int64 8B ...
└── Group: /Group2
Dimensions: (y: 2, x: 2)
Dimensions without coordinates: y, x
Data variables:
a (y) int64 16B ...
b (x) float64 16B ...
"""
filepath = tmp_path_factory.mktemp("data") / "unaligned_simple_datatree.zarr"
root_data = xr.Dataset({"a": ("y", [6, 7, 8]), "set0": ("x", [9, 10])})
set1_data = xr.Dataset({"a": 0, "b": 1})
set2_data = xr.Dataset({"a": ("y", [2, 3]), "b": ("x", [0.1, 0.2])})
root_data.to_zarr(filepath)
set1_data.to_zarr(filepath, group="/Group1", mode="a")
set2_data.to_zarr(filepath, group="/Group2", mode="a")
set1_data.to_zarr(filepath, group="/Group1/subgroup1", mode="a")
yield filepath


class DatatreeIOBase:
engine: T_DataTreeNetcdfEngine | None = None

Expand Down Expand Up @@ -73,111 +155,35 @@ def test_netcdf_encoding(self, tmpdir, simple_datatree):
class TestNetCDF4DatatreeIO(DatatreeIOBase):
engine: T_DataTreeNetcdfEngine | None = "netcdf4"

def test_open_datatree(self, tmpdir) -> None:
"""Create a test netCDF4 file with this unaligned structure:
Group: /
│ Dimensions: (lat: 1, lon: 2)
│ Dimensions without coordinates: lat, lon
│ Data variables:
│ root_variable (lat, lon) float64 16B ...
└── Group: /Group1
│ Dimensions: (lat: 1, lon: 2)
│ Dimensions without coordinates: lat, lon
│ Data variables:
│ group_1_var (lat, lon) float64 16B ...
└── Group: /Group1/subgroup1
Dimensions: (lat: 2, lon: 2)
Dimensions without coordinates: lat, lon
Data variables:
subgroup1_var (lat, lon) float64 32B ...
"""
filepath = tmpdir + "/unaligned_subgroups.nc"
with nc4.Dataset(filepath, "w", format="NETCDF4") as root_group:
group_1 = root_group.createGroup("/Group1")
subgroup_1 = group_1.createGroup("/subgroup1")

root_group.createDimension("lat", 1)
root_group.createDimension("lon", 2)
root_group.createVariable("root_variable", np.float64, ("lat", "lon"))

group_1_var = group_1.createVariable(
"group_1_var", np.float64, ("lat", "lon")
)
group_1_var[:] = np.array([[0.1, 0.2]])
group_1_var.units = "K"
group_1_var.long_name = "air_temperature"

subgroup_1.createDimension("lat", 2)

subgroup1_var = subgroup_1.createVariable(
"subgroup1_var", np.float64, ("lat", "lon")
)
subgroup1_var[:] = np.array([[0.1, 0.2]])
def test_open_datatree(self, unaligned_datatree_nc) -> None:
"""Test if `open_datatree` fails to open a netCDF4 with an unaligned group hierarchy."""
with pytest.raises(ValueError):
open_datatree(filepath)
open_datatree(unaligned_datatree_nc)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should definitely use

pytest.raises(ValueError, match=...):

to match a much more specific error. We are expecting this test to fail because of an alignment error, not any other type of ValueError.

(Maybe we should create a new exception AlignmentError(ValueError), given that we do already have MergeError(ValueError)? @shoyer)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about something like this but in the treenode.py module?

class AlignmentError(ValueError):
    """Raised when a child node has coordinates that are not aligned with its parents."""

If we do this I think we might have to change datatree.py to raise AlignmentError in
`_check_alignment

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alignment is a core xarray concept separate of datatree. My suggestion was to change the ValueErrors raised in alignment.py to become AlignmentErrors, then datatree.py would automatically throw AlignmentError when _check_alignment is called (because it calls xr.align internally).

This suggestion is outside the scope of this PR though - you should go ahead with just using match to match a specific ValueError and then we can maybe do AlignmentError in a separate PR.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated in the last commit 🚀

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure the order of merging to expect, but if @flamingbear merges #9033 first, then we can clean up the imports and use xr.open_datatree. If it's the other way around, then it'll probably be best for one of us to take a quick skim over the tests and tidy them up in a separate PR.

My guess is that this PR will merge before #9033, so we'll likely end up with a separate tiny PR for this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tidying up these imports isn't a big deal, and can be left for another PR


def test_open_groups(self, tmpdir) -> None:
"""Test `open_groups` with netCDF4 file with the same unaligned structure:
Group: /
│ Dimensions: (lat: 1, lon: 2)
│ Dimensions without coordinates: lat, lon
│ Data variables:
│ root_variable (lat, lon) float64 16B ...
└── Group: /Group1
│ Dimensions: (lat: 1, lon: 2)
│ Dimensions without coordinates: lat, lon
│ Data variables:
│ group_1_var (lat, lon) float64 16B ...
└── Group: /Group1/subgroup1
Dimensions: (lat: 2, lon: 2)
Dimensions without coordinates: lat, lon
Data variables:
subgroup1_var (lat, lon) float64 32B ...
"""
filepath = tmpdir + "/unaligned_subgroups.nc"
with nc4.Dataset(filepath, "w", format="NETCDF4") as root_group:
group_1 = root_group.createGroup("/Group1")
subgroup_1 = group_1.createGroup("/subgroup1")

root_group.createDimension("lat", 1)
root_group.createDimension("lon", 2)
root_group.createVariable("root_variable", np.float64, ("lat", "lon"))

group_1_var = group_1.createVariable(
"group_1_var", np.float64, ("lat", "lon")
)
group_1_var[:] = np.array([[0.1, 0.2]])
group_1_var.units = "K"
group_1_var.long_name = "air_temperature"

subgroup_1.createDimension("lat", 2)

subgroup1_var = subgroup_1.createVariable(
"subgroup1_var", np.float64, ("lat", "lon")
)
subgroup1_var[:] = np.array([[0.1, 0.2]])

unaligned_dict_of_datasets = open_groups(filepath)
def test_open_groups(self, unaligned_datatree_nc) -> None:
"""Test `open_groups` with a netCDF4 file with an unaligned group hierarchy."""
unaligned_dict_of_datasets = open_groups(unaligned_datatree_nc)

# Check that group names are keys in the dictionary of `xr.Datasets`
assert "/" in unaligned_dict_of_datasets.keys()
assert "/Group1" in unaligned_dict_of_datasets.keys()
assert "/Group1/subgroup1" in unaligned_dict_of_datasets.keys()
# Check that group name returns the correct datasets
assert_identical(
unaligned_dict_of_datasets["/"], xr.open_dataset(filepath, group="/")
unaligned_dict_of_datasets["/"],
xr.open_dataset(unaligned_datatree_nc, group="/"),
)
assert_identical(
unaligned_dict_of_datasets["/Group1"],
xr.open_dataset(filepath, group="Group1"),
xr.open_dataset(unaligned_datatree_nc, group="Group1"),
)
assert_identical(
unaligned_dict_of_datasets["/Group1/subgroup1"],
xr.open_dataset(filepath, group="/Group1/subgroup1"),
xr.open_dataset(unaligned_datatree_nc, group="/Group1/subgroup1"),
)

def test_open_groups_to_dict(self, tmpdir) -> None:
"""Create a an aligned netCDF4 with the following structure to test `open_groups`
"""Create an aligned netCDF4 with the following structure to test `open_groups`
and `DataTree.from_dict`.
Group: /
│ Dimensions: (lat: 1, lon: 2)
Expand Down Expand Up @@ -305,3 +311,48 @@ def test_to_zarr_inherited_coords(self, tmpdir):
assert_equal(original_dt, roundtrip_dt)
subtree = cast(DataTree, roundtrip_dt["/sub"])
assert "x" not in subtree.to_dataset(inherited=False).coords

def test_open_groups_round_trip(self, tmpdir, simple_datatree) -> None:
"""Test `open_groups` opens a zarr store with the `simple_datatree` structure."""
filepath = tmpdir / "test.zarr"
original_dt = simple_datatree
original_dt.to_zarr(filepath)

roundtrip_dict = open_groups(filepath, engine="zarr")
roundtrip_dt = DataTree.from_dict(roundtrip_dict)

assert open_datatree(filepath, engine="zarr").identical(roundtrip_dt)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.identical is going to be different to check if the coordinates are defined on disk.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Specifically see #9473.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay so from looking at this and testing this locally I don't think this will affect these tests.


def test_open_datatree(self, unaligned_datatree_zarr) -> None:
"""Test if `open_datatree` fails to open a zarr store with an unaligned group hierarchy."""
with pytest.raises(ValueError):
open_datatree(unaligned_datatree_zarr, engine="zarr")

def test_open_groups(self, unaligned_datatree_zarr) -> None:
"""Test `open_groups` with a zarr store of an unaligned group hierarchy."""

unaligned_dict_of_datasets = open_groups(unaligned_datatree_zarr, engine="zarr")

assert "/" in unaligned_dict_of_datasets.keys()
assert "/Group1" in unaligned_dict_of_datasets.keys()
assert "/Group1/subgroup1" in unaligned_dict_of_datasets.keys()
assert "/Group2" in unaligned_dict_of_datasets.keys()
# Check that group name returns the correct datasets
assert_identical(
unaligned_dict_of_datasets["/"],
xr.open_dataset(unaligned_datatree_zarr, group="/", engine="zarr"),
)
assert_identical(
unaligned_dict_of_datasets["/Group1"],
xr.open_dataset(unaligned_datatree_zarr, group="Group1", engine="zarr"),
)
assert_identical(
unaligned_dict_of_datasets["/Group1/subgroup1"],
xr.open_dataset(
unaligned_datatree_zarr, group="/Group1/subgroup1", engine="zarr"
),
)
assert_identical(
unaligned_dict_of_datasets["/Group2"],
xr.open_dataset(unaligned_datatree_zarr, group="/Group2", engine="zarr"),
)
Loading