Skip to content

Commit

Permalink
Merge branch 'main' into croptype-inference
Browse files Browse the repository at this point in the history
Conflicts:
	src/worldcereal/job.py
  • Loading branch information
GriffinBabe committed Jun 24, 2024
2 parents 15ef9ff + 3b70eee commit db9025a
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 23 deletions.
10 changes: 6 additions & 4 deletions scripts/inference/cropland_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
import argparse
from pathlib import Path

from loguru import logger
from openeo_gfmap import BoundingBoxExtent, TemporalContext
from openeo_gfmap.backend import Backend, BackendContext

from worldcereal.job import generate_map
from worldcereal.job import WorldCerealProduct, generate_map

if __name__ == "__main__":
parser = argparse.ArgumentParser(
Expand Down Expand Up @@ -50,11 +51,12 @@

backend_context = BackendContext(Backend.FED)

generate_map(
job_results = generate_map(
spatial_extent,
temporal_extent,
backend_context,
args.output_path,
product="cropland",
format="GTiff",
product_type=WorldCerealProduct.CROPLAND,
out_format="GTiff",
)
logger.success("Job finished:\n\t%s", job_results)
81 changes: 62 additions & 19 deletions src/worldcereal/job.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
"""Executing inference jobs on the OpenEO backend."""
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import Union
from typing import Optional, Union

import openeo
from openeo_gfmap import BackendContext, BoundingBoxExtent, TemporalContext
Expand All @@ -8,14 +11,21 @@
from openeo_gfmap.preprocessing.scaling import compress_uint8, compress_uint16

from worldcereal.openeo.feature_extractor import PrestoFeatureExtractor
from worldcereal.openeo.inference import CroplandClassifier
from worldcereal.openeo.inference import CroplandClassifier, CroptypeClassifier
from worldcereal.openeo.preprocessing import worldcereal_preprocessed_inputs_gfmap

ONNX_DEPS_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/openeo/onnx_dependencies_1.16.3.zip"

class WorldCerealProduct(Enum):
"""Enum to define the different WorldCereal products."""

CROPLAND = "cropland"
CROPTYPE = "croptype"


ONNX_DEPS_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/openeo/onnx_dependencies_1.16.3.zip"

PRODUCT_SETTINGS = {
"cropland": {
WorldCerealProduct.CROPLAND: {
"features": {
"extractor": PrestoFeatureExtractor,
"parameters": {
Expand All @@ -30,7 +40,7 @@
},
},
},
"croptype": {
WorldCerealProduct.CROPTYPE: {
"features": {
"extractor": PrestoFeatureExtractor,
"parameters": {
Expand All @@ -39,7 +49,7 @@
},
},
"classification": {
"classifier": CroplandClassifier, # TODO: update to croptype classifier
"classifier": CroptypeClassifier,
"parameters": {
"classifier_url": "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/models/PhaseII/presto-ss-wc-ft-ct-30D_test_CROPTYPE9.onnx" # NOQA
},
Expand All @@ -48,13 +58,35 @@
}


@dataclass
class InferenceResults:
"""Dataclass to store the results of the WorldCereal job.
Attributes
----------
job_id : str
Job ID of the finished OpenEO job.
product_url : str
Public URL to the product accessible of the resulting OpenEO job.
output_path : Optional[Path]
Path to the output file, if it was downloaded locally.
product : WorldCerealProduct
Product that was generated.
"""

job_id: str
product_url: str
output_path: Optional[Path]
product: WorldCerealProduct


def generate_map(
spatial_extent: BoundingBoxExtent,
temporal_extent: TemporalContext,
backend_context: BackendContext,
output_path: Union[Path, str],
product: str = "cropland",
format: str = "GTiff",
output_path: Optional[Union[Path, str]],
product_type: WorldCerealProduct = WorldCerealProduct.CROPLAND,
out_format: str = "GTiff",
):
"""Main function to generate a WorldCereal product.
Expand All @@ -71,8 +103,8 @@ def generate_map(
"""

if product not in PRODUCT_SETTINGS.keys():
raise ValueError(f"Product {product} not supported.")
if product_type not in PRODUCT_SETTINGS.keys():
raise ValueError(f"Product {product_type.value} not supported.")

# Connect to openeo
connection = openeo.connect(
Expand All @@ -89,9 +121,9 @@ def generate_map(

# Run feature computer
features = apply_feature_extractor(
feature_extractor_class=PRODUCT_SETTINGS[product]["features"]["extractor"],
feature_extractor_class=PRODUCT_SETTINGS[product_type]["features"]["extractor"],
cube=inputs,
parameters=PRODUCT_SETTINGS[product]["features"]["parameters"],
parameters=PRODUCT_SETTINGS[product_type]["features"]["parameters"],
size=[
{"dimension": "x", "unit": "px", "value": 100},
{"dimension": "y", "unit": "px", "value": 100},
Expand All @@ -102,13 +134,15 @@ def generate_map(
],
)

if format not in ["GTiff", "NetCDF"]:
if out_format not in ["GTiff", "NetCDF"]:
raise ValueError(f"Format {format} not supported.")

classes = apply_model_inference(
model_inference_class=PRODUCT_SETTINGS[product]["classification"]["classifier"],
model_inference_class=PRODUCT_SETTINGS[product_type]["classification"][
"classifier"
],
cube=features,
parameters=PRODUCT_SETTINGS[product]["classification"]["parameters"],
parameters=PRODUCT_SETTINGS[product_type]["classification"]["parameters"],
size=[
{"dimension": "x", "unit": "px", "value": 100},
{"dimension": "y", "unit": "px", "value": 100},
Expand All @@ -121,21 +155,30 @@ def generate_map(
)

# Cast to uint8
if product == "cropland":
if product_type == WorldCerealProduct.CROPLAND:
classes = compress_uint8(classes)
else:
classes = compress_uint16(classes)

classes.execute_batch(
job = classes.execute_batch(
outputfile=output_path,
out_format=format,
out_format=out_format,
job_options={
"driver-memory": "4g",
"executor-memoryOverhead": "6g",
"udf-dependency-archives": [f"{ONNX_DEPS_URL}#onnx_deps"],
},
)

asset = job.get_results().get_assets()[0]

return InferenceResults(
job_id=classes.job_id,
product_url=asset.href,
output_path=output_path,
product=product_type,
)


def collect_inputs(
spatial_extent: BoundingBoxExtent,
Expand Down

0 comments on commit db9025a

Please sign in to comment.