Skip to content

Commit

Permalink
Merge pull request #85 from WorldCereal/cropland-masking
Browse files Browse the repository at this point in the history
Update crop type mapping workflow
  • Loading branch information
kvantricht authored Jun 27, 2024
2 parents 012d6b7 + 3437b23 commit 595099c
Show file tree
Hide file tree
Showing 5 changed files with 243 additions and 78 deletions.
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ dependencies:
- openeo=0.29.0
- pyarrow=16.1.0
- python=3.10.0
- pytorch=2.3.0
- pytorch=2.3.1
- rasterio=1.3.10
- rioxarray=0.15.5
- scikit-image=0.22.0
Expand Down
20 changes: 17 additions & 3 deletions scripts/inference/cropland_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@

if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog="WC - Cropland Inference",
description="Cropland inference using GFMAP, Presto and WorldCereal classifiers",
prog="WC - Crop Mapping Inference",
description="Crop Mapping inference using GFMAP, Presto and WorldCereal classifiers",
)

parser.add_argument("minx", type=float, help="Minimum X coordinate (west)")
Expand All @@ -25,6 +25,11 @@
default=4326,
help="EPSG code of the input `minx`, `miny`, `maxx`, `maxy` parameters.",
)
parser.add_argument(
"product",
type=str,
help="Product to generate. One of ['cropland', 'croptype']",
)
parser.add_argument(
"start_date", type=str, help="Starting date for data extraction."
)
Expand All @@ -46,6 +51,15 @@
start_date = args.start_date
end_date = args.end_date

product = args.product

# minx, miny, maxx, maxy = (664000, 5611134, 665000, 5612134) # Small test
# minx, miny, maxx, maxy = (664000, 5611134, 684000, 5631134) # Large test
# epsg = 32631
# start_date = "2020-11-01"
# end_date = "2021-10-31"
# product = "croptype"

spatial_extent = BoundingBoxExtent(minx, miny, maxx, maxy, epsg)
temporal_extent = TemporalContext(start_date, end_date)

Expand All @@ -56,7 +70,7 @@
temporal_extent,
backend_context,
args.output_path,
product_type=WorldCerealProduct.CROPLAND,
product_type=WorldCerealProduct(product),
out_format="GTiff",
)
logger.success("Job finished:\n\t%s", job_results)
42 changes: 33 additions & 9 deletions scripts/inference/croptype_mapping_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from openeo_gfmap.inference.model_inference import apply_model_inference_local

from worldcereal.openeo.feature_extractor import PrestoFeatureExtractor
from worldcereal.openeo.inference import CroptypeClassifier
from worldcereal.openeo.inference import CroplandClassifier, CroptypeClassifier

TEST_FILE_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/presto/localtestdata/local_presto_inputs.nc"
TEST_FILE_PATH = Path.cwd() / "presto_test_inputs.nc"
Expand All @@ -40,29 +40,53 @@
.astype("uint16")
)

print("Running presto UDF locally")
features = apply_feature_extractor_local(
print("Get Presto cropland features")
cropland_features = apply_feature_extractor_local(
PrestoFeatureExtractor,
arr,
parameters={EPSG_HARMONIZED_NAME: 32631, "ignore_dependencies": True},
)

print("Running cropland classification inference UDF locally")

cropland_classification = apply_model_inference_local(
CroplandClassifier,
cropland_features,
parameters={
EPSG_HARMONIZED_NAME: 32631,
"ignore_dependencies": True,
"presto_model_url": PRESTO_URL,
},
)

features.to_netcdf(Path.cwd() / "presto_test_features_croptype.nc")
print("Get Presto croptype features")
croptype_features = apply_feature_extractor_local(
PrestoFeatureExtractor,
arr,
parameters={
EPSG_HARMONIZED_NAME: 32631,
"ignore_dependencies": True,
"presto_model_url": PRESTO_URL,
},
)

print("Running classification inference UDF locally")
print("Running croptype classification inference UDF locally")

classification = apply_model_inference_local(
croptype_classification = apply_model_inference_local(
CroptypeClassifier,
features,
croptype_features,
parameters={
EPSG_HARMONIZED_NAME: 32631,
"ignore_dependencies": True,
"classifier_url": CATBOOST_URL,
},
)

classification.to_netcdf(Path.cwd() / "test_classification_croptype.nc")
# Apply cropland mask -> on the backend this is done with mask process
croptype_classification = croptype_classification.where(
cropland_classification.sel(bands="classification") == 1, 0
)

croptype_classification.to_netcdf(
Path("/vitodata/worldcereal/validation/internal_validation/")
/ "test_classification_croptype_local.nc"
)
Loading

0 comments on commit 595099c

Please sign in to comment.