diff --git a/curvlinops/kfac.py b/curvlinops/kfac.py index 37976a7e..6ca86c0b 100644 --- a/curvlinops/kfac.py +++ b/curvlinops/kfac.py @@ -13,6 +13,7 @@ from __future__ import annotations +from functools import partial from math import sqrt from typing import Dict, Iterable, List, Set, Tuple, Union @@ -253,8 +254,8 @@ def _compute_kfac(self): # gradient covariance required for weights and biases hook_handles.append( - module.register_full_backward_hook( - self._hook_accumulate_gradient_covariance + module.register_forward_hook( + self._register_tensor_hook_on_output_to_accumulate_gradient_covariance ) ) @@ -343,28 +344,45 @@ def draw_label(self, output: Tensor) -> Tensor: else: raise NotImplementedError - def _hook_accumulate_gradient_covariance( - self, module: Module, grad_input: Tuple[Tensor], grad_output: Tuple[Tensor] + def _register_tensor_hook_on_output_to_accumulate_gradient_covariance( + self, module: Module, inputs: Tuple[Tensor], output: Tensor ): - """Backward hook that accumulates the output-gradient covariance of a layer. + """Register tensor hook on layer's output to accumulate the grad. covariance. + + Note: + The easier way to compute the gradient covariance would be via a full + backward hook on the module itself which performs the computation. + However, this approach breaks down if the output of a layer feeds into an + activation with `inplace=True` (see + https://github.com/pytorch/pytorch/issues/61519). Hence we use the + workaround + https://github.com/pytorch/pytorch/issues/61519#issuecomment-883524237, and + install a module hook which installs a tensor hook on the module's output + tensor, which performs the accumulation of the gradient covariance. + + Args: + module: Layer onto whose output a tensor hook to accumulate the gradient + covariance will be installed. + inputs: The layer's input tensors. + output: The layer's output tensor. + """ + tensor_hook = partial(self._accumulate_gradient_covariance, module) + output.register_hook(tensor_hook) + + def _accumulate_gradient_covariance(self, module: Module, grad_output: Tensor): + """Accumulate the gradient covariance for a layer's output. Updates ``self._gradient_covariances``. Args: - module: The layer on which the hook is called. - grad_input: The gradient of the loss w.r.t. the layer's inputs. - grad_output: The gradient of the loss w.r.t. the layer's outputs. + module: The layer whose output's gradient covariance will be accumulated. + grad_output: The gradient w.r.t. the output. Raises: - ValueError: If ``grad_output`` is not a 1-tuple. NotImplementedError: If a layer uses weight sharing. NotImplementedError: If the layer is not supported. """ - if len(grad_output) != 1: - raise ValueError( - f"Expected grad_output to be a 1-tuple, got {len(grad_output)}." - ) - g = grad_output[0].data.detach() + g = grad_output.data.detach() if isinstance(module, Linear): if g.ndim != 2: diff --git a/test/test_kfac.py b/test/test_kfac.py index 4fa95ef5..0769645f 100644 --- a/test/test_kfac.py +++ b/test/test_kfac.py @@ -1,13 +1,14 @@ """Contains tests for ``curvlinops.kfac``.""" -from test.utils import ggn_block_diagonal +from test.cases import DEVICES, DEVICES_IDS +from test.utils import ggn_block_diagonal, regression_targets from typing import Iterable, List, Tuple from numpy import eye from pytest import mark from scipy.linalg import block_diag -from torch import Tensor, randperm -from torch.nn import Module, MSELoss, Parameter +from torch import Tensor, device, manual_seed, rand, randperm +from torch.nn import Linear, Module, MSELoss, Parameter, ReLU, Sequential from curvlinops.examples.utils import report_nonclose from curvlinops.gradient_moments import EFLinearOperator @@ -113,3 +114,40 @@ def test_kfac_ef_one_datum( kfac_mat = kfac @ eye(kfac.shape[1]) report_nonclose(ef, kfac_mat) + + +@mark.parametrize("dev", DEVICES, ids=DEVICES_IDS) +def test_kfac_inplace_activations(dev: device): + """Test that KFAC works if the network has in-place activations. + + We use a test case with a single datum as KFAC becomes exact as the number of + MC samples increases. + + Args: + dev: The device to run the test on. + """ + manual_seed(0) + model = Sequential(Linear(6, 3), ReLU(inplace=True), Linear(3, 2)).to(dev) + loss_func = MSELoss().to(dev) + batch_size = 1 + data = [(rand(batch_size, 6), regression_targets((batch_size, 2)))] + params = list(model.parameters()) + + # 1) compare KFAC and GGN + ggn = ggn_block_diagonal(model, loss_func, params, data) + + kfac = KFACLinearOperator(model, loss_func, params, data, mc_samples=2_000) + kfac_mat = kfac @ eye(kfac.shape[1]) + + atol = {"sum": 5e-1, "mean": 2e-3}[loss_func.reduction] + rtol = {"sum": 2e-2, "mean": 2e-2}[loss_func.reduction] + + report_nonclose(ggn, kfac_mat, rtol=rtol, atol=atol) + + # 2) Compare GGN (inplace=True) and GGN (inplace=False) + for mod in model.modules(): + if hasattr(mod, "inplace"): + mod.inplace = False + ggn_no_inplace = ggn_block_diagonal(model, loss_func, params, data) + + report_nonclose(ggn, ggn_no_inplace)