Skip to content

Commit 7f83f04

Browse files
authored
[ADD] Squared Frobenius norm estimator (#80)
* [ADD] Squared Frobenius norm estimator * [FIX] Typo * [FIX] Math typo
1 parent 73ee0eb commit 7f83f04

File tree

4 files changed

+78
-0
lines changed

4 files changed

+78
-0
lines changed

curvlinops/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
)
1313
from curvlinops.jacobian import JacobianLinearOperator, TransposedJacobianLinearOperator
1414
from curvlinops.kfac import KFACLinearOperator
15+
from curvlinops.norm.hutchinson import HutchinsonSquaredFrobeniusNormEstimator
1516
from curvlinops.papyan2020traces.spectrum import (
1617
LanczosApproximateLogSpectrumCached,
1718
LanczosApproximateSpectrumCached,
@@ -23,22 +24,30 @@
2324
from curvlinops.trace.meyer2020hutch import HutchPPTraceEstimator
2425

2526
__all__ = [
27+
# linear operators
2628
"HessianLinearOperator",
2729
"GGNLinearOperator",
2830
"EFLinearOperator",
2931
"FisherMCLinearOperator",
3032
"KFACLinearOperator",
3133
"JacobianLinearOperator",
3234
"TransposedJacobianLinearOperator",
35+
# inversion
3336
"CGInverseLinearOperator",
3437
"NeumannInverseLinearOperator",
3538
"KFACInverseLinearOperator",
39+
# slicing
3640
"SubmatrixLinearOperator",
41+
# spectral properties
3742
"lanczos_approximate_spectrum",
3843
"lanczos_approximate_log_spectrum",
3944
"LanczosApproximateSpectrumCached",
4045
"LanczosApproximateLogSpectrumCached",
46+
# trace estimation
4147
"HutchinsonTraceEstimator",
4248
"HutchPPTraceEstimator",
49+
# diagonal estimation
4350
"HutchinsonDiagonalEstimator",
51+
# norm estimation
52+
"HutchinsonSquaredFrobeniusNormEstimator",
4453
]

curvlinops/norm/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Matrix norm estimation methods."""

curvlinops/norm/hutchinson.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
"""Hutchinson-style matrix norm estimation."""
2+
3+
from scipy.sparse.linalg import LinearOperator
4+
5+
from curvlinops.trace.hutchinson import HutchinsonTraceEstimator
6+
7+
8+
class HutchinsonSquaredFrobeniusNormEstimator:
9+
r"""Estimate the squared Frobenius norm of a matrix using Hutchinson's method.
10+
11+
Let :math:`\mathbf{A} \in \mathbb{R}^{M \times N}` be some matrix. It's Frobenius
12+
norm :math:`\lVert\mathbf{A}\rVert_\text{F}` is defined via:
13+
14+
.. math::
15+
\lVert\mathbf{A}\rVert_\text{F}^2
16+
=
17+
\sum_{m=1}^M \sum_{n=1}^N \mathbf{A}_{n,m}^2
18+
=
19+
\text{Tr}(\mathbf{A}^\top \mathbf{A}).
20+
21+
Due to the last equality, we can use Hutchinson-style trace estimation to estimate
22+
the squared Frobenius norm.
23+
24+
Example:
25+
>>> from numpy import mean, round
26+
>>> from numpy.linalg import norm
27+
>>> from numpy.random import rand, seed
28+
>>> seed(0) # make deterministic
29+
>>> A = rand(5, 5)
30+
>>> fro2_A = norm(A, ord='fro')**2 # exact squared Frobenius norm as reference
31+
>>> estimator = HutchinsonSquaredFrobeniusNormEstimator(A)
32+
>>> # one- and multi-sample approximations
33+
>>> fro2_A_low_prec = estimator.sample()
34+
>>> fro2_A_high_prec = mean([estimator.sample() for _ in range(1_000)])
35+
>>> assert abs(fro2_A - fro2_A_low_prec) > abs(fro2_A - fro2_A_high_prec)
36+
>>> round(fro2_A, 4), round(fro2_A_low_prec, 4), round(fro2_A_high_prec, 4)
37+
(10.7192, 8.3257, 10.6406)
38+
"""
39+
40+
def __init__(self, A: LinearOperator):
41+
"""Store the linear operator whose squared Frobenius norm will be estimated.
42+
43+
Args:
44+
A: Linear operator whose squared Frobenius norm will be estimated.
45+
"""
46+
self._trace_estimator = HutchinsonTraceEstimator(A.T @ A)
47+
48+
def sample(self, distribution: str = "rademacher") -> float:
49+
"""Draw a sample from the squared Frobenius norm estimator.
50+
51+
Multiple samples can be combined into a more accurate squared Frobenius norm
52+
estimation via averaging.
53+
54+
Args:
55+
distribution: Distribution of the vector along which the linear operator
56+
will be evaluated. Either ``'rademacher'`` or ``'normal'``.
57+
Default is ``'rademacher'``.
58+
59+
Returns:
60+
Sample from the squared Frobenius norm estimator.
61+
"""
62+
return self._trace_estimator.sample(distribution=distribution)

docs/rtd/linops.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,12 @@ Diagonal approximation
7878
.. autoclass:: curvlinops.HutchinsonDiagonalEstimator
7979
:members: __init__, sample
8080

81+
Frobenius norm approximation
82+
============================
83+
84+
.. autoclass:: curvlinops.HutchinsonSquaredFrobeniusNormEstimator
85+
:members: __init__, sample
86+
8187
Experimental
8288
============
8389

0 commit comments

Comments
 (0)