diff --git a/changelog.md b/changelog.md index 6a81d2af..dfced650 100644 --- a/changelog.md +++ b/changelog.md @@ -8,6 +8,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added/New +- Add trace estimation with the XTrace algorithm + ([paper](https://arxiv.org/pdf/2301.07825), + [PR](https://github.com/f-dangel/curvlinops/pull/166)) + - Add a [use case example](https://curvlinops.readthedocs.io/en/latest/basic_usage/example_benchmark.html) that benchmarks the linear operators diff --git a/curvlinops/__init__.py b/curvlinops/__init__.py index 9f4fafa6..abf1fd6d 100644 --- a/curvlinops/__init__.py +++ b/curvlinops/__init__.py @@ -21,6 +21,7 @@ lanczos_approximate_spectrum, ) from curvlinops.submatrix import SubmatrixLinearOperator +from curvlinops.trace.epperly2024xtrace import xtrace from curvlinops.trace.hutchinson import HutchinsonTraceEstimator from curvlinops.trace.meyer2020hutch import HutchPPTraceEstimator @@ -51,6 +52,7 @@ # trace estimation "HutchinsonTraceEstimator", "HutchPPTraceEstimator", + "xtrace", # diagonal estimation "HutchinsonDiagonalEstimator", # norm estimation diff --git a/curvlinops/trace/epperly2024xtrace.py b/curvlinops/trace/epperly2024xtrace.py new file mode 100644 index 00000000..46fa2584 --- /dev/null +++ b/curvlinops/trace/epperly2024xtrace.py @@ -0,0 +1,101 @@ +"""Implements the XTrace algorithm from Epperly 2024.""" + +from numpy import column_stack, dot, einsum, mean, ndarray +from numpy.linalg import inv, qr +from scipy.sparse.linalg import LinearOperator + +from curvlinops.sampling import random_vector + + +def xtrace( + A: LinearOperator, num_matvecs: int, distribution: str = "rademacher" +) -> float: + """Estimate a linear operator's trace using the XTrace algorithm. + + The method is presented in `this paper `_: + + - Epperly, E. N., Tropp, J. A., & Webber, R. J. (2024). Xtrace: making the most + of every sample in stochastic trace estimation. SIAM Journal on Matrix Analysis + and Applications (SIMAX). + + It combines the variance reduction from Hutch++ with the exchangeability principle. + + Args: + A: A square linear operator. + num_matvecs: Total number of matrix-vector products to use. Must be even and + less than the dimension of the linear operator. + distribution: Distribution of the random vectors used for the trace estimation. + Can be either ``'rademacher'`` or ``'normal'``. Default: ``'rademacher'``. + + Returns: + The estimated trace of the linear operator. + + Raises: + ValueError: If the linear operator is not square or if the number of matrix- + vector products is not even or is greater than the dimension of the linear + operator. + """ + if len(A.shape) != 2 or A.shape[0] != A.shape[1]: + raise ValueError(f"A must be square. Got shape {A.shape}.") + dim = A.shape[1] + if num_matvecs % 2 != 0 or num_matvecs >= dim: + raise ValueError( + "num_matvecs must be even and less than the dimension of A.", + f" Got {num_matvecs}.", + ) + + # draw random vectors and compute their matrix-vector products + num_vecs = num_matvecs // 2 + W = column_stack([random_vector(dim, distribution) for _ in range(num_vecs)]) + A_W = A @ W + + # compute the orthogonal basis for all test vectors, and its associated trace + Q, R = qr(A_W) + A_Q = A @ Q + tr_QT_A_Q = einsum("ij,ij->", Q, A_Q) + + # compute the traces in the bases that would result had we left out the i-th + # test vector in the QR decomposition + RT_inv = inv(R.T) + D = 1 / (RT_inv**2).sum(0) ** 0.5 + S = einsum("ij,j->ij", RT_inv, D) + tr_QT_i_A_Q_i = einsum("ij,ki,kl,lj->j", S, Q, A_Q, S, optimize="optimal") + + # Traces in the bases {Q_i}. This follows by writing Tr(QT_i A Q_i) = Tr(A Q_i QT_i) + # then using the relation that Q_i QT_i = Q (I - s_i sT_i) QT. Further + # simplification then leads to + traces = tr_QT_A_Q - tr_QT_i_A_Q_i + + def deflate(v: ndarray, s: ndarray) -> ndarray: + """Apply (I - s sT) to a vector. + + Args: + v: Vector to deflate. + s: Deflation vector. + + Returns: + Deflated vector. + """ + return v - dot(s, v) * s + + # estimate the trace on the complement of Q_i with vanilla Hutchinson using the + # i-th test vector + for i in range(num_vecs): + w_i = W[:, i] + s_i = S[:, i] + A_w_i = A_W[:, i] + + # Compute (I - Q_i QT_i) A (I - Q_i QT_i) w_i + # = (I - Q_i QT_i) (Aw - AQ_i QT_i w_i) + # ( using that Q_i QT_i = Q (I - s_i sT_i) QT ) + # = (I - Q_i QT_i) (Aw - AQ (I - s_i sT_i) QT w) + # = (I - Q (I - s_i sT_i) QT) (Aw - AQ (I - s_i sT_i) QT w) + # |--------- A_p_w_i ---------| + # |-------------------- PT_A_P_w_i----------------------| + A_P_w_i = A_w_i - A_Q @ deflate(Q.T @ w_i, s_i) + PT_A_P_w_i = A_P_w_i - Q @ deflate(Q.T @ A_P_w_i, s_i) + + tr_w_i = dot(w_i, PT_A_P_w_i) + traces[i] += tr_w_i + + return mean(traces) diff --git a/docs/rtd/linops.rst b/docs/rtd/linops.rst index 229cb23d..04fd42f4 100644 --- a/docs/rtd/linops.rst +++ b/docs/rtd/linops.rst @@ -78,6 +78,8 @@ Trace approximation .. autoclass:: curvlinops.HutchPPTraceEstimator :members: __init__, sample +.. autofunction:: curvlinops.xtrace + Diagonal approximation ====================== diff --git a/test/trace/test_epperly2024xtrace.py b/test/trace/test_epperly2024xtrace.py new file mode 100644 index 00000000..4c2ee4d3 --- /dev/null +++ b/test/trace/test_epperly2024xtrace.py @@ -0,0 +1,136 @@ +"""Test ``curvlinops.trace.epperli2024xtrace.""" + +from test.trace import DISTRIBUTION_IDS, DISTRIBUTIONS + +from numpy import column_stack, dot, isclose, mean, trace +from numpy.linalg import qr +from numpy.random import rand, seed +from pytest import mark +from scipy.sparse.linalg import LinearOperator + +from curvlinops import xtrace +from curvlinops.sampling import random_vector + +NUM_MATVECS = [4, 10] +NUM_MATVEC_IDS = [f"num_matvecs={num_matvecs}" for num_matvecs in NUM_MATVECS] + + +def xtrace_naive( + A: LinearOperator, num_matvecs: int, distribution: str = "rademacher" +) -> float: + """Naive reference implementation of XTrace. + + See Algorithm 1.2 in https://arxiv.org/pdf/2301.07825. + + Args: + A: A square linear operator. + num_matvecs: Total number of matrix-vector products to use. Must be even and + less than the dimension of the linear operator. + distribution: Distribution of the random vectors used for the trace estimation. + Can be either ``'rademacher'`` or ``'normal'``. Default: ``'rademacher'``. + + Returns: + The estimated trace of the linear operator. + + Raises: + ValueError: If the linear operator is not square or if the number of matrix- + vector products is not even or is greater than the dimension of the linear + operator. + """ + if len(A.shape) != 2 or A.shape[0] != A.shape[1]: + raise ValueError(f"A must be square. Got shape {A.shape}.") + dim = A.shape[1] + if num_matvecs % 2 != 0 or num_matvecs >= dim: + raise ValueError( + "num_matvecs must be even and less than the dimension of A.", + f" Got {num_matvecs}.", + ) + num_vecs = num_matvecs // 2 + + W = column_stack([random_vector(dim, distribution) for _ in range(num_vecs)]) + A_W = A @ W + + traces = [] + + for i in range(num_vecs): + # compute the exact trace in the basis spanned by the sketch matrix without + # test vector i + not_i = [j for j in range(num_vecs) if j != i] + Q_i, _ = qr(A_W[:, not_i]) + A_Q_i = A @ Q_i + tr_QT_i_A_Q_i = trace(Q_i.T @ A_Q_i) + + # apply vanilla Hutchinson in the complement, using test vector i + w_i = W[:, i] + A_w_i = A_W[:, i] + A_P_w_i = A_w_i - A_Q_i @ (Q_i.T @ w_i) + PT_A_P_w_i = A_P_w_i - Q_i @ (Q_i.T @ A_P_w_i) + tr_w_i = dot(w_i, PT_A_P_w_i) + + traces.append(float(tr_QT_i_A_Q_i + tr_w_i)) + + return mean(traces) + + +@mark.parametrize("num_matvecs", NUM_MATVECS, ids=NUM_MATVEC_IDS) +@mark.parametrize("distribution", DISTRIBUTIONS, ids=DISTRIBUTION_IDS) +def test_xtrace( + distribution: str, + num_matvecs: int, + max_total_matvecs: int = 10_000, + check_every: int = 10, + target_rel_error: float = 1e-3, +): + """Test whether the XTrace estimator converges to the true trace. + + Args: + distribution: Distribution of the random vectors used for the trace estimation. + num_matvecs: Number of matrix-vector multiplications used by one estimator. + max_total_matvecs: Maximum number of matrix-vector multiplications to perform. + Default: ``1_000``. If convergence has not been reached by then, the test + will fail. + check_every: Check for convergence every ``check_every`` estimates. + Default: ``10``. + target_rel_error: Target relative error for considering the estimator converged. + Default: ``1e-3``. + """ + seed(0) + A = rand(50, 50) + tr_A = trace(A) + + used_matvecs, converged = 0, False + + estimates = [] + while used_matvecs < max_total_matvecs and not converged: + estimates.append(xtrace(A, num_matvecs, distribution=distribution)) + used_matvecs += num_matvecs + + if len(estimates) % check_every == 0: + rel_error = abs(tr_A - mean(estimates)) / abs(tr_A) + print(f"Relative error after {used_matvecs} matvecs: {rel_error:.5f}.") + converged = rel_error < target_rel_error + + assert converged + + +@mark.parametrize("num_matvecs", NUM_MATVECS, ids=NUM_MATVEC_IDS) +@mark.parametrize("distribution", DISTRIBUTIONS, ids=DISTRIBUTION_IDS) +def test_xtrace_matches_naive(num_matvecs: int, distribution: str, num_seeds: int = 5): + """Test whether the efficient implementation of XTrace matches the naive. + + Args: + num_matvecs: Number of matrix-vector multiplications used by one estimator. + distribution: Distribution of the random vectors used for the trace estimation. + num_seeds: Number of different seeds to test the estimators with. + Default: ``5``. + """ + seed(0) + A = rand(50, 50) + + # check for different seeds + for i in range(num_seeds): + seed(i) + efficient = xtrace(A, num_matvecs, distribution=distribution) + seed(i) + naive = xtrace_naive(A, num_matvecs, distribution=distribution) + assert isclose(efficient, naive)