Skip to content

Commit

Permalink
Support KFAC for >2d model outputs (#62)
Browse files Browse the repository at this point in the history
* Add Rearrange module util

* Add test for KFAC with >2d model outputs

* Add support for >2d linear layer inputs (expand) and model outputs

* Fix black

* Address review feedback

* Remove superfluous if statement
  • Loading branch information
runame authored Jan 11, 2024
1 parent 43a4822 commit 1abed46
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 40 deletions.
89 changes: 55 additions & 34 deletions curvlinops/kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ class KFACLinearOperator(_LinearOperator):

_SUPPORTED_LOSSES = (MSELoss, CrossEntropyLoss)
_SUPPORTED_MODULES = (Linear,)
_SUPPORTED_LOSS_AVERAGE: Tuple[Union[None, str], ...] = (
None,
"batch",
"batch+sequence",
)

def __init__(
self,
Expand All @@ -84,6 +89,7 @@ def __init__(
seed: int = 2147483647,
fisher_type: str = "mc",
mc_samples: int = 1,
loss_average: Union[None, str] = "batch",
separate_weight_and_bias: bool = True,
):
"""Kronecker-factored approximate curvature (KFAC) proxy of the Fisher/GGN.
Expand Down Expand Up @@ -128,6 +134,15 @@ def __init__(
mc_samples: The number of Monte-Carlo samples to use per data point.
Has to be set to ``1`` when ``fisher_type != 'mc'``.
Defaults to ``1``.
loss_average: Whether the loss function is a mean over per-sample
losses and if yes, over which dimensions the mean is taken.
If ``'batch'``, the loss function is a mean over as many terms as
the size of the mini-batch. If ``'batch+sequence'``, the loss
function is a mean over as many terms as the size of the
mini-batch times the sequence length, e.g. in the case of
language modeling. If ``None``, the loss function is a sum. This
argument is used to ensure that the preconditioner is scaled
consistently with the loss and the gradient. Default: ``'batch'``.
separate_weight_and_bias: Whether to treat weights and biases separately.
Defaults to ``True``.
Expand All @@ -139,6 +154,16 @@ def __init__(
raise ValueError(
f"Invalid loss: {loss_func}. Supported: {self._SUPPORTED_LOSSES}."
)
if loss_average not in self._SUPPORTED_LOSS_AVERAGE:
raise ValueError(
f"Invalid loss_average: {loss_average}. "
f"Supported: {self._SUPPORTED_LOSS_AVERAGE}."
)
if loss_average is None and loss_func.reduction != "sum":
raise ValueError(
f"Invalid loss_average: {loss_average}. "
f"Must be 'batch' or 'batch+sequence' if loss_func.reduction != 'sum'."
)
if fisher_type != "mc" and mc_samples != 1:
raise ValueError(
f"Invalid mc_samples: {mc_samples}. "
Expand Down Expand Up @@ -167,6 +192,7 @@ def __init__(
self._separate_weight_and_bias = separate_weight_and_bias
self._fisher_type = fisher_type
self._mc_samples = mc_samples
self._loss_average = loss_average
self._input_covariances: Dict[str, Tensor] = {}
self._gradient_covariances: Dict[str, Tensor] = {}

Expand Down Expand Up @@ -284,14 +310,16 @@ def _compute_loss_and_backward(self, output: Tensor, y: Tensor):
Raises:
ValueError: If ``fisher_type`` is not ``'type-2'``, ``'mc'``, or
``'empirical'``.
NotImplementedError: If ``fisher_type`` is ``'type-2'`` and the
output is not 2d.
"""
# if >2d output we convert to an equivalent 2d output
if isinstance(self._loss_func, CrossEntropyLoss):
output = rearrange(output, "batch c ... -> (batch ...) c")
y = rearrange(y, "batch ... -> (batch ...)")
else:
output = rearrange(output, "batch ... c -> (batch ...) c")
y = rearrange(y, "batch ... c -> (batch ...) c")

if self._fisher_type == "type-2":
if output.ndim != 2:
raise NotImplementedError(
"Type-2 Fisher not implemented for non-2d output."
)
# Compute per-sample Hessian square root, then concatenate over samples.
# Result has shape `(batch_size, num_classes, num_classes)`
hessian_sqrts = stack(
Expand Down Expand Up @@ -349,12 +377,16 @@ def draw_label(self, output: Tensor) -> Tensor:
together with ``output``.
Raises:
ValueError: If the output is not 2d.
NotImplementedError: If the loss function is not supported.
"""
if output.ndim != 2:
raise ValueError("Only a 2d output is supported.")

if isinstance(self._loss_func, MSELoss):
std = {
"sum": sqrt(1.0 / 2.0),
"mean": sqrt(output.shape[1:].numel() / 2.0),
"mean": sqrt(output.shape[1] / 2.0),
}[self._loss_func.reduction]
perturbation = std * randn(
output.shape,
Expand All @@ -365,20 +397,11 @@ def draw_label(self, output: Tensor) -> Tensor:
return output.clone().detach() + perturbation

elif isinstance(self._loss_func, CrossEntropyLoss):
# TODO For output.ndim > 2, the scale of the 'would-be' gradient resulting
# from these labels might be off
if output.ndim != 2:
raise NotImplementedError(
"Only 2D output is supported for CrossEntropyLoss for now."
)
probs = output.softmax(dim=1)
# each row contains a vector describing a categorical
probs_as_mat = rearrange(probs, "n c ... -> (n ...) c")
labels = probs_as_mat.multinomial(
labels = probs.multinomial(
num_samples=1, generator=self._generator
).squeeze(-1)
label_shape = output.shape[:1] + output.shape[2:]
return labels.reshape(label_shape)
return labels

else:
raise NotImplementedError
Expand Down Expand Up @@ -422,20 +445,21 @@ def _accumulate_gradient_covariance(self, module: Module, grad_output: Tensor):
NotImplementedError: If the layer is not supported.
"""
g = grad_output.data.detach()
sequence_length = g.shape[1:-1].numel()
if self._loss_average is not None:
num_loss_terms = g.shape[0] # batch_size
if self._loss_average == "batch+sequence":
# Number of all loss terms = batch_size * sequence_length
num_loss_terms *= sequence_length

if isinstance(module, Linear):
if g.ndim != 2:
# TODO Support weight sharing
raise NotImplementedError(
"Only 2d grad_outputs are supported for linear layers. "
+ f"Got {g.ndim}d."
)
g = rearrange(g, "batch ... d_out -> (batch ...) d_out")

batch_size = g.shape[0]
if isinstance(module, Linear):
# self._mc_samples will be 1 if fisher_type != "mc"
correction = {
"sum": 1.0 / self._mc_samples,
"mean": batch_size**2 / (self._N_data * self._mc_samples),
"mean": num_loss_terms**2
/ (self._N_data * self._mc_samples * sequence_length),
}[self._loss_func.reduction]
covariance = einsum("bi,bj->ij", g, g).mul_(correction)
else:
Expand Down Expand Up @@ -468,21 +492,18 @@ def _hook_accumulate_input_covariance(self, module: Module, inputs: Tuple[Tensor
if len(inputs) != 1:
raise ValueError("Modules with multiple inputs are not supported.")
x = inputs[0].data.detach()
sequence_length = x.shape[1:-1].numel()

if isinstance(module, Linear):
if x.ndim != 2:
# TODO Support weight sharing
raise NotImplementedError(
f"Only 2d inputs are supported for linear layers. Got {x.ndim}d."
)
x = rearrange(x, "batch ... d_in -> (batch ...) d_in")

if isinstance(module, Linear):
if (
self.in_params(module.weight, module.bias)
and not self._separate_weight_and_bias
):
x = cat([x, x.new_ones(x.shape[0], 1)], dim=1)

covariance = einsum("bi,bj->ij", x, x).div_(self._N_data)
covariance = einsum("bi,bj->ij", x, x).div_(self._N_data * sequence_length)
else:
# TODO Support convolutions
raise NotImplementedError(f"Layer of type {type(module)} is unsupported.")
Expand Down
74 changes: 72 additions & 2 deletions test/test_kfac.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
"""Contains tests for ``curvlinops.kfac``."""

from test.cases import DEVICES, DEVICES_IDS
from test.utils import ggn_block_diagonal, regression_targets
from typing import Iterable, List, Tuple
from test.utils import classification_targets, ggn_block_diagonal, regression_targets
from typing import Iterable, List, Tuple, Union

from einops.layers.torch import Rearrange
from numpy import eye
from pytest import mark
from scipy.linalg import block_diag
from torch import Tensor, device, manual_seed, rand, randperm
from torch.nn import (
CrossEntropyLoss,
Flatten,
Linear,
Module,
MSELoss,
Expand Down Expand Up @@ -200,3 +202,71 @@ def test_kfac_inplace_activations(dev: device):
ggn_no_inplace = ggn_block_diagonal(model, loss_func, params, data)

report_nonclose(ggn, ggn_no_inplace)


@mark.parametrize("fisher_type", ["type-2", "mc", "empirical"])
@mark.parametrize("loss", [MSELoss, CrossEntropyLoss], ids=["mse", "ce"])
@mark.parametrize("reduction", ["mean", "sum"])
@mark.parametrize("dev", DEVICES, ids=DEVICES_IDS)
def test_multi_dim_output(
fisher_type: str,
loss: Union[MSELoss, CrossEntropyLoss],
reduction: str,
dev: device,
):
"""Test the KFAC implementation for >2d outputs (using a 3d and 4d output).
Args:
fisher_type: The type of Fisher matrix to use.
loss: The loss function to use.
reduction: The reduction to use for the loss function.
dev: The device to run the test on.
"""
manual_seed(0)
# set up loss function, data, and model
loss_func = loss(reduction=reduction).to(dev)
if isinstance(loss_func, MSELoss):
data = [
(rand(2, 7, 5, 5), regression_targets((2, 7, 5, 3))),
(rand(4, 7, 5, 5), regression_targets((4, 7, 5, 3))),
]
manual_seed(711)
model = Sequential(Linear(5, 4), Linear(4, 3)).to(dev)
else:
data = [
(rand(2, 7, 5, 5), classification_targets((2, 7, 5), 3)),
(rand(4, 7, 5, 5), classification_targets((4, 7, 5), 3)),
]
manual_seed(711)
# rearrange is necessary to get the expected output shape for ce loss
model = Sequential(
Linear(5, 4),
Linear(4, 3),
Rearrange("batch ... c -> batch c ..."),
).to(dev)

# KFAC for deep linear network with 4d input and output
params = list(model.parameters())
kfac = KFACLinearOperator(model, loss_func, params, data, fisher_type=fisher_type)
kfac_mat = kfac @ eye(kfac.shape[1])

# KFAC for deep linear network with 4d input and equivalent 2d output
manual_seed(711)
model_flat = Sequential(
Linear(5, 4),
Linear(4, 3),
Flatten(start_dim=0, end_dim=-2),
).to(dev)
params_flat = list(model_flat.parameters())
data_flat = [
(x, y.flatten(start_dim=0, end_dim=-2))
if isinstance(loss_func, MSELoss)
else (x, y.flatten(start_dim=0))
for x, y in data
]
kfac_flat = KFACLinearOperator(
model_flat, loss_func, params_flat, data_flat, fisher_type=fisher_type
)
kfac_flat_mat = kfac_flat @ eye(kfac_flat.shape[1])

report_nonclose(kfac_mat, kfac_flat_mat)
23 changes: 19 additions & 4 deletions test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,28 @@ def get_available_devices():
return devices


def classification_targets(size, num_classes):
"""Create random targets for classes 0, ..., `num_classes - 1`."""
def classification_targets(size: Tuple[int], num_classes: int) -> Tensor:
"""Create random targets for classes ``0``, ..., ``num_classes - 1``.
Args:
size: Size of the targets to create.
num_classes: Number of classes.
Returns:
Random targets.
"""
return randint(size=size, low=0, high=num_classes)


def regression_targets(size):
"""Create random targets for regression."""
def regression_targets(size: Tuple[int]) -> Tensor:
"""Create random targets for regression.
Args:
size: Size of the targets to create.
Returns:
Random targets.
"""
return rand(*size)


Expand Down

0 comments on commit 1abed46

Please sign in to comment.