Skip to content

Commit

Permalink
sqme: zarr dask
Browse files Browse the repository at this point in the history
  • Loading branch information
Kirill888 committed Jun 24, 2024
1 parent ad77d7b commit 2ecadf3
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 9 deletions.
19 changes: 14 additions & 5 deletions odc/loader/_zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import fsspec
import numpy as np
import xarray as xr
from dask.delayed import Delayed, delayed
from odc.geo.geobox import GeoBox
from odc.geo.xr import ODCExtensionDa, ODCExtensionDs, xr_coords, xr_reproject

Expand Down Expand Up @@ -276,6 +277,10 @@ def read(
return yx_roi, dst


def _with_roi(xx: np.ndarray) -> tuple[tuple[slice, slice], np.ndarray]:
return (slice(None), slice(None)), xx


class XrMemReaderDask:
"""
Dask version of the reader.
Expand All @@ -288,7 +293,7 @@ def __init__(
layer_name: str = "",
idx: int = -1,
) -> None:
self._src = XrSource(src) if src is not None else None
self._src = XrSource(src, chunks="auto") if src is not None else None
self._ctx = ctx
self._layer_name = layer_name
self._idx = idx
Expand All @@ -300,15 +305,19 @@ def read(
*,
selection: ReaderSubsetSelection | None = None,
idx: tuple[int, ...] = (),
) -> Any:
) -> Delayed:
assert self._src is not None
assert isinstance(idx, tuple)

xx = self._src.resolve(regen_coords=True)
xx = _subset_src(xx, selection, cfg)
yy = xr_reproject(xx, dst_geobox, resampling=cfg.resampling, chunks={})
yx_roi = (slice(None), slice(None))
return yx_roi, yy.data
yy = xr_reproject(
xx,
dst_geobox,
resampling=cfg.resampling,
chunks=dst_geobox.shape.yx,
)
return delayed(_with_roi)(yy.data, dask_key_name=(self._layer_name, *idx))

def open(
self,
Expand Down
34 changes: 30 additions & 4 deletions odc/loader/test_memreader.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,28 @@
import numpy as np
import pytest
import xarray as xr
from dask import is_dask_collection
from odc.geo.data import country_geom
from odc.geo.gcp import GCPGeoBox
from odc.geo.geobox import GeoBox
from odc.geo.xr import ODCExtensionDa, ODCExtensionDs, rasterize

from ._zarr import Context, XrMemReader, XrMemReaderDriver, raster_group_md
from .types import FixedCoord, RasterGroupMetadata, RasterLoadParams, RasterSource
from odc.loader._zarr import (
Context,
XrMemReader,
XrMemReaderDask,
XrMemReaderDriver,
raster_group_md,
)
from odc.loader.types import (
FixedCoord,
RasterGroupMetadata,
RasterLoadParams,
RasterSource,
)

# pylint: disable=missing-function-docstring,use-implicit-booleaness-not-comparison
# pylint: disable=too-many-locals,too-many-statements,redefined-outer-name
# pylint: disable=missing-function-docstring,use-implicit-booleaness-not-comparison,protected-access
# pylint: disable=too-many-locals,too-many-statements,redefined-outer-name,import-outside-toplevel


@pytest.fixture
Expand Down Expand Up @@ -70,6 +82,7 @@ def test_mem_reader(sample_ds: xr.Dataset) -> None:

driver = XrMemReaderDriver(ds)
assert driver.md_parser is not None
assert driver.dask_reader is not None
md = driver.md_parser.extract(fake_item)

assert isinstance(md, RasterGroupMetadata)
Expand Down Expand Up @@ -201,3 +214,16 @@ def test_memreader_zarr(sample_ds: xr.Dataset):
assert isinstance(xx, np.ndarray)
assert xx.shape == gbox[roi].shape.yx
assert gbox == gbox[roi]

rdr = XrMemReaderDask().open(src, ctx, layer_name="xx", idx=0)
assert isinstance(rdr, XrMemReaderDask)

assert rdr._src._chunks == "auto"

fut = rdr.read(cfg, gbox)
assert is_dask_collection(fut)

roi, xx = fut.compute(scheduler="synchronous")
assert isinstance(xx, np.ndarray)
assert roi == (slice(None), slice(None))
assert xx.shape == gbox.shape.yx

0 comments on commit 2ecadf3

Please sign in to comment.