From c950792fe01fa8dfc448c2949d442f21f7bc333b Mon Sep 17 00:00:00 2001 From: "Carson M." Date: Sat, 18 May 2024 15:01:14 -0500 Subject: [PATCH 01/12] feat: add skeleton training API --- Cargo.toml | 4 +- ort-sys/src/lib.rs | 103 +++++++++++++++++++++++++++++++++++++++++++- src/error.rs | 4 +- src/lib.rs | 5 +++ src/training/mod.rs | 29 +++++++++++++ 5 files changed, 142 insertions(+), 3 deletions(-) create mode 100644 src/training/mod.rs diff --git a/Cargo.toml b/Cargo.toml index e9c3ccf..d098f5a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -45,13 +45,15 @@ strip = true codegen-units = 1 [package.metadata.docs.rs] -features = [ "ndarray", "half", "operator-libraries", "fetch-models", "load-dynamic", "copy-dylibs" ] +features = [ "ndarray", "half", "training", "operator-libraries", "fetch-models", "load-dynamic", "copy-dylibs" ] targets = ["x86_64-unknown-linux-gnu", "wasm32-unknown-unknown"] rustdoc-args = [ "--cfg", "docsrs" ] [features] default = [ "ndarray", "half", "download-binaries", "copy-dylibs" ] +training = [] + operator-libraries = [ "libc", "winapi" ] fetch-models = [ "ureq" ] diff --git a/ort-sys/src/lib.rs b/ort-sys/src/lib.rs index ee08e63..215a866 100644 --- a/ort-sys/src/lib.rs +++ b/ort-sys/src/lib.rs @@ -821,9 +821,110 @@ fn bindgen_test_layout_OrtOpenVINOProviderOptions() { } #[repr(C)] #[derive(Debug, Copy, Clone)] -pub struct OrtTrainingApi { +pub struct OrtTrainingSession { + _unused: [u8; 0] +} +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct OrtCheckpointState { _unused: [u8; 0] } +#[repr(i32)] +#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] +pub enum OrtPropertyType { + OrtIntProperty = 0, + OrtFloatProperty = 1, + OrtStringProperty = 2 +} +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct OrtTrainingApi { + pub LoadCheckpoint: ::std::option::Option< + _system!(unsafe fn(checkpoint_path: *const ::std::os::raw::c_char, checkpoint_state: *mut *mut OrtCheckpointState) -> OrtStatusPtr) + >, + pub SaveCheckpoint: ::std::option::Option< + _system!( + unsafe fn(checkpoint_state: *mut OrtCheckpointState, checkpoint_path: *const ::std::os::raw::c_char, include_optimizer_state: bool) -> OrtStatusPtr + ) + >, + pub CreateTrainingSession: ::std::option::Option< + _system!( + unsafe fn( + env: *const OrtEnv, + options: *const OrtSessionOptions, + checkpoint_state: *mut OrtCheckpointState, + train_model_path: *const ortchar, + eval_model_path: *const ortchar, + optimizer_model_path: *const ortchar, + out: *mut *mut OrtTrainingSession + ) -> OrtStatusPtr + ) + >, + pub CreateTrainingSessionFromBuffer: ::std::option::Option< + _system!( + unsafe fn( + env: *const OrtEnv, + options: *const OrtSessionOptions, + checkpoint_state: *mut OrtCheckpointState, + train_model_data: *const (), + train_data_length: size_t, + eval_model_data: *const (), + eval_data_length: size_t, + optimizer_model_data: *const (), + optimizer_data_length: size_t, + out: *mut *mut OrtTrainingSession + ) -> OrtStatusPtr + ) + >, + pub TrainingSessionGetTrainingModelOutputCount: + ::std::option::Option<_system!(unsafe fn(sess: *const OrtTrainingSession, out: *mut size_t) -> OrtStatusPtr)>, + pub TrainingSessionGetEvalModelOutputCount: ::std::option::Option<_system!(unsafe fn(sess: *const OrtTrainingSession, out: *mut size_t) -> OrtStatusPtr)>, + pub TrainingSessionGetTrainingModelOutputName: ::std::option::Option< + _system!(unsafe fn(sess: *const OrtTrainingSession, index: size_t, allocator: *mut OrtAllocator, output: *mut *mut char) -> OrtStatusPtr) + >, + pub TrainingSessionGetEvalModelOutputName: ::std::option::Option< + _system!(unsafe fn(sess: *const OrtTrainingSession, index: size_t, allocator: *mut OrtAllocator, output: *mut *mut char) -> OrtStatusPtr) + >, + pub LazyResetGrad: ::std::option::Option<_system!(unsafe fn(session: *mut OrtTrainingSession) -> OrtStatusPtr)>, + pub TrainStep: ::std::option::Option< + _system!( + unsafe fn( + session: *mut OrtTrainingSession, + run_options: *const OrtRunOptions, + inputs_len: size_t, + inputs: *const *const OrtValue, + outputs_len: size_t, + outputs: *mut *mut OrtValue + ) -> OrtStatusPtr + ) + >, + pub EvalStep: ::std::option::Option< + _system!( + unsafe fn( + session: *mut OrtTrainingSession, + run_options: *const OrtRunOptions, + inputs_len: size_t, + inputs: *const *const OrtValue, + outputs_len: size_t, + outputs: *mut *mut OrtValue + ) -> OrtStatusPtr + ) + >, + pub SetLearningRate: ::std::option::Option<_system!(unsafe fn(session: *mut OrtTrainingSession, learning_rate: f32) -> OrtStatusPtr)>, + pub GetLearningRate: ::std::option::Option<_system!(unsafe fn(session: *mut OrtTrainingSession, learning_rate: *mut f32) -> OrtStatusPtr)>, + pub OptimizerStep: ::std::option::Option<_system!(unsafe fn(session: *mut OrtTrainingSession, run_options: *const OrtRunOptions) -> OrtStatusPtr)>, + pub RegisterLinearLRScheduler: ::std::option::Option< + _system!(unsafe fn(session: *mut OrtTrainingSession, warmup_step_count: i64, total_step_count: i64, initial_lr: f32) -> OrtStatusPtr) + >, + pub SchedulerStep: ::std::option::Option<_system!(unsafe fn(session: *mut OrtTrainingSession) -> OrtStatusPtr)>, + pub GetParametersSize: ::std::option::Option<_system!(unsafe fn(session: *mut OrtTrainingSession, out: *mut size_t, trainable_only: bool) -> OrtStatusPtr)>, + pub CopyParametersToBuffer: + ::std::option::Option<_system!(unsafe fn(session: *mut OrtTrainingSession, parameters_buffer: *mut OrtValue, trainable_only: bool) -> OrtStatusPtr)>, + pub CopyBufferToParameters: + ::std::option::Option<_system!(unsafe fn(session: *mut OrtTrainingSession, parameters_buffer: *mut OrtValue, trainable_only: bool) -> OrtStatusPtr)>, + pub ReleaseTrainingSession: ::std::option::Option<_system!(unsafe fn(input: *mut OrtTrainingSession))>, + pub ReleaseCheckpointState: ::std::option::Option<_system!(unsafe fn(input: *mut OrtCheckpointState))> +} #[doc = " \\brief The helper interface to get the right version of OrtApi\n\n Get a pointer to this structure through ::OrtGetApiBase"] #[repr(C)] #[derive(Debug, Copy, Clone)] diff --git a/src/error.rs b/src/error.rs index 0737520..ee193c4 100644 --- a/src/error.rs +++ b/src/error.rs @@ -267,7 +267,9 @@ pub enum Error { #[error("Could't get `AllocatorType` from memory info: {0}")] GetAllocatorType(ErrorInternal), #[error("Could't get device ID from memory info: {0}")] - GetDeviceId(ErrorInternal) + GetDeviceId(ErrorInternal), + #[error("Training API is not enabled in this build of ONNX Runtime.")] + TrainingNotEnabled } impl From for Error { diff --git a/src/lib.rs b/src/lib.rs index ed5a2ad..b69f54e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -23,6 +23,8 @@ pub(crate) mod metadata; pub(crate) mod operator; pub(crate) mod session; pub(crate) mod tensor; +#[cfg(feature = "training")] +pub(crate) mod training; pub(crate) mod value; #[cfg_attr(docsrs, doc(cfg(target_arch = "wasm32")))] #[cfg(target_arch = "wasm32")] @@ -66,6 +68,9 @@ pub use self::session::{ #[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] pub use self::tensor::ArrayExtensions; pub use self::tensor::{IntoTensorElementType, TensorElementType}; +#[cfg(feature = "training")] +#[cfg_attr(docsrs, doc(cfg(feature = "training")))] +pub use self::training::*; pub use self::value::{ DowncastableTarget, DynMap, DynMapRef, DynMapRefMut, DynMapValueType, DynSequence, DynSequenceRef, DynSequenceRefMut, DynSequenceValueType, DynTensor, DynTensorRef, DynTensorRefMut, DynTensorValueType, DynValue, DynValueTypeMarker, Map, MapRef, MapRefMut, MapValueType, MapValueTypeMarker, Sequence, diff --git a/src/training/mod.rs b/src/training/mod.rs new file mode 100644 index 0000000..719c955 --- /dev/null +++ b/src/training/mod.rs @@ -0,0 +1,29 @@ +use std::{ + ptr::NonNull, + sync::{ + atomic::{AtomicPtr, Ordering}, + OnceLock + } +}; + +use crate::{ortsys, Error, Result}; + +pub(crate) static TRAINING_API: OnceLock> = OnceLock::new(); + +/// Returns a pointer to the global [`ort_sys::OrtTrainingApi`] object, or errors if the Training API is not enabled. +/// +/// # Panics +/// May panic if: +/// - Getting the `OrtApi` struct fails, due to `ort` loading an unsupported version of ONNX Runtime. +/// - Loading the ONNX Runtime dynamic library fails if the `load-dynamic` feature is enabled. +pub fn training_api() -> Result> { + NonNull::new( + TRAINING_API + .get_or_init(|| { + let training_api = ortsys![unsafe GetTrainingApi(ort_sys::ORT_API_VERSION)]; + AtomicPtr::new(training_api.cast_mut()) + }) + .load(Ordering::Relaxed) + ) + .ok_or(Error::TrainingNotEnabled) +} From d46d5198ec600a57f8ab7363d08206b498778e1f Mon Sep 17 00:00:00 2001 From: "Carson M." Date: Sat, 18 May 2024 20:48:49 -0500 Subject: [PATCH 02/12] wip --- .gitignore | 4 + ort-sys/src/lib.rs | 9 +-- src/lib.rs | 1 + src/session/builder.rs | 15 +--- src/training/mod.rs | 116 ++++++++++++++++++++++++++++- src/util.rs | 29 ++++++++ tools/requirements.txt | 4 + tools/train-data/mini-clm.py | 139 +++++++++++++++++++++++++++++++++++ 8 files changed, 295 insertions(+), 22 deletions(-) create mode 100644 src/util.rs create mode 100644 tools/requirements.txt create mode 100644 tools/train-data/mini-clm.py diff --git a/.gitignore b/.gitignore index bf1af90..c85c507 100644 --- a/.gitignore +++ b/.gitignore @@ -186,6 +186,7 @@ WixTools/ # ONNX Runtime downloaded models **/*.onnx **/*.ort +**/*.pbseq !examples/webassembly/**/*.ort !tests/data/*.onnx !tests/data/*.ort @@ -195,3 +196,6 @@ WixTools/ # Glassbench results /glassbench*.db + +# Python virtual environment +.venv diff --git a/ort-sys/src/lib.rs b/ort-sys/src/lib.rs index 215a866..dbd47d9 100644 --- a/ort-sys/src/lib.rs +++ b/ort-sys/src/lib.rs @@ -839,13 +839,10 @@ pub enum OrtPropertyType { #[repr(C)] #[derive(Debug, Copy, Clone)] pub struct OrtTrainingApi { - pub LoadCheckpoint: ::std::option::Option< - _system!(unsafe fn(checkpoint_path: *const ::std::os::raw::c_char, checkpoint_state: *mut *mut OrtCheckpointState) -> OrtStatusPtr) - >, + pub LoadCheckpoint: + ::std::option::Option<_system!(unsafe fn(checkpoint_path: *const ortchar, checkpoint_state: *mut *mut OrtCheckpointState) -> OrtStatusPtr)>, pub SaveCheckpoint: ::std::option::Option< - _system!( - unsafe fn(checkpoint_state: *mut OrtCheckpointState, checkpoint_path: *const ::std::os::raw::c_char, include_optimizer_state: bool) -> OrtStatusPtr - ) + _system!(unsafe fn(checkpoint_state: *mut OrtCheckpointState, checkpoint_path: *const ortchar, include_optimizer_state: bool) -> OrtStatusPtr) >, pub CreateTrainingSession: ::std::option::Option< _system!( diff --git a/src/lib.rs b/src/lib.rs index b69f54e..9495cf5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -25,6 +25,7 @@ pub(crate) mod session; pub(crate) mod tensor; #[cfg(feature = "training")] pub(crate) mod training; +pub(crate) mod util; pub(crate) mod value; #[cfg_attr(docsrs, doc(cfg(target_arch = "wasm32")))] #[cfg(target_arch = "wasm32")] diff --git a/src/session/builder.rs b/src/session/builder.rs index 60f716e..e43c961 100644 --- a/src/session/builder.rs +++ b/src/session/builder.rs @@ -313,20 +313,7 @@ impl SessionBuilder { }); } - // Build an OsString, then a vector of bytes to pass to C - let model_path = std::ffi::OsString::from(model_filepath); - #[cfg(target_family = "windows")] - let model_path: Vec = model_path - .encode_wide() - .chain(std::iter::once(0)) // Make sure we have a null terminated string - .collect(); - #[cfg(not(target_family = "windows"))] - let model_path: Vec = model_path - .as_encoded_bytes() - .iter() - .chain(std::iter::once(&b'\0')) // Make sure we have a null terminated string - .map(|b| *b as std::os::raw::c_char) - .collect(); + let model_path = crate::util::path_to_os_char(model_filepath); let env = get_environment()?; apply_execution_providers(&self, env.execution_providers.iter().cloned()); diff --git a/src/training/mod.rs b/src/training/mod.rs index 719c955..8f78420 100644 --- a/src/training/mod.rs +++ b/src/training/mod.rs @@ -1,12 +1,13 @@ use std::{ - ptr::NonNull, + path::Path, + ptr::{self, NonNull}, sync::{ atomic::{AtomicPtr, Ordering}, OnceLock } }; -use crate::{ortsys, Error, Result}; +use crate::{ortsys, Error, Result, SessionBuilder}; pub(crate) static TRAINING_API: OnceLock> = OnceLock::new(); @@ -27,3 +28,114 @@ pub fn training_api() -> Result> { ) .ok_or(Error::TrainingNotEnabled) } + +macro_rules! trainsys { + ($method:ident) => { + $crate::training_api().unwrap().as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null"))) + }; + (unsafe $method:ident) => { + unsafe { $crate::training_api().unwrap().as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null"))) } + }; + ($method:ident($($n:expr),+ $(,)?)) => { + $crate::training_api().unwrap().as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+) + }; + (unsafe $method:ident($($n:expr),+ $(,)?)) => { + unsafe { $crate::training_api().unwrap().as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+) } + }; + ($method:ident($($n:expr),+ $(,)?).expect($e:expr)) => { + $crate::error::status_to_result($crate::training_api().unwrap().as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+)).expect($e) + }; + (unsafe $method:ident($($n:expr),+ $(,)?).expect($e:expr)) => { + $crate::error::status_to_result(unsafe { $crate::training_api().unwrap().as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+) }).expect($e) + }; + ($method:ident($($n:expr),+ $(,)?); nonNull($($check:expr),+ $(,)?)$(;)?) => { + $crate::training_api().unwrap().as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+); + $($crate::error::assert_non_null_pointer($check, stringify!($method))?;)+ + }; + (unsafe $method:ident($($n:expr),+ $(,)?); nonNull($($check:expr),+ $(,)?)$(;)?) => {{ + let _x = unsafe { $crate::training_api().unwrap().as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+) }; + $($crate::error::assert_non_null_pointer($check, stringify!($method)).unwrap();)+ + _x + }}; + ($method:ident($($n:expr),+ $(,)?) -> $err:expr$(;)?) => { + $crate::error::status_to_result($crate::training_api()?.as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+)).map_err($err)?; + }; + (unsafe $method:ident($($n:expr),+ $(,)?) -> $err:expr$(;)?) => { + $crate::error::status_to_result(unsafe { $crate::training_api()?.as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+) }).map_err($err)?; + }; + ($method:ident($($n:expr),+ $(,)?) -> $err:expr; nonNull($($check:expr),+ $(,)?)$(;)?) => { + $crate::error::status_to_result($crate::training_api()?.as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+)).map_err($err)?; + $($crate::error::assert_non_null_pointer($check, stringify!($method))?;)+ + }; + (unsafe $method:ident($($n:expr),+ $(,)?) -> $err:expr; nonNull($($check:expr),+ $(,)?)$(;)?) => {{ + $crate::error::status_to_result(unsafe { $crate::training_api()?.as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+) }).map_err($err)?; + $($crate::error::assert_non_null_pointer($check, stringify!($method))?;)+ + }}; +} + +#[derive(Debug)] +pub struct Checkpoint { + pub(crate) ptr: NonNull +} + +impl Checkpoint { + pub fn load(path: impl AsRef) -> Result { + let path = crate::util::path_to_os_char(path); + let mut ptr: *mut ort_sys::OrtCheckpointState = ptr::null_mut(); + trainsys![unsafe LoadCheckpoint(path.as_ptr(), &mut ptr) -> Error::CreateSession; nonNull(ptr)]; + Ok(Checkpoint { + ptr: unsafe { NonNull::new_unchecked(ptr) } + }) + } + + pub fn save(&self, path: impl AsRef, include_optimizer_state: bool) -> Result<()> { + let path = crate::util::path_to_os_char(path); + trainsys![unsafe SaveCheckpoint(self.ptr.as_ptr(), path.as_ptr(), include_optimizer_state) -> Error::CreateSession]; + Ok(()) + } +} + +impl Drop for Checkpoint { + fn drop(&mut self) { + tracing::trace!("dropping checkpoint"); + trainsys![unsafe ReleaseCheckpointState(self.ptr.as_ptr())]; + } +} + +#[derive(Debug)] +pub struct Trainer { + pub(crate) ptr: NonNull, + ckpt: Checkpoint +} + +impl Trainer { + pub fn new( + session_options: SessionBuilder, + ckpt: Checkpoint, + training_model_path: impl AsRef, + eval_model_path: impl AsRef, + optimizer_model_path: impl AsRef + ) -> Result { + let training_model_path = crate::util::path_to_os_char(training_model_path); + let eval_model_path = crate::util::path_to_os_char(eval_model_path); + let optimizer_model_path = crate::util::path_to_os_char(optimizer_model_path); + let mut ptr: *mut ort_sys::OrtTrainingSession = ptr::null_mut(); + let env = crate::get_environment()?; + trainsys![unsafe CreateTrainingSession(env.ptr(), session_options.session_options_ptr.as_ptr(), ckpt.ptr.as_ptr(), training_model_path.as_ptr(), eval_model_path.as_ptr(), optimizer_model_path.as_ptr(), &mut ptr) -> Error::CreateSession; nonNull(ptr)]; + Ok(Self { + ptr: unsafe { NonNull::new_unchecked(ptr) }, + ckpt + }) + } + + pub fn ckpt(&self) -> &Checkpoint { + &self.ckpt + } +} + +impl Drop for Trainer { + fn drop(&mut self) { + tracing::trace!("dropping trainer"); + trainsys![unsafe ReleaseTrainingSession(self.ptr.as_ptr())]; + } +} diff --git a/src/util.rs b/src/util.rs new file mode 100644 index 0000000..9fa4c0d --- /dev/null +++ b/src/util.rs @@ -0,0 +1,29 @@ +#[cfg(not(target_family = "windows"))] +use std::os::raw::c_char; +#[cfg(unix)] +use std::os::unix::ffi::OsStrExt; +#[cfg(target_family = "windows")] +use std::os::windows::ffi::OsStrExt; +use std::{ffi::OsString, path::Path}; + +#[cfg(target_family = "windows")] +type OsCharArray = Vec; +#[cfg(not(target_family = "windows"))] +type OsCharArray = Vec; + +pub fn path_to_os_char(path: impl AsRef) -> Vec { + let model_path = OsString::from(path.as_ref()); + #[cfg(target_family = "windows")] + let model_path: Vec = model_path + .encode_wide() + .chain(std::iter::once(0)) + .collect(); + #[cfg(not(target_family = "windows"))] + let model_path: Vec = model_path + .as_encoded_bytes() + .iter() + .chain(std::iter::once(&b'\0')) + .map(|b| *b as c_char) + .collect(); + model_path +} diff --git a/tools/requirements.txt b/tools/requirements.txt new file mode 100644 index 0000000..1d17d54 --- /dev/null +++ b/tools/requirements.txt @@ -0,0 +1,4 @@ +torch~=2.3 +torch-ort~=1.17 +onnx~=1.16 +onnxruntime~=1.17 diff --git a/tools/train-data/mini-clm.py b/tools/train-data/mini-clm.py new file mode 100644 index 0000000..e82d934 --- /dev/null +++ b/tools/train-data/mini-clm.py @@ -0,0 +1,139 @@ +import math + +import onnx +from onnxruntime.training import artifacts +import torch +from torch import nn, Tensor +from torch.nn import functional as F + +class RMSNorm(nn.Module): + def __init__(self, dim: int, *, eps: float = 1e-6): + super().__init__() + + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def forward(self, x: Tensor) -> Tensor: + if x.dtype != torch.float32: + xf = x.to(dtype=torch.float32) + else: + xf = x + output = (xf * torch.sqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps)) + if x.dtype != torch.float32: + output = output.to(dtype=x.dtype) + return output * self.weight + +class RoPE(nn.Module): + def __init__(self, embedding_dim: int, *, max_seq_length: int = 2048, base: float = 10000.0): + super().__init__() + + pe = torch.zeros(max_seq_length, embedding_dim) + position = torch.arange(0, max_seq_length, dtype=torch.float32).unsqueeze(1) + div_term = torch.exp(torch.arange(0, embedding_dim, step=2).float() * (-math.log(base) / embedding_dim)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + + pe = pe.unsqueeze(0) + self.register_buffer('pe', pe, persistent=False) + + @torch.no_grad() + def forward(self, x: Tensor) -> Tensor: + return x + self.pe[:, :x.shape[1], :] + +class Attention(nn.Module): + def __init__(self, embedding_dim: int, *, rope: RoPE, max_seq_length: int = 2048, n_heads: int = 4): + super().__init__() + + self.embedding_dim = embedding_dim + self.n_heads = n_heads + self.qkv = nn.Linear(embedding_dim, embedding_dim * 3, bias=False) + self.proj = nn.Linear(embedding_dim, embedding_dim, bias=False) + self.rope = rope + self.register_buffer('bias', torch.tril(torch.ones(max_seq_length, max_seq_length))[None, None, :, :], persistent=False) + + def forward(self, x: Tensor) -> Tensor: + b, t, c = x.size() + + x = self.rope(x) + + q, k, v = self.qkv(x).split(self.embedding_dim, dim=2) + q = q.view(b, t, self.n_heads, c // self.n_heads).transpose(1, 2) + k = k.view(b, t, self.n_heads, c // self.n_heads).transpose(1, 2) + v = v.view(b, t, self.n_heads, c // self.n_heads).transpose(1, 2) + + att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + att = att.masked_fill(self.bias[:, :, :t, :t] == 0, float('-inf')) + att = F.softmax(att, dim=-1) + y = att @ v + y = y.transpose(1, 2).contiguous().view(b, t, c) + + return self.proj(y) + +class FFN(nn.Module): + def __init__(self, embedding_dim: int, intermediate_dim: int | None = None): + super().__init__() + + intermediate_dim = intermediate_dim or embedding_dim * 4 + + self.w1 = nn.Linear(embedding_dim, intermediate_dim * 2, bias=False) + self.w2 = nn.Linear(intermediate_dim, embedding_dim, bias=False) + + def forward(self, x: Tensor) -> Tensor: + x, gate = self.w1(x).chunk(2, dim=-1) + return self.w2(F.gelu(gate) * x) + +class Layer(nn.Module): + def __init__(self, embedding_dim: int, rope: RoPE): + super().__init__() + + self.attn = Attention(embedding_dim, rope=rope) + self.norm1 = RMSNorm(embedding_dim) + self.ffn = FFN(embedding_dim) + self.norm2 = RMSNorm(embedding_dim) + + def forward(self, x: Tensor) -> Tensor: + x = x + self.attn(self.norm1(x)) + x = x + self.ffn(self.norm2(x)) + return x + +class CLM(nn.Module): + def __init__(self, embedding_dim: int, n_layers: int, *, vocab_size: int): + super().__init__() + + rope = RoPE(embedding_dim) + self.layers = nn.ModuleList([Layer(embedding_dim, rope=rope) for _ in range(n_layers)]) + self.word_embeddings = nn.Embedding(vocab_size, embedding_dim) + self.norm = RMSNorm(embedding_dim) + self.lm_head = nn.Linear(embedding_dim, vocab_size, bias=False) + + def forward(self, x: Tensor) -> Tensor: + x = self.word_embeddings(x) + for layer in self.layers: + x = layer(x) + return self.lm_head(self.norm(x)) + +lm = CLM(256, 4, vocab_size=32000) +torch.onnx.export( + lm, + torch.randint(0, 32000, (1, 64)), + f'tools/train-data/mini-clm/model.onnx', + input_names=['input_ids'], + output_names=['probs'], + dynamic_axes={ + 'input_ids': {0: 'batch', 1: 'seq'}, + 'probs': {0: 'batch', 1: 'seq'} + }, + opset_version=14 +) + +onnx_model = onnx.load('tools/train-data/mini-clm/model.onnx') +requires_grad = [param.name for param in onnx_model.graph.initializer] + +artifacts.generate_artifacts( + onnx_model, + requires_grad=requires_grad, + frozen_params=[], + loss=artifacts.LossType.CrossEntropyLoss, + optimizer=artifacts.OptimType.AdamW, + artifact_directory='tools/train-data/mini-clm' +) From 53e69b497b0f42de5d0507bd846fd059a0671acf Mon Sep 17 00:00:00 2001 From: "Carson M." Date: Sun, 19 May 2024 21:15:13 -0500 Subject: [PATCH 03/12] training working --- .gitignore | 5 +- Cargo.toml | 1 + examples/training/Cargo.toml | 16 +++ examples/training/build.rs | 5 + examples/training/examples/train-clm.rs | 31 ++++++ ort-sys/src/lib.rs | 4 +- src/training/mod.rs | 125 +++++++++++++++++++++++- tools/train-data/mini-clm.py | 5 +- 8 files changed, 182 insertions(+), 10 deletions(-) create mode 100644 examples/training/Cargo.toml create mode 100644 examples/training/build.rs create mode 100644 examples/training/examples/train-clm.rs diff --git a/.gitignore b/.gitignore index c85c507..6ab7181 100644 --- a/.gitignore +++ b/.gitignore @@ -198,4 +198,7 @@ WixTools/ /glassbench*.db # Python virtual environment -.venv +.venv* + +# Training checkpoints +tools/train-data/**/checkpoint diff --git a/Cargo.toml b/Cargo.toml index d098f5a..872f431 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,6 +7,7 @@ members = [ 'examples/model-info', 'examples/yolov8', 'examples/modnet', + 'examples/training', 'examples/webassembly' ] default-members = [ diff --git a/examples/training/Cargo.toml b/examples/training/Cargo.toml new file mode 100644 index 0000000..acacc21 --- /dev/null +++ b/examples/training/Cargo.toml @@ -0,0 +1,16 @@ +[package] +publish = false +name = "example-training" +version = "0.0.0" +edition = "2021" + +[dependencies] +ort = { path = "../../", features = [ "training" ] } +ndarray = "0.15" +tokenizers = { version = ">=0.13.4", default-features = false, features = [ "onig" ] } +rand = "0.8" +tracing-subscriber = { version = "0.3", default-features = false, features = [ "env-filter", "fmt" ] } + +[features] +load-dynamic = [ "ort/load-dynamic" ] +cuda = [ "ort/cuda" ] diff --git a/examples/training/build.rs b/examples/training/build.rs new file mode 100644 index 0000000..79d3a0b --- /dev/null +++ b/examples/training/build.rs @@ -0,0 +1,5 @@ +fn main() { + // Need this for CoreML. See: https://ort.pyke.io/perf/execution-providers#coreml + #[cfg(target_os = "macos")] + println!("cargo:rustc-link-arg=-fapple-link-rtlib"); +} diff --git a/examples/training/examples/train-clm.rs b/examples/training/examples/train-clm.rs new file mode 100644 index 0000000..1956dac --- /dev/null +++ b/examples/training/examples/train-clm.rs @@ -0,0 +1,31 @@ +use ndarray::{ArrayView0, Array1, Array2}; +use ort::{Allocator, Checkpoint, SessionBuilder, Trainer}; + +fn main() -> ort::Result<()> { + ort::init().commit()?; + + let trainer = Trainer::new( + SessionBuilder::new()?, + Allocator::default(), + Checkpoint::load("tools/train-data/mini-clm/checkpoint")?, + "tools/train-data/mini-clm/training_model.onnx", + "tools/train-data/mini-clm/eval_model.onnx", + "tools/train-data/mini-clm/optimizer_model.onnx" + )?; + + let optimizer = trainer.optimizer(); + optimizer.set_lr(1e-4)?; + + let inputs = Array2::::from_shape_vec([1, 5], vec![0, 1, 2, 3, 4]).unwrap(); + let labels = Array1::::from_shape_vec([5], vec![1, 2, 3, 4, 5]).unwrap(); + + for _ in 0..50 { + let outputs = trainer.step(ort::inputs![inputs.view()]?, ort::inputs![labels.view()]?)?; + let loss: ArrayView0 = outputs[0].try_extract_tensor::()?.into_dimensionality().unwrap(); + println!("{}", loss.into_scalar()); + optimizer.step()?; + optimizer.reset_grad()?; + } + + Ok(()) +} diff --git a/ort-sys/src/lib.rs b/ort-sys/src/lib.rs index dbd47d9..b40edb2 100644 --- a/ort-sys/src/lib.rs +++ b/ort-sys/src/lib.rs @@ -877,10 +877,10 @@ pub struct OrtTrainingApi { ::std::option::Option<_system!(unsafe fn(sess: *const OrtTrainingSession, out: *mut size_t) -> OrtStatusPtr)>, pub TrainingSessionGetEvalModelOutputCount: ::std::option::Option<_system!(unsafe fn(sess: *const OrtTrainingSession, out: *mut size_t) -> OrtStatusPtr)>, pub TrainingSessionGetTrainingModelOutputName: ::std::option::Option< - _system!(unsafe fn(sess: *const OrtTrainingSession, index: size_t, allocator: *mut OrtAllocator, output: *mut *mut char) -> OrtStatusPtr) + _system!(unsafe fn(sess: *const OrtTrainingSession, index: size_t, allocator: *mut OrtAllocator, output: *mut *mut c_char) -> OrtStatusPtr) >, pub TrainingSessionGetEvalModelOutputName: ::std::option::Option< - _system!(unsafe fn(sess: *const OrtTrainingSession, index: size_t, allocator: *mut OrtAllocator, output: *mut *mut char) -> OrtStatusPtr) + _system!(unsafe fn(sess: *const OrtTrainingSession, index: size_t, allocator: *mut OrtAllocator, output: *mut *mut c_char) -> OrtStatusPtr) >, pub LazyResetGrad: ::std::option::Option<_system!(unsafe fn(session: *mut OrtTrainingSession) -> OrtStatusPtr)>, pub TrainStep: ::std::option::Option< diff --git a/src/training/mod.rs b/src/training/mod.rs index 8f78420..93aa5da 100644 --- a/src/training/mod.rs +++ b/src/training/mod.rs @@ -3,11 +3,13 @@ use std::{ ptr::{self, NonNull}, sync::{ atomic::{AtomicPtr, Ordering}, - OnceLock + Arc, OnceLock } }; -use crate::{ortsys, Error, Result, SessionBuilder}; +use ort_sys::c_char; + +use crate::{char_p_to_string, ortsys, Allocator, Error, Result, RunOptions, SessionBuilder, SessionInputValue, SessionInputs, SessionOutputs, Value}; pub(crate) static TRAINING_API: OnceLock> = OnceLock::new(); @@ -102,15 +104,49 @@ impl Drop for Checkpoint { } } +#[derive(Debug)] +pub struct Optimizer(NonNull); + +impl Optimizer { + pub fn reset_grad(&self) -> Result<()> { + trainsys![unsafe LazyResetGrad(self.0.as_ptr()) -> Error::CreateSession]; + Ok(()) + } + + pub fn lr(&self) -> Result { + let mut lr = f32::NAN; + trainsys![unsafe GetLearningRate(self.0.as_ptr(), &mut lr) -> Error::CreateSession]; + Ok(lr) + } + + pub fn set_lr(&self, lr: f32) -> Result<()> { + trainsys![unsafe SetLearningRate(self.0.as_ptr(), lr) -> Error::CreateSession]; + Ok(()) + } + + pub fn step(&self) -> Result<()> { + self.step_with_options(RunOptions::new()?) + } + + pub fn step_with_options(&self, options: RunOptions) -> Result<()> { + trainsys![unsafe OptimizerStep(self.0.as_ptr(), options.run_options_ptr.as_ptr()) -> Error::CreateSession]; + Ok(()) + } +} + #[derive(Debug)] pub struct Trainer { pub(crate) ptr: NonNull, - ckpt: Checkpoint + train_output_names: Vec, + optimizer: Optimizer, + ckpt: Checkpoint, + _allocator: Allocator } impl Trainer { pub fn new( session_options: SessionBuilder, + allocator: Allocator, ckpt: Checkpoint, training_model_path: impl AsRef, eval_model_path: impl AsRef, @@ -119,15 +155,94 @@ impl Trainer { let training_model_path = crate::util::path_to_os_char(training_model_path); let eval_model_path = crate::util::path_to_os_char(eval_model_path); let optimizer_model_path = crate::util::path_to_os_char(optimizer_model_path); - let mut ptr: *mut ort_sys::OrtTrainingSession = ptr::null_mut(); + let env = crate::get_environment()?; + + let mut ptr: *mut ort_sys::OrtTrainingSession = ptr::null_mut(); trainsys![unsafe CreateTrainingSession(env.ptr(), session_options.session_options_ptr.as_ptr(), ckpt.ptr.as_ptr(), training_model_path.as_ptr(), eval_model_path.as_ptr(), optimizer_model_path.as_ptr(), &mut ptr) -> Error::CreateSession; nonNull(ptr)]; + + let ptr = unsafe { NonNull::new_unchecked(ptr) }; + + let mut train_output_len = 0; + trainsys![unsafe TrainingSessionGetTrainingModelOutputCount(ptr.as_ptr(), &mut train_output_len) -> Error::CreateSession]; + let train_output_names = (0..train_output_len) + .map(|i| { + let mut name_bytes: *mut c_char = std::ptr::null_mut(); + trainsys![unsafe TrainingSessionGetTrainingModelOutputName(ptr.as_ptr(), i, allocator.ptr.as_ptr(), &mut name_bytes) -> Error::CreateSession]; + let name = match char_p_to_string(name_bytes) { + Ok(name) => name, + Err(e) => { + unsafe { allocator.free(name_bytes) }; + return Err(e); + } + }; + unsafe { allocator.free(name_bytes) }; + Ok(name) + }) + .collect::>>()?; + Ok(Self { - ptr: unsafe { NonNull::new_unchecked(ptr) }, + ptr, + _allocator: allocator, + train_output_names, + optimizer: Optimizer(ptr), ckpt }) } + pub fn step<'s, 'i1, 'v1: 'i1, 'i2: 'i1, 'v2: 'i2 + 'i1, const N1: usize, const N2: usize>( + &'s self, + inputs: impl Into>, + labels: impl Into> + ) -> Result> { + match inputs.into() { + SessionInputs::ValueSlice(input_values) => match labels.into() { + SessionInputs::ValueSlice(labels) => self.step_inner(input_values.iter().chain(labels), None), + SessionInputs::ValueArray(labels) => self.step_inner(input_values.iter().chain(labels.iter()), None), + SessionInputs::ValueMap(_) => unimplemented!("named values not supported?") + }, + SessionInputs::ValueArray(input_values) => match labels.into() { + SessionInputs::ValueSlice(labels) => self.step_inner(input_values.iter().chain(labels), None), + SessionInputs::ValueArray(labels) => self.step_inner(input_values.iter().chain(labels.iter()), None), + SessionInputs::ValueMap(_) => unimplemented!("named values not supported?") + }, + SessionInputs::ValueMap(_) => unimplemented!("named values not supported?") + } + } + + fn step_inner<'s, 'i1, 'v1: 'i1, 'i2, 'v2: 'i2>( + &'s self, + input_values: impl Iterator>, + run_options: Option> + ) -> Result> { + let mut output_tensor_ptrs: Vec<*mut ort_sys::OrtValue> = vec![std::ptr::null_mut(); self.train_output_names.len()]; + + let input_ort_values: Vec<*const ort_sys::OrtValue> = input_values.map(|input_array_ort| input_array_ort.ptr().cast_const()).collect(); + + let run_options_ptr = if let Some(run_options) = &run_options { + run_options.run_options_ptr.as_ptr() + } else { + std::ptr::null_mut() + }; + + trainsys![unsafe TrainStep(self.ptr.as_ptr(), run_options_ptr, input_ort_values.len(), input_ort_values.as_ptr(), output_tensor_ptrs.len(), output_tensor_ptrs.as_mut_ptr()) -> Error::SessionRun]; + + let outputs: Vec = output_tensor_ptrs + .into_iter() + .map(|tensor_ptr| unsafe { + // TODO: `Value` should absolutely be refactored to accept a different backing pointer than `SharedSessionInner`. + // but for now, nobody should be using the loss tensor past the lifetime of the trainer... right...? 😣 + Value::from_ptr(NonNull::new(tensor_ptr).expect("OrtValue ptr returned from session Run should not be null"), None) + }) + .collect(); + + Ok(SessionOutputs::new(self.train_output_names.iter().map(|o| o.as_str()), outputs)) + } + + pub fn optimizer(&self) -> &Optimizer { + &self.optimizer + } + pub fn ckpt(&self) -> &Checkpoint { &self.ckpt } diff --git a/tools/train-data/mini-clm.py b/tools/train-data/mini-clm.py index e82d934..d7a7991 100644 --- a/tools/train-data/mini-clm.py +++ b/tools/train-data/mini-clm.py @@ -110,7 +110,8 @@ def forward(self, x: Tensor) -> Tensor: x = self.word_embeddings(x) for layer in self.layers: x = layer(x) - return self.lm_head(self.norm(x)) + logits = self.lm_head(self.norm(x)) + return logits.view(-1, logits.size(-1)) lm = CLM(256, 4, vocab_size=32000) torch.onnx.export( @@ -121,7 +122,7 @@ def forward(self, x: Tensor) -> Tensor: output_names=['probs'], dynamic_axes={ 'input_ids': {0: 'batch', 1: 'seq'}, - 'probs': {0: 'batch', 1: 'seq'} + 'probs': {0: 'batch_seq'} }, opset_version=14 ) From 2aaadda7ea9f69938b3705b9e6eface9538a93ba Mon Sep 17 00:00:00 2001 From: "Carson M." Date: Sun, 19 May 2024 23:11:10 -0500 Subject: [PATCH 04/12] fix non-windows platforms --- src/session/builder.rs | 4 ---- src/util.rs | 7 ++----- 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/src/session/builder.rs b/src/session/builder.rs index e43c961..90580ec 100644 --- a/src/session/builder.rs +++ b/src/session/builder.rs @@ -1,7 +1,3 @@ -#[cfg(unix)] -use std::os::unix::ffi::OsStrExt; -#[cfg(target_family = "windows")] -use std::os::windows::ffi::OsStrExt; #[cfg(feature = "fetch-models")] use std::path::PathBuf; use std::{ diff --git a/src/util.rs b/src/util.rs index 9fa4c0d..bfa11d9 100644 --- a/src/util.rs +++ b/src/util.rs @@ -11,13 +11,10 @@ type OsCharArray = Vec; #[cfg(not(target_family = "windows"))] type OsCharArray = Vec; -pub fn path_to_os_char(path: impl AsRef) -> Vec { +pub fn path_to_os_char(path: impl AsRef) -> OsCharArray { let model_path = OsString::from(path.as_ref()); #[cfg(target_family = "windows")] - let model_path: Vec = model_path - .encode_wide() - .chain(std::iter::once(0)) - .collect(); + let model_path: Vec = model_path.encode_wide().chain(std::iter::once(0)).collect(); #[cfg(not(target_family = "windows"))] let model_path: Vec = model_path .as_encoded_bytes() From f778d7afb95d0ff5a59cbe993267ad7137d1ccaa Mon Sep 17 00:00:00 2001 From: "Carson M." Date: Wed, 29 May 2024 16:27:21 -0500 Subject: [PATCH 05/12] add training dists --- Cargo.toml | 2 +- examples/training/examples/train-clm.rs | 6 +++-- ort-sys/Cargo.toml | 2 +- ort-sys/build.rs | 32 +++++++++++++++---------- ort-sys/dist.txt | 17 +++++++++++++ tools/requirements.txt | 2 +- 6 files changed, 44 insertions(+), 17 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 872f431..c1a70bb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -53,7 +53,7 @@ rustdoc-args = [ "--cfg", "docsrs" ] [features] default = [ "ndarray", "half", "download-binaries", "copy-dylibs" ] -training = [] +training = [ "ort-sys/training" ] operator-libraries = [ "libc", "winapi" ] diff --git a/examples/training/examples/train-clm.rs b/examples/training/examples/train-clm.rs index 1956dac..4d3c9ab 100644 --- a/examples/training/examples/train-clm.rs +++ b/examples/training/examples/train-clm.rs @@ -1,11 +1,13 @@ use ndarray::{ArrayView0, Array1, Array2}; -use ort::{Allocator, Checkpoint, SessionBuilder, Trainer}; +use ort::{Allocator, Checkpoint, CUDAExecutionProvider, SessionBuilder, Trainer}; fn main() -> ort::Result<()> { + tracing_subscriber::fmt::init(); + ort::init().commit()?; let trainer = Trainer::new( - SessionBuilder::new()?, + SessionBuilder::new()?.with_execution_providers([CUDAExecutionProvider::default().build()])?, Allocator::default(), Checkpoint::load("tools/train-data/mini-clm/checkpoint")?, "tools/train-data/mini-clm/training_model.onnx", diff --git a/ort-sys/Cargo.toml b/ort-sys/Cargo.toml index 9a6a3cf..abe657f 100644 --- a/ort-sys/Cargo.toml +++ b/ort-sys/Cargo.toml @@ -16,6 +16,7 @@ include = [ "src/", "dist.txt", "build.rs", "LICENSE-APACHE", "LICENSE-MIT" ] [features] default = [] +training = [] download-binaries = [ "ureq", "tar", "flate2", "sha2" ] load-dynamic = [] copy-dylibs = [] @@ -38,7 +39,6 @@ vitis = [] cann = [] qnn = [] - [build-dependencies] ureq = { version = "2.1", optional = true, default-features = false, features = [ "tls" ] } tar = { version = "0.4", optional = true } diff --git a/ort-sys/build.rs b/ort-sys/build.rs index 440023c..29f1920 100644 --- a/ort-sys/build.rs +++ b/ort-sys/build.rs @@ -37,12 +37,12 @@ fn fetch_file(source_url: &str) -> Vec { buffer } -fn find_dist(target: &str, designator: &str) -> Option<(&'static str, &'static str)> { +fn find_dist(target: &str, feature_set: &str) -> Option<(&'static str, &'static str)> { DIST_TABLE .split('\n') .filter(|c| !c.is_empty() && !c.starts_with('#')) .map(|c| c.split('\t').collect::>()) - .find(|c| c[0] == designator && c[1] == target) + .find(|c| c[0] == feature_set && c[1] == target) .map(|c| (c[2], c[3])) } @@ -317,23 +317,31 @@ fn prepare_libort_dir() -> (PathBuf, bool) { #[cfg(feature = "download-binaries")] { let target = env::var("TARGET").unwrap().to_string(); - let designator = if cfg!(any(feature = "cuda", feature = "tensorrt")) { - if lib_exists("cudart64_12.dll") || lib_exists("libcudart.so.12") { "cu12" } else { "cu11" } + + let mut feature_set = Vec::new(); + if cfg!(feature = "training") { + feature_set.push("train"); + } + if cfg!(any(feature = "cuda", feature = "tensorrt")) { + if lib_exists("cudart64_11.dll") || lib_exists("libcudart.so.11") || env::var("ORT_DFBIN_FORCE_CUDA_VERSION").as_deref() == Ok("11") { + feature_set.push("cu11"); + } else { + feature_set.push("cu12"); + } } else if cfg!(feature = "rocm") { - "rocm" - } else { - "none" - }; - let mut dist = find_dist(&target, designator); - if dist.is_none() && designator != "none" { + feature_set.push("rocm"); + } + let feature_set = if !feature_set.is_empty() { feature_set.join(",") } else { "none".to_owned() }; + let mut dist = find_dist(&target, &feature_set); + if dist.is_none() && feature_set != "none" { dist = find_dist(&target, "none"); } if dist.is_none() { panic!( "downloaded binaries not available for target {target}{}\nyou may have to compile ONNX Runtime from source", - if designator != "none" { - format!(" (note: also requested `{designator}`)") + if feature_set != "none" { + format!(" (note: also requested features `{feature_set}`)") } else { String::new() } diff --git a/ort-sys/dist.txt b/ort-sys/dist.txt index 066a2b5..ffddef9 100644 --- a/ort-sys/dist.txt +++ b/ort-sys/dist.txt @@ -4,12 +4,29 @@ cu11 x86_64-unknown-linux-gnu https://parcel.pyke.io/v2/delivery/ortrs/packages/ rocm x86_64-unknown-linux-gnu https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.0/ortrs-msort_dylib_rocm-v1.18.0-x86_64-unknown-linux-gnu.tgz D6113A895DEB0BCBC28FD7E23A201DE4C5FBA6BADEB49F3190A084A36C24B43D none x86_64-unknown-linux-gnu https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.0/ortrs-msort_static-v1.18.0-x86_64-unknown-linux-gnu.tgz F486F4B9F040FF533DCD6B26E074BEB5F9092E8E4C67F72D08696D9EB4C9C082 +train aarch64-unknown-linux-gnu https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.0/ortrs-msort_static_train-v1.18.0-aarch64-unknown-linux-gnu.tgz 3E86449F62C3E775CD5B9681FB27A299E9801221D2BF4827FF0E951C649A78BC +train,cu12 x86_64-unknown-linux-gnu https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.0/ortrs-msort_dylib_train,cu12-v1.18.0-x86_64-unknown-linux-gnu.tgz 34B76361A6679D7407ABDE0AA17863B57198D36B662CD0CE8AC783B9272EB4D6 +train,cu11 x86_64-unknown-linux-gnu https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.0/ortrs-msort_dylib_train,cu11-v1.18.0-x86_64-unknown-linux-gnu.tgz D5DACFBF7F27EFEB887227AA48FFD10527A92090031680648620B2241080A41B +train,rocm x86_64-unknown-linux-gnu https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.0/ortrs-msort_dylib_train,rocm-v1.18.0-x86_64-unknown-linux-gnu.tgz 467B3B5D05A307CEE73B9A159C479178CAD993F8CB043A7008C4847ADBC32258 +train x86_64-unknown-linux-gnu https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.0/ortrs-msort_static_train-v1.18.0-x86_64-unknown-linux-gnu.tgz 91ADC91DC9D6E04F1A1C7B97934D6D27ABE8869A1BB329A0BE31F97A2364750D + none aarch64-pc-windows-msvc https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.0/ortrs-msort_static-v1.18.0-aarch64-pc-windows-msvc.tgz 9A1BF23A73D680290B52C22AAD039B490AC5AAA66FC21C06343A41369747B514 cu12 x86_64-pc-windows-msvc https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.0/ortrs-msort_dylib_cu12-v1.18.0-x86_64-pc-windows-msvc.tgz A9457AC9AC5D6BE1F98B3BEE3D6AF5C074C9984F7CC7D1E660EA8082EBF65D48 cu11 x86_64-pc-windows-msvc https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.0/ortrs-msort_dylib_cu11-v1.18.0-x86_64-pc-windows-msvc.tgz C5C62263BDD82B58ED15A6467D0729B21F26E78EA0E49E1E5197ECBA80783903 none x86_64-pc-windows-msvc https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.0/ortrs-msort_static-v1.18.0-x86_64-pc-windows-msvc.tgz 08A22E94EBA56BF30ECBB2DC9DD9F90A4583C8372BAFC7FE3DAB6C28A06544CE +train aarch64-pc-windows-msvc https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.0/ortrs-msort_static_train-v1.18.0-aarch64-pc-windows-msvc.tgz 35E39BD9D874EAAB35A58F21E7EA57F752AB297EFB05D9F9448E86AA9BDB0E31 +train,cu12 x86_64-pc-windows-msvc https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.0/ortrs-msort_dylib_train,cu12-v1.18.0-x86_64-pc-windows-msvc.tgz BDD34F48A26F9DF4366325B911DB39660D222D71BC8F7EF582E894C48E4791AF +train,cu11 x86_64-pc-windows-msvc https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.0/ortrs-msort_dylib_train,cu11-v1.18.0-x86_64-pc-windows-msvc.tgz 95B987F4384D3EE398E6CE871CB31B2062D82BCBB877657F461AA4B2CAC8BEEB +train x86_64-pc-windows-msvc https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.0/ortrs-msort_static_train-v1.18.0-x86_64-pc-windows-msvc.tgz EBA42354420E4E87AD77BE81B473597316841C0A3707DF231D4381771E76F1E7 + none aarch64-apple-darwin https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.0/ortrs-msort_static-v1.18.0-aarch64-apple-darwin.tgz F8DB068DFACFE3B00B9F0181B79780C6971CD1A6EAEB9D9A7FC2129CEB8413A5 none x86_64-apple-darwin https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.0/ortrs-msort_static-v1.18.0-x86_64-apple-darwin.tgz E6E0457CB9C727DBA818D10245D3A2A29203CB037546B39C217E4CC9FB61ABE8 +train aarch64-apple-darwin https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.0/ortrs-msort_static_train-v1.18.0-aarch64-apple-darwin.tgz 793B996C4CAFDCB294ECEA3E602D28F8272164EE0C240EAC9AE8D3160913BFF0 +train x86_64-apple-darwin https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.0/ortrs-msort_static_train-v1.18.0-x86_64-apple-darwin.tgz 3A49CBA8FC7388E556F5E16BFD1DAB7A32F4FD1712C53B93013DE5C1C870A2F2 + none wasm32-unknown-unknown https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.0/ortrs-pkort_static-v1.18.0-wasm32-unknown-unknown.tgz 8AB76874E977961A1CFA9714973521AA1B85F0F40D31EF38492CCA659BE58BF5 +none wasm32-wasi https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.0/ortrs-pkort_static-v1.18.0-wasm32-unknown-unknown.tgz 8AB76874E977961A1CFA9714973521AA1B85F0F40D31EF38492CCA659BE58BF5 +none wasm32-wasi-preview1 https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.0/ortrs-pkort_static-v1.18.0-wasm32-unknown-unknown.tgz 8AB76874E977961A1CFA9714973521AA1B85F0F40D31EF38492CCA659BE58BF5 +none wasm32-wasi-preview2 https://parcel.pyke.io/v2/delivery/ortrs/packages/msort-binary/1.18.0/ortrs-pkort_static-v1.18.0-wasm32-unknown-unknown.tgz 8AB76874E977961A1CFA9714973521AA1B85F0F40D31EF38492CCA659BE58BF5 diff --git a/tools/requirements.txt b/tools/requirements.txt index 1d17d54..d49cd91 100644 --- a/tools/requirements.txt +++ b/tools/requirements.txt @@ -1,4 +1,4 @@ torch~=2.3 torch-ort~=1.17 onnx~=1.16 -onnxruntime~=1.17 +--extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ORT/pypi/simple/ onnxruntime-training-cpu==1.18.0 From 71692bd43e790ea2928d1ec59d69bcf21cfa400d Mon Sep 17 00:00:00 2001 From: "Carson M." Date: Wed, 29 May 2024 20:58:17 -0500 Subject: [PATCH 06/12] slightly more useful train-clm example --- examples/training/Cargo.toml | 1 + examples/training/README.md | 32 +++++++ examples/training/examples/pretokenize.rs | 44 +++++++++ examples/training/examples/train-clm.rs | 111 ++++++++++++++++++++-- ort-sys/src/lib.rs | 12 ++- src/training/mod.rs | 84 +++++++++++++++- tools/train-data/mini-clm.py | 4 +- 7 files changed, 276 insertions(+), 12 deletions(-) create mode 100644 examples/training/README.md create mode 100644 examples/training/examples/pretokenize.rs diff --git a/examples/training/Cargo.toml b/examples/training/Cargo.toml index acacc21..5bb177e 100644 --- a/examples/training/Cargo.toml +++ b/examples/training/Cargo.toml @@ -9,6 +9,7 @@ ort = { path = "../../", features = [ "training" ] } ndarray = "0.15" tokenizers = { version = ">=0.13.4", default-features = false, features = [ "onig" ] } rand = "0.8" +simd-json = "0.13" tracing-subscriber = { version = "0.3", default-features = false, features = [ "env-filter", "fmt" ] } [features] diff --git a/examples/training/README.md b/examples/training/README.md new file mode 100644 index 0000000..64e1724 --- /dev/null +++ b/examples/training/README.md @@ -0,0 +1,32 @@ +# Training Examples + +## `train-clm` +This example trains a tiny causal language model on a small subset of pyke's [**OshiChats v2**](https://huggingface.co/datasets/pykeio/oshichats-v2), a dataset of live text chat messages collected from various [VTuber](https://en.wikipedia.org/wiki/VTuber) live streams. The model is not particularly useful or interesting (due to both the low-quality dataset and small model size), but it showcases that entire language models can be trained from scratch entirely in Rust on (almost) any device. + +To get started, create a Python virtual environment and install the following packages: +``` +pip install -i https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ORT/pypi/simple/ onnxruntime-training-cpu==1.18.0 onnx~=1.17 torch~=2.3 +``` + +We're installing the CPU version of the `onnxruntime-training` & `torch` packages because we only need to use Python to *create* the initial graph which will be used for training. Run `python tools/train-data/mini-clm.py` from the root directory of the `ort` repo to create the training artifacts. + +Next, we need to convert our dataset into tokens to feed the model. This can be achieved by downloading the `oshicats-v2.jsonl` file from the OshiChats v2 dataset and running `cargo run -p example-training --example pretokenize -- ~/oshichats-v2.jsonl`, or if you (rightfully) don't wish to waste 30 GB worth of disk space and bandwidth on brainrot, you may download a [1 MB pre-tokenized subset of the dataset](https://parcel.pyke.io/v2/cdn/assetdelivery/ortrsv2/ex_data/dataset.bin). Make sure `dataset.bin` is in the root of the `ort` repo. + +Finally, we can train our model! Run `cargo run -p example-training --example train-clm` to start training. If you have an NVIDIA GPU, add `--features cuda` to enable CUDA, though it's not required and you can train directly on CPU instead. **This will use ~8 GB of (V)RAM!** You can lower the memory usage by adjusting the `BATCH_SIZE` and `SEQUENCE_LENGTH` constants in `train-clm.rs`, though note that changing the batch size may require adjustments to the learning rate. + +While training, the program prints the cross-entropy loss at each training step. At the end of training, the final trained model will be saved to `trained-clm.onnx`, and the program will use the model to generate a small snippet of text: +``` +2.5692816 +2.718276 +2.4533236 +2.776122 +2.7698023 +2.5013578 +2.4010067 +2.7219558 +2.8342185 +2.660532 +I'm so much better than the game<|endoftext|>I think you can't see it<|endoftext|>I think you can't see it<|endoftext|>I think so it's a new game<|endoftext|>I think I'm sure you can't see what you can't see it<|endoftext|> +``` + +Not bad, considering the model & dataset size! This example can easily be scaled up to pre-train or fine-tune (both full-parameter and PEFT) larger language models like Llama/Phi, so long as you have enough compute. diff --git a/examples/training/examples/pretokenize.rs b/examples/training/examples/pretokenize.rs new file mode 100644 index 0000000..79eee19 --- /dev/null +++ b/examples/training/examples/pretokenize.rs @@ -0,0 +1,44 @@ +use std::{ + env, + fs::File, + io::{BufRead, BufReader, BufWriter, Write}, + path::Path +}; + +use simd_json::derived::ValueObjectAccessAsScalar; +use tokenizers::Tokenizer; + +const MAX_TOKENS: usize = 500_000; + +fn main() { + let input = env::args().nth(1).expect("provide input jsonl"); + let output = env::args().nth(2).unwrap_or_else(|| "dataset.bin".into()); + + let input = BufReader::new(File::open(input).unwrap()); + let mut output = BufWriter::new(File::create(output).unwrap()); + + let tokenizer = Tokenizer::from_file( + Path::new(env!("CARGO_MANIFEST_DIR")) + .parent() + .unwrap() + .join("gpt2") + .join("data") + .join("tokenizer.json") + ) + .unwrap(); + let mut bytes_written = 0; + + for line in input.lines() { + let line: simd_json::OwnedValue = unsafe { simd_json::from_str(&mut line.unwrap()).unwrap() }; + let tokenized = tokenizer + .encode(format!("<|endoftext|>{}", line.get_str("message").unwrap()), false) + .unwrap(); + let id_bytes: Vec = tokenized.get_ids().iter().flat_map(|c| (*c as u16).to_le_bytes()).collect(); + output.write_all(&id_bytes).unwrap(); + bytes_written += id_bytes.len(); + if bytes_written >= MAX_TOKENS * 2 { + output.flush().unwrap(); + break; + } + } +} diff --git a/examples/training/examples/train-clm.rs b/examples/training/examples/train-clm.rs index 4d3c9ab..9b23c32 100644 --- a/examples/training/examples/train-clm.rs +++ b/examples/training/examples/train-clm.rs @@ -1,5 +1,16 @@ -use ndarray::{ArrayView0, Array1, Array2}; -use ort::{Allocator, Checkpoint, CUDAExecutionProvider, SessionBuilder, Trainer}; +use std::{ + fs::File, + io::{Read, Seek, SeekFrom, Write}, + path::Path +}; + +use ndarray::{concatenate, s, Array1, Array2, ArrayViewD, Axis, Ix0}; +use ort::{Allocator, CUDAExecutionProvider, Checkpoint, Session, SessionBuilder, Trainer}; +use rand::RngCore; +use tokenizers::Tokenizer; + +const BATCH_SIZE: usize = 16; +const SEQUENCE_LENGTH: usize = 256; fn main() -> ort::Result<()> { tracing_subscriber::fmt::init(); @@ -15,19 +26,103 @@ fn main() -> ort::Result<()> { "tools/train-data/mini-clm/optimizer_model.onnx" )?; + let tokenizer = Tokenizer::from_file( + Path::new(env!("CARGO_MANIFEST_DIR")) + .parent() + .unwrap() + .join("gpt2") + .join("data") + .join("tokenizer.json") + ) + .unwrap(); + let optimizer = trainer.optimizer(); - optimizer.set_lr(1e-4)?; + optimizer.set_lr(7e-5)?; - let inputs = Array2::::from_shape_vec([1, 5], vec![0, 1, 2, 3, 4]).unwrap(); - let labels = Array1::::from_shape_vec([5], vec![1, 2, 3, 4, 5]).unwrap(); + let mut dataset = File::open("dataset.bin").unwrap(); + let file_size = dataset.metadata().unwrap().len(); + let num_tokens = (file_size / 2) as usize; // 16-bit tokens + let mut rng = rand::thread_rng(); + + let mut input_buffer = vec![0u16; SEQUENCE_LENGTH * BATCH_SIZE]; + let mut label_buffer = vec![0u16; SEQUENCE_LENGTH * BATCH_SIZE]; + for _ in 0..5000 { + for batch in 0..BATCH_SIZE { + let start_idx = rng.next_u64() % (num_tokens - SEQUENCE_LENGTH - 1) as u64; + dataset.seek(SeekFrom::Start(start_idx * 2)).unwrap(); + dataset + .read_exact(unsafe { + std::slice::from_raw_parts_mut( + input_buffer[batch * SEQUENCE_LENGTH..(batch + 1) * SEQUENCE_LENGTH] + .as_mut_ptr() + .cast::(), + SEQUENCE_LENGTH * 2 + ) + }) + .unwrap(); + dataset.seek(SeekFrom::Start((start_idx + 1) * 2)).unwrap(); + dataset + .read_exact(unsafe { + std::slice::from_raw_parts_mut( + label_buffer[batch * SEQUENCE_LENGTH..(batch + 1) * SEQUENCE_LENGTH] + .as_mut_ptr() + .cast::(), + SEQUENCE_LENGTH * 2 + ) + }) + .unwrap(); + } + + let inputs = Array2::::from_shape_vec([BATCH_SIZE, SEQUENCE_LENGTH], input_buffer.iter().map(|c| *c as i64).collect()).unwrap(); + let labels = Array1::::from_shape_vec([BATCH_SIZE * SEQUENCE_LENGTH], label_buffer.iter().map(|c| *c as i64).collect()).unwrap(); - for _ in 0..50 { let outputs = trainer.step(ort::inputs![inputs.view()]?, ort::inputs![labels.view()]?)?; - let loss: ArrayView0 = outputs[0].try_extract_tensor::()?.into_dimensionality().unwrap(); - println!("{}", loss.into_scalar()); + let loss = outputs[0] + .try_extract_tensor::()? + .into_dimensionality::() + .unwrap() + .into_scalar(); + println!("{}", loss); + if loss.is_nan() { + return Ok(()); + } optimizer.step()?; optimizer.reset_grad()?; } + trainer.export("trained-clm.onnx", ["probs"])?; + + let session = Session::builder()?.commit_from_file("trained-clm.onnx")?; + + let mut stdout = std::io::stdout(); + + let tokens = tokenizer.encode("<|endoftext|>", false).unwrap(); + let tokens = tokens.get_ids().iter().map(|i| *i as i64).collect::>(); + + let mut tokens = Array1::from_iter(tokens.iter().cloned()); + + for _ in 0..50 { + let array = tokens.view().insert_axis(Axis(0)); + let outputs = session.run(ort::inputs![array]?)?; + let generated_tokens: ArrayViewD = outputs["probs"].try_extract_tensor()?; + + let probabilities = &mut generated_tokens + .slice(s![-1, ..]) + .to_owned() + .iter() + .cloned() + .enumerate() + .collect::>(); + probabilities.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Less)); + + let token = probabilities[0].0; + tokens = concatenate![Axis(0), tokens, ndarray::array![token.try_into().unwrap()]]; + + let token_str = tokenizer.decode(&[token as _], false).unwrap(); + print!("{}", token_str); + stdout.flush().unwrap(); + } + + println!(); Ok(()) } diff --git a/ort-sys/src/lib.rs b/ort-sys/src/lib.rs index f570005..f7cb853 100644 --- a/ort-sys/src/lib.rs +++ b/ort-sys/src/lib.rs @@ -922,7 +922,17 @@ pub struct OrtTrainingApi { pub CopyBufferToParameters: ::std::option::Option<_system!(unsafe fn(session: *mut OrtTrainingSession, parameters_buffer: *mut OrtValue, trainable_only: bool) -> OrtStatusPtr)>, pub ReleaseTrainingSession: ::std::option::Option<_system!(unsafe fn(input: *mut OrtTrainingSession))>, - pub ReleaseCheckpointState: ::std::option::Option<_system!(unsafe fn(input: *mut OrtCheckpointState))> + pub ReleaseCheckpointState: ::std::option::Option<_system!(unsafe fn(input: *mut OrtCheckpointState))>, + pub ExportModelForInferencing: ::std::option::Option< + _system!( + unsafe fn( + session: *mut OrtTrainingSession, + inference_model_path: *const ortchar, + graph_outputs_len: usize, + graph_output_names: *const *const c_char + ) -> OrtStatusPtr + ) + > } #[doc = " \\brief The helper interface to get the right version of OrtApi\n\n Get a pointer to this structure through ::OrtGetApiBase"] #[repr(C)] diff --git a/src/training/mod.rs b/src/training/mod.rs index 93aa5da..a951715 100644 --- a/src/training/mod.rs +++ b/src/training/mod.rs @@ -1,4 +1,5 @@ use std::{ + ffi::CString, path::Path, ptr::{self, NonNull}, sync::{ @@ -9,7 +10,11 @@ use std::{ use ort_sys::c_char; -use crate::{char_p_to_string, ortsys, Allocator, Error, Result, RunOptions, SessionBuilder, SessionInputValue, SessionInputs, SessionOutputs, Value}; +use crate::{ + char_p_to_string, + error::{assert_non_null_pointer, status_to_result}, + ortsys, Allocator, Error, Result, RunOptions, SessionBuilder, SessionInputValue, SessionInputs, SessionOutputs, Value +}; pub(crate) static TRAINING_API: OnceLock> = OnceLock::new(); @@ -239,6 +244,83 @@ impl Trainer { Ok(SessionOutputs::new(self.train_output_names.iter().map(|o| o.as_str()), outputs)) } + pub fn eval<'s, 'i1, 'v1: 'i1, 'i2: 'i1, 'v2: 'i2 + 'i1, const N1: usize, const N2: usize>( + &'s self, + inputs: impl Into>, + labels: impl Into> + ) -> Result> { + match inputs.into() { + SessionInputs::ValueSlice(input_values) => match labels.into() { + SessionInputs::ValueSlice(labels) => self.eval_inner(input_values.iter().chain(labels), None), + SessionInputs::ValueArray(labels) => self.eval_inner(input_values.iter().chain(labels.iter()), None), + SessionInputs::ValueMap(_) => unimplemented!("named values not supported?") + }, + SessionInputs::ValueArray(input_values) => match labels.into() { + SessionInputs::ValueSlice(labels) => self.eval_inner(input_values.iter().chain(labels), None), + SessionInputs::ValueArray(labels) => self.eval_inner(input_values.iter().chain(labels.iter()), None), + SessionInputs::ValueMap(_) => unimplemented!("named values not supported?") + }, + SessionInputs::ValueMap(_) => unimplemented!("named values not supported?") + } + } + + fn eval_inner<'s, 'i1, 'v1: 'i1, 'i2, 'v2: 'i2>( + &'s self, + input_values: impl Iterator>, + run_options: Option> + ) -> Result> { + let mut output_tensor_ptrs: Vec<*mut ort_sys::OrtValue> = vec![std::ptr::null_mut(); self.train_output_names.len()]; + + let input_ort_values: Vec<*const ort_sys::OrtValue> = input_values.map(|input_array_ort| input_array_ort.ptr().cast_const()).collect(); + + let run_options_ptr = if let Some(run_options) = &run_options { + run_options.run_options_ptr.as_ptr() + } else { + std::ptr::null_mut() + }; + + trainsys![unsafe EvalStep(self.ptr.as_ptr(), run_options_ptr, input_ort_values.len(), input_ort_values.as_ptr(), output_tensor_ptrs.len(), output_tensor_ptrs.as_mut_ptr()) -> Error::SessionRun]; + + let outputs: Vec = output_tensor_ptrs + .into_iter() + .map(|tensor_ptr| unsafe { + // TODO: `Value` should absolutely be refactored to accept a different backing pointer than `SharedSessionInner`. + // but for now, nobody should be using the loss tensor past the lifetime of the trainer... right...? 😣 + Value::from_ptr(NonNull::new(tensor_ptr).expect("OrtValue ptr returned from session Run should not be null"), None) + }) + .collect(); + + Ok(SessionOutputs::new(self.train_output_names.iter().map(|o| o.as_str()), outputs)) + } + + pub fn export>(&self, out_path: impl AsRef, output_names: impl AsRef<[O]>) -> Result<()> { + let out_path = crate::util::path_to_os_char(out_path); + + let output_names_ptr: Vec<*const c_char> = output_names + .as_ref() + .iter() + .map(|output| CString::new(output.as_ref()).unwrap_or_else(|_| unreachable!())) + .map(|n| n.into_raw().cast_const()) + .collect(); + + let res = trainsys![unsafe ExportModelForInferencing(self.ptr.as_ptr(), out_path.as_ptr(), output_names_ptr.len(), output_names_ptr.as_ptr())]; + + // Reconvert name ptrs to CString so drop impl is called and memory is freed + drop( + output_names_ptr + .into_iter() + .map(|p| { + assert_non_null_pointer(p, "c_char for CString")?; + unsafe { Ok(CString::from_raw(p.cast_mut().cast())) } + }) + .collect::>>()? + ); + + status_to_result(res).map_err(Error::CreateSession)?; + + Ok(()) + } + pub fn optimizer(&self) -> &Optimizer { &self.optimizer } diff --git a/tools/train-data/mini-clm.py b/tools/train-data/mini-clm.py index d7a7991..6f06a70 100644 --- a/tools/train-data/mini-clm.py +++ b/tools/train-data/mini-clm.py @@ -113,10 +113,10 @@ def forward(self, x: Tensor) -> Tensor: logits = self.lm_head(self.norm(x)) return logits.view(-1, logits.size(-1)) -lm = CLM(256, 4, vocab_size=32000) +lm = CLM(256, 4, vocab_size=50257) torch.onnx.export( lm, - torch.randint(0, 32000, (1, 64)), + torch.randint(0, 50256, (1, 64)), f'tools/train-data/mini-clm/model.onnx', input_names=['input_ids'], output_names=['probs'], From 4aa2301a47bf74307d55e04b46be156164d0d802 Mon Sep 17 00:00:00 2001 From: "Carson M." Date: Thu, 30 May 2024 10:03:08 -0500 Subject: [PATCH 07/12] make train-clm example a little nicer --- examples/training/Cargo.toml | 1 + examples/training/README.md | 13 ++----------- examples/training/examples/train-clm.rs | 11 ++++++++++- 3 files changed, 13 insertions(+), 12 deletions(-) diff --git a/examples/training/Cargo.toml b/examples/training/Cargo.toml index 5bb177e..945f62e 100644 --- a/examples/training/Cargo.toml +++ b/examples/training/Cargo.toml @@ -10,6 +10,7 @@ ndarray = "0.15" tokenizers = { version = ">=0.13.4", default-features = false, features = [ "onig" ] } rand = "0.8" simd-json = "0.13" +kdam = "0.5" tracing-subscriber = { version = "0.3", default-features = false, features = [ "env-filter", "fmt" ] } [features] diff --git a/examples/training/README.md b/examples/training/README.md index 64e1724..7ae0fef 100644 --- a/examples/training/README.md +++ b/examples/training/README.md @@ -14,18 +14,9 @@ Next, we need to convert our dataset into tokens to feed the model. This can be Finally, we can train our model! Run `cargo run -p example-training --example train-clm` to start training. If you have an NVIDIA GPU, add `--features cuda` to enable CUDA, though it's not required and you can train directly on CPU instead. **This will use ~8 GB of (V)RAM!** You can lower the memory usage by adjusting the `BATCH_SIZE` and `SEQUENCE_LENGTH` constants in `train-clm.rs`, though note that changing the batch size may require adjustments to the learning rate. -While training, the program prints the cross-entropy loss at each training step. At the end of training, the final trained model will be saved to `trained-clm.onnx`, and the program will use the model to generate a small snippet of text: +While training, the progress bar will show the cross-entropy loss at each training step. At the end of training, the final trained model will be saved to `trained-clm.onnx`, and the program will use the model to generate a small snippet of text: ``` -2.5692816 -2.718276 -2.4533236 -2.776122 -2.7698023 -2.5013578 -2.4010067 -2.7219558 -2.8342185 -2.660532 +100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 5000/5000 [06:29<00:00, 12.83it/s, loss=3.611] I'm so much better than the game<|endoftext|>I think you can't see it<|endoftext|>I think you can't see it<|endoftext|>I think so it's a new game<|endoftext|>I think I'm sure you can't see what you can't see it<|endoftext|> ``` diff --git a/examples/training/examples/train-clm.rs b/examples/training/examples/train-clm.rs index 9b23c32..54547dc 100644 --- a/examples/training/examples/train-clm.rs +++ b/examples/training/examples/train-clm.rs @@ -4,6 +4,7 @@ use std::{ path::Path }; +use kdam::BarExt; use ndarray::{concatenate, s, Array1, Array2, ArrayViewD, Axis, Ix0}; use ort::{Allocator, CUDAExecutionProvider, Checkpoint, Session, SessionBuilder, Trainer}; use rand::RngCore; @@ -17,6 +18,9 @@ fn main() -> ort::Result<()> { ort::init().commit()?; + kdam::term::init(true); + let _ = kdam::term::hide_cursor(); + let trainer = Trainer::new( SessionBuilder::new()?.with_execution_providers([CUDAExecutionProvider::default().build()])?, Allocator::default(), @@ -46,6 +50,7 @@ fn main() -> ort::Result<()> { let mut input_buffer = vec![0u16; SEQUENCE_LENGTH * BATCH_SIZE]; let mut label_buffer = vec![0u16; SEQUENCE_LENGTH * BATCH_SIZE]; + let mut pb = kdam::tqdm!(total = 5000); for _ in 0..5000 { for batch in 0..BATCH_SIZE { let start_idx = rng.next_u64() % (num_tokens - SEQUENCE_LENGTH - 1) as u64; @@ -82,7 +87,8 @@ fn main() -> ort::Result<()> { .into_dimensionality::() .unwrap() .into_scalar(); - println!("{}", loss); + pb.set_postfix(format!("loss={loss:.3}")); + pb.update(1).unwrap(); if loss.is_nan() { return Ok(()); } @@ -90,6 +96,9 @@ fn main() -> ort::Result<()> { optimizer.reset_grad()?; } + eprintln!(); + let _ = kdam::term::show_cursor(); + trainer.export("trained-clm.onnx", ["probs"])?; let session = Session::builder()?.commit_from_file("trained-clm.onnx")?; From 692ffbb9cd2ed70435d45195ff20d49a0c539692 Mon Sep 17 00:00:00 2001 From: Florian Kasischke Date: Wed, 5 Jun 2024 19:08:46 +0200 Subject: [PATCH 08/12] feat(sys): support `pkg-config` Co-authored-by: Florian Kasischke --- ort-sys/Cargo.toml | 1 + ort-sys/build.rs | 43 ++++++++++++++++++++++++++++++++++++------- 2 files changed, 37 insertions(+), 7 deletions(-) diff --git a/ort-sys/Cargo.toml b/ort-sys/Cargo.toml index abe657f..f1f0a28 100644 --- a/ort-sys/Cargo.toml +++ b/ort-sys/Cargo.toml @@ -44,3 +44,4 @@ ureq = { version = "2.1", optional = true, default-features = false, features = tar = { version = "0.4", optional = true } flate2 = { version = "1.0", optional = true } sha2 = { version = "0.10", optional = true } +pkg-config = "0.3.30" \ No newline at end of file diff --git a/ort-sys/build.rs b/ort-sys/build.rs index 29f1920..e357a03 100644 --- a/ort-sys/build.rs +++ b/ort-sys/build.rs @@ -384,6 +384,29 @@ fn prepare_libort_dir() -> (PathBuf, bool) { } } +fn try_setup_with_pkg_config() -> bool { + match pkg_config::Config::new().probe("libonnxruntime") { + Ok(lib) => { + // Setting the link paths + for path in lib.link_paths { + println!("cargo:rustc-link-search=native={}", path.display()); + } + + // Setting the libraries to link against + for lib in lib.libs { + println!("cargo:rustc-link-lib={}", lib); + } + + println!("Using onnxruntime found by pkg-config."); + true + } + Err(_) => { + println!("onnxruntime not found using pkg-config, falling back to manual setup."); + false + } + } +} + fn real_main(link: bool) { println!("cargo:rerun-if-env-changed={}", ORT_ENV_SYSTEM_LIB_LOCATION); println!("cargo:rerun-if-env-changed={}", ORT_ENV_SYSTEM_LIB_PROFILE); @@ -408,14 +431,20 @@ fn main() { } if cfg!(feature = "load-dynamic") { - // we only need to execute the real main step if we are using the download strategy... - if cfg!(feature = "download-binaries") && env::var(ORT_ENV_SYSTEM_LIB_LOCATION).is_err() { - // but we don't need to link to the binaries we download (so all we are doing is downloading them and placing them in - // the output directory) - real_main(false); + if !try_setup_with_pkg_config() { + // Only execute the real main step if pkg-config fails and if we are using the download + // strategy + if cfg!(feature = "download-binaries") && env::var(ORT_ENV_SYSTEM_LIB_LOCATION).is_err() { + // but we don't need to link to the binaries we download (so all we are doing is + // downloading them and placing them in the output directory) + real_main(false); // but we don't need to link to the binaries we download + } } } else { - // if we are not using the load-dynamic feature then we need to link to dylibs. - real_main(true); + // If pkg-config setup was successful, we don't need further action + // Otherwise, if we are not using the load-dynamic feature, we need to link to the dylibs. + if !try_setup_with_pkg_config() { + real_main(true); + } } } From a60bda0b7080b8b8becc2bdfb6e10ab715ae69eb Mon Sep 17 00:00:00 2001 From: "Carson M." Date: Sat, 8 Jun 2024 13:19:34 -0500 Subject: [PATCH 09/12] docs: nuke mintlify --- docs/mint.json | 102 - docs/next-env.d.ts | 5 + docs/next.config.mjs | 11 + docs/package.json | 23 + docs/pages/_app.mdx | 5 + docs/pages/_meta.json | 37 + docs/{introduction.mdx => pages/index.mdx} | 97 +- docs/{ => pages}/migrating/opsets.mdx | 0 docs/{ => pages}/migrating/v2.mdx | 7 - .../{ => pages}/migrating/version-mapping.mdx | 0 docs/{ => pages}/perf/execution-providers.mdx | 34 +- docs/{ => pages}/perf/io-binding.mdx | 0 docs/{ => pages}/setup/cargo-features.mdx | 0 docs/pages/setup/linking.mdx | 109 + docs/{ => pages}/setup/platforms.mdx | 18 +- docs/{ => pages}/setup/webassembly.mdx | 8 +- .../{ => pages}/troubleshooting/compiling.mdx | 0 .../troubleshooting/performance.mdx | 81 +- docs/pnpm-lock.yaml | 3200 +++++++++++++++++ docs/{ => public}/assets/banner.png | Bin docs/{ => public}/assets/icon.png | Bin .../{ => public}/assets/sample-onnx-graph.png | Bin docs/{ => public}/assets/trend-banner.png | Bin docs/setup/linking.mdx | 106 - docs/theme.config.jsx | 33 + docs/tsconfig.json | 28 + 26 files changed, 3575 insertions(+), 329 deletions(-) delete mode 100644 docs/mint.json create mode 100644 docs/next-env.d.ts create mode 100644 docs/next.config.mjs create mode 100644 docs/package.json create mode 100644 docs/pages/_app.mdx create mode 100644 docs/pages/_meta.json rename docs/{introduction.mdx => pages/index.mdx} (56%) rename docs/{ => pages}/migrating/opsets.mdx (100%) rename docs/{ => pages}/migrating/v2.mdx (98%) rename docs/{ => pages}/migrating/version-mapping.mdx (100%) rename docs/{ => pages}/perf/execution-providers.mdx (94%) rename docs/{ => pages}/perf/io-binding.mdx (100%) rename docs/{ => pages}/setup/cargo-features.mdx (100%) create mode 100644 docs/pages/setup/linking.mdx rename docs/{ => pages}/setup/platforms.mdx (63%) rename docs/{ => pages}/setup/webassembly.mdx (83%) rename docs/{ => pages}/troubleshooting/compiling.mdx (100%) rename docs/{ => pages}/troubleshooting/performance.mdx (55%) create mode 100644 docs/pnpm-lock.yaml rename docs/{ => public}/assets/banner.png (100%) rename docs/{ => public}/assets/icon.png (100%) rename docs/{ => public}/assets/sample-onnx-graph.png (100%) rename docs/{ => public}/assets/trend-banner.png (100%) delete mode 100644 docs/setup/linking.mdx create mode 100644 docs/theme.config.jsx create mode 100644 docs/tsconfig.json diff --git a/docs/mint.json b/docs/mint.json deleted file mode 100644 index 8a6237b..0000000 --- a/docs/mint.json +++ /dev/null @@ -1,102 +0,0 @@ -{ - "$schema": "https://mintlify.com/schema.json", - "name": "ort", - "logo": { - "dark": "/assets/banner.png", - "light": "/assets/banner.png" - }, - "favicon": "/assets/icon.png", - "colors": { - "primary": "#F74C00", - "light": "#F74C00", - "background": { - "light": "#FFFFFF", - "dark": "#000000" - }, - "dark": "#F74C00", - "anchors": { - "from": "#F74C00", - "to": "#eb8e65" - } - }, - "tabs": [ - { - "name": "API Reference", - "url": "https://docs.rs/ort/2.0.0-rc.2/ort/" - } - ], - "anchors": [ - { - "name": "Sponsor", - "icon": "hand-holding-heart", - "url": "https://opencollective.com/pyke-osai" - }, - { - "name": "Crates.io", - "icon": "rust", - "url": "https://crates.io/crates/ort" - }, - { - "name": "GitHub", - "icon": "github", - "url": "https://github.com/pykeio/ort" - }, - { - "name": "Discord", - "icon": "discord", - "url": "https://discord.gg/uQtsNu2xMa" - } - ], - "navigation": [ - { - "group": "Get Started", - "pages": [ - "introduction" - ] - }, - { - "group": "Setup", - "pages": [ - "setup/platforms", - "setup/webassembly", - "setup/linking", - "setup/cargo-features" - ] - }, - { - "group": "Fundamentals", - "pages": [ - "fundamentals/environment", - "fundamentals/session", - "fundamentals/value" - ] - }, - { - "group": "Performance", - "pages": [ - "perf/execution-providers", - "perf/io-binding" - ] - }, - { - "group": "Troubleshooting", - "pages": [ - "troubleshooting/precision", - "troubleshooting/performance", - "troubleshooting/compiling" - ] - }, - { - "group": "Migration & versioning", - "pages": [ - "migrating/version-mapping", - "migrating/v2" - ] - } - ], - "footerSocials": { - "website": "https://pyke.io/", - "github": "https://github.com/pykeio/ort", - "discord": "https://discord.gg/uQtsNu2xMa" - } -} diff --git a/docs/next-env.d.ts b/docs/next-env.d.ts new file mode 100644 index 0000000..4f11a03 --- /dev/null +++ b/docs/next-env.d.ts @@ -0,0 +1,5 @@ +/// +/// + +// NOTE: This file should not be edited +// see https://nextjs.org/docs/basic-features/typescript for more information. diff --git a/docs/next.config.mjs b/docs/next.config.mjs new file mode 100644 index 0000000..47e0f5e --- /dev/null +++ b/docs/next.config.mjs @@ -0,0 +1,11 @@ +import nextra from 'nextra'; + +export default nextra({ + theme: 'nextra-theme-docs', + themeConfig: './theme.config.jsx' +})({ + output: 'export', + images: { + unoptimized: true + } +}); diff --git a/docs/package.json b/docs/package.json new file mode 100644 index 0000000..36fdb98 --- /dev/null +++ b/docs/package.json @@ -0,0 +1,23 @@ +{ + "private": true, + "name": "ort-docs", + "version": "0.0.0", + "scripts": { + "dev": "next dev", + "build": "next build", + "start": "next start" + }, + "dependencies": { + "next": "^14.2.3", + "nextra": "^2.13.4", + "nextra-theme-docs": "^2.13.4", + "react": "^18.3.1", + "react-dom": "^18.3.1" + }, + "devDependencies": { + "@types/node": "20.14.2", + "@types/react": "^18.3.3", + "@types/react-dom": "^18.3.0", + "typescript": "^5.4.5" + } +} diff --git a/docs/pages/_app.mdx b/docs/pages/_app.mdx new file mode 100644 index 0000000..c466f98 --- /dev/null +++ b/docs/pages/_app.mdx @@ -0,0 +1,5 @@ +import font from 'next/font/google'; + +export default function App({ Component, pageProps }) { + return ; +} diff --git a/docs/pages/_meta.json b/docs/pages/_meta.json new file mode 100644 index 0000000..14840b8 --- /dev/null +++ b/docs/pages/_meta.json @@ -0,0 +1,37 @@ +{ + "-- Links": { + "type": "separator", + "title": "Links" + }, + "link-oc": { + "title": "Sponsor β†—", + "href": "https://opencollective.com/pyke-osai", + "newWindow": true + }, + "link-api": { + "title": "API Reference β†—", + "href": "https://docs.rs/ort/2.0.0-rc.2/ort" + }, + "link-crates": { + "title": "Crates.io β†—", + "href": "https://crates.io/crates/ort", + "newWindow": true + }, + "-- Docs": { + "type": "separator", + "title": "Docs" + }, + "index": "Introduction", + "setup": { + "title": "Setup" + }, + "perf": { + "title": "Performance" + }, + "troubleshooting": { + "title": "Troubleshooting" + }, + "migrating": { + "title": "Migration & versioning" + } +} diff --git a/docs/introduction.mdx b/docs/pages/index.mdx similarity index 56% rename from docs/introduction.mdx rename to docs/pages/index.mdx index 0ff676b..97ab9f5 100644 --- a/docs/introduction.mdx +++ b/docs/pages/index.mdx @@ -2,14 +2,17 @@ title: Introduction --- +import Image from 'next/image'; +import { Callout, Card, Cards, Steps } from 'nextra/components'; +

ort is an open-source Rust binding for ONNX Runtime.

- + These docs are for the latest alpha version of `ort`, `2.0.0-rc.2`. This version is production-ready (just not API stable) and we recommend new & existing projects use it. - + `ort` makes it easy to deploy your machine learning models to production via [ONNX Runtime](https://onnxruntime.ai/), a hardware-accelerated inference engine. With `ort` + ONNX Runtime, you can run almost any ML model (including ResNet, YOLOv8, BERT, LLaMA) on almost any hardware, often far faster than PyTorch, and with the added bonus of Rust's efficiency. @@ -29,52 +32,54 @@ Converting a neural network to a graph representation like ONNX opens the door t # Getting started - - If you have a [supported platform](/setup/platforms) (and you probably do), installing `ort` couldn't be any simpler! Just add it to your Cargo dependencies: - ```toml - [dependencies] - ort = "2.0.0-rc.2" - ``` - - - Your model will need to be converted to an ONNX graph before you can use it. - - The awesome folks at Hugging Face have [a guide](https://huggingface.co/docs/transformers/serialization) to export πŸ€— Transformers models to ONNX with πŸ€— Optimum. - - For any PyTorch model: [`torch.onnx`](https://pytorch.org/docs/stable/onnx.html) - - For `scikit-learn` models: [`sklearn-onnx`](https://onnx.ai/sklearn-onnx/) - - For TensorFlow, Keras, TFlite, & TensorFlow.js: [`tf2onnx`](https://github.com/onnx/tensorflow-onnx) - - For PaddlePaddle: [`Paddle2ONNX`](https://github.com/PaddlePaddle/Paddle2ONNX) - - - Once you've got a model, load it via `ort` by creating a [`Session`](/fundamentals/session): - - ```rust - use ort::{GraphOptimizationLevel, Session}; - - let model = Session::builder()? - .with_optimization_level(GraphOptimizationLevel::Level3)? - .with_intra_threads(4)? - .commit_from_file("yolov8m.onnx")?; - ``` - - - Preprocess your inputs, then `run()` the session to perform inference. - - ```rust - let outputs = model.run(ort::inputs!["image" => image]?)?; - let predictions = outputs["output0"].try_extract_tensor::()?; - ... - ``` - - There are some more useful examples [in the `ort` repo](https://github.com/pykeio/ort/tree/main/examples)! - + +### Add ort to your Cargo.toml +If you have a [supported platform](/setup/platforms) (and you probably do), installing `ort` couldn't be any simpler! Just add it to your Cargo dependencies: +```toml +[dependencies] +ort = "2.0.0-rc.2" +``` + +### Convert your model +Your model will need to be converted to an ONNX graph before you can use it. +- The awesome folks at Hugging Face have [a guide](https://huggingface.co/docs/transformers/serialization) to export πŸ€— Transformers models to ONNX with πŸ€— Optimum. +- For any PyTorch model: [`torch.onnx`](https://pytorch.org/docs/stable/onnx.html) +- For `scikit-learn` models: [`sklearn-onnx`](https://onnx.ai/sklearn-onnx/) +- For TensorFlow, Keras, TFlite, & TensorFlow.js: [`tf2onnx`](https://github.com/onnx/tensorflow-onnx) +- For PaddlePaddle: [`Paddle2ONNX`](https://github.com/PaddlePaddle/Paddle2ONNX) + +### Load your model +Once you've got a model, load it via `ort` by creating a [`Session`](/fundamentals/session): + +```rust +use ort::{GraphOptimizationLevel, Session}; + +let model = Session::builder()? + .with_optimization_level(GraphOptimizationLevel::Level3)? + .with_intra_threads(4)? + .commit_from_file("yolov8m.onnx")?; +``` + +### Perform inference +Preprocess your inputs, then `run()` the session to perform inference. + +```rust +let outputs = model.run(ort::inputs!["image" => image]?)?; +let predictions = outputs["output0"].try_extract_tensor::()?; +... +``` + +There are some more useful examples [in the `ort` repo](https://github.com/pykeio/ort/tree/main/examples)! + # Next steps - - Use [execution providers](/perf/execution-providers) to enable hardware acceleration in your app and unlock the full power of your GPU or NPU. - - - We'd love to see what you've made with `ort`! Show off your project in [GitHub Discussions](https://github.com/pykeio/ort/discussions/categories/show-and-tell) or on our [Discord](https://discord.gg/uQtsNu2xMa). - + +### Unlock more performance with EPs +Use [execution providers](/perf/execution-providers) to enable hardware acceleration in your app and unlock the full power of your GPU or NPU. + +### Show off your project! +We'd love to see what you've made with `ort`! Show off your project in [GitHub Discussions](https://github.com/pykeio/ort/discussions/categories/show-and-tell) or on our [Discord](https://discord.gg/uQtsNu2xMa). + diff --git a/docs/migrating/opsets.mdx b/docs/pages/migrating/opsets.mdx similarity index 100% rename from docs/migrating/opsets.mdx rename to docs/pages/migrating/opsets.mdx diff --git a/docs/migrating/v2.mdx b/docs/pages/migrating/v2.mdx similarity index 98% rename from docs/migrating/v2.mdx rename to docs/pages/migrating/v2.mdx index a2f2020..a3c1afd 100644 --- a/docs/migrating/v2.mdx +++ b/docs/pages/migrating/v2.mdx @@ -141,13 +141,6 @@ let noise_pred = unet.run(ort::inputs![ ]?)?; ``` -You can also supply `ort::inputs!` your `IoBinding` by specifying `bind =`: -```rust -let binding = model.create_binding()?; -... -let outputs = model.run(ort::inputs![bind = binding]?)?; -``` - ### Tensor creation no longer requires the session's allocator In previous versions, `Value::from_array` took an allocator parameter. The allocator was only used because the string data contained in string tensors had to be cloned into ONNX Runtime-managed memory. However, 99% of users only ever use primitive tensors, so the extra parameter served little purpose. The new `Tensor::from_array` function now takes only an array, and the logic for converting string arrays has been moved to a new function, `DynTensor::from_string_array`. diff --git a/docs/migrating/version-mapping.mdx b/docs/pages/migrating/version-mapping.mdx similarity index 100% rename from docs/migrating/version-mapping.mdx rename to docs/pages/migrating/version-mapping.mdx diff --git a/docs/perf/execution-providers.mdx b/docs/pages/perf/execution-providers.mdx similarity index 94% rename from docs/perf/execution-providers.mdx rename to docs/pages/perf/execution-providers.mdx index 03447b1..f2e7b9e 100644 --- a/docs/perf/execution-providers.mdx +++ b/docs/pages/perf/execution-providers.mdx @@ -3,6 +3,8 @@ title: Execution providers description: Learn how to enable execution providers to leverage hardware acceleration. --- +import { Callout, Tabs } from 'nextra/components'; + Execution providers (EPs) enable ONNX Runtime to execute ONNX graphs with hardware acceleration. If you have specialized hardware like a GPU or NPU, execution providers can provide a massive performance boost to your `ort` applications. For more information on the intricacies of execution providers, see the [ONNX Runtime docs](https://onnxruntime.ai/docs/execution-providers/). ONNX Runtime must be compiled with support for each execution provider. pyke provides precompiled binaries for some of the most common EPs, so you won't need to compile ONNX Runtime from source. Below is a table showing available EPs, their support in `ort`, and their binary availability status. @@ -28,12 +30,12 @@ ONNX Runtime must be compiled with support for each execution provider. pyke pro | Microsoft Azure | ❌ | ❌ | ❓ | | Rockchip RKNPU | ❌ | ❌ | ❓ | - + Some EPs supported by ONNX Runtime are not supported by `ort` due to a lack of hardware for testing. If your preferred EP is missing support and you've got the hardware, please [open an issue](https://github.com/pykeio/ort/issues/new)! - + ## Registering execution providers - + To use an execution provider with `ort`, you'll need to enable its respective Cargo feature, e.g. the `cuda` feature to use CUDA, or the `coreml` feature to use CoreML. ```toml Cargo.toml @@ -42,7 +44,7 @@ ONNX Runtime must be compiled with support for each execution provider. pyke pro ``` See [Cargo features](/setup/cargo-features) for the full list of features. - + In order to configure sessions to use certain execution providers, you must **register** them when creating an environment or session. You can do this via the `SessionBuilder::with_execution_providers` method. For example, to register the CUDA execution provider for a session: @@ -167,9 +169,9 @@ fn main() -> anyhow::Result<()> { } ``` - + `ort::init` must come before you create any sessions, otherwise the configuration will not take effect! - + Sessions configured with their own execution providers will *extend* the execution provider defaults, rather than overriding them. @@ -181,32 +183,32 @@ If it seems like the execution provider is not registering properly, or you are ### CoreML Statically linking to CoreML (the default behavior when using downloaded binaries + the `coreml` Cargo feature) requires an additional Rust flag in order to link properly. You'll need to provide the flag `-C link-arg=-fapple-link-rtlib` to `rustc`. You can do this via an entry in [`.cargo/config.toml`](https://doc.rust-lang.org/cargo/reference/config.html#hierarchical-structure), in a build script, or in an environment variable. - - + + See [Configuration: Hierarchical structure](https://doc.rust-lang.org/cargo/reference/config.html#hierarchical-structure) for more information on where the configuration file can be placed. - ```toml .cargo/config.toml + ```toml filename=".cargo/config.toml" copy [target.aarch64-apple-darwin] rustflags = ["-Clink-arg=-fapple-link-rtlib"] [target.x86_64-apple-darwin] rustflags = ["-Clink-arg=-fapple-link-rtlib"] ``` - - + + Add the following to the `build.rs` script of any **binary** crate that uses `ort`. - ```rust build.rs + ```rust filename="build.rs" copy fn main() { println!("cargo:rustc-link-arg=-fapple-link-rtlib"); } ``` Library crates do not need this flag, and the usage of it in a library crate will not transitively apply to any binary crates dependent on it. - - - ```shell + + + ```shell copy $ RUSTFLAGS="-Clink-arg=-fapple-link-rtlib" cargo build ``` - + diff --git a/docs/perf/io-binding.mdx b/docs/pages/perf/io-binding.mdx similarity index 100% rename from docs/perf/io-binding.mdx rename to docs/pages/perf/io-binding.mdx diff --git a/docs/setup/cargo-features.mdx b/docs/pages/setup/cargo-features.mdx similarity index 100% rename from docs/setup/cargo-features.mdx rename to docs/pages/setup/cargo-features.mdx diff --git a/docs/pages/setup/linking.mdx b/docs/pages/setup/linking.mdx new file mode 100644 index 0000000..bdb4923 --- /dev/null +++ b/docs/pages/setup/linking.mdx @@ -0,0 +1,109 @@ +--- +title: Linking +description: Here's how `ort` links to ONNX Runtime, and how to configure its behavior. +--- + +import { Callout, Tabs, Steps } from 'nextra/components'; + +`ort` provides its own builds of ONNX Runtime to make your experience as painless as possible, but in some cases, you'll want to use a custom build of ONNX Runtime with `ort`. Luckily, we make this very easy by handling all of the linking configuration automagically. Just point `ort` to the output of ONNX Runtime's build pipeline and it'll Just Workβ„’. + +## Static linking +Most ONNX Runtime compile configurations will support static linking - just run `build.sh` without the `--build_shared_lib` argument. You should prefer static linking if your execution providers support it, as it avoids many issues and follows de facto Rust practices. If you compile both static libraries and dynamic libraries, `ort` will prefer linking to the static libraries. + +To direct `ort` to your statically built binaries, use the `ORT_LIB_LOCATION` environment variable when running `cargo build`. Point it to the location where the static libraries (`.a`/`.lib` files) are compiled to. This will typically be `onnxruntime/build/`. For example: +```shell +$ ORT_LIB_LOCATION=~/onnxruntime/build/Linux cargo build +``` + +For iOS (or for other platforms if you are compiling multiple profiles at once), you'll need to manually specify the profile with the `ORT_LIB_PROFILE` environment variable. If not specified, `ort` will prefer `Release` over `RelWithDebInfo` over `MinSizeRel` over `Debug`. + +## Dynamic linking +Some execution providers unfortunately only support dynamic linking. Dynamic linking doesn't play well with the Rust ecosystem, though `ort` tries to alleviate the pain as much as possible. + +When it comes to dynamic linking, there are two options: `load-dynamic`, or standard compile-time dynamic linking. We recommend `load-dynamic` as it gives more control and is often far less troublesome to work with. + +### Runtime loading with `load-dynamic` +The `load-dynamic` Cargo feature solves a few of the issues with dynamic linking by **loading the library at runtime** rather than **linking at compile time**. This means that the path to the ONNX Runtime library can be configured at runtime, and the executable will not just completely fail to start if the binary couldn't be found. + +To use `load-dynamic`: + + +#### Enable the feature in Cargo.toml +```toml filename="Cargo.toml" +[dependencies] +ort = { version = "2", features = [ "load-dynamic" ] } +``` + +### Point ort to the dylib + + + ```rust main.rs + fn main() -> anyhow::Result<()> { + // Find our custom ONNX Runtime dylib path somehow + // (i.e. resolving it from the root of our program's install folder) + let dylib_path = crate::internal::find_onnxruntime_dylib()?; + // The path should point to the `libonnxruntime` binary, which looks like: + // - on Unix: /etc/.../libonnxruntime.so + // - on Windows: C:\Program Files\...\onnxruntime.dll + + // Initialize ort with the path to the dylib. This **must** be called before any usage of `ort`! + // `init_from` returns an `EnvironmentBuilder` which you can use to further configure the environment + // before `.commit()`ing; see the Environment docs for more information on what you can configure. + ort::init_from(dylib_path).commit()?; + + Ok(()) + } + ``` + + + Set the `ORT_DYLIB_PATH` environment variable to the path to `libonnxruntime.so`/`onnxruntime.dll`. + + ```shell + $ ORT_DYLIB_PATH=../onnxruntime-build/linux-x64/libonnxruntime.so ./mirai + ``` + + + + + +`ORT_DYLIB_PATH` is relative to the executable. Cargo examples and tests are compiled to a different directory than binary crates: `target//examples` and `target//deps` respectively. Keep this in mind if you're going to use relative paths. + +### Compile-time dynamic linking +For compile-time dynamic linking, you'll need to configure your environment in the exact same way as if you were [statically linking](#static-linking). + +Note that the dylibs then have to be placed in a certain location for them to be found by the executable. For Windows, this is either somewhere on the `PATH`, or in the same folder as the executable. On macOS and Linux, they have to be placed somewhere in the `LD_LIBRARY_PATH`, or you can use rpath to configure the executable to search for dylibs in its parent folder. We've had the least issues with rpath, but YMMV. + +To configure rpath, you'll need to: + +#### Enable rpath in Cargo.toml +```toml filename="Cargo.toml" copy +[profile.dev] +rpath = true + +[profile.release] +rpath = true + +# do this for any other profiles +``` + +### Configure the path in the linker args in .cargo/config.toml to be relative to the executable + + + ```toml filename="~/.cargo/config.toml" copy + [target.x86_64-unknown-linux-gnu] + rustflags = [ "-Clink-args=-Wl,-rpath,\\$ORIGIN" ] + + # do this for any other Linux targets as well + ``` + + + ```toml filename="~/.cargo/config.toml" copy + [target.x86_64-apple-darwin] + rustflags = [ "-Clink-args=-Wl,-rpath,@loader_path" ] + + # do this for any other macOS targets as well + ``` + + + + diff --git a/docs/setup/platforms.mdx b/docs/pages/setup/platforms.mdx similarity index 63% rename from docs/setup/platforms.mdx rename to docs/pages/setup/platforms.mdx index 0244345..f83d131 100644 --- a/docs/setup/platforms.mdx +++ b/docs/pages/setup/platforms.mdx @@ -3,6 +3,8 @@ title: Platform support description: ONNX Runtime, and by extension `ort`, supports a wide variety of platforms. For most desktop users, pre-built binaries are available, so setting up `ort` is as simple as adding it to your `Cargo.toml`! --- +import { Callout } from 'nextra/components'; + Here are the supported platforms and binary availability status, as of v2.0.0-rc.2. * 🟒 - Supported. Dynamic & static binaries provided by pyke. @@ -19,14 +21,18 @@ Here are the supported platforms and binary availability status, as of v2.0.0-rc | **Android** | ❌ | ❌ | β­• | β­• | ❌ | | **Web** | ❌ | ❌ | ❌ | ❌ | πŸ”·ΒΆ | -\* Recent version of Windows 10/11 required for pyke binaries.
-† glibc β‰₯ 2.31 (Ubuntu β‰₯ 20.04) required for pyke binaries.
-‑ glibc β‰₯ 2.35 (Ubuntu β‰₯ 22.04) required for pyke binaries.
-Β§ macOS β‰₯ 10.15 required.
-ΒΆ WASM supports a limited subset of ONNX Runtime features. For more info, see [the docs on WebAssembly support](/setup/webassembly). +
+

\* Recent version of Windows 10/11 required for pyke binaries.

+

† glibc β‰₯ 2.31 (Ubuntu β‰₯ 20.04) required for pyke binaries.

+

‑ glibc β‰₯ 2.35 (Ubuntu β‰₯ 22.04) required for pyke binaries.

+

Β§ macOS β‰₯ 10.15 required.

+

ΒΆ WASM supports a limited subset of ONNX Runtime features. For more info, see [the docs on WebAssembly support](/setup/webassembly).

+
If your platform is marked as 🟒 or πŸ”·, you're in luck! Almost no setup will be required to get `ort` up and running. For platforms marked as β­•, you'll need to [compile ONNX Runtime from source](https://onnxruntime.ai/docs/build/) and then [link `ort` to your custom binaries](/setup/linking) (but don't worry, we made this setup as simple as possible!) -Certain execution providers may not have binaries available. You can check EP binary support in the [Execution providers](/perf/execution-providers) documentation. + + Certain execution providers may not have binaries available. You can check EP binary support in the [Execution providers](/perf/execution-providers) documentation. + diff --git a/docs/setup/webassembly.mdx b/docs/pages/setup/webassembly.mdx similarity index 83% rename from docs/setup/webassembly.mdx rename to docs/pages/setup/webassembly.mdx index 87b1ccf..3e5e5cc 100644 --- a/docs/setup/webassembly.mdx +++ b/docs/pages/setup/webassembly.mdx @@ -5,19 +5,13 @@ description: Deploy ONNX models to the web WebAssembly support in `ort` is currently experimental. If you experience any issues using `ort` in WebAssembly, please [open an issue](https://github.com/pykeio/ort/issues/new). -Development of WASM support is done in a separate branch for now, so you'll have to add `ort` as a Git dependency: -```toml Cargo.toml -[dependencies] -ort = { git = "https://github.com/pykeio/ort.git", branch = "wasm32-unknown-unknown" } -``` - By nature, some features of ONNX Runtime are not available in the web. These include: - **Support for `.onnx` models.** You instead need to [convert `.onnx` models to the `.ort` format](https://onnxruntime.ai/docs/performance/model-optimizations/ort-format-models.html). - **Runtime graph optimizations**, aka `SessionBuilder::with_optimization_level`. You can statically optimize the graph using the `.ort` conversion tool, though. - **Loading models with `commit_from_file`/`commit_from_url`.** You can create models from a slice of bytes in memory with `SessionBuilder::commit_from_memory` or `SessionBuilder::commit_from_memory_directly`. Additionally, you'll need to call `ort::wasm::initialize()` at the earliest possible point in your code, before you use any `ort` APIs: -```rust main.rs +```rust filename="main.rs" copy use ort::Session; static MODEL_BYTES: &[u8] = include_bytes!("../model.ort"); diff --git a/docs/troubleshooting/compiling.mdx b/docs/pages/troubleshooting/compiling.mdx similarity index 100% rename from docs/troubleshooting/compiling.mdx rename to docs/pages/troubleshooting/compiling.mdx diff --git a/docs/troubleshooting/performance.mdx b/docs/pages/troubleshooting/performance.mdx similarity index 55% rename from docs/troubleshooting/performance.mdx rename to docs/pages/troubleshooting/performance.mdx index 6bf4112..e407895 100644 --- a/docs/troubleshooting/performance.mdx +++ b/docs/pages/troubleshooting/performance.mdx @@ -2,53 +2,56 @@ title: 'Troubleshoot: Performance' --- +import { Callout, Tabs, Steps } from 'nextra/components'; + ## Execution providers don't seem to register `ort` is designed to fail gracefully when an execution provider is not available. It logs failure events through [`tracing`](https://crates.io/crates/tracing), thus you'll need a library that subscribes to `tracing` events to see the logs. The simplest way to do this is to use [`tracing-subscriber`](https://crates.io/crates/tracing-subscriber). - - ```toml Cargo.toml - [dependencies] - tracing-subscriber = { version = "0.3", features = [ "env-filter", "fmt" ] } + +### Add tracing-subscriber to your dependencies +```toml Cargo.toml +[dependencies] +tracing-subscriber = { version = "0.3", features = [ "env-filter", "fmt" ] } +``` + +### Initialize the subscriber in the main function +```rust main.rs +fn main() { + tracing_subscriber::fmt::init(); +} +``` + +### Show debug messages from ort +Set the environment variable `RUST_LOG` to `ort=debug` to see all debug messages from `ort`. + + + ```powershell + $env:RUST_LOG = 'ort=debug'; + cargo run ``` - - - ```rust main.rs - fn main() { - tracing_subscriber::fmt::init(); - } + + + ```cmd + set RUST_LOG=ort=debug + cargo run ``` - - - Set the environment variable `RUST_LOG` to `ort=debug` to see all debug messages from `ort`. - - - ```powershell - $env:RUST_LOG = 'ort=debug'; - cargo run - ``` - - - ```cmd - set RUST_LOG=ort=debug - cargo run - ``` - - - ```shell - RUST_LOG="ort=debug" cargo run - ``` - - - ```shell - RUST_LOG="ort=debug" cargo run - ``` - - - + + + ```shell + RUST_LOG="ort=debug" cargo run + ``` + + + ```shell + RUST_LOG="ort=debug" cargo run + ``` + + + -You can also detect EP regsitration failures programmatically. See [Execution providers: Fallback behavior](/perf/execution-providers#fallback-behavior) for more info. +You can also detect EP regsitration failures programmatically. See [Execution providers: Fallback behavior](/perf/execution-providers#fallback-behavior) for more info. ## Inference is slower than expected There are a few things you could try to improve performance: diff --git a/docs/pnpm-lock.yaml b/docs/pnpm-lock.yaml new file mode 100644 index 0000000..4bebc4f --- /dev/null +++ b/docs/pnpm-lock.yaml @@ -0,0 +1,3200 @@ +lockfileVersion: '9.0' + +settings: + autoInstallPeers: true + excludeLinksFromLockfile: false + +importers: + + .: + dependencies: + next: + specifier: ^14.2.3 + version: 14.2.3(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + nextra: + specifier: ^2.13.4 + version: 2.13.4(next@14.2.3(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + nextra-theme-docs: + specifier: ^2.13.4 + version: 2.13.4(next@14.2.3(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(nextra@2.13.4(next@14.2.3(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + react: + specifier: ^18.3.1 + version: 18.3.1 + react-dom: + specifier: ^18.3.1 + version: 18.3.1(react@18.3.1) + devDependencies: + '@types/node': + specifier: 20.14.2 + version: 20.14.2 + '@types/react': + specifier: ^18.3.3 + version: 18.3.3 + '@types/react-dom': + specifier: ^18.3.0 + version: 18.3.0 + typescript: + specifier: ^5.4.5 + version: 5.4.5 + +packages: + + '@babel/runtime@7.24.7': + resolution: {integrity: sha512-UwgBRMjJP+xv857DCngvqXI3Iq6J4v0wXmwc6sapg+zyhbwmQX67LUEFrkK5tbyJ30jGuG3ZvWpBiB9LCy1kWw==} + engines: {node: '>=6.9.0'} + + '@braintree/sanitize-url@6.0.4': + resolution: {integrity: sha512-s3jaWicZd0pkP0jf5ysyHUI/RE7MHos6qlToFcGWXVp+ykHOy77OUMrfbgJ9it2C5bow7OIQwYYaHjk9XlBQ2A==} + + '@headlessui/react@1.7.19': + resolution: {integrity: sha512-Ll+8q3OlMJfJbAKM/+/Y2q6PPYbryqNTXDbryx7SXLIDamkF6iQFbriYHga0dY44PvDhvvBWCx1Xj4U5+G4hOw==} + engines: {node: '>=10'} + peerDependencies: + react: ^16 || ^17 || ^18 + react-dom: ^16 || ^17 || ^18 + + '@mdx-js/mdx@2.3.0': + resolution: {integrity: sha512-jLuwRlz8DQfQNiUCJR50Y09CGPq3fLtmtUQfVrj79E0JWu3dvsVcxVIcfhR5h0iXu+/z++zDrYeiJqifRynJkA==} + + '@mdx-js/react@2.3.0': + resolution: {integrity: sha512-zQH//gdOmuu7nt2oJR29vFhDv88oGPmVw6BggmrHeMI+xgEkp1B2dX9/bMBSYtK0dyLX/aOmesKS09g222K1/g==} + peerDependencies: + react: '>=16' + + '@napi-rs/simple-git-android-arm-eabi@0.1.16': + resolution: {integrity: sha512-dbrCL0Pl5KZG7x7tXdtVsA5CO6At5ohDX3myf5xIYn9kN4jDFxsocl8bNt6Vb/hZQoJd8fI+k5VlJt+rFhbdVw==} + engines: {node: '>= 10'} + cpu: [arm] + os: [android] + + '@napi-rs/simple-git-android-arm64@0.1.16': + resolution: {integrity: sha512-xYz+TW5J09iK8SuTAKK2D5MMIsBUXVSs8nYp7HcMi8q6FCRO7yJj96YfP9PvKsc/k64hOyqGmL5DhCzY9Cu1FQ==} + engines: {node: '>= 10'} + cpu: [arm64] + os: [android] + + '@napi-rs/simple-git-darwin-arm64@0.1.16': + resolution: {integrity: sha512-XfgsYqxhUE022MJobeiX563TJqyQyX4FmYCnqrtJwAfivESVeAJiH6bQIum8dDEYMHXCsG7nL8Ok0Dp8k2m42g==} + engines: {node: '>= 10'} + cpu: [arm64] + os: [darwin] + + '@napi-rs/simple-git-darwin-x64@0.1.16': + resolution: {integrity: sha512-tkEVBhD6vgRCbeWsaAQqM3bTfpIVGeitamPPRVSbsq8qgzJ5Dx6ZedH27R7KSsA/uao7mZ3dsrNLXbu1Wy5MzA==} + engines: {node: '>= 10'} + cpu: [x64] + os: [darwin] + + '@napi-rs/simple-git-linux-arm-gnueabihf@0.1.16': + resolution: {integrity: sha512-R6VAyNnp/yRaT7DV1Ao3r67SqTWDa+fNq2LrNy0Z8gXk2wB9ZKlrxFtLPE1WSpWknWtyRDLpRlsorh7Evk7+7w==} + engines: {node: '>= 10'} + cpu: [arm] + os: [linux] + + '@napi-rs/simple-git-linux-arm64-gnu@0.1.16': + resolution: {integrity: sha512-LAGI0opFKw/HBMCV2qIBK3uWSEW9h4xd2ireZKLJy8DBPymX6NrWIamuxYNyCuACnFdPRxR4LaRFy4J5ZwuMdw==} + engines: {node: '>= 10'} + cpu: [arm64] + os: [linux] + + '@napi-rs/simple-git-linux-arm64-musl@0.1.16': + resolution: {integrity: sha512-I57Ph0F0Yn2KW93ep+V1EzKhACqX0x49vvSiapqIsdDA2PifdEWLc1LJarBolmK7NKoPqKmf6lAKKO9lhiZzkg==} + engines: {node: '>= 10'} + cpu: [arm64] + os: [linux] + + '@napi-rs/simple-git-linux-x64-gnu@0.1.16': + resolution: {integrity: sha512-AZYYFY2V7hlcQASPEOWyOa3e1skzTct9QPzz0LiDM3f/hCFY/wBaU2M6NC5iG3d2Kr38heuyFS/+JqxLm5WaKA==} + engines: {node: '>= 10'} + cpu: [x64] + os: [linux] + + '@napi-rs/simple-git-linux-x64-musl@0.1.16': + resolution: {integrity: sha512-9TyMcYSBJwjT8jwjY9m24BZbu7ozyWTjsmYBYNtK3B0Um1Ov6jthSNneLVvouQ6x+k3Ow+00TiFh6bvmT00r8g==} + engines: {node: '>= 10'} + cpu: [x64] + os: [linux] + + '@napi-rs/simple-git-win32-arm64-msvc@0.1.16': + resolution: {integrity: sha512-uslJ1WuAHCYJWui6xjsyT47SjX6KOHDtClmNO8hqKz1pmDSNY7AjyUY8HxvD1lK9bDnWwc4JYhikS9cxCqHybw==} + engines: {node: '>= 10'} + cpu: [arm64] + os: [win32] + + '@napi-rs/simple-git-win32-x64-msvc@0.1.16': + resolution: {integrity: sha512-SoEaVeCZCDF1MP+M9bMSXsZWgEjk4On9GWADO5JOulvzR1bKjk0s9PMHwe/YztR9F0sJzrCxwtvBZowhSJsQPg==} + engines: {node: '>= 10'} + cpu: [x64] + os: [win32] + + '@napi-rs/simple-git@0.1.16': + resolution: {integrity: sha512-C5wRPw9waqL2jk3jEDeJv+f7ScuO3N0a39HVdyFLkwKxHH4Sya4ZbzZsu2JLi6eEqe7RuHipHL6mC7B2OfYZZw==} + engines: {node: '>= 10'} + + '@next/env@14.2.3': + resolution: {integrity: sha512-W7fd7IbkfmeeY2gXrzJYDx8D2lWKbVoTIj1o1ScPHNzvp30s1AuoEFSdr39bC5sjxJaxTtq3OTCZboNp0lNWHA==} + + '@next/swc-darwin-arm64@14.2.3': + resolution: {integrity: sha512-3pEYo/RaGqPP0YzwnlmPN2puaF2WMLM3apt5jLW2fFdXD9+pqcoTzRk+iZsf8ta7+quAe4Q6Ms0nR0SFGFdS1A==} + engines: {node: '>= 10'} + cpu: [arm64] + os: [darwin] + + '@next/swc-darwin-x64@14.2.3': + resolution: {integrity: sha512-6adp7waE6P1TYFSXpY366xwsOnEXM+y1kgRpjSRVI2CBDOcbRjsJ67Z6EgKIqWIue52d2q/Mx8g9MszARj8IEA==} + engines: {node: '>= 10'} + cpu: [x64] + os: [darwin] + + '@next/swc-linux-arm64-gnu@14.2.3': + resolution: {integrity: sha512-cuzCE/1G0ZSnTAHJPUT1rPgQx1w5tzSX7POXSLaS7w2nIUJUD+e25QoXD/hMfxbsT9rslEXugWypJMILBj/QsA==} + engines: {node: '>= 10'} + cpu: [arm64] + os: [linux] + + '@next/swc-linux-arm64-musl@14.2.3': + resolution: {integrity: sha512-0D4/oMM2Y9Ta3nGuCcQN8jjJjmDPYpHX9OJzqk42NZGJocU2MqhBq5tWkJrUQOQY9N+In9xOdymzapM09GeiZw==} + engines: {node: '>= 10'} + cpu: [arm64] + os: [linux] + + '@next/swc-linux-x64-gnu@14.2.3': + resolution: {integrity: sha512-ENPiNnBNDInBLyUU5ii8PMQh+4XLr4pG51tOp6aJ9xqFQ2iRI6IH0Ds2yJkAzNV1CfyagcyzPfROMViS2wOZ9w==} + engines: {node: '>= 10'} + cpu: [x64] + os: [linux] + + '@next/swc-linux-x64-musl@14.2.3': + resolution: {integrity: sha512-BTAbq0LnCbF5MtoM7I/9UeUu/8ZBY0i8SFjUMCbPDOLv+un67e2JgyN4pmgfXBwy/I+RHu8q+k+MCkDN6P9ViQ==} + engines: {node: '>= 10'} + cpu: [x64] + os: [linux] + + '@next/swc-win32-arm64-msvc@14.2.3': + resolution: {integrity: sha512-AEHIw/dhAMLNFJFJIJIyOFDzrzI5bAjI9J26gbO5xhAKHYTZ9Or04BesFPXiAYXDNdrwTP2dQceYA4dL1geu8A==} + engines: {node: '>= 10'} + cpu: [arm64] + os: [win32] + + '@next/swc-win32-ia32-msvc@14.2.3': + resolution: {integrity: sha512-vga40n1q6aYb0CLrM+eEmisfKCR45ixQYXuBXxOOmmoV8sYST9k7E3US32FsY+CkkF7NtzdcebiFT4CHuMSyZw==} + engines: {node: '>= 10'} + cpu: [ia32] + os: [win32] + + '@next/swc-win32-x64-msvc@14.2.3': + resolution: {integrity: sha512-Q1/zm43RWynxrO7lW4ehciQVj+5ePBhOK+/K2P7pLFX3JaJ/IZVC69SHidrmZSOkqz7ECIOhhy7XhAFG4JYyHA==} + engines: {node: '>= 10'} + cpu: [x64] + os: [win32] + + '@popperjs/core@2.11.8': + resolution: {integrity: sha512-P1st0aksCrn9sGZhp8GMYwBnQsbvAWsZAX44oXNNvLHGqAOcoVxmjZiohstwQ7SqKnbR47akdNi+uleWD8+g6A==} + + '@swc/counter@0.1.3': + resolution: {integrity: sha512-e2BR4lsJkkRlKZ/qCHPw9ZaSxc0MVUd7gtbtaB7aMvHeJVYe8sOB8DBZkP2DtISHGSku9sCK6T6cnY0CtXrOCQ==} + + '@swc/helpers@0.5.5': + resolution: {integrity: sha512-KGYxvIOXcceOAbEk4bi/dVLEK9z8sZ0uBB3Il5b1rhfClSpcX0yfRO0KmTkqR2cnQDymwLB+25ZyMzICg/cm/A==} + + '@tanstack/react-virtual@3.5.1': + resolution: {integrity: sha512-jIsuhfgy8GqA67PdWqg73ZB2LFE+HD9hjWL1L6ifEIZVyZVAKpYmgUG4WsKQ005aEyImJmbuimPiEvc57IY0Aw==} + peerDependencies: + react: ^16.8.0 || ^17.0.0 || ^18.0.0 + react-dom: ^16.8.0 || ^17.0.0 || ^18.0.0 + + '@tanstack/virtual-core@3.5.1': + resolution: {integrity: sha512-046+AUSiDru/V9pajE1du8WayvBKeCvJ2NmKPy/mR8/SbKKrqmSbj7LJBfXE+nSq4f5TBXvnCzu0kcYebI9WdQ==} + + '@theguild/remark-mermaid@0.0.5': + resolution: {integrity: sha512-e+ZIyJkEv9jabI4m7q29wZtZv+2iwPGsXJ2d46Zi7e+QcFudiyuqhLhHG/3gX3ZEB+hxTch+fpItyMS8jwbIcw==} + peerDependencies: + react: ^18.2.0 + + '@theguild/remark-npm2yarn@0.2.1': + resolution: {integrity: sha512-jUTFWwDxtLEFtGZh/TW/w30ySaDJ8atKWH8dq2/IiQF61dPrGfETpl0WxD0VdBfuLOeU14/kop466oBSRO/5CA==} + + '@types/acorn@4.0.6': + resolution: {integrity: sha512-veQTnWP+1D/xbxVrPC3zHnCZRjSrKfhbMUlEA43iMZLu7EsnTtkJklIuwrCPbOi8YkvDQAiW05VQQFvvz9oieQ==} + + '@types/d3-scale-chromatic@3.0.3': + resolution: {integrity: sha512-laXM4+1o5ImZv3RpFAsTRn3TEkzqkytiOY0Dz0sq5cnd1dtNlk6sHLon4OvqaiJb28T0S/TdsBI3Sjsy+keJrw==} + + '@types/d3-scale@4.0.8': + resolution: {integrity: sha512-gkK1VVTr5iNiYJ7vWDI+yUFFlszhNMtVeneJ6lUTKPjprsvLLI9/tgEGiXJOnlINJA8FyA88gfnQsHbybVZrYQ==} + + '@types/d3-time@3.0.3': + resolution: {integrity: sha512-2p6olUZ4w3s+07q3Tm2dbiMZy5pCDfYwtLXXHUnVzXgQlZ/OyPtUz6OL382BkOuGlLXqfT+wqv8Fw2v8/0geBw==} + + '@types/debug@4.1.12': + resolution: {integrity: sha512-vIChWdVG3LG1SMxEvI/AK+FWJthlrqlTu7fbrlywTkkaONwk/UAGaULXRlf8vkzFBLVm0zkMdCquhL5aOjhXPQ==} + + '@types/estree-jsx@1.0.5': + resolution: {integrity: sha512-52CcUVNFyfb1A2ALocQw/Dd1BQFNmSdkuC3BkZ6iqhdMfQz7JWOFRuJFloOzjk+6WijU56m9oKXFAXc7o3Towg==} + + '@types/estree@1.0.5': + resolution: {integrity: sha512-/kYRxGDLWzHOB7q+wtSUQlFrtcdUccpfy+X+9iMBpHK8QLLhx2wIPYuS5DYtR9Wa/YlZAbIovy7qVdB1Aq6Lyw==} + + '@types/hast@2.3.10': + resolution: {integrity: sha512-McWspRw8xx8J9HurkVBfYj0xKoE25tOFlHGdx4MJ5xORQrMGZNqJhVQWaIbm6Oyla5kYOXtDiopzKRJzEOkwJw==} + + '@types/hast@3.0.4': + resolution: {integrity: sha512-WPs+bbQw5aCj+x6laNGWLH3wviHtoCv/P3+otBhbOhJgG8qtpdAMlTCxLtsTWA7LH1Oh/bFCHsBn0TPS5m30EQ==} + + '@types/js-yaml@4.0.9': + resolution: {integrity: sha512-k4MGaQl5TGo/iipqb2UDG2UwjXziSWkh0uysQelTlJpX1qGlpUZYm8PnO4DxG1qBomtJUdYJ6qR6xdIah10JLg==} + + '@types/katex@0.16.7': + resolution: {integrity: sha512-HMwFiRujE5PjrgwHQ25+bsLJgowjGjm5Z8FVSf0N6PwgJrwxH0QxzHYDcKsTfV3wva0vzrpqMTJS2jXPr5BMEQ==} + + '@types/mdast@3.0.15': + resolution: {integrity: sha512-LnwD+mUEfxWMa1QpDraczIn6k0Ee3SMicuYSSzS6ZYl2gKS09EClnJYGd8Du6rfc5r/GZEk5o1mRb8TaTj03sQ==} + + '@types/mdast@4.0.4': + resolution: {integrity: sha512-kGaNbPh1k7AFzgpud/gMdvIm5xuECykRR+JnWKQno9TAXVa6WIVCGTPvYGekIDL4uwCZQSYbUxNBSb1aUo79oA==} + + '@types/mdx@2.0.13': + resolution: {integrity: sha512-+OWZQfAYyio6YkJb3HLxDrvnx6SWWDbC0zVPfBRzUk0/nqoDyf6dNxQi3eArPe8rJ473nobTMQ/8Zk+LxJ+Yuw==} + + '@types/ms@0.7.34': + resolution: {integrity: sha512-nG96G3Wp6acyAgJqGasjODb+acrI7KltPiRxzHPXnP3NgI28bpQDRv53olbqGXbfcgF5aiiHmO3xpwEpS5Ld9g==} + + '@types/node@20.14.2': + resolution: {integrity: sha512-xyu6WAMVwv6AKFLB+e/7ySZVr/0zLCzOa7rSpq6jNwpqOrUbcACDWC+53d4n2QHOnDou0fbIsg8wZu/sxrnI4Q==} + + '@types/prop-types@15.7.12': + resolution: {integrity: sha512-5zvhXYtRNRluoE/jAp4GVsSduVUzNWKkOZrCDBWYtE7biZywwdC2AcEzg+cSMLFRfVgeAFqpfNabiPjxFddV1Q==} + + '@types/react-dom@18.3.0': + resolution: {integrity: sha512-EhwApuTmMBmXuFOikhQLIBUn6uFg81SwLMOAUgodJF14SOBOCMdU04gDoYi0WOJJHD144TL32z4yDqCW3dnkQg==} + + '@types/react@18.3.3': + resolution: {integrity: sha512-hti/R0pS0q1/xx+TsI73XIqk26eBsISZ2R0wUijXIngRK9R/e7Xw/cXVxQK7R5JjW+SV4zGcn5hXjudkN/pLIw==} + + '@types/unist@2.0.10': + resolution: {integrity: sha512-IfYcSBWE3hLpBg8+X2SEa8LVkJdJEkT2Ese2aaLs3ptGdVtABxndrMaxuFlQ1qdFf9Q5rDvDpxI3WwgvKFAsQA==} + + '@types/unist@3.0.2': + resolution: {integrity: sha512-dqId9J8K/vGi5Zr7oo212BGii5m3q5Hxlkwy3WpYuKPklmBEvsbMYYyLxAQpSffdLl/gdW0XUpKWFvYmyoWCoQ==} + + '@ungap/structured-clone@1.2.0': + resolution: {integrity: sha512-zuVdFrMJiuCDQUMCzQaD6KL28MjnqqN8XnAqiEq9PNm/hCPTSGfrXCOfwj1ow4LFb/tNymJPwsNbVePc1xFqrQ==} + + acorn-jsx@5.3.2: + resolution: {integrity: sha512-rq9s+JNhf0IChjtDXxllJ7g41oZk5SlXtp0LHwyA5cejwn7vKmKp4pPri6YEePv2PU65sAsegbXtIinmDFDXgQ==} + peerDependencies: + acorn: ^6.0.0 || ^7.0.0 || ^8.0.0 + + acorn@8.11.3: + resolution: {integrity: sha512-Y9rRfJG5jcKOE0CLisYbojUjIrIEE7AGMzA/Sm4BslANhbS+cDMpgBdcPT91oJ7OuJ9hYJBx59RjbhxVnrF8Xg==} + engines: {node: '>=0.4.0'} + hasBin: true + + ansi-sequence-parser@1.1.1: + resolution: {integrity: sha512-vJXt3yiaUL4UU546s3rPXlsry/RnM730G1+HkpKE012AN0sx1eOrxSu95oKDIonskeLTijMgqWZ3uDEe3NFvyg==} + + ansi-styles@3.2.1: + resolution: {integrity: sha512-VT0ZI6kZRdTh8YyJw3SMbYm/u+NqfsAxEpWO0Pf9sq8/e94WxxOpPKx9FR1FlyCtOVDNOQ+8ntlqFxiRc+r5qA==} + engines: {node: '>=4'} + + arch@2.2.0: + resolution: {integrity: sha512-Of/R0wqp83cgHozfIYLbBMnej79U/SVGOOyuB3VVFv1NRM/PSFMK12x9KVtiYzJqmnU5WR2qp0Z5rHb7sWGnFQ==} + + arg@1.0.0: + resolution: {integrity: sha512-Wk7TEzl1KqvTGs/uyhmHO/3XLd3t1UeU4IstvPXVzGPM522cTjqjNZ99esCkcL52sjqjo8e8CTBcWhkxvGzoAw==} + + argparse@1.0.10: + resolution: {integrity: sha512-o5Roy6tNG4SL/FOkCAN6RzjiakZS25RLYFrcMttJqbdd8BWrnA+fGz57iN5Pb06pvBGvl5gQ0B48dJlslXvoTg==} + + argparse@2.0.1: + resolution: {integrity: sha512-8+9WqebbFzpX9OR+Wa6O29asIogeRMzcGtAINdpMHHyAg10f05aSFVBbcEqGf/PXw1EjAZ+q2/bEBg3DvurK3Q==} + + astring@1.8.6: + resolution: {integrity: sha512-ISvCdHdlTDlH5IpxQJIex7BWBywFWgjJSVdwst+/iQCoEYnyOaQ95+X1JGshuBjGp6nxKUy1jMgE3zPqN7fQdg==} + hasBin: true + + bail@2.0.2: + resolution: {integrity: sha512-0xO6mYd7JB2YesxDKplafRpsiOzPt9V02ddPCLbY1xYGPOX24NTyN50qnUxgCPcSoYMhKpAuBTjQoRZCAkUDRw==} + + busboy@1.6.0: + resolution: {integrity: sha512-8SFQbg/0hQ9xy3UNTB0YEnsNBbWfhf7RtnzpL7TkBiTBRfrQ9Fxcnz7VJsleJpyp6rVLvXiuORqjlHi5q+PYuA==} + engines: {node: '>=10.16.0'} + + caniuse-lite@1.0.30001629: + resolution: {integrity: sha512-c3dl911slnQhmxUIT4HhYzT7wnBK/XYpGnYLOj4nJBaRiw52Ibe7YxlDaAeRECvA786zCuExhxIUJ2K7nHMrBw==} + + ccount@2.0.1: + resolution: {integrity: sha512-eyrF0jiFpY+3drT6383f1qhkbGsLSifNAjA61IUjZjmLCWjItY6LB9ft9YhoDgwfmclB2zhu51Lc7+95b8NRAg==} + + chalk@2.3.0: + resolution: {integrity: sha512-Az5zJR2CBujap2rqXGaJKaPHyJ0IrUimvYNX+ncCy8PJP4ltOGTrHUIo097ZaL2zMeKYpiCdqDvS6zdrTFok3Q==} + engines: {node: '>=4'} + + character-entities-html4@2.1.0: + resolution: {integrity: sha512-1v7fgQRj6hnSwFpq1Eu0ynr/CDEw0rXo2B61qXrLNdHZmPKgb7fqS1a2JwF0rISo9q77jDI8VMEHoApn8qDoZA==} + + character-entities-legacy@3.0.0: + resolution: {integrity: sha512-RpPp0asT/6ufRm//AJVwpViZbGM/MkjQFxJccQRHmISF/22NBtsHqAWmL+/pmkPWoIUJdWyeVleTl1wydHATVQ==} + + character-entities@2.0.2: + resolution: {integrity: sha512-shx7oQ0Awen/BRIdkjkvz54PnEEI/EjwXDSIZp86/KKdbafHh1Df/RYGBhn4hbe2+uKC9FnT5UCEdyPz3ai9hQ==} + + character-reference-invalid@2.0.1: + resolution: {integrity: sha512-iBZ4F4wRbyORVsu0jPV7gXkOsGYjGHPmAyv+HiHG8gi5PtC9KI2j1+v8/tlibRvjoWX027ypmG/n0HtO5t7unw==} + + client-only@0.0.1: + resolution: {integrity: sha512-IV3Ou0jSMzZrd3pZ48nLkT9DA7Ag1pnPzaiQhpW7c3RbcqqzvzzVu+L8gfqMp/8IM2MQtSiqaCxrrcfu8I8rMA==} + + clipboardy@1.2.2: + resolution: {integrity: sha512-16KrBOV7bHmHdxcQiCvfUFYVFyEah4FI8vYT1Fr7CGSA4G+xBWMEfUEQJS1hxeHGtI9ju1Bzs9uXSbj5HZKArw==} + engines: {node: '>=4'} + + clsx@2.1.1: + resolution: {integrity: sha512-eYm0QWBtUrBWZWG0d386OGAw16Z995PiOVo2B7bjWSbHedGl5e0ZWaq65kOGgUSNesEIDkB9ISbTg/JK9dhCZA==} + engines: {node: '>=6'} + + color-convert@1.9.3: + resolution: {integrity: sha512-QfAUtd+vFdAtFQcC8CCyYt1fYWxSqAiK2cSD6zDB8N3cpsEBAvRxp9zOGg6G/SHHJYAT88/az/IuDGALsNVbGg==} + + color-name@1.1.3: + resolution: {integrity: sha512-72fSenhMw2HZMTVHeCA9KCmpEIbzWiQsjN+BHcBbS9vr1mtt+vJjPdksIBNUmKAW8TFUDPJK5SUU3QhE9NEXDw==} + + comma-separated-tokens@2.0.3: + resolution: {integrity: sha512-Fu4hJdvzeylCfQPp9SGWidpzrMs7tTrlu6Vb8XGaRGck8QSNZJJp538Wrb60Lax4fPwR64ViY468OIUTbRlGZg==} + + commander@7.2.0: + resolution: {integrity: sha512-QrWXB+ZQSVPmIWIhtEO9H+gwHaMGYiF5ChvoJ+K9ZGHG/sVsa6yiesAD1GC/x46sET00Xlwo1u49RVVVzvcSkw==} + engines: {node: '>= 10'} + + commander@8.3.0: + resolution: {integrity: sha512-OkTL9umf+He2DZkUq8f8J9of7yL6RJKI24dVITBmNfZBmri9zYZQrKkuXiKhyfPSu8tUhnVBB1iKXevvnlR4Ww==} + engines: {node: '>= 12'} + + compute-scroll-into-view@3.1.0: + resolution: {integrity: sha512-rj8l8pD4bJ1nx+dAkMhV1xB5RuZEyVysfxJqB1pRchh1KVvwOv9b7CGB8ZfjTImVv2oF+sYMUkMZq6Na5Ftmbg==} + + cose-base@1.0.3: + resolution: {integrity: sha512-s9whTXInMSgAp/NVXVNuVxVKzGH2qck3aQlVHxDCdAEPgtMKwc4Wq6/QKhgdEdgbLSi9rBTAcPoRa6JpiG4ksg==} + + cross-spawn@5.1.0: + resolution: {integrity: sha512-pTgQJ5KC0d2hcY8eyL1IzlBPYjTkyH72XRZPnLyKus2mBfNjQs3klqbJU2VILqZryAZUt9JOb3h/mWMy23/f5A==} + + csstype@3.1.3: + resolution: {integrity: sha512-M1uQkMl8rQK/szD0LNhtqxIPLpimGm8sOBwU7lLnCpSbTyY3yeU1Vc7l4KT5zT4s/yOxHH5O7tIuuLOCnLADRw==} + + cytoscape-cose-bilkent@4.1.0: + resolution: {integrity: sha512-wgQlVIUJF13Quxiv5e1gstZ08rnZj2XaLHGoFMYXz7SkNfCDOOteKBE6SYRfA9WxxI/iBc3ajfDoc6hb/MRAHQ==} + peerDependencies: + cytoscape: ^3.2.0 + + cytoscape@3.29.2: + resolution: {integrity: sha512-2G1ycU28Nh7OHT9rkXRLpCDP30MKH1dXJORZuBhtEhEW7pKwgPi77ImqlCWinouyE1PNepIOGZBOrE84DG7LyQ==} + engines: {node: '>=0.10'} + + d3-array@2.12.1: + resolution: {integrity: sha512-B0ErZK/66mHtEsR1TkPEEkwdy+WDesimkM5gpZr5Dsg54BiTA5RXtYW5qTLIAcekaS9xfZrzBLF/OAkB3Qn1YQ==} + + d3-array@3.2.4: + resolution: {integrity: sha512-tdQAmyA18i4J7wprpYq8ClcxZy3SC31QMeByyCFyRt7BVHdREQZ5lpzoe5mFEYZUWe+oq8HBvk9JjpibyEV4Jg==} + engines: {node: '>=12'} + + d3-axis@3.0.0: + resolution: {integrity: sha512-IH5tgjV4jE/GhHkRV0HiVYPDtvfjHQlQfJHs0usq7M30XcSBvOotpmH1IgkcXsO/5gEQZD43B//fc7SRT5S+xw==} + engines: {node: '>=12'} + + d3-brush@3.0.0: + resolution: {integrity: sha512-ALnjWlVYkXsVIGlOsuWH1+3udkYFI48Ljihfnh8FZPF2QS9o+PzGLBslO0PjzVoHLZ2KCVgAM8NVkXPJB2aNnQ==} + engines: {node: '>=12'} + + d3-chord@3.0.1: + resolution: {integrity: sha512-VE5S6TNa+j8msksl7HwjxMHDM2yNK3XCkusIlpX5kwauBfXuyLAtNg9jCp/iHH61tgI4sb6R/EIMWCqEIdjT/g==} + engines: {node: '>=12'} + + d3-color@3.1.0: + resolution: {integrity: sha512-zg/chbXyeBtMQ1LbD/WSoW2DpC3I0mpmPdW+ynRTj/x2DAWYrIY7qeZIHidozwV24m4iavr15lNwIwLxRmOxhA==} + engines: {node: '>=12'} + + d3-contour@4.0.2: + resolution: {integrity: sha512-4EzFTRIikzs47RGmdxbeUvLWtGedDUNkTcmzoeyg4sP/dvCexO47AaQL7VKy/gul85TOxw+IBgA8US2xwbToNA==} + engines: {node: '>=12'} + + d3-delaunay@6.0.4: + resolution: {integrity: sha512-mdjtIZ1XLAM8bm/hx3WwjfHt6Sggek7qH043O8KEjDXN40xi3vx/6pYSVTwLjEgiXQTbvaouWKynLBiUZ6SK6A==} + engines: {node: '>=12'} + + d3-dispatch@3.0.1: + resolution: {integrity: sha512-rzUyPU/S7rwUflMyLc1ETDeBj0NRuHKKAcvukozwhshr6g6c5d8zh4c2gQjY2bZ0dXeGLWc1PF174P2tVvKhfg==} + engines: {node: '>=12'} + + d3-drag@3.0.0: + resolution: {integrity: sha512-pWbUJLdETVA8lQNJecMxoXfH6x+mO2UQo8rSmZ+QqxcbyA3hfeprFgIT//HW2nlHChWeIIMwS2Fq+gEARkhTkg==} + engines: {node: '>=12'} + + d3-dsv@3.0.1: + resolution: {integrity: sha512-UG6OvdI5afDIFP9w4G0mNq50dSOsXHJaRE8arAS5o9ApWnIElp8GZw1Dun8vP8OyHOZ/QJUKUJwxiiCCnUwm+Q==} + engines: {node: '>=12'} + hasBin: true + + d3-ease@3.0.1: + resolution: {integrity: sha512-wR/XK3D3XcLIZwpbvQwQ5fK+8Ykds1ip7A2Txe0yxncXSdq1L9skcG7blcedkOX+ZcgxGAmLX1FrRGbADwzi0w==} + engines: {node: '>=12'} + + d3-fetch@3.0.1: + resolution: {integrity: sha512-kpkQIM20n3oLVBKGg6oHrUchHM3xODkTzjMoj7aWQFq5QEM+R6E4WkzT5+tojDY7yjez8KgCBRoj4aEr99Fdqw==} + engines: {node: '>=12'} + + d3-force@3.0.0: + resolution: {integrity: sha512-zxV/SsA+U4yte8051P4ECydjD/S+qeYtnaIyAs9tgHCqfguma/aAQDjo85A9Z6EKhBirHRJHXIgJUlffT4wdLg==} + engines: {node: '>=12'} + + d3-format@3.1.0: + resolution: {integrity: sha512-YyUI6AEuY/Wpt8KWLgZHsIU86atmikuoOmCfommt0LYHiQSPjvX2AcFc38PX0CBpr2RCyZhjex+NS/LPOv6YqA==} + engines: {node: '>=12'} + + d3-geo@3.1.1: + resolution: {integrity: sha512-637ln3gXKXOwhalDzinUgY83KzNWZRKbYubaG+fGVuc/dxO64RRljtCTnf5ecMyE1RIdtqpkVcq0IbtU2S8j2Q==} + engines: {node: '>=12'} + + d3-hierarchy@3.1.2: + resolution: {integrity: sha512-FX/9frcub54beBdugHjDCdikxThEqjnR93Qt7PvQTOHxyiNCAlvMrHhclk3cD5VeAaq9fxmfRp+CnWw9rEMBuA==} + engines: {node: '>=12'} + + d3-interpolate@3.0.1: + resolution: {integrity: sha512-3bYs1rOD33uo8aqJfKP3JWPAibgw8Zm2+L9vBKEHJ2Rg+viTR7o5Mmv5mZcieN+FRYaAOWX5SJATX6k1PWz72g==} + engines: {node: '>=12'} + + d3-path@1.0.9: + resolution: {integrity: sha512-VLaYcn81dtHVTjEHd8B+pbe9yHWpXKZUC87PzoFmsFrJqgFwDe/qxfp5MlfsfM1V5E/iVt0MmEbWQ7FVIXh/bg==} + + d3-path@3.1.0: + resolution: {integrity: sha512-p3KP5HCf/bvjBSSKuXid6Zqijx7wIfNW+J/maPs+iwR35at5JCbLUT0LzF1cnjbCHWhqzQTIN2Jpe8pRebIEFQ==} + engines: {node: '>=12'} + + d3-polygon@3.0.1: + resolution: {integrity: sha512-3vbA7vXYwfe1SYhED++fPUQlWSYTTGmFmQiany/gdbiWgU/iEyQzyymwL9SkJjFFuCS4902BSzewVGsHHmHtXg==} + engines: {node: '>=12'} + + d3-quadtree@3.0.1: + resolution: {integrity: sha512-04xDrxQTDTCFwP5H6hRhsRcb9xxv2RzkcsygFzmkSIOJy3PeRJP7sNk3VRIbKXcog561P9oU0/rVH6vDROAgUw==} + engines: {node: '>=12'} + + d3-random@3.0.1: + resolution: {integrity: sha512-FXMe9GfxTxqd5D6jFsQ+DJ8BJS4E/fT5mqqdjovykEB2oFbTMDVdg1MGFxfQW+FBOGoB++k8swBrgwSHT1cUXQ==} + engines: {node: '>=12'} + + d3-sankey@0.12.3: + resolution: {integrity: sha512-nQhsBRmM19Ax5xEIPLMY9ZmJ/cDvd1BG3UVvt5h3WRxKg5zGRbvnteTyWAbzeSvlh3tW7ZEmq4VwR5mB3tutmQ==} + + d3-scale-chromatic@3.1.0: + resolution: {integrity: sha512-A3s5PWiZ9YCXFye1o246KoscMWqf8BsD9eRiJ3He7C9OBaxKhAd5TFCdEx/7VbKtxxTsu//1mMJFrEt572cEyQ==} + engines: {node: '>=12'} + + d3-scale@4.0.2: + resolution: {integrity: sha512-GZW464g1SH7ag3Y7hXjf8RoUuAFIqklOAq3MRl4OaWabTFJY9PN/E1YklhXLh+OQ3fM9yS2nOkCoS+WLZ6kvxQ==} + engines: {node: '>=12'} + + d3-selection@3.0.0: + resolution: {integrity: sha512-fmTRWbNMmsmWq6xJV8D19U/gw/bwrHfNXxrIN+HfZgnzqTHp9jOmKMhsTUjXOJnZOdZY9Q28y4yebKzqDKlxlQ==} + engines: {node: '>=12'} + + d3-shape@1.3.7: + resolution: {integrity: sha512-EUkvKjqPFUAZyOlhY5gzCxCeI0Aep04LwIRpsZ/mLFelJiUfnK56jo5JMDSE7yyP2kLSb6LtF+S5chMk7uqPqw==} + + d3-shape@3.2.0: + resolution: {integrity: sha512-SaLBuwGm3MOViRq2ABk3eLoxwZELpH6zhl3FbAoJ7Vm1gofKx6El1Ib5z23NUEhF9AsGl7y+dzLe5Cw2AArGTA==} + engines: {node: '>=12'} + + d3-time-format@4.1.0: + resolution: {integrity: sha512-dJxPBlzC7NugB2PDLwo9Q8JiTR3M3e4/XANkreKSUxF8vvXKqm1Yfq4Q5dl8budlunRVlUUaDUgFt7eA8D6NLg==} + engines: {node: '>=12'} + + d3-time@3.1.0: + resolution: {integrity: sha512-VqKjzBLejbSMT4IgbmVgDjpkYrNWUYJnbCGo874u7MMKIWsILRX+OpX/gTk8MqjpT1A/c6HY2dCA77ZN0lkQ2Q==} + engines: {node: '>=12'} + + d3-timer@3.0.1: + resolution: {integrity: sha512-ndfJ/JxxMd3nw31uyKoY2naivF+r29V+Lc0svZxe1JvvIRmi8hUsrMvdOwgS1o6uBHmiz91geQ0ylPP0aj1VUA==} + engines: {node: '>=12'} + + d3-transition@3.0.1: + resolution: {integrity: sha512-ApKvfjsSR6tg06xrL434C0WydLr7JewBB3V+/39RMHsaXTOG0zmt/OAXeng5M5LBm0ojmxJrpomQVZ1aPvBL4w==} + engines: {node: '>=12'} + peerDependencies: + d3-selection: 2 - 3 + + d3-zoom@3.0.0: + resolution: {integrity: sha512-b8AmV3kfQaqWAuacbPuNbL6vahnOJflOhexLzMMNLga62+/nh0JzvJ0aO/5a5MVgUFGS7Hu1P9P03o3fJkDCyw==} + engines: {node: '>=12'} + + d3@7.9.0: + resolution: {integrity: sha512-e1U46jVP+w7Iut8Jt8ri1YsPOvFpg46k+K8TpCb0P+zjCkjkPnV7WzfDJzMHy1LnA+wj5pLT1wjO901gLXeEhA==} + engines: {node: '>=12'} + + dagre-d3-es@7.0.10: + resolution: {integrity: sha512-qTCQmEhcynucuaZgY5/+ti3X/rnszKZhEQH/ZdWdtP1tA/y3VoHJzcVrO9pjjJCNpigfscAtoUB5ONcd2wNn0A==} + + dayjs@1.11.11: + resolution: {integrity: sha512-okzr3f11N6WuqYtZSvm+F776mB41wRZMhKP+hc34YdW+KmtYYK9iqvHSwo2k9FEH3fhGXvOPV6yz2IcSrfRUDg==} + + debug@4.3.5: + resolution: {integrity: sha512-pt0bNEmneDIvdL1Xsd9oDQ/wrQRkXDT4AUWlNZNPKvW5x/jyO9VFXkJUP07vQ2upmw5PlaITaPKc31jK13V+jg==} + engines: {node: '>=6.0'} + peerDependencies: + supports-color: '*' + peerDependenciesMeta: + supports-color: + optional: true + + decode-named-character-reference@1.0.2: + resolution: {integrity: sha512-O8x12RzrUF8xyVcY0KJowWsmaJxQbmy0/EtnNtHRpsOcT7dFk5W598coHqBVpmWo1oQQfsCqfCmkZN5DJrZVdg==} + + delaunator@5.0.1: + resolution: {integrity: sha512-8nvh+XBe96aCESrGOqMp/84b13H9cdKbG5P2ejQCh4d4sK9RL4371qou9drQjMhvnPmhWl5hnmqbEE0fXr9Xnw==} + + dequal@2.0.3: + resolution: {integrity: sha512-0je+qPKHEMohvfRTCEo3CrPG6cAzAYgmzKyxRiYSSDkS6eGJdyVJm7WaYA5ECaAD9wLB2T4EEeymA5aFVcYXCA==} + engines: {node: '>=6'} + + devlop@1.1.0: + resolution: {integrity: sha512-RWmIqhcFf1lRYBvNmr7qTNuyCt/7/ns2jbpp1+PalgE/rDQcBT0fioSMUpJ93irlUhC5hrg4cYqe6U+0ImW0rA==} + + diff@5.2.0: + resolution: {integrity: sha512-uIFDxqpRZGZ6ThOk84hEfqWoHx2devRFvpTZcTHur85vImfaxUbTW9Ryh4CpCuDnToOP1CEtXKIgytHBPVff5A==} + engines: {node: '>=0.3.1'} + + dompurify@3.1.5: + resolution: {integrity: sha512-lwG+n5h8QNpxtyrJW/gJWckL+1/DQiYMX8f7t8Z2AZTPw1esVrqjI63i7Zc2Gz0aKzLVMYC1V1PL/ky+aY/NgA==} + + elkjs@0.9.3: + resolution: {integrity: sha512-f/ZeWvW/BCXbhGEf1Ujp29EASo/lk1FDnETgNKwJrsVvGZhUWCZyg3xLJjAsxfOmt8KjswHmI5EwCQcPMpOYhQ==} + + entities@4.5.0: + resolution: {integrity: sha512-V0hjH4dGPh9Ao5p0MoRY6BVqtwCjhz6vI5LT8AJ55H+4g9/4vbHx1I54fS0XuclLhDHArPQCiMjDxjaL8fPxhw==} + engines: {node: '>=0.12'} + + escape-string-regexp@1.0.5: + resolution: {integrity: sha512-vbRorB5FUQWvla16U8R/qgaFIya2qGzwDrNmCZuYKrbdSUMG6I1ZCGQRefkRVhuOkIGVne7BQ35DSfo1qvJqFg==} + engines: {node: '>=0.8.0'} + + escape-string-regexp@5.0.0: + resolution: {integrity: sha512-/veY75JbMK4j1yjvuUxuVsiS/hr/4iHs9FTT6cgTexxdE0Ly/glccBAkloH/DofkjRbZU3bnoj38mOmhkZ0lHw==} + engines: {node: '>=12'} + + esprima@4.0.1: + resolution: {integrity: sha512-eGuFFw7Upda+g4p+QHvnW0RyTX/SVeJBDM/gCtMARO0cLuT2HcEKnTPvhjV6aGeqrCB/sbNop0Kszm0jsaWU4A==} + engines: {node: '>=4'} + hasBin: true + + estree-util-attach-comments@2.1.1: + resolution: {integrity: sha512-+5Ba/xGGS6mnwFbXIuQiDPTbuTxuMCooq3arVv7gPZtYpjp+VXH/NkHAP35OOefPhNG/UGqU3vt/LTABwcHX0w==} + + estree-util-build-jsx@2.2.2: + resolution: {integrity: sha512-m56vOXcOBuaF+Igpb9OPAy7f9w9OIkb5yhjsZuaPm7HoGi4oTOQi0h2+yZ+AtKklYFZ+rPC4n0wYCJCEU1ONqg==} + + estree-util-is-identifier-name@2.1.0: + resolution: {integrity: sha512-bEN9VHRyXAUOjkKVQVvArFym08BTWB0aJPppZZr0UNyAqWsLaVfAqP7hbaTJjzHifmB5ebnR8Wm7r7yGN/HonQ==} + + estree-util-to-js@1.2.0: + resolution: {integrity: sha512-IzU74r1PK5IMMGZXUVZbmiu4A1uhiPgW5hm1GjcOfr4ZzHaMPpLNJjR7HjXiIOzi25nZDrgFTobHTkV5Q6ITjA==} + + estree-util-value-to-estree@1.3.0: + resolution: {integrity: sha512-Y+ughcF9jSUJvncXwqRageavjrNPAI+1M/L3BI3PyLp1nmgYTGUXU6t5z1Y7OWuThoDdhPME07bQU+d5LxdJqw==} + engines: {node: '>=12.0.0'} + + estree-util-visit@1.2.1: + resolution: {integrity: sha512-xbgqcrkIVbIG+lI/gzbvd9SGTJL4zqJKBFttUl5pP27KhAjtMKbX/mQXJ7qgyXpMgVy/zvpm0xoQQaGL8OloOw==} + + estree-walker@3.0.3: + resolution: {integrity: sha512-7RUKfXgSMMkzt6ZuXmqapOurLGPPfgj6l9uRZ7lRGolvk0y2yocc35LdcxKC5PQZdn2DMqioAQ2NoWcrTKmm6g==} + + execa@0.8.0: + resolution: {integrity: sha512-zDWS+Rb1E8BlqqhALSt9kUhss8Qq4nN3iof3gsOdyINksElaPyNBtKUMTR62qhvgVWR0CqCX7sdnKe4MnUbFEA==} + engines: {node: '>=4'} + + extend-shallow@2.0.1: + resolution: {integrity: sha512-zCnTtlxNoAiDc3gqY2aYAWFx7XWWiasuF2K8Me5WbN8otHKTUKBwjPtNpRs/rbUZm7KxWAaNj7P1a/p52GbVug==} + engines: {node: '>=0.10.0'} + + extend@3.0.2: + resolution: {integrity: sha512-fjquC59cD7CyW6urNXK0FBufkZcoiGG80wTuPujX590cB5Ttln20E2UB4S/WARVqhXffZl2LNgS+gQdPIIim/g==} + + flexsearch@0.7.43: + resolution: {integrity: sha512-c5o/+Um8aqCSOXGcZoqZOm+NqtVwNsvVpWv6lfmSclU954O3wvQKxxK8zj74fPaSJbXpSLTs4PRhh+wnoCXnKg==} + + focus-visible@5.2.0: + resolution: {integrity: sha512-Rwix9pBtC1Nuy5wysTmKy+UjbDJpIfg8eHjw0rjZ1mX4GNLz1Bmd16uDpI3Gk1i70Fgcs8Csg2lPm8HULFg9DQ==} + + get-stream@3.0.0: + resolution: {integrity: sha512-GlhdIUuVakc8SJ6kK0zAFbiGzRFzNnY4jUuEbV9UROo4Y+0Ny4fjvcZFVTeDA4odpFyOQzaw6hXukJSq/f28sQ==} + engines: {node: '>=4'} + + git-up@7.0.0: + resolution: {integrity: sha512-ONdIrbBCFusq1Oy0sC71F5azx8bVkvtZtMJAsv+a6lz5YAmbNnLD6HAB4gptHZVLPR8S2/kVN6Gab7lryq5+lQ==} + + git-url-parse@13.1.1: + resolution: {integrity: sha512-PCFJyeSSdtnbfhSNRw9Wk96dDCNx+sogTe4YNXeXSJxt7xz5hvXekuRn9JX7m+Mf4OscCu8h+mtAl3+h5Fo8lQ==} + + github-slugger@2.0.0: + resolution: {integrity: sha512-IaOQ9puYtjrkq7Y0Ygl9KDZnrf/aiUJYUpVf89y8kyaxbRG7Y1SrX/jaumrv81vc61+kiMempujsM3Yw7w5qcw==} + + graceful-fs@4.2.11: + resolution: {integrity: sha512-RbJ5/jmFcNNCcDV5o9eTnBLJ/HszWV0P73bc+Ff4nS/rJj+YaS6IGyiOL0VoBYX+l1Wrl3k63h/KrH+nhJ0XvQ==} + + gray-matter@4.0.3: + resolution: {integrity: sha512-5v6yZd4JK3eMI3FqqCouswVqwugaA9r4dNZB1wwcmrD02QkV5H0y7XBQW8QwQqEaZY1pM9aqORSORhJRdNK44Q==} + engines: {node: '>=6.0'} + + has-flag@2.0.0: + resolution: {integrity: sha512-P+1n3MnwjR/Epg9BBo1KT8qbye2g2Ou4sFumihwt6I4tsUX7jnLcX4BTOSKg/B1ZrIYMN9FcEnG4x5a7NB8Eng==} + engines: {node: '>=0.10.0'} + + hash-obj@4.0.0: + resolution: {integrity: sha512-FwO1BUVWkyHasWDW4S8o0ssQXjvyghLV2rfVhnN36b2bbcj45eGiuzdn9XOvOpjV3TKQD7Gm2BWNXdE9V4KKYg==} + engines: {node: '>=12'} + + hast-util-from-dom@5.0.0: + resolution: {integrity: sha512-d6235voAp/XR3Hh5uy7aGLbM3S4KamdW0WEgOaU1YoewnuYw4HXb5eRtv9g65m/RFGEfUY1Mw4UqCc5Y8L4Stg==} + + hast-util-from-html-isomorphic@2.0.0: + resolution: {integrity: sha512-zJfpXq44yff2hmE0XmwEOzdWin5xwH+QIhMLOScpX91e/NSGPsAzNCvLQDIEPyO2TXi+lBmU6hjLIhV8MwP2kw==} + + hast-util-from-html@2.0.1: + resolution: {integrity: sha512-RXQBLMl9kjKVNkJTIO6bZyb2n+cUH8LFaSSzo82jiLT6Tfc+Pt7VQCS+/h3YwG4jaNE2TA2sdJisGWR+aJrp0g==} + + hast-util-from-parse5@8.0.1: + resolution: {integrity: sha512-Er/Iixbc7IEa7r/XLtuG52zoqn/b3Xng/w6aZQ0xGVxzhw5xUFxcRqdPzP6yFi/4HBYRaifaI5fQ1RH8n0ZeOQ==} + + hast-util-is-element@3.0.0: + resolution: {integrity: sha512-Val9mnv2IWpLbNPqc/pUem+a7Ipj2aHacCwgNfTiK0vJKl0LF+4Ba4+v1oPHFpf3bLYmreq0/l3Gud9S5OH42g==} + + hast-util-parse-selector@4.0.0: + resolution: {integrity: sha512-wkQCkSYoOGCRKERFWcxMVMOcYE2K1AaNLU8DXS9arxnLOUEWbOXKXiJUNzEpqZ3JOKpnha3jkFrumEjVliDe7A==} + + hast-util-raw@9.0.3: + resolution: {integrity: sha512-ICWvVOF2fq4+7CMmtCPD5CM4QKjPbHpPotE6+8tDooV0ZuyJVUzHsrNX+O5NaRbieTf0F7FfeBOMAwi6Td0+yQ==} + + hast-util-to-estree@2.3.3: + resolution: {integrity: sha512-ihhPIUPxN0v0w6M5+IiAZZrn0LH2uZomeWwhn7uP7avZC6TE7lIiEh2yBMPr5+zi1aUCXq6VoYRgs2Bw9xmycQ==} + + hast-util-to-parse5@8.0.0: + resolution: {integrity: sha512-3KKrV5ZVI8if87DVSi1vDeByYrkGzg4mEfeu4alwgmmIeARiBLKCZS2uw5Gb6nU9x9Yufyj3iudm6i7nl52PFw==} + + hast-util-to-text@4.0.2: + resolution: {integrity: sha512-KK6y/BN8lbaq654j7JgBydev7wuNMcID54lkRav1P0CaE1e47P72AWWPiGKXTJU271ooYzcvTAn/Zt0REnvc7A==} + + hast-util-whitespace@2.0.1: + resolution: {integrity: sha512-nAxA0v8+vXSBDt3AnRUNjyRIQ0rD+ntpbAp4LnPkumc5M9yUbSMa4XDU9Q6etY4f1Wp4bNgvc1yjiZtsTTrSng==} + + hastscript@8.0.0: + resolution: {integrity: sha512-dMOtzCEd3ABUeSIISmrETiKuyydk1w0pa+gE/uormcTpSYuaNJPbX1NU3JLyscSLjwAQM8bWMhhIlnCqnRvDTw==} + + html-void-elements@3.0.0: + resolution: {integrity: sha512-bEqo66MRXsUGxWHV5IP0PUiAWwoEjba4VCzg0LjFJBpchPaTfyfCKTG6bc5F8ucKec3q5y6qOdGyYTSBEvhCrg==} + + iconv-lite@0.6.3: + resolution: {integrity: sha512-4fCk79wshMdzMp2rH06qWrJE4iolqLhCUH+OiuIgU++RB0+94NlDL81atO7GX55uUKueo0txHNtvEyI6D7WdMw==} + engines: {node: '>=0.10.0'} + + inline-style-parser@0.1.1: + resolution: {integrity: sha512-7NXolsK4CAS5+xvdj5OMMbI962hU/wvwoxk+LWR9Ek9bVtyuuYScDN6eS0rUm6TxApFpw7CX1o4uJzcd4AyD3Q==} + + internmap@1.0.1: + resolution: {integrity: sha512-lDB5YccMydFBtasVtxnZ3MRBHuaoE8GKsppq+EchKL2U4nK/DmEpPHNH8MZe5HkMtpSiTSOZwfN0tzYjO/lJEw==} + + internmap@2.0.3: + resolution: {integrity: sha512-5Hh7Y1wQbvY5ooGgPbDaL5iYLAPzMTUrjMulskHLH6wnv/A+1q5rgEaiuqEjB+oxGXIVZs1FF+R/KPN3ZSQYYg==} + engines: {node: '>=12'} + + intersection-observer@0.12.2: + resolution: {integrity: sha512-7m1vEcPCxXYI8HqnL8CKI6siDyD+eIWSwgB3DZA+ZTogxk9I4CDnj4wilt9x/+/QbHI4YG5YZNmC6458/e9Ktg==} + + is-alphabetical@2.0.1: + resolution: {integrity: sha512-FWyyY60MeTNyeSRpkM2Iry0G9hpr7/9kD40mD/cGQEuilcZYS4okz8SN2Q6rLCJ8gbCt6fN+rC+6tMGS99LaxQ==} + + is-alphanumerical@2.0.1: + resolution: {integrity: sha512-hmbYhX/9MUMF5uh7tOXyK/n0ZvWpad5caBA17GsC6vyuCqaWliRG5K1qS9inmUhEMaOBIW7/whAnSwveW/LtZw==} + + is-buffer@2.0.5: + resolution: {integrity: sha512-i2R6zNFDwgEHJyQUtJEk0XFi1i0dPFn/oqjK3/vPCcDeJvW5NQ83V8QbicfF1SupOaB0h8ntgBC2YiE7dfyctQ==} + engines: {node: '>=4'} + + is-decimal@2.0.1: + resolution: {integrity: sha512-AAB9hiomQs5DXWcRB1rqsxGUstbRroFOPPVAomNk/3XHR5JyEZChOyTWe2oayKnsSsr/kcGqF+z6yuH6HHpN0A==} + + is-extendable@0.1.1: + resolution: {integrity: sha512-5BMULNob1vgFX6EjQw5izWDxrecWK9AM72rugNr0TFldMOi0fj6Jk+zeKIt0xGj4cEfQIJth4w3OKWOJ4f+AFw==} + engines: {node: '>=0.10.0'} + + is-hexadecimal@2.0.1: + resolution: {integrity: sha512-DgZQp241c8oO6cA1SbTEWiXeoxV42vlcJxgH+B3hi1AiqqKruZR3ZGF8In3fj4+/y/7rHvlOZLZtgJ/4ttYGZg==} + + is-obj@3.0.0: + resolution: {integrity: sha512-IlsXEHOjtKhpN8r/tRFj2nDyTmHvcfNeu/nrRIcXE17ROeatXchkojffa1SpdqW4cr/Fj6QkEf/Gn4zf6KKvEQ==} + engines: {node: '>=12'} + + is-plain-obj@3.0.0: + resolution: {integrity: sha512-gwsOE28k+23GP1B6vFl1oVh/WOzmawBrKwo5Ev6wMKzPkaXaCDIQKzLnvsA42DRlbVTWorkgTKIviAKCWkfUwA==} + engines: {node: '>=10'} + + is-plain-obj@4.1.0: + resolution: {integrity: sha512-+Pgi+vMuUNkJyExiMBt5IlFoMyKnr5zhJ4Uspz58WOhBF5QoIZkFyNHIbBAtHwzVAgk5RtndVNsDRN61/mmDqg==} + engines: {node: '>=12'} + + is-reference@3.0.2: + resolution: {integrity: sha512-v3rht/LgVcsdZa3O2Nqs+NMowLOxeOm7Ay9+/ARQ2F+qEoANRcqrjAZKGN0v8ymUetZGgkp26LTnGT7H0Qo9Pg==} + + is-ssh@1.4.0: + resolution: {integrity: sha512-x7+VxdxOdlV3CYpjvRLBv5Lo9OJerlYanjwFrPR9fuGPjCiNiCzFgAWpiLAohSbsnH4ZAys3SBh+hq5rJosxUQ==} + + is-stream@1.1.0: + resolution: {integrity: sha512-uQPm8kcs47jx38atAcWTVxyltQYoPT68y9aWYdV6yWXSyW8mzSat0TL6CiWdZeCdF3KrAvpVtnHbTv4RN+rqdQ==} + engines: {node: '>=0.10.0'} + + isexe@2.0.0: + resolution: {integrity: sha512-RHxMLp9lnKHGHRng9QFhRCMbYAcVpn69smSGcq3f36xjgVVWThj4qqLbTLlq7Ssj8B+fIQ1EuCEGI2lKsyQeIw==} + + js-tokens@4.0.0: + resolution: {integrity: sha512-RdJUflcE3cUzKiMqQgsCu06FPu9UdIJO0beYbPhHN4k6apgJtifcoCtT9bcxOpYBtpD2kCM6Sbzg4CausW/PKQ==} + + js-yaml@3.14.1: + resolution: {integrity: sha512-okMH7OXXJ7YrN9Ok3/SXrnu4iX9yOk+25nqX4imS2npuvTYDmo/QEZoqwZkYaIDk3jVvBOTOIEgEhaLOynBS9g==} + hasBin: true + + js-yaml@4.1.0: + resolution: {integrity: sha512-wpxZs9NoxZaJESJGIZTyDEaYpl0FKSA+FB9aJiyemKhMwkxQg63h4T1KJgUGHpTqPDNRcmmYLugrRjJlBtWvRA==} + hasBin: true + + jsonc-parser@3.2.1: + resolution: {integrity: sha512-AilxAyFOAcK5wA1+LeaySVBrHsGQvUFCDWXKpZjzaL0PqW+xfBOttn8GNtWKFWqneyMZj41MWF9Kl6iPWLwgOA==} + + katex@0.16.10: + resolution: {integrity: sha512-ZiqaC04tp2O5utMsl2TEZTXxa6WSC4yo0fv5ML++D3QZv/vx2Mct0mTlRx3O+uUkjfuAgOkzsCmq5MiUEsDDdA==} + hasBin: true + + khroma@2.1.0: + resolution: {integrity: sha512-Ls993zuzfayK269Svk9hzpeGUKob/sIgZzyHYdjQoAdQetRKpOLj+k/QQQ/6Qi0Yz65mlROrfd+Ev+1+7dz9Kw==} + + kind-of@6.0.3: + resolution: {integrity: sha512-dcS1ul+9tmeD95T+x28/ehLgd9mENa3LsvDTtzm3vyBEO7RPptvAD+t44WVXaUjTBRcrpFeFlC8WCruUR456hw==} + engines: {node: '>=0.10.0'} + + kleur@4.1.5: + resolution: {integrity: sha512-o+NO+8WrRiQEE4/7nwRJhN1HWpVmJm511pBHUxPLtp0BUISzlBplORYSmTclCnJvQq2tKu/sgl3xVpkc7ZWuQQ==} + engines: {node: '>=6'} + + layout-base@1.0.2: + resolution: {integrity: sha512-8h2oVEZNktL4BH2JCOI90iD1yXwL6iNW7KcCKT2QZgQJR2vbqDsldCTPRU9NifTCqHZci57XvQQ15YTu+sTYPg==} + + lodash-es@4.17.21: + resolution: {integrity: sha512-mKnC+QJ9pWVzv+C4/U3rRsHapFfHvQFoFB92e52xeyGMcX6/OlIl78je1u8vePzYZSkkogMPJ2yjxxsb89cxyw==} + + lodash.get@4.4.2: + resolution: {integrity: sha512-z+Uw/vLuy6gQe8cfaFWD7p0wVv8fJl3mbzXh33RS+0oW2wvUqiRXiQ69gLWSLpgB5/6sU+r6BlQR0MBILadqTQ==} + + longest-streak@3.1.0: + resolution: {integrity: sha512-9Ri+o0JYgehTaVBBDoMqIl8GXtbWg711O3srftcHhZ0dqnETqLaoIK0x17fUw9rFSlK/0NlsKe0Ahhyl5pXE2g==} + + loose-envify@1.4.0: + resolution: {integrity: sha512-lyuxPGr/Wfhrlem2CL/UcnUc1zcqKAImBDzukY7Y5F/yQiNdko6+fRLevlw1HgMySw7f611UIY408EtxRSoK3Q==} + hasBin: true + + lru-cache@4.1.5: + resolution: {integrity: sha512-sWZlbEP2OsHNkXrMl5GYk/jKk70MBng6UU4YI/qGDYbgf6YbP4EvmqISbXCoJiRKs+1bSpFHVgQxvJ17F2li5g==} + + markdown-extensions@1.1.1: + resolution: {integrity: sha512-WWC0ZuMzCyDHYCasEGs4IPvLyTGftYwh6wIEOULOF0HXcqZlhwRzrK0w2VUlxWA98xnvb/jszw4ZSkJ6ADpM6Q==} + engines: {node: '>=0.10.0'} + + markdown-table@3.0.3: + resolution: {integrity: sha512-Z1NL3Tb1M9wH4XESsCDEksWoKTdlUafKc4pt0GRwjUyXaCFZ+dc3g2erqB6zm3szA2IUSi7VnPI+o/9jnxh9hw==} + + match-sorter@6.3.4: + resolution: {integrity: sha512-jfZW7cWS5y/1xswZo8VBOdudUiSd9nifYRWphc9M5D/ee4w4AoXLgBEdRbgVaxbMuagBPeUC5y2Hi8DO6o9aDg==} + + mdast-util-definitions@5.1.2: + resolution: {integrity: sha512-8SVPMuHqlPME/z3gqVwWY4zVXn8lqKv/pAhC57FuJ40ImXyBpmO5ukh98zB2v7Blql2FiHjHv9LVztSIqjY+MA==} + + mdast-util-find-and-replace@2.2.2: + resolution: {integrity: sha512-MTtdFRz/eMDHXzeK6W3dO7mXUlF82Gom4y0oOgvHhh/HXZAGvIQDUvQ0SuUx+j2tv44b8xTHOm8K/9OoRFnXKw==} + + mdast-util-from-markdown@1.3.1: + resolution: {integrity: sha512-4xTO/M8c82qBcnQc1tgpNtubGUW/Y1tBQ1B0i5CtSoelOLKFYlElIr3bvgREYYO5iRqbMY1YuqZng0GVOI8Qww==} + + mdast-util-gfm-autolink-literal@1.0.3: + resolution: {integrity: sha512-My8KJ57FYEy2W2LyNom4n3E7hKTuQk/0SES0u16tjA9Z3oFkF4RrC/hPAPgjlSpezsOvI8ObcXcElo92wn5IGA==} + + mdast-util-gfm-footnote@1.0.2: + resolution: {integrity: sha512-56D19KOGbE00uKVj3sgIykpwKL179QsVFwx/DCW0u/0+URsryacI4MAdNJl0dh+u2PSsD9FtxPFbHCzJ78qJFQ==} + + mdast-util-gfm-strikethrough@1.0.3: + resolution: {integrity: sha512-DAPhYzTYrRcXdMjUtUjKvW9z/FNAMTdU0ORyMcbmkwYNbKocDpdk+PX1L1dQgOID/+vVs1uBQ7ElrBQfZ0cuiQ==} + + mdast-util-gfm-table@1.0.7: + resolution: {integrity: sha512-jjcpmNnQvrmN5Vx7y7lEc2iIOEytYv7rTvu+MeyAsSHTASGCCRA79Igg2uKssgOs1i1po8s3plW0sTu1wkkLGg==} + + mdast-util-gfm-task-list-item@1.0.2: + resolution: {integrity: sha512-PFTA1gzfp1B1UaiJVyhJZA1rm0+Tzn690frc/L8vNX1Jop4STZgOE6bxUhnzdVSB+vm2GU1tIsuQcA9bxTQpMQ==} + + mdast-util-gfm@2.0.2: + resolution: {integrity: sha512-qvZ608nBppZ4icQlhQQIAdc6S3Ffj9RGmzwUKUWuEICFnd1LVkN3EktF7ZHAgfcEdvZB5owU9tQgt99e2TlLjg==} + + mdast-util-math@2.0.2: + resolution: {integrity: sha512-8gmkKVp9v6+Tgjtq6SYx9kGPpTf6FVYRa53/DLh479aldR9AyP48qeVOgNZ5X7QUK7nOy4yw7vg6mbiGcs9jWQ==} + + mdast-util-mdx-expression@1.3.2: + resolution: {integrity: sha512-xIPmR5ReJDu/DHH1OoIT1HkuybIfRGYRywC+gJtI7qHjCJp/M9jrmBEJW22O8lskDWm562BX2W8TiAwRTb0rKA==} + + mdast-util-mdx-jsx@2.1.4: + resolution: {integrity: sha512-DtMn9CmVhVzZx3f+optVDF8yFgQVt7FghCRNdlIaS3X5Bnym3hZwPbg/XW86vdpKjlc1PVj26SpnLGeJBXD3JA==} + + mdast-util-mdx@2.0.1: + resolution: {integrity: sha512-38w5y+r8nyKlGvNjSEqWrhG0w5PmnRA+wnBvm+ulYCct7nsGYhFVb0lljS9bQav4psDAS1eGkP2LMVcZBi/aqw==} + + mdast-util-mdxjs-esm@1.3.1: + resolution: {integrity: sha512-SXqglS0HrEvSdUEfoXFtcg7DRl7S2cwOXc7jkuusG472Mmjag34DUDeOJUZtl+BVnyeO1frIgVpHlNRWc2gk/w==} + + mdast-util-phrasing@3.0.1: + resolution: {integrity: sha512-WmI1gTXUBJo4/ZmSk79Wcb2HcjPJBzM1nlI/OUWA8yk2X9ik3ffNbBGsU+09BFmXaL1IBb9fiuvq6/KMiNycSg==} + + mdast-util-to-hast@12.3.0: + resolution: {integrity: sha512-pits93r8PhnIoU4Vy9bjW39M2jJ6/tdHyja9rrot9uujkN7UTU9SDnE6WNJz/IGyQk3XHX6yNNtrBH6cQzm8Hw==} + + mdast-util-to-hast@13.1.0: + resolution: {integrity: sha512-/e2l/6+OdGp/FB+ctrJ9Avz71AN/GRH3oi/3KAx/kMnoUsD6q0woXlDT8lLEeViVKE7oZxE7RXzvO3T8kF2/sA==} + + mdast-util-to-markdown@1.5.0: + resolution: {integrity: sha512-bbv7TPv/WC49thZPg3jXuqzuvI45IL2EVAr/KxF0BSdHsU0ceFHOmwQn6evxAh1GaoK/6GQ1wp4R4oW2+LFL/A==} + + mdast-util-to-string@3.2.0: + resolution: {integrity: sha512-V4Zn/ncyN1QNSqSBxTrMOLpjr+IKdHl2v3KVLoWmDPscP4r9GcCi71gjgvUV1SFSKh92AjAG4peFuBl2/YgCJg==} + + mermaid@10.9.1: + resolution: {integrity: sha512-Mx45Obds5W1UkW1nv/7dHRsbfMM1aOKA2+Pxs/IGHNonygDHwmng8xTHyS9z4KWVi0rbko8gjiBmuwwXQ7tiNA==} + + micromark-core-commonmark@1.1.0: + resolution: {integrity: sha512-BgHO1aRbolh2hcrzL2d1La37V0Aoz73ymF8rAcKnohLy93titmv62E0gP8Hrx9PKcKrqCZ1BbLGbP3bEhoXYlw==} + + micromark-extension-gfm-autolink-literal@1.0.5: + resolution: {integrity: sha512-z3wJSLrDf8kRDOh2qBtoTRD53vJ+CWIyo7uyZuxf/JAbNJjiHsOpG1y5wxk8drtv3ETAHutCu6N3thkOOgueWg==} + + micromark-extension-gfm-footnote@1.1.2: + resolution: {integrity: sha512-Yxn7z7SxgyGWRNa4wzf8AhYYWNrwl5q1Z8ii+CSTTIqVkmGZF1CElX2JI8g5yGoM3GAman9/PVCUFUSJ0kB/8Q==} + + micromark-extension-gfm-strikethrough@1.0.7: + resolution: {integrity: sha512-sX0FawVE1o3abGk3vRjOH50L5TTLr3b5XMqnP9YDRb34M0v5OoZhG+OHFz1OffZ9dlwgpTBKaT4XW/AsUVnSDw==} + + micromark-extension-gfm-table@1.0.7: + resolution: {integrity: sha512-3ZORTHtcSnMQEKtAOsBQ9/oHp9096pI/UvdPtN7ehKvrmZZ2+bbWhi0ln+I9drmwXMt5boocn6OlwQzNXeVeqw==} + + micromark-extension-gfm-tagfilter@1.0.2: + resolution: {integrity: sha512-5XWB9GbAUSHTn8VPU8/1DBXMuKYT5uOgEjJb8gN3mW0PNW5OPHpSdojoqf+iq1xo7vWzw/P8bAHY0n6ijpXF7g==} + + micromark-extension-gfm-task-list-item@1.0.5: + resolution: {integrity: sha512-RMFXl2uQ0pNQy6Lun2YBYT9g9INXtWJULgbt01D/x8/6yJ2qpKyzdZD3pi6UIkzF++Da49xAelVKUeUMqd5eIQ==} + + micromark-extension-gfm@2.0.3: + resolution: {integrity: sha512-vb9OoHqrhCmbRidQv/2+Bc6pkP0FrtlhurxZofvOEy5o8RtuuvTq+RQ1Vw5ZDNrVraQZu3HixESqbG+0iKk/MQ==} + + micromark-extension-math@2.1.2: + resolution: {integrity: sha512-es0CcOV89VNS9wFmyn+wyFTKweXGW4CEvdaAca6SWRWPyYCbBisnjaHLjWO4Nszuiud84jCpkHsqAJoa768Pvg==} + + micromark-extension-mdx-expression@1.0.8: + resolution: {integrity: sha512-zZpeQtc5wfWKdzDsHRBY003H2Smg+PUi2REhqgIhdzAa5xonhP03FcXxqFSerFiNUr5AWmHpaNPQTBVOS4lrXw==} + + micromark-extension-mdx-jsx@1.0.5: + resolution: {integrity: sha512-gPH+9ZdmDflbu19Xkb8+gheqEDqkSpdCEubQyxuz/Hn8DOXiXvrXeikOoBA71+e8Pfi0/UYmU3wW3H58kr7akA==} + + micromark-extension-mdx-md@1.0.1: + resolution: {integrity: sha512-7MSuj2S7xjOQXAjjkbjBsHkMtb+mDGVW6uI2dBL9snOBCbZmoNgDAeZ0nSn9j3T42UE/g2xVNMn18PJxZvkBEA==} + + micromark-extension-mdxjs-esm@1.0.5: + resolution: {integrity: sha512-xNRBw4aoURcyz/S69B19WnZAkWJMxHMT5hE36GtDAyhoyn/8TuAeqjFJQlwk+MKQsUD7b3l7kFX+vlfVWgcX1w==} + + micromark-extension-mdxjs@1.0.1: + resolution: {integrity: sha512-7YA7hF6i5eKOfFUzZ+0z6avRG52GpWR8DL+kN47y3f2KhxbBZMhmxe7auOeaTBrW2DenbbZTf1ea9tA2hDpC2Q==} + + micromark-factory-destination@1.1.0: + resolution: {integrity: sha512-XaNDROBgx9SgSChd69pjiGKbV+nfHGDPVYFs5dOoDd7ZnMAE+Cuu91BCpsY8RT2NP9vo/B8pds2VQNCLiu0zhg==} + + micromark-factory-label@1.1.0: + resolution: {integrity: sha512-OLtyez4vZo/1NjxGhcpDSbHQ+m0IIGnT8BoPamh+7jVlzLJBH98zzuCoUeMxvM6WsNeh8wx8cKvqLiPHEACn0w==} + + micromark-factory-mdx-expression@1.0.9: + resolution: {integrity: sha512-jGIWzSmNfdnkJq05c7b0+Wv0Kfz3NJ3N4cBjnbO4zjXIlxJr+f8lk+5ZmwFvqdAbUy2q6B5rCY//g0QAAaXDWA==} + + micromark-factory-space@1.1.0: + resolution: {integrity: sha512-cRzEj7c0OL4Mw2v6nwzttyOZe8XY/Z8G0rzmWQZTBi/jjwyw/U4uqKtUORXQrR5bAZZnbTI/feRV/R7hc4jQYQ==} + + micromark-factory-title@1.1.0: + resolution: {integrity: sha512-J7n9R3vMmgjDOCY8NPw55jiyaQnH5kBdV2/UXCtZIpnHH3P6nHUKaH7XXEYuWwx/xUJcawa8plLBEjMPU24HzQ==} + + micromark-factory-whitespace@1.1.0: + resolution: {integrity: sha512-v2WlmiymVSp5oMg+1Q0N1Lxmt6pMhIHD457whWM7/GUlEks1hI9xj5w3zbc4uuMKXGisksZk8DzP2UyGbGqNsQ==} + + micromark-util-character@1.2.0: + resolution: {integrity: sha512-lXraTwcX3yH/vMDaFWCQJP1uIszLVebzUa3ZHdrgxr7KEU/9mL4mVgCpGbyhvNLNlauROiNUq7WN5u7ndbY6xg==} + + micromark-util-character@2.1.0: + resolution: {integrity: sha512-KvOVV+X1yLBfs9dCBSopq/+G1PcgT3lAK07mC4BzXi5E7ahzMAF8oIupDDJ6mievI6F+lAATkbQQlQixJfT3aQ==} + + micromark-util-chunked@1.1.0: + resolution: {integrity: sha512-Ye01HXpkZPNcV6FiyoW2fGZDUw4Yc7vT0E9Sad83+bEDiCJ1uXu0S3mr8WLpsz3HaG3x2q0HM6CTuPdcZcluFQ==} + + micromark-util-classify-character@1.1.0: + resolution: {integrity: sha512-SL0wLxtKSnklKSUplok1WQFoGhUdWYKggKUiqhX+Swala+BtptGCu5iPRc+xvzJ4PXE/hwM3FNXsfEVgoZsWbw==} + + micromark-util-combine-extensions@1.1.0: + resolution: {integrity: sha512-Q20sp4mfNf9yEqDL50WwuWZHUrCO4fEyeDCnMGmG5Pr0Cz15Uo7KBs6jq+dq0EgX4DPwwrh9m0X+zPV1ypFvUA==} + + micromark-util-decode-numeric-character-reference@1.1.0: + resolution: {integrity: sha512-m9V0ExGv0jB1OT21mrWcuf4QhP46pH1KkfWy9ZEezqHKAxkj4mPCy3nIH1rkbdMlChLHX531eOrymlwyZIf2iw==} + + micromark-util-decode-string@1.1.0: + resolution: {integrity: sha512-YphLGCK8gM1tG1bd54azwyrQRjCFcmgj2S2GoJDNnh4vYtnL38JS8M4gpxzOPNyHdNEpheyWXCTnnTDY3N+NVQ==} + + micromark-util-encode@1.1.0: + resolution: {integrity: sha512-EuEzTWSTAj9PA5GOAs992GzNh2dGQO52UvAbtSOMvXTxv3Criqb6IOzJUBCmEqrrXSblJIJBbFFv6zPxpreiJw==} + + micromark-util-encode@2.0.0: + resolution: {integrity: sha512-pS+ROfCXAGLWCOc8egcBvT0kf27GoWMqtdarNfDcjb6YLuV5cM3ioG45Ys2qOVqeqSbjaKg72vU+Wby3eddPsA==} + + micromark-util-events-to-acorn@1.2.3: + resolution: {integrity: sha512-ij4X7Wuc4fED6UoLWkmo0xJQhsktfNh1J0m8g4PbIMPlx+ek/4YdW5mvbye8z/aZvAPUoxgXHrwVlXAPKMRp1w==} + + micromark-util-html-tag-name@1.2.0: + resolution: {integrity: sha512-VTQzcuQgFUD7yYztuQFKXT49KghjtETQ+Wv/zUjGSGBioZnkA4P1XXZPT1FHeJA6RwRXSF47yvJ1tsJdoxwO+Q==} + + micromark-util-normalize-identifier@1.1.0: + resolution: {integrity: sha512-N+w5vhqrBihhjdpM8+5Xsxy71QWqGn7HYNUvch71iV2PM7+E3uWGox1Qp90loa1ephtCxG2ftRV/Conitc6P2Q==} + + micromark-util-resolve-all@1.1.0: + resolution: {integrity: sha512-b/G6BTMSg+bX+xVCshPTPyAu2tmA0E4X98NSR7eIbeC6ycCqCeE7wjfDIgzEbkzdEVJXRtOG4FbEm/uGbCRouA==} + + micromark-util-sanitize-uri@1.2.0: + resolution: {integrity: sha512-QO4GXv0XZfWey4pYFndLUKEAktKkG5kZTdUNaTAkzbuJxn2tNBOr+QtxR2XpWaMhbImT2dPzyLrPXLlPhph34A==} + + micromark-util-sanitize-uri@2.0.0: + resolution: {integrity: sha512-WhYv5UEcZrbAtlsnPuChHUAsu/iBPOVaEVsntLBIdpibO0ddy8OzavZz3iL2xVvBZOpolujSliP65Kq0/7KIYw==} + + micromark-util-subtokenize@1.1.0: + resolution: {integrity: sha512-kUQHyzRoxvZO2PuLzMt2P/dwVsTiivCK8icYTeR+3WgbuPqfHgPPy7nFKbeqRivBvn/3N3GBiNC+JRTMSxEC7A==} + + micromark-util-symbol@1.1.0: + resolution: {integrity: sha512-uEjpEYY6KMs1g7QfJ2eX1SQEV+ZT4rUD3UcF6l57acZvLNK7PBZL+ty82Z1qhK1/yXIY4bdx04FKMgR0g4IAag==} + + micromark-util-symbol@2.0.0: + resolution: {integrity: sha512-8JZt9ElZ5kyTnO94muPxIGS8oyElRJaiJO8EzV6ZSyGQ1Is8xwl4Q45qU5UOg+bGH4AikWziz0iN4sFLWs8PGw==} + + micromark-util-types@1.1.0: + resolution: {integrity: sha512-ukRBgie8TIAcacscVHSiddHjO4k/q3pnedmzMQ4iwDcK0FtFCohKOlFbaOL/mPgfnPsL3C1ZyxJa4sbWrBl3jg==} + + micromark-util-types@2.0.0: + resolution: {integrity: sha512-oNh6S2WMHWRZrmutsRmDDfkzKtxF+bc2VxLC9dvtrDIRFln627VsFP6fLMgTryGDljgLPjkrzQSDcPrjPyDJ5w==} + + micromark@3.2.0: + resolution: {integrity: sha512-uD66tJj54JLYq0De10AhWycZWGQNUvDI55xPgk2sQM5kn1JYlhbCMTtEeT27+vAhW2FBQxLlOmS3pmA7/2z4aA==} + + mri@1.2.0: + resolution: {integrity: sha512-tzzskb3bG8LvYGFF/mDTpq3jpI6Q9wc3LEmBaghu+DdCssd1FakN7Bc0hVNmEyGq1bq3RgfkCb3cmQLpNPOroA==} + engines: {node: '>=4'} + + ms@2.1.2: + resolution: {integrity: sha512-sGkPx+VjMtmA6MX27oA4FBFELFCZZ4S4XqeGOXCv68tT+jb3vk/RyaKWP0PTKyWtmLSM0b+adUTEvbs1PEaH2w==} + + nanoid@3.3.7: + resolution: {integrity: sha512-eSRppjcPIatRIMC1U6UngP8XFcz8MQWGQdt1MTBQ7NaAmvXDfvNxbvWV3x2y6CdEUciCSsDHDQZbhYaB8QEo2g==} + engines: {node: ^10 || ^12 || ^13.7 || ^14 || >=15.0.1} + hasBin: true + + next-mdx-remote@4.4.1: + resolution: {integrity: sha512-1BvyXaIou6xy3XoNF4yaMZUCb6vD2GTAa5ciOa6WoO+gAUTYsb1K4rI/HSC2ogAWLrb/7VSV52skz07vOzmqIQ==} + engines: {node: '>=14', npm: '>=7'} + peerDependencies: + react: '>=16.x <=18.x' + react-dom: '>=16.x <=18.x' + + next-seo@6.5.0: + resolution: {integrity: sha512-MfzUeWTN/x/rsKp/1n0213eojO97lIl0unxqbeCY+6pAucViHDA8GSLRRcXpgjsSmBxfCFdfpu7LXbt4ANQoNQ==} + peerDependencies: + next: ^8.1.1-canary.54 || >=9.0.0 + react: '>=16.0.0' + react-dom: '>=16.0.0' + + next-themes@0.2.1: + resolution: {integrity: sha512-B+AKNfYNIzh0vqQQKqQItTS8evEouKD7H5Hj3kmuPERwddR2TxvDSFZuTj6T7Jfn1oyeUyJMydPl1Bkxkh0W7A==} + peerDependencies: + next: '*' + react: '*' + react-dom: '*' + + next@14.2.3: + resolution: {integrity: sha512-dowFkFTR8v79NPJO4QsBUtxv0g9BrS/phluVpMAt2ku7H+cbcBJlopXjkWlwxrk/xGqMemr7JkGPGemPrLLX7A==} + engines: {node: '>=18.17.0'} + hasBin: true + peerDependencies: + '@opentelemetry/api': ^1.1.0 + '@playwright/test': ^1.41.2 + react: ^18.2.0 + react-dom: ^18.2.0 + sass: ^1.3.0 + peerDependenciesMeta: + '@opentelemetry/api': + optional: true + '@playwright/test': + optional: true + sass: + optional: true + + nextra-theme-docs@2.13.4: + resolution: {integrity: sha512-2XOoMfwBCTYBt8ds4ZHftt9Wyf2XsykiNo02eir/XEYB+sGeUoE77kzqfidjEOKCSzOHYbK9BDMcg2+B/2vYRw==} + peerDependencies: + next: '>=9.5.3' + nextra: 2.13.4 + react: '>=16.13.1' + react-dom: '>=16.13.1' + + nextra@2.13.4: + resolution: {integrity: sha512-7of2rSBxuUa3+lbMmZwG9cqgftcoNOVQLTT6Rxf3EhBR9t1EI7b43dted8YoqSNaigdE3j1CoyNkX8N/ZzlEpw==} + engines: {node: '>=16'} + peerDependencies: + next: '>=9.5.3' + react: '>=16.13.1' + react-dom: '>=16.13.1' + + non-layered-tidy-tree-layout@2.0.2: + resolution: {integrity: sha512-gkXMxRzUH+PB0ax9dUN0yYF0S25BqeAYqhgMaLUFmpXLEk7Fcu8f4emJuOAY0V8kjDICxROIKsTAKsV/v355xw==} + + npm-run-path@2.0.2: + resolution: {integrity: sha512-lJxZYlT4DW/bRUtFh1MQIWqmLwQfAxnqWG4HhEdjMlkrJYnJn0Jrr2u3mgxqaWsdiBc76TYkTG/mhrnYTuzfHw==} + engines: {node: '>=4'} + + npm-to-yarn@2.2.1: + resolution: {integrity: sha512-O/j/ROyX0KGLG7O6Ieut/seQ0oiTpHF2tXAcFbpdTLQFiaNtkyTXXocM1fwpaa60dg1qpWj0nHlbNhx6qwuENQ==} + engines: {node: ^12.22.0 || ^14.17.0 || >=16.0.0} + + p-finally@1.0.0: + resolution: {integrity: sha512-LICb2p9CB7FS+0eR1oqWnHhp0FljGLZCWBE9aix0Uye9W8LTQPwMTYVGWQWIw9RdQiDg4+epXQODwIYJtSJaow==} + engines: {node: '>=4'} + + p-limit@3.1.0: + resolution: {integrity: sha512-TYOanM3wGwNGsZN2cVTYPArw454xnXj5qmWF1bEoAc4+cU/ol7GVh7odevjp1FNHduHc3KZMcFduxU5Xc6uJRQ==} + engines: {node: '>=10'} + + parse-entities@4.0.1: + resolution: {integrity: sha512-SWzvYcSJh4d/SGLIOQfZ/CoNv6BTlI6YEQ7Nj82oDVnRpwe/Z/F1EMx42x3JAOwGBlCjeCH0BRJQbQ/opHL17w==} + + parse-numeric-range@1.3.0: + resolution: {integrity: sha512-twN+njEipszzlMJd4ONUYgSfZPDxgHhT9Ahed5uTigpQn90FggW4SA/AIPq/6a149fTbE9qBEcSwE3FAEp6wQQ==} + + parse-path@7.0.0: + resolution: {integrity: sha512-Euf9GG8WT9CdqwuWJGdf3RkUcTBArppHABkO7Lm8IzRQp0e2r/kkFnmhu4TSK30Wcu5rVAZLmfPKSBBi9tWFog==} + + parse-url@8.1.0: + resolution: {integrity: sha512-xDvOoLU5XRrcOZvnI6b8zA6n9O9ejNk/GExuz1yBuWUGn9KA97GI6HTs6u02wKara1CeVmZhH+0TZFdWScR89w==} + + parse5@7.1.2: + resolution: {integrity: sha512-Czj1WaSVpaoj0wbhMzLmWD69anp2WH7FXMB9n1Sy8/ZFF9jolSQVMu1Ij5WIyGmcBmhk7EOndpO4mIpihVqAXw==} + + path-key@2.0.1: + resolution: {integrity: sha512-fEHGKCSmUSDPv4uoj8AlD+joPlq3peND+HRYyxFz4KPw4z926S/b8rIuFs2FYJg3BwsxJf6A9/3eIdLaYC+9Dw==} + engines: {node: '>=4'} + + periscopic@3.1.0: + resolution: {integrity: sha512-vKiQ8RRtkl9P+r/+oefh25C3fhybptkHKCZSPlcXiJux2tJF55GnEj3BVn4A5gKfq9NWWXXrxkHBwVPUfH0opw==} + + picocolors@1.0.1: + resolution: {integrity: sha512-anP1Z8qwhkbmu7MFP5iTt+wQKXgwzf7zTyGlcdzabySa9vd0Xt392U0rVmz9poOaBj0uHJKyyo9/upk0HrEQew==} + + postcss@8.4.31: + resolution: {integrity: sha512-PS08Iboia9mts/2ygV3eLpY5ghnUcfLV/EXTOW1E2qYxJKGGBUtNjN76FYHnMs36RmARn41bC0AZmn+rR0OVpQ==} + engines: {node: ^10 || ^12 || >=14} + + property-information@6.5.0: + resolution: {integrity: sha512-PgTgs/BlvHxOu8QuEN7wi5A0OmXaBcHpmCSTehcs6Uuu9IkDIEo13Hy7n898RHfrQ49vKCoGeWZSaAK01nwVig==} + + protocols@2.0.1: + resolution: {integrity: sha512-/XJ368cyBJ7fzLMwLKv1e4vLxOju2MNAIokcr7meSaNcVbWz/CPcW22cP04mwxOErdA5mwjA8Q6w/cdAQxVn7Q==} + + pseudomap@1.0.2: + resolution: {integrity: sha512-b/YwNhb8lk1Zz2+bXXpS/LK9OisiZZ1SNsSLxN1x2OXVEhW2Ckr/7mWE5vrC1ZTiJlD9g19jWszTmJsB+oEpFQ==} + + react-dom@18.3.1: + resolution: {integrity: sha512-5m4nQKp+rZRb09LNH59GM4BxTh9251/ylbKIbpe7TpGxfJ+9kv6BLkLBXIjjspbgbnIBNqlI23tRnTWT0snUIw==} + peerDependencies: + react: ^18.3.1 + + react@18.3.1: + resolution: {integrity: sha512-wS+hAgJShR0KhEvPJArfuPVN1+Hz1t0Y6n5jLrGQbkb4urgPE/0Rve+1kMB1v/oWgHgm4WIcV+i7F2pTVj+2iQ==} + engines: {node: '>=0.10.0'} + + reading-time@1.5.0: + resolution: {integrity: sha512-onYyVhBNr4CmAxFsKS7bz+uTLRakypIe4R+5A824vBSkQy/hB3fZepoVEf8OVAxzLvK+H/jm9TzpI3ETSm64Kg==} + + regenerator-runtime@0.14.1: + resolution: {integrity: sha512-dYnhHh0nJoMfnkZs6GmmhFknAGRrLznOu5nc9ML+EJxGvrx6H7teuevqVqCuPcPK//3eDrrjQhehXVx9cnkGdw==} + + rehype-katex@7.0.0: + resolution: {integrity: sha512-h8FPkGE00r2XKU+/acgqwWUlyzve1IiOKwsEkg4pDL3k48PiE0Pt+/uLtVHDVkN1yA4iurZN6UES8ivHVEQV6Q==} + + rehype-pretty-code@0.9.11: + resolution: {integrity: sha512-Eq90eCYXQJISktfRZ8PPtwc5SUyH6fJcxS8XOMnHPUQZBtC6RYo67gGlley9X2nR8vlniPj0/7oCDEYHKQa/oA==} + engines: {node: '>=16'} + peerDependencies: + shiki: '*' + + rehype-raw@7.0.0: + resolution: {integrity: sha512-/aE8hCfKlQeA8LmyeyQvQF3eBiLRGNlfBJEvWH7ivp9sBqs7TNqBL5X3v157rM4IFETqDnIOO+z5M/biZbo9Ww==} + + remark-gfm@3.0.1: + resolution: {integrity: sha512-lEFDoi2PICJyNrACFOfDD3JlLkuSbOa5Wd8EPt06HUdptv8Gn0bxYTdbU/XXQ3swAPkEaGxxPN9cbnMHvVu1Ig==} + + remark-math@5.1.1: + resolution: {integrity: sha512-cE5T2R/xLVtfFI4cCePtiRn+e6jKMtFDR3P8V3qpv8wpKjwvHoBA4eJzvX+nVrnlNy0911bdGmuspCSwetfYHw==} + + remark-mdx@2.3.0: + resolution: {integrity: sha512-g53hMkpM0I98MU266IzDFMrTD980gNF3BJnkyFcmN+dD873mQeD5rdMO3Y2X+x8umQfbSE0PcoEDl7ledSA+2g==} + + remark-parse@10.0.2: + resolution: {integrity: sha512-3ydxgHa/ZQzG8LvC7jTXccARYDcRld3VfcgIIFs7bI6vbRSxJJmzgLEIIoYKyrfhaY+ujuWaf/PJiMZXoiCXgw==} + + remark-reading-time@2.0.1: + resolution: {integrity: sha512-fy4BKy9SRhtYbEHvp6AItbRTnrhiDGbqLQTSYVbQPGuRCncU1ubSsh9p/W5QZSxtYcUXv8KGL0xBgPLyNJA1xw==} + + remark-rehype@10.1.0: + resolution: {integrity: sha512-EFmR5zppdBp0WQeDVZ/b66CWJipB2q2VLNFMabzDSGR66Z2fQii83G5gTBbgGEnEEA0QRussvrFHxk1HWGJskw==} + + remove-accents@0.5.0: + resolution: {integrity: sha512-8g3/Otx1eJaVD12e31UbJj1YzdtVvzH85HV7t+9MJYk/u3XmkOUJ5Ys9wQrf9PCPK8+xn4ymzqYCiZl6QWKn+A==} + + robust-predicates@3.0.2: + resolution: {integrity: sha512-IXgzBWvWQwE6PrDI05OvmXUIruQTcoMDzRsOd5CDvHCVLcLHMTSYvOK5Cm46kWqlV3yAbuSpBZdJ5oP5OUoStg==} + + rw@1.3.3: + resolution: {integrity: sha512-PdhdWy89SiZogBLaw42zdeqtRJ//zFd2PgQavcICDUgJT5oW10QCRKbJ6bg4r0/UY2M6BWd5tkxuGFRvCkgfHQ==} + + sade@1.8.1: + resolution: {integrity: sha512-xal3CZX1Xlo/k4ApwCFrHVACi9fBqJ7V+mwhBsuf/1IOKbBy098Fex+Wa/5QMubw09pSZ/u8EY8PWgevJsXp1A==} + engines: {node: '>=6'} + + safer-buffer@2.1.2: + resolution: {integrity: sha512-YZo3K82SD7Riyi0E1EQPojLz7kpepnSQI9IyPbHHg1XXXevb5dJI7tpyN2ADxGcQbHG7vcyRHk0cbwqcQriUtg==} + + scheduler@0.23.2: + resolution: {integrity: sha512-UOShsPwz7NrMUqhR6t0hWjFduvOzbtv7toDH1/hIrfRNIDBnnBWd0CwJTGvTpngVlmwGCdP9/Zl/tVrDqcuYzQ==} + + scroll-into-view-if-needed@3.1.0: + resolution: {integrity: sha512-49oNpRjWRvnU8NyGVmUaYG4jtTkNonFZI86MmGRDqBphEK2EXT9gdEUoQPZhuBM8yWHxCWbobltqYO5M4XrUvQ==} + + section-matter@1.0.0: + resolution: {integrity: sha512-vfD3pmTzGpufjScBh50YHKzEu2lxBWhVEHsNGoEXmCmn2hKGfeNLYMzCJpe8cD7gqX7TJluOVpBkAequ6dgMmA==} + engines: {node: '>=4'} + + shebang-command@1.2.0: + resolution: {integrity: sha512-EV3L1+UQWGor21OmnvojK36mhg+TyIKDh3iFBKBohr5xeXIhNBcx8oWdgkTEEQ+BEFFYdLRuqMfd5L84N1V5Vg==} + engines: {node: '>=0.10.0'} + + shebang-regex@1.0.0: + resolution: {integrity: sha512-wpoSFAxys6b2a2wHZ1XpDSgD7N9iVjg29Ph9uV/uaP9Ex/KXlkTZTeddxDPSYQpgvzKLGJke2UU0AzoGCjNIvQ==} + engines: {node: '>=0.10.0'} + + shiki@0.14.7: + resolution: {integrity: sha512-dNPAPrxSc87ua2sKJ3H5dQ/6ZaY8RNnaAqK+t0eG7p0Soi2ydiqbGOTaZCqaYvA/uZYfS1LJnemt3Q+mSfcPCg==} + + signal-exit@3.0.7: + resolution: {integrity: sha512-wnD2ZE+l+SPC/uoS0vXeE9L1+0wuaMqKlfz9AMUo38JsyLSBWSFcHR1Rri62LZc12vLr1gb3jl7iwQhgwpAbGQ==} + + slash@3.0.0: + resolution: {integrity: sha512-g9Q1haeby36OSStwb4ntCGGGaKsaVSjQ68fBxoQcutl5fS1vuY18H3wSt3jFyFtrkx+Kz0V1G85A4MyAdDMi2Q==} + engines: {node: '>=8'} + + sort-keys@5.0.0: + resolution: {integrity: sha512-Pdz01AvCAottHTPQGzndktFNdbRA75BgOfeT1hH+AMnJFv8lynkPi42rfeEhpx1saTEI3YNMWxfqu0sFD1G8pw==} + engines: {node: '>=12'} + + source-map-js@1.2.0: + resolution: {integrity: sha512-itJW8lvSA0TXEphiRoawsCksnlf8SyvmFzIhltqAHluXd88pkCd+cXJVHTDwdCr0IzwptSm035IHQktUu1QUMg==} + engines: {node: '>=0.10.0'} + + source-map@0.7.4: + resolution: {integrity: sha512-l3BikUxvPOcn5E74dZiq5BGsTb5yEwhaTSzccU6t4sDOH8NWJCstKO5QT2CvtFoK6F0saL7p9xHAqHOlCPJygA==} + engines: {node: '>= 8'} + + space-separated-tokens@2.0.2: + resolution: {integrity: sha512-PEGlAwrG8yXGXRjW32fGbg66JAlOAwbObuqVoJpv/mRgoWDQfgH1wDPvtzWyUSNAXBGSk8h755YDbbcEy3SH2Q==} + + sprintf-js@1.0.3: + resolution: {integrity: sha512-D9cPgkvLlV3t3IzL0D0YLvGA9Ahk4PcvVwUbN0dSGr1aP0Nrt4AEnTUbuGvquEC0mA64Gqt1fzirlRs5ibXx8g==} + + streamsearch@1.1.0: + resolution: {integrity: sha512-Mcc5wHehp9aXz1ax6bZUyY5afg9u2rv5cqQI3mRrYkGC8rW2hM02jWuwjtL++LS5qinSyhj2QfLyNsuc+VsExg==} + engines: {node: '>=10.0.0'} + + stringify-entities@4.0.4: + resolution: {integrity: sha512-IwfBptatlO+QCJUo19AqvrPNqlVMpW9YEL2LIVY+Rpv2qsjCGxaDLNRgeGsQWJhfItebuJhsGSLjaBbNSQ+ieg==} + + strip-bom-string@1.0.0: + resolution: {integrity: sha512-uCC2VHvQRYu+lMh4My/sFNmF2klFymLX1wHJeXnbEJERpV/ZsVuonzerjfrGpIGF7LBVa1O7i9kjiWvJiFck8g==} + engines: {node: '>=0.10.0'} + + strip-eof@1.0.0: + resolution: {integrity: sha512-7FCwGGmx8mD5xQd3RPUvnSpUXHM3BWuzjtpD4TXsfcZ9EL4azvVVUscFYwD9nx8Kh+uCBC00XBtAykoMHwTh8Q==} + engines: {node: '>=0.10.0'} + + style-to-object@0.4.4: + resolution: {integrity: sha512-HYNoHZa2GorYNyqiCaBgsxvcJIn7OHq6inEga+E6Ke3m5JkoqpQbnFssk4jwe+K7AhGa2fcha4wSOf1Kn01dMg==} + + styled-jsx@5.1.1: + resolution: {integrity: sha512-pW7uC1l4mBZ8ugbiZrcIsiIvVx1UmTfw7UkC3Um2tmfUq9Bhk8IiyEIPl6F8agHgjzku6j0xQEZbfA5uSgSaCw==} + engines: {node: '>= 12.0.0'} + peerDependencies: + '@babel/core': '*' + babel-plugin-macros: '*' + react: '>= 16.8.0 || 17.x.x || ^18.0.0-0' + peerDependenciesMeta: + '@babel/core': + optional: true + babel-plugin-macros: + optional: true + + stylis@4.3.2: + resolution: {integrity: sha512-bhtUjWd/z6ltJiQwg0dUfxEJ+W+jdqQd8TbWLWyeIJHlnsqmGLRFFd8e5mA0AZi/zx90smXRlN66YMTcaSFifg==} + + supports-color@4.5.0: + resolution: {integrity: sha512-ycQR/UbvI9xIlEdQT1TQqwoXtEldExbCEAJgRo5YXlmSKjv6ThHnP9/vwGa1gr19Gfw+LkFd7KqYMhzrRC5JYw==} + engines: {node: '>=4'} + + title@3.5.3: + resolution: {integrity: sha512-20JyowYglSEeCvZv3EZ0nZ046vLarO37prvV0mbtQV7C8DJPGgN967r8SJkqd3XK3K3lD3/Iyfp3avjfil8Q2Q==} + hasBin: true + + titleize@1.0.0: + resolution: {integrity: sha512-TARUb7z1pGvlLxgPk++7wJ6aycXF3GJ0sNSBTAsTuJrQG5QuZlkUQP+zl+nbjAh4gMX9yDw9ZYklMd7vAfJKEw==} + engines: {node: '>=0.10.0'} + + trim-lines@3.0.1: + resolution: {integrity: sha512-kRj8B+YHZCc9kQYdWfJB2/oUl9rA99qbowYYBtr4ui4mZyAQ2JpvVBd/6U2YloATfqBhBTSMhTpgBHtU0Mf3Rg==} + + trough@2.2.0: + resolution: {integrity: sha512-tmMpK00BjZiUyVyvrBK7knerNgmgvcV/KLVyuma/SC+TQN167GrMRciANTz09+k3zW8L8t60jWO1GpfkZdjTaw==} + + ts-dedent@2.2.0: + resolution: {integrity: sha512-q5W7tVM71e2xjHZTlgfTDoPF/SmqKG5hddq9SzR49CH2hayqRKJtQ4mtRlSxKaJlR/+9rEM+mnBHf7I2/BQcpQ==} + engines: {node: '>=6.10'} + + tslib@2.6.3: + resolution: {integrity: sha512-xNvxJEOUiWPGhUuUdQgAJPKOOJfGnIyKySOc09XkKsgdUV/3E2zvwZYdejjmRgPCgcym1juLH3226yA7sEFJKQ==} + + type-fest@1.4.0: + resolution: {integrity: sha512-yGSza74xk0UG8k+pLh5oeoYirvIiWo5t0/o3zHHAO2tRDiZcxWP7fywNlXhqb6/r6sWvwi+RsyQMWhVLe4BVuA==} + engines: {node: '>=10'} + + typescript@5.4.5: + resolution: {integrity: sha512-vcI4UpRgg81oIRUFwR0WSIHKt11nJ7SAVlYNIu+QpqeyXP+gpQJy/Z4+F0aGxSE4MqwjyXvW/TzgkLAx2AGHwQ==} + engines: {node: '>=14.17'} + hasBin: true + + undici-types@5.26.5: + resolution: {integrity: sha512-JlCMO+ehdEIKqlFxk6IfVoAUVmgz7cU7zD/h9XZ0qzeosSHmUJVOzSQvvYSYWXkFXC+IfLKSIffhv0sVZup6pA==} + + unified@10.1.2: + resolution: {integrity: sha512-pUSWAi/RAnVy1Pif2kAoeWNBa3JVrx0MId2LASj8G+7AiHWoKZNTomq6LG326T68U7/e263X6fTdcXIy7XnF7Q==} + + unist-util-find-after@5.0.0: + resolution: {integrity: sha512-amQa0Ep2m6hE2g72AugUItjbuM8X8cGQnFoHk0pGfrFeT9GZhzN5SW8nRsiGKK7Aif4CrACPENkA6P/Lw6fHGQ==} + + unist-util-generated@2.0.1: + resolution: {integrity: sha512-qF72kLmPxAw0oN2fwpWIqbXAVyEqUzDHMsbtPvOudIlUzXYFIeQIuxXQCRCFh22B7cixvU0MG7m3MW8FTq/S+A==} + + unist-util-is@5.2.1: + resolution: {integrity: sha512-u9njyyfEh43npf1M+yGKDGVPbY/JWEemg5nH05ncKPfi+kBbKBJoTdsogMu33uhytuLlv9y0O7GH7fEdwLdLQw==} + + unist-util-is@6.0.0: + resolution: {integrity: sha512-2qCTHimwdxLfz+YzdGfkqNlH0tLi9xjTnHddPmJwtIG9MGsdbutfTc4P+haPD7l7Cjxf/WZj+we5qfVPvvxfYw==} + + unist-util-position-from-estree@1.1.2: + resolution: {integrity: sha512-poZa0eXpS+/XpoQwGwl79UUdea4ol2ZuCYguVaJS4qzIOMDzbqz8a3erUCOmubSZkaOuGamb3tX790iwOIROww==} + + unist-util-position@4.0.4: + resolution: {integrity: sha512-kUBE91efOWfIVBo8xzh/uZQ7p9ffYRtUbMRZBNFYwf0RK8koUMx6dGUfwylLOKmaT2cs4wSW96QoYUSXAyEtpg==} + + unist-util-position@5.0.0: + resolution: {integrity: sha512-fucsC7HjXvkB5R3kTCO7kUjRdrS0BJt3M/FPxmHMBOm8JQi2BsHAHFsy27E0EolP8rp0NzXsJ+jNPyDWvOJZPA==} + + unist-util-remove-position@4.0.2: + resolution: {integrity: sha512-TkBb0HABNmxzAcfLf4qsIbFbaPDvMO6wa3b3j4VcEzFVaw1LBKwnW4/sRJ/atSLSzoIg41JWEdnE7N6DIhGDGQ==} + + unist-util-remove-position@5.0.0: + resolution: {integrity: sha512-Hp5Kh3wLxv0PHj9m2yZhhLt58KzPtEYKQQ4yxfYFEO7EvHwzyDYnduhHnY1mDxoqr7VUwVuHXk9RXKIiYS1N8Q==} + + unist-util-remove@4.0.0: + resolution: {integrity: sha512-b4gokeGId57UVRX/eVKej5gXqGlc9+trkORhFJpu9raqZkZhU0zm8Doi05+HaiBsMEIJowL+2WtQ5ItjsngPXg==} + + unist-util-stringify-position@3.0.3: + resolution: {integrity: sha512-k5GzIBZ/QatR8N5X2y+drfpWG8IDBzdnVj6OInRNWm1oXrzydiaAT2OQiA8DPRRZyAKb9b6I2a6PxYklZD0gKg==} + + unist-util-stringify-position@4.0.0: + resolution: {integrity: sha512-0ASV06AAoKCDkS2+xw5RXJywruurpbC4JZSm7nr7MOt1ojAzvyyaO+UxZf18j8FCF6kmzCZKcAgN/yu2gm2XgQ==} + + unist-util-visit-parents@4.1.1: + resolution: {integrity: sha512-1xAFJXAKpnnJl8G7K5KgU7FY55y3GcLIXqkzUj5QF/QVP7biUm0K0O2oqVkYsdjzJKifYeWn9+o6piAK2hGSHw==} + + unist-util-visit-parents@5.1.3: + resolution: {integrity: sha512-x6+y8g7wWMyQhL1iZfhIPhDAs7Xwbn9nRosDXl7qoPTSCy0yNxnKc+hWokFifWQIDGi154rdUqKvbCa4+1kLhg==} + + unist-util-visit-parents@6.0.1: + resolution: {integrity: sha512-L/PqWzfTP9lzzEa6CKs0k2nARxTdZduw3zyh8d2NVBnsyvHjSX4TWse388YrrQKbvI8w20fGjGlhgT96WwKykw==} + + unist-util-visit@3.1.0: + resolution: {integrity: sha512-Szoh+R/Ll68QWAyQyZZpQzZQm2UPbxibDvaY8Xc9SUtYgPsDzx5AWSk++UUt2hJuow8mvwR+rG+LQLw+KsuAKA==} + + unist-util-visit@4.1.2: + resolution: {integrity: sha512-MSd8OUGISqHdVvfY9TPhyK2VdUrPgxkUtWSuMHF6XAAFuL4LokseigBnZtPnJMu+FbynTkFNnFlyjxpVKujMRg==} + + unist-util-visit@5.0.0: + resolution: {integrity: sha512-MR04uvD+07cwl/yhVuVWAtw+3GOR/knlL55Nd/wAdblk27GCVt3lqpTivy/tkJcZoNPzTwS1Y+KMojlLDhoTzg==} + + uuid@9.0.1: + resolution: {integrity: sha512-b+1eJOlsR9K8HJpow9Ok3fiWOWSIcIzXodvv0rQjVoOVNpWMpxf1wZNpt4y9h10odCNrqnYp1OBzRktckBe3sA==} + hasBin: true + + uvu@0.5.6: + resolution: {integrity: sha512-+g8ENReyr8YsOc6fv/NVJs2vFdHBnBNdfE49rshrTzDWOlUx4Gq7KOS2GD8eqhy2j+Ejq29+SbKH8yjkAqXqoA==} + engines: {node: '>=8'} + hasBin: true + + vfile-location@5.0.2: + resolution: {integrity: sha512-NXPYyxyBSH7zB5U6+3uDdd6Nybz6o6/od9rk8bp9H8GR3L+cm/fC0uUTbqBmUTnMCUDslAGBOIKNfvvb+gGlDg==} + + vfile-matter@3.0.1: + resolution: {integrity: sha512-CAAIDwnh6ZdtrqAuxdElUqQRQDQgbbIrYtDYI8gCjXS1qQ+1XdLoK8FIZWxJwn0/I+BkSSZpar3SOgjemQz4fg==} + + vfile-message@3.1.4: + resolution: {integrity: sha512-fa0Z6P8HUrQN4BZaX05SIVXic+7kE3b05PWAtPuYP9QLHsLKYR7/AlLW3NtOrpXRLeawpDLMsVkmk5DG0NXgWw==} + + vfile-message@4.0.2: + resolution: {integrity: sha512-jRDZ1IMLttGj41KcZvlrYAaI3CfqpLpfpf+Mfig13viT6NKvRzWZ+lXz0Y5D60w6uJIBAOGq9mSHf0gktF0duw==} + + vfile@5.3.7: + resolution: {integrity: sha512-r7qlzkgErKjobAmyNIkkSpizsFPYiUPuJb5pNW1RB4JcYVZhs4lIbVqk8XPk033CV/1z8ss5pkax8SuhGpcG8g==} + + vfile@6.0.1: + resolution: {integrity: sha512-1bYqc7pt6NIADBJ98UiG0Bn/CHIVOoZ/IyEkqIruLg0mE1BKzkOXY2D6CSqQIcKqgadppE5lrxgWXJmXd7zZJw==} + + vscode-oniguruma@1.7.0: + resolution: {integrity: sha512-L9WMGRfrjOhgHSdOYgCt/yRMsXzLDJSL7BPrOZt73gU0iWO4mpqzqQzOz5srxqTvMBaR0XZTSrVWo4j55Rc6cA==} + + vscode-textmate@8.0.0: + resolution: {integrity: sha512-AFbieoL7a5LMqcnOF04ji+rpXadgOXnZsxQr//r83kLPr7biP7am3g9zbaZIaBGwBRWeSvoMD4mgPdX3e4NWBg==} + + web-namespaces@2.0.1: + resolution: {integrity: sha512-bKr1DkiNa2krS7qxNtdrtHAmzuYGFQLiQ13TsorsdT6ULTkPLKuu5+GsFpDlg6JFjUTwX2DyhMPG2be8uPrqsQ==} + + web-worker@1.3.0: + resolution: {integrity: sha512-BSR9wyRsy/KOValMgd5kMyr3JzpdeoR9KVId8u5GVlTTAtNChlsE4yTxeY7zMdNSyOmoKBv8NH2qeRY9Tg+IaA==} + + which@1.3.1: + resolution: {integrity: sha512-HxJdYWq1MTIQbJ3nw0cqssHoTNU267KlrDuGZ1WYlxDStUtKUhOaJmh112/TZmHxxUfuJqPXSOm7tDyas0OSIQ==} + hasBin: true + + yallist@2.1.2: + resolution: {integrity: sha512-ncTzHV7NvsQZkYe1DW7cbDLm0YpzHmZF5r/iyP3ZnQtMiJ+pjzisCiMNI+Sj+xQF5pXhSHxSB3uDbsBTzY/c2A==} + + yocto-queue@0.1.0: + resolution: {integrity: sha512-rVksvsnNCdJ/ohGc6xgPwyN8eheCxsiLM8mxuE/t/mOVqJewPuO1miLpTHQiRgTKCLexL4MeAFVagts7HmNZ2Q==} + engines: {node: '>=10'} + + zod@3.23.8: + resolution: {integrity: sha512-XBx9AXhXktjUqnepgTiE5flcKIYWi/rme0Eaj+5Y0lftuGBq+jyRu/md4WnuxqgP1ubdpNCsYEYPxrzVHD8d6g==} + + zwitch@2.0.4: + resolution: {integrity: sha512-bXE4cR/kVZhKZX/RjPEflHaKVhUVl85noU3v6b8apfQEc1x4A+zBxjZ4lN8LqGd6WZ3dl98pY4o717VFmoPp+A==} + +snapshots: + + '@babel/runtime@7.24.7': + dependencies: + regenerator-runtime: 0.14.1 + + '@braintree/sanitize-url@6.0.4': {} + + '@headlessui/react@1.7.19(react-dom@18.3.1(react@18.3.1))(react@18.3.1)': + dependencies: + '@tanstack/react-virtual': 3.5.1(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + client-only: 0.0.1 + react: 18.3.1 + react-dom: 18.3.1(react@18.3.1) + + '@mdx-js/mdx@2.3.0': + dependencies: + '@types/estree-jsx': 1.0.5 + '@types/mdx': 2.0.13 + estree-util-build-jsx: 2.2.2 + estree-util-is-identifier-name: 2.1.0 + estree-util-to-js: 1.2.0 + estree-walker: 3.0.3 + hast-util-to-estree: 2.3.3 + markdown-extensions: 1.1.1 + periscopic: 3.1.0 + remark-mdx: 2.3.0 + remark-parse: 10.0.2 + remark-rehype: 10.1.0 + unified: 10.1.2 + unist-util-position-from-estree: 1.1.2 + unist-util-stringify-position: 3.0.3 + unist-util-visit: 4.1.2 + vfile: 5.3.7 + transitivePeerDependencies: + - supports-color + + '@mdx-js/react@2.3.0(react@18.3.1)': + dependencies: + '@types/mdx': 2.0.13 + '@types/react': 18.3.3 + react: 18.3.1 + + '@napi-rs/simple-git-android-arm-eabi@0.1.16': + optional: true + + '@napi-rs/simple-git-android-arm64@0.1.16': + optional: true + + '@napi-rs/simple-git-darwin-arm64@0.1.16': + optional: true + + '@napi-rs/simple-git-darwin-x64@0.1.16': + optional: true + + '@napi-rs/simple-git-linux-arm-gnueabihf@0.1.16': + optional: true + + '@napi-rs/simple-git-linux-arm64-gnu@0.1.16': + optional: true + + '@napi-rs/simple-git-linux-arm64-musl@0.1.16': + optional: true + + '@napi-rs/simple-git-linux-x64-gnu@0.1.16': + optional: true + + '@napi-rs/simple-git-linux-x64-musl@0.1.16': + optional: true + + '@napi-rs/simple-git-win32-arm64-msvc@0.1.16': + optional: true + + '@napi-rs/simple-git-win32-x64-msvc@0.1.16': + optional: true + + '@napi-rs/simple-git@0.1.16': + optionalDependencies: + '@napi-rs/simple-git-android-arm-eabi': 0.1.16 + '@napi-rs/simple-git-android-arm64': 0.1.16 + '@napi-rs/simple-git-darwin-arm64': 0.1.16 + '@napi-rs/simple-git-darwin-x64': 0.1.16 + '@napi-rs/simple-git-linux-arm-gnueabihf': 0.1.16 + '@napi-rs/simple-git-linux-arm64-gnu': 0.1.16 + '@napi-rs/simple-git-linux-arm64-musl': 0.1.16 + '@napi-rs/simple-git-linux-x64-gnu': 0.1.16 + '@napi-rs/simple-git-linux-x64-musl': 0.1.16 + '@napi-rs/simple-git-win32-arm64-msvc': 0.1.16 + '@napi-rs/simple-git-win32-x64-msvc': 0.1.16 + + '@next/env@14.2.3': {} + + '@next/swc-darwin-arm64@14.2.3': + optional: true + + '@next/swc-darwin-x64@14.2.3': + optional: true + + '@next/swc-linux-arm64-gnu@14.2.3': + optional: true + + '@next/swc-linux-arm64-musl@14.2.3': + optional: true + + '@next/swc-linux-x64-gnu@14.2.3': + optional: true + + '@next/swc-linux-x64-musl@14.2.3': + optional: true + + '@next/swc-win32-arm64-msvc@14.2.3': + optional: true + + '@next/swc-win32-ia32-msvc@14.2.3': + optional: true + + '@next/swc-win32-x64-msvc@14.2.3': + optional: true + + '@popperjs/core@2.11.8': {} + + '@swc/counter@0.1.3': {} + + '@swc/helpers@0.5.5': + dependencies: + '@swc/counter': 0.1.3 + tslib: 2.6.3 + + '@tanstack/react-virtual@3.5.1(react-dom@18.3.1(react@18.3.1))(react@18.3.1)': + dependencies: + '@tanstack/virtual-core': 3.5.1 + react: 18.3.1 + react-dom: 18.3.1(react@18.3.1) + + '@tanstack/virtual-core@3.5.1': {} + + '@theguild/remark-mermaid@0.0.5(react@18.3.1)': + dependencies: + mermaid: 10.9.1 + react: 18.3.1 + unist-util-visit: 5.0.0 + transitivePeerDependencies: + - supports-color + + '@theguild/remark-npm2yarn@0.2.1': + dependencies: + npm-to-yarn: 2.2.1 + unist-util-visit: 5.0.0 + + '@types/acorn@4.0.6': + dependencies: + '@types/estree': 1.0.5 + + '@types/d3-scale-chromatic@3.0.3': {} + + '@types/d3-scale@4.0.8': + dependencies: + '@types/d3-time': 3.0.3 + + '@types/d3-time@3.0.3': {} + + '@types/debug@4.1.12': + dependencies: + '@types/ms': 0.7.34 + + '@types/estree-jsx@1.0.5': + dependencies: + '@types/estree': 1.0.5 + + '@types/estree@1.0.5': {} + + '@types/hast@2.3.10': + dependencies: + '@types/unist': 2.0.10 + + '@types/hast@3.0.4': + dependencies: + '@types/unist': 3.0.2 + + '@types/js-yaml@4.0.9': {} + + '@types/katex@0.16.7': {} + + '@types/mdast@3.0.15': + dependencies: + '@types/unist': 2.0.10 + + '@types/mdast@4.0.4': + dependencies: + '@types/unist': 3.0.2 + + '@types/mdx@2.0.13': {} + + '@types/ms@0.7.34': {} + + '@types/node@20.14.2': + dependencies: + undici-types: 5.26.5 + + '@types/prop-types@15.7.12': {} + + '@types/react-dom@18.3.0': + dependencies: + '@types/react': 18.3.3 + + '@types/react@18.3.3': + dependencies: + '@types/prop-types': 15.7.12 + csstype: 3.1.3 + + '@types/unist@2.0.10': {} + + '@types/unist@3.0.2': {} + + '@ungap/structured-clone@1.2.0': {} + + acorn-jsx@5.3.2(acorn@8.11.3): + dependencies: + acorn: 8.11.3 + + acorn@8.11.3: {} + + ansi-sequence-parser@1.1.1: {} + + ansi-styles@3.2.1: + dependencies: + color-convert: 1.9.3 + + arch@2.2.0: {} + + arg@1.0.0: {} + + argparse@1.0.10: + dependencies: + sprintf-js: 1.0.3 + + argparse@2.0.1: {} + + astring@1.8.6: {} + + bail@2.0.2: {} + + busboy@1.6.0: + dependencies: + streamsearch: 1.1.0 + + caniuse-lite@1.0.30001629: {} + + ccount@2.0.1: {} + + chalk@2.3.0: + dependencies: + ansi-styles: 3.2.1 + escape-string-regexp: 1.0.5 + supports-color: 4.5.0 + + character-entities-html4@2.1.0: {} + + character-entities-legacy@3.0.0: {} + + character-entities@2.0.2: {} + + character-reference-invalid@2.0.1: {} + + client-only@0.0.1: {} + + clipboardy@1.2.2: + dependencies: + arch: 2.2.0 + execa: 0.8.0 + + clsx@2.1.1: {} + + color-convert@1.9.3: + dependencies: + color-name: 1.1.3 + + color-name@1.1.3: {} + + comma-separated-tokens@2.0.3: {} + + commander@7.2.0: {} + + commander@8.3.0: {} + + compute-scroll-into-view@3.1.0: {} + + cose-base@1.0.3: + dependencies: + layout-base: 1.0.2 + + cross-spawn@5.1.0: + dependencies: + lru-cache: 4.1.5 + shebang-command: 1.2.0 + which: 1.3.1 + + csstype@3.1.3: {} + + cytoscape-cose-bilkent@4.1.0(cytoscape@3.29.2): + dependencies: + cose-base: 1.0.3 + cytoscape: 3.29.2 + + cytoscape@3.29.2: {} + + d3-array@2.12.1: + dependencies: + internmap: 1.0.1 + + d3-array@3.2.4: + dependencies: + internmap: 2.0.3 + + d3-axis@3.0.0: {} + + d3-brush@3.0.0: + dependencies: + d3-dispatch: 3.0.1 + d3-drag: 3.0.0 + d3-interpolate: 3.0.1 + d3-selection: 3.0.0 + d3-transition: 3.0.1(d3-selection@3.0.0) + + d3-chord@3.0.1: + dependencies: + d3-path: 3.1.0 + + d3-color@3.1.0: {} + + d3-contour@4.0.2: + dependencies: + d3-array: 3.2.4 + + d3-delaunay@6.0.4: + dependencies: + delaunator: 5.0.1 + + d3-dispatch@3.0.1: {} + + d3-drag@3.0.0: + dependencies: + d3-dispatch: 3.0.1 + d3-selection: 3.0.0 + + d3-dsv@3.0.1: + dependencies: + commander: 7.2.0 + iconv-lite: 0.6.3 + rw: 1.3.3 + + d3-ease@3.0.1: {} + + d3-fetch@3.0.1: + dependencies: + d3-dsv: 3.0.1 + + d3-force@3.0.0: + dependencies: + d3-dispatch: 3.0.1 + d3-quadtree: 3.0.1 + d3-timer: 3.0.1 + + d3-format@3.1.0: {} + + d3-geo@3.1.1: + dependencies: + d3-array: 3.2.4 + + d3-hierarchy@3.1.2: {} + + d3-interpolate@3.0.1: + dependencies: + d3-color: 3.1.0 + + d3-path@1.0.9: {} + + d3-path@3.1.0: {} + + d3-polygon@3.0.1: {} + + d3-quadtree@3.0.1: {} + + d3-random@3.0.1: {} + + d3-sankey@0.12.3: + dependencies: + d3-array: 2.12.1 + d3-shape: 1.3.7 + + d3-scale-chromatic@3.1.0: + dependencies: + d3-color: 3.1.0 + d3-interpolate: 3.0.1 + + d3-scale@4.0.2: + dependencies: + d3-array: 3.2.4 + d3-format: 3.1.0 + d3-interpolate: 3.0.1 + d3-time: 3.1.0 + d3-time-format: 4.1.0 + + d3-selection@3.0.0: {} + + d3-shape@1.3.7: + dependencies: + d3-path: 1.0.9 + + d3-shape@3.2.0: + dependencies: + d3-path: 3.1.0 + + d3-time-format@4.1.0: + dependencies: + d3-time: 3.1.0 + + d3-time@3.1.0: + dependencies: + d3-array: 3.2.4 + + d3-timer@3.0.1: {} + + d3-transition@3.0.1(d3-selection@3.0.0): + dependencies: + d3-color: 3.1.0 + d3-dispatch: 3.0.1 + d3-ease: 3.0.1 + d3-interpolate: 3.0.1 + d3-selection: 3.0.0 + d3-timer: 3.0.1 + + d3-zoom@3.0.0: + dependencies: + d3-dispatch: 3.0.1 + d3-drag: 3.0.0 + d3-interpolate: 3.0.1 + d3-selection: 3.0.0 + d3-transition: 3.0.1(d3-selection@3.0.0) + + d3@7.9.0: + dependencies: + d3-array: 3.2.4 + d3-axis: 3.0.0 + d3-brush: 3.0.0 + d3-chord: 3.0.1 + d3-color: 3.1.0 + d3-contour: 4.0.2 + d3-delaunay: 6.0.4 + d3-dispatch: 3.0.1 + d3-drag: 3.0.0 + d3-dsv: 3.0.1 + d3-ease: 3.0.1 + d3-fetch: 3.0.1 + d3-force: 3.0.0 + d3-format: 3.1.0 + d3-geo: 3.1.1 + d3-hierarchy: 3.1.2 + d3-interpolate: 3.0.1 + d3-path: 3.1.0 + d3-polygon: 3.0.1 + d3-quadtree: 3.0.1 + d3-random: 3.0.1 + d3-scale: 4.0.2 + d3-scale-chromatic: 3.1.0 + d3-selection: 3.0.0 + d3-shape: 3.2.0 + d3-time: 3.1.0 + d3-time-format: 4.1.0 + d3-timer: 3.0.1 + d3-transition: 3.0.1(d3-selection@3.0.0) + d3-zoom: 3.0.0 + + dagre-d3-es@7.0.10: + dependencies: + d3: 7.9.0 + lodash-es: 4.17.21 + + dayjs@1.11.11: {} + + debug@4.3.5: + dependencies: + ms: 2.1.2 + + decode-named-character-reference@1.0.2: + dependencies: + character-entities: 2.0.2 + + delaunator@5.0.1: + dependencies: + robust-predicates: 3.0.2 + + dequal@2.0.3: {} + + devlop@1.1.0: + dependencies: + dequal: 2.0.3 + + diff@5.2.0: {} + + dompurify@3.1.5: {} + + elkjs@0.9.3: {} + + entities@4.5.0: {} + + escape-string-regexp@1.0.5: {} + + escape-string-regexp@5.0.0: {} + + esprima@4.0.1: {} + + estree-util-attach-comments@2.1.1: + dependencies: + '@types/estree': 1.0.5 + + estree-util-build-jsx@2.2.2: + dependencies: + '@types/estree-jsx': 1.0.5 + estree-util-is-identifier-name: 2.1.0 + estree-walker: 3.0.3 + + estree-util-is-identifier-name@2.1.0: {} + + estree-util-to-js@1.2.0: + dependencies: + '@types/estree-jsx': 1.0.5 + astring: 1.8.6 + source-map: 0.7.4 + + estree-util-value-to-estree@1.3.0: + dependencies: + is-plain-obj: 3.0.0 + + estree-util-visit@1.2.1: + dependencies: + '@types/estree-jsx': 1.0.5 + '@types/unist': 2.0.10 + + estree-walker@3.0.3: + dependencies: + '@types/estree': 1.0.5 + + execa@0.8.0: + dependencies: + cross-spawn: 5.1.0 + get-stream: 3.0.0 + is-stream: 1.1.0 + npm-run-path: 2.0.2 + p-finally: 1.0.0 + signal-exit: 3.0.7 + strip-eof: 1.0.0 + + extend-shallow@2.0.1: + dependencies: + is-extendable: 0.1.1 + + extend@3.0.2: {} + + flexsearch@0.7.43: {} + + focus-visible@5.2.0: {} + + get-stream@3.0.0: {} + + git-up@7.0.0: + dependencies: + is-ssh: 1.4.0 + parse-url: 8.1.0 + + git-url-parse@13.1.1: + dependencies: + git-up: 7.0.0 + + github-slugger@2.0.0: {} + + graceful-fs@4.2.11: {} + + gray-matter@4.0.3: + dependencies: + js-yaml: 3.14.1 + kind-of: 6.0.3 + section-matter: 1.0.0 + strip-bom-string: 1.0.0 + + has-flag@2.0.0: {} + + hash-obj@4.0.0: + dependencies: + is-obj: 3.0.0 + sort-keys: 5.0.0 + type-fest: 1.4.0 + + hast-util-from-dom@5.0.0: + dependencies: + '@types/hast': 3.0.4 + hastscript: 8.0.0 + web-namespaces: 2.0.1 + + hast-util-from-html-isomorphic@2.0.0: + dependencies: + '@types/hast': 3.0.4 + hast-util-from-dom: 5.0.0 + hast-util-from-html: 2.0.1 + unist-util-remove-position: 5.0.0 + + hast-util-from-html@2.0.1: + dependencies: + '@types/hast': 3.0.4 + devlop: 1.1.0 + hast-util-from-parse5: 8.0.1 + parse5: 7.1.2 + vfile: 6.0.1 + vfile-message: 4.0.2 + + hast-util-from-parse5@8.0.1: + dependencies: + '@types/hast': 3.0.4 + '@types/unist': 3.0.2 + devlop: 1.1.0 + hastscript: 8.0.0 + property-information: 6.5.0 + vfile: 6.0.1 + vfile-location: 5.0.2 + web-namespaces: 2.0.1 + + hast-util-is-element@3.0.0: + dependencies: + '@types/hast': 3.0.4 + + hast-util-parse-selector@4.0.0: + dependencies: + '@types/hast': 3.0.4 + + hast-util-raw@9.0.3: + dependencies: + '@types/hast': 3.0.4 + '@types/unist': 3.0.2 + '@ungap/structured-clone': 1.2.0 + hast-util-from-parse5: 8.0.1 + hast-util-to-parse5: 8.0.0 + html-void-elements: 3.0.0 + mdast-util-to-hast: 13.1.0 + parse5: 7.1.2 + unist-util-position: 5.0.0 + unist-util-visit: 5.0.0 + vfile: 6.0.1 + web-namespaces: 2.0.1 + zwitch: 2.0.4 + + hast-util-to-estree@2.3.3: + dependencies: + '@types/estree': 1.0.5 + '@types/estree-jsx': 1.0.5 + '@types/hast': 2.3.10 + '@types/unist': 2.0.10 + comma-separated-tokens: 2.0.3 + estree-util-attach-comments: 2.1.1 + estree-util-is-identifier-name: 2.1.0 + hast-util-whitespace: 2.0.1 + mdast-util-mdx-expression: 1.3.2 + mdast-util-mdxjs-esm: 1.3.1 + property-information: 6.5.0 + space-separated-tokens: 2.0.2 + style-to-object: 0.4.4 + unist-util-position: 4.0.4 + zwitch: 2.0.4 + transitivePeerDependencies: + - supports-color + + hast-util-to-parse5@8.0.0: + dependencies: + '@types/hast': 3.0.4 + comma-separated-tokens: 2.0.3 + devlop: 1.1.0 + property-information: 6.5.0 + space-separated-tokens: 2.0.2 + web-namespaces: 2.0.1 + zwitch: 2.0.4 + + hast-util-to-text@4.0.2: + dependencies: + '@types/hast': 3.0.4 + '@types/unist': 3.0.2 + hast-util-is-element: 3.0.0 + unist-util-find-after: 5.0.0 + + hast-util-whitespace@2.0.1: {} + + hastscript@8.0.0: + dependencies: + '@types/hast': 3.0.4 + comma-separated-tokens: 2.0.3 + hast-util-parse-selector: 4.0.0 + property-information: 6.5.0 + space-separated-tokens: 2.0.2 + + html-void-elements@3.0.0: {} + + iconv-lite@0.6.3: + dependencies: + safer-buffer: 2.1.2 + + inline-style-parser@0.1.1: {} + + internmap@1.0.1: {} + + internmap@2.0.3: {} + + intersection-observer@0.12.2: {} + + is-alphabetical@2.0.1: {} + + is-alphanumerical@2.0.1: + dependencies: + is-alphabetical: 2.0.1 + is-decimal: 2.0.1 + + is-buffer@2.0.5: {} + + is-decimal@2.0.1: {} + + is-extendable@0.1.1: {} + + is-hexadecimal@2.0.1: {} + + is-obj@3.0.0: {} + + is-plain-obj@3.0.0: {} + + is-plain-obj@4.1.0: {} + + is-reference@3.0.2: + dependencies: + '@types/estree': 1.0.5 + + is-ssh@1.4.0: + dependencies: + protocols: 2.0.1 + + is-stream@1.1.0: {} + + isexe@2.0.0: {} + + js-tokens@4.0.0: {} + + js-yaml@3.14.1: + dependencies: + argparse: 1.0.10 + esprima: 4.0.1 + + js-yaml@4.1.0: + dependencies: + argparse: 2.0.1 + + jsonc-parser@3.2.1: {} + + katex@0.16.10: + dependencies: + commander: 8.3.0 + + khroma@2.1.0: {} + + kind-of@6.0.3: {} + + kleur@4.1.5: {} + + layout-base@1.0.2: {} + + lodash-es@4.17.21: {} + + lodash.get@4.4.2: {} + + longest-streak@3.1.0: {} + + loose-envify@1.4.0: + dependencies: + js-tokens: 4.0.0 + + lru-cache@4.1.5: + dependencies: + pseudomap: 1.0.2 + yallist: 2.1.2 + + markdown-extensions@1.1.1: {} + + markdown-table@3.0.3: {} + + match-sorter@6.3.4: + dependencies: + '@babel/runtime': 7.24.7 + remove-accents: 0.5.0 + + mdast-util-definitions@5.1.2: + dependencies: + '@types/mdast': 3.0.15 + '@types/unist': 2.0.10 + unist-util-visit: 4.1.2 + + mdast-util-find-and-replace@2.2.2: + dependencies: + '@types/mdast': 3.0.15 + escape-string-regexp: 5.0.0 + unist-util-is: 5.2.1 + unist-util-visit-parents: 5.1.3 + + mdast-util-from-markdown@1.3.1: + dependencies: + '@types/mdast': 3.0.15 + '@types/unist': 2.0.10 + decode-named-character-reference: 1.0.2 + mdast-util-to-string: 3.2.0 + micromark: 3.2.0 + micromark-util-decode-numeric-character-reference: 1.1.0 + micromark-util-decode-string: 1.1.0 + micromark-util-normalize-identifier: 1.1.0 + micromark-util-symbol: 1.1.0 + micromark-util-types: 1.1.0 + unist-util-stringify-position: 3.0.3 + uvu: 0.5.6 + transitivePeerDependencies: + - supports-color + + mdast-util-gfm-autolink-literal@1.0.3: + dependencies: + '@types/mdast': 3.0.15 + ccount: 2.0.1 + mdast-util-find-and-replace: 2.2.2 + micromark-util-character: 1.2.0 + + mdast-util-gfm-footnote@1.0.2: + dependencies: + '@types/mdast': 3.0.15 + mdast-util-to-markdown: 1.5.0 + micromark-util-normalize-identifier: 1.1.0 + + mdast-util-gfm-strikethrough@1.0.3: + dependencies: + '@types/mdast': 3.0.15 + mdast-util-to-markdown: 1.5.0 + + mdast-util-gfm-table@1.0.7: + dependencies: + '@types/mdast': 3.0.15 + markdown-table: 3.0.3 + mdast-util-from-markdown: 1.3.1 + mdast-util-to-markdown: 1.5.0 + transitivePeerDependencies: + - supports-color + + mdast-util-gfm-task-list-item@1.0.2: + dependencies: + '@types/mdast': 3.0.15 + mdast-util-to-markdown: 1.5.0 + + mdast-util-gfm@2.0.2: + dependencies: + mdast-util-from-markdown: 1.3.1 + mdast-util-gfm-autolink-literal: 1.0.3 + mdast-util-gfm-footnote: 1.0.2 + mdast-util-gfm-strikethrough: 1.0.3 + mdast-util-gfm-table: 1.0.7 + mdast-util-gfm-task-list-item: 1.0.2 + mdast-util-to-markdown: 1.5.0 + transitivePeerDependencies: + - supports-color + + mdast-util-math@2.0.2: + dependencies: + '@types/mdast': 3.0.15 + longest-streak: 3.1.0 + mdast-util-to-markdown: 1.5.0 + + mdast-util-mdx-expression@1.3.2: + dependencies: + '@types/estree-jsx': 1.0.5 + '@types/hast': 2.3.10 + '@types/mdast': 3.0.15 + mdast-util-from-markdown: 1.3.1 + mdast-util-to-markdown: 1.5.0 + transitivePeerDependencies: + - supports-color + + mdast-util-mdx-jsx@2.1.4: + dependencies: + '@types/estree-jsx': 1.0.5 + '@types/hast': 2.3.10 + '@types/mdast': 3.0.15 + '@types/unist': 2.0.10 + ccount: 2.0.1 + mdast-util-from-markdown: 1.3.1 + mdast-util-to-markdown: 1.5.0 + parse-entities: 4.0.1 + stringify-entities: 4.0.4 + unist-util-remove-position: 4.0.2 + unist-util-stringify-position: 3.0.3 + vfile-message: 3.1.4 + transitivePeerDependencies: + - supports-color + + mdast-util-mdx@2.0.1: + dependencies: + mdast-util-from-markdown: 1.3.1 + mdast-util-mdx-expression: 1.3.2 + mdast-util-mdx-jsx: 2.1.4 + mdast-util-mdxjs-esm: 1.3.1 + mdast-util-to-markdown: 1.5.0 + transitivePeerDependencies: + - supports-color + + mdast-util-mdxjs-esm@1.3.1: + dependencies: + '@types/estree-jsx': 1.0.5 + '@types/hast': 2.3.10 + '@types/mdast': 3.0.15 + mdast-util-from-markdown: 1.3.1 + mdast-util-to-markdown: 1.5.0 + transitivePeerDependencies: + - supports-color + + mdast-util-phrasing@3.0.1: + dependencies: + '@types/mdast': 3.0.15 + unist-util-is: 5.2.1 + + mdast-util-to-hast@12.3.0: + dependencies: + '@types/hast': 2.3.10 + '@types/mdast': 3.0.15 + mdast-util-definitions: 5.1.2 + micromark-util-sanitize-uri: 1.2.0 + trim-lines: 3.0.1 + unist-util-generated: 2.0.1 + unist-util-position: 4.0.4 + unist-util-visit: 4.1.2 + + mdast-util-to-hast@13.1.0: + dependencies: + '@types/hast': 3.0.4 + '@types/mdast': 4.0.4 + '@ungap/structured-clone': 1.2.0 + devlop: 1.1.0 + micromark-util-sanitize-uri: 2.0.0 + trim-lines: 3.0.1 + unist-util-position: 5.0.0 + unist-util-visit: 5.0.0 + vfile: 6.0.1 + + mdast-util-to-markdown@1.5.0: + dependencies: + '@types/mdast': 3.0.15 + '@types/unist': 2.0.10 + longest-streak: 3.1.0 + mdast-util-phrasing: 3.0.1 + mdast-util-to-string: 3.2.0 + micromark-util-decode-string: 1.1.0 + unist-util-visit: 4.1.2 + zwitch: 2.0.4 + + mdast-util-to-string@3.2.0: + dependencies: + '@types/mdast': 3.0.15 + + mermaid@10.9.1: + dependencies: + '@braintree/sanitize-url': 6.0.4 + '@types/d3-scale': 4.0.8 + '@types/d3-scale-chromatic': 3.0.3 + cytoscape: 3.29.2 + cytoscape-cose-bilkent: 4.1.0(cytoscape@3.29.2) + d3: 7.9.0 + d3-sankey: 0.12.3 + dagre-d3-es: 7.0.10 + dayjs: 1.11.11 + dompurify: 3.1.5 + elkjs: 0.9.3 + katex: 0.16.10 + khroma: 2.1.0 + lodash-es: 4.17.21 + mdast-util-from-markdown: 1.3.1 + non-layered-tidy-tree-layout: 2.0.2 + stylis: 4.3.2 + ts-dedent: 2.2.0 + uuid: 9.0.1 + web-worker: 1.3.0 + transitivePeerDependencies: + - supports-color + + micromark-core-commonmark@1.1.0: + dependencies: + decode-named-character-reference: 1.0.2 + micromark-factory-destination: 1.1.0 + micromark-factory-label: 1.1.0 + micromark-factory-space: 1.1.0 + micromark-factory-title: 1.1.0 + micromark-factory-whitespace: 1.1.0 + micromark-util-character: 1.2.0 + micromark-util-chunked: 1.1.0 + micromark-util-classify-character: 1.1.0 + micromark-util-html-tag-name: 1.2.0 + micromark-util-normalize-identifier: 1.1.0 + micromark-util-resolve-all: 1.1.0 + micromark-util-subtokenize: 1.1.0 + micromark-util-symbol: 1.1.0 + micromark-util-types: 1.1.0 + uvu: 0.5.6 + + micromark-extension-gfm-autolink-literal@1.0.5: + dependencies: + micromark-util-character: 1.2.0 + micromark-util-sanitize-uri: 1.2.0 + micromark-util-symbol: 1.1.0 + micromark-util-types: 1.1.0 + + micromark-extension-gfm-footnote@1.1.2: + dependencies: + micromark-core-commonmark: 1.1.0 + micromark-factory-space: 1.1.0 + micromark-util-character: 1.2.0 + micromark-util-normalize-identifier: 1.1.0 + micromark-util-sanitize-uri: 1.2.0 + micromark-util-symbol: 1.1.0 + micromark-util-types: 1.1.0 + uvu: 0.5.6 + + micromark-extension-gfm-strikethrough@1.0.7: + dependencies: + micromark-util-chunked: 1.1.0 + micromark-util-classify-character: 1.1.0 + micromark-util-resolve-all: 1.1.0 + micromark-util-symbol: 1.1.0 + micromark-util-types: 1.1.0 + uvu: 0.5.6 + + micromark-extension-gfm-table@1.0.7: + dependencies: + micromark-factory-space: 1.1.0 + micromark-util-character: 1.2.0 + micromark-util-symbol: 1.1.0 + micromark-util-types: 1.1.0 + uvu: 0.5.6 + + micromark-extension-gfm-tagfilter@1.0.2: + dependencies: + micromark-util-types: 1.1.0 + + micromark-extension-gfm-task-list-item@1.0.5: + dependencies: + micromark-factory-space: 1.1.0 + micromark-util-character: 1.2.0 + micromark-util-symbol: 1.1.0 + micromark-util-types: 1.1.0 + uvu: 0.5.6 + + micromark-extension-gfm@2.0.3: + dependencies: + micromark-extension-gfm-autolink-literal: 1.0.5 + micromark-extension-gfm-footnote: 1.1.2 + micromark-extension-gfm-strikethrough: 1.0.7 + micromark-extension-gfm-table: 1.0.7 + micromark-extension-gfm-tagfilter: 1.0.2 + micromark-extension-gfm-task-list-item: 1.0.5 + micromark-util-combine-extensions: 1.1.0 + micromark-util-types: 1.1.0 + + micromark-extension-math@2.1.2: + dependencies: + '@types/katex': 0.16.7 + katex: 0.16.10 + micromark-factory-space: 1.1.0 + micromark-util-character: 1.2.0 + micromark-util-symbol: 1.1.0 + micromark-util-types: 1.1.0 + uvu: 0.5.6 + + micromark-extension-mdx-expression@1.0.8: + dependencies: + '@types/estree': 1.0.5 + micromark-factory-mdx-expression: 1.0.9 + micromark-factory-space: 1.1.0 + micromark-util-character: 1.2.0 + micromark-util-events-to-acorn: 1.2.3 + micromark-util-symbol: 1.1.0 + micromark-util-types: 1.1.0 + uvu: 0.5.6 + + micromark-extension-mdx-jsx@1.0.5: + dependencies: + '@types/acorn': 4.0.6 + '@types/estree': 1.0.5 + estree-util-is-identifier-name: 2.1.0 + micromark-factory-mdx-expression: 1.0.9 + micromark-factory-space: 1.1.0 + micromark-util-character: 1.2.0 + micromark-util-symbol: 1.1.0 + micromark-util-types: 1.1.0 + uvu: 0.5.6 + vfile-message: 3.1.4 + + micromark-extension-mdx-md@1.0.1: + dependencies: + micromark-util-types: 1.1.0 + + micromark-extension-mdxjs-esm@1.0.5: + dependencies: + '@types/estree': 1.0.5 + micromark-core-commonmark: 1.1.0 + micromark-util-character: 1.2.0 + micromark-util-events-to-acorn: 1.2.3 + micromark-util-symbol: 1.1.0 + micromark-util-types: 1.1.0 + unist-util-position-from-estree: 1.1.2 + uvu: 0.5.6 + vfile-message: 3.1.4 + + micromark-extension-mdxjs@1.0.1: + dependencies: + acorn: 8.11.3 + acorn-jsx: 5.3.2(acorn@8.11.3) + micromark-extension-mdx-expression: 1.0.8 + micromark-extension-mdx-jsx: 1.0.5 + micromark-extension-mdx-md: 1.0.1 + micromark-extension-mdxjs-esm: 1.0.5 + micromark-util-combine-extensions: 1.1.0 + micromark-util-types: 1.1.0 + + micromark-factory-destination@1.1.0: + dependencies: + micromark-util-character: 1.2.0 + micromark-util-symbol: 1.1.0 + micromark-util-types: 1.1.0 + + micromark-factory-label@1.1.0: + dependencies: + micromark-util-character: 1.2.0 + micromark-util-symbol: 1.1.0 + micromark-util-types: 1.1.0 + uvu: 0.5.6 + + micromark-factory-mdx-expression@1.0.9: + dependencies: + '@types/estree': 1.0.5 + micromark-util-character: 1.2.0 + micromark-util-events-to-acorn: 1.2.3 + micromark-util-symbol: 1.1.0 + micromark-util-types: 1.1.0 + unist-util-position-from-estree: 1.1.2 + uvu: 0.5.6 + vfile-message: 3.1.4 + + micromark-factory-space@1.1.0: + dependencies: + micromark-util-character: 1.2.0 + micromark-util-types: 1.1.0 + + micromark-factory-title@1.1.0: + dependencies: + micromark-factory-space: 1.1.0 + micromark-util-character: 1.2.0 + micromark-util-symbol: 1.1.0 + micromark-util-types: 1.1.0 + + micromark-factory-whitespace@1.1.0: + dependencies: + micromark-factory-space: 1.1.0 + micromark-util-character: 1.2.0 + micromark-util-symbol: 1.1.0 + micromark-util-types: 1.1.0 + + micromark-util-character@1.2.0: + dependencies: + micromark-util-symbol: 1.1.0 + micromark-util-types: 1.1.0 + + micromark-util-character@2.1.0: + dependencies: + micromark-util-symbol: 2.0.0 + micromark-util-types: 2.0.0 + + micromark-util-chunked@1.1.0: + dependencies: + micromark-util-symbol: 1.1.0 + + micromark-util-classify-character@1.1.0: + dependencies: + micromark-util-character: 1.2.0 + micromark-util-symbol: 1.1.0 + micromark-util-types: 1.1.0 + + micromark-util-combine-extensions@1.1.0: + dependencies: + micromark-util-chunked: 1.1.0 + micromark-util-types: 1.1.0 + + micromark-util-decode-numeric-character-reference@1.1.0: + dependencies: + micromark-util-symbol: 1.1.0 + + micromark-util-decode-string@1.1.0: + dependencies: + decode-named-character-reference: 1.0.2 + micromark-util-character: 1.2.0 + micromark-util-decode-numeric-character-reference: 1.1.0 + micromark-util-symbol: 1.1.0 + + micromark-util-encode@1.1.0: {} + + micromark-util-encode@2.0.0: {} + + micromark-util-events-to-acorn@1.2.3: + dependencies: + '@types/acorn': 4.0.6 + '@types/estree': 1.0.5 + '@types/unist': 2.0.10 + estree-util-visit: 1.2.1 + micromark-util-symbol: 1.1.0 + micromark-util-types: 1.1.0 + uvu: 0.5.6 + vfile-message: 3.1.4 + + micromark-util-html-tag-name@1.2.0: {} + + micromark-util-normalize-identifier@1.1.0: + dependencies: + micromark-util-symbol: 1.1.0 + + micromark-util-resolve-all@1.1.0: + dependencies: + micromark-util-types: 1.1.0 + + micromark-util-sanitize-uri@1.2.0: + dependencies: + micromark-util-character: 1.2.0 + micromark-util-encode: 1.1.0 + micromark-util-symbol: 1.1.0 + + micromark-util-sanitize-uri@2.0.0: + dependencies: + micromark-util-character: 2.1.0 + micromark-util-encode: 2.0.0 + micromark-util-symbol: 2.0.0 + + micromark-util-subtokenize@1.1.0: + dependencies: + micromark-util-chunked: 1.1.0 + micromark-util-symbol: 1.1.0 + micromark-util-types: 1.1.0 + uvu: 0.5.6 + + micromark-util-symbol@1.1.0: {} + + micromark-util-symbol@2.0.0: {} + + micromark-util-types@1.1.0: {} + + micromark-util-types@2.0.0: {} + + micromark@3.2.0: + dependencies: + '@types/debug': 4.1.12 + debug: 4.3.5 + decode-named-character-reference: 1.0.2 + micromark-core-commonmark: 1.1.0 + micromark-factory-space: 1.1.0 + micromark-util-character: 1.2.0 + micromark-util-chunked: 1.1.0 + micromark-util-combine-extensions: 1.1.0 + micromark-util-decode-numeric-character-reference: 1.1.0 + micromark-util-encode: 1.1.0 + micromark-util-normalize-identifier: 1.1.0 + micromark-util-resolve-all: 1.1.0 + micromark-util-sanitize-uri: 1.2.0 + micromark-util-subtokenize: 1.1.0 + micromark-util-symbol: 1.1.0 + micromark-util-types: 1.1.0 + uvu: 0.5.6 + transitivePeerDependencies: + - supports-color + + mri@1.2.0: {} + + ms@2.1.2: {} + + nanoid@3.3.7: {} + + next-mdx-remote@4.4.1(react-dom@18.3.1(react@18.3.1))(react@18.3.1): + dependencies: + '@mdx-js/mdx': 2.3.0 + '@mdx-js/react': 2.3.0(react@18.3.1) + react: 18.3.1 + react-dom: 18.3.1(react@18.3.1) + vfile: 5.3.7 + vfile-matter: 3.0.1 + transitivePeerDependencies: + - supports-color + + next-seo@6.5.0(next@14.2.3(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react-dom@18.3.1(react@18.3.1))(react@18.3.1): + dependencies: + next: 14.2.3(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + react: 18.3.1 + react-dom: 18.3.1(react@18.3.1) + + next-themes@0.2.1(next@14.2.3(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react-dom@18.3.1(react@18.3.1))(react@18.3.1): + dependencies: + next: 14.2.3(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + react: 18.3.1 + react-dom: 18.3.1(react@18.3.1) + + next@14.2.3(react-dom@18.3.1(react@18.3.1))(react@18.3.1): + dependencies: + '@next/env': 14.2.3 + '@swc/helpers': 0.5.5 + busboy: 1.6.0 + caniuse-lite: 1.0.30001629 + graceful-fs: 4.2.11 + postcss: 8.4.31 + react: 18.3.1 + react-dom: 18.3.1(react@18.3.1) + styled-jsx: 5.1.1(react@18.3.1) + optionalDependencies: + '@next/swc-darwin-arm64': 14.2.3 + '@next/swc-darwin-x64': 14.2.3 + '@next/swc-linux-arm64-gnu': 14.2.3 + '@next/swc-linux-arm64-musl': 14.2.3 + '@next/swc-linux-x64-gnu': 14.2.3 + '@next/swc-linux-x64-musl': 14.2.3 + '@next/swc-win32-arm64-msvc': 14.2.3 + '@next/swc-win32-ia32-msvc': 14.2.3 + '@next/swc-win32-x64-msvc': 14.2.3 + transitivePeerDependencies: + - '@babel/core' + - babel-plugin-macros + + nextra-theme-docs@2.13.4(next@14.2.3(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(nextra@2.13.4(next@14.2.3(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react-dom@18.3.1(react@18.3.1))(react@18.3.1): + dependencies: + '@headlessui/react': 1.7.19(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + '@popperjs/core': 2.11.8 + clsx: 2.1.1 + escape-string-regexp: 5.0.0 + flexsearch: 0.7.43 + focus-visible: 5.2.0 + git-url-parse: 13.1.1 + intersection-observer: 0.12.2 + match-sorter: 6.3.4 + next: 14.2.3(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + next-seo: 6.5.0(next@14.2.3(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + next-themes: 0.2.1(next@14.2.3(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + nextra: 2.13.4(next@14.2.3(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + react: 18.3.1 + react-dom: 18.3.1(react@18.3.1) + scroll-into-view-if-needed: 3.1.0 + zod: 3.23.8 + + nextra@2.13.4(next@14.2.3(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(react-dom@18.3.1(react@18.3.1))(react@18.3.1): + dependencies: + '@headlessui/react': 1.7.19(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + '@mdx-js/mdx': 2.3.0 + '@mdx-js/react': 2.3.0(react@18.3.1) + '@napi-rs/simple-git': 0.1.16 + '@theguild/remark-mermaid': 0.0.5(react@18.3.1) + '@theguild/remark-npm2yarn': 0.2.1 + clsx: 2.1.1 + github-slugger: 2.0.0 + graceful-fs: 4.2.11 + gray-matter: 4.0.3 + katex: 0.16.10 + lodash.get: 4.4.2 + next: 14.2.3(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + next-mdx-remote: 4.4.1(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + p-limit: 3.1.0 + react: 18.3.1 + react-dom: 18.3.1(react@18.3.1) + rehype-katex: 7.0.0 + rehype-pretty-code: 0.9.11(shiki@0.14.7) + rehype-raw: 7.0.0 + remark-gfm: 3.0.1 + remark-math: 5.1.1 + remark-reading-time: 2.0.1 + shiki: 0.14.7 + slash: 3.0.0 + title: 3.5.3 + unist-util-remove: 4.0.0 + unist-util-visit: 5.0.0 + zod: 3.23.8 + transitivePeerDependencies: + - supports-color + + non-layered-tidy-tree-layout@2.0.2: {} + + npm-run-path@2.0.2: + dependencies: + path-key: 2.0.1 + + npm-to-yarn@2.2.1: {} + + p-finally@1.0.0: {} + + p-limit@3.1.0: + dependencies: + yocto-queue: 0.1.0 + + parse-entities@4.0.1: + dependencies: + '@types/unist': 2.0.10 + character-entities: 2.0.2 + character-entities-legacy: 3.0.0 + character-reference-invalid: 2.0.1 + decode-named-character-reference: 1.0.2 + is-alphanumerical: 2.0.1 + is-decimal: 2.0.1 + is-hexadecimal: 2.0.1 + + parse-numeric-range@1.3.0: {} + + parse-path@7.0.0: + dependencies: + protocols: 2.0.1 + + parse-url@8.1.0: + dependencies: + parse-path: 7.0.0 + + parse5@7.1.2: + dependencies: + entities: 4.5.0 + + path-key@2.0.1: {} + + periscopic@3.1.0: + dependencies: + '@types/estree': 1.0.5 + estree-walker: 3.0.3 + is-reference: 3.0.2 + + picocolors@1.0.1: {} + + postcss@8.4.31: + dependencies: + nanoid: 3.3.7 + picocolors: 1.0.1 + source-map-js: 1.2.0 + + property-information@6.5.0: {} + + protocols@2.0.1: {} + + pseudomap@1.0.2: {} + + react-dom@18.3.1(react@18.3.1): + dependencies: + loose-envify: 1.4.0 + react: 18.3.1 + scheduler: 0.23.2 + + react@18.3.1: + dependencies: + loose-envify: 1.4.0 + + reading-time@1.5.0: {} + + regenerator-runtime@0.14.1: {} + + rehype-katex@7.0.0: + dependencies: + '@types/hast': 3.0.4 + '@types/katex': 0.16.7 + hast-util-from-html-isomorphic: 2.0.0 + hast-util-to-text: 4.0.2 + katex: 0.16.10 + unist-util-visit-parents: 6.0.1 + vfile: 6.0.1 + + rehype-pretty-code@0.9.11(shiki@0.14.7): + dependencies: + '@types/hast': 2.3.10 + hash-obj: 4.0.0 + parse-numeric-range: 1.3.0 + shiki: 0.14.7 + + rehype-raw@7.0.0: + dependencies: + '@types/hast': 3.0.4 + hast-util-raw: 9.0.3 + vfile: 6.0.1 + + remark-gfm@3.0.1: + dependencies: + '@types/mdast': 3.0.15 + mdast-util-gfm: 2.0.2 + micromark-extension-gfm: 2.0.3 + unified: 10.1.2 + transitivePeerDependencies: + - supports-color + + remark-math@5.1.1: + dependencies: + '@types/mdast': 3.0.15 + mdast-util-math: 2.0.2 + micromark-extension-math: 2.1.2 + unified: 10.1.2 + + remark-mdx@2.3.0: + dependencies: + mdast-util-mdx: 2.0.1 + micromark-extension-mdxjs: 1.0.1 + transitivePeerDependencies: + - supports-color + + remark-parse@10.0.2: + dependencies: + '@types/mdast': 3.0.15 + mdast-util-from-markdown: 1.3.1 + unified: 10.1.2 + transitivePeerDependencies: + - supports-color + + remark-reading-time@2.0.1: + dependencies: + estree-util-is-identifier-name: 2.1.0 + estree-util-value-to-estree: 1.3.0 + reading-time: 1.5.0 + unist-util-visit: 3.1.0 + + remark-rehype@10.1.0: + dependencies: + '@types/hast': 2.3.10 + '@types/mdast': 3.0.15 + mdast-util-to-hast: 12.3.0 + unified: 10.1.2 + + remove-accents@0.5.0: {} + + robust-predicates@3.0.2: {} + + rw@1.3.3: {} + + sade@1.8.1: + dependencies: + mri: 1.2.0 + + safer-buffer@2.1.2: {} + + scheduler@0.23.2: + dependencies: + loose-envify: 1.4.0 + + scroll-into-view-if-needed@3.1.0: + dependencies: + compute-scroll-into-view: 3.1.0 + + section-matter@1.0.0: + dependencies: + extend-shallow: 2.0.1 + kind-of: 6.0.3 + + shebang-command@1.2.0: + dependencies: + shebang-regex: 1.0.0 + + shebang-regex@1.0.0: {} + + shiki@0.14.7: + dependencies: + ansi-sequence-parser: 1.1.1 + jsonc-parser: 3.2.1 + vscode-oniguruma: 1.7.0 + vscode-textmate: 8.0.0 + + signal-exit@3.0.7: {} + + slash@3.0.0: {} + + sort-keys@5.0.0: + dependencies: + is-plain-obj: 4.1.0 + + source-map-js@1.2.0: {} + + source-map@0.7.4: {} + + space-separated-tokens@2.0.2: {} + + sprintf-js@1.0.3: {} + + streamsearch@1.1.0: {} + + stringify-entities@4.0.4: + dependencies: + character-entities-html4: 2.1.0 + character-entities-legacy: 3.0.0 + + strip-bom-string@1.0.0: {} + + strip-eof@1.0.0: {} + + style-to-object@0.4.4: + dependencies: + inline-style-parser: 0.1.1 + + styled-jsx@5.1.1(react@18.3.1): + dependencies: + client-only: 0.0.1 + react: 18.3.1 + + stylis@4.3.2: {} + + supports-color@4.5.0: + dependencies: + has-flag: 2.0.0 + + title@3.5.3: + dependencies: + arg: 1.0.0 + chalk: 2.3.0 + clipboardy: 1.2.2 + titleize: 1.0.0 + + titleize@1.0.0: {} + + trim-lines@3.0.1: {} + + trough@2.2.0: {} + + ts-dedent@2.2.0: {} + + tslib@2.6.3: {} + + type-fest@1.4.0: {} + + typescript@5.4.5: {} + + undici-types@5.26.5: {} + + unified@10.1.2: + dependencies: + '@types/unist': 2.0.10 + bail: 2.0.2 + extend: 3.0.2 + is-buffer: 2.0.5 + is-plain-obj: 4.1.0 + trough: 2.2.0 + vfile: 5.3.7 + + unist-util-find-after@5.0.0: + dependencies: + '@types/unist': 3.0.2 + unist-util-is: 6.0.0 + + unist-util-generated@2.0.1: {} + + unist-util-is@5.2.1: + dependencies: + '@types/unist': 2.0.10 + + unist-util-is@6.0.0: + dependencies: + '@types/unist': 3.0.2 + + unist-util-position-from-estree@1.1.2: + dependencies: + '@types/unist': 2.0.10 + + unist-util-position@4.0.4: + dependencies: + '@types/unist': 2.0.10 + + unist-util-position@5.0.0: + dependencies: + '@types/unist': 3.0.2 + + unist-util-remove-position@4.0.2: + dependencies: + '@types/unist': 2.0.10 + unist-util-visit: 4.1.2 + + unist-util-remove-position@5.0.0: + dependencies: + '@types/unist': 3.0.2 + unist-util-visit: 5.0.0 + + unist-util-remove@4.0.0: + dependencies: + '@types/unist': 3.0.2 + unist-util-is: 6.0.0 + unist-util-visit-parents: 6.0.1 + + unist-util-stringify-position@3.0.3: + dependencies: + '@types/unist': 2.0.10 + + unist-util-stringify-position@4.0.0: + dependencies: + '@types/unist': 3.0.2 + + unist-util-visit-parents@4.1.1: + dependencies: + '@types/unist': 2.0.10 + unist-util-is: 5.2.1 + + unist-util-visit-parents@5.1.3: + dependencies: + '@types/unist': 2.0.10 + unist-util-is: 5.2.1 + + unist-util-visit-parents@6.0.1: + dependencies: + '@types/unist': 3.0.2 + unist-util-is: 6.0.0 + + unist-util-visit@3.1.0: + dependencies: + '@types/unist': 2.0.10 + unist-util-is: 5.2.1 + unist-util-visit-parents: 4.1.1 + + unist-util-visit@4.1.2: + dependencies: + '@types/unist': 2.0.10 + unist-util-is: 5.2.1 + unist-util-visit-parents: 5.1.3 + + unist-util-visit@5.0.0: + dependencies: + '@types/unist': 3.0.2 + unist-util-is: 6.0.0 + unist-util-visit-parents: 6.0.1 + + uuid@9.0.1: {} + + uvu@0.5.6: + dependencies: + dequal: 2.0.3 + diff: 5.2.0 + kleur: 4.1.5 + sade: 1.8.1 + + vfile-location@5.0.2: + dependencies: + '@types/unist': 3.0.2 + vfile: 6.0.1 + + vfile-matter@3.0.1: + dependencies: + '@types/js-yaml': 4.0.9 + is-buffer: 2.0.5 + js-yaml: 4.1.0 + + vfile-message@3.1.4: + dependencies: + '@types/unist': 2.0.10 + unist-util-stringify-position: 3.0.3 + + vfile-message@4.0.2: + dependencies: + '@types/unist': 3.0.2 + unist-util-stringify-position: 4.0.0 + + vfile@5.3.7: + dependencies: + '@types/unist': 2.0.10 + is-buffer: 2.0.5 + unist-util-stringify-position: 3.0.3 + vfile-message: 3.1.4 + + vfile@6.0.1: + dependencies: + '@types/unist': 3.0.2 + unist-util-stringify-position: 4.0.0 + vfile-message: 4.0.2 + + vscode-oniguruma@1.7.0: {} + + vscode-textmate@8.0.0: {} + + web-namespaces@2.0.1: {} + + web-worker@1.3.0: {} + + which@1.3.1: + dependencies: + isexe: 2.0.0 + + yallist@2.1.2: {} + + yocto-queue@0.1.0: {} + + zod@3.23.8: {} + + zwitch@2.0.4: {} diff --git a/docs/assets/banner.png b/docs/public/assets/banner.png similarity index 100% rename from docs/assets/banner.png rename to docs/public/assets/banner.png diff --git a/docs/assets/icon.png b/docs/public/assets/icon.png similarity index 100% rename from docs/assets/icon.png rename to docs/public/assets/icon.png diff --git a/docs/assets/sample-onnx-graph.png b/docs/public/assets/sample-onnx-graph.png similarity index 100% rename from docs/assets/sample-onnx-graph.png rename to docs/public/assets/sample-onnx-graph.png diff --git a/docs/assets/trend-banner.png b/docs/public/assets/trend-banner.png similarity index 100% rename from docs/assets/trend-banner.png rename to docs/public/assets/trend-banner.png diff --git a/docs/setup/linking.mdx b/docs/setup/linking.mdx deleted file mode 100644 index ecba449..0000000 --- a/docs/setup/linking.mdx +++ /dev/null @@ -1,106 +0,0 @@ ---- -title: Linking -description: Here's how `ort` links to ONNX Runtime, and how to configure its behavior. ---- - -In some cases, you'll want to use a custom build of ONNX Runtime with `ort`. Luckily, we make this very easy by handling all of the linking configuration automagically. Just point `ort` to the output of ONNX Runtime's build pipeline and it'll Just Workβ„’. - -## Static linking -Most ONNX Runtime compile configurations will support static linking - just run `build.sh` without the `--build_shared_lib` argument. You should prefer static linking if your execution providers support it, as it avoids many issues and follows de facto Rust practices. If you compile both static libraries and dynamic libraries, `ort` will prefer linking to the static libraries. - -To direct `ort` to your statically built binaries, use the `ORT_LIB_LOCATION` environment variable when running `cargo build`. Point it to the location where the static libraries (`.a`/`.lib` files) are compiled to. This will typically be `onnxruntime/build/`. For example: -```shell -$ ORT_LIB_LOCATION=~/onnxruntime/build/Linux cargo build -``` - -For iOS (or for other platforms if you are compiling multiple profiles at once), you'll need to manually specify the profile with the `ORT_LIB_PROFILE` environment variable. If not specified, `ort` will prefer `Release` over `RelWithDebInfo` over `MinSizeRel` over `Debug`. - -## Dynamic linking -Some execution providers unfortunately only support dynamic linking. Dynamic linking doesn't play well with the Rust ecosystem, though `ort` tries to alleviate the pain as much as possible. - -When it comes to dynamic linking, there are two options: `load-dynamic`, or standard compile-time dynamic linking. We recommend `load-dynamic` as it gives more control and is often far less troublesome to work with. - -### Runtime loading with `load-dynamic` -The `load-dynamic` Cargo feature solves a few of the issues with dynamic linking by **loading the library at runtime** rather than **linking at compile time**. This means that the path to the ONNX Runtime library can be configured at runtime, and the executable will not just completely fail to start if the binary couldn't be found. - -To use `load-dynamic`: - - - ```toml Cargo.toml - [dependencies] - ort = { version = "2", features = [ "load-dynamic" ] } - ``` - - - - - ```rust main.rs - fn main() -> anyhow::Result<()> { - // Find our custom ONNX Runtime dylib path somehow - // (i.e. resolving it from the root of our program's install folder) - let dylib_path = crate::internal::find_onnxruntime_dylib()?; - // The path should point to the `libonnxruntime` binary, which looks like: - // - on Unix: /etc/.../libonnxruntime.so - // - on Windows: C:\Program Files\...\onnxruntime.dll - - // Initialize ort with the path to the dylib. This **must** be called before any usage of `ort`! - // `init_from` returns an `EnvironmentBuilder` which you can use to further configure the environment - // before `.commit()`ing; see the Environment docs for more information on what you can configure. - ort::init_from(dylib_path).commit()?; - - Ok(()) - } - ``` - - - Set the `ORT_DYLIB_PATH` environment variable to the path to `libonnxruntime.so`/`onnxruntime.dll`. - - ```shell - $ ORT_DYLIB_PATH=../onnxruntime-build/linux-x64/libonnxruntime.so ./mirai - ``` - - - - - -`ORT_DYLIB_PATH` is relative to the executable. Cargo examples and tests are compiled to a different directory than binary crates: `target//examples` and `target//deps` respectively. Keep this in mind if you're going to use relative paths. - -### Compile-time dynamic linking -For compile-time dynamic linking, you'll need to configure your environment in the exact same way as if you were [statically linking](#static-linking). - -Note that the dylibs then have to be placed in a certain location for them to be found by the executable. For Windows, this is either somewhere on the `PATH`, or in the same folder as the executable. On macOS and Linux, they have to be placed somewhere in the `LD_LIBRARY_PATH`, or you can use rpath to configure the executable to search for dylibs in its parent folder. We've had the least issues with rpath, but YMMV. - -To configure rpath, you'll need to: - - - ```toml - [profile.dev] - rpath = true - - [profile.release] - rpath = true - - # do this for any other profiles - ``` - - - - - ```toml - [target.x86_64-unknown-linux-gnu] - rustflags = [ "-Clink-args=-Wl,-rpath,\\$ORIGIN" ] - - # do this for any other Linux targets as well - ``` - - - ```toml - [target.x86_64-apple-darwin] - rustflags = [ "-Clink-args=-Wl,-rpath,@loader_path" ] - - # do this for any other macOS targets as well - ``` - - - - diff --git a/docs/theme.config.jsx b/docs/theme.config.jsx new file mode 100644 index 0000000..ef71c4e --- /dev/null +++ b/docs/theme.config.jsx @@ -0,0 +1,33 @@ +import Image from 'next/image'; + +/** @type {import('nextra-theme-docs').DocsThemeConfig} */ +const config = { + project: { + link: 'https://github.com/pykeio/ort' + }, + chat: { + link: 'https://discord.gg/uQtsNu2xMa' + }, + docsRepositoryBase: 'https://github.com/pykeio/ort/blob/main/docs', + useNextSeoProps() { + return { + titleTemplate: '%s | ort' + } + }, + logo: , + darkMode: true, + nextThemes: { + defaultTheme: 'system' + }, + footer: { + text:
+

made with πŸ’œ by pyke β€’ sponsor

+
+ }, + primaryHue: 20, + primarySaturation: 100, + toc: { + float: true + } +}; +export default config; diff --git a/docs/tsconfig.json b/docs/tsconfig.json new file mode 100644 index 0000000..19deeff --- /dev/null +++ b/docs/tsconfig.json @@ -0,0 +1,28 @@ +{ + "compilerOptions": { + "lib": [ + "dom", + "dom.iterable", + "esnext" + ], + "allowJs": true, + "skipLibCheck": true, + "strict": false, + "noEmit": true, + "incremental": true, + "module": "esnext", + "esModuleInterop": true, + "moduleResolution": "node", + "resolveJsonModule": true, + "isolatedModules": true, + "jsx": "preserve" + }, + "include": [ + "next-env.d.ts", + "**/*.ts", + "**/*.tsx" +, "pages/_app.mdx" ], + "exclude": [ + "node_modules" + ] +} From a3ffe92572655de579c1c4a5d1a38d743731d39a Mon Sep 17 00:00:00 2001 From: "Carson M." Date: Wed, 12 Jun 2024 11:56:53 -0500 Subject: [PATCH 10/12] feat: simple trainer API --- examples/training/README.md | 3 + .../training/examples/train-clm-simple.rs | 118 +++++++++ src/training/mod.rs | 214 +-------------- src/training/simple.rs | 250 ++++++++++++++++++ src/training/trainer.rs | 235 ++++++++++++++++ 5 files changed, 615 insertions(+), 205 deletions(-) create mode 100644 examples/training/examples/train-clm-simple.rs create mode 100644 src/training/simple.rs create mode 100644 src/training/trainer.rs diff --git a/examples/training/README.md b/examples/training/README.md index 7ae0fef..7c99d64 100644 --- a/examples/training/README.md +++ b/examples/training/README.md @@ -21,3 +21,6 @@ I'm so much better than the game<|endoftext|>I think you can't see it<|endoftext ``` Not bad, considering the model & dataset size! This example can easily be scaled up to pre-train or fine-tune (both full-parameter and PEFT) larger language models like Llama/Phi, so long as you have enough compute. + +## `train-clm-simple` +This example is functionally identical to `train-clm`, except it uses ort's "simple" Trainer API instead of implementing a manual training loop. The simple API is more akin to πŸ€— Transformer's [`Trainer`](https://huggingface.co/docs/transformers/en/main_classes/trainer) API or [PyTorch Lightning](https://lightning.ai/pytorch-lightning). With the simple API, all you have to do is pass a data loader & parameters, and let `ort` handle training for you! diff --git a/examples/training/examples/train-clm-simple.rs b/examples/training/examples/train-clm-simple.rs new file mode 100644 index 0000000..0c3ac32 --- /dev/null +++ b/examples/training/examples/train-clm-simple.rs @@ -0,0 +1,118 @@ +use std::{ + fs::File, + io::{Read, Seek, SeekFrom, Write}, + path::Path +}; + +use ndarray::{concatenate, s, Array1, Array2, ArrayViewD, Axis}; +use ort::{Allocator, CUDAExecutionProvider, CheckpointStrategy, Session, SessionBuilder, Trainer, TrainingArguments}; +use rand::RngCore; +use tokenizers::Tokenizer; + +const BATCH_SIZE: usize = 16; +const SEQUENCE_LENGTH: usize = 256; + +fn main() -> ort::Result<()> { + tracing_subscriber::fmt::init(); + + ort::init().commit()?; + + let trainer = Trainer::new_from_artifacts( + SessionBuilder::new()?.with_execution_providers([CUDAExecutionProvider::default().build()])?, + Allocator::default(), + "tools/train-data/mini-clm", + None + )?; + + let tokenizer = Tokenizer::from_file( + Path::new(env!("CARGO_MANIFEST_DIR")) + .parent() + .unwrap() + .join("gpt2") + .join("data") + .join("tokenizer.json") + ) + .unwrap(); + + let mut dataset = File::open("dataset.bin").unwrap(); + let file_size = dataset.metadata().unwrap().len(); + let num_tokens = (file_size / 2) as usize; // 16-bit tokens + let mut rng = rand::thread_rng(); + let mut input_buffer = vec![0u16; SEQUENCE_LENGTH * BATCH_SIZE]; + let mut label_buffer = vec![0u16; SEQUENCE_LENGTH * BATCH_SIZE]; + let dataloader = move |_: usize| { + for batch in 0..BATCH_SIZE { + let start_idx = rng.next_u64() % (num_tokens - SEQUENCE_LENGTH - 1) as u64; + dataset.seek(SeekFrom::Start(start_idx * 2)).unwrap(); + dataset + .read_exact(unsafe { + std::slice::from_raw_parts_mut( + input_buffer[batch * SEQUENCE_LENGTH..(batch + 1) * SEQUENCE_LENGTH] + .as_mut_ptr() + .cast::(), + SEQUENCE_LENGTH * 2 + ) + }) + .unwrap(); + dataset.seek(SeekFrom::Start((start_idx + 1) * 2)).unwrap(); + dataset + .read_exact(unsafe { + std::slice::from_raw_parts_mut( + label_buffer[batch * SEQUENCE_LENGTH..(batch + 1) * SEQUENCE_LENGTH] + .as_mut_ptr() + .cast::(), + SEQUENCE_LENGTH * 2 + ) + }) + .unwrap(); + } + + Ok(( + ort::inputs![Array2::::from_shape_vec([BATCH_SIZE, SEQUENCE_LENGTH], input_buffer.iter().map(|c| *c as i64).collect()).unwrap()]?, + ort::inputs![Array1::::from_shape_vec([BATCH_SIZE * SEQUENCE_LENGTH], label_buffer.iter().map(|c| *c as i64).collect()).unwrap()]? + )) + }; + + trainer.train( + TrainingArguments::new(dataloader) + .with_lr(7e-5) + .with_max_steps(5000) + .with_ckpt_strategy(CheckpointStrategy::Steps(500)) + )?; + + trainer.export("trained-clm.onnx", ["probs"])?; + + let session = Session::builder()?.commit_from_file("trained-clm.onnx")?; + + let mut stdout = std::io::stdout(); + + let tokens = tokenizer.encode("<|endoftext|>", false).unwrap(); + let tokens = tokens.get_ids().iter().map(|i| *i as i64).collect::>(); + + let mut tokens = Array1::from_iter(tokens.iter().cloned()); + + for _ in 0..50 { + let array = tokens.view().insert_axis(Axis(0)); + let outputs = session.run(ort::inputs![array]?)?; + let generated_tokens: ArrayViewD = outputs["probs"].try_extract_tensor()?; + + let probabilities = &mut generated_tokens + .slice(s![-1, ..]) + .to_owned() + .iter() + .cloned() + .enumerate() + .collect::>(); + probabilities.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Less)); + + let token = probabilities[0].0; + tokens = concatenate![Axis(0), tokens, ndarray::array![token.try_into().unwrap()]]; + + let token_str = tokenizer.decode(&[token as _], false).unwrap(); + print!("{}", token_str); + stdout.flush().unwrap(); + } + + println!(); + Ok(()) +} diff --git a/src/training/mod.rs b/src/training/mod.rs index a951715..d66db11 100644 --- a/src/training/mod.rs +++ b/src/training/mod.rs @@ -1,19 +1,20 @@ use std::{ - ffi::CString, path::Path, ptr::{self, NonNull}, sync::{ atomic::{AtomicPtr, Ordering}, - Arc, OnceLock + OnceLock } }; -use ort_sys::c_char; +use crate::{ortsys, Error, Result, RunOptions}; -use crate::{ - char_p_to_string, - error::{assert_non_null_pointer, status_to_result}, - ortsys, Allocator, Error, Result, RunOptions, SessionBuilder, SessionInputValue, SessionInputs, SessionOutputs, Value +mod simple; +mod trainer; + +pub use self::{ + simple::{iterable_data_loader, CheckpointStrategy, DataLoader, EvaluationStrategy, IterableDataLoader, TrainingArguments}, + trainer::Trainer }; pub(crate) static TRAINING_API: OnceLock> = OnceLock::new(); @@ -79,6 +80,7 @@ macro_rules! trainsys { $($crate::error::assert_non_null_pointer($check, stringify!($method))?;)+ }}; } +pub(crate) use trainsys; #[derive(Debug)] pub struct Checkpoint { @@ -138,201 +140,3 @@ impl Optimizer { Ok(()) } } - -#[derive(Debug)] -pub struct Trainer { - pub(crate) ptr: NonNull, - train_output_names: Vec, - optimizer: Optimizer, - ckpt: Checkpoint, - _allocator: Allocator -} - -impl Trainer { - pub fn new( - session_options: SessionBuilder, - allocator: Allocator, - ckpt: Checkpoint, - training_model_path: impl AsRef, - eval_model_path: impl AsRef, - optimizer_model_path: impl AsRef - ) -> Result { - let training_model_path = crate::util::path_to_os_char(training_model_path); - let eval_model_path = crate::util::path_to_os_char(eval_model_path); - let optimizer_model_path = crate::util::path_to_os_char(optimizer_model_path); - - let env = crate::get_environment()?; - - let mut ptr: *mut ort_sys::OrtTrainingSession = ptr::null_mut(); - trainsys![unsafe CreateTrainingSession(env.ptr(), session_options.session_options_ptr.as_ptr(), ckpt.ptr.as_ptr(), training_model_path.as_ptr(), eval_model_path.as_ptr(), optimizer_model_path.as_ptr(), &mut ptr) -> Error::CreateSession; nonNull(ptr)]; - - let ptr = unsafe { NonNull::new_unchecked(ptr) }; - - let mut train_output_len = 0; - trainsys![unsafe TrainingSessionGetTrainingModelOutputCount(ptr.as_ptr(), &mut train_output_len) -> Error::CreateSession]; - let train_output_names = (0..train_output_len) - .map(|i| { - let mut name_bytes: *mut c_char = std::ptr::null_mut(); - trainsys![unsafe TrainingSessionGetTrainingModelOutputName(ptr.as_ptr(), i, allocator.ptr.as_ptr(), &mut name_bytes) -> Error::CreateSession]; - let name = match char_p_to_string(name_bytes) { - Ok(name) => name, - Err(e) => { - unsafe { allocator.free(name_bytes) }; - return Err(e); - } - }; - unsafe { allocator.free(name_bytes) }; - Ok(name) - }) - .collect::>>()?; - - Ok(Self { - ptr, - _allocator: allocator, - train_output_names, - optimizer: Optimizer(ptr), - ckpt - }) - } - - pub fn step<'s, 'i1, 'v1: 'i1, 'i2: 'i1, 'v2: 'i2 + 'i1, const N1: usize, const N2: usize>( - &'s self, - inputs: impl Into>, - labels: impl Into> - ) -> Result> { - match inputs.into() { - SessionInputs::ValueSlice(input_values) => match labels.into() { - SessionInputs::ValueSlice(labels) => self.step_inner(input_values.iter().chain(labels), None), - SessionInputs::ValueArray(labels) => self.step_inner(input_values.iter().chain(labels.iter()), None), - SessionInputs::ValueMap(_) => unimplemented!("named values not supported?") - }, - SessionInputs::ValueArray(input_values) => match labels.into() { - SessionInputs::ValueSlice(labels) => self.step_inner(input_values.iter().chain(labels), None), - SessionInputs::ValueArray(labels) => self.step_inner(input_values.iter().chain(labels.iter()), None), - SessionInputs::ValueMap(_) => unimplemented!("named values not supported?") - }, - SessionInputs::ValueMap(_) => unimplemented!("named values not supported?") - } - } - - fn step_inner<'s, 'i1, 'v1: 'i1, 'i2, 'v2: 'i2>( - &'s self, - input_values: impl Iterator>, - run_options: Option> - ) -> Result> { - let mut output_tensor_ptrs: Vec<*mut ort_sys::OrtValue> = vec![std::ptr::null_mut(); self.train_output_names.len()]; - - let input_ort_values: Vec<*const ort_sys::OrtValue> = input_values.map(|input_array_ort| input_array_ort.ptr().cast_const()).collect(); - - let run_options_ptr = if let Some(run_options) = &run_options { - run_options.run_options_ptr.as_ptr() - } else { - std::ptr::null_mut() - }; - - trainsys![unsafe TrainStep(self.ptr.as_ptr(), run_options_ptr, input_ort_values.len(), input_ort_values.as_ptr(), output_tensor_ptrs.len(), output_tensor_ptrs.as_mut_ptr()) -> Error::SessionRun]; - - let outputs: Vec = output_tensor_ptrs - .into_iter() - .map(|tensor_ptr| unsafe { - // TODO: `Value` should absolutely be refactored to accept a different backing pointer than `SharedSessionInner`. - // but for now, nobody should be using the loss tensor past the lifetime of the trainer... right...? 😣 - Value::from_ptr(NonNull::new(tensor_ptr).expect("OrtValue ptr returned from session Run should not be null"), None) - }) - .collect(); - - Ok(SessionOutputs::new(self.train_output_names.iter().map(|o| o.as_str()), outputs)) - } - - pub fn eval<'s, 'i1, 'v1: 'i1, 'i2: 'i1, 'v2: 'i2 + 'i1, const N1: usize, const N2: usize>( - &'s self, - inputs: impl Into>, - labels: impl Into> - ) -> Result> { - match inputs.into() { - SessionInputs::ValueSlice(input_values) => match labels.into() { - SessionInputs::ValueSlice(labels) => self.eval_inner(input_values.iter().chain(labels), None), - SessionInputs::ValueArray(labels) => self.eval_inner(input_values.iter().chain(labels.iter()), None), - SessionInputs::ValueMap(_) => unimplemented!("named values not supported?") - }, - SessionInputs::ValueArray(input_values) => match labels.into() { - SessionInputs::ValueSlice(labels) => self.eval_inner(input_values.iter().chain(labels), None), - SessionInputs::ValueArray(labels) => self.eval_inner(input_values.iter().chain(labels.iter()), None), - SessionInputs::ValueMap(_) => unimplemented!("named values not supported?") - }, - SessionInputs::ValueMap(_) => unimplemented!("named values not supported?") - } - } - - fn eval_inner<'s, 'i1, 'v1: 'i1, 'i2, 'v2: 'i2>( - &'s self, - input_values: impl Iterator>, - run_options: Option> - ) -> Result> { - let mut output_tensor_ptrs: Vec<*mut ort_sys::OrtValue> = vec![std::ptr::null_mut(); self.train_output_names.len()]; - - let input_ort_values: Vec<*const ort_sys::OrtValue> = input_values.map(|input_array_ort| input_array_ort.ptr().cast_const()).collect(); - - let run_options_ptr = if let Some(run_options) = &run_options { - run_options.run_options_ptr.as_ptr() - } else { - std::ptr::null_mut() - }; - - trainsys![unsafe EvalStep(self.ptr.as_ptr(), run_options_ptr, input_ort_values.len(), input_ort_values.as_ptr(), output_tensor_ptrs.len(), output_tensor_ptrs.as_mut_ptr()) -> Error::SessionRun]; - - let outputs: Vec = output_tensor_ptrs - .into_iter() - .map(|tensor_ptr| unsafe { - // TODO: `Value` should absolutely be refactored to accept a different backing pointer than `SharedSessionInner`. - // but for now, nobody should be using the loss tensor past the lifetime of the trainer... right...? 😣 - Value::from_ptr(NonNull::new(tensor_ptr).expect("OrtValue ptr returned from session Run should not be null"), None) - }) - .collect(); - - Ok(SessionOutputs::new(self.train_output_names.iter().map(|o| o.as_str()), outputs)) - } - - pub fn export>(&self, out_path: impl AsRef, output_names: impl AsRef<[O]>) -> Result<()> { - let out_path = crate::util::path_to_os_char(out_path); - - let output_names_ptr: Vec<*const c_char> = output_names - .as_ref() - .iter() - .map(|output| CString::new(output.as_ref()).unwrap_or_else(|_| unreachable!())) - .map(|n| n.into_raw().cast_const()) - .collect(); - - let res = trainsys![unsafe ExportModelForInferencing(self.ptr.as_ptr(), out_path.as_ptr(), output_names_ptr.len(), output_names_ptr.as_ptr())]; - - // Reconvert name ptrs to CString so drop impl is called and memory is freed - drop( - output_names_ptr - .into_iter() - .map(|p| { - assert_non_null_pointer(p, "c_char for CString")?; - unsafe { Ok(CString::from_raw(p.cast_mut().cast())) } - }) - .collect::>>()? - ); - - status_to_result(res).map_err(Error::CreateSession)?; - - Ok(()) - } - - pub fn optimizer(&self) -> &Optimizer { - &self.optimizer - } - - pub fn ckpt(&self) -> &Checkpoint { - &self.ckpt - } -} - -impl Drop for Trainer { - fn drop(&mut self) { - tracing::trace!("dropping trainer"); - trainsys![unsafe ReleaseTrainingSession(self.ptr.as_ptr())]; - } -} diff --git a/src/training/simple.rs b/src/training/simple.rs new file mode 100644 index 0000000..205bb18 --- /dev/null +++ b/src/training/simple.rs @@ -0,0 +1,250 @@ +use std::{collections::VecDeque, fs, path::PathBuf}; + +use ndarray::Ix0; + +use crate::{Result, SessionInputs}; + +#[allow(clippy::len_without_is_empty)] +pub trait DataLoader { + fn load(&mut self, idx: usize) -> Result<(I, L)>; + + fn len(&self) -> Option { + None + } +} + +pub struct IterableDataLoader Result<(I, L)>> { + items: Box<[T]>, + collator: C +} + +impl Result<(I, L)>> DataLoader for IterableDataLoader { + fn load(&mut self, idx: usize) -> Result<(I, L)> { + (self.collator)(&self.items[idx]) + } + + fn len(&self) -> Option { + Some(self.items.len()) + } +} + +pub fn iterable_data_loader Result<(I, L)>>(iterable: impl Iterator, collator: C) -> IterableDataLoader { + IterableDataLoader { + items: iterable.collect::>().into_boxed_slice(), + collator + } +} + +impl Result<(I, L)>> DataLoader for F { + fn load(&mut self, idx: usize) -> Result<(I, L)> { + (self)(idx) + } + + fn len(&self) -> Option { + None + } +} + +pub enum EvaluationStrategy { + None, + Steps(usize), + Epochs(usize) +} + +impl EvaluationStrategy { + pub(crate) fn should_fire(&self, _global_step: usize, iter_step: usize, dataloader_size: Option) -> bool { + match self { + Self::None => false, + Self::Steps(steps) => iter_step > 0 && iter_step % steps == 0, + Self::Epochs(epochs) => { + if let Some(dataloader_size) = dataloader_size { + iter_step > 0 && iter_step % (dataloader_size * epochs) == 0 + } else { + false + } + } + } + } +} + +pub enum CheckpointStrategy { + None, + Steps(usize), + Epochs(usize) +} + +impl CheckpointStrategy { + pub(crate) fn should_fire(&self, _global_step: usize, iter_step: usize, dataloader_size: Option) -> bool { + match self { + Self::None => false, + Self::Steps(steps) => iter_step > 0 && iter_step % steps == 0, + Self::Epochs(epochs) => { + if let Some(dataloader_size) = dataloader_size { + iter_step > 0 && iter_step % (dataloader_size * epochs) == 0 + } else { + false + } + } + } + } +} + +pub struct TrainingArguments>, L: Into>, const NI: usize, const NL: usize> { + loader: Box>, + eval_loader: Option>>, + eval_strategy: EvaluationStrategy, + ckpt_strategy: CheckpointStrategy, + ckpt_path: PathBuf, + lr: f32, + max_saved_ckpts: usize, + gradient_accumulation_steps: usize, + max_steps: usize, + max_eval_steps: usize +} + +impl>, L: Into>, const NI: usize, const NL: usize> + TrainingArguments +{ + pub fn new + 'static>(train_loader: D) -> Self { + Self { + loader: Box::new(train_loader), + eval_loader: None, + eval_strategy: EvaluationStrategy::None, + ckpt_strategy: CheckpointStrategy::Epochs(1), + ckpt_path: PathBuf::from("checkpoints"), + lr: 1e-4, + gradient_accumulation_steps: 1, + max_saved_ckpts: 1, + max_steps: usize::MAX, + max_eval_steps: usize::MAX + } + } + + pub fn with_lr(mut self, lr: f32) -> Self { + self.lr = lr; + self + } + + pub fn with_max_steps(mut self, steps: usize) -> Self { + self.max_steps = steps; + self + } + + pub fn with_max_eval_steps(mut self, steps: usize) -> Self { + self.max_eval_steps = steps; + self + } + + pub fn with_gradient_accumulation(mut self, steps: usize) -> Self { + self.gradient_accumulation_steps = steps; + self + } + + pub fn with_ckpt_path(mut self, path: impl Into) -> Self { + self.ckpt_path = path.into(); + self + } + + pub fn with_ckpt_strategy(mut self, strategy: CheckpointStrategy) -> Self { + self.ckpt_strategy = strategy; + self + } + + pub fn with_max_saved_ckpts(mut self, max_ckpts: usize) -> Self { + self.max_saved_ckpts = max_ckpts; + self + } + + pub fn with_eval_loader + 'static>(mut self, eval_loader: D) -> Self { + self.eval_loader = Some(Box::new(eval_loader)); + self + } + + pub fn with_eval_strategy(mut self, strategy: EvaluationStrategy) -> Self { + self.eval_strategy = strategy; + self + } +} + +impl super::Trainer { + pub fn train>, L: Into>, const NI: usize, const NL: usize>( + &self, + mut args: TrainingArguments + ) -> crate::Result<()> { + let optimizer = self.optimizer(); + optimizer.set_lr(args.lr)?; + + let mut saved_ckpts = VecDeque::new(); + let mut global_step = 0; + for (iter_step, _) in (0..args.max_steps).enumerate() { + let epoch = iter_step / args.loader.len().unwrap_or(usize::MAX); + let (inputs, labels) = args.loader.load(iter_step)?; + let (inputs, labels) = (inputs.into(), labels.into()); + + let outputs = self.step(inputs, labels)?; + let loss = outputs[0] + .try_extract_tensor::()? + .into_dimensionality::() + .expect("first output should be the 0-dimensional loss tensor") + .into_scalar(); + println!("epoch={epoch} step={global_step} loss={loss}"); + + if iter_step % args.gradient_accumulation_steps == 0 { + optimizer.step()?; + optimizer.reset_grad()?; + global_step += 1; + } + + if args.ckpt_strategy.should_fire(global_step, iter_step, args.loader.len()) { + if !args.ckpt_path.exists() { + let _ = fs::create_dir_all(&args.ckpt_path); + } + + let ckpt_path = args.ckpt_path.join(format!("epoch={epoch},step={global_step}.ortckpt")); + self.checkpoint().save(&ckpt_path, true)?; + + saved_ckpts.push_front(ckpt_path.clone()); + while saved_ckpts.len() > args.max_saved_ckpts { + let Some(old_ckpt) = saved_ckpts.pop_back() else { + break; + }; + let _ = fs::remove_file(old_ckpt); + } + } + + if args + .eval_strategy + .should_fire(global_step, iter_step, args.eval_loader.as_ref().and_then(|d| d.len())) + { + let eval_loss = self.eval_inner(&mut args)?; + println!("eval_loss={eval_loss}"); + } + } + Ok(()) + } + + pub(crate) fn eval_inner>, L: Into>, const NI: usize, const NL: usize>( + &self, + args: &mut TrainingArguments + ) -> crate::Result { + let Some(eval_loader) = &mut args.eval_loader else { + return Ok(0.0); + }; + + let mut total_loss = 0.0; + for step in 0..args.max_eval_steps.min(eval_loader.len().unwrap_or(usize::MAX)) { + let (inputs, labels) = eval_loader.load(step)?; + let (inputs, labels) = (inputs.into(), labels.into()); + + let outputs = self.eval_step(inputs, labels)?; + let loss = outputs[0] + .try_extract_tensor::()? + .into_dimensionality::() + .expect("first output should be the 0-dimensional loss tensor") + .into_scalar(); + total_loss = (total_loss * (step as f32) + loss) / (step as f32 + 1.); + } + + Ok(total_loss) + } +} diff --git a/src/training/trainer.rs b/src/training/trainer.rs new file mode 100644 index 0000000..6ef4c92 --- /dev/null +++ b/src/training/trainer.rs @@ -0,0 +1,235 @@ +use std::{ + ffi::CString, + path::Path, + ptr::{self, NonNull}, + sync::Arc +}; + +use ort_sys::c_char; + +use super::{trainsys, Checkpoint, Optimizer}; +use crate::{ + char_p_to_string, + error::{assert_non_null_pointer, status_to_result}, + Allocator, Error, Result, RunOptions, SessionBuilder, SessionInputValue, SessionInputs, SessionOutputs, Value +}; + +#[derive(Debug)] +pub struct Trainer { + pub(crate) ptr: NonNull, + train_output_names: Vec, + optimizer: Optimizer, + ckpt: Checkpoint, + _allocator: Allocator +} + +impl Trainer { + pub fn new( + session_options: SessionBuilder, + allocator: Allocator, + ckpt: Checkpoint, + training_model_path: impl AsRef, + eval_model_path: impl AsRef, + optimizer_model_path: impl AsRef + ) -> Result { + let training_model_path = crate::util::path_to_os_char(training_model_path); + let eval_model_path = crate::util::path_to_os_char(eval_model_path); + let optimizer_model_path = crate::util::path_to_os_char(optimizer_model_path); + + let env = crate::get_environment()?; + + let mut ptr: *mut ort_sys::OrtTrainingSession = ptr::null_mut(); + trainsys![unsafe CreateTrainingSession(env.ptr(), session_options.session_options_ptr.as_ptr(), ckpt.ptr.as_ptr(), training_model_path.as_ptr(), eval_model_path.as_ptr(), optimizer_model_path.as_ptr(), &mut ptr) -> Error::CreateSession; nonNull(ptr)]; + + let ptr = unsafe { NonNull::new_unchecked(ptr) }; + + let mut train_output_len = 0; + trainsys![unsafe TrainingSessionGetTrainingModelOutputCount(ptr.as_ptr(), &mut train_output_len) -> Error::CreateSession]; + let train_output_names = (0..train_output_len) + .map(|i| { + let mut name_bytes: *mut c_char = std::ptr::null_mut(); + trainsys![unsafe TrainingSessionGetTrainingModelOutputName(ptr.as_ptr(), i, allocator.ptr.as_ptr(), &mut name_bytes) -> Error::CreateSession]; + let name = match char_p_to_string(name_bytes) { + Ok(name) => name, + Err(e) => { + unsafe { allocator.free(name_bytes) }; + return Err(e); + } + }; + unsafe { allocator.free(name_bytes) }; + Ok(name) + }) + .collect::>>()?; + + Ok(Self { + ptr, + _allocator: allocator, + train_output_names, + optimizer: Optimizer(ptr), + ckpt + }) + } + + pub fn new_from_artifacts( + session_options: SessionBuilder, + allocator: Allocator, + base_dir: impl AsRef, + override_ckpt: Option + ) -> Result { + let base_dir = base_dir.as_ref(); + let ckpt = if let Some(ckpt) = override_ckpt { + ckpt + } else { + Checkpoint::load(base_dir.join("checkpoint"))? + }; + Self::new( + session_options, + allocator, + ckpt, + base_dir.join("training_model.onnx"), + base_dir.join("eval_model.onnx"), + base_dir.join("optimizer_model.onnx") + ) + } + + pub fn step<'s, 'i1, 'v1: 'i1, 'i2: 'i1, 'v2: 'i2 + 'i1, const N1: usize, const N2: usize>( + &'s self, + inputs: impl Into>, + labels: impl Into> + ) -> Result> { + match inputs.into() { + SessionInputs::ValueSlice(input_values) => match labels.into() { + SessionInputs::ValueSlice(labels) => self.step_inner(input_values.iter().chain(labels), None), + SessionInputs::ValueArray(labels) => self.step_inner(input_values.iter().chain(labels.iter()), None), + SessionInputs::ValueMap(_) => unimplemented!("named values not supported?") + }, + SessionInputs::ValueArray(input_values) => match labels.into() { + SessionInputs::ValueSlice(labels) => self.step_inner(input_values.iter().chain(labels), None), + SessionInputs::ValueArray(labels) => self.step_inner(input_values.iter().chain(labels.iter()), None), + SessionInputs::ValueMap(_) => unimplemented!("named values not supported?") + }, + SessionInputs::ValueMap(_) => unimplemented!("named values not supported?") + } + } + + fn step_inner<'s, 'i1, 'v1: 'i1, 'i2, 'v2: 'i2>( + &'s self, + input_values: impl Iterator>, + run_options: Option> + ) -> Result> { + let mut output_tensor_ptrs: Vec<*mut ort_sys::OrtValue> = vec![std::ptr::null_mut(); self.train_output_names.len()]; + + let input_ort_values: Vec<*const ort_sys::OrtValue> = input_values.map(|input_array_ort| input_array_ort.ptr().cast_const()).collect(); + + let run_options_ptr = if let Some(run_options) = &run_options { + run_options.run_options_ptr.as_ptr() + } else { + std::ptr::null_mut() + }; + + trainsys![unsafe TrainStep(self.ptr.as_ptr(), run_options_ptr, input_ort_values.len(), input_ort_values.as_ptr(), output_tensor_ptrs.len(), output_tensor_ptrs.as_mut_ptr()) -> Error::SessionRun]; + + let outputs: Vec = output_tensor_ptrs + .into_iter() + .map(|tensor_ptr| unsafe { + // TODO: `Value` should absolutely be refactored to accept a different backing pointer than `SharedSessionInner`. + // but for now, nobody should be using the loss tensor past the lifetime of the trainer... right...? 😣 + Value::from_ptr(NonNull::new(tensor_ptr).expect("OrtValue ptr returned from session Run should not be null"), None) + }) + .collect(); + + Ok(SessionOutputs::new(self.train_output_names.iter().map(|o| o.as_str()), outputs)) + } + + pub fn eval_step<'s, 'i1, 'v1: 'i1, 'i2: 'i1, 'v2: 'i2 + 'i1, const N1: usize, const N2: usize>( + &'s self, + inputs: impl Into>, + labels: impl Into> + ) -> Result> { + match inputs.into() { + SessionInputs::ValueSlice(input_values) => match labels.into() { + SessionInputs::ValueSlice(labels) => self.eval_step_inner(input_values.iter().chain(labels), None), + SessionInputs::ValueArray(labels) => self.eval_step_inner(input_values.iter().chain(labels.iter()), None), + SessionInputs::ValueMap(_) => unimplemented!("named values not supported?") + }, + SessionInputs::ValueArray(input_values) => match labels.into() { + SessionInputs::ValueSlice(labels) => self.eval_step_inner(input_values.iter().chain(labels), None), + SessionInputs::ValueArray(labels) => self.eval_step_inner(input_values.iter().chain(labels.iter()), None), + SessionInputs::ValueMap(_) => unimplemented!("named values not supported?") + }, + SessionInputs::ValueMap(_) => unimplemented!("named values not supported?") + } + } + + fn eval_step_inner<'s, 'i1, 'v1: 'i1, 'i2, 'v2: 'i2>( + &'s self, + input_values: impl Iterator>, + run_options: Option> + ) -> Result> { + let mut output_tensor_ptrs: Vec<*mut ort_sys::OrtValue> = vec![std::ptr::null_mut(); self.train_output_names.len()]; + + let input_ort_values: Vec<*const ort_sys::OrtValue> = input_values.map(|input_array_ort| input_array_ort.ptr().cast_const()).collect(); + + let run_options_ptr = if let Some(run_options) = &run_options { + run_options.run_options_ptr.as_ptr() + } else { + std::ptr::null_mut() + }; + + trainsys![unsafe EvalStep(self.ptr.as_ptr(), run_options_ptr, input_ort_values.len(), input_ort_values.as_ptr(), output_tensor_ptrs.len(), output_tensor_ptrs.as_mut_ptr()) -> Error::SessionRun]; + + let outputs: Vec = output_tensor_ptrs + .into_iter() + .map(|tensor_ptr| unsafe { + // TODO: `Value` should absolutely be refactored to accept a different backing pointer than `SharedSessionInner`. + // but for now, nobody should be using the loss tensor past the lifetime of the trainer... right...? 😣 + Value::from_ptr(NonNull::new(tensor_ptr).expect("OrtValue ptr returned from session Run should not be null"), None) + }) + .collect(); + + Ok(SessionOutputs::new(self.train_output_names.iter().map(|o| o.as_str()), outputs)) + } + + pub fn export>(&self, out_path: impl AsRef, output_names: impl AsRef<[O]>) -> Result<()> { + let out_path = crate::util::path_to_os_char(out_path); + + let output_names_ptr: Vec<*const c_char> = output_names + .as_ref() + .iter() + .map(|output| CString::new(output.as_ref()).unwrap_or_else(|_| unreachable!())) + .map(|n| n.into_raw().cast_const()) + .collect(); + + let res = trainsys![unsafe ExportModelForInferencing(self.ptr.as_ptr(), out_path.as_ptr(), output_names_ptr.len(), output_names_ptr.as_ptr())]; + + // Reconvert name ptrs to CString so drop impl is called and memory is freed + drop( + output_names_ptr + .into_iter() + .map(|p| { + assert_non_null_pointer(p, "c_char for CString")?; + unsafe { Ok(CString::from_raw(p.cast_mut().cast())) } + }) + .collect::>>()? + ); + + status_to_result(res).map_err(Error::CreateSession)?; + + Ok(()) + } + + pub fn optimizer(&self) -> &Optimizer { + &self.optimizer + } + + pub fn checkpoint(&self) -> &Checkpoint { + &self.ckpt + } +} + +impl Drop for Trainer { + fn drop(&mut self) { + tracing::trace!("dropping trainer"); + trainsys![unsafe ReleaseTrainingSession(self.ptr.as_ptr())]; + } +} From b47ae603dee27fad8067673c3e57b214cde79ce3 Mon Sep 17 00:00:00 2001 From: "Carson M." Date: Mon, 24 Jun 2024 20:37:22 -0500 Subject: [PATCH 11/12] refactor: use `try_extract_scalar` on loss output --- examples/training/examples/train-clm.rs | 8 ++------ src/training/simple.rs | 14 ++------------ 2 files changed, 4 insertions(+), 18 deletions(-) diff --git a/examples/training/examples/train-clm.rs b/examples/training/examples/train-clm.rs index 54547dc..9e46bf4 100644 --- a/examples/training/examples/train-clm.rs +++ b/examples/training/examples/train-clm.rs @@ -5,7 +5,7 @@ use std::{ }; use kdam::BarExt; -use ndarray::{concatenate, s, Array1, Array2, ArrayViewD, Axis, Ix0}; +use ndarray::{concatenate, s, Array1, Array2, ArrayViewD, Axis}; use ort::{Allocator, CUDAExecutionProvider, Checkpoint, Session, SessionBuilder, Trainer}; use rand::RngCore; use tokenizers::Tokenizer; @@ -82,11 +82,7 @@ fn main() -> ort::Result<()> { let labels = Array1::::from_shape_vec([BATCH_SIZE * SEQUENCE_LENGTH], label_buffer.iter().map(|c| *c as i64).collect()).unwrap(); let outputs = trainer.step(ort::inputs![inputs.view()]?, ort::inputs![labels.view()]?)?; - let loss = outputs[0] - .try_extract_tensor::()? - .into_dimensionality::() - .unwrap() - .into_scalar(); + let loss = outputs[0].try_extract_scalar::()?; pb.set_postfix(format!("loss={loss:.3}")); pb.update(1).unwrap(); if loss.is_nan() { diff --git a/src/training/simple.rs b/src/training/simple.rs index 205bb18..267f3c6 100644 --- a/src/training/simple.rs +++ b/src/training/simple.rs @@ -1,7 +1,5 @@ use std::{collections::VecDeque, fs, path::PathBuf}; -use ndarray::Ix0; - use crate::{Result, SessionInputs}; #[allow(clippy::len_without_is_empty)] @@ -182,11 +180,7 @@ impl super::Trainer { let (inputs, labels) = (inputs.into(), labels.into()); let outputs = self.step(inputs, labels)?; - let loss = outputs[0] - .try_extract_tensor::()? - .into_dimensionality::() - .expect("first output should be the 0-dimensional loss tensor") - .into_scalar(); + let loss = outputs[0].try_extract_scalar::()?; println!("epoch={epoch} step={global_step} loss={loss}"); if iter_step % args.gradient_accumulation_steps == 0 { @@ -237,11 +231,7 @@ impl super::Trainer { let (inputs, labels) = (inputs.into(), labels.into()); let outputs = self.eval_step(inputs, labels)?; - let loss = outputs[0] - .try_extract_tensor::()? - .into_dimensionality::() - .expect("first output should be the 0-dimensional loss tensor") - .into_scalar(); + let loss = outputs[0].try_extract_scalar::()?; total_loss = (total_loss * (step as f32) + loss) / (step as f32 + 1.); } From 27676fcd7875c3347b140f4810ae68e4dcc7d641 Mon Sep 17 00:00:00 2001 From: "Carson M." Date: Sun, 30 Jun 2024 12:36:51 -0500 Subject: [PATCH 12/12] fix: take `RunOptions` by reference --- src/training/trainer.rs | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/training/trainer.rs b/src/training/trainer.rs index 6ef4c92..f7c7cb3 100644 --- a/src/training/trainer.rs +++ b/src/training/trainer.rs @@ -96,7 +96,7 @@ impl Trainer { &'s self, inputs: impl Into>, labels: impl Into> - ) -> Result> { + ) -> Result> { match inputs.into() { SessionInputs::ValueSlice(input_values) => match labels.into() { SessionInputs::ValueSlice(labels) => self.step_inner(input_values.iter().chain(labels), None), @@ -112,11 +112,11 @@ impl Trainer { } } - fn step_inner<'s, 'i1, 'v1: 'i1, 'i2, 'v2: 'i2>( + fn step_inner<'r, 's: 'r, 'i1, 'v1: 'i1, 'i2, 'v2: 'i2>( &'s self, input_values: impl Iterator>, - run_options: Option> - ) -> Result> { + run_options: Option<&'r RunOptions> + ) -> Result> { let mut output_tensor_ptrs: Vec<*mut ort_sys::OrtValue> = vec![std::ptr::null_mut(); self.train_output_names.len()]; let input_ort_values: Vec<*const ort_sys::OrtValue> = input_values.map(|input_array_ort| input_array_ort.ptr().cast_const()).collect(); @@ -145,7 +145,7 @@ impl Trainer { &'s self, inputs: impl Into>, labels: impl Into> - ) -> Result> { + ) -> Result> { match inputs.into() { SessionInputs::ValueSlice(input_values) => match labels.into() { SessionInputs::ValueSlice(labels) => self.eval_step_inner(input_values.iter().chain(labels), None), @@ -161,11 +161,11 @@ impl Trainer { } } - fn eval_step_inner<'s, 'i1, 'v1: 'i1, 'i2, 'v2: 'i2>( + fn eval_step_inner<'r, 's: 'r, 'i1, 'v1: 'i1, 'i2, 'v2: 'i2>( &'s self, input_values: impl Iterator>, - run_options: Option> - ) -> Result> { + run_options: Option<&'r RunOptions> + ) -> Result> { let mut output_tensor_ptrs: Vec<*mut ort_sys::OrtValue> = vec![std::ptr::null_mut(); self.train_output_names.len()]; let input_ort_values: Vec<*const ort_sys::OrtValue> = input_values.map(|input_array_ort| input_array_ort.ptr().cast_const()).collect();