diff --git a/odc/stats/plugins/lc_fc_wo_a0.py b/odc/stats/plugins/lc_fc_wo_a0.py index 8ff8e64..a9c37ab 100644 --- a/odc/stats/plugins/lc_fc_wo_a0.py +++ b/odc/stats/plugins/lc_fc_wo_a0.py @@ -34,12 +34,14 @@ class StatsVegCount(StatsPluginInterface): def __init__( self, ue_threshold: Optional[int] = None, + veg_threshold: Optional[int] = None, cloud_filters: Dict[str, Iterable[Tuple[str, int]]] = None, **kwargs, ): super().__init__(input_bands=["water", "pv", "bs", "npv", "ue"], **kwargs) self.ue_threshold = ue_threshold if ue_threshold is not None else 30 + self.veg_threshold = veg_threshold if veg_threshold is not None else 2 self.cloud_filters = cloud_filters if cloud_filters is not None else {} def native_transform(self, xx): @@ -63,15 +65,6 @@ def native_transform(self, xx): # clear dry pixels clear = xx["water"].data == 0 - # get "valid" wo pixels, both dry and wet used in veg_frequency - wet_valid = expr_eval( - "where(a|b, a, _nan)", - {"a": wet, "b": valid}, - name="get_valid_pixels", - dtype="float32", - **{"_nan": np.nan}, - ) - # get "clear" wo pixels, both dry and wet used in water_frequency wet_clear = expr_eval( "where(a|b, a, _nan)", @@ -101,13 +94,7 @@ def native_transform(self, xx): dtype="float32", **{"_nan": np.nan}, ) - wet_valid = expr_eval( - "where(b>0, _nan, a)", - {"a": wet_valid, "b": raw_mask.data}, - name="get_valid_pixels", - dtype="float32", - **{"_nan": np.nan}, - ) + xx = xx.drop_vars(["water"]) # Pick out the fc pixels that have an unmixing error of less than the threshold @@ -124,9 +111,6 @@ def native_transform(self, xx): xx = keep_good_only(xx, valid, nodata=NODATA) xx = to_float(xx, dtype="float32") - xx["wet_valid"] = xr.DataArray( - wet_valid, dims=xx["pv"].dims, coords=xx["pv"].coords - ) xx["wet_clear"] = xr.DataArray( wet_clear, dims=xx["pv"].dims, coords=xx["pv"].coords ) @@ -135,16 +119,14 @@ def native_transform(self, xx): def fuser(self, xx): - wet_valid = xx["wet_valid"] wet_clear = xx["wet_clear"] xx = _xr_fuse( - xx.drop_vars(["wet_valid", "wet_clear"]), + xx.drop_vars(["wet_clear"]), partial(_fuse_mean_np, nodata=np.nan), "", ) - xx["wet_valid"] = _nodata_fuser(wet_valid, nodata=np.nan) xx["wet_clear"] = _nodata_fuser(wet_clear, nodata=np.nan) return xx @@ -168,14 +150,6 @@ def _veg_or_not(self, xx: xr.Dataset): **{"nodata": int(NODATA)}, ) - # mark water freq >= 0.5 as 0 - data = expr_eval( - "where(a>0, 0, b)", - {"a": xx["wet_valid"].data, "b": data}, - name="get_veg", - dtype="uint8", - ) - return data def _water_or_not(self, xx: xr.Dataset): @@ -262,8 +236,30 @@ def reduce(self, xx: xr.Dataset) -> xr.Dataset: xx = xx.groupby("time.month").map(median_ds, dim="spec") - data = self._veg_or_not(xx) - max_count_veg = self._max_consecutive_months(data, NODATA) + # consecutive observation of veg + veg_data = self._veg_or_not(xx) + max_count_veg = self._max_consecutive_months(veg_data, NODATA) + + # consecutive observation of non-veg + non_veg_data = expr_eval( + "where(a= threshold + # implies any "wet" area potentially veg + + max_count_veg = expr_eval( + "where((a<_v)&(b<_v), _v, b)", + {"a": max_count_non_veg, "b": max_count_veg}, + name="clip_veg", + dtype="uint8", + **{"_v": self.veg_threshold}, + ) data = self._water_or_not(xx) max_count_water = self._max_consecutive_months(data, NODATA, normalize=True) diff --git a/odc/stats/plugins/lc_ml_treelite.py b/odc/stats/plugins/lc_ml_treelite.py index 794fc8a..7f2b6a0 100644 --- a/odc/stats/plugins/lc_ml_treelite.py +++ b/odc/stats/plugins/lc_ml_treelite.py @@ -76,7 +76,9 @@ def __init__( self.dask_worker_plugin = TreeliteModelPlugin(model_path) self.output_classes = output_classes self.mask_bands = mask_bands - self.temporal_coverage = temporal_coverage + self.temporal_coverage = ( + temporal_coverage if temporal_coverage is not None else {} + ) self._log = logging.getLogger(__name__) def input_data( @@ -160,8 +162,8 @@ def convert_dtype(var): for var in xx.data_vars: if var not in self.mask_bands: - if self.temporal_coverage is not None: - # filter and impute by sensors + # filter and impute by sensors + if self.temporal_coverage.get(var) is not None: temporal_range = [ DateTimeRange(v) for v in self.temporal_coverage.get(var) ] diff --git a/tests/test_landcover_plugin_a0.py b/tests/test_landcover_plugin_a0.py index 095f7ca..0bd8cc8 100644 --- a/tests/test_landcover_plugin_a0.py +++ b/tests/test_landcover_plugin_a0.py @@ -322,17 +322,8 @@ def test_native_transform(fc_wo_dataset, bits): stats_veg = StatsVegCount(measurements=["veg_frequency", "water_frequency"]) out_xx = stats_veg.native_transform(xx).compute() - expected_valid = ( - np.array([0, 0, 0, 0, 0, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3]), - np.array([1, 1, 3, 5, 6, 2, 6, 2, 2, 5, 6, 0, 0, 2, 3]), - np.array([0, 3, 2, 1, 3, 5, 6, 1, 4, 5, 6, 0, 2, 4, 2]), - ) - result = np.where(out_xx["wet_valid"].data == out_xx["wet_valid"].data) - for a, b in zip(expected_valid, result): - assert (a == b).all() - expected_valid = (np.array([1, 2, 3]), np.array([6, 2, 0]), np.array([6, 1, 2])) - result = np.where(out_xx["wet_valid"].data == 1) + result = np.where(out_xx["wet_clear"].data == 1) for a, b in zip(expected_valid, result): assert (a == b).all() @@ -374,11 +365,11 @@ def test_veg_or_not(fc_wo_dataset): xx = xx.groupby("solar_day").map(partial(StatsVegCount.fuser, None)) yy = stats_veg._veg_or_not(xx).compute() valid_index = ( - np.array([0, 0, 1, 2, 2, 2, 2, 2]), - np.array([1, 5, 6, 0, 0, 2, 2, 3]), - np.array([0, 1, 6, 0, 2, 1, 4, 2]), + np.array([0, 0, 2, 2, 2]), + np.array([1, 5, 0, 2, 3]), + np.array([0, 1, 0, 4, 2]), ) - expected_value = np.array([1, 1, 0, 1, 0, 0, 1, 1]) + expected_value = np.array([1, 1, 1, 1, 1]) i = 0 for idx in zip(*valid_index): assert yy[idx] == expected_value[i]