Skip to content

Commit

Permalink
Merge branch 'development' into kfac-type-2
Browse files Browse the repository at this point in the history
Conflicts:
	curvlinops/kfac.py
	test/test_kfac.py
  • Loading branch information
runame committed Nov 8, 2023
2 parents ab33832 + e842f31 commit 6f7387a
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 39 deletions.
68 changes: 56 additions & 12 deletions curvlinops/kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from einops import rearrange
from numpy import ndarray
from torch import Generator, Tensor, einsum
from torch import Generator, Tensor, cat, einsum
from torch import mean as torch_mean
from torch import no_grad, randn
from torch import sum as torch_sum
Expand Down Expand Up @@ -86,6 +86,7 @@ def __init__(
seed: int = 2147483647,
fisher_type: str = "mc",
mc_samples: int = 1,
separate_weight_and_bias: bool = True,
):
"""Kronecker-factored approximate curvature (KFAC) proxy of the Fisher/GGN.
Expand All @@ -100,7 +101,6 @@ def __init__(
This is an early proto-type with many limitations:
- Only linear layers are supported.
- Weights and biases are treated separately.
- No weight sharing is supported.
- Only the ``'expand'`` approximation is supported.
Expand Down Expand Up @@ -130,11 +130,12 @@ def __init__(
mc_samples: The number of Monte-Carlo samples to use per data point.
Has to be set to ``1`` when ``fisher_type != 'mc'``.
Defaults to ``1``.
separate_weight_and_bias: Whether to treat weights and biases separately.
Defaults to ``True``.
Raises:
ValueError: If the loss function is not supported.
NotImplementedError: If any parameter cannot be identified with a supported
layer.
NotImplementedError: If a parameter is in an unsupported layer.
"""
if not isinstance(loss_func, self._SUPPORTED_LOSSES):
raise ValueError(
Expand Down Expand Up @@ -165,6 +166,7 @@ def __init__(

self._seed = seed
self._generator: Union[None, Generator] = None
self._separate_weight_and_bias = separate_weight_and_bias
self._fisher_type = fisher_type
self._mc_samples = mc_samples
self._input_covariances: Dict[str, Tensor] = {}
Expand Down Expand Up @@ -199,16 +201,30 @@ def _matvec(self, x: ndarray) -> ndarray:
for name in self.param_ids_to_hooked_modules.values():
mod = self._model_func.get_submodule(name)

if mod.weight.data_ptr() in self.param_ids:
idx = self.param_ids.index(mod.weight.data_ptr())
# bias and weights are treated jointly
if not self._separate_weight_and_bias and self.in_params(
mod.weight, mod.bias
):
w_pos, b_pos = self.param_pos(mod.weight), self.param_pos(mod.bias)
x_joint = cat([x_torch[w_pos], x_torch[b_pos].unsqueeze(-1)], dim=1)
aaT = self._input_covariances[name]
ggT = self._gradient_covariances[name]
x_torch[idx] = ggT @ x_torch[idx] @ aaT
x_joint = ggT @ x_joint @ aaT

if mod.bias is not None and mod.bias.data_ptr() in self.param_ids:
idx = self.param_ids.index(mod.bias.data_ptr())
ggT = self._gradient_covariances[name]
x_torch[idx] = ggT @ x_torch[idx]
w_cols = mod.weight.shape[1]
x_torch[w_pos], x_torch[b_pos] = x_joint.split([w_cols, 1], dim=1)

# for weights we need to multiply from the right with aaT
# for weights and biases we need to multiply from the left with ggT
else:
for p_name in ["weight", "bias"]:
p = getattr(mod, p_name)
if self.in_params(p):
pos = self.param_pos(p)
x_torch[pos] = self._gradient_covariances[name] @ x_torch[pos]

if p_name == "weight":
x_torch[pos] = x_torch[pos] @ self._input_covariances[name]

return super()._postprocess(x_torch)

Expand All @@ -231,7 +247,7 @@ def _compute_kfac(self):
module = self._model_func.get_submodule(name)

# input covariance only required for weights
if module.weight.data_ptr() in self.param_ids:
if self.in_params(module.weight):
hook_handles.append(
module.register_forward_pre_hook(
self._hook_accumulate_input_covariance
Expand Down Expand Up @@ -437,6 +453,12 @@ def _hook_accumulate_input_covariance(self, module: Module, inputs: Tuple[Tensor
f"Only 2d inputs are supported for linear layers. Got {x.ndim}d."
)

if (
self.in_params(module.weight, module.bias)
and not self._separate_weight_and_bias
):
x = cat([x, x.new_ones(x.shape[0], 1)], dim=1)

covariance = einsum("bi,bj->ij", x, x).div_(self._N_data)
else:
# TODO Support convolutions
Expand All @@ -459,3 +481,25 @@ def get_module_name(self, module: Module) -> str:
"""
p_ids = tuple(p.data_ptr() for p in module.parameters())
return self.param_ids_to_hooked_modules[p_ids]

def in_params(self, *params: Union[Parameter, Tensor, None]) -> bool:
"""Check if all parameters are used in KFAC.
Args:
params: Parameters to check.
Returns:
Whether all parameters are used in KFAC.
"""
return all(p is not None and p.data_ptr() in self.param_ids for p in params)

def param_pos(self, param: Union[Parameter, Tensor]) -> int:
"""Get the position of a parameter in the list of parameters used in KFAC.
Args:
param: The parameter.
Returns:
The parameter's position in the parameter list.
"""
return self.param_ids.index(param.data_ptr())
52 changes: 26 additions & 26 deletions test/test_kfac.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Contains tests for ``curvlinops.kfac``."""

from test.utils import ggn_block_diagonal
from typing import Iterable, List, Tuple

from numpy import eye
Expand All @@ -9,11 +10,13 @@
from torch.nn import CrossEntropyLoss, Module, MSELoss, Parameter

from curvlinops.examples.utils import report_nonclose
from curvlinops.ggn import GGNLinearOperator
from curvlinops.gradient_moments import EFLinearOperator
from curvlinops.kfac import KFACLinearOperator


@mark.parametrize(
"separate_weight_and_bias", [True, False], ids=["separate_bias", "joint_bias"]
)
@mark.parametrize(
"exclude", [None, "weight", "bias"], ids=["all", "no_weights", "no_biases"]
)
Expand All @@ -24,6 +27,7 @@ def test_kfac(
],
shuffle: bool,
exclude: str,
separate_weight_and_bias: bool,
):
"""Test the KFAC implementation against the exact GGN.
Expand All @@ -33,6 +37,8 @@ def test_kfac(
shuffle: Whether to shuffle the parameters before computing the KFAC matrix.
exclude: Which parameters to exclude. Can be ``'weight'``, ``'bias'``,
or ``None``.
separate_weight_and_bias: Whether to treat weight and bias as separate blocks in
the KFAC matrix.
"""
assert exclude in [None, "weight", "bias"]
model, loss_func, params, data = kfac_expand_exact_case
Expand All @@ -45,13 +51,21 @@ def test_kfac(
permutation = randperm(len(params))
params = [params[i] for i in permutation]

ggn_blocks = [] # list of per-parameter GGNs
for param in params:
ggn = GGNLinearOperator(model, loss_func, [param], data)
ggn_blocks.append(ggn @ eye(ggn.shape[1]))
ggn = block_diag(*ggn_blocks)

kfac = KFACLinearOperator(model, loss_func, params, data, fisher_type="type-2")
ggn = ggn_block_diagonal(
model,
loss_func,
params,
data,
separate_weight_and_bias=separate_weight_and_bias,
)
kfac = KFACLinearOperator(
model,
loss_func,
params,
data,
fisher_type="type-2",
separate_weight_and_bias=separate_weight_and_bias,
)
kfac_mat = kfac @ eye(kfac.shape[1])

report_nonclose(ggn, kfac_mat)
Expand Down Expand Up @@ -81,13 +95,9 @@ def test_kfac_mc(
permutation = randperm(len(params))
params = [params[i] for i in permutation]

ggn_blocks = [] # list of per-parameter GGNs
for param in params:
ggn = GGNLinearOperator(model, loss_func, [param], data)
ggn_blocks.append(ggn @ eye(ggn.shape[1]))
ggn = block_diag(*ggn_blocks)

ggn = ggn_block_diagonal(model, loss_func, params, data)
kfac = KFACLinearOperator(model, loss_func, params, data, mc_samples=2_000)

kfac_mat = kfac @ eye(kfac.shape[1])

atol = {"sum": 5e-1, "mean": 5e-3}[loss_func.reduction]
Expand All @@ -103,12 +113,7 @@ def test_kfac_one_datum(
):
model, loss_func, params, data = kfac_expand_exact_one_datum_case

ggn_blocks = [] # list of per-parameter GGNs
for param in params:
ggn = GGNLinearOperator(model, loss_func, [param], data)
ggn_blocks.append(ggn @ eye(ggn.shape[1]))
ggn = block_diag(*ggn_blocks)

ggn = ggn_block_diagonal(model, loss_func, params, data)
kfac = KFACLinearOperator(model, loss_func, params, data, fisher_type="type-2")
kfac_mat = kfac @ eye(kfac.shape[1])

Expand All @@ -121,12 +126,7 @@ def test_kfac_mc_one_datum(
]
):
model, loss_func, params, data = kfac_expand_exact_one_datum_case

ggn_blocks = [] # list of per-parameter GGNs
for param in params:
ggn = GGNLinearOperator(model, loss_func, [param], data)
ggn_blocks.append(ggn @ eye(ggn.shape[1]))
ggn = block_diag(*ggn_blocks)
ggn = ggn_block_diagonal(model, loss_func, params, data)

kfac = KFACLinearOperator(model, loss_func, params, data, mc_samples=10_000)
kfac_mat = kfac @ eye(kfac.shape[1])
Expand Down
68 changes: 67 additions & 1 deletion test/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
"""Utility functions to test `curvlinops`."""

from torch import cuda, device, rand, randint
from itertools import product
from typing import Iterable, List, Tuple

from numpy import eye, ndarray
from torch import Tensor, cat, cuda, device, from_numpy, rand, randint
from torch.nn import Module, Parameter

from curvlinops import GGNLinearOperator


def get_available_devices():
Expand All @@ -25,3 +32,62 @@ def classification_targets(size, num_classes):
def regression_targets(size):
"""Create random targets for regression."""
return rand(*size)


def ggn_block_diagonal(
model: Module,
loss_func: Module,
params: List[Parameter],
data: Iterable[Tuple[Tensor, Tensor]],
separate_weight_and_bias: bool = True,
) -> ndarray:
"""Compute the block-diagonal GGN.
Args:
model: The neural network.
loss_func: The loss function.
params: The parameters w.r.t. which the GGN block-diagonals will be computed.
data: A data loader.
separate_weight_and_bias: Whether to treat weight and bias of a layer as
separate blocks in the block-diagonal GGN. Default: ``True``.
Returns:
The block-diagonal GGN.
"""
# compute the full GGN then zero out the off-diagonal blocks
ggn = GGNLinearOperator(model, loss_func, params, data)
ggn = from_numpy(ggn @ eye(ggn.shape[1]))
sizes = [p.numel() for p in params]
# ggn_blocks[i, j] corresponds to the block of (params[i], params[j])
ggn_blocks = [list(block.split(sizes, dim=1)) for block in ggn.split(sizes, dim=0)]

# find out which blocks to keep
num_params = len(params)
keep = [(i, i) for i in range(num_params)]
param_ids = [p.data_ptr() for p in params]

# keep blocks corresponding to jointly-treated weights and biases
if not separate_weight_and_bias:
# find all layers with weight and bias
has_weight_and_bias = [
mod
for mod in model.modules()
if hasattr(mod, "weight") and hasattr(mod, "bias") and mod.bias is not None
]
# only keep those whose parameters are included
has_weight_and_bias = [
mod
for mod in has_weight_and_bias
if mod.weight.data_ptr() in param_ids and mod.bias.data_ptr() in param_ids
]
for mod in has_weight_and_bias:
w_pos = param_ids.index(mod.weight.data_ptr())
b_pos = param_ids.index(mod.bias.data_ptr())
keep.extend([(w_pos, b_pos), (b_pos, w_pos)])

for i, j in product(range(num_params), range(num_params)):
if (i, j) not in keep:
ggn_blocks[i][j].zero_()

# concatenate all blocks
return cat([cat(row_blocks, dim=1) for row_blocks in ggn_blocks], dim=0).numpy()

0 comments on commit 6f7387a

Please sign in to comment.