Skip to content

Commit

Permalink
Merge pull request #48 from WorldCereal/44-translate-cropno-crop-infe…
Browse files Browse the repository at this point in the history
…rence-udf-to-gfmap

44 translate cropno crop inference udf to gfmap
  • Loading branch information
GriffinBabe authored Jun 12, 2024
2 parents 5d9ce61 + 20746f4 commit 77d5efc
Show file tree
Hide file tree
Showing 6 changed files with 371 additions and 244 deletions.
15 changes: 15 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,18 @@ notebooks/S1A_IW_GRDH_1SDV_20191026T153410_20191026T153444_029631_035FDA_2640.SA
scripts/classification/tenpercent_sparse/.nfs00000000c35c9cfd00000035
download.zip
catboost_info/catboost_training.json

*.cbm
*.pt
*.onnx
*.nc
*.7z
*.dmg
*.gz
*.iso
*.jar
*.rar
*.tar
*.zip

.notebook-tests/
93 changes: 70 additions & 23 deletions scripts/inference/cropland_mapping.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,19 @@
"""Cropland mapping inference script, demonstrating the use of the GFMAP, Presto and WorldCereal classifiers in a first inference pipeline."""

import argparse
from pathlib import Path

import openeo
from openeo_gfmap import BoundingBoxExtent, TemporalContext
from openeo_gfmap.backend import Backend, BackendContext, cdse_connection
from openeo_gfmap.features.feature_extractor import PatchFeatureExtractor
from openeo_gfmap.backend import Backend, BackendContext
from openeo_gfmap.features.feature_extractor import apply_feature_extractor
from openeo_gfmap.inference.model_inference import apply_model_inference

from worldcereal.openeo.feature_extractor import PrestoFeatureExtractor
from worldcereal.openeo.inference import CroplandClassifier
from worldcereal.openeo.preprocessing import worldcereal_preprocessed_inputs_gfmap


class PrestoFeatureExtractor(PatchFeatureExtractor):
def __init__(self):
pass

def extract(self, image):
pass

ONNX_DEPS_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/openeo/onnx_dependencies_1.16.3.zip"

if __name__ == "__main__":
parser = argparse.ArgumentParser(
Expand All @@ -27,12 +25,20 @@ def extract(self, image):
parser.add_argument("miny", type=float, help="Minimum Y coordinate (south)")
parser.add_argument("maxx", type=float, help="Maximum X coordinate (east)")
parser.add_argument("maxy", type=float, help="Maximum Y coordinate (north)")
parser.add_argument(
"--epsg",
type=int,
default=4326,
help="EPSG code of the input `minx`, `miny`, `maxx`, `maxy` parameters.",
)
parser.add_argument(
"start_date", type=str, help="Starting date for data extraction."
)
parser.add_argument("end_date", type=str, help="Ending date for data extraction.")
parser.add_argument(
"output_folder", type=str, help="Path to folder where to save results."
"output_path",
type=Path,
help="Path to folder where to save the resulting NetCDF.",
)

args = parser.parse_args()
Expand All @@ -41,29 +47,70 @@ def extract(self, image):
miny = args.miny
maxx = args.maxx
maxy = args.maxy
epsg = args.epsg

start_date = args.start_date
end_date = args.end_date

spatial_extent = BoundingBoxExtent(minx, miny, maxx, maxy)
spatial_extent = BoundingBoxExtent(minx, miny, maxx, maxy, epsg)
temporal_extent = TemporalContext(start_date, end_date)

backend = BackendContext(Backend.CDSE)
backend_context = BackendContext(Backend.FED)

connection = openeo.connect(
"https://openeo.creo.vito.be/openeo/"
).authenticate_oidc()

# Preparing the input cube for the inference
input_cube = worldcereal_preprocessed_inputs_gfmap(
connection=cdse_connection(),
backend_context=backend,
inputs = worldcereal_preprocessed_inputs_gfmap(
connection=connection,
backend_context=backend_context,
spatial_extent=spatial_extent,
temporal_extent=temporal_extent,
)

# Start the job and download
job = input_cube.create_job(
title=f"Cropland inference BBOX: {minx} {miny} {maxx} {maxy}",
description="Cropland inference using WorldCereal, Presto and GFMAP classifiers",
out_format="NetCDF",
# Test feature computer
presto_parameters = {
"rescale_s1": False, # Will be done in the Presto UDF itself!
}

features = apply_feature_extractor(
feature_extractor_class=PrestoFeatureExtractor,
cube=inputs,
parameters=presto_parameters,
size=[
{"dimension": "x", "unit": "px", "value": 100},
{"dimension": "y", "unit": "px", "value": 100},
],
overlap=[
{"dimension": "x", "unit": "px", "value": 0},
{"dimension": "y", "unit": "px", "value": 0},
],
)

catboost_parameters = {}

classes = apply_model_inference(
model_inference_class=CroplandClassifier,
cube=features,
parameters=catboost_parameters,
size=[
{"dimension": "x", "unit": "px", "value": 100},
{"dimension": "y", "unit": "px", "value": 100},
{"dimension": "t", "value": "P1D"},
],
overlap=[
{"dimension": "x", "unit": "px", "value": 0},
{"dimension": "y", "unit": "px", "value": 0},
],
)

job.start_and_wait()
job.get_results().download_files(args.output_folder)
classes.execute_batch(
outputfile=args.output_path,
out_format="NetCDF",
job_options={
"driver-memory": "4g",
"executor-memoryOverhead": "12g",
"udf-dependency-archives": [f"{ONNX_DEPS_URL}#onnx_deps"],
},
)
113 changes: 113 additions & 0 deletions src/worldcereal/openeo/feature_extractor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
"""Feature computer GFMAP compatible to compute Presto embeddings."""

import xarray as xr
from openeo.udf import XarrayDataCube
from openeo_gfmap.features.feature_extractor import PatchFeatureExtractor


class PrestoFeatureExtractor(PatchFeatureExtractor):
"""Feature extractor to use Presto model to compute per-pixel embeddings.
This will generate a datacube with 128 bands, each band representing a
feature from the Presto model.
Interesting UDF parameters:
- presto_url: A public URL to the Presto model file. A default Presto
version is provided if the parameter is left undefined.
- rescale_s1: Is specifically disabled by default, as the presto
dependencies already take care of the backscatter decompression. If
specified, should be set as `False`.
"""

PRESTO_MODEL_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal-minimal-inference/presto.pt" # NOQA
PRESO_WHL_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/dependencies/presto_worldcereal-0.1.0-py3-none-any.whl"
BASE_URL = "https://s3.waw3-1.cloudferro.com/swift/v1/project_dependencies" # NOQA
DEPENDENCY_NAME = "worldcereal_deps.zip"

GFMAP_BAND_MAPPING = {
"S2-L2A-B02": "B02",
"S2-L2A-B03": "B03",
"S2-L2A-B04": "B04",
"S2-L2A-B05": "B05",
"S2-L2A-B06": "B06",
"S2-L2A-B07": "B07",
"S2-L2A-B08": "B08",
"S2-L2A-B8A": "B8A",
"S2-L2A-B11": "B11",
"S2-L2A-B12": "B12",
"S1-SIGMA0-VH": "VH",
"S1-SIGMA0-VV": "VV",
"COP-DEM": "DEM",
"AGERA5-TMEAN": "temperature-mean",
"AGERA5-PRECIP": "precipitation-flux",
}

def unpack_presto_wheel(self, wheel_url: str, destination_dir: str) -> list:
import urllib.request
import zipfile
from pathlib import Path

# Downloads the wheel file
modelfile, _ = urllib.request.urlretrieve(
wheel_url, filename=Path.cwd() / Path(wheel_url).name
)
with zipfile.ZipFile(modelfile, "r") as zip_ref:
zip_ref.extractall(destination_dir)
return destination_dir

def output_labels(self) -> list:
"""Returns the output labels from this UDF, which is the output labels
of the presto embeddings"""
return [f"presto_ft_{i}" for i in range(128)]

def execute(self, inarr: xr.DataArray) -> xr.DataArray:
import sys
from pathlib import Path

if self.epsg is None:
raise ValueError(
"EPSG code is required for Presto feature extraction, but was "
"not correctly initialized."
)
presto_model_url = self._parameters.get(
"presto_model_url", self.PRESTO_MODEL_URL
)
presto_wheel_url = self._parameters.get("presot_wheel_url", self.PRESO_WHL_URL)

# The below is required to avoid flipping of the result
# when running on OpenEO backend!
inarr = inarr.transpose("bands", "t", "x", "y")

# Change the band names
new_band_names = [
self.GFMAP_BAND_MAPPING.get(b.item(), b.item()) for b in inarr.bands
]
inarr = inarr.assign_coords(bands=new_band_names)

# Handle NaN values in Presto compatible way
inarr = inarr.fillna(65535)

# Unzip de dependencies on the backend
self.logger.info("Unzipping dependencies")
deps_dir = self.extract_dependencies(self.BASE_URL, self.DEPENDENCY_NAME)
self.logger.info("Unpacking presto wheel")
deps_dir = self.unpack_presto_wheel(presto_wheel_url, deps_dir)

self.logger.info("Appending dependencies")
sys.path.append(str(deps_dir))

# Debug, print the dependency directory
self.logger.info("Dependency directory: %s", list(Path(deps_dir).iterdir()))

from presto.inference import ( # pylint: disable=import-outside-toplevel
get_presto_features,
)

self.logger.info("Extracting presto features")
features = get_presto_features(inarr, presto_model_url, self.epsg)
return features

def _execute(self, cube: XarrayDataCube, parameters: dict) -> XarrayDataCube:
# Disable S1 rescaling (decompression) by default
if parameters.get("rescale_s1", None) is None:
parameters.update({"rescale_s1": False})
return super()._execute(cube, parameters)
Loading

0 comments on commit 77d5efc

Please sign in to comment.