diff --git a/src/solvers/bfgs_solver.rs b/src/solvers/bfgs_solver.rs index 76691bd..31f7f6e 100644 --- a/src/solvers/bfgs_solver.rs +++ b/src/solvers/bfgs_solver.rs @@ -1,8 +1,9 @@ -use nalgebra::{DMatrix, UniformNorm}; use std::error::Error; +use nalgebra::{DMatrix, UniformNorm}; + use crate::sketch::Sketch; -use crate::solvers::line_search::line_search_wolfe; +use crate::solvers::line_search::{line_search_wolfe, LineSearchError}; use super::Solver; @@ -44,6 +45,8 @@ impl Solver for BFGSSolver { let mut h = DMatrix::identity(n, n); + let mut recently_reset = false; + while iterations < self.max_iterations { let loss = sketch.get_loss(); if loss < self.min_loss { @@ -64,7 +67,24 @@ impl Solver for BFGSSolver { return Err("search direction contains non-finite values".into()); } - let alpha = line_search_wolfe(sketch, &p, &gradient)?; + let alpha = match line_search_wolfe(sketch, &p, &gradient) { + Ok(alpha) => alpha, + Err(LineSearchError::SearchFailed) => { + // If the line search could not find a suitable step size, the Hessian + // approximation may be inaccurate. Resetting the Hessian to the identity matrix + // will restart with a steepest descent step and hopefully build a better + // approximation. + if recently_reset { + return Err("bfgs: line search failed twice in a row".into()); + } + h = DMatrix::identity(n, n); + recently_reset = true; + continue; + } + Err(e) => return Err(e.into()), + }; + + recently_reset = false; let s = alpha * &p; diff --git a/src/solvers/line_search.rs b/src/solvers/line_search.rs index 19e109a..3024446 100644 --- a/src/solvers/line_search.rs +++ b/src/solvers/line_search.rs @@ -1,28 +1,37 @@ use crate::sketch::Sketch; use nalgebra::DVector; -use std::error::Error; +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum LineSearchError { + #[error("line search failed: search direction is not a descent direction")] + NotDescentDirection, + #[error("line search failed: could not find a suitable step size")] + SearchFailed, +} const WOLFE_C1: f64 = 1e-4; const WOLFE_C2: f64 = 0.9; +const MAX_ITER: usize = 15; pub(crate) fn line_search_wolfe( sketch: &mut Sketch, direction: &DVector, gradient: &DVector, -) -> Result> { +) -> Result { let mut alpha = 1.0; let m = gradient.dot(direction); if m >= 0.0 { - return Err("line search failed: search direction is not a descent direction".into()); + return Err(LineSearchError::NotDescentDirection); } let curvature_condition = WOLFE_C2 * m; let loss = sketch.get_loss(); let x0 = sketch.get_data(); - while alpha > 1e-16 { + for _i in 0..MAX_ITER { let data = &x0 + alpha * direction; sketch.set_data(data); let new_loss = sketch.get_loss(); - // Sufficent decrease condition + // Sufficient decrease condition if new_loss <= loss + WOLFE_C1 * alpha * m { // Curvature condition let new_gradient = sketch.get_gradient(); @@ -35,5 +44,5 @@ pub(crate) fn line_search_wolfe( alpha *= 0.5; } } - Err("line search failed: alpha is too small".into()) + Err(LineSearchError::SearchFailed) }