From af0254494d2471ad5f5c28c4c1e2ebdaa65780d8 Mon Sep 17 00:00:00 2001 From: Emma Ai Date: Tue, 3 Dec 2024 14:08:48 +0000 Subject: [PATCH] normalize water frequency --- odc/stats/plugins/lc_fc_wo_a0.py | 176 +++++++++++++++++------------- tests/test_landcover_plugin_a0.py | 33 +++--- 2 files changed, 117 insertions(+), 92 deletions(-) diff --git a/odc/stats/plugins/lc_fc_wo_a0.py b/odc/stats/plugins/lc_fc_wo_a0.py index 3a396fb..e47615b 100644 --- a/odc/stats/plugins/lc_fc_wo_a0.py +++ b/odc/stats/plugins/lc_fc_wo_a0.py @@ -52,102 +52,117 @@ def native_transform(self, xx): 5. Drop the WOfS band """ - # clear and dry pixels not mask against bit 4: terrain high slope, + # valid and dry pixels not mask against bit 4: terrain high slope, # bit 3: terrain shadow, and # bit 2: low solar angle - valid = (xx["water"] & ~((1 << 4) | (1 << 3) | (1 << 2))) == 0 + valid = (xx["water"].data & ~((1 << 4) | (1 << 3) | (1 << 2))) == 0 - # clear and wet pixels not mask against bit 2: low solar angle - wet = (xx["water"] & ~(1 << 2)) == 128 + # clear wet pixels not mask against bit 2: low solar angle + wet = (xx["water"].data & ~(1 << 2)) == 128 - # dilate both 'valid' and 'water' - for key, val in self.BAD_BITS_MASK.items(): - if self.cloud_filters.get(key) is not None: - raw_mask = (xx["water"] & val) > 0 - raw_mask = mask_cleanup( - raw_mask, mask_filters=self.cloud_filters.get(key) - ) - valid &= ~raw_mask - wet &= ~raw_mask - - xx = xx.drop_vars(["water"]) + # clear dry pixels + clear = xx["water"].data == 0 - # get valid wo pixels, both dry and wet - wet = expr_eval( + # get "valid" wo pixels, both dry and wet used in veg_frequency + wet_valid = expr_eval( "where(a|b, a, _nan)", - {"a": wet.data, "b": valid.data}, + {"a": wet, "b": valid}, name="get_valid_pixels", dtype="float32", **{"_nan": np.nan}, ) - # pick all valid fc pixels - xx = keep_good_only(xx, valid, nodata=NODATA) - xx = to_float(xx, dtype="float32") - - # get high ue valid pixels - ue = expr_eval( - "where(a>=_v, 1, _nan)", - {"a": xx["ue"].data}, - name="get_high_ue", + # get "clear" wo pixels, both dry and wet used in water_frequency + wet_clear = expr_eval( + "where(a|b, a, _nan)", + {"a": wet, "b": clear}, + name="get_clear_pixels", dtype="float32", - **{ - "_v": self.ue_threshold, - "_nan": np.nan, - }, + **{"_nan": np.nan}, ) - # get low ue valid pixels + # dilate both 'valid' and 'water' + for key, val in self.BAD_BITS_MASK.items(): + if self.cloud_filters.get(key) is not None: + raw_mask = (xx["water"] & val) > 0 + raw_mask = mask_cleanup( + raw_mask, mask_filters=self.cloud_filters.get(key) + ) + valid = expr_eval( + "where(b>0, 0, a)", + {"a": valid, "b": raw_mask.data}, + name="get_valid_pixels", + dtype="uint8", + ) + wet_clear = expr_eval( + "where(b>0, _nan, a)", + {"a": wet_clear, "b": raw_mask.data}, + name="get_lear_pixels", + dtype="float32", + **{"_nan": np.nan}, + ) + wet_valid = expr_eval( + "where(b>0, _nan, a)", + {"a": wet_valid, "b": raw_mask.data}, + name="get_valid_pixels", + dtype="float32", + **{"_nan": np.nan}, + ) + xx = xx.drop_vars(["water"]) + + # Pick out the fc pixels that have an unmixing error of less than the threshold valid = expr_eval( - "where(b<_v, 1, 0)", - {"b": xx["ue"].data}, - name="get_valid_pixels", + "where(b<_v, a, 0)", + {"a": valid, "b": xx["ue"].data}, + name="get_low_ue", dtype="bool", **{"_v": self.ue_threshold}, ) xx = xx.drop_vars(["ue"]) valid = xr.DataArray(valid, dims=xx["pv"].dims, coords=xx["pv"].coords) - xx = keep_good_only(xx, valid, nodata=np.nan) - xx["wet"] = xr.DataArray(wet, dims=xx["pv"].dims, coords=xx["pv"].coords) - xx["ue"] = xr.DataArray(ue, dims=xx["pv"].dims, coords=xx["pv"].coords) + xx = keep_good_only(xx, valid, nodata=NODATA) + xx = to_float(xx, dtype="float32") + + xx["wet_valid"] = xr.DataArray( + wet_valid, dims=xx["pv"].dims, coords=xx["pv"].coords + ) + xx["wet_clear"] = xr.DataArray( + wet_clear, dims=xx["pv"].dims, coords=xx["pv"].coords + ) + return xx def fuser(self, xx): - wet = xx["wet"] - ue = xx["ue"] + wet_valid = xx["wet_valid"] + wet_clear = xx["wet_clear"] xx = _xr_fuse( - xx.drop_vars(["wet", "ue"]), + xx.drop_vars(["wet_valid", "wet_clear"]), partial(_fuse_mean_np, nodata=np.nan), - "fuse_fc", + "", ) - xx["wet"] = _nodata_fuser(wet, nodata=np.nan) - xx["ue"] = _nodata_fuser(ue, nodata=np.nan) + xx["wet_valid"] = _nodata_fuser(wet_valid, nodata=np.nan) + xx["wet_clear"] = _nodata_fuser(wet_clear, nodata=np.nan) return xx def _veg_or_not(self, xx: xr.Dataset): - # pv or npv > bs: 1 + # either pv or npv > bs: 1 # otherwise 0 data = expr_eval( "where((a>b)|(c>b), 1, 0)", - { - "a": xx["pv"].data, - "c": xx["npv"].data, - "b": xx["bs"].data, - }, + {"a": xx["pv"].data, "c": xx["npv"].data, "b": xx["bs"].data}, name="get_veg", dtype="uint8", ) - # mark nans only if not valid & low ue - # if any high ue valid (ue is not nan): 0 + # mark nans data = expr_eval( - "where((a!=a)&(c!=c), nodata, b)", - {"a": xx["pv"].data, "c": xx["ue"].data, "b": data}, + "where(a!=a, nodata, b)", + {"a": xx["pv"].data, "b": data}, name="get_veg", dtype="uint8", **{"nodata": int(NODATA)}, @@ -156,7 +171,7 @@ def _veg_or_not(self, xx: xr.Dataset): # mark water freq >= 0.5 as 0 data = expr_eval( "where(a>0, 0, b)", - {"a": xx["wet"].data, "b": data}, + {"a": xx["wet_valid"].data, "b": data}, name="get_veg", dtype="uint8", ) @@ -167,7 +182,7 @@ def _water_or_not(self, xx: xr.Dataset): # mark water freq > 0.5 as 1 data = expr_eval( "where(a>0.5, 1, 0)", - {"a": xx["wet"].data}, + {"a": xx["wet_clear"].data}, name="get_water", dtype="uint8", ) @@ -175,17 +190,17 @@ def _water_or_not(self, xx: xr.Dataset): # mark nans data = expr_eval( "where(a!=a, nodata, b)", - {"a": xx["wet"].data, "b": data}, + {"a": xx["wet_clear"].data, "b": data}, name="get_water", dtype="uint8", **{"nodata": int(NODATA)}, ) return data - def _max_consecutive_months(self, data, nodata): - nan_mask = da.ones(data.shape[1:], chunks=data.chunks[1:], dtype="bool") + def _max_consecutive_months(self, data, nodata, normalize=False): tmp = da.zeros(data.shape[1:], chunks=data.chunks[1:], dtype="uint8") max_count = da.zeros(data.shape[1:], chunks=data.chunks[1:], dtype="uint8") + total = da.zeros(data.shape[1:], chunks=data.chunks[1:], dtype="uint8") for t in data: # +1 if not nodata @@ -213,23 +228,34 @@ def _max_consecutive_months(self, data, nodata): dtype="uint8", ) - # mark nodata - nan_mask = expr_eval( - "where(a==nodata, b, False)", - {"a": t, "b": nan_mask}, - name="mark_nodata", - dtype="bool", + # total valid + total = expr_eval( + "where(a==nodata, b, b+1)", + {"a": t, "b": total}, + name="get_total_valid", + dtype="uint8", **{"nodata": nodata}, ) # mark nodata - max_count = expr_eval( - "where(a, nodata, b)", - {"a": nan_mask, "b": max_count}, - name="mark_nodata", - dtype="uint8", - **{"nodata": int(nodata)}, - ) + if normalize: + max_count = expr_eval( + "where(a<=0, nodata, b/a*12)", + {"a": total, "b": max_count}, + name="normalize_max_count", + dtype="float32", + **{"nodata": int(nodata)}, + ) + max_count = da.ceil(max_count).astype("uint8") + else: + max_count = expr_eval( + "where(a<=0, nodata, b)", + {"a": total, "b": max_count}, + name="mark_nodata", + dtype="uint8", + **{"nodata": int(nodata)}, + ) + return max_count def reduce(self, xx: xr.Dataset) -> xr.Dataset: @@ -240,15 +266,15 @@ def reduce(self, xx: xr.Dataset) -> xr.Dataset: max_count_veg = self._max_consecutive_months(data, NODATA) data = self._water_or_not(xx) - max_count_water = self._max_consecutive_months(data, NODATA) + max_count_water = self._max_consecutive_months(data, NODATA, normalize=True) attrs = xx.attrs.copy() attrs["nodata"] = int(NODATA) data_vars = { - k: xr.DataArray(v, dims=xx["wet"].dims[1:], attrs=attrs) + k: xr.DataArray(v, dims=xx["pv"].dims[1:], attrs=attrs) for k, v in zip(self.measurements, [max_count_veg, max_count_water]) } - coords = dict((dim, xx.coords[dim]) for dim in xx["wet"].dims[1:]) + coords = dict((dim, xx.coords[dim]) for dim in xx["pv"].dims[1:]) return xr.Dataset(data_vars=data_vars, coords=coords, attrs=xx.attrs) diff --git a/tests/test_landcover_plugin_a0.py b/tests/test_landcover_plugin_a0.py index 5ed1d92..095f7ca 100644 --- a/tests/test_landcover_plugin_a0.py +++ b/tests/test_landcover_plugin_a0.py @@ -327,12 +327,12 @@ def test_native_transform(fc_wo_dataset, bits): np.array([1, 1, 3, 5, 6, 2, 6, 2, 2, 5, 6, 0, 0, 2, 3]), np.array([0, 3, 2, 1, 3, 5, 6, 1, 4, 5, 6, 0, 2, 4, 2]), ) - result = np.where(out_xx["wet"].data == out_xx["wet"].data) + result = np.where(out_xx["wet_valid"].data == out_xx["wet_valid"].data) for a, b in zip(expected_valid, result): assert (a == b).all() expected_valid = (np.array([1, 2, 3]), np.array([6, 2, 0]), np.array([6, 1, 2])) - result = np.where(out_xx["wet"].data == 1) + result = np.where(out_xx["wet_valid"].data == 1) for a, b in zip(expected_valid, result): assert (a == b).all() @@ -391,11 +391,11 @@ def test_water_or_not(fc_wo_dataset): xx = xx.groupby("solar_day").map(partial(StatsVegCount.fuser, None)) yy = stats_veg._water_or_not(xx).compute() valid_index = ( - np.array([0, 0, 0, 0, 0, 1, 1, 2, 2, 2, 2, 2, 2, 2]), - np.array([1, 1, 3, 5, 6, 2, 6, 0, 0, 2, 2, 3, 5, 6]), - np.array([0, 3, 2, 1, 3, 5, 6, 0, 2, 1, 4, 2, 5, 6]), + np.array([0, 0, 1, 1, 2, 2, 2]), + np.array([3, 6, 2, 6, 0, 2, 2]), + np.array([2, 3, 5, 6, 2, 1, 4]), ) - expected_value = np.array([0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0]) + expected_value = np.array([0, 0, 0, 1, 1, 1, 0]) i = 0 for idx in zip(*valid_index): assert yy[idx] == expected_value[i] @@ -410,27 +410,26 @@ def test_reduce(fc_wo_dataset): expected_value = np.array( [ [1, 255, 0, 255, 255, 255, 255], - [1, 255, 255, 0, 255, 255, 255], - [255, 0, 255, 255, 1, 0, 255], + [1, 255, 255, 255, 255, 255, 255], + [255, 0, 255, 255, 1, 255, 255], [255, 255, 1, 255, 255, 255, 255], [255, 255, 255, 255, 255, 255, 255], - [255, 1, 255, 255, 255, 0, 255], - [255, 255, 255, 0, 255, 255, 0], - ], - dtype="uint8", + [255, 1, 255, 255, 255, 255, 255], + [255, 255, 255, 255, 255, 255, 0], + ] ) assert (xx.veg_frequency.data == expected_value).all() expected_value = np.array( [ - [0, 255, 1, 255, 255, 255, 255], - [0, 255, 255, 0, 255, 255, 255], - [255, 1, 255, 255, 0, 0, 255], + [255, 255, 12, 255, 255, 255, 255], + [255, 255, 255, 255, 255, 255, 255], + [255, 12, 255, 255, 0, 0, 255], [255, 255, 0, 255, 255, 255, 255], [255, 255, 255, 255, 255, 255, 255], - [255, 0, 255, 255, 255, 0, 255], - [255, 255, 255, 0, 255, 255, 1], + [255, 255, 255, 255, 255, 255, 255], + [255, 255, 255, 0, 255, 255, 12], ], dtype="uint8", )