Skip to content

Commit

Permalink
Remove stratified kfold (#173)
Browse files Browse the repository at this point in the history
  • Loading branch information
L-M-Sherlock committed Apr 2, 2024
1 parent 02230be commit 7d27226
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 166 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "fsrs"
version = "0.6.0"
version = "0.6.1"
authors = ["Open Spaced Repetition"]
categories = ["algorithms", "science"]
edition = "2021"
Expand Down
22 changes: 2 additions & 20 deletions src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,31 +209,13 @@ pub fn filter_outlier(
(filtered_items, trainset)
}

fn stratified_kfold(mut trainset: Vec<FSRSItem>, n_splits: usize) -> Vec<Vec<FSRSItem>> {
trainset.sort_by(|a, b| a.reviews.len().cmp(&b.reviews.len()));
(0..n_splits)
.cycle() // cycle to evenly distribute
.zip(trainset)
.fold(vec![vec![]; n_splits], |mut acc, (i, item)| {
acc[i].push(item);
acc
})
}

pub fn split_data(
items: Vec<FSRSItem>,
n_splits: usize,
) -> (Vec<FSRSItem>, Vec<Vec<FSRSItem>>, Vec<FSRSItem>) {
pub fn split_filter_data(items: Vec<FSRSItem>) -> (Vec<FSRSItem>, Vec<FSRSItem>) {
let (mut pretrainset, mut trainset) =
items.into_iter().partition(|item| item.reviews.len() == 2);
if std::env::var("FSRS_NO_OUTLIER").is_err() {
(pretrainset, trainset) = filter_outlier(pretrainset, trainset);
}
(
pretrainset,
stratified_kfold(trainset.clone(), n_splits),
trainset,
)
(pretrainset, trainset)
}

#[cfg(test)]
Expand Down
198 changes: 54 additions & 144 deletions src/training.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::batch_shuffle::BatchShuffledDataLoaderBuilder;
use crate::cosine_annealing::CosineAnnealingLR;
use crate::dataset::{split_data, FSRSBatcher, FSRSDataset, FSRSItem};
use crate::dataset::{split_filter_data, FSRSBatcher, FSRSDataset, FSRSItem};
use crate::error::Result;
use crate::model::{Model, ModelConfig};
use crate::pre_training::pretrain;
Expand All @@ -22,10 +22,6 @@ use burn::{config::Config, module::Param, tensor::backend::AutodiffBackend};
use core::marker::PhantomData;
use log::info;

use rayon::prelude::{
IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator, ParallelIterator,
};

use std::sync::{Arc, Mutex};

pub struct BCELoss<B: Backend> {
Expand Down Expand Up @@ -197,7 +193,7 @@ impl<B: Backend> FSRS<B> {
/// Calculate appropriate parameters for the provided review history.
pub fn compute_parameters(
&self,
items: Vec<FSRSItem>,
train_set: Vec<FSRSItem>,
progress: Option<Arc<Mutex<CombinedProgressState>>>,
) -> Result<Vec<f32>> {
let finish_progress = || {
Expand All @@ -210,23 +206,22 @@ impl<B: Backend> FSRS<B> {
}
};

let n_splits = 5;
let average_recall = calculate_average_recall(&items);
let (pre_trainset, trainsets, testset) = split_data(items, n_splits);
if pre_trainset.len() + testset.len() < 8 {
let average_recall = calculate_average_recall(&train_set);
let (pre_train_set, next_train_set) = split_filter_data(train_set);
if pre_train_set.len() + next_train_set.len() < 8 {
finish_progress();
return Ok(DEFAULT_PARAMETERS.to_vec());
}

let initial_stability = pretrain(pre_trainset.clone(), average_recall).map_err(|e| {
let initial_stability = pretrain(pre_train_set.clone(), average_recall).map_err(|e| {
finish_progress();
e
})?;
let pretrained_parameters: Vec<f32> = initial_stability
.into_iter()
.chain(DEFAULT_PARAMETERS[4..].iter().copied())
.collect();
if testset.is_empty() || pre_trainset.len() + testset.len() < 64 {
if next_train_set.is_empty() || pre_train_set.len() + next_train_set.len() < 64 {
finish_progress();
return Ok(pretrained_parameters);
}
Expand All @@ -239,70 +234,48 @@ impl<B: Backend> FSRS<B> {
AdamConfig::new(),
);

let trainsets: Vec<Vec<FSRSItem>> = (0..n_splits)
.into_par_iter()
.map(|i| {
trainsets
.iter()
.enumerate()
.filter(|&(j, _)| j != i)
.flat_map(|(_, trainset)| trainset.clone())
.collect()
})
.collect();

if let Some(progress) = &progress {
let mut progress_states = vec![ProgressState::default(); n_splits];
for (i, progress_state) in progress_states.iter_mut().enumerate() {
progress_state.epoch_total = config.num_epochs;
progress_state.items_total = trainsets[i].len();
}
progress.lock().unwrap().splits = progress_states
let progress_state = ProgressState {
epoch_total: config.num_epochs,
items_total: next_train_set.len(),
epoch: 0,
items_processed: 0,
};
progress.lock().unwrap().splits = vec![progress_state];
}

let weight_sets: Result<Vec<Vec<f32>>> = trainsets
.into_par_iter()
.enumerate()
.map(|(idx, trainset)| {
let model = train::<Autodiff<B>>(
trainset,
testset.clone(),
&config,
self.device(),
progress.clone().map(|p| ProgressCollector::new(p, idx)),
);
Ok(model
.map_err(|e| {
finish_progress();
e
})?
.w
.val()
.to_data()
.convert()
.value)
})
.collect();
let model = train::<Autodiff<B>>(
next_train_set.clone(),
next_train_set,
&config,
self.device(),
progress.clone().map(|p| ProgressCollector::new(p, 0)),
);

let optimized_parameters = model
.map_err(|e| {
finish_progress();
e
})?
.w
.val()
.to_data()
.convert()
.value;

finish_progress();

let weight_sets = weight_sets?;
let average_parameters: Vec<f32> = weight_sets
if optimized_parameters
.iter()
.fold(vec![0.0; weight_sets[0].len()], |sum, parameters| {
sum.par_iter().zip(parameters).map(|(a, b)| a + b).collect()
})
.par_iter()
.map(|&sum| sum / n_splits as f32)
.collect();

if average_parameters.iter().any(|weight| weight.is_infinite()) {
.any(|weight: &f32| weight.is_infinite())
{
return Err(FSRSError::InvalidInput);
}

Ok(average_parameters)
Ok(optimized_parameters)
}

pub fn benchmark(&self, train_set: Vec<FSRSItem>, test_set: Vec<FSRSItem>) -> Vec<f32> {
pub fn benchmark(&self, train_set: Vec<FSRSItem>) -> Vec<f32> {
let average_recall = calculate_average_recall(&train_set);
let (pre_train_set, next_train_set) = train_set
.into_iter()
Expand All @@ -315,34 +288,40 @@ impl<B: Backend> FSRS<B> {
},
AdamConfig::new(),
);
let model = train::<Autodiff<B>>(next_train_set, test_set, &config, self.device(), None);
let model = train::<Autodiff<B>>(
next_train_set.clone(),
next_train_set,
&config,
self.device(),
None,
);
let parameters: Vec<f32> = model.unwrap().w.val().to_data().convert().value;
parameters
}
}

fn train<B: AutodiffBackend>(
trainset: Vec<FSRSItem>,
testset: Vec<FSRSItem>,
train_set: Vec<FSRSItem>,
test_set: Vec<FSRSItem>,
config: &TrainingConfig,
device: B::Device,
progress: Option<ProgressCollector>,
) -> Result<Model<B>> {
B::seed(config.seed);

// Training data
let iterations = (trainset.len() / config.batch_size + 1) * config.num_epochs;
let iterations = (train_set.len() / config.batch_size + 1) * config.num_epochs;
let batcher_train = FSRSBatcher::<B>::new(device.clone());
let dataloader_train = BatchShuffledDataLoaderBuilder::new(batcher_train).build(
FSRSDataset::from(trainset),
FSRSDataset::from(train_set),
config.batch_size,
config.seed,
);

let batcher_valid = FSRSBatcher::new(device);
let dataloader_valid = DataLoaderBuilder::new(batcher_valid)
.batch_size(config.batch_size)
.build(FSRSDataset::from(testset.clone()));
.build(FSRSDataset::from(test_set.clone()));

let mut lr_scheduler = CosineAnnealingLR::init(iterations as f64, config.learning_rate);
let interrupter = TrainingInterrupter::new();
Expand Down Expand Up @@ -414,7 +393,7 @@ fn train<B: AutodiffBackend>(
break;
}
}
loss_valid /= testset.len() as f64;
loss_valid /= test_set.len() as f64;
info!("epoch: {:?} loss: {:?}", epoch, loss_valid);
if loss_valid < best_loss {
best_loss = loss_valid;
Expand Down Expand Up @@ -452,9 +431,6 @@ mod tests {

use super::*;
use crate::convertor_tests::anki21_sample_file_converted_to_fsrs;
use crate::pre_training::pretrain;
use crate::test_helpers::NdArrayAutodiff;
use burn::backend::ndarray::NdArrayDevice;
use log::LevelFilter;

#[test]
Expand Down Expand Up @@ -490,24 +466,7 @@ mod tests {
.apply()
.unwrap();
}
let n_splits = 5;
let device = NdArrayDevice::Cpu;
let items = anki21_sample_file_converted_to_fsrs();
let (pre_trainset, trainsets, testset) = split_data(items.clone(), n_splits);
let items = [pre_trainset.clone(), testset.clone()].concat();
dbg!(pre_trainset.len());
dbg!(testset.len());
let average_recall = calculate_average_recall(&items);
dbg!(average_recall);
let initial_stability = pretrain(pre_trainset, average_recall).unwrap();
dbg!(initial_stability);
let config = TrainingConfig::new(
ModelConfig {
freeze_stability: true,
initial_stability: Some(initial_stability),
},
AdamConfig::new(),
);
let progress = CombinedProgressState::new_shared();
let progress2 = Some(progress.clone());
thread::spawn(move || {
Expand All @@ -520,57 +479,8 @@ mod tests {
}
});

let trainsets: Vec<Vec<FSRSItem>> = (0..n_splits)
.into_par_iter()
.map(|i| {
trainsets
.par_iter()
.enumerate()
.filter(|&(j, _)| j != i)
.flat_map(|(_, trainset)| trainset.clone())
.collect()
})
.collect();

if let Some(progress2) = &progress2 {
let mut progress_states = vec![ProgressState::default(); n_splits];
for (i, progress_state) in progress_states.iter_mut().enumerate() {
progress_state.epoch_total = config.num_epochs;
progress_state.items_total = trainsets[i].len();
}
progress2.lock().unwrap().splits = progress_states
}

let parameters_sets: Vec<Vec<f32>> = (0..n_splits)
.into_par_iter()
.map(|i| {
let model = train::<NdArrayAutodiff>(
trainsets[i].clone(),
items.clone(),
&config,
device,
progress2.clone().map(|p| ProgressCollector::new(p, i)),
);
model.unwrap().w.val().to_data().convert().value
})
.collect();

dbg!(&parameters_sets);

let average_parameters: Vec<f32> = parameters_sets
.iter()
.fold(vec![0.0; parameters_sets[0].len()], |sum, parameters| {
sum.par_iter().zip(parameters).map(|(a, b)| a + b).collect()
})
.par_iter()
.map(|&sum| sum / n_splits as f32)
.collect();
dbg!(&average_parameters);
let optimized_fsrs = FSRS::new(Some(&average_parameters)).unwrap();
let optimized_rmse = optimized_fsrs
.evaluate(testset, |_| true)
.unwrap()
.rmse_bins;
dbg!(optimized_rmse);
let fsrs = FSRS::new(Some(&[])).unwrap();
let parameters = fsrs.compute_parameters(items, progress2).unwrap();
dbg!(&parameters);
}
}

0 comments on commit 7d27226

Please sign in to comment.