-
Notifications
You must be signed in to change notification settings - Fork 0
/
vgg11.py
78 lines (68 loc) · 2.44 KB
/
vgg11.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import torch.nn as nn
import torchvision
class VGG(nn.Module):
def __init__(self, n_class):
super().__init__()
# VGG11 output architecture:
# conv3-64-1
# max 2x2-2
# conv3-128-1
# max 2x2-2
# conv3-256-1
# conv3-256-1
# max 2x2-2
# conv3-512-1
# conv3-512-1
# max 2x2-2
# conv3-512-1
# conv3-512-1
# max 2x2-2
# --------
# FC-4096
# FC-4096
# FC-1000
# soft-max
mod = torchvision.models.vgg11_bn(pretrained=True)
# take only the feature portion of the model (no avg pool or classification)
self.mod = mod.features
# freeze pre-trained layers
for param in self.mod.parameters():
param.requires_grad = False
self.n_class = n_class
#self.fc = nn.Linear(1000, 512)
#self.bnd1 = nn.BatchNorm2d(512)
self.relu = nn.ReLU(inplace=True)
self.deconv1 = nn.ConvTranspose2d(512, 512, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
self.bn1 = nn.BatchNorm2d(512)
self.deconv2 = nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
self.bn2 = nn.BatchNorm2d(256)
self.deconv3 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
self.bn3 = nn.BatchNorm2d(128)
self.deconv4 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
self.bn4 = nn.BatchNorm2d(64)
self.deconv5 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
self.bn5 = nn.BatchNorm2d(32)
self.classifier = nn.Conv2d(32, self.n_class, kernel_size=1)
def forward(self, x):
out_encoder = self.mod
out_decoder = nn.Sequential(
self.deconv1,
self.bn1,
self.relu,
self.deconv2,
self.bn2,
self.relu,
self.deconv3,
self.bn3,
self.relu,
self.deconv4,
self.bn4,
self.relu,
self.deconv5,
self.bn5,
self.relu
)
encoded = out_encoder(x)
decoded = out_decoder(encoded)
score = self.classifier(decoded)
return score # size=(N, n_class, x.H/1, x.W/1)