Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ADD] Implement XTrace trace estimator #166

Merged
merged 8 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions curvlinops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
lanczos_approximate_spectrum,
)
from curvlinops.submatrix import SubmatrixLinearOperator
from curvlinops.trace.epperly2024xtrace import xtrace
from curvlinops.trace.hutchinson import HutchinsonTraceEstimator
from curvlinops.trace.meyer2020hutch import HutchPPTraceEstimator

Expand Down Expand Up @@ -51,6 +52,7 @@
# trace estimation
"HutchinsonTraceEstimator",
"HutchPPTraceEstimator",
"xtrace",
# diagonal estimation
"HutchinsonDiagonalEstimator",
# norm estimation
Expand Down
101 changes: 101 additions & 0 deletions curvlinops/trace/epperly2024xtrace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
"""Implements the XTrace algorithm from Epperly 2024."""

from numpy import column_stack, dot, einsum, mean, ndarray
from numpy.linalg import inv, qr
from scipy.sparse.linalg import LinearOperator

from curvlinops.sampling import random_vector


def xtrace(
A: LinearOperator, num_matvecs: int, distribution: str = "rademacher"
) -> float:
"""Estimate a linear operator's trace using the XTrace algorithm.

The method is presented in `this paper <https://arxiv.org/pdf/2301.07825>`_:

- Epperly, E. N., Tropp, J. A., & Webber, R. J. (2024). Xtrace: making the most
of every sample in stochastic trace estimation. SIAM Journal on Matrix Analysis
and Applications (SIMAX).

It combines the variance reduction from Hutch++ with the exchangeability principle.

Args:
A: A square linear operator.
num_matvecs: Total number of matrix-vector products to use. Must be even and
less than the dimension of the linear operator.
distribution: Distribution of the random vectors used for the trace estimation.
Can be either ``'rademacher'`` or ``'normal'``. Default: ``'rademacher'``.

Returns:
The estimated trace of the linear operator.

Raises:
ValueError: If the linear operator is not square or if the number of matrix-
vector products is not even or is greater than the dimension of the linear
operator.
"""
if len(A.shape) != 2 or A.shape[0] != A.shape[1]:
raise ValueError(f"A must be square. Got shape {A.shape}.")
dim = A.shape[1]
if num_matvecs % 2 != 0 or num_matvecs >= dim:
raise ValueError(
"num_matvecs must be even and less than the dimension of A.",
f" Got {num_matvecs}.",
)

# draw random vectors and compute their matrix-vector products
num_vecs = num_matvecs // 2
W = column_stack([random_vector(dim, distribution) for _ in range(num_vecs)])
A_W = A @ W

# compute the orthogonal basis for all test vectors, and its associated trace
Q, R = qr(A_W)
A_Q = A @ Q
tr_QT_A_Q = einsum("ij,ji->", Q.T, A_Q)
f-dangel marked this conversation as resolved.
Show resolved Hide resolved

# compute the traces in the bases that would result had we left out the i-th
# test vector in the QR decomposition
RT_inv = inv(R.T)
D = 1 / (RT_inv**2).sum(0) ** 0.5
S = einsum("ij,j->ij", RT_inv, D)
tr_QT_i_A_Q_i = einsum("ij,jk,kl,li->i", S.T, Q.T, A_Q, S)
f-dangel marked this conversation as resolved.
Show resolved Hide resolved

# Traces in the bases {Q_i}. This follows by writing Tr(QT_i A Q_i) = Tr(A Q_i QT_i)
# then using the relation that Q_i QT_i = Q (I - s_i sT_i) QT. Further
# simplification then leads to
traces = tr_QT_A_Q - tr_QT_i_A_Q_i

def deflate(v: ndarray, s: ndarray) -> ndarray:
"""Apply (I - s sT) to a vector.

Args:
v: Vector to deflate.
s: Deflation vector.

Returns:
Deflated vector.
"""
return v - dot(s, v) * s

# estimate the trace on the complement of Q_i with vanilla Hutchinson using the
# i-th test vector
for i in range(num_vecs):
w_i = W[:, i]
s_i = S[:, i]
A_w_i = A_W[:, i]

# Compute (I - Q_i QT_i) A (I - Q_i QT_i) w_i
# = (I - Q_i QT_i) (Aw - AQ_i QT_i w_i)
# ( using that Q_i QT_i = Q (I - s_i sT_i) QT )
# = (I - Q_i QT_i) (Aw - AQ (I - s_i sT_i) QT w)
# = (I - Q (I - s_i sT_i) QT) (Aw - AQ (I - s_i sT_i) QT w)
# |--------- A_p_w_i ---------|
# |-------------------- PT_A_P_w_i----------------------|
A_P_w_i = A_w_i - A_Q @ deflate(Q.T @ w_i, s_i)
PT_A_P_w_i = A_P_w_i - Q @ deflate(Q.T @ A_P_w_i, s_i)

tr_w_i = dot(w_i, PT_A_P_w_i)
traces[i] += tr_w_i

return mean(traces)
2 changes: 2 additions & 0 deletions docs/rtd/linops.rst
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ Trace approximation
.. autoclass:: curvlinops.HutchPPTraceEstimator
:members: __init__, sample

.. autofunction:: curvlinops.xtrace

Diagonal approximation
======================

Expand Down
136 changes: 136 additions & 0 deletions test/trace/test_epperly2024xtrace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
"""Test ``curvlinops.trace.epperli2024xtrace."""

from test.trace import DISTRIBUTION_IDS, DISTRIBUTIONS

from numpy import column_stack, dot, isclose, mean, trace
from numpy.linalg import qr
from numpy.random import rand, seed
from pytest import mark
from scipy.sparse.linalg import LinearOperator

from curvlinops import xtrace
from curvlinops.sampling import random_vector

NUM_MATVECS = [4, 10]
NUM_MATVEC_IDS = [f"num_matvecs={num_matvecs}" for num_matvecs in NUM_MATVECS]


def xtrace_naive(
A: LinearOperator, num_matvecs: int, distribution: str = "rademacher"
) -> float:
"""Naive reference implementation of XTrace.

See Algorithm 1.2 in https://arxiv.org/pdf/2301.07825.

Args:
A: A square linear operator.
num_matvecs: Total number of matrix-vector products to use. Must be even and
less than the dimension of the linear operator.
distribution: Distribution of the random vectors used for the trace estimation.
Can be either ``'rademacher'`` or ``'normal'``. Default: ``'rademacher'``.

Returns:
The estimated trace of the linear operator.

Raises:
ValueError: If the linear operator is not square or if the number of matrix-
vector products is not even or is greater than the dimension of the linear
operator.
"""
if len(A.shape) != 2 or A.shape[0] != A.shape[1]:
raise ValueError(f"A must be square. Got shape {A.shape}.")
dim = A.shape[1]
if num_matvecs % 2 != 0 or num_matvecs >= dim:
raise ValueError(
"num_matvecs must be even and less than the dimension of A.",
f" Got {num_matvecs}.",
)
num_vecs = num_matvecs // 2

W = column_stack([random_vector(dim, distribution) for _ in range(num_vecs)])
A_W = A @ W

traces = []

for i in range(num_vecs):
# compute the exact trace in the basis spanned by the sketch matrix without
# test vector i
not_i = [j for j in range(num_vecs) if j != i]
Q_i, _ = qr(A_W[:, not_i])
A_Q_i = A @ Q_i
tr_QT_i_A_Q_i = trace(Q_i.T @ A_Q_i)

# apply vanilla Hutchinson in the complement, using test vector i
w_i = W[:, i]
A_w_i = A_W[:, i]
A_P_w_i = A_w_i - A_Q_i @ (Q_i.T @ w_i)
PT_A_P_w_i = A_P_w_i - Q_i @ (Q_i.T @ A_P_w_i)
tr_w_i = dot(w_i, PT_A_P_w_i)

traces.append(float(tr_QT_i_A_Q_i + tr_w_i))

return mean(traces)


@mark.parametrize("num_matvecs", NUM_MATVECS, ids=NUM_MATVEC_IDS)
@mark.parametrize("distribution", DISTRIBUTIONS, ids=DISTRIBUTION_IDS)
def test_xtrace(
distribution: str,
num_matvecs: int,
max_total_matvecs: int = 10_000,
check_every: int = 10,
target_rel_error: float = 1e-3,
):
"""Test whether the XTrace estimator converges to the true trace.

Args:
distribution: Distribution of the random vectors used for the trace estimation.
num_matvecs: Number of matrix-vector multiplications used by one estimator.
max_total_matvecs: Maximum number of matrix-vector multiplications to perform.
Default: ``1_000``. If convergence has not been reached by then, the test
will fail.
check_every: Check for convergence every ``check_every`` estimates.
Default: ``10``.
target_rel_error: Target relative error for considering the estimator converged.
Default: ``1e-3``.
"""
seed(0)
A = rand(50, 50)
tr_A = trace(A)

used_matvecs, converged = 0, False

estimates = []
while used_matvecs < max_total_matvecs and not converged:
estimates.append(xtrace(A, num_matvecs, distribution=distribution))
used_matvecs += num_matvecs

if len(estimates) % check_every == 0:
rel_error = abs(tr_A - mean(estimates)) / abs(tr_A)
print(f"Relative error after {used_matvecs} matvecs: {rel_error:.5f}.")
converged = rel_error < target_rel_error

assert converged


@mark.parametrize("num_matvecs", NUM_MATVECS, ids=NUM_MATVEC_IDS)
@mark.parametrize("distribution", DISTRIBUTIONS, ids=DISTRIBUTION_IDS)
def test_xtrace_matches_naive(num_matvecs: int, distribution: str, num_seeds: int = 5):
"""Test whether the efficient implementation of XTrace matches the naive.

Args:
num_matvecs: Number of matrix-vector multiplications used by one estimator.
distribution: Distribution of the random vectors used for the trace estimation.
num_seeds: Number of different seeds to test the estimators with.
Default: ``5``.
"""
seed(0)
A = rand(50, 50)

# check for different seeds
for i in range(num_seeds):
seed(i)
efficient = xtrace(A, num_matvecs, distribution=distribution)
seed(i)
naive = xtrace_naive(A, num_matvecs, distribution=distribution)
assert isclose(efficient, naive)
Loading