-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[ADD] Hutchinson-style matrix diagonal estimation (#40)
* [REF] Use same code in trace tests, extract random vector generation * [ADD] Hutchinson-style diagonal estimation * [DOC] Add diagonal estimator to documentation * [DOC] Short summary for each trace/diagonal estimation method
- Loading branch information
Showing
13 changed files
with
303 additions
and
134 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
"""Matrix diagonal estimation methods.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
"""Hutchinson-style matrix diagonal estimation.""" | ||
|
||
|
||
from numpy import ndarray | ||
from scipy.sparse.linalg import LinearOperator | ||
|
||
from curvlinops.sampling import random_vector | ||
|
||
|
||
class HutchinsonDiagonalEstimator: | ||
r"""Class to perform diagonal estimation with Hutchinson's method. | ||
For details, see | ||
- Martens, J., Sutskever, I., & Swersky, K. (2012). Estimating the hessian by | ||
back-propagating curvature. International Conference on Machine Learning (ICML). | ||
Let :math:`\mathbf{A}` be a square linear operator. We can approximate its diagonal | ||
:math:`\mathrm{diag}(\mathbf{A})` by drawing a random vector :math:`\mathbf{v}` | ||
which satisfies :math:`\mathbb{E}[\mathbf{v} \mathbf{v}^\top] = \mathbf{I}` and | ||
sample from the estimator | ||
.. math:: | ||
\mathbf{a} | ||
:= \mathbf{v} \odot \mathbf{A} \mathbf{v} | ||
\approx \mathrm{diag}(\mathbf{A})\,. | ||
This estimator is unbiased, | ||
.. math:: | ||
\mathbb{E}[a_i] | ||
= \sum_j \mathbb{E}[v_i A_{i,j} v_j] | ||
= \sum_j A_{i,j} \mathbb{E}[v_i v_j] | ||
= \sum_j A_{i,j} \delta_{i, j} | ||
= A_{i,i}\,. | ||
Example: | ||
>>> from numpy import diag, mean, round | ||
>>> from numpy.random import rand, seed | ||
>>> from numpy.linalg import norm | ||
>>> seed(0) # make deterministic | ||
>>> A = rand(10, 10) | ||
>>> diag_A = diag(A) # exact diagonal as reference | ||
>>> estimator = HutchinsonDiagonalEstimator(A) | ||
>>> # one- and multi-sample approximations | ||
>>> diag_A_low_precision = estimator.sample() | ||
>>> samples = [estimator.sample() for _ in range(1_000)] | ||
>>> diag_A_high_precision = mean(samples, axis=0) | ||
>>> # compute residual norms | ||
>>> error_low_precision = norm(diag_A - diag_A_low_precision) | ||
>>> error_high_precision = norm(diag_A - diag_A_high_precision) | ||
>>> assert error_low_precision > error_high_precision | ||
>>> round(error_low_precision, 4), round(error_high_precision, 4) | ||
(5.7268, 0.1525) | ||
""" | ||
|
||
def __init__(self, A: LinearOperator): | ||
"""Store the linear operator whose diagonal will be estimated. | ||
Args: | ||
A: Linear square-shaped operator whose diagonal will be estimated. | ||
Raises: | ||
ValueError: If the operator is not square. | ||
""" | ||
if len(A.shape) != 2 or A.shape[0] != A.shape[1]: | ||
raise ValueError(f"A must be square. Got shape {A.shape}.") | ||
self._A = A | ||
|
||
def sample(self, distribution: str = "rademacher") -> ndarray: | ||
"""Draw a sample from the diagonal estimator. | ||
Multiple samples can be combined into a more accurate diagonal estimation via | ||
averaging. | ||
Args: | ||
distribution: Distribution of the vector along which the linear operator | ||
will be evaluated. Either ``'rademacher'`` or ``'normal'``. | ||
Default is ``'rademacher'``. | ||
Returns: | ||
A Sample from the diagonal estimator. | ||
""" | ||
dim = self._A.shape[1] | ||
v = random_vector(dim, distribution) | ||
Av = self._A @ v | ||
return v * Av |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
"""Sampling methods for random vectors.""" | ||
|
||
from numpy import ndarray | ||
from numpy.random import binomial, randn | ||
|
||
|
||
def rademacher(dim: int) -> ndarray: | ||
"""Draw a vector with i.i.d. Rademacher elements. | ||
Args: | ||
dim: Dimension of the vector. | ||
Returns: | ||
Vector with i.i.d. Rademacher elements and specified dimension. | ||
""" | ||
num_trials, success_prob = 1, 0.5 | ||
return binomial(num_trials, success_prob, size=dim).astype(float) * 2 - 1 | ||
|
||
|
||
def normal(dim: int) -> ndarray: | ||
"""Draw a vector with i.i.d. standard normal elements. | ||
Args: | ||
dim: Dimension of the vector. | ||
Returns: | ||
Vector with i.i.d. standard normal elements and specified dimension. | ||
""" | ||
return randn(dim) | ||
|
||
|
||
def random_vector(dim: int, distribution: str) -> ndarray: | ||
"""Draw a vector with i.i.d. elements. | ||
Args: | ||
dim: Dimension of the vector. | ||
distribution: Distribution of the vector's elements. Either ``'rademacher'`` or | ||
``'normal'``. | ||
Returns: | ||
Vector with i.i.d. elements and specified dimension. | ||
Raises: | ||
ValueError: If the distribution is unknown. | ||
""" | ||
if distribution == "rademacher": | ||
return rademacher(dim) | ||
elif distribution == "normal": | ||
return normal(dim) | ||
else: | ||
raise ValueError(f"Unknown distribution {distribution:!r}.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.