diff --git a/src/session/builder/impl_options.rs b/src/session/builder/impl_options.rs index aad1873..85ed535 100644 --- a/src/session/builder/impl_options.rs +++ b/src/session/builder/impl_options.rs @@ -1,6 +1,4 @@ -#[cfg(not(windows))] -use std::ffi::CString; -use std::{path::Path, rc::Rc, sync::Arc}; +use std::{borrow::Cow, path::Path, rc::Rc, sync::Arc}; use super::SessionBuilder; use crate::{ @@ -8,7 +6,7 @@ use crate::{ execution_providers::{apply_execution_providers, ExecutionProviderDispatch}, ortsys, util::path_to_os_char, - MemoryInfo, OperatorDomain + DynValue, MemoryInfo, OperatorDomain }; impl SessionBuilder { @@ -79,22 +77,16 @@ impl SessionBuilder { /// newly optimized model to the given path (for 'offline' graph optimization). /// /// Note that the file will only be created after the model is committed. - pub fn with_optimized_model_path>(self, path: S) -> Result { - #[cfg(windows)] - let path = path.as_ref().encode_utf16().chain([0]).collect::>(); - #[cfg(not(windows))] - let path = CString::new(path.as_ref())?; + pub fn with_optimized_model_path>(self, path: S) -> Result { + let path = crate::util::path_to_os_char(path); ortsys![unsafe SetOptimizedModelFilePath(self.session_options_ptr.as_ptr(), path.as_ptr())?]; Ok(self) } /// Enables profiling. Profile information will be writen to `profiling_file` after profiling completes. /// See [`Session::end_profiling`]. - pub fn with_profiling>(self, profiling_file: S) -> Result { - #[cfg(windows)] - let profiling_file = profiling_file.as_ref().encode_utf16().chain([0]).collect::>(); - #[cfg(not(windows))] - let profiling_file = CString::new(profiling_file.as_ref())?; + pub fn with_profiling>(self, profiling_file: S) -> Result { + let profiling_file = crate::util::path_to_os_char(profiling_file); ortsys![unsafe EnableProfiling(self.session_options_ptr.as_ptr(), profiling_file.as_ptr())?]; Ok(self) } @@ -136,6 +128,35 @@ impl SessionBuilder { self.operator_domains.push(domain); Ok(self) } + + /// Enables/disables deterministic computation. + /// + /// The default (non-deterministic) kernels will typically use faster algorithms that may introduce slight variance. + /// Enabling deterministic compute will output reproducible results, but may come at a performance penalty. + pub fn with_deterministic_compute(self, enable: bool) -> Result { + ortsys![unsafe SetDeterministicCompute(self.session_options_ptr.as_ptr(), enable)?]; + Ok(self) + } + + pub fn with_external_initializer(mut self, name: impl AsRef, value: DynValue) -> Result { + let name = name.as_ref(); + let value = Rc::new(value); + ortsys![unsafe AddExternalInitializers(self.session_options_ptr.as_ptr(), &name.as_ptr().cast::(), &value.ptr().cast_const(), 1)?]; + self.external_initializers.push(value); + Ok(self) + } + + pub fn with_external_initializer_file(mut self, file_name: impl AsRef, buffer: Cow<'static, [u8]>) -> Result { + // We need to hold onto `buffer` until the session is actually committed. This means `buffer` must outlive 'self (if + // SessionBuilder were to have a lifetime). Adding a lifetime to SessionBuilder would be breaking, so right now we + // either accept a &'static [u8] or Vec via Cow<'_, [u8]>, which still allows users to use include_bytes!. + + let file_name = crate::util::path_to_os_char(file_name); + let sizes = [buffer.len() as ort_sys::size_t]; + ortsys![unsafe AddExternalInitializersFromMemory(self.session_options_ptr.as_ptr(), &file_name.as_ptr(), &buffer.as_ptr().cast::().cast_mut(), sizes.as_ptr(), 1)?]; + self.external_initializer_buffers.push(buffer); + Ok(self) + } } /// ONNX Runtime provides various graph optimizations to improve performance. Graph optimizations are essentially diff --git a/src/session/builder/mod.rs b/src/session/builder/mod.rs index 2a4e5c1..3f6c051 100644 --- a/src/session/builder/mod.rs +++ b/src/session/builder/mod.rs @@ -1,4 +1,5 @@ use std::{ + borrow::Cow, ffi::CString, ptr::{self, NonNull}, rc::Rc, @@ -9,7 +10,7 @@ use crate::{ error::{assert_non_null_pointer, status_to_result, Result}, memory::MemoryInfo, operator::OperatorDomain, - ortsys + ortsys, DynValue }; mod impl_commit; @@ -36,7 +37,9 @@ pub use self::impl_options::GraphOptimizationLevel; pub struct SessionBuilder { pub(crate) session_options_ptr: NonNull, memory_info: Option>, - operator_domains: Vec> + operator_domains: Vec>, + external_initializers: Vec>, + external_initializer_buffers: Vec> } impl Clone for SessionBuilder { @@ -48,7 +51,9 @@ impl Clone for SessionBuilder { Self { session_options_ptr: unsafe { NonNull::new_unchecked(session_options_ptr) }, memory_info: self.memory_info.clone(), - operator_domains: self.operator_domains.clone() + operator_domains: self.operator_domains.clone(), + external_initializers: self.external_initializers.clone(), + external_initializer_buffers: self.external_initializer_buffers.clone() } } } @@ -79,7 +84,9 @@ impl SessionBuilder { Ok(Self { session_options_ptr: unsafe { NonNull::new_unchecked(session_options_ptr) }, memory_info: None, - operator_domains: Vec::new() + operator_domains: Vec::new(), + external_initializers: Vec::new(), + external_initializer_buffers: Vec::new() }) }