Skip to content

Commit

Permalink
Some minor copy-editing of the FFI module.
Browse files Browse the repository at this point in the history
  • Loading branch information
adamreichold committed Apr 21, 2022
1 parent 8599efc commit e1e5385
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 16 deletions.
11 changes: 6 additions & 5 deletions src/npyffi/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -328,14 +328,15 @@ impl PyArrayAPI {
impl_api![303; PyArray_SetWritebackIfCopyBase(arr: *mut PyArrayObject, base: *mut PyArrayObject) -> c_int];
}

// Define type objects that belongs to Numpy API
// Define type objects associated with the NumPy API
macro_rules! impl_array_type {
($(($offset:expr, $tname:ident)),*) => {
/// All type objects of numpy API.
/// All type objects exported by the NumPy API.
#[allow(non_camel_case_types)]
pub enum NpyTypes { $($tname),* }

impl PyArrayAPI {
/// Get the pointer of the type object that `self` refers.
/// Get a pointer of the type object assocaited with `ty`.
pub unsafe fn get_type_object(&self, py: Python, ty: NpyTypes) -> *mut PyTypeObject {
match ty {
$( NpyTypes::$tname => *(self.get(py, $offset)) as _ ),*
Expand Down Expand Up @@ -401,11 +402,11 @@ pub unsafe fn PyArray_CheckExact(py: Python, op: *mut PyObject) -> c_int {

#[cfg(test)]
mod tests {
use super::PY_ARRAY_API;
use super::*;

#[test]
fn call_api() {
pyo3::Python::with_gil(|py| unsafe {
Python::with_gil(|py| unsafe {
assert_eq!(
PY_ARRAY_API.PyArray_MultiplyIntList(py, [1, 2, 3].as_mut_ptr(), 3),
6
Expand Down
16 changes: 6 additions & 10 deletions src/npyffi/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,28 +18,24 @@ fn get_numpy_api(_py: Python, module: &str, capsule: &str) -> *const *const c_vo
let module = CString::new(module).unwrap();
let capsule = CString::new(capsule).unwrap();
unsafe {
let numpy = ffi::PyImport_ImportModule(module.as_ptr());
assert!(!numpy.is_null(), "Failed to import numpy module");
let capsule = ffi::PyObject_GetAttrString(numpy as _, capsule.as_ptr());
assert!(!capsule.is_null(), "Failed to get numpy capsule API");
let module = ffi::PyImport_ImportModule(module.as_ptr());
assert!(!module.is_null(), "Failed to import NumPy module");
let capsule = ffi::PyObject_GetAttrString(module as _, capsule.as_ptr());
assert!(!capsule.is_null(), "Failed to get NumPy API capsule");
ffi::PyCapsule_GetPointer(capsule, null_mut()) as _
}
}

// Define Array&UFunc APIs
// Implements wrappers for NumPy's Array and UFunc API
macro_rules! impl_api {
[$offset: expr; $fname: ident ( $($arg: ident : $t: ty),* ) $( -> $ret: ty )* ] => {
[$offset: expr; $fname: ident ( $($arg: ident : $t: ty),* $(,)?) $( -> $ret: ty )* ] => {
#[allow(non_snake_case)]
pub unsafe fn $fname(&self, py: Python, $($arg : $t), *) $( -> $ret )* {
let fptr = self.get(py, $offset)
as *const extern fn ($($arg : $t), *) $( -> $ret )*;
(*fptr)($($arg), *)
}
};
// To allow fn a(b: type,) -> ret
[$offset: expr; $fname: ident ( $($arg: ident : $t:ty,)* ) $( -> $ret: ty )* ] => {
impl_api![$offset; $fname( $($arg: $t),*) $( -> $ret )*];
}
}

pub mod array;
Expand Down
2 changes: 1 addition & 1 deletion src/npyffi/ufunc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ const MOD_NAME: &str = "numpy.core.umath";
const CAPSULE_NAME: &str = "_UFUNC_API";

/// A global variable which stores a ['capsule'](https://docs.python.org/3/c-api/capsule.html)
/// pointer to [Numpy UFunc API](https://numpy.org/doc/stable/reference/c-api/array.html).
/// pointer to [Numpy UFunc API](https://numpy.org/doc/stable/reference/c-api/ufunc.html).
pub static PY_UFUNC_API: PyUFuncAPI = PyUFuncAPI::new();

pub struct PyUFuncAPI {
Expand Down

0 comments on commit e1e5385

Please sign in to comment.