Skip to content

Commit

Permalink
feat: load external initializers from memory, closes #286
Browse files Browse the repository at this point in the history
  • Loading branch information
decahedron1 committed Sep 23, 2024
1 parent 4da5700 commit c8b36f3
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 18 deletions.
49 changes: 35 additions & 14 deletions src/session/builder/impl_options.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
#[cfg(not(windows))]
use std::ffi::CString;
use std::{path::Path, rc::Rc, sync::Arc};
use std::{borrow::Cow, path::Path, rc::Rc, sync::Arc};

use super::SessionBuilder;
use crate::{
error::Result,
execution_providers::{apply_execution_providers, ExecutionProviderDispatch},
ortsys,
util::path_to_os_char,
MemoryInfo, OperatorDomain
DynValue, MemoryInfo, OperatorDomain
};

impl SessionBuilder {
Expand Down Expand Up @@ -79,22 +77,16 @@ impl SessionBuilder {
/// newly optimized model to the given path (for 'offline' graph optimization).
///
/// Note that the file will only be created after the model is committed.
pub fn with_optimized_model_path<S: AsRef<str>>(self, path: S) -> Result<Self> {
#[cfg(windows)]
let path = path.as_ref().encode_utf16().chain([0]).collect::<Vec<_>>();
#[cfg(not(windows))]
let path = CString::new(path.as_ref())?;
pub fn with_optimized_model_path<S: AsRef<Path>>(self, path: S) -> Result<Self> {
let path = crate::util::path_to_os_char(path);
ortsys![unsafe SetOptimizedModelFilePath(self.session_options_ptr.as_ptr(), path.as_ptr())?];
Ok(self)
}

/// Enables profiling. Profile information will be writen to `profiling_file` after profiling completes.
/// See [`Session::end_profiling`].
pub fn with_profiling<S: AsRef<str>>(self, profiling_file: S) -> Result<Self> {
#[cfg(windows)]
let profiling_file = profiling_file.as_ref().encode_utf16().chain([0]).collect::<Vec<_>>();
#[cfg(not(windows))]
let profiling_file = CString::new(profiling_file.as_ref())?;
pub fn with_profiling<S: AsRef<Path>>(self, profiling_file: S) -> Result<Self> {
let profiling_file = crate::util::path_to_os_char(profiling_file);
ortsys![unsafe EnableProfiling(self.session_options_ptr.as_ptr(), profiling_file.as_ptr())?];
Ok(self)
}
Expand Down Expand Up @@ -136,6 +128,35 @@ impl SessionBuilder {
self.operator_domains.push(domain);
Ok(self)
}

/// Enables/disables deterministic computation.
///
/// The default (non-deterministic) kernels will typically use faster algorithms that may introduce slight variance.
/// Enabling deterministic compute will output reproducible results, but may come at a performance penalty.
pub fn with_deterministic_compute(self, enable: bool) -> Result<Self> {
ortsys![unsafe SetDeterministicCompute(self.session_options_ptr.as_ptr(), enable)?];
Ok(self)
}

pub fn with_external_initializer(mut self, name: impl AsRef<str>, value: DynValue) -> Result<Self> {
let name = name.as_ref();
let value = Rc::new(value);
ortsys![unsafe AddExternalInitializers(self.session_options_ptr.as_ptr(), &name.as_ptr().cast::<i8>(), &value.ptr().cast_const(), 1)?];
self.external_initializers.push(value);
Ok(self)
}

pub fn with_external_initializer_file(mut self, file_name: impl AsRef<Path>, buffer: Cow<'static, [u8]>) -> Result<Self> {
// We need to hold onto `buffer` until the session is actually committed. This means `buffer` must outlive 'self (if
// SessionBuilder were to have a lifetime). Adding a lifetime to SessionBuilder would be breaking, so right now we
// either accept a &'static [u8] or Vec<u8> via Cow<'_, [u8]>, which still allows users to use include_bytes!.

let file_name = crate::util::path_to_os_char(file_name);
let sizes = [buffer.len() as ort_sys::size_t];
ortsys![unsafe AddExternalInitializersFromMemory(self.session_options_ptr.as_ptr(), &file_name.as_ptr(), &buffer.as_ptr().cast::<i8>().cast_mut(), sizes.as_ptr(), 1)?];
self.external_initializer_buffers.push(buffer);
Ok(self)
}
}

/// ONNX Runtime provides various graph optimizations to improve performance. Graph optimizations are essentially
Expand Down
15 changes: 11 additions & 4 deletions src/session/builder/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::{
borrow::Cow,
ffi::CString,
ptr::{self, NonNull},
rc::Rc,
Expand All @@ -9,7 +10,7 @@ use crate::{
error::{assert_non_null_pointer, status_to_result, Result},
memory::MemoryInfo,
operator::OperatorDomain,
ortsys
ortsys, DynValue
};

mod impl_commit;
Expand All @@ -36,7 +37,9 @@ pub use self::impl_options::GraphOptimizationLevel;
pub struct SessionBuilder {
pub(crate) session_options_ptr: NonNull<ort_sys::OrtSessionOptions>,
memory_info: Option<Rc<MemoryInfo>>,
operator_domains: Vec<Arc<OperatorDomain>>
operator_domains: Vec<Arc<OperatorDomain>>,
external_initializers: Vec<Rc<DynValue>>,
external_initializer_buffers: Vec<Cow<'static, [u8]>>
}

impl Clone for SessionBuilder {
Expand All @@ -48,7 +51,9 @@ impl Clone for SessionBuilder {
Self {
session_options_ptr: unsafe { NonNull::new_unchecked(session_options_ptr) },
memory_info: self.memory_info.clone(),
operator_domains: self.operator_domains.clone()
operator_domains: self.operator_domains.clone(),
external_initializers: self.external_initializers.clone(),
external_initializer_buffers: self.external_initializer_buffers.clone()
}
}
}
Expand Down Expand Up @@ -79,7 +84,9 @@ impl SessionBuilder {
Ok(Self {
session_options_ptr: unsafe { NonNull::new_unchecked(session_options_ptr) },
memory_info: None,
operator_domains: Vec::new()
operator_domains: Vec::new(),
external_initializers: Vec::new(),
external_initializer_buffers: Vec::new()
})
}

Expand Down

0 comments on commit c8b36f3

Please sign in to comment.