Skip to content

Commit

Permalink
Add test for KFAC with >2d model outputs
Browse files Browse the repository at this point in the history
  • Loading branch information
runame committed Jan 9, 2024
1 parent c362dba commit 5ae9553
Showing 1 changed file with 76 additions and 2 deletions.
78 changes: 76 additions & 2 deletions test/test_kfac.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
"""Contains tests for ``curvlinops.kfac``."""

from test.cases import DEVICES, DEVICES_IDS
from test.utils import ggn_block_diagonal, regression_targets
from typing import Iterable, List, Tuple
from test.utils import (
Rearrange,
classification_targets,
ggn_block_diagonal,
regression_targets,
)
from typing import Iterable, List, Tuple, Union

from numpy import eye
from pytest import mark
from scipy.linalg import block_diag
from torch import Tensor, device, manual_seed, rand, randperm
from torch.nn import (
CrossEntropyLoss,
Flatten,
Linear,
Module,
MSELoss,
Expand Down Expand Up @@ -200,3 +206,71 @@ def test_kfac_inplace_activations(dev: device):
ggn_no_inplace = ggn_block_diagonal(model, loss_func, params, data)

report_nonclose(ggn, ggn_no_inplace)


@mark.parametrize("fisher_type", ["type-2", "mc", "empirical"])
@mark.parametrize("loss", [MSELoss, CrossEntropyLoss], ids=["mse", "ce"])
@mark.parametrize("reduction", ["mean", "sum"])
@mark.parametrize("dev", DEVICES, ids=DEVICES_IDS)
def test_multi_dim_output(
fisher_type: str,
loss: Union[MSELoss, CrossEntropyLoss],
reduction: str,
dev: device,
):
"""Test the KFAC implementation for >2d outputs (using a 3d and 4d output).
Args:
fisher_type: The type of Fisher matrix to use.
loss: The loss function to use.
reduction: The reduction to use for the loss function.
dev: The device to run the test on.
"""
manual_seed(0)
# set up loss function, data, and model
loss_func = loss(reduction=reduction).to(dev)
if isinstance(loss_func, MSELoss):
data = [
(rand(2, 4, 5), regression_targets((2, 4, 3))),
(rand(4, 7, 5, 5), regression_targets((4, 7, 5, 3))),
]
manual_seed(711)
model = Sequential(Linear(5, 4), Linear(4, 3)).to(dev)
else:
data = [
(rand(2, 4, 5), classification_targets((2, 4), 3)),
(rand(4, 7, 5, 5), classification_targets((4, 7, 5), 3)),
]
manual_seed(711)
# rearrange is necessary to get the expected output shape for ce loss
model = Sequential(
Linear(5, 4),
Linear(4, 3),
Rearrange("batch ... c -> batch c ..."),
).to(dev)

# KFAC for deep linear network with 3d/4d input and output
params = list(model.parameters())
kfac = KFACLinearOperator(model, loss_func, params, data, fisher_type=fisher_type)
kfac_mat = kfac @ eye(kfac.shape[1])

# KFAC for deep linear network with 3d/4d input and equivalent 2d output
manual_seed(711)
model_flat = Sequential(
Linear(5, 4),
Linear(4, 3),
Flatten(start_dim=0, end_dim=-2),
).to(dev)
params_flat = list(model_flat.parameters())
data_flat = [
(x, y.flatten(start_dim=0, end_dim=-2))
if isinstance(loss_func, MSELoss)
else (x, y.flatten(start_dim=0))
for x, y in data
]
kfac_flat = KFACLinearOperator(
model_flat, loss_func, params_flat, data_flat, fisher_type=fisher_type
)
kfac_flat_mat = kfac_flat @ eye(kfac_flat.shape[1])

report_nonclose(kfac_mat, kfac_flat_mat)

0 comments on commit 5ae9553

Please sign in to comment.