Skip to content

Commit

Permalink
[REF] Improve function name
Browse files Browse the repository at this point in the history
  • Loading branch information
f-dangel committed Nov 9, 2023
1 parent 5ebe927 commit 89c2dfc
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions curvlinops/kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from torch.utils.hooks import RemovableHandle

from curvlinops._base import _LinearOperator
from curvlinops.kfac_utils import hessian_matrix_sqrt
from curvlinops.kfac_utils import loss_hessian_matrix_sqrt


class KFACLinearOperator(_LinearOperator):
Expand Down Expand Up @@ -296,7 +296,7 @@ def _compute_loss_and_backward(self, output: Tensor, y: Tensor):
# Result has shape `(batch_size, num_classes, num_classes)`
hessian_sqrts = stack(
[
hessian_matrix_sqrt(out.detach(), self._loss_func)
loss_hessian_matrix_sqrt(out.detach(), self._loss_func)
for out in output.split(1)
]
)
Expand Down
2 changes: 1 addition & 1 deletion curvlinops/kfac_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch.nn import CrossEntropyLoss, MSELoss


def hessian_matrix_sqrt(
def loss_hessian_matrix_sqrt(
output_one_datum: Tensor, loss_func: Union[MSELoss, CrossEntropyLoss]
) -> Tensor:
r"""Compute the loss function's matrix square root for a sample's output.
Expand Down
2 changes: 1 addition & 1 deletion docs/rtd/internals.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ details; because rendered LaTeX is easier to read than source code.
KFAC-related
-------------

.. autofunction:: curvlinops.kfac_utils.hessian_matrix_sqrt
.. autofunction:: curvlinops.kfac_utils.loss_hessian_matrix_sqrt

0 comments on commit 89c2dfc

Please sign in to comment.