Skip to content

Commit

Permalink
refactor!: remove redundant error checks; more useful downcast messages
Browse files Browse the repository at this point in the history
Many of the ORT APIs are hardcoded to actually *never* return a failure status, despite returning `*mut OrtStatus`, probably for consistency (which is appreciated!)
For `ort`, this means almost every API requires a ? or unwrap, even if it will never fail. Notably, this includes all functions on `MemoryInfo`.
This commit makes many of these functions return `T` instead of `Result<T>`.

Additionally, I found that all of the `downcast` methods on `Value` just panic if the type is not downcastable - that's bad! Now a nice `Display` is implemented for `ValueType`, and the downcast functions actually return errors. Extraction methods also have better error messages.

Also in this commit: checking if memory is CPU accessible the *correct* way instead of matching AllocationDevices.
  • Loading branch information
decahedron1 committed Sep 14, 2024
1 parent b9ef6aa commit 359a051
Show file tree
Hide file tree
Showing 11 changed files with 275 additions and 180 deletions.
65 changes: 32 additions & 33 deletions src/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,19 +247,6 @@ impl AllocationDevice {
pub fn as_str(&self) -> &'static str {
self.0
}

/// Returns `true` if this memory is accessible by the CPU; meaning that, if a value were allocated on this device,
/// it could be extracted to an `ndarray` or slice.
pub fn is_cpu_accessible(&self) -> bool {
self == &Self::CPU
|| self == &Self::CUDA_PINNED
|| self == &Self::CANN_PINNED
|| self == &Self::HIP_PINNED
|| self == &Self::OPENVINO_CPU
|| self == &Self::DIRECTML_CPU
|| self == &Self::XNNPACK
|| self == &Self::TVM
}
}

impl PartialEq<str> for AllocationDevice {
Expand Down Expand Up @@ -293,7 +280,7 @@ pub enum MemoryType {
CPUInput,
/// CPU-accessible memory output by a non-CPU execution provider, i.e. [`AllocatorDevice::CUDAPinned`].
CPUOutput,
/// The default allocator for an execution provider.
/// The default (typically device memory) allocator for an execution provider.
#[default]
Default
}
Expand Down Expand Up @@ -323,8 +310,7 @@ impl From<ort_sys::OrtMemType> for MemoryType {
}
}

/// Structure describing a memory location - the device on which the memory resides, the type of allocator (device
/// default, or arena) used, and the type of memory allocated (device-only, or CPU accessible).
/// Describes allocation properties for value memory.
///
/// `MemoryInfo` is used in the creation of [`Session`]s, [`Allocator`]s, and [`crate::Value`]s to describe on which
/// device value data should reside, and how that data should be accessible with regard to the CPU (if a non-CPU device
Expand Down Expand Up @@ -371,52 +357,59 @@ impl MemoryInfo {
MemoryInfo { ptr, should_release }
}

// All getter functions are (at least currently) infallible - they simply just dereference the corresponding fields,
// and always return `nullptr` for the status; so none of these have to return `Result`s.
// https://github.com/microsoft/onnxruntime/blob/v1.19.2/onnxruntime/core/framework/allocator.cc#L166

/// Returns the [`MemoryType`] described by this struct.
/// ```
/// # use ort::{MemoryInfo, MemoryType, AllocationDevice, AllocatorType};
/// # fn main() -> ort::Result<()> {
/// let mem = MemoryInfo::new(AllocationDevice::CPU, 0, AllocatorType::Device, MemoryType::Default)?;
/// assert_eq!(mem.memory_type()?, MemoryType::Default);
/// assert_eq!(mem.memory_type(), MemoryType::Default);
/// # Ok(())
/// # }
/// ```
pub fn memory_type(&self) -> Result<MemoryType> {
pub fn memory_type(&self) -> MemoryType {
let mut raw_type: ort_sys::OrtMemType = ort_sys::OrtMemType::OrtMemTypeDefault;
ortsys![unsafe MemoryInfoGetMemType(self.ptr.as_ptr(), &mut raw_type)?];
Ok(MemoryType::from(raw_type))
ortsys![unsafe MemoryInfoGetMemType(self.ptr.as_ptr(), &mut raw_type)];
MemoryType::from(raw_type)
}

/// Returns the [`AllocatorType`] described by this struct.
/// ```
/// # use ort::{MemoryInfo, MemoryType, AllocationDevice, AllocatorType};
/// # fn main() -> ort::Result<()> {
/// let mem = MemoryInfo::new(AllocationDevice::CPU, 0, AllocatorType::Device, MemoryType::Default)?;
/// assert_eq!(mem.allocator_type()?, AllocatorType::Device);
/// assert_eq!(mem.allocator_type(), AllocatorType::Device);
/// # Ok(())
/// # }
/// ```
pub fn allocator_type(&self) -> Result<AllocatorType> {
pub fn allocator_type(&self) -> AllocatorType {
let mut raw_type: ort_sys::OrtAllocatorType = ort_sys::OrtAllocatorType::OrtInvalidAllocator;
ortsys![unsafe MemoryInfoGetType(self.ptr.as_ptr(), &mut raw_type)?];
Ok(match raw_type {
ortsys![unsafe MemoryInfoGetType(self.ptr.as_ptr(), &mut raw_type)];
match raw_type {
ort_sys::OrtAllocatorType::OrtArenaAllocator => AllocatorType::Arena,
ort_sys::OrtAllocatorType::OrtDeviceAllocator => AllocatorType::Device,
_ => unreachable!()
})
}
}

/// Returns the [`AllocationDevice`] this struct was created with.
/// ```
/// # use ort::{MemoryInfo, MemoryType, AllocationDevice, AllocatorType};
/// # fn main() -> ort::Result<()> {
/// let mem = MemoryInfo::new(AllocationDevice::CPU, 0, AllocatorType::Device, MemoryType::Default)?;
/// assert_eq!(mem.allocation_device()?, AllocationDevice::CPU);
/// assert_eq!(mem.allocation_device(), AllocationDevice::CPU);
/// # Ok(())
/// # }
/// ```
pub fn allocation_device(&self) -> Result<AllocationDevice> {
pub fn allocation_device(&self) -> AllocationDevice {
let mut name_ptr: *const c_char = std::ptr::null_mut();
ortsys![unsafe MemoryInfoGetName(self.ptr.as_ptr(), &mut name_ptr)?; nonNull(name_ptr)];
ortsys![unsafe MemoryInfoGetName(self.ptr.as_ptr(), &mut name_ptr)];

// SAFETY: `name_ptr` can never be null - `CreateMemoryInfo` internally checks against builtin device names, erroring
// if a non-builtin device is passed, and ONNX Runtime will never supply a pointer to the C++ constructor

let mut len = 0;
while unsafe { *name_ptr.add(len) } != 0x00 {
Expand All @@ -426,22 +419,28 @@ impl MemoryInfo {
// SAFETY: ONNX Runtime internally only ever defines allocation device names as ASCII. can't wait for this to blow up
// one day regardless
let name = unsafe { std::str::from_utf8_unchecked(std::slice::from_raw_parts(name_ptr.cast::<u8>(), len)) };
Ok(AllocationDevice(name))
AllocationDevice(name)
}

/// Returns the ID of the [`AllocationDevice`] described by this struct.
/// ```
/// # use ort::{MemoryInfo, MemoryType, AllocationDevice, AllocatorType};
/// # fn main() -> ort::Result<()> {
/// let mem = MemoryInfo::new(AllocationDevice::CPU, 0, AllocatorType::Device, MemoryType::Default)?;
/// assert_eq!(mem.device_id()?, 0);
/// assert_eq!(mem.device_id(), 0);
/// # Ok(())
/// # }
/// ```
pub fn device_id(&self) -> Result<i32> {
pub fn device_id(&self) -> i32 {
let mut raw: ort_sys::c_int = 0;
ortsys![unsafe MemoryInfoGetId(self.ptr.as_ptr(), &mut raw)?];
Ok(raw as _)
ortsys![unsafe MemoryInfoGetId(self.ptr.as_ptr(), &mut raw)];
raw as _
}

/// Returns `true` if this memory is accessible by the CPU; meaning that, if a value were allocated on this device,
/// it could be extracted to an `ndarray` or slice.
pub fn is_cpu_accessible(&self) -> bool {
self.allocation_device() == AllocationDevice::CPU || matches!(self.memory_type(), MemoryType::CPUInput | MemoryType::CPUOutput)
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/operator/kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ impl KernelAttributes {
.map_err(Error::wrap)?;
let mut type_info = ptr::null_mut();
ortsys![unsafe KernelInfo_GetInputTypeInfo(self.0.as_ptr(), idx as _, &mut type_info)?; nonNull(type_info)];
let input_type = ValueType::from_type_info(type_info)?;
let input_type = ValueType::from_type_info(type_info);
inputs.push(Input { name, input_type })
}
Ok(inputs)
Expand All @@ -75,7 +75,7 @@ impl KernelAttributes {
.map_err(Error::wrap)?;
let mut type_info = ptr::null_mut();
ortsys![unsafe KernelInfo_GetOutputTypeInfo(self.0.as_ptr(), idx as _, &mut type_info)?; nonNull(type_info)];
let output_type = ValueType::from_type_info(type_info)?;
let output_type = ValueType::from_type_info(type_info);
outputs.push(Output { name, output_type })
}
Ok(outputs)
Expand Down
6 changes: 3 additions & 3 deletions src/session/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ impl Session {
/// # Ok(())
/// # }
/// ```
pub fn run<'s, 'i, 'v: 'i, const N: usize>(&'s self, input_values: impl Into<SessionInputs<'i, 'v, N>>) -> Result<SessionOutputs<'_, 's>> {
pub fn run<'s, 'i, 'v: 'i, const N: usize>(&'s self, input_values: impl Into<SessionInputs<'i, 'v, N>>) -> Result<SessionOutputs<'s, 's>> {
match input_values.into() {
SessionInputs::ValueSlice(input_values) => {
self.run_inner::<NoSelectedOutputs>(&self.inputs.iter().map(|input| input.name.as_str()).collect::<Vec<_>>(), input_values.iter(), None)
Expand Down Expand Up @@ -330,7 +330,7 @@ impl Session {
pub fn run_async<'s, 'i, 'v: 'i + 's, const N: usize>(
&'s self,
input_values: impl Into<SessionInputs<'i, 'v, N>> + 'static
) -> Result<InferenceFut<'s, '_, NoSelectedOutputs>> {
) -> Result<InferenceFut<'s, 's, NoSelectedOutputs>> {
match input_values.into() {
SessionInputs::ValueSlice(_) => unimplemented!("slices cannot be used in `run_async`"),
SessionInputs::ValueArray(input_values) => {
Expand Down Expand Up @@ -549,6 +549,6 @@ mod dangerous {
status_to_result(status)?;
assert_non_null_pointer(typeinfo_ptr, "TypeInfo")?;

ValueType::from_type_info(typeinfo_ptr)
Ok(ValueType::from_type_info(typeinfo_ptr))
}
}
22 changes: 22 additions & 0 deletions src/tensor/types.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::fmt;
#[cfg(feature = "ndarray")]
use std::ptr;

Expand Down Expand Up @@ -41,6 +42,27 @@ pub enum TensorElementType {
Bfloat16
}

impl fmt::Display for TensorElementType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(match self {
TensorElementType::Bfloat16 => "bf16",
TensorElementType::Bool => "bool",
TensorElementType::Float16 => "f16",
TensorElementType::Float32 => "f32",
TensorElementType::Float64 => "f64",
TensorElementType::Int16 => "i16",
TensorElementType::Int32 => "i32",
TensorElementType::Int64 => "i64",
TensorElementType::Int8 => "i8",
TensorElementType::String => "String",
TensorElementType::Uint16 => "u16",
TensorElementType::Uint32 => "u32",
TensorElementType::Uint64 => "u64",
TensorElementType::Uint8 => "u8"
})
}
}

impl From<TensorElementType> for ort_sys::ONNXTensorElementDataType {
fn from(val: TensorElementType) -> Self {
match val {
Expand Down
4 changes: 2 additions & 2 deletions src/training/trainer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ impl Trainer {
&'s self,
inputs: impl Into<SessionInputs<'i1, 'v1, N1>>,
labels: impl Into<SessionInputs<'i2, 'v2, N2>>
) -> Result<SessionOutputs<'_, 's>> {
) -> Result<SessionOutputs<'s, 's>> {
match inputs.into() {
SessionInputs::ValueSlice(input_values) => match labels.into() {
SessionInputs::ValueSlice(labels) => self.step_inner(input_values.iter().chain(labels), None),
Expand Down Expand Up @@ -144,7 +144,7 @@ impl Trainer {
&'s self,
inputs: impl Into<SessionInputs<'i1, 'v1, N1>>,
labels: impl Into<SessionInputs<'i2, 'v2, N2>>
) -> Result<SessionOutputs<'_, 's>> {
) -> Result<SessionOutputs<'s, 's>> {
match inputs.into() {
SessionInputs::ValueSlice(input_values) => match labels.into() {
SessionInputs::ValueSlice(labels) => self.eval_step_inner(input_values.iter().chain(labels), None),
Expand Down
36 changes: 28 additions & 8 deletions src/value/impl_map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ pub trait MapValueTypeMarker: ValueTypeMarker {
#[derive(Debug)]
pub struct DynMapValueType;
impl ValueTypeMarker for DynMapValueType {
fn format() -> String {
"DynMap".to_string()
}

crate::private_impl!();
}
impl MapValueTypeMarker for DynMapValueType {
Expand All @@ -43,6 +47,10 @@ impl DowncastableTarget for DynMapValueType {
#[derive(Debug)]
pub struct MapValueType<K: IntoTensorElementType + Clone + Hash + Eq, V: IntoTensorElementType + Debug>(PhantomData<(K, V)>);
impl<K: IntoTensorElementType + Debug + Clone + Hash + Eq, V: IntoTensorElementType + Debug> ValueTypeMarker for MapValueType<K, V> {
fn format() -> String {
format!("Map<{}, {}>", K::into_tensor_element_type(), V::into_tensor_element_type())
}

crate::private_impl!();
}
impl<K: IntoTensorElementType + Debug + Clone + Hash + Eq, V: IntoTensorElementType + Debug> MapValueTypeMarker for MapValueType<K, V> {
Expand Down Expand Up @@ -70,7 +78,7 @@ pub type MapRefMut<'v, K, V> = ValueRefMut<'v, MapValueType<K, V>>;

impl<Type: MapValueTypeMarker + ?Sized> Value<Type> {
pub fn try_extract_map<K: IntoTensorElementType + Clone + Hash + Eq, V: PrimitiveTensorElementType + Clone>(&self) -> Result<HashMap<K, V>> {
match self.dtype()? {
match self.dtype() {
ValueType::Map { key, value } => {
let k_type = K::into_tensor_element_type();
if k_type != key {
Expand All @@ -80,7 +88,7 @@ impl<Type: MapValueTypeMarker + ?Sized> Value<Type> {
if v_type != value {
return Err(Error::new_with_code(
ErrorCode::InvalidArgument,
format!("Cannot extract Map<_, {:?}> (value has V type {:?})", v_type, value)
format!("Cannot extract Map<{}, {}> from Map<{}, {}>", K::into_tensor_element_type(), V::into_tensor_element_type(), k_type, v_type)
));
}

Expand All @@ -90,12 +98,15 @@ impl<Type: MapValueTypeMarker + ?Sized> Value<Type> {
ortsys![unsafe GetValue(self.ptr(), 0, allocator.ptr.as_ptr(), &mut key_tensor_ptr)?; nonNull(key_tensor_ptr)];
let key_value: DynTensor = unsafe { Value::from_ptr(NonNull::new_unchecked(key_tensor_ptr), None) };
if K::into_tensor_element_type() != TensorElementType::String {
let dtype = key_value.dtype()?;
let dtype = key_value.dtype();
let (key_tensor_shape, key_tensor) = match dtype {
ValueType::Tensor { ty, dimensions } => {
let device = key_value.memory_info()?.allocation_device()?;
if !device.is_cpu_accessible() {
return Err(Error::new(format!("Cannot extract from value on device `{}`, which is not CPU accessible", device.as_str())));
let mem = key_value.memory_info();
if !mem.is_cpu_accessible() {
return Err(Error::new(format!(
"Cannot extract from value on device `{}`, which is not CPU accessible",
mem.allocation_device().as_str()
)));
}

if ty == K::into_tensor_element_type() {
Expand All @@ -109,7 +120,13 @@ impl<Type: MapValueTypeMarker + ?Sized> Value<Type> {
} else {
return Err(Error::new_with_code(
ErrorCode::InvalidArgument,
format!("Cannot extract Map<{:?}, _> (value has K type {:?})", K::into_tensor_element_type(), ty)
format!(
"Cannot extract Map<{}, {}> from Map<{}, {}>",
K::into_tensor_element_type(),
V::into_tensor_element_type(),
k_type,
v_type
)
));
}
}
Expand Down Expand Up @@ -153,7 +170,10 @@ impl<Type: MapValueTypeMarker + ?Sized> Value<Type> {
Ok(vec.into_iter().collect())
}
}
t => Err(Error::new(format!("Cannot extract a Map from a value which is actually a {t:?}")))
t => Err(Error::new_with_code(
ErrorCode::InvalidArgument,
format!("Cannot extract Map<{}, {}> from {t}", K::into_tensor_element_type(), V::into_tensor_element_type())
))
}
}
}
Expand Down
18 changes: 13 additions & 5 deletions src/value/impl_sequence.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ pub trait SequenceValueTypeMarker: ValueTypeMarker {
#[derive(Debug)]
pub struct DynSequenceValueType;
impl ValueTypeMarker for DynSequenceValueType {
fn format() -> String {
"DynSequence".to_string()
}

crate::private_impl!();
}
impl SequenceValueTypeMarker for DynSequenceValueType {
Expand All @@ -28,6 +32,10 @@ impl SequenceValueTypeMarker for DynSequenceValueType {
#[derive(Debug)]
pub struct SequenceValueType<T: ValueTypeMarker + DowncastableTarget + Debug + ?Sized>(PhantomData<T>);
impl<T: ValueTypeMarker + DowncastableTarget + Debug + ?Sized> ValueTypeMarker for SequenceValueType<T> {
fn format() -> String {
format!("Sequence<{}>", T::format())
}

crate::private_impl!();
}
impl<T: ValueTypeMarker + DowncastableTarget + Debug + ?Sized> SequenceValueTypeMarker for SequenceValueType<T> {
Expand All @@ -47,7 +55,7 @@ impl<Type: SequenceValueTypeMarker + Sized> Value<Type> {
&'s self,
allocator: &Allocator
) -> Result<Vec<ValueRef<'s, OtherType>>> {
match self.dtype()? {
match self.dtype() {
ValueType::Sequence(_) => {
let mut len: ort_sys::size_t = 0;
ortsys![unsafe GetValueCount(self.ptr(), &mut len)?];
Expand All @@ -61,19 +69,19 @@ impl<Type: SequenceValueTypeMarker + Sized> Value<Type> {
inner: unsafe { Value::from_ptr(NonNull::new_unchecked(value_ptr), None) },
lifetime: PhantomData
};
let value_type = value.dtype()?;
if !OtherType::can_downcast(&value.dtype()?) {
let value_type = value.dtype();
if !OtherType::can_downcast(&value.dtype()) {
return Err(Error::new_with_code(
ErrorCode::InvalidArgument,
format!("Cannot extract Sequence<T> (downcast of T from {value_type:?} failed)")
format!("Cannot extract Sequence<{}> from {value_type:?}", OtherType::format())
));
}

vec.push(value);
}
Ok(vec)
}
t => Err(Error::new(format!("Cannot extract a Sequence from a value which is actually a {t:?}")))
t => Err(Error::new(format!("Cannot extract Sequence<{}> from {t}", OtherType::format())))
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/value/impl_tensor/create.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ impl<T: PrimitiveTensorElementType + Debug> Tensor<T> {
/// Raw data provided as a `Arc<Box<[T]>>`, `Box<[T]>`, or `Vec<T>` will never be copied. Raw data is expected to be
/// in standard, contigous layout.
pub fn from_array(input: impl IntoValueTensor<Item = T>) -> Result<Tensor<T>> {
let memory_info = MemoryInfo::new(AllocationDevice::CPU, 0, AllocatorType::Arena, MemoryType::Default)?;
let memory_info = MemoryInfo::new(AllocationDevice::CPU, 0, AllocatorType::Arena, MemoryType::CPUInput)?;

let mut value_ptr: *mut ort_sys::OrtValue = ptr::null_mut();

Expand Down
Loading

0 comments on commit 359a051

Please sign in to comment.