Skip to content

Commit a6b6a14

Browse files
124C41padamreichold
authored andcommitted
Add PyArrayLike wrapper around PyReadonlyArray
Extracts a read-only reference if the correct NumPy array type is given. Tries to convert the input into the correct type using `numpy.asarray` otherwise.
1 parent 08510a3 commit a6b6a14

File tree

4 files changed

+260
-0
lines changed

4 files changed

+260
-0
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
- Increase MSRV to 1.56 released in October 2021 and available in Debain 12, RHEL 9 and Alpine 3.17 following the same change for PyO3. ([#378](https://github.com/PyO3/rust-numpy/pull/378))
55
- Add support for ASCII (`PyFixedString<N>`) and Unicode (`PyFixedUnicode<N>`) string arrays, i.e. dtypes `SN` and `UN` where `N` is the number of characters. ([#378](https://github.com/PyO3/rust-numpy/pull/378))
66
- Add support for the `bfloat16` dtype by extending the optional integration with the `half` crate. Note that the `bfloat16` dtype is not part of NumPy itself so that usage requires third-party packages like Tensorflow. ([#381](https://github.com/PyO3/rust-numpy/pull/381))
7+
- Add `PyArrayLike` type which extracts `PyReadonlyArray` if a NumPy array of the correct type is given and attempts a conversion using `numpy.asarray` otherwise. ([#383](https://github.com/PyO3/rust-numpy/pull/383))
78

89
- v0.19.0
910
- Add `PyUntypedArray` as an untyped base type for `PyArray` which can be used to inspect arguments before more targeted downcasts. This is accompanied by some methods like `dtype` and `shape` moving from `PyArray` to `PyUntypedArray`. They are still accessible though, as `PyArray` dereferences to `PyUntypedArray` via the `Deref` trait. ([#369](https://github.com/PyO3/rust-numpy/pull/369))

src/array_like.rs

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
use std::marker::PhantomData;
2+
use std::ops::Deref;
3+
4+
use ndarray::{Array1, Dimension, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn};
5+
use pyo3::{intern, sync::GILOnceCell, types::PyDict, FromPyObject, Py, PyAny, PyResult};
6+
7+
use crate::sealed::Sealed;
8+
use crate::{get_array_module, Element, IntoPyArray, PyArray, PyReadonlyArray};
9+
10+
pub trait Coerce: Sealed {
11+
const VAL: bool;
12+
}
13+
14+
/// TODO
15+
#[derive(Debug)]
16+
pub struct TypeMustMatch;
17+
18+
impl Sealed for TypeMustMatch {}
19+
20+
impl Coerce for TypeMustMatch {
21+
const VAL: bool = false;
22+
}
23+
24+
/// TODO
25+
#[derive(Debug)]
26+
pub struct AllowTypeChange;
27+
28+
impl Sealed for AllowTypeChange {}
29+
30+
impl Coerce for AllowTypeChange {
31+
const VAL: bool = true;
32+
}
33+
34+
/// TODO
35+
#[derive(Debug)]
36+
#[repr(transparent)]
37+
pub struct PyArrayLike<'py, T, D, C = TypeMustMatch>(PyReadonlyArray<'py, T, D>, PhantomData<C>)
38+
where
39+
T: Element,
40+
D: Dimension,
41+
C: Coerce;
42+
43+
impl<'py, T, D, C> Deref for PyArrayLike<'py, T, D, C>
44+
where
45+
T: Element,
46+
D: Dimension,
47+
C: Coerce,
48+
{
49+
type Target = PyReadonlyArray<'py, T, D>;
50+
51+
fn deref(&self) -> &Self::Target {
52+
&self.0
53+
}
54+
}
55+
56+
impl<'py, T, D, C> FromPyObject<'py> for PyArrayLike<'py, T, D, C>
57+
where
58+
T: Element,
59+
D: Dimension,
60+
C: Coerce,
61+
Vec<T>: FromPyObject<'py>,
62+
{
63+
fn extract(ob: &'py PyAny) -> PyResult<Self> {
64+
if let Ok(array) = ob.downcast::<PyArray<T, D>>() {
65+
return Ok(Self(array.readonly(), PhantomData));
66+
}
67+
68+
let py = ob.py();
69+
70+
if matches!(D::NDIM, None | Some(1)) {
71+
if let Ok(vec) = ob.extract::<Vec<T>>() {
72+
let array = Array1::from(vec)
73+
.into_dimensionality()
74+
.expect("D being compatible to Ix1")
75+
.into_pyarray(py)
76+
.readonly();
77+
return Ok(Self(array, PhantomData));
78+
}
79+
}
80+
81+
static AS_ARRAY: GILOnceCell<Py<PyAny>> = GILOnceCell::new();
82+
83+
let as_array = AS_ARRAY
84+
.get_or_try_init(py, || {
85+
get_array_module(py)?.getattr("asarray").map(Into::into)
86+
})?
87+
.as_ref(py);
88+
89+
let kwargs = if C::VAL {
90+
let kwargs = PyDict::new(py);
91+
kwargs.set_item(intern!(py, "dtype"), T::get_dtype(py))?;
92+
Some(kwargs)
93+
} else {
94+
None
95+
};
96+
97+
let array = as_array.call((ob,), kwargs)?.extract()?;
98+
Ok(Self(array, PhantomData))
99+
}
100+
}
101+
102+
/// TODO
103+
pub type PyArrayLike1<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix1, C>;
104+
/// TODO
105+
pub type PyArrayLike2<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix2, C>;
106+
/// TODO
107+
pub type PyArrayLike3<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix3, C>;
108+
/// TODO
109+
pub type PyArrayLike4<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix4, C>;
110+
/// TODO
111+
pub type PyArrayLike5<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix5, C>;
112+
/// TODO
113+
pub type PyArrayLike6<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix6, C>;
114+
/// TODO
115+
pub type PyArrayLikeDyn<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, IxDyn, C>;

src/lib.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ as well as the [`PyReadonlyArray::try_as_matrix`] and [`PyReadwriteArray::try_as
7575
#![allow(clippy::needless_lifetimes)]
7676

7777
pub mod array;
78+
mod array_like;
7879
pub mod borrow;
7980
pub mod convert;
8081
pub mod datetime;
@@ -96,6 +97,10 @@ pub use crate::array::{
9697
get_array_module, PyArray, PyArray0, PyArray1, PyArray2, PyArray3, PyArray4, PyArray5,
9798
PyArray6, PyArrayDyn,
9899
};
100+
pub use crate::array_like::{
101+
AllowTypeChange, PyArrayLike, PyArrayLike1, PyArrayLike2, PyArrayLike3, PyArrayLike4,
102+
PyArrayLike5, PyArrayLike6, PyArrayLikeDyn, TypeMustMatch,
103+
};
99104
pub use crate::borrow::{
100105
PyReadonlyArray, PyReadonlyArray1, PyReadonlyArray2, PyReadonlyArray3, PyReadonlyArray4,
101106
PyReadonlyArray5, PyReadonlyArray6, PyReadonlyArrayDyn, PyReadwriteArray, PyReadwriteArray1,

tests/array_like.rs

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
use ndarray::array;
2+
use numpy::{get_array_module, AllowTypeChange, PyArrayLike1, PyArrayLike2, PyArrayLikeDyn};
3+
use pyo3::{
4+
types::{IntoPyDict, PyDict},
5+
Python,
6+
};
7+
8+
fn get_np_locals(py: Python) -> &PyDict {
9+
[("np", get_array_module(py).unwrap())].into_py_dict(py)
10+
}
11+
12+
#[test]
13+
fn extract_reference() {
14+
Python::with_gil(|py| {
15+
let locals = get_np_locals(py);
16+
let py_array = py
17+
.eval(
18+
"np.array([[1,2],[3,4]], dtype='float64')",
19+
Some(locals),
20+
None,
21+
)
22+
.unwrap();
23+
let extracted_array = py_array.extract::<PyArrayLike2<f64>>().unwrap();
24+
25+
assert_eq!(
26+
array![[1_f64, 2_f64], [3_f64, 4_f64]],
27+
extracted_array.as_array()
28+
);
29+
});
30+
}
31+
32+
#[test]
33+
fn convert_array_on_extract() {
34+
Python::with_gil(|py| {
35+
let locals = get_np_locals(py);
36+
let py_array = py
37+
.eval("np.array([[1,2],[3,4]], dtype='int32')", Some(locals), None)
38+
.unwrap();
39+
let extracted_array = py_array
40+
.extract::<PyArrayLike2<f64, AllowTypeChange>>()
41+
.unwrap();
42+
43+
assert_eq!(
44+
array![[1_f64, 2_f64], [3_f64, 4_f64]],
45+
extracted_array.as_array()
46+
);
47+
});
48+
}
49+
50+
#[test]
51+
fn convert_list_on_extract() {
52+
Python::with_gil(|py| {
53+
let py_list = py.eval("[[1.0,2.0],[3.0,4.0]]", None, None).unwrap();
54+
let extracted_array = py_list.extract::<PyArrayLike2<f64>>().unwrap();
55+
56+
assert_eq!(array![[1.0, 2.0], [3.0, 4.0]], extracted_array.as_array());
57+
});
58+
}
59+
60+
#[test]
61+
fn convert_array_in_list_on_extract() {
62+
Python::with_gil(|py| {
63+
let locals = get_np_locals(py);
64+
let py_array = py
65+
.eval("[np.array([1.0, 2.0]), [3.0, 4.0]]", Some(locals), None)
66+
.unwrap();
67+
let extracted_array = py_array.extract::<PyArrayLike2<f64>>().unwrap();
68+
69+
assert_eq!(array![[1.0, 2.0], [3.0, 4.0]], extracted_array.as_array());
70+
});
71+
}
72+
73+
#[test]
74+
fn convert_list_on_extract_dyn() {
75+
Python::with_gil(|py| {
76+
let py_list = py
77+
.eval("[[[1,2],[3,4]],[[5,6],[7,8]]]", None, None)
78+
.unwrap();
79+
let extracted_array = py_list
80+
.extract::<PyArrayLikeDyn<i64, AllowTypeChange>>()
81+
.unwrap();
82+
83+
assert_eq!(
84+
array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(),
85+
extracted_array.as_array()
86+
);
87+
});
88+
}
89+
90+
#[test]
91+
fn convert_1d_list_on_extract() {
92+
Python::with_gil(|py| {
93+
let py_list = py.eval("[1,2,3,4]", None, None).unwrap();
94+
let extracted_array_1d = py_list.extract::<PyArrayLike1<u32>>().unwrap();
95+
let extracted_array_dyn = py_list.extract::<PyArrayLikeDyn<f64>>().unwrap();
96+
97+
assert_eq!(array![1, 2, 3, 4], extracted_array_1d.as_array());
98+
assert_eq!(
99+
array![1_f64, 2_f64, 3_f64, 4_f64].into_dyn(),
100+
extracted_array_dyn.as_array()
101+
);
102+
});
103+
}
104+
105+
#[test]
106+
fn unsafe_cast_shall_fail() {
107+
Python::with_gil(|py| {
108+
let locals = get_np_locals(py);
109+
let py_list = py
110+
.eval(
111+
"np.array([1.1,2.2,3.3,4.4], dtype='float64')",
112+
Some(locals),
113+
None,
114+
)
115+
.unwrap();
116+
let extracted_array = py_list.extract::<PyArrayLike1<i32>>();
117+
118+
assert!(extracted_array.is_err());
119+
});
120+
}
121+
122+
#[test]
123+
fn unsafe_cast_with_coerce_works() {
124+
Python::with_gil(|py| {
125+
let locals = get_np_locals(py);
126+
let py_list = py
127+
.eval(
128+
"np.array([1.1,2.2,3.3,4.4], dtype='float64')",
129+
Some(locals),
130+
None,
131+
)
132+
.unwrap();
133+
let extracted_array = py_list
134+
.extract::<PyArrayLike1<i32, AllowTypeChange>>()
135+
.unwrap();
136+
137+
assert_eq!(array![1, 2, 3, 4], extracted_array.as_array());
138+
});
139+
}

0 commit comments

Comments
 (0)