From c9a621d84284ebae39cba46f2cb72d77909f9f63 Mon Sep 17 00:00:00 2001 From: Abel Soares Siqueira Date: Tue, 27 Jun 2023 12:03:06 +0200 Subject: [PATCH] Skip files that don't have enough points to create a neighbourhood --- bird_cloud_gnn/radar_dataset.py | 5 ++++ tests/test_radar_dataset.py | 45 +++++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+) diff --git a/bird_cloud_gnn/radar_dataset.py b/bird_cloud_gnn/radar_dataset.py index 45927af..10ee60f 100644 --- a/bird_cloud_gnn/radar_dataset.py +++ b/bird_cloud_gnn/radar_dataset.py @@ -117,6 +117,11 @@ def _process_data(self, data, origin=""): ) ].index ).reset_index(drop=True) + if len(data) < self.num_neighbours: + print( + f"Warning: There are not enough points in {origin} to form neighbourhood of size {self.num_neighbours}" + ) + return data_xyz = data[xyz] # remove the special features so they can be generated later diff --git a/tests/test_radar_dataset.py b/tests/test_radar_dataset.py index 9cb04a4..0000c0d 100644 --- a/tests/test_radar_dataset.py +++ b/tests/test_radar_dataset.py @@ -278,3 +278,48 @@ def test_no_graphs(tmp_path): num_neighbours=8, ) assert len(dataset) == 0 + + +def test_not_enough_points_in_neighbourhood(tmp_path): + with open( + tmp_path / "two_clusters_one_nan_one_labeled.csv", "w", encoding="utf-8" + ) as f: + f.write( + """range,x,y,z,f1,target +10000,1,1,1,1, +10000,0,1,1,2, +10000,1,0,1,3, +10000,1,1,0,4, +10000,5,5,5,5,0 +10000,6,5,5,6,1 +10000,5,6,5,7,1 +10000,5,5,6,8,1""" + ) + + dataset = RadarDataset( + tmp_path, + ["x", "y", "z", "f1"], + "target", + num_neighbours=8, + max_edge_distance=2.0, + max_poi_per_label=10, + ) + assert len(dataset) == 4 + dataset = RadarDataset( + tmp_path, + ["x", "y", "z", "f1"], + "target", + num_neighbours=9, + max_edge_distance=2.0, + max_poi_per_label=10, + ) + assert len(dataset) == 0 + dataset = RadarDataset( + pd.read_csv(os.path.join(tmp_path, "two_clusters_one_nan_one_labeled.csv")), + ["x", "y", "z", "f1"], + "target", + num_neighbours=9, + max_edge_distance=2.0, + max_poi_per_label=10, + ) + assert len(dataset) == 0