Skip to content

Commit

Permalink
make cultivated classification conservative
Browse files Browse the repository at this point in the history
  • Loading branch information
Emma Ai committed Nov 4, 2024
1 parent 39c9445 commit 4ce1d71
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions odc/stats/plugins/lc_treelite_cultivated.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def predict(self, input_array):

def aggregate_results_from_group(self, predict_output):
# if there are >= 2 images
# any is cultivated -> final class is cultivated
# any is natural -> final class is natrual
# any is valid -> final class is valid
# for each pixel
m_size = len(predict_output)
Expand All @@ -268,13 +268,13 @@ def aggregate_results_from_group(self, predict_output):
else:
predict_output = predict_output[0]

predict_output = expr_eval(
"where(a<nodata, 1-a, a)",
{"a": predict_output},
name="invert_output",
dtype="float32",
**{"nodata": NODATA},
)
# predict_output = expr_eval(
# "where(a<nodata, 1-a, a)",
# {"a": predict_output},
# name="invert_output",
# dtype="float32",
# **{"nodata": NODATA},
# )

if m_size > 1:
predict_output = predict_output.sum(axis=0)
Expand All @@ -290,17 +290,17 @@ def aggregate_results_from_group(self, predict_output):
predict_output = expr_eval(
"where((a>0)&(a<nodata), _u, a)",
{"a": predict_output},
name="output_classes_cultivated",
name="output_classes_natural",
dtype="float32",
**{"_u": self.output_classes["cultivated"], "nodata": NODATA},
**{"_u": self.output_classes["natural"], "nodata": NODATA},
)

predict_output = expr_eval(
"where(a<=0, _nu, a)",
{"a": predict_output},
name="output_classes_natural",
name="output_classes_cultivated",
dtype="uint8",
**{"_nu": self.output_classes["natural"]},
**{"_nu": self.output_classes["cultivated"]},
)

return predict_output.rechunk(-1, -1)
Expand Down

0 comments on commit 4ce1d71

Please sign in to comment.