Skip to content

Commit 1166fa0

Browse files
author
Darius Couchard
committed
Now saves uint16 values with 0-1 for crop/no-crop and 0-100 for probabilities
1 parent ea4317a commit 1166fa0

File tree

2 files changed

+11
-4
lines changed

2 files changed

+11
-4
lines changed

scripts/inference/cropland_mapping.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from openeo_gfmap.backend import Backend, BackendContext
99
from openeo_gfmap.features.feature_extractor import apply_feature_extractor
1010
from openeo_gfmap.inference.model_inference import apply_model_inference
11+
from openeo_gfmap.preprocessing.scaling import compress_uint16
1112

1213
from worldcereal.openeo.feature_extractor import PrestoFeatureExtractor
1314
from worldcereal.openeo.inference import CroplandClassifier
@@ -105,6 +106,9 @@
105106
],
106107
)
107108

109+
# Cast to uint16
110+
classes = compress_uint16(classes)
111+
108112
classes.execute_batch(
109113
outputfile=args.output_path,
110114
out_format="GTiff",

src/worldcereal/openeo/inference.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,12 @@ def predict(self, features: np.ndarray) -> np.ndarray:
4848
# Extract all prediction values and convert them to binary labels
4949
prediction_values = [sublist["True"] for sublist in outputs[1]]
5050
binary_labels = np.array(prediction_values) >= threshold
51-
binary_labels = binary_labels.astype(int)
51+
binary_labels = binary_labels.astype("uint8")
5252

53-
return np.stack([binary_labels, prediction_values], axis=0).astype(np.float32)
53+
prediction_values = np.array(prediction_values) * 100.0
54+
prediction_values = np.round(prediction_values).astype("uint8")
55+
56+
return np.stack([binary_labels, prediction_values], axis=0)
5457

5558
def execute(self, inarr: xr.DataArray) -> xr.DataArray:
5659
classifier_url = self._parameters.get("classifier_url", self.CATBOOST_PATH)
@@ -62,9 +65,9 @@ def execute(self, inarr: xr.DataArray) -> xr.DataArray:
6265
self.onnx_session = self.load_ort_session(classifier_url)
6366

6467
# Run catboost classification
65-
self.logger.info(f"Catboost classification with input shape: {inarr.shape}")
68+
self.logger.info("Catboost classification with input shape: %s", inarr.shape)
6669
classification = self.predict(inarr.values)
67-
self.logger.info(f"Classification done with shape: {classification.shape}")
70+
self.logger.info("Classification done with shape: %s", inarr.shape)
6871

6972
classification = xr.DataArray(
7073
classification.reshape((2, len(x_coords), len(y_coords))),

0 commit comments

Comments
 (0)