Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add croptype inference #75

Merged
merged 7 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions scripts/inference/croptype_mapping_local.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
"""Perform cropland mapping inference using a local execution of presto.

Make sure you test this script on the Python version 3.9+, and have worldcereal
dependencies installed with the presto wheel file installed with it's dependencies.
"""

from pathlib import Path

import requests
import xarray as xr
from openeo_gfmap.features.feature_extractor import (
EPSG_HARMONIZED_NAME,
apply_feature_extractor_local,
)
from openeo_gfmap.inference.model_inference import apply_model_inference_local

from worldcereal.openeo.feature_extractor import PrestoFeatureExtractor
from worldcereal.openeo.inference import CroptypeClassifier

TEST_FILE_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/presto/localtestdata/local_presto_inputs.nc"
TEST_FILE_PATH = Path.cwd() / "presto_test_inputs.nc"
PRESTO_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/models/PhaseII/presto-ss-wc-ft-ct-30D_test.pt"
CATBOOST_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/models/PhaseII/presto-ss-wc-ft-ct-30D_test_CROPTYPE9.onnx"

if __name__ == "__main__":
if not TEST_FILE_PATH.exists():
print("Downloading test input data...")
# Download the test input data
with requests.get(TEST_FILE_URL, stream=True, timeout=180) as response:
response.raise_for_status()
with open(TEST_FILE_PATH, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)

print("Loading array in-memory...")
arr = (
xr.open_dataset(TEST_FILE_PATH)
.to_array(dim="bands")
.drop_sel(bands="crs")
.astype("uint16")
)

print("Running presto UDF locally")
features = apply_feature_extractor_local(
PrestoFeatureExtractor,
arr,
parameters={
EPSG_HARMONIZED_NAME: 32631,
"ignore_dependencies": True,
"presto_model_url": PRESTO_URL,
},
)

features.to_netcdf(Path.cwd() / "presto_test_features_croptype.nc")

print("Running classification inference UDF locally")

classification = apply_model_inference_local(
CroptypeClassifier,
features,
parameters={
EPSG_HARMONIZED_NAME: 32631,
"ignore_dependencies": True,
"classifier_url": CATBOOST_URL,
},
)

classification.to_netcdf(Path.cwd() / "test_classification_croptype.nc")
110 changes: 89 additions & 21 deletions src/worldcereal/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,9 @@
from openeo_gfmap.preprocessing.scaling import compress_uint8, compress_uint16

from worldcereal.openeo.feature_extractor import PrestoFeatureExtractor
from worldcereal.openeo.inference import CroplandClassifier
from worldcereal.openeo.inference import CroplandClassifier, CroptypeClassifier
from worldcereal.openeo.preprocessing import worldcereal_preprocessed_inputs_gfmap

ONNX_DEPS_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/openeo/onnx_dependencies_1.16.3.zip"


class WorldCerealProduct(Enum):
"""Enum to define the different WorldCereal products."""
Expand All @@ -24,6 +22,42 @@ class WorldCerealProduct(Enum):
CROPTYPE = "croptype"


ONNX_DEPS_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/openeo/onnx_dependencies_1.16.3.zip"

PRODUCT_SETTINGS = {
WorldCerealProduct.CROPLAND: {
"features": {
"extractor": PrestoFeatureExtractor,
"parameters": {
"rescale_s1": False, # Will be done in the Presto UDF itself!
"presto_model_url": "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal-minimal-inference/presto.pt", # NOQA
},
},
"classification": {
"classifier": CroplandClassifier,
"parameters": {
"classifier_url": "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal-minimal-inference/wc_catboost.onnx" # NOQA
},
},
},
WorldCerealProduct.CROPTYPE: {
"features": {
"extractor": PrestoFeatureExtractor,
"parameters": {
"rescale_s1": False, # Will be done in the Presto UDF itself!
"presto_model_url": "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/models/PhaseII/presto-ss-wc-ft-ct-30D_test.pt", # NOQA
},
},
"classification": {
"classifier": CroptypeClassifier,
"parameters": {
"classifier_url": "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/models/PhaseII/presto-ss-wc-ft-ct-30D_test_CROPTYPE9.onnx" # NOQA
},
},
},
}


@dataclass
class InferenceResults:
"""Dataclass to store the results of the WorldCereal job.
Expand Down Expand Up @@ -69,6 +103,9 @@ def generate_map(

"""

if product_type not in PRODUCT_SETTINGS.keys():
raise ValueError(f"Product {product_type.value} not supported.")

# Connect to openeo
connection = openeo.connect(
"https://openeo.creo.vito.be/openeo/"
Expand All @@ -83,14 +120,10 @@ def generate_map(
)

# Run feature computer
presto_parameters = {
"rescale_s1": False, # Will be done in the Presto UDF itself!
}

features = apply_feature_extractor(
feature_extractor_class=PrestoFeatureExtractor,
feature_extractor_class=PRODUCT_SETTINGS[product_type]["features"]["extractor"],
cube=inputs,
parameters=presto_parameters,
parameters=PRODUCT_SETTINGS[product_type]["features"]["parameters"],
size=[
{"dimension": "x", "unit": "px", "value": 100},
{"dimension": "y", "unit": "px", "value": 100},
Expand All @@ -101,20 +134,15 @@ def generate_map(
],
)

if product_type == WorldCerealProduct.CROPLAND:
# initiate default cropland model
model_inference_class = CroplandClassifier
model_inference_parameters = {}
else:
raise ValueError(f"Product {product_type} not supported.")

if out_format not in ["GTiff", "NetCDF"]:
raise ValueError(f"Format {out_format} not supported.")
raise ValueError(f"Format {format} not supported.")

classes = apply_model_inference(
model_inference_class=model_inference_class,
model_inference_class=PRODUCT_SETTINGS[product_type]["classification"][
"classifier"
],
cube=features,
parameters=model_inference_parameters,
parameters=PRODUCT_SETTINGS[product_type]["classification"]["parameters"],
size=[
{"dimension": "x", "unit": "px", "value": 100},
{"dimension": "y", "unit": "px", "value": 100},
Expand All @@ -137,11 +165,11 @@ def generate_map(
out_format=out_format,
job_options={
"driver-memory": "4g",
"executor-memoryOverhead": "12g",
"executor-memoryOverhead": "6g",
"udf-dependency-archives": [f"{ONNX_DEPS_URL}#onnx_deps"],
},
)
# Should contain a single job as this is a single-jon tile inference.

asset = job.get_results().get_assets()[0]

return InferenceResults(
Expand All @@ -150,3 +178,43 @@ def generate_map(
output_path=output_path,
product=product_type,
)


def collect_inputs(
spatial_extent: BoundingBoxExtent,
temporal_extent: TemporalContext,
backend_context: BackendContext,
output_path: Union[Path, str],
):
"""Function to retrieve preprocessed inputs that are being
used in the generation of WorldCereal products.

Args:
spatial_extent (BoundingBoxExtent): spatial extent of the map
temporal_extent (TemporalContext): temporal range to consider
backend_context (BackendContext): backend to run the job on
output_path (Union[Path, str]): output path to download the product to

Raises:
ValueError: if the product is not supported

"""

# Connect to openeo
connection = openeo.connect(
"https://openeo.creo.vito.be/openeo/"
).authenticate_oidc()

# Preparing the input cube for the inference
inputs = worldcereal_preprocessed_inputs_gfmap(
connection=connection,
backend_context=backend_context,
spatial_extent=spatial_extent,
temporal_extent=temporal_extent,
)

inputs.execute_batch(
outputfile=output_path,
out_format="NetCDF",
job_options={"driver-memory": "4g", "executor-memoryOverhead": "4g"},
)
86 changes: 86 additions & 0 deletions src/worldcereal/openeo/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,89 @@ def execute(self, inarr: xr.DataArray) -> xr.DataArray:
)

return classification


class CroptypeClassifier(ModelInference):
"""Multi-class crop classifier using ONNX to load a catboost model.

The classifier use the embeddings computed from the Presto Feature
Extractor.

Interesting UDF parameters:
- classifier_url: A public URL to the ONNX classification model. Default is
the public Presto model.
"""

import numpy as np

CATBOOST_PATH = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/models/PhaseII/presto-ss-wc-ft-ct-30D_test_CROPTYPE9.onnx" # NOQA

def __init__(self):
super().__init__()

self.onnx_session = None

def dependencies(self) -> list:
return [] # Disable the dependencies from PIP install

def output_labels(self) -> list:
return ["classification", "probability"]

def predict(self, features: np.ndarray) -> np.ndarray:
"""
Predicts labels using the provided features array.
"""
import numpy as np

if self.onnx_session is None:
raise ValueError("Model has not been loaded. Please load a model first.")

# Prepare input data for ONNX model
outputs = self.onnx_session.run(None, {"features": features})

# Apply LUT: TODO:this needs an update!
LUT = {
"barley": 1,
"maize": 2,
"millet_sorghum": 3,
"other_crop": 4,
"rapeseed_rape": 5,
"soy_soybeans": 6,
"sunflower": 7,
"wheat": 8,
}

# Extract classes as INTs and probability of winning class values
labels = np.zeros((len(outputs[0]),), dtype=np.uint16)
probabilities = np.zeros((len(outputs[0]),), dtype=np.uint8)
for i, (label, prob) in enumerate(zip(outputs[0], outputs[1])):
labels[i] = LUT[label]
probabilities[i] = int(prob[label] * 100)

return np.stack([labels, probabilities], axis=0)

def execute(self, inarr: xr.DataArray) -> xr.DataArray:
classifier_url = self._parameters.get("classifier_url", self.CATBOOST_PATH)

# shape and indices for output ("xy", "bands")
x_coords, y_coords = inarr.x.values, inarr.y.values
inarr = inarr.transpose("bands", "x", "y").stack(xy=["x", "y"]).transpose()

self.onnx_session = self.load_ort_session(classifier_url)

# Run catboost classification
self.logger.info("Catboost classification with input shape: %s", inarr.shape)
classification = self.predict(inarr.values)
self.logger.info("Classification done with shape: %s", inarr.shape)

classification = xr.DataArray(
classification.reshape((2, len(x_coords), len(y_coords))),
dims=["bands", "x", "y"],
coords={
"bands": ["classification", "probability"],
"x": x_coords,
"y": y_coords,
},
)

return classification
1 change: 1 addition & 0 deletions src/worldcereal/openeo/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ def precomposited_datacube_METEO(
temporal_extent=temporal_extent,
bands=["precipitation-flux", "temperature-mean"],
)
cube.result_node().update_arguments(featureflags={"tilesize": 1})
cube = cube.rename_labels(
dimension="bands", target=["AGERA5-PRECIP", "AGERA5-TMEAN"]
)
Expand Down
Loading