-
Notifications
You must be signed in to change notification settings - Fork 2
/
discriminator.py
69 lines (50 loc) · 2.72 KB
/
discriminator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from torch import nn
class discriminator(nn.Module):
'''
Discrimator setup
'''
def __init__(self, channels=3):
super().__init__()
self.channels = channels
def down_sampling(input_channels=None, output_channels=None, stride=None, kernel_size=None, normalize=None, activation=True):
layers = list()
#Conv2d Inputs
args = {
'in_channels':input_channels,
'out_channels':output_channels,
'stride':stride,
'kernel_size':kernel_size,
'padding':1
}
layers.append(nn.Conv2d(**args, bias=False))
#normalize flag
if normalize:
layers.append(nn.BatchNorm2d(output_channels))
if activation:
layers.append(nn.LeakyReLU(.2, inplace=True))
return layers
def get_layers(layer_config_list):
layers = list()
for layer_config in layer_config_list:
layers += down_sampling(**layer_config)
return layers
def config_layer(input_channels=3, output_channels=3, stride=2, normalize=True, kernel_size=(4, 4), activation=True):
return locals()
self.discriminator_layers = list()
# Discriminator layers
self.discriminator_layers.append(config_layer(self.channels, 64, 2, normalize=False)) #layer-->1 input --> 3 x 64 x 64
self.discriminator_layers.append(config_layer(64, 128, 2, True)) #layer-->2 input --> 64 x 32 x 32
self.discriminator_layers.append(config_layer(128, 256, 2, True)) #layer-->3 input --> 128 x 16 x 16
self.discriminator_layers.append(config_layer(256, 512, 2, True)) #layer-->4 input --> 256 x 8 x 8
# self.discriminator_layers.append(config_layer(512, 1, 2, False, activation=False))
# Model
self.model = nn.Sequential(
*get_layers(self.discriminator_layers),
nn.Conv2d(512, 1, kernel_size=(4, 4), bias=False), #layer-->5 input --> 512 x 4 x 4
nn.Sigmoid(),
# output --> 1 x 1 x 1
)
def forward(self, x):
output = self.model(x)
# print(output.shape)
return output.view(-1)