From daf839374771d68e41e12e495c135877ccf1969d Mon Sep 17 00:00:00 2001 From: runame Date: Tue, 9 Jan 2024 22:39:22 +0100 Subject: [PATCH] Add KFAC with weight-sharing exactness test --- test/test_kfac.py | 107 ++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 94 insertions(+), 13 deletions(-) diff --git a/test/test_kfac.py b/test/test_kfac.py index 49b92386..c86e43d0 100644 --- a/test/test_kfac.py +++ b/test/test_kfac.py @@ -3,11 +3,12 @@ from test.cases import DEVICES, DEVICES_IDS from test.utils import ( Rearrange, + WeightShareModel, classification_targets, ggn_block_diagonal, regression_targets, ) -from typing import Iterable, List, Tuple, Union +from typing import Dict, Iterable, List, Tuple, Union from numpy import eye from pytest import mark @@ -37,7 +38,7 @@ ) @mark.parametrize("shuffle", [False, True], ids=["", "shuffled"]) def test_kfac_type2( - kfac_expand_exact_case: Tuple[ + kfac_exact_case: Tuple[ Module, MSELoss, List[Parameter], Iterable[Tuple[Tensor, Tensor]] ], shuffle: bool, @@ -47,7 +48,7 @@ def test_kfac_type2( """Test the KFAC implementation against the exact GGN. Args: - kfac_expand_exact_case: A fixture that returns a model, loss function, list of + kfac_exact_case: A fixture that returns a model, loss function, list of parameters, and data. shuffle: Whether to shuffle the parameters before computing the KFAC matrix. exclude: Which parameters to exclude. Can be ``'weight'``, ``'bias'``, @@ -56,7 +57,7 @@ def test_kfac_type2( the KFAC matrix. """ assert exclude in [None, "weight", "bias"] - model, loss_func, params, data = kfac_expand_exact_case + model, loss_func, params, data = kfac_exact_case if exclude is not None: names = {p.data_ptr(): name for name, p in model.named_parameters()} @@ -90,9 +91,89 @@ def test_kfac_type2( assert len(kfac._input_covariances) == 0 +@mark.parametrize("setting", ["expand", "reduce"]) +@mark.parametrize( + "separate_weight_and_bias", [True, False], ids=["separate_bias", "joint_bias"] +) +@mark.parametrize( + "exclude", [None, "weight", "bias"], ids=["all", "no_weights", "no_biases"] +) +@mark.parametrize("shuffle", [False, True], ids=["", "shuffled"]) +def test_kfac_type2_weight_sharing( + kfac_weight_sharing_exact_case: Tuple[ + WeightShareModel, + MSELoss, + List[Parameter], + Dict[str, Iterable[Tuple[Tensor, Tensor]]], + ], + setting: str, + shuffle: bool, + exclude: str, + separate_weight_and_bias: bool, +): + """Test KFAC for linear weight-sharing layers against the exact GGN. + + Args: + kfac_weight_sharing_exact_case: A fixture that returns a model, loss function, list of + parameters, and data. + setting: The weight-sharing setting to use. Can be ``'expand'`` or ``'reduce'``. + shuffle: Whether to shuffle the parameters before computing the KFAC matrix. + exclude: Which parameters to exclude. Can be ``'weight'``, ``'bias'``, + or ``None``. + separate_weight_and_bias: Whether to treat weight and bias as separate blocks in + the KFAC matrix. + """ + assert exclude in [None, "weight", "bias"] + model, loss_func, params, data = kfac_weight_sharing_exact_case + model.setting = setting + data = data[setting] + + # set appropriate loss_average argument based on loss reduction and setting + if loss_func.reduction == "mean": + if setting == "expand": + loss_average = "batch+sequence" + else: + loss_average = "batch" + else: + loss_average = None + + if exclude is not None: + names = {p.data_ptr(): name for name, p in model.named_parameters()} + params = [p for p in params if exclude not in names[p.data_ptr()]] + + if shuffle: + permutation = randperm(len(params)) + params = [params[i] for i in permutation] + + ggn = ggn_block_diagonal( + model, + loss_func, + params, + data, + separate_weight_and_bias=separate_weight_and_bias, + ) + kfac = KFACLinearOperator( + model, + loss_func, + params, + data, + fisher_type="type-2", + kfac_approx=setting, # choose KFAC approximation consistent with setting + loss_average=loss_average, + separate_weight_and_bias=separate_weight_and_bias, + ) + kfac_mat = kfac @ eye(kfac.shape[1]) + + report_nonclose(ggn, kfac_mat) + + # Check that input covariances were not computed + if exclude == "weight": + assert len(kfac._input_covariances) == 0 + + @mark.parametrize("shuffle", [False, True], ids=["", "shuffled"]) def test_kfac_mc( - kfac_expand_exact_case: Tuple[ + kfac_exact_case: Tuple[ Module, MSELoss, List[Parameter], Iterable[Tuple[Tensor, Tensor]] ], shuffle: bool, @@ -100,11 +181,11 @@ def test_kfac_mc( """Test the KFAC implementation using MC samples against the exact GGN. Args: - kfac_expand_exact_case: A fixture that returns a model, loss function, list of + kfac_exact_case: A fixture that returns a model, loss function, list of parameters, and data. shuffle: Whether to shuffle the parameters before computing the KFAC matrix. """ - model, loss_func, params, data = kfac_expand_exact_case + model, loss_func, params, data = kfac_exact_case if shuffle: permutation = randperm(len(params)) @@ -122,11 +203,11 @@ def test_kfac_mc( def test_kfac_one_datum( - kfac_expand_exact_one_datum_case: Tuple[ + kfac_exact_one_datum_case: Tuple[ Module, CrossEntropyLoss, List[Parameter], Iterable[Tuple[Tensor, Tensor]] ] ): - model, loss_func, params, data = kfac_expand_exact_one_datum_case + model, loss_func, params, data = kfac_exact_one_datum_case ggn = ggn_block_diagonal(model, loss_func, params, data) kfac = KFACLinearOperator(model, loss_func, params, data, fisher_type="type-2") @@ -136,11 +217,11 @@ def test_kfac_one_datum( def test_kfac_mc_one_datum( - kfac_expand_exact_one_datum_case: Tuple[ + kfac_exact_one_datum_case: Tuple[ Module, CrossEntropyLoss, List[Parameter], Iterable[Tuple[Tensor, Tensor]] ] ): - model, loss_func, params, data = kfac_expand_exact_one_datum_case + model, loss_func, params, data = kfac_exact_one_datum_case ggn = ggn_block_diagonal(model, loss_func, params, data) kfac = KFACLinearOperator(model, loss_func, params, data, mc_samples=10_000) @@ -153,11 +234,11 @@ def test_kfac_mc_one_datum( def test_kfac_ef_one_datum( - kfac_expand_exact_one_datum_case: Tuple[ + kfac_exact_one_datum_case: Tuple[ Module, CrossEntropyLoss, List[Parameter], Iterable[Tuple[Tensor, Tensor]] ] ): - model, loss_func, params, data = kfac_expand_exact_one_datum_case + model, loss_func, params, data = kfac_exact_one_datum_case ef_blocks = [] # list of per-parameter EFs for param in params: