-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathheat_map_utils.py
31 lines (25 loc) · 912 Bytes
/
heat_map_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch
import torch.nn as nn
def find_keypoints_max(heatmaps):
"""
heatmaps: C x H x W
return: C x 3
"""
# flatten the last axis
heatmaps_flat = heatmaps.view(heatmaps.size(0), -1)
# max loc
max_val, max_ind = heatmaps_flat.max(1)
max_ind = max_ind.float()
max_v = torch.floor(torch.div(max_ind, heatmaps.size(1)))
max_u = torch.fmod(max_ind, heatmaps.size(2))
return torch.cat((max_u.view(-1, 1), max_v.view(-1, 1), max_val.view(-1, 1)), 1)
def compute_uv_from_heatmaps(hm, resize_dim):
"""
:param hm: B x K x H x W (Variable)
:param resize_dim:
:return: uv in resize_dim (Variable)
"""
upsample = nn.Upsample(size=resize_dim, mode='bilinear') # (B x K) x H x W
resized_hm = upsample(hm).view(-1, resize_dim, resize_dim)
uv_confidence = find_keypoints_max(resized_hm) # (B x K) x 3
return uv_confidence.view(-1, hm.size(1), 3)