Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyArrayLike type introduced #383

Merged
merged 1 commit into from
Oct 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
- Increase MSRV to 1.56 released in October 2021 and available in Debain 12, RHEL 9 and Alpine 3.17 following the same change for PyO3. ([#378](https://github.com/PyO3/rust-numpy/pull/378))
- Add support for ASCII (`PyFixedString<N>`) and Unicode (`PyFixedUnicode<N>`) string arrays, i.e. dtypes `SN` and `UN` where `N` is the number of characters. ([#378](https://github.com/PyO3/rust-numpy/pull/378))
- Add support for the `bfloat16` dtype by extending the optional integration with the `half` crate. Note that the `bfloat16` dtype is not part of NumPy itself so that usage requires third-party packages like Tensorflow. ([#381](https://github.com/PyO3/rust-numpy/pull/381))
- Add `PyArrayLike` type which extracts `PyReadonlyArray` if a NumPy array of the correct type is given and attempts a conversion using `numpy.asarray` otherwise. ([#383](https://github.com/PyO3/rust-numpy/pull/383))

- v0.19.0
- Add `PyUntypedArray` as an untyped base type for `PyArray` which can be used to inspect arguments before more targeted downcasts. This is accompanied by some methods like `dtype` and `shape` moving from `PyArray` to `PyUntypedArray`. They are still accessible though, as `PyArray` dereferences to `PyUntypedArray` via the `Deref` trait. ([#369](https://github.com/PyO3/rust-numpy/pull/369))
Expand Down
195 changes: 195 additions & 0 deletions src/array_like.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
use std::marker::PhantomData;
use std::ops::Deref;

use ndarray::{Array1, Dimension, Ix0, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn};
use pyo3::{intern, sync::GILOnceCell, types::PyDict, FromPyObject, Py, PyAny, PyResult};

use crate::sealed::Sealed;
use crate::{get_array_module, Element, IntoPyArray, PyArray, PyReadonlyArray};

pub trait Coerce: Sealed {
const VAL: bool;
}

/// Marker type to indicate that the element type received via [`PyArrayLike`] must match the specified type exactly.
#[derive(Debug)]
pub struct TypeMustMatch;

impl Sealed for TypeMustMatch {}

impl Coerce for TypeMustMatch {
const VAL: bool = false;
}

/// Marker type to indicate that the element type received via [`PyArrayLike`] can be cast to the specified type by NumPy's [`asarray`](https://numpy.org/doc/stable/reference/generated/numpy.asarray.html).
#[derive(Debug)]
pub struct AllowTypeChange;

impl Sealed for AllowTypeChange {}

impl Coerce for AllowTypeChange {
const VAL: bool = true;
}

/// Receiver for arrays or array-like types.
///
/// When building API using NumPy in Python, it is common for functions to additionally accept any array-like type such as `list[float]` as arguments.
/// `PyArrayLike` enables the same pattern in Rust extensions, i.e. by taking this type as the argument of a `#[pyfunction]`,
/// one will always get access to a [`PyReadonlyArray`] that will either reference to the NumPy array originally passed into the function
/// or a temporary one created by converting the input type into a NumPy array.
///
/// Depending on whether [`TypeMustMatch`] or [`AllowTypeChange`] is used for the `C` type parameter,
/// the element type must either match the specific type `T` exactly or will be cast to it by NumPy's [`asarray`](https://numpy.org/doc/stable/reference/generated/numpy.asarray.html).
///
/// # Example
///
/// `PyArrayLike1<'py, T, TypeMustMatch>` will enable you to receive both NumPy arrays and sequences
///
/// ```rust
/// # use pyo3::prelude::*;
/// use pyo3::py_run;
/// use numpy::{get_array_module, PyArrayLike1, TypeMustMatch};
///
/// #[pyfunction]
/// fn sum_up<'py>(py: Python<'py>, array: PyArrayLike1<'py, f64, TypeMustMatch>) -> f64 {
/// array.as_array().sum()
/// }
///
/// Python::with_gil(|py| {
/// let np = get_array_module(py).unwrap();
/// let sum_up = wrap_pyfunction!(sum_up)(py).unwrap();
///
/// py_run!(py, np sum_up, r"assert sum_up(np.array([1., 2., 3.])) == 6.");
/// py_run!(py, np sum_up, r"assert sum_up((1., 2., 3.)) == 6.");
/// });
/// ```
///
/// but it will not cast the element type if that is required
///
/// ```rust,should_panic
/// use pyo3::prelude::*;
/// use pyo3::py_run;
/// use numpy::{get_array_module, PyArrayLike1, TypeMustMatch};
///
/// #[pyfunction]
/// fn sum_up<'py>(py: Python<'py>, array: PyArrayLike1<'py, i32, TypeMustMatch>) -> i32 {
/// array.as_array().sum()
/// }
///
/// Python::with_gil(|py| {
/// let np = get_array_module(py).unwrap();
/// let sum_up = wrap_pyfunction!(sum_up)(py).unwrap();
///
/// py_run!(py, np sum_up, r"assert sum_up((1., 2., 3.)) == 6");
/// });
/// ```
///
/// whereas `PyArrayLike1<'py, T, AllowTypeChange>` will do even at the cost loosing precision
///
/// ```rust
/// use pyo3::prelude::*;
/// use pyo3::py_run;
/// use numpy::{get_array_module, AllowTypeChange, PyArrayLike1};
///
/// #[pyfunction]
/// fn sum_up<'py>(py: Python<'py>, array: PyArrayLike1<'py, i32, AllowTypeChange>) -> i32 {
/// array.as_array().sum()
/// }
///
/// Python::with_gil(|py| {
/// let np = get_array_module(py).unwrap();
/// let sum_up = wrap_pyfunction!(sum_up)(py).unwrap();
///
/// py_run!(py, np sum_up, r"assert sum_up((1.5, 2.5)) == 3");
/// });
/// ```
#[derive(Debug)]
#[repr(transparent)]
pub struct PyArrayLike<'py, T, D, C = TypeMustMatch>(PyReadonlyArray<'py, T, D>, PhantomData<C>)
where
T: Element,
D: Dimension,
C: Coerce;

impl<'py, T, D, C> Deref for PyArrayLike<'py, T, D, C>
where
T: Element,
D: Dimension,
C: Coerce,
{
type Target = PyReadonlyArray<'py, T, D>;

fn deref(&self) -> &Self::Target {
&self.0
}
}

impl<'py, T, D, C> FromPyObject<'py> for PyArrayLike<'py, T, D, C>
where
T: Element,
D: Dimension,
C: Coerce,
Vec<T>: FromPyObject<'py>,
{
fn extract(ob: &'py PyAny) -> PyResult<Self> {
if let Ok(array) = ob.downcast::<PyArray<T, D>>() {
return Ok(Self(array.readonly(), PhantomData));
}

let py = ob.py();

if matches!(D::NDIM, None | Some(1)) {
if let Ok(vec) = ob.extract::<Vec<T>>() {
let array = Array1::from(vec)
.into_dimensionality()
.expect("D being compatible to Ix1")
.into_pyarray(py)
.readonly();
return Ok(Self(array, PhantomData));
}
}

static AS_ARRAY: GILOnceCell<Py<PyAny>> = GILOnceCell::new();

let as_array = AS_ARRAY
.get_or_try_init(py, || {
get_array_module(py)?.getattr("asarray").map(Into::into)
})?
.as_ref(py);

let kwargs = if C::VAL {
let kwargs = PyDict::new(py);
kwargs.set_item(intern!(py, "dtype"), T::get_dtype(py))?;
Some(kwargs)
} else {
None
};

let array = as_array.call((ob,), kwargs)?.extract()?;
Ok(Self(array, PhantomData))
}
}

/// Receiver for zero-dimensional arrays or array-like types.
pub type PyArrayLike0<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix0, C>;

/// Receiver for one-dimensional arrays or array-like types.
pub type PyArrayLike1<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix1, C>;

/// Receiver for two-dimensional arrays or array-like types.
pub type PyArrayLike2<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix2, C>;

/// Receiver for three-dimensional arrays or array-like types.
pub type PyArrayLike3<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix3, C>;

/// Receiver for four-dimensional arrays or array-like types.
pub type PyArrayLike4<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix4, C>;

/// Receiver for five-dimensional arrays or array-like types.
pub type PyArrayLike5<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix5, C>;

/// Receiver for six-dimensional arrays or array-like types.
pub type PyArrayLike6<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix6, C>;

/// Receiver for arrays or array-like types whose dimensionality is determined at runtime.
pub type PyArrayLikeDyn<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, IxDyn, C>;
5 changes: 5 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ as well as the [`PyReadonlyArray::try_as_matrix`] and [`PyReadwriteArray::try_as
#![deny(missing_docs, missing_debug_implementations)]

pub mod array;
mod array_like;
pub mod borrow;
pub mod convert;
pub mod datetime;
Expand All @@ -94,6 +95,10 @@ pub use crate::array::{
get_array_module, PyArray, PyArray0, PyArray1, PyArray2, PyArray3, PyArray4, PyArray5,
PyArray6, PyArrayDyn,
};
pub use crate::array_like::{
AllowTypeChange, PyArrayLike, PyArrayLike0, PyArrayLike1, PyArrayLike2, PyArrayLike3,
PyArrayLike4, PyArrayLike5, PyArrayLike6, PyArrayLikeDyn, TypeMustMatch,
};
pub use crate::borrow::{
PyReadonlyArray, PyReadonlyArray0, PyReadonlyArray1, PyReadonlyArray2, PyReadonlyArray3,
PyReadonlyArray4, PyReadonlyArray5, PyReadonlyArray6, PyReadonlyArrayDyn, PyReadwriteArray,
Expand Down
139 changes: 139 additions & 0 deletions tests/array_like.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
use ndarray::array;
use numpy::{get_array_module, AllowTypeChange, PyArrayLike1, PyArrayLike2, PyArrayLikeDyn};
use pyo3::{
types::{IntoPyDict, PyDict},
Python,
};

fn get_np_locals<'py>(py: Python<'py>) -> &'py PyDict {
[("np", get_array_module(py).unwrap())].into_py_dict(py)
}

#[test]
fn extract_reference() {
Python::with_gil(|py| {
let locals = get_np_locals(py);
let py_array = py
.eval(
"np.array([[1,2],[3,4]], dtype='float64')",
Some(locals),
None,
)
.unwrap();
let extracted_array = py_array.extract::<PyArrayLike2<'_, f64>>().unwrap();

assert_eq!(
array![[1_f64, 2_f64], [3_f64, 4_f64]],
extracted_array.as_array()
);
});
}

#[test]
fn convert_array_on_extract() {
Python::with_gil(|py| {
let locals = get_np_locals(py);
let py_array = py
.eval("np.array([[1,2],[3,4]], dtype='int32')", Some(locals), None)
.unwrap();
let extracted_array = py_array
.extract::<PyArrayLike2<'_, f64, AllowTypeChange>>()
.unwrap();

assert_eq!(
array![[1_f64, 2_f64], [3_f64, 4_f64]],
extracted_array.as_array()
);
});
}

#[test]
fn convert_list_on_extract() {
Python::with_gil(|py| {
let py_list = py.eval("[[1.0,2.0],[3.0,4.0]]", None, None).unwrap();
let extracted_array = py_list.extract::<PyArrayLike2<'_, f64>>().unwrap();

assert_eq!(array![[1.0, 2.0], [3.0, 4.0]], extracted_array.as_array());
});
}

#[test]
fn convert_array_in_list_on_extract() {
Python::with_gil(|py| {
let locals = get_np_locals(py);
let py_array = py
.eval("[np.array([1.0, 2.0]), [3.0, 4.0]]", Some(locals), None)
.unwrap();
let extracted_array = py_array.extract::<PyArrayLike2<'_, f64>>().unwrap();

assert_eq!(array![[1.0, 2.0], [3.0, 4.0]], extracted_array.as_array());
});
}

#[test]
fn convert_list_on_extract_dyn() {
Python::with_gil(|py| {
let py_list = py
.eval("[[[1,2],[3,4]],[[5,6],[7,8]]]", None, None)
.unwrap();
let extracted_array = py_list
.extract::<PyArrayLikeDyn<'_, i64, AllowTypeChange>>()
.unwrap();

assert_eq!(
array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
extracted_array.as_array()
);
});
}

#[test]
fn convert_1d_list_on_extract() {
Python::with_gil(|py| {
let py_list = py.eval("[1,2,3,4]", None, None).unwrap();
let extracted_array_1d = py_list.extract::<PyArrayLike1<'_, u32>>().unwrap();
let extracted_array_dyn = py_list.extract::<PyArrayLikeDyn<'_, f64>>().unwrap();

assert_eq!(array![1, 2, 3, 4], extracted_array_1d.as_array());
assert_eq!(
array![1_f64, 2_f64, 3_f64, 4_f64].into_dyn(),
extracted_array_dyn.as_array()
);
});
}

#[test]
fn unsafe_cast_shall_fail() {
Python::with_gil(|py| {
let locals = get_np_locals(py);
let py_list = py
.eval(
"np.array([1.1,2.2,3.3,4.4], dtype='float64')",
Some(locals),
None,
)
.unwrap();
let extracted_array = py_list.extract::<PyArrayLike1<'_, i32>>();

assert!(extracted_array.is_err());
});
}

#[test]
fn unsafe_cast_with_coerce_works() {
Python::with_gil(|py| {
let locals = get_np_locals(py);
let py_list = py
.eval(
"np.array([1.1,2.2,3.3,4.4], dtype='float64')",
Some(locals),
None,
)
.unwrap();
let extracted_array = py_list
.extract::<PyArrayLike1<'_, i32, AllowTypeChange>>()
.unwrap();

assert_eq!(array![1, 2, 3, 4], extracted_array.as_array());
});
}
Loading