Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Centralize worldcereal generate_map method #67

Merged
merged 2 commits into from
Jun 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 8 additions & 66 deletions scripts/inference/cropland_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,10 @@
import argparse
from pathlib import Path

import openeo
from openeo_gfmap import BoundingBoxExtent, TemporalContext
from openeo_gfmap.backend import Backend, BackendContext
from openeo_gfmap.features.feature_extractor import apply_feature_extractor
from openeo_gfmap.inference.model_inference import apply_model_inference
from openeo_gfmap.preprocessing.scaling import compress_uint8

from worldcereal.openeo.feature_extractor import PrestoFeatureExtractor
from worldcereal.openeo.inference import CroplandClassifier
from worldcereal.openeo.preprocessing import worldcereal_preprocessed_inputs_gfmap
from worldcereal.job import generate_map

ONNX_DEPS_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/openeo/onnx_dependencies_1.16.3.zip"

Expand Down Expand Up @@ -58,63 +52,11 @@

backend_context = BackendContext(Backend.FED)

connection = openeo.connect(
"https://openeo.creo.vito.be/openeo/"
).authenticate_oidc()

# Preparing the input cube for the inference
inputs = worldcereal_preprocessed_inputs_gfmap(
connection=connection,
backend_context=backend_context,
spatial_extent=spatial_extent,
temporal_extent=temporal_extent,
)

# Test feature computer
presto_parameters = {
"rescale_s1": False, # Will be done in the Presto UDF itself!
}

features = apply_feature_extractor(
feature_extractor_class=PrestoFeatureExtractor,
cube=inputs,
parameters=presto_parameters,
size=[
{"dimension": "x", "unit": "px", "value": 100},
{"dimension": "y", "unit": "px", "value": 100},
],
overlap=[
{"dimension": "x", "unit": "px", "value": 0},
{"dimension": "y", "unit": "px", "value": 0},
],
)

catboost_parameters = {}

classes = apply_model_inference(
model_inference_class=CroplandClassifier,
cube=features,
parameters=catboost_parameters,
size=[
{"dimension": "x", "unit": "px", "value": 100},
{"dimension": "y", "unit": "px", "value": 100},
{"dimension": "t", "value": "P1D"},
],
overlap=[
{"dimension": "x", "unit": "px", "value": 0},
{"dimension": "y", "unit": "px", "value": 0},
],
)

# Cast to uint8
classes = compress_uint8(classes)

classes.execute_batch(
outputfile=args.output_path,
out_format="GTiff",
job_options={
"driver-memory": "4g",
"executor-memoryOverhead": "12g",
"udf-dependency-archives": [f"{ONNX_DEPS_URL}#onnx_deps"],
},
generate_map(
spatial_extent,
temporal_extent,
backend_context,
args.output_path,
product="cropland",
format="GTiff",
)
108 changes: 108 additions & 0 deletions src/worldcereal/job.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
from pathlib import Path
from typing import Union

import openeo
from openeo_gfmap import BackendContext, BoundingBoxExtent, TemporalContext
from openeo_gfmap.features.feature_extractor import apply_feature_extractor
from openeo_gfmap.inference.model_inference import apply_model_inference
from openeo_gfmap.preprocessing.scaling import compress_uint8

from worldcereal.openeo.feature_extractor import PrestoFeatureExtractor
from worldcereal.openeo.inference import CroplandClassifier
from worldcereal.openeo.preprocessing import worldcereal_preprocessed_inputs_gfmap

ONNX_DEPS_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/openeo/onnx_dependencies_1.16.3.zip"


def generate_map(
spatial_extent: BoundingBoxExtent,
temporal_extent: TemporalContext,
backend_context: BackendContext,
output_path: Union[Path, str],
product: str = "cropland",
format: str = "GTiff",
):
"""Main function to generate a WorldCereal product.

Args:
spatial_extent (BoundingBoxExtent): spatial extent of the map
temporal_extent (TemporalContext): temporal range to consider
backend_context (BackendContext): backend to run the job on
output_path (Union[Path, str]): output path to download the product to
product (str, optional): product describer. Defaults to "cropland".
format (str, optional): Output format. Defaults to "GTiff".

Raises:
ValueError: if the product is not supported

"""

# Connect to openeo
connection = openeo.connect(
"https://openeo.creo.vito.be/openeo/"
).authenticate_oidc()

# Preparing the input cube for the inference
inputs = worldcereal_preprocessed_inputs_gfmap(
connection=connection,
backend_context=backend_context,
spatial_extent=spatial_extent,
temporal_extent=temporal_extent,
)

# Run feature computer
presto_parameters = {
"rescale_s1": False, # Will be done in the Presto UDF itself!
}

features = apply_feature_extractor(
feature_extractor_class=PrestoFeatureExtractor,
cube=inputs,
parameters=presto_parameters,
size=[
{"dimension": "x", "unit": "px", "value": 100},
{"dimension": "y", "unit": "px", "value": 100},
],
overlap=[
{"dimension": "x", "unit": "px", "value": 0},
{"dimension": "y", "unit": "px", "value": 0},
],
)

if product == "cropland":
# initiate default cropland model
model_inference_class = CroplandClassifier
model_inference_parameters = {}
else:
raise ValueError(f"Product {product} not supported.")

if format not in ["GTiff", "NetCDF"]:
raise ValueError(f"Format {format} not supported.")

classes = apply_model_inference(
model_inference_class=model_inference_class,
cube=features,
parameters=model_inference_parameters,
size=[
{"dimension": "x", "unit": "px", "value": 100},
{"dimension": "y", "unit": "px", "value": 100},
{"dimension": "t", "value": "P1D"},
],
overlap=[
{"dimension": "x", "unit": "px", "value": 0},
{"dimension": "y", "unit": "px", "value": 0},
],
)

# Cast to uint8
classes = compress_uint8(classes)

classes.execute_batch(
outputfile=output_path,
out_format=format,
job_options={
"driver-memory": "4g",
"executor-memoryOverhead": "12g",
"udf-dependency-archives": [f"{ONNX_DEPS_URL}#onnx_deps"],
},
)
Loading