Skip to content

Commit

Permalink
Filter against the externally contributed datasets in the CropHarvest…
Browse files Browse the repository at this point in the history
… task
  • Loading branch information
gabrieltseng committed Aug 3, 2022
1 parent edb570d commit 8d00595
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 8 deletions.
2 changes: 1 addition & 1 deletion benchmarks/dl/maml.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ def _make_tasks(
if task.k >= min_task_k:
label_to_task[task.id] = task

for label in labels.classes_in_bbox(country_bbox):
for label in labels.classes_in_bbox(country_bbox, True):
if country in test_countries_to_crops:
if label in test_countries_to_crops[country]:
continue
Expand Down
28 changes: 21 additions & 7 deletions cropharvest/datasets.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from pathlib import Path
from xml.etree.ElementInclude import include
import geopandas
import numpy as np
import h5py
Expand Down Expand Up @@ -36,6 +37,7 @@ class Task:
balance_negative_crops: bool = False
test_identifier: Optional[str] = None
normalize: bool = True
include_externally_contributed_labels: bool = True

def __post_init__(self):
if self.target_label is None:
Expand Down Expand Up @@ -90,17 +92,27 @@ def as_geojson(self) -> geopandas.GeoDataFrame:
return self._labels

@staticmethod
def filter_geojson(gpdf: geopandas.GeoDataFrame, bounding_box: BBox) -> geopandas.GeoDataFrame:
def filter_geojson(
gpdf: geopandas.GeoDataFrame, bounding_box: BBox, include_external_contributions: bool
) -> geopandas.GeoDataFrame:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
# warning: invalid value encountered in ? (vectorized)
in_bounding_box = np.vectorize(bounding_box.contains)(
include_condition = np.vectorize(bounding_box.contains)(
gpdf[RequiredColumns.LAT], gpdf[RequiredColumns.LON]
)
return gpdf[in_bounding_box]

def classes_in_bbox(self, bounding_box: BBox) -> List[str]:
bbox_geojson = self.filter_geojson(self.as_geojson(), bounding_box)
if not include_external_contributions:
include_condition &= gpdf[
gpdf[RequiredColumns.EXTERNALLY_CONTRIBUTED_DATASET] == False
]
return gpdf[include_condition]

def classes_in_bbox(
self, bounding_box: BBox, include_external_contributions: bool
) -> List[str]:
bbox_geojson = self.filter_geojson(
self.as_geojson(), bounding_box, include_external_contributions
)
unique_labels = [x for x in bbox_geojson.label.unique() if x is not None]
return unique_labels

Expand All @@ -117,7 +129,9 @@ def construct_positive_and_negative_labels(
if filter_test:
gpdf = gpdf[gpdf[RequiredColumns.IS_TEST] == False]
if task.bounding_box is not None:
gpdf = self.filter_geojson(gpdf, task.bounding_box)
gpdf = self.filter_geojson(
gpdf, task.bounding_box, task.include_externally_contributed_labels
)

if len(gpdf) == 0:
raise NoDataForBoundingBoxError
Expand Down

0 comments on commit 8d00595

Please sign in to comment.