Skip to content

Commit

Permalink
sqme: memreader
Browse files Browse the repository at this point in the history
  • Loading branch information
Kirill888 committed Jun 24, 2024
1 parent 62295b7 commit 6f185ba
Show file tree
Hide file tree
Showing 2 changed files with 156 additions and 41 deletions.
60 changes: 52 additions & 8 deletions odc/loader/test_memreader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,40 @@

from __future__ import annotations

import json

import numpy as np
import pytest
import xarray as xr
from odc.geo.data import country_geom
from odc.geo.geobox import GeoBox
from odc.geo.xr import ODCExtensionDa, rasterize
from odc.geo.xr import ODCExtensionDa, ODCExtensionDs, rasterize

from .testing.mem_reader import XrMemReader, XrMemReaderDriver, raster_group_md
from .testing.mem_reader import Context, XrMemReader, XrMemReaderDriver, raster_group_md
from .types import FixedCoord, RasterGroupMetadata, RasterLoadParams, RasterSource

# pylint: disable=missing-function-docstring,use-implicit-booleaness-not-comparison
# pylint: disable=too-many-locals,too-many-statements

# pylint: disable=too-many-locals,too-many-statements,redefined-outer-name

def test_mem_reader() -> None:
fake_item = object()

@pytest.fixture
def sample_ds() -> xr.Dataset:
poly = country_geom("AUS", 3857)
gbox = GeoBox.from_geopolygon(poly, resolution=10_000)
xx = rasterize(poly, gbox).astype("int16")
xx.attrs["units"] = "uu"
xx.attrs["nodata"] = -33

ds = xx.to_dataset(name="xx")
driver = XrMemReaderDriver(ds)
return xx.to_dataset(name="xx")


def test_mem_reader(sample_ds: xr.Dataset) -> None:
fake_item = object()

assert isinstance(sample_ds.odc, ODCExtensionDs)
gbox = sample_ds.odc.geobox
assert gbox is not None
driver = XrMemReaderDriver(sample_ds)

assert driver.md_parser is not None

Expand All @@ -43,6 +53,8 @@ def test_mem_reader() -> None:
assert md.extra_dims == {}
assert md.extra_coords == []

ds = sample_ds.copy()
xx = ds.xx
yy = xx.astype("uint8", keep_attrs=False).rename("yy")
yy = yy.expand_dims("band", 2)
yy = xr.concat([yy, yy + 1, yy + 2], "band").assign_coords(band=["r", "g", "b"])
Expand Down Expand Up @@ -149,3 +161,35 @@ def test_raster_group_md():
assert rgm.extra_dims_full() == {"band": 3}
assert len(rgm.extra_coords) == 1
assert rgm.extra_coords[0] == coord


def test_memreader_zarr(sample_ds: xr.Dataset):
assert isinstance(sample_ds.odc, ODCExtensionDs)
assert "xx" in sample_ds

gbox = sample_ds.odc.geobox
chunks = None

md_store = {}
chunk_store = {}
sample_ds.to_zarr(md_store, chunk_store, compute=False, consolidated=True)

assert ".zmetadata" in md_store
zmd = json.loads(md_store[".zmetadata"])["metadata"]

src = RasterSource(
"file:///tmp/no-such-dir/xx.zarr",
subdataset="xx",
driver_data=zmd,
)
assert src.driver_data is zmd

cfg = RasterLoadParams.same_as(src)

ctx = Context(gbox, chunks)
rdr = XrMemReader(src, ctx)

roi, xx = rdr.read(cfg, gbox)
assert isinstance(xx, np.ndarray)
assert xx.shape == gbox[roi].shape.yx
assert gbox == gbox[roi]
137 changes: 104 additions & 33 deletions odc/loader/testing/mem_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,15 @@
ReaderSubsetSelection,
)

SomeDoc = dict[str, Any]
ZarrSpec = dict[str, Any]
ZarrSpecFs = dict[str, Any]


def extract_zarr_spec(src: SomeDoc) -> ZarrSpecFs | None:
if ".zmetadata" in src:
return src

def extract_zarr_spec(src: dict[str, Any]) -> dict[str, Any] | None:
if "zarr:metadata" in src:
# TODO: handle zarr:chunks for reference filesystem
zmd = {"zarr_consolidated_format": 1, "metadata": src["zarr:metadata"]}
Expand All @@ -39,6 +46,45 @@ def extract_zarr_spec(src: dict[str, Any]) -> dict[str, Any] | None:
return {".zmetadata": json.dumps(zmd)}


def _from_zarr_spec(
spec_doc: ZarrSpecFs,
regen_coords: bool = False,
fs: fsspec.AbstractFileSystem | None = None,
chunks=None,
target: str | None = None,
fsspec_opts: dict[str, Any] | None = None,
) -> xr.Dataset:
fsspec_opts = fsspec_opts or {}
rfs = fsspec.filesystem(
"reference", fo=spec_doc, fs=fs, target=target, **fsspec_opts
)

xx = xr.open_dataset(rfs.get_mapper(""), engine="zarr", mode="r", chunks=chunks)
gbox = xx.odc.geobox
if gbox is not None and regen_coords:
# re-gen x,y coords from geobox
xx = xx.assign_coords(xr_coords(gbox))

return xx


def _resolve_src_dataset(
md: Any,
*,
regen_coords: bool = False,
fallback: xr.Dataset | None = None,
**kw,
) -> xr.Dataset | None:
if isinstance(md, dict) and (spec_doc := extract_zarr_spec(md)) is not None:
return _from_zarr_spec(spec_doc, regen_coords=regen_coords, **kw)

if isinstance(md, xr.Dataset):
return md

# TODO: support stac items and datacube datasets
return fallback


class XrMDPlugin:
"""
Convert xarray.Dataset to RasterGroupMetadata.
Expand All @@ -57,32 +103,10 @@ def __init__(
self._template = template
self._src = src

def _from_zarr_spec(
self,
spec_doc: dict[str, Any],
regen_coords: bool = False,
) -> xr.Dataset:
rfs = fsspec.filesystem("reference", fo=spec_doc)
xx = xr.open_dataset(rfs.get_mapper(""), engine="zarr")
gbox = xx.odc.geobox
if gbox is not None and regen_coords:
# re-gen x,y coords from geobox
xx = xx.assign_coords(xr_coords(gbox))

return xx

def _resolve_src(self, md: Any, regen_coords: bool = False) -> xr.Dataset | None:
src = self._src

if isinstance(md, dict) and (spec_doc := extract_zarr_spec(md)) is not None:
src = self._from_zarr_spec(spec_doc, regen_coords=regen_coords)

if isinstance(md, xr.Dataset):
src = md

# TODO: support stac items and datacube datasets

return src
return _resolve_src_dataset(
md, regen_coords=regen_coords, fallback=self._src, chunks={}
)

def extract(self, md: Any) -> RasterGroupMetadata:
"""Fixed description of src dataset."""
Expand All @@ -94,20 +118,27 @@ def extract(self, md: Any) -> RasterGroupMetadata:

return self._template

def driver_data(self, md: Any, band_key: BandKey) -> xr.DataArray | None:
def driver_data(
self, md: Any, band_key: BandKey
) -> xr.DataArray | ZarrSpec | ZarrSpecFs | SomeDoc | None:
"""
Extract driver specific data for a given band.
"""
name, _ = band_key

if isinstance(md, dict):
if (spec_doc := extract_zarr_spec(md)) is not None:
return spec_doc
return md

if isinstance(md, xr.DataArray):
return md

if (src := self._resolve_src(md, regen_coords=True)) is not None:
if (aa := src.data_vars.get(name)) is not None:
return aa
src = self._resolve_src(md, regen_coords=False)
if src is None or name not in src.data_vars:
return None

return None
return src.data_vars[name]


class Context:
Expand Down Expand Up @@ -136,8 +167,48 @@ class XrMemReader:
# pylint: disable=too-few-public-methods

def __init__(self, src: RasterSource, ctx: Context) -> None:
self._xx: xr.DataArray = src.driver_data
driver_data: xr.DataArray | xr.Dataset | SomeDoc = src.driver_data
self._spec: ZarrSpecFs | None = None
self._xx: xr.DataArray | None = None
self._ctx = ctx
self._src = src

if isinstance(driver_data, xr.DataArray):
self._xx = driver_data
elif isinstance(driver_data, xr.Dataset):
subdataset = src.subdataset
assert subdataset in driver_data.data_vars
self._xx = driver_data.data_vars[subdataset]
elif isinstance(driver_data, dict):
self._spec = extract_zarr_spec(driver_data)
elif driver_data is not None:
raise ValueError(f"Unsupported driver data type: {type(driver_data)}")

assert driver_data is None or (self._spec is not None or self._xx is not None)

def _resolve_data_array(self) -> xr.DataArray:
if self._xx is not None:
return self._xx

assert self._spec is not None
kw = {
"target": self._src.uri,
"chunks": self._ctx.chunks,
}
src_ds = _from_zarr_spec(self._spec, regen_coords=True, **kw)

if src_ds is None:
raise ValueError("Failed to interpret driver data")

subdataset = self._src.subdataset
if subdataset is None:
_first, *_ = src_ds.data_vars
subdataset = str(_first)

if subdataset not in src_ds.data_vars:
raise ValueError(f"Band {subdataset!r} not found in dataset")
self._xx = src_ds.data_vars[subdataset]
return self._xx

def read(
self,
Expand All @@ -147,7 +218,7 @@ def read(
dst: np.ndarray | None = None,
selection: ReaderSubsetSelection | None = None,
) -> tuple[tuple[slice, slice], np.ndarray]:
src = self._xx
src = self._resolve_data_array()

if selection is not None:
# only support single extra dimension
Expand Down

0 comments on commit 6f185ba

Please sign in to comment.