Skip to content

Commit

Permalink
Fix agc trigger
Browse files Browse the repository at this point in the history
  • Loading branch information
vballoli committed Mar 3, 2021
1 parent 2f88915 commit bf8ee3b
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion nfnets/agc.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def step(self, closure=None):
grad_norm = unitwise_norm(p.grad.detach())
max_norm = param_norm * group['clipping']

trigger = grad_norm < max_norm
trigger = grad_norm > max_norm

clipped_grad = p.grad * \
(max_norm / torch.max(grad_norm,
Expand Down
2 changes: 1 addition & 1 deletion nfnets/sgd_agc.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def step(self, closure=None):
grad_norm = unitwise_norm(p.grad.detach())
max_norm = param_norm * group['clipping']

trigger = grad_norm < max_norm
trigger = grad_norm > max_norm

clipped_grad = p.grad * \
(max_norm / torch.max(grad_norm,
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
setup(
name = 'nfnets-pytorch',
packages = find_packages(),
version = '0.0.8',
version = '0.0.9',
license='MIT',
description = 'NFNets, PyTorch',
long_description=long_description,
Expand Down

0 comments on commit bf8ee3b

Please sign in to comment.