Skip to content

Commit

Permalink
Implement zarr support in xrmemreader
Browse files Browse the repository at this point in the history
- Refactor xrmemreader to support multiple
  xarray inputs
- Add zarr support to xrmemreader
  - Construct xarrays from zarr store when needed
    rather than only using in-memory xarray
  • Loading branch information
Kirill888 committed Jun 24, 2024
1 parent 634e590 commit c06020a
Show file tree
Hide file tree
Showing 3 changed files with 274 additions and 37 deletions.
93 changes: 84 additions & 9 deletions odc/loader/test_memreader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,43 @@

from __future__ import annotations

import json

import numpy as np
import pytest
import xarray as xr
from odc.geo.data import country_geom
from odc.geo.gcp import GCPGeoBox
from odc.geo.geobox import GeoBox
from odc.geo.xr import ODCExtensionDa, rasterize
from odc.geo.xr import ODCExtensionDa, ODCExtensionDs, rasterize

from odc.loader.testing.mem_reader import XrMemReader, XrMemReaderDriver
from odc.loader.types import RasterGroupMetadata, RasterLoadParams, RasterSource
from .testing.mem_reader import Context, XrMemReader, XrMemReaderDriver, raster_group_md
from .types import FixedCoord, RasterGroupMetadata, RasterLoadParams, RasterSource

# pylint: disable=missing-function-docstring,use-implicit-booleaness-not-comparison
# pylint: disable=too-many-locals,too-many-statements

# pylint: disable=too-many-locals,too-many-statements,redefined-outer-name

def test_mem_reader() -> None:
fake_item = object()

@pytest.fixture
def sample_ds() -> xr.Dataset:
poly = country_geom("AUS", 3857)
gbox = GeoBox.from_geopolygon(poly, resolution=10_000)
xx = rasterize(poly, gbox).astype("int16")
xx.attrs["units"] = "uu"
xx.attrs["nodata"] = -33

ds = xx.to_dataset(name="xx")
driver = XrMemReaderDriver(ds)
return xx.to_dataset(name="xx")


def test_mem_reader(sample_ds: xr.Dataset) -> None:
fake_item = object()

assert isinstance(sample_ds.odc, ODCExtensionDs)
gbox = sample_ds.odc.geobox
assert gbox is not None
assert isinstance(gbox, GeoBox)

driver = XrMemReaderDriver(sample_ds)

assert driver.md_parser is not None

Expand All @@ -43,6 +56,8 @@ def test_mem_reader() -> None:
assert md.extra_dims == {}
assert md.extra_coords == []

ds = sample_ds.copy()
xx = ds.xx
yy = xx.astype("uint8", keep_attrs=False).rename("yy")
yy = yy.expand_dims("band", 2)
yy = xr.concat([yy, yy + 1, yy + 2], "band").assign_coords(band=["r", "g", "b"])
Expand Down Expand Up @@ -126,3 +141,63 @@ def test_mem_reader() -> None:
loader = loaders["zz"]
roi, pix = loader.read(cfgs["zz"], gbox, selection=np.s_[:2])
assert pix.shape == (2, *gbox.shape.yx)


def test_raster_group_md():
rgm = raster_group_md(xr.Dataset())
assert rgm.bands == {}
assert rgm.aliases == {}
assert rgm.extra_dims == {}

coord = FixedCoord("band", ["r", "g", "b"], dim="band")

rgm = raster_group_md(
xr.Dataset(), base=RasterGroupMetadata({}, {}, {"band": 3}, [])
)
assert rgm.extra_dims == {"band": 3}
assert len(rgm.extra_coords) == 0

rgm = raster_group_md(
xr.Dataset(), base=RasterGroupMetadata({}, extra_coords=[coord])
)
assert rgm.extra_dims == {}
assert rgm.extra_dims_full() == {"band": 3}
assert len(rgm.extra_coords) == 1
assert rgm.extra_coords[0] == coord


def test_memreader_zarr(sample_ds: xr.Dataset):
assert isinstance(sample_ds.odc, ODCExtensionDs)
assert "xx" in sample_ds

zarr = pytest.importorskip("zarr")
assert zarr is not None

_gbox = sample_ds.odc.geobox
chunks = None
assert _gbox is not None
gbox = _gbox.approx if isinstance(_gbox, GCPGeoBox) else _gbox

md_store: dict[str, bytes] = {}
chunk_store: dict[str, bytes] = {}
sample_ds.to_zarr(md_store, chunk_store, compute=False, consolidated=True)

assert ".zmetadata" in md_store
zmd = json.loads(md_store[".zmetadata"])["metadata"]

src = RasterSource(
"file:///tmp/no-such-dir/xx.zarr",
subdataset="xx",
driver_data=zmd,
)
assert src.driver_data is zmd

cfg = RasterLoadParams.same_as(src)

ctx = Context(gbox, chunks)
rdr = XrMemReader(src, ctx)

roi, xx = rdr.read(cfg, gbox)
assert isinstance(xx, np.ndarray)
assert xx.shape == gbox[roi].shape.yx
assert gbox == gbox[roi]
Loading

0 comments on commit c06020a

Please sign in to comment.