|
| 1 | +import pdb |
| 2 | +import os |
| 3 | +import sys |
| 4 | +import tqdm |
| 5 | + |
| 6 | +import numpy as np |
| 7 | +import torch |
| 8 | + |
| 9 | +from PIL import Image |
| 10 | +from matplotlib import pyplot as pl; pl.ion() |
| 11 | +from scipy.ndimage import uniform_filter |
| 12 | +smooth = lambda arr: uniform_filter(arr, 3) |
| 13 | + |
| 14 | +def transparent(img, alpha, cmap, **kw): |
| 15 | + from matplotlib.colors import Normalize |
| 16 | + colored_img = cmap(Normalize(clip=True,**kw)(img)) |
| 17 | + colored_img[:,:,-1] = alpha |
| 18 | + return colored_img |
| 19 | + |
| 20 | +from tools import common |
| 21 | +from tools.dataloader import norm_RGB |
| 22 | +from nets.patchnet import * |
| 23 | +from extract import NonMaxSuppression |
| 24 | + |
| 25 | + |
| 26 | +if __name__ == '__main__': |
| 27 | + import argparse |
| 28 | + parser = argparse.ArgumentParser("Visualize the patch detector and descriptor") |
| 29 | + |
| 30 | + parser.add_argument("--img", type=str, default="imgs/brooklyn.png") |
| 31 | + parser.add_argument("--resize", type=int, default=512) |
| 32 | + parser.add_argument("--out", type=str, default="viz.png") |
| 33 | + |
| 34 | + parser.add_argument("--checkpoint", type=str, required=True, help='network path') |
| 35 | + parser.add_argument("--net", type=str, default="", help='network command') |
| 36 | + |
| 37 | + parser.add_argument("--max-kpts", type=int, default=200) |
| 38 | + parser.add_argument("--reliability-thr", type=float, default=0.8) |
| 39 | + parser.add_argument("--repeatability-thr", type=float, default=0.7) |
| 40 | + parser.add_argument("--border", type=int, default=20,help='rm keypoints close to border') |
| 41 | + |
| 42 | + parser.add_argument("--gpu", type=int, nargs='+', required=True, help='-1 for CPU') |
| 43 | + parser.add_argument("--dbg", type=str, nargs='+', default=(), help='debug options') |
| 44 | + |
| 45 | + args = parser.parse_args() |
| 46 | + args.dbg = set(args.dbg) |
| 47 | + |
| 48 | + iscuda = common.torch_set_gpu(args.gpu) |
| 49 | + device = torch.device('cuda' if iscuda else 'cpu') |
| 50 | + |
| 51 | + # create network |
| 52 | + checkpoint = torch.load(args.checkpoint, lambda a,b:a) |
| 53 | + args.net = args.net or checkpoint['net'] |
| 54 | + print("\n>> Creating net = " + args.net) |
| 55 | + net = eval(args.net) |
| 56 | + net.load_state_dict({k.replace('module.',''):v for k,v in checkpoint['state_dict'].items()}) |
| 57 | + if iscuda: net = net.cuda() |
| 58 | + print(f" ( Model size: {common.model_size(net)/1000:.0f}K parameters )") |
| 59 | + |
| 60 | + img = Image.open(args.img).convert('RGB') |
| 61 | + if args.resize: img.thumbnail((args.resize,args.resize)) |
| 62 | + img = np.asarray(img) |
| 63 | + |
| 64 | + detector = NonMaxSuppression( |
| 65 | + rel_thr = args.reliability_thr, |
| 66 | + rep_thr = args.repeatability_thr) |
| 67 | + |
| 68 | + with torch.no_grad(): |
| 69 | + print(">> computing features...") |
| 70 | + res = net(imgs=[norm_RGB(img).unsqueeze(0).to(device)]) |
| 71 | + rela = res.get('reliability') |
| 72 | + repe = res.get('repeatability') |
| 73 | + kpts = detector(**res).T[:,[1,0]] |
| 74 | + kpts = kpts[repe[0][0,0][kpts[:,1],kpts[:,0]].argsort()[-args.max_kpts:]] |
| 75 | + |
| 76 | + fig = pl.figure("viz") |
| 77 | + kw = dict(cmap=pl.cm.RdYlGn, vmax=1) |
| 78 | + crop = (slice(args.border,-args.border or 1),)*2 |
| 79 | + |
| 80 | + if 'reliability' in args.dbg: |
| 81 | + |
| 82 | + ax1 = pl.subplot(131) |
| 83 | + pl.imshow(img[crop], cmap=pl.cm.gray) |
| 84 | + pl.xticks(()); pl.yticks(()) |
| 85 | + |
| 86 | + pl.subplot(132) |
| 87 | + pl.imshow(img[crop], cmap=pl.cm.gray, alpha=0) |
| 88 | + pl.xticks(()); pl.yticks(()) |
| 89 | + |
| 90 | + x,y = kpts[:,0:2].cpu().numpy().T - args.border |
| 91 | + pl.plot(x,y,'+',c=(0,1,0),ms=10, scalex=0, scaley=0) |
| 92 | + |
| 93 | + ax1 = pl.subplot(133) |
| 94 | + rela = rela[0][0,0].cpu().numpy() |
| 95 | + pl.imshow(rela[crop], cmap=pl.cm.RdYlGn, vmax=1, vmin=0.9) |
| 96 | + pl.xticks(()); pl.yticks(()) |
| 97 | + |
| 98 | + else: |
| 99 | + ax1 = pl.subplot(131) |
| 100 | + pl.imshow(img[crop], cmap=pl.cm.gray) |
| 101 | + pl.xticks(()); pl.yticks(()) |
| 102 | + |
| 103 | + x,y = kpts[:,0:2].cpu().numpy().T - args.border |
| 104 | + pl.plot(x,y,'+',c=(0,1,0),ms=10, scalex=0, scaley=0) |
| 105 | + |
| 106 | + pl.subplot(132) |
| 107 | + pl.imshow(img[crop], cmap=pl.cm.gray) |
| 108 | + pl.xticks(()); pl.yticks(()) |
| 109 | + c = repe[0][0,0].cpu().numpy() |
| 110 | + pl.imshow(transparent(smooth(c)[crop], 0.5, vmin=0, **kw)) |
| 111 | + |
| 112 | + ax1 = pl.subplot(133) |
| 113 | + pl.imshow(img[crop], cmap=pl.cm.gray) |
| 114 | + pl.xticks(()); pl.yticks(()) |
| 115 | + rela = rela[0][0,0].cpu().numpy() |
| 116 | + pl.imshow(transparent(rela[crop], 0.5, vmin=0.9, **kw)) |
| 117 | + |
| 118 | + pl.gcf().set_size_inches(9, 2.73) |
| 119 | + pl.subplots_adjust(0.01,0.01,0.99,0.99,hspace=0.1) |
| 120 | + pl.savefig(args.out) |
| 121 | + pdb.set_trace() |
| 122 | + |
0 commit comments