Skip to content

Commit

Permalink
[scripts] implement max-change within customized SGD optimizer
Browse files Browse the repository at this point in the history
  • Loading branch information
aadps committed Apr 11, 2020
1 parent ce5f93f commit 6b58186
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
from torch.optim.optimizer import Optimizer, required


class SGD_MC(Optimizer):
r"""Implements stochastic gradient descent (optionally with momentum).
class SgdMaxChange(Optimizer):
r"""Implements stochastic gradient descent (optionally with momentum and max
change).
Nesterov momentum is based on the formula from
`On the importance of initialization and momentum in deep learning`__.
Args:
Expand All @@ -14,6 +15,10 @@ class SGD_MC(Optimizer):
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
dampening (float, optional): dampening for momentum (default: 0)
nesterov (bool, optional): enables Nesterov momentum (default: False)
max_change_per_layer (float, optional): change in parameters allowed of
any given layer, on any given batch, measured in l2 norm
max_change (float, optional): change in parameters allowed of the whole
model, after applying the per-layer constraint
Example:
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
>>> optimizer.zero_grad()
Expand Down Expand Up @@ -49,17 +54,21 @@ def __init__(self, params, lr=required, momentum=0, dampening=0,
raise ValueError("Invalid momentum value: {}".format(momentum))
if weight_decay < 0.0:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
if max_change_per_layer < 0.01:
raise ValueError("Invalid max_change_per_layer value: {}".format(max_change_per_layer))
if max_change < 0.01:
raise ValueError("Invalid max_change value: {}".format(max_change))

defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
weight_decay=weight_decay, nesterov=nesterov,
max_change_per_layer=max_change_per_layer,
max_change=max_change)
if nesterov and (momentum <= 0 or dampening != 0):
raise ValueError("Nesterov momentum requires a momentum and zero dampening")
super(SGD_MC, self).__init__(params, defaults)
super(SgdMaxChange, self).__init__(params, defaults)

def __setstate__(self, state):
super(SGD_MC, self).__setstate__(state)
super(SgdMaxChange, self).__setstate__(state)
for group in self.param_groups:
group.setdefault('nesterov', False)

Expand Down Expand Up @@ -107,7 +116,7 @@ def step(self, closure=None):
d_p = buf
norm = d_p.norm(2).item()
if norm * group['lr'] > max_change_per_layer:
d_p.mul_(max_change_per_layer / norm)
d_p.mul_(max_change_per_layer / (norm * group['lr']))
delta.append(d_p)
total_norm += d_p.norm(2).item() ** 2.

Expand All @@ -118,7 +127,7 @@ def step(self, closure=None):
if p.grad is None:
continue
if total_norm * group['lr'] > max_change:
p.add_(delta[i], alpha=-group['lr'] * max_change / total_norm)
p.add_(delta[i], alpha=-group['lr'] * max_change / (total_norm * group['lr']))
else:
p.add_(delta[i], alpha=-group['lr'])

Expand Down
9 changes: 7 additions & 2 deletions egs/aishell/s10/chain/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from libs.nnet3.train.dropout_schedule import _get_dropout_proportions
from model import get_chain_model
from options import get_args
from sgd_mc import SGD_MC
from sgd_max_change import SgdMaxChange

def get_objf(batch, model, device, criterion, opts, den_graph, training, optimizer=None, dropout=0.):
feature, supervision = batch
Expand Down Expand Up @@ -168,6 +168,11 @@ def train_one_epoch(dataloader, valid_dataloader, model, device, optimizer, crit
dropout,
pseudo_epoch + current_epoch * len(dataloader))

tf_writer.add_scalar(
'train/current_batch_change',
curr_batch_change,
pseudo_epoch + current_epoch * len(dataloader))

state_dict = model.state_dict()
for key, value in state_dict.items():
# skip batchnorm parameters
Expand Down Expand Up @@ -302,7 +307,7 @@ def process_job(learning_rate, device_id=None, local_rank=None):
else:
valid_dataloader = None

optimizer = SGD_MC(model.parameters(),
optimizer = SgdMaxChange(model.parameters(),
lr=learning_rate,
weight_decay=5e-4)

Expand Down

0 comments on commit 6b58186

Please sign in to comment.