Skip to content

Commit e3d0b68

Browse files
committed
add visualization of keypoint localization
1 parent bbd2583 commit e3d0b68

File tree

5 files changed

+544
-18
lines changed

5 files changed

+544
-18
lines changed

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,10 @@ If setup correctly, the output will look like
9595

9696
![cat](./assets/cat.png)
9797

98+
### Visualization of the voting procedure
99+
100+
We add a jupyter notebook [visualization.ipynb](./visualization.ipynb) for the keypoint detection pipeline of PVNet, aiming to make it easier for readers to understand our paper. Thanks for Kudlur, M 's suggestion.
101+
98102
## Training and testing
99103

100104
### Training on the LINEMOD

lib/ransac_voting_gpu_layer/ransac_voting_gpu.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -980,6 +980,60 @@ def ransac_motion_voting(mask, vertex):
980980

981981
return torch.cat(pts,0)
982982

983+
def generate_hypothesis(mask, vertex, round_hyp_num, inlier_thresh=0.999, confidence=0.99, max_iter=20,
984+
min_num=5, max_num=30000):
985+
'''
986+
:param mask: [b,h,w]
987+
:param vertex: [b,h,w,vn,2]
988+
:param round_hyp_num:
989+
:param inlier_thresh:
990+
:return: [b,vn,2]
991+
'''
992+
b, h, w, vn, _ = vertex.shape
993+
batch_hyp_pts = []
994+
batch_hyp_counts = []
995+
for bi in range(b):
996+
hyp_num = 0
997+
cur_mask = (mask[bi]).byte()
998+
foreground_num = torch.sum(cur_mask)
999+
1000+
# if too few points, just skip it
1001+
if foreground_num < min_num:
1002+
win_pts = torch.zeros([1, vn, 2], dtype=torch.float32, device=mask.device)
1003+
batch_win_pts.append(win_pts) # [1,vn,2]
1004+
continue
1005+
1006+
# if too many inliers, we randomly down sample it
1007+
if foreground_num > max_num:
1008+
selection = torch.zeros(cur_mask.shape, dtype=torch.float32, device=mask.device).uniform_(0, 1)
1009+
selected_mask = (selection < (max_num / foreground_num.float()))
1010+
cur_mask *= selected_mask
1011+
1012+
coords = torch.nonzero(cur_mask).float() # [tn,2]
1013+
coords = coords[:, [1, 0]]
1014+
direct = vertex[bi].masked_select(torch.unsqueeze(torch.unsqueeze(cur_mask, 2), 3)) # [tn,vn,2]
1015+
direct = direct.view([coords.shape[0], vn, 2])
1016+
tn = coords.shape[0]
1017+
idxs = torch.zeros([round_hyp_num, vn, 2], dtype=torch.int32, device=mask.device).random_(0, direct.shape[0])
1018+
all_win_ratio = torch.zeros([vn], dtype=torch.float32, device=mask.device)
1019+
all_win_pts = torch.zeros([vn, 2], dtype=torch.float32, device=mask.device)
1020+
1021+
# generate hypothesis
1022+
cur_hyp_pts = ransac_voting.generate_hypothesis(direct, coords, idxs) # [hn,vn,2]
1023+
1024+
# voting for hypothesis
1025+
cur_inlier = torch.zeros([round_hyp_num, vn, tn], dtype=torch.uint8, device=mask.device)
1026+
ransac_voting.voting_for_hypothesis(direct, coords, cur_hyp_pts, cur_inlier, inlier_thresh) # [hn,vn,tn]
1027+
1028+
# find max
1029+
cur_inlier_counts = torch.sum(cur_inlier, 2) # [hn,vn]
1030+
1031+
batch_hyp_pts.append(cur_hyp_pts)
1032+
batch_hyp_counts.append(cur_inlier_counts)
1033+
1034+
return torch.stack(batch_hyp_pts), torch.stack(batch_hyp_counts)
1035+
1036+
9831037

9841038
if __name__=="__main__":
9851039
from lib.datasets.linemod_dataset import LineModDatasetRealAug,VotingType

lib/utils/draw_utils.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -183,25 +183,26 @@ def visualize_voting_ellipse(rgb,mean,var,target,save=False, save_fn=None):
183183
:return:
184184
'''
185185
b,vn,_=mean.shape
186+
yellow=np.array([1.0,0.0,0.0])
187+
red=np.asarray([1.0,1.0,0.0])
188+
num=5
186189
for bi in range(b):
187-
_, ax = plt.subplots(1)
188-
189190
for vi in range(vn):
191+
_, ax = plt.subplots(1, figsize=(10, 8))
190192
cov=var[bi,vi]
191193
w,v=np.linalg.eig(cov)
192-
w*=50
193-
elp=patches.Ellipse(mean[bi,vi],w[0],w[1],np.arctan2(v[1,0],v[0,0])/np.pi*180,fill=False)
194-
ax.add_patch(elp)
195-
196-
ax.plot(target[bi,:,0],target[bi,:,1],'*')
197-
ax.scatter(mean[bi,:,0],mean[bi,:,1],c=np.arange(vn))
198-
ax.imshow(rgb[bi])
199-
if save:
200-
plt.savefig(save_fn.format(bi))
201-
else:
202-
plt.show()
203-
plt.close()
194+
for k in range(num):
195+
size=w*k*3
196+
elp = patches.Ellipse(mean[bi, vi], size[0], size[1], np.arctan2(v[1, 0], v[0, 0]) / np.pi * 180, fill=False, color=yellow/num*(num-k)+red/num*k)
197+
ax.add_patch(elp)
204198

199+
ax.scatter(mean[bi,vi,0],mean[bi,vi,1], marker='*', c=[yellow], s=8)
200+
ax.imshow(rgb[bi])
201+
if save:
202+
plt.savefig(save_fn.format(bi))
203+
else:
204+
plt.show()
205+
plt.close()
205206

206207

207208

tools/demo.py

Lines changed: 59 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
sys.path.append('.')
44
sys.path.append('..')
55
from lib.networks.model_repository import *
6-
from lib.utils.arg_utils import args
76
from lib.utils.net_utils import smooth_l1_loss, load_model, compute_precision_recall
87
import torch
98
from lib.ransac_voting_gpu_layer.ransac_voting_gpu import ransac_voting_layer_v3
@@ -22,9 +21,9 @@
2221
import numpy as np
2322
import matplotlib.pyplot as plt
2423

25-
with open(args.cfg_file, 'r') as f:
24+
with open('configs/linemod_train.json', 'r') as f:
2625
train_cfg = json.load(f)
27-
train_cfg['model_name'] = '{}_{}'.format(args.linemod_cls, train_cfg['model_name'])
26+
train_cfg['model_name'] = '{}_{}'.format('cat', train_cfg['model_name'])
2827

2928
vote_num = 9
3029

@@ -104,18 +103,74 @@ def read_data():
104103
return data, points_3d, bb8_3d
105104

106105

106+
def visualize_mask(mask):
107+
plt.imshow(mask[0])
108+
plt.show()
109+
110+
111+
def visualize_vertex(vertex, vertex_weights):
112+
vertex = vertex * vertex_weights
113+
for i in range(9):
114+
_, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 8))
115+
ax1.imshow(vertex[0, 2*i])
116+
ax2.imshow(vertex[0, 2*i+1])
117+
plt.show()
118+
119+
120+
def visualize_hypothesis(image, seg_pred, vertex_pred, corner_target):
121+
from lib.ransac_voting_gpu_layer.ransac_voting_gpu import generate_hypothesis
122+
123+
vertex_pred = vertex_pred.permute(0, 2, 3, 1)
124+
b, h, w, vn_2 = vertex_pred.shape
125+
vertex_pred = vertex_pred.view(b, h, w, vn_2 // 2, 2)
126+
mask = torch.argmax(seg_pred, 1)
127+
hyp, hyp_counts = generate_hypothesis(mask, vertex_pred, 1024, inlier_thresh=0.99)
128+
129+
image = imagenet_to_uint8(image.detach().cpu().numpy())
130+
hyp = hyp.detach().cpu().numpy()
131+
hyp_counts = hyp_counts.detach().cpu().numpy()
132+
133+
from lib.utils.draw_utils import visualize_hypothesis
134+
visualize_hypothesis(image, hyp, hyp_counts, corner_target)
135+
136+
137+
def visualize_voting_ellipse(image, seg_pred, vertex_pred, corner_target):
138+
from lib.ransac_voting_gpu_layer.ransac_voting_gpu import estimate_voting_distribution_with_mean
139+
140+
vertex_pred = vertex_pred.permute(0, 2, 3, 1)
141+
b, h, w, vn_2 = vertex_pred.shape
142+
vertex_pred = vertex_pred.view(b, h, w, vn_2//2, 2)
143+
mask = torch.argmax(seg_pred, 1)
144+
mean = ransac_voting_layer_v3(mask, vertex_pred, 512, inlier_thresh=0.99)
145+
mean, var = estimate_voting_distribution_with_mean(mask, vertex_pred, mean)
146+
147+
image = imagenet_to_uint8(image.detach().cpu().numpy())
148+
mean = mean.detach().cpu().numpy()
149+
var = var.detach().cpu().numpy()
150+
corner_target = corner_target.detach().cpu().numpy()
151+
152+
from lib.utils.draw_utils import visualize_voting_ellipse
153+
visualize_voting_ellipse(image, mean, var, corner_target)
154+
155+
156+
107157
def demo():
108158
net = Resnet18_8s(ver_dim=vote_num * 2, seg_dim=2)
109159
net = NetWrapper(net).cuda()
110160
net = DataParallel(net)
111161

112162
optimizer = optim.Adam(net.parameters(), lr=train_cfg['lr'])
113163
model_dir = os.path.join(cfg.MODEL_DIR, "cat_demo")
114-
load_model(net.module.net, optimizer, model_dir, args.load_epoch)
164+
load_model(net.module.net, optimizer, model_dir, -1)
115165
data, points_3d, bb8_3d = read_data()
116166
image, mask, vertex, vertex_weights, pose, corner_target = [d.unsqueeze(0).cuda() for d in data]
117167
seg_pred, vertex_pred, loss_seg, loss_vertex, precision, recall = net(image, mask, vertex, vertex_weights)
118168

169+
# visualize_mask(mask)
170+
# visualize_vertex(vertex, vertex_weights)
171+
# visualize_hypothesis(image, seg_pred, vertex_pred, corner_target)
172+
# visualize_voting_ellipse(image, seg_pred, vertex_pred, corner_target)
173+
119174
eval_net = DataParallel(EvalWrapper().cuda())
120175
corner_pred = eval_net(seg_pred, vertex_pred).cpu().detach().numpy()[0]
121176
camera_matrix = np.array([[572.4114, 0., 325.2611],

visualization.ipynb

Lines changed: 412 additions & 0 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)