Skip to content

Commit

Permalink
refactor: typestate for RunOptions that have selected outputs
Browse files Browse the repository at this point in the history
  • Loading branch information
decahedron1 committed Jul 1, 2024
1 parent 3b93e73 commit a127d0f
Show file tree
Hide file tree
Showing 6 changed files with 182 additions and 57 deletions.
6 changes: 3 additions & 3 deletions src/io_binding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use crate::{
ortsys,
session::{output::SessionOutputs, RunOptions},
value::{Value, ValueInner},
DynValue, Error, Result, Session, ValueTypeMarker
DynValue, Error, NoSelectedOutputs, Result, Session, ValueTypeMarker
};

/// Enables binding of session inputs and/or outputs to pre-allocated memory.
Expand Down Expand Up @@ -177,11 +177,11 @@ impl<'s> IoBinding<'s> {
}

/// Performs inference on the session using the bound inputs specified by [`IoBinding::bind_input`].
pub fn run_with_options(&mut self, run_options: &RunOptions) -> Result<SessionOutputs<'_, 's>> {
pub fn run_with_options(&mut self, run_options: &RunOptions<NoSelectedOutputs>) -> Result<SessionOutputs<'_, 's>> {
self.run_inner(Some(run_options))
}

fn run_inner(&mut self, run_options: Option<&RunOptions>) -> Result<SessionOutputs<'_, 's>> {
fn run_inner(&mut self, run_options: Option<&RunOptions<NoSelectedOutputs>>) -> Result<SessionOutputs<'_, 's>> {
let run_options_ptr = if let Some(run_options) = run_options {
run_options.run_options_ptr.as_ptr()
} else {
Expand Down
8 changes: 4 additions & 4 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ pub use self::operator::{
InferShapeFn, Operator, OperatorDomain
};
pub use self::session::{
GraphOptimizationLevel, InMemorySession, Input, Output, OutputSelector, RunOptions, Session, SessionBuilder, SessionInputValue, SessionInputs,
SessionOutputs, SharedSessionInner
GraphOptimizationLevel, HasSelectedOutputs, InMemorySession, InferenceFut, Input, NoSelectedOutputs, Output, OutputSelector, RunOptions,
SelectedOutputMarker, Session, SessionBuilder, SessionInputValue, SessionInputs, SessionOutputs, SharedSessionInner
};
#[cfg(feature = "ndarray")]
#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))]
Expand All @@ -69,8 +69,8 @@ pub use self::tensor::{IntoTensorElementType, PrimitiveTensorElementType, Tensor
pub use self::value::{
DowncastableTarget, DynMap, DynMapRef, DynMapRefMut, DynMapValueType, DynSequence, DynSequenceRef, DynSequenceRefMut, DynSequenceValueType, DynTensor,
DynTensorRef, DynTensorRefMut, DynTensorValueType, DynValue, DynValueTypeMarker, Map, MapRef, MapRefMut, MapValueType, MapValueTypeMarker, Sequence,
SequenceRef, SequenceRefMut, SequenceValueType, SequenceValueTypeMarker, Tensor, TensorRef, TensorRefMut, TensorValueTypeMarker, Value, ValueRef,
ValueRefMut, ValueType, ValueTypeMarker
SequenceRef, SequenceRefMut, SequenceValueType, SequenceValueTypeMarker, Tensor, TensorRef, TensorRefMut, TensorValueType, TensorValueTypeMarker, Value,
ValueRef, ValueRefMut, ValueType, ValueTypeMarker
};

#[cfg(not(all(target_arch = "x86", target_os = "windows")))]
Expand Down
32 changes: 16 additions & 16 deletions src/session/async.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use std::{

use ort_sys::{c_void, OrtStatus};

use crate::{error::assert_non_null_pointer, Error, Result, RunOptions, SessionInputValue, SessionOutputs, SharedSessionInner, Value};
use crate::{error::assert_non_null_pointer, Error, Result, RunOptions, SelectedOutputMarker, SessionInputValue, SessionOutputs, SharedSessionInner, Value};

#[derive(Debug)]
pub(crate) struct InferenceFutInner<'r, 's> {
Expand Down Expand Up @@ -49,25 +49,25 @@ impl<'r, 's> InferenceFutInner<'r, 's> {
unsafe impl<'r, 's> Send for InferenceFutInner<'r, 's> {}
unsafe impl<'r, 's> Sync for InferenceFutInner<'r, 's> {}

pub enum RunOptionsRef<'r> {
Arc(Arc<RunOptions>),
Ref(&'r RunOptions)
pub enum RunOptionsRef<'r, O: SelectedOutputMarker> {
Arc(Arc<RunOptions<O>>),
Ref(&'r RunOptions<O>)
}

impl<'r> From<&Arc<RunOptions>> for RunOptionsRef<'r> {
fn from(value: &Arc<RunOptions>) -> Self {
impl<'r, O: SelectedOutputMarker> From<&Arc<RunOptions<O>>> for RunOptionsRef<'r, O> {
fn from(value: &Arc<RunOptions<O>>) -> Self {
Self::Arc(Arc::clone(value))
}
}

impl<'r> From<&'r RunOptions> for RunOptionsRef<'r> {
fn from(value: &'r RunOptions) -> Self {
impl<'r, O: SelectedOutputMarker> From<&'r RunOptions<O>> for RunOptionsRef<'r, O> {
fn from(value: &'r RunOptions<O>) -> Self {
Self::Ref(value)
}
}

impl<'r> Deref for RunOptionsRef<'r> {
type Target = RunOptions;
impl<'r, O: SelectedOutputMarker> Deref for RunOptionsRef<'r, O> {
type Target = RunOptions<O>;

fn deref(&self) -> &Self::Target {
match self {
Expand All @@ -77,14 +77,14 @@ impl<'r> Deref for RunOptionsRef<'r> {
}
}

pub struct InferenceFut<'s, 'r> {
pub struct InferenceFut<'s, 'r, O: SelectedOutputMarker> {
inner: Arc<InferenceFutInner<'r, 's>>,
run_options: RunOptionsRef<'r>,
run_options: RunOptionsRef<'r, O>,
did_receive: bool
}

impl<'s, 'r> InferenceFut<'s, 'r> {
pub(crate) fn new(inner: Arc<InferenceFutInner<'r, 's>>, run_options: RunOptionsRef<'r>) -> Self {
impl<'s, 'r, O: SelectedOutputMarker> InferenceFut<'s, 'r, O> {
pub(crate) fn new(inner: Arc<InferenceFutInner<'r, 's>>, run_options: RunOptionsRef<'r, O>) -> Self {
Self {
inner,
run_options,
Expand All @@ -93,7 +93,7 @@ impl<'s, 'r> InferenceFut<'s, 'r> {
}
}

impl<'s, 'r> Future for InferenceFut<'s, 'r> {
impl<'s, 'r, O: SelectedOutputMarker> Future for InferenceFut<'s, 'r, O> {
type Output = Result<SessionOutputs<'r, 's>>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
Expand All @@ -109,7 +109,7 @@ impl<'s, 'r> Future for InferenceFut<'s, 'r> {
}
}

impl<'s, 'r> Drop for InferenceFut<'s, 'r> {
impl<'s, 'r, O: SelectedOutputMarker> Drop for InferenceFut<'s, 'r, O> {
fn drop(&mut self) {
if !self.did_receive {
let _ = self.run_options.terminate();
Expand Down
45 changes: 24 additions & 21 deletions src/session/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

use std::{any::Any, ffi::CString, marker::PhantomData, ops::Deref, os::raw::c_char, ptr::NonNull, sync::Arc};

use r#async::RunOptionsRef;

use super::{
char_p_to_string,
environment::Environment,
Expand All @@ -21,13 +19,13 @@ pub(crate) mod builder;
pub(crate) mod input;
pub(crate) mod output;
mod run_options;
use self::r#async::{AsyncInferenceContext, InferenceFutInner};
use self::r#async::{AsyncInferenceContext, InferenceFutInner, RunOptionsRef};
pub use self::{
r#async::InferenceFut,
builder::{GraphOptimizationLevel, SessionBuilder},
input::{SessionInputValue, SessionInputs},
output::SessionOutputs,
run_options::{OutputSelector, RunOptions}
run_options::{HasSelectedOutputs, NoSelectedOutputs, OutputSelector, RunOptions, SelectedOutputMarker}
};

/// Holds onto an [`ort_sys::OrtSession`] pointer and its associated allocator.
Expand Down Expand Up @@ -164,14 +162,16 @@ impl Session {
pub fn run<'s, 'i, 'v: 'i, const N: usize>(&'s self, input_values: impl Into<SessionInputs<'i, 'v, N>>) -> Result<SessionOutputs<'_, 's>> {
match input_values.into() {
SessionInputs::ValueSlice(input_values) => {
self.run_inner(&self.inputs.iter().map(|input| input.name.as_str()).collect::<Vec<_>>(), input_values.iter(), None)
self.run_inner::<NoSelectedOutputs>(&self.inputs.iter().map(|input| input.name.as_str()).collect::<Vec<_>>(), input_values.iter(), None)
}
SessionInputs::ValueArray(input_values) => {
self.run_inner(&self.inputs.iter().map(|input| input.name.as_str()).collect::<Vec<_>>(), input_values.iter(), None)
}
SessionInputs::ValueMap(input_values) => {
self.run_inner(&input_values.iter().map(|(k, _)| k.as_ref()).collect::<Vec<_>>(), input_values.iter().map(|(_, v)| v), None)
self.run_inner::<NoSelectedOutputs>(&self.inputs.iter().map(|input| input.name.as_str()).collect::<Vec<_>>(), input_values.iter(), None)
}
SessionInputs::ValueMap(input_values) => self.run_inner::<NoSelectedOutputs>(
&input_values.iter().map(|(k, _)| k.as_ref()).collect::<Vec<_>>(),
input_values.iter().map(|(_, v)| v),
None
)
}
}

Expand Down Expand Up @@ -201,10 +201,10 @@ impl Session {
/// # Ok(())
/// # }
/// ```
pub fn run_with_options<'r, 's: 'r, 'i, 'v: 'i, const N: usize>(
pub fn run_with_options<'r, 's: 'r, 'i, 'v: 'i, O: SelectedOutputMarker, const N: usize>(
&'s self,
input_values: impl Into<SessionInputs<'i, 'v, N>>,
run_options: &'r RunOptions
run_options: &'r RunOptions<O>
) -> Result<SessionOutputs<'r, 's>> {
match input_values.into() {
SessionInputs::ValueSlice(input_values) => {
Expand All @@ -219,11 +219,11 @@ impl Session {
}
}

fn run_inner<'i, 'r, 's: 'r, 'v: 'i>(
fn run_inner<'i, 'r, 's: 'r, 'v: 'i, O: SelectedOutputMarker>(
&'s self,
input_names: &[&str],
input_values: impl Iterator<Item = &'i SessionInputValue<'v>>,
run_options: Option<&'r RunOptions>
run_options: Option<&'r RunOptions<O>>
) -> Result<SessionOutputs<'r, 's>> {
let input_names_ptr: Vec<*const c_char> = input_names
.iter()
Expand Down Expand Up @@ -321,7 +321,7 @@ impl Session {
pub fn run_async<'s, 'i, 'v: 'i + 's, const N: usize>(
&'s self,
input_values: impl Into<SessionInputs<'i, 'v, N>> + 'static
) -> Result<InferenceFut<'s, '_>> {
) -> Result<InferenceFut<'s, '_, NoSelectedOutputs>> {
match input_values.into() {
SessionInputs::ValueSlice(_) => unimplemented!("slices cannot be used in `run_async`"),
SessionInputs::ValueArray(input_values) => {
Expand All @@ -335,11 +335,11 @@ impl Session {

/// Asynchronously run input data through the ONNX graph, performing inference, with the given [`RunOptions`].
/// See [`Session::run_with_options`] and [`Session::run_async`] for more details.
pub fn run_async_with_options<'s, 'i, 'v: 'i + 's, 'r, const N: usize>(
pub fn run_async_with_options<'s, 'i, 'v: 'i + 's, 'r, O: SelectedOutputMarker, const N: usize>(
&'s self,
input_values: impl Into<SessionInputs<'i, 'v, N>> + 'static,
run_options: &'r RunOptions
) -> Result<InferenceFut<'s, 'r>> {
run_options: &'r RunOptions<O>
) -> Result<InferenceFut<'s, 'r, O>> {
match input_values.into() {
SessionInputs::ValueSlice(_) => unimplemented!("slices cannot be used in `run_async`"),
SessionInputs::ValueArray(input_values) => {
Expand All @@ -353,17 +353,20 @@ impl Session {
}
}

fn run_inner_async<'s, 'v: 's, 'r>(
fn run_inner_async<'s, 'v: 's, 'r, O: SelectedOutputMarker>(
&'s self,
input_names: &[String],
input_values: impl Iterator<Item = SessionInputValue<'v>>,
run_options: Option<&'r RunOptions>
) -> Result<InferenceFut<'s, 'r>> {
run_options: Option<&'r RunOptions<O>>
) -> Result<InferenceFut<'s, 'r, O>> {
let run_options = match run_options {
Some(r) => RunOptionsRef::Ref(r),
// create a `RunOptions` to pass to the future so that when it drops, it terminates inference - crucial
// (performance-wise) for routines involving `tokio::select!` or timeouts
None => RunOptionsRef::Arc(Arc::new(RunOptions::new()?))
None => RunOptionsRef::Arc(Arc::new(unsafe {
// SAFETY: transmuting from `RunOptions<NoSelectedOutputs>` to `RunOptions<O>`; safe because its just a marker
std::mem::transmute(RunOptions::new()?)
}))
};

let input_name_ptrs: Vec<*const c_char> = input_names
Expand Down
Loading

0 comments on commit a127d0f

Please sign in to comment.