From c362dba9e26829f41756d79b928b8855f1373a31 Mon Sep 17 00:00:00 2001 From: runame Date: Tue, 9 Jan 2024 15:57:45 +0100 Subject: [PATCH 1/6] Add Rearrange module util --- test/utils.py | 48 ++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 44 insertions(+), 4 deletions(-) diff --git a/test/utils.py b/test/utils.py index f343c7e7..880df4e1 100644 --- a/test/utils.py +++ b/test/utils.py @@ -3,6 +3,7 @@ from itertools import product from typing import Iterable, List, Tuple +from einops import rearrange from numpy import eye, ndarray from torch import Tensor, cat, cuda, device, from_numpy, rand, randint from torch.nn import Module, Parameter @@ -24,13 +25,28 @@ def get_available_devices(): return devices -def classification_targets(size, num_classes): - """Create random targets for classes 0, ..., `num_classes - 1`.""" +def classification_targets(size: Tuple[int], num_classes: int) -> Tensor: + """Create random targets for classes 0, ..., `num_classes - 1`. + + Args: + size: Size of the targets to create. + num_classes: Number of classes. + + Returns: + Random targets. + """ return randint(size=size, low=0, high=num_classes) -def regression_targets(size): - """Create random targets for regression.""" +def regression_targets(size: Tuple[int]) -> Tensor: + """Create random targets for regression. + + Args: + size: Size of the targets to create. + + Returns: + Random targets. + """ return rand(*size) @@ -91,3 +107,27 @@ def ggn_block_diagonal( # concatenate all blocks return cat([cat(row_blocks, dim=1) for row_blocks in ggn_blocks], dim=0).numpy() + + +class Rearrange(Module): + """A module that rearranges the input tensor.""" + + def __init__(self, pattern: str): + """Initialize the module. + + Args: + pattern: The rearrangement pattern. + """ + super().__init__() + self.pattern = pattern + + def forward(self, x: Tensor) -> Tensor: + """Rearrange the input tensor. + + Args: + x: The input tensor. + + Returns: + The rearranged tensor. + """ + return rearrange(x, self.pattern) From 5ae955386f6e6a244a9d59e91c078c8e2abaee89 Mon Sep 17 00:00:00 2001 From: runame Date: Tue, 9 Jan 2024 15:58:32 +0100 Subject: [PATCH 2/6] Add test for KFAC with >2d model outputs --- test/test_kfac.py | 78 +++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 76 insertions(+), 2 deletions(-) diff --git a/test/test_kfac.py b/test/test_kfac.py index e2d4b187..49b92386 100644 --- a/test/test_kfac.py +++ b/test/test_kfac.py @@ -1,8 +1,13 @@ """Contains tests for ``curvlinops.kfac``.""" from test.cases import DEVICES, DEVICES_IDS -from test.utils import ggn_block_diagonal, regression_targets -from typing import Iterable, List, Tuple +from test.utils import ( + Rearrange, + classification_targets, + ggn_block_diagonal, + regression_targets, +) +from typing import Iterable, List, Tuple, Union from numpy import eye from pytest import mark @@ -10,6 +15,7 @@ from torch import Tensor, device, manual_seed, rand, randperm from torch.nn import ( CrossEntropyLoss, + Flatten, Linear, Module, MSELoss, @@ -200,3 +206,71 @@ def test_kfac_inplace_activations(dev: device): ggn_no_inplace = ggn_block_diagonal(model, loss_func, params, data) report_nonclose(ggn, ggn_no_inplace) + + +@mark.parametrize("fisher_type", ["type-2", "mc", "empirical"]) +@mark.parametrize("loss", [MSELoss, CrossEntropyLoss], ids=["mse", "ce"]) +@mark.parametrize("reduction", ["mean", "sum"]) +@mark.parametrize("dev", DEVICES, ids=DEVICES_IDS) +def test_multi_dim_output( + fisher_type: str, + loss: Union[MSELoss, CrossEntropyLoss], + reduction: str, + dev: device, +): + """Test the KFAC implementation for >2d outputs (using a 3d and 4d output). + + Args: + fisher_type: The type of Fisher matrix to use. + loss: The loss function to use. + reduction: The reduction to use for the loss function. + dev: The device to run the test on. + """ + manual_seed(0) + # set up loss function, data, and model + loss_func = loss(reduction=reduction).to(dev) + if isinstance(loss_func, MSELoss): + data = [ + (rand(2, 4, 5), regression_targets((2, 4, 3))), + (rand(4, 7, 5, 5), regression_targets((4, 7, 5, 3))), + ] + manual_seed(711) + model = Sequential(Linear(5, 4), Linear(4, 3)).to(dev) + else: + data = [ + (rand(2, 4, 5), classification_targets((2, 4), 3)), + (rand(4, 7, 5, 5), classification_targets((4, 7, 5), 3)), + ] + manual_seed(711) + # rearrange is necessary to get the expected output shape for ce loss + model = Sequential( + Linear(5, 4), + Linear(4, 3), + Rearrange("batch ... c -> batch c ..."), + ).to(dev) + + # KFAC for deep linear network with 3d/4d input and output + params = list(model.parameters()) + kfac = KFACLinearOperator(model, loss_func, params, data, fisher_type=fisher_type) + kfac_mat = kfac @ eye(kfac.shape[1]) + + # KFAC for deep linear network with 3d/4d input and equivalent 2d output + manual_seed(711) + model_flat = Sequential( + Linear(5, 4), + Linear(4, 3), + Flatten(start_dim=0, end_dim=-2), + ).to(dev) + params_flat = list(model_flat.parameters()) + data_flat = [ + (x, y.flatten(start_dim=0, end_dim=-2)) + if isinstance(loss_func, MSELoss) + else (x, y.flatten(start_dim=0)) + for x, y in data + ] + kfac_flat = KFACLinearOperator( + model_flat, loss_func, params_flat, data_flat, fisher_type=fisher_type + ) + kfac_flat_mat = kfac_flat @ eye(kfac_flat.shape[1]) + + report_nonclose(kfac_mat, kfac_flat_mat) From 2acd800fe62b4f52982ffbf25e853c90dccbc8fb Mon Sep 17 00:00:00 2001 From: runame Date: Tue, 9 Jan 2024 16:00:27 +0100 Subject: [PATCH 3/6] Add support for >2d linear layer inputs (expand) and model outputs --- curvlinops/kfac.py | 85 ++++++++++++++++++++++++++++------------------ 1 file changed, 52 insertions(+), 33 deletions(-) diff --git a/curvlinops/kfac.py b/curvlinops/kfac.py index 19ad6e7a..3fcc34bf 100644 --- a/curvlinops/kfac.py +++ b/curvlinops/kfac.py @@ -71,6 +71,11 @@ class KFACLinearOperator(_LinearOperator): _SUPPORTED_LOSSES = (MSELoss, CrossEntropyLoss) _SUPPORTED_MODULES = (Linear,) + _SUPPORTED_LOSS_AVERAGE: Tuple[Union[None, str], ...] = ( + None, + "batch", + "batch+sequence", + ) def __init__( self, @@ -84,6 +89,7 @@ def __init__( seed: int = 2147483647, fisher_type: str = "mc", mc_samples: int = 1, + loss_average: Union[None, str] = "batch", separate_weight_and_bias: bool = True, ): """Kronecker-factored approximate curvature (KFAC) proxy of the Fisher/GGN. @@ -128,6 +134,15 @@ def __init__( mc_samples: The number of Monte-Carlo samples to use per data point. Has to be set to ``1`` when ``fisher_type != 'mc'``. Defaults to ``1``. + loss_average: Whether the loss function is a mean over per-sample + losses and if yes, over which dimensions the mean is taken. + If `"batch"`, the loss function is a mean over as many terms as + the size of the mini-batch. If `"batch+sequence"`, the loss + function is a mean over as many terms as the size of the + mini-batch times the sequence length, e.g. in the case of + language modeling. If `None`, the loss function is a sum. This + argument is used to ensure that the preconditioner is scaled + consistently with the loss and the gradient. Default: `"batch"`. separate_weight_and_bias: Whether to treat weights and biases separately. Defaults to ``True``. @@ -139,6 +154,16 @@ def __init__( raise ValueError( f"Invalid loss: {loss_func}. Supported: {self._SUPPORTED_LOSSES}." ) + if loss_average not in self._SUPPORTED_LOSS_AVERAGE: + raise ValueError( + f"Invalid loss_average: {loss_average}. " + f"Supported: {self._SUPPORTED_LOSS_AVERAGE}." + ) + if loss_average is None and loss_func.reduction != "sum": + raise ValueError( + f"Invalid loss_average: {loss_average}. " + f"Must be 'batch' or 'batch+sequence' if loss_func.reduction != 'sum'." + ) if fisher_type != "mc" and mc_samples != 1: raise ValueError( f"Invalid mc_samples: {mc_samples}. " @@ -167,6 +192,7 @@ def __init__( self._separate_weight_and_bias = separate_weight_and_bias self._fisher_type = fisher_type self._mc_samples = mc_samples + self._loss_average = loss_average self._input_covariances: Dict[str, Tensor] = {} self._gradient_covariances: Dict[str, Tensor] = {} @@ -284,14 +310,17 @@ def _compute_loss_and_backward(self, output: Tensor, y: Tensor): Raises: ValueError: If ``fisher_type`` is not ``'type-2'``, ``'mc'``, or ``'empirical'``. - NotImplementedError: If ``fisher_type`` is ``'type-2'`` and the - output is not 2d. """ + # if >2d output we convert to an equivalent 2d output + if output.ndim > 2: + if isinstance(self._loss_func, CrossEntropyLoss): + output = rearrange(output, "batch c ... -> (batch ...) c") + y = rearrange(y, "batch ... -> (batch ...)") + else: + output = rearrange(output, "batch ... c -> (batch ...) c") + y = rearrange(y, "batch ... c -> (batch ...) c") + if self._fisher_type == "type-2": - if output.ndim != 2: - raise NotImplementedError( - "Type-2 Fisher not implemented for non-2d output." - ) # Compute per-sample Hessian square root, then concatenate over samples. # Result has shape `(batch_size, num_classes, num_classes)` hessian_sqrts = stack( @@ -365,20 +394,12 @@ def draw_label(self, output: Tensor) -> Tensor: return output.clone().detach() + perturbation elif isinstance(self._loss_func, CrossEntropyLoss): - # TODO For output.ndim > 2, the scale of the 'would-be' gradient resulting - # from these labels might be off - if output.ndim != 2: - raise NotImplementedError( - "Only 2D output is supported for CrossEntropyLoss for now." - ) - probs = output.softmax(dim=1) # each row contains a vector describing a categorical - probs_as_mat = rearrange(probs, "n c ... -> (n ...) c") - labels = probs_as_mat.multinomial( + probs = output.softmax(dim=1) + labels = probs.multinomial( num_samples=1, generator=self._generator ).squeeze(-1) - label_shape = output.shape[:1] + output.shape[2:] - return labels.reshape(label_shape) + return labels else: raise NotImplementedError @@ -422,20 +443,21 @@ def _accumulate_gradient_covariance(self, module: Module, grad_output: Tensor): NotImplementedError: If the layer is not supported. """ g = grad_output.data.detach() + sequence_length = g.shape[1:-1].numel() + if self._loss_average is not None: + num_loss_terms = g.shape[0] # batch_size + if self._loss_average == "batch+sequence": + # Number of all loss terms = batch_size * sequence_length + num_loss_terms *= sequence_length - if isinstance(module, Linear): - if g.ndim != 2: - # TODO Support weight sharing - raise NotImplementedError( - "Only 2d grad_outputs are supported for linear layers. " - + f"Got {g.ndim}d." - ) + g = rearrange(g, "batch ... d_out -> (batch ...) d_out") - batch_size = g.shape[0] + if isinstance(module, Linear): # self._mc_samples will be 1 if fisher_type != "mc" correction = { "sum": 1.0 / self._mc_samples, - "mean": batch_size**2 / (self._N_data * self._mc_samples), + "mean": num_loss_terms**2 + / (self._N_data * self._mc_samples * sequence_length), }[self._loss_func.reduction] covariance = einsum("bi,bj->ij", g, g).mul_(correction) else: @@ -468,21 +490,18 @@ def _hook_accumulate_input_covariance(self, module: Module, inputs: Tuple[Tensor if len(inputs) != 1: raise ValueError("Modules with multiple inputs are not supported.") x = inputs[0].data.detach() + sequence_length = x.shape[1:-1].numel() - if isinstance(module, Linear): - if x.ndim != 2: - # TODO Support weight sharing - raise NotImplementedError( - f"Only 2d inputs are supported for linear layers. Got {x.ndim}d." - ) + x = rearrange(x, "batch ... d_in -> (batch ...) d_in") + if isinstance(module, Linear): if ( self.in_params(module.weight, module.bias) and not self._separate_weight_and_bias ): x = cat([x, x.new_ones(x.shape[0], 1)], dim=1) - covariance = einsum("bi,bj->ij", x, x).div_(self._N_data) + covariance = einsum("bi,bj->ij", x, x).div_(self._N_data * sequence_length) else: # TODO Support convolutions raise NotImplementedError(f"Layer of type {type(module)} is unsupported.") From 861ec7eaaca6e388458f867d59b811803bce5754 Mon Sep 17 00:00:00 2001 From: runame Date: Tue, 9 Jan 2024 16:06:56 +0100 Subject: [PATCH 4/6] Fix black --- test/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/utils.py b/test/utils.py index 880df4e1..4d3363de 100644 --- a/test/utils.py +++ b/test/utils.py @@ -27,7 +27,7 @@ def get_available_devices(): def classification_targets(size: Tuple[int], num_classes: int) -> Tensor: """Create random targets for classes 0, ..., `num_classes - 1`. - + Args: size: Size of the targets to create. num_classes: Number of classes. @@ -40,10 +40,10 @@ def classification_targets(size: Tuple[int], num_classes: int) -> Tensor: def regression_targets(size: Tuple[int]) -> Tensor: """Create random targets for regression. - + Args: size: Size of the targets to create. - + Returns: Random targets. """ From 33f7d502648571d5bc3ab18884d3ecff7efb9e9c Mon Sep 17 00:00:00 2001 From: runame Date: Wed, 10 Jan 2024 22:52:57 +0100 Subject: [PATCH 5/6] Address review feedback --- curvlinops/kfac.py | 15 +++++++++------ test/test_kfac.py | 16 ++++++---------- test/utils.py | 27 +-------------------------- 3 files changed, 16 insertions(+), 42 deletions(-) diff --git a/curvlinops/kfac.py b/curvlinops/kfac.py index 3fcc34bf..3f65f7e9 100644 --- a/curvlinops/kfac.py +++ b/curvlinops/kfac.py @@ -136,13 +136,13 @@ def __init__( Defaults to ``1``. loss_average: Whether the loss function is a mean over per-sample losses and if yes, over which dimensions the mean is taken. - If `"batch"`, the loss function is a mean over as many terms as - the size of the mini-batch. If `"batch+sequence"`, the loss + If ``'batch'``, the loss function is a mean over as many terms as + the size of the mini-batch. If ``'batch+sequence'``, the loss function is a mean over as many terms as the size of the mini-batch times the sequence length, e.g. in the case of - language modeling. If `None`, the loss function is a sum. This + language modeling. If ``None``, the loss function is a sum. This argument is used to ensure that the preconditioner is scaled - consistently with the loss and the gradient. Default: `"batch"`. + consistently with the loss and the gradient. Default: ``'batch'``. separate_weight_and_bias: Whether to treat weights and biases separately. Defaults to ``True``. @@ -378,12 +378,16 @@ def draw_label(self, output: Tensor) -> Tensor: together with ``output``. Raises: + ValueError: If the output is not 2d. NotImplementedError: If the loss function is not supported. """ + if output.ndim != 2: + raise ValueError("Only a 2d output is supported.") + if isinstance(self._loss_func, MSELoss): std = { "sum": sqrt(1.0 / 2.0), - "mean": sqrt(output.shape[1:].numel() / 2.0), + "mean": sqrt(output.shape[1] / 2.0), }[self._loss_func.reduction] perturbation = std * randn( output.shape, @@ -394,7 +398,6 @@ def draw_label(self, output: Tensor) -> Tensor: return output.clone().detach() + perturbation elif isinstance(self._loss_func, CrossEntropyLoss): - # each row contains a vector describing a categorical probs = output.softmax(dim=1) labels = probs.multinomial( num_samples=1, generator=self._generator diff --git a/test/test_kfac.py b/test/test_kfac.py index 49b92386..b73b020d 100644 --- a/test/test_kfac.py +++ b/test/test_kfac.py @@ -1,14 +1,10 @@ """Contains tests for ``curvlinops.kfac``.""" from test.cases import DEVICES, DEVICES_IDS -from test.utils import ( - Rearrange, - classification_targets, - ggn_block_diagonal, - regression_targets, -) +from test.utils import classification_targets, ggn_block_diagonal, regression_targets from typing import Iterable, List, Tuple, Union +from einops.layers.torch import Rearrange from numpy import eye from pytest import mark from scipy.linalg import block_diag @@ -231,14 +227,14 @@ def test_multi_dim_output( loss_func = loss(reduction=reduction).to(dev) if isinstance(loss_func, MSELoss): data = [ - (rand(2, 4, 5), regression_targets((2, 4, 3))), + (rand(2, 7, 5, 5), regression_targets((2, 7, 5, 3))), (rand(4, 7, 5, 5), regression_targets((4, 7, 5, 3))), ] manual_seed(711) model = Sequential(Linear(5, 4), Linear(4, 3)).to(dev) else: data = [ - (rand(2, 4, 5), classification_targets((2, 4), 3)), + (rand(2, 7, 5, 5), classification_targets((2, 7, 5), 3)), (rand(4, 7, 5, 5), classification_targets((4, 7, 5), 3)), ] manual_seed(711) @@ -249,12 +245,12 @@ def test_multi_dim_output( Rearrange("batch ... c -> batch c ..."), ).to(dev) - # KFAC for deep linear network with 3d/4d input and output + # KFAC for deep linear network with 4d input and output params = list(model.parameters()) kfac = KFACLinearOperator(model, loss_func, params, data, fisher_type=fisher_type) kfac_mat = kfac @ eye(kfac.shape[1]) - # KFAC for deep linear network with 3d/4d input and equivalent 2d output + # KFAC for deep linear network with 4d input and equivalent 2d output manual_seed(711) model_flat = Sequential( Linear(5, 4), diff --git a/test/utils.py b/test/utils.py index 4d3363de..a617ae78 100644 --- a/test/utils.py +++ b/test/utils.py @@ -3,7 +3,6 @@ from itertools import product from typing import Iterable, List, Tuple -from einops import rearrange from numpy import eye, ndarray from torch import Tensor, cat, cuda, device, from_numpy, rand, randint from torch.nn import Module, Parameter @@ -26,7 +25,7 @@ def get_available_devices(): def classification_targets(size: Tuple[int], num_classes: int) -> Tensor: - """Create random targets for classes 0, ..., `num_classes - 1`. + """Create random targets for classes ``0``, ..., ``num_classes - 1``. Args: size: Size of the targets to create. @@ -107,27 +106,3 @@ def ggn_block_diagonal( # concatenate all blocks return cat([cat(row_blocks, dim=1) for row_blocks in ggn_blocks], dim=0).numpy() - - -class Rearrange(Module): - """A module that rearranges the input tensor.""" - - def __init__(self, pattern: str): - """Initialize the module. - - Args: - pattern: The rearrangement pattern. - """ - super().__init__() - self.pattern = pattern - - def forward(self, x: Tensor) -> Tensor: - """Rearrange the input tensor. - - Args: - x: The input tensor. - - Returns: - The rearranged tensor. - """ - return rearrange(x, self.pattern) From ad2dc3f501d41ebe9c6df6b76eb938d5c4e1316b Mon Sep 17 00:00:00 2001 From: runame Date: Thu, 11 Jan 2024 13:04:07 +0100 Subject: [PATCH 6/6] Remove superfluous if statement --- curvlinops/kfac.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/curvlinops/kfac.py b/curvlinops/kfac.py index 3f65f7e9..999739df 100644 --- a/curvlinops/kfac.py +++ b/curvlinops/kfac.py @@ -312,13 +312,12 @@ def _compute_loss_and_backward(self, output: Tensor, y: Tensor): ``'empirical'``. """ # if >2d output we convert to an equivalent 2d output - if output.ndim > 2: - if isinstance(self._loss_func, CrossEntropyLoss): - output = rearrange(output, "batch c ... -> (batch ...) c") - y = rearrange(y, "batch ... -> (batch ...)") - else: - output = rearrange(output, "batch ... c -> (batch ...) c") - y = rearrange(y, "batch ... c -> (batch ...) c") + if isinstance(self._loss_func, CrossEntropyLoss): + output = rearrange(output, "batch c ... -> (batch ...) c") + y = rearrange(y, "batch ... -> (batch ...)") + else: + output = rearrange(output, "batch ... c -> (batch ...) c") + y = rearrange(y, "batch ... c -> (batch ...) c") if self._fisher_type == "type-2": # Compute per-sample Hessian square root, then concatenate over samples.