diff --git a/curvlinops/kfac.py b/curvlinops/kfac.py index 8023b1cb..ceebd273 100644 --- a/curvlinops/kfac.py +++ b/curvlinops/kfac.py @@ -171,6 +171,8 @@ def __init__( ValueError: If the loss average is not supported. ValueError: If the loss average is ``None`` and the loss function's reduction is not ``'sum'``. + ValueError: If the loss average is not ``None`` and the loss function's + reduction is ``'sum'``. ValueError: If ``fisher_type != 'mc'`` and ``mc_samples != 1``. NotImplementedError: If a parameter is in an unsupported layer. """ @@ -189,13 +191,10 @@ def __init__( f"Must be 'batch' or 'batch+sequence' if loss_func.reduction != 'sum'." ) if loss_func.reduction == "sum" and loss_average is not None: - # reduction used in loss function will overwrite loss_average - warnings.warn( + raise ValueError( f"Loss function uses reduction='sum', but loss_average={loss_average}." - " loss_average is set to None.", - stacklevel=2, + " Set loss_average to None if you want to use reduction='sum'." ) - loss_average = None if fisher_type != "mc" and mc_samples != 1: raise ValueError( f"Invalid mc_samples: {mc_samples}. "