Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update crop type mapping workflow #85

Merged
merged 9 commits into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ dependencies:
- openeo=0.29.0
- pyarrow=16.1.0
- python=3.10.0
- pytorch=2.3.0
- pytorch=2.3.1
- rasterio=1.3.10
- rioxarray=0.15.5
- scikit-image=0.22.0
Expand Down
20 changes: 17 additions & 3 deletions scripts/inference/cropland_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@

if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog="WC - Cropland Inference",
description="Cropland inference using GFMAP, Presto and WorldCereal classifiers",
prog="WC - Crop Mapping Inference",
description="Crop Mapping inference using GFMAP, Presto and WorldCereal classifiers",
)

parser.add_argument("minx", type=float, help="Minimum X coordinate (west)")
Expand All @@ -25,6 +25,11 @@
default=4326,
help="EPSG code of the input `minx`, `miny`, `maxx`, `maxy` parameters.",
)
parser.add_argument(
"product",
type=str,
help="Product to generate. One of ['cropland', 'croptype']",
)
parser.add_argument(
"start_date", type=str, help="Starting date for data extraction."
)
Expand All @@ -46,6 +51,15 @@
start_date = args.start_date
end_date = args.end_date

product = args.product

# minx, miny, maxx, maxy = (664000, 5611134, 665000, 5612134) # Small test
# minx, miny, maxx, maxy = (664000, 5611134, 684000, 5631134) # Large test
# epsg = 32631
# start_date = "2020-11-01"
# end_date = "2021-10-31"
# product = "croptype"

spatial_extent = BoundingBoxExtent(minx, miny, maxx, maxy, epsg)
temporal_extent = TemporalContext(start_date, end_date)

Expand All @@ -56,7 +70,7 @@
temporal_extent,
backend_context,
args.output_path,
product_type=WorldCerealProduct.CROPLAND,
product_type=WorldCerealProduct(product),
out_format="GTiff",
)
logger.success("Job finished:\n\t%s", job_results)
42 changes: 33 additions & 9 deletions scripts/inference/croptype_mapping_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from openeo_gfmap.inference.model_inference import apply_model_inference_local

from worldcereal.openeo.feature_extractor import PrestoFeatureExtractor
from worldcereal.openeo.inference import CroptypeClassifier
from worldcereal.openeo.inference import CroplandClassifier, CroptypeClassifier

TEST_FILE_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/presto/localtestdata/local_presto_inputs.nc"
TEST_FILE_PATH = Path.cwd() / "presto_test_inputs.nc"
Expand All @@ -40,29 +40,53 @@
.astype("uint16")
)

print("Running presto UDF locally")
features = apply_feature_extractor_local(
print("Get Presto cropland features")
cropland_features = apply_feature_extractor_local(
PrestoFeatureExtractor,
arr,
parameters={EPSG_HARMONIZED_NAME: 32631, "ignore_dependencies": True},
)

print("Running cropland classification inference UDF locally")

cropland_classification = apply_model_inference_local(
CroplandClassifier,
cropland_features,
parameters={
EPSG_HARMONIZED_NAME: 32631,
"ignore_dependencies": True,
"presto_model_url": PRESTO_URL,
},
)

features.to_netcdf(Path.cwd() / "presto_test_features_croptype.nc")
print("Get Presto croptype features")
croptype_features = apply_feature_extractor_local(
PrestoFeatureExtractor,
arr,
parameters={
EPSG_HARMONIZED_NAME: 32631,
"ignore_dependencies": True,
"presto_model_url": PRESTO_URL,
},
)

print("Running classification inference UDF locally")
print("Running croptype classification inference UDF locally")

classification = apply_model_inference_local(
croptype_classification = apply_model_inference_local(
CroptypeClassifier,
features,
croptype_features,
parameters={
EPSG_HARMONIZED_NAME: 32631,
"ignore_dependencies": True,
"classifier_url": CATBOOST_URL,
},
)

classification.to_netcdf(Path.cwd() / "test_classification_croptype.nc")
# Apply cropland mask -> on the backend this is done with mask process
croptype_classification = croptype_classification.where(
cropland_classification.sel(bands="classification") == 1, 0
)

croptype_classification.to_netcdf(
Path("/vitodata/worldcereal/validation/internal_validation/")
/ "test_classification_croptype_local.nc"
)
190 changes: 147 additions & 43 deletions src/worldcereal/job.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""Executing inference jobs on the OpenEO backend."""

from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import Optional, Union

import openeo
from openeo import DataCube
from openeo_gfmap import BackendContext, BoundingBoxExtent, TemporalContext
from openeo_gfmap.features.feature_extractor import apply_feature_extractor
from openeo_gfmap.inference.model_inference import apply_model_inference
Expand Down Expand Up @@ -100,80 +102,65 @@ def generate_map(

Raises:
ValueError: if the product is not supported
ValueError: if the out_format is not supported
kvantricht marked this conversation as resolved.
Show resolved Hide resolved
ValueError: if a cropland mask is applied on a cropland workflow


"""

if product_type not in PRODUCT_SETTINGS.keys():
raise ValueError(f"Product {product_type.value} not supported.")

if out_format not in ["GTiff", "NetCDF"]:
raise ValueError(f"Format {format} not supported.")

# Connect to openeo
connection = openeo.connect(
"https://openeo.creo.vito.be/openeo/"
).authenticate_oidc()

# Preparing the input cube for the inference
# Preparing the input cube for inference
inputs = worldcereal_preprocessed_inputs_gfmap(
connection=connection,
backend_context=backend_context,
spatial_extent=spatial_extent,
temporal_extent=temporal_extent,
)

# Run feature computer
features = apply_feature_extractor(
feature_extractor_class=PRODUCT_SETTINGS[product_type]["features"]["extractor"],
cube=inputs,
parameters=PRODUCT_SETTINGS[product_type]["features"]["parameters"],
size=[
{"dimension": "x", "unit": "px", "value": 100},
{"dimension": "y", "unit": "px", "value": 100},
],
overlap=[
{"dimension": "x", "unit": "px", "value": 0},
{"dimension": "y", "unit": "px", "value": 0},
],
)

if out_format not in ["GTiff", "NetCDF"]:
raise ValueError(f"Format {format} not supported.")

classes = apply_model_inference(
model_inference_class=PRODUCT_SETTINGS[product_type]["classification"][
"classifier"
],
cube=features,
parameters=PRODUCT_SETTINGS[product_type]["classification"]["parameters"],
size=[
{"dimension": "x", "unit": "px", "value": 100},
{"dimension": "y", "unit": "px", "value": 100},
{"dimension": "t", "value": "P1D"},
],
overlap=[
{"dimension": "x", "unit": "px", "value": 0},
{"dimension": "y", "unit": "px", "value": 0},
],
)
# Explicit filtering again for bbox because of METEO low
# resolution causing issues
inputs = inputs.filter_bbox(dict(spatial_extent))

# Cast to uint8
# Construct the feature extraction and model inference pipeline
if product_type == WorldCerealProduct.CROPLAND:
classes = compress_uint8(classes)
else:
classes = compress_uint16(classes)

classes = _cropland_map(inputs)
elif product_type == WorldCerealProduct.CROPTYPE:
# First compute cropland map
cropland_mask = (
_cropland_map(inputs)
.filter_bands("classification")
.reduce_dimension(
dimension="t", reducer="mean"
) # Temporary fix to make this work as mask
)

classes = _croptype_map(inputs, cropland_mask=cropland_mask)

# Submit the job
job = classes.execute_batch(
outputfile=output_path,
out_format=out_format,
job_options={
"driver-memory": "4g",
"executor-memoryOverhead": "6g",
"executor-memoryOverhead": "4g",
"udf-dependency-archives": [f"{ONNX_DEPS_URL}#onnx_deps"],
},
)

asset = job.get_results().get_assets()[0]

return InferenceResults(
job_id=classes.job_id,
job_id=job.job_id,
product_url=asset.href,
output_path=output_path,
product=product_type,
Expand All @@ -185,7 +172,7 @@ def collect_inputs(
temporal_extent: TemporalContext,
backend_context: BackendContext,
output_path: Union[Path, str],
):
) -> DataCube:
"""Function to retrieve preprocessed inputs that are being
used in the generation of WorldCereal products.

Expand Down Expand Up @@ -218,3 +205,120 @@ def collect_inputs(
out_format="NetCDF",
job_options={"driver-memory": "4g", "executor-memoryOverhead": "4g"},
)


def _cropland_map(inputs: DataCube) -> DataCube:
jdegerickx marked this conversation as resolved.
Show resolved Hide resolved
"""Method to produce cropland map from preprocessed inputs, using
a Presto feature extractor and a CatBoost classifier.

Args:
inputs (DataCube): preprocessed input cube

Returns:
DataCube: binary labels and probability
"""

# Run feature computer
features = apply_feature_extractor(
feature_extractor_class=PRODUCT_SETTINGS[WorldCerealProduct.CROPLAND][
"features"
]["extractor"],
cube=inputs,
parameters=PRODUCT_SETTINGS[WorldCerealProduct.CROPLAND]["features"][
"parameters"
],
size=[
{"dimension": "x", "unit": "px", "value": 100},
{"dimension": "y", "unit": "px", "value": 100},
],
overlap=[
{"dimension": "x", "unit": "px", "value": 0},
{"dimension": "y", "unit": "px", "value": 0},
],
)

# Run model inference on features
classes = apply_model_inference(
model_inference_class=PRODUCT_SETTINGS[WorldCerealProduct.CROPLAND][
"classification"
]["classifier"],
cube=features,
parameters=PRODUCT_SETTINGS[WorldCerealProduct.CROPLAND]["classification"][
"parameters"
],
size=[
{"dimension": "x", "unit": "px", "value": 100},
{"dimension": "y", "unit": "px", "value": 100},
{"dimension": "t", "value": "P1D"},
],
overlap=[
{"dimension": "x", "unit": "px", "value": 0},
{"dimension": "y", "unit": "px", "value": 0},
],
)

# Cast to uint8
classes = compress_uint8(classes)

return classes


def _croptype_map(inputs: DataCube, cropland_mask: DataCube = None) -> DataCube:
"""Method to produce croptype map from preprocessed inputs, using
a Presto feature extractor and a CatBoost classifier.

Args:
inputs (DataCube): preprocessed input cube
cropland_mask (DataCube): optional cropland mask

Returns:
DataCube: croptype labels and probability
"""

# Run feature computer
features = apply_feature_extractor(
feature_extractor_class=PRODUCT_SETTINGS[WorldCerealProduct.CROPTYPE][
"features"
]["extractor"],
cube=inputs,
parameters=PRODUCT_SETTINGS[WorldCerealProduct.CROPTYPE]["features"][
"parameters"
],
size=[
{"dimension": "x", "unit": "px", "value": 100},
{"dimension": "y", "unit": "px", "value": 100},
],
overlap=[
{"dimension": "x", "unit": "px", "value": 0},
{"dimension": "y", "unit": "px", "value": 0},
],
)

# Run model inference on features
classes = apply_model_inference(
model_inference_class=PRODUCT_SETTINGS[WorldCerealProduct.CROPTYPE][
"classification"
]["classifier"],
cube=features,
parameters=PRODUCT_SETTINGS[WorldCerealProduct.CROPTYPE]["classification"][
"parameters"
],
size=[
{"dimension": "x", "unit": "px", "value": 100},
{"dimension": "y", "unit": "px", "value": 100},
{"dimension": "t", "value": "P1D"},
],
overlap=[
{"dimension": "x", "unit": "px", "value": 0},
{"dimension": "y", "unit": "px", "value": 0},
],
)

# Mask cropland
if cropland_mask is not None:
classes = classes.mask(cropland_mask == 0, replacement=0)

# Cast to uint16
classes = compress_uint16(classes)
GriffinBabe marked this conversation as resolved.
Show resolved Hide resolved

return classes
Loading
Loading