Skip to content

Commit

Permalink
Merge pull request #325 from PyO3/repr-transparent
Browse files Browse the repository at this point in the history
Use repr(transparent) to enforce layout compatibility with PyAny
  • Loading branch information
adamreichold authored Apr 21, 2022
2 parents 8b2679a + e1e5385 commit a8aac58
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 29 deletions.
11 changes: 4 additions & 7 deletions src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ use ndarray::{
};
use num_traits::AsPrimitive;
use pyo3::{
ffi, pyobject_native_type_named, type_object, types::PyModule, AsPyPointer, FromPyObject,
IntoPy, Py, PyAny, PyClassInitializer, PyDowncastError, PyErr, PyNativeType, PyObject,
PyResult, PyTypeInfo, Python, ToPyObject,
ffi, pyobject_native_type_named, types::PyModule, AsPyPointer, FromPyObject, IntoPy, Py, PyAny,
PyClassInitializer, PyDowncastError, PyErr, PyNativeType, PyObject, PyResult, PyTypeInfo,
Python, ToPyObject,
};

use crate::borrow::{PyReadonlyArray, PyReadwriteArray};
Expand Down Expand Up @@ -95,6 +95,7 @@ use crate::slice_container::PySliceContainer;
///
/// [ndarray]: https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html
/// [pyo3-memory]: https://pyo3.rs/main/memory.html
#[repr(transparent)]
pub struct PyArray<T, D>(PyAny, PhantomData<T>, PhantomData<D>);

/// Zero-dimensional array.
Expand All @@ -119,10 +120,6 @@ pub fn get_array_module(py: Python<'_>) -> PyResult<&PyModule> {
PyModule::import(py, npyffi::array::MOD_NAME)
}

unsafe impl<T, D> type_object::PyLayout<PyArray<T, D>> for npyffi::PyArrayObject {}

impl<T, D> type_object::PySizedLayout<PyArray<T, D>> for npyffi::PyArrayObject {}

unsafe impl<T: Element, D: Dimension> PyTypeInfo for PyArray<T, D> {
type AsRefTarget = Self;

Expand Down
8 changes: 2 additions & 6 deletions src/dtype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ pub use num_complex::{Complex32, Complex64};
/// ```
///
/// [dtype]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.html
#[repr(transparent)]
pub struct PyArrayDescr(PyAny);

pyobject_native_type_named!(PyArrayDescr);
Expand All @@ -61,12 +62,7 @@ unsafe impl PyTypeInfo for PyArrayDescr {
}

fn is_type_of(ob: &PyAny) -> bool {
unsafe {
ffi::PyObject_TypeCheck(
ob.as_ptr(),
PY_ARRAY_API.get_type_object(ob.py(), NpyTypes::PyArrayDescr_Type),
) > 0
}
unsafe { ffi::PyObject_TypeCheck(ob.as_ptr(), Self::type_object_raw(ob.py())) > 0 }
}
}

Expand Down
11 changes: 6 additions & 5 deletions src/npyffi/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -328,14 +328,15 @@ impl PyArrayAPI {
impl_api![303; PyArray_SetWritebackIfCopyBase(arr: *mut PyArrayObject, base: *mut PyArrayObject) -> c_int];
}

// Define type objects that belongs to Numpy API
// Define type objects associated with the NumPy API
macro_rules! impl_array_type {
($(($offset:expr, $tname:ident)),*) => {
/// All type objects of numpy API.
/// All type objects exported by the NumPy API.
#[allow(non_camel_case_types)]
pub enum NpyTypes { $($tname),* }

impl PyArrayAPI {
/// Get the pointer of the type object that `self` refers.
/// Get a pointer of the type object assocaited with `ty`.
pub unsafe fn get_type_object(&self, py: Python, ty: NpyTypes) -> *mut PyTypeObject {
match ty {
$( NpyTypes::$tname => *(self.get(py, $offset)) as _ ),*
Expand Down Expand Up @@ -401,11 +402,11 @@ pub unsafe fn PyArray_CheckExact(py: Python, op: *mut PyObject) -> c_int {

#[cfg(test)]
mod tests {
use super::PY_ARRAY_API;
use super::*;

#[test]
fn call_api() {
pyo3::Python::with_gil(|py| unsafe {
Python::with_gil(|py| unsafe {
assert_eq!(
PY_ARRAY_API.PyArray_MultiplyIntList(py, [1, 2, 3].as_mut_ptr(), 3),
6
Expand Down
16 changes: 6 additions & 10 deletions src/npyffi/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,28 +18,24 @@ fn get_numpy_api(_py: Python, module: &str, capsule: &str) -> *const *const c_vo
let module = CString::new(module).unwrap();
let capsule = CString::new(capsule).unwrap();
unsafe {
let numpy = ffi::PyImport_ImportModule(module.as_ptr());
assert!(!numpy.is_null(), "Failed to import numpy module");
let capsule = ffi::PyObject_GetAttrString(numpy as _, capsule.as_ptr());
assert!(!capsule.is_null(), "Failed to get numpy capsule API");
let module = ffi::PyImport_ImportModule(module.as_ptr());
assert!(!module.is_null(), "Failed to import NumPy module");
let capsule = ffi::PyObject_GetAttrString(module as _, capsule.as_ptr());
assert!(!capsule.is_null(), "Failed to get NumPy API capsule");
ffi::PyCapsule_GetPointer(capsule, null_mut()) as _
}
}

// Define Array&UFunc APIs
// Implements wrappers for NumPy's Array and UFunc API
macro_rules! impl_api {
[$offset: expr; $fname: ident ( $($arg: ident : $t: ty),* ) $( -> $ret: ty )* ] => {
[$offset: expr; $fname: ident ( $($arg: ident : $t: ty),* $(,)?) $( -> $ret: ty )* ] => {
#[allow(non_snake_case)]
pub unsafe fn $fname(&self, py: Python, $($arg : $t), *) $( -> $ret )* {
let fptr = self.get(py, $offset)
as *const extern fn ($($arg : $t), *) $( -> $ret )*;
(*fptr)($($arg), *)
}
};
// To allow fn a(b: type,) -> ret
[$offset: expr; $fname: ident ( $($arg: ident : $t:ty,)* ) $( -> $ret: ty )* ] => {
impl_api![$offset; $fname( $($arg: $t),*) $( -> $ret )*];
}
}

pub mod array;
Expand Down
2 changes: 1 addition & 1 deletion src/npyffi/ufunc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ const MOD_NAME: &str = "numpy.core.umath";
const CAPSULE_NAME: &str = "_UFUNC_API";

/// A global variable which stores a ['capsule'](https://docs.python.org/3/c-api/capsule.html)
/// pointer to [Numpy UFunc API](https://numpy.org/doc/stable/reference/c-api/array.html).
/// pointer to [Numpy UFunc API](https://numpy.org/doc/stable/reference/c-api/ufunc.html).
pub static PY_UFUNC_API: PyUFuncAPI = PyUFuncAPI::new();

pub struct PyUFuncAPI {
Expand Down

0 comments on commit a8aac58

Please sign in to comment.