Skip to content

Commit

Permalink
change woody cover band name and fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Emma Ai committed Nov 6, 2024
1 parent 59e2358 commit e176cd8
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 15 deletions.
8 changes: 0 additions & 8 deletions odc/stats/plugins/lc_treelite_cultivated.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,14 +268,6 @@ def aggregate_results_from_group(self, predict_output):
else:
predict_output = predict_output[0]

# predict_output = expr_eval(
# "where(a<nodata, 1-a, a)",
# {"a": predict_output},
# name="invert_output",
# dtype="float32",
# **{"nodata": NODATA},
# )

if m_size > 1:
predict_output = predict_output.sum(axis=0)

Expand Down
4 changes: 2 additions & 2 deletions odc/stats/plugins/lc_treelite_woody.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class StatsWoodyCover(StatsMLTree):

@property
def measurements(self) -> Tuple[str, ...]:
_measurements = ["woody_cover"]
_measurements = ["woody"]
return _measurements

def predict(self, input_array):
Expand Down Expand Up @@ -101,7 +101,7 @@ def reduce(self, xx: xr.Dataset) -> xr.Dataset:
attrs = res[var].attrs.copy()
attrs["nodata"] = int(NODATA)
res[var].attrs = attrs
var_rename = {var: "woody_cover"}
var_rename = dict(zip(res.data_vars, self.measurements))
return res.rename(var_rename)


Expand Down
9 changes: 4 additions & 5 deletions tests/test_rf_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,7 @@ def test_cultivated_aggregate_results(
res = cultivated.aggregate_results_from_group([cultivated_results[0]])
assert (res.compute() == np.array([[112, 255], [111, 112]], dtype="uint8")).all()
res = cultivated.aggregate_results_from_group(cultivated_results)
assert (res.compute() == np.array([[111, 112], [111, 112]], dtype="uint8")).all()
assert (res.compute() == np.array([[112, 112], [111, 112]], dtype="uint8")).all()


def test_cultivated_reduce(
Expand Down Expand Up @@ -528,9 +528,8 @@ def test_woody_reduce(
)
dask_client.register_plugin(woody_cover.dask_worker_plugin)
res = woody_cover.reduce(woody_inputs)
assert res["woody_cover"].attrs["nodata"] == 255
assert res["woody_cover"].data.dtype == "uint8"
assert res["woody"].attrs["nodata"] == 255
assert res["woody"].data.dtype == "uint8"
assert (
res["woody_cover"].data.compute()
== np.array([[114, 255], [114, 114]], dtype="uint8")
res["woody"].data.compute() == np.array([[114, 255], [114, 114]], dtype="uint8")
).all()

0 comments on commit e176cd8

Please sign in to comment.