Skip to content

Commit 1141c50

Browse files
runamef-dangel
andauthored
[ADD] Support for KFAC with type-2 Fisher (#56)
* Add test for MSELoss type-2 KFAC * Implement type-2 KFAC for MSELoss * Add docstrings * Fix black * Fix docstring * Add test for CrossEntropyLoss type-2 KFAC * Implement type-2 KFAC for CrossEntropyLoss * Fix auto-merge issue * Fix comment * [REF] Refactor type-2 using Hessian matrix square root * [DEL] Remove unused imports * [FIX] Darglint * [FIX] Function name in docs * [REF] Improve function name * [REF] Rename `num_classes` into `output_dim` --------- Co-authored-by: Felix Dangel <[email protected]>
1 parent 10180ce commit 1141c50

File tree

6 files changed

+237
-54
lines changed

6 files changed

+237
-54
lines changed

curvlinops/kfac.py

Lines changed: 68 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,12 @@
1919

2020
from einops import rearrange
2121
from numpy import ndarray
22-
from torch import Generator, Tensor, cat, einsum, randn
22+
from torch import Generator, Tensor, cat, einsum, randn, stack
2323
from torch.nn import CrossEntropyLoss, Linear, Module, MSELoss, Parameter
2424
from torch.utils.hooks import RemovableHandle
2525

2626
from curvlinops._base import _LinearOperator
27+
from curvlinops.kfac_utils import loss_hessian_matrix_sqrt
2728

2829

2930
class KFACLinearOperator(_LinearOperator):
@@ -125,7 +126,7 @@ def __init__(
125126
used which corresponds to the uncentered gradient covariance, or
126127
the empirical Fisher. Defaults to ``'mc'``.
127128
mc_samples: The number of Monte-Carlo samples to use per data point.
128-
Will be ignored when ``fisher_type`` is not ``'mc'``.
129+
Has to be set to ``1`` when ``fisher_type != 'mc'``.
129130
Defaults to ``1``.
130131
separate_weight_and_bias: Whether to treat weights and biases separately.
131132
Defaults to ``True``.
@@ -138,6 +139,11 @@ def __init__(
138139
raise ValueError(
139140
f"Invalid loss: {loss_func}. Supported: {self._SUPPORTED_LOSSES}."
140141
)
142+
if fisher_type != "mc" and mc_samples != 1:
143+
raise ValueError(
144+
f"Invalid mc_samples: {mc_samples}. "
145+
"Only mc_samples=1 is supported for fisher_type != 'mc'."
146+
)
141147

142148
self.param_ids = [p.data_ptr() for p in params]
143149
# mapping from tuples of parameter data pointers in a module to its name
@@ -231,13 +237,7 @@ def _adjoint(self) -> KFACLinearOperator:
231237
return self
232238

233239
def _compute_kfac(self):
234-
"""Compute and cache KFAC's Kronecker factors for future ``matvec``s.
235-
236-
Raises:
237-
NotImplementedError: If ``fisher_type == 'type-2'``.
238-
ValueError: If ``fisher_type`` is not ``'type-2'``, ``'mc'``, or
239-
``'empirical'``.
240-
"""
240+
"""Compute and cache KFAC's Kronecker factors for future ``matvec``s."""
241241
# install forward and backward hooks
242242
hook_handles: List[RemovableHandle] = []
243243

@@ -266,31 +266,70 @@ def _compute_kfac(self):
266266

267267
for X, y in self._loop_over_data(desc="KFAC matrices"):
268268
output = self._model_func(X)
269-
270-
if self._fisher_type == "type-2":
271-
raise NotImplementedError(
272-
"Using the exact expectation for computing the KFAC "
273-
"approximation of the Fisher is not yet supported."
274-
)
275-
elif self._fisher_type == "mc":
276-
for mc in range(self._mc_samples):
277-
y_sampled = self.draw_label(output)
278-
loss = self._loss_func(output, y_sampled)
279-
loss.backward(retain_graph=mc != self._mc_samples - 1)
280-
elif self._fisher_type == "empirical":
281-
loss = self._loss_func(output, y)
282-
loss.backward()
283-
else:
284-
raise ValueError(
285-
f"Invalid fisher_type: {self._fisher_type}. "
286-
+ "Supported: 'type-2', 'mc', 'empirical'."
287-
)
269+
self._compute_loss_and_backward(output, y)
288270

289271
# clean up
290272
self._model_func.zero_grad()
291273
for handle in hook_handles:
292274
handle.remove()
293275

276+
def _compute_loss_and_backward(self, output: Tensor, y: Tensor):
277+
r"""Compute the loss and the backward pass(es) required for KFAC.
278+
279+
Args:
280+
output: The model's prediction
281+
:math:`\{f_\mathbf{\theta}(\mathbf{x}_n)\}_{n=1}^N`.
282+
y: The labels :math:`\{\mathbf{y}_n\}_{n=1}^N`.
283+
284+
Raises:
285+
ValueError: If ``fisher_type`` is not ``'type-2'``, ``'mc'``, or
286+
``'empirical'``.
287+
NotImplementedError: If ``fisher_type`` is ``'type-1'`` and the
288+
output is not 2d.
289+
"""
290+
if self._fisher_type == "type-2":
291+
if output.ndim != 2:
292+
raise NotImplementedError(
293+
"Type-2 Fisher not implemented for non-2d output."
294+
)
295+
# Compute per-sample Hessian square root, then concatenate over samples.
296+
# Result has shape `(batch_size, num_classes, num_classes)`
297+
hessian_sqrts = stack(
298+
[
299+
loss_hessian_matrix_sqrt(out.detach(), self._loss_func)
300+
for out in output.split(1)
301+
]
302+
)
303+
304+
# Fix scaling caused by the batch dimension
305+
batch_size = output.shape[0]
306+
reduction = self._loss_func.reduction
307+
scale = {"sum": 1.0, "mean": 1.0 / batch_size}[reduction]
308+
hessian_sqrts.mul_(scale)
309+
310+
# For each column `c` of the matrix square root we need to backpropagate,
311+
# but we can do this for all samples in parallel
312+
num_cols = hessian_sqrts.shape[-1]
313+
for c in range(num_cols):
314+
batched_column = hessian_sqrts[:, :, c]
315+
(output * batched_column).sum().backward(retain_graph=c < num_cols - 1)
316+
317+
elif self._fisher_type == "mc":
318+
for mc in range(self._mc_samples):
319+
y_sampled = self.draw_label(output)
320+
loss = self._loss_func(output, y_sampled)
321+
loss.backward(retain_graph=mc != self._mc_samples - 1)
322+
323+
elif self._fisher_type == "empirical":
324+
loss = self._loss_func(output, y)
325+
loss.backward()
326+
327+
else:
328+
raise ValueError(
329+
f"Invalid fisher_type: {self._fisher_type}. "
330+
+ "Supported: 'type-2', 'mc', 'empirical'."
331+
)
332+
294333
def draw_label(self, output: Tensor) -> Tensor:
295334
r"""Draw a sample from the model's predictive distribution.
296335
@@ -393,6 +432,7 @@ def _accumulate_gradient_covariance(self, module: Module, grad_output: Tensor):
393432
)
394433

395434
batch_size = g.shape[0]
435+
# self._mc_samples will be 1 if fisher_type != "mc"
396436
correction = {
397437
"sum": 1.0 / self._mc_samples,
398438
"mean": batch_size**2 / (self._N_data * self._mc_samples),

curvlinops/kfac_utils.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
"""Utility functions related to KFAC."""
2+
3+
from math import sqrt
4+
from typing import Union
5+
6+
from torch import Tensor, diag, einsum, eye
7+
from torch.nn import CrossEntropyLoss, MSELoss
8+
9+
10+
def loss_hessian_matrix_sqrt(
11+
output_one_datum: Tensor, loss_func: Union[MSELoss, CrossEntropyLoss]
12+
) -> Tensor:
13+
r"""Compute the loss function's matrix square root for a sample's output.
14+
15+
Args:
16+
output_one_datum: The model's prediction on a single datum. Has shape
17+
``[1, C]`` where ``C`` is the number of classes (outputs of the neural
18+
network).
19+
loss_func: The loss function.
20+
21+
Returns:
22+
The matrix square root
23+
:math:`\mathbf{S}` of the Hessian. Has shape
24+
``[C, C]`` and satisfies the relation
25+
26+
.. math::
27+
\mathbf{S} \mathbf{S}^\top
28+
=
29+
\nabla^2_{\mathbf{f}} \ell(\mathbf{f}, \mathbf{y})
30+
\in \mathbb{R}^{C \times C}
31+
32+
where :math:`\mathbf{f} := f(\mathbf{x}) \in \mathbb{R}^C` is the model's
33+
prediction on a single datum :math:`\mathbf{x}` and :math:`\mathbf{y}` is
34+
the label.
35+
36+
Note:
37+
For :class:`torch.nn.MSELoss` (with :math:`c = 1` for ``reduction='sum'``
38+
and :math:`c = 1/C` for ``reduction='mean'``), we have:
39+
40+
.. math::
41+
\ell(\mathbf{f}) &= c \sum_{i=1}^C (f_i - y_i)^2
42+
\\
43+
\nabla^2_{\mathbf{f}} \ell(\mathbf{f}, \mathbf{y}) &= 2 c \mathbf{I}_C
44+
\\
45+
\mathbf{S} &= \sqrt{2 c} \mathbf{I}_C
46+
47+
Note:
48+
For :class:`torch.nn.CrossEntropyLoss` (with :math:`c = 1` irrespective of the
49+
reduction, :math:`\mathbf{p}:=\mathrm{softmax}(\mathbf{f}) \in \mathbb{R}^C`,
50+
and the element-wise natural logarithm :math:`\log`) we have:
51+
52+
.. math::
53+
\ell(\mathbf{f}, y) = - c \log(\mathbf{p})^\top \mathrm{onehot}(y)
54+
\\
55+
\nabla^2_{\mathbf{f}} \ell(\mathbf{f}, y)
56+
=
57+
c \left(
58+
\mathrm{diag}(\mathbf{p}) - \mathbf{p} \mathbf{p}^\top
59+
\right)
60+
\\
61+
\mathbf{S} = \sqrt{c} \left(
62+
\mathrm{diag}(\sqrt{\mathbf{p}}) - \sqrt{\mathbf{p}} \mathbf{p}^\top
63+
\right)\,,
64+
65+
where the square root is applied element-wise. See for instance Example 5.1 of
66+
`this thesis <https://d-nb.info/1280233206/34>`_ or equations (5) and (6) of
67+
`this paper <https://arxiv.org/abs/1901.08244>`_.
68+
69+
Raises:
70+
ValueError: If the batch size is not one, or the output is not 2d.
71+
NotImplementedError: If the loss function is not supported.
72+
"""
73+
if output_one_datum.ndim != 2 or output_one_datum.shape[0] != 1:
74+
raise ValueError(
75+
f"Expected 'output_one_datum' to be 2d with shape [1, C], got "
76+
f"{output_one_datum.shape}"
77+
)
78+
output = output_one_datum.squeeze(0)
79+
output_dim = output.numel()
80+
81+
if isinstance(loss_func, MSELoss):
82+
c = {"sum": 1.0, "mean": 1.0 / output_dim}[loss_func.reduction]
83+
return eye(output_dim, device=output.device, dtype=output.dtype).mul_(
84+
sqrt(2 * c)
85+
)
86+
elif isinstance(loss_func, CrossEntropyLoss):
87+
c = 1.0
88+
p = output_one_datum.softmax(dim=1).squeeze()
89+
p_sqrt = p.sqrt()
90+
return (diag(p_sqrt) - einsum("i,j->ij", p, p_sqrt)).mul_(sqrt(c))
91+
else:
92+
raise NotImplementedError(f"Loss function {loss_func} not supported.")

docs/rtd/index.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,8 @@ Installation
4040

4141
linops
4242
basic_usage/index
43+
44+
.. toctree::
45+
:caption: Internals
46+
47+
internals

docs/rtd/internals.rst

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
Internals
2+
============
3+
4+
This section is for internal purposes only and serves to inform developers about
5+
details; because rendered LaTeX is easier to read than source code.
6+
7+
8+
KFAC-related
9+
-------------
10+
11+
.. autofunction:: curvlinops.kfac_utils.loss_hessian_matrix_sqrt

test/conftest.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -86,17 +86,3 @@ def kfac_expand_exact_one_datum_case(
8686
"""
8787
case = request.param
8888
yield initialize_case(case)
89-
90-
91-
@fixture(params=KFAC_EXPAND_EXACT_ONE_DATUM_CASES)
92-
def kfac_ef_exact_one_datum_case(
93-
request,
94-
) -> Tuple[Module, MSELoss, List[Tensor], Iterable[Tuple[Tensor, Tensor]],]:
95-
"""Prepare a test case with one datum for which KFAC with empirical gradients equals the EF.
96-
97-
Yields:
98-
A neural network, the mean-squared error function, a list of parameters, and
99-
a data set.
100-
"""
101-
case = request.param
102-
yield initialize_case(case)

test/test_kfac.py

Lines changed: 61 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,15 @@
88
from pytest import mark
99
from scipy.linalg import block_diag
1010
from torch import Tensor, device, manual_seed, rand, randperm
11-
from torch.nn import Linear, Module, MSELoss, Parameter, ReLU, Sequential
11+
from torch.nn import (
12+
CrossEntropyLoss,
13+
Linear,
14+
Module,
15+
MSELoss,
16+
Parameter,
17+
ReLU,
18+
Sequential,
19+
)
1220

1321
from curvlinops.examples.utils import report_nonclose
1422
from curvlinops.gradient_moments import EFLinearOperator
@@ -22,7 +30,7 @@
2230
"exclude", [None, "weight", "bias"], ids=["all", "no_weights", "no_biases"]
2331
)
2432
@mark.parametrize("shuffle", [False, True], ids=["", "shuffled"])
25-
def test_kfac(
33+
def test_kfac_type2(
2634
kfac_expand_exact_case: Tuple[
2735
Module, MSELoss, List[Parameter], Iterable[Tuple[Tensor, Tensor]]
2836
],
@@ -59,30 +67,71 @@ def test_kfac(
5967
data,
6068
separate_weight_and_bias=separate_weight_and_bias,
6169
)
62-
6370
kfac = KFACLinearOperator(
6471
model,
6572
loss_func,
6673
params,
6774
data,
68-
mc_samples=2_000,
75+
fisher_type="type-2",
6976
separate_weight_and_bias=separate_weight_and_bias,
7077
)
7178
kfac_mat = kfac @ eye(kfac.shape[1])
7279

73-
atol = {"sum": 5e-1, "mean": 5e-3}[loss_func.reduction]
74-
rtol = {"sum": 2e-2, "mean": 2e-2}[loss_func.reduction]
75-
76-
report_nonclose(ggn, kfac_mat, rtol=rtol, atol=atol)
80+
report_nonclose(ggn, kfac_mat)
7781

7882
# Check that input covariances were not computed
7983
if exclude == "weight":
8084
assert len(kfac._input_covariances) == 0
8185

8286

87+
@mark.parametrize("shuffle", [False, True], ids=["", "shuffled"])
88+
def test_kfac_mc(
89+
kfac_expand_exact_case: Tuple[
90+
Module, MSELoss, List[Parameter], Iterable[Tuple[Tensor, Tensor]]
91+
],
92+
shuffle: bool,
93+
):
94+
"""Test the KFAC implementation using MC samples against the exact GGN.
95+
96+
Args:
97+
kfac_expand_exact_case: A fixture that returns a model, loss function, list of
98+
parameters, and data.
99+
shuffle: Whether to shuffle the parameters before computing the KFAC matrix.
100+
"""
101+
model, loss_func, params, data = kfac_expand_exact_case
102+
103+
if shuffle:
104+
permutation = randperm(len(params))
105+
params = [params[i] for i in permutation]
106+
107+
ggn = ggn_block_diagonal(model, loss_func, params, data)
108+
kfac = KFACLinearOperator(model, loss_func, params, data, mc_samples=2_000)
109+
110+
kfac_mat = kfac @ eye(kfac.shape[1])
111+
112+
atol = {"sum": 5e-1, "mean": 5e-3}[loss_func.reduction]
113+
rtol = {"sum": 2e-2, "mean": 2e-2}[loss_func.reduction]
114+
115+
report_nonclose(ggn, kfac_mat, rtol=rtol, atol=atol)
116+
117+
83118
def test_kfac_one_datum(
84119
kfac_expand_exact_one_datum_case: Tuple[
85-
Module, MSELoss, List[Parameter], Iterable[Tuple[Tensor, Tensor]]
120+
Module, CrossEntropyLoss, List[Parameter], Iterable[Tuple[Tensor, Tensor]]
121+
]
122+
):
123+
model, loss_func, params, data = kfac_expand_exact_one_datum_case
124+
125+
ggn = ggn_block_diagonal(model, loss_func, params, data)
126+
kfac = KFACLinearOperator(model, loss_func, params, data, fisher_type="type-2")
127+
kfac_mat = kfac @ eye(kfac.shape[1])
128+
129+
report_nonclose(ggn, kfac_mat)
130+
131+
132+
def test_kfac_mc_one_datum(
133+
kfac_expand_exact_one_datum_case: Tuple[
134+
Module, CrossEntropyLoss, List[Parameter], Iterable[Tuple[Tensor, Tensor]]
86135
]
87136
):
88137
model, loss_func, params, data = kfac_expand_exact_one_datum_case
@@ -98,11 +147,11 @@ def test_kfac_one_datum(
98147

99148

100149
def test_kfac_ef_one_datum(
101-
kfac_ef_exact_one_datum_case: Tuple[
102-
Module, MSELoss, List[Parameter], Iterable[Tuple[Tensor, Tensor]]
150+
kfac_expand_exact_one_datum_case: Tuple[
151+
Module, CrossEntropyLoss, List[Parameter], Iterable[Tuple[Tensor, Tensor]]
103152
]
104153
):
105-
model, loss_func, params, data = kfac_ef_exact_one_datum_case
154+
model, loss_func, params, data = kfac_expand_exact_one_datum_case
106155

107156
ef_blocks = [] # list of per-parameter EFs
108157
for param in params:

0 commit comments

Comments
 (0)