Skip to content

Commit

Permalink
feat: training (#202)
Browse files Browse the repository at this point in the history
  • Loading branch information
decahedron1 authored Jul 6, 2024
1 parent 0407adb commit 0a43482
Show file tree
Hide file tree
Showing 21 changed files with 1,294 additions and 33 deletions.
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ WixTools/
# ONNX Runtime downloaded models
**/*.onnx
**/*.ort
**/*.pbseq
!examples/webassembly/**/*.ort
!tests/data/*.onnx
!tests/data/*.ort
Expand All @@ -196,4 +197,8 @@ WixTools/
# Glassbench results
/glassbench*.db

# Python virtual environment
.venv*

# Training checkpoints
tools/train-data/**/checkpoint
5 changes: 4 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ members = [
'examples/model-info',
'examples/yolov8',
'examples/modnet',
'examples/training',
'examples/webassembly'
]
default-members = [
Expand Down Expand Up @@ -45,13 +46,15 @@ strip = true
codegen-units = 1

[package.metadata.docs.rs]
features = [ "ndarray", "half", "operator-libraries", "fetch-models", "load-dynamic", "copy-dylibs" ]
features = [ "ndarray", "half", "training", "operator-libraries", "fetch-models", "load-dynamic", "copy-dylibs" ]
targets = ["x86_64-unknown-linux-gnu", "wasm32-unknown-unknown"]
rustdoc-args = [ "--cfg", "docsrs" ]

[features]
default = [ "ndarray", "half", "download-binaries", "copy-dylibs" ]

training = [ "ort-sys/training" ]

operator-libraries = [ "libc", "winapi" ]

fetch-models = [ "ureq" ]
Expand Down
18 changes: 18 additions & 0 deletions examples/training/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
[package]
publish = false
name = "example-training"
version = "0.0.0"
edition = "2021"

[dependencies]
ort = { path = "../../", features = [ "training" ] }
ndarray = "0.15"
tokenizers = { version = ">=0.13.4", default-features = false, features = [ "onig" ] }
rand = "0.8"
simd-json = "0.13"
kdam = "0.5"
tracing-subscriber = { version = "0.3", default-features = false, features = [ "env-filter", "fmt" ] }

[features]
load-dynamic = [ "ort/load-dynamic" ]
cuda = [ "ort/cuda" ]
26 changes: 26 additions & 0 deletions examples/training/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Training Examples

## `train-clm`
This example trains a tiny causal language model on a small subset of pyke's [**OshiChats v2**](https://huggingface.co/datasets/pykeio/oshichats-v2), a dataset of live text chat messages collected from various [VTuber](https://en.wikipedia.org/wiki/VTuber) live streams. The model is not particularly useful or interesting (due to both the low-quality dataset and small model size), but it showcases that entire language models can be trained from scratch entirely in Rust on (almost) any device.

To get started, create a Python virtual environment and install the following packages:
```
pip install -i https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ORT/pypi/simple/ onnxruntime-training-cpu==1.18.0 onnx~=1.17 torch~=2.3
```

We're installing the CPU version of the `onnxruntime-training` & `torch` packages because we only need to use Python to *create* the initial graph which will be used for training. Run `python tools/train-data/mini-clm.py` from the root directory of the `ort` repo to create the training artifacts.

Next, we need to convert our dataset into tokens to feed the model. This can be achieved by downloading the `oshicats-v2.jsonl` file from the OshiChats v2 dataset and running `cargo run -p example-training --example pretokenize -- ~/oshichats-v2.jsonl`, or if you (rightfully) don't wish to waste 30 GB worth of disk space and bandwidth on brainrot, you may download a [1 MB pre-tokenized subset of the dataset](https://parcel.pyke.io/v2/cdn/assetdelivery/ortrsv2/ex_data/dataset.bin). Make sure `dataset.bin` is in the root of the `ort` repo.

Finally, we can train our model! Run `cargo run -p example-training --example train-clm` to start training. If you have an NVIDIA GPU, add `--features cuda` to enable CUDA, though it's not required and you can train directly on CPU instead. **This will use ~8 GB of (V)RAM!** You can lower the memory usage by adjusting the `BATCH_SIZE` and `SEQUENCE_LENGTH` constants in `train-clm.rs`, though note that changing the batch size may require adjustments to the learning rate.

While training, the progress bar will show the cross-entropy loss at each training step. At the end of training, the final trained model will be saved to `trained-clm.onnx`, and the program will use the model to generate a small snippet of text:
```
100%|██████████████████████████████████████| 5000/5000 [06:29<00:00, 12.83it/s, loss=3.611]
I'm so much better than the game<|endoftext|>I think you can't see it<|endoftext|>I think you can't see it<|endoftext|>I think so it's a new game<|endoftext|>I think I'm sure you can't see what you can't see it<|endoftext|>
```

Not bad, considering the model & dataset size! This example can easily be scaled up to pre-train or fine-tune (both full-parameter and PEFT) larger language models like Llama/Phi, so long as you have enough compute.

## `train-clm-simple`
This example is functionally identical to `train-clm`, except it uses ort's "simple" Trainer API instead of implementing a manual training loop. The simple API is more akin to 🤗 Transformer's [`Trainer`](https://huggingface.co/docs/transformers/en/main_classes/trainer) API or [PyTorch Lightning](https://lightning.ai/pytorch-lightning). With the simple API, all you have to do is pass a data loader & parameters, and let `ort` handle training for you!
5 changes: 5 additions & 0 deletions examples/training/build.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
fn main() {
// Need this for CoreML. See: https://ort.pyke.io/perf/execution-providers#coreml
#[cfg(target_os = "macos")]
println!("cargo:rustc-link-arg=-fapple-link-rtlib");
}
44 changes: 44 additions & 0 deletions examples/training/examples/pretokenize.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
use std::{
env,
fs::File,
io::{BufRead, BufReader, BufWriter, Write},
path::Path
};

use simd_json::derived::ValueObjectAccessAsScalar;
use tokenizers::Tokenizer;

const MAX_TOKENS: usize = 500_000;

fn main() {
let input = env::args().nth(1).expect("provide input jsonl");
let output = env::args().nth(2).unwrap_or_else(|| "dataset.bin".into());

let input = BufReader::new(File::open(input).unwrap());
let mut output = BufWriter::new(File::create(output).unwrap());

let tokenizer = Tokenizer::from_file(
Path::new(env!("CARGO_MANIFEST_DIR"))
.parent()
.unwrap()
.join("gpt2")
.join("data")
.join("tokenizer.json")
)
.unwrap();
let mut bytes_written = 0;

for line in input.lines() {
let line: simd_json::OwnedValue = unsafe { simd_json::from_str(&mut line.unwrap()).unwrap() };
let tokenized = tokenizer
.encode(format!("<|endoftext|>{}", line.get_str("message").unwrap()), false)
.unwrap();
let id_bytes: Vec<u8> = tokenized.get_ids().iter().flat_map(|c| (*c as u16).to_le_bytes()).collect();
output.write_all(&id_bytes).unwrap();
bytes_written += id_bytes.len();
if bytes_written >= MAX_TOKENS * 2 {
output.flush().unwrap();
break;
}
}
}
118 changes: 118 additions & 0 deletions examples/training/examples/train-clm-simple.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
use std::{
fs::File,
io::{Read, Seek, SeekFrom, Write},
path::Path
};

use ndarray::{concatenate, s, Array1, Array2, ArrayViewD, Axis};
use ort::{Allocator, CUDAExecutionProvider, CheckpointStrategy, Session, SessionBuilder, Trainer, TrainingArguments};
use rand::RngCore;
use tokenizers::Tokenizer;

const BATCH_SIZE: usize = 16;
const SEQUENCE_LENGTH: usize = 256;

fn main() -> ort::Result<()> {
tracing_subscriber::fmt::init();

ort::init().commit()?;

let trainer = Trainer::new_from_artifacts(
SessionBuilder::new()?.with_execution_providers([CUDAExecutionProvider::default().build()])?,
Allocator::default(),
"tools/train-data/mini-clm",
None
)?;

let tokenizer = Tokenizer::from_file(
Path::new(env!("CARGO_MANIFEST_DIR"))
.parent()
.unwrap()
.join("gpt2")
.join("data")
.join("tokenizer.json")
)
.unwrap();

let mut dataset = File::open("dataset.bin").unwrap();
let file_size = dataset.metadata().unwrap().len();
let num_tokens = (file_size / 2) as usize; // 16-bit tokens
let mut rng = rand::thread_rng();
let mut input_buffer = vec![0u16; SEQUENCE_LENGTH * BATCH_SIZE];
let mut label_buffer = vec![0u16; SEQUENCE_LENGTH * BATCH_SIZE];
let dataloader = move |_: usize| {
for batch in 0..BATCH_SIZE {
let start_idx = rng.next_u64() % (num_tokens - SEQUENCE_LENGTH - 1) as u64;
dataset.seek(SeekFrom::Start(start_idx * 2)).unwrap();
dataset
.read_exact(unsafe {
std::slice::from_raw_parts_mut(
input_buffer[batch * SEQUENCE_LENGTH..(batch + 1) * SEQUENCE_LENGTH]
.as_mut_ptr()
.cast::<u8>(),
SEQUENCE_LENGTH * 2
)
})
.unwrap();
dataset.seek(SeekFrom::Start((start_idx + 1) * 2)).unwrap();
dataset
.read_exact(unsafe {
std::slice::from_raw_parts_mut(
label_buffer[batch * SEQUENCE_LENGTH..(batch + 1) * SEQUENCE_LENGTH]
.as_mut_ptr()
.cast::<u8>(),
SEQUENCE_LENGTH * 2
)
})
.unwrap();
}

Ok((
ort::inputs![Array2::<i64>::from_shape_vec([BATCH_SIZE, SEQUENCE_LENGTH], input_buffer.iter().map(|c| *c as i64).collect()).unwrap()]?,
ort::inputs![Array1::<i64>::from_shape_vec([BATCH_SIZE * SEQUENCE_LENGTH], label_buffer.iter().map(|c| *c as i64).collect()).unwrap()]?
))
};

trainer.train(
TrainingArguments::new(dataloader)
.with_lr(7e-5)
.with_max_steps(5000)
.with_ckpt_strategy(CheckpointStrategy::Steps(500))
)?;

trainer.export("trained-clm.onnx", ["probs"])?;

let session = Session::builder()?.commit_from_file("trained-clm.onnx")?;

let mut stdout = std::io::stdout();

let tokens = tokenizer.encode("<|endoftext|>", false).unwrap();
let tokens = tokens.get_ids().iter().map(|i| *i as i64).collect::<Vec<_>>();

let mut tokens = Array1::from_iter(tokens.iter().cloned());

for _ in 0..50 {
let array = tokens.view().insert_axis(Axis(0));
let outputs = session.run(ort::inputs![array]?)?;
let generated_tokens: ArrayViewD<f32> = outputs["probs"].try_extract_tensor()?;

let probabilities = &mut generated_tokens
.slice(s![-1, ..])
.to_owned()
.iter()
.cloned()
.enumerate()
.collect::<Vec<_>>();
probabilities.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Less));

let token = probabilities[0].0;
tokens = concatenate![Axis(0), tokens, ndarray::array![token.try_into().unwrap()]];

let token_str = tokenizer.decode(&[token as _], false).unwrap();
print!("{}", token_str);
stdout.flush().unwrap();
}

println!();
Ok(())
}
133 changes: 133 additions & 0 deletions examples/training/examples/train-clm.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
use std::{
fs::File,
io::{Read, Seek, SeekFrom, Write},
path::Path
};

use kdam::BarExt;
use ndarray::{concatenate, s, Array1, Array2, ArrayViewD, Axis};
use ort::{Allocator, CUDAExecutionProvider, Checkpoint, Session, SessionBuilder, Trainer};
use rand::RngCore;
use tokenizers::Tokenizer;

const BATCH_SIZE: usize = 16;
const SEQUENCE_LENGTH: usize = 256;

fn main() -> ort::Result<()> {
tracing_subscriber::fmt::init();

ort::init().commit()?;

kdam::term::init(true);
let _ = kdam::term::hide_cursor();

let trainer = Trainer::new(
SessionBuilder::new()?.with_execution_providers([CUDAExecutionProvider::default().build()])?,
Allocator::default(),
Checkpoint::load("tools/train-data/mini-clm/checkpoint")?,
"tools/train-data/mini-clm/training_model.onnx",
"tools/train-data/mini-clm/eval_model.onnx",
"tools/train-data/mini-clm/optimizer_model.onnx"
)?;

let tokenizer = Tokenizer::from_file(
Path::new(env!("CARGO_MANIFEST_DIR"))
.parent()
.unwrap()
.join("gpt2")
.join("data")
.join("tokenizer.json")
)
.unwrap();

let optimizer = trainer.optimizer();
optimizer.set_lr(7e-5)?;

let mut dataset = File::open("dataset.bin").unwrap();
let file_size = dataset.metadata().unwrap().len();
let num_tokens = (file_size / 2) as usize; // 16-bit tokens
let mut rng = rand::thread_rng();

let mut input_buffer = vec![0u16; SEQUENCE_LENGTH * BATCH_SIZE];
let mut label_buffer = vec![0u16; SEQUENCE_LENGTH * BATCH_SIZE];
let mut pb = kdam::tqdm!(total = 5000);
for _ in 0..5000 {
for batch in 0..BATCH_SIZE {
let start_idx = rng.next_u64() % (num_tokens - SEQUENCE_LENGTH - 1) as u64;
dataset.seek(SeekFrom::Start(start_idx * 2)).unwrap();
dataset
.read_exact(unsafe {
std::slice::from_raw_parts_mut(
input_buffer[batch * SEQUENCE_LENGTH..(batch + 1) * SEQUENCE_LENGTH]
.as_mut_ptr()
.cast::<u8>(),
SEQUENCE_LENGTH * 2
)
})
.unwrap();
dataset.seek(SeekFrom::Start((start_idx + 1) * 2)).unwrap();
dataset
.read_exact(unsafe {
std::slice::from_raw_parts_mut(
label_buffer[batch * SEQUENCE_LENGTH..(batch + 1) * SEQUENCE_LENGTH]
.as_mut_ptr()
.cast::<u8>(),
SEQUENCE_LENGTH * 2
)
})
.unwrap();
}

let inputs = Array2::<i64>::from_shape_vec([BATCH_SIZE, SEQUENCE_LENGTH], input_buffer.iter().map(|c| *c as i64).collect()).unwrap();
let labels = Array1::<i64>::from_shape_vec([BATCH_SIZE * SEQUENCE_LENGTH], label_buffer.iter().map(|c| *c as i64).collect()).unwrap();

let outputs = trainer.step(ort::inputs![inputs.view()]?, ort::inputs![labels.view()]?)?;
let loss = outputs[0].try_extract_scalar::<f32>()?;
pb.set_postfix(format!("loss={loss:.3}"));
pb.update(1).unwrap();
if loss.is_nan() {
return Ok(());
}
optimizer.step()?;
optimizer.reset_grad()?;
}

eprintln!();
let _ = kdam::term::show_cursor();

trainer.export("trained-clm.onnx", ["probs"])?;

let session = Session::builder()?.commit_from_file("trained-clm.onnx")?;

let mut stdout = std::io::stdout();

let tokens = tokenizer.encode("<|endoftext|>", false).unwrap();
let tokens = tokens.get_ids().iter().map(|i| *i as i64).collect::<Vec<_>>();

let mut tokens = Array1::from_iter(tokens.iter().cloned());

for _ in 0..50 {
let array = tokens.view().insert_axis(Axis(0));
let outputs = session.run(ort::inputs![array]?)?;
let generated_tokens: ArrayViewD<f32> = outputs["probs"].try_extract_tensor()?;

let probabilities = &mut generated_tokens
.slice(s![-1, ..])
.to_owned()
.iter()
.cloned()
.enumerate()
.collect::<Vec<_>>();
probabilities.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Less));

let token = probabilities[0].0;
tokens = concatenate![Axis(0), tokens, ndarray::array![token.try_into().unwrap()]];

let token_str = tokenizer.decode(&[token as _], false).unwrap();
print!("{}", token_str);
stdout.flush().unwrap();
}

println!();
Ok(())
}
1 change: 1 addition & 0 deletions ort-sys/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ include = [ "src/", "dist.txt", "build.rs", "LICENSE-APACHE", "LICENSE-MIT" ]

[features]
default = []
training = []
download-binaries = [ "ureq", "tar", "flate2", "sha2" ]
load-dynamic = []
copy-dylibs = []
Expand Down
Loading

0 comments on commit 0a43482

Please sign in to comment.