Skip to content

Commit

Permalink
Merge pull request #54 from bonlime/dev
Browse files Browse the repository at this point in the history
Merge last month commits
  • Loading branch information
bonlime authored Apr 10, 2020
2 parents 4b6e145 + 1e86b17 commit 218b5ac
Show file tree
Hide file tree
Showing 35 changed files with 1,165 additions and 534 deletions.
2 changes: 1 addition & 1 deletion pytorch_tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.1.1"
__version__ = "0.1.2"

from . import fit_wrapper
from . import losses
Expand Down
Empty file.
74 changes: 74 additions & 0 deletions pytorch_tools/detection_models/retinanet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# import torch
import torch.nn as nn
import torch.nn.functional as F
from pytorch_tools.modules.fpn import FPN
# from pytorch_tools.modules.bifpn import BiFPN
from pytorch_tools.modules import bn_from_name
# from pytorch_tools.modules.residual import conv1x1
from pytorch_tools.modules.residual import conv3x3
# from pytorch_tools.modules.decoder import SegmentationUpsample
# from pytorch_tools.utils.misc import initialize
from pytorch_tools.segmentation_models.encoders import get_encoder


class RetinaNet(nn.Module):
def __init__(
self,
encoder_name="resnet34",
encoder_weights="imagenet",
pyramid_channels=256,
num_classes=80,
norm_layer="abn",
norm_act="relu",
**encoder_params,
):
super().__init__()
self.encoder = get_encoder(
encoder_name,
norm_layer=norm_layer,
norm_act=norm_act,
encoder_weights=encoder_weights,
**encoder_params,
)
norm_layer = bn_from_name(norm_layer)
self.pyramid6 = conv3x3(256, 256, 2, bias=True)
self.pyramid7 = conv3x3(256, 256, 2, bias=True)
self.fpn = FPN(
self.encoder.out_shapes[:-2],
pyramid_channels=pyramid_channels,
)

def make_head(out_size):
layers = []
for _ in range(4):
# some implementations don't use BN here but I think it's needed
# TODO: test how it affects results
layers += [nn.Conv2d(256, 256, 3, padding=1), norm_layer(256, activation=norm_act)]
# layers += [nn.Conv2d(256, 256, 3, padding=1), nn.ReLU()]

layers += [nn.Conv2d(256, out_size, 3, padding=1)]
return nn.Sequential(*layers)

self.ratios = [1.0, 2.0, 0.5]
self.scales = [4 * 2 ** (i / 3) for i in range(3)]
anchors = len(self.ratios) * len(self.scales) # 9

self.cls_head = make_head(num_classes * anchors)
self.box_head = make_head(4 * anchors)

def forward(self, x):
# don't use p2 and p1
p5, p4, p3, _, _ = self.encoder(x)
# enhance features
p5, p4, p3 = self.fpn([p5, p4, p3])
# coarsers FPN levels
p6 = self.pyramid6(p5)
p7 = self.pyramid7(F.relu(p6))
features = [p7, p6, p5, p4, p3]
# TODO: (18.03.20) TF implementation has additional BN here before class/box outputs
class_outputs = [self.cls_head(f) for f in features]
box_outputs = [self.box_head(f) for f in features]
return class_outputs, box_outputs



96 changes: 71 additions & 25 deletions pytorch_tools/fit_wrapper/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,38 +271,49 @@ class CheckpointSaver(Callback):
Args:
save_dir (str): path to folder where to save the model
save_name (str): name of the saved model. can additionally
save_name (str): name of the saved model. can additionally
add epoch and metric to model save name
monitor (str): quantity to monitor. Implicitly prefers validation metrics over train. One of:
`loss` or name of any metric passed to the runner.
mode (str): one of "min" of "max". Whether to decide to save based
on minimizing or maximizing loss
include_optimizer (bool): if True would also save `optimizers` state_dict.
include_optimizer (bool): if True would also save `optimizers` state_dict.
This increases checkpoint size 2x times.
verbose (bool): If `True` reports each time new best is found
"""

def __init__(
self, save_dir, save_name="model_{ep}_{metric:.2f}.chpn", mode="min", include_optimizer=False
self,
save_dir,
save_name="model_{ep}_{metric:.2f}.chpn",
monitor="loss",
mode="min",
include_optimizer=False,
verbose=True,
):
super().__init__()
self.save_dir = save_dir
self.save_name = save_name
self.mode = ReduceMode(mode)
self.best = float("inf") if self.mode == ReduceMode.MIN else -float("inf")
self.monitor = monitor
mode = ReduceMode(mode)
if mode == ReduceMode.MIN:
self.best = np.inf
self.monitor_op = np.less
elif mode == ReduceMode.MAX:
self.best = -np.inf
self.monitor_op = np.greater
self.include_optimizer = include_optimizer
self.verbose = verbose

def on_begin(self):
os.makedirs(self.save_dir, exist_ok=True)

def on_epoch_end(self):
# TODO zakirov(1.11.19) Add support for saving based on metric
if self.state.val_loss is not None:
current = self.state.val_loss.avg
else:
current = self.state.train_loss.avg
if (self.mode == ReduceMode.MIN and current < self.best) or (
self.mode == ReduceMode.MAX and current > self.best
):
ep = self.state.epoch
# print(f"Epoch {ep}: best loss improved from {self.best:.4f} to {current:.4f}")
current = self.get_monitor_value()
if self.monitor_op(current, self.best):
ep = self.state.epoch_log
if self.verbose:
print(f"Epoch {ep:2d}: best {self.monitor} improved from {self.best:.4f} to {current:.4f}")
self.best = current
save_name = os.path.join(self.save_dir, self.save_name.format(ep=ep, metric=current))
self._save_checkpoint(save_name)
Expand All @@ -317,6 +328,18 @@ def _save_checkpoint(self, path):
save_dict["optimizer"] = self.state.optimizer.state_dict()
torch.save(save_dict, path)

def get_monitor_value(self):
value = None
if self.monitor == "loss":
value = self.state.loss_meter.avg
else:
for metric_meter in self.state.metric_meters:
if metric_meter.name == self.monitor:
value = metric_meter.avg
if value is None:
raise ValueError(f"CheckpointSaver can't find {self.monitor} value to monitor")
return value


class TensorBoard(Callback):
"""
Expand Down Expand Up @@ -407,7 +430,7 @@ def on_batch_end(self):

def on_loader_end(self):
super().on_loader_end()
f = plot_confusion_matrix(self.cmap, self.class_names, show=False)
f = plot_confusion_matrix(self.cmap, self.class_names, normalize=True, show=False)
cm_img = render_figure_to_tensor(f)
if self.state.is_train:
self.train_cm_img = cm_img
Expand Down Expand Up @@ -527,10 +550,11 @@ def mixup(self, data, target):
if not self.state.is_train or np.random.rand() > self.prob:
return data, target_one_hot
prev_data, prev_target = (data, target_one_hot) if self.prev_input is None else self.prev_input
self.prev_input = data, target_one_hot
self.prev_input = data.clone(), target_one_hot.clone()
perm = torch.randperm(data.size(0)).cuda()
c = self.tb.sample()
md = c * data + (1 - c) * prev_data
mt = c * target_one_hot + (1 - c) * prev_target
md = c * data + (1 - c) * prev_data[perm]
mt = c * target_one_hot + (1 - c) * prev_target[perm]
return md, mt


Expand Down Expand Up @@ -570,16 +594,17 @@ def cutmix(self, data, target):
if not self.state.is_train or np.random.rand() > self.prob:
return data, target_one_hot
prev_data, prev_target = (data, target_one_hot) if self.prev_input is None else self.prev_input
self.prev_input = data, target_one_hot
self.prev_input = data.clone(), target_one_hot.clone()
# prev_data shape can be different from current. so need to take min
H, W = min(data.size(2), prev_data.size(2)), min(data.size(3), prev_data.size(3))
perm = torch.randperm(data.size(0)).cuda()
lam = self.tb.sample()
lam = min([lam, 1 - lam])
bbh1, bbw1, bbh2, bbw2 = self.rand_bbox(H, W, lam)
# real lambda may be diffrent from sampled. adjust for it
lam = (bbh2 - bbh1) * (bbw2 - bbw1) / (H * W)
data[:, :, bbh1:bbh2, bbw1:bbw2] = prev_data[:, :, bbh1:bbh2, bbw1:bbw2]
mixed_target = (1 - lam) * target_one_hot + lam * prev_target
data[:, :, bbh1:bbh2, bbw1:bbw2] = prev_data[perm, :, bbh1:bbh2, bbw1:bbw2]
mixed_target = (1 - lam) * target_one_hot + lam * prev_target[perm]
return data, mixed_target

@staticmethod
Expand Down Expand Up @@ -609,11 +634,32 @@ def cutmix(self, data, target):
if not self.state.is_train or np.random.rand() > self.prob:
return data, target
prev_data, prev_target = (data, target) if self.prev_input is None else self.prev_input
self.prev_input = data, target
self.prev_input = data.clone(), target.clone()
H, W = min(data.size(2), prev_data.size(2)), min(data.size(3), prev_data.size(3))
perm = torch.randperm(data.size(0)).cuda()
lam = self.tb.sample()
lam = min([lam, 1 - lam])
bbh1, bbw1, bbh2, bbw2 = self.rand_bbox(H, W, lam)
data[:, :, bbh1:bbh2, bbw1:bbw2] = prev_data[:, :, bbh1:bbh2, bbw1:bbw2]
target[:, :, bbh1:bbh2, bbw1:bbw2] = prev_target[:, :, bbh1:bbh2, bbw1:bbw2]
data[:, :, bbh1:bbh2, bbw1:bbw2] = prev_data[perm, :, bbh1:bbh2, bbw1:bbw2]
target[:, :, bbh1:bbh2, bbw1:bbw2] = prev_target[perm, :, bbh1:bbh2, bbw1:bbw2]
return data, target


class ScheduledDropout(Callback):
def __init__(self, drop_rate=0.1, epochs=30, attr_name="dropout.p"):
"""
Slowly changes dropout value for `attr_name` each epoch.
Ref: https://arxiv.org/abs/1703.06229
Args:
drop_rate (float): max dropout rate
epochs (int): num epochs to max dropout to fully take effect
attr_name (str): name of dropout block in model
"""
super().__init__()
self.drop_rate = drop_rate
self.epochs = epochs
self.attr_name = attr_name

def on_epoch_end(self):
current_rate = self.drop_rate * min(1, self.state.epoch / self.epochs)
setattr(self.state.model, self.attr_name, current_rate)
3 changes: 2 additions & 1 deletion pytorch_tools/fit_wrapper/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ def __init__(self, model, optimizer, criterion, metrics=None, callbacks=ConsoleL
super().__init__()

if not hasattr(amp._amp_state, "opt_properties"):
model, optimizer = amp.initialize(model, optimizer, enabled=False)
model_optimizer = amp.initialize(model, optimizer, enabled=False)
model, optimizer = (model_optimizer, None) if optimizer is None else model_optimizer

self.state = RunnerState(model=model, optimizer=optimizer, criterion=criterion, metrics=metrics,)
self.callbacks = Callbacks(callbacks)
Expand Down
7 changes: 4 additions & 3 deletions pytorch_tools/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,16 @@
import torch.nn as nn

from .base import Loss
from .focal import BinaryFocalLoss, FocalLoss
from .dice_jaccard import DiceLoss, JaccardLoss
from .focal import FocalLoss
from .dice_jaccard import DiceLoss
from .dice_jaccard import JaccardLoss
from .lovasz import LovaszLoss
from .wing_loss import WingLoss
from .vgg_loss import ContentLoss, StyleLoss
from .smooth import CrossEntropyLoss
from .hinge import BinaryHinge

from .functional import sigmoid_focal_loss
from .functional import focal_loss_with_logits
from .functional import soft_dice_score
from .functional import soft_jaccard_score
from .functional import wing_loss
Expand Down
6 changes: 5 additions & 1 deletion pytorch_tools/losses/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@ class Mode(Enum):
MULTICLASS = "multiclass"
MULTILABEL = "multilabel"


class Reduction(Enum):
SUM = "sum"
MEAN = "mean"
NONE = "none"

class Loss(_Loss):
"""Loss which supports addition and multiplication"""

Expand Down
39 changes: 24 additions & 15 deletions pytorch_tools/losses/dice_jaccard.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,22 @@ class DiceLoss(Loss):
"""
Implementation of Dice loss for image segmentation task.
It supports binary, multiclass and multilabel cases
Args:
mode (str): Target mode {'binary', 'multiclass', 'multilabel'}
'multilabel' - expects y_true of shape [N, C, H, W]
'multiclass', 'binary' - expects y_true of shape [N, H, W]
log_loss (bool): If True, loss computed as `-log(jaccard)`; otherwise `1 - jaccard`
from_logits (bool): If True assumes input is raw logits
eps (float): small epsilon for numerical stability
Shape:
y_pred: [N, C, H, W]
y_true: [N, C, H, W] or [N, H, W] depending on mode
"""

IOU_FUNCTION = soft_dice_score

def __init__(self, mode="binary", log_loss=False, from_logits=True, eps=1.):
"""
Args:
mode (str): Target mode {'binary', 'multiclass', 'multilabel'}
'multilabel' - expects y_true of shape [N, C, H, W]
'multiclass', 'binary' - expects y_true of shape [N, H, W]
log_loss (bool): If True, loss computed as `-log(jaccard)`; otherwise `1 - jaccard`
from_logits (bool): If True assumes input is raw logits
eps (float): small epsilon for numerical stability
Shape:
y_pred: [N, C, H, W]
y_true: [N, C, H, W] or [N, H, W] depending on mode
"""

super(DiceLoss, self).__init__()
self.mode = Mode(mode) # raises an error if not valid
self.log_loss = log_loss
Expand All @@ -34,9 +32,9 @@ def __init__(self, mode="binary", log_loss=False, from_logits=True, eps=1.):
def forward(self, y_pred, y_true):
if self.from_logits:
# Apply activations to get [0..1] class probabilities
if self.mode == Mode.BINARY:
if self.mode == Mode.BINARY or self.mode == Mode.MULTILABEL:
y_pred = y_pred.sigmoid()
else:
elif self.mode == Mode.MULTICLASS:
y_pred = y_pred.softmax(dim=1)

bs = y_true.size(0)
Expand Down Expand Up @@ -74,6 +72,17 @@ class JaccardLoss(DiceLoss):
"""
Implementation of Jaccard loss for image segmentation task.
It supports binary, multiclass and multilabel cases
Args:
mode (str): Target mode {'binary', 'multiclass', 'multilabel'}
'multilabel' - expects y_true of shape [N, C, H, W]
'multiclass', 'binary' - expects y_true of shape [N, H, W]
log_loss (bool): If True, loss computed as `-log(jaccard)`; otherwise `1 - jaccard`
from_logits (bool): If True assumes input is raw logits
eps (float): small epsilon for numerical stability
Shape:
y_pred: [N, C, H, W]
y_true: [N, C, H, W] or [N, H, W] depending on mode
"""

# the only difference is which function to use
Expand Down
Loading

0 comments on commit 218b5ac

Please sign in to comment.