Skip to content

Commit

Permalink
Feat/add rules for returning parameters (#172)
Browse files Browse the repository at this point in the history
  • Loading branch information
L-M-Sherlock committed Mar 31, 2024
1 parent 5cd2d8d commit 02230be
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 15 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "fsrs"
version = "0.5.5"
version = "0.6.0"
authors = ["Open Spaced Repetition"]
categories = ["algorithms", "science"]
edition = "2021"
Expand Down
35 changes: 22 additions & 13 deletions src/training.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,6 @@ impl<B: Backend> FSRS<B> {
pub fn compute_parameters(
&self,
items: Vec<FSRSItem>,
pretrain_only: bool,
progress: Option<Arc<Mutex<CombinedProgressState>>>,
) -> Result<Vec<f32>> {
let finish_progress = || {
Expand All @@ -214,18 +213,24 @@ impl<B: Backend> FSRS<B> {
let n_splits = 5;
let average_recall = calculate_average_recall(&items);
let (pre_trainset, trainsets, testset) = split_data(items, n_splits);
let initial_stability = pretrain(pre_trainset, average_recall).map_err(|e| {
if pre_trainset.len() + testset.len() < 8 {
finish_progress();
return Ok(DEFAULT_PARAMETERS.to_vec());
}

let initial_stability = pretrain(pre_trainset.clone(), average_recall).map_err(|e| {
finish_progress();
e
})?;
if pretrain_only {
let pretrained_parameters: Vec<f32> = initial_stability
.into_iter()
.chain(DEFAULT_PARAMETERS[4..].iter().copied())
.collect();
if testset.is_empty() || pre_trainset.len() + testset.len() < 64 {
finish_progress();
let parameters = initial_stability
.into_iter()
.chain(DEFAULT_PARAMETERS[4..].iter().copied())
.collect();
return Ok(parameters);
return Ok(pretrained_parameters);
}

let config = TrainingConfig::new(
ModelConfig {
freeze_stability: true,
Expand Down Expand Up @@ -490,6 +495,8 @@ mod tests {
let items = anki21_sample_file_converted_to_fsrs();
let (pre_trainset, trainsets, testset) = split_data(items.clone(), n_splits);
let items = [pre_trainset.clone(), testset.clone()].concat();
dbg!(pre_trainset.len());
dbg!(testset.len());
let average_recall = calculate_average_recall(&items);
dbg!(average_recall);
let initial_stability = pretrain(pre_trainset, average_recall).unwrap();
Expand All @@ -506,7 +513,7 @@ mod tests {
thread::spawn(move || {
let mut finished = false;
while !finished {
thread::sleep(Duration::from_millis(10));
thread::sleep(Duration::from_millis(1000));
let guard = progress.lock().unwrap();
finished = guard.finished();
println!("progress: {}/{}", guard.current(), guard.total());
Expand Down Expand Up @@ -559,9 +566,11 @@ mod tests {
.map(|&sum| sum / n_splits as f32)
.collect();
dbg!(&average_parameters);

let fsrs = FSRS::new(Some(&average_parameters)).unwrap();
let metrics = fsrs.evaluate(items.clone(), |_| true).unwrap();
dbg!(&metrics);
let optimized_fsrs = FSRS::new(Some(&average_parameters)).unwrap();
let optimized_rmse = optimized_fsrs
.evaluate(testset, |_| true)
.unwrap()
.rmse_bins;
dbg!(optimized_rmse);
}
}

0 comments on commit 02230be

Please sign in to comment.