Skip to content

Commit

Permalink
add batch tests
Browse files Browse the repository at this point in the history
  • Loading branch information
SangbumChoi committed Sep 4, 2024
1 parent 2f3d6df commit 3b04bc7
Showing 1 changed file with 43 additions and 2 deletions.
45 changes: 43 additions & 2 deletions tests/models/vitpose/test_modeling_vitpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def test_inference_pose_estimation(self):

assert torch.allclose(heatmaps[0, 0, :3, :3], expected_slice, atol=1e-4)

pose_results = image_processor.post_process_pose_estimation(outputs, boxes=boxes)
pose_results = image_processor.post_process_pose_estimation(outputs, boxes=boxes)[0]

expected_bbox = torch.tensor([439.3250, 226.6150, 438.9719, 226.4776, 22320.4219, 0.0000]).to(torch_device)
expected_keypoints = torch.tensor(
Expand All @@ -277,4 +277,45 @@ def test_inference_pose_estimation(self):

@slow
def test_batched_inference(self):
raise NotImplementedError("To do")
image_processor = self.default_image_processor
# TODO update organization
model = ViTPoseForPoseEstimation.from_pretrained("nielsr/vitpose-base-simple")

image = prepare_img()
boxes = [
[[412.8, 157.61, 53.05, 138.01], [384.43, 172.21, 15.12, 35.74]],
[[412.8, 157.61, 53.05, 138.01], [384.43, 172.21, 15.12, 35.74]],
]

inputs = image_processor(images=[image, image], boxes=boxes, return_tensors="pt")

outputs = model(**inputs)
heatmaps = outputs.heatmaps

assert heatmaps.shape == (4, 17, 64, 48)

expected_slice = torch.tensor(
[
[9.9330e-06, 9.9330e-06, 9.9330e-06],
[9.9330e-06, 9.9330e-06, 9.9330e-06],
[9.9330e-06, 9.9330e-06, 9.9330e-06],
]
)

assert torch.allclose(heatmaps[0, 0, :3, :3], expected_slice, atol=1e-4)

pose_results = image_processor.post_process_pose_estimation(outputs, boxes=boxes)

expected_bbox = torch.tensor([439.3250, 226.6150, 438.9719, 226.4776, 22320.4219, 0.0000]).to(torch_device)
expected_keypoints = torch.tensor(
[
[3.9813e02, 1.8184e02, 8.7529e-01],
[3.9828e02, 1.7981e02, 8.4315e-01],
[3.9596e02, 1.7948e02, 9.2678e-01],
]
).to(torch_device)

self.assertEqual(len(pose_results), 2)
self.assertEqual(len(pose_results[0]), 2)
self.assertTrue(torch.allclose(pose_results[0][0]["bbox"], expected_bbox, atol=1e-4))
self.assertTrue(torch.allclose(pose_results[0][0]["keypoints"], expected_keypoints, atol=1e-4))

0 comments on commit 3b04bc7

Please sign in to comment.