Skip to content

Commit

Permalink
refactor: remove lifetime bound for IoBinding
Browse files Browse the repository at this point in the history
And implement `Send` for it. yw audioxd
  • Loading branch information
decahedron1 committed Sep 23, 2024
1 parent bd3c891 commit cf1be86
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 17 deletions.
58 changes: 42 additions & 16 deletions src/io_binding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ use crate::{
memory::MemoryInfo,
ortsys,
session::{output::SessionOutputs, NoSelectedOutputs, RunOptions, Session},
value::{DynValue, Value, ValueInner, ValueTypeMarker}
value::{DynValue, Value, ValueInner, ValueTypeMarker},
SharedSessionInner
};

/// Enables binding of session inputs and/or outputs to pre-allocated memory.
Expand Down Expand Up @@ -86,21 +87,21 @@ use crate::{
/// of `unet.run()`, and this copying can come with significant latency & overhead. With [`IoBinding`], the `condition`
/// tensor is only copied to the device once instead of 20 times.
#[derive(Debug)]
pub struct IoBinding<'s> {
pub struct IoBinding {
pub(crate) ptr: NonNull<ort_sys::OrtIoBinding>,
session: &'s Session,
held_inputs: HashMap<String, Arc<ValueInner>>,
output_names: Vec<String>,
output_values: HashMap<String, DynValue>
output_values: HashMap<String, DynValue>,
session: Arc<SharedSessionInner>
}

impl<'s> IoBinding<'s> {
pub(crate) fn new(session: &'s Session) -> Result<Self> {
impl IoBinding {
pub(crate) fn new(session: &Session) -> Result<Self> {
let mut ptr: *mut ort_sys::OrtIoBinding = ptr::null_mut();
ortsys![unsafe CreateIoBinding(session.inner.session_ptr.as_ptr(), &mut ptr)?; nonNull(ptr)];
Ok(Self {
ptr: unsafe { NonNull::new_unchecked(ptr) },
session,
session: session.inner(),
held_inputs: HashMap::new(),
output_names: Vec::new(),
output_values: HashMap::new()
Expand Down Expand Up @@ -177,24 +178,23 @@ impl<'s> IoBinding<'s> {
}

/// Performs inference on the session using the bound inputs specified by [`IoBinding::bind_input`].
pub fn run_with_options(&mut self, run_options: &RunOptions<NoSelectedOutputs>) -> Result<SessionOutputs<'_, 's>> {
pub fn run_with_options(&mut self, run_options: &RunOptions<NoSelectedOutputs>) -> Result<SessionOutputs<'_, '_>> {
self.run_inner(Some(run_options))
}

fn run_inner(&mut self, run_options: Option<&RunOptions<NoSelectedOutputs>>) -> Result<SessionOutputs<'_, 's>> {
fn run_inner(&mut self, run_options: Option<&RunOptions<NoSelectedOutputs>>) -> Result<SessionOutputs<'_, '_>> {
let run_options_ptr = if let Some(run_options) = run_options {
run_options.run_options_ptr.as_ptr()
} else {
std::ptr::null_mut()
};
ortsys![unsafe RunWithBinding(self.session.inner.session_ptr.as_ptr(), run_options_ptr, self.ptr.as_ptr())?];
ortsys![unsafe RunWithBinding(self.session.session_ptr.as_ptr(), run_options_ptr, self.ptr.as_ptr())?];

let owned_ptrs: HashMap<*mut ort_sys::OrtValue, &Arc<ValueInner>> = self.output_values.values().map(|c| (c.ptr(), &c.inner)).collect();
let mut count = self.output_names.len() as ort_sys::size_t;
if count > 0 {
let mut output_values_ptr: *mut *mut ort_sys::OrtValue = ptr::null_mut();
let allocator = self.session.allocator();
ortsys![unsafe GetBoundOutputValues(self.ptr.as_ptr(), allocator.ptr.as_ptr(), &mut output_values_ptr, &mut count)?; nonNull(output_values_ptr)];
ortsys![unsafe GetBoundOutputValues(self.ptr.as_ptr(), self.session.allocator.ptr.as_ptr(), &mut output_values_ptr, &mut count)?; nonNull(output_values_ptr)];

let output_values = unsafe { std::slice::from_raw_parts(output_values_ptr, count as _).to_vec() }
.into_iter()
Expand All @@ -207,21 +207,23 @@ impl<'s> IoBinding<'s> {
} else {
DynValue::from_ptr(
NonNull::new(v).expect("OrtValue ptrs returned by GetBoundOutputValues should not be null"),
Some(Arc::clone(&self.session.inner))
Some(Arc::clone(&self.session))
)
}
});

// output values will be freed when the `Value`s in `SessionOutputs` drop

Ok(SessionOutputs::new_backed(self.output_names.iter().map(String::as_str), output_values, allocator, output_values_ptr.cast()))
Ok(SessionOutputs::new_backed(self.output_names.iter().map(String::as_str), output_values, &self.session.allocator, output_values_ptr.cast()))
} else {
Ok(SessionOutputs::new_empty())
}
}
}

impl<'s> Drop for IoBinding<'s> {
unsafe impl Send for IoBinding {}

impl Drop for IoBinding {
fn drop(&mut self) {
ortsys![unsafe ReleaseIoBinding(self.ptr.as_ptr())];
}
Expand Down Expand Up @@ -295,7 +297,31 @@ mod tests {
}

#[test]
fn test_mnist_clears_bound() -> Result<()> {
fn test_send_iobinding() -> Result<()> {
let session = Session::builder()?.commit_from_url("https://parcel.pyke.io/v2/cdn/assetdelivery/ortrsv2/ex_models/mnist.onnx")?;

let array = get_image();

let mut binding = session.create_binding()?;
let output = Array2::from_shape_simple_fn((1, 10), || 0.0_f32);
binding.bind_output(&session.outputs[0].name, Tensor::from_array(output)?)?;

let probabilities = std::thread::spawn(move || {
binding.bind_input(&session.inputs[0].name, &Tensor::from_array(array)?)?;
let outputs = binding.run()?;
let probabilities = extract_probabilities(&outputs[0])?;
Ok::<Vec<(usize, f32)>, crate::Error>(probabilities)
})
.join()
.expect("")?;

assert_eq!(probabilities[0].0, 5);

Ok(())
}

#[test]
fn test_mnist_clear_bounds() -> Result<()> {
let session = Session::builder()?.commit_from_url("https://parcel.pyke.io/v2/cdn/assetdelivery/ortrsv2/ex_models/mnist.onnx")?;

let array = get_image();
Expand Down
2 changes: 1 addition & 1 deletion src/session/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ pub use self::{
#[derive(Debug)]
pub struct SharedSessionInner {
pub(crate) session_ptr: NonNull<ort_sys::OrtSession>,
allocator: Allocator,
pub(crate) allocator: Allocator,
/// Additional things we may need to hold onto for the duration of this session, like [`crate::OperatorDomain`]s and
/// DLL handles for operator libraries.
_extras: Vec<Box<dyn Any>>,
Expand Down

0 comments on commit cf1be86

Please sign in to comment.