Skip to content

Commit

Permalink
Feat/minimal workload for optimal retention (#165)
Browse files Browse the repository at this point in the history
  • Loading branch information
L-M-Sherlock committed Mar 11, 2024
1 parent ec2eae5 commit 08771a8
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 29 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "fsrs"
version = "0.5.1"
version = "0.5.2"
authors = ["Open Spaced Repetition"]
categories = ["algorithms", "science"]
edition = "2021"
Expand Down
77 changes: 50 additions & 27 deletions src/optimal_retention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ impl From<Column> for SliceInfoElem {
}
}

const R_MIN: f64 = 0.75;
const R_MAX: f64 = 0.95;

#[derive(Debug, Clone)]
pub struct SimulatorConfig {
pub deck_size: usize,
Expand Down Expand Up @@ -107,7 +110,7 @@ fn simulate(
desired_retention: f64,
seed: Option<u64>,
existing_cards: Option<Vec<Card>>,
) -> (Array1<f64>, Array1<usize>, Array1<usize>) {
) -> (Array1<f64>, Array1<usize>, Array1<usize>, Array1<f64>) {
let SimulatorConfig {
deck_size,
learn_span,
Expand Down Expand Up @@ -142,6 +145,7 @@ fn simulate(
let mut review_cnt_per_day = Array1::<usize>::zeros(learn_span);
let mut learn_cnt_per_day = Array1::<usize>::zeros(learn_span);
let mut memorized_cnt_per_day = Array1::zeros(learn_span);
let mut cost_per_day = Array1::zeros(learn_span);

let first_rating_choices = [1, 2, 3, 4];
let first_rating_dist = WeightedIndex::new(first_rating_prob).unwrap();
Expand Down Expand Up @@ -388,9 +392,18 @@ fn simulate(
review_cnt_per_day[today] = true_review.iter().filter(|&&x| x).count();
learn_cnt_per_day[today] = true_learn.iter().filter(|&&x| x).count();
memorized_cnt_per_day[today] = retrievability.sum();
cost_per_day[today] = izip!(cost, &true_review, &true_learn)
.filter(|(_, &true_review_flag, &true_learn_flag)| true_review_flag || true_learn_flag)
.map(|(cost, ..)| cost)
.sum();
}

(memorized_cnt_per_day, review_cnt_per_day, learn_cnt_per_day)
(
memorized_cnt_per_day,
review_cnt_per_day,
learn_cnt_per_day,
cost_per_day,
)
}

fn sample<F>(
Expand All @@ -409,21 +422,22 @@ where
Ok((0..n)
.into_par_iter()
.map(|i| {
let memorization = simulate(
let (memorized_cnt_per_day, _, _, cost_per_day) = simulate(
config,
parameters,
desired_retention,
Some((i + 42).try_into().unwrap()),
None,
)
.0;
memorization[memorization.len() - 1]
);
let total_memorized = memorized_cnt_per_day[memorized_cnt_per_day.len() - 1];
let total_cost = cost_per_day.sum();
total_cost / total_memorized
})
.sum::<f64>()
/ n as f64)
}

const SAMPLE_SIZE: usize = 10;
const SAMPLE_SIZE: usize = 4;

/// https://github.com/scipy/scipy/blob/5e4a5e3785f79dd4e8930eed883da89958860db2/scipy/optimize/_optimize.py#L2894
fn bracket<F>(
Expand All @@ -436,21 +450,21 @@ fn bracket<F>(
where
F: FnMut() -> bool,
{
const U_LIM: f64 = 0.95;
const L_LIM: f64 = 0.75;
const U_LIM: f64 = R_MAX;
const L_LIM: f64 = R_MIN;
const GROW_LIMIT: f64 = 100f64;
const GOLD: f64 = 1.618_033_988_749_895; // wait for https://doc.rust-lang.org/std/f64/consts/constant.PHI.html
const MAXITER: i32 = 20;

let mut fa = -sample(config, parameters, xa, SAMPLE_SIZE, progress)?;
let mut fb = -sample(config, parameters, xb, SAMPLE_SIZE, progress)?;
let mut fa = sample(config, parameters, xa, SAMPLE_SIZE, progress)?;
let mut fb = sample(config, parameters, xb, SAMPLE_SIZE, progress)?;

if fa < fb {
(fa, fb) = (fb, fa);
(xa, xb) = (xb, xa);
}
let mut xc = GOLD.mul_add(xb - xa, xb).clamp(L_LIM, U_LIM);
let mut fc = -sample(config, parameters, xc, SAMPLE_SIZE, progress)?;
let mut fc = sample(config, parameters, xc, SAMPLE_SIZE, progress)?;

let mut iter = 0;
while fc < fb {
Expand All @@ -470,23 +484,23 @@ where
let mut fw: f64;

if (w - xc) * (xb - w) > 0.0 {
fw = -sample(config, parameters, w, SAMPLE_SIZE, progress)?;
fw = sample(config, parameters, w, SAMPLE_SIZE, progress)?;
if fw < fc {
(xa, xb) = (xb.clamp(L_LIM, U_LIM), w.clamp(L_LIM, U_LIM));
(fa, fb) = (fb, fw);
break;
} else if fw > fb {
xc = w.clamp(L_LIM, U_LIM);
fc = -sample(config, parameters, xc, SAMPLE_SIZE, progress)?;
fc = sample(config, parameters, xc, SAMPLE_SIZE, progress)?;
break;
}
w = GOLD.mul_add(xc - xb, xc).clamp(L_LIM, U_LIM);
fw = -sample(config, parameters, w, SAMPLE_SIZE, progress)?;
fw = sample(config, parameters, w, SAMPLE_SIZE, progress)?;
} else if (w - wlim) * (wlim - xc) >= 0.0 {
w = wlim;
fw = -sample(config, parameters, w, SAMPLE_SIZE, progress)?;
fw = sample(config, parameters, w, SAMPLE_SIZE, progress)?;
} else if (w - wlim) * (xc - w) > 0.0 {
fw = -sample(config, parameters, w, SAMPLE_SIZE, progress)?;
fw = sample(config, parameters, w, SAMPLE_SIZE, progress)?;
if fw < fc {
(xb, xc, w) = (
xc.clamp(L_LIM, U_LIM),
Expand All @@ -496,12 +510,12 @@ where
(fb, fc, fw) = (
fc,
fw,
-sample(config, parameters, w, SAMPLE_SIZE, progress)?,
sample(config, parameters, w, SAMPLE_SIZE, progress)?,
);
}
} else {
w = GOLD.mul_add(xc - xb, xc).clamp(L_LIM, U_LIM);
fw = -sample(config, parameters, w, SAMPLE_SIZE, progress)?;
fw = sample(config, parameters, w, SAMPLE_SIZE, progress)?;
}
(xa, xb, xc) = (
xb.clamp(L_LIM, U_LIM),
Expand Down Expand Up @@ -562,7 +576,7 @@ impl<B: Backend> FSRS<B> {
let maxiter = 64;
let tol = 0.01f64;

let (xa, xb, xc, _fa, fb, _fc) = bracket(0.75, 0.95, config, parameters, &mut progress)?;
let (xa, xb, xc, _fa, fb, _fc) = bracket(R_MIN, R_MAX, config, parameters, &mut progress)?;

let (mut v, mut w, mut x) = (xb, xb, xb);
let (mut fx, mut fv, mut fw) = (fb, fb, fb);
Expand Down Expand Up @@ -620,7 +634,7 @@ impl<B: Backend> FSRS<B> {
rat
};
// calculate new output value
let fu = -sample(config, parameters, u, SAMPLE_SIZE, &mut progress)?;
let fu = sample(config, parameters, u, SAMPLE_SIZE, &mut progress)?;

// if it's bigger than current
if fu > fx {
Expand Down Expand Up @@ -649,7 +663,7 @@ impl<B: Backend> FSRS<B> {
iter += 1;
}
let xmin = x;
let success = iter < maxiter && (0.75..=0.95).contains(&xmin);
let success = iter < maxiter && (R_MIN..=R_MAX).contains(&xmin);

if success {
Ok(xmin)
Expand All @@ -669,15 +683,17 @@ mod tests {
#[test]
fn simulator() {
let config = SimulatorConfig::default();
let memorization = simulate(
let (memorized_cnt_per_day, _, _, _) = simulate(
&config,
&DEFAULT_PARAMETERS.iter().map(|v| *v as f64).collect_vec(),
0.9,
None,
None,
);
assert_eq!(
memorized_cnt_per_day[memorized_cnt_per_day.len() - 1],
3130.8465582271774
)
.0;
assert_eq!(memorization[memorization.len() - 1], 3130.8465582271774)
}

#[test]
Expand Down Expand Up @@ -744,10 +760,17 @@ mod tests {

#[test]
fn optimal_retention() -> Result<()> {
let config = SimulatorConfig::default();
let fsrs = FSRS::new(None)?;
let config = SimulatorConfig {
deck_size: 10000,
learn_span: 1000,
max_cost_perday: f64::INFINITY,
learn_limit: 10,
loss_aversion: 1.0,
..Default::default()
};
let optimal_retention = fsrs.optimal_retention(&config, &[], |_v| true).unwrap();
assert_eq!(optimal_retention, 0.8468471175527587);
assert_eq!(optimal_retention, 0.7984864824748231);
assert!(fsrs.optimal_retention(&config, &[1.], |_v| true).is_err());
Ok(())
}
Expand Down

0 comments on commit 08771a8

Please sign in to comment.