Skip to content

Commit

Permalink
[REF] Parallelize grad_output computation in empirical Fisher
Browse files Browse the repository at this point in the history
  • Loading branch information
f-dangel committed Mar 12, 2024
1 parent 325c86c commit 08f6f2b
Showing 1 changed file with 8 additions and 12 deletions.
20 changes: 8 additions & 12 deletions curvlinops/gradient_moments.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

from backpack.hessianfree.ggnvp import ggn_vector_product_from_plist
from einops import einsum
from torch import Tensor, autograd, cat, zeros_like
from torch import Tensor, zeros_like
from torch.autograd import grad

from curvlinops._base import _LinearOperator

Expand Down Expand Up @@ -60,24 +61,19 @@ def _matmat_batch(
``M_list``, i.e. each tensor in the list has the shape of a parameter and a
leading dimension of matrix columns.
"""
# compute ∂ℓₙ/∂fₙ without reduction factor of L
output = self._model_func(X)
grad_output = []
for output_n, y_n in zip(output.split(1), y.split(1)):
loss_n = self._loss_func(output_n, y_n)
(grad_output_n,) = autograd.grad(loss_n, output_n)
grad_output.append(grad_output_n.detach())
grad_output = cat(grad_output)
reduction_factor = {"mean": X.shape[0], "sum": 1.0}[self._loss_func.reduction]

# compute ∂ℓₙ/∂fₙ without reduction factor of L
(grad_output,) = grad(self._loss_func(output, y), output)
grad_output = grad_output.detach() * reduction_factor

# Compute the pseudo-loss L' := 0.5 / c ∑ₙ fₙᵀ (gₙ gₙᵀ) fₙ where gₙ = ∂ℓₙ/∂fₙ
# (detached). The GGN of L' linearized at fₙ is the empirical Fisher.
# We can thus multiply with the EF by computing the GGN-vector products of L'.
normalization = {"mean": 1.0 / X.shape[0], "sum": 1.0}[
self._loss_func.reduction
]
loss = (
0.5
* normalization
/ reduction_factor
* (einsum(output, grad_output, "n ..., n ... -> n") ** 2).sum()
)

Expand Down

0 comments on commit 08f6f2b

Please sign in to comment.