-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
305 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .ibnunet import UnetIBN |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .ibnresnet import resnet50_ibn_a |
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,244 @@ | ||
import os | ||
import math | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
|
||
|
||
__all__ = ['ResNet_IBN', 'resnet50_ibn_a', 'resnet50_ibn_b'] | ||
|
||
output_channles = { | ||
"resnet50_ibn_a": (3, 64, 256, 512, 1024, 2048), | ||
"resnet50_ibn_b": (3, 64, 256, 512, 1024, 2048), | ||
} | ||
|
||
|
||
class IBN(nn.Module): | ||
r"""Instance-Batch Normalization layer from | ||
`"Two at Once: Enhancing Learning and Generalization Capacities via IBN-Net" | ||
<https://arxiv.org/pdf/1807.09441.pdf>` | ||
Args: | ||
planes (int): Number of channels for the input tensor | ||
ratio (float): Ratio of instance normalization in the IBN layer | ||
""" | ||
|
||
def __init__(self, planes, ratio=0.5): | ||
super(IBN, self).__init__() | ||
self.half = int(planes * ratio) | ||
self.IN = nn.InstanceNorm2d(self.half, affine=True) | ||
self.BN = nn.BatchNorm2d(planes - self.half) | ||
|
||
def forward(self, x): | ||
split = torch.split(x, self.half, 1) | ||
out1 = self.IN(split[0].contiguous()) | ||
out2 = self.BN(split[1].contiguous()) | ||
out = torch.cat((out1, out2), 1) | ||
return out | ||
|
||
|
||
class BasicBlock_IBN(nn.Module): | ||
expansion = 1 | ||
|
||
def __init__(self, inplanes, planes, ibn=None, stride=1, downsample=None): | ||
super(BasicBlock_IBN, self).__init__() | ||
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, | ||
padding=1, bias=False) | ||
if ibn == 'a': | ||
self.bn1 = IBN(planes) | ||
else: | ||
self.bn1 = nn.BatchNorm2d(planes) | ||
self.relu = nn.ReLU(inplace=True) | ||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, bias=False) | ||
self.bn2 = nn.BatchNorm2d(planes) | ||
self.IN = nn.InstanceNorm2d(planes, affine=True) if ibn == 'b' else None | ||
self.downsample = downsample | ||
self.stride = stride | ||
|
||
def forward(self, x): | ||
residual = x | ||
|
||
out = self.conv1(x) | ||
out = self.bn1(out) | ||
out = self.relu(out) | ||
|
||
out = self.conv2(out) | ||
out = self.bn2(out) | ||
|
||
if self.downsample is not None: | ||
residual = self.downsample(x) | ||
|
||
out += residual | ||
if self.IN is not None: | ||
out = self.IN(out) | ||
out = self.relu(out) | ||
|
||
return out | ||
|
||
|
||
class Bottleneck_IBN(nn.Module): | ||
expansion = 4 | ||
|
||
def __init__(self, inplanes, planes, ibn=None, stride=1, downsample=None): | ||
super(Bottleneck_IBN, self).__init__() | ||
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) | ||
if ibn == 'a': | ||
self.bn1 = IBN(planes) | ||
else: | ||
self.bn1 = nn.BatchNorm2d(planes) | ||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, | ||
padding=1, bias=False) | ||
self.bn2 = nn.BatchNorm2d(planes) | ||
self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) | ||
self.bn3 = nn.BatchNorm2d(planes * self.expansion) | ||
self.IN = nn.InstanceNorm2d(planes * 4, affine=True) if ibn == 'b' else None | ||
self.relu = nn.ReLU(inplace=True) | ||
self.downsample = downsample | ||
self.stride = stride | ||
|
||
def forward(self, x): | ||
residual = x | ||
|
||
out = self.conv1(x) | ||
out = self.bn1(out) | ||
out = self.relu(out) | ||
|
||
out = self.conv2(out) | ||
out = self.bn2(out) | ||
out = self.relu(out) | ||
|
||
out = self.conv3(out) | ||
out = self.bn3(out) | ||
|
||
if self.downsample is not None: | ||
residual = self.downsample(x) | ||
|
||
out += residual | ||
if self.IN is not None: | ||
out = self.IN(out) | ||
out = self.relu(out) | ||
|
||
return out | ||
|
||
|
||
class ResNet_IBN(nn.Module): | ||
def __init__(self, | ||
block, | ||
layers, | ||
ibn_cfg=('a', 'a', 'a', None)): | ||
self.inplanes = 64 | ||
super(ResNet_IBN, self).__init__() | ||
self.conv1 = nn.Sequential( | ||
nn.Conv2d(4, 64, kernel_size=3, stride=2, padding=1, bias=False), | ||
nn.BatchNorm2d(64), | ||
nn.ReLU(inplace=True), | ||
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False), | ||
) | ||
if ibn_cfg[0] == 'b': | ||
self.bn1 = nn.InstanceNorm2d(64, affine=True) | ||
else: | ||
self.bn1 = nn.BatchNorm2d(64) | ||
|
||
self.relu = nn.ReLU(inplace=True) | ||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) | ||
self.layer1 = self._make_layer(block, 64, layers[0], ibn=ibn_cfg[0]) | ||
self.layer2 = self._make_layer(block, 128, layers[1], stride=2, ibn=ibn_cfg[1]) | ||
self.layer3 = self._make_layer(block, 256, layers[2], stride=2, ibn=ibn_cfg[2]) | ||
self.layer4 = self._make_layer(block, 512, layers[3], stride=2, ibn=ibn_cfg[3]) | ||
|
||
self.out_channels = None | ||
|
||
# self.avgpool = nn.AvgPool2d(7) | ||
# self.fc = nn.Linear(512 * block.expansion, num_classes) | ||
|
||
for m in self.modules(): | ||
if isinstance(m, nn.Conv2d): | ||
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels | ||
m.weight.data.normal_(0, math.sqrt(2. / n)) | ||
elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.InstanceNorm2d): | ||
m.weight.data.fill_(1) | ||
m.bias.data.zero_() | ||
|
||
def _make_layer(self, block, planes, blocks, stride=1, ibn=None): | ||
downsample = None | ||
if stride != 1 or self.inplanes != planes * block.expansion: | ||
downsample = nn.Sequential( | ||
nn.Conv2d(self.inplanes, planes * block.expansion, | ||
kernel_size=1, stride=stride, bias=False), | ||
nn.BatchNorm2d(planes * block.expansion), | ||
) | ||
|
||
layers = [] | ||
layers.append(block(self.inplanes, planes, | ||
None if ibn == 'b' else ibn, | ||
stride, downsample)) | ||
self.inplanes = planes * block.expansion | ||
for i in range(1, blocks): | ||
layers.append(block(self.inplanes, planes, | ||
None if (ibn == 'b' and i < blocks - 1) else ibn)) | ||
|
||
return nn.Sequential(*layers) | ||
|
||
def forward(self, x): | ||
outputs = [x] | ||
|
||
x = self.conv1(x) | ||
x = self.bn1(x) | ||
x = self.relu(x) | ||
outputs.append(x) | ||
|
||
x = self.maxpool(x) | ||
x = self.layer1(x) | ||
outputs.append(x) | ||
|
||
x = self.layer2(x) | ||
outputs.append(x) | ||
|
||
x = self.layer3(x) | ||
outputs.append(x) | ||
|
||
x = self.layer4(x) | ||
outputs.append(x) | ||
|
||
return outputs | ||
|
||
|
||
def resnet50_ibn_a(pretrained=False): | ||
"""Constructs a ResNet-50-IBN-a model. | ||
Args: | ||
pretrained (bool): If True, returns a model pre-trained on ImageNet | ||
""" | ||
model = ResNet_IBN(block=Bottleneck_IBN, | ||
layers=[3, 4, 6, 3], | ||
ibn_cfg=('a', 'a', 'a', None)) | ||
model.out_channels = output_channles['resnet50_ibn_a'] | ||
if pretrained: | ||
|
||
pretrained_dict = torch.load('resnet50_ibn_a-d9d0bb7b.pth') | ||
print('=> loading pretrained model {}'.format(pretrained)) | ||
model_dict = model.state_dict() | ||
pretrained_dict = {k: v for k, v in pretrained_dict.items() | ||
if k in model_dict.keys()} | ||
# for k, _ in pretrained_dict.items(): | ||
# print('=> loading {} pretrained model {}'.format(k, pretrained)) | ||
model_dict.update(pretrained_dict) | ||
model.load_state_dict(model_dict) | ||
|
||
return model | ||
|
||
|
||
def resnet50_ibn_b(pretrained=False, **kwargs): | ||
"""Constructs a ResNet-50-IBN-b model. | ||
Args: | ||
pretrained (bool): If True, returns a model pre-trained on ImageNet | ||
""" | ||
model = ResNet_IBN(block=Bottleneck_IBN, | ||
layers=[3, 4, 6, 3], | ||
ibn_cfg=('b', 'b', None, None)) | ||
model.out_channels = output_channles['resnet50_ibn_b'] | ||
if pretrained: | ||
pretrained_state_dict = torch.load(os.path.join(PROJECT_DIR, 'external_data', 'resnet50_ibn_b-9ca61e85.pth')) | ||
pretrained_state_dict.pop("fc.bias") | ||
pretrained_state_dict.pop("fc.weight") | ||
model.load_state_dict(pretrained_state_dict) | ||
return model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
from model.backbone import resnet50_ibn_a | ||
from model.modules.basics import Conv2dBnAct | ||
from model.modules.blocks.ibn import IBNaDecoderBlock | ||
|
||
|
||
class CenterBlock(nn.Module): | ||
def __init__(self, in_channels, out_channels): | ||
super(CenterBlock, self).__init__() | ||
self.conv = Conv2dBnAct(in_channels, out_channels, 3, 1, 1) | ||
|
||
def forward(self, x): | ||
return self.conv(x) | ||
|
||
|
||
class UnetIBN(nn.Module): | ||
def __init__(self, | ||
encoder_pretrained=True, | ||
head_channels=512, | ||
decoder_channels=[256, 128, 64, 32], | ||
dropout=0., | ||
classes=10): | ||
super(UnetIBN, self).__init__() | ||
|
||
# ENCODER | ||
self.encoder = resnet50_ibn_a(pretrained=encoder_pretrained) | ||
encoder_channels = self.encoder.out_channels[1:] | ||
|
||
# CENTER BLOCK | ||
self.center_block = CenterBlock(encoder_channels[-1], head_channels) | ||
|
||
# DECODER | ||
skip_channels = encoder_channels[:-1][::-1] | ||
input_channels = [head_channels] + decoder_channels[:-1] | ||
|
||
self.decoder_modules = nn.ModuleList() | ||
for in_ch, sk_ch, de_ch in zip(input_channels, skip_channels, decoder_channels): | ||
self.decoder_modules.append(IBNaDecoderBlock(in_ch + sk_ch, de_ch, use_attention=True)) | ||
|
||
# PREDICT | ||
self.pred_head = nn.Sequential( | ||
nn.Conv2d(decoder_channels[-1], 32, 3, 1, 1, bias=False), | ||
nn.BatchNorm2d(32), | ||
nn.ReLU(inplace=True), | ||
nn.Dropout(dropout), | ||
nn.Conv2d(32, classes, 1), | ||
) | ||
|
||
def forward(self, x): | ||
encoder_feats = self.encoder(x)[1:] | ||
decoder_feat = self.center_block(encoder_feats[-1]) | ||
skip_feats = encoder_feats[:-1][::-1] | ||
for idx, decoder_module in enumerate(self.decoder_modules): | ||
decoder_feat = decoder_module(decoder_feat, skip_feats[idx]) | ||
out = self.pred_head(decoder_feat) | ||
out = F.interpolate(out, size=x.shape[-2:]) | ||
return out |