diff --git a/Cargo.lock b/Cargo.lock index 0a57e4c..2666cc3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1055,7 +1055,7 @@ dependencies = [ [[package]] name = "fsrs" -version = "0.5.3" +version = "0.5.4" dependencies = [ "burn", "chrono", diff --git a/Cargo.toml b/Cargo.toml index 1d5ce61..ca7f50f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "fsrs" -version = "0.5.3" +version = "0.5.4" authors = ["Open Spaced Repetition"] categories = ["algorithms", "science"] edition = "2021" diff --git a/README.md b/README.md index 6f98b94..784a34c 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,7 @@ add #!/bin/sh cargo fmt cargo clippy -- -D warnings +git add . ``` to `.git/hooks/pre-commit`, then `chmod +x .git/hooks/pre-commit` diff --git a/benches/benchmark.rs b/benches/benchmark.rs index f3bcb6f..0bc0802 100644 --- a/benches/benchmark.rs +++ b/benches/benchmark.rs @@ -9,6 +9,7 @@ use criterion::criterion_main; use criterion::Criterion; use fsrs::FSRSReview; use fsrs::NextStates; +use fsrs::SimulatorConfig; use fsrs::FSRS; use fsrs::{FSRSItem, MemoryState}; use itertools::Itertools; @@ -34,6 +35,10 @@ pub(crate) fn next_states(inf: &FSRS) -> NextStates { .unwrap() } +pub(crate) fn optimal_retention(inf: &FSRS, config: &SimulatorConfig) -> f64 { + inf.optimal_retention(config, &[], |_v| true).unwrap() +} + pub fn criterion_benchmark(c: &mut Criterion) { let fsrs = FSRS::new(Some(&[ 0.81497127, @@ -55,9 +60,19 @@ pub fn criterion_benchmark(c: &mut Criterion) { 2.6646678, ])) .unwrap(); - + let config = SimulatorConfig { + deck_size: 3650, + learn_span: 365, + max_cost_perday: f64::INFINITY, + learn_limit: 10, + loss_aversion: 1.0, + ..Default::default() + }; c.bench_function("calc_mem", |b| b.iter(|| black_box(calc_mem(&fsrs, 100)))); c.bench_function("next_states", |b| b.iter(|| black_box(next_states(&fsrs)))); + c.bench_function("optimal_retention", |b| { + b.iter(|| black_box(optimal_retention(&fsrs, &config))) + }); } criterion_group!(benches, criterion_benchmark); diff --git a/src/optimal_retention.rs b/src/optimal_retention.rs index 7f9c614..68f9d2d 100644 --- a/src/optimal_retention.rs +++ b/src/optimal_retention.rs @@ -439,94 +439,6 @@ where const SAMPLE_SIZE: usize = 4; -/// https://github.com/scipy/scipy/blob/5e4a5e3785f79dd4e8930eed883da89958860db2/scipy/optimize/_optimize.py#L2894 -fn bracket( - mut xa: f64, - mut xb: f64, - config: &SimulatorConfig, - parameters: &[f64], - progress: &mut F, -) -> Result<(f64, f64, f64, f64, f64, f64)> -where - F: FnMut() -> bool, -{ - const U_LIM: f64 = R_MAX; - const L_LIM: f64 = R_MIN; - const GROW_LIMIT: f64 = 100f64; - const GOLD: f64 = 1.618_033_988_749_895; // wait for https://doc.rust-lang.org/std/f64/consts/constant.PHI.html - const MAXITER: i32 = 20; - - let mut fa = sample(config, parameters, xa, SAMPLE_SIZE, progress)?; - let mut fb = sample(config, parameters, xb, SAMPLE_SIZE, progress)?; - - if fa < fb { - (fa, fb) = (fb, fa); - (xa, xb) = (xb, xa); - } - let mut xc = GOLD.mul_add(xb - xa, xb).clamp(L_LIM, U_LIM); - let mut fc = sample(config, parameters, xc, SAMPLE_SIZE, progress)?; - - let mut iter = 0; - while fc < fb { - let tmp1 = (xb - xa) * (fb - fc); - let tmp2 = (xb - xc) * (fb - fa); - let val = tmp2 - tmp1; - let denom = 2.0 * val.clamp(1e-20, 1e20); - let mut w = (xb - (xb - xc).mul_add(tmp2, (xa - xb) * tmp1) / denom).clamp(L_LIM, U_LIM); - let wlim = GROW_LIMIT.mul_add(xc - xb, xb).clamp(L_LIM, U_LIM); - - if iter >= MAXITER { - break; - } - - iter += 1; - - let mut fw: f64; - - if (w - xc) * (xb - w) > 0.0 { - fw = sample(config, parameters, w, SAMPLE_SIZE, progress)?; - if fw < fc { - (xa, xb) = (xb.clamp(L_LIM, U_LIM), w.clamp(L_LIM, U_LIM)); - (fa, fb) = (fb, fw); - break; - } else if fw > fb { - xc = w.clamp(L_LIM, U_LIM); - fc = sample(config, parameters, xc, SAMPLE_SIZE, progress)?; - break; - } - w = GOLD.mul_add(xc - xb, xc).clamp(L_LIM, U_LIM); - fw = sample(config, parameters, w, SAMPLE_SIZE, progress)?; - } else if (w - wlim) * (wlim - xc) >= 0.0 { - w = wlim; - fw = sample(config, parameters, w, SAMPLE_SIZE, progress)?; - } else if (w - wlim) * (xc - w) > 0.0 { - fw = sample(config, parameters, w, SAMPLE_SIZE, progress)?; - if fw < fc { - (xb, xc, w) = ( - xc.clamp(L_LIM, U_LIM), - w.clamp(L_LIM, U_LIM), - GOLD.mul_add(xc - xb, xc).clamp(L_LIM, U_LIM), - ); - (fb, fc, fw) = ( - fc, - fw, - sample(config, parameters, w, SAMPLE_SIZE, progress)?, - ); - } - } else { - w = GOLD.mul_add(xc - xb, xc).clamp(L_LIM, U_LIM); - fw = sample(config, parameters, w, SAMPLE_SIZE, progress)?; - } - (xa, xb, xc) = ( - xb.clamp(L_LIM, U_LIM), - xc.clamp(L_LIM, U_LIM), - w.clamp(L_LIM, U_LIM), - ); - (fa, fb, fc) = (fb, fc, fw); - } - Ok((xa, xb, xc, fa, fb, fc)) -} - impl FSRS { /// For the given simulator parameters and parameters, determine the suggested `desired_retention` /// value. @@ -576,11 +488,13 @@ impl FSRS { let maxiter = 64; let tol = 0.01f64; - let (xa, xb, xc, _fa, fb, _fc) = bracket(R_MIN, R_MAX, config, parameters, &mut progress)?; - - let (mut v, mut w, mut x) = (xb, xb, xb); + let (xb, fb) = ( + R_MIN, + sample(config, parameters, R_MIN, SAMPLE_SIZE, &mut progress)?, + ); + let (mut x, mut v, mut w) = (xb, xb, xb); let (mut fx, mut fv, mut fw) = (fb, fb, fb); - let (mut a, mut b) = (xa.min(xc), xa.max(xc)); + let (mut a, mut b) = (R_MIN, R_MAX); let mut deltax: f64 = 0.0; let mut iter = 0; let mut rat = 0.0; @@ -664,6 +578,7 @@ impl FSRS { } let xmin = x; let success = iter < maxiter && (R_MIN..=R_MAX).contains(&xmin); + dbg!(iter); if success { Ok(xmin) @@ -760,17 +675,19 @@ mod tests { #[test] fn optimal_retention() -> Result<()> { + let learn_span = 1000; + let learn_limit = 10; let fsrs = FSRS::new(None)?; let config = SimulatorConfig { - deck_size: 10000, - learn_span: 1000, + deck_size: learn_span * learn_limit, + learn_span, max_cost_perday: f64::INFINITY, - learn_limit: 10, - loss_aversion: 1.0, + learn_limit, + loss_aversion: 2.5, ..Default::default() }; let optimal_retention = fsrs.optimal_retention(&config, &[], |_v| true).unwrap(); - assert_eq!(optimal_retention, 0.7984864824748231); + assert_eq!(optimal_retention, 0.8263932); assert!(fsrs.optimal_retention(&config, &[1.], |_v| true).is_err()); Ok(()) }