diff --git a/benchmarks/train_on_subset_of_points.py b/benchmarks/train_on_subset_of_points.py new file mode 100644 index 0000000..02ca967 --- /dev/null +++ b/benchmarks/train_on_subset_of_points.py @@ -0,0 +1,99 @@ +from pathlib import Path +import pandas as pd +import json + +from cropharvest.utils import DATAFOLDER_PATH +from cropharvest.datasets import CropHarvest, Task, CropHarvestLabels +from cropharvest.engineer import TestInstance +from cropharvest.columns import RequiredColumns +from cropharvest.bbox import BBox + +from sklearn.ensemble import RandomForestClassifier + + +def select_points(bounding_box: BBox, all_labels: CropHarvestLabels) -> pd.DataFrame: + """ + This is what participants would implement. + Given an evaluation dataset's bounding box, they would need + to implement some method of selecting points against + which a model will be trained + """ + # let's manually select points in the labels to be within the bounding box. + # 1. Make a new geojson. We do it according to the bounding boxes but this could be done + # in any way + filtered_geojson = all_labels.filter_geojson( + all_labels.as_geojson(), + bounding_box, + include_external_contributions=True, + ) + + # the csv will contain the ids and datasets of the selected rows + return pd.DataFrame(filtered_geojson[[RequiredColumns.DATASET, RequiredColumns.INDEX]]) + + +def train_and_eval( + training_labels: pd.DataFrame, evaluation_dataset: CropHarvest, results_folder: Path +): + # 1. we make a training dataset from the labels + labels = CropHarvestLabels(DATAFOLDER_PATH) + filtered_labels = labels.as_geojson().merge( + training_labels, on=[RequiredColumns.DATASET, RequiredColumns.INDEX] + ) + labels._labels = filtered_labels + training_dataset = CropHarvest( + evaluation_dataset.root, + Task( + target_label=evaluation_dataset.task.target_label, + balance_negative_crops=evaluation_dataset.task.balance_negative_crops, + ), + ) + training_dataset.update_labels(labels) + + train_x, train_y = training_dataset.as_array(flatten_x=True) + # train a model + model = RandomForestClassifier() + model.fit(train_x, train_y) + + json_suffix = f"{training_dataset.id}.json" + nc_suffix = f"{training_dataset.id}.nc" + for test_id, test_instance in evaluation_dataset.test_data(flatten_x=True, max_size=10000): + results_json = results_folder / f"{test_id}_{json_suffix}" + results_nc = results_folder / f"{test_id}_{nc_suffix}" + + if results_json.exists(): + print(f"Results already saved for {results_json} - skipping") + + preds = model.predict_proba(test_instance.x)[:, 1] + + results = test_instance.evaluate_predictions(preds) + + with Path(results_json).open("w") as f: + json.dump(results, f) + + ds = test_instance.to_xarray(preds) + ds.to_netcdf(results_nc) + + # finally, we want to get results when all the test instances are considered + # together + all_nc_files = list(results_folder.glob(f"*_{nc_suffix}")) + combined_instance, combined_preds = TestInstance.load_from_nc(all_nc_files) + + combined_results = combined_instance.evaluate_predictions(combined_preds) + + with (results_folder / f"combined_{json_suffix}").open("w") as f: + json.dump(combined_results, f) + + +def main(): + evaluation_datasets = CropHarvest.create_benchmark_datasets(DATAFOLDER_PATH) + all_labels = CropHarvestLabels(DATAFOLDER_PATH) + results_folder = DATAFOLDER_PATH / "data_centric_test" + results_folder.mkdir(exist_ok=True) + + togo_eval = [x for x in evaluation_datasets if "Togo" in x.task.id][0] + training_points_df = select_points(togo_eval.task.bounding_box, all_labels) + train_and_eval(training_points_df, togo_eval, results_folder) + + +if __name__ == "__main__": + main() diff --git a/cropharvest/datasets.py b/cropharvest/datasets.py index 5a32e56..f37a70f 100644 --- a/cropharvest/datasets.py +++ b/cropharvest/datasets.py @@ -203,6 +203,15 @@ def from_labels(cls): class CropHarvest(BaseDataset): """Dataset consisting of satellite data and associated labels""" + filepaths: List[Path] + y_vals: List[int] + positive_indices: List[int] + negative_indices: List[int] + # used in the sample() function, to ensure filepaths are sampled without + # duplication as much as possible + sampled_positive_indices: List[int] + sampled_negative_indices: List[int] + def __init__( self, root, @@ -218,37 +227,40 @@ def __init__( print("Using the default task; crop vs. non crop globally") task = Task() self.task = task + self.is_val = is_val + self.val_ratio = val_ratio self.normalizing_dict = load_normalizing_dict( Path(root) / f"{FEATURES_DIR}/normalizing_dict.h5" ) + self.update_labels(labels) + def paths_from_labels(self, labels: CropHarvestLabels) -> Tuple[List[Path], List[Path]]: positive_paths, negative_paths = labels.construct_positive_and_negative_labels( - task, filter_test=True + self.task, filter_test=True ) - if val_ratio > 0.0: + if self.val_ratio > 0.0: # the fixed seed is to ensure the validation set is always # different from the training set positive_paths = deterministic_shuffle(positive_paths, seed=42) negative_paths = deterministic_shuffle(negative_paths, seed=42) - if is_val: - positive_paths = positive_paths[: int(len(positive_paths) * val_ratio)] - negative_paths = negative_paths[: int(len(negative_paths) * val_ratio)] + if self.is_val: + positive_paths = positive_paths[: int(len(positive_paths) * self.val_ratio)] + negative_paths = negative_paths[: int(len(negative_paths) * self.val_ratio)] else: - positive_paths = positive_paths[int(len(positive_paths) * val_ratio) :] - negative_paths = negative_paths[int(len(negative_paths) * val_ratio) :] - - self.filepaths: List[Path] = positive_paths + negative_paths - self.y_vals: List[int] = [1] * len(positive_paths) + [0] * len(negative_paths) + positive_paths = positive_paths[int(len(positive_paths) * self.val_ratio) :] + negative_paths = negative_paths[int(len(negative_paths) * self.val_ratio) :] + return positive_paths, negative_paths + + def update_labels(self, labels: CropHarvestLabels) -> None: + positive_paths, negative_paths = self.paths_from_labels(labels) + self.filepaths = positive_paths + negative_paths + self.y_vals = [1] * len(positive_paths) + [0] * len(negative_paths) self.positive_indices = list(range(len(positive_paths))) self.negative_indices = list( range(len(positive_paths), len(positive_paths) + len(negative_paths)) ) - - # used in the sample() function, to ensure filepaths are sampled without - # duplication as much as possible - self.sampled_positive_indices: List[int] = [] - self.sampled_negative_indices: List[int] = [] + self.reset_sampled_indices() def reset_sampled_indices(self) -> None: self.sampled_positive_indices = []