-
Notifications
You must be signed in to change notification settings - Fork 0
/
run.py
81 lines (70 loc) · 2.71 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
"""
Train and validate a model
Usage:
./run.py -w </path/to/data> -e 8 -r 0
"""
import os
import numpy as np
from neon import logger as neon_logger
from neon.layers import GeneralizedCost
from neon.optimizers import Adadelta
from neon.transforms import LogLoss, CrossEntropyBinary
from neon.callbacks.callbacks import Callbacks
from neon.util.argparser import NeonArgparser
from neon.initializers import Kaiming
from neon.layers import Conv, Dropout, Pooling, Affine
from neon.transforms import Rectlin, Softmax
from neon.models import Model
from data import ChunkLoader
import settings
import video
def create_network():
init = Kaiming()
padding = dict(pad_d=1, pad_h=1, pad_w=1)
strides = dict(str_d=2, str_h=2, str_w=2)
dilation = dict(dil_d=2, dil_h=2, dil_w=2)
common = dict(init=init, batch_norm=True, activation=Rectlin())
layers = [
Conv((9, 9, 9, 16), padding=padding, strides=strides, init=init, activation=Rectlin()),
Conv((5, 5, 5, 32), dilation=dilation, **common),
Conv((3, 3, 3, 64), dilation=dilation, **common),
Pooling((2, 2, 2), padding=padding, strides=strides),
Conv((2, 2, 2, 128), **common),
Conv((2, 2, 2, 128), **common),
Conv((2, 2, 2, 128), **common),
Conv((2, 2, 2, 256), **common),
Conv((2, 2, 2, 1024), **common),
Conv((2, 2, 2, 4096), **common),
Conv((2, 2, 2, 2048), **common),
Conv((2, 2, 2, 1024), **common),
Dropout(),
Affine(2, init=Kaiming(local=False), batch_norm=True, activation=Softmax())
]
return Model(layers=layers)
# Parse the command line arguments
parser = NeonArgparser(__doc__)
parser.add_argument('-tm', '--test_mode', action='store_true',
help='make predictions on test data')
args = parser.parse_args()
# Create model
model = create_network()
# Setup data provider
repo_dir = args.data_dir
common = dict(datum_dtype=np.uint8, repo_dir=repo_dir, test_mode=True)
train = ChunkLoader(set_name='train', augment=args.test_mode, **common)
test = ChunkLoader(set_name='val', augment=True, **common)
if args.test_mode:
assert args.model_file is not None
model.load_params(args.model_file)
for dataset in [train, test]:
pred = model.get_outputs(dataset)
np.save(os.path.join(repo_dir, dataset.set_name + '-pred.npy'), pred[:, 1])
else:
# Setup callbacks
callbacks = Callbacks(model, eval_set=test, **args.callback_args)
# Train model
opt = Adadelta()
cost = GeneralizedCost(costfunc=CrossEntropyBinary())
model.fit(train, optimizer=opt, num_epochs=args.epochs, cost=cost, callbacks=callbacks)
# Output metrics
neon_logger.display('Test Logloss = %.4f' % (model.eval(test, metric=LogLoss())))