Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Timestep position debugging #114

Merged
merged 12 commits into from
Oct 9, 2024
3 changes: 3 additions & 0 deletions presto/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,9 @@ def get_presto_features(
return presto_extractor._get_encodings(dl)

elif isinstance(inarr, xr.DataArray):
# Check if we have the expected 12 timesteps
if len(inarr.t) != 12:
raise ValueError(f"Can only run Presto on 12 timesteps, got: {len(inarr.t)}")
return presto_extractor.extract_presto_features(inarr, epsg=epsg)

else:
Expand Down
58 changes: 49 additions & 9 deletions presto/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import torch
import xarray as xr

from presto.dataops import NUM_TIMESTEPS

from .dataops import (
BANDS,
ERA5_BANDS,
Expand Down Expand Up @@ -117,13 +119,20 @@ def process_parquet(df: pd.DataFrame) -> pd.DataFrame:
- adding dummy timesteps filled with NODATA values before the start_date or after
the end_date for samples where valid_date is close to the edge of the timeseries;
this closeness is defined by the globally defined parameter MIN_EDGE_BUFFER
- reinitializing the start_date and timestamp_ind to take into account
- reinitializing the start_date, end_date and timestamp_ind to take into account
newly added timesteps
- checking for missing timesteps in the middle of the timeseries and adding them
with NODATA values
- pivoting the DataFrame to wide format with columns for each band
and timesteps as suffixes
- assigning the correct suffixes to the band names
- computing the final valid_date position in the timeseries that takes
into account updated start_date
- computing the number of available timesteps in the timeseries that
takes into account updated start_date and end_date; available_timesteps
holds the absolute number of timesteps that for which observations are
available; it cannot be less than NUM_TIMESTEPS; if this is the case,
sample is considered faulty and is removed from the dataset
- post-processing with prep_dataframe function

Returns
Expand Down Expand Up @@ -238,15 +247,15 @@ def process_parquet(df: pd.DataFrame) -> pd.DataFrame:
if len(samples_after_end_date) > 0 or len(samples_before_start_date) > 0:
logger.warning(
f"""\
Dataset {df["sample_id"].iloc[0].split("_")[0]}: removing {len(samples_after_end_date)}\
samples with valid_date after the end_date\
and {len(samples_before_start_date)} samples with valid_date before the start_date
"""
Dataset {df["ref_id"].iloc[0]}: removing {len(samples_after_end_date)} \
samples with valid_date after the end_date \
and {len(samples_before_start_date)} samples with valid_date before the start_date"""
)
df = df[~df["sample_id"].isin(samples_before_start_date)]
df = df[~df["sample_id"].isin(samples_after_end_date)]

# add timesteps before the start_date where needed
intermediate_dummy_df = pd.DataFrame()
for n_ts_to_add in range(1, MIN_EDGE_BUFFER + 1):
samples_to_add_ts_before_start = latest_obs_position[
(MIN_EDGE_BUFFER - latest_obs_position["valid_position"]) >= -n_ts_to_add
Expand All @@ -258,9 +267,11 @@ def process_parquet(df: pd.DataFrame) -> pd.DataFrame:
months=n_ts_to_add
) # type: ignore
dummy_df[feature_columns] = NODATAVALUE
df = pd.concat([df, dummy_df])
intermediate_dummy_df = pd.concat([intermediate_dummy_df, dummy_df])
df = pd.concat([df, intermediate_dummy_df])

# add timesteps after the end_date where needed
intermediate_dummy_df = pd.DataFrame()
for n_ts_to_add in range(1, MIN_EDGE_BUFFER + 1):
samples_to_add_ts_after_end = latest_obs_position[
(MIN_EDGE_BUFFER - latest_obs_position["valid_position_diff"]) >= n_ts_to_add
Expand All @@ -272,12 +283,17 @@ def process_parquet(df: pd.DataFrame) -> pd.DataFrame:
months=n_ts_to_add
) # type: ignore
dummy_df[feature_columns] = NODATAVALUE
df = pd.concat([df, dummy_df])
intermediate_dummy_df = pd.concat([intermediate_dummy_df, dummy_df])
df = pd.concat([df, intermediate_dummy_df])

# Now reassign start_date to the minimum timestamp
new_start_date = df.groupby(["sample_id"])["timestamp"].min()
df["start_date"] = df["sample_id"].map(new_start_date)

# Also reassign end_date to the maximum timestamp
new_end_date = df.groupby(["sample_id"])["timestamp"].max()
df["end_date"] = df["sample_id"].map(new_end_date)

# reinitialize timestep_ind
df["timestamp_ind"] = (df["timestamp"].dt.year * 12 + df["timestamp"].dt.month) - (
df["start_date"].dt.year * 12 + df["start_date"].dt.month
Expand Down Expand Up @@ -323,8 +339,32 @@ def process_parquet(df: pd.DataFrame) -> pd.DataFrame:
df_pivot["valid_date"].dt.year * 12 + df_pivot["valid_date"].dt.month
) - (df_pivot["start_date"].dt.year * 12 + df_pivot["start_date"].dt.month)
df_pivot["available_timesteps"] = (
df_pivot["end_date"].dt.year * 12 + df_pivot["end_date"].dt.month
) - (df_pivot["start_date"].dt.year * 12 + df_pivot["start_date"].dt.month)
(df_pivot["end_date"].dt.year * 12 + df_pivot["end_date"].dt.month)
- (df_pivot["start_date"].dt.year * 12 + df_pivot["start_date"].dt.month)
+ 1
)

min_center_point = np.maximum(
NUM_TIMESTEPS // 2,
df_pivot["valid_position"] + MIN_EDGE_BUFFER - NUM_TIMESTEPS // 2,
)
max_center_point = np.minimum(
df_pivot["available_timesteps"] - NUM_TIMESTEPS // 2,
df_pivot["valid_position"] - MIN_EDGE_BUFFER + NUM_TIMESTEPS // 2,
)

faulty_samples = min_center_point > max_center_point
if faulty_samples.sum() > 0:
logger.warning(f"Dropping {faulty_samples.sum()} faulty samples.")
df_pivot = df_pivot[~faulty_samples]

samples_with_too_few_ts = df_pivot["available_timesteps"] < NUM_TIMESTEPS
if samples_with_too_few_ts.sum() > 0:
logger.warning(
f"Dropping {samples_with_too_few_ts.sum()} samples with \
number of available timesteps less than {NUM_TIMESTEPS}."
)
df_pivot = df_pivot[~samples_with_too_few_ts]

df_pivot["year"] = df_pivot["valid_date"].dt.year

Expand Down
10 changes: 7 additions & 3 deletions train_finetuned.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import pandas as pd
import requests
import torch
import xarray as xr
from tqdm.auto import tqdm

Expand Down Expand Up @@ -155,10 +156,12 @@
]

logger.info("Reading dataset")
files = sorted(glob(f"{parquet_file}/**/*.parquet"))[10:20]
files = sorted(glob(f"{parquet_file}/**/*.parquet"))
df_list = []
for f in tqdm(files):
_data = pd.read_parquet(f, engine="fastparquet")
_ref_id = f.split("/")[-2].split("=")[-1]
_data["ref_id"] = _ref_id
_data_pivot = process_parquet(_data)
_data_pivot.reset_index(inplace=True)
df_list.append(_data_pivot)
Expand Down Expand Up @@ -199,7 +202,8 @@

experiment_prefix = f"""\
{presto_model_description}_{task_type}_{finetune_classes}_\
{compositing_window}_{test_type}_time-token={time_token}_balance={balance}\
{compositing_window}_{test_type}_time-token={time_token}_balance={balance}_\
augment={augment}\
"""

finetuned_model_path = model_path / f"{experiment_prefix}.pt"
Expand Down Expand Up @@ -266,7 +270,7 @@
results_df_combined, finetuned_model, sklearn_models_trained = full_eval.finetuning_results(
model, sklearn_model_modes=model_modes
)
# torch.save(finetuned_model.state_dict(), finetuned_model_path)
torch.save(finetuned_model.state_dict(), finetuned_model_path)

results_df_combined["presto_model_description"] = presto_model_description
results_df_combined["compositing_window"] = compositing_window
Expand Down
4 changes: 3 additions & 1 deletion train_self_supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,12 @@


logger.info("Loading data")
files = sorted(glob(f"{parquet_file}/**/*.parquet"))[10:20]
files = sorted(glob(f"{parquet_file}/**/*.parquet"))
df_list = []
for f in tqdm(files):
_data = pd.read_parquet(f, engine="fastparquet")
_ref_id = f.split("/")[-2].split("=")[-1]
_data["ref_id"] = _ref_id
_data_pivot = process_parquet(_data)
_data_pivot.reset_index(inplace=True)
df_list.append(_data_pivot)
Expand Down
Loading