Skip to content

Releases: vballoli/nfnets-pytorch

Major CUDA bug fix for WSConv2d, Add generic AGC

15 Feb 17:09
Compare
Choose a tag to compare
  1. Incopororated @bfialkoff's fix #3.
  2. Add a generic AGC implementation.

Minor bug fix, improve replace_conv

15 Feb 08:38
Compare
Choose a tag to compare

Fixes of replace_conv and replaces BatchNorm2d with Identity.

Initial release

14 Feb 19:21
b87d84a
Compare
Choose a tag to compare

PyTorch implementation of Normalizer-Free Networks and SGD - Adaptive Gradient Clipping

Paper: https://arxiv.org/abs/2102.06171.pdf
Original code: https://github.com/deepmind/deepmind-research/tree/master/nfnets

Installation

pip3 install git+https://github.com/vballoli/nfnets-pytorch

Usage

WSConv2d

Use WSConv2d like any other torch.nn.Conv2d.

import torch
from torch import nn
from nfnets import WSConv2d

conv = nn.Conv2d(3,6,3)
w_conv = WSConv2d(3,6,3)

SGD - Adaptive Gradient Clipping

Similarly, use SGD_AGC like torch.optim.SGD

import torch
from torch import nn, optim
from nfnets import WSConv2d, SGD_AGC

conv = nn.Conv2d(3,6,3)
w_conv = WSConv2d(3,6,3)

optim = optim.SGD(conv.parameters(), 1e-3)
optim_agc = SGD_AGC(conv.parameters(), 1e-3)

Using it within any PyTorch model

import torch
from torch import nn
from torchvision.models import resnet18

from nfnets import replace_conv

model = resnet18()
replace_conv(model)

Docs

Find the docs at readthedocs

TODO

  • WSConv2d
  • SGD - Adaptive Gradient Clipping
  • Function to automatically replace Convolutions in any module with WSConv2d
  • Documentation
  • NFNets
  • NF-ResNets

Cite Original Work

To cite the original paper, use:

@article{brock2021high,
  author={Andrew Brock and Soham De and Samuel L. Smith and Karen Simonyan},
  title={High-Performance Large-Scale Image Recognition Without Normalization},
  journal={arXiv preprint arXiv:},
  year={2021}
}