@@ -185,13 +185,15 @@ def test_bbox_nms():
185185 nms_cfg = NMSConfig (min_confidence = 0.5 , min_iou = 0.5 , max_bbox = 400 )
186186
187187 # Batch 1:
188- # - box 1 is kept with class 0 as it has a higher confidence than box 4 i.e. box 4 is filtered out
189- # - box 2 is kept with class 1
190- # - box 3 is rejected by the confidence filter
188+ # - box 1 is kept with classes 0 and 2 as it overlaps with box 4 and has a higher confidence for classes 0 and 2.
189+ # - box 2 is kept with classes 0, 1, 2 as it does not overlap with any other box.
190+ # - box 3 is rejected by the confidence filter.
191+ # - box 4 is kept with class 1 as it overlaps with box 1 and has a higher confidence for class 1.
191192 # Batch 2:
192- # - box 4 is kept with class 0 as it has a higher confidence than box 1 i.e. box 1 is filtered out
193- # - box 2 is rejected by the confidence filter
194- # - box 3 is kept with class 2
193+ # - box 1 is kept with classes 1 and 2 as it overlaps with box 1 and has a higher confidence for classes 1 and 2.
194+ # - box 2 is rejected by the confidence filter.
195+ # - box 3 is kept with classes 0, 1, 2 as it does not overlap with any other box.
196+ # - box 4 is kept with class 0 as it overlaps with box 1 and has a higher confidence for class 0.
195197 expected_output = torch .tensor (
196198 [
197199 [
@@ -214,7 +216,6 @@ def test_bbox_nms():
214216 )
215217
216218 output = bbox_nms (cls_dist , bbox , nms_cfg )
217- print (output )
218219 for out , exp in zip (output , expected_output ):
219220 assert allclose (out , exp , atol = 1e-4 ), f"Output: { out } Expected: { exp } "
220221
0 commit comments