-
Notifications
You must be signed in to change notification settings - Fork 228
/
run.py
99 lines (78 loc) · 2.62 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
from lib.config import cfg, args
import numpy as np
import os
def run_dataset():
from lib.datasets import make_data_loader
import tqdm
cfg.train.num_workers = 0
data_loader = make_data_loader(cfg, is_train=False)
for batch in tqdm.tqdm(data_loader):
pass
def run_network():
from lib.networks import make_network
from lib.datasets import make_data_loader
from lib.utils.net_utils import load_network
import tqdm
import torch
import time
network = make_network(cfg).cuda()
load_network(network, cfg.model_dir, epoch=cfg.test.epoch)
network.eval()
data_loader = make_data_loader(cfg, is_train=False)
total_time = 0
for batch in tqdm.tqdm(data_loader):
for k in batch:
if k != 'meta':
batch[k] = batch[k].cuda()
with torch.no_grad():
torch.cuda.synchronize()
start = time.time()
network(batch['inp'])
torch.cuda.synchronize()
total_time += time.time() - start
print(total_time / len(data_loader))
def run_evaluate():
from lib.datasets import make_data_loader
from lib.evaluators import make_evaluator
import tqdm
import torch
from lib.networks import make_network
from lib.utils.net_utils import load_network
network = make_network(cfg).cuda()
load_network(network, cfg.model_dir, epoch=cfg.test.epoch)
network.eval()
data_loader = make_data_loader(cfg, is_train=False)
evaluator = make_evaluator(cfg)
for batch in tqdm.tqdm(data_loader):
inp = batch['inp'].cuda()
with torch.no_grad():
output = network(inp)
evaluator.evaluate(output, batch)
evaluator.summarize()
def run_visualize():
from lib.networks import make_network
from lib.datasets import make_data_loader
from lib.utils.net_utils import load_network
import tqdm
import torch
from lib.visualizers import make_visualizer
network = make_network(cfg).cuda()
load_network(network, cfg.model_dir, resume=cfg.resume, epoch=cfg.test.epoch)
network.eval()
data_loader = make_data_loader(cfg, is_train=False)
visualizer = make_visualizer(cfg)
for batch in tqdm.tqdm(data_loader):
for k in batch:
if k != 'meta':
batch[k] = batch[k].cuda()
with torch.no_grad():
output = network(batch['inp'], batch)
visualizer.visualize(output, batch)
def run_sbd():
from tools import convert_sbd
convert_sbd.convert_sbd()
def run_demo():
from tools import demo
demo.demo()
if __name__ == '__main__':
globals()['run_'+args.type]()