Skip to content

Commit

Permalink
Fix AGC model bug
Browse files Browse the repository at this point in the history
  • Loading branch information
vballoli committed Mar 9, 2021
1 parent 5dcdba6 commit ad2dbdc
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 4 deletions.
3 changes: 1 addition & 2 deletions nfnets/agc.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@ def __init__(self, params, optim: optim.Optimizer, clipping: float = 1e-2, eps:
module in model.named_modules() if name not in ignore_agc]

else:
params = [{"params": list(module.parameters())} for name,
module in model.named_modules()]
params = [{"params": params}]

self.agc_params = params
self.eps = eps
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
setup(
name = 'nfnets-pytorch',
packages = find_packages(),
version = '0.1.1',
version = '0.1.2',
license='MIT',
description = 'NFNets, PyTorch',
long_description=long_description,
Expand Down
11 changes: 10 additions & 1 deletion tests/test_agc.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,20 @@

from nfnets import replace_conv, AGC

def test_agc():
def test_agc_model():
model = resnet18()
replace_conv(model)
optim = SGD(model.parameters(), 1e-3)
optim = AGC(model.parameters(), optim, model=model)
optim.zero_grad()
model(torch.randn(1,3,64,64)).sum().backward()
optim.step()

def test_agc():
model = resnet18()
replace_conv(model)
optim = SGD(model.parameters(), 1e-3)
optim = AGC(model.parameters(), optim)
optim.zero_grad()
model(torch.randn(1,3,64,64)).sum().backward()
optim.step()

0 comments on commit ad2dbdc

Please sign in to comment.