Skip to content

Commit

Permalink
Refactor num_loss_terms for readability
Browse files Browse the repository at this point in the history
  • Loading branch information
runame committed Jan 12, 2024
1 parent fa7eb0f commit 6e529fd
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions curvlinops/kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def __init__(
or ``Linear`` modules that process matrix- or higher-dimensional
features.
Possible values are ``'expand'`` and ``'reduce'``.
See `Eschenhagen et al., 2023 <https://arxiv.org/abs/2311.00636>`_
See `Eschenhagen et al., 2023 <https://arxiv.org/abs/2311.00636>`_
for an explanation of the two approximations.
loss_average: Whether the loss function is a mean over per-sample
losses and if yes, over which dimensions the mean is taken.
Expand Down Expand Up @@ -483,11 +483,13 @@ def _accumulate_gradient_covariance(self, module: Module, grad_output: Tensor):
NotImplementedError: If the layer is not supported.
"""
g = grad_output.data.detach()
num_loss_terms = g.shape[0] # batch_size
batch_size = g.shape[0]
sequence_length = g.shape[1:-1].numel()
if self._loss_average == "batch+sequence":
# Number of all loss terms = batch_size * sequence_length
num_loss_terms *= sequence_length
num_loss_terms = {
None: batch_size,
"batch": batch_size,
"batch+sequence": batch_size * sequence_length,
}[self._loss_average]

if self._kfac_approx == "expand":
# KFAC-expand approximation
Expand Down

0 comments on commit 6e529fd

Please sign in to comment.