diff --git a/presto/inference.py b/presto/inference.py index 7f31f36..122ece9 100644 --- a/presto/inference.py +++ b/presto/inference.py @@ -1,3 +1,4 @@ +import logging from typing import Tuple, Union import numpy as np @@ -20,6 +21,9 @@ from .presto import Presto from .utils import device, prep_dataframe +logger = logging.getLogger(__name__) + + # Index to band groups mapping IDX_TO_BAND_GROUPS = { NORMED_BANDS[idx]: band_group_idx @@ -59,6 +63,8 @@ def __init__(self, model: Presto, batch_size: int = 8192): "temperature-mean": "temperature_2m", } + STATIC_BAND_MAPPING = {"DEM-alt-20m": "elevation", "DEM-slo-20m": "slope"} + @classmethod def _preprocess_band_values(cls, values: np.ndarray, presto_band: str) -> np.ndarray: """ @@ -106,8 +112,25 @@ def _extract_eo_data(cls, inarr: xr.DataArray) -> Tuple[np.ndarray, np.ndarray]: ) idx_valid = values != cls._NODATAVALUE values = cls._preprocess_band_values(values, presto_band) - eo_data[:, :, BANDS.index(presto_band)] = values + eo_data[:, :, BANDS.index(presto_band)] = values * idx_valid mask[:, :, IDX_TO_BAND_GROUPS[presto_band]] += ~idx_valid + else: + logger.warning(f"Band {org_band} not found in input data.") + eo_data[:, :, BANDS.index(presto_band)] = 0 + mask[:, :, IDX_TO_BAND_GROUPS[presto_band]] = 1 + + for org_band, presto_band in cls.STATIC_BAND_MAPPING.items(): + if org_band in inarr.coords["bands"]: + values = np.swapaxes( + inarr.sel(bands=org_band).values.reshape((num_timesteps, -1)), 0, 1 + ) + idx_valid = values != cls._NODATAVALUE + eo_data[:, :, BANDS.index(presto_band)] = values * idx_valid + mask[:, IDX_TO_BAND_GROUPS[presto_band]] += ~idx_valid + else: + logger.warning(f"Band {org_band} not found in input data.") + eo_data[:, :, BANDS.index(presto_band)] = 0 + mask[:, :, IDX_TO_BAND_GROUPS[presto_band]] = 1 return eo_data, mask @@ -369,18 +392,25 @@ def process_parquet(df: pd.DataFrame) -> pd.DataFrame: ] bands100m = ["METEO-precipitation_flux", "METEO-temperature_mean"] + # ---------------------------------------------------------------------------- # PLACEHOLDER for substituting start_date with one derived from crop calendars # df['start_date'] = seasons.get_season_start(df[['lat','lon']]) + # For now, in absence of a relevant start_date, we get time difference with respect + # to end_date so we can take 12 months counted back from end_date df["valid_date_ind"] = ( (((df["timestamp"] - df["end_date"]).dt.days + 365) / 30).round().astype(int) ) - # once the start date is settled, we take 12 months from that as input to Presto + # Now reassign start_date to the actual subset counted back from end_date + df["start_date"] = df["end_date"] - pd.DateOffset(years=1) + pd.DateOffset(days=1) + df_pivot = df[(df["valid_date_ind"] >= 0) & (df["valid_date_ind"] < 12)].pivot( index=index_columns, columns="valid_date_ind", values=feature_columns ) + # ---------------------------------------------------------------------------- + if df_pivot.empty: raise ValueError("Left with an empty DataFrame!")