Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support KFAC for >2d model outputs #62

Merged
merged 6 commits into from
Jan 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 55 additions & 34 deletions curvlinops/kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ class KFACLinearOperator(_LinearOperator):

_SUPPORTED_LOSSES = (MSELoss, CrossEntropyLoss)
_SUPPORTED_MODULES = (Linear,)
_SUPPORTED_LOSS_AVERAGE: Tuple[Union[None, str], ...] = (
None,
"batch",
"batch+sequence",
)

def __init__(
self,
Expand All @@ -84,6 +89,7 @@ def __init__(
seed: int = 2147483647,
fisher_type: str = "mc",
mc_samples: int = 1,
loss_average: Union[None, str] = "batch",
separate_weight_and_bias: bool = True,
):
"""Kronecker-factored approximate curvature (KFAC) proxy of the Fisher/GGN.
Expand Down Expand Up @@ -128,6 +134,15 @@ def __init__(
mc_samples: The number of Monte-Carlo samples to use per data point.
Has to be set to ``1`` when ``fisher_type != 'mc'``.
Defaults to ``1``.
loss_average: Whether the loss function is a mean over per-sample
losses and if yes, over which dimensions the mean is taken.
If ``'batch'``, the loss function is a mean over as many terms as
the size of the mini-batch. If ``'batch+sequence'``, the loss
function is a mean over as many terms as the size of the
mini-batch times the sequence length, e.g. in the case of
language modeling. If ``None``, the loss function is a sum. This
argument is used to ensure that the preconditioner is scaled
consistently with the loss and the gradient. Default: ``'batch'``.
separate_weight_and_bias: Whether to treat weights and biases separately.
Defaults to ``True``.

Expand All @@ -139,6 +154,16 @@ def __init__(
raise ValueError(
f"Invalid loss: {loss_func}. Supported: {self._SUPPORTED_LOSSES}."
)
if loss_average not in self._SUPPORTED_LOSS_AVERAGE:
raise ValueError(
f"Invalid loss_average: {loss_average}. "
f"Supported: {self._SUPPORTED_LOSS_AVERAGE}."
)
if loss_average is None and loss_func.reduction != "sum":
raise ValueError(
f"Invalid loss_average: {loss_average}. "
f"Must be 'batch' or 'batch+sequence' if loss_func.reduction != 'sum'."
)
if fisher_type != "mc" and mc_samples != 1:
raise ValueError(
f"Invalid mc_samples: {mc_samples}. "
Expand Down Expand Up @@ -167,6 +192,7 @@ def __init__(
self._separate_weight_and_bias = separate_weight_and_bias
self._fisher_type = fisher_type
self._mc_samples = mc_samples
self._loss_average = loss_average
self._input_covariances: Dict[str, Tensor] = {}
self._gradient_covariances: Dict[str, Tensor] = {}

Expand Down Expand Up @@ -284,14 +310,16 @@ def _compute_loss_and_backward(self, output: Tensor, y: Tensor):
Raises:
ValueError: If ``fisher_type`` is not ``'type-2'``, ``'mc'``, or
``'empirical'``.
NotImplementedError: If ``fisher_type`` is ``'type-2'`` and the
output is not 2d.
"""
# if >2d output we convert to an equivalent 2d output
if isinstance(self._loss_func, CrossEntropyLoss):
output = rearrange(output, "batch c ... -> (batch ...) c")
y = rearrange(y, "batch ... -> (batch ...)")
else:
output = rearrange(output, "batch ... c -> (batch ...) c")
y = rearrange(y, "batch ... c -> (batch ...) c")

if self._fisher_type == "type-2":
if output.ndim != 2:
raise NotImplementedError(
"Type-2 Fisher not implemented for non-2d output."
)
# Compute per-sample Hessian square root, then concatenate over samples.
# Result has shape `(batch_size, num_classes, num_classes)`
hessian_sqrts = stack(
Expand Down Expand Up @@ -349,12 +377,16 @@ def draw_label(self, output: Tensor) -> Tensor:
together with ``output``.

Raises:
ValueError: If the output is not 2d.
NotImplementedError: If the loss function is not supported.
"""
if output.ndim != 2:
raise ValueError("Only a 2d output is supported.")

if isinstance(self._loss_func, MSELoss):
std = {
"sum": sqrt(1.0 / 2.0),
"mean": sqrt(output.shape[1:].numel() / 2.0),
"mean": sqrt(output.shape[1] / 2.0),
}[self._loss_func.reduction]
perturbation = std * randn(
output.shape,
Expand All @@ -365,20 +397,11 @@ def draw_label(self, output: Tensor) -> Tensor:
return output.clone().detach() + perturbation

elif isinstance(self._loss_func, CrossEntropyLoss):
# TODO For output.ndim > 2, the scale of the 'would-be' gradient resulting
# from these labels might be off
if output.ndim != 2:
f-dangel marked this conversation as resolved.
Show resolved Hide resolved
raise NotImplementedError(
"Only 2D output is supported for CrossEntropyLoss for now."
)
probs = output.softmax(dim=1)
# each row contains a vector describing a categorical
probs_as_mat = rearrange(probs, "n c ... -> (n ...) c")
runame marked this conversation as resolved.
Show resolved Hide resolved
labels = probs_as_mat.multinomial(
labels = probs.multinomial(
num_samples=1, generator=self._generator
).squeeze(-1)
label_shape = output.shape[:1] + output.shape[2:]
return labels.reshape(label_shape)
return labels

else:
raise NotImplementedError
Expand Down Expand Up @@ -422,20 +445,21 @@ def _accumulate_gradient_covariance(self, module: Module, grad_output: Tensor):
NotImplementedError: If the layer is not supported.
"""
g = grad_output.data.detach()
sequence_length = g.shape[1:-1].numel()
if self._loss_average is not None:
num_loss_terms = g.shape[0] # batch_size
if self._loss_average == "batch+sequence":
# Number of all loss terms = batch_size * sequence_length
num_loss_terms *= sequence_length

if isinstance(module, Linear):
if g.ndim != 2:
# TODO Support weight sharing
raise NotImplementedError(
"Only 2d grad_outputs are supported for linear layers. "
+ f"Got {g.ndim}d."
)
g = rearrange(g, "batch ... d_out -> (batch ...) d_out")

batch_size = g.shape[0]
if isinstance(module, Linear):
# self._mc_samples will be 1 if fisher_type != "mc"
correction = {
"sum": 1.0 / self._mc_samples,
"mean": batch_size**2 / (self._N_data * self._mc_samples),
"mean": num_loss_terms**2
/ (self._N_data * self._mc_samples * sequence_length),
}[self._loss_func.reduction]
covariance = einsum("bi,bj->ij", g, g).mul_(correction)
else:
Expand Down Expand Up @@ -468,21 +492,18 @@ def _hook_accumulate_input_covariance(self, module: Module, inputs: Tuple[Tensor
if len(inputs) != 1:
raise ValueError("Modules with multiple inputs are not supported.")
x = inputs[0].data.detach()
sequence_length = x.shape[1:-1].numel()

if isinstance(module, Linear):
if x.ndim != 2:
# TODO Support weight sharing
raise NotImplementedError(
f"Only 2d inputs are supported for linear layers. Got {x.ndim}d."
)
x = rearrange(x, "batch ... d_in -> (batch ...) d_in")

if isinstance(module, Linear):
if (
self.in_params(module.weight, module.bias)
and not self._separate_weight_and_bias
):
x = cat([x, x.new_ones(x.shape[0], 1)], dim=1)

covariance = einsum("bi,bj->ij", x, x).div_(self._N_data)
covariance = einsum("bi,bj->ij", x, x).div_(self._N_data * sequence_length)
else:
# TODO Support convolutions
raise NotImplementedError(f"Layer of type {type(module)} is unsupported.")
Expand Down
74 changes: 72 additions & 2 deletions test/test_kfac.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
"""Contains tests for ``curvlinops.kfac``."""

from test.cases import DEVICES, DEVICES_IDS
from test.utils import ggn_block_diagonal, regression_targets
from typing import Iterable, List, Tuple
from test.utils import classification_targets, ggn_block_diagonal, regression_targets
from typing import Iterable, List, Tuple, Union

from einops.layers.torch import Rearrange
from numpy import eye
from pytest import mark
from scipy.linalg import block_diag
from torch import Tensor, device, manual_seed, rand, randperm
from torch.nn import (
CrossEntropyLoss,
Flatten,
Linear,
Module,
MSELoss,
Expand Down Expand Up @@ -200,3 +202,71 @@ def test_kfac_inplace_activations(dev: device):
ggn_no_inplace = ggn_block_diagonal(model, loss_func, params, data)

report_nonclose(ggn, ggn_no_inplace)


@mark.parametrize("fisher_type", ["type-2", "mc", "empirical"])
@mark.parametrize("loss", [MSELoss, CrossEntropyLoss], ids=["mse", "ce"])
@mark.parametrize("reduction", ["mean", "sum"])
@mark.parametrize("dev", DEVICES, ids=DEVICES_IDS)
def test_multi_dim_output(
fisher_type: str,
loss: Union[MSELoss, CrossEntropyLoss],
reduction: str,
dev: device,
):
"""Test the KFAC implementation for >2d outputs (using a 3d and 4d output).

Args:
fisher_type: The type of Fisher matrix to use.
loss: The loss function to use.
reduction: The reduction to use for the loss function.
dev: The device to run the test on.
"""
manual_seed(0)
# set up loss function, data, and model
loss_func = loss(reduction=reduction).to(dev)
if isinstance(loss_func, MSELoss):
data = [
(rand(2, 7, 5, 5), regression_targets((2, 7, 5, 3))),
(rand(4, 7, 5, 5), regression_targets((4, 7, 5, 3))),
]
manual_seed(711)
model = Sequential(Linear(5, 4), Linear(4, 3)).to(dev)
else:
data = [
(rand(2, 7, 5, 5), classification_targets((2, 7, 5), 3)),
(rand(4, 7, 5, 5), classification_targets((4, 7, 5), 3)),
]
manual_seed(711)
# rearrange is necessary to get the expected output shape for ce loss
model = Sequential(
Linear(5, 4),
Linear(4, 3),
Rearrange("batch ... c -> batch c ..."),
).to(dev)

# KFAC for deep linear network with 4d input and output
params = list(model.parameters())
kfac = KFACLinearOperator(model, loss_func, params, data, fisher_type=fisher_type)
kfac_mat = kfac @ eye(kfac.shape[1])

# KFAC for deep linear network with 4d input and equivalent 2d output
manual_seed(711)
model_flat = Sequential(
Linear(5, 4),
Linear(4, 3),
Flatten(start_dim=0, end_dim=-2),
).to(dev)
params_flat = list(model_flat.parameters())
data_flat = [
(x, y.flatten(start_dim=0, end_dim=-2))
if isinstance(loss_func, MSELoss)
else (x, y.flatten(start_dim=0))
for x, y in data
]
kfac_flat = KFACLinearOperator(
model_flat, loss_func, params_flat, data_flat, fisher_type=fisher_type
)
kfac_flat_mat = kfac_flat @ eye(kfac_flat.shape[1])

report_nonclose(kfac_mat, kfac_flat_mat)
23 changes: 19 additions & 4 deletions test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,28 @@ def get_available_devices():
return devices


def classification_targets(size, num_classes):
"""Create random targets for classes 0, ..., `num_classes - 1`."""
def classification_targets(size: Tuple[int], num_classes: int) -> Tensor:
"""Create random targets for classes ``0``, ..., ``num_classes - 1``.

Args:
size: Size of the targets to create.
num_classes: Number of classes.

Returns:
Random targets.
"""
return randint(size=size, low=0, high=num_classes)


def regression_targets(size):
"""Create random targets for regression."""
def regression_targets(size: Tuple[int]) -> Tensor:
"""Create random targets for regression.

Args:
size: Size of the targets to create.

Returns:
Random targets.
"""
return rand(*size)


Expand Down
Loading