Skip to content

Commit

Permalink
Merge pull request #109 from point-cloud-radar/108-skip-files-with-no…
Browse files Browse the repository at this point in the history
…t-enough-neighbours
  • Loading branch information
lyashevska authored Jul 3, 2023
2 parents cc78cce + c9a621d commit f95716a
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 0 deletions.
5 changes: 5 additions & 0 deletions bird_cloud_gnn/radar_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,11 @@ def _process_data(self, data, origin=""):
)
].index
).reset_index(drop=True)
if len(data) < self.num_neighbours:
print(
f"Warning: There are not enough points in {origin} to form neighbourhood of size {self.num_neighbours}"
)
return

data_xyz = data[xyz]
# remove the special features so they can be generated later
Expand Down
45 changes: 45 additions & 0 deletions tests/test_radar_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,3 +278,48 @@ def test_no_graphs(tmp_path):
num_neighbours=8,
)
assert len(dataset) == 0


def test_not_enough_points_in_neighbourhood(tmp_path):
with open(
tmp_path / "two_clusters_one_nan_one_labeled.csv", "w", encoding="utf-8"
) as f:
f.write(
"""range,x,y,z,f1,target
10000,1,1,1,1,
10000,0,1,1,2,
10000,1,0,1,3,
10000,1,1,0,4,
10000,5,5,5,5,0
10000,6,5,5,6,1
10000,5,6,5,7,1
10000,5,5,6,8,1"""
)

dataset = RadarDataset(
tmp_path,
["x", "y", "z", "f1"],
"target",
num_neighbours=8,
max_edge_distance=2.0,
max_poi_per_label=10,
)
assert len(dataset) == 4
dataset = RadarDataset(
tmp_path,
["x", "y", "z", "f1"],
"target",
num_neighbours=9,
max_edge_distance=2.0,
max_poi_per_label=10,
)
assert len(dataset) == 0
dataset = RadarDataset(
pd.read_csv(os.path.join(tmp_path, "two_clusters_one_nan_one_labeled.csv")),
["x", "y", "z", "f1"],
"target",
num_neighbours=9,
max_edge_distance=2.0,
max_poi_per_label=10,
)
assert len(dataset) == 0

0 comments on commit f95716a

Please sign in to comment.