-
Notifications
You must be signed in to change notification settings - Fork 2
/
model_mts.py
150 lines (124 loc) · 6.96 KB
/
model_mts.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import torch
import torch.nn as nn
import numpy as np
from embed_regularize import embedded_dropout
from locked_dropout import LockedDropout
from weight_drop import WeightDrop
import scipy.stats
class RNNModel(nn.Module):
"""Container module with an encoder, a recurrent module, and a decoder."""
def __init__(self, rnn_type, ntoken, ninp, nhid, nlayers, dropout=0.5, dropouth=0.5, dropouti=0.5, dropoute=0.1, wdrop=0, tie_weights=False):
super(RNNModel, self).__init__()
self.lockdrop = LockedDropout()
self.idrop = nn.Dropout(dropouti)
self.hdrop = nn.Dropout(dropouth)
self.drop = nn.Dropout(dropout)
self.encoder = nn.Embedding(ntoken, ninp)
assert rnn_type in ['LSTM', 'QRNN', 'GRU'], 'RNN type is not supported'
if rnn_type == 'LSTM':
self.rnns = [torch.nn.LSTM(ninp if l == 0 else nhid, nhid if l != nlayers - 1 else (ninp if tie_weights else nhid), 1, dropout=0) for l in range(nlayers)]
## start edit: steps for mts model bias assigments
## STEP1: make a list of size of hidden layers - useful for init step
hid_dim = [nhid if l != nlayers -1 else ( ninp if tie_weights else nhid) for l in range(nlayers)]
## STEP2: create bias values depending on type of init we want
chrono_bias = [np.zeros(hid_dim[l]) for l in range(nlayers)]
multi_timescale = True
if multi_timescale:
#layer 0 with half units of timescale 3 and half of timescale 4
half_length = int(0.5*hid_dim[0]) ;
timescale_first_half, timescale_second_half = 3,4
#calculate bias values from timescale and store in an array
chrono_bias[0][:half_length] = -1 * np.log(np.exp(1/timescale_first_half)-1)
chrono_bias[0][half_length:] = -1 * np.log(np.exp(1/timescale_second_half)-1)
#layer 1 with timescale sampled from an inverse gamma distribution
timescale_invgamma = scipy.stats.invgamma.isf(np.linspace(0, 1, 1151),a=0.56,scale=1)[1:]
#calculate bias values from timescales and store in an array
chrono_bias[1] = -1 * np.log(np.exp(1/timescale_invgamma)-1)
## STEP 3: assign bias values to the layers-first half is input gate bias, second half is forget gate for both i to h and h to h
for l in range(nlayers-1): #Assign biases for only first two layers
self.rnns[l].bias_ih_l0.data[0:hid_dim[l]*2] = torch.tensor(np.zeros(hid_dim[l]*2),dtype=torch.float)
self.rnns[l].bias_hh_l0.data[0:hid_dim[l]*2] = torch.from_numpy(np.hstack((-1*chrono_bias[l], chrono_bias[l] )).astype(np.float32))
## STEP 4: fix the bias - if we want to fix the bias instead of just init them
fixed_weights = True
if fixed_weights:
for l in range(nlayers-1):
print(l)
self.rnns[l].bias_ih_l0.requires_grad = False
self.rnns[l].bias_hh_l0.requires_grad = False
##end edit
###
if wdrop:
self.rnns = [WeightDrop(rnn, ['weight_hh_l0'], dropout=wdrop) for rnn in self.rnns]
if rnn_type == 'GRU':
self.rnns = [torch.nn.GRU(ninp if l == 0 else nhid, nhid if l != nlayers - 1 else ninp, 1, dropout=0) for l in range(nlayers)]
if wdrop:
self.rnns = [WeightDrop(rnn, ['weight_hh_l0'], dropout=wdrop) for rnn in self.rnns]
elif rnn_type == 'QRNN':
from torchqrnn import QRNNLayer
self.rnns = [QRNNLayer(input_size=ninp if l == 0 else nhid, hidden_size=nhid if l != nlayers - 1 else (ninp if tie_weights else nhid), save_prev_x=True, zoneout=0, window=2 if l == 0 else 1, output_gate=True) for l in range(nlayers)]
for rnn in self.rnns:
rnn.linear = WeightDrop(rnn.linear, ['weight'], dropout=wdrop)
print(self.rnns)
self.rnns = torch.nn.ModuleList(self.rnns)
self.decoder = nn.Linear(nhid, ntoken)
# Optionally tie weights as in:
# "Using the Output Embedding to Improve Language Models" (Press & Wolf 2016)
# https://arxiv.org/abs/1608.05859
# and
# "Tying Word Vectors and Word Classifiers: A Loss Framework for Language Modeling" (Inan et al. 2016)
# https://arxiv.org/abs/1611.01462
if tie_weights:
#if nhid != ninp:
# raise ValueError('When using the tied flag, nhid must be equal to emsize')
self.decoder.weight = self.encoder.weight
self.init_weights()
self.rnn_type = rnn_type
self.ninp = ninp
self.nhid = nhid
self.nlayers = nlayers
self.dropout = dropout
self.dropouti = dropouti
self.dropouth = dropouth
self.dropoute = dropoute
self.tie_weights = tie_weights
def reset(self):
if self.rnn_type == 'QRNN': [r.reset() for r in self.rnns]
def init_weights(self):
initrange = 0.1
self.encoder.weight.data.uniform_(-initrange, initrange)
self.decoder.bias.data.fill_(0)
self.decoder.weight.data.uniform_(-initrange, initrange)
def forward(self, input, hidden, return_h=False):
emb = embedded_dropout(self.encoder, input, dropout=self.dropoute if self.training else 0)
#emb = self.idrop(emb)
emb = self.lockdrop(emb, self.dropouti)
raw_output = emb
new_hidden = []
#raw_output, hidden = self.rnn(emb, hidden)
raw_outputs = []
outputs = []
for l, rnn in enumerate(self.rnns):
current_input = raw_output
raw_output, new_h = rnn(raw_output, hidden[l])
new_hidden.append(new_h)
raw_outputs.append(raw_output)
if l != self.nlayers - 1:
#self.hdrop(raw_output)
raw_output = self.lockdrop(raw_output, self.dropouth)
outputs.append(raw_output)
hidden = new_hidden
output = self.lockdrop(raw_output, self.dropout)
outputs.append(output)
result = output.view(output.size(0)*output.size(1), output.size(2))
if return_h:
return result, hidden, raw_outputs, outputs
return result, hidden
def init_hidden(self, bsz):
weight = next(self.parameters()).data
if self.rnn_type == 'LSTM':
return [(weight.new(1, bsz, self.nhid if l != self.nlayers - 1 else (self.ninp if self.tie_weights else self.nhid)).zero_(),
weight.new(1, bsz, self.nhid if l != self.nlayers - 1 else (self.ninp if self.tie_weights else self.nhid)).zero_())
for l in range(self.nlayers)]
elif self.rnn_type == 'QRNN' or self.rnn_type == 'GRU':
return [weight.new(1, bsz, self.nhid if l != self.nlayers - 1 else (self.ninp if self.tie_weights else self.nhid)).zero_()
for l in range(self.nlayers)]