diff --git a/presto/inference.py b/presto/inference.py index b51c1f1..fae25df 100644 --- a/presto/inference.py +++ b/presto/inference.py @@ -214,6 +214,9 @@ def get_presto_features( return presto_extractor._get_encodings(dl) elif isinstance(inarr, xr.DataArray): + # Check if we have the expected 12 timesteps + if len(inarr.t) != 12: + raise ValueError(f"Can only run Presto on 12 timesteps, got: {len(inarr.t)}") return presto_extractor.extract_presto_features(inarr, epsg=epsg) else: diff --git a/presto/utils.py b/presto/utils.py index fdb9b6a..d03c274 100644 --- a/presto/utils.py +++ b/presto/utils.py @@ -13,6 +13,8 @@ import torch import xarray as xr +from presto.dataops import NUM_TIMESTEPS + from .dataops import ( BANDS, ERA5_BANDS, @@ -117,13 +119,20 @@ def process_parquet(df: pd.DataFrame) -> pd.DataFrame: - adding dummy timesteps filled with NODATA values before the start_date or after the end_date for samples where valid_date is close to the edge of the timeseries; this closeness is defined by the globally defined parameter MIN_EDGE_BUFFER - - reinitializing the start_date and timestamp_ind to take into account + - reinitializing the start_date, end_date and timestamp_ind to take into account newly added timesteps - checking for missing timesteps in the middle of the timeseries and adding them with NODATA values - pivoting the DataFrame to wide format with columns for each band and timesteps as suffixes - assigning the correct suffixes to the band names + - computing the final valid_date position in the timeseries that takes + into account updated start_date + - computing the number of available timesteps in the timeseries that + takes into account updated start_date and end_date; available_timesteps + holds the absolute number of timesteps that for which observations are + available; it cannot be less than NUM_TIMESTEPS; if this is the case, + sample is considered faulty and is removed from the dataset - post-processing with prep_dataframe function Returns @@ -238,15 +247,15 @@ def process_parquet(df: pd.DataFrame) -> pd.DataFrame: if len(samples_after_end_date) > 0 or len(samples_before_start_date) > 0: logger.warning( f"""\ - Dataset {df["sample_id"].iloc[0].split("_")[0]}: removing {len(samples_after_end_date)}\ - samples with valid_date after the end_date\ - and {len(samples_before_start_date)} samples with valid_date before the start_date - """ +Dataset {df["ref_id"].iloc[0]}: removing {len(samples_after_end_date)} \ +samples with valid_date after the end_date \ +and {len(samples_before_start_date)} samples with valid_date before the start_date""" ) df = df[~df["sample_id"].isin(samples_before_start_date)] df = df[~df["sample_id"].isin(samples_after_end_date)] # add timesteps before the start_date where needed + intermediate_dummy_df = pd.DataFrame() for n_ts_to_add in range(1, MIN_EDGE_BUFFER + 1): samples_to_add_ts_before_start = latest_obs_position[ (MIN_EDGE_BUFFER - latest_obs_position["valid_position"]) >= -n_ts_to_add @@ -258,9 +267,11 @@ def process_parquet(df: pd.DataFrame) -> pd.DataFrame: months=n_ts_to_add ) # type: ignore dummy_df[feature_columns] = NODATAVALUE - df = pd.concat([df, dummy_df]) + intermediate_dummy_df = pd.concat([intermediate_dummy_df, dummy_df]) + df = pd.concat([df, intermediate_dummy_df]) # add timesteps after the end_date where needed + intermediate_dummy_df = pd.DataFrame() for n_ts_to_add in range(1, MIN_EDGE_BUFFER + 1): samples_to_add_ts_after_end = latest_obs_position[ (MIN_EDGE_BUFFER - latest_obs_position["valid_position_diff"]) >= n_ts_to_add @@ -272,12 +283,17 @@ def process_parquet(df: pd.DataFrame) -> pd.DataFrame: months=n_ts_to_add ) # type: ignore dummy_df[feature_columns] = NODATAVALUE - df = pd.concat([df, dummy_df]) + intermediate_dummy_df = pd.concat([intermediate_dummy_df, dummy_df]) + df = pd.concat([df, intermediate_dummy_df]) # Now reassign start_date to the minimum timestamp new_start_date = df.groupby(["sample_id"])["timestamp"].min() df["start_date"] = df["sample_id"].map(new_start_date) + # Also reassign end_date to the maximum timestamp + new_end_date = df.groupby(["sample_id"])["timestamp"].max() + df["end_date"] = df["sample_id"].map(new_end_date) + # reinitialize timestep_ind df["timestamp_ind"] = (df["timestamp"].dt.year * 12 + df["timestamp"].dt.month) - ( df["start_date"].dt.year * 12 + df["start_date"].dt.month @@ -323,8 +339,32 @@ def process_parquet(df: pd.DataFrame) -> pd.DataFrame: df_pivot["valid_date"].dt.year * 12 + df_pivot["valid_date"].dt.month ) - (df_pivot["start_date"].dt.year * 12 + df_pivot["start_date"].dt.month) df_pivot["available_timesteps"] = ( - df_pivot["end_date"].dt.year * 12 + df_pivot["end_date"].dt.month - ) - (df_pivot["start_date"].dt.year * 12 + df_pivot["start_date"].dt.month) + (df_pivot["end_date"].dt.year * 12 + df_pivot["end_date"].dt.month) + - (df_pivot["start_date"].dt.year * 12 + df_pivot["start_date"].dt.month) + + 1 + ) + + min_center_point = np.maximum( + NUM_TIMESTEPS // 2, + df_pivot["valid_position"] + MIN_EDGE_BUFFER - NUM_TIMESTEPS // 2, + ) + max_center_point = np.minimum( + df_pivot["available_timesteps"] - NUM_TIMESTEPS // 2, + df_pivot["valid_position"] - MIN_EDGE_BUFFER + NUM_TIMESTEPS // 2, + ) + + faulty_samples = min_center_point > max_center_point + if faulty_samples.sum() > 0: + logger.warning(f"Dropping {faulty_samples.sum()} faulty samples.") + df_pivot = df_pivot[~faulty_samples] + + samples_with_too_few_ts = df_pivot["available_timesteps"] < NUM_TIMESTEPS + if samples_with_too_few_ts.sum() > 0: + logger.warning( + f"Dropping {samples_with_too_few_ts.sum()} samples with \ +number of available timesteps less than {NUM_TIMESTEPS}." + ) + df_pivot = df_pivot[~samples_with_too_few_ts] df_pivot["year"] = df_pivot["valid_date"].dt.year diff --git a/train_finetuned.py b/train_finetuned.py index efac6cb..b4c3c92 100644 --- a/train_finetuned.py +++ b/train_finetuned.py @@ -10,6 +10,7 @@ import pandas as pd import requests +import torch import xarray as xr from tqdm.auto import tqdm @@ -155,10 +156,12 @@ ] logger.info("Reading dataset") -files = sorted(glob(f"{parquet_file}/**/*.parquet"))[10:20] +files = sorted(glob(f"{parquet_file}/**/*.parquet")) df_list = [] for f in tqdm(files): _data = pd.read_parquet(f, engine="fastparquet") + _ref_id = f.split("/")[-2].split("=")[-1] + _data["ref_id"] = _ref_id _data_pivot = process_parquet(_data) _data_pivot.reset_index(inplace=True) df_list.append(_data_pivot) @@ -199,7 +202,8 @@ experiment_prefix = f"""\ {presto_model_description}_{task_type}_{finetune_classes}_\ -{compositing_window}_{test_type}_time-token={time_token}_balance={balance}\ +{compositing_window}_{test_type}_time-token={time_token}_balance={balance}_\ +augment={augment}\ """ finetuned_model_path = model_path / f"{experiment_prefix}.pt" @@ -266,7 +270,7 @@ results_df_combined, finetuned_model, sklearn_models_trained = full_eval.finetuning_results( model, sklearn_model_modes=model_modes ) - # torch.save(finetuned_model.state_dict(), finetuned_model_path) + torch.save(finetuned_model.state_dict(), finetuned_model_path) results_df_combined["presto_model_description"] = presto_model_description results_df_combined["compositing_window"] = compositing_window diff --git a/train_self_supervised.py b/train_self_supervised.py index 374e98e..a16f31f 100644 --- a/train_self_supervised.py +++ b/train_self_supervised.py @@ -149,10 +149,12 @@ logger.info("Loading data") -files = sorted(glob(f"{parquet_file}/**/*.parquet"))[10:20] +files = sorted(glob(f"{parquet_file}/**/*.parquet")) df_list = [] for f in tqdm(files): _data = pd.read_parquet(f, engine="fastparquet") + _ref_id = f.split("/")[-2].split("=")[-1] + _data["ref_id"] = _ref_id _data_pivot = process_parquet(_data) _data_pivot.reset_index(inplace=True) df_list.append(_data_pivot)