Skip to content

Commit

Permalink
Merge pull request #65 from WorldCereal/61-probabilities-inference
Browse files Browse the repository at this point in the history
Now writing GeoTiff and having two bands: binary mask and probabilities
  • Loading branch information
GriffinBabe authored Jun 20, 2024
2 parents 8fbc9b6 + c14db91 commit 77d14af
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 8 deletions.
6 changes: 5 additions & 1 deletion scripts/inference/cropland_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from openeo_gfmap.backend import Backend, BackendContext
from openeo_gfmap.features.feature_extractor import apply_feature_extractor
from openeo_gfmap.inference.model_inference import apply_model_inference
from openeo_gfmap.preprocessing.scaling import compress_uint8

from worldcereal.openeo.feature_extractor import PrestoFeatureExtractor
from worldcereal.openeo.inference import CroplandClassifier
Expand Down Expand Up @@ -105,9 +106,12 @@
],
)

# Cast to uint8
classes = compress_uint8(classes)

classes.execute_batch(
outputfile=args.output_path,
out_format="NetCDF",
out_format="GTiff",
job_options={
"driver-memory": "4g",
"executor-memoryOverhead": "12g",
Expand Down
21 changes: 14 additions & 7 deletions src/worldcereal/openeo/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def dependencies(self) -> list:
return [] # Disable the dependencies from PIP install

def output_labels(self) -> list:
return ["classification"]
return ["classification", "probability"]

def predict(self, features: np.ndarray) -> np.ndarray:
"""
Expand All @@ -48,9 +48,12 @@ def predict(self, features: np.ndarray) -> np.ndarray:
# Extract all prediction values and convert them to binary labels
prediction_values = [sublist["True"] for sublist in outputs[1]]
binary_labels = np.array(prediction_values) >= threshold
binary_labels = binary_labels.astype(int)
binary_labels = binary_labels.astype("uint8")

return binary_labels
prediction_values = np.array(prediction_values) * 100.0
prediction_values = np.round(prediction_values).astype("uint8")

return np.stack([binary_labels, prediction_values], axis=0)

def execute(self, inarr: xr.DataArray) -> xr.DataArray:
classifier_url = self._parameters.get("classifier_url", self.CATBOOST_PATH)
Expand All @@ -62,14 +65,18 @@ def execute(self, inarr: xr.DataArray) -> xr.DataArray:
self.onnx_session = self.load_ort_session(classifier_url)

# Run catboost classification
self.logger.info(f"Catboost classification with input shape: {inarr.shape}")
self.logger.info("Catboost classification with input shape: %s", inarr.shape)
classification = self.predict(inarr.values)
self.logger.info(f"Classification done with shape: {classification.shape}")
self.logger.info("Classification done with shape: %s", inarr.shape)

classification = xr.DataArray(
classification.reshape((1, len(x_coords), len(y_coords))),
classification.reshape((2, len(x_coords), len(y_coords))),
dims=["bands", "x", "y"],
coords={"bands": ["classification"], "x": x_coords, "y": y_coords},
coords={
"bands": ["classification", "probability"],
"x": x_coords,
"y": y_coords,
},
)

return classification

0 comments on commit 77d14af

Please sign in to comment.