-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathweight_std.py
97 lines (75 loc) · 4.03 KB
/
weight_std.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
"""
---
title: Weight Standardization
summary: >
A PyTorch implementation/tutorial of Weight Standardization.
---
# Weight Standardization
This is a [PyTorch](https://pytorch.org) implementation of Weight Standardization from the paper
[Micro-Batch Training with Batch-Channel Normalization and Weight Standardization](https://papers.labml.ai/paper/1903.10520).
We also have an [annotated implementation of Batch-Channel Normalization](../batch_channel_norm/index.html).
Batch normalization **gives a smooth loss landscape** and
**avoids elimination singularities**.
Elimination singularities are nodes of the network that become
useless (e.g. a ReLU that gives 0 all the time).
However, batch normalization doesn't work well when the batch size is too small,
which happens when training large networks because of device memory limitations.
The paper introduces Weight Standardization with Batch-Channel Normalization as
a better alternative.
Weight Standardization:
1. Normalizes the gradients
2. Smoothes the landscape (reduced Lipschitz constant)
3. Avoids elimination singularities
The Lipschitz constant is the maximum slope a function has between two points.
That is, $L$ is the Lipschitz constant where $L$ is the smallest value that satisfies,
$\forall a,b \in A: \lVert f(a) - f(b) \rVert \le L \lVert a - b \rVert$
where $f: A \rightarrow \mathbb{R}^m, A \in \mathbb{R}^n$.
Elimination singularities are avoided because it keeps the statistics of the outputs similar to the
inputs. So as long as the inputs are normally distributed the outputs remain close to normal.
This avoids outputs of nodes from always falling beyond the active range of the activation function
(e.g. always negative input for a ReLU).
*[Refer to the paper for proofs](https://papers.labml.ai/paper/1903.10520)*.
Here is [the training code](experiment.html) for training
a VGG network that uses weight standardization to classify CIFAR-10 data.
This uses a [2D-Convolution Layer with Weight Standardization](conv2d.html).
[](https://colab.research.google.com/github/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/normalization/weight_standardization/experiment.ipynb)
[](https://app.labml.ai/run/f4a783a2a7df11eb921d0242ac1c0002)
"""
import torch
def weight_standardization(weight: torch.Tensor, eps: float = 1e-5):
r"""
## Weight Standardization
$$\hat{W}_{i,j} = \frac{W_{i,j} - \mu_{W_{i,\cdot}}} {\sigma_{W_{i,\cdot}}}$$
where,
\begin{align}
W &\in \mathbb{R}^{O \times I} \\
\mu_{W_{i,\cdot}} &= \frac{1}{I} \sum_{j=1}^I W_{i,j} \\
\sigma_{W_{i,\cdot}} &= \sqrt{\frac{1}{I} \sum_{j=1}^I W^2_{i,j} - \mu^2_{W_{i,\cdot}} + \epsilon} \\
\end{align}
for a 2D-convolution layer $O$ is the number of output channels ($O = C_{out}$)
and $I$ is the number of input channels times the kernel size ($I = C_{in} \times k_H \times k_W$)
"""
# Get $C_{out}$, $C_{in}$ and kernel shape
c_out, c_in, *kernel_shape = weight.shape
# Reshape $W$ to $O \times I$
weight = weight.view(c_out, -1)
# Calculate
#
# \begin{align}
# \mu_{W_{i,\cdot}} &= \frac{1}{I} \sum_{j=1}^I W_{i,j} \\
# \sigma^2_{W_{i,\cdot}} &= \frac{1}{I} \sum_{j=1}^I W^2_{i,j} - \mu^2_{W_{i,\cdot}}
# \end{align}
var, mean = torch.var_mean(weight, dim=1, keepdim=True)
# Normalize
# $$\hat{W}_{i,j} = \frac{W_{i,j} - \mu_{W_{i,\cdot}}} {\sigma_{W_{i,\cdot}}}$$
weight = (weight - mean) / (torch.sqrt(var + eps))
# Change back to original shape and return
return weight.view(c_out, c_in, *kernel_shape)
from torch.functional import F
from torch import nn
class Linear_wstd(nn.Linear):
def __init__(self, in_channels, out_channels, eps: float = 1e-5, bias: bool = True):
super().__init__(in_channels, out_channels, bias=bias)
self.eps = eps
def forward(self, x: torch.Tensor):
return F.linear(x, weight_standardization(self.weight, self.eps), self.bias)