Skip to content

Commit

Permalink
[scripts] implement max-change within customized SGD optimizer
Browse files Browse the repository at this point in the history
  • Loading branch information
aadps committed Apr 11, 2020
1 parent d211d33 commit 5ccb456
Show file tree
Hide file tree
Showing 2 changed files with 151 additions and 9 deletions.
136 changes: 136 additions & 0 deletions egs/aishell/s10/chain/sgd_max_change.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import torch
from torch.optim.optimizer import Optimizer, required


class SgdMaxChange(Optimizer):
r"""Implements stochastic gradient descent (optionally with momentum and max
change).
Nesterov momentum is based on the formula from
`On the importance of initialization and momentum in deep learning`__.
Args:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float): learning rate
momentum (float, optional): momentum factor (default: 0)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
dampening (float, optional): dampening for momentum (default: 0)
nesterov (bool, optional): enables Nesterov momentum (default: False)
max_change_per_layer (float, optional): change in parameters allowed of
any given layer, on any given batch, measured in l2 norm
max_change (float, optional): change in parameters allowed of the whole
model, after applying the per-layer constraint
Example:
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
>>> optimizer.zero_grad()
>>> loss_fn(model(input), target).backward()
>>> optimizer.step()
__ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf
.. note::
The implementation of SGD with Momentum/Nesterov subtly differs from
Sutskever et. al. and implementations in some other frameworks.
Considering the specific case of Momentum, the update can be written as
.. math::
\begin{aligned}
v_{t+1} & = \mu * v_{t} + g_{t+1}, \\
p_{t+1} & = p_{t} - \text{lr} * v_{t+1},
\end{aligned}
where :math:`p`, :math:`g`, :math:`v` and :math:`\mu` denote the
parameters, gradient, velocity, and momentum respectively.
This is in contrast to Sutskever et. al. and
other frameworks which employ an update of the form
.. math::
\begin{aligned}
v_{t+1} & = \mu * v_{t} + \text{lr} * g_{t+1}, \\
p_{t+1} & = p_{t} - v_{t+1}.
\end{aligned}
The Nesterov version is analogously modified.
"""

def __init__(self, params, lr=required, momentum=0, dampening=0,
weight_decay=0, nesterov=False, max_change_per_layer=0.75, max_change=1.5):
if lr is not required and lr < 0.0:
raise ValueError("Invalid learning rate: {}".format(lr))
if momentum < 0.0:
raise ValueError("Invalid momentum value: {}".format(momentum))
if weight_decay < 0.0:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
if max_change_per_layer < 0.01:
raise ValueError("Invalid max_change_per_layer value: {}".format(max_change_per_layer))
if max_change < 0.01:
raise ValueError("Invalid max_change value: {}".format(max_change))

defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
weight_decay=weight_decay, nesterov=nesterov,
max_change_per_layer=max_change_per_layer,
max_change=max_change)
if nesterov and (momentum <= 0 or dampening != 0):
raise ValueError("Nesterov momentum requires a momentum and zero dampening")
super(SgdMaxChange, self).__init__(params, defaults)

def __setstate__(self, state):
super(SgdMaxChange, self).__setstate__(state)
for group in self.param_groups:
group.setdefault('nesterov', False)

@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
change = 0

for group in self.param_groups:
weight_decay = group['weight_decay']
momentum = group['momentum']
dampening = group['dampening']
nesterov = group['nesterov']
max_change_per_layer = group['max_change_per_layer']
max_change = group['max_change']

delta = []
total_norm = 0

for i in range(len(group['params'])):
p = group['params'][i]
if p.grad is None:
continue
d_p = p.grad
if weight_decay != 0:
d_p = d_p.add(p, alpha=weight_decay)
if momentum != 0:
param_state = self.state[p]
if 'momentum_buffer' not in param_state:
buf = param_state['momentum_buffer'] = torch.clone(d_p).detach()
else:
buf = param_state['momentum_buffer']
buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
if nesterov:
d_p = d_p.add(buf, alpha=momentum)
else:
d_p = buf
norm = d_p.norm(2).item()
if norm * group['lr'] > max_change_per_layer:
d_p.mul_(max_change_per_layer / (norm * group['lr']))
delta.append(d_p)
total_norm += d_p.norm(2).item() ** 2.

total_norm = total_norm ** 0.5

for i in range(len(group['params'])):
p = group['params'][i]
if p.grad is None:
continue
if total_norm * group['lr'] > max_change:
p.add_(delta[i], alpha=-group['lr'] * max_change / (total_norm * group['lr']))
else:
p.add_(delta[i], alpha=-group['lr'])

change += total_norm * group['lr']

return loss, change
24 changes: 15 additions & 9 deletions egs/aishell/s10/chain/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from libs.nnet3.train.dropout_schedule import _get_dropout_proportions
from model import get_chain_model
from options import get_args
from sgd_max_change import SgdMaxChange

def get_objf(batch, model, device, criterion, opts, den_graph, training, optimizer=None, dropout=0.):
feature, supervision = batch
Expand Down Expand Up @@ -67,20 +68,20 @@ def get_objf(batch, model, device, criterion, opts, den_graph, training, optimiz
supervision, nnet_output,
xent_output)
objf = objf_l2_term_weight[0]
change = 0
if training:
optimizer.zero_grad()
objf.backward()
clip_grad_value_(model.parameters(), 5.0)
optimizer.step()
# clip_grad_value_(model.parameters(), 5.0)
_, change = optimizer.step()

objf_l2_term_weight = objf_l2_term_weight.detach().cpu()

total_objf = objf_l2_term_weight[0].item()
total_weight = objf_l2_term_weight[2].item()
total_frames = nnet_output.shape[0]

return total_objf, total_weight, total_frames

return total_objf, total_weight, total_frames, change

def get_validation_objf(dataloader, model, device, criterion, opts, den_graph):
total_objf = 0.
Expand All @@ -90,7 +91,7 @@ def get_validation_objf(dataloader, model, device, criterion, opts, den_graph):
model.eval()

for batch_idx, (pseudo_epoch, batch) in enumerate(dataloader):
objf, weight, frames = get_objf(
objf, weight, frames, _ = get_objf(
batch, model, device, criterion, opts, den_graph, False)
total_objf += objf
total_weight += weight
Expand All @@ -116,7 +117,7 @@ def train_one_epoch(dataloader, valid_dataloader, model, device, optimizer, crit
len(dataloader)) / (len(dataloader) * num_epochs)
_, dropout = _get_dropout_proportions(
dropout_schedule, data_fraction)[0]
curr_batch_objf, curr_batch_weight, curr_batch_frames = get_objf(
curr_batch_objf, curr_batch_weight, curr_batch_frames, curr_batch_change = get_objf(
batch, model, device, criterion, opts, den_graph, True, optimizer, dropout=dropout)

total_objf += curr_batch_objf
Expand All @@ -127,13 +128,13 @@ def train_one_epoch(dataloader, valid_dataloader, model, device, optimizer, crit
logging.info(
'Device ({}) processing batch {}, current pseudo-epoch is {}/{}({:.6f}%), '
'global average objf: {:.6f} over {} '
'frames, current batch average objf: {:.6f} over {} frames, epoch {}'
'frames, current batch average objf: {:.6f} over {} frames, minibatch change: {:.6f}, epoch {}'
.format(
device.index, batch_idx, pseudo_epoch, len(dataloader),
float(pseudo_epoch) / len(dataloader) * 100,
total_objf / total_weight, total_frames,
curr_batch_objf / curr_batch_weight,
curr_batch_frames, current_epoch))
curr_batch_frames, curr_batch_change, current_epoch))

if valid_dataloader and batch_idx % 1000 == 0:
total_valid_objf, total_valid_weight, total_valid_frames = get_validation_objf(
Expand Down Expand Up @@ -167,6 +168,11 @@ def train_one_epoch(dataloader, valid_dataloader, model, device, optimizer, crit
dropout,
pseudo_epoch + current_epoch * len(dataloader))

tf_writer.add_scalar(
'train/current_batch_change',
curr_batch_change,
pseudo_epoch + current_epoch * len(dataloader))

state_dict = model.state_dict()
for key, value in state_dict.items():
# skip batchnorm parameters
Expand Down Expand Up @@ -301,7 +307,7 @@ def process_job(learning_rate, device_id=None, local_rank=None):
else:
valid_dataloader = None

optimizer = optim.Adam(model.parameters(),
optimizer = SgdMaxChange(model.parameters(),
lr=learning_rate,
weight_decay=5e-4)

Expand Down

0 comments on commit 5ccb456

Please sign in to comment.