Skip to content

Commit

Permalink
[DOC] Describe KFAC and its limitations
Browse files Browse the repository at this point in the history
  • Loading branch information
f-dangel committed Oct 27, 2023
1 parent 8b377a7 commit 87b044c
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 27 deletions.
2 changes: 2 additions & 0 deletions curvlinops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from curvlinops.hessian import HessianLinearOperator
from curvlinops.inverse import CGInverseLinearOperator, NeumannInverseLinearOperator
from curvlinops.jacobian import JacobianLinearOperator, TransposedJacobianLinearOperator
from curvlinops.kfac import KFACLinearOperator
from curvlinops.papyan2020traces.spectrum import (
LanczosApproximateLogSpectrumCached,
LanczosApproximateSpectrumCached,
Expand All @@ -22,6 +23,7 @@
"GGNLinearOperator",
"EFLinearOperator",
"FisherMCLinearOperator",
"KFACLinearOperator",
"JacobianLinearOperator",
"TransposedJacobianLinearOperator",
"CGInverseLinearOperator",
Expand Down
141 changes: 116 additions & 25 deletions curvlinops/kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,61 @@

from numpy import ndarray
from torch import Generator, Tensor, einsum, randn
from torch.nn import CrossEntropyLoss, Linear, Module, MSELoss, Parameter
from torch.nn import Linear, Module, MSELoss, Parameter
from torch.utils.hooks import RemovableHandle

from curvlinops._base import _LinearOperator


class KFACLinearOperator(_LinearOperator):
"""Linear operator to multiply with Fisher/GGN's KFAC approximation."""
r"""Linear operator to multiply with the Fisher/GGN's KFAC approximation.
_SUPPORTED_LOSSES = (MSELoss, CrossEntropyLoss)
KFAC approximates the per-layer Fisher/GGN with a Kronecker product:
Consider a weight matrix :math:`\mathbf{W}` and a bias vector :math:`\mathbf{b}`
in a single layer. The layer's Fisher :math:`\mathbf{F}(\mathbf{\theta})` for
.. math::
\mathbf{\theta}
=
\begin{pmatrix}
\mathrm{vec}(\mathbf{W}) \\ \mathbf{b}
\end{pmatrix}
where :math:`\mathrm{vec}` denotes column-stacking is approximated as
.. math::
\mathbf{F}(\mathbf{\theta})
\approx
\mathbf{A}_{(\text{KFAC})} \otimes \mathbf{B}_{(\text{KFAC})}
(see :class:`curvlinops.FisherMCLinearOperator` for the Fisher's definition).
Loosely speaking, the first Kronecker factor is the un-centered covariance of the
inputs to a layer. The second Kronecker factor is the un-centered covariance of
'would-be' gradients w.r.t. the layer's output. Those 'would-be' gradients result
from sampling labels from the model's distribution and computing their gradients.
The basic version of KFAC for MLPs was introduced in
- Martens, J., & Grosse, R. (2015). Optimizing neural networks with
Kronecker-factored approximate curvature. ICML.
and later generalized to convolutions in
- Grosse, R., & Martens, J. (2016). A kronecker-factored approximate Fisher
matrix for convolution layers. ICML.
Attributes:
_SUPPORTED_LOSSES: Tuple of supported loss functions.
_SUPPORTED_MODULES: Tuple of supported layers.
"""

_SUPPORTED_LOSSES = (MSELoss,)
_SUPPORTED_MODULES = (Linear,)

def __init__(
self,
model_func: Module,
loss_func: Union[MSELoss, CrossEntropyLoss],
loss_func: MSELoss,
params: List[Parameter],
data: Iterable[Tuple[Tensor, Tensor]],
progressbar: bool = False,
Expand All @@ -41,20 +81,71 @@ def __init__(
seed: int = 2147483647,
mc_samples: int = 1,
):
"""Kronecker-factored approximate curvature (KFAC) proxy of the Fisher/GGN.
Warning:
If the model's parameters change, e.g. during training, you need to
create a fresh instance of this object. This is because, for performance
reasons, the Kronecker factors are computed once and cached during the
first matrix-vector product. They will thus become outdated if the model
changes.
Warning:
This is an early proto-type with many limitations:
- Parameters must be in the same order as the model's parameters.
- Only linear layers with bias are supported.
- Weights and biases are treated separately.
- No weight sharing is supported.
- Only the Monte-Carlo sampled version is supported.
- Only the ``'expand'`` setting is supported.
Args:
model_func: The neural network. Must consist of modules.
loss_func: The loss function.
params: The parameters defining the Fisher/GGN that will be approximated
through KFAC.
data: A data loader containing the data of the Fisher/GGN.
progressbar: Whether to show a progress bar when computing the Kronecker
factors. Defaults to ``False``.
check_deterministic: Whether to check that the linear operator is
deterministic. Defaults to ``True``.
shape: The shape of the linear operator. If ``None``, it will be inferred
from the parameters. Defaults to ``None``.
seed: The seed for the random number generator used to draw labels
from the model's predictive distribution. Defaults to ``2147483647``.
mc_samples: The number of Monte-Carlo samples to use per data point.
Defaults to ``1``.
"""
if not isinstance(loss_func, self._SUPPORTED_LOSSES):
raise ValueError(
f"Invalid loss: {loss_func}. Supported: {self._SUPPORTED_LOSSES}."
)

# TODO Check for only linear layers
self.hooked_modules: List[str] = []
idx = 0
for mod in model_func.modules():
if len(list(mod.modules())) == 1 and list(mod.parameters()):
assert isinstance(mod, Linear)
assert mod.bias is not None
assert params[idx].data_ptr() == mod.weight.data_ptr()
assert params[idx + 1].data_ptr() == mod.bias.data_ptr()
for name, mod in model_func.named_modules():
if isinstance(mod, self._SUPPORTED_MODULES):
# TODO Support bias-free layers
if mod.bias is None:
raise NotImplementedError(
"Bias-free linear layers are not yet supported."
)
# TODO Support arbitrary orders and sub-sets of parameters
if (
params[idx].data_ptr() != mod.weight.data_ptr()
or params[idx + 1].data_ptr() != mod.bias.data_ptr()
):
raise NotImplementedError(
"KFAC parameters must be in same order as model parameters "
+ "for now."
)
idx += 2
self.hooked_modules.append(name)
if idx != len(params):
raise NotImplementedError(
"Could not identify all parameters with supported layers."
)

self._seed = seed
self._generator: Union[None, Generator] = None
Expand Down Expand Up @@ -113,37 +204,37 @@ def _adjoint(self) -> KFACLinearOperator:
return self

def _compute_kfac(self):
# install forward and backward hooks on layers
"""Compute and cache KFAC's Kronecker factors for future ``matvec``s."""
# install forward and backward hooks
hook_handles: List[RemovableHandle] = []

modules = []
for mod in self._model_func.modules():
if len(list(mod.modules())) == 1 and list(mod.parameters()):
assert isinstance(mod, Linear)
modules.append(mod)
hook_handles.extend(
mod.register_forward_pre_hook(self._hook_accumulate_input_covariance)
for mod in modules
self._model_func.get_submodule(mod).register_forward_pre_hook(
self._hook_accumulate_input_covariance
)
for mod in self.hooked_modules
)
hook_handles.extend(
mod.register_full_backward_hook(self._hook_accumulate_gradient_covariance)
for mod in modules
self._model_func.get_submodule(mod).register_full_backward_hook(
self._hook_accumulate_gradient_covariance
)
for mod in self.hooked_modules
)

# loop over data set
# loop over data set, computing the Kronecker factors
if self._generator is None:
self._generator = Generator(device=self._device)
self._generator.manual_seed(self._seed)

for X, _ in self._loop_over_data(desc="Computing KFAC matrices"):
for X, _ in self._loop_over_data(desc="KFAC matrices"):
output = self._model_func(X)

for mc in range(self._mc_samples):
y_sampled = self.draw_label(output)
loss = self._loss_func(output, y_sampled)
loss.backward(retain_graph=mc != self._mc_samples - 1)

# remove hooks
# clean up
self._model_func.zero_grad()
for handle in hook_handles:
handle.remove()

Expand Down
3 changes: 3 additions & 0 deletions docs/rtd/linops.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ Fisher (approximate)
.. autoclass:: curvlinops.FisherMCLinearOperator
:members: __init__

.. autoclass:: curvlinops.KFACLinearOperator
:members: __init__

Uncentered gradient covariance (empirical Fisher)
-------------------------------------------------

Expand Down
4 changes: 2 additions & 2 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

from numpy import random
from pytest import fixture
from torch import Module, Tensor, manual_seed
from torch.nn import MSELoss
from torch import Tensor, manual_seed
from torch.nn import Module, MSELoss


def initialize_case(
Expand Down

0 comments on commit 87b044c

Please sign in to comment.