From 6c53f8dffb1a54bdfb698d7d129cd6761df74911 Mon Sep 17 00:00:00 2001 From: Drayton Munster Date: Sun, 9 Jun 2024 11:56:51 -0400 Subject: [PATCH] Solvers: Avoid non-terminating condition in line search This commit replaces the potentially infinite loop in the line search with a finite number of iterations. In infinite precision arithmetic, the loop is guaranteed to terminate. However, with floating point precision, it's possible for the search direction to be nearly orthogonal to the gradient and for there not to be a valid step. --- src/solvers/bfgs_solver.rs | 26 +++++++++++++++++++++++--- src/solvers/line_search.rs | 21 +++++++++++++++------ 2 files changed, 38 insertions(+), 9 deletions(-) diff --git a/src/solvers/bfgs_solver.rs b/src/solvers/bfgs_solver.rs index 76691bd..31f7f6e 100644 --- a/src/solvers/bfgs_solver.rs +++ b/src/solvers/bfgs_solver.rs @@ -1,8 +1,9 @@ -use nalgebra::{DMatrix, UniformNorm}; use std::error::Error; +use nalgebra::{DMatrix, UniformNorm}; + use crate::sketch::Sketch; -use crate::solvers::line_search::line_search_wolfe; +use crate::solvers::line_search::{line_search_wolfe, LineSearchError}; use super::Solver; @@ -44,6 +45,8 @@ impl Solver for BFGSSolver { let mut h = DMatrix::identity(n, n); + let mut recently_reset = false; + while iterations < self.max_iterations { let loss = sketch.get_loss(); if loss < self.min_loss { @@ -64,7 +67,24 @@ impl Solver for BFGSSolver { return Err("search direction contains non-finite values".into()); } - let alpha = line_search_wolfe(sketch, &p, &gradient)?; + let alpha = match line_search_wolfe(sketch, &p, &gradient) { + Ok(alpha) => alpha, + Err(LineSearchError::SearchFailed) => { + // If the line search could not find a suitable step size, the Hessian + // approximation may be inaccurate. Resetting the Hessian to the identity matrix + // will restart with a steepest descent step and hopefully build a better + // approximation. + if recently_reset { + return Err("bfgs: line search failed twice in a row".into()); + } + h = DMatrix::identity(n, n); + recently_reset = true; + continue; + } + Err(e) => return Err(e.into()), + }; + + recently_reset = false; let s = alpha * &p; diff --git a/src/solvers/line_search.rs b/src/solvers/line_search.rs index 19e109a..3024446 100644 --- a/src/solvers/line_search.rs +++ b/src/solvers/line_search.rs @@ -1,28 +1,37 @@ use crate::sketch::Sketch; use nalgebra::DVector; -use std::error::Error; +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum LineSearchError { + #[error("line search failed: search direction is not a descent direction")] + NotDescentDirection, + #[error("line search failed: could not find a suitable step size")] + SearchFailed, +} const WOLFE_C1: f64 = 1e-4; const WOLFE_C2: f64 = 0.9; +const MAX_ITER: usize = 15; pub(crate) fn line_search_wolfe( sketch: &mut Sketch, direction: &DVector, gradient: &DVector, -) -> Result> { +) -> Result { let mut alpha = 1.0; let m = gradient.dot(direction); if m >= 0.0 { - return Err("line search failed: search direction is not a descent direction".into()); + return Err(LineSearchError::NotDescentDirection); } let curvature_condition = WOLFE_C2 * m; let loss = sketch.get_loss(); let x0 = sketch.get_data(); - while alpha > 1e-16 { + for _i in 0..MAX_ITER { let data = &x0 + alpha * direction; sketch.set_data(data); let new_loss = sketch.get_loss(); - // Sufficent decrease condition + // Sufficient decrease condition if new_loss <= loss + WOLFE_C1 * alpha * m { // Curvature condition let new_gradient = sketch.get_gradient(); @@ -35,5 +44,5 @@ pub(crate) fn line_search_wolfe( alpha *= 0.5; } } - Err("line search failed: alpha is too small".into()) + Err(LineSearchError::SearchFailed) }