Skip to content

Commit

Permalink
Added configuration data classes for cropland/croptype mapping workflow
Browse files Browse the repository at this point in the history
  • Loading branch information
GriffinBabe committed Jul 2, 2024
1 parent ab9293e commit 7f25b68
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 73 deletions.
16 changes: 8 additions & 8 deletions scripts/inference/cropland_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,25 +20,25 @@
parser.add_argument("maxx", type=float, help="Maximum X coordinate (east)")
parser.add_argument("maxy", type=float, help="Maximum Y coordinate (north)")
parser.add_argument(
"--epsg",
type=int,
default=4326,
help="EPSG code of the input `minx`, `miny`, `maxx`, `maxy` parameters.",
"start_date", type=str, help="Starting date for data extraction."
)
parser.add_argument("end_date", type=str, help="Ending date for data extraction.")
parser.add_argument(
"product",
type=str,
help="Product to generate. One of ['cropland', 'croptype']",
)
parser.add_argument(
"start_date", type=str, help="Starting date for data extraction."
)
parser.add_argument("end_date", type=str, help="Ending date for data extraction.")
parser.add_argument(
"output_path",
type=Path,
help="Path to folder where to save the resulting GeoTiff.",
)
parser.add_argument(
"--epsg",
type=int,
default=4326,
help="EPSG code of the input `minx`, `miny`, `maxx`, `maxy` parameters.",
)

args = parser.parse_args()

Expand Down
191 changes: 126 additions & 65 deletions src/worldcereal/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from openeo_gfmap.features.feature_extractor import apply_feature_extractor
from openeo_gfmap.inference.model_inference import apply_model_inference
from openeo_gfmap.preprocessing.scaling import compress_uint8, compress_uint16
from pydantic import BaseModel

from worldcereal.openeo.feature_extractor import PrestoFeatureExtractor
from worldcereal.openeo.inference import CroplandClassifier, CroptypeClassifier
Expand All @@ -24,44 +25,108 @@ class WorldCerealProduct(Enum):
CROPTYPE = "croptype"


ONNX_DEPS_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/openeo/onnx_dependencies_1.16.3.zip"
class FeaturesParameters(BaseModel):
"""Parameters for the feature extraction UDFs. Types are enforced by
Pydantic.
Attributes
----------
rescale_s1 : bool (default=False)
Whether to rescale Sentinel-1 bands before feature extraction. Should be
left to False, as this is done in the Presto UDF itself.
presto_model_url : str
Public URL to the Presto model used for feature extraction. The file
should be a PyTorch serialized model.
"""

rescale_s1: bool
presto_model_url: str


class ClassifierParameters(BaseModel):
"""Parameters for the classifier. Types are enforced by Pydantic.
Attributes
----------
classifier_url : str
Public URL to the classifier model. Te file should be an ONNX accepting
a `features` field for input data and returning two output probability
arrays `true` and `false`.
"""

classifier_url: str


class CropLandParameters(BaseModel):
"""Parameters for the cropland product inference pipeline. Types are
enforced by Pydantic.
Attributes
----------
feature_extractor : PrestoFeatureExtractor
Feature extractor to use for the inference. This class must be a
subclass of GFMAP's `PatchFeatureExtractor` and returns float32
features.
features_parameters : FeaturesParameters
Parameters for the feature extraction UDF. Will be serialized into a
dictionnary and passed in the process graph.
classifier : CroplandClassifier
Classifier to use for the inference. This class must be a subclass of
GFMAP's `ModelInference` and returns predictions/probabilities for
cropland.
classifier_parameters : ClassifierParameters
Parameters for the classifier UDF. Will be serialized into a dictionnary
and passed in the process graph.
"""

feature_extractor: PrestoFeatureExtractor = PrestoFeatureExtractor
features_parameters: FeaturesParameters = FeaturesParameters(
rescale_s1=False,
presto_model_url="https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal-minimal-inference/presto.pt", # NOQA
)
classifier: CroplandClassifier = CroplandClassifier
classifier_parameters: ClassifierParameters = ClassifierParameters(
classifier_url="https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal-minimal-inference/wc_catboost.onnx" # NOQA
)

PRODUCT_SETTINGS = {
WorldCerealProduct.CROPLAND: {
"features": {
"extractor": PrestoFeatureExtractor,
"parameters": {
"rescale_s1": False, # Will be done in the Presto UDF itself!
"presto_model_url": "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal-minimal-inference/presto.pt", # NOQA
},
},
"classification": {
"classifier": CroplandClassifier,
"parameters": {
"classifier_url": "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal-minimal-inference/wc_catboost.onnx" # NOQA
},
},
},
WorldCerealProduct.CROPTYPE: {
"features": {
"extractor": PrestoFeatureExtractor,
"parameters": {
"rescale_s1": False, # Will be done in the Presto UDF itself!
"presto_model_url": "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/models/PhaseII/presto-ss-wc-ft-ct-30D_test.pt", # NOQA
},
},
"classification": {
"classifier": CroptypeClassifier,
"parameters": {
"classifier_url": "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/models/PhaseII/presto-ss-wc-ft-ct-30D_test_CROPTYPE9.onnx" # NOQA
},
},
},
}

class CropTypeParameters(BaseModel):
"""Parameters for the croptype product inference pipeline. Types are
enforced by Pydantic.
@dataclass
class InferenceResults:
Attributes
----------
feature_extractor : PrestoFeatureExtractor
Feature extractor to use for the inference. This class must be a
subclass of GFMAP's `PatchFeatureExtractor` and returns float32
features.
features_parameters : FeaturesParameters
Parameters for the feature extraction UDF. Will be serialized into a
dictionnary and passed in the process graph.
classifier : CroptypeClassifier
Classifier to use for the inference. This class must be a subclass of
GFMAP's `ModelInference` and returns predictions/probabilities for
cropland classes.
classifier_parameters : ClassifierParameters
Parameters for the classifier UDF. Will be serialized into a dictionnary
and passed in the process graph.
"""

feature_extractor: PrestoFeatureExtractor = PrestoFeatureExtractor
feature_parameters: FeaturesParameters = FeaturesParameters(
rescale_s1=False,
presto_model_url="https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/models/PhaseII/presto-ss-wc-ft-ct-30D_test.pt", # NOQA
)
classifier: CroptypeClassifier = CroptypeClassifier
classifier_parameters: ClassifierParameters = ClassifierParameters(
classifier_url="https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/models/PhaseII/presto-ss-wc-ft-ct-30D_test_CROPTYPE9.onnx" # NOQA
)


ONNX_DEPS_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/openeo/onnx_dependencies_1.16.3.zip"


class InferenceResults(BaseModel):
"""Dataclass to store the results of the WorldCereal job.
Attributes
Expand All @@ -88,6 +153,8 @@ def generate_map(
backend_context: BackendContext,
output_path: Optional[Union[Path, str]],
product_type: WorldCerealProduct = WorldCerealProduct.CROPLAND,
cropland_parameters: CropLandParameters = CropLandParameters(),
croptype_parameters: CropTypeParameters = CropTypeParameters(),
out_format: str = "GTiff",
) -> InferenceResults:
"""Main function to generate a WorldCereal product.
Expand Down Expand Up @@ -120,7 +187,7 @@ def generate_map(
if the out_format is not supported
"""

if product_type not in PRODUCT_SETTINGS.keys():
if product_type not in WorldCerealProduct:
raise ValueError(f"Product {product_type.value} not supported.")

if out_format not in ["GTiff", "NetCDF"]:
Expand All @@ -145,18 +212,22 @@ def generate_map(

# Construct the feature extraction and model inference pipeline
if product_type == WorldCerealProduct.CROPLAND:
classes = _cropland_map(inputs)
classes = _cropland_map(inputs, cropland_parameters=cropland_parameters)
elif product_type == WorldCerealProduct.CROPTYPE:
# First compute cropland map
cropland_mask = (
_cropland_map(inputs)
_cropland_map(inputs, cropland_parameters=cropland_parameters)
.filter_bands("classification")
.reduce_dimension(
dimension="t", reducer="mean"
) # Temporary fix to make this work as mask
)

classes = _croptype_map(inputs, cropland_mask=cropland_mask)
classes = _croptype_map(
inputs,
croptype_parameters=croptype_parameters,
cropland_mask=cropland_mask,
)

# Submit the job
job = classes.execute_batch(
Expand Down Expand Up @@ -221,7 +292,9 @@ def collect_inputs(
)


def _cropland_map(inputs: DataCube) -> DataCube:
def _cropland_map(
inputs: DataCube, cropland_parameters: CropLandParameters
) -> DataCube:
"""Method to produce cropland map from preprocessed inputs, using
a Presto feature extractor and a CatBoost classifier.
Expand All @@ -238,13 +311,9 @@ def _cropland_map(inputs: DataCube) -> DataCube:

# Run feature computer
features = apply_feature_extractor(
feature_extractor_class=PRODUCT_SETTINGS[WorldCerealProduct.CROPLAND][
"features"
]["extractor"],
feature_extractor_class=cropland_parameters.feature_extractor,
cube=inputs,
parameters=PRODUCT_SETTINGS[WorldCerealProduct.CROPLAND]["features"][
"parameters"
],
parameters=cropland_parameters.features_parameters.dict(),
size=[
{"dimension": "x", "unit": "px", "value": 100},
{"dimension": "y", "unit": "px", "value": 100},
Expand All @@ -257,13 +326,9 @@ def _cropland_map(inputs: DataCube) -> DataCube:

# Run model inference on features
classes = apply_model_inference(
model_inference_class=PRODUCT_SETTINGS[WorldCerealProduct.CROPLAND][
"classification"
]["classifier"],
model_inference_class=cropland_parameters.classifier,
cube=features,
parameters=PRODUCT_SETTINGS[WorldCerealProduct.CROPLAND]["classification"][
"parameters"
],
parameters=cropland_parameters.classifier_parameters.dict(),
size=[
{"dimension": "x", "unit": "px", "value": 100},
{"dimension": "y", "unit": "px", "value": 100},
Expand All @@ -281,7 +346,11 @@ def _cropland_map(inputs: DataCube) -> DataCube:
return classes


def _croptype_map(inputs: DataCube, cropland_mask: DataCube = None) -> DataCube:
def _croptype_map(
inputs: DataCube,
croptype_parameters: CropTypeParameters,
cropland_mask: DataCube = None,
) -> DataCube:
"""Method to produce croptype map from preprocessed inputs, using
a Presto feature extractor and a CatBoost classifier.
Expand All @@ -300,13 +369,9 @@ def _croptype_map(inputs: DataCube, cropland_mask: DataCube = None) -> DataCube:

# Run feature computer
features = apply_feature_extractor(
feature_extractor_class=PRODUCT_SETTINGS[WorldCerealProduct.CROPTYPE][
"features"
]["extractor"],
feature_extractor_class=croptype_parameters.feature_extractor,
cube=inputs,
parameters=PRODUCT_SETTINGS[WorldCerealProduct.CROPTYPE]["features"][
"parameters"
],
parameters=croptype_parameters.feature_parameters.dict(),
size=[
{"dimension": "x", "unit": "px", "value": 100},
{"dimension": "y", "unit": "px", "value": 100},
Expand All @@ -319,13 +384,9 @@ def _croptype_map(inputs: DataCube, cropland_mask: DataCube = None) -> DataCube:

# Run model inference on features
classes = apply_model_inference(
model_inference_class=PRODUCT_SETTINGS[WorldCerealProduct.CROPTYPE][
"classification"
]["classifier"],
model_inference_class=croptype_parameters.classifier,
cube=features,
parameters=PRODUCT_SETTINGS[WorldCerealProduct.CROPTYPE]["classification"][
"parameters"
],
parameters=croptype_parameters.classifier_parameters.dict(),
size=[
{"dimension": "x", "unit": "px", "value": 100},
{"dimension": "y", "unit": "px", "value": 100},
Expand Down

0 comments on commit 7f25b68

Please sign in to comment.