From 220859d1eebae98a9f4ddcf47f803d3873950a57 Mon Sep 17 00:00:00 2001 From: sangbumchoi Date: Wed, 11 Sep 2024 03:21:43 +0000 Subject: [PATCH] change to accept target_size --- .../vitpose/image_processing_vitpose.py | 21 ++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/vitpose/image_processing_vitpose.py b/src/transformers/models/vitpose/image_processing_vitpose.py index 33bf39cdd22553..89357f681b88c7 100644 --- a/src/transformers/models/vitpose/image_processing_vitpose.py +++ b/src/transformers/models/vitpose/image_processing_vitpose.py @@ -591,7 +591,9 @@ def keypoints_from_heatmaps( return preds, scores - def post_process_pose_estimation(self, outputs, boxes, kernel_size=11): + def post_process_pose_estimation( + self, outputs, boxes, kernel_size=11, target_sizes: Union[TensorType, List[Tuple]] = None + ): """ Transform the heatmaps into keypoint predictions and transform them back to the image. @@ -603,6 +605,9 @@ def post_process_pose_estimation(self, outputs, boxes, kernel_size=11): box coordinates in COCO format (top_left_x, top_left_y, width, height). kernel_size (`int`, *optional*, defaults to 11): Gaussian kernel size (K) for modulation. + target_sizes (`torch.Tensor` or `List[Tuple[int, int]]`, *optional*): + Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size + `(height, width)` of each image in the batch. If unset, predictions will be resize with the default value. Returns: `List[List[Dict]]`: A list of dictionaries, each dictionary containing the keypoints and boxes for an image in the batch as predicted by the model. @@ -610,10 +615,21 @@ def post_process_pose_estimation(self, outputs, boxes, kernel_size=11): # First compute centers and scales for each bounding box batch_size = len(outputs.heatmaps) + + if target_sizes is not None: + if batch_size != len(target_sizes): + raise ValueError( + "Make sure that you pass in as many target sizes as the batch dimension of the logits" + ) + centers = np.zeros((batch_size, 2), dtype=np.float32) scales = np.zeros((batch_size, 2), dtype=np.float32) flattened_boxes = list(itertools.chain(*boxes)) for i in range(batch_size): + if target_sizes is not None: + img_w, img_h = target_sizes[i][0], target_sizes[i][1] + scale_fct = np.array([img_w, img_h, img_w, img_h]) + flattened_boxes[i] = flattened_boxes[i] * scale_fct width, height = self.size["width"], self.size["height"] center, scale = box_to_center_and_scale(flattened_boxes[i], image_width=width, image_height=height) centers[i, :] = center @@ -630,7 +646,6 @@ def post_process_pose_estimation(self, outputs, boxes, kernel_size=11): all_boxes[:, 4] = np.prod(scales * 200.0, axis=1) poses = torch.Tensor(all_preds) - bboxes_xyxy = torch.Tensor(coco_to_pascal_voc(all_boxes)) results: List[List[Dict[str, torch.Tensor]]] = [] @@ -646,4 +661,4 @@ def post_process_pose_estimation(self, outputs, boxes, kernel_size=11): batch_results.append(pose_result) results.append(batch_results) - return results + r \ No newline at end of file