Skip to content

Commit

Permalink
Merge pull request #68 from WorldCereal/hotfix-seasons
Browse files Browse the repository at this point in the history
Bugfix seasonality check and max season diff check
  • Loading branch information
kvantricht authored Jun 20, 2024
2 parents 77d14af + 07a172f commit af569d3
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 68 deletions.
74 changes: 8 additions & 66 deletions scripts/inference/cropland_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,10 @@
import argparse
from pathlib import Path

import openeo
from openeo_gfmap import BoundingBoxExtent, TemporalContext
from openeo_gfmap.backend import Backend, BackendContext
from openeo_gfmap.features.feature_extractor import apply_feature_extractor
from openeo_gfmap.inference.model_inference import apply_model_inference
from openeo_gfmap.preprocessing.scaling import compress_uint8

from worldcereal.openeo.feature_extractor import PrestoFeatureExtractor
from worldcereal.openeo.inference import CroplandClassifier
from worldcereal.openeo.preprocessing import worldcereal_preprocessed_inputs_gfmap
from worldcereal.job import generate_map

ONNX_DEPS_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/openeo/onnx_dependencies_1.16.3.zip"

Expand Down Expand Up @@ -58,63 +52,11 @@

backend_context = BackendContext(Backend.FED)

connection = openeo.connect(
"https://openeo.creo.vito.be/openeo/"
).authenticate_oidc()

# Preparing the input cube for the inference
inputs = worldcereal_preprocessed_inputs_gfmap(
connection=connection,
backend_context=backend_context,
spatial_extent=spatial_extent,
temporal_extent=temporal_extent,
)

# Test feature computer
presto_parameters = {
"rescale_s1": False, # Will be done in the Presto UDF itself!
}

features = apply_feature_extractor(
feature_extractor_class=PrestoFeatureExtractor,
cube=inputs,
parameters=presto_parameters,
size=[
{"dimension": "x", "unit": "px", "value": 100},
{"dimension": "y", "unit": "px", "value": 100},
],
overlap=[
{"dimension": "x", "unit": "px", "value": 0},
{"dimension": "y", "unit": "px", "value": 0},
],
)

catboost_parameters = {}

classes = apply_model_inference(
model_inference_class=CroplandClassifier,
cube=features,
parameters=catboost_parameters,
size=[
{"dimension": "x", "unit": "px", "value": 100},
{"dimension": "y", "unit": "px", "value": 100},
{"dimension": "t", "value": "P1D"},
],
overlap=[
{"dimension": "x", "unit": "px", "value": 0},
{"dimension": "y", "unit": "px", "value": 0},
],
)

# Cast to uint8
classes = compress_uint8(classes)

classes.execute_batch(
outputfile=args.output_path,
out_format="GTiff",
job_options={
"driver-memory": "4g",
"executor-memoryOverhead": "12g",
"udf-dependency-archives": [f"{ONNX_DEPS_URL}#onnx_deps"],
},
generate_map(
spatial_extent,
temporal_extent,
backend_context,
args.output_path,
product="cropland",
format="GTiff",
)
108 changes: 108 additions & 0 deletions src/worldcereal/job.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
from pathlib import Path
from typing import Union

import openeo
from openeo_gfmap import BackendContext, BoundingBoxExtent, TemporalContext
from openeo_gfmap.features.feature_extractor import apply_feature_extractor
from openeo_gfmap.inference.model_inference import apply_model_inference
from openeo_gfmap.preprocessing.scaling import compress_uint8

from worldcereal.openeo.feature_extractor import PrestoFeatureExtractor
from worldcereal.openeo.inference import CroplandClassifier
from worldcereal.openeo.preprocessing import worldcereal_preprocessed_inputs_gfmap

ONNX_DEPS_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/openeo/onnx_dependencies_1.16.3.zip"


def generate_map(
spatial_extent: BoundingBoxExtent,
temporal_extent: TemporalContext,
backend_context: BackendContext,
output_path: Union[Path, str],
product: str = "cropland",
format: str = "GTiff",
):
"""Main function to generate a WorldCereal product.
Args:
spatial_extent (BoundingBoxExtent): spatial extent of the map
temporal_extent (TemporalContext): temporal range to consider
backend_context (BackendContext): backend to run the job on
output_path (Union[Path, str]): output path to download the product to
product (str, optional): product describer. Defaults to "cropland".
format (str, optional): Output format. Defaults to "GTiff".
Raises:
ValueError: if the product is not supported
"""

# Connect to openeo
connection = openeo.connect(
"https://openeo.creo.vito.be/openeo/"
).authenticate_oidc()

# Preparing the input cube for the inference
inputs = worldcereal_preprocessed_inputs_gfmap(
connection=connection,
backend_context=backend_context,
spatial_extent=spatial_extent,
temporal_extent=temporal_extent,
)

# Run feature computer
presto_parameters = {
"rescale_s1": False, # Will be done in the Presto UDF itself!
}

features = apply_feature_extractor(
feature_extractor_class=PrestoFeatureExtractor,
cube=inputs,
parameters=presto_parameters,
size=[
{"dimension": "x", "unit": "px", "value": 100},
{"dimension": "y", "unit": "px", "value": 100},
],
overlap=[
{"dimension": "x", "unit": "px", "value": 0},
{"dimension": "y", "unit": "px", "value": 0},
],
)

if product == "cropland":
# initiate default cropland model
model_inference_class = CroplandClassifier
model_inference_parameters = {}
else:
raise ValueError(f"Product {product} not supported.")

if format not in ["GTiff", "NetCDF"]:
raise ValueError(f"Format {format} not supported.")

classes = apply_model_inference(
model_inference_class=model_inference_class,
cube=features,
parameters=model_inference_parameters,
size=[
{"dimension": "x", "unit": "px", "value": 100},
{"dimension": "y", "unit": "px", "value": 100},
{"dimension": "t", "value": "P1D"},
],
overlap=[
{"dimension": "x", "unit": "px", "value": 0},
{"dimension": "y", "unit": "px", "value": 0},
],
)

# Cast to uint8
classes = compress_uint8(classes)

classes.execute_batch(
outputfile=output_path,
out_format=format,
job_options={
"driver-memory": "4g",
"executor-memoryOverhead": "12g",
"udf-dependency-archives": [f"{ONNX_DEPS_URL}#onnx_deps"],
},
)
48 changes: 47 additions & 1 deletion src/worldcereal/seasons.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ class NoSeasonError(Exception):
pass


class SeasonMaxDiffError(Exception):
pass


def doy_to_angle(day_of_year, total_days=365):
return 2 * math.pi * (day_of_year / total_days)

Expand All @@ -25,6 +29,32 @@ def angle_to_doy(angle, total_days=365):
return (angle / (2 * math.pi)) * total_days


def max_doy_difference(doy_array):
"""Method to check the max difference in days between all DOY values
in an array taking into account wrap-around effects due to the circular nature
"""

doy_array = np.expand_dims(doy_array, axis=1)
x, y = np.meshgrid(doy_array, doy_array.T)

days_in_year = 365 # True for crop calendars

# Step 2: Calculate the direct difference
direct_difference = np.abs(x - y)

# Step 3: Calculate the wrap-around difference
wrap_around_difference = days_in_year - direct_difference

# Step 4: Determine the minimum difference
effective_difference = np.min(
np.stack([direct_difference, wrap_around_difference]), axis=0
)

# Step 5: Determine the maximum difference for all combinations

return effective_difference.max()


def circular_median_day_of_year(doy_array, total_days=365):
"""This function computes the median doy from a given array
taking into account its circular nature. Still has to be used with caution!
Expand Down Expand Up @@ -211,7 +241,10 @@ def season_doys_to_dates(


def get_processing_dates_for_extent(
extent: BoundingBoxExtent, year: int, season: str = "tc-annual"
extent: BoundingBoxExtent,
year: int,
season: str = "tc-annual",
max_seasonality_difference: int = 60,
):
"""Function to retrieve required temporal range of input products for a
given extent, season and year. Based on the requested season's end date
Expand All @@ -221,9 +254,12 @@ def get_processing_dates_for_extent(
extent (BoundingBoxExtent): extent for which to infer dates
year (int): year in which the end of season needs to be
season (str): season identifier for which to infer dates. Defaults to tc-annual
max_seasonality_difference (int): maximum difference in seasonality for all pixels
in extent before raising an exception. Defaults to 60.
Raises:
ValueError: invalid season specified
SeasonMaxDiffError: raised when seasonality difference is too large
Returns:
(start_date, end_date): tuple of date strings specifying
Expand All @@ -243,6 +279,16 @@ def get_processing_dates_for_extent(
if not np.isfinite(eos_doy).any():
raise NoSeasonError(f"No valid EOS DOY found for season `{season}`")

# Only consider valid seasonality pixels
eos_doy = eos_doy[np.isfinite(eos_doy)]

# Check max seasonality difference
seasonality_difference = max_doy_difference(eos_doy)
if seasonality_difference > max_seasonality_difference:
raise SeasonMaxDiffError(
f"Seasonality difference too large: {seasonality_difference} days"
)

# Compute median DOY
eos_doy_median = circular_median_day_of_year(eos_doy)

Expand Down
2 changes: 1 addition & 1 deletion tests/worldcerealtests/test_seasons.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def test_doy_to_date_after():
def test_get_processing_dates_for_extent():
# Test to check if we can infer processing dates for default season
# tc-annual
bounds = (167286, 553423, 943774, 997257)
bounds = (574680, 5621800, 575320, 5622440)
epsg = 32631
year = 2021
extent = BoundingBoxExtent(*bounds, epsg)
Expand Down

0 comments on commit af569d3

Please sign in to comment.