Skip to content

Commit

Permalink
Speed up eigenvalue solving using power iteration/inverse iteration (#…
Browse files Browse the repository at this point in the history
…317)

* add `power_iteration` and `shifted_inverse_iteration` for faster eigen solves

* move to _eigs, allow v0 starting vector

* try no v0

* move in `_smallest` and `_largest` to `_eigs`. Add tests.

* rename module to `_eigenvalues`

* Use LU decomposition. rename to `eigenvalues`. Add docs

* remove comment `v0` arg, add local runtime table

* remove debug print

* remove example for docstring

* fix placement of comment
  • Loading branch information
scottstanie authored Jun 21, 2024
1 parent b3a9b10 commit f5b9d8a
Show file tree
Hide file tree
Showing 3 changed files with 302 additions and 70 deletions.
83 changes: 13 additions & 70 deletions src/dolphin/phase_link/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@

import jax.numpy as jnp
import numpy as np
from jax import Array, jit, lax, vmap
from jax.scipy.linalg import cho_factor, cho_solve, eigh
from jax import Array, jit, lax
from jax.scipy.linalg import cho_factor, cho_solve
from jax.typing import ArrayLike

from dolphin._types import HalfWindow, Strides
from dolphin.utils import take_looks

from . import covariance, metrics
from ._eigenvalues import eigh_largest_stack, eigh_smallest_stack
from ._ps_filling import fill_ps_pixels

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -377,9 +378,11 @@ def process_coherence_matrices(
"""
rows, cols, n, _ = C_arrays.shape

if use_evd:
# EVD
eig_vals, eig_vecs = eigh_largest_stack(C_arrays)
evd_eig_vals, evd_eig_vecs = eigh_largest_stack(C_arrays)
eig_vals, eig_vecs = evd_eig_vals, evd_eig_vecs
estimator = jnp.zeros(eig_vals.shape, dtype=bool)
else:
# EMI
Expand All @@ -395,14 +398,18 @@ def process_coherence_matrices(
if beta > 0:
# Perform regularization
Gamma = (1 - beta) * Gamma + beta * Id

# Attempt to invert Gamma
cho, is_lower = cho_factor(Gamma)

# Check: If it fails the cholesky factor, it's close to singular and
# we should just fall back to EVD
# Use the already- factored |Gamma|^-1, solving Ax = I gives the inverse
Gamma_inv = cho_solve((cho, is_lower), Id)
emi_eig_vals, emi_eig_vecs = eigh_smallest_stack(Gamma_inv * C_arrays)
# We're looking for the lambda nearest to 1. So shift by 0.99
# Also, use the evd vectors as iteration starting point:
mu = 0.99
emi_eig_vals, emi_eig_vecs = eigh_smallest_stack(Gamma_inv * C_arrays, mu)
# From the EMI paper, normalize the eigenvectors to have norm sqrt(n)
emi_eig_vecs = (
jnp.sqrt(n)
Expand All @@ -411,9 +418,6 @@ def process_coherence_matrices(
)
# is the output is the inverse of the eigenvectors? or inverse conj?

# For places where inverting |Gamma| failed: fall back to computing EVD
evd_eig_vals, evd_eig_vecs = eigh_largest_stack(C_arrays)

# Use https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.select.html
# Note that `if` would fail the jit tracing
# https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#cond
Expand All @@ -422,6 +426,8 @@ def process_coherence_matrices(
# Must broadcast the 2D boolean array so it's the same size as the outputs
inv_has_nans_3d = jnp.tile(inv_has_nans[:, :, None], (1, 1, n))

# For EVD, or places where inverting |Gamma| failed: fall back to computing EVD
evd_eig_vals, evd_eig_vecs = eigh_largest_stack(C_arrays)
eig_vecs = lax.select(
inv_has_nans_3d,
# Run this on True: EVD, since we failed to invert:
Expand All @@ -448,69 +454,6 @@ def process_coherence_matrices(
return evd_estimate, eig_vals, estimator


# The eigenvalues are in ascending order
# Column j of `eig_vecs` is the normalized eigenvector corresponding
# to eigenvalue `lam[j]``
@jit
def _get_smallest_eigenpair(C) -> tuple[Array, Array]:
lam, eig_vecs = eigh(C)
return lam[0], eig_vecs[:, 0]


@jit
def _get_largest_eigenpair(C) -> tuple[Array, Array]:
lam, eig_vecs = eigh(C)
return lam[-1], eig_vecs[:, -1]


# We map over the first two dimensions, so now instead of one scalar eigenvalue,
# we have (rows, cols) eigenvalues
@jit
def eigh_smallest_stack(C_arrays: ArrayLike) -> tuple[Array, Array]:
"""Get the smallest (eigenvalue, eigenvector) for each pixel in a 3D stack.
Parameters
----------
C_arrays : ArrayLike
The stack of coherence matrices.
Shape = (rows, cols, nslc, nslc)
Returns
-------
eigenvalues : Array
The smallest eigenvalue for each pixel's matrix
Shape = (rows, cols)
eigenvectors : Array
The normalized eigenvector corresponding to the smallest eigenvalue
Shape = (rows, cols, nslc)
"""
return vmap(vmap(_get_smallest_eigenpair))(C_arrays)


@jit
def eigh_largest_stack(C_arrays: ArrayLike) -> tuple[Array, Array]:
"""Get the largest (eigenvalue, eigenvector) for each pixel in a 3D stack.
Parameters
----------
C_arrays : ArrayLike
The stack of coherence matrices.
Shape = (rows, cols, nslc, nslc)
Returns
-------
eigenvalues : Array
The largest eigenvalue for each pixel's matrix
Shape = (rows, cols)
eigenvectors : Array
The normalized eigenvector corresponding to the largest eigenvalue
Shape = (rows, cols, nslc)
"""
return vmap(vmap(_get_largest_eigenpair))(C_arrays)


def decimate(arr: ArrayLike, strides: Strides) -> Array:
"""Decimate an array by strides in the x and y directions.
Expand Down
205 changes: 205 additions & 0 deletions src/dolphin/phase_link/_eigenvalues.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
"""Eigenvalue solvers specialized for the phase linking use case.
| | EVD | | EMI | |
| --------------- | ----------------- | ------- | ----------------- | ------- |
| | Runtime (seconds) | Speedup | Runtime (seconds) | Speedup |
| `scipy.eigh` | 185 | - | 398 | - |
| Power iteration | 19 | 9.7 | 57 | 6.8 |
"""

from __future__ import annotations

from functools import partial

import jax
import jax.numpy as jnp
from jax import Array, jit, vmap
from jax.lax import while_loop
from jax.typing import ArrayLike

# For both largest and smallest eig, we map over the first two dimensions
# so now instead of one scalar eigenvalue, we have (rows, cols) eigenvalues


@partial(jit, static_argnames=("mu"))
def eigh_smallest_stack(
C_arrays: ArrayLike,
mu: float,
) -> tuple[Array, Array]:
"""Get the smallest (eigenvalue, eigenvector) for each pixel in a 3D stack.
Uses shift inverse iteration to find the eigenvalue closest to `mu`.
Pick `mu` to be slightly below the smallest eigenvalue for fastest convergence.
Returns real eigenvalues, assuming the arrays are Hermitian.
Parameters
----------
C_arrays : ArrayLike
The stack of coherence matrices.
Shape = (rows, cols, nslc, nslc)
mu : float
The value to use for the shift inverse iteration.
The eigenvalue closest to this value is returned.
v0 : ArrayLike, optional
The initial guess for the eigenvector.
If None, a vector of 1s is used.
Shape = (rows, cols, nslc)
Returns
-------
eigenvalues : Array
The smallest eigenvalue for each pixel's matrix
Shape = (rows, cols)
eigenvectors : Array
The normalized eigenvector corresponding to the smallest eigenvalue
Shape = (rows, cols, nslc)
"""
in_axes = (0, None)
eig_vals, eig_vecs = vmap(
vmap(inverse_iteration, in_axes=in_axes), in_axes=in_axes
)(C_arrays, mu)
return eig_vals.real, eig_vecs


@jit
def eigh_largest_stack(C_arrays: ArrayLike) -> tuple[Array, Array]:
"""Get the largest (eigenvalue, eigenvector) for each pixel in a 3D stack.
Returns real eigenvalues, assuming the arrays are Hermitian.
Parameters
----------
C_arrays : ArrayLike
The stack of coherence matrices.
Shape = (rows, cols, nslc, nslc)
Returns
-------
eigenvalues : Array
The largest eigenvalue for each pixel's matrix
Shape = (rows, cols)
eigenvectors : Array
The normalized eigenvector corresponding to the largest eigenvalue
Shape = (rows, cols, nslc)
"""
eig_vals, eig_vecs = vmap(vmap(power_iteration))(C_arrays)
return eig_vals.real, eig_vecs


@partial(jit, static_argnames=("tol",))
def power_iteration(
A: jnp.array, tol: float = 1e-5, max_iters: int = 50
) -> tuple[jnp.array, jnp.array]:
"""Compute the dominant eigenpair of a matrix using power iteration.
Parameters
----------
A : jnp.array
The input matrix.
tol : float, optional
The tolerance for convergence (default is 1e-3).
max_iters : int, optional
The maximum number of iterations (default is 50).
Returns
-------
eigenvalue : jnp.array
The dominant eigenvalue of the matrix.
vk : jnp.array
The corresponding eigenvector of the dominant eigenvalue.
"""
n = A.shape[-1]
vk = jnp.ones(n, dtype=A.dtype)
vk = vk / jnp.linalg.norm(vk)

def body_fun(val):
vk, _, idx = val
A_times_vk = A @ vk
next_vk = A_times_vk / jnp.linalg.norm(A_times_vk)
diff = jnp.linalg.norm(next_vk - vk)
return next_vk, diff, idx + 1

def cond_fun(val):
_, diff, iters = val
return jnp.logical_and(diff > tol, iters < max_iters)

init_val = (vk, 1, 1)
vk, _, end_iters = while_loop(cond_fun, body_fun, init_val)

# vk is normalized to 1, so no need to divide by (vk.T @ vk)
eigenvalue = vk.conj() @ A @ vk
return eigenvalue, vk


@partial(jit, static_argnames=("mu", "tol", "max_iters"))
def inverse_iteration(
A: ArrayLike,
mu: float,
tol: float = 1e-5,
max_iters: int = 50,
) -> tuple[Array, Array]:
"""Compute the eigenvalue of the positive definite matrix `A` closest to `mu`.
Inverse iteration (or inverse power iteration) is an iterative method used to
find the eigenvalue of a matrix A that is closest to a given scalar mu.
Parameters
----------
A : ArrayLike
Square matrix of which we seek an eigenvalue and eigenvector.
mu : float
The shift value, around which we seek the closest eigenvalue of A.
tol : float, optional
Tolerance for convergence of the method. The default is 1e-5.
max_iters : int, optional
Maximum number of iterations to perform. The default is 50.
Returns
-------
eigenvalue : jnp.array
The eigenvalue of the matrix closest to `mu`.
vk : jnp.array
The corresponding eigenvector.
Notes
-----
This method may not converge with a poor guess of `mu`.
The close the guess of `mu` is to your eigenvalue, the quicker the convergence.
However, picking the exact eigenvalue may lead to numerical instability.
References
----------
[1] https://services.math.duke.edu/~jtwong/math361-2019/lectures/Lec10eigenvalues.pdf
"""
n = A.shape[0]
vk = jnp.ones(n, dtype=A.dtype)

vk = vk / jnp.linalg.norm(vk)
Id = jnp.eye(A.shape[0], dtype=A.dtype)
# Prefactor A - mu I to quickly solve each iteration
lu_and_pivots = jax.scipy.linalg.lu_factor(A - mu * Id)

def body_fun(val):
vk_cur, _, iters = val
# Perform power iteration on (A - mu I)^{-1} ,
# which is the same as solving (A - mu I)x = v_k
vk_new = jax.scipy.linalg.lu_solve(lu_and_pivots, vk_cur)
vk_new = vk_new / jnp.linalg.norm(vk_new)
diff = jnp.linalg.norm(vk_new - vk_cur)
return vk_new, diff, iters + 1

def cond_fun(val):
_, diff, iters = val
return jnp.logical_and(diff > tol, iters < max_iters)

init_val = (vk, 1.0, 1)
vk_sol, _, end_iters = while_loop(cond_fun, body_fun, init_val)

eigenvalue = vk_sol.conj() @ A @ vk_sol
return eigenvalue, vk_sol
Loading

0 comments on commit f5b9d8a

Please sign in to comment.