Skip to content

Commit

Permalink
normalize water frequency
Browse files Browse the repository at this point in the history
  • Loading branch information
Emma Ai committed Dec 3, 2024
1 parent 6c33ccd commit af02544
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 92 deletions.
176 changes: 101 additions & 75 deletions odc/stats/plugins/lc_fc_wo_a0.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,102 +52,117 @@ def native_transform(self, xx):
5. Drop the WOfS band
"""

# clear and dry pixels not mask against bit 4: terrain high slope,
# valid and dry pixels not mask against bit 4: terrain high slope,
# bit 3: terrain shadow, and
# bit 2: low solar angle
valid = (xx["water"] & ~((1 << 4) | (1 << 3) | (1 << 2))) == 0
valid = (xx["water"].data & ~((1 << 4) | (1 << 3) | (1 << 2))) == 0

# clear and wet pixels not mask against bit 2: low solar angle
wet = (xx["water"] & ~(1 << 2)) == 128
# clear wet pixels not mask against bit 2: low solar angle
wet = (xx["water"].data & ~(1 << 2)) == 128

# dilate both 'valid' and 'water'
for key, val in self.BAD_BITS_MASK.items():
if self.cloud_filters.get(key) is not None:
raw_mask = (xx["water"] & val) > 0
raw_mask = mask_cleanup(
raw_mask, mask_filters=self.cloud_filters.get(key)
)
valid &= ~raw_mask
wet &= ~raw_mask

xx = xx.drop_vars(["water"])
# clear dry pixels
clear = xx["water"].data == 0

# get valid wo pixels, both dry and wet
wet = expr_eval(
# get "valid" wo pixels, both dry and wet used in veg_frequency
wet_valid = expr_eval(
"where(a|b, a, _nan)",
{"a": wet.data, "b": valid.data},
{"a": wet, "b": valid},
name="get_valid_pixels",
dtype="float32",
**{"_nan": np.nan},
)

# pick all valid fc pixels
xx = keep_good_only(xx, valid, nodata=NODATA)
xx = to_float(xx, dtype="float32")

# get high ue valid pixels
ue = expr_eval(
"where(a>=_v, 1, _nan)",
{"a": xx["ue"].data},
name="get_high_ue",
# get "clear" wo pixels, both dry and wet used in water_frequency
wet_clear = expr_eval(
"where(a|b, a, _nan)",
{"a": wet, "b": clear},
name="get_clear_pixels",
dtype="float32",
**{
"_v": self.ue_threshold,
"_nan": np.nan,
},
**{"_nan": np.nan},
)

# get low ue valid pixels
# dilate both 'valid' and 'water'
for key, val in self.BAD_BITS_MASK.items():
if self.cloud_filters.get(key) is not None:
raw_mask = (xx["water"] & val) > 0
raw_mask = mask_cleanup(
raw_mask, mask_filters=self.cloud_filters.get(key)
)
valid = expr_eval(
"where(b>0, 0, a)",
{"a": valid, "b": raw_mask.data},
name="get_valid_pixels",
dtype="uint8",
)
wet_clear = expr_eval(
"where(b>0, _nan, a)",
{"a": wet_clear, "b": raw_mask.data},
name="get_lear_pixels",
dtype="float32",
**{"_nan": np.nan},
)
wet_valid = expr_eval(
"where(b>0, _nan, a)",
{"a": wet_valid, "b": raw_mask.data},
name="get_valid_pixels",
dtype="float32",
**{"_nan": np.nan},
)
xx = xx.drop_vars(["water"])

# Pick out the fc pixels that have an unmixing error of less than the threshold
valid = expr_eval(
"where(b<_v, 1, 0)",
{"b": xx["ue"].data},
name="get_valid_pixels",
"where(b<_v, a, 0)",
{"a": valid, "b": xx["ue"].data},
name="get_low_ue",
dtype="bool",
**{"_v": self.ue_threshold},
)
xx = xx.drop_vars(["ue"])
valid = xr.DataArray(valid, dims=xx["pv"].dims, coords=xx["pv"].coords)
xx = keep_good_only(xx, valid, nodata=np.nan)

xx["wet"] = xr.DataArray(wet, dims=xx["pv"].dims, coords=xx["pv"].coords)
xx["ue"] = xr.DataArray(ue, dims=xx["pv"].dims, coords=xx["pv"].coords)
xx = keep_good_only(xx, valid, nodata=NODATA)
xx = to_float(xx, dtype="float32")

xx["wet_valid"] = xr.DataArray(
wet_valid, dims=xx["pv"].dims, coords=xx["pv"].coords
)
xx["wet_clear"] = xr.DataArray(
wet_clear, dims=xx["pv"].dims, coords=xx["pv"].coords
)

return xx

def fuser(self, xx):

wet = xx["wet"]
ue = xx["ue"]
wet_valid = xx["wet_valid"]
wet_clear = xx["wet_clear"]

xx = _xr_fuse(
xx.drop_vars(["wet", "ue"]),
xx.drop_vars(["wet_valid", "wet_clear"]),
partial(_fuse_mean_np, nodata=np.nan),
"fuse_fc",
"",
)

xx["wet"] = _nodata_fuser(wet, nodata=np.nan)
xx["ue"] = _nodata_fuser(ue, nodata=np.nan)
xx["wet_valid"] = _nodata_fuser(wet_valid, nodata=np.nan)
xx["wet_clear"] = _nodata_fuser(wet_clear, nodata=np.nan)

return xx

def _veg_or_not(self, xx: xr.Dataset):
# pv or npv > bs: 1
# either pv or npv > bs: 1
# otherwise 0
data = expr_eval(
"where((a>b)|(c>b), 1, 0)",
{
"a": xx["pv"].data,
"c": xx["npv"].data,
"b": xx["bs"].data,
},
{"a": xx["pv"].data, "c": xx["npv"].data, "b": xx["bs"].data},
name="get_veg",
dtype="uint8",
)

# mark nans only if not valid & low ue
# if any high ue valid (ue is not nan): 0
# mark nans
data = expr_eval(
"where((a!=a)&(c!=c), nodata, b)",
{"a": xx["pv"].data, "c": xx["ue"].data, "b": data},
"where(a!=a, nodata, b)",
{"a": xx["pv"].data, "b": data},
name="get_veg",
dtype="uint8",
**{"nodata": int(NODATA)},
Expand All @@ -156,7 +171,7 @@ def _veg_or_not(self, xx: xr.Dataset):
# mark water freq >= 0.5 as 0
data = expr_eval(
"where(a>0, 0, b)",
{"a": xx["wet"].data, "b": data},
{"a": xx["wet_valid"].data, "b": data},
name="get_veg",
dtype="uint8",
)
Expand All @@ -167,25 +182,25 @@ def _water_or_not(self, xx: xr.Dataset):
# mark water freq > 0.5 as 1
data = expr_eval(
"where(a>0.5, 1, 0)",
{"a": xx["wet"].data},
{"a": xx["wet_clear"].data},
name="get_water",
dtype="uint8",
)

# mark nans
data = expr_eval(
"where(a!=a, nodata, b)",
{"a": xx["wet"].data, "b": data},
{"a": xx["wet_clear"].data, "b": data},
name="get_water",
dtype="uint8",
**{"nodata": int(NODATA)},
)
return data

def _max_consecutive_months(self, data, nodata):
nan_mask = da.ones(data.shape[1:], chunks=data.chunks[1:], dtype="bool")
def _max_consecutive_months(self, data, nodata, normalize=False):
tmp = da.zeros(data.shape[1:], chunks=data.chunks[1:], dtype="uint8")
max_count = da.zeros(data.shape[1:], chunks=data.chunks[1:], dtype="uint8")
total = da.zeros(data.shape[1:], chunks=data.chunks[1:], dtype="uint8")

for t in data:
# +1 if not nodata
Expand Down Expand Up @@ -213,23 +228,34 @@ def _max_consecutive_months(self, data, nodata):
dtype="uint8",
)

# mark nodata
nan_mask = expr_eval(
"where(a==nodata, b, False)",
{"a": t, "b": nan_mask},
name="mark_nodata",
dtype="bool",
# total valid
total = expr_eval(
"where(a==nodata, b, b+1)",
{"a": t, "b": total},
name="get_total_valid",
dtype="uint8",
**{"nodata": nodata},
)

# mark nodata
max_count = expr_eval(
"where(a, nodata, b)",
{"a": nan_mask, "b": max_count},
name="mark_nodata",
dtype="uint8",
**{"nodata": int(nodata)},
)
if normalize:
max_count = expr_eval(
"where(a<=0, nodata, b/a*12)",
{"a": total, "b": max_count},
name="normalize_max_count",
dtype="float32",
**{"nodata": int(nodata)},
)
max_count = da.ceil(max_count).astype("uint8")
else:
max_count = expr_eval(
"where(a<=0, nodata, b)",
{"a": total, "b": max_count},
name="mark_nodata",
dtype="uint8",
**{"nodata": int(nodata)},
)

return max_count

def reduce(self, xx: xr.Dataset) -> xr.Dataset:
Expand All @@ -240,15 +266,15 @@ def reduce(self, xx: xr.Dataset) -> xr.Dataset:
max_count_veg = self._max_consecutive_months(data, NODATA)

data = self._water_or_not(xx)
max_count_water = self._max_consecutive_months(data, NODATA)
max_count_water = self._max_consecutive_months(data, NODATA, normalize=True)

attrs = xx.attrs.copy()
attrs["nodata"] = int(NODATA)
data_vars = {
k: xr.DataArray(v, dims=xx["wet"].dims[1:], attrs=attrs)
k: xr.DataArray(v, dims=xx["pv"].dims[1:], attrs=attrs)
for k, v in zip(self.measurements, [max_count_veg, max_count_water])
}
coords = dict((dim, xx.coords[dim]) for dim in xx["wet"].dims[1:])
coords = dict((dim, xx.coords[dim]) for dim in xx["pv"].dims[1:])
return xr.Dataset(data_vars=data_vars, coords=coords, attrs=xx.attrs)


Expand Down
33 changes: 16 additions & 17 deletions tests/test_landcover_plugin_a0.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,12 +327,12 @@ def test_native_transform(fc_wo_dataset, bits):
np.array([1, 1, 3, 5, 6, 2, 6, 2, 2, 5, 6, 0, 0, 2, 3]),
np.array([0, 3, 2, 1, 3, 5, 6, 1, 4, 5, 6, 0, 2, 4, 2]),
)
result = np.where(out_xx["wet"].data == out_xx["wet"].data)
result = np.where(out_xx["wet_valid"].data == out_xx["wet_valid"].data)
for a, b in zip(expected_valid, result):
assert (a == b).all()

expected_valid = (np.array([1, 2, 3]), np.array([6, 2, 0]), np.array([6, 1, 2]))
result = np.where(out_xx["wet"].data == 1)
result = np.where(out_xx["wet_valid"].data == 1)

for a, b in zip(expected_valid, result):
assert (a == b).all()
Expand Down Expand Up @@ -391,11 +391,11 @@ def test_water_or_not(fc_wo_dataset):
xx = xx.groupby("solar_day").map(partial(StatsVegCount.fuser, None))
yy = stats_veg._water_or_not(xx).compute()
valid_index = (
np.array([0, 0, 0, 0, 0, 1, 1, 2, 2, 2, 2, 2, 2, 2]),
np.array([1, 1, 3, 5, 6, 2, 6, 0, 0, 2, 2, 3, 5, 6]),
np.array([0, 3, 2, 1, 3, 5, 6, 0, 2, 1, 4, 2, 5, 6]),
np.array([0, 0, 1, 1, 2, 2, 2]),
np.array([3, 6, 2, 6, 0, 2, 2]),
np.array([2, 3, 5, 6, 2, 1, 4]),
)
expected_value = np.array([0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0])
expected_value = np.array([0, 0, 0, 1, 1, 1, 0])
i = 0
for idx in zip(*valid_index):
assert yy[idx] == expected_value[i]
Expand All @@ -410,27 +410,26 @@ def test_reduce(fc_wo_dataset):
expected_value = np.array(
[
[1, 255, 0, 255, 255, 255, 255],
[1, 255, 255, 0, 255, 255, 255],
[255, 0, 255, 255, 1, 0, 255],
[1, 255, 255, 255, 255, 255, 255],
[255, 0, 255, 255, 1, 255, 255],
[255, 255, 1, 255, 255, 255, 255],
[255, 255, 255, 255, 255, 255, 255],
[255, 1, 255, 255, 255, 0, 255],
[255, 255, 255, 0, 255, 255, 0],
],
dtype="uint8",
[255, 1, 255, 255, 255, 255, 255],
[255, 255, 255, 255, 255, 255, 0],
]
)

assert (xx.veg_frequency.data == expected_value).all()

expected_value = np.array(
[
[0, 255, 1, 255, 255, 255, 255],
[0, 255, 255, 0, 255, 255, 255],
[255, 1, 255, 255, 0, 0, 255],
[255, 255, 12, 255, 255, 255, 255],
[255, 255, 255, 255, 255, 255, 255],
[255, 12, 255, 255, 0, 0, 255],
[255, 255, 0, 255, 255, 255, 255],
[255, 255, 255, 255, 255, 255, 255],
[255, 0, 255, 255, 255, 0, 255],
[255, 255, 255, 0, 255, 255, 1],
[255, 255, 255, 255, 255, 255, 255],
[255, 255, 255, 0, 255, 255, 12],
],
dtype="uint8",
)
Expand Down

0 comments on commit af02544

Please sign in to comment.