-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
24 changed files
with
952 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
import torch.nn as nn | ||
import segmentation_models_pytorch as smp | ||
|
||
|
||
class FPN(nn.Module): | ||
def __init__(self, num_classes): | ||
super(FPN, self).__init__() | ||
|
||
self.model = smp.FPN( | ||
encoder_name='resnet50', | ||
encoder_depth=5, | ||
encoder_weights=None, | ||
decoder_pyramid_channels=256, | ||
decoder_segmentation_channels=128, | ||
decoder_merge_policy='add', | ||
decoder_dropout=0., | ||
in_channels=3, | ||
classes=num_classes | ||
) | ||
|
||
def forward(self, x): | ||
logits = self.model(x) | ||
return [logits] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
from pytorch_toolbelt import losses as L | ||
|
||
from model.losses.pseudo_ce_loss import PseudoCrossEntropyLoss | ||
|
||
|
||
class LossFunction(nn.Module): | ||
def __init__(self): | ||
super(LossFunction, self).__init__() | ||
|
||
self.loss_func1 = nn.CrossEntropyLoss() | ||
self.loss_func2 = L.DiceLoss(mode='multiclass') | ||
|
||
def forward(self, logits, target): | ||
loss = self.loss_func1(logits[0], target) + 0.2 * self.loss_func2(logits[0], target) | ||
return loss | ||
|
||
|
||
class SelfCorrectionLossFunction(nn.Module): | ||
def __init__(self, cycle=12): | ||
super(SelfCorrectionLossFunction, self).__init__() | ||
self.cycle = cycle | ||
|
||
self.sc_loss_func1 = PseudoCrossEntropyLoss() | ||
self.sc_loss_func2 = L.DiceLoss(mode='multiclass') | ||
|
||
def forward(self, predicts, target, soft_predict, cycle_n): | ||
with torch.no_grad: | ||
soft_predict = F.softmax(soft_predict, dim=1) | ||
soft_predict = self.weighted(self.to_one_hot(target, soft_predict.size(1)), soft_predict, | ||
alpha=1. / (cycle_n + 1)) | ||
loss1 = self.sc_loss_func1(predicts[0], soft_predict) | ||
loss2 = self.sc_loss_func2(predicts, target) | ||
return loss1 + 0.2 * loss2 | ||
|
||
@staticmethod | ||
def weighted(target_one_hot, soft_predict, alpha): | ||
soft_predict = alpha * target_one_hot + (1 - alpha) * soft_predict | ||
return soft_predict | ||
|
||
@staticmethod | ||
def to_one_hot(tensor, num_cls, dim=1, ignore_index=255): | ||
b, h, w = tensor.shape | ||
tensor[tensor == ignore_index] = 0 | ||
onehot_tensor = torch.zeros(b, num_cls, h, w).cuda() | ||
onehot_tensor.scatter_(dim, tensor.unsqueeze(dim), 1) | ||
return onehot_tensor |
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
from torch import Tensor | ||
|
||
|
||
class PseudoCrossEntropyLoss(nn.Module): | ||
def __init__(self, dim=1): | ||
super(PseudoCrossEntropyLoss, self).__init__() | ||
self.dim = dim | ||
|
||
def forward(self, input: Tensor, target: Tensor): | ||
input_log_prob = F.log_softmax(input, dim=self.dim) | ||
loss = torch.sum(-input_log_prob * target, dim=self.dim) | ||
return loss.mean() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
import torch.optim as optim | ||
|
||
from .radam import RAdam | ||
from .lookahead import Lookahead | ||
from .cyclicLR import CyclicCosAnnealingLR | ||
from .warmup_scheduler import GradualWarmupScheduler | ||
|
||
|
||
def get_optimizer(params, optimizer_cfg): | ||
if optimizer_cfg['mode'] == 'SGD': | ||
optimizer = optim.SGD(params, lr=optimizer_cfg['lr'], momentum=0.9, | ||
weight_decay=optimizer_cfg['weight_decay'], nesterov=optimizer_cfg['nesterov']) | ||
elif optimizer_cfg['mode'] == 'RAdam': | ||
optimizer = RAdam(params, lr=optimizer_cfg['lr'], betas=(0.9, 0.999), | ||
weight_decay=optimizer_cfg['weight_decay']) | ||
else: | ||
optimizer = optim.Adam(params, lr=optimizer_cfg['lr'], betas=(0.9, 0.999), | ||
weight_decay=optimizer_cfg['weight_decay']) | ||
|
||
if optimizer_cfg['lookahead']: | ||
optimizer = Lookahead(optimizer, k=5, alpha=0.5) | ||
|
||
# todo: add split_weights.py | ||
|
||
return optimizer | ||
|
||
|
||
def get_scheduler(optimizer, scheduler_cfg): | ||
MODE = scheduler_cfg['mode'] | ||
|
||
if MODE == 'OneCycleLR': | ||
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, | ||
max_lr=optimizer.param_groups[0]['lr'], | ||
total_steps=scheduler_cfg['steps'], | ||
pct_start=scheduler_cfg['pct_start'], | ||
final_div_factor=scheduler_cfg['final_div_factor'], | ||
cycle_momentum=scheduler_cfg['cycle_momentum'], | ||
anneal_strategy=scheduler_cfg['anneal_strategy']) | ||
|
||
elif MODE == 'PolyLR': | ||
lr_lambda = lambda step: (1 - step / scheduler_cfg['steps']) ** scheduler_cfg['power'] | ||
scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) | ||
|
||
elif MODE == 'CosineAnnealingLR': | ||
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=scheduler_cfg['steps'], | ||
eta_min=scheduler_cfg['eta_min']) | ||
|
||
elif MODE == 'MultiStepLR': | ||
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, | ||
scheduler_cfg['milestones'], | ||
gamma=scheduler_cfg['gamma']) | ||
|
||
elif MODE == 'CosineAnnealingWarmRestarts': | ||
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, | ||
T_0=scheduler_cfg['T_0'], | ||
T_mult=scheduler_cfg['T_multi'], | ||
eta_min=scheduler_cfg['eta_min']) | ||
|
||
elif MODE == 'CyclicCosAnnealingLR': | ||
scheduler = CyclicCosAnnealingLR(optimizer, | ||
milestones=scheduler_cfg['milestones'], | ||
decay_milestones=scheduler_cfg['decay_milestones'], | ||
eta_min=scheduler_cfg['eta_min'], | ||
gamma=scheduler_cfg['gamma']) | ||
|
||
elif scheduler_cfg.MODE == 'GradualWarmupScheduler': | ||
milestones = list(map(lambda x: x - scheduler_cfg['warmup_steps'], scheduler_cfg['milestones'])) | ||
scheduler_steplr = optim.lr_scheduler.MultiStepLR(optimizer, | ||
milestones=milestones, | ||
gamma=scheduler_cfg['gamma']) | ||
scheduler = GradualWarmupScheduler(optimizer, | ||
multiplier=scheduler_cfg['milestones'], | ||
total_epoch=scheduler_cfg['warmup_steps'], | ||
after_scheduler=scheduler_steplr) | ||
else: | ||
raise ValueError | ||
|
||
return scheduler |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
import math | ||
from bisect import bisect_right | ||
from torch.optim.lr_scheduler import _LRScheduler | ||
from torch.optim.optimizer import Optimizer | ||
|
||
|
||
class CyclicCosAnnealingLR(_LRScheduler): | ||
r""" | ||
Implements reset on milestones inspired from CosineAnnealingLR pytorch | ||
Set the learning rate of each parameter group using a cosine annealing | ||
schedule, where :math:`\eta_{max}` is set to the initial lr and | ||
:math:`T_{cur}` is the number of epochs since the last restart in SGDR: | ||
.. math:: | ||
\eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})(1 + | ||
\cos(\frac{T_{cur}}{T_{max}}\pi)) | ||
When last_epoch > last set milestone, lr is automatically set to \eta_{min} | ||
It has been proposed in | ||
`SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this only | ||
implements the cosine annealing part of SGDR, and not the restarts. | ||
Args: | ||
optimizer (Optimizer): Wrapped optimizer. | ||
milestones (list of ints): List of epoch indices. Must be increasing. | ||
decay_milestones(list of ints):List of increasing epoch indices. Ideally,decay values should overlap with milestone points | ||
gamma (float): factor by which to decay the max learning rate at each decay milestone | ||
eta_min (float): Minimum learning rate. Default: 1e-6 | ||
last_epoch (int): The index of last epoch. Default: -1. | ||
.. _SGDR\: Stochastic Gradient Descent with Warm Restarts: | ||
https://arxiv.org/abs/1608.03983 | ||
""" | ||
|
||
def __init__(self, optimizer, milestones, decay_milestones=None, gamma=0.5, eta_min=1e-6, last_epoch=-1): | ||
if not list(milestones) == sorted(milestones): | ||
raise ValueError('Milestones should be a list of' | ||
' increasing integers. Got {}', milestones) | ||
self.eta_min = eta_min | ||
self.milestones = milestones | ||
self.milestones2 = decay_milestones | ||
|
||
self.gamma = gamma | ||
super(CyclicCosAnnealingLR, self).__init__(optimizer, last_epoch) | ||
|
||
def get_lr(self): | ||
|
||
if self.last_epoch >= self.milestones[-1]: | ||
return [self.eta_min for base_lr in self.base_lrs] | ||
|
||
idx = bisect_right(self.milestones, self.last_epoch) | ||
|
||
left_barrier = 0 if idx == 0 else self.milestones[idx - 1] | ||
right_barrier = self.milestones[idx] | ||
|
||
width = right_barrier - left_barrier | ||
curr_pos = self.last_epoch - left_barrier | ||
|
||
if self.milestones2: | ||
return [self.eta_min + ( | ||
base_lr * self.gamma ** bisect_right(self.milestones2, self.last_epoch) - self.eta_min) * | ||
(1 + math.cos(math.pi * curr_pos / width)) / 2 | ||
for base_lr in self.base_lrs] | ||
else: | ||
return [self.eta_min + (base_lr - self.eta_min) * | ||
(1 + math.cos(math.pi * curr_pos / width)) / 2 | ||
for base_lr in self.base_lrs] | ||
|
||
|
||
class CyclicLinearLR(_LRScheduler): | ||
r""" | ||
Implements reset on milestones inspired from Linear learning rate decay | ||
Set the learning rate of each parameter group using a linear decay | ||
schedule, where :math:`\eta_{max}` is set to the initial lr and | ||
:math:`T_{cur}` is the number of epochs since the last restart: | ||
.. math:: | ||
\eta_t = \eta_{min} + (\eta_{max} - \eta_{min})(1 -\frac{T_{cur}}{T_{max}}) | ||
When last_epoch > last set milestone, lr is automatically set to \eta_{min} | ||
Args: | ||
optimizer (Optimizer): Wrapped optimizer. | ||
milestones (list of ints): List of epoch indices. Must be increasing. | ||
decay_milestones(list of ints):List of increasing epoch indices. Ideally,decay values should overlap with milestone points | ||
gamma (float): factor by which to decay the max learning rate at each decay milestone | ||
eta_min (float): Minimum learning rate. Default: 1e-6 | ||
last_epoch (int): The index of last epoch. Default: -1. | ||
.. _SGDR\: Stochastic Gradient Descent with Warm Restarts: | ||
https://arxiv.org/abs/1608.03983 | ||
""" | ||
|
||
def __init__(self, optimizer, milestones, decay_milestones=None, gamma=0.5, eta_min=1e-6, last_epoch=-1): | ||
if not list(milestones) == sorted(milestones): | ||
raise ValueError('Milestones should be a list of' | ||
' increasing integers. Got {}', milestones) | ||
self.eta_min = eta_min | ||
|
||
self.gamma = gamma | ||
self.milestones = milestones | ||
self.milestones2 = decay_milestones | ||
super(CyclicLinearLR, self).__init__(optimizer, last_epoch) | ||
|
||
def get_lr(self): | ||
|
||
if self.last_epoch >= self.milestones[-1]: | ||
return [self.eta_min for base_lr in self.base_lrs] | ||
|
||
idx = bisect_right(self.milestones, self.last_epoch) | ||
|
||
left_barrier = 0 if idx == 0 else self.milestones[idx - 1] | ||
right_barrier = self.milestones[idx] | ||
|
||
width = right_barrier - left_barrier | ||
curr_pos = self.last_epoch - left_barrier | ||
|
||
if self.milestones2: | ||
return [self.eta_min + ( | ||
base_lr * self.gamma ** bisect_right(self.milestones2, self.last_epoch) - self.eta_min) * | ||
(1. - 1.0 * curr_pos / width) | ||
for base_lr in self.base_lrs] | ||
|
||
else: | ||
return [self.eta_min + (base_lr - self.eta_min) * | ||
(1. - 1.0 * curr_pos / width) | ||
for base_lr in self.base_lrs] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
import torch | ||
from torch.optim import Optimizer | ||
from collections import defaultdict | ||
|
||
|
||
class Lookahead(Optimizer): | ||
''' | ||
PyTorch implementation of the lookahead wrapper. | ||
Lookahead Optimizer: https://arxiv.org/abs/1907.08610 | ||
''' | ||
|
||
def __init__(self, optimizer, alpha=0.5, k=6, pullback_momentum="none"): | ||
''' | ||
:param optimizer:inner optimizer | ||
:param k (int): number of lookahead steps | ||
:param alpha(float): linear interpolation factor. 1.0 recovers the inner optimizer. | ||
:param pullback_momentum (str): change to inner optimizer momentum on interpolation update | ||
''' | ||
if not 0.0 <= alpha <= 1.0: | ||
raise ValueError(f'Invalid slow update rate: {alpha}') | ||
if not 1 <= k: | ||
raise ValueError(f'Invalid lookahead steps: {k}') | ||
self.optimizer = optimizer | ||
self.param_groups = self.optimizer.param_groups | ||
self.alpha = alpha | ||
self.k = k | ||
self.step_counter = 0 | ||
assert pullback_momentum in ["reset", "pullback", "none"] | ||
self.pullback_momentum = pullback_momentum | ||
self.state = defaultdict(dict) | ||
|
||
# Cache the current optimizer parameters | ||
for group in self.optimizer.param_groups: | ||
for p in group['params']: | ||
param_state = self.state[p] | ||
param_state['cached_params'] = torch.zeros_like(p.data) | ||
param_state['cached_params'].copy_(p.data) | ||
|
||
def __getstate__(self): | ||
return { | ||
'state': self.state, | ||
'optimizer': self.optimizer, | ||
'alpha': self.alpha, | ||
'step_counter': self.step_counter, | ||
'k': self.k, | ||
'pullback_momentum': self.pullback_momentum | ||
} | ||
|
||
def zero_grad(self): | ||
self.optimizer.zero_grad() | ||
|
||
def state_dict(self): | ||
return self.optimizer.state_dict() | ||
|
||
def load_state_dict(self, state_dict): | ||
self.optimizer.load_state_dict(state_dict) | ||
|
||
def _backup_and_load_cache(self): | ||
"""Useful for performing evaluation on the slow weights (which typically generalize better) | ||
""" | ||
for group in self.optimizer.param_groups: | ||
for p in group['params']: | ||
param_state = self.state[p] | ||
param_state['backup_params'] = torch.zeros_like(p.data) | ||
param_state['backup_params'].copy_(p.data) | ||
p.data.copy_(param_state['cached_params']) | ||
|
||
def _clear_and_load_backup(self): | ||
for group in self.optimizer.param_groups: | ||
for p in group['params']: | ||
param_state = self.state[p] | ||
p.data.copy_(param_state['backup_params']) | ||
del param_state['backup_params'] | ||
|
||
def step(self, closure=None): | ||
"""Performs a single Lookahead optimization step. | ||
Arguments: | ||
closure (callable, optional): A closure that reevaluates the model | ||
and returns the loss. | ||
""" | ||
loss = self.optimizer.step(closure) | ||
self.step_counter += 1 | ||
|
||
if self.step_counter >= self.k: | ||
self.step_counter = 0 | ||
# Lookahead and cache the current optimizer parameters | ||
for group in self.optimizer.param_groups: | ||
for p in group['params']: | ||
param_state = self.state[p] | ||
p.data.mul_(self.alpha).add_(1.0 - self.alpha, param_state['cached_params']) # crucial line | ||
param_state['cached_params'].copy_(p.data) | ||
if self.pullback_momentum == "pullback": | ||
internal_momentum = self.optimizer.state[p]["momentum_buffer"] | ||
self.optimizer.state[p]["momentum_buffer"] = internal_momentum.mul_(self.alpha).add_( | ||
1.0 - self.alpha, param_state["cached_mom"]) | ||
param_state["cached_mom"] = self.optimizer.state[p]["momentum_buffer"] | ||
elif self.pullback_momentum == "reset": | ||
self.optimizer.state[p]["momentum_buffer"] = torch.zeros_like(p.data) | ||
|
||
return loss |
Oops, something went wrong.