19
19
20
20
from einops import rearrange
21
21
from numpy import ndarray
22
- from torch import Generator , Tensor , cat , einsum , randn
22
+ from torch import Generator , Tensor , cat , einsum , randn , stack
23
23
from torch .nn import CrossEntropyLoss , Linear , Module , MSELoss , Parameter
24
24
from torch .utils .hooks import RemovableHandle
25
25
26
26
from curvlinops ._base import _LinearOperator
27
+ from curvlinops .kfac_utils import loss_hessian_matrix_sqrt
27
28
28
29
29
30
class KFACLinearOperator (_LinearOperator ):
@@ -125,7 +126,7 @@ def __init__(
125
126
used which corresponds to the uncentered gradient covariance, or
126
127
the empirical Fisher. Defaults to ``'mc'``.
127
128
mc_samples: The number of Monte-Carlo samples to use per data point.
128
- Will be ignored when ``fisher_type `` is not `` 'mc'``.
129
+ Has to be set to ``1 `` when ``fisher_type != 'mc'``.
129
130
Defaults to ``1``.
130
131
separate_weight_and_bias: Whether to treat weights and biases separately.
131
132
Defaults to ``True``.
@@ -138,6 +139,11 @@ def __init__(
138
139
raise ValueError (
139
140
f"Invalid loss: { loss_func } . Supported: { self ._SUPPORTED_LOSSES } ."
140
141
)
142
+ if fisher_type != "mc" and mc_samples != 1 :
143
+ raise ValueError (
144
+ f"Invalid mc_samples: { mc_samples } . "
145
+ "Only mc_samples=1 is supported for fisher_type != 'mc'."
146
+ )
141
147
142
148
self .param_ids = [p .data_ptr () for p in params ]
143
149
# mapping from tuples of parameter data pointers in a module to its name
@@ -231,13 +237,7 @@ def _adjoint(self) -> KFACLinearOperator:
231
237
return self
232
238
233
239
def _compute_kfac (self ):
234
- """Compute and cache KFAC's Kronecker factors for future ``matvec``s.
235
-
236
- Raises:
237
- NotImplementedError: If ``fisher_type == 'type-2'``.
238
- ValueError: If ``fisher_type`` is not ``'type-2'``, ``'mc'``, or
239
- ``'empirical'``.
240
- """
240
+ """Compute and cache KFAC's Kronecker factors for future ``matvec``s."""
241
241
# install forward and backward hooks
242
242
hook_handles : List [RemovableHandle ] = []
243
243
@@ -266,31 +266,70 @@ def _compute_kfac(self):
266
266
267
267
for X , y in self ._loop_over_data (desc = "KFAC matrices" ):
268
268
output = self ._model_func (X )
269
-
270
- if self ._fisher_type == "type-2" :
271
- raise NotImplementedError (
272
- "Using the exact expectation for computing the KFAC "
273
- "approximation of the Fisher is not yet supported."
274
- )
275
- elif self ._fisher_type == "mc" :
276
- for mc in range (self ._mc_samples ):
277
- y_sampled = self .draw_label (output )
278
- loss = self ._loss_func (output , y_sampled )
279
- loss .backward (retain_graph = mc != self ._mc_samples - 1 )
280
- elif self ._fisher_type == "empirical" :
281
- loss = self ._loss_func (output , y )
282
- loss .backward ()
283
- else :
284
- raise ValueError (
285
- f"Invalid fisher_type: { self ._fisher_type } . "
286
- + "Supported: 'type-2', 'mc', 'empirical'."
287
- )
269
+ self ._compute_loss_and_backward (output , y )
288
270
289
271
# clean up
290
272
self ._model_func .zero_grad ()
291
273
for handle in hook_handles :
292
274
handle .remove ()
293
275
276
+ def _compute_loss_and_backward (self , output : Tensor , y : Tensor ):
277
+ r"""Compute the loss and the backward pass(es) required for KFAC.
278
+
279
+ Args:
280
+ output: The model's prediction
281
+ :math:`\{f_\mathbf{\theta}(\mathbf{x}_n)\}_{n=1}^N`.
282
+ y: The labels :math:`\{\mathbf{y}_n\}_{n=1}^N`.
283
+
284
+ Raises:
285
+ ValueError: If ``fisher_type`` is not ``'type-2'``, ``'mc'``, or
286
+ ``'empirical'``.
287
+ NotImplementedError: If ``fisher_type`` is ``'type-1'`` and the
288
+ output is not 2d.
289
+ """
290
+ if self ._fisher_type == "type-2" :
291
+ if output .ndim != 2 :
292
+ raise NotImplementedError (
293
+ "Type-2 Fisher not implemented for non-2d output."
294
+ )
295
+ # Compute per-sample Hessian square root, then concatenate over samples.
296
+ # Result has shape `(batch_size, num_classes, num_classes)`
297
+ hessian_sqrts = stack (
298
+ [
299
+ loss_hessian_matrix_sqrt (out .detach (), self ._loss_func )
300
+ for out in output .split (1 )
301
+ ]
302
+ )
303
+
304
+ # Fix scaling caused by the batch dimension
305
+ batch_size = output .shape [0 ]
306
+ reduction = self ._loss_func .reduction
307
+ scale = {"sum" : 1.0 , "mean" : 1.0 / batch_size }[reduction ]
308
+ hessian_sqrts .mul_ (scale )
309
+
310
+ # For each column `c` of the matrix square root we need to backpropagate,
311
+ # but we can do this for all samples in parallel
312
+ num_cols = hessian_sqrts .shape [- 1 ]
313
+ for c in range (num_cols ):
314
+ batched_column = hessian_sqrts [:, :, c ]
315
+ (output * batched_column ).sum ().backward (retain_graph = c < num_cols - 1 )
316
+
317
+ elif self ._fisher_type == "mc" :
318
+ for mc in range (self ._mc_samples ):
319
+ y_sampled = self .draw_label (output )
320
+ loss = self ._loss_func (output , y_sampled )
321
+ loss .backward (retain_graph = mc != self ._mc_samples - 1 )
322
+
323
+ elif self ._fisher_type == "empirical" :
324
+ loss = self ._loss_func (output , y )
325
+ loss .backward ()
326
+
327
+ else :
328
+ raise ValueError (
329
+ f"Invalid fisher_type: { self ._fisher_type } . "
330
+ + "Supported: 'type-2', 'mc', 'empirical'."
331
+ )
332
+
294
333
def draw_label (self , output : Tensor ) -> Tensor :
295
334
r"""Draw a sample from the model's predictive distribution.
296
335
@@ -393,6 +432,7 @@ def _accumulate_gradient_covariance(self, module: Module, grad_output: Tensor):
393
432
)
394
433
395
434
batch_size = g .shape [0 ]
435
+ # self._mc_samples will be 1 if fisher_type != "mc"
396
436
correction = {
397
437
"sum" : 1.0 / self ._mc_samples ,
398
438
"mean" : batch_size ** 2 / (self ._N_data * self ._mc_samples ),
0 commit comments