Skip to content

Commit 2522f72

Browse files
committed
✅ [Pass] test in multiclass label&dynamic shape
1 parent 3092710 commit 2522f72

File tree

3 files changed

+14
-6
lines changed

3 files changed

+14
-6
lines changed

tests/test_tools/test_data_augmentation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def test_mosaic():
5454

5555
# Mock parent with image_size and get_more_data method
5656
class MockParent:
57-
image_size = (100, 100)
57+
base_size = 100
5858

5959
def get_more_data(self, num_images):
6060
return [(img, boxes) for _ in range(num_images)]

tests/test_tools/test_data_loader.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,14 @@ def test_training_data_loader_correctness(train_dataloader: DataLoader):
4343
def test_validation_data_loader_correctness(validation_dataloader: DataLoader):
4444
batch_size, images, targets, reverse_tensors, image_paths = next(iter(validation_dataloader))
4545
assert batch_size == 4
46-
assert images.shape == (4, 3, 640, 640)
46+
assert images.shape == (4, 3, 512, 768)
4747
assert targets.shape == (4, 18, 5)
4848
assert reverse_tensors.shape == (4, 5)
4949
expected_paths = [
50-
Path("tests/data/images/val/000000151480.jpg"),
5150
Path("tests/data/images/val/000000284106.jpg"),
52-
Path("tests/data/images/val/000000323571.jpg"),
51+
Path("tests/data/images/val/000000151480.jpg"),
5352
Path("tests/data/images/val/000000570456.jpg"),
53+
Path("tests/data/images/val/000000323571.jpg"),
5454
]
5555
assert list(image_paths) == list(expected_paths)
5656

tests/test_utils/test_bounding_box_utils.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ def test_bbox_nms():
182182
dtype=float32,
183183
)
184184

185-
nms_cfg = NMSConfig(min_confidence=0.5, min_iou=0.5)
185+
nms_cfg = NMSConfig(min_confidence=0.5, min_iou=0.5, max_bbox=400)
186186

187187
# Batch 1:
188188
# - box 1 is kept with class 0 as it has a higher confidence than box 4 i.e. box 4 is filtered out
@@ -197,16 +197,24 @@ def test_bbox_nms():
197197
[
198198
[0.0, 0.0, 0.0, 160.0, 120.0, 0.6682],
199199
[1.0, 160.0, 120.0, 320.0, 240.0, 0.6457],
200+
[0.0, 160.0, 120.0, 320.0, 240.0, 0.5744],
201+
[2.0, 0.0, 0.0, 160.0, 120.0, 0.5498],
202+
[1.0, 16.0, 12.0, 176.0, 132.0, 0.5498],
203+
[2.0, 160.0, 120.0, 320.0, 240.0, 0.5250],
200204
],
201205
[
202206
[0.0, 16.0, 12.0, 176.0, 132.0, 0.6900],
203207
[2.0, 0.0, 120.0, 160.0, 240.0, 0.6570],
208+
[1.0, 0.0, 0.0, 160.0, 120.0, 0.5622],
209+
[2.0, 0.0, 0.0, 160.0, 120.0, 0.5498],
210+
[1.0, 0.0, 120.0, 160.0, 240.0, 0.5498],
211+
[0.0, 0.0, 120.0, 160.0, 240.0, 0.5374],
204212
],
205213
]
206214
)
207215

208216
output = bbox_nms(cls_dist, bbox, nms_cfg)
209-
217+
print(output)
210218
for out, exp in zip(output, expected_output):
211219
assert allclose(out, exp, atol=1e-4), f"Output: {out} Expected: {exp}"
212220

0 commit comments

Comments
 (0)