Skip to content

Commit

Permalink
add tolerance
Browse files Browse the repository at this point in the history
  • Loading branch information
Emma Ai committed Aug 21, 2024
1 parent 4480ab1 commit d7307d9
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 27 deletions.
2 changes: 1 addition & 1 deletion docker/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ odc-dscache>=0.2.3
odc-stac @ git+https://github.com/opendatacube/odc-stac@69bdf64

# odc-stac is in PyPI
odc-stats[ows] @ git+https://github.com/opendatacube/odc-stats@eee2ed1
odc-stats[ows] @ git+https://github.com/opendatacube/odc-stats@4480ab1

# For ML
tflite-runtime
Expand Down
6 changes: 3 additions & 3 deletions odc/stats/plugins/lc_tf_urban.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,15 +178,15 @@ def aggregate_results_from_group(self, urban_masks):
urban_masks = urban_masks[0]

urban_masks = expr_eval(
"where((a/nodata)>=_l, nodata, a%nodata)",
"where((a/nodata)+0.5>=_l, nodata, a%nodata)",
{"a": urban_masks},
name="mark_nodata",
dtype="float32",
**{"_l": m_size, "nodata": NODATA},
)

urban_masks = expr_eval(
"where((a>0)&(a<nodata), _u, a)",
"where((a>0.5)&(a<nodata), _u, a)",
{"a": urban_masks},
name="output_classes_artificial",
dtype="float32",
Expand All @@ -197,7 +197,7 @@ def aggregate_results_from_group(self, urban_masks):
)

urban_masks = expr_eval(
"where(a<=0, _nu, a)",
"where(a<0.5, _nu, a)",
{"a": urban_masks},
name="output_classes_natrual",
dtype="uint8",
Expand Down
39 changes: 19 additions & 20 deletions odc/stats/plugins/lc_treelite_cultivated.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,21 +164,6 @@ def generate_features(input_block, bands_indices):
norm, 1e-8
) # Avoid division by zero, fine if it's nan

# reassemble the array
output_block = np.concatenate(
[output_block, input_block[..., bands_indices["sdev"]][..., np.newaxis]],
axis=-1,
).astype("float32")
# scale edev \in [0, 1]
edev = input_block[..., bands_indices["edev"]] / 1e4
output_block = np.concatenate(
[output_block, edev[..., np.newaxis]], axis=-1
).astype("float32")
output_block = np.concatenate(
[output_block, input_block[..., bands_indices["bcdev"]][..., np.newaxis]],
axis=-1,
).astype("float32")

feature_block = None
for f, p in zip(
[
Expand All @@ -193,15 +178,29 @@ def generate_features(input_block, bands_indices):
],
feature_input_indices,
):
ib = f(output_block[..., : bands_indices["nbart_swir_2"] + 1], *p)
ib = f(output_block, *p)
if feature_block is None:
feature_block = ib[..., np.newaxis]
else:
feature_block = np.concatenate(
[feature_block, ib[..., np.newaxis]], axis=-1
)

# reassemble the array
output_block = np.concatenate(
[output_block, input_block[..., bands_indices["sdev"]][..., np.newaxis]],
axis=-1,
).astype("float32")
# scale edev \in [0, 1]
edev = input_block[..., bands_indices["edev"]] / 1e4
output_block = np.concatenate(
[output_block, edev[..., np.newaxis]], axis=-1
).astype("float32")
output_block = np.concatenate(
[output_block, input_block[..., bands_indices["bcdev"]][..., np.newaxis]],
axis=-1,
).astype("float32")
output_block = np.concatenate([output_block, feature_block], axis=-1)

selected_indices = np.r_[
[
bands_indices[k]
Expand Down Expand Up @@ -288,23 +287,23 @@ def aggregate_results_from_group(self, predict_output):
predict_output = predict_output.sum(axis=0)

predict_output = expr_eval(
"where((m/nodata)>=_l, nodata, m%nodata)",
"where((m/nodata)+0.5>=_l, nodata, m%nodata)",
{"m": predict_output},
name="mark_nodata",
dtype="float32",
**{"_l": m_size, "nodata": NODATA},
)

predict_output = expr_eval(
"where((m>0)&(m<nodata), _u, m)",
"where((m>0.5)&(m<nodata), _u, m)",
{"m": predict_output},
name="output_classes_cultivated",
dtype="float32",
**{"_u": self.output_classes["cultivated"], "nodata": NODATA},
)

predict_output = expr_eval(
"where(m<=0, _nu, m)",
"where(m<0.5, _nu, m)",
{"m": predict_output},
name="output_classes_natural",
dtype="uint8",
Expand Down
6 changes: 3 additions & 3 deletions odc/stats/plugins/lc_treelite_woody.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def aggregate_results_from_group(self, predict_output):
predict_output = predict_output.sum(axis=0)

predict_output = expr_eval(
"where((a/nodata)>=_l, nodata, a%nodata)",
"where((a/nodata)+0.5>=_l, nodata, a%nodata)",
{"a": predict_output},
name="summary_over_classes",
dtype="float32",
Expand All @@ -77,15 +77,15 @@ def aggregate_results_from_group(self, predict_output):
)

predict_output = expr_eval(
"where((a>0)&(a<nodata), _nw, a)",
"where((a>0.5)&(a<nodata), _nw, a)",
{"a": predict_output},
name="output_classes_herbaceous",
dtype="float32",
**{"nodata": NODATA, "_nw": self.output_classes["herbaceous"]},
)

predict_output = expr_eval(
"where(a<=0, _nw, a)",
"where(a<0.5, _nw, a)",
{"a": predict_output},
name="output_classes_woody",
dtype="uint8",
Expand Down

0 comments on commit d7307d9

Please sign in to comment.