Skip to content

Commit

Permalink
Changes (#53)
Browse files Browse the repository at this point in the history
* fix bug in cutmix (again)

* support output_stride in BiFPN

* fixed cutmix bug connected to `prev_data` shape
  • Loading branch information
bonlime authored Mar 17, 2020
1 parent 0a7c4da commit 4b6e145
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 14 deletions.
2 changes: 1 addition & 1 deletion pytorch_tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.1.0"
__version__ = "0.1.1"

from . import fit_wrapper
from . import losses
Expand Down
19 changes: 10 additions & 9 deletions pytorch_tools/fit_wrapper/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,11 +526,11 @@ def mixup(self, data, target):
target_one_hot = target
if not self.state.is_train or np.random.rand() > self.prob:
return data, target_one_hot
prev_data, prev_target = data, target_one_hot if self.prev_input is None else self.prev_input
prev_data, prev_target = (data, target_one_hot) if self.prev_input is None else self.prev_input
self.prev_input = data, target_one_hot
c = self.tb.sample()
md = c * data + (1 - c) * prev_data
mt = c * target_one_hot + (1 - c) * prev_target
self.prev_input = data, target_one_hot
return md, mt


Expand Down Expand Up @@ -569,16 +569,17 @@ def cutmix(self, data, target):
target_one_hot = target
if not self.state.is_train or np.random.rand() > self.prob:
return data, target_one_hot
prev_data, prev_target = data, target_one_hot if self.prev_input is None else self.prev_input
_, _, H, W = data.size()
prev_data, prev_target = (data, target_one_hot) if self.prev_input is None else self.prev_input
self.prev_input = data, target_one_hot
# prev_data shape can be different from current. so need to take min
H, W = min(data.size(2), prev_data.size(2)), min(data.size(3), prev_data.size(3))
lam = self.tb.sample()
lam = min([lam, 1 - lam])
bbh1, bbw1, bbh2, bbw2 = self.rand_bbox(H, W, lam)
# real lambda may be diffrent from sampled. adjust for it
lam = (bbh2 - bbh1) * (bbw2 - bbw1) / (H * W)
data[:, bbh1:bbh2, bbw1:bbw2] = prev_data[:, bbh1:bbh2, bbw1:bbw2]
data[:, :, bbh1:bbh2, bbw1:bbw2] = prev_data[:, :, bbh1:bbh2, bbw1:bbw2]
mixed_target = (1 - lam) * target_one_hot + lam * prev_target
self.prev_input = data, target_one_hot
return data, mixed_target

@staticmethod
Expand Down Expand Up @@ -607,12 +608,12 @@ def __init__(self, alpha=1.0, prob=0.5):
def cutmix(self, data, target):
if not self.state.is_train or np.random.rand() > self.prob:
return data, target
prev_data, prev_target = data, target if self.prev_input is None else self.prev_input
_, _, H, W = data.size()
prev_data, prev_target = (data, target) if self.prev_input is None else self.prev_input
self.prev_input = data, target
H, W = min(data.size(2), prev_data.size(2)), min(data.size(3), prev_data.size(3))
lam = self.tb.sample()
lam = min([lam, 1 - lam])
bbh1, bbw1, bbh2, bbw2 = self.rand_bbox(H, W, lam)
data[:, :, bbh1:bbh2, bbw1:bbw2] = prev_data[:, :, bbh1:bbh2, bbw1:bbw2]
target[:, :, bbh1:bbh2, bbw1:bbw2] = prev_target[:, :, bbh1:bbh2, bbw1:bbw2]
self.prev_input = data, target
return data, target
11 changes: 7 additions & 4 deletions pytorch_tools/modules/bifpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,15 @@ class BiFPNLayer(nn.Module):
p_out: features processed by 1 layer of BiFPN
"""

def __init__(self, channels=64, upsample_mode="nearest", **bn_args):
def __init__(self, channels=64, output_stride=32, upsample_mode="nearest", **bn_args):
super(BiFPNLayer, self).__init__()

self.up = nn.Upsample(scale_factor=2, mode=upsample_mode)
self.first_up = self.up if output_stride == 32 else nn.Identity()
last_stride = 2 if output_stride == 32 else 1
self.down_p2 = DepthwiseSeparableConv(channels, channels, stride=2, **bn_args)
self.down_p3 = DepthwiseSeparableConv(channels, channels, stride=2, **bn_args)
self.down_p4 = DepthwiseSeparableConv(channels, channels, stride=2, **bn_args)
self.down_p4 = DepthwiseSeparableConv(channels, channels, stride=last_stride, **bn_args)

## TODO (jamil) 11.02.2020 Rewrite this using list comprehensions
self.fuse_p4_td = FastNormalizedFusion(in_nodes=2)
Expand Down Expand Up @@ -75,7 +77,7 @@ def forward(self, features):
p5_inp, p4_inp, p3_inp, p2_inp, p1_inp = features

# Top-down pathway
p4_td = self.p4_td(self.fuse_p4_td(p4_inp, self.up(p5_inp)))
p4_td = self.p4_td(self.fuse_p4_td(p4_inp, self.first_up(p5_inp)))
p3_td = self.p3_td(self.fuse_p3_td(p3_inp, self.up(p4_td)))
p2_out = self.p2_td(self.fuse_p2_td(p2_inp, self.up(p3_td)))

Expand Down Expand Up @@ -134,6 +136,7 @@ def __init__(
encoder_channels,
pyramid_channels=64,
num_layers=1,
output_stride=32,
**bn_args,
):
super(BiFPN, self).__init__()
Expand All @@ -142,7 +145,7 @@ def __init__(

bifpns = []
for _ in range(num_layers):
bifpns.append(BiFPNLayer(pyramid_channels, **bn_args))
bifpns.append(BiFPNLayer(pyramid_channels, output_stride, **bn_args))
self.bifpn = nn.Sequential(*bifpns)

def forward(self, features):
Expand Down
1 change: 1 addition & 0 deletions pytorch_tools/segmentation_models/segm_fpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def __init__(
self.encoder.out_shapes,
pyramid_channels=pyramid_channels,
num_layers=num_fpn_layers,
output_stride=output_stride,
**bn_args,
)

Expand Down

0 comments on commit 4b6e145

Please sign in to comment.