-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathInvertedResidual.py
84 lines (66 loc) · 2.27 KB
/
InvertedResidual.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import torch
from torch import nn
def conv_dbl(in_dim, out_dim, stride):
return nn.Sequential(
nn.Conv2d(in_dim, out_dim, 3, stride, 1, bias=False),
nn.BatchNorm2d(out_dim),
nn.ReLU6(True)
)
def conv_3_1(dim):
return nn.Sequential(
nn.Conv2d(dim, dim, 3, 1, 1, groups=dim, bias=False),
nn.BatchNorm2d(dim),
nn.ReLU6(True),
nn.Conv2d(dim, dim * 2, 1, 1, 0, bias=False),
nn.BatchNorm2d(dim * 2),
nn.ReLU6(True),
)
def con1x1(in_dim, out_dim):
return nn.Sequential(
nn.Conv2d(in_dim, out_dim, 1, 1, 0, bias=False),
nn.BatchNorm2d(out_dim),
nn.ReLU6(True),
)
def extend_layers(in_dim, inter_dim):
return nn.Sequential(
con1x1(in_dim, inter_dim),
conv_3_1(inter_dim),
con1x1(inter_dim * 2, inter_dim),
conv_3_1(inter_dim),
con1x1(inter_dim * 2, inter_dim),
)
def output_layers(inter_dim, out_dim):
return nn.Sequential(
conv_3_1(inter_dim),
nn.Conv2d(inter_dim * 2, out_dim, 1, 1, 0)
)
class InvertedResidual(nn.Module):
def __init__(self, in_dim, out_dim, stride, expand_ratio):
super(InvertedResidual, self).__init__()
self.stride = stride
hidden_dim = round(expand_ratio * in_dim)
self.use_res = stride == 1 and in_dim == out_dim
if expand_ratio == 1:
self.conv = nn.Sequential(
nn.Conv2d(in_dim, in_dim, 3, stride, 1, groups=in_dim, bias=False),
nn.BatchNorm2d(hidden_dim),
nn.ReLU6(True),
nn.Conv2d(in_dim, out_dim, 1, 1, 0),
nn.BatchNorm2d(out_dim)
)
else:
self.conv = nn.Sequential(
nn.Conv2d(in_dim, hidden_dim, 1, 1, 0, bias=False),
nn.BatchNorm2d(hidden_dim),
nn.ReLU6(True),
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
nn.BatchNorm2d(hidden_dim),
nn.ReLU6(True),
nn.Conv2d(hidden_dim, out_dim, 1, 1, 0),
nn.BatchNorm2d(out_dim)
)
def forward(self, x):
if self.use_res:
return x + self.conv(x)
else:
return self.conv(x)