Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ADD] Hutchinson-style matrix diagonal estimation #40

Merged
merged 5 commits into from
Oct 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions curvlinops/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""``curvlinops`` library API."""

from curvlinops.diagonal.hutchinson import HutchinsonDiagonalEstimator
from curvlinops.fisher import FisherMCLinearOperator
from curvlinops.ggn import GGNLinearOperator
from curvlinops.gradient_moments import EFLinearOperator
Expand Down Expand Up @@ -32,4 +33,5 @@
"LanczosApproximateLogSpectrumCached",
"HutchinsonTraceEstimator",
"HutchPPTraceEstimator",
"HutchinsonDiagonalEstimator",
]
1 change: 1 addition & 0 deletions curvlinops/diagonal/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Matrix diagonal estimation methods."""
87 changes: 87 additions & 0 deletions curvlinops/diagonal/hutchinson.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""Hutchinson-style matrix diagonal estimation."""


from numpy import ndarray
from scipy.sparse.linalg import LinearOperator

from curvlinops.sampling import random_vector


class HutchinsonDiagonalEstimator:
r"""Class to perform diagonal estimation with Hutchinson's method.

For details, see

- Martens, J., Sutskever, I., & Swersky, K. (2012). Estimating the hessian by
back-propagating curvature. International Conference on Machine Learning (ICML).

Let :math:`\mathbf{A}` be a square linear operator. We can approximate its diagonal
:math:`\mathrm{diag}(\mathbf{A})` by drawing a random vector :math:`\mathbf{v}`
which satisfies :math:`\mathbb{E}[\mathbf{v} \mathbf{v}^\top] = \mathbf{I}` and
sample from the estimator

.. math::
\mathbf{a}
:= \mathbf{v} \odot \mathbf{A} \mathbf{v}
\approx \mathrm{diag}(\mathbf{A})\,.

This estimator is unbiased,

.. math::
\mathbb{E}[a_i]
= \sum_j \mathbb{E}[v_i A_{i,j} v_j]
= \sum_j A_{i,j} \mathbb{E}[v_i v_j]
= \sum_j A_{i,j} \delta_{i, j}
= A_{i,i}\,.

Example:
>>> from numpy import diag, mean, round
>>> from numpy.random import rand, seed
>>> from numpy.linalg import norm
>>> seed(0) # make deterministic
>>> A = rand(10, 10)
>>> diag_A = diag(A) # exact diagonal as reference
>>> estimator = HutchinsonDiagonalEstimator(A)
>>> # one- and multi-sample approximations
>>> diag_A_low_precision = estimator.sample()
>>> samples = [estimator.sample() for _ in range(1_000)]
>>> diag_A_high_precision = mean(samples, axis=0)
>>> # compute residual norms
>>> error_low_precision = norm(diag_A - diag_A_low_precision)
>>> error_high_precision = norm(diag_A - diag_A_high_precision)
>>> assert error_low_precision > error_high_precision
>>> round(error_low_precision, 4), round(error_high_precision, 4)
(5.7268, 0.1525)
"""

def __init__(self, A: LinearOperator):
"""Store the linear operator whose diagonal will be estimated.

Args:
A: Linear square-shaped operator whose diagonal will be estimated.

Raises:
ValueError: If the operator is not square.
"""
if len(A.shape) != 2 or A.shape[0] != A.shape[1]:
raise ValueError(f"A must be square. Got shape {A.shape}.")
self._A = A

def sample(self, distribution: str = "rademacher") -> ndarray:
"""Draw a sample from the diagonal estimator.

Multiple samples can be combined into a more accurate diagonal estimation via
averaging.

Args:
distribution: Distribution of the vector along which the linear operator
will be evaluated. Either ``'rademacher'`` or ``'normal'``.
Default is ``'rademacher'``.

Returns:
A Sample from the diagonal estimator.
"""
dim = self._A.shape[1]
v = random_vector(dim, distribution)
Av = self._A @ v
return v * Av
51 changes: 51 additions & 0 deletions curvlinops/sampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""Sampling methods for random vectors."""

from numpy import ndarray
from numpy.random import binomial, randn


def rademacher(dim: int) -> ndarray:
"""Draw a vector with i.i.d. Rademacher elements.

Args:
dim: Dimension of the vector.

Returns:
Vector with i.i.d. Rademacher elements and specified dimension.
"""
num_trials, success_prob = 1, 0.5
return binomial(num_trials, success_prob, size=dim).astype(float) * 2 - 1


def normal(dim: int) -> ndarray:
"""Draw a vector with i.i.d. standard normal elements.

Args:
dim: Dimension of the vector.

Returns:
Vector with i.i.d. standard normal elements and specified dimension.
"""
return randn(dim)


def random_vector(dim: int, distribution: str) -> ndarray:
"""Draw a vector with i.i.d. elements.

Args:
dim: Dimension of the vector.
distribution: Distribution of the vector's elements. Either ``'rademacher'`` or
``'normal'``.

Returns:
Vector with i.i.d. elements and specified dimension.

Raises:
ValueError: If the distribution is unknown.
"""
if distribution == "rademacher":
return rademacher(dim)
elif distribution == "normal":
return normal(dim)
else:
raise ValueError(f"Unknown distribution {distribution:!r}.")
49 changes: 23 additions & 26 deletions curvlinops/trace/hutchinson.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,39 @@
"""Vanilla Hutchinson trace estimation."""

from typing import Callable, Dict

from numpy import dot, ndarray
from numpy import dot
from scipy.sparse.linalg import LinearOperator

from curvlinops.trace.sampling import normal, rademacher
from curvlinops.sampling import random_vector


class HutchinsonTraceEstimator:
"""Class to perform trace estimation with Hutchinson's method.
r"""Class to perform trace estimation with Hutchinson's method.

For details, see

- Hutchinson, M. (1989). A stochastic estimator of the trace of the influence
matrix for laplacian smoothing splines. Communication in Statistics---Simulation
and Computation.

Let :math:`\mathbf{A}` be a square linear operator. We can approximate its trace
:math:`\mathrm{Tr}(\mathbf{A})` by drawing a random vector :math:`\mathbf{v}`
which satisfies :math:`\mathbb{E}[\mathbf{v} \mathbf{v}^\top] = \mathbf{I}` and
sample from the estimator

.. math::
a
:= \mathbf{v}^\top \mathbf{A} \mathbf{v}
\approx \mathrm{Tr}(\mathbf{A})\,.

This estimator is unbiased,

.. math::
\mathbb{E}[a]
= \mathrm{Tr}(\mathbb{E}[\mathbf{v}^\top\mathbf{A} \mathbf{v}])
= \mathrm{Tr}(\mathbf{A} \mathbb{E}[\mathbf{v} \mathbf{v}^\top])
= \mathrm{Tr}(\mathbf{A} \mathbf{I})
= \mathrm{Tr}(\mathbf{A})\,.

Example:
>>> from numpy import trace, mean, round
>>> from numpy.random import rand, seed
Expand All @@ -30,17 +47,8 @@ class HutchinsonTraceEstimator:
>>> assert abs(tr_A - tr_A_low_precision) > abs(tr_A - tr_A_high_precision)
>>> round(tr_A, 4), round(tr_A_low_precision, 4), round(tr_A_high_precision, 4)
(4.4575, 6.6796, 4.3886)

Attributes:
SUPPORTED_DISTRIBUTIONS: Dictionary mapping supported distributions to their
sampling functions.
"""

SUPPORTED_DISTRIBUTIONS: Dict[str, Callable[[int], ndarray]] = {
"rademacher": rademacher,
"normal": normal,
}

def __init__(self, A: LinearOperator):
"""Store the linear operator whose trace will be estimated.

Expand All @@ -67,19 +75,8 @@ def sample(self, distribution: str = "rademacher") -> float:

Returns:
Sample from the trace estimator.

Raises:
ValueError: If the distribution is not supported.
"""
dim = self._A.shape[1]

if distribution not in self.SUPPORTED_DISTRIBUTIONS:
raise ValueError(
f"Unsupported distribution {distribution:!r}. "
f"Supported distributions are {list(self.SUPPORTED_DISTRIBUTIONS)}."
)

v = self.SUPPORTED_DISTRIBUTIONS[distribution](dim)
v = random_vector(dim, distribution)
Av = self._A @ v

return dot(v, Av)
60 changes: 27 additions & 33 deletions curvlinops/trace/meyer2020hutch.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
"""Implementation of Hutch++ trace estimation from Meyer et al."""

from typing import Callable, Dict, Optional, Union
from typing import Optional, Union

from numpy import column_stack, dot, ndarray
from numpy.linalg import qr
from scipy.sparse.linalg import LinearOperator

from curvlinops.trace.sampling import normal, rademacher
from curvlinops.sampling import random_vector


class HutchPPTraceEstimator:
"""Class to perform trace estimation with the Huch++ method.
r"""Class to perform trace estimation with the Huch++ method.

In contrast to vanilla Hutchinson, Hutch++ has lower variance, but requires more
memory.
Expand All @@ -20,6 +20,26 @@ class HutchPPTraceEstimator:
- Meyer, R. A., Musco, C., Musco, C., & Woodruff, D. P. (2020). Hutch++:
optimal stochastic trace estimation.

Let :math:`\mathbf{A}` be a square linear operator whose trace we want to
approximate. First, we compute an orthonormal basis :math:`\mathbf{Q}` of a
sub-space spanned by :math:`\mathbf{A} \mathbf{S}` where :math:`\mathbf{S}` is a
tall random matrix with i.i.d. elements. Then, we compute the trace in the sub-space
and apply Hutchinson's estimator in the remaining space spanned by
:math:`\mathbf{I} - \mathbf{Q} \mathbf{Q}^\top`: We can draw a random vector
:math:`\mathbf{v}` which satisfies
:math:`\mathbb{E}[\mathbf{v} \mathbf{v}^\top] = \mathbf{I}` and sample from the
estimator

.. math::
a
:= \mathrm{Tr}(\mathbf{Q}^\top \mathbf{A} \mathbf{Q})
+ \mathbf{v}^\top (\mathbf{I} - \mathbf{Q} \mathbf{Q}^\top)^\top
\mathbf{A} (\mathbf{I} - \mathbf{Q} \mathbf{Q}^\top) \mathbf{v}
\approx \mathrm{Tr}(\mathbf{A})\,.

This estimator is unbiased, :math:`\mathbb{E}[a] = \mathrm{Tr}(\mathbf{A})`, as the
first term is constant and the second part is Hutchinson's estimator in a sub-space.

Example:
>>> from numpy import trace, mean, round
>>> from numpy.random import rand, seed
Expand All @@ -33,17 +53,8 @@ class HutchPPTraceEstimator:
>>> # assert abs(tr_A - tr_A_low_precision) > abs(tr_A - tr_A_high_precision)
>>> round(tr_A, 4), round(tr_A_low_precision, 4), round(tr_A_high_precision, 4)
(4.4575, 2.4085, 4.5791)

Attributes:
SUPPORTED_DISTRIBUTIONS: Dictionary mapping supported distributions to their
sampling functions.
"""

SUPPORTED_DISTRIBUTIONS: Dict[str, Callable[[int], ndarray]] = {
"rademacher": rademacher,
"normal": normal,
}

def __init__(
self,
A: LinearOperator,
Expand All @@ -64,8 +75,8 @@ def __init__(
``'rademacher'``.

Raises:
ValueError: If the operator is not square, the basis dimension is too
large, or the sampling distribution is not supported.
ValueError: If the operator is not square or the basis dimension is too
large.

Note:
If you are planning to perform a fair (i.e. same computation budget)
Expand All @@ -86,12 +97,6 @@ def __init__(
f"Basis dimension must be at most {self._A.shape[1]}. Got {basis_dim}."
)
self._basis_dim = basis_dim

if basis_distribution not in self.SUPPORTED_DISTRIBUTIONS:
raise ValueError(
f"Unsupported distribution {basis_distribution:!r}. "
f"Supported distributions are {list(self.SUPPORTED_DISTRIBUTIONS)}."
)
self._basis_distribution = basis_distribution

# When drawing the first sample, the basis and its subspace trace will be
Expand All @@ -116,25 +121,14 @@ def sample(self, distribution: str = "rademacher") -> float:

Returns:
Sample from the trace estimator.

Raises:
ValueError: If the distribution is not supported.
"""
self.maybe_compute_and_cache_subspace()

if distribution not in self.SUPPORTED_DISTRIBUTIONS:
raise ValueError(
f"Unsupported distribution {distribution:!r}. "
f"Supported distributions are {list(self.SUPPORTED_DISTRIBUTIONS)}."
)

dim = self._A.shape[1]
v = self.SUPPORTED_DISTRIBUTIONS[distribution](dim)
v = random_vector(dim, distribution)
# project out subspace
v -= self._Q @ (self._Q.T @ v)

Av = self._A @ v

return self._tr_QT_A_Q + dot(v, Av)

def maybe_compute_and_cache_subspace(self):
Expand All @@ -145,7 +139,7 @@ def maybe_compute_and_cache_subspace(self):
dim = self._A.shape[1]
AS = column_stack(
[
self._A @ self.SUPPORTED_DISTRIBUTIONS[self._basis_distribution](dim)
self._A @ random_vector(dim, self._basis_distribution)
for _ in range(self._basis_dim)
]
)
Expand Down
29 changes: 0 additions & 29 deletions curvlinops/trace/sampling.py

This file was deleted.

Loading