-
Notifications
You must be signed in to change notification settings - Fork 0
/
DenseNet_model
89 lines (71 loc) · 3.73 KB
/
DenseNet_model
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import torch
import torch.nn as nn
def init_weights(module):
if isinstance(module, nn.Conv2d):
nn.init.kaiming_normal_(module.weight.detach(), mode='fan_out', nonlinearity='relu')
elif isinstance(module, nn.BatchNorm2d):
module.weight.detach().fill_(1.)
module.bias.detach().fill_(0.)
def pool_position_check(layers_num):
position_A = round(layers_num / 3)
position_B = round(layers_num * 2 / 3)
if position_B < position_A:
position_A -= 1
position_B += 1
return position_A, position_B
class Denselayers(nn.Module):
def __init__(self, num, in_channel, drop_rate=0., growth_rate=None):
super(Denselayers, self).__init__()
self.num = num
self.drop_rate = drop_rate
self.growth_rate = growth_rate
setattr(self, 'BN_{}_1'.format(self.num), nn.BatchNorm2d(in_channel))
setattr(self, 'ReLU_{}_1'.format(self.num), nn.ReLU())
setattr(self, 'conv1_{}'.format(self.num),
nn.Conv2d(in_channel, self.num * self.growth_rate, kernel_size=1, padding=0, stride=1))
if self.drop_rate > 0.: setattr(self, 'droprate_{}_1'.format(self.num), nn.Dropout(p=drop_rate))
setattr(self, 'BN_{}_2'.format(self.num), nn.BatchNorm2d(self.num * self.growth_rate))
setattr(self, 'ReLU_{}_2'.format(self.num), nn.ReLU())
setattr(self, 'conv3_{}'.format(self.num),
nn.Conv2d(num * self.growth_rate, 3 * self.growth_rate, kernel_size=3, padding=1, stride=1))
if self.drop_rate > 0.: setattr(self, 'droprate_{}_2'.format(num), nn.Dropout(p=drop_rate))
self.apply(init_weights)
def forward(self, input):
_ = getattr(self, 'BN_{}_1'.format(self.num))(input)
_ = getattr(self, 'ReLU_{}_1'.format(self.num))(_)
_ = getattr(self, 'conv1_{}'.format(self.num))(_)
if self.drop_rate > 0.: _ = getattr(self, 'droprate_{}_1'.format(self.num))(_)
_ = getattr(self, 'BN_{}_2'.format(self.num))(_)
_ = getattr(self, 'ReLU_{}_2'.format(self.num))(_)
output = getattr(self, 'conv3_{}'.format(self.num))(_)
if self.drop_rate > 0.: output = getattr(self, 'droprate_{}_2'.format(self.num))(output)
return torch.cat((input, output), dim=1)
class DenseNet(nn.Module):
def __init__(self, layers_num, num_init_feature=False, growth_rate=None, drop_rate=0.):
super(DenseNet, self).__init__()
self.layers_num = layers_num
in_channel = num_init_feature if num_init_feature else growth_rate * 2
self.init_conv = nn.Conv2d(1, in_channel, kernel_size=3, stride=1, padding=1)
self.init_pool = nn.MaxPool2d(2, 2)
self.last_batchnorm = nn.BatchNorm2d(int(in_channel + (3 * growth_rate * self.layers_num)))
self.last_avgpool = nn.AvgPool2d(kernel_size=8, stride=2, padding=0)
self.classifier = nn.Linear(int(in_channel + (3 * growth_rate * self.layers_num)), 2)
for num in range(self.layers_num):
setattr(self, 'Dense_{}'.format(num + 1),
Denselayers(num + 1, in_channel, drop_rate=drop_rate, growth_rate=growth_rate))
# if num % 2 == 1:
setattr(self, 'avgpool_{}'.format(num + 1), nn.AvgPool2d(kernel_size=2, stride=2, padding=0))
in_channel += 3 * growth_rate
self.apply(init_weights)
def forward(self, input):
_ = self.init_conv(input)
_ = self.init_pool(_)
for num in range(self.layers_num):
_ = getattr(self, 'Dense_{}'.format(num + 1))(_)
# if num % 2 == 1:
_ = getattr(self, 'avgpool_{}'.format(num + 1))(_)
_ = self.last_batchnorm(_)
_ = self.last_avgpool(_)
_ = nn.Flatten()(_)
output = self.classifier(_)
return output