-
Notifications
You must be signed in to change notification settings - Fork 384
/
test_multiperson.py
131 lines (102 loc) · 4.56 KB
/
test_multiperson.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import argparse
import logging
import os
import numpy as np
import scipy.io
import scipy.ndimage
import json
from json import encoder
encoder.FLOAT_REPR = lambda o: format(o, '.2f')
from util.config import load_config
from dataset.factory import create as create_dataset
from dataset.pose_dataset import Batch
from util.mscoco_util import pose_predict_with_gt_segm
from nnet.predict import *
from util import visualize
from multiperson.detections import extract_detections
from multiperson.predict import SpatialModel, eval_graph, get_person_conf_multicut
from multiperson.visualize import PersonDraw, visualize_detections
import matplotlib.pyplot as plt
def test_net(visualise, cache_scoremaps, development):
logging.basicConfig(level=logging.INFO)
cfg = load_config()
dataset = create_dataset(cfg)
dataset.set_shuffle(False)
sm = SpatialModel(cfg)
sm.load()
draw_multi = PersonDraw()
from_cache = "cached_scoremaps" in cfg
if not from_cache:
sess, inputs, outputs = setup_pose_prediction(cfg)
if cache_scoremaps:
out_dir = cfg.scoremap_dir
if not os.path.exists(out_dir):
os.makedirs(out_dir)
pairwise_stats = dataset.pairwise_stats
num_images = dataset.num_images if not development else min(10, dataset.num_images)
coco_results = []
for k in range(num_images):
print('processing image {}/{}'.format(k, num_images-1))
batch = dataset.next_batch()
cache_name = "{}.mat".format(batch[Batch.data_item].coco_id)
if not from_cache:
outputs_np = sess.run(outputs, feed_dict={inputs: batch[Batch.inputs]})
scmap, locref, pairwise_diff = extract_cnn_output(outputs_np, cfg, pairwise_stats)
if cache_scoremaps:
if visualise:
img = np.squeeze(batch[Batch.inputs]).astype('uint8')
pose = argmax_pose_predict(scmap, locref, cfg.stride)
arrows = argmax_arrows_predict(scmap, locref, pairwise_diff, cfg.stride)
visualize.show_arrows(cfg, img, pose, arrows)
visualize.waitforbuttonpress()
continue
out_fn = os.path.join(out_dir, cache_name)
dict = {'scoremaps': scmap.astype('float32'),
'locreg_pred': locref.astype('float32'),
'pairwise_diff': pairwise_diff.astype('float32')}
scipy.io.savemat(out_fn, mdict=dict)
continue
else:
#cache_name = '1.mat'
full_fn = os.path.join(cfg.cached_scoremaps, cache_name)
mlab = scipy.io.loadmat(full_fn)
scmap = mlab["scoremaps"]
locref = mlab["locreg_pred"]
pairwise_diff = mlab["pairwise_diff"]
detections = extract_detections(cfg, scmap, locref, pairwise_diff)
unLab, pos_array, unary_array, pwidx_array, pw_array = eval_graph(sm, detections)
person_conf_multi = get_person_conf_multicut(sm, unLab, unary_array, pos_array)
if visualise:
img = np.squeeze(batch[Batch.inputs]).astype('uint8')
#visualize.show_heatmaps(cfg, img, scmap, pose)
"""
# visualize part detections after NMS
visim_dets = visualize_detections(cfg, img, detections)
plt.imshow(visim_dets)
plt.show()
visualize.waitforbuttonpress()
"""
# """
visim_multi = img.copy()
draw_multi.draw(visim_multi, dataset, person_conf_multi)
plt.imshow(visim_multi)
plt.show()
visualize.waitforbuttonpress()
# """
if cfg.use_gt_segm:
coco_img_results = pose_predict_with_gt_segm(scmap, locref, cfg.stride, batch[Batch.data_item].gt_segm,
batch[Batch.data_item].coco_id)
coco_results += coco_img_results
if len(coco_img_results):
dataset.visualize_coco(coco_img_results, batch[Batch.data_item].visibilities)
if cfg.use_gt_segm:
with open('predictions_with_segm.json', 'w') as outfile:
json.dump(coco_results, outfile)
sess.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--novis', default=False, action='store_true')
parser.add_argument('--cache', default=False, action='store_true')
parser.add_argument('--dev', default=False, action='store_true')
args, unparsed = parser.parse_known_args()
test_net(not args.novis, args.cache, args.dev)