Skip to content

Commit

Permalink
[REF] Make trace/diagonal estimators functional (#168)
Browse files Browse the repository at this point in the history
* [REF] Make trace estimators functional

* [DOC] Add missing return statement

* [DOC] Improve docstrings

* [REF] Refactor Hutchinson diagonal estimator into function

* [REF] Make Frobenius norm estimator functional, add tests

* [DOC] Update changelog

* [DOC] Formatting

* [DOC] Fix darglint

* [DOC] Really fix darglint

* [DOC] Add missing exception

* [REF] Share test code between diagonal/trace/norm

* [DEL] Unused imports

* [DOC] Add Hutchinson diagonal to RTD

* [REF] Extract checks for square matrices and number of matvecs

* [REF] Avoid cyclic import due to type annotations
  • Loading branch information
f-dangel authored Jan 16, 2025
1 parent c49580a commit 53267c5
Show file tree
Hide file tree
Showing 20 changed files with 439 additions and 458 deletions.
16 changes: 13 additions & 3 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added/New

- **Backward-incompatible:** Refactor class-based trace and diagonal estimators
into functions ([PR](https://github.com/f-dangel/curvlinops/pull/168)):
- If you used `HutchinsonTraceEstimator`, switch to `hutchinson_trace`
- If you used `HutchPPTraceEstimator`, switch to `hutchpp_trace`
- If you used `HutchinsonDiagonalEstimator`, switch to `hutchinson_diag`
- If you used `HutchinsonSquaredFrobeniusNormEstimator`, switch to
`hutchinson_squared_fro`

- Add diagonal estimation with the XDiag algorithm
([paper](https://arxiv.org/pdf/2301.07825),
[PR](https://github.com/f-dangel/curvlinops/pull/167))
Expand All @@ -21,7 +29,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
that benchmarks the linear operators
([PR](https://github.com/f-dangel/curvlinops/pull/162))

- Make linear operators purely PyTorch with a SciPy export option
- **Backward-incompatible:** Make linear operators purely PyTorch with a SciPy
export option
- `GGNLinearOperator` ([PR](https://github.com/f-dangel/curvlinops/pull/146))
- `TransposedJacobianLinearOperator` and `JacobianLinearOperator`
([PR](https://github.com/f-dangel/curvlinops/pull/147))
Expand All @@ -45,8 +54,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
([PR](https://github.com/f-dangel/curvlinops/pull/156))
- More test cases for `KFACInverseLinearOperator` and bug fix in
`.load_state_dict` ([PR](https://github.com/f-dangel/curvlinops/pull/158))
- Remove `.to_device` function of linear operators, always carry out deterministic
checks on the linear operator's device (previously always on CPU)
- **Backward-incompatible:** Remove `.to_device` function of linear operators,
always carry out deterministic checks on the linear operator's device
(previously always on CPU)
([PR](https://github.com/f-dangel/curvlinops/pull/160))

### Internal
Expand Down
16 changes: 8 additions & 8 deletions curvlinops/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""``curvlinops`` library API."""

from curvlinops.diagonal.epperly2024xtrace import xdiag
from curvlinops.diagonal.hutchinson import HutchinsonDiagonalEstimator
from curvlinops.diagonal.hutchinson import hutchinson_diag
from curvlinops.fisher import FisherMCLinearOperator
from curvlinops.ggn import GGNLinearOperator
from curvlinops.gradient_moments import EFLinearOperator
Expand All @@ -14,7 +14,7 @@
)
from curvlinops.jacobian import JacobianLinearOperator, TransposedJacobianLinearOperator
from curvlinops.kfac import FisherType, KFACLinearOperator, KFACType
from curvlinops.norm.hutchinson import HutchinsonSquaredFrobeniusNormEstimator
from curvlinops.norm.hutchinson import hutchinson_squared_fro
from curvlinops.papyan2020traces.spectrum import (
LanczosApproximateLogSpectrumCached,
LanczosApproximateSpectrumCached,
Expand All @@ -23,8 +23,8 @@
)
from curvlinops.submatrix import SubmatrixLinearOperator
from curvlinops.trace.epperly2024xtrace import xtrace
from curvlinops.trace.hutchinson import HutchinsonTraceEstimator
from curvlinops.trace.meyer2020hutch import HutchPPTraceEstimator
from curvlinops.trace.hutchinson import hutchinson_trace
from curvlinops.trace.meyer2020hutch import hutchpp_trace

__all__ = [
# linear operators
Expand All @@ -51,12 +51,12 @@
"LanczosApproximateSpectrumCached",
"LanczosApproximateLogSpectrumCached",
# trace estimation
"HutchinsonTraceEstimator",
"HutchPPTraceEstimator",
"hutchinson_trace",
"hutchpp_trace",
"xtrace",
# diagonal estimation
"HutchinsonDiagonalEstimator",
"hutchinson_diag",
"xdiag",
# norm estimation
"HutchinsonSquaredFrobeniusNormEstimator",
"hutchinson_squared_fro",
]
24 changes: 10 additions & 14 deletions curvlinops/diagonal/epperly2024xtrace.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@
from scipy.sparse.linalg import LinearOperator

from curvlinops.sampling import random_vector
from curvlinops.utils import (
assert_divisible_by,
assert_is_square,
assert_matvecs_subseed_dim,
)


def xdiag(A: LinearOperator, num_matvecs: int) -> ndarray:
Expand All @@ -21,24 +26,15 @@ def xdiag(A: LinearOperator, num_matvecs: int) -> ndarray:
Args:
A: A square linear operator.
num_matvecs: Total number of matrix-vector products to use. Must be even and
less than the dimension of the linear operator.
less than the dimension of the linear operator (because otherwise one can
evaluate the true diagonal directly at the same cost).
Returns:
The estimated diagonal of the linear operator.
Raises:
ValueError: If the linear operator is not square or if the number of matrix-
vector products is not even or is greater than the dimension of the linear
operator.
"""
if len(A.shape) != 2 or A.shape[0] != A.shape[1]:
raise ValueError(f"A must be square. Got shape {A.shape}.")
dim = A.shape[1]
if num_matvecs % 2 != 0 or num_matvecs >= dim:
raise ValueError(
"num_matvecs must be even and less than the dimension of A.",
f" Got {num_matvecs}.",
)
dim = assert_is_square(A)
assert_matvecs_subseed_dim(A, num_matvecs)
assert_divisible_by(num_matvecs, 2, "num_matvecs")

# draw random vectors and compute their matrix-vector products
num_vecs = num_matvecs // 2
Expand Down
85 changes: 36 additions & 49 deletions curvlinops/diagonal/hutchinson.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,31 @@
"""Hutchinson-style matrix diagonal estimation."""

from numpy import ndarray
from numpy import column_stack, einsum, ndarray
from scipy.sparse.linalg import LinearOperator

from curvlinops.sampling import random_vector
from curvlinops.utils import assert_is_square, assert_matvecs_subseed_dim


class HutchinsonDiagonalEstimator:
r"""Class to perform diagonal estimation with Hutchinson's method.
def hutchinson_diag(
A: LinearOperator, num_matvecs: int, distribution: str = "rademacher"
) -> ndarray:
r"""Estimate a linear operator's diagonal using Hutchinson's method.
For details, see
- Martens, J., Sutskever, I., & Swersky, K. (2012). Estimating the hessian by
back-propagating curvature. International Conference on Machine Learning (ICML).
- Bekas, C., Kokiopoulou, E., & Saad, Y. (2007). An estimator for the diagonal
of a matrix. Applied Numerical Mathematics.
Let :math:`\mathbf{A}` be a square linear operator. We can approximate its diagonal
:math:`\mathrm{diag}(\mathbf{A})` by drawing a random vector :math:`\mathbf{v}`
which satisfies :math:`\mathbb{E}[\mathbf{v} \mathbf{v}^\top] = \mathbf{I}` and
sample from the estimator
:math:`\mathrm{diag}(\mathbf{A})` by drawing random vectors :math:`N`
:math:`\mathbf{v}_n \sim \mathbf{v}` from a distribution :math:`\mathbf{v}` that
satisfies :math:`\mathbb{E}[\mathbf{v} \mathbf{v}^\top] = \mathbf{I}`, and compute
the estimator
.. math::
\mathbf{a}
:= \mathbf{v} \odot \mathbf{A} \mathbf{v}
:= \frac{1}{N} \sum_{n=1}^N \mathbf{v}_n \odot \mathbf{A} \mathbf{v}_n
\approx \mathrm{diag}(\mathbf{A})\,.
This estimator is unbiased,
Expand All @@ -33,54 +37,37 @@ class HutchinsonDiagonalEstimator:
= \sum_j A_{i,j} \delta_{i, j}
= A_{i,i}\,.
Args:
A: A square linear operator whose diagonal is estimated.
num_matvecs: Total number of matrix-vector products to use. Must be smaller
than the dimension of the linear operator (because otherwise one can
evaluate the true diagonal directly at the same cost).
distribution: Distribution of the random vectors used for the diagonal
estimation. Can be either ``'rademacher'`` or ``'normal'``.
Default: ``'rademacher'``.
Returns:
The estimated diagonal of the linear operator.
Example:
>>> from numpy import diag, mean, round
>>> from numpy import diag
>>> from numpy.random import rand, seed
>>> from numpy.linalg import norm
>>> seed(0) # make deterministic
>>> A = rand(10, 10)
>>> A = rand(40, 40)
>>> diag_A = diag(A) # exact diagonal as reference
>>> estimator = HutchinsonDiagonalEstimator(A)
>>> # one- and multi-sample approximations
>>> diag_A_low_precision = estimator.sample()
>>> samples = [estimator.sample() for _ in range(1_000)]
>>> diag_A_high_precision = mean(samples, axis=0)
>>> diag_A_low_precision = hutchinson_diag(A, num_matvecs=1)
>>> diag_A_high_precision = hutchinson_diag(A, num_matvecs=30)
>>> # compute residual norms
>>> error_low_precision = norm(diag_A - diag_A_low_precision)
>>> error_high_precision = norm(diag_A - diag_A_high_precision)
>>> error_low_precision = norm(diag_A - diag_A_low_precision) / norm(diag_A)
>>> error_high_precision = norm(diag_A - diag_A_high_precision) / norm(diag_A)
>>> assert error_low_precision > error_high_precision
>>> round(error_low_precision, 4), round(error_high_precision, 4)
(5.7268, 0.1525)
(4.616, 1.2441)
"""
dim = assert_is_square(A)
assert_matvecs_subseed_dim(A, num_matvecs)
G = column_stack([random_vector(dim, distribution) for _ in range(num_matvecs)])

def __init__(self, A: LinearOperator):
"""Store the linear operator whose diagonal will be estimated.
Args:
A: Linear square-shaped operator whose diagonal will be estimated.
Raises:
ValueError: If the operator is not square.
"""
if len(A.shape) != 2 or A.shape[0] != A.shape[1]:
raise ValueError(f"A must be square. Got shape {A.shape}.")
self._A = A

def sample(self, distribution: str = "rademacher") -> ndarray:
"""Draw a sample from the diagonal estimator.
Multiple samples can be combined into a more accurate diagonal estimation via
averaging.
Args:
distribution: Distribution of the vector along which the linear operator
will be evaluated. Either ``'rademacher'`` or ``'normal'``.
Default is ``'rademacher'``.
Returns:
A Sample from the diagonal estimator.
"""
dim = self._A.shape[1]
v = random_vector(dim, distribution)
Av = self._A @ v
return v * Av
return einsum("ij,ij->i", G, A @ G) / num_matvecs
73 changes: 38 additions & 35 deletions curvlinops/norm/hutchinson.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
"""Hutchinson-style matrix norm estimation."""

from numpy import dot
from numpy import column_stack
from scipy.sparse.linalg import LinearOperator

from curvlinops.sampling import random_vector


class HutchinsonSquaredFrobeniusNormEstimator:
def hutchinson_squared_fro(
A: LinearOperator, num_matvecs: int, distribution: str = "rademacher"
) -> float:
r"""Estimate the squared Frobenius norm of a matrix using Hutchinson's method.
Let :math:`\mathbf{A} \in \mathbb{R}^{M \times N}` be some matrix. It's Frobenius
Expand All @@ -22,45 +24,46 @@ class HutchinsonSquaredFrobeniusNormEstimator:
Due to the last equality, we can use Hutchinson-style trace estimation to estimate
the squared Frobenius norm.
Args:
A: A matrix whose squared Frobenius norm is estimated.
num_matvecs: Total number of matrix-vector products to use. Must be smaller
than the minimum dimension of the matrix.
distribution: Distribution of the random vectors used for the trace estimation.
Can be either ``'rademacher'`` or ``'normal'``. Default: ``'rademacher'``.
Returns:
The estimated squared Frobenius norm of the matrix.
Raises:
ValueError: If the matrix is not two-dimensional or if the number of matrix-
vector products is greater than the minimum dimension of the matrix
(because then you can evaluate the true squared Frobenius norm directly
atthe same cost).
Example:
>>> from numpy import mean, round
>>> from numpy.linalg import norm
>>> from numpy.random import rand, seed
>>> seed(0) # make deterministic
>>> A = rand(5, 5)
>>> A = rand(40, 40)
>>> fro2_A = norm(A, ord='fro')**2 # exact squared Frobenius norm as reference
>>> estimator = HutchinsonSquaredFrobeniusNormEstimator(A)
>>> # one- and multi-sample approximations
>>> fro2_A_low_prec = estimator.sample()
>>> fro2_A_high_prec = mean([estimator.sample() for _ in range(1_000)])
>>> fro2_A_low_prec = hutchinson_squared_fro(A, num_matvecs=1)
>>> fro2_A_high_prec = hutchinson_squared_fro(A, num_matvecs=30)
>>> assert abs(fro2_A - fro2_A_low_prec) > abs(fro2_A - fro2_A_high_prec)
>>> round(fro2_A, 4), round(fro2_A_low_prec, 4), round(fro2_A_high_prec, 4)
(10.7192, 8.3257, 10.6406)
>>> round(fro2_A, 1), round(fro2_A_low_prec, 1), round(fro2_A_high_prec, 1)
(546.0, 319.7, 645.2)
"""
if len(A.shape) != 2:
raise ValueError(f"A must be a matrix. Got shape {A.shape}.")
dim = min(A.shape)
if num_matvecs >= dim:
raise ValueError(
f"num_matvecs ({num_matvecs}) must be less than the minimum dimension of A."
)
# Instead of AT @ A, use A @ AT if the matrix is wider than tall
if A.shape[1] > A.shape[0]:
A = A.T

def __init__(self, A: LinearOperator):
"""Store the linear operator whose squared Frobenius norm will be estimated.
Args:
A: Linear operator whose squared Frobenius norm will be estimated.
"""
self._A = A

def sample(self, distribution: str = "rademacher") -> float:
"""Draw a sample from the squared Frobenius norm estimator.
Multiple samples can be combined into a more accurate squared Frobenius norm
estimation via averaging.
Args:
distribution: Distribution of the vector along which the linear operator
will be evaluated. Either ``'rademacher'`` or ``'normal'``.
Default is ``'rademacher'``.
Returns:
Sample from the squared Frobenius norm estimator.
"""
dim = self._A.shape[1]
v = random_vector(dim, distribution)
Av = self._A @ v
return dot(Av, Av)
G = column_stack([random_vector(dim, distribution) for _ in range(num_matvecs)])
AG = A @ G
return (AG**2 / num_matvecs).sum()
24 changes: 10 additions & 14 deletions curvlinops/trace/epperly2024xtrace.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@
from scipy.sparse.linalg import LinearOperator

from curvlinops.sampling import random_vector
from curvlinops.utils import (
assert_divisible_by,
assert_is_square,
assert_matvecs_subseed_dim,
)


def xtrace(
Expand All @@ -23,26 +28,17 @@ def xtrace(
Args:
A: A square linear operator.
num_matvecs: Total number of matrix-vector products to use. Must be even and
less than the dimension of the linear operator.
less than the dimension of the linear operator (because otherwise one can
evaluate the true trace directly at the same cost).
distribution: Distribution of the random vectors used for the trace estimation.
Can be either ``'rademacher'`` or ``'normal'``. Default: ``'rademacher'``.
Returns:
The estimated trace of the linear operator.
Raises:
ValueError: If the linear operator is not square or if the number of matrix-
vector products is not even or is greater than the dimension of the linear
operator.
"""
if len(A.shape) != 2 or A.shape[0] != A.shape[1]:
raise ValueError(f"A must be square. Got shape {A.shape}.")
dim = A.shape[1]
if num_matvecs % 2 != 0 or num_matvecs >= dim:
raise ValueError(
"num_matvecs must be even and less than the dimension of A.",
f" Got {num_matvecs}.",
)
dim = assert_is_square(A)
assert_matvecs_subseed_dim(A, num_matvecs)
assert_divisible_by(num_matvecs, 2, "num_matvecs")

# draw random vectors and compute their matrix-vector products
num_vecs = num_matvecs // 2
Expand Down
Loading

0 comments on commit 53267c5

Please sign in to comment.