-
Notifications
You must be signed in to change notification settings - Fork 0
/
mnist_mx_gluon_mgpu.py
132 lines (117 loc) · 4.14 KB
/
mnist_mx_gluon_mgpu.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import mxnet as mx
import numpy as np
import random
import time
from mxnet import autograd as ag
from mxnet.io import NDArrayIter
from mxnet.metric import Accuracy
from mxnet.optimizer import Adam
from mxnet.test_utils import get_mnist_iterator
from mxnet.gluon import Block, Trainer
from mxnet.gluon.loss import SoftmaxCrossEntropyLoss
from mxnet.gluon.nn import Conv2D, Dense, Dropout, Flatten, MaxPool2D, HybridBlock
from mxnet.gluon.utils import split_and_load
BATCH_SIZE_PER_REPLICA = 512
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * 2
NUM_CLASSES = 10
EPOCHS = 10
GPU_COUNT = 2
class Model(HybridBlock):
def __init__(self, **kwargs):
super(Model, self).__init__(**kwargs)
with self.name_scope():
self.conv1 = Conv2D(32, (3, 3))
self.conv2 = Conv2D(64, (3, 3))
self.pool = MaxPool2D(pool_size=(2, 2))
self.dropout1 = Dropout(0.25)
self.flatten = Flatten()
self.dense1 = Dense(128)
self.dropout2 = Dropout(0.5)
self.dense2 = Dense(NUM_CLASSES)
def hybrid_forward(self, F, x):
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = self.pool(x)
x = self.dropout1(x)
x = self.flatten(x)
x = F.relu(self.dense1(x))
x = self.dropout2(x)
x = self.dense2(x)
return x
mx.random.seed(42)
random.seed(42)
# get data
input_shape = (1, 28, 28)
train_data, test_data = get_mnist_iterator(input_shape=input_shape,
batch_size=BATCH_SIZE)
# build nodel
model = Model()
# hybridize for speed
model.hybridize(static_alloc=True, static_shape=True)
# pin GPUs
ctx = [mx.gpu(i) for i in range(GPU_COUNT)]
# optimizer
opt_params={'learning_rate':0.001, 'beta1':0.9, 'beta2':0.999, 'epsilon':1e-08}
opt = mx.optimizer.create('adam', **opt_params)
# initialize parameters
model.initialize(force_reinit=True, ctx=ctx)
# fetch and broadcast parameters
params = model.collect_params()
# trainer
trainer = Trainer(params=params,
optimizer=opt,
kvstore='device')
# loss function
loss_fn = SoftmaxCrossEntropyLoss()
# use accuracy as the evaluation metric
metric = Accuracy()
start = time.perf_counter()
for epoch in range(1, EPOCHS+1):
# reset the train data iterator.
train_data.reset()
# loop over the train data iterator
for i, batch in enumerate(train_data):
if i == 0:
tick_0 = time.time()
# splits train data into multiple slices along batch_axis
# copy each slice into a context
data = split_and_load(batch.data[0], ctx_list=ctx, batch_axis=0)
# splits train labels into multiple slices along batch_axis
# copy each slice into a context
label = split_and_load(batch.label[0], ctx_list=ctx, batch_axis=0)
outputs = []
losses = []
# inside training scope
with ag.record():
for x, y in zip(data, label):
z = model(x)
# computes softmax cross entropy loss
l = loss_fn(z, y)
outputs.append(z)
losses.append(l)
# backpropagate the error for one iteration
for l in losses:
l.backward()
# make one step of parameter update.
# trainer needs to know the batch size of data
# to normalize the gradient by 1/batch_size
trainer.step(BATCH_SIZE)
# updates internal evaluation
metric.update(label, outputs)
str1 = 'Epoch [{}], Accuracy {:.4f}'.format(epoch, metric.get()[1])
str2 = '~Samples/Sec {:.4f}'.format(BATCH_SIZE*(i+1)/(time.time()-tick_0))
print('%s %s' % (str1, str2))
# reset evaluation result to initial state.
metric.reset()
elapsed = time.perf_counter() - start
print('elapsed: {:0.3f}'.format(elapsed))
# use Accuracy as the evaluation metric
metric = Accuracy()
for batch in test_data:
data = split_and_load(batch.data[0], ctx_list=ctx, batch_axis=0)
label = split_and_load(batch.label[0], ctx_list=ctx, batch_axis=0)
outputs = []
for x in data:
outputs.append(model(x))
metric.update(label, outputs)
print('validation %s=%f' % metric.get())