Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
reacher-l authored Apr 8, 2021
1 parent 722f679 commit ef6f62d
Show file tree
Hide file tree
Showing 24 changed files with 952 additions and 0 deletions.
23 changes: 23 additions & 0 deletions model/fpn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import torch.nn as nn
import segmentation_models_pytorch as smp


class FPN(nn.Module):
def __init__(self, num_classes):
super(FPN, self).__init__()

self.model = smp.FPN(
encoder_name='resnet50',
encoder_depth=5,
encoder_weights=None,
decoder_pyramid_channels=256,
decoder_segmentation_channels=128,
decoder_merge_policy='add',
decoder_dropout=0.,
in_channels=3,
classes=num_classes
)

def forward(self, x):
logits = self.model(x)
return [logits]
50 changes: 50 additions & 0 deletions model/losses/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

from pytorch_toolbelt import losses as L

from model.losses.pseudo_ce_loss import PseudoCrossEntropyLoss


class LossFunction(nn.Module):
def __init__(self):
super(LossFunction, self).__init__()

self.loss_func1 = nn.CrossEntropyLoss()
self.loss_func2 = L.DiceLoss(mode='multiclass')

def forward(self, logits, target):
loss = self.loss_func1(logits[0], target) + 0.2 * self.loss_func2(logits[0], target)
return loss


class SelfCorrectionLossFunction(nn.Module):
def __init__(self, cycle=12):
super(SelfCorrectionLossFunction, self).__init__()
self.cycle = cycle

self.sc_loss_func1 = PseudoCrossEntropyLoss()
self.sc_loss_func2 = L.DiceLoss(mode='multiclass')

def forward(self, predicts, target, soft_predict, cycle_n):
with torch.no_grad:
soft_predict = F.softmax(soft_predict, dim=1)
soft_predict = self.weighted(self.to_one_hot(target, soft_predict.size(1)), soft_predict,
alpha=1. / (cycle_n + 1))
loss1 = self.sc_loss_func1(predicts[0], soft_predict)
loss2 = self.sc_loss_func2(predicts, target)
return loss1 + 0.2 * loss2

@staticmethod
def weighted(target_one_hot, soft_predict, alpha):
soft_predict = alpha * target_one_hot + (1 - alpha) * soft_predict
return soft_predict

@staticmethod
def to_one_hot(tensor, num_cls, dim=1, ignore_index=255):
b, h, w = tensor.shape
tensor[tensor == ignore_index] = 0
onehot_tensor = torch.zeros(b, num_cls, h, w).cuda()
onehot_tensor.scatter_(dim, tensor.unsqueeze(dim), 1)
return onehot_tensor
Binary file added model/losses/__pycache__/__init__.cpython-36.pyc
Binary file not shown.
Binary file not shown.
16 changes: 16 additions & 0 deletions model/losses/pseudo_ce_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch import Tensor


class PseudoCrossEntropyLoss(nn.Module):
def __init__(self, dim=1):
super(PseudoCrossEntropyLoss, self).__init__()
self.dim = dim

def forward(self, input: Tensor, target: Tensor):
input_log_prob = F.log_softmax(input, dim=self.dim)
loss = torch.sum(-input_log_prob * target, dim=self.dim)
return loss.mean()
78 changes: 78 additions & 0 deletions model/optim/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import torch.optim as optim

from .radam import RAdam
from .lookahead import Lookahead
from .cyclicLR import CyclicCosAnnealingLR
from .warmup_scheduler import GradualWarmupScheduler


def get_optimizer(params, optimizer_cfg):
if optimizer_cfg['mode'] == 'SGD':
optimizer = optim.SGD(params, lr=optimizer_cfg['lr'], momentum=0.9,
weight_decay=optimizer_cfg['weight_decay'], nesterov=optimizer_cfg['nesterov'])
elif optimizer_cfg['mode'] == 'RAdam':
optimizer = RAdam(params, lr=optimizer_cfg['lr'], betas=(0.9, 0.999),
weight_decay=optimizer_cfg['weight_decay'])
else:
optimizer = optim.Adam(params, lr=optimizer_cfg['lr'], betas=(0.9, 0.999),
weight_decay=optimizer_cfg['weight_decay'])

if optimizer_cfg['lookahead']:
optimizer = Lookahead(optimizer, k=5, alpha=0.5)

# todo: add split_weights.py

return optimizer


def get_scheduler(optimizer, scheduler_cfg):
MODE = scheduler_cfg['mode']

if MODE == 'OneCycleLR':
scheduler = optim.lr_scheduler.OneCycleLR(optimizer,
max_lr=optimizer.param_groups[0]['lr'],
total_steps=scheduler_cfg['steps'],
pct_start=scheduler_cfg['pct_start'],
final_div_factor=scheduler_cfg['final_div_factor'],
cycle_momentum=scheduler_cfg['cycle_momentum'],
anneal_strategy=scheduler_cfg['anneal_strategy'])

elif MODE == 'PolyLR':
lr_lambda = lambda step: (1 - step / scheduler_cfg['steps']) ** scheduler_cfg['power']
scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)

elif MODE == 'CosineAnnealingLR':
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=scheduler_cfg['steps'],
eta_min=scheduler_cfg['eta_min'])

elif MODE == 'MultiStepLR':
scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
scheduler_cfg['milestones'],
gamma=scheduler_cfg['gamma'])

elif MODE == 'CosineAnnealingWarmRestarts':
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer,
T_0=scheduler_cfg['T_0'],
T_mult=scheduler_cfg['T_multi'],
eta_min=scheduler_cfg['eta_min'])

elif MODE == 'CyclicCosAnnealingLR':
scheduler = CyclicCosAnnealingLR(optimizer,
milestones=scheduler_cfg['milestones'],
decay_milestones=scheduler_cfg['decay_milestones'],
eta_min=scheduler_cfg['eta_min'],
gamma=scheduler_cfg['gamma'])

elif scheduler_cfg.MODE == 'GradualWarmupScheduler':
milestones = list(map(lambda x: x - scheduler_cfg['warmup_steps'], scheduler_cfg['milestones']))
scheduler_steplr = optim.lr_scheduler.MultiStepLR(optimizer,
milestones=milestones,
gamma=scheduler_cfg['gamma'])
scheduler = GradualWarmupScheduler(optimizer,
multiplier=scheduler_cfg['milestones'],
total_epoch=scheduler_cfg['warmup_steps'],
after_scheduler=scheduler_steplr)
else:
raise ValueError

return scheduler
Binary file added model/optim/__pycache__/__init__.cpython-36.pyc
Binary file not shown.
Binary file added model/optim/__pycache__/cyclicLR.cpython-36.pyc
Binary file not shown.
Binary file added model/optim/__pycache__/lookahead.cpython-36.pyc
Binary file not shown.
Binary file added model/optim/__pycache__/radam.cpython-36.pyc
Binary file not shown.
Binary file not shown.
125 changes: 125 additions & 0 deletions model/optim/cyclicLR.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import math
from bisect import bisect_right
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.optimizer import Optimizer


class CyclicCosAnnealingLR(_LRScheduler):
r"""
Implements reset on milestones inspired from CosineAnnealingLR pytorch
Set the learning rate of each parameter group using a cosine annealing
schedule, where :math:`\eta_{max}` is set to the initial lr and
:math:`T_{cur}` is the number of epochs since the last restart in SGDR:
.. math::
\eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})(1 +
\cos(\frac{T_{cur}}{T_{max}}\pi))
When last_epoch > last set milestone, lr is automatically set to \eta_{min}
It has been proposed in
`SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this only
implements the cosine annealing part of SGDR, and not the restarts.
Args:
optimizer (Optimizer): Wrapped optimizer.
milestones (list of ints): List of epoch indices. Must be increasing.
decay_milestones(list of ints):List of increasing epoch indices. Ideally,decay values should overlap with milestone points
gamma (float): factor by which to decay the max learning rate at each decay milestone
eta_min (float): Minimum learning rate. Default: 1e-6
last_epoch (int): The index of last epoch. Default: -1.
.. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
https://arxiv.org/abs/1608.03983
"""

def __init__(self, optimizer, milestones, decay_milestones=None, gamma=0.5, eta_min=1e-6, last_epoch=-1):
if not list(milestones) == sorted(milestones):
raise ValueError('Milestones should be a list of'
' increasing integers. Got {}', milestones)
self.eta_min = eta_min
self.milestones = milestones
self.milestones2 = decay_milestones

self.gamma = gamma
super(CyclicCosAnnealingLR, self).__init__(optimizer, last_epoch)

def get_lr(self):

if self.last_epoch >= self.milestones[-1]:
return [self.eta_min for base_lr in self.base_lrs]

idx = bisect_right(self.milestones, self.last_epoch)

left_barrier = 0 if idx == 0 else self.milestones[idx - 1]
right_barrier = self.milestones[idx]

width = right_barrier - left_barrier
curr_pos = self.last_epoch - left_barrier

if self.milestones2:
return [self.eta_min + (
base_lr * self.gamma ** bisect_right(self.milestones2, self.last_epoch) - self.eta_min) *
(1 + math.cos(math.pi * curr_pos / width)) / 2
for base_lr in self.base_lrs]
else:
return [self.eta_min + (base_lr - self.eta_min) *
(1 + math.cos(math.pi * curr_pos / width)) / 2
for base_lr in self.base_lrs]


class CyclicLinearLR(_LRScheduler):
r"""
Implements reset on milestones inspired from Linear learning rate decay
Set the learning rate of each parameter group using a linear decay
schedule, where :math:`\eta_{max}` is set to the initial lr and
:math:`T_{cur}` is the number of epochs since the last restart:
.. math::
\eta_t = \eta_{min} + (\eta_{max} - \eta_{min})(1 -\frac{T_{cur}}{T_{max}})
When last_epoch > last set milestone, lr is automatically set to \eta_{min}
Args:
optimizer (Optimizer): Wrapped optimizer.
milestones (list of ints): List of epoch indices. Must be increasing.
decay_milestones(list of ints):List of increasing epoch indices. Ideally,decay values should overlap with milestone points
gamma (float): factor by which to decay the max learning rate at each decay milestone
eta_min (float): Minimum learning rate. Default: 1e-6
last_epoch (int): The index of last epoch. Default: -1.
.. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
https://arxiv.org/abs/1608.03983
"""

def __init__(self, optimizer, milestones, decay_milestones=None, gamma=0.5, eta_min=1e-6, last_epoch=-1):
if not list(milestones) == sorted(milestones):
raise ValueError('Milestones should be a list of'
' increasing integers. Got {}', milestones)
self.eta_min = eta_min

self.gamma = gamma
self.milestones = milestones
self.milestones2 = decay_milestones
super(CyclicLinearLR, self).__init__(optimizer, last_epoch)

def get_lr(self):

if self.last_epoch >= self.milestones[-1]:
return [self.eta_min for base_lr in self.base_lrs]

idx = bisect_right(self.milestones, self.last_epoch)

left_barrier = 0 if idx == 0 else self.milestones[idx - 1]
right_barrier = self.milestones[idx]

width = right_barrier - left_barrier
curr_pos = self.last_epoch - left_barrier

if self.milestones2:
return [self.eta_min + (
base_lr * self.gamma ** bisect_right(self.milestones2, self.last_epoch) - self.eta_min) *
(1. - 1.0 * curr_pos / width)
for base_lr in self.base_lrs]

else:
return [self.eta_min + (base_lr - self.eta_min) *
(1. - 1.0 * curr_pos / width)
for base_lr in self.base_lrs]
100 changes: 100 additions & 0 deletions model/optim/lookahead.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import torch
from torch.optim import Optimizer
from collections import defaultdict


class Lookahead(Optimizer):
'''
PyTorch implementation of the lookahead wrapper.
Lookahead Optimizer: https://arxiv.org/abs/1907.08610
'''

def __init__(self, optimizer, alpha=0.5, k=6, pullback_momentum="none"):
'''
:param optimizer:inner optimizer
:param k (int): number of lookahead steps
:param alpha(float): linear interpolation factor. 1.0 recovers the inner optimizer.
:param pullback_momentum (str): change to inner optimizer momentum on interpolation update
'''
if not 0.0 <= alpha <= 1.0:
raise ValueError(f'Invalid slow update rate: {alpha}')
if not 1 <= k:
raise ValueError(f'Invalid lookahead steps: {k}')
self.optimizer = optimizer
self.param_groups = self.optimizer.param_groups
self.alpha = alpha
self.k = k
self.step_counter = 0
assert pullback_momentum in ["reset", "pullback", "none"]
self.pullback_momentum = pullback_momentum
self.state = defaultdict(dict)

# Cache the current optimizer parameters
for group in self.optimizer.param_groups:
for p in group['params']:
param_state = self.state[p]
param_state['cached_params'] = torch.zeros_like(p.data)
param_state['cached_params'].copy_(p.data)

def __getstate__(self):
return {
'state': self.state,
'optimizer': self.optimizer,
'alpha': self.alpha,
'step_counter': self.step_counter,
'k': self.k,
'pullback_momentum': self.pullback_momentum
}

def zero_grad(self):
self.optimizer.zero_grad()

def state_dict(self):
return self.optimizer.state_dict()

def load_state_dict(self, state_dict):
self.optimizer.load_state_dict(state_dict)

def _backup_and_load_cache(self):
"""Useful for performing evaluation on the slow weights (which typically generalize better)
"""
for group in self.optimizer.param_groups:
for p in group['params']:
param_state = self.state[p]
param_state['backup_params'] = torch.zeros_like(p.data)
param_state['backup_params'].copy_(p.data)
p.data.copy_(param_state['cached_params'])

def _clear_and_load_backup(self):
for group in self.optimizer.param_groups:
for p in group['params']:
param_state = self.state[p]
p.data.copy_(param_state['backup_params'])
del param_state['backup_params']

def step(self, closure=None):
"""Performs a single Lookahead optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = self.optimizer.step(closure)
self.step_counter += 1

if self.step_counter >= self.k:
self.step_counter = 0
# Lookahead and cache the current optimizer parameters
for group in self.optimizer.param_groups:
for p in group['params']:
param_state = self.state[p]
p.data.mul_(self.alpha).add_(1.0 - self.alpha, param_state['cached_params']) # crucial line
param_state['cached_params'].copy_(p.data)
if self.pullback_momentum == "pullback":
internal_momentum = self.optimizer.state[p]["momentum_buffer"]
self.optimizer.state[p]["momentum_buffer"] = internal_momentum.mul_(self.alpha).add_(
1.0 - self.alpha, param_state["cached_mom"])
param_state["cached_mom"] = self.optimizer.state[p]["momentum_buffer"]
elif self.pullback_momentum == "reset":
self.optimizer.state[p]["momentum_buffer"] = torch.zeros_like(p.data)

return loss
Loading

0 comments on commit ef6f62d

Please sign in to comment.