Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

44 translate cropno crop inference udf to gfmap #48

Merged
merged 38 commits into from
Jun 12, 2024
Merged
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
de4540d
added mvp repo for minimal inference workflow
HansVRP May 6, 2024
c749633
minimal presto functionality
HansVRP May 6, 2024
1020ee4
hv remove pandas to xarray conversion
HansVRP May 7, 2024
bc9bd1a
Succesful run, todo is fix apply metadata for bands
HansVRP May 7, 2024
2ec4ebd
rework UDF and include presto UDF
HansVRP May 8, 2024
f17e0a8
fix: resolve presto specific UDF and include udf_long which does not …
HansVRP May 16, 2024
6ae5da2
fix: test remote inference
HansVRP May 21, 2024
4a3b74b
fix: dynamic size
HansVRP May 21, 2024
f3f4b15
Work in xarray as much as possible
kvantricht May 21, 2024
e0c1d05
Fix typing errors
kvantricht May 21, 2024
433f001
fix: inference
HansVRP May 22, 2024
af151f7
fix: udf_long
HansVRP May 22, 2024
44f9651
Updated UDF (still flips result though!)
kvantricht May 23, 2024
e0ca616
user order="F" for reshaping fixes the flipping issue
kvantricht May 23, 2024
7968ba0
Avoid use of rearrange. Bug remains.
kvantricht May 24, 2024
a579be7
Avoid the use of np.swapaxes
kvantricht May 24, 2024
42218f0
Add a comment for clarification
kvantricht May 24, 2024
919391c
Updated inference notebook
kvantricht May 24, 2024
9f105e6
Merge branch 'kvt_mvp_inferenceUDF' of https://github.com/WorldCereal…
GriffinBabe May 27, 2024
b74ecad
Updated inference notebook
kvantricht May 27, 2024
29b3034
Merge branch 'hv_mvp_inferenceUDF' of github.com:WorldCereal/worldcer…
kvantricht May 27, 2024
3e03ab4
Merge pull request #46 from WorldCereal/kvt_mvp_inferenceUDF
HansVRP May 27, 2024
7915b93
Updating preprocessing to match better kristof's results
GriffinBabe May 28, 2024
005841a
Added feature extractor with GFMAP compatibility
GriffinBabe May 28, 2024
f7d09b9
fix: clean-up + updated dependencies
HansVRP May 29, 2024
63722e5
Added presto feature computer using GFMAP
GriffinBabe May 31, 2024
14ff604
Merge branch 'hv_mvp_inferenceUDF' into 44-translate-cropno-crop-infe…
GriffinBabe May 31, 2024
5ed426b
UDFs are passing and reformatting for repository
GriffinBabe May 31, 2024
b443e8b
Cleaned up more by deleting a few duplicate codes
GriffinBabe May 31, 2024
2add215
Merge branch '44-translate-cropno-crop-inference-udf-to-gfmap' of htt…
GriffinBabe May 31, 2024
3251919
Fixed conflicts
GriffinBabe May 31, 2024
7b7ca4d
Implemented changed request by kristof
GriffinBabe Jun 3, 2024
3faef72
make use of external dependency through whl
kvantricht Jun 3, 2024
aa423c6
Merge branch 'main' into 44-translate-cropno-crop-inference-udf-to-gfmap
kvantricht Jun 7, 2024
8723fae
Changed to work with new openeo way of handling dependencies
GriffinBabe Jun 7, 2024
34e4621
Merge branch '44-translate-cropno-crop-inference-udf-to-gfmap' of htt…
GriffinBabe Jun 7, 2024
df87509
Now working with dependency as zip file and presto code packed as whe…
GriffinBabe Jun 11, 2024
20746f4
Changed dependencies .zip file
GriffinBabe Jun 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,18 @@ notebooks/S1A_IW_GRDH_1SDV_20191026T153410_20191026T153444_029631_035FDA_2640.SA
scripts/classification/tenpercent_sparse/.nfs00000000c35c9cfd00000035
download.zip
catboost_info/catboost_training.json

*.cbm
*.pt
*.onnx
*.nc
*.7z
*.dmg
*.gz
*.iso
*.jar
*.rar
*.tar
*.zip

.notebook-tests/
91 changes: 91 additions & 0 deletions minimal_wc_presto/ONNX_conversion.py
GriffinBabe marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
#%% Catboost
import catboost
from catboost.utils import convert_to_onnx_object
import onnx

# Load your CatBoost model
model = catboost.CatBoost()
model.load_model('./model/catboost.cbm')

onnx_model = convert_to_onnx_object(model)
onnx.save(onnx_model, './model/wc_catboost.onnx')





#%% For the pytorch model we need to know the input shape

import torch
from presto.presto import Presto
from model_class import PrestoFeatureExtractor
import xarray as xr
import numpy as np

#load the data
ds = xr.open_dataset("./data/belgium_good_2020-12-01_2021-11-30.nc", engine='netcdf4')
arr = ds.drop('crs').to_array(dim='bands')


# Load the Presto model
PRESTO_PATH = './model/presto.pt'
presto_model = Presto.load_pretrained(model_path=PRESTO_PATH, strict=False)
presto_extractor = PrestoFeatureExtractor(presto_model)

#get the required presto input through the feature extractor
input = presto_extractor.create_presto_input(arr)

x_sample = torch.tensor(np.expand_dims(input[0][0], axis=0), dtype=torch.float32) # Shape matches the shape of eo data in your DataLoader
dw_sample = torch.tensor(np.expand_dims(input[1][0], axis=0), dtype=torch.long) # Shape matches the shape of dynamic_world data in your DataLoader
month_sample = torch.tensor(np.expand_dims(input[2][0], axis = 0), dtype=torch.long) # Shape matches the shape of months data in your DataLoader
latlons_sample = torch.tensor(np.expand_dims(input[3][0], axis = 0), dtype=torch.float32) # Shape matches the shape of latlons data in your DataLoader
mask_sample = torch.tensor(np.expand_dims(input[4][0], axis = 0), dtype=torch.int)

encoder_model = presto_model.encoder



with torch.no_grad():
encoder_output = encoder_model(
x_sample, # Add batch dimension
dynamic_world=dw_sample, # Add batch dimension
mask=mask_sample, # Add batch dimension
latlons=latlons_sample, # Add batch dimension
month=month_sample # Add batch dimension
)

#%%

# Export the encoder model to ONNX
torch.onnx.export(
encoder_model,
(x_sample, dw_sample, latlons_sample,mask_sample, month_sample),
'./model/wc_presto.onnx',
input_names=["x", "dynamic_world", "latlons", "mask", "month"],
output_names=["output"],
dynamic_axes={
"x": {0: "batch_size"},
"dynamic_world": {0: "batch_size"},
"mask": {0: "batch_size"},
"latlons": {0: "batch_size"},
"month": {0: "batch_size"},
"output": {0: "batch_size"}
}
)
#%%
# Export the model to ONNX
torch.onnx.export(
encoder_model,
(x_sample, dw_sample, latlons_sample, month_sample, mask_sample),
'./model/wc_presto.onnx',
input_names=["x", "dynamic_world", "latlons", "month", "mask"],
output_names=["output"],
dynamic_axes={
"x": {0: "batch_size"},
"dynamic_world": {0: "batch_size"},
"mask": {0: "batch_size"},
"latlons": {0: "batch_size"},
"month": {0: "batch_size"},
"output": {0: "batch_size"}
}
)
414 changes: 414 additions & 0 deletions minimal_wc_presto/backend_inference_example_openeo.ipynb

Large diffs are not rendered by default.

Empty file.
165 changes: 165 additions & 0 deletions minimal_wc_presto/mvp_wc_presto/dataops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
# This file contains many of the constants
# defined in presto/dataops
import warnings
from collections import OrderedDict
from typing import List
from typing import OrderedDict as OrderedDictType

import numpy as np
import torch

"""
For easier normalization of the band values (instead of needing to recompute
the normalization dict with the addition of new data), we provide maximum
values for each band
"""
S1_BANDS = ["VV", "VH"]
# EarthEngine estimates Sentinel-1 values range from -50 to 1
S1_SHIFT_VALUES = [25.0, 25.0]
S1_DIV_VALUES = [25.0, 25.0]
S2_BANDS = [
"B1",
"B2",
"B3",
"B4",
"B5",
"B6",
"B7",
"B8",
"B8A",
"B9",
"B10",
"B11",
"B12",
]
S2_SHIFT_VALUES = [float(0.0)] * len(S2_BANDS)
S2_DIV_VALUES = [float(1e4)] * len(S2_BANDS)
ERA5_BANDS = ["temperature_2m", "total_precipitation"]
# for temperature, shift to celcius and then divide by 35 based on notebook (ranges from)
# 37 to -22 degrees celcius
# For rainfall, based on
# https://github.com/nasaharvest/lem/blob/main/notebooks/exploratory_data_analysis.ipynb
ERA5_SHIFT_VALUES = [-272.15, 0.0]
ERA5_DIV_VALUES = [35.0, 0.03]
SRTM_BANDS = ["elevation", "slope"]
# visually gauged 90th percentile from
# https://github.com/nasaharvest/lem/blob/main/notebooks/exploratory_data_analysis.ipynb
SRTM_SHIFT_VALUES = [0.0, 0.0]
SRTM_DIV_VALUES = [2000.0, 50.0]

DYNAMIC_BANDS = S1_BANDS + S2_BANDS + ERA5_BANDS
STATIC_BANDS = SRTM_BANDS

DYNAMIC_BANDS_SHIFT = S1_SHIFT_VALUES + S2_SHIFT_VALUES + ERA5_SHIFT_VALUES
DYNAMIC_BANDS_DIV = S1_DIV_VALUES + S2_DIV_VALUES + ERA5_DIV_VALUES

STATIC_BANDS_SHIFT = SRTM_SHIFT_VALUES
STATIC_BANDS_DIV = SRTM_DIV_VALUES

# These bands are what is created by the Engineer. If the engineer changes, the bands
# here will need to change (and vice versa)
REMOVED_BANDS = ["B1", "B10"]
RAW_BANDS = DYNAMIC_BANDS + STATIC_BANDS

BANDS = [x for x in DYNAMIC_BANDS if x not in REMOVED_BANDS] + STATIC_BANDS + ["NDVI"]
# NDVI is between 0 and 1
ADD_BY = (
[DYNAMIC_BANDS_SHIFT[i] for i, x in enumerate(DYNAMIC_BANDS) if x not in REMOVED_BANDS]
+ STATIC_BANDS_SHIFT
+ [0.0]
)
DIVIDE_BY = (
[DYNAMIC_BANDS_DIV[i] for i, x in enumerate(DYNAMIC_BANDS) if x not in REMOVED_BANDS]
+ STATIC_BANDS_DIV
+ [1.0]
)

NUM_TIMESTEPS = 12
NUM_ORG_BANDS = len(BANDS)
TIMESTEPS_IDX = list(range(NUM_TIMESTEPS))

NORMED_BANDS = [x for x in BANDS if x != "B9"]
NUM_BANDS = len(NORMED_BANDS)
BANDS_IDX = list(range(NUM_BANDS))
BANDS_GROUPS_IDX: OrderedDictType[str, List[int]] = OrderedDict(
{
"S1": [NORMED_BANDS.index(b) for b in S1_BANDS],
"S2_RGB": [NORMED_BANDS.index(b) for b in ["B2", "B3", "B4"]],
"S2_Red_Edge": [NORMED_BANDS.index(b) for b in ["B5", "B6", "B7"]],
"S2_NIR_10m": [NORMED_BANDS.index(b) for b in ["B8"]],
"S2_NIR_20m": [NORMED_BANDS.index(b) for b in ["B8A"]],
"S2_SWIR": [NORMED_BANDS.index(b) for b in ["B11", "B12"]], # Include B10?
"ERA5": [NORMED_BANDS.index(b) for b in ERA5_BANDS],
"SRTM": [NORMED_BANDS.index(b) for b in SRTM_BANDS],
"NDVI": [NORMED_BANDS.index("NDVI")],
}
)

BAND_EXPANSION = [len(x) for x in BANDS_GROUPS_IDX.values()]
SRTM_INDEX = list(BANDS_GROUPS_IDX.keys()).index("SRTM")


class DynamicWorld2020_2021:
class_amount = 9

@classmethod
def normalize(cls, x: np.ndarray) -> np.ndarray:
return x


class S1_S2_ERA5_SRTM:
@staticmethod
def calculate_ndvi(input_array):
r"""
Given an input array of shape [timestep, bands] or [batches, timesteps, shapes]
where bands == len(bands), returns an array of shape
[timestep, bands + 1] where the extra band is NDVI,
(b08 - b04) / (b08 + b04)
"""
band_1, band_2 = "B8", "B4"

num_dims = len(input_array.shape)
if num_dims == 2:
band_1_np = input_array[:, NORMED_BANDS.index(band_1)]
band_2_np = input_array[:, NORMED_BANDS.index(band_2)]
elif num_dims == 3:
band_1_np = input_array[:, :, NORMED_BANDS.index(band_1)]
band_2_np = input_array[:, :, NORMED_BANDS.index(band_2)]
else:
raise ValueError(f"Expected num_dims to be 2 or 3 - got {num_dims}")

with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="invalid value encountered in true_divide")
# suppress the following warning
# RuntimeWarning: invalid value encountered in true_divide
# for cases where near_infrared + red == 0
# since this is handled in the where condition
if isinstance(band_1_np, np.ndarray):
return np.where(
(band_1_np + band_2_np) > 0,
(band_1_np - band_2_np) / (band_1_np + band_2_np),
0,
)
else:
return torch.where(
(band_1_np + band_2_np) > 0,
(band_1_np - band_2_np) / (band_1_np + band_2_np),
0,
)

@classmethod
def normalize(cls, x):
# remove the b9 band
keep_indices = [idx for idx, val in enumerate(BANDS) if val != "B9"]
if isinstance(x, np.ndarray):
x = ((x + ADD_BY) / DIVIDE_BY).astype(np.float32)
else:
x = (x + torch.tensor(ADD_BY)) / torch.tensor(DIVIDE_BY)

if len(x.shape) == 2:
x = x[:, keep_indices]
x[:, NORMED_BANDS.index("NDVI")] = cls.calculate_ndvi(x)
else:
x = x[:, :, keep_indices]
x[:, :, NORMED_BANDS.index("NDVI")] = cls.calculate_ndvi(x)
return x
Loading
Loading