Skip to content

Commit

Permalink
[ADD] Support for KFAC with type-2 Fisher (#56)
Browse files Browse the repository at this point in the history
* Add test for MSELoss type-2 KFAC

* Implement type-2 KFAC for MSELoss

* Add docstrings

* Fix black

* Fix docstring

* Add test for CrossEntropyLoss type-2 KFAC

* Implement type-2 KFAC for CrossEntropyLoss

* Fix auto-merge issue

* Fix comment

* [REF] Refactor type-2 using Hessian matrix square root

* [DEL] Remove unused imports

* [FIX] Darglint

* [FIX] Function name in docs

* [REF] Improve function name

* [REF] Rename `num_classes` into `output_dim`

---------

Co-authored-by: Felix Dangel <[email protected]>
  • Loading branch information
runame and f-dangel authored Nov 9, 2023
1 parent 10180ce commit 1141c50
Show file tree
Hide file tree
Showing 6 changed files with 237 additions and 54 deletions.
96 changes: 68 additions & 28 deletions curvlinops/kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@

from einops import rearrange
from numpy import ndarray
from torch import Generator, Tensor, cat, einsum, randn
from torch import Generator, Tensor, cat, einsum, randn, stack
from torch.nn import CrossEntropyLoss, Linear, Module, MSELoss, Parameter
from torch.utils.hooks import RemovableHandle

from curvlinops._base import _LinearOperator
from curvlinops.kfac_utils import loss_hessian_matrix_sqrt


class KFACLinearOperator(_LinearOperator):
Expand Down Expand Up @@ -125,7 +126,7 @@ def __init__(
used which corresponds to the uncentered gradient covariance, or
the empirical Fisher. Defaults to ``'mc'``.
mc_samples: The number of Monte-Carlo samples to use per data point.
Will be ignored when ``fisher_type`` is not ``'mc'``.
Has to be set to ``1`` when ``fisher_type != 'mc'``.
Defaults to ``1``.
separate_weight_and_bias: Whether to treat weights and biases separately.
Defaults to ``True``.
Expand All @@ -138,6 +139,11 @@ def __init__(
raise ValueError(
f"Invalid loss: {loss_func}. Supported: {self._SUPPORTED_LOSSES}."
)
if fisher_type != "mc" and mc_samples != 1:
raise ValueError(
f"Invalid mc_samples: {mc_samples}. "
"Only mc_samples=1 is supported for fisher_type != 'mc'."
)

self.param_ids = [p.data_ptr() for p in params]
# mapping from tuples of parameter data pointers in a module to its name
Expand Down Expand Up @@ -231,13 +237,7 @@ def _adjoint(self) -> KFACLinearOperator:
return self

def _compute_kfac(self):
"""Compute and cache KFAC's Kronecker factors for future ``matvec``s.
Raises:
NotImplementedError: If ``fisher_type == 'type-2'``.
ValueError: If ``fisher_type`` is not ``'type-2'``, ``'mc'``, or
``'empirical'``.
"""
"""Compute and cache KFAC's Kronecker factors for future ``matvec``s."""
# install forward and backward hooks
hook_handles: List[RemovableHandle] = []

Expand Down Expand Up @@ -266,31 +266,70 @@ def _compute_kfac(self):

for X, y in self._loop_over_data(desc="KFAC matrices"):
output = self._model_func(X)

if self._fisher_type == "type-2":
raise NotImplementedError(
"Using the exact expectation for computing the KFAC "
"approximation of the Fisher is not yet supported."
)
elif self._fisher_type == "mc":
for mc in range(self._mc_samples):
y_sampled = self.draw_label(output)
loss = self._loss_func(output, y_sampled)
loss.backward(retain_graph=mc != self._mc_samples - 1)
elif self._fisher_type == "empirical":
loss = self._loss_func(output, y)
loss.backward()
else:
raise ValueError(
f"Invalid fisher_type: {self._fisher_type}. "
+ "Supported: 'type-2', 'mc', 'empirical'."
)
self._compute_loss_and_backward(output, y)

# clean up
self._model_func.zero_grad()
for handle in hook_handles:
handle.remove()

def _compute_loss_and_backward(self, output: Tensor, y: Tensor):
r"""Compute the loss and the backward pass(es) required for KFAC.
Args:
output: The model's prediction
:math:`\{f_\mathbf{\theta}(\mathbf{x}_n)\}_{n=1}^N`.
y: The labels :math:`\{\mathbf{y}_n\}_{n=1}^N`.
Raises:
ValueError: If ``fisher_type`` is not ``'type-2'``, ``'mc'``, or
``'empirical'``.
NotImplementedError: If ``fisher_type`` is ``'type-1'`` and the
output is not 2d.
"""
if self._fisher_type == "type-2":
if output.ndim != 2:
raise NotImplementedError(
"Type-2 Fisher not implemented for non-2d output."
)
# Compute per-sample Hessian square root, then concatenate over samples.
# Result has shape `(batch_size, num_classes, num_classes)`
hessian_sqrts = stack(
[
loss_hessian_matrix_sqrt(out.detach(), self._loss_func)
for out in output.split(1)
]
)

# Fix scaling caused by the batch dimension
batch_size = output.shape[0]
reduction = self._loss_func.reduction
scale = {"sum": 1.0, "mean": 1.0 / batch_size}[reduction]
hessian_sqrts.mul_(scale)

# For each column `c` of the matrix square root we need to backpropagate,
# but we can do this for all samples in parallel
num_cols = hessian_sqrts.shape[-1]
for c in range(num_cols):
batched_column = hessian_sqrts[:, :, c]
(output * batched_column).sum().backward(retain_graph=c < num_cols - 1)

elif self._fisher_type == "mc":
for mc in range(self._mc_samples):
y_sampled = self.draw_label(output)
loss = self._loss_func(output, y_sampled)
loss.backward(retain_graph=mc != self._mc_samples - 1)

elif self._fisher_type == "empirical":
loss = self._loss_func(output, y)
loss.backward()

else:
raise ValueError(
f"Invalid fisher_type: {self._fisher_type}. "
+ "Supported: 'type-2', 'mc', 'empirical'."
)

def draw_label(self, output: Tensor) -> Tensor:
r"""Draw a sample from the model's predictive distribution.
Expand Down Expand Up @@ -393,6 +432,7 @@ def _accumulate_gradient_covariance(self, module: Module, grad_output: Tensor):
)

batch_size = g.shape[0]
# self._mc_samples will be 1 if fisher_type != "mc"
correction = {
"sum": 1.0 / self._mc_samples,
"mean": batch_size**2 / (self._N_data * self._mc_samples),
Expand Down
92 changes: 92 additions & 0 deletions curvlinops/kfac_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
"""Utility functions related to KFAC."""

from math import sqrt
from typing import Union

from torch import Tensor, diag, einsum, eye
from torch.nn import CrossEntropyLoss, MSELoss


def loss_hessian_matrix_sqrt(
output_one_datum: Tensor, loss_func: Union[MSELoss, CrossEntropyLoss]
) -> Tensor:
r"""Compute the loss function's matrix square root for a sample's output.
Args:
output_one_datum: The model's prediction on a single datum. Has shape
``[1, C]`` where ``C`` is the number of classes (outputs of the neural
network).
loss_func: The loss function.
Returns:
The matrix square root
:math:`\mathbf{S}` of the Hessian. Has shape
``[C, C]`` and satisfies the relation
.. math::
\mathbf{S} \mathbf{S}^\top
=
\nabla^2_{\mathbf{f}} \ell(\mathbf{f}, \mathbf{y})
\in \mathbb{R}^{C \times C}
where :math:`\mathbf{f} := f(\mathbf{x}) \in \mathbb{R}^C` is the model's
prediction on a single datum :math:`\mathbf{x}` and :math:`\mathbf{y}` is
the label.
Note:
For :class:`torch.nn.MSELoss` (with :math:`c = 1` for ``reduction='sum'``
and :math:`c = 1/C` for ``reduction='mean'``), we have:
.. math::
\ell(\mathbf{f}) &= c \sum_{i=1}^C (f_i - y_i)^2
\\
\nabla^2_{\mathbf{f}} \ell(\mathbf{f}, \mathbf{y}) &= 2 c \mathbf{I}_C
\\
\mathbf{S} &= \sqrt{2 c} \mathbf{I}_C
Note:
For :class:`torch.nn.CrossEntropyLoss` (with :math:`c = 1` irrespective of the
reduction, :math:`\mathbf{p}:=\mathrm{softmax}(\mathbf{f}) \in \mathbb{R}^C`,
and the element-wise natural logarithm :math:`\log`) we have:
.. math::
\ell(\mathbf{f}, y) = - c \log(\mathbf{p})^\top \mathrm{onehot}(y)
\\
\nabla^2_{\mathbf{f}} \ell(\mathbf{f}, y)
=
c \left(
\mathrm{diag}(\mathbf{p}) - \mathbf{p} \mathbf{p}^\top
\right)
\\
\mathbf{S} = \sqrt{c} \left(
\mathrm{diag}(\sqrt{\mathbf{p}}) - \sqrt{\mathbf{p}} \mathbf{p}^\top
\right)\,,
where the square root is applied element-wise. See for instance Example 5.1 of
`this thesis <https://d-nb.info/1280233206/34>`_ or equations (5) and (6) of
`this paper <https://arxiv.org/abs/1901.08244>`_.
Raises:
ValueError: If the batch size is not one, or the output is not 2d.
NotImplementedError: If the loss function is not supported.
"""
if output_one_datum.ndim != 2 or output_one_datum.shape[0] != 1:
raise ValueError(
f"Expected 'output_one_datum' to be 2d with shape [1, C], got "
f"{output_one_datum.shape}"
)
output = output_one_datum.squeeze(0)
output_dim = output.numel()

if isinstance(loss_func, MSELoss):
c = {"sum": 1.0, "mean": 1.0 / output_dim}[loss_func.reduction]
return eye(output_dim, device=output.device, dtype=output.dtype).mul_(
sqrt(2 * c)
)
elif isinstance(loss_func, CrossEntropyLoss):
c = 1.0
p = output_one_datum.softmax(dim=1).squeeze()
p_sqrt = p.sqrt()
return (diag(p_sqrt) - einsum("i,j->ij", p, p_sqrt)).mul_(sqrt(c))
else:
raise NotImplementedError(f"Loss function {loss_func} not supported.")
5 changes: 5 additions & 0 deletions docs/rtd/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,8 @@ Installation

linops
basic_usage/index

.. toctree::
:caption: Internals

internals
11 changes: 11 additions & 0 deletions docs/rtd/internals.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
Internals
============

This section is for internal purposes only and serves to inform developers about
details; because rendered LaTeX is easier to read than source code.


KFAC-related
-------------

.. autofunction:: curvlinops.kfac_utils.loss_hessian_matrix_sqrt
14 changes: 0 additions & 14 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,17 +86,3 @@ def kfac_expand_exact_one_datum_case(
"""
case = request.param
yield initialize_case(case)


@fixture(params=KFAC_EXPAND_EXACT_ONE_DATUM_CASES)
def kfac_ef_exact_one_datum_case(
request,
) -> Tuple[Module, MSELoss, List[Tensor], Iterable[Tuple[Tensor, Tensor]],]:
"""Prepare a test case with one datum for which KFAC with empirical gradients equals the EF.
Yields:
A neural network, the mean-squared error function, a list of parameters, and
a data set.
"""
case = request.param
yield initialize_case(case)
73 changes: 61 additions & 12 deletions test/test_kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,15 @@
from pytest import mark
from scipy.linalg import block_diag
from torch import Tensor, device, manual_seed, rand, randperm
from torch.nn import Linear, Module, MSELoss, Parameter, ReLU, Sequential
from torch.nn import (
CrossEntropyLoss,
Linear,
Module,
MSELoss,
Parameter,
ReLU,
Sequential,
)

from curvlinops.examples.utils import report_nonclose
from curvlinops.gradient_moments import EFLinearOperator
Expand All @@ -22,7 +30,7 @@
"exclude", [None, "weight", "bias"], ids=["all", "no_weights", "no_biases"]
)
@mark.parametrize("shuffle", [False, True], ids=["", "shuffled"])
def test_kfac(
def test_kfac_type2(
kfac_expand_exact_case: Tuple[
Module, MSELoss, List[Parameter], Iterable[Tuple[Tensor, Tensor]]
],
Expand Down Expand Up @@ -59,30 +67,71 @@ def test_kfac(
data,
separate_weight_and_bias=separate_weight_and_bias,
)

kfac = KFACLinearOperator(
model,
loss_func,
params,
data,
mc_samples=2_000,
fisher_type="type-2",
separate_weight_and_bias=separate_weight_and_bias,
)
kfac_mat = kfac @ eye(kfac.shape[1])

atol = {"sum": 5e-1, "mean": 5e-3}[loss_func.reduction]
rtol = {"sum": 2e-2, "mean": 2e-2}[loss_func.reduction]

report_nonclose(ggn, kfac_mat, rtol=rtol, atol=atol)
report_nonclose(ggn, kfac_mat)

# Check that input covariances were not computed
if exclude == "weight":
assert len(kfac._input_covariances) == 0


@mark.parametrize("shuffle", [False, True], ids=["", "shuffled"])
def test_kfac_mc(
kfac_expand_exact_case: Tuple[
Module, MSELoss, List[Parameter], Iterable[Tuple[Tensor, Tensor]]
],
shuffle: bool,
):
"""Test the KFAC implementation using MC samples against the exact GGN.
Args:
kfac_expand_exact_case: A fixture that returns a model, loss function, list of
parameters, and data.
shuffle: Whether to shuffle the parameters before computing the KFAC matrix.
"""
model, loss_func, params, data = kfac_expand_exact_case

if shuffle:
permutation = randperm(len(params))
params = [params[i] for i in permutation]

ggn = ggn_block_diagonal(model, loss_func, params, data)
kfac = KFACLinearOperator(model, loss_func, params, data, mc_samples=2_000)

kfac_mat = kfac @ eye(kfac.shape[1])

atol = {"sum": 5e-1, "mean": 5e-3}[loss_func.reduction]
rtol = {"sum": 2e-2, "mean": 2e-2}[loss_func.reduction]

report_nonclose(ggn, kfac_mat, rtol=rtol, atol=atol)


def test_kfac_one_datum(
kfac_expand_exact_one_datum_case: Tuple[
Module, MSELoss, List[Parameter], Iterable[Tuple[Tensor, Tensor]]
Module, CrossEntropyLoss, List[Parameter], Iterable[Tuple[Tensor, Tensor]]
]
):
model, loss_func, params, data = kfac_expand_exact_one_datum_case

ggn = ggn_block_diagonal(model, loss_func, params, data)
kfac = KFACLinearOperator(model, loss_func, params, data, fisher_type="type-2")
kfac_mat = kfac @ eye(kfac.shape[1])

report_nonclose(ggn, kfac_mat)


def test_kfac_mc_one_datum(
kfac_expand_exact_one_datum_case: Tuple[
Module, CrossEntropyLoss, List[Parameter], Iterable[Tuple[Tensor, Tensor]]
]
):
model, loss_func, params, data = kfac_expand_exact_one_datum_case
Expand All @@ -98,11 +147,11 @@ def test_kfac_one_datum(


def test_kfac_ef_one_datum(
kfac_ef_exact_one_datum_case: Tuple[
Module, MSELoss, List[Parameter], Iterable[Tuple[Tensor, Tensor]]
kfac_expand_exact_one_datum_case: Tuple[
Module, CrossEntropyLoss, List[Parameter], Iterable[Tuple[Tensor, Tensor]]
]
):
model, loss_func, params, data = kfac_ef_exact_one_datum_case
model, loss_func, params, data = kfac_expand_exact_one_datum_case

ef_blocks = [] # list of per-parameter EFs
for param in params:
Expand Down

0 comments on commit 1141c50

Please sign in to comment.