Skip to content

Commit

Permalink
Merge pull request #94 from WorldCereal/inference-bugfixes
Browse files Browse the repository at this point in the history
Inference-bugfixes
  • Loading branch information
kvantricht authored Aug 21, 2024
2 parents efaa356 + 3f5cae0 commit e8d5bbc
Showing 1 changed file with 32 additions and 2 deletions.
34 changes: 32 additions & 2 deletions presto/inference.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from typing import Tuple, Union

import numpy as np
Expand All @@ -20,6 +21,9 @@
from .presto import Presto
from .utils import device, prep_dataframe

logger = logging.getLogger(__name__)


# Index to band groups mapping
IDX_TO_BAND_GROUPS = {
NORMED_BANDS[idx]: band_group_idx
Expand Down Expand Up @@ -59,6 +63,8 @@ def __init__(self, model: Presto, batch_size: int = 8192):
"temperature-mean": "temperature_2m",
}

STATIC_BAND_MAPPING = {"DEM-alt-20m": "elevation", "DEM-slo-20m": "slope"}

@classmethod
def _preprocess_band_values(cls, values: np.ndarray, presto_band: str) -> np.ndarray:
"""
Expand Down Expand Up @@ -106,8 +112,25 @@ def _extract_eo_data(cls, inarr: xr.DataArray) -> Tuple[np.ndarray, np.ndarray]:
)
idx_valid = values != cls._NODATAVALUE
values = cls._preprocess_band_values(values, presto_band)
eo_data[:, :, BANDS.index(presto_band)] = values
eo_data[:, :, BANDS.index(presto_band)] = values * idx_valid
mask[:, :, IDX_TO_BAND_GROUPS[presto_band]] += ~idx_valid
else:
logger.warning(f"Band {org_band} not found in input data.")
eo_data[:, :, BANDS.index(presto_band)] = 0
mask[:, :, IDX_TO_BAND_GROUPS[presto_band]] = 1

for org_band, presto_band in cls.STATIC_BAND_MAPPING.items():
if org_band in inarr.coords["bands"]:
values = np.swapaxes(
inarr.sel(bands=org_band).values.reshape((num_timesteps, -1)), 0, 1
)
idx_valid = values != cls._NODATAVALUE
eo_data[:, :, BANDS.index(presto_band)] = values * idx_valid
mask[:, IDX_TO_BAND_GROUPS[presto_band]] += ~idx_valid
else:
logger.warning(f"Band {org_band} not found in input data.")
eo_data[:, :, BANDS.index(presto_band)] = 0
mask[:, :, IDX_TO_BAND_GROUPS[presto_band]] = 1

return eo_data, mask

Expand Down Expand Up @@ -369,18 +392,25 @@ def process_parquet(df: pd.DataFrame) -> pd.DataFrame:
]
bands100m = ["METEO-precipitation_flux", "METEO-temperature_mean"]

# ----------------------------------------------------------------------------
# PLACEHOLDER for substituting start_date with one derived from crop calendars
# df['start_date'] = seasons.get_season_start(df[['lat','lon']])

# For now, in absence of a relevant start_date, we get time difference with respect
# to end_date so we can take 12 months counted back from end_date
df["valid_date_ind"] = (
(((df["timestamp"] - df["end_date"]).dt.days + 365) / 30).round().astype(int)
)

# once the start date is settled, we take 12 months from that as input to Presto
# Now reassign start_date to the actual subset counted back from end_date
df["start_date"] = df["end_date"] - pd.DateOffset(years=1) + pd.DateOffset(days=1)

df_pivot = df[(df["valid_date_ind"] >= 0) & (df["valid_date_ind"] < 12)].pivot(
index=index_columns, columns="valid_date_ind", values=feature_columns
)

# ----------------------------------------------------------------------------

if df_pivot.empty:
raise ValueError("Left with an empty DataFrame!")

Expand Down

0 comments on commit e8d5bbc

Please sign in to comment.