Skip to content

Commit

Permalink
Solvers: Avoid non-terminating condition in line search
Browse files Browse the repository at this point in the history
This commit replaces the potentially infinite loop in the line search with a finite number of iterations. In infinite precision arithmetic, the loop is guaranteed to terminate. However, with floating point precision, it's possible for the search direction to be nearly orthogonal to the gradient and for there not to be a valid step.
  • Loading branch information
dwmunster committed Jun 9, 2024
1 parent 58b12c9 commit 6c53f8d
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 9 deletions.
26 changes: 23 additions & 3 deletions src/solvers/bfgs_solver.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use nalgebra::{DMatrix, UniformNorm};
use std::error::Error;

use nalgebra::{DMatrix, UniformNorm};

use crate::sketch::Sketch;
use crate::solvers::line_search::line_search_wolfe;
use crate::solvers::line_search::{line_search_wolfe, LineSearchError};

use super::Solver;

Expand Down Expand Up @@ -44,6 +45,8 @@ impl Solver for BFGSSolver {

let mut h = DMatrix::identity(n, n);

let mut recently_reset = false;

while iterations < self.max_iterations {
let loss = sketch.get_loss();
if loss < self.min_loss {
Expand All @@ -64,7 +67,24 @@ impl Solver for BFGSSolver {
return Err("search direction contains non-finite values".into());
}

let alpha = line_search_wolfe(sketch, &p, &gradient)?;
let alpha = match line_search_wolfe(sketch, &p, &gradient) {
Ok(alpha) => alpha,
Err(LineSearchError::SearchFailed) => {
// If the line search could not find a suitable step size, the Hessian
// approximation may be inaccurate. Resetting the Hessian to the identity matrix
// will restart with a steepest descent step and hopefully build a better
// approximation.
if recently_reset {
return Err("bfgs: line search failed twice in a row".into());
}
h = DMatrix::identity(n, n);
recently_reset = true;
continue;
}
Err(e) => return Err(e.into()),
};

recently_reset = false;

let s = alpha * &p;

Expand Down
21 changes: 15 additions & 6 deletions src/solvers/line_search.rs
Original file line number Diff line number Diff line change
@@ -1,28 +1,37 @@
use crate::sketch::Sketch;
use nalgebra::DVector;
use std::error::Error;
use thiserror::Error;

#[derive(Debug, Error)]
pub enum LineSearchError {
#[error("line search failed: search direction is not a descent direction")]
NotDescentDirection,
#[error("line search failed: could not find a suitable step size")]
SearchFailed,
}

const WOLFE_C1: f64 = 1e-4;
const WOLFE_C2: f64 = 0.9;
const MAX_ITER: usize = 15;

pub(crate) fn line_search_wolfe(
sketch: &mut Sketch,
direction: &DVector<f64>,
gradient: &DVector<f64>,
) -> Result<f64, Box<dyn Error>> {
) -> Result<f64, LineSearchError> {
let mut alpha = 1.0;
let m = gradient.dot(direction);
if m >= 0.0 {
return Err("line search failed: search direction is not a descent direction".into());
return Err(LineSearchError::NotDescentDirection);
}
let curvature_condition = WOLFE_C2 * m;
let loss = sketch.get_loss();
let x0 = sketch.get_data();
while alpha > 1e-16 {
for _i in 0..MAX_ITER {
let data = &x0 + alpha * direction;
sketch.set_data(data);
let new_loss = sketch.get_loss();
// Sufficent decrease condition
// Sufficient decrease condition
if new_loss <= loss + WOLFE_C1 * alpha * m {
// Curvature condition
let new_gradient = sketch.get_gradient();
Expand All @@ -35,5 +44,5 @@ pub(crate) fn line_search_wolfe(
alpha *= 0.5;
}
}
Err("line search failed: alpha is too small".into())
Err(LineSearchError::SearchFailed)
}

0 comments on commit 6c53f8d

Please sign in to comment.