diff --git a/curvlinops/kfac.py b/curvlinops/kfac.py index b015810c..e31d1393 100644 --- a/curvlinops/kfac.py +++ b/curvlinops/kfac.py @@ -24,7 +24,7 @@ from torch.utils.hooks import RemovableHandle from curvlinops._base import _LinearOperator -from curvlinops.kfac_utils import hessian_matrix_sqrt +from curvlinops.kfac_utils import loss_hessian_matrix_sqrt class KFACLinearOperator(_LinearOperator): @@ -296,7 +296,7 @@ def _compute_loss_and_backward(self, output: Tensor, y: Tensor): # Result has shape `(batch_size, num_classes, num_classes)` hessian_sqrts = stack( [ - hessian_matrix_sqrt(out.detach(), self._loss_func) + loss_hessian_matrix_sqrt(out.detach(), self._loss_func) for out in output.split(1) ] ) diff --git a/curvlinops/kfac_utils.py b/curvlinops/kfac_utils.py index b001180b..8a49d9d8 100644 --- a/curvlinops/kfac_utils.py +++ b/curvlinops/kfac_utils.py @@ -7,7 +7,7 @@ from torch.nn import CrossEntropyLoss, MSELoss -def hessian_matrix_sqrt( +def loss_hessian_matrix_sqrt( output_one_datum: Tensor, loss_func: Union[MSELoss, CrossEntropyLoss] ) -> Tensor: r"""Compute the loss function's matrix square root for a sample's output. diff --git a/docs/rtd/internals.rst b/docs/rtd/internals.rst index ef1da4e6..f6d9b4f2 100644 --- a/docs/rtd/internals.rst +++ b/docs/rtd/internals.rst @@ -8,4 +8,4 @@ details; because rendered LaTeX is easier to read than source code. KFAC-related ------------- -.. autofunction:: curvlinops.kfac_utils.hessian_matrix_sqrt +.. autofunction:: curvlinops.kfac_utils.loss_hessian_matrix_sqrt