diff --git a/odc/loader/test_memreader.py b/odc/loader/test_memreader.py index 0e64051..8ff6a97 100644 --- a/odc/loader/test_memreader.py +++ b/odc/loader/test_memreader.py @@ -10,8 +10,8 @@ from odc.geo.geobox import GeoBox from odc.geo.xr import ODCExtensionDa, rasterize -from odc.loader.testing.mem_reader import XrMemReader, XrMemReaderDriver -from odc.loader.types import RasterGroupMetadata, RasterLoadParams, RasterSource +from .testing.mem_reader import XrMemReader, XrMemReaderDriver, raster_group_md +from .types import FixedCoord, RasterGroupMetadata, RasterLoadParams, RasterSource # pylint: disable=missing-function-docstring,use-implicit-booleaness-not-comparison # pylint: disable=too-many-locals,too-many-statements @@ -126,3 +126,26 @@ def test_mem_reader() -> None: loader = loaders["zz"] roi, pix = loader.read(cfgs["zz"], gbox, selection=np.s_[:2]) assert pix.shape == (2, *gbox.shape.yx) + + +def test_raster_group_md(): + rgm = raster_group_md(xr.Dataset()) + assert rgm.bands == {} + assert rgm.aliases == {} + assert rgm.extra_dims == {} + + coord = FixedCoord("band", ["r", "g", "b"], dim="band") + + rgm = raster_group_md( + xr.Dataset(), base=RasterGroupMetadata({}, {}, {"band": 3}, []) + ) + assert rgm.extra_dims == {"band": 3} + assert len(rgm.extra_coords) == 0 + + rgm = raster_group_md( + xr.Dataset(), base=RasterGroupMetadata({}, extra_coords=[coord]) + ) + assert rgm.extra_dims == {} + assert rgm.extra_dims_full() == {"band": 3} + assert len(rgm.extra_coords) == 1 + assert rgm.extra_coords[0] == coord diff --git a/odc/loader/testing/mem_reader.py b/odc/loader/testing/mem_reader.py index d7f5bcc..10d9a61 100644 --- a/odc/loader/testing/mem_reader.py +++ b/odc/loader/testing/mem_reader.py @@ -4,13 +4,15 @@ from __future__ import annotations +import json from contextlib import contextmanager -from typing import Any, Iterator +from typing import Any, Iterator, Sequence +import fsspec import numpy as np import xarray as xr from odc.geo.geobox import GeoBox -from odc.geo.xr import ODCExtensionDa, ODCExtensionDs, xr_reproject +from odc.geo.xr import ODCExtensionDa, ODCExtensionDs, xr_coords, xr_reproject from ..types import ( BandKey, @@ -25,6 +27,18 @@ ) +def extract_zarr_spec(src: dict[str, Any]) -> dict[str, Any] | None: + if "zarr:metadata" in src: + # TODO: handle zarr:chunks for reference filesystem + zmd = {"zarr_consolidated_format": 1, "metadata": src["zarr:metadata"]} + elif "zarr_consolidated_format" in src: + zmd = src + else: + zmd = {"zarr_consolidated_format": 1, "metadata": src} + + return {".zmetadata": json.dumps(zmd)} + + class XrMDPlugin: """ Convert xarray.Dataset to RasterGroupMetadata. @@ -35,22 +49,65 @@ class XrMDPlugin: - Driver data is xarray.DataArray for each band """ - def __init__(self, src: xr.Dataset) -> None: + def __init__( + self, + template: RasterGroupMetadata, + src: xr.Dataset | None = None, + ) -> None: + self._template = template self._src = src - self._md = raster_group_md(src) + + def _from_zarr_spec( + self, + spec_doc: dict[str, Any], + regen_coords: bool = False, + ) -> xr.Dataset: + rfs = fsspec.filesystem("reference", fo=spec_doc) + xx = xr.open_dataset(rfs.get_mapper(""), engine="zarr") + gbox = xx.odc.geobox + if gbox is not None and regen_coords: + # re-gen x,y coords from geobox + xx = xx.assign_coords(xr_coords(gbox)) + + return xx + + def _resolve_src(self, md: Any, regen_coords: bool = False) -> xr.Dataset | None: + src = self._src + + if isinstance(md, dict) and (spec_doc := extract_zarr_spec(md)) is not None: + src = self._from_zarr_spec(spec_doc, regen_coords=regen_coords) + + if isinstance(md, xr.Dataset): + src = md + + # TODO: support stac items and datacube datasets + + return src def extract(self, md: Any) -> RasterGroupMetadata: """Fixed description of src dataset.""" - assert md is not None - return self._md + if isinstance(md, RasterGroupMetadata): + return md + + if (src := self._resolve_src(md, regen_coords=False)) is not None: + return raster_group_md(src, base=self._template) + + return self._template - def driver_data(self, md: Any, band_key: BandKey) -> xr.DataArray: + def driver_data(self, md: Any, band_key: BandKey) -> xr.DataArray | None: """ Extract driver specific data for a given band. """ - assert md is not None name, _ = band_key - return self._src[name] + + if isinstance(md, xr.DataArray): + return md + + if (src := self._resolve_src(md, regen_coords=True)) is not None: + if (aa := src.data_vars.get(name)) is not None: + return aa + + return None class Context: @@ -60,17 +117,15 @@ class Context: def __init__( self, - src: xr.Dataset, geobox: GeoBox, chunks: None | dict[str, int], ) -> None: - self.src = src self.geobox = geobox self.chunks = chunks def with_env(self, env: dict[str, Any]) -> "Context": assert isinstance(env, dict) - return Context(self.src, self.geobox, self.chunks) + return Context(self.geobox, self.chunks) class XrMemReader: @@ -81,7 +136,6 @@ class XrMemReader: # pylint: disable=too-few-public-methods def __init__(self, src: RasterSource, ctx: Context) -> None: - self._src = src self._xx: xr.DataArray = src.driver_data self._ctx = ctx @@ -121,8 +175,17 @@ class XrMemReaderDriver: Reader = XrMemReader - def __init__(self, src: xr.Dataset) -> None: + def __init__( + self, + src: xr.Dataset | None = None, + template: RasterGroupMetadata | None = None, + ) -> None: + if src is not None and template is None: + template = raster_group_md(src) + if template is None: + template = RasterGroupMetadata({}, {}, {}, []) self.src = src + self.template = template def new_load( self, @@ -130,7 +193,7 @@ def new_load( *, chunks: None | dict[str, int] = None, ) -> Context: - return Context(self.src, geobox, chunks) + return Context(geobox, chunks) def finalise_load(self, load_state: Context) -> Context: return load_state @@ -149,7 +212,7 @@ def open(self, src: RasterSource, ctx: Context) -> XrMemReader: @property def md_parser(self) -> MDParser: - return XrMDPlugin(self.src) + return XrMDPlugin(self.template, src=self.src) @property def dask_reader(self) -> DaskRasterReader | None: @@ -177,25 +240,46 @@ def band_info(xx: xr.DataArray) -> RasterBandMetadata: ) -def raster_group_md(src: xr.Dataset) -> RasterGroupMetadata: +def raster_group_md( + src: xr.Dataset, + *, + base: RasterGroupMetadata | None = None, + aliases: dict[str, list[BandKey]] | None = None, + extra_coords: Sequence[FixedCoord] = (), + extra_dims: dict[str, int] | None = None, +) -> RasterGroupMetadata: oo: ODCExtensionDs = src.odc sdims = oo.spatial_dims or ("y", "x") - bands: dict[BandKey, RasterBandMetadata] = { - (str(k), 1): band_info(v) for k, v in src.data_vars.items() if v.ndim >= 2 - } + if base is None: + base = RasterGroupMetadata( + bands={}, + aliases=aliases or {}, + extra_coords=extra_coords, + extra_dims=extra_dims or {}, + ) + + bands = base.bands.copy() + bands.update( + {(str(k), 1): band_info(v) for k, v in src.data_vars.items() if v.ndim >= 2} + ) + + edims = base.extra_dims.copy() + edims.update({str(name): sz for name, sz in src.sizes.items() if name not in sdims}) - extra_dims: dict[str, int] = { - str(name): sz for name, sz in src.sizes.items() if name not in sdims - } + aliases: dict[str, list[BandKey]] = base.aliases.copy() - aliases: dict[str, list[BandKey]] = {} + extra_coords: list[FixedCoord] = list(base.extra_coords) + supplied_coords = set(coord.name for coord in extra_coords) - extra_coords: list[FixedCoord] = [] for coord in src.coords.values(): if len(coord.dims) != 1 or coord.dims[0] in sdims: # Only 1-d non-spatial coords continue + + if coord.name in supplied_coords: + continue + extra_coords.append( FixedCoord( coord.name, @@ -205,4 +289,9 @@ def raster_group_md(src: xr.Dataset) -> RasterGroupMetadata: ) ) - return RasterGroupMetadata(bands, aliases, extra_dims, extra_coords) + return RasterGroupMetadata( + bands=bands, + aliases=aliases, + extra_dims=edims, + extra_coords=extra_coords, + )