30
30
get_sinusoid_encoding_table ,
31
31
param_groups_lrd ,
32
32
)
33
- from .utils import DEFAULT_SEED , device
33
+ from .utils import DEFAULT_SEED , device , prep_dataframe
34
34
35
35
logger = logging .getLogger ("__main__" )
36
36
@@ -73,9 +73,9 @@ def __init__(
73
73
self .target_function = target_function
74
74
75
75
train_data , val_data = WorldCerealLabelledDataset .split_df (train_data , val_size = val_size )
76
- self .train_df = self . prep_dataframe (train_data , filter_function , dekadal = dekadal )
77
- self .val_df = self . prep_dataframe (val_data , filter_function , dekadal = dekadal )
78
- self .test_df = self . prep_dataframe (test_data , filter_function , dekadal = dekadal )
76
+ self .train_df = prep_dataframe (train_data , filter_function , dekadal = dekadal )
77
+ self .val_df = prep_dataframe (val_data , filter_function , dekadal = dekadal )
78
+ self .test_df = prep_dataframe (test_data , filter_function , dekadal = dekadal )
79
79
80
80
self .spatial_inference_savedir = spatial_inference_savedir
81
81
@@ -90,23 +90,6 @@ def __init__(
90
90
self .dekadal = dekadal
91
91
self .ds_class = WorldCerealLabelled10DDataset if dekadal else WorldCerealLabelledDataset
92
92
93
- @staticmethod
94
- def prep_dataframe (
95
- df : pd .DataFrame ,
96
- filter_function : Optional [Callable [[pd .DataFrame ], pd .DataFrame ]] = None ,
97
- dekadal : bool = False ,
98
- ):
99
- # SAR cannot equal 0.0 since we take the log of it
100
- cols = [f"SAR-{ s } -ts{ t } -20m" for s in ["VV" , "VH" ] for t in range (36 if dekadal else 12 )]
101
-
102
- df = df .drop_duplicates (subset = ["sample_id" , "lat" , "lon" , "end_date" ])
103
- df = df [~ pd .isna (df ).any (axis = 1 )]
104
- df = df [~ (df .loc [:, cols ] == 0.0 ).any (axis = 1 )]
105
- df = df .set_index ("sample_id" )
106
- if filter_function is not None :
107
- df = filter_function (df )
108
- return df
109
-
110
93
def _construct_finetuning_model (self , pretrained_model : Presto ) -> PrestoFineTuningModel :
111
94
model : PrestoFineTuningModel = cast (Callable , pretrained_model .construct_finetuning_model )(
112
95
num_outputs = self .num_outputs
0 commit comments