Skip to content

Commit c24008c

Browse files
authored
Merge pull request #86 from WorldCereal/move-prepdataframe
Isolated `prep_dataframe` method
2 parents 357207a + 46f31a2 commit c24008c

File tree

4 files changed

+29
-26
lines changed

4 files changed

+29
-26
lines changed

presto/eval.py

Lines changed: 4 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
get_sinusoid_encoding_table,
3131
param_groups_lrd,
3232
)
33-
from .utils import DEFAULT_SEED, device
33+
from .utils import DEFAULT_SEED, device, prep_dataframe
3434

3535
logger = logging.getLogger("__main__")
3636

@@ -73,9 +73,9 @@ def __init__(
7373
self.target_function = target_function
7474

7575
train_data, val_data = WorldCerealLabelledDataset.split_df(train_data, val_size=val_size)
76-
self.train_df = self.prep_dataframe(train_data, filter_function, dekadal=dekadal)
77-
self.val_df = self.prep_dataframe(val_data, filter_function, dekadal=dekadal)
78-
self.test_df = self.prep_dataframe(test_data, filter_function, dekadal=dekadal)
76+
self.train_df = prep_dataframe(train_data, filter_function, dekadal=dekadal)
77+
self.val_df = prep_dataframe(val_data, filter_function, dekadal=dekadal)
78+
self.test_df = prep_dataframe(test_data, filter_function, dekadal=dekadal)
7979

8080
self.spatial_inference_savedir = spatial_inference_savedir
8181

@@ -90,23 +90,6 @@ def __init__(
9090
self.dekadal = dekadal
9191
self.ds_class = WorldCerealLabelled10DDataset if dekadal else WorldCerealLabelledDataset
9292

93-
@staticmethod
94-
def prep_dataframe(
95-
df: pd.DataFrame,
96-
filter_function: Optional[Callable[[pd.DataFrame], pd.DataFrame]] = None,
97-
dekadal: bool = False,
98-
):
99-
# SAR cannot equal 0.0 since we take the log of it
100-
cols = [f"SAR-{s}-ts{t}-20m" for s in ["VV", "VH"] for t in range(36 if dekadal else 12)]
101-
102-
df = df.drop_duplicates(subset=["sample_id", "lat", "lon", "end_date"])
103-
df = df[~pd.isna(df).any(axis=1)]
104-
df = df[~(df.loc[:, cols] == 0.0).any(axis=1)]
105-
df = df.set_index("sample_id")
106-
if filter_function is not None:
107-
df = filter_function(df)
108-
return df
109-
11093
def _construct_finetuning_model(self, pretrained_model: Presto) -> PrestoFineTuningModel:
11194
model: PrestoFineTuningModel = cast(Callable, pretrained_model.construct_finetuning_model)(
11295
num_outputs=self.num_outputs

presto/inference.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,9 @@
1616
DynamicWorld2020_2021,
1717
)
1818
from .dataset import WorldCerealBase
19-
from .eval import WorldCerealEval
2019
from .masking import BAND_EXPANSION
2120
from .presto import Presto
22-
from .utils import device
21+
from .utils import device, prep_dataframe
2322

2423
# Index to band groups mapping
2524
IDX_TO_BAND_GROUPS = {
@@ -402,6 +401,6 @@ def process_parquet(df: pd.DataFrame) -> pd.DataFrame:
402401
df_pivot["end_date"] = df_pivot["end_date"].dt.date.astype(str)
403402
df_pivot["valid_date"] = df_pivot["valid_date"].dt.date.astype(str)
404403

405-
df_pivot = WorldCerealEval.prep_dataframe(df_pivot)
404+
df_pivot = prep_dataframe(df_pivot)
406405

407406
return df_pivot

presto/utils.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,8 @@ def plot_for_group(grp_df):
235235
partial(plot_map, diff_country, vmin=-1, cmap="coolwarm"),
236236
),
237237
plot(
238-
f"{name} AEZ - CatBoost", partial(plot_map, diff_aez, vmin=-1, cmap="coolwarm")
238+
f"{name} AEZ - CatBoost",
239+
partial(plot_map, diff_aez, vmin=-1, cmap="coolwarm"),
239240
),
240241
plot(
241242
f"{name} Year - CatBoost",
@@ -308,3 +309,23 @@ def load_world_df() -> pd.DataFrame:
308309
world_df = gpd.read_file(data_dir / filename)
309310
world_df = world_df.drop(columns=["status", "color_code", "iso_3166_1_"])
310311
return world_df
312+
313+
314+
def prep_dataframe(
315+
df: pd.DataFrame,
316+
filter_function: Optional[Callable[[pd.DataFrame], pd.DataFrame]] = None,
317+
dekadal: bool = False,
318+
):
319+
"""Duplication from eval.py but otherwise we would need catboost during
320+
presto inference on OpenEO.
321+
"""
322+
# SAR cannot equal 0.0 since we take the log of it
323+
cols = [f"SAR-{s}-ts{t}-20m" for s in ["VV", "VH"] for t in range(36 if dekadal else 12)]
324+
325+
df = df.drop_duplicates(subset=["sample_id", "lat", "lon", "end_date"])
326+
df = df[~pd.isna(df).any(axis=1)]
327+
df = df[~(df.loc[:, cols] == 0.0).any(axis=1)]
328+
df = df.set_index("sample_id")
329+
if filter_function is not None:
330+
df = filter_function(df)
331+
return df

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def load_dependencies(tag: str) -> List[str]:
3232
long_description_content_type="text/markdown",
3333
author="Gabriel Tseng",
3434
author_email="[email protected]",
35-
version="0.1.1",
35+
version="0.1.2",
3636
classifiers=[
3737
"Programming Language :: Python :: 3",
3838
"License :: Other/Proprietary License",

0 commit comments

Comments
 (0)