-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[ADD] Implement XTrace trace estimator
- Loading branch information
Showing
4 changed files
with
215 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
"""Implements the XTrace algorithm from Epperly 2024.""" | ||
|
||
from numpy import column_stack, dot, einsum, mean | ||
from numpy.linalg import inv, qr | ||
from scipy.sparse.linalg import LinearOperator | ||
|
||
from curvlinops.sampling import random_vector | ||
|
||
|
||
def xtrace( | ||
A: LinearOperator, num_matvecs: int, distribution: str = "rademacher" | ||
) -> float: | ||
"""Estimate a linear operator's trace using the XTrace algorithm. | ||
The method is presented in `this paper<https://arxiv.org/pdf/2301.07825>`_: | ||
- Epperly, E. N., Tropp, J. A., & Webber, R. J. (2024). Xtrace: making the most | ||
of every sample in stochastic trace estimation. SIAM Journal on Matrix Analysis | ||
and Applications (SIMAX). | ||
It combines the variance reduction from Hutch++ with the exchangeability principle. | ||
Args: | ||
A: A square linear operator. | ||
num_matvecs: Total number of matrix-vector products to use. Must be even and | ||
less than the dimension of the linear operator. | ||
distribution: Distribution of the random vectors used for the trace estimation. | ||
Can be either ``'rademacher'`` or ``'normal'``. Default: ``'rademacher'``. | ||
Returns: | ||
The estimated trace of the linear operator. | ||
Raises: | ||
ValueError: If the linear operator is not square or if the number of matrix- | ||
vector products is not even or is greater than the dimension of the linear | ||
operator. | ||
""" | ||
if len(A.shape) != 2 or A.shape[0] != A.shape[1]: | ||
raise ValueError(f"A must be square. Got shape {A.shape}.") | ||
dim = A.shape[1] | ||
if num_matvecs % 2 != 0 or num_matvecs >= dim: | ||
raise ValueError( | ||
"num_matvecs must be even and less than the dimension of A.", | ||
f" Got {num_matvecs}.", | ||
) | ||
|
||
# draw random vectors and compute their matrix-vector products | ||
num_vecs = num_matvecs // 2 | ||
W = column_stack([random_vector(dim, distribution) for _ in range(num_vecs)]) | ||
A_W = A @ W | ||
|
||
# compute the orthogonal basis for all test vectors, and its associated trace | ||
Q, R = qr(A_W) | ||
A_Q = A @ Q | ||
tr_QT_A_Q = einsum("ij,ji->", Q.T, A_Q) | ||
|
||
# compute the traces in the bases that would result had we left out the i-th | ||
# test vector in the QR decomposition | ||
RT_inv = inv(R.T) | ||
D = 1 / (RT_inv**2).sum(0) ** 0.5 | ||
S = einsum("ij,j->ij", RT_inv, D) | ||
tr_QT_i_A_Q_i = einsum("ij,jk,kl,li->i", S.T, Q.T, A_Q, S) | ||
|
||
# Traces in the bases {Q_i}. This follows by writing Tr(QT_i A Q_i) = Tr(A Q_i QT_i) | ||
# then using the relation that Q_i QT_i = Q (I - s_i sT_i) QT. Further | ||
# simplification then leads to | ||
traces = tr_QT_A_Q - tr_QT_i_A_Q_i | ||
|
||
# estimate the trace on the complement of Q_i with vanilla Hutchinson using the | ||
# i-th test vector | ||
for i in range(num_vecs): | ||
w_i = W[:, i] | ||
s_i = S[:, i] | ||
A_w_i = A_W[:, i] | ||
|
||
def deflate(v): | ||
"""Apply (I - s_i sT_i) to a vector.""" | ||
return v - dot(s_i, v) * s_i | ||
|
||
# Compute (I - Q_i QT_i) A (I - Q_i QT_i) w_i | ||
# = (I - Q_i QT_i) (Aw - AQ_i QT_i w_i) | ||
# ( using that Q_i QT_i = Q (I - s_i sT_i) QT ) | ||
# = (I - Q_i QT_i) (Aw - AQ (I - s_i sT_i) QT w) | ||
# = (I - Q (I - s_i sT_i) QT) (Aw - AQ (I - s_i sT_i) QT w) | ||
# |--------- A_p_w_i ---------| | ||
# |-------------------- PT_A_P_w_i----------------------| | ||
A_P_w_i = A_w_i - A_Q @ deflate(Q.T @ w_i) | ||
PT_A_P_w_i = A_P_w_i - Q @ deflate(Q.T @ A_P_w_i) | ||
|
||
tr_w_i = dot(w_i, PT_A_P_w_i) | ||
traces[i] += tr_w_i | ||
|
||
return mean(traces) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
"""Test ``curvlinops.trace.epperli2024xtrace.""" | ||
|
||
from test.trace import DISTRIBUTION_IDS, DISTRIBUTIONS, _test_convergence | ||
|
||
from numpy import column_stack, dot, isclose, mean, trace | ||
from numpy.linalg import qr | ||
from numpy.random import rand, seed | ||
from pytest import mark | ||
from scipy.sparse.linalg import LinearOperator | ||
|
||
from curvlinops import xtrace | ||
from curvlinops.sampling import random_vector | ||
|
||
NUM_MATVECS = [4, 10] | ||
NUM_MATVEC_IDS = [f"num_matvecs={num_matvecs}" for num_matvecs in NUM_MATVECS] | ||
|
||
|
||
def xtrace_naive( | ||
A: LinearOperator, num_matvecs: int, distribution: str = "rademacher" | ||
) -> float: | ||
"""Naive reference implementation of XTrace.""" | ||
if len(A.shape) != 2 or A.shape[0] != A.shape[1]: | ||
raise ValueError(f"A must be square. Got shape {A.shape}.") | ||
dim = A.shape[1] | ||
if num_matvecs % 2 != 0 or num_matvecs >= dim: | ||
raise ValueError( | ||
"num_matvecs must be even and less than the dimension of A.", | ||
f" Got {num_matvecs}.", | ||
) | ||
sketch_dim = num_matvecs // 2 | ||
|
||
W = column_stack([random_vector(dim, distribution) for _ in range(sketch_dim)]) | ||
A_W = A @ W | ||
|
||
traces = [] | ||
|
||
for i in range(sketch_dim): | ||
# compute the exact trace in the basis spanned by the sketch matrix without | ||
# test vector i | ||
not_i = [j for j in range(sketch_dim) if j != i] | ||
Q_i, _ = qr(A_W[:, not_i]) | ||
A_Q_i = A @ Q_i | ||
tr_QT_i_A_Q_i = trace(Q_i.T @ A_Q_i) | ||
|
||
# apply vanilla Hutchinson in the complement, using test vector i | ||
w_i = W[:, i] | ||
A_w_i = A_W[:, i] | ||
A_P_w_i = A_w_i - A_Q_i @ (Q_i.T @ w_i) | ||
PT_A_P_w_i = A_P_w_i - Q_i @ (Q_i.T @ A_P_w_i) | ||
tr_w_i = dot(w_i, PT_A_P_w_i) | ||
|
||
traces.append(float(tr_QT_i_A_Q_i + tr_w_i)) | ||
|
||
return mean(traces) | ||
|
||
|
||
@mark.parametrize("num_matvecs", NUM_MATVECS, ids=NUM_MATVEC_IDS) | ||
@mark.parametrize("distribution", DISTRIBUTIONS, ids=DISTRIBUTION_IDS) | ||
def test_xtrace( | ||
distribution: str, | ||
num_matvecs: int, | ||
max_total_matvecs: int = 10_000, | ||
check_every: int = 10, | ||
target_rel_error: float = 1e-3, | ||
): | ||
"""Test whether the XTrace estimator converges to the true trace. | ||
Args: | ||
distribution: Distribution of the random vectors used for the trace estimation. | ||
num_matvecs: Number of matrix-vector multiplications used by one estimator. | ||
max_total_matvecs: Maximum number of matrix-vector multiplications to perform. | ||
Default: ``1_000``. If convergence has not been reached by then, the test | ||
will fail. | ||
check_every: Check for convergence every ``check_every`` estimates. | ||
Default: ``10``. | ||
target_rel_error: Target relative error for considering the estimator converged. | ||
Default: ``1e-3``. | ||
""" | ||
seed(0) | ||
A = rand(50, 50) | ||
tr_A = trace(A) | ||
|
||
used_matvecs, converged = 0, False | ||
|
||
estimates = [] | ||
while used_matvecs < max_total_matvecs and not converged: | ||
estimates.append(xtrace(A, num_matvecs, distribution=distribution)) | ||
used_matvecs += num_matvecs | ||
|
||
if len(estimates) % check_every == 0: | ||
rel_error = abs(tr_A - mean(estimates)) / abs(tr_A) | ||
print(f"Relative error after {used_matvecs} matvecs: {rel_error:.5f}.") | ||
converged = rel_error < target_rel_error | ||
|
||
assert converged | ||
|
||
|
||
@mark.parametrize("num_matvecs", NUM_MATVECS, ids=NUM_MATVEC_IDS) | ||
@mark.parametrize("distribution", DISTRIBUTIONS, ids=DISTRIBUTION_IDS) | ||
def test_xtrace_matches_naive(num_matvecs: int, distribution: str, num_seeds: int = 5): | ||
"""Test whether the efficient implementation of XTrace matches the naive. | ||
Args: | ||
num_matvecs: Number of matrix-vector multiplications used by one estimator. | ||
distribution: Distribution of the random vectors used for the trace estimation. | ||
num_seeds: Number of different seeds to test the estimators with. | ||
Default: ``5``. | ||
""" | ||
seed(0) | ||
A = rand(50, 50) | ||
|
||
# check for different seeds | ||
for i in range(num_seeds): | ||
seed(i) | ||
efficient = xtrace(A, num_matvecs, distribution=distribution) | ||
seed(i) | ||
naive = xtrace_naive(A, num_matvecs, distribution=distribution) | ||
assert isclose(efficient, naive) |