Skip to content

Commit

Permalink
Merge pull request #128 from nasaharvest/train_on_subset
Browse files Browse the repository at this point in the history
Add script to train on subset of points
  • Loading branch information
gabrieltseng authored Dec 4, 2023
2 parents 731a4ba + 99d89d5 commit f7b3b8e
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 15 deletions.
99 changes: 99 additions & 0 deletions benchmarks/train_on_subset_of_points.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from pathlib import Path
import pandas as pd
import json

from cropharvest.utils import DATAFOLDER_PATH
from cropharvest.datasets import CropHarvest, Task, CropHarvestLabels
from cropharvest.engineer import TestInstance
from cropharvest.columns import RequiredColumns
from cropharvest.bbox import BBox

from sklearn.ensemble import RandomForestClassifier


def select_points(bounding_box: BBox, all_labels: CropHarvestLabels) -> pd.DataFrame:
"""
This is what participants would implement.
Given an evaluation dataset's bounding box, they would need
to implement some method of selecting points against
which a model will be trained
"""
# let's manually select points in the labels to be within the bounding box.
# 1. Make a new geojson. We do it according to the bounding boxes but this could be done
# in any way
filtered_geojson = all_labels.filter_geojson(
all_labels.as_geojson(),
bounding_box,
include_external_contributions=True,
)

# the csv will contain the ids and datasets of the selected rows
return pd.DataFrame(filtered_geojson[[RequiredColumns.DATASET, RequiredColumns.INDEX]])


def train_and_eval(
training_labels: pd.DataFrame, evaluation_dataset: CropHarvest, results_folder: Path
):
# 1. we make a training dataset from the labels
labels = CropHarvestLabels(DATAFOLDER_PATH)
filtered_labels = labels.as_geojson().merge(
training_labels, on=[RequiredColumns.DATASET, RequiredColumns.INDEX]
)
labels._labels = filtered_labels
training_dataset = CropHarvest(
evaluation_dataset.root,
Task(
target_label=evaluation_dataset.task.target_label,
balance_negative_crops=evaluation_dataset.task.balance_negative_crops,
),
)
training_dataset.update_labels(labels)

train_x, train_y = training_dataset.as_array(flatten_x=True)
# train a model
model = RandomForestClassifier()
model.fit(train_x, train_y)

json_suffix = f"{training_dataset.id}.json"
nc_suffix = f"{training_dataset.id}.nc"
for test_id, test_instance in evaluation_dataset.test_data(flatten_x=True, max_size=10000):
results_json = results_folder / f"{test_id}_{json_suffix}"
results_nc = results_folder / f"{test_id}_{nc_suffix}"

if results_json.exists():
print(f"Results already saved for {results_json} - skipping")

preds = model.predict_proba(test_instance.x)[:, 1]

results = test_instance.evaluate_predictions(preds)

with Path(results_json).open("w") as f:
json.dump(results, f)

ds = test_instance.to_xarray(preds)
ds.to_netcdf(results_nc)

# finally, we want to get results when all the test instances are considered
# together
all_nc_files = list(results_folder.glob(f"*_{nc_suffix}"))
combined_instance, combined_preds = TestInstance.load_from_nc(all_nc_files)

combined_results = combined_instance.evaluate_predictions(combined_preds)

with (results_folder / f"combined_{json_suffix}").open("w") as f:
json.dump(combined_results, f)


def main():
evaluation_datasets = CropHarvest.create_benchmark_datasets(DATAFOLDER_PATH)
all_labels = CropHarvestLabels(DATAFOLDER_PATH)
results_folder = DATAFOLDER_PATH / "data_centric_test"
results_folder.mkdir(exist_ok=True)

togo_eval = [x for x in evaluation_datasets if "Togo" in x.task.id][0]
training_points_df = select_points(togo_eval.task.bounding_box, all_labels)
train_and_eval(training_points_df, togo_eval, results_folder)


if __name__ == "__main__":
main()
42 changes: 27 additions & 15 deletions cropharvest/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,15 @@ def from_labels(cls):
class CropHarvest(BaseDataset):
"""Dataset consisting of satellite data and associated labels"""

filepaths: List[Path]
y_vals: List[int]
positive_indices: List[int]
negative_indices: List[int]
# used in the sample() function, to ensure filepaths are sampled without
# duplication as much as possible
sampled_positive_indices: List[int]
sampled_negative_indices: List[int]

def __init__(
self,
root,
Expand All @@ -218,37 +227,40 @@ def __init__(
print("Using the default task; crop vs. non crop globally")
task = Task()
self.task = task
self.is_val = is_val
self.val_ratio = val_ratio

self.normalizing_dict = load_normalizing_dict(
Path(root) / f"{FEATURES_DIR}/normalizing_dict.h5"
)
self.update_labels(labels)

def paths_from_labels(self, labels: CropHarvestLabels) -> Tuple[List[Path], List[Path]]:
positive_paths, negative_paths = labels.construct_positive_and_negative_labels(
task, filter_test=True
self.task, filter_test=True
)
if val_ratio > 0.0:
if self.val_ratio > 0.0:
# the fixed seed is to ensure the validation set is always
# different from the training set
positive_paths = deterministic_shuffle(positive_paths, seed=42)
negative_paths = deterministic_shuffle(negative_paths, seed=42)
if is_val:
positive_paths = positive_paths[: int(len(positive_paths) * val_ratio)]
negative_paths = negative_paths[: int(len(negative_paths) * val_ratio)]
if self.is_val:
positive_paths = positive_paths[: int(len(positive_paths) * self.val_ratio)]
negative_paths = negative_paths[: int(len(negative_paths) * self.val_ratio)]
else:
positive_paths = positive_paths[int(len(positive_paths) * val_ratio) :]
negative_paths = negative_paths[int(len(negative_paths) * val_ratio) :]

self.filepaths: List[Path] = positive_paths + negative_paths
self.y_vals: List[int] = [1] * len(positive_paths) + [0] * len(negative_paths)
positive_paths = positive_paths[int(len(positive_paths) * self.val_ratio) :]
negative_paths = negative_paths[int(len(negative_paths) * self.val_ratio) :]
return positive_paths, negative_paths

def update_labels(self, labels: CropHarvestLabels) -> None:
positive_paths, negative_paths = self.paths_from_labels(labels)
self.filepaths = positive_paths + negative_paths
self.y_vals = [1] * len(positive_paths) + [0] * len(negative_paths)
self.positive_indices = list(range(len(positive_paths)))
self.negative_indices = list(
range(len(positive_paths), len(positive_paths) + len(negative_paths))
)

# used in the sample() function, to ensure filepaths are sampled without
# duplication as much as possible
self.sampled_positive_indices: List[int] = []
self.sampled_negative_indices: List[int] = []
self.reset_sampled_indices()

def reset_sampled_indices(self) -> None:
self.sampled_positive_indices = []
Expand Down

0 comments on commit f7b3b8e

Please sign in to comment.