Skip to content

Commit

Permalink
Fix/next train set is train set after 1.0.0 (#201)
Browse files Browse the repository at this point in the history
  • Loading branch information
L-M-Sherlock authored Jul 16, 2024
1 parent 5121d7b commit fb1d4c0
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 8 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "fsrs"
version = "1.1.0"
version = "1.1.1"
authors = ["Open Spaced Repetition"]
categories = ["algorithms", "science"]
edition = "2021"
Expand Down
12 changes: 6 additions & 6 deletions src/training.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,8 @@ impl<B: Backend> FSRS<B> {
};

let average_recall = calculate_average_recall(&train_set);
let (pre_train_set, next_train_set) = prepare_training_data(train_set);
if pre_train_set.len() + next_train_set.len() < 8 {
let (pre_train_set, train_set) = prepare_training_data(train_set);
if train_set.len() < 8 {
finish_progress();
return Ok(DEFAULT_PARAMETERS.to_vec());
}
Expand All @@ -221,7 +221,7 @@ impl<B: Backend> FSRS<B> {
.into_iter()
.chain(DEFAULT_PARAMETERS[4..].iter().copied())
.collect();
if next_train_set.is_empty() || pre_train_set.len() + next_train_set.len() < 64 {
if train_set.len() == pre_train_set.len() || train_set.len() < 64 {
finish_progress();
return Ok(pretrained_parameters);
}
Expand All @@ -237,16 +237,16 @@ impl<B: Backend> FSRS<B> {
if let Some(progress) = &progress {
let progress_state = ProgressState {
epoch_total: config.num_epochs,
items_total: next_train_set.len(),
items_total: train_set.len(),
epoch: 0,
items_processed: 0,
};
progress.lock().unwrap().splits = vec![progress_state];
}

let model = train::<Autodiff<B>>(
next_train_set.clone(),
next_train_set,
train_set.clone(),
train_set,
&config,
self.device(),
progress.clone().map(|p| ProgressCollector::new(p, 0)),
Expand Down

0 comments on commit fb1d4c0

Please sign in to comment.