Skip to content

Commit

Permalink
[REF] Extract checks for square matrices and number of matvecs
Browse files Browse the repository at this point in the history
  • Loading branch information
f-dangel committed Jan 15, 2025
1 parent e4eccbe commit 6b97fe0
Show file tree
Hide file tree
Showing 6 changed files with 95 additions and 72 deletions.
24 changes: 10 additions & 14 deletions curvlinops/diagonal/epperly2024xtrace.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@
from scipy.sparse.linalg import LinearOperator

from curvlinops.sampling import random_vector
from curvlinops.utils import (
assert_divisible_by,
assert_is_square,
assert_matvecs_subseed_dim,
)


def xdiag(A: LinearOperator, num_matvecs: int) -> ndarray:
Expand All @@ -21,24 +26,15 @@ def xdiag(A: LinearOperator, num_matvecs: int) -> ndarray:
Args:
A: A square linear operator.
num_matvecs: Total number of matrix-vector products to use. Must be even and
less than the dimension of the linear operator.
less than the dimension of the linear operator (because otherwise one can
evaluate the true diagonal directly at the same cost).
Returns:
The estimated diagonal of the linear operator.
Raises:
ValueError: If the linear operator is not square or if the number of matrix-
vector products is not even or is greater than the dimension of the linear
operator.
"""
if len(A.shape) != 2 or A.shape[0] != A.shape[1]:
raise ValueError(f"A must be square. Got shape {A.shape}.")
dim = A.shape[1]
if num_matvecs % 2 != 0 or num_matvecs >= dim:
raise ValueError(
"num_matvecs must be even and less than the dimension of A.",
f" Got {num_matvecs}.",
)
dim = assert_is_square(A)
assert_matvecs_subseed_dim(A, num_matvecs)
assert_divisible_by(num_matvecs, 2, "num_matvecs")

# draw random vectors and compute their matrix-vector products
num_vecs = num_matvecs // 2
Expand Down
22 changes: 7 additions & 15 deletions curvlinops/diagonal/hutchinson.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from scipy.sparse.linalg import LinearOperator

from curvlinops.sampling import random_vector
from curvlinops.utils import assert_is_square, assert_matvecs_subseed_dim


def hutchinson_diag(
Expand Down Expand Up @@ -39,19 +40,15 @@ def hutchinson_diag(
Args:
A: A square linear operator whose diagonal is estimated.
num_matvecs: Total number of matrix-vector products to use. Must be smaller
than the dimension of the linear operator.
than the dimension of the linear operator (because otherwise one can
evaluate the true diagonal directly at the same cost).
distribution: Distribution of the random vectors used for the diagonal
estimation. Can be either ``'rademacher'`` or ``'normal'``.
Default: ``'rademacher'``.
Returns:
The estimated diagonal of the linear operator.
Raises:
ValueError: If the linear operator is not square or if the number of matrix-
vector products is greater than the dimension of the linear operator
(because then you can evaluate the true diagonal directly at the same cost).
Example:
>>> from numpy import diag
>>> from numpy.random import rand, seed
Expand All @@ -69,13 +66,8 @@ def hutchinson_diag(
>>> round(error_low_precision, 4), round(error_high_precision, 4)
(4.616, 1.2441)
"""
if len(A.shape) != 2 or A.shape[0] != A.shape[1]:
raise ValueError(f"A must be square. Got shape {A.shape}.")
dim = A.shape[1]
if num_matvecs >= dim:
raise ValueError(
f"num_matvecs ({num_matvecs}) must be less than A's size ({dim})."
)
dim = assert_is_square(A)
assert_matvecs_subseed_dim(A, num_matvecs)
G = column_stack([random_vector(dim, distribution) for _ in range(num_matvecs)])
AG = A @ G
return einsum("ij,ij->i", G, AG) / num_matvecs

return einsum("ij,ij->i", G, A @ G) / num_matvecs
24 changes: 10 additions & 14 deletions curvlinops/trace/epperly2024xtrace.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@
from scipy.sparse.linalg import LinearOperator

from curvlinops.sampling import random_vector
from curvlinops.utils import (
assert_divisible_by,
assert_is_square,
assert_matvecs_subseed_dim,
)


def xtrace(
Expand All @@ -23,26 +28,17 @@ def xtrace(
Args:
A: A square linear operator.
num_matvecs: Total number of matrix-vector products to use. Must be even and
less than the dimension of the linear operator.
less than the dimension of the linear operator (because otherwise one can
evaluate the true trace directly at the same cost).
distribution: Distribution of the random vectors used for the trace estimation.
Can be either ``'rademacher'`` or ``'normal'``. Default: ``'rademacher'``.
Returns:
The estimated trace of the linear operator.
Raises:
ValueError: If the linear operator is not square or if the number of matrix-
vector products is not even or is greater than the dimension of the linear
operator.
"""
if len(A.shape) != 2 or A.shape[0] != A.shape[1]:
raise ValueError(f"A must be square. Got shape {A.shape}.")
dim = A.shape[1]
if num_matvecs % 2 != 0 or num_matvecs >= dim:
raise ValueError(
"num_matvecs must be even and less than the dimension of A.",
f" Got {num_matvecs}.",
)
dim = assert_is_square(A)
assert_matvecs_subseed_dim(A, num_matvecs)
assert_divisible_by(num_matvecs, 2, "num_matvecs")

# draw random vectors and compute their matrix-vector products
num_vecs = num_matvecs // 2
Expand Down
18 changes: 5 additions & 13 deletions curvlinops/trace/hutchinson.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from scipy.sparse.linalg import LinearOperator

from curvlinops.sampling import random_vector
from curvlinops.utils import assert_is_square, assert_matvecs_subseed_dim


def hutchinson_trace(
Expand Down Expand Up @@ -40,18 +41,14 @@ def hutchinson_trace(
Args:
A: A square linear operator whose trace is estimated.
num_matvecs: Total number of matrix-vector products to use. Must be smaller
than the dimension of the linear operator.
than the dimension of the linear operator (because otherwise one can
evaluate the true trace directly at the same cost).
distribution: Distribution of the random vectors used for the trace estimation.
Can be either ``'rademacher'`` or ``'normal'``. Default: ``'rademacher'``.
Returns:
The estimated trace of the linear operator.
Raises:
ValueError: If the linear operator is not square or if the number of matrix-
vector products is greater than the dimension of the linear operator
(because then you can evaluate the true trace directly at the same cost).
Example:
>>> from numpy import trace, mean
>>> from numpy.random import rand, seed
Expand All @@ -68,13 +65,8 @@ def hutchinson_trace(
>>> round(tr_A, 4), round(tr_A_low_precision, 4), round(tr_A_high_precision, 4)
(25.7342, 59.7307, 20.033)
"""
if len(A.shape) != 2 or A.shape[0] != A.shape[1]:
raise ValueError(f"A must be square. Got shape {A.shape}.")
dim = A.shape[1]
if num_matvecs >= dim:
raise ValueError(
f"num_matvecs ({num_matvecs}) must be less than A's size ({dim})."
)
dim = assert_is_square(A)
assert_matvecs_subseed_dim(A, num_matvecs)
G = column_stack([random_vector(dim, distribution) for _ in range(num_matvecs)])

return einsum("ij,ij", G, A @ G) / num_matvecs
26 changes: 10 additions & 16 deletions curvlinops/trace/meyer2020hutch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@
from scipy.sparse.linalg import LinearOperator

from curvlinops.sampling import random_vector
from curvlinops.utils import (
assert_divisible_by,
assert_is_square,
assert_matvecs_subseed_dim,
)


def hutchpp_trace(
Expand Down Expand Up @@ -47,19 +52,14 @@ def hutchpp_trace(
Args:
A: A square linear operator whose trace is estimated.
num_matvecs: Total number of matrix-vector products to use. Must be smaller
than the dimension of the linear operator, and divisible by 3.
than the dimension of the linear operator (because otherwise one can
evaluate the true trace directly at the same cost), and divisible by 3.
distribution: Distribution of the random vectors used for the trace estimation.
Can be either ``'rademacher'`` or ``'normal'``. Default: ``'rademacher'``.
Returns:
The estimated trace of the linear operator.
Raises:
ValueError: If the linear operator is not square or if the number of matrix-
vector products is greater than the dimension of the linear operator
(because then you can evaluate the true trace directly at the same cost)
or not divisible by 3.
Example:
>>> from numpy import trace, mean
>>> from numpy.random import rand, seed
Expand All @@ -76,15 +76,9 @@ def hutchpp_trace(
>>> round(tr_A, 4), round(tr_A_low_precision, 4), round(tr_A_high_precision, 4)
(25.7342, 50.3488, 26.052)
"""
if len(A.shape) != 2 or A.shape[0] != A.shape[1]:
raise ValueError(f"A must be square. Got shape {A.shape}.")
dim = A.shape[1]
if num_matvecs >= dim or num_matvecs % 3 != 0:
raise ValueError(
f"num_matvecs ({num_matvecs}) must be less than A's size ({dim})"
" and divisible by 3."
)

dim = assert_is_square(A)
assert_matvecs_subseed_dim(A, num_matvecs)
assert_divisible_by(num_matvecs, 3, "num_matvecs")
N = num_matvecs // 3

# compute the orthogonal basis for the subspace spanned by AS, and evaluate the
Expand Down
53 changes: 53 additions & 0 deletions curvlinops/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from numpy import cumsum
from torch import Tensor

from curvlinops._torch_base import PyTorchLinearOperator


def split_list(x: Union[List, Tuple], sizes: List[int]) -> List[List]:
"""Split a list into multiple lists of specified size.
Expand Down Expand Up @@ -66,3 +68,54 @@ def allclose_report(
print(f"rtol = {rtol}, atol = {atol}.")

return close


def assert_is_square(A: Union[Tensor, PyTorchLinearOperator]) -> int:
"""Assert that a matrix or linear operator is square.
Args:
A: Matrix or linear operator to be checked.
Returns:
The dimension of the square matrix.
Raises:
ValueError: If the matrix is not square.
"""
if len(A.shape) != 2 or A.shape[0] != A.shape[1]:
raise ValueError(f"Operator must be square. Got shape {A.shape}.")
(dim,) = set(A.shape)
return dim


def assert_matvecs_subseed_dim(
A: Union[Tensor, PyTorchLinearOperator], num_matvecs: int
):
"""Assert that the number of matrix-vector products is smaller than the dimension.
Args:
A: Matrix or linear operator to be checked.
num_matvecs: Number of matrix-vector products.
Raises:
ValueError: If the number of matrix-vector products is greater than the dimension.
"""
if any(num_matvecs >= d for d in A.shape):
raise ValueError(
f"num_matvecs ({num_matvecs}) must be less than A's size ({A.shape})."
)


def assert_divisible_by(num: int, divisor: int, name: str):
"""Assert that a number is divisible by another number.
Args:
num: Number to be checked.
divisor: Divisor.
name: Name of the number.
Raises:
ValueError: If the number is not divisible by the divisor.
"""
if num % divisor != 0:
raise ValueError(f"{name} ({num}) must be divisible by {divisor}.")

0 comments on commit 6b97fe0

Please sign in to comment.