-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathpytorch_se_resnet_densenet.py
158 lines (127 loc) · 5.08 KB
/
pytorch_se_resnet_densenet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import math
import torch
from torchvision.models import ResNet
from torchvision import models
from torch import nn
import torch.nn.functional as F
class SELayer(nn.Module):
def __init__(self, channel, reduction=16, alpha=0.):
super(SELayer, self).__init__()
assert alpha >= 0 and alpha <= 1
self.alpha = alpha
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction),
nn.ReLU(inplace=True),
nn.Linear(channel // reduction, channel),
nn.Sigmoid()
)
def forward(self, x):
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
return x * (self.alpha * y + (1 - self.alpha))
class SEBottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=16):
super(SEBottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * 4)
self.relu = nn.ReLU(inplace=True)
self.se = SELayer(planes * 4, reduction)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
out = self.se(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
def se_resnet50(num_classes, path_to_model=None):
"""Constructs a ResNet-50 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(SEBottleneck, [3, 4, 6, 3], num_classes=num_classes)
model.avgpool = nn.AdaptiveAvgPool2d(1)
if path_to_model is not None:
model.load_state_dict(torch.load(path_to_model))
return model
class SE_Resnet50(nn.Module):
def __init__(self, num_classes, pretrained=True):
super().__init__()
self.bottleneck_conf = [3, 4, 6, 3]
self.net = ResNet(SEBottleneck, self.bottleneck_conf, num_classes=1000)
self.net.avgpool = nn.AdaptiveAvgPool2d(1)
self.net.load_state_dict(models.resnet50(pretrained=True).state_dict(), strict=False)
self.fc = nn.Sequential(
nn.Linear(self.net.fc.in_features + 1, 512),
nn.Dropout(0.3),
nn.Linear(512, 128),
nn.Dropout(0.3),
nn.Linear(128, num_classes)
)
self.net.fc = nn.Dropout(0.0)
def set_SE_alpha(self, alpha):
assert alpha >= 0 and alpha <= 1
for idx, block in enumerate([self.net.layer1, self.net.layer2, self.net.layer3, self.net.layer4]):
for i in range(self.bottleneck_conf[idx]):
block[i].se.alpha = alpha
def forward(self, x, O):
out = self.net(x)
out = out.view(out.size(0), -1)
out = torch.cat([out, O], 1)
return F.softmax(self.fc(out), dim=1)
class AvgPool(nn.Module):
def forward(self, x):
return torch.nn.functional.avg_pool2d(x, (x.size(2), x.size(3)))
class DenseNetWithManip(nn.Module):
finetune = True
def __init__(self, num_classes, two_layer=True, densenet=None, freeze=True):
super().__init__()
self.net = models.densenet201(num_classes=1000)
self.net.avgpool = AvgPool()
self.relu = torch.nn.ReLU(inplace=True)
if densenet is not None:
self.load_state_dict(densenet.state_dict(), strict=False)
if freeze:
self.set_trainable(trainable=False)
if two_layer:
mid_channels = 512
self.net.classifier = nn.Sequential(
nn.Linear(self.net.classifier.in_features + 1, mid_channels),
nn.Dropout(p=0.1),
nn.Linear(mid_channels, num_classes))
else:
self.net.classifier = nn.Linear(
self.net.classifier.in_features + 1, num_classes)
def set_trainable(self, trainable):
parameters = filter(lambda p: p.requires_grad is not None, self.parameters())
for param in parameters:
param.requires_grad = trainable
def forward(self, x, O): # 0, 1, 2, 3 -> (0, 3, 1, 2)
x = torch.transpose(x, 1, 3) # 0, 3, 2, 1
x = torch.transpose(x, 2, 3) # 0, 3, 1, 2
net = self.net
features = net.features(x)
out = self.relu(features)
out = torch.nn.functional.avg_pool2d(out, (out.size(2), out.size(3)))
out = out.view(features.size(0), -1)
out = torch.cat([out, O], 1)
out = net.classifier(out)
return out