From f792ffa54ce3df7b250dde2d24b0ba0d1e44fe89 Mon Sep 17 00:00:00 2001 From: Emma Ai Date: Tue, 12 Nov 2024 04:50:22 +0000 Subject: [PATCH] make output band name as config options --- odc/stats/plugins/_base.py | 7 +++++-- odc/stats/plugins/lc_fc_wo_a0.py | 13 ++---------- odc/stats/plugins/lc_tf_urban.py | 9 ++------- odc/stats/plugins/lc_treelite_cultivated.py | 6 ------ odc/stats/plugins/lc_treelite_woody.py | 6 ------ odc/stats/plugins/lc_veg_class_a1.py | 22 ++++----------------- tests/test_landcover_plugin_a0.py | 12 +++++------ tests/test_landcover_plugin_a1.py | 3 +++ tests/test_rf_models.py | 16 +++++++++++++-- 9 files changed, 36 insertions(+), 58 deletions(-) diff --git a/odc/stats/plugins/_base.py b/odc/stats/plugins/_base.py index b45745e6..f6df4ac6 100644 --- a/odc/stats/plugins/_base.py +++ b/odc/stats/plugins/_base.py @@ -29,6 +29,7 @@ def __init__( rgb_clamp: Tuple[float, float] = (1.0, 3_000.0), transform_code: Optional[str] = None, area_of_interest: Optional[Sequence[float]] = None, + measurements: Optional[Sequence[str]] = None, ): self.resampling = resampling self.input_bands = input_bands if input_bands is not None else [] @@ -40,12 +41,14 @@ def __init__( self.rgb_clamp = rgb_clamp self.transform_code = transform_code self.area_of_interest = area_of_interest + self._measurements = measurements self.dask_worker_plugin = None @property - @abstractmethod def measurements(self) -> Tuple[str, ...]: - pass + if self._measurements is None: + raise NotImplementedError("Plugins must provide 'measurements'") + return self._measurements def native_transform(self, xx: xr.Dataset) -> xr.Dataset: for var in xx.data_vars: diff --git a/odc/stats/plugins/lc_fc_wo_a0.py b/odc/stats/plugins/lc_fc_wo_a0.py index 620db9ea..8ee16ab3 100644 --- a/odc/stats/plugins/lc_fc_wo_a0.py +++ b/odc/stats/plugins/lc_fc_wo_a0.py @@ -42,11 +42,6 @@ def __init__( self.ue_threshold = ue_threshold if ue_threshold is not None else 30 self.cloud_filters = cloud_filters if cloud_filters is not None else {} - @property - def measurements(self) -> Tuple[str, ...]: - _measurements = ["veg_frequency", "water_frequency"] - return _measurements - def native_transform(self, xx): """ Loads data in its native projection. It performs the following: @@ -217,12 +212,8 @@ def reduce(self, xx: xr.Dataset) -> xr.Dataset: attrs = xx.attrs.copy() attrs["nodata"] = int(NODATA) data_vars = { - "veg_frequency": xr.DataArray( - max_count_veg, dims=xx["wet"].dims[1:], attrs=attrs - ), - "water_frequency": xr.DataArray( - max_count_water, dims=xx["wet"].dims[1:], attrs=attrs - ), + k: xr.DataArray(v, dims=xx["wet"].dims[1:], attrs=attrs) + for k, v in zip(self.measurements, [max_count_veg, max_count_water]) } coords = dict((dim, xx.coords[dim]) for dim in xx["wet"].dims[1:]) return xr.Dataset(data_vars=data_vars, coords=coords, attrs=xx.attrs) diff --git a/odc/stats/plugins/lc_tf_urban.py b/odc/stats/plugins/lc_tf_urban.py index b941fca7..d899ee40 100644 --- a/odc/stats/plugins/lc_tf_urban.py +++ b/odc/stats/plugins/lc_tf_urban.py @@ -2,7 +2,7 @@ Plugin of TF urban model in LandCover PipeLine """ -from typing import Tuple, Dict, Sequence +from typing import Dict, Sequence import os import numpy as np @@ -91,11 +91,6 @@ def __init__( else: self.crop_size = crop_size - @property - def measurements(self) -> Tuple[str, ...]: - _measurements = ["urban_classes"] - return _measurements - def input_data( self, datasets: Sequence[Dataset], geobox: GeoBox, **kwargs ) -> xr.Dataset: @@ -219,7 +214,7 @@ def reduce(self, xx: xr.Dataset) -> xr.Dataset: attrs = xx.attrs.copy() attrs["nodata"] = int(NODATA) dims = list(xx.dims.keys())[:2] - data_vars = {"urban_classes": xr.DataArray(um, dims=dims, attrs=attrs)} + data_vars = {self.measurements[0]: xr.DataArray(um, dims=dims, attrs=attrs)} coords = {dim: xx.coords[dim] for dim in dims} return xr.Dataset(data_vars=data_vars, coords=coords, attrs=attrs) diff --git a/odc/stats/plugins/lc_treelite_cultivated.py b/odc/stats/plugins/lc_treelite_cultivated.py index 81edc6ab..096b9e24 100644 --- a/odc/stats/plugins/lc_treelite_cultivated.py +++ b/odc/stats/plugins/lc_treelite_cultivated.py @@ -2,7 +2,6 @@ Plugin of RFclassfication cultivated model in LandCover PipeLine """ -from typing import Tuple import numpy as np import xarray as xr import dask.array as da @@ -226,11 +225,6 @@ class StatsCultivatedClass(StatsMLTree): VERSION = "0.0.1" PRODUCT_FAMILY = "lccs" - @property - def measurements(self) -> Tuple[str, ...]: - _measurements = ["cultivated"] - return _measurements - def predict(self, input_array): bands_indices = dict(zip(self.input_bands, np.arange(len(self.input_bands)))) input_features = da.map_blocks( diff --git a/odc/stats/plugins/lc_treelite_woody.py b/odc/stats/plugins/lc_treelite_woody.py index 17adda35..949e7743 100644 --- a/odc/stats/plugins/lc_treelite_woody.py +++ b/odc/stats/plugins/lc_treelite_woody.py @@ -2,7 +2,6 @@ Plugin of RFregressor woody cover model in LandCover PipeLine """ -from typing import Tuple import xarray as xr import dask.array as da @@ -19,11 +18,6 @@ class StatsWoodyCover(StatsMLTree): VERSION = "0.0.1" PRODUCT_FAMILY = "lccs" - @property - def measurements(self) -> Tuple[str, ...]: - _measurements = ["woody"] - return _measurements - def predict(self, input_array): wc = da.map_blocks( mask_and_predict, diff --git a/odc/stats/plugins/lc_veg_class_a1.py b/odc/stats/plugins/lc_veg_class_a1.py index f41df6db..537cec82 100644 --- a/odc/stats/plugins/lc_veg_class_a1.py +++ b/odc/stats/plugins/lc_veg_class_a1.py @@ -25,13 +25,6 @@ def __init__( **kwargs, ): super().__init__(**kwargs) - self._measurements = ( - measurements if measurements is not None else self.input_bands - ) - - @property - def measurements(self) -> Tuple[str, ...]: - return self._measurements def native_transform(self, xx): # reproject cannot work with nodata being int for float @@ -89,11 +82,6 @@ def __init__( ) self.output_classes = output_classes - @property - def measurements(self) -> Tuple[str, ...]: - _measurements = ["classes_l3_l4", "water_seasonality"] - return _measurements - def fuser(self, xx): return xx @@ -249,12 +237,10 @@ def reduce(self, xx: xr.Dataset) -> xr.Dataset: attrs = xx.attrs.copy() attrs["nodata"] = int(NODATA) data_vars = { - "classes_l3_l4": xr.DataArray( - l3_mask[0], dims=xx["veg_frequency"].dims[1:], attrs=attrs - ), - "water_seasonality": xr.DataArray( - water_seasonality[0], dims=xx["veg_frequency"].dims[1:], attrs=attrs - ), + k: xr.DataArray(v, dims=xx["veg_frequency"].dims[1:], attrs=attrs) + for k, v in zip( + self.measurements, [l3_mask.squeeze(0), water_seasonality.squeeze(0)] + ) } coords = dict((dim, xx.coords[dim]) for dim in xx["veg_frequency"].dims[1:]) return xr.Dataset(data_vars=data_vars, coords=coords, attrs=xx.attrs) diff --git a/tests/test_landcover_plugin_a0.py b/tests/test_landcover_plugin_a0.py index 5a2369a3..eb636ed4 100644 --- a/tests/test_landcover_plugin_a0.py +++ b/tests/test_landcover_plugin_a0.py @@ -319,7 +319,7 @@ def fc_wo_dataset(): def test_native_transform(fc_wo_dataset, bits): xx = fc_wo_dataset.copy() xx["water"] = da.bitwise_or(xx["water"], bits) - stats_veg = StatsVegCount() + stats_veg = StatsVegCount(measurements=["veg_frequency", "water_frequency"]) out_xx = stats_veg.native_transform(xx).compute() expected_valid = ( @@ -349,7 +349,7 @@ def test_native_transform(fc_wo_dataset, bits): def test_fusing(fc_wo_dataset): - stats_veg = StatsVegCount() + stats_veg = StatsVegCount(measurements=["veg_frequency", "water_frequency"]) xx = stats_veg.native_transform(fc_wo_dataset) xx = xx.groupby("solar_day").map(partial(StatsVegCount.fuser, None)).compute() valid_index = ( @@ -369,7 +369,7 @@ def test_fusing(fc_wo_dataset): def test_veg_or_not(fc_wo_dataset): - stats_veg = StatsVegCount() + stats_veg = StatsVegCount(measurements=["veg_frequency", "water_frequency"]) xx = stats_veg.native_transform(fc_wo_dataset) xx = xx.groupby("solar_day").map(partial(StatsVegCount.fuser, None)) yy = stats_veg._veg_or_not(xx).compute() @@ -386,7 +386,7 @@ def test_veg_or_not(fc_wo_dataset): def test_water_or_not(fc_wo_dataset): - stats_veg = StatsVegCount() + stats_veg = StatsVegCount(measurements=["veg_frequency", "water_frequency"]) xx = stats_veg.native_transform(fc_wo_dataset) xx = xx.groupby("solar_day").map(partial(StatsVegCount.fuser, None)) yy = stats_veg._water_or_not(xx).compute() @@ -403,7 +403,7 @@ def test_water_or_not(fc_wo_dataset): def test_reduce(fc_wo_dataset): - stats_veg = StatsVegCount() + stats_veg = StatsVegCount(measurements=["veg_frequency", "water_frequency"]) xx = stats_veg.native_transform(fc_wo_dataset) xx = xx.groupby("solar_day").map(partial(StatsVegCount.fuser, None)) xx = stats_veg.reduce(xx).compute() @@ -437,7 +437,7 @@ def test_reduce(fc_wo_dataset): def test_consecutive_month(consecutive_count): - stats_veg = StatsVegCount() + stats_veg = StatsVegCount(measurements=["veg_frequency", "water_frequency"]) xx = stats_veg._max_consecutive_months(consecutive_count, 255).compute() expected_value = np.array( [ diff --git a/tests/test_landcover_plugin_a1.py b/tests/test_landcover_plugin_a1.py index 580c7f8f..22e4beea 100644 --- a/tests/test_landcover_plugin_a1.py +++ b/tests/test_landcover_plugin_a1.py @@ -135,6 +135,7 @@ def test_l3_classes(dataset): "surface": 210, }, optional_bands=["canopy_cover_class", "elevation"], + measurements=["level_3_4", "water_season"], ) expected_res = np.array( @@ -163,6 +164,7 @@ def test_l4_water_seasonality(dataset): "surface": 210, }, optional_bands=["canopy_cover_class", "elevation"], + measurements=["level_3_4", "water_season"], ) wo_fq = np.array( @@ -208,6 +210,7 @@ def test_reduce(dataset): "surface": 210, }, optional_bands=["canopy_cover_class", "elevation"], + measurements=["level_3_4", "water_season"], ) res = stats_l3.reduce(dataset) diff --git a/tests/test_rf_models.py b/tests/test_rf_models.py index ad16e6f5..b297f691 100644 --- a/tests/test_rf_models.py +++ b/tests/test_rf_models.py @@ -421,6 +421,7 @@ def test_preprocess_predict_intput( cultivated_model_path, mask_bands, input_bands=cultivated_input_bands, + measurements=["cultivated"], ) res = cultivated.preprocess_predict_input(input_datasets) for r in res: @@ -440,6 +441,7 @@ def test_cultivated_predict( cultivated_model_path, mask_bands, input_bands=cultivated_input_bands, + measurements=["cultivated"], ) dask_client.register_plugin(cultivated.dask_worker_plugin) imgs = cultivated.preprocess_predict_input(input_datasets) @@ -462,6 +464,7 @@ def test_cultivated_aggregate_results( cultivated_model_path, mask_bands, input_bands=cultivated_input_bands, + measurements=["cultivated"], ) res = cultivated.aggregate_results_from_group([cultivated_results[0]]) assert (res.compute() == np.array([[112, 255], [111, 112]], dtype="uint8")).all() @@ -482,6 +485,7 @@ def test_cultivated_reduce( cultivated_model_path, mask_bands, input_bands=cultivated_input_bands, + measurements=["cultivated"], ) dask_client.register_plugin(cultivated.dask_worker_plugin) res = cultivated.reduce(input_datasets) @@ -506,7 +510,11 @@ def test_woody_aggregate_results( ): woody_cover = StatsWoodyCover( - woody_classes, woody_model_path, mask_bands, input_bands=woody_input_bands + woody_classes, + woody_model_path, + mask_bands, + input_bands=woody_input_bands, + measurements=["woody"], ) res = woody_cover.aggregate_results_from_group([woody_results[0]]) assert (res.compute() == np.array([[113, 255], [114, 113]], dtype="uint8")).all() @@ -524,7 +532,11 @@ def test_woody_reduce( ): woody_inputs = input_datasets.sel(bands=woody_input_bands[:-1]) woody_cover = StatsWoodyCover( - woody_classes, woody_model_path, mask_bands, input_bands=woody_input_bands + woody_classes, + woody_model_path, + mask_bands, + input_bands=woody_input_bands, + measurements=["woody"], ) dask_client.register_plugin(woody_cover.dask_worker_plugin) res = woody_cover.reduce(woody_inputs)