Skip to content

Commit c71a47d

Browse files
committed
added heatmap visualizer
1 parent b23a2c4 commit c71a47d

File tree

2 files changed

+122
-0
lines changed

2 files changed

+122
-0
lines changed

imgs/peppers.png

526 KB
Loading

viz_heatmaps.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
import pdb
2+
import os
3+
import sys
4+
import tqdm
5+
6+
import numpy as np
7+
import torch
8+
9+
from PIL import Image
10+
from matplotlib import pyplot as pl; pl.ion()
11+
from scipy.ndimage import uniform_filter
12+
smooth = lambda arr: uniform_filter(arr, 3)
13+
14+
def transparent(img, alpha, cmap, **kw):
15+
from matplotlib.colors import Normalize
16+
colored_img = cmap(Normalize(clip=True,**kw)(img))
17+
colored_img[:,:,-1] = alpha
18+
return colored_img
19+
20+
from tools import common
21+
from tools.dataloader import norm_RGB
22+
from nets.patchnet import *
23+
from extract import NonMaxSuppression
24+
25+
26+
if __name__ == '__main__':
27+
import argparse
28+
parser = argparse.ArgumentParser("Visualize the patch detector and descriptor")
29+
30+
parser.add_argument("--img", type=str, default="imgs/brooklyn.png")
31+
parser.add_argument("--resize", type=int, default=512)
32+
parser.add_argument("--out", type=str, default="viz.png")
33+
34+
parser.add_argument("--checkpoint", type=str, required=True, help='network path')
35+
parser.add_argument("--net", type=str, default="", help='network command')
36+
37+
parser.add_argument("--max-kpts", type=int, default=200)
38+
parser.add_argument("--reliability-thr", type=float, default=0.8)
39+
parser.add_argument("--repeatability-thr", type=float, default=0.7)
40+
parser.add_argument("--border", type=int, default=20,help='rm keypoints close to border')
41+
42+
parser.add_argument("--gpu", type=int, nargs='+', required=True, help='-1 for CPU')
43+
parser.add_argument("--dbg", type=str, nargs='+', default=(), help='debug options')
44+
45+
args = parser.parse_args()
46+
args.dbg = set(args.dbg)
47+
48+
iscuda = common.torch_set_gpu(args.gpu)
49+
device = torch.device('cuda' if iscuda else 'cpu')
50+
51+
# create network
52+
checkpoint = torch.load(args.checkpoint, lambda a,b:a)
53+
args.net = args.net or checkpoint['net']
54+
print("\n>> Creating net = " + args.net)
55+
net = eval(args.net)
56+
net.load_state_dict({k.replace('module.',''):v for k,v in checkpoint['state_dict'].items()})
57+
if iscuda: net = net.cuda()
58+
print(f" ( Model size: {common.model_size(net)/1000:.0f}K parameters )")
59+
60+
img = Image.open(args.img).convert('RGB')
61+
if args.resize: img.thumbnail((args.resize,args.resize))
62+
img = np.asarray(img)
63+
64+
detector = NonMaxSuppression(
65+
rel_thr = args.reliability_thr,
66+
rep_thr = args.repeatability_thr)
67+
68+
with torch.no_grad():
69+
print(">> computing features...")
70+
res = net(imgs=[norm_RGB(img).unsqueeze(0).to(device)])
71+
rela = res.get('reliability')
72+
repe = res.get('repeatability')
73+
kpts = detector(**res).T[:,[1,0]]
74+
kpts = kpts[repe[0][0,0][kpts[:,1],kpts[:,0]].argsort()[-args.max_kpts:]]
75+
76+
fig = pl.figure("viz")
77+
kw = dict(cmap=pl.cm.RdYlGn, vmax=1)
78+
crop = (slice(args.border,-args.border or 1),)*2
79+
80+
if 'reliability' in args.dbg:
81+
82+
ax1 = pl.subplot(131)
83+
pl.imshow(img[crop], cmap=pl.cm.gray)
84+
pl.xticks(()); pl.yticks(())
85+
86+
pl.subplot(132)
87+
pl.imshow(img[crop], cmap=pl.cm.gray, alpha=0)
88+
pl.xticks(()); pl.yticks(())
89+
90+
x,y = kpts[:,0:2].cpu().numpy().T - args.border
91+
pl.plot(x,y,'+',c=(0,1,0),ms=10, scalex=0, scaley=0)
92+
93+
ax1 = pl.subplot(133)
94+
rela = rela[0][0,0].cpu().numpy()
95+
pl.imshow(rela[crop], cmap=pl.cm.RdYlGn, vmax=1, vmin=0.9)
96+
pl.xticks(()); pl.yticks(())
97+
98+
else:
99+
ax1 = pl.subplot(131)
100+
pl.imshow(img[crop], cmap=pl.cm.gray)
101+
pl.xticks(()); pl.yticks(())
102+
103+
x,y = kpts[:,0:2].cpu().numpy().T - args.border
104+
pl.plot(x,y,'+',c=(0,1,0),ms=10, scalex=0, scaley=0)
105+
106+
pl.subplot(132)
107+
pl.imshow(img[crop], cmap=pl.cm.gray)
108+
pl.xticks(()); pl.yticks(())
109+
c = repe[0][0,0].cpu().numpy()
110+
pl.imshow(transparent(smooth(c)[crop], 0.5, vmin=0, **kw))
111+
112+
ax1 = pl.subplot(133)
113+
pl.imshow(img[crop], cmap=pl.cm.gray)
114+
pl.xticks(()); pl.yticks(())
115+
rela = rela[0][0,0].cpu().numpy()
116+
pl.imshow(transparent(rela[crop], 0.5, vmin=0.9, **kw))
117+
118+
pl.gcf().set_size_inches(9, 2.73)
119+
pl.subplots_adjust(0.01,0.01,0.99,0.99,hspace=0.1)
120+
pl.savefig(args.out)
121+
pdb.set_trace()
122+

0 commit comments

Comments
 (0)