Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bugs in cultivated and woody cover plugins #149

Merged
merged 7 commits into from
Aug 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ ENV GDAL_DRIVER_PATH=/env/lib/gdalplugins \
GDAL_DATA=/env/share/gdal \
PATH=/env/bin:$PATH

# here is very hacky fix for the threading issue
# MUST follow up with package owner and further address the issue accordingly

RUN wget -q -O /env/lib/python3.10/site-packages/numexpr/necompiler.py https://raw.githubusercontent.com/emmaai/numexpr/master/numexpr/necompiler.py

WORKDIR /tmp

RUN odc-stats --version
16 changes: 15 additions & 1 deletion odc/stats/plugins/lc_ml_treelite.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Dict, Sequence, Optional

import os
import sys
import numpy as np
import numexpr as ne
import xarray as xr
Expand All @@ -21,6 +22,7 @@
from ._registry import StatsPluginInterface
from ._worker import TreeliteModelPlugin
import tl2cgen
import logging


def mask_and_predict(
Expand All @@ -44,6 +46,8 @@ def mask_and_predict(
if block_masked.shape[0] > 0:
dmat = tl2cgen.DMatrix(block_masked)
output_data = predictor.predict(dmat).squeeze(axis=1)
# round the number to float32 resolution
output_data = np.round(output_data, 6)
if ptype == "categorical":
prediction[mask_flat] = output_data.argmax(axis=-1)[..., np.newaxis]
else:
Expand All @@ -70,6 +74,7 @@ def __init__(
self.dask_worker_plugin = TreeliteModelPlugin(model_path)
self.output_classes = output_classes
self.mask_bands = mask_bands
self._log = logging.getLogger(__name__)

def input_data(
self, datasets: Sequence[Dataset], geobox: GeoBox, **kwargs
Expand Down Expand Up @@ -117,6 +122,7 @@ def input_data(

def preprocess_predict_input(self, xx: xr.Dataset):
images = []
veg_mask = None
for var in xx.data_vars:
image = xx[var].data
if var not in self.mask_bands:
Expand All @@ -140,6 +146,9 @@ def preprocess_predict_input(self, xx: xr.Dataset):
**{"_v": int(self.mask_bands[var])},
)

if veg_mask is None:
raise TypeError("Missing Veg Mask")

images = [
da.concatenate([image, veg_mask[..., np.newaxis]], axis=-1).rechunk(
(None, None, image.shape[-1] + veg_mask.shape[-1])
Expand All @@ -157,7 +166,12 @@ def aggregate_results_from_group(self, predict_output):
pass

def reduce(self, xx: xr.Dataset) -> xr.Dataset:
images = self.preprocess_predict_input(xx)
try:
images = self.preprocess_predict_input(xx)
except TypeError as e:
self._log.warning(e)
sys.exit(0)

res = []

for image in images:
Expand Down
6 changes: 3 additions & 3 deletions odc/stats/plugins/lc_treelite_woody.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,13 @@ def aggregate_results_from_group(self, predict_output):
)

if m_size > 1:
predict_output = predict_output.sum(axis=0).astype("int")
predict_output = predict_output.sum(axis=0)

predict_output = expr_eval(
"where((a/nodata)>=_l, nodata, a%nodata)",
{"a": predict_output},
name="summary_over_classes",
dtype="uint8",
dtype="float32",
**{
"_l": m_size,
"nodata": NODATA,
Expand All @@ -80,7 +80,7 @@ def aggregate_results_from_group(self, predict_output):
"where((a>0)&(a<nodata), _nw, a)",
{"a": predict_output},
name="output_classes_herbaceous",
dtype="uint8",
dtype="float32",
**{"nodata": NODATA, "_nw": self.output_classes["herbaceous"]},
)

Expand Down
4 changes: 4 additions & 0 deletions tests/test_rf_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,10 @@ def test_cultivated_reduce(
== np.array([[112, 255], [112, 112]], dtype="uint8")
).all()

with pytest.raises(SystemExit) as excinfo:
cultivated.reduce(input_datasets.drop("classes_l3_l4"))
assert excinfo.value.code == 0


def test_woody_aggregate_results(
woody_input_bands,
Expand Down
Loading