diff --git a/curvlinops/kfac_utils.py b/curvlinops/kfac_utils.py index 8a49d9d8..5263246a 100644 --- a/curvlinops/kfac_utils.py +++ b/curvlinops/kfac_utils.py @@ -76,11 +76,11 @@ def loss_hessian_matrix_sqrt( f"{output_one_datum.shape}" ) output = output_one_datum.squeeze(0) - num_classes = output.numel() + output_dim = output.numel() if isinstance(loss_func, MSELoss): - c = {"sum": 1.0, "mean": 1.0 / num_classes}[loss_func.reduction] - return eye(num_classes, device=output.device, dtype=output.dtype).mul_( + c = {"sum": 1.0, "mean": 1.0 / output_dim}[loss_func.reduction] + return eye(output_dim, device=output.device, dtype=output.dtype).mul_( sqrt(2 * c) ) elif isinstance(loss_func, CrossEntropyLoss):