Skip to content

Commit

Permalink
Add support for KFAC-expand and reduce for linear modules
Browse files Browse the repository at this point in the history
  • Loading branch information
runame committed Jan 9, 2024
1 parent daf8393 commit 2a74ee0
Showing 1 changed file with 52 additions and 14 deletions.
66 changes: 52 additions & 14 deletions curvlinops/kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@

from __future__ import annotations

import warnings
from functools import partial
from math import sqrt
from typing import Dict, Iterable, List, Set, Tuple, Union

from einops import rearrange
from einops import rearrange, reduce
from numpy import ndarray
from torch import Generator, Tensor, cat, einsum, randn, stack
from torch.nn import CrossEntropyLoss, Linear, Module, MSELoss, Parameter
Expand Down Expand Up @@ -76,6 +77,7 @@ class KFACLinearOperator(_LinearOperator):
"batch",
"batch+sequence",
)
_SUPPORTED_KFAC_APPROX: Tuple[str, ...] = ("expand", "reduce")

def __init__(
self,
Expand All @@ -89,6 +91,7 @@ def __init__(
seed: int = 2147483647,
fisher_type: str = "mc",
mc_samples: int = 1,
kfac_approx: str = "expand",
loss_average: Union[None, str] = "batch",
separate_weight_and_bias: bool = True,
):
Expand Down Expand Up @@ -134,6 +137,13 @@ def __init__(
mc_samples: The number of Monte-Carlo samples to use per data point.
Has to be set to ``1`` when ``fisher_type != 'mc'``.
Defaults to ``1``.
kfac_approx: A string specifying the KFAC approximation that should
be used for linear weight-sharing layers, e.g. `Conv2d` modules
or `Linear` modules that process matrix- or higher-dimensional
features.
Possible values are ``'expand'`` and ``'reduce'``.
See [Eschenhagen et al., 2023](https://arxiv.org/abs/2311.00636)
for an explanation of the two approximations.
loss_average: Whether the loss function is a mean over per-sample
losses and if yes, over which dimensions the mean is taken.
If `"batch"`, the loss function is a mean over as many terms as
Expand All @@ -148,6 +158,10 @@ def __init__(
Raises:
ValueError: If the loss function is not supported.
ValueError: If the loss average is not supported.
ValueError: If the loss average is ``None`` and the loss function's
reduction is not ``'sum'``.
ValueError: If ``fisher_type != 'mc'`` and ``mc_samples != 1``.
NotImplementedError: If a parameter is in an unsupported layer.
"""
if not isinstance(loss_func, self._SUPPORTED_LOSSES):
Expand All @@ -164,11 +178,24 @@ def __init__(
f"Invalid loss_average: {loss_average}. "
f"Must be 'batch' or 'batch+sequence' if loss_func.reduction != 'sum'."
)
if loss_func.reduction == "sum" and loss_average is not None:
# reduction used in loss function will overwrite loss_average
warnings.warn(
f"Loss function uses reduction='sum', but loss_average={loss_average}."
" loss_average is set to None.",
stacklevel=2,
)
loss_average = None
if fisher_type != "mc" and mc_samples != 1:
raise ValueError(
f"Invalid mc_samples: {mc_samples}. "
"Only mc_samples=1 is supported for fisher_type != 'mc'."
)
if kfac_approx not in self._SUPPORTED_KFAC_APPROX:
raise ValueError(
f"Invalid kfac_approx: {kfac_approx}. "
f"Supported: {self._SUPPORTED_KFAC_APPROX}."
)

self.param_ids = [p.data_ptr() for p in params]
# mapping from tuples of parameter data pointers in a module to its name
Expand All @@ -192,6 +219,7 @@ def __init__(
self._separate_weight_and_bias = separate_weight_and_bias
self._fisher_type = fisher_type
self._mc_samples = mc_samples
self._kfac_approx = kfac_approx
self._loss_average = loss_average
self._input_covariances: Dict[str, Tensor] = {}
self._gradient_covariances: Dict[str, Tensor] = {}
Expand Down Expand Up @@ -443,22 +471,27 @@ def _accumulate_gradient_covariance(self, module: Module, grad_output: Tensor):
NotImplementedError: If the layer is not supported.
"""
g = grad_output.data.detach()
num_loss_terms = g.shape[0] # batch_size
sequence_length = g.shape[1:-1].numel()
if self._loss_average is not None:
num_loss_terms = g.shape[0] # batch_size
if self._loss_average == "batch+sequence":
# Number of all loss terms = batch_size * sequence_length
num_loss_terms *= sequence_length
if self._loss_average == "batch+sequence":
# Number of all loss terms = batch_size * sequence_length
num_loss_terms *= sequence_length

g = rearrange(g, "batch ... d_out -> (batch ...) d_out")
if self._kfac_approx == "expand":
# KFAC-expand approximation
g = rearrange(g, "batch ... d_out -> (batch ...) d_out")
else:
# KFAC-reduce approximation
g = reduce(g, "batch ... d_out -> batch d_out", "sum")

if isinstance(module, Linear):
# self._mc_samples will be 1 if fisher_type != "mc"
correction = {
"sum": 1.0 / self._mc_samples,
"mean": num_loss_terms**2
None: 1.0 / self._mc_samples,
"batch": num_loss_terms**2 / (self._N_data * self._mc_samples),
"batch+sequence": num_loss_terms**2
/ (self._N_data * self._mc_samples * sequence_length),
}[self._loss_func.reduction]
}[self._loss_average]
covariance = einsum("bi,bj->ij", g, g).mul_(correction)
else:
# TODO Support convolutions
Expand Down Expand Up @@ -490,18 +523,23 @@ def _hook_accumulate_input_covariance(self, module: Module, inputs: Tuple[Tensor
if len(inputs) != 1:
raise ValueError("Modules with multiple inputs are not supported.")
x = inputs[0].data.detach()
sequence_length = x.shape[1:-1].numel()

x = rearrange(x, "batch ... d_in -> (batch ...) d_in")
if self._kfac_approx == "expand":
# KFAC-expand approximation
scale = x.shape[1:-1].numel() # sequence_length
x = rearrange(x, "batch ... d_in -> (batch ...) d_in")
else:
# KFAC-reduce approximation
scale = 1.0 # since we use a mean reduction
x = reduce(x, "batch ... d_in -> batch d_in", "mean")

if isinstance(module, Linear):
if (
self.in_params(module.weight, module.bias)
and not self._separate_weight_and_bias
):
x = cat([x, x.new_ones(x.shape[0], 1)], dim=1)

covariance = einsum("bi,bj->ij", x, x).div_(self._N_data * sequence_length)
covariance = einsum("bi,bj->ij", x, x).div_(self._N_data * scale)
else:
# TODO Support convolutions
raise NotImplementedError(f"Layer of type {type(module)} is unsupported.")
Expand Down

0 comments on commit 2a74ee0

Please sign in to comment.