Skip to content

Commit

Permalink
Do not elide lifetimes and try to be disciplined about naming the GIL…
Browse files Browse the repository at this point in the history
… lifetime 'py.
  • Loading branch information
adamreichold committed Sep 8, 2023
1 parent 652619d commit ac8190b
Show file tree
Hide file tree
Showing 19 changed files with 120 additions and 101 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ use numpy::{IntoPyArray, PyArrayDyn, PyReadonlyArrayDyn};
use pyo3::{pymodule, types::PyModule, PyResult, Python};

#[pymodule]
fn rust_ext(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
fn rust_ext<'py>(_py: Python<'py>, m: &'py PyModule) -> PyResult<()> {
// example using immutable borrows producing a new array
fn axpy(a: f64, x: ArrayViewD<'_, f64>, y: ArrayViewD<'_, f64>) -> ArrayD<f64> {
a * &x + &y
Expand All @@ -65,8 +65,8 @@ fn rust_ext(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
fn axpy_py<'py>(
py: Python<'py>,
a: f64,
x: PyReadonlyArrayDyn<f64>,
y: PyReadonlyArrayDyn<f64>,
x: PyReadonlyArrayDyn<'py, f64>,
y: PyReadonlyArrayDyn<'py, f64>,
) -> &'py PyArrayDyn<f64> {
let x = x.as_array();
let y = y.as_array();
Expand All @@ -77,7 +77,7 @@ fn rust_ext(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
// wrapper of `mult`
#[pyfn(m)]
#[pyo3(name = "mult")]
fn mult_py(_py: Python<'_>, a: f64, x: &PyArrayDyn<f64>) {
fn mult_py<'py>(a: f64, x: &'py PyArrayDyn<f64>) {
let x = unsafe { x.as_array_mut() };
mult(a, x);
}
Expand Down
4 changes: 2 additions & 2 deletions examples/linalg/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ use numpy::{IntoPyArray, PyArray2, PyReadonlyArray2};
use pyo3::{exceptions::PyRuntimeError, pymodule, types::PyModule, PyResult, Python};

#[pymodule]
fn rust_linalg(_py: Python, m: &PyModule) -> PyResult<()> {
fn rust_linalg<'py>(_py: Python<'py>, m: &'py PyModule) -> PyResult<()> {
#[pyfn(m)]
fn inv<'py>(py: Python<'py>, x: PyReadonlyArray2<f64>) -> PyResult<&'py PyArray2<f64>> {
fn inv<'py>(py: Python<'py>, x: PyReadonlyArray2<'py, f64>) -> PyResult<&'py PyArray2<f64>> {
let x = x.as_array();
let y = x
.inv()
Expand Down
6 changes: 3 additions & 3 deletions examples/parallel/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@ use numpy::{IntoPyArray, PyArray1, PyReadonlyArray1, PyReadonlyArray2};
use pyo3::{pymodule, types::PyModule, PyResult, Python};

#[pymodule]
fn rust_parallel(_py: Python, m: &PyModule) -> PyResult<()> {
fn rust_parallel<'py>(_py: Python<'py>, m: &'py PyModule) -> PyResult<()> {
#[pyfn(m)]
fn rows_dot<'py>(
py: Python<'py>,
x: PyReadonlyArray2<f64>,
y: PyReadonlyArray1<f64>,
x: PyReadonlyArray2<'py, f64>,
y: PyReadonlyArray1<'py, f64>,
) -> &'py PyArray1<f64> {
let x = x.as_array();
let y = y.as_array();
Expand Down
31 changes: 17 additions & 14 deletions examples/simple/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,38 +14,41 @@ use pyo3::{
};

#[pymodule]
fn rust_ext(_py: Python, m: &PyModule) -> PyResult<()> {
fn rust_ext<'py>(_py: Python<'py>, m: &'py PyModule) -> PyResult<()> {
// example using generic PyObject
fn head(x: ArrayViewD<PyObject>) -> PyResult<PyObject> {
fn head(x: ArrayViewD<'_, PyObject>) -> PyResult<PyObject> {
x.get(0)
.cloned()
.ok_or_else(|| PyIndexError::new_err("array index out of range"))
}

// example using immutable borrows producing a new array
fn axpy(a: f64, x: ArrayViewD<f64>, y: ArrayViewD<f64>) -> ArrayD<f64> {
fn axpy(a: f64, x: ArrayViewD<'_, f64>, y: ArrayViewD<'_, f64>) -> ArrayD<f64> {
a * &x + &y
}

// example using a mutable borrow to modify an array in-place
fn mult(a: f64, mut x: ArrayViewMutD<f64>) {
fn mult(a: f64, mut x: ArrayViewMutD<'_, f64>) {
x *= a;
}

// example using complex numbers
fn conj(x: ArrayViewD<Complex64>) -> ArrayD<Complex64> {
fn conj(x: ArrayViewD<'_, Complex64>) -> ArrayD<Complex64> {
x.map(|c| c.conj())
}

// example using generics
fn generic_add<T: Copy + Add<Output = T>>(x: ArrayView1<T>, y: ArrayView1<T>) -> Array1<T> {
fn generic_add<T: Copy + Add<Output = T>>(
x: ArrayView1<'_, T>,
y: ArrayView1<'_, T>,
) -> Array1<T> {
&x + &y
}

// wrapper of `head`
#[pyfn(m)]
#[pyo3(name = "head")]
fn head_py(_py: Python, x: PyReadonlyArrayDyn<PyObject>) -> PyResult<PyObject> {
fn head_py<'py>(x: PyReadonlyArrayDyn<'py, PyObject>) -> PyResult<PyObject> {
head(x.as_array())
}

Expand All @@ -55,8 +58,8 @@ fn rust_ext(_py: Python, m: &PyModule) -> PyResult<()> {
fn axpy_py<'py>(
py: Python<'py>,
a: f64,
x: PyReadonlyArrayDyn<f64>,
y: PyReadonlyArrayDyn<f64>,
x: PyReadonlyArrayDyn<'py, f64>,
y: PyReadonlyArrayDyn<'py, f64>,
) -> &'py PyArrayDyn<f64> {
let x = x.as_array();
let y = y.as_array();
Expand All @@ -67,7 +70,7 @@ fn rust_ext(_py: Python, m: &PyModule) -> PyResult<()> {
// wrapper of `mult`
#[pyfn(m)]
#[pyo3(name = "mult")]
fn mult_py(a: f64, mut x: PyReadwriteArrayDyn<f64>) {
fn mult_py<'py>(a: f64, mut x: PyReadwriteArrayDyn<'py, f64>) {
let x = x.as_array_mut();
mult(a, x);
}
Expand All @@ -77,7 +80,7 @@ fn rust_ext(_py: Python, m: &PyModule) -> PyResult<()> {
#[pyo3(name = "conj")]
fn conj_py<'py>(
py: Python<'py>,
x: PyReadonlyArrayDyn<Complex64>,
x: PyReadonlyArrayDyn<'py, Complex64>,
) -> &'py PyArrayDyn<Complex64> {
conj(x.as_array()).into_pyarray(py)
}
Expand All @@ -96,9 +99,9 @@ fn rust_ext(_py: Python, m: &PyModule) -> PyResult<()> {

// example using timedelta64 array
#[pyfn(m)]
fn add_minutes_to_seconds(
mut x: PyReadwriteArray1<Timedelta<units::Seconds>>,
y: PyReadonlyArray1<Timedelta<units::Minutes>>,
fn add_minutes_to_seconds<'py>(
mut x: PyReadwriteArray1<'py, Timedelta<units::Seconds>>,
y: PyReadonlyArray1<'py, Timedelta<units::Minutes>>,
) {
#[allow(deprecated)]
Zip::from(x.as_array_mut())
Expand Down
26 changes: 13 additions & 13 deletions src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ pub type PyArray6<T> = PyArray<T, Ix6>;
pub type PyArrayDyn<T> = PyArray<T, IxDyn>;

/// Returns a handle to NumPy's multiarray module.
pub fn get_array_module(py: Python<'_>) -> PyResult<&PyModule> {
pub fn get_array_module<'py>(py: Python<'py>) -> PyResult<&PyModule> {
PyModule::import(py, npyffi::array::MOD_NAME)
}

Expand All @@ -128,7 +128,7 @@ unsafe impl<T: Element, D: Dimension> PyTypeInfo for PyArray<T, D> {
const NAME: &'static str = "PyArray<T, D>";
const MODULE: Option<&'static str> = Some("numpy");

fn type_object_raw(py: Python) -> *mut ffi::PyTypeObject {
fn type_object_raw<'py>(py: Python<'py>) -> *mut ffi::PyTypeObject {
unsafe { npyffi::PY_ARRAY_API.get_type_object(py, npyffi::NpyTypes::PyArray_Type) }
}

Expand Down Expand Up @@ -164,7 +164,7 @@ impl<T, D> AsPyPointer for PyArray<T, D> {

impl<T, D> IntoPy<Py<PyArray<T, D>>> for &'_ PyArray<T, D> {
#[inline]
fn into_py(self, py: Python<'_>) -> Py<PyArray<T, D>> {
fn into_py<'py>(self, py: Python<'py>) -> Py<PyArray<T, D>> {
unsafe { Py::from_borrowed_ptr(py, self.as_ptr()) }
}
}
Expand All @@ -183,7 +183,7 @@ impl<'a, T, D> From<&'a PyArray<T, D>> for &'a PyAny {
}

impl<T, D> IntoPy<PyObject> for PyArray<T, D> {
fn into_py(self, py: Python<'_>) -> PyObject {
fn into_py<'py>(self, py: Python<'py>) -> PyObject {
unsafe { PyObject::from_borrowed_ptr(py, self.as_ptr()) }
}
}
Expand Down Expand Up @@ -327,16 +327,16 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
/// assert_eq!(arr.shape(), &[4, 5, 6]);
/// });
/// ```
pub unsafe fn new<ID>(py: Python, dims: ID, is_fortran: bool) -> &Self
pub unsafe fn new<'py, ID>(py: Python<'py>, dims: ID, is_fortran: bool) -> &Self
where
ID: IntoDimension<Dim = D>,
{
let flags = c_int::from(is_fortran);
Self::new_uninit(py, dims, ptr::null_mut(), flags)
}

pub(crate) unsafe fn new_uninit<ID>(
py: Python,
pub(crate) unsafe fn new_uninit<'py, ID>(
py: Python<'py>,
dims: ID,
strides: *const npy_intp,
flag: c_int,
Expand Down Expand Up @@ -484,7 +484,7 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
///
/// [numpy-zeros]: https://numpy.org/doc/stable/reference/generated/numpy.zeros.html
/// [PyArray_Zeros]: https://numpy.org/doc/stable/reference/c-api/array.html#c.PyArray_Zeros
pub fn zeros<ID>(py: Python, dims: ID, is_fortran: bool) -> &Self
pub fn zeros<'py, ID>(py: Python<'py>, dims: ID, is_fortran: bool) -> &Self
where
ID: IntoDimension<Dim = D>,
{
Expand Down Expand Up @@ -989,7 +989,7 @@ where
#[doc(alias = "nalgebra")]
pub unsafe fn try_as_matrix<R, C, RStride, CStride>(
&self,
) -> Option<nalgebra::MatrixView<N, R, C, RStride, CStride>>
) -> Option<nalgebra::MatrixView<'_, N, R, C, RStride, CStride>>
where
R: nalgebra::Dim,
C: nalgebra::Dim,
Expand All @@ -1011,7 +1011,7 @@ where
#[doc(alias = "nalgebra")]
pub unsafe fn try_as_matrix_mut<R, C, RStride, CStride>(
&self,
) -> Option<nalgebra::MatrixViewMut<N, R, C, RStride, CStride>>
) -> Option<nalgebra::MatrixViewMut<'_, N, R, C, RStride, CStride>>
where
R: nalgebra::Dim,
C: nalgebra::Dim,
Expand Down Expand Up @@ -1086,7 +1086,7 @@ impl<T: Copy + Element> PyArray<T, Ix0> {
}

impl<T: Element> PyArray<T, Ix1> {
/// Construct a one-dimensional array from a [slice][std::slice].
/// Construct a one-dimensional array from a [mod@slice].
///
/// # Example
///
Expand Down Expand Up @@ -1144,7 +1144,7 @@ impl<T: Element> PyArray<T, Ix1> {
/// assert_eq!(pyarray.readonly().as_slice().unwrap(), &[97, 98, 99, 100, 101]);
/// });
/// ```
pub fn from_iter<I>(py: Python<'_>, iter: I) -> &Self
pub fn from_iter<'py, I>(py: Python<'py>, iter: I) -> &'py Self
where
I: IntoIterator<Item = T>,
{
Expand Down Expand Up @@ -1448,7 +1448,7 @@ impl<T: Element + AsPrimitive<f64>> PyArray<T, Ix1> {
///
/// [numpy.arange]: https://numpy.org/doc/stable/reference/generated/numpy.arange.html
/// [PyArray_Arange]: https://numpy.org/doc/stable/reference/c-api/array.html#c.PyArray_Arange
pub fn arange(py: Python, start: T, stop: T, step: T) -> &Self {
pub fn arange<'py>(py: Python<'py>, start: T, stop: T, step: T) -> &Self {
unsafe {
let ptr = PY_ARRAY_API.PyArray_Arange(
py,
Expand Down
16 changes: 8 additions & 8 deletions src/borrow/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ where

/// Provides an immutable array view of the interior of the NumPy array.
#[inline(always)]
pub fn as_array(&self) -> ArrayView<T, D> {
pub fn as_array(&self) -> ArrayView<'_, T, D> {
// SAFETY: Global borrow flags ensure aliasing discipline.
unsafe { self.array.as_array() }
}
Expand Down Expand Up @@ -278,7 +278,7 @@ where
#[doc(alias = "nalgebra")]
pub fn try_as_matrix<R, C, RStride, CStride>(
&self,
) -> Option<nalgebra::MatrixView<N, R, C, RStride, CStride>>
) -> Option<nalgebra::MatrixView<'_, N, R, C, RStride, CStride>>
where
R: nalgebra::Dim,
C: nalgebra::Dim,
Expand All @@ -300,7 +300,7 @@ where
///
/// Panics if the array has negative strides.
#[doc(alias = "nalgebra")]
pub fn as_matrix(&self) -> nalgebra::DMatrixView<N, nalgebra::Dyn, nalgebra::Dyn> {
pub fn as_matrix(&self) -> nalgebra::DMatrixView<'_, N, nalgebra::Dyn, nalgebra::Dyn> {
self.try_as_matrix().unwrap()
}
}
Expand All @@ -316,7 +316,7 @@ where
///
/// Panics if the array has negative strides.
#[doc(alias = "nalgebra")]
pub fn as_matrix(&self) -> nalgebra::DMatrixView<N, nalgebra::Dyn, nalgebra::Dyn> {
pub fn as_matrix(&self) -> nalgebra::DMatrixView<'_, N, nalgebra::Dyn, nalgebra::Dyn> {
self.try_as_matrix().unwrap()
}
}
Expand Down Expand Up @@ -428,7 +428,7 @@ where

/// Provides a mutable array view of the interior of the NumPy array.
#[inline(always)]
pub fn as_array_mut(&mut self) -> ArrayViewMut<T, D> {
pub fn as_array_mut(&mut self) -> ArrayViewMut<'_, T, D> {
// SAFETY: Global borrow flags ensure aliasing discipline.
unsafe { self.array.as_array_mut() }
}
Expand Down Expand Up @@ -460,7 +460,7 @@ where
#[doc(alias = "nalgebra")]
pub fn try_as_matrix_mut<R, C, RStride, CStride>(
&self,
) -> Option<nalgebra::MatrixViewMut<N, R, C, RStride, CStride>>
) -> Option<nalgebra::MatrixViewMut<'_, N, R, C, RStride, CStride>>
where
R: nalgebra::Dim,
C: nalgebra::Dim,
Expand All @@ -482,7 +482,7 @@ where
///
/// Panics if the array has negative strides.
#[doc(alias = "nalgebra")]
pub fn as_matrix_mut(&self) -> nalgebra::DMatrixViewMut<N, nalgebra::Dyn, nalgebra::Dyn> {
pub fn as_matrix_mut(&self) -> nalgebra::DMatrixViewMut<'_, N, nalgebra::Dyn, nalgebra::Dyn> {
self.try_as_matrix_mut().unwrap()
}
}
Expand All @@ -498,7 +498,7 @@ where
///
/// Panics if the array has negative strides.
#[doc(alias = "nalgebra")]
pub fn as_matrix_mut(&self) -> nalgebra::DMatrixViewMut<N, nalgebra::Dyn, nalgebra::Dyn> {
pub fn as_matrix_mut(&self) -> nalgebra::DMatrixViewMut<'_, N, nalgebra::Dyn, nalgebra::Dyn> {
self.try_as_matrix_mut().unwrap()
}
}
Expand Down
14 changes: 7 additions & 7 deletions src/borrow/shared.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ fn get_or_insert_shared<'py>(py: Python<'py>) -> PyResult<&'py Shared> {
// immediately initialize the cache used access it from this extension.

#[cold]
fn insert_shared(py: Python) -> PyResult<*const Shared> {
fn insert_shared<'py>(py: Python<'py>) -> PyResult<*const Shared> {
let module = get_array_module(py)?;

let capsule: &PyCapsule = match module.getattr("_RUST_NUMPY_BORROW_CHECKING_API") {
Expand Down Expand Up @@ -170,7 +170,7 @@ fn insert_shared(py: Python) -> PyResult<*const Shared> {

// These entry points will be used to access the shared borrow checking API from this extension:

pub fn acquire(py: Python, array: *mut PyArrayObject) -> Result<(), BorrowError> {
pub fn acquire<'py>(py: Python<'py>, array: *mut PyArrayObject) -> Result<(), BorrowError> {
let shared = get_or_insert_shared(py).expect("Interal borrow checking API error");

let rc = unsafe { (shared.acquire)(shared.flags, array) };
Expand All @@ -182,7 +182,7 @@ pub fn acquire(py: Python, array: *mut PyArrayObject) -> Result<(), BorrowError>
}
}

pub fn acquire_mut(py: Python, array: *mut PyArrayObject) -> Result<(), BorrowError> {
pub fn acquire_mut<'py>(py: Python<'py>, array: *mut PyArrayObject) -> Result<(), BorrowError> {
let shared = get_or_insert_shared(py).expect("Interal borrow checking API error");

let rc = unsafe { (shared.acquire_mut)(shared.flags, array) };
Expand All @@ -195,15 +195,15 @@ pub fn acquire_mut(py: Python, array: *mut PyArrayObject) -> Result<(), BorrowEr
}
}

pub fn release(py: Python, array: *mut PyArrayObject) {
pub fn release<'py>(py: Python<'py>, array: *mut PyArrayObject) {
let shared = get_or_insert_shared(py).expect("Interal borrow checking API error");

unsafe {
(shared.release)(shared.flags, array);
}
}

pub fn release_mut(py: Python, array: *mut PyArrayObject) {
pub fn release_mut<'py>(py: Python<'py>, array: *mut PyArrayObject) {
let shared = get_or_insert_shared(py).expect("Interal borrow checking API error");

unsafe {
Expand Down Expand Up @@ -365,7 +365,7 @@ impl BorrowFlags {
}
}

fn base_address(py: Python, mut array: *mut PyArrayObject) -> *mut c_void {
fn base_address<'py>(py: Python<'py>, mut array: *mut PyArrayObject) -> *mut c_void {
loop {
let base = unsafe { (*array).base };

Expand Down Expand Up @@ -450,7 +450,7 @@ mod tests {
use crate::array::{PyArray, PyArray1, PyArray2, PyArray3};
use crate::convert::IntoPyArray;

fn get_borrow_flags<'py>(py: Python) -> &'py BorrowFlagsInner {
fn get_borrow_flags<'py>(py: Python<'py>) -> &'py BorrowFlagsInner {
let shared = get_or_insert_shared(py).unwrap();
assert_eq!(shared.version, 1);
unsafe { &(*(shared.flags as *mut BorrowFlags)).0 }
Expand Down
Loading

0 comments on commit ac8190b

Please sign in to comment.