Skip to content

Commit

Permalink
[ADD] Option to treat weight and bias of a layer jointly in KFAC (#57)
Browse files Browse the repository at this point in the history
* [ADD] Attempt supporting treating weights and biases jointly in KFAC

* [DOC] Update limitation

* [DEL] Remove `assert`s

* [FIX] Make tests work

* [REF] More verbose id
  • Loading branch information
f-dangel authored Nov 8, 2023
1 parent 2eb654f commit e842f31
Show file tree
Hide file tree
Showing 3 changed files with 147 additions and 27 deletions.
68 changes: 56 additions & 12 deletions curvlinops/kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from einops import rearrange
from numpy import ndarray
from torch import Generator, Tensor, einsum, randn
from torch import Generator, Tensor, cat, einsum, randn
from torch.nn import CrossEntropyLoss, Linear, Module, MSELoss, Parameter
from torch.utils.hooks import RemovableHandle

Expand Down Expand Up @@ -82,6 +82,7 @@ def __init__(
seed: int = 2147483647,
fisher_type: str = "mc",
mc_samples: int = 1,
separate_weight_and_bias: bool = True,
):
"""Kronecker-factored approximate curvature (KFAC) proxy of the Fisher/GGN.
Expand All @@ -96,7 +97,6 @@ def __init__(
This is an early proto-type with many limitations:
- Only linear layers are supported.
- Weights and biases are treated separately.
- No weight sharing is supported.
- Only the ``'expand'`` approximation is supported.
Expand Down Expand Up @@ -126,11 +126,12 @@ def __init__(
mc_samples: The number of Monte-Carlo samples to use per data point.
Will be ignored when ``fisher_type`` is not ``'mc'``.
Defaults to ``1``.
separate_weight_and_bias: Whether to treat weights and biases separately.
Defaults to ``True``.
Raises:
ValueError: If the loss function is not supported.
NotImplementedError: If any parameter cannot be identified with a supported
layer.
NotImplementedError: If a parameter is in an unsupported layer.
"""
if not isinstance(loss_func, self._SUPPORTED_LOSSES):
raise ValueError(
Expand All @@ -156,6 +157,7 @@ def __init__(

self._seed = seed
self._generator: Union[None, Generator] = None
self._separate_weight_and_bias = separate_weight_and_bias
self._fisher_type = fisher_type
self._mc_samples = mc_samples
self._input_covariances: Dict[str, Tensor] = {}
Expand Down Expand Up @@ -190,16 +192,30 @@ def _matvec(self, x: ndarray) -> ndarray:
for name in self.param_ids_to_hooked_modules.values():
mod = self._model_func.get_submodule(name)

if mod.weight.data_ptr() in self.param_ids:
idx = self.param_ids.index(mod.weight.data_ptr())
# bias and weights are treated jointly
if not self._separate_weight_and_bias and self.in_params(
mod.weight, mod.bias
):
w_pos, b_pos = self.param_pos(mod.weight), self.param_pos(mod.bias)
x_joint = cat([x_torch[w_pos], x_torch[b_pos].unsqueeze(-1)], dim=1)
aaT = self._input_covariances[name]
ggT = self._gradient_covariances[name]
x_torch[idx] = ggT @ x_torch[idx] @ aaT
x_joint = ggT @ x_joint @ aaT

if mod.bias is not None and mod.bias.data_ptr() in self.param_ids:
idx = self.param_ids.index(mod.bias.data_ptr())
ggT = self._gradient_covariances[name]
x_torch[idx] = ggT @ x_torch[idx]
w_cols = mod.weight.shape[1]
x_torch[w_pos], x_torch[b_pos] = x_joint.split([w_cols, 1], dim=1)

# for weights we need to multiply from the right with aaT
# for weights and biases we need to multiply from the left with ggT
else:
for p_name in ["weight", "bias"]:
p = getattr(mod, p_name)
if self.in_params(p):
pos = self.param_pos(p)
x_torch[pos] = self._gradient_covariances[name] @ x_torch[pos]

if p_name == "weight":
x_torch[pos] = x_torch[pos] @ self._input_covariances[name]

return super()._postprocess(x_torch)

Expand Down Expand Up @@ -228,7 +244,7 @@ def _compute_kfac(self):
module = self._model_func.get_submodule(name)

# input covariance only required for weights
if module.weight.data_ptr() in self.param_ids:
if self.in_params(module.weight):
hook_handles.append(
module.register_forward_pre_hook(
self._hook_accumulate_input_covariance
Expand Down Expand Up @@ -402,6 +418,12 @@ def _hook_accumulate_input_covariance(self, module: Module, inputs: Tuple[Tensor
f"Only 2d inputs are supported for linear layers. Got {x.ndim}d."
)

if (
self.in_params(module.weight, module.bias)
and not self._separate_weight_and_bias
):
x = cat([x, x.new_ones(x.shape[0], 1)], dim=1)

covariance = einsum("bi,bj->ij", x, x).div_(self._N_data)
else:
# TODO Support convolutions
Expand All @@ -424,3 +446,25 @@ def get_module_name(self, module: Module) -> str:
"""
p_ids = tuple(p.data_ptr() for p in module.parameters())
return self.param_ids_to_hooked_modules[p_ids]

def in_params(self, *params: Union[Parameter, Tensor, None]) -> bool:
"""Check if all parameters are used in KFAC.
Args:
params: Parameters to check.
Returns:
Whether all parameters are used in KFAC.
"""
return all(p is not None and p.data_ptr() in self.param_ids for p in params)

def param_pos(self, param: Union[Parameter, Tensor]) -> int:
"""Get the position of a parameter in the list of parameters used in KFAC.
Args:
param: The parameter.
Returns:
The parameter's position in the parameter list.
"""
return self.param_ids.index(param.data_ptr())
38 changes: 24 additions & 14 deletions test/test_kfac.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Contains tests for ``curvlinops.kfac``."""

from test.utils import ggn_block_diagonal
from typing import Iterable, List, Tuple

from numpy import eye
Expand All @@ -9,11 +10,13 @@
from torch.nn import Module, MSELoss, Parameter

from curvlinops.examples.utils import report_nonclose
from curvlinops.ggn import GGNLinearOperator
from curvlinops.gradient_moments import EFLinearOperator
from curvlinops.kfac import KFACLinearOperator


@mark.parametrize(
"separate_weight_and_bias", [True, False], ids=["separate_bias", "joint_bias"]
)
@mark.parametrize(
"exclude", [None, "weight", "bias"], ids=["all", "no_weights", "no_biases"]
)
Expand All @@ -24,6 +27,7 @@ def test_kfac(
],
shuffle: bool,
exclude: str,
separate_weight_and_bias: bool,
):
"""Test the KFAC implementation against the exact GGN.
Expand All @@ -33,6 +37,8 @@ def test_kfac(
shuffle: Whether to shuffle the parameters before computing the KFAC matrix.
exclude: Which parameters to exclude. Can be ``'weight'``, ``'bias'``,
or ``None``.
separate_weight_and_bias: Whether to treat weight and bias as separate blocks in
the KFAC matrix.
"""
assert exclude in [None, "weight", "bias"]
model, loss_func, params, data = kfac_expand_exact_case
Expand All @@ -45,13 +51,22 @@ def test_kfac(
permutation = randperm(len(params))
params = [params[i] for i in permutation]

ggn_blocks = [] # list of per-parameter GGNs
for param in params:
ggn = GGNLinearOperator(model, loss_func, [param], data)
ggn_blocks.append(ggn @ eye(ggn.shape[1]))
ggn = block_diag(*ggn_blocks)

kfac = KFACLinearOperator(model, loss_func, params, data, mc_samples=2_000)
ggn = ggn_block_diagonal(
model,
loss_func,
params,
data,
separate_weight_and_bias=separate_weight_and_bias,
)

kfac = KFACLinearOperator(
model,
loss_func,
params,
data,
mc_samples=2_000,
separate_weight_and_bias=separate_weight_and_bias,
)
kfac_mat = kfac @ eye(kfac.shape[1])

atol = {"sum": 5e-1, "mean": 5e-3}[loss_func.reduction]
Expand All @@ -70,12 +85,7 @@ def test_kfac_one_datum(
]
):
model, loss_func, params, data = kfac_expand_exact_one_datum_case

ggn_blocks = [] # list of per-parameter GGNs
for param in params:
ggn = GGNLinearOperator(model, loss_func, [param], data)
ggn_blocks.append(ggn @ eye(ggn.shape[1]))
ggn = block_diag(*ggn_blocks)
ggn = ggn_block_diagonal(model, loss_func, params, data)

kfac = KFACLinearOperator(model, loss_func, params, data, mc_samples=10_000)
kfac_mat = kfac @ eye(kfac.shape[1])
Expand Down
68 changes: 67 additions & 1 deletion test/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
"""Utility functions to test `curvlinops`."""

from torch import cuda, device, rand, randint
from itertools import product
from typing import Iterable, List, Tuple

from numpy import eye, ndarray
from torch import Tensor, cat, cuda, device, from_numpy, rand, randint
from torch.nn import Module, Parameter

from curvlinops import GGNLinearOperator


def get_available_devices():
Expand All @@ -25,3 +32,62 @@ def classification_targets(size, num_classes):
def regression_targets(size):
"""Create random targets for regression."""
return rand(*size)


def ggn_block_diagonal(
model: Module,
loss_func: Module,
params: List[Parameter],
data: Iterable[Tuple[Tensor, Tensor]],
separate_weight_and_bias: bool = True,
) -> ndarray:
"""Compute the block-diagonal GGN.
Args:
model: The neural network.
loss_func: The loss function.
params: The parameters w.r.t. which the GGN block-diagonals will be computed.
data: A data loader.
separate_weight_and_bias: Whether to treat weight and bias of a layer as
separate blocks in the block-diagonal GGN. Default: ``True``.
Returns:
The block-diagonal GGN.
"""
# compute the full GGN then zero out the off-diagonal blocks
ggn = GGNLinearOperator(model, loss_func, params, data)
ggn = from_numpy(ggn @ eye(ggn.shape[1]))
sizes = [p.numel() for p in params]
# ggn_blocks[i, j] corresponds to the block of (params[i], params[j])
ggn_blocks = [list(block.split(sizes, dim=1)) for block in ggn.split(sizes, dim=0)]

# find out which blocks to keep
num_params = len(params)
keep = [(i, i) for i in range(num_params)]
param_ids = [p.data_ptr() for p in params]

# keep blocks corresponding to jointly-treated weights and biases
if not separate_weight_and_bias:
# find all layers with weight and bias
has_weight_and_bias = [
mod
for mod in model.modules()
if hasattr(mod, "weight") and hasattr(mod, "bias") and mod.bias is not None
]
# only keep those whose parameters are included
has_weight_and_bias = [
mod
for mod in has_weight_and_bias
if mod.weight.data_ptr() in param_ids and mod.bias.data_ptr() in param_ids
]
for mod in has_weight_and_bias:
w_pos = param_ids.index(mod.weight.data_ptr())
b_pos = param_ids.index(mod.bias.data_ptr())
keep.extend([(w_pos, b_pos), (b_pos, w_pos)])

for i, j in product(range(num_params), range(num_params)):
if (i, j) not in keep:
ggn_blocks[i][j].zero_()

# concatenate all blocks
return cat([cat(row_blocks, dim=1) for row_blocks in ggn_blocks], dim=0).numpy()

0 comments on commit e842f31

Please sign in to comment.