-
Notifications
You must be signed in to change notification settings - Fork 384
/
test.py
75 lines (53 loc) · 2.34 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import argparse
import logging
import os
import numpy as np
import scipy.io
import scipy.ndimage
from util.config import load_config
from dataset.factory import create as create_dataset
from dataset.pose_dataset import Batch
from nnet.predict import setup_pose_prediction, extract_cnn_output, argmax_pose_predict
from util import visualize
def test_net(visualise, cache_scoremaps):
logging.basicConfig(level=logging.INFO)
cfg = load_config()
dataset = create_dataset(cfg)
dataset.set_shuffle(False)
dataset.set_test_mode(True)
sess, inputs, outputs = setup_pose_prediction(cfg)
if cache_scoremaps:
out_dir = cfg.scoremap_dir
if not os.path.exists(out_dir):
os.makedirs(out_dir)
num_images = dataset.num_images
predictions = np.zeros((num_images,), dtype=np.object)
for k in range(num_images):
print('processing image {}/{}'.format(k, num_images-1))
batch = dataset.next_batch()
outputs_np = sess.run(outputs, feed_dict={inputs: batch[Batch.inputs]})
scmap, locref, pairwise_diff = extract_cnn_output(outputs_np, cfg)
pose = argmax_pose_predict(scmap, locref, cfg.stride)
pose_refscale = np.copy(pose)
pose_refscale[:, 0:2] /= cfg.global_scale
predictions[k] = pose_refscale
if visualise:
img = np.squeeze(batch[Batch.inputs]).astype('uint8')
visualize.show_heatmaps(cfg, img, scmap, pose)
visualize.waitforbuttonpress()
if cache_scoremaps:
base = os.path.basename(batch[Batch.data_item].im_path)
raw_name = os.path.splitext(base)[0]
out_fn = os.path.join(out_dir, raw_name + '.mat')
scipy.io.savemat(out_fn, mdict={'scoremaps': scmap.astype('float32')})
out_fn = os.path.join(out_dir, raw_name + '_locreg' + '.mat')
if cfg.location_refinement:
scipy.io.savemat(out_fn, mdict={'locreg_pred': locref.astype('float32')})
scipy.io.savemat('predictions.mat', mdict={'joints': predictions})
sess.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--novis', default=False, action='store_true')
parser.add_argument('--cache', default=False, action='store_true')
args, unparsed = parser.parse_known_args()
test_net(not args.novis, args.cache)