diff --git a/odc/stats/plugins/lc_fc_wo_a0.py b/odc/stats/plugins/lc_fc_wo_a0.py index 8ee16ab..9475a5f 100644 --- a/odc/stats/plugins/lc_fc_wo_a0.py +++ b/odc/stats/plugins/lc_fc_wo_a0.py @@ -73,7 +73,7 @@ def native_transform(self, xx): xx = xx.drop_vars(["water"]) # get valid wo pixels, both dry and wet - data = expr_eval( + wet = expr_eval( "where(a|b, a, _nan)", {"a": wet.data, "b": valid.data}, name="get_valid_pixels", @@ -81,40 +81,68 @@ def native_transform(self, xx): **{"_nan": np.nan}, ) - # Pick out the fc pixels that have an unmixing error of less than the threshold - valid &= xx["ue"] < self.ue_threshold + # get high ue pixels + ue = expr_eval( + "where((a<_v), _nan, 1)", + {"a": xx["ue"].data}, + name="get_high_ue", + dtype="float32", + **{"_v": self.ue_threshold, "_nan": np.nan}, + ) xx = xx.drop_vars(["ue"]) + + # Pick out the fc pixels that have an unmixing error of less than the threshold + valid = expr_eval( + "where(a&(b!=b), 1, 0)", + {"a": valid.data, "b": ue}, + name="get_valid_pixels", + dtype="bool", + ) + valid = xr.DataArray(valid, dims=xx["pv"].dims, coords=xx["pv"].coords) + xx = keep_good_only(xx, valid, nodata=NODATA) xx = to_float(xx, dtype="float32") - - xx["wet"] = xr.DataArray(data, dims=wet.dims, coords=wet.coords) + xx["wet"] = xr.DataArray(wet, dims=xx["pv"].dims, coords=xx["pv"].coords) + xx["ue"] = xr.DataArray(ue, dims=xx["pv"].dims, coords=xx["pv"].coords) return xx def fuser(self, xx): wet = xx["wet"] + ue = xx["ue"] - xx = _xr_fuse(xx.drop_vars(["wet"]), partial(_fuse_mean_np, nodata=np.nan), "") + xx = _xr_fuse( + xx.drop_vars(["wet", "ue"]), + partial(_fuse_mean_np, nodata=np.nan), + "fuse_fc", + ) xx["wet"] = _nodata_fuser(wet, nodata=np.nan) + xx["ue"] = _nodata_fuser(ue, nodata=np.nan) return xx def _veg_or_not(self, xx: xr.Dataset): - # either pv or npv > bs: 1 + # pv or npv > bs: 1 # otherwise 0 data = expr_eval( "where((a>b)|(c>b), 1, 0)", - {"a": xx["pv"].data, "c": xx["npv"].data, "b": xx["bs"].data}, + { + "a": xx["pv"].data, + "c": xx["npv"].data, + "b": xx["bs"].data, + "d": xx["ue"].data, + }, name="get_veg", dtype="uint8", ) - # mark nans + # mark nans only if not valid & low ue + # if any high ue (c is not nan): 0 data = expr_eval( - "where(a!=a, nodata, b)", - {"a": xx["pv"].data, "b": data}, + "where((a!=a)&(c!=c), nodata, b)", + {"a": xx["pv"].data, "c": xx["ue"].data, "b": data}, name="get_veg", dtype="uint8", **{"nodata": int(NODATA)}, diff --git a/tests/test_landcover_plugin_a0.py b/tests/test_landcover_plugin_a0.py index 2e04b15..a013bd4 100644 --- a/tests/test_landcover_plugin_a0.py +++ b/tests/test_landcover_plugin_a0.py @@ -409,14 +409,15 @@ def test_reduce(fc_wo_dataset): xx = stats_veg.reduce(xx).compute() expected_value = np.array( [ - [1, 255, 0, 255, 255, 255, 255], - [1, 255, 255, 255, 255, 255, 255], - [255, 0, 255, 255, 1, 255, 255], - [255, 255, 1, 255, 255, 255, 255], - [255, 255, 255, 255, 255, 255, 255], - [255, 1, 255, 255, 255, 255, 255], - [255, 255, 255, 255, 255, 255, 0], - ] + [1, 0, 0, 0, 0, 0, 0], + [1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0], + [0, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + ], + dtype="uint8", ) assert (xx.veg_frequency.data == expected_value).all() @@ -430,7 +431,8 @@ def test_reduce(fc_wo_dataset): [255, 255, 255, 255, 255, 255, 255], [255, 0, 255, 255, 255, 0, 255], [255, 255, 255, 0, 255, 255, 1], - ] + ], + dtype="uint8", ) assert (xx.water_frequency.data == expected_value).all()