diff --git a/benchmarks/dl/maml.py b/benchmarks/dl/maml.py index c1f824b8..6bfc701e 100644 --- a/benchmarks/dl/maml.py +++ b/benchmarks/dl/maml.py @@ -393,7 +393,7 @@ def _make_tasks( if task.k >= min_task_k: label_to_task[task.id] = task - for label in labels.classes_in_bbox(country_bbox): + for label in labels.classes_in_bbox(country_bbox, True): if country in test_countries_to_crops: if label in test_countries_to_crops[country]: continue diff --git a/cropharvest/datasets.py b/cropharvest/datasets.py index ca082a97..ab606823 100644 --- a/cropharvest/datasets.py +++ b/cropharvest/datasets.py @@ -1,4 +1,5 @@ from pathlib import Path +from xml.etree.ElementInclude import include import geopandas import numpy as np import h5py @@ -36,6 +37,7 @@ class Task: balance_negative_crops: bool = False test_identifier: Optional[str] = None normalize: bool = True + include_externally_contributed_labels: bool = True def __post_init__(self): if self.target_label is None: @@ -90,17 +92,27 @@ def as_geojson(self) -> geopandas.GeoDataFrame: return self._labels @staticmethod - def filter_geojson(gpdf: geopandas.GeoDataFrame, bounding_box: BBox) -> geopandas.GeoDataFrame: + def filter_geojson( + gpdf: geopandas.GeoDataFrame, bounding_box: BBox, include_external_contributions: bool + ) -> geopandas.GeoDataFrame: with warnings.catch_warnings(): warnings.simplefilter("ignore") # warning: invalid value encountered in ? (vectorized) - in_bounding_box = np.vectorize(bounding_box.contains)( + include_condition = np.vectorize(bounding_box.contains)( gpdf[RequiredColumns.LAT], gpdf[RequiredColumns.LON] ) - return gpdf[in_bounding_box] - - def classes_in_bbox(self, bounding_box: BBox) -> List[str]: - bbox_geojson = self.filter_geojson(self.as_geojson(), bounding_box) + if not include_external_contributions: + include_condition &= gpdf[ + gpdf[RequiredColumns.EXTERNALLY_CONTRIBUTED_DATASET] == False + ] + return gpdf[include_condition] + + def classes_in_bbox( + self, bounding_box: BBox, include_external_contributions: bool + ) -> List[str]: + bbox_geojson = self.filter_geojson( + self.as_geojson(), bounding_box, include_external_contributions + ) unique_labels = [x for x in bbox_geojson.label.unique() if x is not None] return unique_labels @@ -117,7 +129,9 @@ def construct_positive_and_negative_labels( if filter_test: gpdf = gpdf[gpdf[RequiredColumns.IS_TEST] == False] if task.bounding_box is not None: - gpdf = self.filter_geojson(gpdf, task.bounding_box) + gpdf = self.filter_geojson( + gpdf, task.bounding_box, task.include_externally_contributed_labels + ) if len(gpdf) == 0: raise NoDataForBoundingBoxError