-
Notifications
You must be signed in to change notification settings - Fork 130
/
test.py
139 lines (115 loc) · 4.45 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
#!/usr/bin/env python
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import json
import torch
import pprint
import argparse
import importlib
import sys
import numpy as np
import matplotlib
matplotlib.use("Agg")
from config import system_configs
from nnet.py_factory import NetworkFactory
from db.datasets import datasets
from db.utils.evaluator import Evaluator
torch.backends.cudnn.benchmark = False
def parse_args():
parser = argparse.ArgumentParser(description="Test CornerNet")
parser.add_argument("cfg_file", help="config file", type=str)
parser.add_argument("--testiter", dest="testiter",
help="test at iteration i",
default=None, type=int)
parser.add_argument("--split", dest="split",
help="which split to use",
default="validation", type=str)
parser.add_argument("--suffix", dest="suffix", default=None, type=str)
parser.add_argument("--debug", action="store_true")
parser.add_argument("--modality", dest="modality",
default=None, type=str)
parser.add_argument("--image_root", dest="image_root",
default=None, type=str)
parser.add_argument("--batch", dest='batch',
help="select a value to maximum your FPS",
default=1, type=int)
parser.add_argument("--debugEnc", action="store_true")
parser.add_argument("--debugDec", action="store_true")
args = parser.parse_args()
return args
def make_dirs(directories):
for directory in directories:
if not os.path.exists(directory):
os.makedirs(directory)
def test(db, split, testiter,
debug=False, suffix=None, modality=None, image_root=None, batch=1,
debugEnc=False, debugDec=False):
result_dir = system_configs.result_dir
result_dir = os.path.join(result_dir, str(testiter), split)
if suffix is not None:
result_dir = os.path.join(result_dir, suffix)
make_dirs([result_dir])
test_iter = system_configs.max_iter if testiter is None else testiter
print("loading parameters at iteration: {}".format(test_iter))
print("building neural network...")
nnet = NetworkFactory()
print("loading parameters...")
nnet.load_params(test_iter)
nnet.cuda()
nnet.eval_mode()
evaluator = Evaluator(db, result_dir)
if modality == 'eval':
print('static evaluating...')
test_file = "test.tusimple"
testing = importlib.import_module(test_file).testing
testing(db, nnet, result_dir, debug=debug, evaluator=evaluator, repeat=batch,
debugEnc=debugEnc, debugDec=debugDec)
elif modality == 'images':
if image_root == None:
raise ValueError('--image_root is not defined!')
print("processing [images]...")
test_file = "test.images"
image_testing = importlib.import_module(test_file).testing
image_testing(db, nnet, image_root, debug=debug, evaluator=None)
else:
raise ValueError('--modality must be one of eval/images, but now: {}'
.format(modality))
if __name__ == "__main__":
args = parse_args()
if args.suffix is None:
cfg_file = os.path.join(system_configs.config_dir, args.cfg_file + ".json")
else:
cfg_file = os.path.join(system_configs.config_dir, args.cfg_file + "-{}.json".format(args.suffix))
print("cfg_file: {}".format(cfg_file))
with open(cfg_file, "r") as f:
configs = json.load(f)
configs["system"]["snapshot_name"] = args.cfg_file
system_configs.update_config(configs["system"])
train_split = system_configs.train_split
val_split = system_configs.val_split
test_split = system_configs.test_split
split = {
"training": train_split,
"validation": val_split,
"testing": test_split
}[args.split]
print("loading all datasets...")
dataset = system_configs.dataset
print("split: {}".format(split)) # test
testing_db = datasets[dataset](configs["db"], split)
# print("system config...")
# pprint.pprint(system_configs.full)
#
# print("db config...")
# pprint.pprint(testing_db.configs)
test(testing_db,
args.split,
args.testiter,
args.debug,
args.suffix,
args.modality,
args.image_root,
args.batch,
args.debugEnc,
args.debugDec,)