From 89c2dfc2b9e37482f21abdcbcef594f81ce0021f Mon Sep 17 00:00:00 2001 From: Felix Dangel Date: Thu, 9 Nov 2023 09:55:54 -0500 Subject: [PATCH] [REF] Improve function name --- curvlinops/kfac.py | 4 ++-- curvlinops/kfac_utils.py | 2 +- docs/rtd/internals.rst | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/curvlinops/kfac.py b/curvlinops/kfac.py index b015810c..e31d1393 100644 --- a/curvlinops/kfac.py +++ b/curvlinops/kfac.py @@ -24,7 +24,7 @@ from torch.utils.hooks import RemovableHandle from curvlinops._base import _LinearOperator -from curvlinops.kfac_utils import hessian_matrix_sqrt +from curvlinops.kfac_utils import loss_hessian_matrix_sqrt class KFACLinearOperator(_LinearOperator): @@ -296,7 +296,7 @@ def _compute_loss_and_backward(self, output: Tensor, y: Tensor): # Result has shape `(batch_size, num_classes, num_classes)` hessian_sqrts = stack( [ - hessian_matrix_sqrt(out.detach(), self._loss_func) + loss_hessian_matrix_sqrt(out.detach(), self._loss_func) for out in output.split(1) ] ) diff --git a/curvlinops/kfac_utils.py b/curvlinops/kfac_utils.py index b001180b..8a49d9d8 100644 --- a/curvlinops/kfac_utils.py +++ b/curvlinops/kfac_utils.py @@ -7,7 +7,7 @@ from torch.nn import CrossEntropyLoss, MSELoss -def hessian_matrix_sqrt( +def loss_hessian_matrix_sqrt( output_one_datum: Tensor, loss_func: Union[MSELoss, CrossEntropyLoss] ) -> Tensor: r"""Compute the loss function's matrix square root for a sample's output. diff --git a/docs/rtd/internals.rst b/docs/rtd/internals.rst index ef1da4e6..f6d9b4f2 100644 --- a/docs/rtd/internals.rst +++ b/docs/rtd/internals.rst @@ -8,4 +8,4 @@ details; because rendered LaTeX is easier to read than source code. KFAC-related ------------- -.. autofunction:: curvlinops.kfac_utils.hessian_matrix_sqrt +.. autofunction:: curvlinops.kfac_utils.loss_hessian_matrix_sqrt