Skip to content

Commit

Permalink
Speed up optimal_retention() (#168)
Browse files Browse the repository at this point in the history
  • Loading branch information
L-M-Sherlock committed Mar 12, 2024
1 parent 994c814 commit 0634cb0
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 100 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "fsrs"
version = "0.5.3"
version = "0.5.4"
authors = ["Open Spaced Repetition"]
categories = ["algorithms", "science"]
edition = "2021"
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ add
#!/bin/sh
cargo fmt
cargo clippy -- -D warnings
git add .
```

to `.git/hooks/pre-commit`, then `chmod +x .git/hooks/pre-commit`
17 changes: 16 additions & 1 deletion benches/benchmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use criterion::criterion_main;
use criterion::Criterion;
use fsrs::FSRSReview;
use fsrs::NextStates;
use fsrs::SimulatorConfig;
use fsrs::FSRS;
use fsrs::{FSRSItem, MemoryState};
use itertools::Itertools;
Expand All @@ -34,6 +35,10 @@ pub(crate) fn next_states(inf: &FSRS) -> NextStates {
.unwrap()
}

pub(crate) fn optimal_retention(inf: &FSRS, config: &SimulatorConfig) -> f64 {
inf.optimal_retention(config, &[], |_v| true).unwrap()
}

pub fn criterion_benchmark(c: &mut Criterion) {
let fsrs = FSRS::new(Some(&[
0.81497127,
Expand All @@ -55,9 +60,19 @@ pub fn criterion_benchmark(c: &mut Criterion) {
2.6646678,
]))
.unwrap();

let config = SimulatorConfig {
deck_size: 3650,
learn_span: 365,
max_cost_perday: f64::INFINITY,
learn_limit: 10,
loss_aversion: 1.0,
..Default::default()
};
c.bench_function("calc_mem", |b| b.iter(|| black_box(calc_mem(&fsrs, 100))));
c.bench_function("next_states", |b| b.iter(|| black_box(next_states(&fsrs))));
c.bench_function("optimal_retention", |b| {
b.iter(|| black_box(optimal_retention(&fsrs, &config)))
});
}

criterion_group!(benches, criterion_benchmark);
Expand Down
111 changes: 14 additions & 97 deletions src/optimal_retention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -439,94 +439,6 @@ where

const SAMPLE_SIZE: usize = 4;

/// https://github.com/scipy/scipy/blob/5e4a5e3785f79dd4e8930eed883da89958860db2/scipy/optimize/_optimize.py#L2894
fn bracket<F>(
mut xa: f64,
mut xb: f64,
config: &SimulatorConfig,
parameters: &[f64],
progress: &mut F,
) -> Result<(f64, f64, f64, f64, f64, f64)>
where
F: FnMut() -> bool,
{
const U_LIM: f64 = R_MAX;
const L_LIM: f64 = R_MIN;
const GROW_LIMIT: f64 = 100f64;
const GOLD: f64 = 1.618_033_988_749_895; // wait for https://doc.rust-lang.org/std/f64/consts/constant.PHI.html
const MAXITER: i32 = 20;

let mut fa = sample(config, parameters, xa, SAMPLE_SIZE, progress)?;
let mut fb = sample(config, parameters, xb, SAMPLE_SIZE, progress)?;

if fa < fb {
(fa, fb) = (fb, fa);
(xa, xb) = (xb, xa);
}
let mut xc = GOLD.mul_add(xb - xa, xb).clamp(L_LIM, U_LIM);
let mut fc = sample(config, parameters, xc, SAMPLE_SIZE, progress)?;

let mut iter = 0;
while fc < fb {
let tmp1 = (xb - xa) * (fb - fc);
let tmp2 = (xb - xc) * (fb - fa);
let val = tmp2 - tmp1;
let denom = 2.0 * val.clamp(1e-20, 1e20);
let mut w = (xb - (xb - xc).mul_add(tmp2, (xa - xb) * tmp1) / denom).clamp(L_LIM, U_LIM);
let wlim = GROW_LIMIT.mul_add(xc - xb, xb).clamp(L_LIM, U_LIM);

if iter >= MAXITER {
break;
}

iter += 1;

let mut fw: f64;

if (w - xc) * (xb - w) > 0.0 {
fw = sample(config, parameters, w, SAMPLE_SIZE, progress)?;
if fw < fc {
(xa, xb) = (xb.clamp(L_LIM, U_LIM), w.clamp(L_LIM, U_LIM));
(fa, fb) = (fb, fw);
break;
} else if fw > fb {
xc = w.clamp(L_LIM, U_LIM);
fc = sample(config, parameters, xc, SAMPLE_SIZE, progress)?;
break;
}
w = GOLD.mul_add(xc - xb, xc).clamp(L_LIM, U_LIM);
fw = sample(config, parameters, w, SAMPLE_SIZE, progress)?;
} else if (w - wlim) * (wlim - xc) >= 0.0 {
w = wlim;
fw = sample(config, parameters, w, SAMPLE_SIZE, progress)?;
} else if (w - wlim) * (xc - w) > 0.0 {
fw = sample(config, parameters, w, SAMPLE_SIZE, progress)?;
if fw < fc {
(xb, xc, w) = (
xc.clamp(L_LIM, U_LIM),
w.clamp(L_LIM, U_LIM),
GOLD.mul_add(xc - xb, xc).clamp(L_LIM, U_LIM),
);
(fb, fc, fw) = (
fc,
fw,
sample(config, parameters, w, SAMPLE_SIZE, progress)?,
);
}
} else {
w = GOLD.mul_add(xc - xb, xc).clamp(L_LIM, U_LIM);
fw = sample(config, parameters, w, SAMPLE_SIZE, progress)?;
}
(xa, xb, xc) = (
xb.clamp(L_LIM, U_LIM),
xc.clamp(L_LIM, U_LIM),
w.clamp(L_LIM, U_LIM),
);
(fa, fb, fc) = (fb, fc, fw);
}
Ok((xa, xb, xc, fa, fb, fc))
}

impl<B: Backend> FSRS<B> {
/// For the given simulator parameters and parameters, determine the suggested `desired_retention`
/// value.
Expand Down Expand Up @@ -576,11 +488,13 @@ impl<B: Backend> FSRS<B> {
let maxiter = 64;
let tol = 0.01f64;

let (xa, xb, xc, _fa, fb, _fc) = bracket(R_MIN, R_MAX, config, parameters, &mut progress)?;

let (mut v, mut w, mut x) = (xb, xb, xb);
let (xb, fb) = (
R_MIN,
sample(config, parameters, R_MIN, SAMPLE_SIZE, &mut progress)?,
);
let (mut x, mut v, mut w) = (xb, xb, xb);
let (mut fx, mut fv, mut fw) = (fb, fb, fb);
let (mut a, mut b) = (xa.min(xc), xa.max(xc));
let (mut a, mut b) = (R_MIN, R_MAX);
let mut deltax: f64 = 0.0;
let mut iter = 0;
let mut rat = 0.0;
Expand Down Expand Up @@ -664,6 +578,7 @@ impl<B: Backend> FSRS<B> {
}
let xmin = x;
let success = iter < maxiter && (R_MIN..=R_MAX).contains(&xmin);
dbg!(iter);

if success {
Ok(xmin)
Expand Down Expand Up @@ -760,17 +675,19 @@ mod tests {

#[test]
fn optimal_retention() -> Result<()> {
let learn_span = 1000;
let learn_limit = 10;
let fsrs = FSRS::new(None)?;
let config = SimulatorConfig {
deck_size: 10000,
learn_span: 1000,
deck_size: learn_span * learn_limit,
learn_span,
max_cost_perday: f64::INFINITY,
learn_limit: 10,
loss_aversion: 1.0,
learn_limit,
loss_aversion: 2.5,
..Default::default()
};
let optimal_retention = fsrs.optimal_retention(&config, &[], |_v| true).unwrap();
assert_eq!(optimal_retention, 0.7984864824748231);
assert_eq!(optimal_retention, 0.8263932);
assert!(fsrs.optimal_retention(&config, &[1.], |_v| true).is_err());
Ok(())
}
Expand Down

0 comments on commit 0634cb0

Please sign in to comment.