Skip to content

Commit

Permalink
amend me: dask based zarr loader
Browse files Browse the repository at this point in the history
  • Loading branch information
Kirill888 committed Jun 24, 2024
1 parent 7814ca3 commit cb4c21d
Showing 1 changed file with 103 additions and 28 deletions.
131 changes: 103 additions & 28 deletions odc/loader/_zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import json
from contextlib import contextmanager
from typing import Any, Iterator, Sequence
from typing import Any, Iterator, Mapping, Sequence

import fsspec
import numpy as np
Expand All @@ -26,20 +26,23 @@
ReaderSubsetSelection,
)

SomeDoc = dict[str, Any]
ZarrSpec = dict[str, Any]
ZarrSpecFs = dict[str, Any]
# TODO: tighten specs for Zarr*
SomeDoc = Mapping[str, Any]
ZarrSpec = Mapping[str, Any]
ZarrSpecFs = Mapping[str, Any]
ZarrSpecFsDict = dict[str, Any]
# pylint: disable=too-few-public-methods


def extract_zarr_spec(src: SomeDoc) -> ZarrSpecFs | None:
def extract_zarr_spec(src: SomeDoc) -> ZarrSpecFsDict | None:
if ".zmetadata" in src:
return src
return dict(src)

if "zarr:metadata" in src:
# TODO: handle zarr:chunks for reference filesystem
zmd = {"zarr_consolidated_format": 1, "metadata": src["zarr:metadata"]}
elif "zarr_consolidated_format" in src:
zmd = src
zmd = dict(src)
else:
zmd = {"zarr_consolidated_format": 1, "metadata": src}

Expand Down Expand Up @@ -118,9 +121,7 @@ def extract(self, md: Any) -> RasterGroupMetadata:

return self._template

def driver_data(
self, md: Any, band_key: BandKey
) -> xr.DataArray | ZarrSpec | ZarrSpecFs | SomeDoc | None:
def driver_data(self, md: Any, band_key: BandKey) -> xr.DataArray | SomeDoc | None:
"""
Extract driver specific data for a given band.
"""
Expand Down Expand Up @@ -161,27 +162,24 @@ def with_env(self, env: dict[str, Any]) -> "Context":
return Context(self.geobox, self.chunks)


class XrMemReader:
class XrSource:
"""
Implements protocol for raster readers.
- Read from in-memory xarray.Dataset
- Read from zarr spec
RasterSource -> xr.DataArray|xr.Dataset
"""

# pylint: disable=too-few-public-methods

def __init__(self, src: RasterSource, ctx: Context) -> None:
def __init__(self, src: RasterSource, chunks: Any | None = None) -> None:
driver_data: xr.DataArray | xr.Dataset | SomeDoc = src.driver_data
self._spec: ZarrSpecFs | None = None
self._ds: xr.Dataset | None = None
self._xx: xr.DataArray | None = None
self._ctx = ctx
self._src = src
self._chunks = chunks

if isinstance(driver_data, xr.DataArray):
self._xx = driver_data
elif isinstance(driver_data, xr.Dataset):
subdataset = src.subdataset
self._ds = driver_data
assert subdataset in driver_data.data_vars
self._xx = driver_data.data_vars[subdataset]
elif isinstance(driver_data, dict):
Expand All @@ -191,18 +189,31 @@ def __init__(self, src: RasterSource, ctx: Context) -> None:

assert driver_data is None or (self._spec is not None or self._xx is not None)

def _resolve_data_array(self) -> xr.DataArray:
if self._xx is not None:
return self._xx
@property
def spec(self) -> ZarrSpecFs | None:
return self._spec

assert self._spec is not None
src_ds = _from_zarr_spec(
def base(self, regen_coords: bool = False) -> xr.Dataset | None:
if self._ds is not None:
return self._ds
if self._spec is None:
return None
self._ds = _from_zarr_spec(
self._spec,
regen_coords=True,
regen_coords=regen_coords,
target=self._src.uri,
chunks=self._ctx.chunks,
chunks=self._chunks,
)
return self._ds

def resolve(
self,
regen_coords: bool = False,
) -> xr.DataArray:
if self._xx is not None:
return self._xx

src_ds = self.base(regen_coords=regen_coords)
if src_ds is None:
raise ValueError("Failed to interpret driver data")

Expand All @@ -213,9 +224,23 @@ def _resolve_data_array(self) -> xr.DataArray:

if subdataset not in src_ds.data_vars:
raise ValueError(f"Band {subdataset!r} not found in dataset")

self._xx = src_ds.data_vars[subdataset]
return self._xx


class XrMemReader:
"""
Implements protocol for raster readers.
- Read from in-memory xarray.Dataset
- Read from zarr spec
"""

def __init__(self, src: RasterSource, ctx: Context) -> None:
self._src = XrSource(src, chunks=None)
self._ctx = ctx

def read(
self,
cfg: RasterLoadParams,
Expand All @@ -224,7 +249,7 @@ def read(
dst: np.ndarray | None = None,
selection: ReaderSubsetSelection | None = None,
) -> tuple[tuple[slice, slice], np.ndarray]:
src = self._resolve_data_array()
src = self._src.resolve(regen_coords=True)

if selection is not None:
# only support single extra dimension
Expand All @@ -245,6 +270,56 @@ def read(
return yx_roi, dst


class XrMemReaderDask:
"""
Dask version of the reader.
"""

def __init__(
self,
src: RasterSource | None = None,
ctx: Context | None = None,
layer_name: str = "",
idx: int = -1,
) -> None:
self._src = XrSource(src) if src is not None else None
self._ctx = ctx
self._layer_name = layer_name
self._idx = idx

def read(
self,
cfg: RasterLoadParams,
dst_geobox: GeoBox,
*,
selection: ReaderSubsetSelection | None = None,
idx: tuple[int, ...],
) -> Any:
assert self._src is not None
assert len(idx)

xx = self._src.resolve(regen_coords=True)

assert xx is not None
assert isinstance(xx.odc, ODCExtensionDa)

# TODO: selection
assert selection is None
yy = xx.odc.reproject(dst_geobox, resampling=cfg.resampling, chunks={})

return yy.data

def open(
self,
src: RasterSource,
ctx: Any,
*,
layer_name: str,
idx: int,
) -> DaskRasterReader:
return XrMemReaderDask(src, ctx, layer_name=layer_name, idx=idx)


class XrMemReaderDriver:
"""
Read from in memory xarray.Dataset or zarr spec document.
Expand Down Expand Up @@ -293,7 +368,7 @@ def md_parser(self) -> MDParser:

@property
def dask_reader(self) -> DaskRasterReader | None:
return None
return XrMemReaderDask()


def band_info(xx: xr.DataArray) -> RasterBandMetadata:
Expand Down

0 comments on commit cb4c21d

Please sign in to comment.