Skip to content

Commit

Permalink
apply check for core vs _core at runtime
Browse files Browse the repository at this point in the history
  • Loading branch information
davidhewitt committed Oct 11, 2024
1 parent 7e702be commit d07dbf5
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 5 deletions.
2 changes: 1 addition & 1 deletion src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ pub type PyArrayDyn<T> = PyArray<T, IxDyn>;

/// Returns a handle to NumPy's multiarray module.
pub fn get_array_module<'py>(py: Python<'py>) -> PyResult<Bound<'_, PyModule>> {
PyModule::import_bound(py, npyffi::array::MOD_NAME)
PyModule::import_bound(py, npyffi::array::mod_name(py)?)
}

impl<T, D> DerefToPyAny for PyArray<T, D> {}
Expand Down
40 changes: 38 additions & 2 deletions src/npyffi/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,43 @@ use pyo3::{

use crate::npyffi::*;

pub(crate) const MOD_NAME: &str = "numpy._core.multiarray";
pub(crate) fn numpy_core_name(py: Python<'_>) -> PyResult<&'static str> {
static MOD_NAME: GILOnceCell<&'static str> = GILOnceCell::new();

MOD_NAME
.get_or_try_init(py, || {
// numpy 2 renamed to numpy._core

// strategy mirrored from https://github.com/pybind/pybind11/blob/af67e87393b0f867ccffc2702885eea12de063fc/include/pybind11/numpy.h#L175-L195

let numpy = PyModule::import_bound(py, "numpy")?;
let version_string = numpy.getattr("__version__")?;

let numpy_lib = PyModule::import_bound(py, "numpy.lib")?;
let numpy_version = numpy_lib
.getattr("NumpyVersion")?
.call1((version_string,))?;
let major_version: u8 = numpy_version.getattr("major")?.extract()?;

Ok(if major_version >= 2 {
"numpy._core"
} else {
"numpy.core"
})
})
.copied()
}

pub(crate) fn mod_name(py: Python<'_>) -> PyResult<&'static str> {
static MOD_NAME: GILOnceCell<String> = GILOnceCell::new();
MOD_NAME
.get_or_try_init(py, || {
let numpy_core = numpy_core_name(py)?;
Ok(format!("{}.multiarray", numpy_core))
})
.map(String::as_str)
}

const CAPSULE_NAME: &str = "_ARRAY_API";

/// A global variable which stores a ['capsule'](https://docs.python.org/3/c-api/capsule.html)
Expand Down Expand Up @@ -49,7 +85,7 @@ impl PyArrayAPI {
unsafe fn get<'py>(&self, py: Python<'py>, offset: isize) -> *const *const c_void {
let api = self
.0
.get_or_try_init(py, || get_numpy_api(py, MOD_NAME, CAPSULE_NAME))
.get_or_try_init(py, || get_numpy_api(py, mod_name(py)?, CAPSULE_NAME))
.expect("Failed to access NumPy array API capsule");

api.offset(offset)
Expand Down
13 changes: 11 additions & 2 deletions src/npyffi/ufunc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,16 @@ use pyo3::{ffi::PyObject, sync::GILOnceCell};

use crate::npyffi::*;

const MOD_NAME: &str = "numpy.core.umath";
fn mod_name(py: Python<'_>) -> PyResult<&'static str> {
static MOD_NAME: GILOnceCell<String> = GILOnceCell::new();
MOD_NAME
.get_or_try_init(py, || {
let numpy_core = super::array::numpy_core_name(py)?;
Ok(format!("{}.umath", numpy_core))
})
.map(String::as_str)
}

const CAPSULE_NAME: &str = "_UFUNC_API";

/// A global variable which stores a ['capsule'](https://docs.python.org/3/c-api/capsule.html)
Expand All @@ -23,7 +32,7 @@ impl PyUFuncAPI {
unsafe fn get<'py>(&self, py: Python<'py>, offset: isize) -> *const *const c_void {
let api = self
.0
.get_or_try_init(py, || get_numpy_api(py, MOD_NAME, CAPSULE_NAME))
.get_or_try_init(py, || get_numpy_api(py, mod_name(py)?, CAPSULE_NAME))
.expect("Failed to access NumPy ufunc API capsule");

api.offset(offset)
Expand Down

0 comments on commit d07dbf5

Please sign in to comment.