diff --git a/curvlinops/kfac.py b/curvlinops/kfac.py index e31d1393..19ad6e7a 100644 --- a/curvlinops/kfac.py +++ b/curvlinops/kfac.py @@ -284,7 +284,7 @@ def _compute_loss_and_backward(self, output: Tensor, y: Tensor): Raises: ValueError: If ``fisher_type`` is not ``'type-2'``, ``'mc'``, or ``'empirical'``. - NotImplementedError: If ``fisher_type`` is ``'type-1'`` and the + NotImplementedError: If ``fisher_type`` is ``'type-2'`` and the output is not 2d. """ if self._fisher_type == "type-2":