Skip to content

Commit

Permalink
add test for woody cover
Browse files Browse the repository at this point in the history
  • Loading branch information
Emma Ai committed Aug 6, 2024
1 parent 919fb97 commit 3b1caa3
Showing 1 changed file with 68 additions and 0 deletions.
68 changes: 68 additions & 0 deletions tests/test_rf_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
StatsCultivatedClass,
generate_features,
)
from odc.stats.plugins.lc_treelite_woody import StatsWoodyCover
from pathlib import Path
import pytest
import boto3
Expand Down Expand Up @@ -381,6 +382,33 @@ def test_genrate_features(cultivated_input_bands, input_arrays):
assert np.allclose(res, expected_res, rtol=1e-6, atol=1e-8)


@pytest.fixture(scope="module")
def woody_input_bands():
return [
"nbart_blue",
"nbart_green",
"nbart_red",
"nbart_nir",
"nbart_swir_1",
"nbart_swir_2",
"classes_l3_l4",
]


@pytest.fixture(scope="module")
def woody_results():
res = [
da.array([[20.1, 255.0], [19.5, 20.1]], dtype="float32"),
da.array([[19.5, 20.1], [255, 20.1]], dtype="float32"),
]
return res


@pytest.fixture(scope="module")
def woody_classes():
return {"woody": 113, "herbaceous": 114}


def test_preprocess_predict_intput(
cultivated_input_bands,
cultivated_model_path,
Expand Down Expand Up @@ -458,7 +486,47 @@ def test_cultivated_reduce(
dask_client.register_plugin(cultivated.dask_worker_plugin)
res = cultivated.reduce(input_datasets)
assert res["cultivated_class"].attrs["nodata"] == 255
assert res["cultivated_class"].data.dtype == "uint8"
assert (
res["cultivated_class"].data.compute()
== np.array([[112, 255], [112, 112]], dtype="uint8")
).all()


def test_woody_aggregate_results(
woody_input_bands,
woody_model_path,
mask_bands,
woody_classes,
woody_results,
):

woody_cover = StatsWoodyCover(
woody_classes, woody_model_path, mask_bands, input_bands=woody_input_bands
)
res = woody_cover.aggregate_results_from_group([woody_results[0]])
assert (res.compute() == np.array([[113, 255], [114, 113]], dtype="uint8")).all()
res = woody_cover.aggregate_results_from_group(woody_results)
assert (res.compute() == np.array([[114, 113], [114, 113]], dtype="uint8")).all()


def test_woody_reduce(
woody_input_bands,
woody_model_path,
mask_bands,
woody_classes,
input_datasets,
dask_client,
):
woody_inputs = input_datasets.sel(bands=woody_input_bands[:-1])
woody_cover = StatsWoodyCover(
woody_classes, woody_model_path, mask_bands, input_bands=woody_input_bands
)
dask_client.register_plugin(woody_cover.dask_worker_plugin)
res = woody_cover.reduce(woody_inputs)
assert res["woody_cover"].attrs["nodata"] == 255
assert res["woody_cover"].data.dtype == "uint8"
assert (
res["woody_cover"].data.compute()
== np.array([[114, 255], [114, 114]], dtype="uint8")
).all()

0 comments on commit 3b1caa3

Please sign in to comment.