Skip to content

Commit

Permalink
Add KFAC with weight-sharing exactness test
Browse files Browse the repository at this point in the history
  • Loading branch information
runame committed Jan 9, 2024
1 parent 054fa43 commit daf8393
Showing 1 changed file with 94 additions and 13 deletions.
107 changes: 94 additions & 13 deletions test/test_kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
from test.cases import DEVICES, DEVICES_IDS
from test.utils import (
Rearrange,
WeightShareModel,
classification_targets,
ggn_block_diagonal,
regression_targets,
)
from typing import Iterable, List, Tuple, Union
from typing import Dict, Iterable, List, Tuple, Union

from numpy import eye
from pytest import mark
Expand Down Expand Up @@ -37,7 +38,7 @@
)
@mark.parametrize("shuffle", [False, True], ids=["", "shuffled"])
def test_kfac_type2(
kfac_expand_exact_case: Tuple[
kfac_exact_case: Tuple[
Module, MSELoss, List[Parameter], Iterable[Tuple[Tensor, Tensor]]
],
shuffle: bool,
Expand All @@ -47,7 +48,7 @@ def test_kfac_type2(
"""Test the KFAC implementation against the exact GGN.
Args:
kfac_expand_exact_case: A fixture that returns a model, loss function, list of
kfac_exact_case: A fixture that returns a model, loss function, list of
parameters, and data.
shuffle: Whether to shuffle the parameters before computing the KFAC matrix.
exclude: Which parameters to exclude. Can be ``'weight'``, ``'bias'``,
Expand All @@ -56,7 +57,7 @@ def test_kfac_type2(
the KFAC matrix.
"""
assert exclude in [None, "weight", "bias"]
model, loss_func, params, data = kfac_expand_exact_case
model, loss_func, params, data = kfac_exact_case

if exclude is not None:
names = {p.data_ptr(): name for name, p in model.named_parameters()}
Expand Down Expand Up @@ -90,21 +91,101 @@ def test_kfac_type2(
assert len(kfac._input_covariances) == 0


@mark.parametrize("setting", ["expand", "reduce"])
@mark.parametrize(
"separate_weight_and_bias", [True, False], ids=["separate_bias", "joint_bias"]
)
@mark.parametrize(
"exclude", [None, "weight", "bias"], ids=["all", "no_weights", "no_biases"]
)
@mark.parametrize("shuffle", [False, True], ids=["", "shuffled"])
def test_kfac_type2_weight_sharing(
kfac_weight_sharing_exact_case: Tuple[
WeightShareModel,
MSELoss,
List[Parameter],
Dict[str, Iterable[Tuple[Tensor, Tensor]]],
],
setting: str,
shuffle: bool,
exclude: str,
separate_weight_and_bias: bool,
):
"""Test KFAC for linear weight-sharing layers against the exact GGN.
Args:
kfac_weight_sharing_exact_case: A fixture that returns a model, loss function, list of
parameters, and data.
setting: The weight-sharing setting to use. Can be ``'expand'`` or ``'reduce'``.
shuffle: Whether to shuffle the parameters before computing the KFAC matrix.
exclude: Which parameters to exclude. Can be ``'weight'``, ``'bias'``,
or ``None``.
separate_weight_and_bias: Whether to treat weight and bias as separate blocks in
the KFAC matrix.
"""
assert exclude in [None, "weight", "bias"]
model, loss_func, params, data = kfac_weight_sharing_exact_case
model.setting = setting
data = data[setting]

# set appropriate loss_average argument based on loss reduction and setting
if loss_func.reduction == "mean":
if setting == "expand":
loss_average = "batch+sequence"
else:
loss_average = "batch"
else:
loss_average = None

if exclude is not None:
names = {p.data_ptr(): name for name, p in model.named_parameters()}
params = [p for p in params if exclude not in names[p.data_ptr()]]

if shuffle:
permutation = randperm(len(params))
params = [params[i] for i in permutation]

ggn = ggn_block_diagonal(
model,
loss_func,
params,
data,
separate_weight_and_bias=separate_weight_and_bias,
)
kfac = KFACLinearOperator(
model,
loss_func,
params,
data,
fisher_type="type-2",
kfac_approx=setting, # choose KFAC approximation consistent with setting
loss_average=loss_average,
separate_weight_and_bias=separate_weight_and_bias,
)
kfac_mat = kfac @ eye(kfac.shape[1])

report_nonclose(ggn, kfac_mat)

# Check that input covariances were not computed
if exclude == "weight":
assert len(kfac._input_covariances) == 0


@mark.parametrize("shuffle", [False, True], ids=["", "shuffled"])
def test_kfac_mc(
kfac_expand_exact_case: Tuple[
kfac_exact_case: Tuple[
Module, MSELoss, List[Parameter], Iterable[Tuple[Tensor, Tensor]]
],
shuffle: bool,
):
"""Test the KFAC implementation using MC samples against the exact GGN.
Args:
kfac_expand_exact_case: A fixture that returns a model, loss function, list of
kfac_exact_case: A fixture that returns a model, loss function, list of
parameters, and data.
shuffle: Whether to shuffle the parameters before computing the KFAC matrix.
"""
model, loss_func, params, data = kfac_expand_exact_case
model, loss_func, params, data = kfac_exact_case

if shuffle:
permutation = randperm(len(params))
Expand All @@ -122,11 +203,11 @@ def test_kfac_mc(


def test_kfac_one_datum(
kfac_expand_exact_one_datum_case: Tuple[
kfac_exact_one_datum_case: Tuple[
Module, CrossEntropyLoss, List[Parameter], Iterable[Tuple[Tensor, Tensor]]
]
):
model, loss_func, params, data = kfac_expand_exact_one_datum_case
model, loss_func, params, data = kfac_exact_one_datum_case

ggn = ggn_block_diagonal(model, loss_func, params, data)
kfac = KFACLinearOperator(model, loss_func, params, data, fisher_type="type-2")
Expand All @@ -136,11 +217,11 @@ def test_kfac_one_datum(


def test_kfac_mc_one_datum(
kfac_expand_exact_one_datum_case: Tuple[
kfac_exact_one_datum_case: Tuple[
Module, CrossEntropyLoss, List[Parameter], Iterable[Tuple[Tensor, Tensor]]
]
):
model, loss_func, params, data = kfac_expand_exact_one_datum_case
model, loss_func, params, data = kfac_exact_one_datum_case
ggn = ggn_block_diagonal(model, loss_func, params, data)

kfac = KFACLinearOperator(model, loss_func, params, data, mc_samples=10_000)
Expand All @@ -153,11 +234,11 @@ def test_kfac_mc_one_datum(


def test_kfac_ef_one_datum(
kfac_expand_exact_one_datum_case: Tuple[
kfac_exact_one_datum_case: Tuple[
Module, CrossEntropyLoss, List[Parameter], Iterable[Tuple[Tensor, Tensor]]
]
):
model, loss_func, params, data = kfac_expand_exact_one_datum_case
model, loss_func, params, data = kfac_exact_one_datum_case

ef_blocks = [] # list of per-parameter EFs
for param in params:
Expand Down

0 comments on commit daf8393

Please sign in to comment.