Skip to content

Commit

Permalink
Minor changes for Release 0.0.7
Browse files Browse the repository at this point in the history
  • Loading branch information
vballoli committed Feb 21, 2021
1 parent c54093f commit 244319a
Show file tree
Hide file tree
Showing 7 changed files with 85 additions and 33 deletions.
35 changes: 21 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@ or install the latest code using:
# Usage
## WSConv2d

Use `WSConv2d` and `WSConvTranspose2d` like any other `torch.nn.Conv2d` or `torch.nn.ConvTranspose2d` modules.
Use `WSConv1d, WSConv2d, ScaledStdConv2d(timm)` and `WSConvTranspose2d` like any other `torch.nn.Conv2d` or `torch.nn.ConvTranspose2d` modules.

```python
import torch
from torch import nn
from nfnets import WSConv2d
from nfnets import WSConv2d, WSConvTranspose2d, ScaledStdConv2d

conv = nn.Conv2d(3,6,3)
w_conv = WSConv2d(3,6,3)
Expand Down Expand Up @@ -75,31 +75,29 @@ optim_agc = SGD_AGC(conv.parameters(), 1e-3)

## Using it within any PyTorch model

`replace_conv` replaces the convolution in your model with the convolution class and replaces the batchnorm with identity. While the identity is not ideal, it shouldn't cause a major difference in the latency.
```python
import torch
from torch import nn
from torchvision.models import resnet18

from nfnets import replace_conv
from nfnets import replace_conv, WSConv2d, ScaledStdConv2d

model = resnet18()
replace_conv(model)
replace_conv(model, WSConv2d) # This repo's original implementation
replace_conv(model, ScaledStdConv2d) # From timm

"""
class YourCustomClass(nn.Conv2d):
...
replace_conv(model, YourCustomClass)
"""
```

# Docs

Find the docs at [readthedocs](https://nfnets-pytorch.readthedocs.io/en/latest/)

# TODO
- [x] WSConv2d
- [x] SGD - Adaptive Gradient Clipping
- [x] Function to automatically replace Convolutions in any module with WSConv2d
- [x] Documentation
- [x] Generic AGC wrapper.(See [this comment](https://github.com/vballoli/nfnets-pytorch/issues/1#issuecomment-778853439) for a reference implementation) (Needs testing for now)
- [x] WSConvTranspose2d
- [ ] NFNets
- [ ] NF-ResNets

# Cite Original Work

To cite the original paper, use:
Expand All @@ -111,3 +109,12 @@ To cite the original paper, use:
year={2021}
}
```

# TODO
- [x] WSConv2d
- [x] SGD - Adaptive Gradient Clipping
- [x] Function to automatically replace Convolutions in any module with WSConv2d
- [x] Documentation
- [x] Generic AGC wrapper.(See [this comment](https://github.com/vballoli/nfnets-pytorch/issues/1#issuecomment-778853439) for a reference implementation) (Needs testing for now)
- [x] WSConvTranspose2d
- [x] WSConv1d(Thanks to [@shi27feng](https://github.com/shi27feng))
8 changes: 5 additions & 3 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,13 @@ Sample usage
from torch import nn
from torchvision.models import resnet18
from nfnets import replace_conv, SGD_AGC
from nfnets import replace_conv, AGC, WSConv2d, ScaledStdConv2d
model = resnet18()
replace_conv(model)
optim = SGD_AGC(model.parameters(), 1e-3)
replace_conv(model, ScaledStdConv2d) # Original repo's implementation
replace_conv(model, ScaledStdConv2d) # timm
optim = torch.optim.SGD(model.parameters(), 1e-3) # Or any of your favourite optimizer
optim = AGC(model.parameters(), optim)
.. toctree::
:maxdepth: 2
Expand Down
2 changes: 1 addition & 1 deletion nfnets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .base import WSConv2d, WSConvTranspose2d
from .base import WSConv1d, WSConv2d, WSConvTranspose2d, ScaledStdConv2d
from .sgd_agc import SGD_AGC
from .agc import AGC
from .utils import replace_conv
61 changes: 51 additions & 10 deletions nfnets/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import torch
from torch import nn
from torch.functional import F
from torch.nn.common_types import _size_1_t, _size_2_t, _size_3_t
from torch import Tensor

from typing import Optional, List, Tuple
Expand Down Expand Up @@ -106,7 +105,8 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0,
dilation=dilation, groups=groups, bias=bias, padding_mode=padding_mode)

nn.init.kaiming_normal_(self.weight)
self.gain = nn.Parameter(torch.ones(self.weight.size()[0], requires_grad=True))
self.gain = nn.Parameter(torch.ones(
self.weight.size()[0], requires_grad=True))

def standardize_weight(self, eps):
var, mean = torch.var_mean(self.weight, dim=(1, 2), keepdims=True)
Expand Down Expand Up @@ -262,7 +262,8 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0,
dilation=dilation, groups=groups, bias=bias, padding_mode=padding_mode)

nn.init.kaiming_normal_(self.weight)
self.gain = nn.Parameter(torch.ones(self.weight.size(0), requires_grad=True))
self.gain = nn.Parameter(torch.ones(
self.weight.size(0), requires_grad=True))

def standardize_weight(self, eps):
var, mean = torch.var_mean(self.weight, dim=(1, 2, 3), keepdims=True)
Expand Down Expand Up @@ -413,10 +414,10 @@ class WSConvTranspose2d(nn.ConvTranspose2d):
def __init__(self,
in_channels: int,
out_channels: int,
kernel_size: _size_2_t,
stride: _size_2_t = 1,
padding: _size_2_t = 0,
output_padding: _size_2_t = 0,
kernel_size,
stride=1,
padding=0,
output_padding=0,
groups: int = 1,
bias: bool = True,
dilation: int = 1,
Expand All @@ -425,7 +426,8 @@ def __init__(self,
output_padding=output_padding, groups=groups, bias=bias, dilation=dilation, padding_mode=padding_mode)

nn.init.kaiming_normal_(self.weight)
self.gain = nn.Parameter(torch.ones(self.weight.size(0), requires_grad=True))
self.gain = nn.Parameter(torch.ones(
self.weight.size(0), requires_grad=True))

def standardize_weight(self, eps):
var, mean = torch.var_mean(self.weight, dim=(1, 2, 3), keepdims=True)
Expand All @@ -436,6 +438,45 @@ def standardize_weight(self, eps):
shift = mean * scale
return self.weight * scale - shift

def forward(self, input: Tensor, output_size: Optional[List[int]] = None, eps: float=1e-4) -> Tensor:
weight = self.standardize_weight()
def forward(self, input: Tensor, output_size: Optional[List[int]] = None, eps: float = 1e-4) -> Tensor:
weight = self.standardize_weight(eps)
return F.conv_transpose2d(input, self.weight, self.bias, self.stride, self.padding, self.output_padding, self.groups, self.dilation)


class ScaledStdConv2d(nn.Conv2d):
"""Conv2d layer with Scaled Weight Standardization.
Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets` -
https://arxiv.org/abs/2101.08692
Adapted from timm: https://github.com/rwightman/pytorch-image-models/blob/4ea593196414684d2074cbb81d762f3847738484/timm/models/layers/std_conv.py
"""

def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=None, dilation=1, groups=1,
bias=True, gain=True, gamma=1.0, eps=1e-5, use_layernorm=False):
if padding is None:
padding = get_padding(kernel_size, stride, dilation)
super().__init__(
in_channels, out_channels, kernel_size, stride=stride,
padding=padding, dilation=dilation, groups=groups, bias=bias)
self.gain = nn.Parameter(torch.ones(
self.out_channels, 1, 1, 1)) if gain else None
# gamma * 1 / sqrt(fan-in)
self.scale = gamma * self.weight[0].numel() ** -0.5
self.eps = eps ** 2 if use_layernorm else eps
# experimental, slightly faster/less GPU memory use
self.use_layernorm = use_layernorm

def get_weight(self):
if self.use_layernorm:
weight = self.scale * \
F.layer_norm(self.weight, self.weight.shape[1:], eps=self.eps)
else:
std, mean = torch.std_mean(
self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
weight = self.scale * (self.weight - mean) / (std + self.eps)
if self.gain is not None:
weight = weight * self.gain
return weight

def forward(self, x):
return F.conv2d(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups)
7 changes: 4 additions & 3 deletions nfnets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,18 @@
from nfnets import WSConv2d


def replace_conv(module: nn.Module):
def replace_conv(module: nn.Module, conv_class=WSConv2d):
"""Recursively replaces every convolution with WSConv2d.
Usage: replace_conv(model) #(In-line replacement)
Args:
module(nn.Module): target's model whose convolutions must be replaced.
module (nn.Module): target's model whose convolutions must be replaced.
conv_class (Class): Class of Conv(WSConv2d or ScaledStdConv2d)
"""
for name, mod in module.named_children():
target_mod = getattr(module, name)
if type(mod) == torch.nn.Conv2d:
setattr(module, name, WSConv2d(target_mod.in_channels, target_mod.out_channels, target_mod.kernel_size,
setattr(module, name, conv_class(target_mod.in_channels, target_mod.out_channels, target_mod.kernel_size,
target_mod.stride, target_mod.padding, target_mod.dilation, target_mod.groups, target_mod.bias))

if type(mod) == torch.nn.BatchNorm2d:
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
torch>=1.5.0
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
setup(
name = 'nfnets-pytorch',
packages = find_packages(),
version = '0.0.6',
version = '0.0.7',
license='MIT',
description = 'NFNets, PyTorch',
long_description=long_description,
Expand All @@ -23,7 +23,7 @@
'adaptive gradient clipping'
],
install_requires=[
'torch',
'torch>=1.5.0',
'torchvision',
],
classifiers=[
Expand Down

0 comments on commit 244319a

Please sign in to comment.