Skip to content

Commit

Permalink
change to accept target_size
Browse files Browse the repository at this point in the history
  • Loading branch information
SangbumChoi committed Sep 11, 2024
1 parent 12a7b8c commit 220859d
Showing 1 changed file with 18 additions and 3 deletions.
21 changes: 18 additions & 3 deletions src/transformers/models/vitpose/image_processing_vitpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,7 +591,9 @@ def keypoints_from_heatmaps(

return preds, scores

def post_process_pose_estimation(self, outputs, boxes, kernel_size=11):
def post_process_pose_estimation(
self, outputs, boxes, kernel_size=11, target_sizes: Union[TensorType, List[Tuple]] = None
):
"""
Transform the heatmaps into keypoint predictions and transform them back to the image.
Expand All @@ -603,17 +605,31 @@ def post_process_pose_estimation(self, outputs, boxes, kernel_size=11):
box coordinates in COCO format (top_left_x, top_left_y, width, height).
kernel_size (`int`, *optional*, defaults to 11):
Gaussian kernel size (K) for modulation.
target_sizes (`torch.Tensor` or `List[Tuple[int, int]]`, *optional*):
Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size
`(height, width)` of each image in the batch. If unset, predictions will be resize with the default value.
Returns:
`List[List[Dict]]`: A list of dictionaries, each dictionary containing the keypoints and boxes for an image
in the batch as predicted by the model.
"""

# First compute centers and scales for each bounding box
batch_size = len(outputs.heatmaps)

if target_sizes is not None:
if batch_size != len(target_sizes):
raise ValueError(
"Make sure that you pass in as many target sizes as the batch dimension of the logits"
)

centers = np.zeros((batch_size, 2), dtype=np.float32)
scales = np.zeros((batch_size, 2), dtype=np.float32)
flattened_boxes = list(itertools.chain(*boxes))
for i in range(batch_size):
if target_sizes is not None:
img_w, img_h = target_sizes[i][0], target_sizes[i][1]
scale_fct = np.array([img_w, img_h, img_w, img_h])
flattened_boxes[i] = flattened_boxes[i] * scale_fct
width, height = self.size["width"], self.size["height"]
center, scale = box_to_center_and_scale(flattened_boxes[i], image_width=width, image_height=height)
centers[i, :] = center
Expand All @@ -630,7 +646,6 @@ def post_process_pose_estimation(self, outputs, boxes, kernel_size=11):
all_boxes[:, 4] = np.prod(scales * 200.0, axis=1)

poses = torch.Tensor(all_preds)

bboxes_xyxy = torch.Tensor(coco_to_pascal_voc(all_boxes))

results: List[List[Dict[str, torch.Tensor]]] = []
Expand All @@ -646,4 +661,4 @@ def post_process_pose_estimation(self, outputs, boxes, kernel_size=11):
batch_results.append(pose_result)
results.append(batch_results)

return results
r

0 comments on commit 220859d

Please sign in to comment.