Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update crop type mapping workflow #85

Merged
merged 9 commits into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ dependencies:
- openeo=0.29.0
- pyarrow=16.1.0
- python=3.10.0
- pytorch=2.3.0
- pytorch=2.3.1
- rasterio=1.3.10
- rioxarray=0.15.5
- scikit-image=0.22.0
Expand Down
20 changes: 17 additions & 3 deletions scripts/inference/cropland_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@

if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog="WC - Cropland Inference",
description="Cropland inference using GFMAP, Presto and WorldCereal classifiers",
prog="WC - Crop Mapping Inference",
description="Crop Mapping inference using GFMAP, Presto and WorldCereal classifiers",
)

parser.add_argument("minx", type=float, help="Minimum X coordinate (west)")
Expand All @@ -25,6 +25,11 @@
default=4326,
help="EPSG code of the input `minx`, `miny`, `maxx`, `maxy` parameters.",
)
parser.add_argument(
"product",
type=str,
help="Product to generate. One of ['cropland', 'croptype']",
)
parser.add_argument(
"start_date", type=str, help="Starting date for data extraction."
)
Expand All @@ -46,6 +51,15 @@
start_date = args.start_date
end_date = args.end_date

product = args.product

# minx, miny, maxx, maxy = (664000, 5611134, 665000, 5612134) # Small test
# minx, miny, maxx, maxy = (664000, 5611134, 684000, 5631134) # Large test
# epsg = 32631
# start_date = "2020-11-01"
# end_date = "2021-10-31"
# product = "croptype"

spatial_extent = BoundingBoxExtent(minx, miny, maxx, maxy, epsg)
temporal_extent = TemporalContext(start_date, end_date)

Expand All @@ -56,7 +70,7 @@
temporal_extent,
backend_context,
args.output_path,
product_type=WorldCerealProduct.CROPLAND,
product_type=WorldCerealProduct(product),
out_format="GTiff",
)
logger.success("Job finished:\n\t%s", job_results)
42 changes: 33 additions & 9 deletions scripts/inference/croptype_mapping_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from openeo_gfmap.inference.model_inference import apply_model_inference_local

from worldcereal.openeo.feature_extractor import PrestoFeatureExtractor
from worldcereal.openeo.inference import CroptypeClassifier
from worldcereal.openeo.inference import CroplandClassifier, CroptypeClassifier

TEST_FILE_URL = "https://artifactory.vgt.vito.be/artifactory/auxdata-public/worldcereal/presto/localtestdata/local_presto_inputs.nc"
TEST_FILE_PATH = Path.cwd() / "presto_test_inputs.nc"
Expand All @@ -40,29 +40,53 @@
.astype("uint16")
)

print("Running presto UDF locally")
features = apply_feature_extractor_local(
print("Get Presto cropland features")
cropland_features = apply_feature_extractor_local(
PrestoFeatureExtractor,
arr,
parameters={EPSG_HARMONIZED_NAME: 32631, "ignore_dependencies": True},
)

print("Running cropland classification inference UDF locally")

cropland_classification = apply_model_inference_local(
CroplandClassifier,
cropland_features,
parameters={
EPSG_HARMONIZED_NAME: 32631,
"ignore_dependencies": True,
"presto_model_url": PRESTO_URL,
},
)

features.to_netcdf(Path.cwd() / "presto_test_features_croptype.nc")
print("Get Presto croptype features")
croptype_features = apply_feature_extractor_local(
PrestoFeatureExtractor,
arr,
parameters={
EPSG_HARMONIZED_NAME: 32631,
"ignore_dependencies": True,
"presto_model_url": PRESTO_URL,
},
)

print("Running classification inference UDF locally")
print("Running croptype classification inference UDF locally")

classification = apply_model_inference_local(
croptype_classification = apply_model_inference_local(
CroptypeClassifier,
features,
croptype_features,
parameters={
EPSG_HARMONIZED_NAME: 32631,
"ignore_dependencies": True,
"classifier_url": CATBOOST_URL,
},
)

classification.to_netcdf(Path.cwd() / "test_classification_croptype.nc")
# Apply cropland mask -> on the backend this is done with mask process
croptype_classification = croptype_classification.where(
cropland_classification.sel(bands="classification") == 1, 0
)

croptype_classification.to_netcdf(
Path("/vitodata/worldcereal/validation/internal_validation/")
/ "test_classification_croptype_local.nc"
)
Loading
Loading