Skip to content

Commit

Permalink
open_groups for zarr backends (pydata#9469)
Browse files Browse the repository at this point in the history
* open groups zarr initial commit

* added tests

* Added requested changes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* TypeHint for zarr groups

* update for parent

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
eni-awowale and pre-commit-ci[bot] authored Sep 12, 2024
1 parent 637f820 commit aeaa082
Show file tree
Hide file tree
Showing 2 changed files with 225 additions and 136 deletions.
121 changes: 72 additions & 49 deletions xarray/backends/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
)
from xarray.backends.store import StoreBackendEntrypoint
from xarray.core import indexing
from xarray.core.treenode import NodePath
from xarray.core.types import ZarrWriteModes
from xarray.core.utils import (
FrozenDict,
Expand All @@ -33,6 +34,8 @@
if TYPE_CHECKING:
from io import BufferedIOBase

from zarr import Group as ZarrGroup

from xarray.backends.common import AbstractDataStore
from xarray.core.dataset import Dataset
from xarray.core.datatree import DataTree
Expand Down Expand Up @@ -1218,66 +1221,86 @@ def open_datatree(
zarr_version=None,
**kwargs,
) -> DataTree:
from xarray.backends.api import open_dataset
from xarray.core.datatree import DataTree

filename_or_obj = _normalize_path(filename_or_obj)
groups_dict = self.open_groups_as_dict(filename_or_obj, **kwargs)

return DataTree.from_dict(groups_dict)

def open_groups_as_dict(
self,
filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore,
*,
mask_and_scale=True,
decode_times=True,
concat_characters=True,
decode_coords=True,
drop_variables: str | Iterable[str] | None = None,
use_cftime=None,
decode_timedelta=None,
group: str | Iterable[str] | Callable | None = None,
mode="r",
synchronizer=None,
consolidated=None,
chunk_store=None,
storage_options=None,
stacklevel=3,
zarr_version=None,
**kwargs,
) -> dict[str, Dataset]:

from xarray.core.treenode import NodePath

filename_or_obj = _normalize_path(filename_or_obj)

# Check for a group and make it a parent if it exists
if group:
parent = NodePath("/") / NodePath(group)
stores = ZarrStore.open_store(
filename_or_obj,
group=parent,
mode=mode,
synchronizer=synchronizer,
consolidated=consolidated,
consolidate_on_close=False,
chunk_store=chunk_store,
storage_options=storage_options,
stacklevel=stacklevel + 1,
zarr_version=zarr_version,
)
if not stores:
ds = open_dataset(
filename_or_obj, group=parent, engine="zarr", **kwargs
)
return DataTree.from_dict({str(parent): ds})
parent = str(NodePath("/") / NodePath(group))
else:
parent = NodePath("/")
stores = ZarrStore.open_store(
filename_or_obj,
group=parent,
mode=mode,
synchronizer=synchronizer,
consolidated=consolidated,
consolidate_on_close=False,
chunk_store=chunk_store,
storage_options=storage_options,
stacklevel=stacklevel + 1,
zarr_version=zarr_version,
)
ds = open_dataset(filename_or_obj, group=parent, engine="zarr", **kwargs)
tree_root = DataTree.from_dict({str(parent): ds})
parent = str(NodePath("/"))

stores = ZarrStore.open_store(
filename_or_obj,
group=parent,
mode=mode,
synchronizer=synchronizer,
consolidated=consolidated,
consolidate_on_close=False,
chunk_store=chunk_store,
storage_options=storage_options,
stacklevel=stacklevel + 1,
zarr_version=zarr_version,
)

groups_dict = {}

for path_group, store in stores.items():
ds = open_dataset(
filename_or_obj, store=store, group=path_group, engine="zarr", **kwargs
)
new_node = DataTree(name=NodePath(path_group).name, dataset=ds)
tree_root._set_item(
path_group,
new_node,
allow_overwrite=False,
new_nodes_along_path=True,
)
return tree_root
store_entrypoint = StoreBackendEntrypoint()

with close_on_error(store):
group_ds = store_entrypoint.open_dataset(
store,
mask_and_scale=mask_and_scale,
decode_times=decode_times,
concat_characters=concat_characters,
decode_coords=decode_coords,
drop_variables=drop_variables,
use_cftime=use_cftime,
decode_timedelta=decode_timedelta,
)
group_name = str(NodePath(path_group))
groups_dict[group_name] = group_ds

return groups_dict


def _iter_zarr_groups(root, parent="/"):
from xarray.core.treenode import NodePath
def _iter_zarr_groups(root: ZarrGroup, parent: str = "/") -> Iterable[str]:

parent = NodePath(parent)
parent_nodepath = NodePath(parent)
yield str(parent_nodepath)
for path, group in root.groups():
gpath = parent / path
gpath = parent_nodepath / path
yield str(gpath)
yield from _iter_zarr_groups(group, parent=gpath)

Expand Down
Loading

0 comments on commit aeaa082

Please sign in to comment.