@@ -146,23 +146,64 @@ def test_anc2box_autoanchor(inference_v7_cfg: Config):
146146
147147
148148def test_bbox_nms ():
149- cls_dist = tensor (
150- [[[0.1 , 0.7 , 0.2 ], [0.6 , 0.3 , 0.1 ]], [[0.4 , 0.4 , 0.2 ], [0.5 , 0.4 , 0.1 ]]] # Example class distribution
149+ cls_dist = torch .tensor (
150+ [
151+ [
152+ [0.7 , 0.1 , 0.2 ], # High confidence, class 0
153+ [0.3 , 0.6 , 0.1 ], # High confidence, class 1
154+ [- 3.0 , - 2.0 , - 1.0 ], # low confidence, class 2
155+ [0.6 , 0.2 , 0.2 ], # Medium confidence, class 0
156+ ],
157+ [
158+ [0.55 , 0.25 , 0.2 ], # Medium confidence, class 0
159+ [- 4.0 , - 0.5 , - 2.0 ], # low confidence, class 1
160+ [0.15 , 0.2 , 0.65 ], # Medium confidence, class 2
161+ [0.8 , 0.1 , 0.1 ], # High confidence, class 0
162+ ],
163+ ],
164+ dtype = float32 ,
151165 )
152- bbox = tensor (
153- [[[50 , 50 , 100 , 100 ], [60 , 60 , 110 , 110 ]], [[40 , 40 , 90 , 90 ], [70 , 70 , 120 , 120 ]]], # Example bounding boxes
166+
167+ bbox = torch .tensor (
168+ [
169+ [
170+ [0 , 0 , 160 , 120 ], # Overlaps with box 4
171+ [160 , 120 , 320 , 240 ],
172+ [0 , 120 , 160 , 240 ],
173+ [16 , 12 , 176 , 132 ],
174+ ],
175+ [
176+ [0 , 0 , 160 , 120 ], # Overlaps with box 4
177+ [160 , 120 , 320 , 240 ],
178+ [0 , 120 , 160 , 240 ],
179+ [16 , 12 , 176 , 132 ],
180+ ],
181+ ],
154182 dtype = float32 ,
155183 )
184+
156185 nms_cfg = NMSConfig (min_confidence = 0.5 , min_iou = 0.5 )
157186
158- expected_output = [
159- tensor (
187+ # Batch 1:
188+ # - box 1 is kept with class 0 as it has a higher confidence than box 4 i.e. box 4 is filtered out
189+ # - box 2 is kept with class 1
190+ # - box 3 is rejected by the confidence filter
191+ # Batch 2:
192+ # - box 4 is kept with class 0 as it has a higher confidence than box 1 i.e. box 1 is filtered out
193+ # - box 2 is rejected by the confidence filter
194+ # - box 3 is kept with class 2
195+ expected_output = torch .tensor (
196+ [
160197 [
161- [1.0000 , 50.0000 , 50.0000 , 100.0000 , 100.0000 , 0.6682 ],
162- [0.0000 , 60.0000 , 60.0000 , 110.0000 , 110.0000 , 0.6457 ],
163- ]
164- )
165- ]
198+ [0.0 , 0.0 , 0.0 , 160.0 , 120.0 , 0.6682 ],
199+ [1.0 , 160.0 , 120.0 , 320.0 , 240.0 , 0.6457 ],
200+ ],
201+ [
202+ [0.0 , 16.0 , 12.0 , 176.0 , 132.0 , 0.6900 ],
203+ [2.0 , 0.0 , 120.0 , 160.0 , 240.0 , 0.6570 ],
204+ ],
205+ ]
206+ )
166207
167208 output = bbox_nms (cls_dist , bbox , nms_cfg )
168209
0 commit comments