Skip to content

Commit

Permalink
Add PyArrayLike wrapper around PyReadonlyArray
Browse files Browse the repository at this point in the history
Extracts a read-only reference if the correct NumPy array type is given.
Tries to convert the input into the correct type using `numpy.asarray`
otherwise.
  • Loading branch information
124C41p authored and adamreichold committed Jul 9, 2023
1 parent 08510a3 commit a6b6a14
Show file tree
Hide file tree
Showing 4 changed files with 260 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
- Increase MSRV to 1.56 released in October 2021 and available in Debain 12, RHEL 9 and Alpine 3.17 following the same change for PyO3. ([#378](https://github.com/PyO3/rust-numpy/pull/378))
- Add support for ASCII (`PyFixedString<N>`) and Unicode (`PyFixedUnicode<N>`) string arrays, i.e. dtypes `SN` and `UN` where `N` is the number of characters. ([#378](https://github.com/PyO3/rust-numpy/pull/378))
- Add support for the `bfloat16` dtype by extending the optional integration with the `half` crate. Note that the `bfloat16` dtype is not part of NumPy itself so that usage requires third-party packages like Tensorflow. ([#381](https://github.com/PyO3/rust-numpy/pull/381))
- Add `PyArrayLike` type which extracts `PyReadonlyArray` if a NumPy array of the correct type is given and attempts a conversion using `numpy.asarray` otherwise. ([#383](https://github.com/PyO3/rust-numpy/pull/383))

- v0.19.0
- Add `PyUntypedArray` as an untyped base type for `PyArray` which can be used to inspect arguments before more targeted downcasts. This is accompanied by some methods like `dtype` and `shape` moving from `PyArray` to `PyUntypedArray`. They are still accessible though, as `PyArray` dereferences to `PyUntypedArray` via the `Deref` trait. ([#369](https://github.com/PyO3/rust-numpy/pull/369))
Expand Down
115 changes: 115 additions & 0 deletions src/array_like.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
use std::marker::PhantomData;
use std::ops::Deref;

use ndarray::{Array1, Dimension, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn};
use pyo3::{intern, sync::GILOnceCell, types::PyDict, FromPyObject, Py, PyAny, PyResult};

use crate::sealed::Sealed;
use crate::{get_array_module, Element, IntoPyArray, PyArray, PyReadonlyArray};

pub trait Coerce: Sealed {
const VAL: bool;
}

/// TODO
#[derive(Debug)]
pub struct TypeMustMatch;

impl Sealed for TypeMustMatch {}

impl Coerce for TypeMustMatch {
const VAL: bool = false;
}

/// TODO
#[derive(Debug)]
pub struct AllowTypeChange;

impl Sealed for AllowTypeChange {}

impl Coerce for AllowTypeChange {
const VAL: bool = true;
}

/// TODO
#[derive(Debug)]
#[repr(transparent)]
pub struct PyArrayLike<'py, T, D, C = TypeMustMatch>(PyReadonlyArray<'py, T, D>, PhantomData<C>)
where
T: Element,
D: Dimension,
C: Coerce;

impl<'py, T, D, C> Deref for PyArrayLike<'py, T, D, C>
where
T: Element,
D: Dimension,
C: Coerce,
{
type Target = PyReadonlyArray<'py, T, D>;

fn deref(&self) -> &Self::Target {
&self.0
}
}

impl<'py, T, D, C> FromPyObject<'py> for PyArrayLike<'py, T, D, C>
where
T: Element,
D: Dimension,
C: Coerce,
Vec<T>: FromPyObject<'py>,
{
fn extract(ob: &'py PyAny) -> PyResult<Self> {
if let Ok(array) = ob.downcast::<PyArray<T, D>>() {
return Ok(Self(array.readonly(), PhantomData));
}

let py = ob.py();

if matches!(D::NDIM, None | Some(1)) {
if let Ok(vec) = ob.extract::<Vec<T>>() {
let array = Array1::from(vec)
.into_dimensionality()
.expect("D being compatible to Ix1")
.into_pyarray(py)
.readonly();
return Ok(Self(array, PhantomData));
}
}

static AS_ARRAY: GILOnceCell<Py<PyAny>> = GILOnceCell::new();

let as_array = AS_ARRAY
.get_or_try_init(py, || {
get_array_module(py)?.getattr("asarray").map(Into::into)
})?
.as_ref(py);

let kwargs = if C::VAL {
let kwargs = PyDict::new(py);
kwargs.set_item(intern!(py, "dtype"), T::get_dtype(py))?;
Some(kwargs)
} else {
None
};

let array = as_array.call((ob,), kwargs)?.extract()?;
Ok(Self(array, PhantomData))
}
}

/// TODO
pub type PyArrayLike1<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix1, C>;
/// TODO
pub type PyArrayLike2<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix2, C>;
/// TODO
pub type PyArrayLike3<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix3, C>;
/// TODO
pub type PyArrayLike4<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix4, C>;
/// TODO
pub type PyArrayLike5<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix5, C>;
/// TODO
pub type PyArrayLike6<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix6, C>;
/// TODO
pub type PyArrayLikeDyn<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, IxDyn, C>;
5 changes: 5 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ as well as the [`PyReadonlyArray::try_as_matrix`] and [`PyReadwriteArray::try_as
#![allow(clippy::needless_lifetimes)]

pub mod array;
mod array_like;
pub mod borrow;
pub mod convert;
pub mod datetime;
Expand All @@ -96,6 +97,10 @@ pub use crate::array::{
get_array_module, PyArray, PyArray0, PyArray1, PyArray2, PyArray3, PyArray4, PyArray5,
PyArray6, PyArrayDyn,
};
pub use crate::array_like::{
AllowTypeChange, PyArrayLike, PyArrayLike1, PyArrayLike2, PyArrayLike3, PyArrayLike4,
PyArrayLike5, PyArrayLike6, PyArrayLikeDyn, TypeMustMatch,
};
pub use crate::borrow::{
PyReadonlyArray, PyReadonlyArray1, PyReadonlyArray2, PyReadonlyArray3, PyReadonlyArray4,
PyReadonlyArray5, PyReadonlyArray6, PyReadonlyArrayDyn, PyReadwriteArray, PyReadwriteArray1,
Expand Down
139 changes: 139 additions & 0 deletions tests/array_like.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
use ndarray::array;
use numpy::{get_array_module, AllowTypeChange, PyArrayLike1, PyArrayLike2, PyArrayLikeDyn};
use pyo3::{
types::{IntoPyDict, PyDict},
Python,
};

fn get_np_locals(py: Python) -> &PyDict {
[("np", get_array_module(py).unwrap())].into_py_dict(py)
}

#[test]
fn extract_reference() {
Python::with_gil(|py| {
let locals = get_np_locals(py);
let py_array = py
.eval(
"np.array([[1,2],[3,4]], dtype='float64')",
Some(locals),
None,
)
.unwrap();
let extracted_array = py_array.extract::<PyArrayLike2<f64>>().unwrap();

assert_eq!(
array![[1_f64, 2_f64], [3_f64, 4_f64]],
extracted_array.as_array()
);
});
}

#[test]
fn convert_array_on_extract() {
Python::with_gil(|py| {
let locals = get_np_locals(py);
let py_array = py
.eval("np.array([[1,2],[3,4]], dtype='int32')", Some(locals), None)
.unwrap();
let extracted_array = py_array
.extract::<PyArrayLike2<f64, AllowTypeChange>>()
.unwrap();

assert_eq!(
array![[1_f64, 2_f64], [3_f64, 4_f64]],
extracted_array.as_array()
);
});
}

#[test]
fn convert_list_on_extract() {
Python::with_gil(|py| {
let py_list = py.eval("[[1.0,2.0],[3.0,4.0]]", None, None).unwrap();
let extracted_array = py_list.extract::<PyArrayLike2<f64>>().unwrap();

assert_eq!(array![[1.0, 2.0], [3.0, 4.0]], extracted_array.as_array());
});
}

#[test]
fn convert_array_in_list_on_extract() {
Python::with_gil(|py| {
let locals = get_np_locals(py);
let py_array = py
.eval("[np.array([1.0, 2.0]), [3.0, 4.0]]", Some(locals), None)
.unwrap();
let extracted_array = py_array.extract::<PyArrayLike2<f64>>().unwrap();

assert_eq!(array![[1.0, 2.0], [3.0, 4.0]], extracted_array.as_array());
});
}

#[test]
fn convert_list_on_extract_dyn() {
Python::with_gil(|py| {
let py_list = py
.eval("[[[1,2],[3,4]],[[5,6],[7,8]]]", None, None)
.unwrap();
let extracted_array = py_list
.extract::<PyArrayLikeDyn<i64, AllowTypeChange>>()
.unwrap();

assert_eq!(
array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
extracted_array.as_array()
);
});
}

#[test]
fn convert_1d_list_on_extract() {
Python::with_gil(|py| {
let py_list = py.eval("[1,2,3,4]", None, None).unwrap();
let extracted_array_1d = py_list.extract::<PyArrayLike1<u32>>().unwrap();
let extracted_array_dyn = py_list.extract::<PyArrayLikeDyn<f64>>().unwrap();

assert_eq!(array![1, 2, 3, 4], extracted_array_1d.as_array());
assert_eq!(
array![1_f64, 2_f64, 3_f64, 4_f64].into_dyn(),
extracted_array_dyn.as_array()
);
});
}

#[test]
fn unsafe_cast_shall_fail() {
Python::with_gil(|py| {
let locals = get_np_locals(py);
let py_list = py
.eval(
"np.array([1.1,2.2,3.3,4.4], dtype='float64')",
Some(locals),
None,
)
.unwrap();
let extracted_array = py_list.extract::<PyArrayLike1<i32>>();

assert!(extracted_array.is_err());
});
}

#[test]
fn unsafe_cast_with_coerce_works() {
Python::with_gil(|py| {
let locals = get_np_locals(py);
let py_list = py
.eval(
"np.array([1.1,2.2,3.3,4.4], dtype='float64')",
Some(locals),
None,
)
.unwrap();
let extracted_array = py_list
.extract::<PyArrayLike1<i32, AllowTypeChange>>()
.unwrap();

assert_eq!(array![1, 2, 3, 4], extracted_array.as_array());
});
}

0 comments on commit a6b6a14

Please sign in to comment.