Skip to content

Commit

Permalink
make output band name as config options
Browse files Browse the repository at this point in the history
  • Loading branch information
Emma Ai committed Nov 12, 2024
1 parent d45be2a commit f792ffa
Show file tree
Hide file tree
Showing 9 changed files with 36 additions and 58 deletions.
7 changes: 5 additions & 2 deletions odc/stats/plugins/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def __init__(
rgb_clamp: Tuple[float, float] = (1.0, 3_000.0),
transform_code: Optional[str] = None,
area_of_interest: Optional[Sequence[float]] = None,
measurements: Optional[Sequence[str]] = None,
):
self.resampling = resampling
self.input_bands = input_bands if input_bands is not None else []
Expand All @@ -40,12 +41,14 @@ def __init__(
self.rgb_clamp = rgb_clamp
self.transform_code = transform_code
self.area_of_interest = area_of_interest
self._measurements = measurements
self.dask_worker_plugin = None

@property
@abstractmethod
def measurements(self) -> Tuple[str, ...]:
pass
if self._measurements is None:
raise NotImplementedError("Plugins must provide 'measurements'")
return self._measurements

def native_transform(self, xx: xr.Dataset) -> xr.Dataset:
for var in xx.data_vars:
Expand Down
13 changes: 2 additions & 11 deletions odc/stats/plugins/lc_fc_wo_a0.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,6 @@ def __init__(
self.ue_threshold = ue_threshold if ue_threshold is not None else 30
self.cloud_filters = cloud_filters if cloud_filters is not None else {}

@property
def measurements(self) -> Tuple[str, ...]:
_measurements = ["veg_frequency", "water_frequency"]
return _measurements

def native_transform(self, xx):
"""
Loads data in its native projection. It performs the following:
Expand Down Expand Up @@ -217,12 +212,8 @@ def reduce(self, xx: xr.Dataset) -> xr.Dataset:
attrs = xx.attrs.copy()
attrs["nodata"] = int(NODATA)
data_vars = {
"veg_frequency": xr.DataArray(
max_count_veg, dims=xx["wet"].dims[1:], attrs=attrs
),
"water_frequency": xr.DataArray(
max_count_water, dims=xx["wet"].dims[1:], attrs=attrs
),
k: xr.DataArray(v, dims=xx["wet"].dims[1:], attrs=attrs)
for k, v in zip(self.measurements, [max_count_veg, max_count_water])
}
coords = dict((dim, xx.coords[dim]) for dim in xx["wet"].dims[1:])
return xr.Dataset(data_vars=data_vars, coords=coords, attrs=xx.attrs)
Expand Down
9 changes: 2 additions & 7 deletions odc/stats/plugins/lc_tf_urban.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Plugin of TF urban model in LandCover PipeLine
"""

from typing import Tuple, Dict, Sequence
from typing import Dict, Sequence

import os
import numpy as np
Expand Down Expand Up @@ -91,11 +91,6 @@ def __init__(
else:
self.crop_size = crop_size

@property
def measurements(self) -> Tuple[str, ...]:
_measurements = ["urban_classes"]
return _measurements

def input_data(
self, datasets: Sequence[Dataset], geobox: GeoBox, **kwargs
) -> xr.Dataset:
Expand Down Expand Up @@ -219,7 +214,7 @@ def reduce(self, xx: xr.Dataset) -> xr.Dataset:
attrs = xx.attrs.copy()
attrs["nodata"] = int(NODATA)
dims = list(xx.dims.keys())[:2]
data_vars = {"urban_classes": xr.DataArray(um, dims=dims, attrs=attrs)}
data_vars = {self.measurements[0]: xr.DataArray(um, dims=dims, attrs=attrs)}
coords = {dim: xx.coords[dim] for dim in dims}
return xr.Dataset(data_vars=data_vars, coords=coords, attrs=attrs)

Expand Down
6 changes: 0 additions & 6 deletions odc/stats/plugins/lc_treelite_cultivated.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
Plugin of RFclassfication cultivated model in LandCover PipeLine
"""

from typing import Tuple
import numpy as np
import xarray as xr
import dask.array as da
Expand Down Expand Up @@ -226,11 +225,6 @@ class StatsCultivatedClass(StatsMLTree):
VERSION = "0.0.1"
PRODUCT_FAMILY = "lccs"

@property
def measurements(self) -> Tuple[str, ...]:
_measurements = ["cultivated"]
return _measurements

def predict(self, input_array):
bands_indices = dict(zip(self.input_bands, np.arange(len(self.input_bands))))
input_features = da.map_blocks(
Expand Down
6 changes: 0 additions & 6 deletions odc/stats/plugins/lc_treelite_woody.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
Plugin of RFregressor woody cover model in LandCover PipeLine
"""

from typing import Tuple
import xarray as xr
import dask.array as da

Expand All @@ -19,11 +18,6 @@ class StatsWoodyCover(StatsMLTree):
VERSION = "0.0.1"
PRODUCT_FAMILY = "lccs"

@property
def measurements(self) -> Tuple[str, ...]:
_measurements = ["woody"]
return _measurements

def predict(self, input_array):
wc = da.map_blocks(
mask_and_predict,
Expand Down
22 changes: 4 additions & 18 deletions odc/stats/plugins/lc_veg_class_a1.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,6 @@ def __init__(
**kwargs,
):
super().__init__(**kwargs)
self._measurements = (
measurements if measurements is not None else self.input_bands
)

@property
def measurements(self) -> Tuple[str, ...]:
return self._measurements

def native_transform(self, xx):
# reproject cannot work with nodata being int for float
Expand Down Expand Up @@ -89,11 +82,6 @@ def __init__(
)
self.output_classes = output_classes

@property
def measurements(self) -> Tuple[str, ...]:
_measurements = ["classes_l3_l4", "water_seasonality"]
return _measurements

def fuser(self, xx):
return xx

Expand Down Expand Up @@ -249,12 +237,10 @@ def reduce(self, xx: xr.Dataset) -> xr.Dataset:
attrs = xx.attrs.copy()
attrs["nodata"] = int(NODATA)
data_vars = {
"classes_l3_l4": xr.DataArray(
l3_mask[0], dims=xx["veg_frequency"].dims[1:], attrs=attrs
),
"water_seasonality": xr.DataArray(
water_seasonality[0], dims=xx["veg_frequency"].dims[1:], attrs=attrs
),
k: xr.DataArray(v, dims=xx["veg_frequency"].dims[1:], attrs=attrs)
for k, v in zip(
self.measurements, [l3_mask.squeeze(0), water_seasonality.squeeze(0)]
)
}
coords = dict((dim, xx.coords[dim]) for dim in xx["veg_frequency"].dims[1:])
return xr.Dataset(data_vars=data_vars, coords=coords, attrs=xx.attrs)
Expand Down
12 changes: 6 additions & 6 deletions tests/test_landcover_plugin_a0.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def fc_wo_dataset():
def test_native_transform(fc_wo_dataset, bits):
xx = fc_wo_dataset.copy()
xx["water"] = da.bitwise_or(xx["water"], bits)
stats_veg = StatsVegCount()
stats_veg = StatsVegCount(measurements=["veg_frequency", "water_frequency"])
out_xx = stats_veg.native_transform(xx).compute()

expected_valid = (
Expand Down Expand Up @@ -349,7 +349,7 @@ def test_native_transform(fc_wo_dataset, bits):


def test_fusing(fc_wo_dataset):
stats_veg = StatsVegCount()
stats_veg = StatsVegCount(measurements=["veg_frequency", "water_frequency"])
xx = stats_veg.native_transform(fc_wo_dataset)
xx = xx.groupby("solar_day").map(partial(StatsVegCount.fuser, None)).compute()
valid_index = (
Expand All @@ -369,7 +369,7 @@ def test_fusing(fc_wo_dataset):


def test_veg_or_not(fc_wo_dataset):
stats_veg = StatsVegCount()
stats_veg = StatsVegCount(measurements=["veg_frequency", "water_frequency"])
xx = stats_veg.native_transform(fc_wo_dataset)
xx = xx.groupby("solar_day").map(partial(StatsVegCount.fuser, None))
yy = stats_veg._veg_or_not(xx).compute()
Expand All @@ -386,7 +386,7 @@ def test_veg_or_not(fc_wo_dataset):


def test_water_or_not(fc_wo_dataset):
stats_veg = StatsVegCount()
stats_veg = StatsVegCount(measurements=["veg_frequency", "water_frequency"])
xx = stats_veg.native_transform(fc_wo_dataset)
xx = xx.groupby("solar_day").map(partial(StatsVegCount.fuser, None))
yy = stats_veg._water_or_not(xx).compute()
Expand All @@ -403,7 +403,7 @@ def test_water_or_not(fc_wo_dataset):


def test_reduce(fc_wo_dataset):
stats_veg = StatsVegCount()
stats_veg = StatsVegCount(measurements=["veg_frequency", "water_frequency"])
xx = stats_veg.native_transform(fc_wo_dataset)
xx = xx.groupby("solar_day").map(partial(StatsVegCount.fuser, None))
xx = stats_veg.reduce(xx).compute()
Expand Down Expand Up @@ -437,7 +437,7 @@ def test_reduce(fc_wo_dataset):


def test_consecutive_month(consecutive_count):
stats_veg = StatsVegCount()
stats_veg = StatsVegCount(measurements=["veg_frequency", "water_frequency"])
xx = stats_veg._max_consecutive_months(consecutive_count, 255).compute()
expected_value = np.array(
[
Expand Down
3 changes: 3 additions & 0 deletions tests/test_landcover_plugin_a1.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def test_l3_classes(dataset):
"surface": 210,
},
optional_bands=["canopy_cover_class", "elevation"],
measurements=["level_3_4", "water_season"],
)

expected_res = np.array(
Expand Down Expand Up @@ -163,6 +164,7 @@ def test_l4_water_seasonality(dataset):
"surface": 210,
},
optional_bands=["canopy_cover_class", "elevation"],
measurements=["level_3_4", "water_season"],
)

wo_fq = np.array(
Expand Down Expand Up @@ -208,6 +210,7 @@ def test_reduce(dataset):
"surface": 210,
},
optional_bands=["canopy_cover_class", "elevation"],
measurements=["level_3_4", "water_season"],
)
res = stats_l3.reduce(dataset)

Expand Down
16 changes: 14 additions & 2 deletions tests/test_rf_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,7 @@ def test_preprocess_predict_intput(
cultivated_model_path,
mask_bands,
input_bands=cultivated_input_bands,
measurements=["cultivated"],
)
res = cultivated.preprocess_predict_input(input_datasets)
for r in res:
Expand All @@ -440,6 +441,7 @@ def test_cultivated_predict(
cultivated_model_path,
mask_bands,
input_bands=cultivated_input_bands,
measurements=["cultivated"],
)
dask_client.register_plugin(cultivated.dask_worker_plugin)
imgs = cultivated.preprocess_predict_input(input_datasets)
Expand All @@ -462,6 +464,7 @@ def test_cultivated_aggregate_results(
cultivated_model_path,
mask_bands,
input_bands=cultivated_input_bands,
measurements=["cultivated"],
)
res = cultivated.aggregate_results_from_group([cultivated_results[0]])
assert (res.compute() == np.array([[112, 255], [111, 112]], dtype="uint8")).all()
Expand All @@ -482,6 +485,7 @@ def test_cultivated_reduce(
cultivated_model_path,
mask_bands,
input_bands=cultivated_input_bands,
measurements=["cultivated"],
)
dask_client.register_plugin(cultivated.dask_worker_plugin)
res = cultivated.reduce(input_datasets)
Expand All @@ -506,7 +510,11 @@ def test_woody_aggregate_results(
):

woody_cover = StatsWoodyCover(
woody_classes, woody_model_path, mask_bands, input_bands=woody_input_bands
woody_classes,
woody_model_path,
mask_bands,
input_bands=woody_input_bands,
measurements=["woody"],
)
res = woody_cover.aggregate_results_from_group([woody_results[0]])
assert (res.compute() == np.array([[113, 255], [114, 113]], dtype="uint8")).all()
Expand All @@ -524,7 +532,11 @@ def test_woody_reduce(
):
woody_inputs = input_datasets.sel(bands=woody_input_bands[:-1])
woody_cover = StatsWoodyCover(
woody_classes, woody_model_path, mask_bands, input_bands=woody_input_bands
woody_classes,
woody_model_path,
mask_bands,
input_bands=woody_input_bands,
measurements=["woody"],
)
dask_client.register_plugin(woody_cover.dask_worker_plugin)
res = woody_cover.reduce(woody_inputs)
Expand Down

0 comments on commit f792ffa

Please sign in to comment.