diff --git a/odc/loader/test_memreader.py b/odc/loader/test_memreader.py index 8ff6a97..b769ec4 100644 --- a/odc/loader/test_memreader.py +++ b/odc/loader/test_memreader.py @@ -4,30 +4,40 @@ from __future__ import annotations +import json + import numpy as np +import pytest import xarray as xr from odc.geo.data import country_geom from odc.geo.geobox import GeoBox -from odc.geo.xr import ODCExtensionDa, rasterize +from odc.geo.xr import ODCExtensionDa, ODCExtensionDs, rasterize -from .testing.mem_reader import XrMemReader, XrMemReaderDriver, raster_group_md +from .testing.mem_reader import Context, XrMemReader, XrMemReaderDriver, raster_group_md from .types import FixedCoord, RasterGroupMetadata, RasterLoadParams, RasterSource # pylint: disable=missing-function-docstring,use-implicit-booleaness-not-comparison -# pylint: disable=too-many-locals,too-many-statements - +# pylint: disable=too-many-locals,too-many-statements,redefined-outer-name -def test_mem_reader() -> None: - fake_item = object() +@pytest.fixture +def sample_ds() -> xr.Dataset: poly = country_geom("AUS", 3857) gbox = GeoBox.from_geopolygon(poly, resolution=10_000) xx = rasterize(poly, gbox).astype("int16") xx.attrs["units"] = "uu" xx.attrs["nodata"] = -33 - ds = xx.to_dataset(name="xx") - driver = XrMemReaderDriver(ds) + return xx.to_dataset(name="xx") + + +def test_mem_reader(sample_ds: xr.Dataset) -> None: + fake_item = object() + + assert isinstance(sample_ds.odc, ODCExtensionDs) + gbox = sample_ds.odc.geobox + assert gbox is not None + driver = XrMemReaderDriver(sample_ds) assert driver.md_parser is not None @@ -43,6 +53,8 @@ def test_mem_reader() -> None: assert md.extra_dims == {} assert md.extra_coords == [] + ds = sample_ds.copy() + xx = ds.xx yy = xx.astype("uint8", keep_attrs=False).rename("yy") yy = yy.expand_dims("band", 2) yy = xr.concat([yy, yy + 1, yy + 2], "band").assign_coords(band=["r", "g", "b"]) @@ -149,3 +161,35 @@ def test_raster_group_md(): assert rgm.extra_dims_full() == {"band": 3} assert len(rgm.extra_coords) == 1 assert rgm.extra_coords[0] == coord + + +def test_memreader_zarr(sample_ds: xr.Dataset): + assert isinstance(sample_ds.odc, ODCExtensionDs) + assert "xx" in sample_ds + + gbox = sample_ds.odc.geobox + chunks = None + + md_store = {} + chunk_store = {} + sample_ds.to_zarr(md_store, chunk_store, compute=False, consolidated=True) + + assert ".zmetadata" in md_store + zmd = json.loads(md_store[".zmetadata"])["metadata"] + + src = RasterSource( + "file:///tmp/no-such-dir/xx.zarr", + subdataset="xx", + driver_data=zmd, + ) + assert src.driver_data is zmd + + cfg = RasterLoadParams.same_as(src) + + ctx = Context(gbox, chunks) + rdr = XrMemReader(src, ctx) + + roi, xx = rdr.read(cfg, gbox) + assert isinstance(xx, np.ndarray) + assert xx.shape == gbox[roi].shape.yx + assert gbox == gbox[roi] diff --git a/odc/loader/testing/mem_reader.py b/odc/loader/testing/mem_reader.py index 10d9a61..3f60c83 100644 --- a/odc/loader/testing/mem_reader.py +++ b/odc/loader/testing/mem_reader.py @@ -26,8 +26,15 @@ ReaderSubsetSelection, ) +SomeDoc = dict[str, Any] +ZarrSpec = dict[str, Any] +ZarrSpecFs = dict[str, Any] + + +def extract_zarr_spec(src: SomeDoc) -> ZarrSpecFs | None: + if ".zmetadata" in src: + return src -def extract_zarr_spec(src: dict[str, Any]) -> dict[str, Any] | None: if "zarr:metadata" in src: # TODO: handle zarr:chunks for reference filesystem zmd = {"zarr_consolidated_format": 1, "metadata": src["zarr:metadata"]} @@ -39,6 +46,45 @@ def extract_zarr_spec(src: dict[str, Any]) -> dict[str, Any] | None: return {".zmetadata": json.dumps(zmd)} +def _from_zarr_spec( + spec_doc: ZarrSpecFs, + regen_coords: bool = False, + fs: fsspec.AbstractFileSystem | None = None, + chunks=None, + target: str | None = None, + fsspec_opts: dict[str, Any] | None = None, +) -> xr.Dataset: + fsspec_opts = fsspec_opts or {} + rfs = fsspec.filesystem( + "reference", fo=spec_doc, fs=fs, target=target, **fsspec_opts + ) + + xx = xr.open_dataset(rfs.get_mapper(""), engine="zarr", mode="r", chunks=chunks) + gbox = xx.odc.geobox + if gbox is not None and regen_coords: + # re-gen x,y coords from geobox + xx = xx.assign_coords(xr_coords(gbox)) + + return xx + + +def _resolve_src_dataset( + md: Any, + *, + regen_coords: bool = False, + fallback: xr.Dataset | None = None, + **kw, +) -> xr.Dataset | None: + if isinstance(md, dict) and (spec_doc := extract_zarr_spec(md)) is not None: + return _from_zarr_spec(spec_doc, regen_coords=regen_coords, **kw) + + if isinstance(md, xr.Dataset): + return md + + # TODO: support stac items and datacube datasets + return fallback + + class XrMDPlugin: """ Convert xarray.Dataset to RasterGroupMetadata. @@ -57,32 +103,10 @@ def __init__( self._template = template self._src = src - def _from_zarr_spec( - self, - spec_doc: dict[str, Any], - regen_coords: bool = False, - ) -> xr.Dataset: - rfs = fsspec.filesystem("reference", fo=spec_doc) - xx = xr.open_dataset(rfs.get_mapper(""), engine="zarr") - gbox = xx.odc.geobox - if gbox is not None and regen_coords: - # re-gen x,y coords from geobox - xx = xx.assign_coords(xr_coords(gbox)) - - return xx - def _resolve_src(self, md: Any, regen_coords: bool = False) -> xr.Dataset | None: - src = self._src - - if isinstance(md, dict) and (spec_doc := extract_zarr_spec(md)) is not None: - src = self._from_zarr_spec(spec_doc, regen_coords=regen_coords) - - if isinstance(md, xr.Dataset): - src = md - - # TODO: support stac items and datacube datasets - - return src + return _resolve_src_dataset( + md, regen_coords=regen_coords, fallback=self._src, chunks={} + ) def extract(self, md: Any) -> RasterGroupMetadata: """Fixed description of src dataset.""" @@ -94,20 +118,27 @@ def extract(self, md: Any) -> RasterGroupMetadata: return self._template - def driver_data(self, md: Any, band_key: BandKey) -> xr.DataArray | None: + def driver_data( + self, md: Any, band_key: BandKey + ) -> xr.DataArray | ZarrSpec | ZarrSpecFs | SomeDoc | None: """ Extract driver specific data for a given band. """ name, _ = band_key + if isinstance(md, dict): + if (spec_doc := extract_zarr_spec(md)) is not None: + return spec_doc + return md + if isinstance(md, xr.DataArray): return md - if (src := self._resolve_src(md, regen_coords=True)) is not None: - if (aa := src.data_vars.get(name)) is not None: - return aa + src = self._resolve_src(md, regen_coords=False) + if src is None or name not in src.data_vars: + return None - return None + return src.data_vars[name] class Context: @@ -136,8 +167,48 @@ class XrMemReader: # pylint: disable=too-few-public-methods def __init__(self, src: RasterSource, ctx: Context) -> None: - self._xx: xr.DataArray = src.driver_data + driver_data: xr.DataArray | xr.Dataset | SomeDoc = src.driver_data + self._spec: ZarrSpecFs | None = None + self._xx: xr.DataArray | None = None self._ctx = ctx + self._src = src + + if isinstance(driver_data, xr.DataArray): + self._xx = driver_data + elif isinstance(driver_data, xr.Dataset): + subdataset = src.subdataset + assert subdataset in driver_data.data_vars + self._xx = driver_data.data_vars[subdataset] + elif isinstance(driver_data, dict): + self._spec = extract_zarr_spec(driver_data) + elif driver_data is not None: + raise ValueError(f"Unsupported driver data type: {type(driver_data)}") + + assert driver_data is None or (self._spec is not None or self._xx is not None) + + def _resolve_data_array(self) -> xr.DataArray: + if self._xx is not None: + return self._xx + + assert self._spec is not None + kw = { + "target": self._src.uri, + "chunks": self._ctx.chunks, + } + src_ds = _from_zarr_spec(self._spec, regen_coords=True, **kw) + + if src_ds is None: + raise ValueError("Failed to interpret driver data") + + subdataset = self._src.subdataset + if subdataset is None: + _first, *_ = src_ds.data_vars + subdataset = str(_first) + + if subdataset not in src_ds.data_vars: + raise ValueError(f"Band {subdataset!r} not found in dataset") + self._xx = src_ds.data_vars[subdataset] + return self._xx def read( self, @@ -147,7 +218,7 @@ def read( dst: np.ndarray | None = None, selection: ReaderSubsetSelection | None = None, ) -> tuple[tuple[slice, slice], np.ndarray]: - src = self._xx + src = self._resolve_data_array() if selection is not None: # only support single extra dimension