Skip to content

Commit

Permalink
[DOC] Add usage example for activation Hessian
Browse files Browse the repository at this point in the history
  • Loading branch information
f-dangel committed Jan 12, 2024
1 parent 30cf504 commit 2fdee76
Showing 1 changed file with 25 additions and 0 deletions.
25 changes: 25 additions & 0 deletions curvlinops/experimental/activation_hessian.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,31 @@ def __init__(
Raises:
ValueError: If ``data`` contains more than one batch.
Example:
>>> from numpy import eye, allclose
>>> from torch import manual_seed, rand
>>> from torch.nn import Linear, MSELoss, Sequential, ReLU
>>>
>>> loss_func = MSELoss()
>>> model = Sequential(Linear(4, 3), ReLU(), Linear(3, 2))
>>> [name for name, _ in model.named_modules()] # available layer names
['', '0', '1', '2']
>>> data = [(rand(10, 4), rand(10, 2))]
>>>
>>> hessian = ActivationHessianLinearOperator( # Hessian w.r.t. ReLU input
... model, loss_func, ("1", "input", 0), data
... )
>>> hessian.shape # batch size * feature dimension (10 * 3)
(30, 30)
>>>
>>> # The ReLU's input is the first Linear's output, let's check that
>>> hessian2 = ActivationHessianLinearOperator( # Hessian w.r.t. first output
... model, loss_func, ("0", "output", 0), data
... )
>>> I = eye(hessian.shape[1])
>>> allclose(hessian @ I, hessian2 @ I)
True
"""
self._activation = activation

Expand Down

0 comments on commit 2fdee76

Please sign in to comment.