Skip to content

Commit

Permalink
Merge pull request #326 from PyO3/extraction-error
Browse files Browse the repository at this point in the history
Avoid the overhead of creating a PyErr for downcasting.
  • Loading branch information
adamreichold authored Apr 25, 2022
2 parents a8aac58 + 3795010 commit c1cc96f
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 29 deletions.
21 changes: 21 additions & 0 deletions benches/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,27 @@ fn extract_failure(bencher: &mut Bencher) {
});
}

#[bench]
fn downcast_success(bencher: &mut Bencher) {
Python::with_gil(|py| {
let any: &PyAny = PyArray2::<f64>::zeros(py, (10, 10), false);

bencher.iter(|| {
black_box(any).downcast::<PyArray2<f64>>().unwrap();
});
});
}

#[bench]
fn downcast_failure(bencher: &mut Bencher) {
Python::with_gil(|py| {
let any: &PyAny = PyArray2::<i32>::zeros(py, (10, 10), false);

bencher.iter(|| {
black_box(any).downcast::<PyArray2<f64>>().unwrap_err();
});
});
}
struct Iter(Range<usize>);

impl Iterator for Iter {
Expand Down
59 changes: 33 additions & 26 deletions src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use crate::cold;
use crate::convert::{ArrayExt, IntoPyArray, NpyIndex, ToNpyDims, ToPyArray};
use crate::dtype::{Element, PyArrayDescr};
use crate::error::{
BorrowError, DimensionalityError, FromVecError, NotContiguousError, TypeError,
BorrowError, DimensionalityError, FromVecError, IgnoreError, NotContiguousError, TypeError,
DIMENSIONALITY_MISMATCH_ERR, MAX_DIMENSIONALITY_ERR,
};
use crate::npyffi::{self, npy_intp, NPY_ORDER, PY_ARRAY_API};
Expand Down Expand Up @@ -131,7 +131,7 @@ unsafe impl<T: Element, D: Dimension> PyTypeInfo for PyArray<T, D> {
}

fn is_type_of(ob: &PyAny) -> bool {
<&Self>::extract(ob).is_ok()
Self::extract::<IgnoreError>(ob).is_ok()
}
}

Expand All @@ -145,30 +145,7 @@ impl<T, D> IntoPy<PyObject> for PyArray<T, D> {

impl<'py, T: Element, D: Dimension> FromPyObject<'py> for &'py PyArray<T, D> {
fn extract(ob: &'py PyAny) -> PyResult<Self> {
// Check if the object is an array.
let array = unsafe {
if npyffi::PyArray_Check(ob.py(), ob.as_ptr()) == 0 {
return Err(PyDowncastError::new(ob, PyArray::<T, D>::NAME).into());
}
&*(ob as *const PyAny as *const PyArray<T, D>)
};

// Check if the dimensionality matches `D`.
let src_ndim = array.ndim();
if let Some(dst_ndim) = D::NDIM {
if src_ndim != dst_ndim {
return Err(DimensionalityError::new(src_ndim, dst_ndim).into());
}
}

// Check if the element type matches `T`.
let src_dtype = array.dtype();
let dst_dtype = T::get_dtype(ob.py());
if !src_dtype.is_equiv_to(dst_dtype) {
return Err(TypeError::new(src_dtype, dst_dtype).into());
}

Ok(array)
PyArray::extract(ob)
}
}

Expand Down Expand Up @@ -390,6 +367,36 @@ impl<T, D> PyArray<T, D> {
}

impl<T: Element, D: Dimension> PyArray<T, D> {
fn extract<'py, E>(ob: &'py PyAny) -> Result<&'py Self, E>
where
E: From<PyDowncastError<'py>> + From<DimensionalityError> + From<TypeError<'py>>,
{
// Check if the object is an array.
let array = unsafe {
if npyffi::PyArray_Check(ob.py(), ob.as_ptr()) == 0 {
return Err(PyDowncastError::new(ob, Self::NAME).into());
}
&*(ob as *const PyAny as *const Self)
};

// Check if the dimensionality matches `D`.
let src_ndim = array.ndim();
if let Some(dst_ndim) = D::NDIM {
if src_ndim != dst_ndim {
return Err(DimensionalityError::new(src_ndim, dst_ndim).into());
}
}

// Check if the element type matches `T`.
let src_dtype = array.dtype();
let dst_dtype = T::get_dtype(ob.py());
if !src_dtype.is_equiv_to(dst_dtype) {
return Err(TypeError::new(src_dtype, dst_dtype).into());
}

Ok(array)
}

/// Same as [`shape`][Self::shape], but returns `D` insead of `&[usize]`.
#[inline(always)]
pub fn dims(&self) -> D {
Expand Down
10 changes: 7 additions & 3 deletions src/dtype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,12 @@ impl PyArrayDescr {

/// Returns true if two type descriptors are equivalent.
pub fn is_equiv_to(&self, other: &Self) -> bool {
let self_ptr = self.as_dtype_ptr();
let other_ptr = other.as_dtype_ptr();

unsafe {
PY_ARRAY_API.PyArray_EquivTypes(self.py(), self.as_dtype_ptr(), other.as_dtype_ptr())
!= 0
self_ptr == other_ptr
|| PY_ARRAY_API.PyArray_EquivTypes(self.py(), self_ptr, other_ptr) != 0
}
}

Expand Down Expand Up @@ -413,7 +416,7 @@ fn npy_int_type_lookup<T, T0, T1, T2>(npy_types: [NPY_TYPES; 3]) -> NPY_TYPES {

fn npy_int_type<T: Bounded + Zero + Sized + PartialEq>() -> NPY_TYPES {
let is_unsigned = T::min_value() == T::zero();
let bit_width = size_of::<T>() << 3;
let bit_width = 8 * size_of::<T>();

match (is_unsigned, bit_width) {
(false, 8) => NPY_TYPES::NPY_BYTE,
Expand Down Expand Up @@ -449,6 +452,7 @@ macro_rules! impl_element_scalar {
$(#[$meta])*
unsafe impl Element for $ty {
const IS_COPY: bool = true;

fn get_dtype(py: Python) -> &PyArrayDescr {
PyArrayDescr::from_npy_type(py, $npy_type)
}
Expand Down
15 changes: 15 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,18 @@ impl fmt::Display for BorrowError {
}

impl_pyerr!(BorrowError);

/// An internal type used to ignore certain error conditions
///
/// This is beneficial when those errors will never reach a public API anyway
/// but dropping them will improve performance.
pub(crate) struct IgnoreError;

impl<E> From<E> for IgnoreError
where
PyErr: From<E>,
{
fn from(_err: E) -> Self {
Self
}
}

0 comments on commit c1cc96f

Please sign in to comment.