From 70c2c2c38e09c17d942a1f92e28455662744e6e4 Mon Sep 17 00:00:00 2001 From: Kirill Kouzoubov Date: Mon, 24 Jun 2024 16:09:25 +1000 Subject: [PATCH] Dask based zarr loader Implement Dask driver backend for Zarr driver based dask version of xr_reproject. --- odc/loader/_zarr.py | 155 +++++++++++++++++++++++++++-------- odc/loader/test_memreader.py | 35 +++++++- 2 files changed, 151 insertions(+), 39 deletions(-) diff --git a/odc/loader/_zarr.py b/odc/loader/_zarr.py index 53639a6..102b053 100644 --- a/odc/loader/_zarr.py +++ b/odc/loader/_zarr.py @@ -6,11 +6,12 @@ import json from contextlib import contextmanager -from typing import Any, Iterator, Sequence +from typing import Any, Iterator, Mapping, Sequence import fsspec import numpy as np import xarray as xr +from dask.delayed import Delayed, delayed from odc.geo.geobox import GeoBox from odc.geo.xr import ODCExtensionDa, ODCExtensionDs, xr_coords, xr_reproject @@ -26,20 +27,23 @@ ReaderSubsetSelection, ) -SomeDoc = dict[str, Any] -ZarrSpec = dict[str, Any] -ZarrSpecFs = dict[str, Any] +# TODO: tighten specs for Zarr* +SomeDoc = Mapping[str, Any] +ZarrSpec = Mapping[str, Any] +ZarrSpecFs = Mapping[str, Any] +ZarrSpecFsDict = dict[str, Any] +# pylint: disable=too-few-public-methods -def extract_zarr_spec(src: SomeDoc) -> ZarrSpecFs | None: +def extract_zarr_spec(src: SomeDoc) -> ZarrSpecFsDict | None: if ".zmetadata" in src: - return src + return dict(src) if "zarr:metadata" in src: # TODO: handle zarr:chunks for reference filesystem zmd = {"zarr_consolidated_format": 1, "metadata": src["zarr:metadata"]} elif "zarr_consolidated_format" in src: - zmd = src + zmd = dict(src) else: zmd = {"zarr_consolidated_format": 1, "metadata": src} @@ -118,9 +122,7 @@ def extract(self, md: Any) -> RasterGroupMetadata: return self._template - def driver_data( - self, md: Any, band_key: BandKey - ) -> xr.DataArray | ZarrSpec | ZarrSpecFs | SomeDoc | None: + def driver_data(self, md: Any, band_key: BandKey) -> xr.DataArray | SomeDoc | None: """ Extract driver specific data for a given band. """ @@ -161,27 +163,24 @@ def with_env(self, env: dict[str, Any]) -> "Context": return Context(self.geobox, self.chunks) -class XrMemReader: +class XrSource: """ - Implements protocol for raster readers. - - - Read from in-memory xarray.Dataset - - Read from zarr spec + RasterSource -> xr.DataArray|xr.Dataset """ - # pylint: disable=too-few-public-methods - - def __init__(self, src: RasterSource, ctx: Context) -> None: + def __init__(self, src: RasterSource, chunks: Any | None = None) -> None: driver_data: xr.DataArray | xr.Dataset | SomeDoc = src.driver_data self._spec: ZarrSpecFs | None = None + self._ds: xr.Dataset | None = None self._xx: xr.DataArray | None = None - self._ctx = ctx self._src = src + self._chunks = chunks if isinstance(driver_data, xr.DataArray): self._xx = driver_data elif isinstance(driver_data, xr.Dataset): subdataset = src.subdataset + self._ds = driver_data assert subdataset in driver_data.data_vars self._xx = driver_data.data_vars[subdataset] elif isinstance(driver_data, dict): @@ -191,18 +190,31 @@ def __init__(self, src: RasterSource, ctx: Context) -> None: assert driver_data is None or (self._spec is not None or self._xx is not None) - def _resolve_data_array(self) -> xr.DataArray: - if self._xx is not None: - return self._xx + @property + def spec(self) -> ZarrSpecFs | None: + return self._spec - assert self._spec is not None - src_ds = _from_zarr_spec( + def base(self, regen_coords: bool = False) -> xr.Dataset | None: + if self._ds is not None: + return self._ds + if self._spec is None: + return None + self._ds = _from_zarr_spec( self._spec, - regen_coords=True, + regen_coords=regen_coords, target=self._src.uri, - chunks=self._ctx.chunks, + chunks=self._chunks, ) + return self._ds + def resolve( + self, + regen_coords: bool = False, + ) -> xr.DataArray: + if self._xx is not None: + return self._xx + + src_ds = self.base(regen_coords=regen_coords) if src_ds is None: raise ValueError("Failed to interpret driver data") @@ -213,9 +225,35 @@ def _resolve_data_array(self) -> xr.DataArray: if subdataset not in src_ds.data_vars: raise ValueError(f"Band {subdataset!r} not found in dataset") + self._xx = src_ds.data_vars[subdataset] return self._xx + +def _subset_src( + src: xr.DataArray, selection: ReaderSubsetSelection, cfg: RasterLoadParams +) -> xr.DataArray: + if selection is None: + return src + + assert isinstance(selection, (slice, int)) or len(selection) == 1 + assert len(cfg.extra_dims) == 1 + (band_dim,) = cfg.extra_dims + return src.isel({band_dim: selection}) + + +class XrMemReader: + """ + Implements protocol for raster readers. + + - Read from in-memory xarray.Dataset + - Read from zarr spec + """ + + def __init__(self, src: RasterSource, ctx: Context) -> None: + self._src = XrSource(src, chunks=None) + self._ctx = ctx + def read( self, cfg: RasterLoadParams, @@ -224,14 +262,8 @@ def read( dst: np.ndarray | None = None, selection: ReaderSubsetSelection | None = None, ) -> tuple[tuple[slice, slice], np.ndarray]: - src = self._resolve_data_array() - - if selection is not None: - # only support single extra dimension - assert isinstance(selection, (slice, int)) or len(selection) == 1 - assert len(cfg.extra_dims) == 1 - (band_dim,) = cfg.extra_dims - src = src.isel({band_dim: selection}) + src = self._src.resolve(regen_coords=True) + src = _subset_src(src, selection, cfg) warped = xr_reproject(src, dst_geobox, resampling=cfg.resampling) assert isinstance(warped.data, np.ndarray) @@ -245,6 +277,59 @@ def read( return yx_roi, dst +def _with_roi(xx: np.ndarray) -> tuple[tuple[slice, slice], np.ndarray]: + return (slice(None), slice(None)), xx + + +class XrMemReaderDask: + """ + Dask version of the reader. + """ + + def __init__( + self, + src: RasterSource | None = None, + ctx: Context | None = None, + layer_name: str = "", + idx: int = -1, + ) -> None: + self._src = XrSource(src, chunks="auto") if src is not None else None + self._ctx = ctx + self._layer_name = layer_name + self._idx = idx + + def read( + self, + cfg: RasterLoadParams, + dst_geobox: GeoBox, + *, + selection: ReaderSubsetSelection | None = None, + idx: tuple[int, ...] = (), + ) -> Delayed: + assert self._src is not None + assert isinstance(idx, tuple) + + xx = self._src.resolve(regen_coords=True) + xx = _subset_src(xx, selection, cfg) + yy = xr_reproject( + xx, + dst_geobox, + resampling=cfg.resampling, + chunks=dst_geobox.shape.yx, + ) + return delayed(_with_roi)(yy.data, dask_key_name=(self._layer_name, *idx)) + + def open( + self, + src: RasterSource, + ctx: Any, + *, + layer_name: str, + idx: int, + ) -> DaskRasterReader: + return XrMemReaderDask(src, ctx, layer_name=layer_name, idx=idx) + + class XrMemReaderDriver: """ Read from in memory xarray.Dataset or zarr spec document. @@ -293,7 +378,7 @@ def md_parser(self) -> MDParser: @property def dask_reader(self) -> DaskRasterReader | None: - return None + return XrMemReaderDask() def band_info(xx: xr.DataArray) -> RasterBandMetadata: diff --git a/odc/loader/test_memreader.py b/odc/loader/test_memreader.py index 5741bbb..9ba364e 100644 --- a/odc/loader/test_memreader.py +++ b/odc/loader/test_memreader.py @@ -9,16 +9,28 @@ import numpy as np import pytest import xarray as xr +from dask import is_dask_collection from odc.geo.data import country_geom from odc.geo.gcp import GCPGeoBox from odc.geo.geobox import GeoBox from odc.geo.xr import ODCExtensionDa, ODCExtensionDs, rasterize -from ._zarr import Context, XrMemReader, XrMemReaderDriver, raster_group_md -from .types import FixedCoord, RasterGroupMetadata, RasterLoadParams, RasterSource +from odc.loader._zarr import ( + Context, + XrMemReader, + XrMemReaderDask, + XrMemReaderDriver, + raster_group_md, +) +from odc.loader.types import ( + FixedCoord, + RasterGroupMetadata, + RasterLoadParams, + RasterSource, +) -# pylint: disable=missing-function-docstring,use-implicit-booleaness-not-comparison -# pylint: disable=too-many-locals,too-many-statements,redefined-outer-name +# pylint: disable=missing-function-docstring,use-implicit-booleaness-not-comparison,protected-access +# pylint: disable=too-many-locals,too-many-statements,redefined-outer-name,import-outside-toplevel @pytest.fixture @@ -70,6 +82,7 @@ def test_mem_reader(sample_ds: xr.Dataset) -> None: driver = XrMemReaderDriver(ds) assert driver.md_parser is not None + assert driver.dask_reader is not None md = driver.md_parser.extract(fake_item) assert isinstance(md, RasterGroupMetadata) @@ -201,3 +214,17 @@ def test_memreader_zarr(sample_ds: xr.Dataset): assert isinstance(xx, np.ndarray) assert xx.shape == gbox[roi].shape.yx assert gbox == gbox[roi] + + rdr = XrMemReaderDask().open(src, ctx, layer_name="xx", idx=0) + assert isinstance(rdr, XrMemReaderDask) + assert rdr._src is not None + + assert rdr._src._chunks == "auto" + + fut = rdr.read(cfg, gbox) + assert is_dask_collection(fut) + + roi, xx = fut.compute(scheduler="synchronous") + assert isinstance(xx, np.ndarray) + assert roi == (slice(None), slice(None)) + assert xx.shape == gbox.shape.yx