Skip to content

Commit

Permalink
Switch from warning to ValueError when inconsistent loss_average is used
Browse files Browse the repository at this point in the history
  • Loading branch information
runame committed Jan 12, 2024
1 parent 27f9c72 commit 6553bcf
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions curvlinops/kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,8 @@ def __init__(
ValueError: If the loss average is not supported.
ValueError: If the loss average is ``None`` and the loss function's
reduction is not ``'sum'``.
ValueError: If the loss average is not ``None`` and the loss function's
reduction is ``'sum'``.
ValueError: If ``fisher_type != 'mc'`` and ``mc_samples != 1``.
NotImplementedError: If a parameter is in an unsupported layer.
"""
Expand All @@ -189,13 +191,10 @@ def __init__(
f"Must be 'batch' or 'batch+sequence' if loss_func.reduction != 'sum'."
)
if loss_func.reduction == "sum" and loss_average is not None:
# reduction used in loss function will overwrite loss_average
warnings.warn(
raise ValueError(
f"Loss function uses reduction='sum', but loss_average={loss_average}."
" loss_average is set to None.",
stacklevel=2,
" Set loss_average to None if you want to use reduction='sum'."
)
loss_average = None
if fisher_type != "mc" and mc_samples != 1:
raise ValueError(
f"Invalid mc_samples: {mc_samples}. "
Expand Down

0 comments on commit 6553bcf

Please sign in to comment.