Skip to content

Commit

Permalink
Merge pull request #75 from WorldCereal/croptype-inference
Browse files Browse the repository at this point in the history
Add croptype inference
  • Loading branch information
GriffinBabe authored Jun 24, 2024
2 parents 3b70eee + db9025a commit 05c1cc7
Show file tree
Hide file tree
Showing 4 changed files with 244 additions and 21 deletions.
68 changes: 68 additions & 0 deletions scripts/inference/croptype_mapping_local.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
"""Perform cropland mapping inference using a local execution of presto.
Make sure you test this script on the Python version 3.9+, and have worldcereal
dependencies installed with the presto wheel file installed with it's dependencies.
"""

from pathlib import Path

import requests
import xarray as xr
from openeo_gfmap.features.feature_extractor import (
EPSG_HARMONIZED_NAME,
apply_feature_extractor_local,
)
from openeo_gfmap.inference.model_inference import apply_model_inference_local

from worldcereal.openeo.feature_extractor import PrestoFeatureExtractor
from worldcereal.openeo.inference import CroptypeClassifier

TEST_FILE_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/presto/localtestdata/local_presto_inputs.nc"
TEST_FILE_PATH = Path.cwd() / "presto_test_inputs.nc"
PRESTO_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/models/PhaseII/presto-ss-wc-ft-ct-30D_test.pt"
CATBOOST_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/models/PhaseII/presto-ss-wc-ft-ct-30D_test_CROPTYPE9.onnx"

if __name__ == "__main__":
if not TEST_FILE_PATH.exists():
print("Downloading test input data...")
# Download the test input data
with requests.get(TEST_FILE_URL, stream=True, timeout=180) as response:
response.raise_for_status()
with open(TEST_FILE_PATH, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)

print("Loading array in-memory...")
arr = (
xr.open_dataset(TEST_FILE_PATH)
.to_array(dim="bands")
.drop_sel(bands="crs")
.astype("uint16")
)

print("Running presto UDF locally")
features = apply_feature_extractor_local(
PrestoFeatureExtractor,
arr,
parameters={
EPSG_HARMONIZED_NAME: 32631,
"ignore_dependencies": True,
"presto_model_url": PRESTO_URL,
},
)

features.to_netcdf(Path.cwd() / "presto_test_features_croptype.nc")

print("Running classification inference UDF locally")

classification = apply_model_inference_local(
CroptypeClassifier,
features,
parameters={
EPSG_HARMONIZED_NAME: 32631,
"ignore_dependencies": True,
"classifier_url": CATBOOST_URL,
},
)

classification.to_netcdf(Path.cwd() / "test_classification_croptype.nc")
110 changes: 89 additions & 21 deletions src/worldcereal/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,9 @@
from openeo_gfmap.preprocessing.scaling import compress_uint8, compress_uint16

from worldcereal.openeo.feature_extractor import PrestoFeatureExtractor
from worldcereal.openeo.inference import CroplandClassifier
from worldcereal.openeo.inference import CroplandClassifier, CroptypeClassifier
from worldcereal.openeo.preprocessing import worldcereal_preprocessed_inputs_gfmap

ONNX_DEPS_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/openeo/onnx_dependencies_1.16.3.zip"


class WorldCerealProduct(Enum):
"""Enum to define the different WorldCereal products."""
Expand All @@ -24,6 +22,42 @@ class WorldCerealProduct(Enum):
CROPTYPE = "croptype"


ONNX_DEPS_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/openeo/onnx_dependencies_1.16.3.zip"

PRODUCT_SETTINGS = {
WorldCerealProduct.CROPLAND: {
"features": {
"extractor": PrestoFeatureExtractor,
"parameters": {
"rescale_s1": False, # Will be done in the Presto UDF itself!
"presto_model_url": "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal-minimal-inference/presto.pt", # NOQA
},
},
"classification": {
"classifier": CroplandClassifier,
"parameters": {
"classifier_url": "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal-minimal-inference/wc_catboost.onnx" # NOQA
},
},
},
WorldCerealProduct.CROPTYPE: {
"features": {
"extractor": PrestoFeatureExtractor,
"parameters": {
"rescale_s1": False, # Will be done in the Presto UDF itself!
"presto_model_url": "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/models/PhaseII/presto-ss-wc-ft-ct-30D_test.pt", # NOQA
},
},
"classification": {
"classifier": CroptypeClassifier,
"parameters": {
"classifier_url": "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/models/PhaseII/presto-ss-wc-ft-ct-30D_test_CROPTYPE9.onnx" # NOQA
},
},
},
}


@dataclass
class InferenceResults:
"""Dataclass to store the results of the WorldCereal job.
Expand Down Expand Up @@ -69,6 +103,9 @@ def generate_map(
"""

if product_type not in PRODUCT_SETTINGS.keys():
raise ValueError(f"Product {product_type.value} not supported.")

# Connect to openeo
connection = openeo.connect(
"https://openeo.creo.vito.be/openeo/"
Expand All @@ -83,14 +120,10 @@ def generate_map(
)

# Run feature computer
presto_parameters = {
"rescale_s1": False, # Will be done in the Presto UDF itself!
}

features = apply_feature_extractor(
feature_extractor_class=PrestoFeatureExtractor,
feature_extractor_class=PRODUCT_SETTINGS[product_type]["features"]["extractor"],
cube=inputs,
parameters=presto_parameters,
parameters=PRODUCT_SETTINGS[product_type]["features"]["parameters"],
size=[
{"dimension": "x", "unit": "px", "value": 100},
{"dimension": "y", "unit": "px", "value": 100},
Expand All @@ -101,20 +134,15 @@ def generate_map(
],
)

if product_type == WorldCerealProduct.CROPLAND:
# initiate default cropland model
model_inference_class = CroplandClassifier
model_inference_parameters = {}
else:
raise ValueError(f"Product {product_type} not supported.")

if out_format not in ["GTiff", "NetCDF"]:
raise ValueError(f"Format {out_format} not supported.")
raise ValueError(f"Format {format} not supported.")

classes = apply_model_inference(
model_inference_class=model_inference_class,
model_inference_class=PRODUCT_SETTINGS[product_type]["classification"][
"classifier"
],
cube=features,
parameters=model_inference_parameters,
parameters=PRODUCT_SETTINGS[product_type]["classification"]["parameters"],
size=[
{"dimension": "x", "unit": "px", "value": 100},
{"dimension": "y", "unit": "px", "value": 100},
Expand All @@ -137,11 +165,11 @@ def generate_map(
out_format=out_format,
job_options={
"driver-memory": "4g",
"executor-memoryOverhead": "12g",
"executor-memoryOverhead": "6g",
"udf-dependency-archives": [f"{ONNX_DEPS_URL}#onnx_deps"],
},
)
# Should contain a single job as this is a single-jon tile inference.

asset = job.get_results().get_assets()[0]

return InferenceResults(
Expand All @@ -150,3 +178,43 @@ def generate_map(
output_path=output_path,
product=product_type,
)


def collect_inputs(
spatial_extent: BoundingBoxExtent,
temporal_extent: TemporalContext,
backend_context: BackendContext,
output_path: Union[Path, str],
):
"""Function to retrieve preprocessed inputs that are being
used in the generation of WorldCereal products.
Args:
spatial_extent (BoundingBoxExtent): spatial extent of the map
temporal_extent (TemporalContext): temporal range to consider
backend_context (BackendContext): backend to run the job on
output_path (Union[Path, str]): output path to download the product to
Raises:
ValueError: if the product is not supported
"""

# Connect to openeo
connection = openeo.connect(
"https://openeo.creo.vito.be/openeo/"
).authenticate_oidc()

# Preparing the input cube for the inference
inputs = worldcereal_preprocessed_inputs_gfmap(
connection=connection,
backend_context=backend_context,
spatial_extent=spatial_extent,
temporal_extent=temporal_extent,
)

inputs.execute_batch(
outputfile=output_path,
out_format="NetCDF",
job_options={"driver-memory": "4g", "executor-memoryOverhead": "4g"},
)
86 changes: 86 additions & 0 deletions src/worldcereal/openeo/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,89 @@ def execute(self, inarr: xr.DataArray) -> xr.DataArray:
)

return classification


class CroptypeClassifier(ModelInference):
"""Multi-class crop classifier using ONNX to load a catboost model.
The classifier use the embeddings computed from the Presto Feature
Extractor.
Interesting UDF parameters:
- classifier_url: A public URL to the ONNX classification model. Default is
the public Presto model.
"""

import numpy as np

CATBOOST_PATH = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/models/PhaseII/presto-ss-wc-ft-ct-30D_test_CROPTYPE9.onnx" # NOQA

def __init__(self):
super().__init__()

self.onnx_session = None

def dependencies(self) -> list:
return [] # Disable the dependencies from PIP install

def output_labels(self) -> list:
return ["classification", "probability"]

def predict(self, features: np.ndarray) -> np.ndarray:
"""
Predicts labels using the provided features array.
"""
import numpy as np

if self.onnx_session is None:
raise ValueError("Model has not been loaded. Please load a model first.")

# Prepare input data for ONNX model
outputs = self.onnx_session.run(None, {"features": features})

# Apply LUT: TODO:this needs an update!
LUT = {
"barley": 1,
"maize": 2,
"millet_sorghum": 3,
"other_crop": 4,
"rapeseed_rape": 5,
"soy_soybeans": 6,
"sunflower": 7,
"wheat": 8,
}

# Extract classes as INTs and probability of winning class values
labels = np.zeros((len(outputs[0]),), dtype=np.uint16)
probabilities = np.zeros((len(outputs[0]),), dtype=np.uint8)
for i, (label, prob) in enumerate(zip(outputs[0], outputs[1])):
labels[i] = LUT[label]
probabilities[i] = int(prob[label] * 100)

return np.stack([labels, probabilities], axis=0)

def execute(self, inarr: xr.DataArray) -> xr.DataArray:
classifier_url = self._parameters.get("classifier_url", self.CATBOOST_PATH)

# shape and indices for output ("xy", "bands")
x_coords, y_coords = inarr.x.values, inarr.y.values
inarr = inarr.transpose("bands", "x", "y").stack(xy=["x", "y"]).transpose()

self.onnx_session = self.load_ort_session(classifier_url)

# Run catboost classification
self.logger.info("Catboost classification with input shape: %s", inarr.shape)
classification = self.predict(inarr.values)
self.logger.info("Classification done with shape: %s", inarr.shape)

classification = xr.DataArray(
classification.reshape((2, len(x_coords), len(y_coords))),
dims=["bands", "x", "y"],
coords={
"bands": ["classification", "probability"],
"x": x_coords,
"y": y_coords,
},
)

return classification
1 change: 1 addition & 0 deletions src/worldcereal/openeo/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ def precomposited_datacube_METEO(
temporal_extent=temporal_extent,
bands=["precipitation-flux", "temperature-mean"],
)
cube.result_node().update_arguments(featureflags={"tilesize": 1})
cube = cube.rename_labels(
dimension="bands", target=["AGERA5-PRECIP", "AGERA5-TMEAN"]
)
Expand Down

0 comments on commit 05c1cc7

Please sign in to comment.