Skip to content

Commit

Permalink
Make it possible to set source subdataset
Browse files Browse the repository at this point in the history
When driver data is a dict with `subdataset` key,
set that on RasterSource.
  • Loading branch information
Kirill888 committed Jun 24, 2024
1 parent 975dada commit 7814ca3
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 3 deletions.
15 changes: 12 additions & 3 deletions odc/loader/testing/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,11 @@ def __init__(
self,
group_md: RasterGroupMetadata,
driver_data,
add_subdataset: bool = False,
):
self._group_md = group_md
self._driver_data = driver_data
self._add_subdataset = add_subdataset

def extract(self, md: Any) -> RasterGroupMetadata:
assert md is not None
Expand All @@ -107,12 +109,19 @@ def extract(self, md: Any) -> RasterGroupMetadata:
def driver_data(self, md, band_key: BandKey) -> Any:
assert md is not None
name, _ = band_key

def _patch(x):
if not isinstance(x, dict) or self._add_subdataset is False:
return x
return {"subdataset": name, **x}

if isinstance(self._driver_data, dict):
if name in self._driver_data:
return self._driver_data[name]
return _patch(self._driver_data[name])
if band_key in self._driver_data:
return self._driver_data[band_key]
return self._driver_data
return _patch(self._driver_data[band_key])

return _patch(self._driver_data)


class FakeReader:
Expand Down
4 changes: 4 additions & 0 deletions odc/stac/_mdtools.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,8 +692,11 @@ def _get_grid(grid_name: str, asset: pystac.asset.Asset) -> GeoBox:
) # pragma: no cover (https://github.com/stac-utils/pystac/issues/754)

driver_data: Any = None
subdataset: str | None = None
if md_plugin is not None:
driver_data = md_plugin.driver_data(asset, bk)
if isinstance(driver_data, dict):
subdataset = driver_data.get("subdataset", None)

# Assumption: if extra dims are defined then asset bands are loaded into 3d+ array
if meta.extra_dims:
Expand All @@ -702,6 +705,7 @@ def _get_grid(grid_name: str, asset: pystac.asset.Asset) -> GeoBox:
bands[bk] = RasterSource(
uri=uri,
band=band_idx,
subdataset=subdataset,
geobox=geobox,
meta=meta,
driver_data=driver_data,
Expand Down
10 changes: 10 additions & 0 deletions tests/test_mdtools.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,16 @@ def test_parse_item_with_plugin():
assert pit.collection["b2"].data_type == "float32"
assert pit["b1"].driver_data == {"foo": "bar"}
assert pit["b2"].driver_data == {"foo": "bar"}
assert pit["b1"].subdataset is None
assert pit["b2"].subdataset is None

md_plugin = FakeMDPlugin(group_md, {"foo": "bar"}, add_subdataset=True)
pit = parse_item(item, md_plugin=md_plugin)
assert isinstance(pit, ParsedItem)
assert pit.collection["b1"].data_type == "uint8"
assert pit["b1"].driver_data == {"foo": "bar", "subdataset": "AA"}
assert pit["b1"].subdataset == "AA"
assert pit["b2"].subdataset == "AA"


def test_noassets_case(no_bands_stac):
Expand Down

0 comments on commit 7814ca3

Please sign in to comment.