diff --git a/Cargo.lock b/Cargo.lock index a63302a..51eb8d5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1055,7 +1055,7 @@ dependencies = [ [[package]] name = "fsrs" -version = "0.5.1" +version = "0.5.2" dependencies = [ "burn", "chrono", diff --git a/Cargo.toml b/Cargo.toml index bc90171..d9d6bf3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "fsrs" -version = "0.5.1" +version = "0.5.2" authors = ["Open Spaced Repetition"] categories = ["algorithms", "science"] edition = "2021" diff --git a/src/optimal_retention.rs b/src/optimal_retention.rs index 2822215..7f9c614 100644 --- a/src/optimal_retention.rs +++ b/src/optimal_retention.rs @@ -43,6 +43,9 @@ impl From for SliceInfoElem { } } +const R_MIN: f64 = 0.75; +const R_MAX: f64 = 0.95; + #[derive(Debug, Clone)] pub struct SimulatorConfig { pub deck_size: usize, @@ -107,7 +110,7 @@ fn simulate( desired_retention: f64, seed: Option, existing_cards: Option>, -) -> (Array1, Array1, Array1) { +) -> (Array1, Array1, Array1, Array1) { let SimulatorConfig { deck_size, learn_span, @@ -142,6 +145,7 @@ fn simulate( let mut review_cnt_per_day = Array1::::zeros(learn_span); let mut learn_cnt_per_day = Array1::::zeros(learn_span); let mut memorized_cnt_per_day = Array1::zeros(learn_span); + let mut cost_per_day = Array1::zeros(learn_span); let first_rating_choices = [1, 2, 3, 4]; let first_rating_dist = WeightedIndex::new(first_rating_prob).unwrap(); @@ -388,9 +392,18 @@ fn simulate( review_cnt_per_day[today] = true_review.iter().filter(|&&x| x).count(); learn_cnt_per_day[today] = true_learn.iter().filter(|&&x| x).count(); memorized_cnt_per_day[today] = retrievability.sum(); + cost_per_day[today] = izip!(cost, &true_review, &true_learn) + .filter(|(_, &true_review_flag, &true_learn_flag)| true_review_flag || true_learn_flag) + .map(|(cost, ..)| cost) + .sum(); } - (memorized_cnt_per_day, review_cnt_per_day, learn_cnt_per_day) + ( + memorized_cnt_per_day, + review_cnt_per_day, + learn_cnt_per_day, + cost_per_day, + ) } fn sample( @@ -409,21 +422,22 @@ where Ok((0..n) .into_par_iter() .map(|i| { - let memorization = simulate( + let (memorized_cnt_per_day, _, _, cost_per_day) = simulate( config, parameters, desired_retention, Some((i + 42).try_into().unwrap()), None, - ) - .0; - memorization[memorization.len() - 1] + ); + let total_memorized = memorized_cnt_per_day[memorized_cnt_per_day.len() - 1]; + let total_cost = cost_per_day.sum(); + total_cost / total_memorized }) .sum::() / n as f64) } -const SAMPLE_SIZE: usize = 10; +const SAMPLE_SIZE: usize = 4; /// https://github.com/scipy/scipy/blob/5e4a5e3785f79dd4e8930eed883da89958860db2/scipy/optimize/_optimize.py#L2894 fn bracket( @@ -436,21 +450,21 @@ fn bracket( where F: FnMut() -> bool, { - const U_LIM: f64 = 0.95; - const L_LIM: f64 = 0.75; + const U_LIM: f64 = R_MAX; + const L_LIM: f64 = R_MIN; const GROW_LIMIT: f64 = 100f64; const GOLD: f64 = 1.618_033_988_749_895; // wait for https://doc.rust-lang.org/std/f64/consts/constant.PHI.html const MAXITER: i32 = 20; - let mut fa = -sample(config, parameters, xa, SAMPLE_SIZE, progress)?; - let mut fb = -sample(config, parameters, xb, SAMPLE_SIZE, progress)?; + let mut fa = sample(config, parameters, xa, SAMPLE_SIZE, progress)?; + let mut fb = sample(config, parameters, xb, SAMPLE_SIZE, progress)?; if fa < fb { (fa, fb) = (fb, fa); (xa, xb) = (xb, xa); } let mut xc = GOLD.mul_add(xb - xa, xb).clamp(L_LIM, U_LIM); - let mut fc = -sample(config, parameters, xc, SAMPLE_SIZE, progress)?; + let mut fc = sample(config, parameters, xc, SAMPLE_SIZE, progress)?; let mut iter = 0; while fc < fb { @@ -470,23 +484,23 @@ where let mut fw: f64; if (w - xc) * (xb - w) > 0.0 { - fw = -sample(config, parameters, w, SAMPLE_SIZE, progress)?; + fw = sample(config, parameters, w, SAMPLE_SIZE, progress)?; if fw < fc { (xa, xb) = (xb.clamp(L_LIM, U_LIM), w.clamp(L_LIM, U_LIM)); (fa, fb) = (fb, fw); break; } else if fw > fb { xc = w.clamp(L_LIM, U_LIM); - fc = -sample(config, parameters, xc, SAMPLE_SIZE, progress)?; + fc = sample(config, parameters, xc, SAMPLE_SIZE, progress)?; break; } w = GOLD.mul_add(xc - xb, xc).clamp(L_LIM, U_LIM); - fw = -sample(config, parameters, w, SAMPLE_SIZE, progress)?; + fw = sample(config, parameters, w, SAMPLE_SIZE, progress)?; } else if (w - wlim) * (wlim - xc) >= 0.0 { w = wlim; - fw = -sample(config, parameters, w, SAMPLE_SIZE, progress)?; + fw = sample(config, parameters, w, SAMPLE_SIZE, progress)?; } else if (w - wlim) * (xc - w) > 0.0 { - fw = -sample(config, parameters, w, SAMPLE_SIZE, progress)?; + fw = sample(config, parameters, w, SAMPLE_SIZE, progress)?; if fw < fc { (xb, xc, w) = ( xc.clamp(L_LIM, U_LIM), @@ -496,12 +510,12 @@ where (fb, fc, fw) = ( fc, fw, - -sample(config, parameters, w, SAMPLE_SIZE, progress)?, + sample(config, parameters, w, SAMPLE_SIZE, progress)?, ); } } else { w = GOLD.mul_add(xc - xb, xc).clamp(L_LIM, U_LIM); - fw = -sample(config, parameters, w, SAMPLE_SIZE, progress)?; + fw = sample(config, parameters, w, SAMPLE_SIZE, progress)?; } (xa, xb, xc) = ( xb.clamp(L_LIM, U_LIM), @@ -562,7 +576,7 @@ impl FSRS { let maxiter = 64; let tol = 0.01f64; - let (xa, xb, xc, _fa, fb, _fc) = bracket(0.75, 0.95, config, parameters, &mut progress)?; + let (xa, xb, xc, _fa, fb, _fc) = bracket(R_MIN, R_MAX, config, parameters, &mut progress)?; let (mut v, mut w, mut x) = (xb, xb, xb); let (mut fx, mut fv, mut fw) = (fb, fb, fb); @@ -620,7 +634,7 @@ impl FSRS { rat }; // calculate new output value - let fu = -sample(config, parameters, u, SAMPLE_SIZE, &mut progress)?; + let fu = sample(config, parameters, u, SAMPLE_SIZE, &mut progress)?; // if it's bigger than current if fu > fx { @@ -649,7 +663,7 @@ impl FSRS { iter += 1; } let xmin = x; - let success = iter < maxiter && (0.75..=0.95).contains(&xmin); + let success = iter < maxiter && (R_MIN..=R_MAX).contains(&xmin); if success { Ok(xmin) @@ -669,15 +683,17 @@ mod tests { #[test] fn simulator() { let config = SimulatorConfig::default(); - let memorization = simulate( + let (memorized_cnt_per_day, _, _, _) = simulate( &config, &DEFAULT_PARAMETERS.iter().map(|v| *v as f64).collect_vec(), 0.9, None, None, + ); + assert_eq!( + memorized_cnt_per_day[memorized_cnt_per_day.len() - 1], + 3130.8465582271774 ) - .0; - assert_eq!(memorization[memorization.len() - 1], 3130.8465582271774) } #[test] @@ -744,10 +760,17 @@ mod tests { #[test] fn optimal_retention() -> Result<()> { - let config = SimulatorConfig::default(); let fsrs = FSRS::new(None)?; + let config = SimulatorConfig { + deck_size: 10000, + learn_span: 1000, + max_cost_perday: f64::INFINITY, + learn_limit: 10, + loss_aversion: 1.0, + ..Default::default() + }; let optimal_retention = fsrs.optimal_retention(&config, &[], |_v| true).unwrap(); - assert_eq!(optimal_retention, 0.8468471175527587); + assert_eq!(optimal_retention, 0.7984864824748231); assert!(fsrs.optimal_retention(&config, &[1.], |_v| true).is_err()); Ok(()) }