diff --git a/scripts/inference/cropland_mapping.py b/scripts/inference/cropland_mapping.py index 03eb7d6c..48f7c1ba 100644 --- a/scripts/inference/cropland_mapping.py +++ b/scripts/inference/cropland_mapping.py @@ -8,6 +8,7 @@ from openeo_gfmap.backend import Backend, BackendContext from openeo_gfmap.features.feature_extractor import apply_feature_extractor from openeo_gfmap.inference.model_inference import apply_model_inference +from openeo_gfmap.preprocessing.scaling import compress_uint8 from worldcereal.openeo.feature_extractor import PrestoFeatureExtractor from worldcereal.openeo.inference import CroplandClassifier @@ -105,9 +106,12 @@ ], ) + # Cast to uint8 + classes = compress_uint8(classes) + classes.execute_batch( outputfile=args.output_path, - out_format="NetCDF", + out_format="GTiff", job_options={ "driver-memory": "4g", "executor-memoryOverhead": "12g", diff --git a/src/worldcereal/openeo/inference.py b/src/worldcereal/openeo/inference.py index e1adf289..cdfd2bf5 100644 --- a/src/worldcereal/openeo/inference.py +++ b/src/worldcereal/openeo/inference.py @@ -28,7 +28,7 @@ def dependencies(self) -> list: return [] # Disable the dependencies from PIP install def output_labels(self) -> list: - return ["classification"] + return ["classification", "probability"] def predict(self, features: np.ndarray) -> np.ndarray: """ @@ -48,9 +48,12 @@ def predict(self, features: np.ndarray) -> np.ndarray: # Extract all prediction values and convert them to binary labels prediction_values = [sublist["True"] for sublist in outputs[1]] binary_labels = np.array(prediction_values) >= threshold - binary_labels = binary_labels.astype(int) + binary_labels = binary_labels.astype("uint8") - return binary_labels + prediction_values = np.array(prediction_values) * 100.0 + prediction_values = np.round(prediction_values).astype("uint8") + + return np.stack([binary_labels, prediction_values], axis=0) def execute(self, inarr: xr.DataArray) -> xr.DataArray: classifier_url = self._parameters.get("classifier_url", self.CATBOOST_PATH) @@ -62,14 +65,18 @@ def execute(self, inarr: xr.DataArray) -> xr.DataArray: self.onnx_session = self.load_ort_session(classifier_url) # Run catboost classification - self.logger.info(f"Catboost classification with input shape: {inarr.shape}") + self.logger.info("Catboost classification with input shape: %s", inarr.shape) classification = self.predict(inarr.values) - self.logger.info(f"Classification done with shape: {classification.shape}") + self.logger.info("Classification done with shape: %s", inarr.shape) classification = xr.DataArray( - classification.reshape((1, len(x_coords), len(y_coords))), + classification.reshape((2, len(x_coords), len(y_coords))), dims=["bands", "x", "y"], - coords={"bands": ["classification"], "x": x_coords, "y": y_coords}, + coords={ + "bands": ["classification", "probability"], + "x": x_coords, + "y": y_coords, + }, ) return classification