diff --git a/curvlinops/kfac.py b/curvlinops/kfac.py index c8810b5f..6ca86c0b 100644 --- a/curvlinops/kfac.py +++ b/curvlinops/kfac.py @@ -19,7 +19,7 @@ from einops import rearrange from numpy import ndarray -from torch import Generator, Tensor, einsum, randn +from torch import Generator, Tensor, cat, einsum, randn from torch.nn import CrossEntropyLoss, Linear, Module, MSELoss, Parameter from torch.utils.hooks import RemovableHandle @@ -83,6 +83,7 @@ def __init__( seed: int = 2147483647, fisher_type: str = "mc", mc_samples: int = 1, + separate_weight_and_bias: bool = True, ): """Kronecker-factored approximate curvature (KFAC) proxy of the Fisher/GGN. @@ -97,7 +98,6 @@ def __init__( This is an early proto-type with many limitations: - Only linear layers are supported. - - Weights and biases are treated separately. - No weight sharing is supported. - Only the ``'expand'`` approximation is supported. @@ -127,11 +127,12 @@ def __init__( mc_samples: The number of Monte-Carlo samples to use per data point. Will be ignored when ``fisher_type`` is not ``'mc'``. Defaults to ``1``. + separate_weight_and_bias: Whether to treat weights and biases separately. + Defaults to ``True``. Raises: ValueError: If the loss function is not supported. - NotImplementedError: If any parameter cannot be identified with a supported - layer. + NotImplementedError: If a parameter is in an unsupported layer. """ if not isinstance(loss_func, self._SUPPORTED_LOSSES): raise ValueError( @@ -157,6 +158,7 @@ def __init__( self._seed = seed self._generator: Union[None, Generator] = None + self._separate_weight_and_bias = separate_weight_and_bias self._fisher_type = fisher_type self._mc_samples = mc_samples self._input_covariances: Dict[str, Tensor] = {} @@ -191,16 +193,30 @@ def _matvec(self, x: ndarray) -> ndarray: for name in self.param_ids_to_hooked_modules.values(): mod = self._model_func.get_submodule(name) - if mod.weight.data_ptr() in self.param_ids: - idx = self.param_ids.index(mod.weight.data_ptr()) + # bias and weights are treated jointly + if not self._separate_weight_and_bias and self.in_params( + mod.weight, mod.bias + ): + w_pos, b_pos = self.param_pos(mod.weight), self.param_pos(mod.bias) + x_joint = cat([x_torch[w_pos], x_torch[b_pos].unsqueeze(-1)], dim=1) aaT = self._input_covariances[name] ggT = self._gradient_covariances[name] - x_torch[idx] = ggT @ x_torch[idx] @ aaT + x_joint = ggT @ x_joint @ aaT - if mod.bias is not None and mod.bias.data_ptr() in self.param_ids: - idx = self.param_ids.index(mod.bias.data_ptr()) - ggT = self._gradient_covariances[name] - x_torch[idx] = ggT @ x_torch[idx] + w_cols = mod.weight.shape[1] + x_torch[w_pos], x_torch[b_pos] = x_joint.split([w_cols, 1], dim=1) + + # for weights we need to multiply from the right with aaT + # for weights and biases we need to multiply from the left with ggT + else: + for p_name in ["weight", "bias"]: + p = getattr(mod, p_name) + if self.in_params(p): + pos = self.param_pos(p) + x_torch[pos] = self._gradient_covariances[name] @ x_torch[pos] + + if p_name == "weight": + x_torch[pos] = x_torch[pos] @ self._input_covariances[name] return super()._postprocess(x_torch) @@ -229,7 +245,7 @@ def _compute_kfac(self): module = self._model_func.get_submodule(name) # input covariance only required for weights - if module.weight.data_ptr() in self.param_ids: + if self.in_params(module.weight): hook_handles.append( module.register_forward_pre_hook( self._hook_accumulate_input_covariance @@ -420,6 +436,12 @@ def _hook_accumulate_input_covariance(self, module: Module, inputs: Tuple[Tensor f"Only 2d inputs are supported for linear layers. Got {x.ndim}d." ) + if ( + self.in_params(module.weight, module.bias) + and not self._separate_weight_and_bias + ): + x = cat([x, x.new_ones(x.shape[0], 1)], dim=1) + covariance = einsum("bi,bj->ij", x, x).div_(self._N_data) else: # TODO Support convolutions @@ -442,3 +464,25 @@ def get_module_name(self, module: Module) -> str: """ p_ids = tuple(p.data_ptr() for p in module.parameters()) return self.param_ids_to_hooked_modules[p_ids] + + def in_params(self, *params: Union[Parameter, Tensor, None]) -> bool: + """Check if all parameters are used in KFAC. + + Args: + params: Parameters to check. + + Returns: + Whether all parameters are used in KFAC. + """ + return all(p is not None and p.data_ptr() in self.param_ids for p in params) + + def param_pos(self, param: Union[Parameter, Tensor]) -> int: + """Get the position of a parameter in the list of parameters used in KFAC. + + Args: + param: The parameter. + + Returns: + The parameter's position in the parameter list. + """ + return self.param_ids.index(param.data_ptr()) diff --git a/test/test_kfac.py b/test/test_kfac.py index e4cd67fb..0769645f 100644 --- a/test/test_kfac.py +++ b/test/test_kfac.py @@ -1,7 +1,7 @@ """Contains tests for ``curvlinops.kfac``.""" from test.cases import DEVICES, DEVICES_IDS -from test.utils import regression_targets +from test.utils import ggn_block_diagonal, regression_targets from typing import Iterable, List, Tuple from numpy import eye @@ -11,11 +11,13 @@ from torch.nn import Linear, Module, MSELoss, Parameter, ReLU, Sequential from curvlinops.examples.utils import report_nonclose -from curvlinops.ggn import GGNLinearOperator from curvlinops.gradient_moments import EFLinearOperator from curvlinops.kfac import KFACLinearOperator +@mark.parametrize( + "separate_weight_and_bias", [True, False], ids=["separate_bias", "joint_bias"] +) @mark.parametrize( "exclude", [None, "weight", "bias"], ids=["all", "no_weights", "no_biases"] ) @@ -26,6 +28,7 @@ def test_kfac( ], shuffle: bool, exclude: str, + separate_weight_and_bias: bool, ): """Test the KFAC implementation against the exact GGN. @@ -35,6 +38,8 @@ def test_kfac( shuffle: Whether to shuffle the parameters before computing the KFAC matrix. exclude: Which parameters to exclude. Can be ``'weight'``, ``'bias'``, or ``None``. + separate_weight_and_bias: Whether to treat weight and bias as separate blocks in + the KFAC matrix. """ assert exclude in [None, "weight", "bias"] model, loss_func, params, data = kfac_expand_exact_case @@ -47,13 +52,22 @@ def test_kfac( permutation = randperm(len(params)) params = [params[i] for i in permutation] - ggn_blocks = [] # list of per-parameter GGNs - for param in params: - ggn = GGNLinearOperator(model, loss_func, [param], data) - ggn_blocks.append(ggn @ eye(ggn.shape[1])) - ggn = block_diag(*ggn_blocks) - - kfac = KFACLinearOperator(model, loss_func, params, data, mc_samples=2_000) + ggn = ggn_block_diagonal( + model, + loss_func, + params, + data, + separate_weight_and_bias=separate_weight_and_bias, + ) + + kfac = KFACLinearOperator( + model, + loss_func, + params, + data, + mc_samples=2_000, + separate_weight_and_bias=separate_weight_and_bias, + ) kfac_mat = kfac @ eye(kfac.shape[1]) atol = {"sum": 5e-1, "mean": 5e-3}[loss_func.reduction] @@ -72,12 +86,7 @@ def test_kfac_one_datum( ] ): model, loss_func, params, data = kfac_expand_exact_one_datum_case - - ggn_blocks = [] # list of per-parameter GGNs - for param in params: - ggn = GGNLinearOperator(model, loss_func, [param], data) - ggn_blocks.append(ggn @ eye(ggn.shape[1])) - ggn = block_diag(*ggn_blocks) + ggn = ggn_block_diagonal(model, loss_func, params, data) kfac = KFACLinearOperator(model, loss_func, params, data, mc_samples=10_000) kfac_mat = kfac @ eye(kfac.shape[1]) @@ -125,11 +134,7 @@ def test_kfac_inplace_activations(dev: device): params = list(model.parameters()) # 1) compare KFAC and GGN - ggn_blocks = [] # list of per-parameter GGNs - for param in params: - ggn = GGNLinearOperator(model, loss_func, [param], data) - ggn_blocks.append(ggn @ eye(ggn.shape[1])) - ggn = block_diag(*ggn_blocks) + ggn = ggn_block_diagonal(model, loss_func, params, data) kfac = KFACLinearOperator(model, loss_func, params, data, mc_samples=2_000) kfac_mat = kfac @ eye(kfac.shape[1]) @@ -143,11 +148,6 @@ def test_kfac_inplace_activations(dev: device): for mod in model.modules(): if hasattr(mod, "inplace"): mod.inplace = False + ggn_no_inplace = ggn_block_diagonal(model, loss_func, params, data) - ggn2_blocks = [] # list of per-parameter GGNs - for param in params: - ggn2 = GGNLinearOperator(model, loss_func, [param], data) - ggn2_blocks.append(ggn2 @ eye(ggn2.shape[1])) - ggn2 = block_diag(*ggn2_blocks) - - report_nonclose(ggn, ggn2) + report_nonclose(ggn, ggn_no_inplace) diff --git a/test/utils.py b/test/utils.py index e099667a..f343c7e7 100644 --- a/test/utils.py +++ b/test/utils.py @@ -1,6 +1,13 @@ """Utility functions to test `curvlinops`.""" -from torch import cuda, device, rand, randint +from itertools import product +from typing import Iterable, List, Tuple + +from numpy import eye, ndarray +from torch import Tensor, cat, cuda, device, from_numpy, rand, randint +from torch.nn import Module, Parameter + +from curvlinops import GGNLinearOperator def get_available_devices(): @@ -25,3 +32,62 @@ def classification_targets(size, num_classes): def regression_targets(size): """Create random targets for regression.""" return rand(*size) + + +def ggn_block_diagonal( + model: Module, + loss_func: Module, + params: List[Parameter], + data: Iterable[Tuple[Tensor, Tensor]], + separate_weight_and_bias: bool = True, +) -> ndarray: + """Compute the block-diagonal GGN. + + Args: + model: The neural network. + loss_func: The loss function. + params: The parameters w.r.t. which the GGN block-diagonals will be computed. + data: A data loader. + separate_weight_and_bias: Whether to treat weight and bias of a layer as + separate blocks in the block-diagonal GGN. Default: ``True``. + + Returns: + The block-diagonal GGN. + """ + # compute the full GGN then zero out the off-diagonal blocks + ggn = GGNLinearOperator(model, loss_func, params, data) + ggn = from_numpy(ggn @ eye(ggn.shape[1])) + sizes = [p.numel() for p in params] + # ggn_blocks[i, j] corresponds to the block of (params[i], params[j]) + ggn_blocks = [list(block.split(sizes, dim=1)) for block in ggn.split(sizes, dim=0)] + + # find out which blocks to keep + num_params = len(params) + keep = [(i, i) for i in range(num_params)] + param_ids = [p.data_ptr() for p in params] + + # keep blocks corresponding to jointly-treated weights and biases + if not separate_weight_and_bias: + # find all layers with weight and bias + has_weight_and_bias = [ + mod + for mod in model.modules() + if hasattr(mod, "weight") and hasattr(mod, "bias") and mod.bias is not None + ] + # only keep those whose parameters are included + has_weight_and_bias = [ + mod + for mod in has_weight_and_bias + if mod.weight.data_ptr() in param_ids and mod.bias.data_ptr() in param_ids + ] + for mod in has_weight_and_bias: + w_pos = param_ids.index(mod.weight.data_ptr()) + b_pos = param_ids.index(mod.bias.data_ptr()) + keep.extend([(w_pos, b_pos), (b_pos, w_pos)]) + + for i, j in product(range(num_params), range(num_params)): + if (i, j) not in keep: + ggn_blocks[i][j].zero_() + + # concatenate all blocks + return cat([cat(row_blocks, dim=1) for row_blocks in ggn_blocks], dim=0).numpy()