-
Notifications
You must be signed in to change notification settings - Fork 13
/
model.py
77 lines (65 loc) · 2.7 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from __future__ import print_function
import mxnet as mx
from mxnet import nd, gluon
from mxnet.gluon import nn, rnn
class LSTNet(gluon.Block):
"""
LSTNet auto-regressive block
"""
def __init__(self, num_series, conv_hid, gru_hid, skip_gru_hid, skip, ar_window):
super(LSTNet, self).__init__()
kernel_size = 6
dropout_rate = 0.2
self.skip = skip
self.ar_window = ar_window
with self.name_scope():
self.conv = nn.Conv1D(conv_hid, kernel_size=kernel_size, layout='NCW', activation='relu')
self.dropout = nn.Dropout(dropout_rate)
self.gru = rnn.GRU(gru_hid, layout='TNC')
self.skip_gru = rnn.GRU(skip_gru_hid, layout='TNC')
self.fc = nn.Dense(num_series)
self.ar_fc = nn.Dense(1)
def forward(self, x):
"""
:param nd.NDArray x: input data in NTC layout (N: batch-size, T: sequence len, C: channels)
:return: output of LSTNet in NC layout
:rtype nd.NDArray
"""
# Convolution
c = self.conv(x.transpose((0, 2, 1))) # Transpose NTC to to NCT (a.k.a NCW) before convolution
c = self.dropout(c)
# GRU
r = self.gru(c.transpose((2, 0, 1))) # Transpose NCT to TNC before GRU
r = r[-1] # Only keep the last output
r = self.dropout(r) # Now in NC layout
# Skip GRU
# Slice off multiples of skip from convolution output
skip_c = c[:, :, -(c.shape[2] // self.skip) * self.skip:]
skip_c = skip_c.reshape(c.shape[0], c.shape[1], -1, self.skip) # Reshape to NCT x skip
skip_c = skip_c.transpose((2, 0, 3, 1)) # Transpose to T x N x skip x C
skip_c = skip_c.reshape(skip_c.shape[0], -1, skip_c.shape[3]) # Reshape to Tx (Nxskip) x C
s = self.skip_gru(skip_c)
s = s[-1] # Only keep the last output (now in (Nxskip) x C layout)
s = s.reshape(x.shape[0], -1) # Now in N x (skipxC) layout
# FC layer
fc = self.fc(nd.concat(r, s)) # NC layout
# Autoregressive highway
ar_x = x[:, -self.ar_window:, :] # NTC layout
ar_x = ar_x.transpose((0, 2, 1)) # NCT layout
ar_x = ar_x.reshape(-1, ar_x.shape[2]) # (NC) x T layout
ar = self.ar_fc(ar_x)
ar = ar.reshape(x.shape[0], -1) # NC layout
# Add autoregressive and fc outputs
res = fc + ar
return res
if __name__ == "__main__":
"""
Run unit-test
"""
net = LSTNet(num_series=321, conv_hid=100, gru_hid=100, skip_gru_hid=5, skip=24, ar_window=24)
x = nd.random.uniform(shape=(128, 1000, 321))
net.initialize()
y = net(x)
assert y.shape == (128, 321)
nd.waitall()
print("Unit-test success!")