Skip to content

Commit

Permalink
refactor: usability
Browse files Browse the repository at this point in the history
aka The Cleanening, part 2

- Add clearer documentation and examples for more things.
- Rework string tensors by introducing `PrimitiveTensorElementType` for primitive (i.e. f32) types, and again re-implementing `IntoTensorElementType` for `String`. This allows string tensors to be used via `Tensor<String>` instead of exclusively via `DynTensor`. Additionally, string tensors no longer require an `Allocator` to be created (which didn't make sense, since string data in Rust can only ever be stored on the CPU anyway). This also now applies to `Map`s, since their data also needed to be on the CPU anyway. (`Sequence`s are currently unaffected because I think a custom allocator could be useful for them?)
- Rework the `IoBinding` interface, and add an example clarifying the intended usage of it (ref #209). Thanks to AAce from the pyke Discord for pointing out the mutability issue in the old interface, which should be addressed now.
- Refactor `OperatorDomain::add` from the slightly-nicer-looking-but-more-confusing `fn<T>(t: T)` to just `fn<T>()` to further enforce the fact that `Operator`s are zero-sized.
- Maps can now have `String` keys.
- Remove some unused errors.
  • Loading branch information
decahedron1 committed Jun 21, 2024
1 parent 19d66de commit c64b8ea
Show file tree
Hide file tree
Showing 20 changed files with 1,112 additions and 566 deletions.
2 changes: 1 addition & 1 deletion examples/custom-ops/examples/custom-ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ impl Kernel for CustomOpTwoKernel {

fn main() -> ort::Result<()> {
let session = Session::builder()?
.with_operators(OperatorDomain::new("test.customop")?.add(CustomOpOne)?.add(CustomOpTwo)?)?
.with_operators(OperatorDomain::new("test.customop")?.add::<CustomOpOne>()?.add::<CustomOpTwo>()?)?
.commit_from_file("tests/data/custom_op_test.onnx")?;

let values = session.run(ort::inputs![Array2::<f32>::zeros((3, 5)), Array2::<f32>::ones((3, 5))]?)?;
Expand Down
80 changes: 58 additions & 22 deletions src/environment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ pub struct Environment {
}

impl Environment {
/// Loads the underlying [`ort_sys::OrtEnv`] pointer.
/// Returns the underlying [`ort_sys::OrtEnv`] pointer.
pub fn ptr(&self) -> *mut ort_sys::OrtEnv {
self.env_ptr.load(Ordering::Relaxed)
}
Expand All @@ -52,13 +52,14 @@ impl Drop for Environment {
}
}

/// Gets a reference to the global environment, creating one if an environment has been committed yet.
/// Gets a reference to the global environment, creating one if an environment has not been
/// [`commit`](EnvironmentBuilder::commit)ted yet.
pub fn get_environment() -> Result<&'static Arc<Environment>> {
if let Some(c) = unsafe { &*G_ENV.cell.get() } {
Ok(c)
} else {
debug!("Environment not yet initialized, creating a new one");
EnvironmentBuilder::default().commit()?;
EnvironmentBuilder::new().commit()?;

Ok(unsafe { (*G_ENV.cell.get()).as_ref().unwrap_unchecked() })
}
Expand All @@ -72,28 +73,26 @@ pub struct EnvironmentGlobalThreadPoolOptions {
pub intra_op_thread_affinity: Option<String>
}

/// Struct used to build an `Environment`.
/// Struct used to build an [`Environment`]; see [`crate::init`].
pub struct EnvironmentBuilder {
name: String,
telemetry: bool,
execution_providers: Vec<ExecutionProviderDispatch>,
global_thread_pool_options: Option<EnvironmentGlobalThreadPoolOptions>
}

impl Default for EnvironmentBuilder {
fn default() -> Self {
impl EnvironmentBuilder {
pub(crate) fn new() -> Self {
EnvironmentBuilder {
name: "default".to_string(),
telemetry: true,
execution_providers: vec![],
global_thread_pool_options: None
}
}
}

impl EnvironmentBuilder {
/// Configure the environment with a given name for logging purposes.
#[must_use]
#[must_use = "commit() must be called in order for the environment to take effect"]
pub fn with_name<S>(mut self, name: S) -> Self
where
S: Into<String>
Expand All @@ -102,7 +101,17 @@ impl EnvironmentBuilder {
self
}

#[must_use]
/// Enable or disable sending telemetry events to Microsoft.
///
/// Typically, only Windows builds of ONNX Runtime provided by Microsoft will have telemetry enabled.
/// Pre-built binaries provided by pyke, or binaries compiled from source, won't have telemetry enabled.
///
/// The exact kind of telemetry data sent can be found [here](https://github.com/microsoft/onnxruntime/blob/v1.18.0/onnxruntime/core/platform/windows/telemetry.cc).
/// Currently, this includes (but is not limited to): ONNX graph version, model producer name & version, whether or
/// not FP16 is used, operator domains & versions, model graph name & custom metadata, execution provider names,
/// error messages, and the total number & time of session inference runs. The ONNX Runtime team uses this data to
/// better understand how customers use ONNX Runtime and where performance can be improved.
#[must_use = "commit() must be called in order for the environment to take effect"]
pub fn with_telemetry(mut self, enable: bool) -> Self {
self.telemetry = enable;
self
Expand All @@ -116,14 +125,14 @@ impl EnvironmentBuilder {
/// Execution providers will only work if the corresponding Cargo feature is enabled and ONNX Runtime was built
/// with support for the corresponding execution provider. Execution providers that do not have their corresponding
/// feature enabled will emit a warning.
#[must_use]
#[must_use = "commit() must be called in order for the environment to take effect"]
pub fn with_execution_providers(mut self, execution_providers: impl AsRef<[ExecutionProviderDispatch]>) -> Self {
self.execution_providers = execution_providers.as_ref().to_vec();
self
}

/// Enables the global thread pool for this environment.
#[must_use]
#[must_use = "commit() must be called in order for the environment to take effect"]
pub fn with_global_thread_pool(mut self, options: EnvironmentGlobalThreadPoolOptions) -> Self {
self.global_thread_pool_options = Some(options);
self
Expand Down Expand Up @@ -158,14 +167,17 @@ impl EnvironmentBuilder {
ortsys![unsafe SetGlobalIntraOpThreadAffinity(thread_options, cstr.as_ptr()) -> Error::CreateEnvironment];
}

ortsys![unsafe CreateEnvWithCustomLoggerAndGlobalThreadPools(
ortsys![
unsafe CreateEnvWithCustomLoggerAndGlobalThreadPools(
logging_function,
logger_param,
ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE,
cname.as_ptr(),
thread_options,
&mut env_ptr
) -> Error::CreateEnvironment; nonNull(env_ptr)];
) -> Error::CreateEnvironment;
nonNull(env_ptr)
];
ortsys![unsafe ReleaseThreadingOptions(thread_options)];
(env_ptr, true)
} else {
Expand All @@ -174,13 +186,16 @@ impl EnvironmentBuilder {
// FIXME: What should go here?
let logger_param: *mut std::ffi::c_void = std::ptr::null_mut();
let cname = CString::new(self.name.clone()).unwrap_or_else(|_| unreachable!());
ortsys![unsafe CreateEnvWithCustomLogger(
ortsys![
unsafe CreateEnvWithCustomLogger(
logging_function,
logger_param,
ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE,
cname.as_ptr(),
&mut env_ptr
) -> Error::CreateEnvironment; nonNull(env_ptr)];
) -> Error::CreateEnvironment;
nonNull(env_ptr)
];
(env_ptr, false)
};
debug!(env_ptr = format!("{env_ptr:?}").as_str(), "Environment created");
Expand All @@ -205,31 +220,52 @@ impl EnvironmentBuilder {

/// Creates an ONNX Runtime environment.
///
/// ```
/// # use ort::CUDAExecutionProvider;
/// # fn main() -> ort::Result<()> {
/// ort::init()
/// .with_execution_providers([CUDAExecutionProvider::default().build()])
/// .commit()?;
/// # Ok(())
/// # }
/// ```
///
/// # Notes
/// - It is not required to call this function. If this is not called by the time any other `ort` APIs are used, a
/// default environment will be created.
/// - Library crates that use `ort` shouldn't create their own environment. Let downstream applications create it.
/// - **Library crates that use `ort` shouldn't create their own environment.** Let downstream applications create it.
/// - In order for environment settings to apply, this must be called **before** you use other APIs like
/// [`crate::Session`], and you *must* call `.commit()` on the builder returned by this function.
#[must_use]
#[must_use = "commit() must be called in order for the environment to take effect"]
pub fn init() -> EnvironmentBuilder {
EnvironmentBuilder::default()
EnvironmentBuilder::new()
}

/// Creates an ONNX Runtime environment, dynamically loading ONNX Runtime from the library file (`.dll`/`.so`/`.dylib`)
/// specified by `path`.
///
/// This must be called before any other `ort` APIs are used in order for the correct dynamic library to be loaded.
///
/// ```no_run
/// # use ort::CUDAExecutionProvider;
/// # fn main() -> ort::Result<()> {
/// let lib_path = std::env::current_exe().unwrap().parent().unwrap().join("lib");
/// ort::init_from(lib_path.join("onnxruntime.dll"))
/// .with_execution_providers([CUDAExecutionProvider::default().build()])
/// .commit()?;
/// # Ok(())
/// # }
/// ```
///
/// # Notes
/// - In order for environment settings to apply, this must be called **before** you use other APIs like
/// [`crate::Session`], and you *must* call `.commit()` on the builder returned by this function.
#[cfg(feature = "load-dynamic")]
#[cfg_attr(docsrs, doc(cfg(feature = "load-dynamic")))]
#[must_use]
#[must_use = "commit() must be called in order for the environment to take effect"]
pub fn init_from(path: impl ToString) -> EnvironmentBuilder {
let _ = G_ORT_DYLIB_PATH.set(Arc::new(path.to_string()));
EnvironmentBuilder::default()
EnvironmentBuilder::new()
}

/// ONNX's logger sends the code location where the log occurred, which will be parsed into this struct.
Expand Down Expand Up @@ -325,7 +361,7 @@ mod tests {
assert!(!is_env_initialized());
assert_eq!(env_ptr(), None);

EnvironmentBuilder::default().with_name("env_is_initialized").commit()?;
EnvironmentBuilder::new().with_name("env_is_initialized").commit()?;
assert!(is_env_initialized());
assert_ne!(env_ptr(), None);
Ok(())
Expand Down
61 changes: 17 additions & 44 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,6 @@ pub enum Error {
/// Error occurred when filling a tensor with string data
#[error("Failed to fill string tensor: {0}")]
FillStringTensor(ErrorInternal),
/// Error occurred when checking if a value is a tensor
#[error("Failed to check if value is a tensor: {0}")]
FailedTensorCheck(ErrorInternal),
/// Error occurred when getting tensor type and shape
#[error("Failed to get tensor type and shape: {0}")]
GetTensorTypeAndShape(ErrorInternal),
Expand Down Expand Up @@ -159,12 +156,6 @@ pub enum Error {
/// Error occurred when downloading a pre-trained ONNX model from the [ONNX Model Zoo](https://github.com/onnx/models).
#[error("Failed to download ONNX model: {0}")]
DownloadError(#[from] FetchModelError),
/// Type of input data and the ONNX model do not match.
#[error("Data types do not match: expected {model:?}, got {input:?}")]
NonMatchingDataTypes { input: TensorElementType, model: TensorElementType },
/// Dimensions of input data and the ONNX model do not match.
#[error("Dimensions do not match: {0:?}")]
NonMatchingDimensions(NonMatchingDimensionsError),
/// File does not exist
#[error("File `{filename:?}` does not exist")]
FileDoesNotExist {
Expand All @@ -186,9 +177,6 @@ pub enum Error {
/// ORT pointer should not have been null
#[error("`{0}` should not be a null pointer")]
PointerShouldNotBeNull(&'static str),
/// The runtime type was undefined.
#[error("Undefined tensor element type")]
UndefinedTensorElementType,
/// Could not retrieve model metadata.
#[error("Failed to retrieve model metadata: {0}")]
GetModelMetadata(ErrorInternal),
Expand All @@ -208,8 +196,8 @@ pub enum Error {
ExecutionProviderNotRegistered(&'static str),
#[error("Expected tensor to be on CPU in order to get data, but had allocation device `{0}`.")]
TensorNotOnCpu(&'static str),
#[error("String tensors require the session's allocator to be provided through `Value::from_array`.")]
StringTensorRequiresAllocator,
#[error("Cannot extract scalar value from a {0}-dimensional tensor")]
TensorNot0Dimensional(usize),
#[error("Failed to create memory info: {0}")]
CreateMemoryInfo(ErrorInternal),
#[error("Could not get allocation device from `MemoryInfo`: {0}")]
Expand All @@ -222,10 +210,10 @@ pub enum Error {
BindInput(ErrorInternal),
#[error("Error when binding output: {0}")]
BindOutput(ErrorInternal),
#[error("Failed to clear IO binding: {0}")]
ClearBinding(ErrorInternal),
#[error("Error when retrieving session outputs from `IoBinding`: {0}")]
GetBoundOutputs(ErrorInternal),
#[error("Cannot use `extract_tensor` on a value that is {0:?}")]
NotTensor(ValueType),
#[error("Cannot use `extract_sequence` on a value that is {0:?}")]
NotSequence(ValueType),
#[error("Cannot use `extract_map` on a value that is {0:?}")]
Expand All @@ -252,6 +240,8 @@ pub enum Error {
GetOperatorInput(ErrorInternal),
#[error("Failed to get operator output: {0}")]
GetOperatorOutput(ErrorInternal),
#[error("Failed to retrieve GPU compute stream from kernel context: {0}")]
GetOperatorGPUComputeStream(ErrorInternal),
#[error("{0}")]
CustomError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
#[error("String tensors cannot be borrowed as mutable")]
Expand All @@ -266,37 +256,20 @@ pub enum Error {
GetDeviceId(ErrorInternal)
}

impl From<Infallible> for Error {
fn from(_: Infallible) -> Self {
Error::Infallible
impl Error {
/// Wrap a custom, user-provided error in an [`ort::Error`](Error). The resulting error will be the
/// [`Error::CustomError`] variant.
///
/// This can be used to return custom errors from e.g. training dataloaders or custom operators if a non-`ort`
/// related operation fails.
pub fn wrap<T: std::error::Error + Send + Sync + 'static>(err: T) -> Self {
Error::CustomError(Box::new(err) as Box<dyn std::error::Error + Send + Sync + 'static>)
}
}

/// Error used when the input dimensions defined in the model and passed from an inference call do not match.
#[non_exhaustive]
#[derive(Error, Debug)]
pub enum NonMatchingDimensionsError {
/// Number of inputs from model does not match the number of inputs from inference call.
#[error(
"Non-matching number of inputs: {inference_input_count:?} provided vs {model_input_count:?} for model (inputs: {inference_input:?}, model: {model_input:?})"
)]
InputsCount {
/// Number of input dimensions used by inference call
inference_input_count: usize,
/// Number of input dimensions defined in model
model_input_count: usize,
/// Input dimensions used by inference call
inference_input: Vec<Vec<usize>>,
/// Input dimensions defined in model
model_input: Vec<Vec<Option<u32>>>
},
/// Inputs length from model does not match the expected input from inference call
#[error("Different input lengths; expected input: {model_input:?}, received input: {inference_input:?}")]
InputsLength {
/// Input dimensions used by inference call
inference_input: Vec<Vec<usize>>,
/// Input dimensions defined in model
model_input: Vec<Vec<Option<u32>>>
impl From<Infallible> for Error {
fn from(_: Infallible) -> Self {
Error::Infallible
}
}

Expand Down
Loading

0 comments on commit c64b8ea

Please sign in to comment.