Skip to content

Commit

Permalink
Dask based zarr loader
Browse files Browse the repository at this point in the history
Implement Dask driver backend for Zarr driver
based dask version of xr_reproject.
  • Loading branch information
Kirill888 committed Jun 24, 2024
1 parent 7814ca3 commit 70c2c2c
Show file tree
Hide file tree
Showing 2 changed files with 151 additions and 39 deletions.
155 changes: 120 additions & 35 deletions odc/loader/_zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@

import json
from contextlib import contextmanager
from typing import Any, Iterator, Sequence
from typing import Any, Iterator, Mapping, Sequence

import fsspec
import numpy as np
import xarray as xr
from dask.delayed import Delayed, delayed
from odc.geo.geobox import GeoBox
from odc.geo.xr import ODCExtensionDa, ODCExtensionDs, xr_coords, xr_reproject

Expand All @@ -26,20 +27,23 @@
ReaderSubsetSelection,
)

SomeDoc = dict[str, Any]
ZarrSpec = dict[str, Any]
ZarrSpecFs = dict[str, Any]
# TODO: tighten specs for Zarr*
SomeDoc = Mapping[str, Any]
ZarrSpec = Mapping[str, Any]
ZarrSpecFs = Mapping[str, Any]
ZarrSpecFsDict = dict[str, Any]
# pylint: disable=too-few-public-methods


def extract_zarr_spec(src: SomeDoc) -> ZarrSpecFs | None:
def extract_zarr_spec(src: SomeDoc) -> ZarrSpecFsDict | None:
if ".zmetadata" in src:
return src
return dict(src)

if "zarr:metadata" in src:
# TODO: handle zarr:chunks for reference filesystem
zmd = {"zarr_consolidated_format": 1, "metadata": src["zarr:metadata"]}
elif "zarr_consolidated_format" in src:
zmd = src
zmd = dict(src)
else:
zmd = {"zarr_consolidated_format": 1, "metadata": src}

Expand Down Expand Up @@ -118,9 +122,7 @@ def extract(self, md: Any) -> RasterGroupMetadata:

return self._template

def driver_data(
self, md: Any, band_key: BandKey
) -> xr.DataArray | ZarrSpec | ZarrSpecFs | SomeDoc | None:
def driver_data(self, md: Any, band_key: BandKey) -> xr.DataArray | SomeDoc | None:
"""
Extract driver specific data for a given band.
"""
Expand Down Expand Up @@ -161,27 +163,24 @@ def with_env(self, env: dict[str, Any]) -> "Context":
return Context(self.geobox, self.chunks)


class XrMemReader:
class XrSource:
"""
Implements protocol for raster readers.
- Read from in-memory xarray.Dataset
- Read from zarr spec
RasterSource -> xr.DataArray|xr.Dataset
"""

# pylint: disable=too-few-public-methods

def __init__(self, src: RasterSource, ctx: Context) -> None:
def __init__(self, src: RasterSource, chunks: Any | None = None) -> None:
driver_data: xr.DataArray | xr.Dataset | SomeDoc = src.driver_data
self._spec: ZarrSpecFs | None = None
self._ds: xr.Dataset | None = None
self._xx: xr.DataArray | None = None
self._ctx = ctx
self._src = src
self._chunks = chunks

if isinstance(driver_data, xr.DataArray):
self._xx = driver_data
elif isinstance(driver_data, xr.Dataset):
subdataset = src.subdataset
self._ds = driver_data
assert subdataset in driver_data.data_vars
self._xx = driver_data.data_vars[subdataset]
elif isinstance(driver_data, dict):
Expand All @@ -191,18 +190,31 @@ def __init__(self, src: RasterSource, ctx: Context) -> None:

assert driver_data is None or (self._spec is not None or self._xx is not None)

def _resolve_data_array(self) -> xr.DataArray:
if self._xx is not None:
return self._xx
@property
def spec(self) -> ZarrSpecFs | None:
return self._spec

assert self._spec is not None
src_ds = _from_zarr_spec(
def base(self, regen_coords: bool = False) -> xr.Dataset | None:
if self._ds is not None:
return self._ds
if self._spec is None:
return None
self._ds = _from_zarr_spec(
self._spec,
regen_coords=True,
regen_coords=regen_coords,
target=self._src.uri,
chunks=self._ctx.chunks,
chunks=self._chunks,
)
return self._ds

def resolve(
self,
regen_coords: bool = False,
) -> xr.DataArray:
if self._xx is not None:
return self._xx

src_ds = self.base(regen_coords=regen_coords)
if src_ds is None:
raise ValueError("Failed to interpret driver data")

Expand All @@ -213,9 +225,35 @@ def _resolve_data_array(self) -> xr.DataArray:

if subdataset not in src_ds.data_vars:
raise ValueError(f"Band {subdataset!r} not found in dataset")

self._xx = src_ds.data_vars[subdataset]
return self._xx


def _subset_src(
src: xr.DataArray, selection: ReaderSubsetSelection, cfg: RasterLoadParams
) -> xr.DataArray:
if selection is None:
return src

assert isinstance(selection, (slice, int)) or len(selection) == 1
assert len(cfg.extra_dims) == 1
(band_dim,) = cfg.extra_dims
return src.isel({band_dim: selection})


class XrMemReader:
"""
Implements protocol for raster readers.
- Read from in-memory xarray.Dataset
- Read from zarr spec
"""

def __init__(self, src: RasterSource, ctx: Context) -> None:
self._src = XrSource(src, chunks=None)
self._ctx = ctx

def read(
self,
cfg: RasterLoadParams,
Expand All @@ -224,14 +262,8 @@ def read(
dst: np.ndarray | None = None,
selection: ReaderSubsetSelection | None = None,
) -> tuple[tuple[slice, slice], np.ndarray]:
src = self._resolve_data_array()

if selection is not None:
# only support single extra dimension
assert isinstance(selection, (slice, int)) or len(selection) == 1
assert len(cfg.extra_dims) == 1
(band_dim,) = cfg.extra_dims
src = src.isel({band_dim: selection})
src = self._src.resolve(regen_coords=True)
src = _subset_src(src, selection, cfg)

warped = xr_reproject(src, dst_geobox, resampling=cfg.resampling)
assert isinstance(warped.data, np.ndarray)
Expand All @@ -245,6 +277,59 @@ def read(
return yx_roi, dst


def _with_roi(xx: np.ndarray) -> tuple[tuple[slice, slice], np.ndarray]:
return (slice(None), slice(None)), xx


class XrMemReaderDask:
"""
Dask version of the reader.
"""

def __init__(
self,
src: RasterSource | None = None,
ctx: Context | None = None,
layer_name: str = "",
idx: int = -1,
) -> None:
self._src = XrSource(src, chunks="auto") if src is not None else None
self._ctx = ctx
self._layer_name = layer_name
self._idx = idx

def read(
self,
cfg: RasterLoadParams,
dst_geobox: GeoBox,
*,
selection: ReaderSubsetSelection | None = None,
idx: tuple[int, ...] = (),
) -> Delayed:
assert self._src is not None
assert isinstance(idx, tuple)

xx = self._src.resolve(regen_coords=True)
xx = _subset_src(xx, selection, cfg)
yy = xr_reproject(
xx,
dst_geobox,
resampling=cfg.resampling,
chunks=dst_geobox.shape.yx,
)
return delayed(_with_roi)(yy.data, dask_key_name=(self._layer_name, *idx))

def open(
self,
src: RasterSource,
ctx: Any,
*,
layer_name: str,
idx: int,
) -> DaskRasterReader:
return XrMemReaderDask(src, ctx, layer_name=layer_name, idx=idx)


class XrMemReaderDriver:
"""
Read from in memory xarray.Dataset or zarr spec document.
Expand Down Expand Up @@ -293,7 +378,7 @@ def md_parser(self) -> MDParser:

@property
def dask_reader(self) -> DaskRasterReader | None:
return None
return XrMemReaderDask()


def band_info(xx: xr.DataArray) -> RasterBandMetadata:
Expand Down
35 changes: 31 additions & 4 deletions odc/loader/test_memreader.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,28 @@
import numpy as np
import pytest
import xarray as xr
from dask import is_dask_collection
from odc.geo.data import country_geom
from odc.geo.gcp import GCPGeoBox
from odc.geo.geobox import GeoBox
from odc.geo.xr import ODCExtensionDa, ODCExtensionDs, rasterize

from ._zarr import Context, XrMemReader, XrMemReaderDriver, raster_group_md
from .types import FixedCoord, RasterGroupMetadata, RasterLoadParams, RasterSource
from odc.loader._zarr import (
Context,
XrMemReader,
XrMemReaderDask,
XrMemReaderDriver,
raster_group_md,
)
from odc.loader.types import (
FixedCoord,
RasterGroupMetadata,
RasterLoadParams,
RasterSource,
)

# pylint: disable=missing-function-docstring,use-implicit-booleaness-not-comparison
# pylint: disable=too-many-locals,too-many-statements,redefined-outer-name
# pylint: disable=missing-function-docstring,use-implicit-booleaness-not-comparison,protected-access
# pylint: disable=too-many-locals,too-many-statements,redefined-outer-name,import-outside-toplevel


@pytest.fixture
Expand Down Expand Up @@ -70,6 +82,7 @@ def test_mem_reader(sample_ds: xr.Dataset) -> None:

driver = XrMemReaderDriver(ds)
assert driver.md_parser is not None
assert driver.dask_reader is not None
md = driver.md_parser.extract(fake_item)

assert isinstance(md, RasterGroupMetadata)
Expand Down Expand Up @@ -201,3 +214,17 @@ def test_memreader_zarr(sample_ds: xr.Dataset):
assert isinstance(xx, np.ndarray)
assert xx.shape == gbox[roi].shape.yx
assert gbox == gbox[roi]

rdr = XrMemReaderDask().open(src, ctx, layer_name="xx", idx=0)
assert isinstance(rdr, XrMemReaderDask)
assert rdr._src is not None

assert rdr._src._chunks == "auto"

fut = rdr.read(cfg, gbox)
assert is_dask_collection(fut)

roi, xx = fut.compute(scheduler="synchronous")
assert isinstance(xx, np.ndarray)
assert roi == (slice(None), slice(None))
assert xx.shape == gbox.shape.yx

0 comments on commit 70c2c2c

Please sign in to comment.