Skip to content

Commit

Permalink
move water_season from level1 to condition of wo frequency in level34
Browse files Browse the repository at this point in the history
  • Loading branch information
Emma Ai committed Dec 2, 2024
1 parent 30b6f13 commit 52b842e
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 125 deletions.
44 changes: 3 additions & 41 deletions odc/stats/plugins/lc_veg_class_a1.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ def __init__(
saltpan_threshold: Optional[int] = None,
water_threshold: Optional[float] = None,
veg_threshold: Optional[int] = None,
water_seasonality_threshold: Optional[float] = None,
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -70,9 +69,6 @@ def __init__(
)
self.water_threshold = water_threshold if water_threshold is not None else 0.2
self.veg_threshold = veg_threshold if veg_threshold is not None else 2
self.water_seasonality_threshold = (
water_seasonality_threshold if water_seasonality_threshold else 0.25
)
self.output_classes = output_classes

def fuser(self, xx):
Expand Down Expand Up @@ -191,49 +187,15 @@ def l3_class(self, xx: xr.Dataset):
**{"nodata": NODATA},
)

# Now add the water frequency
# Divide water frequency into following classes:
# 0 --> 0
# (0,0.25] --> 1
# (0.25,1] --> 2

water_seasonality = expr_eval(
"where((a > 0) & (a <= wt), 1, a)",
{"a": xx["frequency"].data},
name="mark_wo_fq",
dtype="float32",
**{"wt": self.water_seasonality_threshold},
)

water_seasonality = expr_eval(
"where((a > wt) & (a <= 1), 2, b)",
{"a": xx["frequency"].data, "b": water_seasonality},
name="mark_wo_fq",
dtype="float32",
**{"wt": self.water_seasonality_threshold},
)

water_seasonality = expr_eval(
"where((a != a), nodata, a)",
{
"a": water_seasonality,
},
name="mark_nodata",
dtype="uint8",
**{"nodata": NODATA},
)

return l3_mask, water_seasonality
return l3_mask

def reduce(self, xx: xr.Dataset) -> xr.Dataset:
l3_mask, water_seasonality = self.l3_class(xx)
l3_mask = self.l3_class(xx)
attrs = xx.attrs.copy()
attrs["nodata"] = int(NODATA)
data_vars = {
k: xr.DataArray(v, dims=xx["veg_frequency"].dims[1:], attrs=attrs)
for k, v in zip(
self.measurements, [l3_mask.squeeze(0), water_seasonality.squeeze(0)]
)
for k, v in zip(self.measurements, [l3_mask.squeeze(0)])
}
coords = dict((dim, xx.coords[dim]) for dim in xx["veg_frequency"].dims[1:])
return xr.Dataset(data_vars=data_vars, coords=coords, attrs=xx.attrs)
Expand Down
66 changes: 2 additions & 64 deletions tests/test_landcover_plugin_a1.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def test_l3_classes(dataset):
"surface": 210,
},
optional_bands=["canopy_cover_class", "elevation"],
measurements=["level_3_4", "water_season"],
measurements=["level_3_4"],
)

expected_res = np.array(
Expand All @@ -151,72 +151,10 @@ def test_l3_classes(dataset):
dtype="uint8",
)

res, water_seasonality = stats_l3.l3_class(dataset)
res = stats_l3.l3_class(dataset)
assert (res == expected_res).all()


def test_l4_water_seasonality(dataset):
stats_l3 = StatsVegClassL1(
output_classes={
"aquatic_veg_wood": 124,
"aquatic_veg_herb": 125,
"terrestrial_veg": 110,
"water": 221,
"intertidal": 223,
"surface": 210,
},
optional_bands=["canopy_cover_class", "elevation"],
measurements=["level_3_4", "water_season"],
)

wo_fq = np.array(
[
[
[0.0, 0.021, 0.152, 255],
[0.249, 0.273, 0.252, 0.0375],
[0.302, 0, 0.789, 0.078],
[0.021, 0.243, 255, 0.255],
]
],
dtype="float32",
)
wo_fq = da.from_array(wo_fq, chunks=(1, -1, -1))

dataset["frequency"] = xr.DataArray(
wo_fq, dims=("spec", "y", "x"), attrs={"nodata": np.nan}
)

expected_water_seasonality = np.array(
[
[
[0, 1, 1, 255],
[1, 2, 2, 1],
[2, 0, 2, 1],
[1, 1, 255, 2],
]
],
dtype="float32",
)

res, water_seasonality = stats_l3.l3_class(dataset)
assert np.allclose(water_seasonality, expected_water_seasonality)


def test_reduce(dataset):
stats_l3 = StatsVegClassL1(
output_classes={
"aquatic_veg_wood": 124,
"aquatic_veg_herb": 125,
"terrestrial_veg": 110,
"water": 221,
"intertidal": 223,
"surface": 210,
},
optional_bands=["canopy_cover_class", "elevation"],
measurements=["level_3_4", "water_season"],
)
res = stats_l3.reduce(dataset)

for var in res:
assert res[var].attrs.get("nodata") is not None
if res[var].dtype == "uint8":
Expand Down
40 changes: 20 additions & 20 deletions tests/test_lc_level34.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,16 +104,16 @@ def image_groups():
dtype="uint8",
)

water_season = np.array(
frequency = np.array(
[
[
[1, 2, 1],
[2, 1, 2],
[1, 1, 2],
[2, 2, 1],
[0.1, 0.3, 0.15],
[0.9, 0.1, 0.5],
[0, 0.15, 0.4],
[0.9, 0.8, 0],
]
],
dtype="uint8",
dtype="float32",
)

tuples = [
Expand Down Expand Up @@ -165,10 +165,10 @@ def image_groups():
dims=("spec", "y", "x"),
attrs={"nodata": 255},
),
"water_season": xr.DataArray(
da.from_array(water_season, chunks=(1, -1, -1)),
"frequency": xr.DataArray(
da.from_array(frequency, chunks=(1, -1, -1)),
dims=("spec", "y", "x"),
attrs={"nodata": 255},
attrs={"nodata": "nan"},
),
}

Expand Down Expand Up @@ -199,7 +199,7 @@ def test_l4_classes(image_groups, urban_shape):
"level1",
"level3",
"woody",
"water_season",
"frequency",
"water_frequency",
"pv_pc_50",
"bs_pc_50",
Expand Down Expand Up @@ -354,16 +354,16 @@ def test_level4(urban_shape):
dtype="float32",
)

water_season = np.array(
frequency = np.array(
[
[
[255, 255, 255, 255, 255, 255],
[255, 255, 255, 255, 255, 255],
[255, 255, 255, 255, 255, 255],
[255, 255, 255, 255, 255, 255],
[np.nan, np.nan, np.nan, np.nan, np.nan, np.nan],
[np.nan, np.nan, np.nan, np.nan, np.nan, np.nan],
[np.nan, np.nan, np.nan, np.nan, np.nan, np.nan],
[np.nan, np.nan, np.nan, np.nan, np.nan, np.nan],
]
],
dtype="uint8",
dtype="float32",
)

water_frequency = np.array(
Expand Down Expand Up @@ -442,10 +442,10 @@ def test_level4(urban_shape):
dims=("spec", "y", "x"),
attrs={"nodata": 255},
),
"water_season": xr.DataArray(
da.from_array(water_season, chunks=(1, -1, -1)),
"frequency": xr.DataArray(
da.from_array(frequency, chunks=(1, -1, -1)),
dims=("spec", "y", "x"),
attrs={"nodata": 255},
attrs={"nodata": "nan"},
),
}

Expand Down Expand Up @@ -482,7 +482,7 @@ def test_level4(urban_shape):
"level1",
"level3",
"woody",
"water_season",
"frequency",
"water_frequency",
"pv_pc_50",
"bs_pc_50",
Expand Down

0 comments on commit 52b842e

Please sign in to comment.