Skip to content

Commit

Permalink
[REF] Remove for-loop over samples in MC Fisher
Browse files Browse the repository at this point in the history
  • Loading branch information
f-dangel committed Mar 7, 2024
1 parent 5852711 commit 1ae6bc8
Showing 1 changed file with 33 additions and 29 deletions.
62 changes: 33 additions & 29 deletions curvlinops/fisher.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,17 @@
from math import sqrt
from typing import Callable, Iterable, List, Optional, Tuple, Union

from backpack.hessianfree.ggnvp import ggn_vector_product_from_plist
from einops import einsum
from numpy import ndarray
from torch import (
Generator,
Tensor,
as_tensor,
autograd,
multinomial,
normal,
softmax,
zeros,
zeros_like,
)
from torch.nn import CrossEntropyLoss, MSELoss, Parameter
Expand Down Expand Up @@ -210,37 +211,40 @@ def _matmat_batch(
``M_list``, i.e. each tensor in the list has the shape of a parameter and a
leading dimension of matrix columns.
"""
N = X.shape[0]
normalization = {"mean": N, "sum": 1.0}[self._loss_func.reduction]
# compute ∂ℓₙ(yₙₘ)/∂fₙ where fₙ is the prediction for datum n and
# yₙₘ is the m-th sampled label for datum n
output = self._model_func(X)
grad_output = zeros(
self._mc_samples, *output.shape, device=output.device, dtype=output.dtype
)
for n, output_n in enumerate(output.split(1)):
for m in range(self._mc_samples):
grad_output[m, n].add_(self.sample_grad_output(output_n).squeeze(0))

# Compute the pseudo-loss L' := 0.5 / (M * c) ∑ₙ ∑ₘ fₙᵀ (gₙₘ gₙₘᵀ) fₙ where
# gₙₘ = ∂ℓₙ(yₙₘ)/∂fₙ (detached) and M is the number of MC samples.
# The GGN of L' linearized at fₙ is the MC Fisher.
# We can thus multiply with it by computing the GGN-vector products of L'.
normalization = {"mean": 1.0 / X.shape[0], "sum": 1.0}[
self._loss_func.reduction
]
loss = (
0.5
* normalization
/ self._mc_samples
* (einsum(output, grad_output, "n ..., m n ... -> m n") ** 2).sum()
)

# Multiply the MC Fisher onto each vector in the input matrix
result_list = [zeros_like(M) for M in M_list]

for n in range(N):
X_n = X[n].unsqueeze(0)
output_n = self._model_func(X_n)

for m in range(self._mc_samples):
grad_output_sampled_n = self.sample_grad_output(output_n)

retain_graph = m != self._mc_samples - 1
grad_sampled_n = autograd.grad(
output_n,
self._params,
grad_outputs=grad_output_sampled_n,
retain_graph=retain_graph,
num_vectors = M_list[0].shape[0]
for v in range(num_vectors):
for idx, ggnvp in enumerate(
ggn_vector_product_from_plist(
loss, output, self._params, [M[v] for M in M_list]
)
# coefficients for each matrix-vector product
c = (
sum(
einsum(g, M, "..., col ... -> col")
for g, M in zip(grad_sampled_n, M_list)
)
/ self._mc_samples
/ normalization
)

for idx, g in enumerate(grad_sampled_n):
result_list[idx].add_(einsum(c, g, "col, ... -> col ..."))
):
result_list[idx][v].add_(ggnvp.detach())

return tuple(result_list)

Expand Down

0 comments on commit 1ae6bc8

Please sign in to comment.