-
Notifications
You must be signed in to change notification settings - Fork 149
/
discriminator.py
70 lines (56 loc) · 2.45 KB
/
discriminator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import torch
import torch.autograd as autograd
import torch.nn as nn
import pdb
class Discriminator(nn.Module):
def __init__(self, embedding_dim, hidden_dim, vocab_size, max_seq_len, gpu=False, dropout=0.2):
super(Discriminator, self).__init__()
self.hidden_dim = hidden_dim
self.embedding_dim = embedding_dim
self.max_seq_len = max_seq_len
self.gpu = gpu
self.embeddings = nn.Embedding(vocab_size, embedding_dim)
self.gru = nn.GRU(embedding_dim, hidden_dim, num_layers=2, bidirectional=True, dropout=dropout)
self.gru2hidden = nn.Linear(2*2*hidden_dim, hidden_dim)
self.dropout_linear = nn.Dropout(p=dropout)
self.hidden2out = nn.Linear(hidden_dim, 1)
def init_hidden(self, batch_size):
h = autograd.Variable(torch.zeros(2*2*1, batch_size, self.hidden_dim))
if self.gpu:
return h.cuda()
else:
return h
def forward(self, input, hidden):
# input dim # batch_size x seq_len
emb = self.embeddings(input) # batch_size x seq_len x embedding_dim
emb = emb.permute(1, 0, 2) # seq_len x batch_size x embedding_dim
_, hidden = self.gru(emb, hidden) # 4 x batch_size x hidden_dim
hidden = hidden.permute(1, 0, 2).contiguous() # batch_size x 4 x hidden_dim
out = self.gru2hidden(hidden.view(-1, 4*self.hidden_dim)) # batch_size x 4*hidden_dim
out = torch.tanh(out)
out = self.dropout_linear(out)
out = self.hidden2out(out) # batch_size x 1
out = torch.sigmoid(out)
return out
def batchClassify(self, inp):
"""
Classifies a batch of sequences.
Inputs: inp
- inp: batch_size x seq_len
Returns: out
- out: batch_size ([0,1] score)
"""
h = self.init_hidden(inp.size()[0])
out = self.forward(inp, h)
return out.view(-1)
def batchBCELoss(self, inp, target):
"""
Returns Binary Cross Entropy Loss for discriminator.
Inputs: inp, target
- inp: batch_size x seq_len
- target: batch_size (binary 1/0)
"""
loss_fn = nn.BCELoss()
h = self.init_hidden(inp.size()[0])
out = self.forward(inp, h)
return loss_fn(out, target)