From 2ecadf34bcb8f1ddea2f3782d479f41e514c89fb Mon Sep 17 00:00:00 2001 From: Kirill Kouzoubov Date: Mon, 24 Jun 2024 19:20:56 +1000 Subject: [PATCH] sqme: zarr dask --- odc/loader/_zarr.py | 19 ++++++++++++++----- odc/loader/test_memreader.py | 34 ++++++++++++++++++++++++++++++---- 2 files changed, 44 insertions(+), 9 deletions(-) diff --git a/odc/loader/_zarr.py b/odc/loader/_zarr.py index feb3bd1..102b053 100644 --- a/odc/loader/_zarr.py +++ b/odc/loader/_zarr.py @@ -11,6 +11,7 @@ import fsspec import numpy as np import xarray as xr +from dask.delayed import Delayed, delayed from odc.geo.geobox import GeoBox from odc.geo.xr import ODCExtensionDa, ODCExtensionDs, xr_coords, xr_reproject @@ -276,6 +277,10 @@ def read( return yx_roi, dst +def _with_roi(xx: np.ndarray) -> tuple[tuple[slice, slice], np.ndarray]: + return (slice(None), slice(None)), xx + + class XrMemReaderDask: """ Dask version of the reader. @@ -288,7 +293,7 @@ def __init__( layer_name: str = "", idx: int = -1, ) -> None: - self._src = XrSource(src) if src is not None else None + self._src = XrSource(src, chunks="auto") if src is not None else None self._ctx = ctx self._layer_name = layer_name self._idx = idx @@ -300,15 +305,19 @@ def read( *, selection: ReaderSubsetSelection | None = None, idx: tuple[int, ...] = (), - ) -> Any: + ) -> Delayed: assert self._src is not None assert isinstance(idx, tuple) xx = self._src.resolve(regen_coords=True) xx = _subset_src(xx, selection, cfg) - yy = xr_reproject(xx, dst_geobox, resampling=cfg.resampling, chunks={}) - yx_roi = (slice(None), slice(None)) - return yx_roi, yy.data + yy = xr_reproject( + xx, + dst_geobox, + resampling=cfg.resampling, + chunks=dst_geobox.shape.yx, + ) + return delayed(_with_roi)(yy.data, dask_key_name=(self._layer_name, *idx)) def open( self, diff --git a/odc/loader/test_memreader.py b/odc/loader/test_memreader.py index 5741bbb..2398502 100644 --- a/odc/loader/test_memreader.py +++ b/odc/loader/test_memreader.py @@ -9,16 +9,28 @@ import numpy as np import pytest import xarray as xr +from dask import is_dask_collection from odc.geo.data import country_geom from odc.geo.gcp import GCPGeoBox from odc.geo.geobox import GeoBox from odc.geo.xr import ODCExtensionDa, ODCExtensionDs, rasterize -from ._zarr import Context, XrMemReader, XrMemReaderDriver, raster_group_md -from .types import FixedCoord, RasterGroupMetadata, RasterLoadParams, RasterSource +from odc.loader._zarr import ( + Context, + XrMemReader, + XrMemReaderDask, + XrMemReaderDriver, + raster_group_md, +) +from odc.loader.types import ( + FixedCoord, + RasterGroupMetadata, + RasterLoadParams, + RasterSource, +) -# pylint: disable=missing-function-docstring,use-implicit-booleaness-not-comparison -# pylint: disable=too-many-locals,too-many-statements,redefined-outer-name +# pylint: disable=missing-function-docstring,use-implicit-booleaness-not-comparison,protected-access +# pylint: disable=too-many-locals,too-many-statements,redefined-outer-name,import-outside-toplevel @pytest.fixture @@ -70,6 +82,7 @@ def test_mem_reader(sample_ds: xr.Dataset) -> None: driver = XrMemReaderDriver(ds) assert driver.md_parser is not None + assert driver.dask_reader is not None md = driver.md_parser.extract(fake_item) assert isinstance(md, RasterGroupMetadata) @@ -201,3 +214,16 @@ def test_memreader_zarr(sample_ds: xr.Dataset): assert isinstance(xx, np.ndarray) assert xx.shape == gbox[roi].shape.yx assert gbox == gbox[roi] + + rdr = XrMemReaderDask().open(src, ctx, layer_name="xx", idx=0) + assert isinstance(rdr, XrMemReaderDask) + + assert rdr._src._chunks == "auto" + + fut = rdr.read(cfg, gbox) + assert is_dask_collection(fut) + + roi, xx = fut.compute(scheduler="synchronous") + assert isinstance(xx, np.ndarray) + assert roi == (slice(None), slice(None)) + assert xx.shape == gbox.shape.yx