diff --git a/odc/loader/test_memreader.py b/odc/loader/test_memreader.py index 0e64051..1eeddc6 100644 --- a/odc/loader/test_memreader.py +++ b/odc/loader/test_memreader.py @@ -4,30 +4,43 @@ from __future__ import annotations +import json + import numpy as np +import pytest import xarray as xr from odc.geo.data import country_geom +from odc.geo.gcp import GCPGeoBox from odc.geo.geobox import GeoBox -from odc.geo.xr import ODCExtensionDa, rasterize +from odc.geo.xr import ODCExtensionDa, ODCExtensionDs, rasterize -from odc.loader.testing.mem_reader import XrMemReader, XrMemReaderDriver -from odc.loader.types import RasterGroupMetadata, RasterLoadParams, RasterSource +from .testing.mem_reader import Context, XrMemReader, XrMemReaderDriver, raster_group_md +from .types import FixedCoord, RasterGroupMetadata, RasterLoadParams, RasterSource # pylint: disable=missing-function-docstring,use-implicit-booleaness-not-comparison -# pylint: disable=too-many-locals,too-many-statements - +# pylint: disable=too-many-locals,too-many-statements,redefined-outer-name -def test_mem_reader() -> None: - fake_item = object() +@pytest.fixture +def sample_ds() -> xr.Dataset: poly = country_geom("AUS", 3857) gbox = GeoBox.from_geopolygon(poly, resolution=10_000) xx = rasterize(poly, gbox).astype("int16") xx.attrs["units"] = "uu" xx.attrs["nodata"] = -33 - ds = xx.to_dataset(name="xx") - driver = XrMemReaderDriver(ds) + return xx.to_dataset(name="xx") + + +def test_mem_reader(sample_ds: xr.Dataset) -> None: + fake_item = object() + + assert isinstance(sample_ds.odc, ODCExtensionDs) + gbox = sample_ds.odc.geobox + assert gbox is not None + assert isinstance(gbox, GeoBox) + + driver = XrMemReaderDriver(sample_ds) assert driver.md_parser is not None @@ -43,6 +56,8 @@ def test_mem_reader() -> None: assert md.extra_dims == {} assert md.extra_coords == [] + ds = sample_ds.copy() + xx = ds.xx yy = xx.astype("uint8", keep_attrs=False).rename("yy") yy = yy.expand_dims("band", 2) yy = xr.concat([yy, yy + 1, yy + 2], "band").assign_coords(band=["r", "g", "b"]) @@ -126,3 +141,63 @@ def test_mem_reader() -> None: loader = loaders["zz"] roi, pix = loader.read(cfgs["zz"], gbox, selection=np.s_[:2]) assert pix.shape == (2, *gbox.shape.yx) + + +def test_raster_group_md(): + rgm = raster_group_md(xr.Dataset()) + assert rgm.bands == {} + assert rgm.aliases == {} + assert rgm.extra_dims == {} + + coord = FixedCoord("band", ["r", "g", "b"], dim="band") + + rgm = raster_group_md( + xr.Dataset(), base=RasterGroupMetadata({}, {}, {"band": 3}, []) + ) + assert rgm.extra_dims == {"band": 3} + assert len(rgm.extra_coords) == 0 + + rgm = raster_group_md( + xr.Dataset(), base=RasterGroupMetadata({}, extra_coords=[coord]) + ) + assert rgm.extra_dims == {} + assert rgm.extra_dims_full() == {"band": 3} + assert len(rgm.extra_coords) == 1 + assert rgm.extra_coords[0] == coord + + +def test_memreader_zarr(sample_ds: xr.Dataset): + assert isinstance(sample_ds.odc, ODCExtensionDs) + assert "xx" in sample_ds + + zarr = pytest.importorskip("zarr") + assert zarr is not None + + _gbox = sample_ds.odc.geobox + chunks = None + assert _gbox is not None + gbox = _gbox.approx if isinstance(_gbox, GCPGeoBox) else _gbox + + md_store: dict[str, bytes] = {} + chunk_store: dict[str, bytes] = {} + sample_ds.to_zarr(md_store, chunk_store, compute=False, consolidated=True) + + assert ".zmetadata" in md_store + zmd = json.loads(md_store[".zmetadata"])["metadata"] + + src = RasterSource( + "file:///tmp/no-such-dir/xx.zarr", + subdataset="xx", + driver_data=zmd, + ) + assert src.driver_data is zmd + + cfg = RasterLoadParams.same_as(src) + + ctx = Context(gbox, chunks) + rdr = XrMemReader(src, ctx) + + roi, xx = rdr.read(cfg, gbox) + assert isinstance(xx, np.ndarray) + assert xx.shape == gbox[roi].shape.yx + assert gbox == gbox[roi] diff --git a/odc/loader/testing/mem_reader.py b/odc/loader/testing/mem_reader.py index d7f5bcc..6b294fb 100644 --- a/odc/loader/testing/mem_reader.py +++ b/odc/loader/testing/mem_reader.py @@ -4,13 +4,15 @@ from __future__ import annotations +import json from contextlib import contextmanager -from typing import Any, Iterator +from typing import Any, Iterator, Sequence +import fsspec import numpy as np import xarray as xr from odc.geo.geobox import GeoBox -from odc.geo.xr import ODCExtensionDa, ODCExtensionDs, xr_reproject +from odc.geo.xr import ODCExtensionDa, ODCExtensionDs, xr_coords, xr_reproject from ..types import ( BandKey, @@ -24,6 +26,64 @@ ReaderSubsetSelection, ) +SomeDoc = dict[str, Any] +ZarrSpec = dict[str, Any] +ZarrSpecFs = dict[str, Any] + + +def extract_zarr_spec(src: SomeDoc) -> ZarrSpecFs | None: + if ".zmetadata" in src: + return src + + if "zarr:metadata" in src: + # TODO: handle zarr:chunks for reference filesystem + zmd = {"zarr_consolidated_format": 1, "metadata": src["zarr:metadata"]} + elif "zarr_consolidated_format" in src: + zmd = src + else: + zmd = {"zarr_consolidated_format": 1, "metadata": src} + + return {".zmetadata": json.dumps(zmd)} + + +def _from_zarr_spec( + spec_doc: ZarrSpecFs, + regen_coords: bool = False, + fs: fsspec.AbstractFileSystem | None = None, + chunks=None, + target: str | None = None, + fsspec_opts: dict[str, Any] | None = None, +) -> xr.Dataset: + fsspec_opts = fsspec_opts or {} + rfs = fsspec.filesystem( + "reference", fo=spec_doc, fs=fs, target=target, **fsspec_opts + ) + + xx = xr.open_dataset(rfs.get_mapper(""), engine="zarr", mode="r", chunks=chunks) + gbox = xx.odc.geobox + if gbox is not None and regen_coords: + # re-gen x,y coords from geobox + xx = xx.assign_coords(xr_coords(gbox)) + + return xx + + +def _resolve_src_dataset( + md: Any, + *, + regen_coords: bool = False, + fallback: xr.Dataset | None = None, + **kw, +) -> xr.Dataset | None: + if isinstance(md, dict) and (spec_doc := extract_zarr_spec(md)) is not None: + return _from_zarr_spec(spec_doc, regen_coords=regen_coords, **kw) + + if isinstance(md, xr.Dataset): + return md + + # TODO: support stac items and datacube datasets + return fallback + class XrMDPlugin: """ @@ -35,22 +95,50 @@ class XrMDPlugin: - Driver data is xarray.DataArray for each band """ - def __init__(self, src: xr.Dataset) -> None: + def __init__( + self, + template: RasterGroupMetadata, + src: xr.Dataset | None = None, + ) -> None: + self._template = template self._src = src - self._md = raster_group_md(src) + + def _resolve_src(self, md: Any, regen_coords: bool = False) -> xr.Dataset | None: + return _resolve_src_dataset( + md, regen_coords=regen_coords, fallback=self._src, chunks={} + ) def extract(self, md: Any) -> RasterGroupMetadata: """Fixed description of src dataset.""" - assert md is not None - return self._md + if isinstance(md, RasterGroupMetadata): + return md + + if (src := self._resolve_src(md, regen_coords=False)) is not None: + return raster_group_md(src, base=self._template) + + return self._template - def driver_data(self, md: Any, band_key: BandKey) -> xr.DataArray: + def driver_data( + self, md: Any, band_key: BandKey + ) -> xr.DataArray | ZarrSpec | ZarrSpecFs | SomeDoc | None: """ Extract driver specific data for a given band. """ - assert md is not None name, _ = band_key - return self._src[name] + + if isinstance(md, dict): + if (spec_doc := extract_zarr_spec(md)) is not None: + return spec_doc + return md + + if isinstance(md, xr.DataArray): + return md + + src = self._resolve_src(md, regen_coords=False) + if src is None or name not in src.data_vars: + return None + + return src.data_vars[name] class Context: @@ -60,17 +148,15 @@ class Context: def __init__( self, - src: xr.Dataset, geobox: GeoBox, chunks: None | dict[str, int], ) -> None: - self.src = src self.geobox = geobox self.chunks = chunks def with_env(self, env: dict[str, Any]) -> "Context": assert isinstance(env, dict) - return Context(self.src, self.geobox, self.chunks) + return Context(self.geobox, self.chunks) class XrMemReader: @@ -81,9 +167,49 @@ class XrMemReader: # pylint: disable=too-few-public-methods def __init__(self, src: RasterSource, ctx: Context) -> None: - self._src = src - self._xx: xr.DataArray = src.driver_data + driver_data: xr.DataArray | xr.Dataset | SomeDoc = src.driver_data + self._spec: ZarrSpecFs | None = None + self._xx: xr.DataArray | None = None self._ctx = ctx + self._src = src + + if isinstance(driver_data, xr.DataArray): + self._xx = driver_data + elif isinstance(driver_data, xr.Dataset): + subdataset = src.subdataset + assert subdataset in driver_data.data_vars + self._xx = driver_data.data_vars[subdataset] + elif isinstance(driver_data, dict): + self._spec = extract_zarr_spec(driver_data) + elif driver_data is not None: + raise ValueError(f"Unsupported driver data type: {type(driver_data)}") + + assert driver_data is None or (self._spec is not None or self._xx is not None) + + def _resolve_data_array(self) -> xr.DataArray: + if self._xx is not None: + return self._xx + + assert self._spec is not None + src_ds = _from_zarr_spec( + self._spec, + regen_coords=True, + target=self._src.uri, + chunks=self._ctx.chunks, + ) + + if src_ds is None: + raise ValueError("Failed to interpret driver data") + + subdataset = self._src.subdataset + if subdataset is None: + _first, *_ = src_ds.data_vars + subdataset = str(_first) + + if subdataset not in src_ds.data_vars: + raise ValueError(f"Band {subdataset!r} not found in dataset") + self._xx = src_ds.data_vars[subdataset] + return self._xx def read( self, @@ -93,7 +219,7 @@ def read( dst: np.ndarray | None = None, selection: ReaderSubsetSelection | None = None, ) -> tuple[tuple[slice, slice], np.ndarray]: - src = self._xx + src = self._resolve_data_array() if selection is not None: # only support single extra dimension @@ -121,8 +247,17 @@ class XrMemReaderDriver: Reader = XrMemReader - def __init__(self, src: xr.Dataset) -> None: + def __init__( + self, + src: xr.Dataset | None = None, + template: RasterGroupMetadata | None = None, + ) -> None: + if src is not None and template is None: + template = raster_group_md(src) + if template is None: + template = RasterGroupMetadata({}, {}, {}, []) self.src = src + self.template = template def new_load( self, @@ -130,7 +265,7 @@ def new_load( *, chunks: None | dict[str, int] = None, ) -> Context: - return Context(self.src, geobox, chunks) + return Context(geobox, chunks) def finalise_load(self, load_state: Context) -> Context: return load_state @@ -149,7 +284,7 @@ def open(self, src: RasterSource, ctx: Context) -> XrMemReader: @property def md_parser(self) -> MDParser: - return XrMDPlugin(self.src) + return XrMDPlugin(self.template, src=self.src) @property def dask_reader(self) -> DaskRasterReader | None: @@ -177,25 +312,46 @@ def band_info(xx: xr.DataArray) -> RasterBandMetadata: ) -def raster_group_md(src: xr.Dataset) -> RasterGroupMetadata: +def raster_group_md( + src: xr.Dataset, + *, + base: RasterGroupMetadata | None = None, + aliases: dict[str, list[BandKey]] | None = None, + extra_coords: Sequence[FixedCoord] = (), + extra_dims: dict[str, int] | None = None, +) -> RasterGroupMetadata: oo: ODCExtensionDs = src.odc sdims = oo.spatial_dims or ("y", "x") - bands: dict[BandKey, RasterBandMetadata] = { - (str(k), 1): band_info(v) for k, v in src.data_vars.items() if v.ndim >= 2 - } + if base is None: + base = RasterGroupMetadata( + bands={}, + aliases=aliases or {}, + extra_coords=extra_coords, + extra_dims=extra_dims or {}, + ) + + bands = base.bands.copy() + bands.update( + {(str(k), 1): band_info(v) for k, v in src.data_vars.items() if v.ndim >= 2} + ) + + edims = base.extra_dims.copy() + edims.update({str(name): sz for name, sz in src.sizes.items() if name not in sdims}) - extra_dims: dict[str, int] = { - str(name): sz for name, sz in src.sizes.items() if name not in sdims - } + aliases: dict[str, list[BandKey]] = base.aliases.copy() - aliases: dict[str, list[BandKey]] = {} + extra_coords: list[FixedCoord] = list(base.extra_coords) + supplied_coords = set(coord.name for coord in extra_coords) - extra_coords: list[FixedCoord] = [] for coord in src.coords.values(): if len(coord.dims) != 1 or coord.dims[0] in sdims: # Only 1-d non-spatial coords continue + + if coord.name in supplied_coords: + continue + extra_coords.append( FixedCoord( coord.name, @@ -205,4 +361,9 @@ def raster_group_md(src: xr.Dataset) -> RasterGroupMetadata: ) ) - return RasterGroupMetadata(bands, aliases, extra_dims, extra_coords) + return RasterGroupMetadata( + bands=bands, + aliases=aliases, + extra_dims=edims, + extra_coords=extra_coords, + ) diff --git a/tests/test-env-py310.yml b/tests/test-env-py310.yml index 93116c3..bbdc374 100644 --- a/tests/test-env-py310.yml +++ b/tests/test-env-py310.yml @@ -36,6 +36,7 @@ dependencies: - pystac-client >=0.2.0 - geopandas - stackstac + - zarr # for docs - sphinx