diff --git a/curvlinops/kfac.py b/curvlinops/kfac.py index 84abaf60..610f138d 100644 --- a/curvlinops/kfac.py +++ b/curvlinops/kfac.py @@ -293,11 +293,12 @@ def _compute_loss_and_backward(self, output: Tensor, y: Tensor): if isinstance(self._loss_func, MSELoss): flat_logits = output.flatten(start_dim=1) out_dims = flat_logits.shape[1] + # Accounts for the reduction used in the loss function. scale = 1.0 / out_dims if reduction == "mean" else 1.0 for i in range(out_dims): # Mean or sum reduction over all loss terms. # Multiply by sqrt(scale * 2.0) since the MSELoss does - # not include the scale / 2.0 factor. + # not include the 1 / 2 factor. loss_i = sqrt(scale * 2.0) * reduction_fn(flat_logits[:, i]) loss_i.backward(retain_graph=i < out_dims - 1) elif isinstance(self._loss_func, CrossEntropyLoss):