-
Notifications
You must be signed in to change notification settings - Fork 111
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add PyArrayLike wrapper around PyReadonlyArray
Extracts a read-only reference if the correct NumPy array type is given. Tries to convert the input into the correct type using `numpy.asarray` otherwise.
- Loading branch information
1 parent
08510a3
commit a6b6a14
Showing
4 changed files
with
260 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
use std::marker::PhantomData; | ||
use std::ops::Deref; | ||
|
||
use ndarray::{Array1, Dimension, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn}; | ||
use pyo3::{intern, sync::GILOnceCell, types::PyDict, FromPyObject, Py, PyAny, PyResult}; | ||
|
||
use crate::sealed::Sealed; | ||
use crate::{get_array_module, Element, IntoPyArray, PyArray, PyReadonlyArray}; | ||
|
||
pub trait Coerce: Sealed { | ||
const VAL: bool; | ||
} | ||
|
||
/// TODO | ||
#[derive(Debug)] | ||
pub struct TypeMustMatch; | ||
|
||
impl Sealed for TypeMustMatch {} | ||
|
||
impl Coerce for TypeMustMatch { | ||
const VAL: bool = false; | ||
} | ||
|
||
/// TODO | ||
#[derive(Debug)] | ||
pub struct AllowTypeChange; | ||
|
||
impl Sealed for AllowTypeChange {} | ||
|
||
impl Coerce for AllowTypeChange { | ||
const VAL: bool = true; | ||
} | ||
|
||
/// TODO | ||
#[derive(Debug)] | ||
#[repr(transparent)] | ||
pub struct PyArrayLike<'py, T, D, C = TypeMustMatch>(PyReadonlyArray<'py, T, D>, PhantomData<C>) | ||
where | ||
T: Element, | ||
D: Dimension, | ||
C: Coerce; | ||
|
||
impl<'py, T, D, C> Deref for PyArrayLike<'py, T, D, C> | ||
where | ||
T: Element, | ||
D: Dimension, | ||
C: Coerce, | ||
{ | ||
type Target = PyReadonlyArray<'py, T, D>; | ||
|
||
fn deref(&self) -> &Self::Target { | ||
&self.0 | ||
} | ||
} | ||
|
||
impl<'py, T, D, C> FromPyObject<'py> for PyArrayLike<'py, T, D, C> | ||
where | ||
T: Element, | ||
D: Dimension, | ||
C: Coerce, | ||
Vec<T>: FromPyObject<'py>, | ||
{ | ||
fn extract(ob: &'py PyAny) -> PyResult<Self> { | ||
if let Ok(array) = ob.downcast::<PyArray<T, D>>() { | ||
return Ok(Self(array.readonly(), PhantomData)); | ||
} | ||
|
||
let py = ob.py(); | ||
|
||
if matches!(D::NDIM, None | Some(1)) { | ||
if let Ok(vec) = ob.extract::<Vec<T>>() { | ||
let array = Array1::from(vec) | ||
.into_dimensionality() | ||
.expect("D being compatible to Ix1") | ||
.into_pyarray(py) | ||
.readonly(); | ||
return Ok(Self(array, PhantomData)); | ||
} | ||
} | ||
|
||
static AS_ARRAY: GILOnceCell<Py<PyAny>> = GILOnceCell::new(); | ||
|
||
let as_array = AS_ARRAY | ||
.get_or_try_init(py, || { | ||
get_array_module(py)?.getattr("asarray").map(Into::into) | ||
})? | ||
.as_ref(py); | ||
|
||
let kwargs = if C::VAL { | ||
let kwargs = PyDict::new(py); | ||
kwargs.set_item(intern!(py, "dtype"), T::get_dtype(py))?; | ||
Some(kwargs) | ||
} else { | ||
None | ||
}; | ||
|
||
let array = as_array.call((ob,), kwargs)?.extract()?; | ||
Ok(Self(array, PhantomData)) | ||
} | ||
} | ||
|
||
/// TODO | ||
pub type PyArrayLike1<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix1, C>; | ||
/// TODO | ||
pub type PyArrayLike2<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix2, C>; | ||
/// TODO | ||
pub type PyArrayLike3<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix3, C>; | ||
/// TODO | ||
pub type PyArrayLike4<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix4, C>; | ||
/// TODO | ||
pub type PyArrayLike5<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix5, C>; | ||
/// TODO | ||
pub type PyArrayLike6<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix6, C>; | ||
/// TODO | ||
pub type PyArrayLikeDyn<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, IxDyn, C>; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
use ndarray::array; | ||
use numpy::{get_array_module, AllowTypeChange, PyArrayLike1, PyArrayLike2, PyArrayLikeDyn}; | ||
use pyo3::{ | ||
types::{IntoPyDict, PyDict}, | ||
Python, | ||
}; | ||
|
||
fn get_np_locals(py: Python) -> &PyDict { | ||
[("np", get_array_module(py).unwrap())].into_py_dict(py) | ||
} | ||
|
||
#[test] | ||
fn extract_reference() { | ||
Python::with_gil(|py| { | ||
let locals = get_np_locals(py); | ||
let py_array = py | ||
.eval( | ||
"np.array([[1,2],[3,4]], dtype='float64')", | ||
Some(locals), | ||
None, | ||
) | ||
.unwrap(); | ||
let extracted_array = py_array.extract::<PyArrayLike2<f64>>().unwrap(); | ||
|
||
assert_eq!( | ||
array![[1_f64, 2_f64], [3_f64, 4_f64]], | ||
extracted_array.as_array() | ||
); | ||
}); | ||
} | ||
|
||
#[test] | ||
fn convert_array_on_extract() { | ||
Python::with_gil(|py| { | ||
let locals = get_np_locals(py); | ||
let py_array = py | ||
.eval("np.array([[1,2],[3,4]], dtype='int32')", Some(locals), None) | ||
.unwrap(); | ||
let extracted_array = py_array | ||
.extract::<PyArrayLike2<f64, AllowTypeChange>>() | ||
.unwrap(); | ||
|
||
assert_eq!( | ||
array![[1_f64, 2_f64], [3_f64, 4_f64]], | ||
extracted_array.as_array() | ||
); | ||
}); | ||
} | ||
|
||
#[test] | ||
fn convert_list_on_extract() { | ||
Python::with_gil(|py| { | ||
let py_list = py.eval("[[1.0,2.0],[3.0,4.0]]", None, None).unwrap(); | ||
let extracted_array = py_list.extract::<PyArrayLike2<f64>>().unwrap(); | ||
|
||
assert_eq!(array![[1.0, 2.0], [3.0, 4.0]], extracted_array.as_array()); | ||
}); | ||
} | ||
|
||
#[test] | ||
fn convert_array_in_list_on_extract() { | ||
Python::with_gil(|py| { | ||
let locals = get_np_locals(py); | ||
let py_array = py | ||
.eval("[np.array([1.0, 2.0]), [3.0, 4.0]]", Some(locals), None) | ||
.unwrap(); | ||
let extracted_array = py_array.extract::<PyArrayLike2<f64>>().unwrap(); | ||
|
||
assert_eq!(array![[1.0, 2.0], [3.0, 4.0]], extracted_array.as_array()); | ||
}); | ||
} | ||
|
||
#[test] | ||
fn convert_list_on_extract_dyn() { | ||
Python::with_gil(|py| { | ||
let py_list = py | ||
.eval("[[[1,2],[3,4]],[[5,6],[7,8]]]", None, None) | ||
.unwrap(); | ||
let extracted_array = py_list | ||
.extract::<PyArrayLikeDyn<i64, AllowTypeChange>>() | ||
.unwrap(); | ||
|
||
assert_eq!( | ||
array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]].into_dyn(), | ||
extracted_array.as_array() | ||
); | ||
}); | ||
} | ||
|
||
#[test] | ||
fn convert_1d_list_on_extract() { | ||
Python::with_gil(|py| { | ||
let py_list = py.eval("[1,2,3,4]", None, None).unwrap(); | ||
let extracted_array_1d = py_list.extract::<PyArrayLike1<u32>>().unwrap(); | ||
let extracted_array_dyn = py_list.extract::<PyArrayLikeDyn<f64>>().unwrap(); | ||
|
||
assert_eq!(array![1, 2, 3, 4], extracted_array_1d.as_array()); | ||
assert_eq!( | ||
array![1_f64, 2_f64, 3_f64, 4_f64].into_dyn(), | ||
extracted_array_dyn.as_array() | ||
); | ||
}); | ||
} | ||
|
||
#[test] | ||
fn unsafe_cast_shall_fail() { | ||
Python::with_gil(|py| { | ||
let locals = get_np_locals(py); | ||
let py_list = py | ||
.eval( | ||
"np.array([1.1,2.2,3.3,4.4], dtype='float64')", | ||
Some(locals), | ||
None, | ||
) | ||
.unwrap(); | ||
let extracted_array = py_list.extract::<PyArrayLike1<i32>>(); | ||
|
||
assert!(extracted_array.is_err()); | ||
}); | ||
} | ||
|
||
#[test] | ||
fn unsafe_cast_with_coerce_works() { | ||
Python::with_gil(|py| { | ||
let locals = get_np_locals(py); | ||
let py_list = py | ||
.eval( | ||
"np.array([1.1,2.2,3.3,4.4], dtype='float64')", | ||
Some(locals), | ||
None, | ||
) | ||
.unwrap(); | ||
let extracted_array = py_list | ||
.extract::<PyArrayLike1<i32, AllowTypeChange>>() | ||
.unwrap(); | ||
|
||
assert_eq!(array![1, 2, 3, 4], extracted_array.as_array()); | ||
}); | ||
} |