diff --git a/.gitignore b/.gitignore index b00a624..6ab7181 100644 --- a/.gitignore +++ b/.gitignore @@ -186,6 +186,7 @@ WixTools/ # ONNX Runtime downloaded models **/*.onnx **/*.ort +**/*.pbseq !examples/webassembly/**/*.ort !tests/data/*.onnx !tests/data/*.ort @@ -196,4 +197,8 @@ WixTools/ # Glassbench results /glassbench*.db +# Python virtual environment .venv* + +# Training checkpoints +tools/train-data/**/checkpoint diff --git a/Cargo.toml b/Cargo.toml index e9c3ccf..c1a70bb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,6 +7,7 @@ members = [ 'examples/model-info', 'examples/yolov8', 'examples/modnet', + 'examples/training', 'examples/webassembly' ] default-members = [ @@ -45,13 +46,15 @@ strip = true codegen-units = 1 [package.metadata.docs.rs] -features = [ "ndarray", "half", "operator-libraries", "fetch-models", "load-dynamic", "copy-dylibs" ] +features = [ "ndarray", "half", "training", "operator-libraries", "fetch-models", "load-dynamic", "copy-dylibs" ] targets = ["x86_64-unknown-linux-gnu", "wasm32-unknown-unknown"] rustdoc-args = [ "--cfg", "docsrs" ] [features] default = [ "ndarray", "half", "download-binaries", "copy-dylibs" ] +training = [ "ort-sys/training" ] + operator-libraries = [ "libc", "winapi" ] fetch-models = [ "ureq" ] diff --git a/examples/training/Cargo.toml b/examples/training/Cargo.toml new file mode 100644 index 0000000..945f62e --- /dev/null +++ b/examples/training/Cargo.toml @@ -0,0 +1,18 @@ +[package] +publish = false +name = "example-training" +version = "0.0.0" +edition = "2021" + +[dependencies] +ort = { path = "../../", features = [ "training" ] } +ndarray = "0.15" +tokenizers = { version = ">=0.13.4", default-features = false, features = [ "onig" ] } +rand = "0.8" +simd-json = "0.13" +kdam = "0.5" +tracing-subscriber = { version = "0.3", default-features = false, features = [ "env-filter", "fmt" ] } + +[features] +load-dynamic = [ "ort/load-dynamic" ] +cuda = [ "ort/cuda" ] diff --git a/examples/training/README.md b/examples/training/README.md new file mode 100644 index 0000000..7c99d64 --- /dev/null +++ b/examples/training/README.md @@ -0,0 +1,26 @@ +# Training Examples + +## `train-clm` +This example trains a tiny causal language model on a small subset of pyke's [**OshiChats v2**](https://huggingface.co/datasets/pykeio/oshichats-v2), a dataset of live text chat messages collected from various [VTuber](https://en.wikipedia.org/wiki/VTuber) live streams. The model is not particularly useful or interesting (due to both the low-quality dataset and small model size), but it showcases that entire language models can be trained from scratch entirely in Rust on (almost) any device. + +To get started, create a Python virtual environment and install the following packages: +``` +pip install -i https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ORT/pypi/simple/ onnxruntime-training-cpu==1.18.0 onnx~=1.17 torch~=2.3 +``` + +We're installing the CPU version of the `onnxruntime-training` & `torch` packages because we only need to use Python to *create* the initial graph which will be used for training. Run `python tools/train-data/mini-clm.py` from the root directory of the `ort` repo to create the training artifacts. + +Next, we need to convert our dataset into tokens to feed the model. This can be achieved by downloading the `oshicats-v2.jsonl` file from the OshiChats v2 dataset and running `cargo run -p example-training --example pretokenize -- ~/oshichats-v2.jsonl`, or if you (rightfully) don't wish to waste 30 GB worth of disk space and bandwidth on brainrot, you may download a [1 MB pre-tokenized subset of the dataset](https://parcel.pyke.io/v2/cdn/assetdelivery/ortrsv2/ex_data/dataset.bin). Make sure `dataset.bin` is in the root of the `ort` repo. + +Finally, we can train our model! Run `cargo run -p example-training --example train-clm` to start training. If you have an NVIDIA GPU, add `--features cuda` to enable CUDA, though it's not required and you can train directly on CPU instead. **This will use ~8 GB of (V)RAM!** You can lower the memory usage by adjusting the `BATCH_SIZE` and `SEQUENCE_LENGTH` constants in `train-clm.rs`, though note that changing the batch size may require adjustments to the learning rate. + +While training, the progress bar will show the cross-entropy loss at each training step. At the end of training, the final trained model will be saved to `trained-clm.onnx`, and the program will use the model to generate a small snippet of text: +``` +100%|██████████████████████████████████████| 5000/5000 [06:29<00:00, 12.83it/s, loss=3.611] +I'm so much better than the game<|endoftext|>I think you can't see it<|endoftext|>I think you can't see it<|endoftext|>I think so it's a new game<|endoftext|>I think I'm sure you can't see what you can't see it<|endoftext|> +``` + +Not bad, considering the model & dataset size! This example can easily be scaled up to pre-train or fine-tune (both full-parameter and PEFT) larger language models like Llama/Phi, so long as you have enough compute. + +## `train-clm-simple` +This example is functionally identical to `train-clm`, except it uses ort's "simple" Trainer API instead of implementing a manual training loop. The simple API is more akin to 🤗 Transformer's [`Trainer`](https://huggingface.co/docs/transformers/en/main_classes/trainer) API or [PyTorch Lightning](https://lightning.ai/pytorch-lightning). With the simple API, all you have to do is pass a data loader & parameters, and let `ort` handle training for you! diff --git a/examples/training/build.rs b/examples/training/build.rs new file mode 100644 index 0000000..79d3a0b --- /dev/null +++ b/examples/training/build.rs @@ -0,0 +1,5 @@ +fn main() { + // Need this for CoreML. See: https://ort.pyke.io/perf/execution-providers#coreml + #[cfg(target_os = "macos")] + println!("cargo:rustc-link-arg=-fapple-link-rtlib"); +} diff --git a/examples/training/examples/pretokenize.rs b/examples/training/examples/pretokenize.rs new file mode 100644 index 0000000..79eee19 --- /dev/null +++ b/examples/training/examples/pretokenize.rs @@ -0,0 +1,44 @@ +use std::{ + env, + fs::File, + io::{BufRead, BufReader, BufWriter, Write}, + path::Path +}; + +use simd_json::derived::ValueObjectAccessAsScalar; +use tokenizers::Tokenizer; + +const MAX_TOKENS: usize = 500_000; + +fn main() { + let input = env::args().nth(1).expect("provide input jsonl"); + let output = env::args().nth(2).unwrap_or_else(|| "dataset.bin".into()); + + let input = BufReader::new(File::open(input).unwrap()); + let mut output = BufWriter::new(File::create(output).unwrap()); + + let tokenizer = Tokenizer::from_file( + Path::new(env!("CARGO_MANIFEST_DIR")) + .parent() + .unwrap() + .join("gpt2") + .join("data") + .join("tokenizer.json") + ) + .unwrap(); + let mut bytes_written = 0; + + for line in input.lines() { + let line: simd_json::OwnedValue = unsafe { simd_json::from_str(&mut line.unwrap()).unwrap() }; + let tokenized = tokenizer + .encode(format!("<|endoftext|>{}", line.get_str("message").unwrap()), false) + .unwrap(); + let id_bytes: Vec = tokenized.get_ids().iter().flat_map(|c| (*c as u16).to_le_bytes()).collect(); + output.write_all(&id_bytes).unwrap(); + bytes_written += id_bytes.len(); + if bytes_written >= MAX_TOKENS * 2 { + output.flush().unwrap(); + break; + } + } +} diff --git a/examples/training/examples/train-clm-simple.rs b/examples/training/examples/train-clm-simple.rs new file mode 100644 index 0000000..0c3ac32 --- /dev/null +++ b/examples/training/examples/train-clm-simple.rs @@ -0,0 +1,118 @@ +use std::{ + fs::File, + io::{Read, Seek, SeekFrom, Write}, + path::Path +}; + +use ndarray::{concatenate, s, Array1, Array2, ArrayViewD, Axis}; +use ort::{Allocator, CUDAExecutionProvider, CheckpointStrategy, Session, SessionBuilder, Trainer, TrainingArguments}; +use rand::RngCore; +use tokenizers::Tokenizer; + +const BATCH_SIZE: usize = 16; +const SEQUENCE_LENGTH: usize = 256; + +fn main() -> ort::Result<()> { + tracing_subscriber::fmt::init(); + + ort::init().commit()?; + + let trainer = Trainer::new_from_artifacts( + SessionBuilder::new()?.with_execution_providers([CUDAExecutionProvider::default().build()])?, + Allocator::default(), + "tools/train-data/mini-clm", + None + )?; + + let tokenizer = Tokenizer::from_file( + Path::new(env!("CARGO_MANIFEST_DIR")) + .parent() + .unwrap() + .join("gpt2") + .join("data") + .join("tokenizer.json") + ) + .unwrap(); + + let mut dataset = File::open("dataset.bin").unwrap(); + let file_size = dataset.metadata().unwrap().len(); + let num_tokens = (file_size / 2) as usize; // 16-bit tokens + let mut rng = rand::thread_rng(); + let mut input_buffer = vec![0u16; SEQUENCE_LENGTH * BATCH_SIZE]; + let mut label_buffer = vec![0u16; SEQUENCE_LENGTH * BATCH_SIZE]; + let dataloader = move |_: usize| { + for batch in 0..BATCH_SIZE { + let start_idx = rng.next_u64() % (num_tokens - SEQUENCE_LENGTH - 1) as u64; + dataset.seek(SeekFrom::Start(start_idx * 2)).unwrap(); + dataset + .read_exact(unsafe { + std::slice::from_raw_parts_mut( + input_buffer[batch * SEQUENCE_LENGTH..(batch + 1) * SEQUENCE_LENGTH] + .as_mut_ptr() + .cast::(), + SEQUENCE_LENGTH * 2 + ) + }) + .unwrap(); + dataset.seek(SeekFrom::Start((start_idx + 1) * 2)).unwrap(); + dataset + .read_exact(unsafe { + std::slice::from_raw_parts_mut( + label_buffer[batch * SEQUENCE_LENGTH..(batch + 1) * SEQUENCE_LENGTH] + .as_mut_ptr() + .cast::(), + SEQUENCE_LENGTH * 2 + ) + }) + .unwrap(); + } + + Ok(( + ort::inputs![Array2::::from_shape_vec([BATCH_SIZE, SEQUENCE_LENGTH], input_buffer.iter().map(|c| *c as i64).collect()).unwrap()]?, + ort::inputs![Array1::::from_shape_vec([BATCH_SIZE * SEQUENCE_LENGTH], label_buffer.iter().map(|c| *c as i64).collect()).unwrap()]? + )) + }; + + trainer.train( + TrainingArguments::new(dataloader) + .with_lr(7e-5) + .with_max_steps(5000) + .with_ckpt_strategy(CheckpointStrategy::Steps(500)) + )?; + + trainer.export("trained-clm.onnx", ["probs"])?; + + let session = Session::builder()?.commit_from_file("trained-clm.onnx")?; + + let mut stdout = std::io::stdout(); + + let tokens = tokenizer.encode("<|endoftext|>", false).unwrap(); + let tokens = tokens.get_ids().iter().map(|i| *i as i64).collect::>(); + + let mut tokens = Array1::from_iter(tokens.iter().cloned()); + + for _ in 0..50 { + let array = tokens.view().insert_axis(Axis(0)); + let outputs = session.run(ort::inputs![array]?)?; + let generated_tokens: ArrayViewD = outputs["probs"].try_extract_tensor()?; + + let probabilities = &mut generated_tokens + .slice(s![-1, ..]) + .to_owned() + .iter() + .cloned() + .enumerate() + .collect::>(); + probabilities.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Less)); + + let token = probabilities[0].0; + tokens = concatenate![Axis(0), tokens, ndarray::array![token.try_into().unwrap()]]; + + let token_str = tokenizer.decode(&[token as _], false).unwrap(); + print!("{}", token_str); + stdout.flush().unwrap(); + } + + println!(); + Ok(()) +} diff --git a/examples/training/examples/train-clm.rs b/examples/training/examples/train-clm.rs new file mode 100644 index 0000000..9e46bf4 --- /dev/null +++ b/examples/training/examples/train-clm.rs @@ -0,0 +1,133 @@ +use std::{ + fs::File, + io::{Read, Seek, SeekFrom, Write}, + path::Path +}; + +use kdam::BarExt; +use ndarray::{concatenate, s, Array1, Array2, ArrayViewD, Axis}; +use ort::{Allocator, CUDAExecutionProvider, Checkpoint, Session, SessionBuilder, Trainer}; +use rand::RngCore; +use tokenizers::Tokenizer; + +const BATCH_SIZE: usize = 16; +const SEQUENCE_LENGTH: usize = 256; + +fn main() -> ort::Result<()> { + tracing_subscriber::fmt::init(); + + ort::init().commit()?; + + kdam::term::init(true); + let _ = kdam::term::hide_cursor(); + + let trainer = Trainer::new( + SessionBuilder::new()?.with_execution_providers([CUDAExecutionProvider::default().build()])?, + Allocator::default(), + Checkpoint::load("tools/train-data/mini-clm/checkpoint")?, + "tools/train-data/mini-clm/training_model.onnx", + "tools/train-data/mini-clm/eval_model.onnx", + "tools/train-data/mini-clm/optimizer_model.onnx" + )?; + + let tokenizer = Tokenizer::from_file( + Path::new(env!("CARGO_MANIFEST_DIR")) + .parent() + .unwrap() + .join("gpt2") + .join("data") + .join("tokenizer.json") + ) + .unwrap(); + + let optimizer = trainer.optimizer(); + optimizer.set_lr(7e-5)?; + + let mut dataset = File::open("dataset.bin").unwrap(); + let file_size = dataset.metadata().unwrap().len(); + let num_tokens = (file_size / 2) as usize; // 16-bit tokens + let mut rng = rand::thread_rng(); + + let mut input_buffer = vec![0u16; SEQUENCE_LENGTH * BATCH_SIZE]; + let mut label_buffer = vec![0u16; SEQUENCE_LENGTH * BATCH_SIZE]; + let mut pb = kdam::tqdm!(total = 5000); + for _ in 0..5000 { + for batch in 0..BATCH_SIZE { + let start_idx = rng.next_u64() % (num_tokens - SEQUENCE_LENGTH - 1) as u64; + dataset.seek(SeekFrom::Start(start_idx * 2)).unwrap(); + dataset + .read_exact(unsafe { + std::slice::from_raw_parts_mut( + input_buffer[batch * SEQUENCE_LENGTH..(batch + 1) * SEQUENCE_LENGTH] + .as_mut_ptr() + .cast::(), + SEQUENCE_LENGTH * 2 + ) + }) + .unwrap(); + dataset.seek(SeekFrom::Start((start_idx + 1) * 2)).unwrap(); + dataset + .read_exact(unsafe { + std::slice::from_raw_parts_mut( + label_buffer[batch * SEQUENCE_LENGTH..(batch + 1) * SEQUENCE_LENGTH] + .as_mut_ptr() + .cast::(), + SEQUENCE_LENGTH * 2 + ) + }) + .unwrap(); + } + + let inputs = Array2::::from_shape_vec([BATCH_SIZE, SEQUENCE_LENGTH], input_buffer.iter().map(|c| *c as i64).collect()).unwrap(); + let labels = Array1::::from_shape_vec([BATCH_SIZE * SEQUENCE_LENGTH], label_buffer.iter().map(|c| *c as i64).collect()).unwrap(); + + let outputs = trainer.step(ort::inputs![inputs.view()]?, ort::inputs![labels.view()]?)?; + let loss = outputs[0].try_extract_scalar::()?; + pb.set_postfix(format!("loss={loss:.3}")); + pb.update(1).unwrap(); + if loss.is_nan() { + return Ok(()); + } + optimizer.step()?; + optimizer.reset_grad()?; + } + + eprintln!(); + let _ = kdam::term::show_cursor(); + + trainer.export("trained-clm.onnx", ["probs"])?; + + let session = Session::builder()?.commit_from_file("trained-clm.onnx")?; + + let mut stdout = std::io::stdout(); + + let tokens = tokenizer.encode("<|endoftext|>", false).unwrap(); + let tokens = tokens.get_ids().iter().map(|i| *i as i64).collect::>(); + + let mut tokens = Array1::from_iter(tokens.iter().cloned()); + + for _ in 0..50 { + let array = tokens.view().insert_axis(Axis(0)); + let outputs = session.run(ort::inputs![array]?)?; + let generated_tokens: ArrayViewD = outputs["probs"].try_extract_tensor()?; + + let probabilities = &mut generated_tokens + .slice(s![-1, ..]) + .to_owned() + .iter() + .cloned() + .enumerate() + .collect::>(); + probabilities.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Less)); + + let token = probabilities[0].0; + tokens = concatenate![Axis(0), tokens, ndarray::array![token.try_into().unwrap()]]; + + let token_str = tokenizer.decode(&[token as _], false).unwrap(); + print!("{}", token_str); + stdout.flush().unwrap(); + } + + println!(); + Ok(()) +} diff --git a/ort-sys/Cargo.toml b/ort-sys/Cargo.toml index 26e2aa8..015a465 100644 --- a/ort-sys/Cargo.toml +++ b/ort-sys/Cargo.toml @@ -16,6 +16,7 @@ include = [ "src/", "dist.txt", "build.rs", "LICENSE-APACHE", "LICENSE-MIT" ] [features] default = [] +training = [] download-binaries = [ "ureq", "tar", "flate2", "sha2" ] load-dynamic = [] copy-dylibs = [] diff --git a/ort-sys/build.rs b/ort-sys/build.rs index d1c5995..719a205 100644 --- a/ort-sys/build.rs +++ b/ort-sys/build.rs @@ -37,12 +37,12 @@ fn fetch_file(source_url: &str) -> Vec { buffer } -fn find_dist(target: &str, designator: &str) -> Option<(&'static str, &'static str)> { +fn find_dist(target: &str, feature_set: &str) -> Option<(&'static str, &'static str)> { DIST_TABLE .split('\n') .filter(|c| !c.is_empty() && !c.starts_with('#')) .map(|c| c.split('\t').collect::>()) - .find(|c| c[0] == designator && c[1] == target) + .find(|c| c[0] == feature_set && c[1] == target) .map(|c| (c[2], c[3])) } @@ -341,23 +341,31 @@ fn prepare_libort_dir() -> (PathBuf, bool) { #[cfg(feature = "download-binaries")] { let target = env::var("TARGET").unwrap().to_string(); - let designator = if cfg!(any(feature = "cuda", feature = "tensorrt")) { - if lib_exists("cudart64_12.dll") || lib_exists("libcudart.so.12") { "cu12" } else { "cu11" } + + let mut feature_set = Vec::new(); + if cfg!(feature = "training") { + feature_set.push("train"); + } + if cfg!(any(feature = "cuda", feature = "tensorrt")) { + if lib_exists("cudart64_11.dll") || lib_exists("libcudart.so.11") || env::var("ORT_DFBIN_FORCE_CUDA_VERSION").as_deref() == Ok("11") { + feature_set.push("cu11"); + } else { + feature_set.push("cu12"); + } } else if cfg!(feature = "rocm") { - "rocm" - } else { - "none" - }; - let mut dist = find_dist(&target, designator); - if dist.is_none() && designator != "none" { + feature_set.push("rocm"); + } + let feature_set = if !feature_set.is_empty() { feature_set.join(",") } else { "none".to_owned() }; + let mut dist = find_dist(&target, &feature_set); + if dist.is_none() && feature_set != "none" { dist = find_dist(&target, "none"); } if dist.is_none() { panic!( "downloaded binaries not available for target {target}{}\nyou may have to compile ONNX Runtime from source", - if designator != "none" { - format!(" (note: also requested `{designator}`)") + if feature_set != "none" { + format!(" (note: also requested features `{feature_set}`)") } else { String::new() } diff --git a/ort-sys/dist.txt b/ort-sys/dist.txt index 5ba397a..98e3f3f 100644 --- a/ort-sys/dist.txt +++ b/ort-sys/dist.txt @@ -4,12 +4,26 @@ cu11 x86_64-unknown-linux-gnu https://parcel.pyke.io/v2/delivery/ortrs/packages/ rocm x86_64-unknown-linux-gnu https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_dylib_rocm-v1.18.1-x86_64-unknown-linux-gnu.tgz 84F74428E0BEC68C55B8E1E91B9282E984CD2866148A2584382B8CB3284214A3 none x86_64-unknown-linux-gnu https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_static-v1.18.1-x86_64-unknown-linux-gnu.tgz 0A193706A95286853D792D7D9B2271CBEA35C57F249943FE811CED97E0E24862 +train aarch64-unknown-linux-gnu https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_static_train-v1.18.1-aarch64-unknown-linux-gnu.tgz C04DBEAF19F2BCD3643F8F7D7FA01110A1AF429DFDD1C1DC7C5EDA2B1A8AA324 +train,cu12 x86_64-unknown-linux-gnu https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_dylib_train,cu12-v1.18.1-x86_64-unknown-linux-gnu.tgz A139D8AD8930930F5A61DF112C8275AAD1F0415FAFD08CE3031CEFFFC30F2445 +train,cu11 x86_64-unknown-linux-gnu https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_dylib_train,cu11-v1.18.1-x86_64-unknown-linux-gnu.tgz 2DAA2E2CF44E9B9A96AB2E9C4271C35189C96BF264D1797DABCF1D6711730DE7 +train,rocm x86_64-unknown-linux-gnu https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_dylib_train,rocm-v1.18.1-x86_64-unknown-linux-gnu.tgz DD373BA6B251D21953223B2FBB64F4DF34CFE98A63C26D16607BEAC6BC788466 +train x86_64-unknown-linux-gnu https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_static_train-v1.18.1-x86_64-unknown-linux-gnu.tgz 0E617970AE83ABE5FB9A3D5D69AAC9A67ED4C9494AD527B14A84FDC98CA9B924 + none aarch64-pc-windows-msvc https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_static-v1.18.1-aarch64-pc-windows-msvc.tgz B2F962F0E75F17F3D657B3504CE891BAA6461B26AF65FBD9244B3CCA17FD79D4 cu12 x86_64-pc-windows-msvc https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_dylib_cu12-v1.18.1-x86_64-pc-windows-msvc.tgz CDBC2D87B202E1847900E94796D102EE4D5C19A9568BBD014838ECD1F5D5350B cu11 x86_64-pc-windows-msvc https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_dylib_cu11-v1.18.1-x86_64-pc-windows-msvc.tgz B514FC25453F955F8592100448B27F5E1762A344E8C2D57D41B908978EF2A126 none x86_64-pc-windows-msvc https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_static-v1.18.1-x86_64-pc-windows-msvc.tgz EB2BCD1778C5934437D4C5B17F67DEAF5F67D2E3C18C7298973EACD41113DC01 +train aarch64-pc-windows-msvc https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_static_train-v1.18.1-aarch64-pc-windows-msvc.tgz 8CC1FFFD8AB5E526A076C29A767A650C436E31179D0C6E52C2EA936067B72566 +train,cu12 x86_64-pc-windows-msvc https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_dylib_train,cu12-v1.18.1-x86_64-pc-windows-msvc.tgz 6AF64567E25B59AD1196D4953EF8C6A65795E8A4B864E10D8303A027AC50B2D0 +train,cu11 x86_64-pc-windows-msvc https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_dylib_train,cu11-v1.18.1-x86_64-pc-windows-msvc.tgz E14AA0F4FBBBCAF925AD4DB4F76B06402F654B36C5F221E00010D1005F47AE56 +train x86_64-pc-windows-msvc https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_static_train-v1.18.1-x86_64-pc-windows-msvc.tgz 84728438E5A950027EBBDC51463F4E5B99B4979087F0F127EA18BC604507E979 + none aarch64-apple-darwin https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_static-v1.18.1-aarch64-apple-darwin.tgz B42BE76AFB9495983A6D5D498D56D5E685B018F1011EF4C5B8C56124B192FD37 none x86_64-apple-darwin https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_static-v1.18.1-x86_64-apple-darwin.tgz 247F73A5B3665A6660DFB35213E6FEAAC6ED6CAC5816DD85A348DF790F60A30B +train aarch64-apple-darwin https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_static_train-v1.18.1-aarch64-apple-darwin.tgz 29DC09AFA5C3619CF3125F3D55DD64E5EE64451D6BD0044527776849AADEE344 +train x86_64-apple-darwin https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-msort_static_train-v1.18.1-x86_64-apple-darwin.tgz 898EC9E3F852843ECDB618CF8E317F4C92BDEB33FC773038960857BCB37CB347 + none wasm32-unknown-unknown https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.1/ortrs-pkort_static-v1.18.1-wasm32-unknown-unknown.tgz D1BF756F02A53C3BC254E3C2048BE617082905A89182A6B1BD18C229920228EF diff --git a/ort-sys/src/lib.rs b/ort-sys/src/lib.rs index f5d3b13..f7cb853 100644 --- a/ort-sys/src/lib.rs +++ b/ort-sys/src/lib.rs @@ -823,9 +823,117 @@ fn bindgen_test_layout_OrtOpenVINOProviderOptions() { } #[repr(C)] #[derive(Debug, Copy, Clone)] -pub struct OrtTrainingApi { +pub struct OrtTrainingSession { _unused: [u8; 0] } +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct OrtCheckpointState { + _unused: [u8; 0] +} +#[repr(i32)] +#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] +pub enum OrtPropertyType { + OrtIntProperty = 0, + OrtFloatProperty = 1, + OrtStringProperty = 2 +} +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct OrtTrainingApi { + pub LoadCheckpoint: + ::std::option::Option<_system!(unsafe fn(checkpoint_path: *const ortchar, checkpoint_state: *mut *mut OrtCheckpointState) -> OrtStatusPtr)>, + pub SaveCheckpoint: ::std::option::Option< + _system!(unsafe fn(checkpoint_state: *mut OrtCheckpointState, checkpoint_path: *const ortchar, include_optimizer_state: bool) -> OrtStatusPtr) + >, + pub CreateTrainingSession: ::std::option::Option< + _system!( + unsafe fn( + env: *const OrtEnv, + options: *const OrtSessionOptions, + checkpoint_state: *mut OrtCheckpointState, + train_model_path: *const ortchar, + eval_model_path: *const ortchar, + optimizer_model_path: *const ortchar, + out: *mut *mut OrtTrainingSession + ) -> OrtStatusPtr + ) + >, + pub CreateTrainingSessionFromBuffer: ::std::option::Option< + _system!( + unsafe fn( + env: *const OrtEnv, + options: *const OrtSessionOptions, + checkpoint_state: *mut OrtCheckpointState, + train_model_data: *const (), + train_data_length: size_t, + eval_model_data: *const (), + eval_data_length: size_t, + optimizer_model_data: *const (), + optimizer_data_length: size_t, + out: *mut *mut OrtTrainingSession + ) -> OrtStatusPtr + ) + >, + pub TrainingSessionGetTrainingModelOutputCount: + ::std::option::Option<_system!(unsafe fn(sess: *const OrtTrainingSession, out: *mut size_t) -> OrtStatusPtr)>, + pub TrainingSessionGetEvalModelOutputCount: ::std::option::Option<_system!(unsafe fn(sess: *const OrtTrainingSession, out: *mut size_t) -> OrtStatusPtr)>, + pub TrainingSessionGetTrainingModelOutputName: ::std::option::Option< + _system!(unsafe fn(sess: *const OrtTrainingSession, index: size_t, allocator: *mut OrtAllocator, output: *mut *mut c_char) -> OrtStatusPtr) + >, + pub TrainingSessionGetEvalModelOutputName: ::std::option::Option< + _system!(unsafe fn(sess: *const OrtTrainingSession, index: size_t, allocator: *mut OrtAllocator, output: *mut *mut c_char) -> OrtStatusPtr) + >, + pub LazyResetGrad: ::std::option::Option<_system!(unsafe fn(session: *mut OrtTrainingSession) -> OrtStatusPtr)>, + pub TrainStep: ::std::option::Option< + _system!( + unsafe fn( + session: *mut OrtTrainingSession, + run_options: *const OrtRunOptions, + inputs_len: size_t, + inputs: *const *const OrtValue, + outputs_len: size_t, + outputs: *mut *mut OrtValue + ) -> OrtStatusPtr + ) + >, + pub EvalStep: ::std::option::Option< + _system!( + unsafe fn( + session: *mut OrtTrainingSession, + run_options: *const OrtRunOptions, + inputs_len: size_t, + inputs: *const *const OrtValue, + outputs_len: size_t, + outputs: *mut *mut OrtValue + ) -> OrtStatusPtr + ) + >, + pub SetLearningRate: ::std::option::Option<_system!(unsafe fn(session: *mut OrtTrainingSession, learning_rate: f32) -> OrtStatusPtr)>, + pub GetLearningRate: ::std::option::Option<_system!(unsafe fn(session: *mut OrtTrainingSession, learning_rate: *mut f32) -> OrtStatusPtr)>, + pub OptimizerStep: ::std::option::Option<_system!(unsafe fn(session: *mut OrtTrainingSession, run_options: *const OrtRunOptions) -> OrtStatusPtr)>, + pub RegisterLinearLRScheduler: ::std::option::Option< + _system!(unsafe fn(session: *mut OrtTrainingSession, warmup_step_count: i64, total_step_count: i64, initial_lr: f32) -> OrtStatusPtr) + >, + pub SchedulerStep: ::std::option::Option<_system!(unsafe fn(session: *mut OrtTrainingSession) -> OrtStatusPtr)>, + pub GetParametersSize: ::std::option::Option<_system!(unsafe fn(session: *mut OrtTrainingSession, out: *mut size_t, trainable_only: bool) -> OrtStatusPtr)>, + pub CopyParametersToBuffer: + ::std::option::Option<_system!(unsafe fn(session: *mut OrtTrainingSession, parameters_buffer: *mut OrtValue, trainable_only: bool) -> OrtStatusPtr)>, + pub CopyBufferToParameters: + ::std::option::Option<_system!(unsafe fn(session: *mut OrtTrainingSession, parameters_buffer: *mut OrtValue, trainable_only: bool) -> OrtStatusPtr)>, + pub ReleaseTrainingSession: ::std::option::Option<_system!(unsafe fn(input: *mut OrtTrainingSession))>, + pub ReleaseCheckpointState: ::std::option::Option<_system!(unsafe fn(input: *mut OrtCheckpointState))>, + pub ExportModelForInferencing: ::std::option::Option< + _system!( + unsafe fn( + session: *mut OrtTrainingSession, + inference_model_path: *const ortchar, + graph_outputs_len: usize, + graph_output_names: *const *const c_char + ) -> OrtStatusPtr + ) + > +} #[doc = " \\brief The helper interface to get the right version of OrtApi\n\n Get a pointer to this structure through ::OrtGetApiBase"] #[repr(C)] #[derive(Debug, Copy, Clone)] diff --git a/src/error.rs b/src/error.rs index 661e332..2e84580 100644 --- a/src/error.rs +++ b/src/error.rs @@ -261,7 +261,9 @@ pub enum Error { #[error("Could't get `AllocatorType` from memory info: {0}")] GetAllocatorType(ErrorInternal), #[error("Could't get device ID from memory info: {0}")] - GetDeviceId(ErrorInternal) + GetDeviceId(ErrorInternal), + #[error("Training API is not enabled in this build of ONNX Runtime.")] + TrainingNotEnabled } impl Error { diff --git a/src/lib.rs b/src/lib.rs index 1dd8204..3455a60 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -23,6 +23,9 @@ pub(crate) mod metadata; pub(crate) mod operator; pub(crate) mod session; pub(crate) mod tensor; +#[cfg(feature = "training")] +pub(crate) mod training; +pub(crate) mod util; pub(crate) mod value; #[cfg_attr(docsrs, doc(cfg(target_arch = "wasm32")))] #[cfg(target_arch = "wasm32")] @@ -66,6 +69,9 @@ pub use self::session::{ #[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] pub use self::tensor::ArrayExtensions; pub use self::tensor::{IntoTensorElementType, PrimitiveTensorElementType, TensorElementType, Utf8Data}; +#[cfg(feature = "training")] +#[cfg_attr(docsrs, doc(cfg(feature = "training")))] +pub use self::training::*; pub use self::value::{ DowncastableTarget, DynMap, DynMapRef, DynMapRefMut, DynMapValueType, DynSequence, DynSequenceRef, DynSequenceRefMut, DynSequenceValueType, DynTensor, DynTensorRef, DynTensorRefMut, DynTensorValueType, DynValue, DynValueTypeMarker, Map, MapRef, MapRefMut, MapValueType, MapValueTypeMarker, Sequence, diff --git a/src/session/builder.rs b/src/session/builder.rs index 8632c57..458c6ad 100644 --- a/src/session/builder.rs +++ b/src/session/builder.rs @@ -1,9 +1,5 @@ #[cfg(any(feature = "operator-libraries", not(windows)))] use std::ffi::CString; -#[cfg(unix)] -use std::os::unix::ffi::OsStrExt; -#[cfg(target_family = "windows")] -use std::os::windows::ffi::OsStrExt; #[cfg(not(target_arch = "wasm32"))] use std::path::Path; #[cfg(feature = "fetch-models")] @@ -316,20 +312,7 @@ impl SessionBuilder { }); } - // Build an OsString, then a vector of bytes to pass to C - let model_path = std::ffi::OsString::from(model_filepath); - #[cfg(target_family = "windows")] - let model_path: Vec = model_path - .encode_wide() - .chain(std::iter::once(0)) // Make sure we have a null terminated string - .collect(); - #[cfg(not(target_family = "windows"))] - let model_path: Vec = model_path - .as_encoded_bytes() - .iter() - .chain(std::iter::once(&b'\0')) // Make sure we have a null terminated string - .map(|b| *b as std::os::raw::c_char) - .collect(); + let model_path = crate::util::path_to_os_char(model_filepath); let env = get_environment()?; apply_execution_providers(&self, env.execution_providers.iter().cloned())?; diff --git a/src/training/mod.rs b/src/training/mod.rs new file mode 100644 index 0000000..d66db11 --- /dev/null +++ b/src/training/mod.rs @@ -0,0 +1,142 @@ +use std::{ + path::Path, + ptr::{self, NonNull}, + sync::{ + atomic::{AtomicPtr, Ordering}, + OnceLock + } +}; + +use crate::{ortsys, Error, Result, RunOptions}; + +mod simple; +mod trainer; + +pub use self::{ + simple::{iterable_data_loader, CheckpointStrategy, DataLoader, EvaluationStrategy, IterableDataLoader, TrainingArguments}, + trainer::Trainer +}; + +pub(crate) static TRAINING_API: OnceLock> = OnceLock::new(); + +/// Returns a pointer to the global [`ort_sys::OrtTrainingApi`] object, or errors if the Training API is not enabled. +/// +/// # Panics +/// May panic if: +/// - Getting the `OrtApi` struct fails, due to `ort` loading an unsupported version of ONNX Runtime. +/// - Loading the ONNX Runtime dynamic library fails if the `load-dynamic` feature is enabled. +pub fn training_api() -> Result> { + NonNull::new( + TRAINING_API + .get_or_init(|| { + let training_api = ortsys![unsafe GetTrainingApi(ort_sys::ORT_API_VERSION)]; + AtomicPtr::new(training_api.cast_mut()) + }) + .load(Ordering::Relaxed) + ) + .ok_or(Error::TrainingNotEnabled) +} + +macro_rules! trainsys { + ($method:ident) => { + $crate::training_api().unwrap().as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null"))) + }; + (unsafe $method:ident) => { + unsafe { $crate::training_api().unwrap().as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null"))) } + }; + ($method:ident($($n:expr),+ $(,)?)) => { + $crate::training_api().unwrap().as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+) + }; + (unsafe $method:ident($($n:expr),+ $(,)?)) => { + unsafe { $crate::training_api().unwrap().as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+) } + }; + ($method:ident($($n:expr),+ $(,)?).expect($e:expr)) => { + $crate::error::status_to_result($crate::training_api().unwrap().as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+)).expect($e) + }; + (unsafe $method:ident($($n:expr),+ $(,)?).expect($e:expr)) => { + $crate::error::status_to_result(unsafe { $crate::training_api().unwrap().as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+) }).expect($e) + }; + ($method:ident($($n:expr),+ $(,)?); nonNull($($check:expr),+ $(,)?)$(;)?) => { + $crate::training_api().unwrap().as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+); + $($crate::error::assert_non_null_pointer($check, stringify!($method))?;)+ + }; + (unsafe $method:ident($($n:expr),+ $(,)?); nonNull($($check:expr),+ $(,)?)$(;)?) => {{ + let _x = unsafe { $crate::training_api().unwrap().as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+) }; + $($crate::error::assert_non_null_pointer($check, stringify!($method)).unwrap();)+ + _x + }}; + ($method:ident($($n:expr),+ $(,)?) -> $err:expr$(;)?) => { + $crate::error::status_to_result($crate::training_api()?.as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+)).map_err($err)?; + }; + (unsafe $method:ident($($n:expr),+ $(,)?) -> $err:expr$(;)?) => { + $crate::error::status_to_result(unsafe { $crate::training_api()?.as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+) }).map_err($err)?; + }; + ($method:ident($($n:expr),+ $(,)?) -> $err:expr; nonNull($($check:expr),+ $(,)?)$(;)?) => { + $crate::error::status_to_result($crate::training_api()?.as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+)).map_err($err)?; + $($crate::error::assert_non_null_pointer($check, stringify!($method))?;)+ + }; + (unsafe $method:ident($($n:expr),+ $(,)?) -> $err:expr; nonNull($($check:expr),+ $(,)?)$(;)?) => {{ + $crate::error::status_to_result(unsafe { $crate::training_api()?.as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+) }).map_err($err)?; + $($crate::error::assert_non_null_pointer($check, stringify!($method))?;)+ + }}; +} +pub(crate) use trainsys; + +#[derive(Debug)] +pub struct Checkpoint { + pub(crate) ptr: NonNull +} + +impl Checkpoint { + pub fn load(path: impl AsRef) -> Result { + let path = crate::util::path_to_os_char(path); + let mut ptr: *mut ort_sys::OrtCheckpointState = ptr::null_mut(); + trainsys![unsafe LoadCheckpoint(path.as_ptr(), &mut ptr) -> Error::CreateSession; nonNull(ptr)]; + Ok(Checkpoint { + ptr: unsafe { NonNull::new_unchecked(ptr) } + }) + } + + pub fn save(&self, path: impl AsRef, include_optimizer_state: bool) -> Result<()> { + let path = crate::util::path_to_os_char(path); + trainsys![unsafe SaveCheckpoint(self.ptr.as_ptr(), path.as_ptr(), include_optimizer_state) -> Error::CreateSession]; + Ok(()) + } +} + +impl Drop for Checkpoint { + fn drop(&mut self) { + tracing::trace!("dropping checkpoint"); + trainsys![unsafe ReleaseCheckpointState(self.ptr.as_ptr())]; + } +} + +#[derive(Debug)] +pub struct Optimizer(NonNull); + +impl Optimizer { + pub fn reset_grad(&self) -> Result<()> { + trainsys![unsafe LazyResetGrad(self.0.as_ptr()) -> Error::CreateSession]; + Ok(()) + } + + pub fn lr(&self) -> Result { + let mut lr = f32::NAN; + trainsys![unsafe GetLearningRate(self.0.as_ptr(), &mut lr) -> Error::CreateSession]; + Ok(lr) + } + + pub fn set_lr(&self, lr: f32) -> Result<()> { + trainsys![unsafe SetLearningRate(self.0.as_ptr(), lr) -> Error::CreateSession]; + Ok(()) + } + + pub fn step(&self) -> Result<()> { + self.step_with_options(RunOptions::new()?) + } + + pub fn step_with_options(&self, options: RunOptions) -> Result<()> { + trainsys![unsafe OptimizerStep(self.0.as_ptr(), options.run_options_ptr.as_ptr()) -> Error::CreateSession]; + Ok(()) + } +} diff --git a/src/training/simple.rs b/src/training/simple.rs new file mode 100644 index 0000000..267f3c6 --- /dev/null +++ b/src/training/simple.rs @@ -0,0 +1,240 @@ +use std::{collections::VecDeque, fs, path::PathBuf}; + +use crate::{Result, SessionInputs}; + +#[allow(clippy::len_without_is_empty)] +pub trait DataLoader { + fn load(&mut self, idx: usize) -> Result<(I, L)>; + + fn len(&self) -> Option { + None + } +} + +pub struct IterableDataLoader Result<(I, L)>> { + items: Box<[T]>, + collator: C +} + +impl Result<(I, L)>> DataLoader for IterableDataLoader { + fn load(&mut self, idx: usize) -> Result<(I, L)> { + (self.collator)(&self.items[idx]) + } + + fn len(&self) -> Option { + Some(self.items.len()) + } +} + +pub fn iterable_data_loader Result<(I, L)>>(iterable: impl Iterator, collator: C) -> IterableDataLoader { + IterableDataLoader { + items: iterable.collect::>().into_boxed_slice(), + collator + } +} + +impl Result<(I, L)>> DataLoader for F { + fn load(&mut self, idx: usize) -> Result<(I, L)> { + (self)(idx) + } + + fn len(&self) -> Option { + None + } +} + +pub enum EvaluationStrategy { + None, + Steps(usize), + Epochs(usize) +} + +impl EvaluationStrategy { + pub(crate) fn should_fire(&self, _global_step: usize, iter_step: usize, dataloader_size: Option) -> bool { + match self { + Self::None => false, + Self::Steps(steps) => iter_step > 0 && iter_step % steps == 0, + Self::Epochs(epochs) => { + if let Some(dataloader_size) = dataloader_size { + iter_step > 0 && iter_step % (dataloader_size * epochs) == 0 + } else { + false + } + } + } + } +} + +pub enum CheckpointStrategy { + None, + Steps(usize), + Epochs(usize) +} + +impl CheckpointStrategy { + pub(crate) fn should_fire(&self, _global_step: usize, iter_step: usize, dataloader_size: Option) -> bool { + match self { + Self::None => false, + Self::Steps(steps) => iter_step > 0 && iter_step % steps == 0, + Self::Epochs(epochs) => { + if let Some(dataloader_size) = dataloader_size { + iter_step > 0 && iter_step % (dataloader_size * epochs) == 0 + } else { + false + } + } + } + } +} + +pub struct TrainingArguments>, L: Into>, const NI: usize, const NL: usize> { + loader: Box>, + eval_loader: Option>>, + eval_strategy: EvaluationStrategy, + ckpt_strategy: CheckpointStrategy, + ckpt_path: PathBuf, + lr: f32, + max_saved_ckpts: usize, + gradient_accumulation_steps: usize, + max_steps: usize, + max_eval_steps: usize +} + +impl>, L: Into>, const NI: usize, const NL: usize> + TrainingArguments +{ + pub fn new + 'static>(train_loader: D) -> Self { + Self { + loader: Box::new(train_loader), + eval_loader: None, + eval_strategy: EvaluationStrategy::None, + ckpt_strategy: CheckpointStrategy::Epochs(1), + ckpt_path: PathBuf::from("checkpoints"), + lr: 1e-4, + gradient_accumulation_steps: 1, + max_saved_ckpts: 1, + max_steps: usize::MAX, + max_eval_steps: usize::MAX + } + } + + pub fn with_lr(mut self, lr: f32) -> Self { + self.lr = lr; + self + } + + pub fn with_max_steps(mut self, steps: usize) -> Self { + self.max_steps = steps; + self + } + + pub fn with_max_eval_steps(mut self, steps: usize) -> Self { + self.max_eval_steps = steps; + self + } + + pub fn with_gradient_accumulation(mut self, steps: usize) -> Self { + self.gradient_accumulation_steps = steps; + self + } + + pub fn with_ckpt_path(mut self, path: impl Into) -> Self { + self.ckpt_path = path.into(); + self + } + + pub fn with_ckpt_strategy(mut self, strategy: CheckpointStrategy) -> Self { + self.ckpt_strategy = strategy; + self + } + + pub fn with_max_saved_ckpts(mut self, max_ckpts: usize) -> Self { + self.max_saved_ckpts = max_ckpts; + self + } + + pub fn with_eval_loader + 'static>(mut self, eval_loader: D) -> Self { + self.eval_loader = Some(Box::new(eval_loader)); + self + } + + pub fn with_eval_strategy(mut self, strategy: EvaluationStrategy) -> Self { + self.eval_strategy = strategy; + self + } +} + +impl super::Trainer { + pub fn train>, L: Into>, const NI: usize, const NL: usize>( + &self, + mut args: TrainingArguments + ) -> crate::Result<()> { + let optimizer = self.optimizer(); + optimizer.set_lr(args.lr)?; + + let mut saved_ckpts = VecDeque::new(); + let mut global_step = 0; + for (iter_step, _) in (0..args.max_steps).enumerate() { + let epoch = iter_step / args.loader.len().unwrap_or(usize::MAX); + let (inputs, labels) = args.loader.load(iter_step)?; + let (inputs, labels) = (inputs.into(), labels.into()); + + let outputs = self.step(inputs, labels)?; + let loss = outputs[0].try_extract_scalar::()?; + println!("epoch={epoch} step={global_step} loss={loss}"); + + if iter_step % args.gradient_accumulation_steps == 0 { + optimizer.step()?; + optimizer.reset_grad()?; + global_step += 1; + } + + if args.ckpt_strategy.should_fire(global_step, iter_step, args.loader.len()) { + if !args.ckpt_path.exists() { + let _ = fs::create_dir_all(&args.ckpt_path); + } + + let ckpt_path = args.ckpt_path.join(format!("epoch={epoch},step={global_step}.ortckpt")); + self.checkpoint().save(&ckpt_path, true)?; + + saved_ckpts.push_front(ckpt_path.clone()); + while saved_ckpts.len() > args.max_saved_ckpts { + let Some(old_ckpt) = saved_ckpts.pop_back() else { + break; + }; + let _ = fs::remove_file(old_ckpt); + } + } + + if args + .eval_strategy + .should_fire(global_step, iter_step, args.eval_loader.as_ref().and_then(|d| d.len())) + { + let eval_loss = self.eval_inner(&mut args)?; + println!("eval_loss={eval_loss}"); + } + } + Ok(()) + } + + pub(crate) fn eval_inner>, L: Into>, const NI: usize, const NL: usize>( + &self, + args: &mut TrainingArguments + ) -> crate::Result { + let Some(eval_loader) = &mut args.eval_loader else { + return Ok(0.0); + }; + + let mut total_loss = 0.0; + for step in 0..args.max_eval_steps.min(eval_loader.len().unwrap_or(usize::MAX)) { + let (inputs, labels) = eval_loader.load(step)?; + let (inputs, labels) = (inputs.into(), labels.into()); + + let outputs = self.eval_step(inputs, labels)?; + let loss = outputs[0].try_extract_scalar::()?; + total_loss = (total_loss * (step as f32) + loss) / (step as f32 + 1.); + } + + Ok(total_loss) + } +} diff --git a/src/training/trainer.rs b/src/training/trainer.rs new file mode 100644 index 0000000..f7c7cb3 --- /dev/null +++ b/src/training/trainer.rs @@ -0,0 +1,235 @@ +use std::{ + ffi::CString, + path::Path, + ptr::{self, NonNull}, + sync::Arc +}; + +use ort_sys::c_char; + +use super::{trainsys, Checkpoint, Optimizer}; +use crate::{ + char_p_to_string, + error::{assert_non_null_pointer, status_to_result}, + Allocator, Error, Result, RunOptions, SessionBuilder, SessionInputValue, SessionInputs, SessionOutputs, Value +}; + +#[derive(Debug)] +pub struct Trainer { + pub(crate) ptr: NonNull, + train_output_names: Vec, + optimizer: Optimizer, + ckpt: Checkpoint, + _allocator: Allocator +} + +impl Trainer { + pub fn new( + session_options: SessionBuilder, + allocator: Allocator, + ckpt: Checkpoint, + training_model_path: impl AsRef, + eval_model_path: impl AsRef, + optimizer_model_path: impl AsRef + ) -> Result { + let training_model_path = crate::util::path_to_os_char(training_model_path); + let eval_model_path = crate::util::path_to_os_char(eval_model_path); + let optimizer_model_path = crate::util::path_to_os_char(optimizer_model_path); + + let env = crate::get_environment()?; + + let mut ptr: *mut ort_sys::OrtTrainingSession = ptr::null_mut(); + trainsys![unsafe CreateTrainingSession(env.ptr(), session_options.session_options_ptr.as_ptr(), ckpt.ptr.as_ptr(), training_model_path.as_ptr(), eval_model_path.as_ptr(), optimizer_model_path.as_ptr(), &mut ptr) -> Error::CreateSession; nonNull(ptr)]; + + let ptr = unsafe { NonNull::new_unchecked(ptr) }; + + let mut train_output_len = 0; + trainsys![unsafe TrainingSessionGetTrainingModelOutputCount(ptr.as_ptr(), &mut train_output_len) -> Error::CreateSession]; + let train_output_names = (0..train_output_len) + .map(|i| { + let mut name_bytes: *mut c_char = std::ptr::null_mut(); + trainsys![unsafe TrainingSessionGetTrainingModelOutputName(ptr.as_ptr(), i, allocator.ptr.as_ptr(), &mut name_bytes) -> Error::CreateSession]; + let name = match char_p_to_string(name_bytes) { + Ok(name) => name, + Err(e) => { + unsafe { allocator.free(name_bytes) }; + return Err(e); + } + }; + unsafe { allocator.free(name_bytes) }; + Ok(name) + }) + .collect::>>()?; + + Ok(Self { + ptr, + _allocator: allocator, + train_output_names, + optimizer: Optimizer(ptr), + ckpt + }) + } + + pub fn new_from_artifacts( + session_options: SessionBuilder, + allocator: Allocator, + base_dir: impl AsRef, + override_ckpt: Option + ) -> Result { + let base_dir = base_dir.as_ref(); + let ckpt = if let Some(ckpt) = override_ckpt { + ckpt + } else { + Checkpoint::load(base_dir.join("checkpoint"))? + }; + Self::new( + session_options, + allocator, + ckpt, + base_dir.join("training_model.onnx"), + base_dir.join("eval_model.onnx"), + base_dir.join("optimizer_model.onnx") + ) + } + + pub fn step<'s, 'i1, 'v1: 'i1, 'i2: 'i1, 'v2: 'i2 + 'i1, const N1: usize, const N2: usize>( + &'s self, + inputs: impl Into>, + labels: impl Into> + ) -> Result> { + match inputs.into() { + SessionInputs::ValueSlice(input_values) => match labels.into() { + SessionInputs::ValueSlice(labels) => self.step_inner(input_values.iter().chain(labels), None), + SessionInputs::ValueArray(labels) => self.step_inner(input_values.iter().chain(labels.iter()), None), + SessionInputs::ValueMap(_) => unimplemented!("named values not supported?") + }, + SessionInputs::ValueArray(input_values) => match labels.into() { + SessionInputs::ValueSlice(labels) => self.step_inner(input_values.iter().chain(labels), None), + SessionInputs::ValueArray(labels) => self.step_inner(input_values.iter().chain(labels.iter()), None), + SessionInputs::ValueMap(_) => unimplemented!("named values not supported?") + }, + SessionInputs::ValueMap(_) => unimplemented!("named values not supported?") + } + } + + fn step_inner<'r, 's: 'r, 'i1, 'v1: 'i1, 'i2, 'v2: 'i2>( + &'s self, + input_values: impl Iterator>, + run_options: Option<&'r RunOptions> + ) -> Result> { + let mut output_tensor_ptrs: Vec<*mut ort_sys::OrtValue> = vec![std::ptr::null_mut(); self.train_output_names.len()]; + + let input_ort_values: Vec<*const ort_sys::OrtValue> = input_values.map(|input_array_ort| input_array_ort.ptr().cast_const()).collect(); + + let run_options_ptr = if let Some(run_options) = &run_options { + run_options.run_options_ptr.as_ptr() + } else { + std::ptr::null_mut() + }; + + trainsys![unsafe TrainStep(self.ptr.as_ptr(), run_options_ptr, input_ort_values.len(), input_ort_values.as_ptr(), output_tensor_ptrs.len(), output_tensor_ptrs.as_mut_ptr()) -> Error::SessionRun]; + + let outputs: Vec = output_tensor_ptrs + .into_iter() + .map(|tensor_ptr| unsafe { + // TODO: `Value` should absolutely be refactored to accept a different backing pointer than `SharedSessionInner`. + // but for now, nobody should be using the loss tensor past the lifetime of the trainer... right...? 😣 + Value::from_ptr(NonNull::new(tensor_ptr).expect("OrtValue ptr returned from session Run should not be null"), None) + }) + .collect(); + + Ok(SessionOutputs::new(self.train_output_names.iter().map(|o| o.as_str()), outputs)) + } + + pub fn eval_step<'s, 'i1, 'v1: 'i1, 'i2: 'i1, 'v2: 'i2 + 'i1, const N1: usize, const N2: usize>( + &'s self, + inputs: impl Into>, + labels: impl Into> + ) -> Result> { + match inputs.into() { + SessionInputs::ValueSlice(input_values) => match labels.into() { + SessionInputs::ValueSlice(labels) => self.eval_step_inner(input_values.iter().chain(labels), None), + SessionInputs::ValueArray(labels) => self.eval_step_inner(input_values.iter().chain(labels.iter()), None), + SessionInputs::ValueMap(_) => unimplemented!("named values not supported?") + }, + SessionInputs::ValueArray(input_values) => match labels.into() { + SessionInputs::ValueSlice(labels) => self.eval_step_inner(input_values.iter().chain(labels), None), + SessionInputs::ValueArray(labels) => self.eval_step_inner(input_values.iter().chain(labels.iter()), None), + SessionInputs::ValueMap(_) => unimplemented!("named values not supported?") + }, + SessionInputs::ValueMap(_) => unimplemented!("named values not supported?") + } + } + + fn eval_step_inner<'r, 's: 'r, 'i1, 'v1: 'i1, 'i2, 'v2: 'i2>( + &'s self, + input_values: impl Iterator>, + run_options: Option<&'r RunOptions> + ) -> Result> { + let mut output_tensor_ptrs: Vec<*mut ort_sys::OrtValue> = vec![std::ptr::null_mut(); self.train_output_names.len()]; + + let input_ort_values: Vec<*const ort_sys::OrtValue> = input_values.map(|input_array_ort| input_array_ort.ptr().cast_const()).collect(); + + let run_options_ptr = if let Some(run_options) = &run_options { + run_options.run_options_ptr.as_ptr() + } else { + std::ptr::null_mut() + }; + + trainsys![unsafe EvalStep(self.ptr.as_ptr(), run_options_ptr, input_ort_values.len(), input_ort_values.as_ptr(), output_tensor_ptrs.len(), output_tensor_ptrs.as_mut_ptr()) -> Error::SessionRun]; + + let outputs: Vec = output_tensor_ptrs + .into_iter() + .map(|tensor_ptr| unsafe { + // TODO: `Value` should absolutely be refactored to accept a different backing pointer than `SharedSessionInner`. + // but for now, nobody should be using the loss tensor past the lifetime of the trainer... right...? 😣 + Value::from_ptr(NonNull::new(tensor_ptr).expect("OrtValue ptr returned from session Run should not be null"), None) + }) + .collect(); + + Ok(SessionOutputs::new(self.train_output_names.iter().map(|o| o.as_str()), outputs)) + } + + pub fn export>(&self, out_path: impl AsRef, output_names: impl AsRef<[O]>) -> Result<()> { + let out_path = crate::util::path_to_os_char(out_path); + + let output_names_ptr: Vec<*const c_char> = output_names + .as_ref() + .iter() + .map(|output| CString::new(output.as_ref()).unwrap_or_else(|_| unreachable!())) + .map(|n| n.into_raw().cast_const()) + .collect(); + + let res = trainsys![unsafe ExportModelForInferencing(self.ptr.as_ptr(), out_path.as_ptr(), output_names_ptr.len(), output_names_ptr.as_ptr())]; + + // Reconvert name ptrs to CString so drop impl is called and memory is freed + drop( + output_names_ptr + .into_iter() + .map(|p| { + assert_non_null_pointer(p, "c_char for CString")?; + unsafe { Ok(CString::from_raw(p.cast_mut().cast())) } + }) + .collect::>>()? + ); + + status_to_result(res).map_err(Error::CreateSession)?; + + Ok(()) + } + + pub fn optimizer(&self) -> &Optimizer { + &self.optimizer + } + + pub fn checkpoint(&self) -> &Checkpoint { + &self.ckpt + } +} + +impl Drop for Trainer { + fn drop(&mut self) { + tracing::trace!("dropping trainer"); + trainsys![unsafe ReleaseTrainingSession(self.ptr.as_ptr())]; + } +} diff --git a/src/util.rs b/src/util.rs new file mode 100644 index 0000000..bfa11d9 --- /dev/null +++ b/src/util.rs @@ -0,0 +1,26 @@ +#[cfg(not(target_family = "windows"))] +use std::os::raw::c_char; +#[cfg(unix)] +use std::os::unix::ffi::OsStrExt; +#[cfg(target_family = "windows")] +use std::os::windows::ffi::OsStrExt; +use std::{ffi::OsString, path::Path}; + +#[cfg(target_family = "windows")] +type OsCharArray = Vec; +#[cfg(not(target_family = "windows"))] +type OsCharArray = Vec; + +pub fn path_to_os_char(path: impl AsRef) -> OsCharArray { + let model_path = OsString::from(path.as_ref()); + #[cfg(target_family = "windows")] + let model_path: Vec = model_path.encode_wide().chain(std::iter::once(0)).collect(); + #[cfg(not(target_family = "windows"))] + let model_path: Vec = model_path + .as_encoded_bytes() + .iter() + .chain(std::iter::once(&b'\0')) + .map(|b| *b as c_char) + .collect(); + model_path +} diff --git a/tools/requirements.txt b/tools/requirements.txt new file mode 100644 index 0000000..d49cd91 --- /dev/null +++ b/tools/requirements.txt @@ -0,0 +1,4 @@ +torch~=2.3 +torch-ort~=1.17 +onnx~=1.16 +--extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ORT/pypi/simple/ onnxruntime-training-cpu==1.18.0 diff --git a/tools/train-data/mini-clm.py b/tools/train-data/mini-clm.py new file mode 100644 index 0000000..6f06a70 --- /dev/null +++ b/tools/train-data/mini-clm.py @@ -0,0 +1,140 @@ +import math + +import onnx +from onnxruntime.training import artifacts +import torch +from torch import nn, Tensor +from torch.nn import functional as F + +class RMSNorm(nn.Module): + def __init__(self, dim: int, *, eps: float = 1e-6): + super().__init__() + + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def forward(self, x: Tensor) -> Tensor: + if x.dtype != torch.float32: + xf = x.to(dtype=torch.float32) + else: + xf = x + output = (xf * torch.sqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps)) + if x.dtype != torch.float32: + output = output.to(dtype=x.dtype) + return output * self.weight + +class RoPE(nn.Module): + def __init__(self, embedding_dim: int, *, max_seq_length: int = 2048, base: float = 10000.0): + super().__init__() + + pe = torch.zeros(max_seq_length, embedding_dim) + position = torch.arange(0, max_seq_length, dtype=torch.float32).unsqueeze(1) + div_term = torch.exp(torch.arange(0, embedding_dim, step=2).float() * (-math.log(base) / embedding_dim)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + + pe = pe.unsqueeze(0) + self.register_buffer('pe', pe, persistent=False) + + @torch.no_grad() + def forward(self, x: Tensor) -> Tensor: + return x + self.pe[:, :x.shape[1], :] + +class Attention(nn.Module): + def __init__(self, embedding_dim: int, *, rope: RoPE, max_seq_length: int = 2048, n_heads: int = 4): + super().__init__() + + self.embedding_dim = embedding_dim + self.n_heads = n_heads + self.qkv = nn.Linear(embedding_dim, embedding_dim * 3, bias=False) + self.proj = nn.Linear(embedding_dim, embedding_dim, bias=False) + self.rope = rope + self.register_buffer('bias', torch.tril(torch.ones(max_seq_length, max_seq_length))[None, None, :, :], persistent=False) + + def forward(self, x: Tensor) -> Tensor: + b, t, c = x.size() + + x = self.rope(x) + + q, k, v = self.qkv(x).split(self.embedding_dim, dim=2) + q = q.view(b, t, self.n_heads, c // self.n_heads).transpose(1, 2) + k = k.view(b, t, self.n_heads, c // self.n_heads).transpose(1, 2) + v = v.view(b, t, self.n_heads, c // self.n_heads).transpose(1, 2) + + att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + att = att.masked_fill(self.bias[:, :, :t, :t] == 0, float('-inf')) + att = F.softmax(att, dim=-1) + y = att @ v + y = y.transpose(1, 2).contiguous().view(b, t, c) + + return self.proj(y) + +class FFN(nn.Module): + def __init__(self, embedding_dim: int, intermediate_dim: int | None = None): + super().__init__() + + intermediate_dim = intermediate_dim or embedding_dim * 4 + + self.w1 = nn.Linear(embedding_dim, intermediate_dim * 2, bias=False) + self.w2 = nn.Linear(intermediate_dim, embedding_dim, bias=False) + + def forward(self, x: Tensor) -> Tensor: + x, gate = self.w1(x).chunk(2, dim=-1) + return self.w2(F.gelu(gate) * x) + +class Layer(nn.Module): + def __init__(self, embedding_dim: int, rope: RoPE): + super().__init__() + + self.attn = Attention(embedding_dim, rope=rope) + self.norm1 = RMSNorm(embedding_dim) + self.ffn = FFN(embedding_dim) + self.norm2 = RMSNorm(embedding_dim) + + def forward(self, x: Tensor) -> Tensor: + x = x + self.attn(self.norm1(x)) + x = x + self.ffn(self.norm2(x)) + return x + +class CLM(nn.Module): + def __init__(self, embedding_dim: int, n_layers: int, *, vocab_size: int): + super().__init__() + + rope = RoPE(embedding_dim) + self.layers = nn.ModuleList([Layer(embedding_dim, rope=rope) for _ in range(n_layers)]) + self.word_embeddings = nn.Embedding(vocab_size, embedding_dim) + self.norm = RMSNorm(embedding_dim) + self.lm_head = nn.Linear(embedding_dim, vocab_size, bias=False) + + def forward(self, x: Tensor) -> Tensor: + x = self.word_embeddings(x) + for layer in self.layers: + x = layer(x) + logits = self.lm_head(self.norm(x)) + return logits.view(-1, logits.size(-1)) + +lm = CLM(256, 4, vocab_size=50257) +torch.onnx.export( + lm, + torch.randint(0, 50256, (1, 64)), + f'tools/train-data/mini-clm/model.onnx', + input_names=['input_ids'], + output_names=['probs'], + dynamic_axes={ + 'input_ids': {0: 'batch', 1: 'seq'}, + 'probs': {0: 'batch_seq'} + }, + opset_version=14 +) + +onnx_model = onnx.load('tools/train-data/mini-clm/model.onnx') +requires_grad = [param.name for param in onnx_model.graph.initializer] + +artifacts.generate_artifacts( + onnx_model, + requires_grad=requires_grad, + frozen_params=[], + loss=artifacts.LossType.CrossEntropyLoss, + optimizer=artifacts.OptimType.AdamW, + artifact_directory='tools/train-data/mini-clm' +)