Skip to content

Commit 9a45915

Browse files
committed
Formal model files
1 parent 134fffd commit 9a45915

File tree

10 files changed

+1020
-564
lines changed

10 files changed

+1020
-564
lines changed

ptsemseg/models/__init__.py

Lines changed: 24 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -12,53 +12,52 @@
1212
def get_model(name, n_classes, version=None):
1313
model = _get_model_instance(name)
1414

15-
if name in ['frrnA', 'frrnB']:
15+
if name in ["frrnA", "frrnB"]:
1616
model = model(n_classes, model_type=name[-1])
1717

18-
elif name in ['fcn32s', 'fcn16s', 'fcn8s']:
18+
elif name in ["fcn32s", "fcn16s", "fcn8s"]:
1919
model = model(n_classes=n_classes)
2020
vgg16 = models.vgg16(pretrained=True)
2121
model.init_vgg16_params(vgg16)
2222

23-
elif name == 'segnet':
24-
model = model(n_classes=n_classes,
25-
is_unpooling=True)
23+
elif name == "segnet":
24+
model = model(n_classes=n_classes, is_unpooling=True)
2625
vgg16 = models.vgg16(pretrained=True)
2726
model.init_vgg16_params(vgg16)
2827

29-
elif name == 'unet':
30-
model = model(n_classes=n_classes,
31-
is_batchnorm=True,
32-
in_channels=3,
33-
is_deconv=True)
28+
elif name == "unet":
29+
model = model(
30+
n_classes=n_classes, is_batchnorm=True, in_channels=3, is_deconv=True
31+
)
3432

35-
elif name == 'pspnet':
33+
elif name == "pspnet":
3634
model = model(n_classes=n_classes, version=version)
3735

38-
elif name == 'icnet':
36+
elif name == "icnet":
3937
model = model(n_classes=n_classes, with_bn=False, version=version)
40-
elif name == 'icnetBN':
38+
elif name == "icnetBN":
4139
model = model(n_classes=n_classes, with_bn=True, version=version)
4240

4341
else:
4442
model = model(n_classes=n_classes)
4543

4644
return model
4745

46+
4847
def _get_model_instance(name):
4948
try:
5049
return {
51-
'fcn32s': fcn32s,
52-
'fcn8s': fcn8s,
53-
'fcn16s': fcn16s,
54-
'unet': unet,
55-
'segnet': segnet,
56-
'pspnet': pspnet,
57-
'icnet': icnet,
58-
'icnetBN': icnet,
59-
'linknet': linknet,
60-
'frrnA': frrn,
61-
'frrnB': frrn,
50+
"fcn32s": fcn32s,
51+
"fcn8s": fcn8s,
52+
"fcn16s": fcn16s,
53+
"unet": unet,
54+
"segnet": segnet,
55+
"pspnet": pspnet,
56+
"icnet": icnet,
57+
"icnetBN": icnet,
58+
"linknet": linknet,
59+
"frrnA": frrn,
60+
"frrnB": frrn,
6261
}[name]
6362
except:
64-
print('Model {} not available'.format(name))
63+
print("Model {} not available".format(name))

ptsemseg/models/fcn.py

Lines changed: 62 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
# FCN32s
55
class fcn32s(nn.Module):
6-
76
def __init__(self, n_classes=21, learned_billinear=False):
87
super(fcn32s, self).__init__()
98
self.learned_billinear = learned_billinear
@@ -14,14 +13,16 @@ def __init__(self, n_classes=21, learned_billinear=False):
1413
nn.ReLU(inplace=True),
1514
nn.Conv2d(64, 64, 3, padding=1),
1615
nn.ReLU(inplace=True),
17-
nn.MaxPool2d(2, stride=2, ceil_mode=True),)
16+
nn.MaxPool2d(2, stride=2, ceil_mode=True),
17+
)
1818

1919
self.conv_block2 = nn.Sequential(
2020
nn.Conv2d(64, 128, 3, padding=1),
2121
nn.ReLU(inplace=True),
2222
nn.Conv2d(128, 128, 3, padding=1),
2323
nn.ReLU(inplace=True),
24-
nn.MaxPool2d(2, stride=2, ceil_mode=True),)
24+
nn.MaxPool2d(2, stride=2, ceil_mode=True),
25+
)
2526

2627
self.conv_block3 = nn.Sequential(
2728
nn.Conv2d(128, 256, 3, padding=1),
@@ -30,7 +31,8 @@ def __init__(self, n_classes=21, learned_billinear=False):
3031
nn.ReLU(inplace=True),
3132
nn.Conv2d(256, 256, 3, padding=1),
3233
nn.ReLU(inplace=True),
33-
nn.MaxPool2d(2, stride=2, ceil_mode=True),)
34+
nn.MaxPool2d(2, stride=2, ceil_mode=True),
35+
)
3436

3537
self.conv_block4 = nn.Sequential(
3638
nn.Conv2d(256, 512, 3, padding=1),
@@ -39,7 +41,8 @@ def __init__(self, n_classes=21, learned_billinear=False):
3941
nn.ReLU(inplace=True),
4042
nn.Conv2d(512, 512, 3, padding=1),
4143
nn.ReLU(inplace=True),
42-
nn.MaxPool2d(2, stride=2, ceil_mode=True),)
44+
nn.MaxPool2d(2, stride=2, ceil_mode=True),
45+
)
4346

4447
self.conv_block5 = nn.Sequential(
4548
nn.Conv2d(512, 512, 3, padding=1),
@@ -48,7 +51,8 @@ def __init__(self, n_classes=21, learned_billinear=False):
4851
nn.ReLU(inplace=True),
4952
nn.Conv2d(512, 512, 3, padding=1),
5053
nn.ReLU(inplace=True),
51-
nn.MaxPool2d(2, stride=2, ceil_mode=True),)
54+
nn.MaxPool2d(2, stride=2, ceil_mode=True),
55+
)
5256

5357
self.classifier = nn.Sequential(
5458
nn.Conv2d(512, 4096, 7),
@@ -57,15 +61,15 @@ def __init__(self, n_classes=21, learned_billinear=False):
5761
nn.Conv2d(4096, 4096, 1),
5862
nn.ReLU(inplace=True),
5963
nn.Dropout2d(),
60-
nn.Conv2d(4096, self.n_classes, 1),)
64+
nn.Conv2d(4096, self.n_classes, 1),
65+
)
6166

6267
# TODO: Add support for learned upsampling
6368
if self.learned_billinear:
6469
raise NotImplementedError
6570
# upscore = nn.ConvTranspose2d(self.n_classes, self.n_classes, 64, stride=32, bias=False)
6671
# upscore.scale_factor = None
6772

68-
6973
def forward(self, x):
7074
conv1 = self.conv_block1(x)
7175
conv2 = self.conv_block2(conv1)
@@ -79,19 +83,20 @@ def forward(self, x):
7983

8084
return out
8185

82-
8386
def init_vgg16_params(self, vgg16, copy_fc8=True):
84-
blocks = [self.conv_block1,
85-
self.conv_block2,
86-
self.conv_block3,
87-
self.conv_block4,
88-
self.conv_block5]
87+
blocks = [
88+
self.conv_block1,
89+
self.conv_block2,
90+
self.conv_block3,
91+
self.conv_block4,
92+
self.conv_block5,
93+
]
8994

9095
ranges = [[0, 4], [5, 9], [10, 16], [17, 23], [24, 29]]
9196
features = list(vgg16.features.children())
9297

9398
for idx, conv_block in enumerate(blocks):
94-
for l1, l2 in zip(features[ranges[idx][0]:ranges[idx][1]], conv_block):
99+
for l1, l2 in zip(features[ranges[idx][0] : ranges[idx][1]], conv_block):
95100
if isinstance(l1, nn.Conv2d) and isinstance(l2, nn.Conv2d):
96101
# print(idx, l1, l2)
97102
assert l1.weight.size() == l2.weight.size()
@@ -111,8 +116,8 @@ def init_vgg16_params(self, vgg16, copy_fc8=True):
111116
l2.weight.data = l1.weight.data[:n_class, :].view(l2.weight.size())
112117
l2.bias.data = l1.bias.data[:n_class]
113118

114-
class fcn16s(nn.Module):
115119

120+
class fcn16s(nn.Module):
116121
def __init__(self, n_classes=21, learned_billinear=False):
117122
super(fcn16s, self).__init__()
118123
self.learned_billinear = learned_billinear
@@ -123,14 +128,16 @@ def __init__(self, n_classes=21, learned_billinear=False):
123128
nn.ReLU(inplace=True),
124129
nn.Conv2d(64, 64, 3, padding=1),
125130
nn.ReLU(inplace=True),
126-
nn.MaxPool2d(2, stride=2, ceil_mode=True),)
131+
nn.MaxPool2d(2, stride=2, ceil_mode=True),
132+
)
127133

128134
self.conv_block2 = nn.Sequential(
129135
nn.Conv2d(64, 128, 3, padding=1),
130136
nn.ReLU(inplace=True),
131137
nn.Conv2d(128, 128, 3, padding=1),
132138
nn.ReLU(inplace=True),
133-
nn.MaxPool2d(2, stride=2, ceil_mode=True),)
139+
nn.MaxPool2d(2, stride=2, ceil_mode=True),
140+
)
134141

135142
self.conv_block3 = nn.Sequential(
136143
nn.Conv2d(128, 256, 3, padding=1),
@@ -139,7 +146,8 @@ def __init__(self, n_classes=21, learned_billinear=False):
139146
nn.ReLU(inplace=True),
140147
nn.Conv2d(256, 256, 3, padding=1),
141148
nn.ReLU(inplace=True),
142-
nn.MaxPool2d(2, stride=2, ceil_mode=True),)
149+
nn.MaxPool2d(2, stride=2, ceil_mode=True),
150+
)
143151

144152
self.conv_block4 = nn.Sequential(
145153
nn.Conv2d(256, 512, 3, padding=1),
@@ -148,7 +156,8 @@ def __init__(self, n_classes=21, learned_billinear=False):
148156
nn.ReLU(inplace=True),
149157
nn.Conv2d(512, 512, 3, padding=1),
150158
nn.ReLU(inplace=True),
151-
nn.MaxPool2d(2, stride=2, ceil_mode=True),)
159+
nn.MaxPool2d(2, stride=2, ceil_mode=True),
160+
)
152161

153162
self.conv_block5 = nn.Sequential(
154163
nn.Conv2d(512, 512, 3, padding=1),
@@ -157,7 +166,8 @@ def __init__(self, n_classes=21, learned_billinear=False):
157166
nn.ReLU(inplace=True),
158167
nn.Conv2d(512, 512, 3, padding=1),
159168
nn.ReLU(inplace=True),
160-
nn.MaxPool2d(2, stride=2, ceil_mode=True),)
169+
nn.MaxPool2d(2, stride=2, ceil_mode=True),
170+
)
161171

162172
self.classifier = nn.Sequential(
163173
nn.Conv2d(512, 4096, 7),
@@ -166,7 +176,8 @@ def __init__(self, n_classes=21, learned_billinear=False):
166176
nn.Conv2d(4096, 4096, 1),
167177
nn.ReLU(inplace=True),
168178
nn.Dropout2d(),
169-
nn.Conv2d(4096, self.n_classes, 1),)
179+
nn.Conv2d(4096, self.n_classes, 1),
180+
)
170181

171182
self.score_pool4 = nn.Conv2d(512, self.n_classes, 1)
172183

@@ -176,7 +187,6 @@ def __init__(self, n_classes=21, learned_billinear=False):
176187
# upscore = nn.ConvTranspose2d(self.n_classes, self.n_classes, 64, stride=32, bias=False)
177188
# upscore.scale_factor = None
178189

179-
180190
def forward(self, x):
181191
conv1 = self.conv_block1(x)
182192
conv2 = self.conv_block2(conv1)
@@ -193,19 +203,20 @@ def forward(self, x):
193203

194204
return out
195205

196-
197206
def init_vgg16_params(self, vgg16, copy_fc8=True):
198-
blocks = [self.conv_block1,
199-
self.conv_block2,
200-
self.conv_block3,
201-
self.conv_block4,
202-
self.conv_block5]
207+
blocks = [
208+
self.conv_block1,
209+
self.conv_block2,
210+
self.conv_block3,
211+
self.conv_block4,
212+
self.conv_block5,
213+
]
203214

204215
ranges = [[0, 4], [5, 9], [10, 16], [17, 23], [24, 29]]
205216
features = list(vgg16.features.children())
206217

207218
for idx, conv_block in enumerate(blocks):
208-
for l1, l2 in zip(features[ranges[idx][0]:ranges[idx][1]], conv_block):
219+
for l1, l2 in zip(features[ranges[idx][0] : ranges[idx][1]], conv_block):
209220
if isinstance(l1, nn.Conv2d) and isinstance(l2, nn.Conv2d):
210221
# print(idx, l1, l2)
211222
assert l1.weight.size() == l2.weight.size()
@@ -224,9 +235,9 @@ def init_vgg16_params(self, vgg16, copy_fc8=True):
224235
l2.weight.data = l1.weight.data[:n_class, :].view(l2.weight.size())
225236
l2.bias.data = l1.bias.data[:n_class]
226237

238+
227239
# FCN 8s
228240
class fcn8s(nn.Module):
229-
230241
def __init__(self, n_classes=21, learned_billinear=False):
231242
super(fcn8s, self).__init__()
232243
self.learned_billinear = learned_billinear
@@ -237,14 +248,16 @@ def __init__(self, n_classes=21, learned_billinear=False):
237248
nn.ReLU(inplace=True),
238249
nn.Conv2d(64, 64, 3, padding=1),
239250
nn.ReLU(inplace=True),
240-
nn.MaxPool2d(2, stride=2, ceil_mode=True),)
251+
nn.MaxPool2d(2, stride=2, ceil_mode=True),
252+
)
241253

242254
self.conv_block2 = nn.Sequential(
243255
nn.Conv2d(64, 128, 3, padding=1),
244256
nn.ReLU(inplace=True),
245257
nn.Conv2d(128, 128, 3, padding=1),
246258
nn.ReLU(inplace=True),
247-
nn.MaxPool2d(2, stride=2, ceil_mode=True),)
259+
nn.MaxPool2d(2, stride=2, ceil_mode=True),
260+
)
248261

249262
self.conv_block3 = nn.Sequential(
250263
nn.Conv2d(128, 256, 3, padding=1),
@@ -253,7 +266,8 @@ def __init__(self, n_classes=21, learned_billinear=False):
253266
nn.ReLU(inplace=True),
254267
nn.Conv2d(256, 256, 3, padding=1),
255268
nn.ReLU(inplace=True),
256-
nn.MaxPool2d(2, stride=2, ceil_mode=True),)
269+
nn.MaxPool2d(2, stride=2, ceil_mode=True),
270+
)
257271

258272
self.conv_block4 = nn.Sequential(
259273
nn.Conv2d(256, 512, 3, padding=1),
@@ -262,7 +276,8 @@ def __init__(self, n_classes=21, learned_billinear=False):
262276
nn.ReLU(inplace=True),
263277
nn.Conv2d(512, 512, 3, padding=1),
264278
nn.ReLU(inplace=True),
265-
nn.MaxPool2d(2, stride=2, ceil_mode=True),)
279+
nn.MaxPool2d(2, stride=2, ceil_mode=True),
280+
)
266281

267282
self.conv_block5 = nn.Sequential(
268283
nn.Conv2d(512, 512, 3, padding=1),
@@ -271,7 +286,8 @@ def __init__(self, n_classes=21, learned_billinear=False):
271286
nn.ReLU(inplace=True),
272287
nn.Conv2d(512, 512, 3, padding=1),
273288
nn.ReLU(inplace=True),
274-
nn.MaxPool2d(2, stride=2, ceil_mode=True),)
289+
nn.MaxPool2d(2, stride=2, ceil_mode=True),
290+
)
275291

276292
self.classifier = nn.Sequential(
277293
nn.Conv2d(512, 4096, 7),
@@ -280,7 +296,8 @@ def __init__(self, n_classes=21, learned_billinear=False):
280296
nn.Conv2d(4096, 4096, 1),
281297
nn.ReLU(inplace=True),
282298
nn.Dropout2d(),
283-
nn.Conv2d(4096, self.n_classes, 1),)
299+
nn.Conv2d(4096, self.n_classes, 1),
300+
)
284301

285302
self.score_pool4 = nn.Conv2d(512, self.n_classes, 1)
286303
self.score_pool3 = nn.Conv2d(256, self.n_classes, 1)
@@ -310,19 +327,20 @@ def forward(self, x):
310327

311328
return out
312329

313-
314330
def init_vgg16_params(self, vgg16, copy_fc8=True):
315-
blocks = [self.conv_block1,
316-
self.conv_block2,
317-
self.conv_block3,
318-
self.conv_block4,
319-
self.conv_block5]
331+
blocks = [
332+
self.conv_block1,
333+
self.conv_block2,
334+
self.conv_block3,
335+
self.conv_block4,
336+
self.conv_block5,
337+
]
320338

321339
ranges = [[0, 4], [5, 9], [10, 16], [17, 23], [24, 29]]
322340
features = list(vgg16.features.children())
323341

324342
for idx, conv_block in enumerate(blocks):
325-
for l1, l2 in zip(features[ranges[idx][0]:ranges[idx][1]], conv_block):
343+
for l1, l2 in zip(features[ranges[idx][0] : ranges[idx][1]], conv_block):
326344
if isinstance(l1, nn.Conv2d) and isinstance(l2, nn.Conv2d):
327345
assert l1.weight.size() == l2.weight.size()
328346
assert l1.bias.size() == l2.bias.size()

0 commit comments

Comments
 (0)