-
Notifications
You must be signed in to change notification settings - Fork 149
/
generator.py
128 lines (99 loc) · 4.47 KB
/
generator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pdb
import math
import torch.nn.init as init
class Generator(nn.Module):
def __init__(self, embedding_dim, hidden_dim, vocab_size, max_seq_len, gpu=False, oracle_init=False):
super(Generator, self).__init__()
self.hidden_dim = hidden_dim
self.embedding_dim = embedding_dim
self.max_seq_len = max_seq_len
self.vocab_size = vocab_size
self.gpu = gpu
self.embeddings = nn.Embedding(vocab_size, embedding_dim)
self.gru = nn.GRU(embedding_dim, hidden_dim)
self.gru2out = nn.Linear(hidden_dim, vocab_size)
# initialise oracle network with N(0,1)
# otherwise variance of initialisation is very small => high NLL for data sampled from the same model
if oracle_init:
for p in self.parameters():
init.normal(p, 0, 1)
def init_hidden(self, batch_size=1):
h = autograd.Variable(torch.zeros(1, batch_size, self.hidden_dim))
if self.gpu:
return h.cuda()
else:
return h
def forward(self, inp, hidden):
"""
Embeds input and applies GRU one token at a time (seq_len = 1)
"""
# input dim # batch_size
emb = self.embeddings(inp) # batch_size x embedding_dim
emb = emb.view(1, -1, self.embedding_dim) # 1 x batch_size x embedding_dim
out, hidden = self.gru(emb, hidden) # 1 x batch_size x hidden_dim (out)
out = self.gru2out(out.view(-1, self.hidden_dim)) # batch_size x vocab_size
out = F.log_softmax(out, dim=1)
return out, hidden
def sample(self, num_samples, start_letter=0):
"""
Samples the network and returns num_samples samples of length max_seq_len.
Outputs: samples, hidden
- samples: num_samples x max_seq_length (a sampled sequence in each row)
"""
samples = torch.zeros(num_samples, self.max_seq_len).type(torch.LongTensor)
h = self.init_hidden(num_samples)
inp = autograd.Variable(torch.LongTensor([start_letter]*num_samples))
if self.gpu:
samples = samples.cuda()
inp = inp.cuda()
for i in range(self.max_seq_len):
out, h = self.forward(inp, h) # out: num_samples x vocab_size
out = torch.multinomial(torch.exp(out), 1) # num_samples x 1 (sampling from each row)
samples[:, i] = out.view(-1).data
inp = out.view(-1)
return samples
def batchNLLLoss(self, inp, target):
"""
Returns the NLL Loss for predicting target sequence.
Inputs: inp, target
- inp: batch_size x seq_len
- target: batch_size x seq_len
inp should be target with <s> (start letter) prepended
"""
loss_fn = nn.NLLLoss()
batch_size, seq_len = inp.size()
inp = inp.permute(1, 0) # seq_len x batch_size
target = target.permute(1, 0) # seq_len x batch_size
h = self.init_hidden(batch_size)
loss = 0
for i in range(seq_len):
out, h = self.forward(inp[i], h)
loss += loss_fn(out, target[i])
return loss # per batch
def batchPGLoss(self, inp, target, reward):
"""
Returns a pseudo-loss that gives corresponding policy gradients (on calling .backward()).
Inspired by the example in http://karpathy.github.io/2016/05/31/rl/
Inputs: inp, target
- inp: batch_size x seq_len
- target: batch_size x seq_len
- reward: batch_size (discriminator reward for each sentence, applied to each token of the corresponding
sentence)
inp should be target with <s> (start letter) prepended
"""
batch_size, seq_len = inp.size()
inp = inp.permute(1, 0) # seq_len x batch_size
target = target.permute(1, 0) # seq_len x batch_size
h = self.init_hidden(batch_size)
loss = 0
for i in range(seq_len):
out, h = self.forward(inp[i], h)
# TODO: should h be detached from graph (.detach())?
for j in range(batch_size):
loss += -out[j][target.data[i][j]]*reward[j] # log(P(y_t|Y_1:Y_{t-1})) * Q
return loss/batch_size