Skip to content

Commit

Permalink
add more described docs
Browse files Browse the repository at this point in the history
  • Loading branch information
SangbumChoi committed Sep 10, 2024
1 parent 22fc705 commit 0e5549f
Showing 1 changed file with 188 additions and 25 deletions.
213 changes: 188 additions & 25 deletions docs/source/en/model_doc/vitpose.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,31 +43,194 @@ The original code can be found [here](https://github.com/ViTAE-Transformer/ViTPo
- ViTPose is a so-called top-down keypoint detection model. This means that one first uses an object detector, like [RT-DETR](rt-detr), to detect people (or other instances) in an image. Next, ViTPose takes the cropped images as input and predicts the keypoints.

```py
>>> import torch
>>> import requests

>>> from PIL import Image
>>> from transformers import VitPoseImageProcessor, VitPoseForPoseEstimation

>>> url = 'http://images.cocodataset.org/val2017/000000000139.jpg'
>>> image = Image.open(requests.get(url, stream=True).raw)

>>> image_processor = VitPoseImageProcessor.from_pretrained("nielsr/vitpose-base-simple")
>>> model = VitPoseForPoseEstimation.from_pretrained("nielsr/vitpose-base-simple")

>>> boxes = [[[412.8, 157.61, 53.05, 138.01], [384.43, 172.21, 15.12, 35.74]]]

>>> pixel_values = image_processor(image, boxes=boxes, return_tensors="pt").pixel_values

>>> with torch.no_grad():
... outputs = model(pixel_values)

>>> pose_results = image_processor.post_process_pose_estimation(outputs, boxes=boxes)[0]

>>> for pose_result in pose_results:
... for keypoint in pose_result['keypoints']:
... x, y, score = keypoint
... print(f"coordinate : [{x}, {y}], score : {score}")
import torch
import requests
import numpy as np
import cv2
import math

from typing import Union
from PIL import Image
from transformers import RTDetrImageProcessor, RTDetrForObjectDetection
from transformers import VitPoseImageProcessor, VitPoseForPoseEstimation

url = 'http://images.cocodataset.org/val2017/000000000139.jpg'
image = Image.open(requests.get(url, stream=True).raw)

# Stage 1. Run Object Detector
# User can replace this object_detector part
person_image_processor = RTDetrImageProcessor.from_pretrained("PekingU/rtdetr_r50vd_coco_o365")
person_model = RTDetrForObjectDetection.from_pretrained("PekingU/rtdetr_r50vd_coco_o365")
inputs = person_image_processor(images=image, return_tensors="pt")

with torch.no_grad():
outputs = person_model(**inputs)

results = person_image_processor.post_process_object_detection(outputs, target_sizes=torch.tensor([image.size[::-1]]), threshold=0.3)

def pascal_voc_to_coco(bboxes: np.ndarray) -> np.ndarray:
"""
Converts bounding boxes from the Pascal VOC format to the COCO format.
In other words, converts from (top_left_x, top_left_y, bottom_right_x, bottom_right_y) format
to (top_left_x, top_left_y, width, height).
Args:
bboxes (`np.ndarray` of shape `(batch_size, 4)):
Bounding boxes in Pascal VOC format.
Returns:
`np.ndarray` of shape `(batch_size, 4) in COCO format.
"""
bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 0]
bboxes[:, 3] = bboxes[:, 3] - bboxes[:, 1]

return bboxes

# 0 index indicates human label in COCO
boxes = results[0]['boxes'][results[0]['labels'] == 0]
boxes = [pascal_voc_to_coco(boxes.cpu().numpy())]

image_processor = VitPoseImageProcessor.from_pretrained("nielsr/vitpose-base-simple")
model = VitPoseForPoseEstimation.from_pretrained("nielsr/vitpose-base-simple")

# Stage 2. Run ViTPose
pixel_values = image_processor(image, boxes=boxes, return_tensors="pt").pixel_values

with torch.no_grad():
outputs = model(pixel_values)

pose_results = image_processor.post_process_pose_estimation(outputs, boxes=boxes)[0]

for pose_result in pose_results:
for keypoint in pose_result['keypoints']:
x, y, score = keypoint
print(f"coordinate : [{x}, {y}], score : {score}")

def visualize_keypoints(img,
pose_result,
skeleton=None,
kpt_score_thr=0.3,
pose_kpt_color=None,
pose_link_color=None,
radius=4,
thickness=1,
show_keypoint_weight=False):
"""Draw keypoints and links on an image.
Args:
img (str or Tensor): The image to draw poses on. If an image array
is given, id will be modified in-place.
pose_result (list[kpts]): The poses to draw. Each element kpts is
a set of K keypoints as an Kx3 numpy.ndarray, where each
keypoint is represented as x, y, score.
kpt_score_thr (float, optional): Minimum score of keypoints
to be shown. Default: 0.3.
pose_kpt_color (np.array[Nx3]`): Color of N keypoints. If None,
the keypoint will not be drawn.
pose_link_color (np.array[Mx3]): Color of M links. If None, the
links will not be drawn.
thickness (int): Thickness of lines.
"""
img = img.copy()
img_h, img_w, _ = img.shape

for kpts in pose_result:
kpts = np.array(kpts, copy=False)

# draw each point on image
if pose_kpt_color is not None:
assert len(pose_kpt_color) == len(kpts)
for kid, kpt in enumerate(kpts):
x_coord, y_coord, kpt_score = int(kpt[0]), int(kpt[1]), kpt[2]
if kpt_score > kpt_score_thr:
color = tuple(int(c) for c in pose_kpt_color[kid])
if show_keypoint_weight:
img_copy = img.copy()
cv2.circle(img_copy, (int(x_coord), int(y_coord)),
radius, color, -1)
transparency = max(0, min(1, kpt_score))
cv2.addWeighted(
img_copy,
transparency,
img,
1 - transparency,
0,
dst=img)
else:
cv2.circle(img, (int(x_coord), int(y_coord)), radius,
color, -1)

# draw links
if skeleton is not None and pose_link_color is not None:
assert len(pose_link_color) == len(skeleton)
for sk_id, sk in enumerate(skeleton):
pos1 = (int(kpts[sk[0], 0]), int(kpts[sk[0], 1]))
pos2 = (int(kpts[sk[1], 0]), int(kpts[sk[1], 1]))
if (pos1[0] > 0 and pos1[0] < img_w and pos1[1] > 0
and pos1[1] < img_h and pos2[0] > 0 and pos2[0] < img_w
and pos2[1] > 0 and pos2[1] < img_h
and kpts[sk[0], 2] > kpt_score_thr
and kpts[sk[1], 2] > kpt_score_thr):
color = tuple(int(c) for c in pose_link_color[sk_id])
if show_keypoint_weight:
img_copy = img.copy()
X = (pos1[0], pos2[0])
Y = (pos1[1], pos2[1])
mX = np.mean(X)
mY = np.mean(Y)
length = ((Y[0] - Y[1])**2 + (X[0] - X[1])**2)**0.5
angle = math.degrees(
math.atan2(Y[0] - Y[1], X[0] - X[1]))
stickwidth = 2
polygon = cv2.ellipse2Poly(
(int(mX), int(mY)),
(int(length / 2), int(stickwidth)), int(angle), 0,
360, 1)
cv2.fillConvexPoly(img_copy, polygon, color)
transparency = max(
0, min(1, 0.5 * (kpts[sk[0], 2] + kpts[sk[1], 2])))
cv2.addWeighted(
img_copy,
transparency,
img,
1 - transparency,
0,
dst=img)
else:
cv2.line(img, pos1, pos2, color, thickness=thickness)

return img

# Note: skeleton and color palette are dataset-specific
skeleton = [[15, 13], [13, 11], [16, 14], [14, 12], [11, 12],
[5, 11], [6, 12], [5, 6], [5, 7], [6, 8], [7, 9],
[8, 10], [1, 2], [0, 1], [0, 2], [1, 3], [2, 4],
[3, 5], [4, 6]]

palette = np.array([[255, 128, 0], [255, 153, 51], [255, 178, 102],
[230, 230, 0], [255, 153, 255], [153, 204, 255],
[255, 102, 255], [255, 51, 255], [102, 178, 255],
[51, 153, 255], [255, 153, 153], [255, 102, 102],
[255, 51, 51], [153, 255, 153], [102, 255, 102],
[51, 255, 51], [0, 255, 0], [0, 0, 255],
[255, 0, 0], [255, 255, 255]])

pose_link_color = palette[[
0, 0, 0, 0, 7, 7, 7, 9, 9, 9, 9, 9, 16, 16, 16, 16, 16, 16, 16
]]
pose_kpt_color = palette[[
16, 16, 16, 16, 16, 9, 9, 9, 9, 9, 9, 0, 0, 0, 0, 0, 0
]]

pose_results = [result["keypoints"] for result in pose_results]

result = visualize_keypoints(np.array(image), pose_results, skeleton=skeleton, kpt_score_thr=0.3,
pose_kpt_color=pose_kpt_color, pose_link_color=pose_link_color,
radius=4, thickness=1)

pose_image = Image.fromarray(result)
pose_image
```
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/vitpose-coco.jpg" alt="drawing" width="600"/>

Expand Down

0 comments on commit 0e5549f

Please sign in to comment.