diff --git a/src/array.rs b/src/array.rs index 878c11ae3..f6907ea1e 100644 --- a/src/array.rs +++ b/src/array.rs @@ -16,9 +16,9 @@ use ndarray::{ }; use num_traits::AsPrimitive; use pyo3::{ - ffi, pyobject_native_type_named, type_object, types::PyModule, AsPyPointer, FromPyObject, - IntoPy, Py, PyAny, PyClassInitializer, PyDowncastError, PyErr, PyNativeType, PyObject, - PyResult, PyTypeInfo, Python, ToPyObject, + ffi, pyobject_native_type_named, types::PyModule, AsPyPointer, FromPyObject, IntoPy, Py, PyAny, + PyClassInitializer, PyDowncastError, PyErr, PyNativeType, PyObject, PyResult, PyTypeInfo, + Python, ToPyObject, }; use crate::borrow::{PyReadonlyArray, PyReadwriteArray}; @@ -95,6 +95,7 @@ use crate::slice_container::PySliceContainer; /// /// [ndarray]: https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html /// [pyo3-memory]: https://pyo3.rs/main/memory.html +#[repr(transparent)] pub struct PyArray(PyAny, PhantomData, PhantomData); /// Zero-dimensional array. @@ -119,10 +120,6 @@ pub fn get_array_module(py: Python<'_>) -> PyResult<&PyModule> { PyModule::import(py, npyffi::array::MOD_NAME) } -unsafe impl type_object::PyLayout> for npyffi::PyArrayObject {} - -impl type_object::PySizedLayout> for npyffi::PyArrayObject {} - unsafe impl PyTypeInfo for PyArray { type AsRefTarget = Self; diff --git a/src/dtype.rs b/src/dtype.rs index f7bc20bcd..9c008e213 100644 --- a/src/dtype.rs +++ b/src/dtype.rs @@ -45,6 +45,7 @@ pub use num_complex::{Complex32, Complex64}; /// ``` /// /// [dtype]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.html +#[repr(transparent)] pub struct PyArrayDescr(PyAny); pyobject_native_type_named!(PyArrayDescr); @@ -61,12 +62,7 @@ unsafe impl PyTypeInfo for PyArrayDescr { } fn is_type_of(ob: &PyAny) -> bool { - unsafe { - ffi::PyObject_TypeCheck( - ob.as_ptr(), - PY_ARRAY_API.get_type_object(ob.py(), NpyTypes::PyArrayDescr_Type), - ) > 0 - } + unsafe { ffi::PyObject_TypeCheck(ob.as_ptr(), Self::type_object_raw(ob.py())) > 0 } } } diff --git a/src/npyffi/array.rs b/src/npyffi/array.rs index 1470f8aac..5be9e0020 100644 --- a/src/npyffi/array.rs +++ b/src/npyffi/array.rs @@ -328,14 +328,15 @@ impl PyArrayAPI { impl_api![303; PyArray_SetWritebackIfCopyBase(arr: *mut PyArrayObject, base: *mut PyArrayObject) -> c_int]; } -// Define type objects that belongs to Numpy API +// Define type objects associated with the NumPy API macro_rules! impl_array_type { ($(($offset:expr, $tname:ident)),*) => { - /// All type objects of numpy API. + /// All type objects exported by the NumPy API. #[allow(non_camel_case_types)] pub enum NpyTypes { $($tname),* } + impl PyArrayAPI { - /// Get the pointer of the type object that `self` refers. + /// Get a pointer of the type object assocaited with `ty`. pub unsafe fn get_type_object(&self, py: Python, ty: NpyTypes) -> *mut PyTypeObject { match ty { $( NpyTypes::$tname => *(self.get(py, $offset)) as _ ),* @@ -401,11 +402,11 @@ pub unsafe fn PyArray_CheckExact(py: Python, op: *mut PyObject) -> c_int { #[cfg(test)] mod tests { - use super::PY_ARRAY_API; + use super::*; #[test] fn call_api() { - pyo3::Python::with_gil(|py| unsafe { + Python::with_gil(|py| unsafe { assert_eq!( PY_ARRAY_API.PyArray_MultiplyIntList(py, [1, 2, 3].as_mut_ptr(), 3), 6 diff --git a/src/npyffi/mod.rs b/src/npyffi/mod.rs index aff5d1aba..d9026b22f 100644 --- a/src/npyffi/mod.rs +++ b/src/npyffi/mod.rs @@ -18,17 +18,17 @@ fn get_numpy_api(_py: Python, module: &str, capsule: &str) -> *const *const c_vo let module = CString::new(module).unwrap(); let capsule = CString::new(capsule).unwrap(); unsafe { - let numpy = ffi::PyImport_ImportModule(module.as_ptr()); - assert!(!numpy.is_null(), "Failed to import numpy module"); - let capsule = ffi::PyObject_GetAttrString(numpy as _, capsule.as_ptr()); - assert!(!capsule.is_null(), "Failed to get numpy capsule API"); + let module = ffi::PyImport_ImportModule(module.as_ptr()); + assert!(!module.is_null(), "Failed to import NumPy module"); + let capsule = ffi::PyObject_GetAttrString(module as _, capsule.as_ptr()); + assert!(!capsule.is_null(), "Failed to get NumPy API capsule"); ffi::PyCapsule_GetPointer(capsule, null_mut()) as _ } } -// Define Array&UFunc APIs +// Implements wrappers for NumPy's Array and UFunc API macro_rules! impl_api { - [$offset: expr; $fname: ident ( $($arg: ident : $t: ty),* ) $( -> $ret: ty )* ] => { + [$offset: expr; $fname: ident ( $($arg: ident : $t: ty),* $(,)?) $( -> $ret: ty )* ] => { #[allow(non_snake_case)] pub unsafe fn $fname(&self, py: Python, $($arg : $t), *) $( -> $ret )* { let fptr = self.get(py, $offset) @@ -36,10 +36,6 @@ macro_rules! impl_api { (*fptr)($($arg), *) } }; - // To allow fn a(b: type,) -> ret - [$offset: expr; $fname: ident ( $($arg: ident : $t:ty,)* ) $( -> $ret: ty )* ] => { - impl_api![$offset; $fname( $($arg: $t),*) $( -> $ret )*]; - } } pub mod array; diff --git a/src/npyffi/ufunc.rs b/src/npyffi/ufunc.rs index 50eb1d7c5..414538068 100644 --- a/src/npyffi/ufunc.rs +++ b/src/npyffi/ufunc.rs @@ -12,7 +12,7 @@ const MOD_NAME: &str = "numpy.core.umath"; const CAPSULE_NAME: &str = "_UFUNC_API"; /// A global variable which stores a ['capsule'](https://docs.python.org/3/c-api/capsule.html) -/// pointer to [Numpy UFunc API](https://numpy.org/doc/stable/reference/c-api/array.html). +/// pointer to [Numpy UFunc API](https://numpy.org/doc/stable/reference/c-api/ufunc.html). pub static PY_UFUNC_API: PyUFuncAPI = PyUFuncAPI::new(); pub struct PyUFuncAPI {