From 13f769a21086774d2c739af7da75520312d012dc Mon Sep 17 00:00:00 2001 From: Kristof Van Tricht Date: Wed, 21 Aug 2024 17:06:08 +0200 Subject: [PATCH 1/6] Log a warning when a band is missing --- presto/inference.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/presto/inference.py b/presto/inference.py index 7f31f36..7cc89b5 100644 --- a/presto/inference.py +++ b/presto/inference.py @@ -1,3 +1,4 @@ +import logging from typing import Tuple, Union import numpy as np @@ -20,6 +21,9 @@ from .presto import Presto from .utils import device, prep_dataframe +logger = logging.getLogger(__name__) + + # Index to band groups mapping IDX_TO_BAND_GROUPS = { NORMED_BANDS[idx]: band_group_idx @@ -108,6 +112,11 @@ def _extract_eo_data(cls, inarr: xr.DataArray) -> Tuple[np.ndarray, np.ndarray]: values = cls._preprocess_band_values(values, presto_band) eo_data[:, :, BANDS.index(presto_band)] = values mask[:, :, IDX_TO_BAND_GROUPS[presto_band]] += ~idx_valid + else: + logger.warning(f"Band {org_band} not found in input data.") + eo_data[:, :, BANDS.index(presto_band)] = 0 + mask[:, :, IDX_TO_BAND_GROUPS[presto_band]] = 1 + return eo_data, mask From dfd490f649642f95234581d2f8e5b28b95eba35c Mon Sep 17 00:00:00 2001 From: Kristof Van Tricht Date: Wed, 21 Aug 2024 17:06:56 +0200 Subject: [PATCH 2/6] Add DEM bands to inference --- presto/inference.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/presto/inference.py b/presto/inference.py index 7cc89b5..a1f1810 100644 --- a/presto/inference.py +++ b/presto/inference.py @@ -63,6 +63,8 @@ def __init__(self, model: Presto, batch_size: int = 8192): "temperature-mean": "temperature_2m", } + STATIC_BAND_MAPPING = {"DEM-alt-20m": "elevation", "DEM-slo-20m": "slope"} + @classmethod def _preprocess_band_values(cls, values: np.ndarray, presto_band: str) -> np.ndarray: """ @@ -117,6 +119,18 @@ def _extract_eo_data(cls, inarr: xr.DataArray) -> Tuple[np.ndarray, np.ndarray]: eo_data[:, :, BANDS.index(presto_band)] = 0 mask[:, :, IDX_TO_BAND_GROUPS[presto_band]] = 1 + for org_band, presto_band in cls.STATIC_BAND_MAPPING.items(): + if org_band in inarr.coords["bands"]: + values = np.swapaxes( + inarr.sel(bands=org_band).values.reshape((num_timesteps, -1)), 0, 1 + ) + idx_valid = values != cls._NODATAVALUE + eo_data[:, :, BANDS.index(presto_band)] = values * idx_valid + mask[:, IDX_TO_BAND_GROUPS[presto_band]] += ~idx_valid + else: + logger.warning(f"Band {org_band} not found in input data.") + eo_data[:, :, BANDS.index(presto_band)] = 0 + mask[:, :, IDX_TO_BAND_GROUPS[presto_band]] = 1 return eo_data, mask From f1573cbec5ee39723390ab55724c864e73ad2ff5 Mon Sep 17 00:00:00 2001 From: Kristof Van Tricht Date: Wed, 21 Aug 2024 17:07:21 +0200 Subject: [PATCH 3/6] Put missing values to 0 --- presto/inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/presto/inference.py b/presto/inference.py index a1f1810..ae5d3f6 100644 --- a/presto/inference.py +++ b/presto/inference.py @@ -112,7 +112,7 @@ def _extract_eo_data(cls, inarr: xr.DataArray) -> Tuple[np.ndarray, np.ndarray]: ) idx_valid = values != cls._NODATAVALUE values = cls._preprocess_band_values(values, presto_band) - eo_data[:, :, BANDS.index(presto_band)] = values + eo_data[:, :, BANDS.index(presto_band)] = values * idx_valid mask[:, :, IDX_TO_BAND_GROUPS[presto_band]] += ~idx_valid else: logger.warning(f"Band {org_band} not found in input data.") From 750930b0e20582354e70734b8ec99e38b357a136 Mon Sep 17 00:00:00 2001 From: Kristof Van Tricht Date: Wed, 21 Aug 2024 17:07:52 +0200 Subject: [PATCH 4/6] Fix start_date adjustment to real subset --- presto/inference.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/presto/inference.py b/presto/inference.py index ae5d3f6..d8f9d14 100644 --- a/presto/inference.py +++ b/presto/inference.py @@ -392,18 +392,24 @@ def process_parquet(df: pd.DataFrame) -> pd.DataFrame: ] bands100m = ["METEO-precipitation_flux", "METEO-temperature_mean"] + # ---------------------------------------------------------------------------- # PLACEHOLDER for substituting start_date with one derived from crop calendars # df['start_date'] = seasons.get_season_start(df[['lat','lon']]) + # For now, in absence of a relevant start_date, we get time difference with respect + # to end_date so we can take 12 months counted back from end_date df["valid_date_ind"] = ( (((df["timestamp"] - df["end_date"]).dt.days + 365) / 30).round().astype(int) ) - # once the start date is settled, we take 12 months from that as input to Presto df_pivot = df[(df["valid_date_ind"] >= 0) & (df["valid_date_ind"] < 12)].pivot( index=index_columns, columns="valid_date_ind", values=feature_columns ) + # Now reassign start_date to the actual subset counted back from end_date + df["start_date"] = df["end_date"] - pd.Timedelta(days=364) + # ---------------------------------------------------------------------------- + if df_pivot.empty: raise ValueError("Left with an empty DataFrame!") From 7c0e160c7b10714d1ec7a047b533d2266b4038c6 Mon Sep 17 00:00:00 2001 From: Kristof Van Tricht Date: Wed, 21 Aug 2024 17:26:30 +0200 Subject: [PATCH 5/6] Adapt start_date before pivoting :facepalm: --- presto/inference.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/presto/inference.py b/presto/inference.py index d8f9d14..87c3321 100644 --- a/presto/inference.py +++ b/presto/inference.py @@ -402,12 +402,13 @@ def process_parquet(df: pd.DataFrame) -> pd.DataFrame: (((df["timestamp"] - df["end_date"]).dt.days + 365) / 30).round().astype(int) ) + # Now reassign start_date to the actual subset counted back from end_date + df["start_date"] = df["end_date"] - pd.Timedelta(days=364) + df_pivot = df[(df["valid_date_ind"] >= 0) & (df["valid_date_ind"] < 12)].pivot( index=index_columns, columns="valid_date_ind", values=feature_columns ) - # Now reassign start_date to the actual subset counted back from end_date - df["start_date"] = df["end_date"] - pd.Timedelta(days=364) # ---------------------------------------------------------------------------- if df_pivot.empty: From 3f5cae08f1be3d037ee01d18b4a960c822be29f8 Mon Sep 17 00:00:00 2001 From: Kristof Van Tricht Date: Wed, 21 Aug 2024 17:48:47 +0200 Subject: [PATCH 6/6] More precise start_date determination --- presto/inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/presto/inference.py b/presto/inference.py index 87c3321..122ece9 100644 --- a/presto/inference.py +++ b/presto/inference.py @@ -403,7 +403,7 @@ def process_parquet(df: pd.DataFrame) -> pd.DataFrame: ) # Now reassign start_date to the actual subset counted back from end_date - df["start_date"] = df["end_date"] - pd.Timedelta(days=364) + df["start_date"] = df["end_date"] - pd.DateOffset(years=1) + pd.DateOffset(days=1) df_pivot = df[(df["valid_date_ind"] >= 0) & (df["valid_date_ind"] < 12)].pivot( index=index_columns, columns="valid_date_ind", values=feature_columns