From ed6033560a78718c0a6c9f081882c698f9dfcadb Mon Sep 17 00:00:00 2001 From: Butsko Christina Date: Mon, 7 Oct 2024 16:04:46 +0200 Subject: [PATCH 01/11] reintroduced ref_id into dataset and made cleaner logger message about samples with incoherent dates --- presto/utils.py | 7 +++---- train_finetuned.py | 2 ++ train_self_supervised.py | 4 +++- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/presto/utils.py b/presto/utils.py index fdb9b6a..2c16cde 100644 --- a/presto/utils.py +++ b/presto/utils.py @@ -238,10 +238,9 @@ def process_parquet(df: pd.DataFrame) -> pd.DataFrame: if len(samples_after_end_date) > 0 or len(samples_before_start_date) > 0: logger.warning( f"""\ - Dataset {df["sample_id"].iloc[0].split("_")[0]}: removing {len(samples_after_end_date)}\ - samples with valid_date after the end_date\ - and {len(samples_before_start_date)} samples with valid_date before the start_date - """ +Dataset {df["ref_id"].iloc[0]}: removing {len(samples_after_end_date)}\ +samples with valid_date after the end_date\ +and {len(samples_before_start_date)} samples with valid_date before the start_date""" ) df = df[~df["sample_id"].isin(samples_before_start_date)] df = df[~df["sample_id"].isin(samples_after_end_date)] diff --git a/train_finetuned.py b/train_finetuned.py index efac6cb..bc394bc 100644 --- a/train_finetuned.py +++ b/train_finetuned.py @@ -159,6 +159,8 @@ df_list = [] for f in tqdm(files): _data = pd.read_parquet(f, engine="fastparquet") + _ref_id = f.split("/")[-2].split("=")[-1] + _data["ref_id"] = _ref_id _data_pivot = process_parquet(_data) _data_pivot.reset_index(inplace=True) df_list.append(_data_pivot) diff --git a/train_self_supervised.py b/train_self_supervised.py index 374e98e..a16f31f 100644 --- a/train_self_supervised.py +++ b/train_self_supervised.py @@ -149,10 +149,12 @@ logger.info("Loading data") -files = sorted(glob(f"{parquet_file}/**/*.parquet"))[10:20] +files = sorted(glob(f"{parquet_file}/**/*.parquet")) df_list = [] for f in tqdm(files): _data = pd.read_parquet(f, engine="fastparquet") + _ref_id = f.split("/")[-2].split("=")[-1] + _data["ref_id"] = _ref_id _data_pivot = process_parquet(_data) _data_pivot.reset_index(inplace=True) df_list.append(_data_pivot) From 373e872a2bf6871e89d8b91cc5c2bb761132fc32 Mon Sep 17 00:00:00 2001 From: Butsko Christina Date: Mon, 7 Oct 2024 16:06:29 +0200 Subject: [PATCH 02/11] fixing the number of available_timesteps --- presto/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/presto/utils.py b/presto/utils.py index 2c16cde..81eb213 100644 --- a/presto/utils.py +++ b/presto/utils.py @@ -323,7 +323,7 @@ def process_parquet(df: pd.DataFrame) -> pd.DataFrame: ) - (df_pivot["start_date"].dt.year * 12 + df_pivot["start_date"].dt.month) df_pivot["available_timesteps"] = ( df_pivot["end_date"].dt.year * 12 + df_pivot["end_date"].dt.month - ) - (df_pivot["start_date"].dt.year * 12 + df_pivot["start_date"].dt.month) + ) - (df_pivot["start_date"].dt.year * 12 + df_pivot["start_date"].dt.month) + 1 df_pivot["year"] = df_pivot["valid_date"].dt.year From f03649e5da8bdba0a507b97791d20117ca516f9f Mon Sep 17 00:00:00 2001 From: Butsko Christina Date: Mon, 7 Oct 2024 22:44:41 +0200 Subject: [PATCH 03/11] fixed available_timesteps computation for corner cases --- presto/utils.py | 48 +++++++++++++++++++++++++++++++----------------- 1 file changed, 31 insertions(+), 17 deletions(-) diff --git a/presto/utils.py b/presto/utils.py index 81eb213..b685a61 100644 --- a/presto/utils.py +++ b/presto/utils.py @@ -13,19 +13,9 @@ import torch import xarray as xr -from .dataops import ( - BANDS, - ERA5_BANDS, - MIN_EDGE_BUFFER, - NODATAVALUE, - NORMED_BANDS, - REMOVED_BANDS, - S1_BANDS, - S1_S2_ERA5_SRTM, - S2_BANDS, - SRTM_BANDS, - DynamicWorld2020_2021, -) +from .dataops import (BANDS, ERA5_BANDS, MIN_EDGE_BUFFER, NODATAVALUE, + NORMED_BANDS, REMOVED_BANDS, S1_BANDS, S1_S2_ERA5_SRTM, + S2_BANDS, SRTM_BANDS, DynamicWorld2020_2021) # plt = None @@ -238,14 +228,15 @@ def process_parquet(df: pd.DataFrame) -> pd.DataFrame: if len(samples_after_end_date) > 0 or len(samples_before_start_date) > 0: logger.warning( f"""\ -Dataset {df["ref_id"].iloc[0]}: removing {len(samples_after_end_date)}\ -samples with valid_date after the end_date\ +Dataset {df["ref_id"].iloc[0]}: removing {len(samples_after_end_date)} \ +samples with valid_date after the end_date \ and {len(samples_before_start_date)} samples with valid_date before the start_date""" ) df = df[~df["sample_id"].isin(samples_before_start_date)] df = df[~df["sample_id"].isin(samples_after_end_date)] # add timesteps before the start_date where needed + intermediate_dummy_df = pd.DataFrame() for n_ts_to_add in range(1, MIN_EDGE_BUFFER + 1): samples_to_add_ts_before_start = latest_obs_position[ (MIN_EDGE_BUFFER - latest_obs_position["valid_position"]) >= -n_ts_to_add @@ -257,9 +248,11 @@ def process_parquet(df: pd.DataFrame) -> pd.DataFrame: months=n_ts_to_add ) # type: ignore dummy_df[feature_columns] = NODATAVALUE - df = pd.concat([df, dummy_df]) + intermediate_dummy_df = pd.concat([intermediate_dummy_df, dummy_df]) + df = pd.concat([df, intermediate_dummy_df]) # add timesteps after the end_date where needed + intermediate_dummy_df = pd.DataFrame() for n_ts_to_add in range(1, MIN_EDGE_BUFFER + 1): samples_to_add_ts_after_end = latest_obs_position[ (MIN_EDGE_BUFFER - latest_obs_position["valid_position_diff"]) >= n_ts_to_add @@ -271,12 +264,17 @@ def process_parquet(df: pd.DataFrame) -> pd.DataFrame: months=n_ts_to_add ) # type: ignore dummy_df[feature_columns] = NODATAVALUE - df = pd.concat([df, dummy_df]) + intermediate_dummy_df = pd.concat([intermediate_dummy_df, dummy_df]) + df = pd.concat([df, intermediate_dummy_df]) # Now reassign start_date to the minimum timestamp new_start_date = df.groupby(["sample_id"])["timestamp"].min() df["start_date"] = df["sample_id"].map(new_start_date) + # Also reassign end_date to the maximum timestamp + new_end_date = df.groupby(["sample_id"])["timestamp"].max() + df["end_date"] = df["sample_id"].map(new_end_date) + # reinitialize timestep_ind df["timestamp_ind"] = (df["timestamp"].dt.year * 12 + df["timestamp"].dt.month) - ( df["start_date"].dt.year * 12 + df["start_date"].dt.month @@ -325,6 +323,22 @@ def process_parquet(df: pd.DataFrame) -> pd.DataFrame: df_pivot["end_date"].dt.year * 12 + df_pivot["end_date"].dt.month ) - (df_pivot["start_date"].dt.year * 12 + df_pivot["start_date"].dt.month) + 1 + from presto.dataops import NUM_TIMESTEPS + + min_center_point = np.maximum( + NUM_TIMESTEPS // 2, + df_pivot["valid_position"] + MIN_EDGE_BUFFER - NUM_TIMESTEPS // 2, + ) + max_center_point = np.minimum( + df_pivot["available_timesteps"] - NUM_TIMESTEPS // 2, + df_pivot["valid_position"] - MIN_EDGE_BUFFER + NUM_TIMESTEPS // 2, + ) + + faulty_samples = min_center_point > max_center_point + if faulty_samples.sum() > 0: + logger.warning(f"Dropping {faulty_samples.sum()} faulty samples.") + df_pivot = df_pivot[~faulty_samples] + df_pivot["year"] = df_pivot["valid_date"].dt.year df_pivot["start_date"] = df_pivot["start_date"].dt.date.astype(str) From c59c0664ecf56de36fe116ddf68b86dd9f1d68c8 Mon Sep 17 00:00:00 2001 From: Butsko Christina Date: Mon, 7 Oct 2024 22:46:35 +0200 Subject: [PATCH 04/11] cleanup --- train_finetuned.py | 27 ++++++++++----------------- 1 file changed, 10 insertions(+), 17 deletions(-) diff --git a/train_finetuned.py b/train_finetuned.py index bc394bc..e0ebcf7 100644 --- a/train_finetuned.py +++ b/train_finetuned.py @@ -10,25 +10,17 @@ import pandas as pd import requests +import torch import xarray as xr -from tqdm.auto import tqdm - from presto.dataops import NODATAVALUE from presto.dataset import WorldCerealBase, filter_remove_noncrops from presto.eval import WorldCerealEval from presto.presto import Presto -from presto.utils import ( - DEFAULT_SEED, - config_dir, - data_dir, - default_model_path, - device, - initialize_logging, - plot_spatial, - process_parquet, - seed_everything, - timestamp_dirname, -) +from presto.utils import (DEFAULT_SEED, config_dir, data_dir, + default_model_path, device, initialize_logging, + plot_spatial, process_parquet, seed_everything, + timestamp_dirname) +from tqdm.auto import tqdm logger = logging.getLogger("__main__") @@ -155,7 +147,7 @@ ] logger.info("Reading dataset") -files = sorted(glob(f"{parquet_file}/**/*.parquet"))[10:20] +files = sorted(glob(f"{parquet_file}/**/*.parquet")) df_list = [] for f in tqdm(files): _data = pd.read_parquet(f, engine="fastparquet") @@ -201,7 +193,8 @@ experiment_prefix = f"""\ {presto_model_description}_{task_type}_{finetune_classes}_\ -{compositing_window}_{test_type}_time-token={time_token}_balance={balance}\ +{compositing_window}_{test_type}_time-token={time_token}_balance={balance}_\ +augment={augment}\ """ finetuned_model_path = model_path / f"{experiment_prefix}.pt" @@ -268,7 +261,7 @@ results_df_combined, finetuned_model, sklearn_models_trained = full_eval.finetuning_results( model, sklearn_model_modes=model_modes ) - # torch.save(finetuned_model.state_dict(), finetuned_model_path) + torch.save(finetuned_model.state_dict(), finetuned_model_path) results_df_combined["presto_model_description"] = presto_model_description results_df_combined["compositing_window"] = compositing_window From 3ba98a3f8297e18394f71e73827316c7adca855b Mon Sep 17 00:00:00 2001 From: Butsko Christina Date: Mon, 7 Oct 2024 23:03:42 +0200 Subject: [PATCH 05/11] formatting --- presto/utils.py | 22 +++++++++++++++++----- train_finetuned.py | 19 ++++++++++++++----- 2 files changed, 31 insertions(+), 10 deletions(-) diff --git a/presto/utils.py b/presto/utils.py index b685a61..07cddf6 100644 --- a/presto/utils.py +++ b/presto/utils.py @@ -13,9 +13,19 @@ import torch import xarray as xr -from .dataops import (BANDS, ERA5_BANDS, MIN_EDGE_BUFFER, NODATAVALUE, - NORMED_BANDS, REMOVED_BANDS, S1_BANDS, S1_S2_ERA5_SRTM, - S2_BANDS, SRTM_BANDS, DynamicWorld2020_2021) +from .dataops import ( + BANDS, + ERA5_BANDS, + MIN_EDGE_BUFFER, + NODATAVALUE, + NORMED_BANDS, + REMOVED_BANDS, + S1_BANDS, + S1_S2_ERA5_SRTM, + S2_BANDS, + SRTM_BANDS, + DynamicWorld2020_2021, +) # plt = None @@ -320,8 +330,10 @@ def process_parquet(df: pd.DataFrame) -> pd.DataFrame: df_pivot["valid_date"].dt.year * 12 + df_pivot["valid_date"].dt.month ) - (df_pivot["start_date"].dt.year * 12 + df_pivot["start_date"].dt.month) df_pivot["available_timesteps"] = ( - df_pivot["end_date"].dt.year * 12 + df_pivot["end_date"].dt.month - ) - (df_pivot["start_date"].dt.year * 12 + df_pivot["start_date"].dt.month) + 1 + (df_pivot["end_date"].dt.year * 12 + df_pivot["end_date"].dt.month) + - (df_pivot["start_date"].dt.year * 12 + df_pivot["start_date"].dt.month) + + 1 + ) from presto.dataops import NUM_TIMESTEPS diff --git a/train_finetuned.py b/train_finetuned.py index e0ebcf7..b4c3c92 100644 --- a/train_finetuned.py +++ b/train_finetuned.py @@ -12,15 +12,24 @@ import requests import torch import xarray as xr +from tqdm.auto import tqdm + from presto.dataops import NODATAVALUE from presto.dataset import WorldCerealBase, filter_remove_noncrops from presto.eval import WorldCerealEval from presto.presto import Presto -from presto.utils import (DEFAULT_SEED, config_dir, data_dir, - default_model_path, device, initialize_logging, - plot_spatial, process_parquet, seed_everything, - timestamp_dirname) -from tqdm.auto import tqdm +from presto.utils import ( + DEFAULT_SEED, + config_dir, + data_dir, + default_model_path, + device, + initialize_logging, + plot_spatial, + process_parquet, + seed_everything, + timestamp_dirname, +) logger = logging.getLogger("__main__") From e806c28033138fbd37b21e4a33f18844b004ef7a Mon Sep 17 00:00:00 2001 From: Butsko Christina Date: Tue, 8 Oct 2024 09:02:35 +0200 Subject: [PATCH 06/11] additional check on the available_timesteps + descr --- presto/utils.py | 31 +++++++++++++++++-------------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/presto/utils.py b/presto/utils.py index 07cddf6..9657578 100644 --- a/presto/utils.py +++ b/presto/utils.py @@ -13,19 +13,9 @@ import torch import xarray as xr -from .dataops import ( - BANDS, - ERA5_BANDS, - MIN_EDGE_BUFFER, - NODATAVALUE, - NORMED_BANDS, - REMOVED_BANDS, - S1_BANDS, - S1_S2_ERA5_SRTM, - S2_BANDS, - SRTM_BANDS, - DynamicWorld2020_2021, -) +from .dataops import (BANDS, ERA5_BANDS, MIN_EDGE_BUFFER, NODATAVALUE, + NORMED_BANDS, REMOVED_BANDS, S1_BANDS, S1_S2_ERA5_SRTM, + S2_BANDS, SRTM_BANDS, DynamicWorld2020_2021) # plt = None @@ -117,13 +107,20 @@ def process_parquet(df: pd.DataFrame) -> pd.DataFrame: - adding dummy timesteps filled with NODATA values before the start_date or after the end_date for samples where valid_date is close to the edge of the timeseries; this closeness is defined by the globally defined parameter MIN_EDGE_BUFFER - - reinitializing the start_date and timestamp_ind to take into account + - reinitializing the start_date, end_date and timestamp_ind to take into account newly added timesteps - checking for missing timesteps in the middle of the timeseries and adding them with NODATA values - pivoting the DataFrame to wide format with columns for each band and timesteps as suffixes - assigning the correct suffixes to the band names + - computing the final valid_date position in the timeseries that takes + into account updated start_date + - computing the number of available timesteps in the timeseries that + takes into account updated start_date and end_date; available_timesteps + holds the absolute number of timesteps that for which observations are + available; it cannot be less than NUM_TIMESTEPS; if this is the case, + sample is considered faulty and is removed from the dataset - post-processing with prep_dataframe function Returns @@ -351,6 +348,12 @@ def process_parquet(df: pd.DataFrame) -> pd.DataFrame: logger.warning(f"Dropping {faulty_samples.sum()} faulty samples.") df_pivot = df_pivot[~faulty_samples] + samples_with_too_few_ts = df_pivot["available_timesteps"] 0: + logger.warning(f"Dropping {samples_with_too_few_ts.sum()} samples with \ +number of available timesteps less than {NUM_TIMESTEPS}.") + df_pivot = df_pivot[~samples_with_too_few_ts] + df_pivot["year"] = df_pivot["valid_date"].dt.year df_pivot["start_date"] = df_pivot["start_date"].dt.date.astype(str) From eee7dc5a678b7518bcfee621307c9fd2ea9fd9f5 Mon Sep 17 00:00:00 2001 From: Butsko Christina Date: Tue, 8 Oct 2024 11:13:03 +0200 Subject: [PATCH 07/11] isort fix, hopefully the correct version --- presto/utils.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/presto/utils.py b/presto/utils.py index 9657578..b6da3fa 100644 --- a/presto/utils.py +++ b/presto/utils.py @@ -13,9 +13,19 @@ import torch import xarray as xr -from .dataops import (BANDS, ERA5_BANDS, MIN_EDGE_BUFFER, NODATAVALUE, - NORMED_BANDS, REMOVED_BANDS, S1_BANDS, S1_S2_ERA5_SRTM, - S2_BANDS, SRTM_BANDS, DynamicWorld2020_2021) +from .dataops import ( + BANDS, + ERA5_BANDS, + MIN_EDGE_BUFFER, + NODATAVALUE, + NORMED_BANDS, + REMOVED_BANDS, + S1_BANDS, + S1_S2_ERA5_SRTM, + S2_BANDS, + SRTM_BANDS, + DynamicWorld2020_2021, +) # plt = None From f03c1cdc789b4ef81993ed875ba298cc5536ebc4 Mon Sep 17 00:00:00 2001 From: Kristof Van Tricht Date: Tue, 8 Oct 2024 18:51:34 +0200 Subject: [PATCH 08/11] Check nr of timesteps in inference --- presto/inference.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/presto/inference.py b/presto/inference.py index b51c1f1..c18ec26 100644 --- a/presto/inference.py +++ b/presto/inference.py @@ -214,6 +214,9 @@ def get_presto_features( return presto_extractor._get_encodings(dl) elif isinstance(inarr, xr.DataArray): + # Check if we have the expected 12 timesteps + if len(inarr.t) != 12: + raise ValueError("Can only run Presto on 12 timesteps, got: {len(inarr.t)}") return presto_extractor.extract_presto_features(inarr, epsg=epsg) else: From 865ab1e50eaa2743068eb17ef625b4f869f31c6a Mon Sep 17 00:00:00 2001 From: Kristof Van Tricht Date: Tue, 8 Oct 2024 18:57:12 +0200 Subject: [PATCH 09/11] Attempt to auto-format --- presto/utils.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/presto/utils.py b/presto/utils.py index b6da3fa..c18e745 100644 --- a/presto/utils.py +++ b/presto/utils.py @@ -358,10 +358,12 @@ def process_parquet(df: pd.DataFrame) -> pd.DataFrame: logger.warning(f"Dropping {faulty_samples.sum()} faulty samples.") df_pivot = df_pivot[~faulty_samples] - samples_with_too_few_ts = df_pivot["available_timesteps"] 0: - logger.warning(f"Dropping {samples_with_too_few_ts.sum()} samples with \ -number of available timesteps less than {NUM_TIMESTEPS}.") + logger.warning( + f"Dropping {samples_with_too_few_ts.sum()} samples with \ +number of available timesteps less than {NUM_TIMESTEPS}." + ) df_pivot = df_pivot[~samples_with_too_few_ts] df_pivot["year"] = df_pivot["valid_date"].dt.year From add376a3312550ce704982f1a0f18a5244ad27d9 Mon Sep 17 00:00:00 2001 From: Kristof Van Tricht Date: Wed, 9 Oct 2024 14:49:15 +0200 Subject: [PATCH 10/11] Should be f-string --- presto/inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/presto/inference.py b/presto/inference.py index c18ec26..fae25df 100644 --- a/presto/inference.py +++ b/presto/inference.py @@ -216,7 +216,7 @@ def get_presto_features( elif isinstance(inarr, xr.DataArray): # Check if we have the expected 12 timesteps if len(inarr.t) != 12: - raise ValueError("Can only run Presto on 12 timesteps, got: {len(inarr.t)}") + raise ValueError(f"Can only run Presto on 12 timesteps, got: {len(inarr.t)}") return presto_extractor.extract_presto_features(inarr, epsg=epsg) else: From e5e01096ac859029737afc94bbe0f6f27aeef4c7 Mon Sep 17 00:00:00 2001 From: Kristof Van Tricht Date: Wed, 9 Oct 2024 14:52:26 +0200 Subject: [PATCH 11/11] Moved import to top --- presto/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/presto/utils.py b/presto/utils.py index c18e745..d03c274 100644 --- a/presto/utils.py +++ b/presto/utils.py @@ -13,6 +13,8 @@ import torch import xarray as xr +from presto.dataops import NUM_TIMESTEPS + from .dataops import ( BANDS, ERA5_BANDS, @@ -342,8 +344,6 @@ def process_parquet(df: pd.DataFrame) -> pd.DataFrame: + 1 ) - from presto.dataops import NUM_TIMESTEPS - min_center_point = np.maximum( NUM_TIMESTEPS // 2, df_pivot["valid_position"] + MIN_EDGE_BUFFER - NUM_TIMESTEPS // 2,