-
Notifications
You must be signed in to change notification settings - Fork 126
/
word_level_rnn.py
48 lines (44 loc) · 1.9 KB
/
word_level_rnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import torch
import torch.nn as nn
class WordLevelRNN(nn.Module):
def __init__(self, config):
super().__init__()
dataset = config.dataset
word_num_hidden = config.word_num_hidden
words_num = config.words_num
words_dim = config.words_dim
self.mode = config.mode
if self.mode == 'rand':
rand_embed_init = torch.Tensor(words_num, words_dim).uniform_(-0.25, 0.25)
self.embed = nn.Embedding.from_pretrained(rand_embed_init, freeze=False)
elif self.mode == 'static':
self.static_embed = nn.Embedding.from_pretrained(dataset.TEXT_FIELD.vocab.vectors, freeze=True)
elif self.mode == 'non-static':
self.non_static_embed = nn.Embedding.from_pretrained(dataset.TEXT_FIELD.vocab.vectors, freeze=False)
else:
print("Unsupported order")
exit()
self.word_context_weights = nn.Parameter(torch.rand(2 * word_num_hidden, 1))
self.GRU = nn.GRU(words_dim, word_num_hidden, bidirectional=True)
self.linear = nn.Linear(2 * word_num_hidden, 2 * word_num_hidden, bias=True)
self.word_context_weights.data.uniform_(-0.25, 0.25)
self.soft_word = nn.Softmax()
def forward(self, x):
# x expected to be of dimensions--> (num_words, batch_size)
if self.mode == 'rand':
x = self.embed(x)
elif self.mode == 'static':
x = self.static_embed(x)
elif self.mode == 'non-static':
x = self.non_static_embed(x)
else :
print("Unsupported mode")
exit()
h, _ = self.GRU(x)
x = torch.tanh(self.linear(h))
x = torch.matmul(x, self.word_context_weights)
x = x.squeeze(dim=2)
x = self.soft_word(x.transpose(1, 0))
x = torch.mul(h.permute(2, 0, 1), x.transpose(1, 0))
x = torch.sum(x, dim=1).transpose(1, 0).unsqueeze(0)
return x