Skip to content

Commit

Permalink
correct veg/non logic in landcover veg frequency (#176)
Browse files Browse the repository at this point in the history
* correct veg/non logic

* fix veg frequency test

---------

Co-authored-by: Emma Ai <[email protected]>
  • Loading branch information
emmaai and Emma Ai authored Dec 6, 2024
1 parent d04b1c0 commit 97df5ad
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 56 deletions.
60 changes: 28 additions & 32 deletions odc/stats/plugins/lc_fc_wo_a0.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,14 @@ class StatsVegCount(StatsPluginInterface):
def __init__(
self,
ue_threshold: Optional[int] = None,
veg_threshold: Optional[int] = None,
cloud_filters: Dict[str, Iterable[Tuple[str, int]]] = None,
**kwargs,
):
super().__init__(input_bands=["water", "pv", "bs", "npv", "ue"], **kwargs)

self.ue_threshold = ue_threshold if ue_threshold is not None else 30
self.veg_threshold = veg_threshold if veg_threshold is not None else 2
self.cloud_filters = cloud_filters if cloud_filters is not None else {}

def native_transform(self, xx):
Expand All @@ -63,15 +65,6 @@ def native_transform(self, xx):
# clear dry pixels
clear = xx["water"].data == 0

# get "valid" wo pixels, both dry and wet used in veg_frequency
wet_valid = expr_eval(
"where(a|b, a, _nan)",
{"a": wet, "b": valid},
name="get_valid_pixels",
dtype="float32",
**{"_nan": np.nan},
)

# get "clear" wo pixels, both dry and wet used in water_frequency
wet_clear = expr_eval(
"where(a|b, a, _nan)",
Expand Down Expand Up @@ -101,13 +94,7 @@ def native_transform(self, xx):
dtype="float32",
**{"_nan": np.nan},
)
wet_valid = expr_eval(
"where(b>0, _nan, a)",
{"a": wet_valid, "b": raw_mask.data},
name="get_valid_pixels",
dtype="float32",
**{"_nan": np.nan},
)

xx = xx.drop_vars(["water"])

# Pick out the fc pixels that have an unmixing error of less than the threshold
Expand All @@ -124,9 +111,6 @@ def native_transform(self, xx):
xx = keep_good_only(xx, valid, nodata=NODATA)
xx = to_float(xx, dtype="float32")

xx["wet_valid"] = xr.DataArray(
wet_valid, dims=xx["pv"].dims, coords=xx["pv"].coords
)
xx["wet_clear"] = xr.DataArray(
wet_clear, dims=xx["pv"].dims, coords=xx["pv"].coords
)
Expand All @@ -135,16 +119,14 @@ def native_transform(self, xx):

def fuser(self, xx):

wet_valid = xx["wet_valid"]
wet_clear = xx["wet_clear"]

xx = _xr_fuse(
xx.drop_vars(["wet_valid", "wet_clear"]),
xx.drop_vars(["wet_clear"]),
partial(_fuse_mean_np, nodata=np.nan),
"",
)

xx["wet_valid"] = _nodata_fuser(wet_valid, nodata=np.nan)
xx["wet_clear"] = _nodata_fuser(wet_clear, nodata=np.nan)

return xx
Expand All @@ -168,14 +150,6 @@ def _veg_or_not(self, xx: xr.Dataset):
**{"nodata": int(NODATA)},
)

# mark water freq >= 0.5 as 0
data = expr_eval(
"where(a>0, 0, b)",
{"a": xx["wet_valid"].data, "b": data},
name="get_veg",
dtype="uint8",
)

return data

def _water_or_not(self, xx: xr.Dataset):
Expand Down Expand Up @@ -262,8 +236,30 @@ def reduce(self, xx: xr.Dataset) -> xr.Dataset:

xx = xx.groupby("time.month").map(median_ds, dim="spec")

data = self._veg_or_not(xx)
max_count_veg = self._max_consecutive_months(data, NODATA)
# consecutive observation of veg
veg_data = self._veg_or_not(xx)
max_count_veg = self._max_consecutive_months(veg_data, NODATA)

# consecutive observation of non-veg
non_veg_data = expr_eval(
"where(a<nodata, 1-a, nodata)",
{"a": veg_data},
name="invert_veg",
dtype="uint8",
**{"nodata": NODATA},
)
max_count_non_veg = self._max_consecutive_months(non_veg_data, NODATA)

# non-veg < threshold implies veg >= threshold
# implies any "wet" area potentially veg

max_count_veg = expr_eval(
"where((a<_v)&(b<_v), _v, b)",
{"a": max_count_non_veg, "b": max_count_veg},
name="clip_veg",
dtype="uint8",
**{"_v": self.veg_threshold},
)

data = self._water_or_not(xx)
max_count_water = self._max_consecutive_months(data, NODATA, normalize=True)
Expand Down
8 changes: 5 additions & 3 deletions odc/stats/plugins/lc_ml_treelite.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@ def __init__(
self.dask_worker_plugin = TreeliteModelPlugin(model_path)
self.output_classes = output_classes
self.mask_bands = mask_bands
self.temporal_coverage = temporal_coverage
self.temporal_coverage = (
temporal_coverage if temporal_coverage is not None else {}
)
self._log = logging.getLogger(__name__)

def input_data(
Expand Down Expand Up @@ -160,8 +162,8 @@ def convert_dtype(var):

for var in xx.data_vars:
if var not in self.mask_bands:
if self.temporal_coverage is not None:
# filter and impute by sensors
# filter and impute by sensors
if self.temporal_coverage.get(var) is not None:
temporal_range = [
DateTimeRange(v) for v in self.temporal_coverage.get(var)
]
Expand Down
34 changes: 13 additions & 21 deletions tests/test_landcover_plugin_a0.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,17 +322,8 @@ def test_native_transform(fc_wo_dataset, bits):
stats_veg = StatsVegCount(measurements=["veg_frequency", "water_frequency"])
out_xx = stats_veg.native_transform(xx).compute()

expected_valid = (
np.array([0, 0, 0, 0, 0, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3]),
np.array([1, 1, 3, 5, 6, 2, 6, 2, 2, 5, 6, 0, 0, 2, 3]),
np.array([0, 3, 2, 1, 3, 5, 6, 1, 4, 5, 6, 0, 2, 4, 2]),
)
result = np.where(out_xx["wet_valid"].data == out_xx["wet_valid"].data)
for a, b in zip(expected_valid, result):
assert (a == b).all()

expected_valid = (np.array([1, 2, 3]), np.array([6, 2, 0]), np.array([6, 1, 2]))
result = np.where(out_xx["wet_valid"].data == 1)
result = np.where(out_xx["wet_clear"].data == 1)

for a, b in zip(expected_valid, result):
assert (a == b).all()
Expand Down Expand Up @@ -374,11 +365,11 @@ def test_veg_or_not(fc_wo_dataset):
xx = xx.groupby("solar_day").map(partial(StatsVegCount.fuser, None))
yy = stats_veg._veg_or_not(xx).compute()
valid_index = (
np.array([0, 0, 1, 2, 2, 2, 2, 2]),
np.array([1, 5, 6, 0, 0, 2, 2, 3]),
np.array([0, 1, 6, 0, 2, 1, 4, 2]),
np.array([0, 0, 2, 2, 2]),
np.array([1, 5, 0, 2, 3]),
np.array([0, 1, 0, 4, 2]),
)
expected_value = np.array([1, 1, 0, 1, 0, 0, 1, 1])
expected_value = np.array([1, 1, 1, 1, 1])
i = 0
for idx in zip(*valid_index):
assert yy[idx] == expected_value[i]
Expand Down Expand Up @@ -409,14 +400,15 @@ def test_reduce(fc_wo_dataset):
xx = stats_veg.reduce(xx).compute()
expected_value = np.array(
[
[1, 255, 0, 255, 255, 255, 255],
[1, 255, 255, 255, 255, 255, 255],
[255, 0, 255, 255, 1, 255, 255],
[255, 255, 1, 255, 255, 255, 255],
[2, 255, 255, 255, 255, 255, 255],
[2, 255, 255, 255, 255, 255, 255],
[255, 255, 255, 255, 2, 255, 255],
[255, 255, 2, 255, 255, 255, 255],
[255, 255, 255, 255, 255, 255, 255],
[255, 1, 255, 255, 255, 255, 255],
[255, 255, 255, 255, 255, 255, 0],
]
[255, 2, 255, 255, 255, 255, 255],
[255, 255, 255, 255, 255, 255, 255],
],
dtype="uint8",
)

assert (xx.veg_frequency.data == expected_value).all()
Expand Down

0 comments on commit 97df5ad

Please sign in to comment.