Skip to content

Commit

Permalink
Merge branch 'development' into kfac-conv2d
Browse files Browse the repository at this point in the history
Conflicts:
	curvlinops/kfac.py
	test/kfac_cases.py
	test/test_kfac.py
	test/utils.py
  • Loading branch information
runame committed Jan 12, 2024
2 parents 7039087 + 5f60692 commit 9878905
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 95 deletions.
86 changes: 49 additions & 37 deletions curvlinops/kfac.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,23 @@
"""Linear operator for the Fisher/GGN's Kronecker-factored approximation.
Kronecker-factored approximate curvature was originally introduced for MLPs in
Kronecker-Factored Approximate Curvature (KFAC) was originally introduced for MLPs in
- Martens, J., & Grosse, R. (2015). Optimizing neural networks with Kronecker-factored
approximate curvature. International Conference on Machine Learning (ICML).
approximate curvature. International Conference on Machine Learning (ICML),
and extended to CNNs in
extended to CNNs in
- Grosse, R., & Martens, J. (2016). A kronecker-factored approximate Fisher matrix for
convolution layers. International Conference on Machine Learning (ICML).
convolution layers. International Conference on Machine Learning (ICML),
and generalized to all linear layers with weight sharing in
- Eschenhagen, R., Immer, A., Turner, R. E., Schneider, F., Hennig, P. (2023).
Kronecker-Factored Approximate Curvature for Modern Neural Network Architectures (NeurIPS).
"""

from __future__ import annotations

import warnings
from functools import partial
from math import sqrt
from typing import Dict, Iterable, List, Set, Tuple, Union
Expand Down Expand Up @@ -59,15 +63,20 @@ class KFACLinearOperator(_LinearOperator):
'would-be' gradients w.r.t. the layer's output. Those 'would-be' gradients result
from sampling labels from the model's distribution and computing their gradients.
The basic version of KFAC for MLPs was introduced in
Kronecker-Factored Approximate Curvature (KFAC) was originally introduced for MLPs in
- Martens, J., & Grosse, R. (2015). Optimizing neural networks with Kronecker-factored
approximate curvature. International Conference on Machine Learning (ICML),
- Martens, J., & Grosse, R. (2015). Optimizing neural networks with
Kronecker-factored approximate curvature. ICML.
extended to CNNs in
and later generalized to convolutions in
- Grosse, R., & Martens, J. (2016). A kronecker-factored approximate Fisher matrix for
convolution layers. International Conference on Machine Learning (ICML),
- Grosse, R., & Martens, J. (2016). A kronecker-factored approximate Fisher
matrix for convolution layers. ICML.
and generalized to all linear layers with weight sharing in
- Eschenhagen, R., Immer, A., Turner, R. E., Schneider, F., Hennig, P. (2023).
Kronecker-Factored Approximate Curvature for Modern Neural Network Architectures (NeurIPS).
Attributes:
_SUPPORTED_LOSSES: Tuple of supported loss functions.
Expand Down Expand Up @@ -139,21 +148,21 @@ def __init__(
Has to be set to ``1`` when ``fisher_type != 'mc'``.
Defaults to ``1``.
kfac_approx: A string specifying the KFAC approximation that should
be used for linear weight-sharing layers, e.g. `Conv2d` modules
or `Linear` modules that process matrix- or higher-dimensional
be used for linear weight-sharing layers, e.g. ``Conv2d`` modules
or ``Linear`` modules that process matrix- or higher-dimensional
features.
Possible values are ``'expand'`` and ``'reduce'``.
See [Eschenhagen et al., 2023](https://arxiv.org/abs/2311.00636)
See `Eschenhagen et al., 2023 <https://arxiv.org/abs/2311.00636>`_
for an explanation of the two approximations.
loss_average: Whether the loss function is a mean over per-sample
losses and if yes, over which dimensions the mean is taken.
If `"batch"`, the loss function is a mean over as many terms as
the size of the mini-batch. If `"batch+sequence"`, the loss
If ``"batch"``, the loss function is a mean over as many terms as
the size of the mini-batch. If ``"batch+sequence"``, the loss
function is a mean over as many terms as the size of the
mini-batch times the sequence length, e.g. in the case of
language modeling. If `None`, the loss function is a sum. This
language modeling. If ``None``, the loss function is a sum. This
argument is used to ensure that the preconditioner is scaled
consistently with the loss and the gradient. Default: `"batch"`.
consistently with the loss and the gradient. Default: ``"batch"``.
separate_weight_and_bias: Whether to treat weights and biases separately.
Defaults to ``True``.
Expand All @@ -162,6 +171,8 @@ def __init__(
ValueError: If the loss average is not supported.
ValueError: If the loss average is ``None`` and the loss function's
reduction is not ``'sum'``.
ValueError: If the loss average is not ``None`` and the loss function's
reduction is ``'sum'``.
ValueError: If ``fisher_type != 'mc'`` and ``mc_samples != 1``.
NotImplementedError: If a parameter is in an unsupported layer.
"""
Expand All @@ -180,13 +191,10 @@ def __init__(
f"Must be 'batch' or 'batch+sequence' if loss_func.reduction != 'sum'."
)
if loss_func.reduction == "sum" and loss_average is not None:
# reduction used in loss function will overwrite loss_average
warnings.warn(
raise ValueError(
f"Loss function uses reduction='sum', but loss_average={loss_average}."
" loss_average is set to None.",
stacklevel=2,
" Set loss_average to None if you want to use reduction='sum'."
)
loss_average = None
if fisher_type != "mc" and mc_samples != 1:
raise ValueError(
f"Invalid mc_samples: {mc_samples}. "
Expand Down Expand Up @@ -344,13 +352,12 @@ def _compute_loss_and_backward(self, output: Tensor, y: Tensor):
``'empirical'``.
"""
# if >2d output we convert to an equivalent 2d output
if output.ndim > 2:
if isinstance(self._loss_func, CrossEntropyLoss):
output = rearrange(output, "batch c ... -> (batch ...) c")
y = rearrange(y, "batch ... -> (batch ...)")
else:
output = rearrange(output, "batch ... c -> (batch ...) c")
y = rearrange(y, "batch ... c -> (batch ...) c")
if isinstance(self._loss_func, CrossEntropyLoss):
output = rearrange(output, "batch c ... -> (batch ...) c")
y = rearrange(y, "batch ... -> (batch ...)")
else:
output = rearrange(output, "batch ... c -> (batch ...) c")
y = rearrange(y, "batch ... c -> (batch ...) c")

if self._fisher_type == "type-2":
# Compute per-sample Hessian square root, then concatenate over samples.
Expand Down Expand Up @@ -410,12 +417,16 @@ def draw_label(self, output: Tensor) -> Tensor:
together with ``output``.
Raises:
ValueError: If the output is not 2d.
NotImplementedError: If the loss function is not supported.
"""
if output.ndim != 2:
raise ValueError("Only a 2d output is supported.")

if isinstance(self._loss_func, MSELoss):
std = {
"sum": sqrt(1.0 / 2.0),
"mean": sqrt(output.shape[1:].numel() / 2.0),
"mean": sqrt(output.shape[1] / 2.0),
}[self._loss_func.reduction]
perturbation = std * randn(
output.shape,
Expand All @@ -426,7 +437,6 @@ def draw_label(self, output: Tensor) -> Tensor:
return output.clone().detach() + perturbation

elif isinstance(self._loss_func, CrossEntropyLoss):
# each row contains a vector describing a categorical
probs = output.softmax(dim=1)
labels = probs.multinomial(
num_samples=1, generator=self._generator
Expand Down Expand Up @@ -471,13 +481,15 @@ def _accumulate_gradient_covariance(self, module: Module, grad_output: Tensor):
grad_output: The gradient w.r.t. the output.
"""
g = grad_output.data.detach()
batch_size = g.shape[0]
if isinstance(module, Conv2d):
g = rearrange(g, "batch c o1 o2 -> batch o1 o2 c")
num_loss_terms = g.shape[0] # batch_size
sequence_length = g.shape[1:-1].numel()
if self._loss_average == "batch+sequence":
# Number of all loss terms = batch_size * sequence_length
num_loss_terms *= sequence_length
num_loss_terms = {
None: batch_size,
"batch": batch_size,
"batch+sequence": batch_size * sequence_length,
}[self._loss_average]

if self._kfac_approx == "expand":
# KFAC-expand approximation
Expand Down Expand Up @@ -536,7 +548,7 @@ def _hook_accumulate_input_covariance(self, module: Module, inputs: Tuple[Tensor
# KFAC-expand approximation
scale = x.shape[1:-1].numel() # sequence_length
x = rearrange(x, "batch ... d_in -> (batch ...) d_in")
elif self._kfac_approx == "reduce":
else:
# KFAC-reduce approximation
scale = 1.0 # since we use a mean reduction
x = reduce(x, "batch ... d_in -> batch d_in", "mean")
Expand Down
12 changes: 4 additions & 8 deletions test/kfac_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,17 +60,13 @@
{
"model_func": lambda: WeightShareModel(Linear(5, 4), Linear(4, 3)),
"data": lambda: {
# exactess with expand requires that the product of all
# weight-sharing dimension sizes is the same in all batches
# (e.g. 32 = 4 * 8)
"expand": [
(rand(2, 32, 5), regression_targets((2, 32, 3))),
(rand(2, 4, 8, 5), regression_targets((2, 4, 8, 3))),
(rand(7, 4, 8, 5), regression_targets((7, 4, 8, 3))),
],
# for reduce it works with arbitrary weight-sharing dims per batch
"reduce": [
(rand(1, 4, 5), regression_targets((1, 3))),
(rand(8, 4, 8, 5), regression_targets((8, 3))),
(rand(1, 4, 8, 5), regression_targets((1, 3))),
(rand(7, 4, 8, 5), regression_targets((7, 3))),
],
},
"seed": 0,
Expand All @@ -85,7 +81,7 @@
],
"reduce": [
(rand(1, 3, 16, 16), regression_targets((1, 2))),
(rand(8, 3, 32, 32), regression_targets((8, 2))),
(rand(8, 3, 16, 16), regression_targets((8, 2))),
],
},
"seed": 0,
Expand Down
58 changes: 45 additions & 13 deletions test/test_kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
from test.cases import DEVICES, DEVICES_IDS
from test.utils import (
Conv2dModel,
Rearrange,
WeightShareModel,
classification_targets,
ggn_block_diagonal,
regression_targets,
)
from typing import Dict, Iterable, List, Tuple, Union

from einops.layers.torch import Rearrange
from numpy import eye
from pytest import mark
from scipy.linalg import block_diag
Expand Down Expand Up @@ -59,6 +59,7 @@ def test_kfac_type2(
"""
assert exclude in [None, "weight", "bias"]
model, loss_func, params, data = kfac_exact_case
loss_average = None if loss_func.reduction == "sum" else "batch"

if exclude is not None:
names = {p.data_ptr(): name for name, p in model.named_parameters()}
Expand All @@ -81,6 +82,7 @@ def test_kfac_type2(
params,
data,
fisher_type="type-2",
loss_average=loss_average,
separate_weight_and_bias=separate_weight_and_bias,
)
kfac_mat = kfac @ eye(kfac.shape[1])
Expand All @@ -102,7 +104,7 @@ def test_kfac_type2(
@mark.parametrize("shuffle", [False, True], ids=["", "shuffled"])
def test_kfac_type2_weight_sharing(
kfac_weight_sharing_exact_case: Tuple[
WeightShareModel,
Union[WeightShareModel, Conv2dModel],
MSELoss,
List[Parameter],
Dict[str, Iterable[Tuple[Tensor, Tensor]]],
Expand Down Expand Up @@ -190,13 +192,16 @@ def test_kfac_mc(
shuffle: Whether to shuffle the parameters before computing the KFAC matrix.
"""
model, loss_func, params, data = kfac_exact_case
loss_average = None if loss_func.reduction == "sum" else "batch"

if shuffle:
permutation = randperm(len(params))
params = [params[i] for i in permutation]

ggn = ggn_block_diagonal(model, loss_func, params, data)
kfac = KFACLinearOperator(model, loss_func, params, data, mc_samples=2_000)
kfac = KFACLinearOperator(
model, loss_func, params, data, mc_samples=2_000, loss_average=loss_average
)

kfac_mat = kfac @ eye(kfac.shape[1])

Expand All @@ -212,9 +217,12 @@ def test_kfac_one_datum(
]
):
model, loss_func, params, data = kfac_exact_one_datum_case
loss_average = None if loss_func.reduction == "sum" else "batch"

ggn = ggn_block_diagonal(model, loss_func, params, data)
kfac = KFACLinearOperator(model, loss_func, params, data, fisher_type="type-2")
kfac = KFACLinearOperator(
model, loss_func, params, data, fisher_type="type-2", loss_average=loss_average
)
kfac_mat = kfac @ eye(kfac.shape[1])

report_nonclose(ggn, kfac_mat)
Expand All @@ -226,9 +234,12 @@ def test_kfac_mc_one_datum(
]
):
model, loss_func, params, data = kfac_exact_one_datum_case
ggn = ggn_block_diagonal(model, loss_func, params, data)
loss_average = None if loss_func.reduction == "sum" else "batch"

kfac = KFACLinearOperator(model, loss_func, params, data, mc_samples=10_000)
ggn = ggn_block_diagonal(model, loss_func, params, data)
kfac = KFACLinearOperator(
model, loss_func, params, data, mc_samples=10_000, loss_average=loss_average
)
kfac_mat = kfac @ eye(kfac.shape[1])

atol = {"sum": 1e-3, "mean": 1e-3}[loss_func.reduction]
Expand All @@ -243,14 +254,22 @@ def test_kfac_ef_one_datum(
]
):
model, loss_func, params, data = kfac_exact_one_datum_case
loss_average = None if loss_func.reduction == "sum" else "batch"

ef_blocks = [] # list of per-parameter EFs
for param in params:
ef = EFLinearOperator(model, loss_func, [param], data)
ef_blocks.append(ef @ eye(ef.shape[1]))
ef = block_diag(*ef_blocks)

kfac = KFACLinearOperator(model, loss_func, params, data, fisher_type="empirical")
kfac = KFACLinearOperator(
model,
loss_func,
params,
data,
fisher_type="empirical",
loss_average=loss_average,
)
kfac_mat = kfac @ eye(kfac.shape[1])

report_nonclose(ef, kfac_mat)
Expand Down Expand Up @@ -314,16 +333,17 @@ def test_multi_dim_output(
manual_seed(0)
# set up loss function, data, and model
loss_func = loss(reduction=reduction).to(dev)
loss_average = None if reduction == "sum" else "batch+sequence"
if isinstance(loss_func, MSELoss):
data = [
(rand(2, 4, 5), regression_targets((2, 4, 3))),
(rand(2, 7, 5, 5), regression_targets((2, 7, 5, 3))),
(rand(4, 7, 5, 5), regression_targets((4, 7, 5, 3))),
]
manual_seed(711)
model = Sequential(Linear(5, 4), Linear(4, 3)).to(dev)
else:
data = [
(rand(2, 4, 5), classification_targets((2, 4), 3)),
(rand(2, 7, 5, 5), classification_targets((2, 7, 5), 3)),
(rand(4, 7, 5, 5), classification_targets((4, 7, 5), 3)),
]
manual_seed(711)
Expand All @@ -334,12 +354,19 @@ def test_multi_dim_output(
Rearrange("batch ... c -> batch c ..."),
).to(dev)

# KFAC for deep linear network with 3d/4d input and output
# KFAC for deep linear network with 4d input and output
params = list(model.parameters())
kfac = KFACLinearOperator(model, loss_func, params, data, fisher_type=fisher_type)
kfac = KFACLinearOperator(
model,
loss_func,
params,
data,
fisher_type=fisher_type,
loss_average=loss_average,
)
kfac_mat = kfac @ eye(kfac.shape[1])

# KFAC for deep linear network with 3d/4d input and equivalent 2d output
# KFAC for deep linear network with 4d input and equivalent 2d output
manual_seed(711)
model_flat = Sequential(
Linear(5, 4),
Expand All @@ -354,7 +381,12 @@ def test_multi_dim_output(
for x, y in data
]
kfac_flat = KFACLinearOperator(
model_flat, loss_func, params_flat, data_flat, fisher_type=fisher_type
model_flat,
loss_func,
params_flat,
data_flat,
fisher_type=fisher_type,
loss_average=loss_average,
)
kfac_flat_mat = kfac_flat @ eye(kfac_flat.shape[1])

Expand Down
Loading

0 comments on commit 9878905

Please sign in to comment.