Skip to content

Commit

Permalink
Work around MSRV problems with const generics using sealed trait.
Browse files Browse the repository at this point in the history
  • Loading branch information
adamreichold committed Jul 8, 2023
1 parent 7c1f688 commit d672028
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 21 deletions.
62 changes: 46 additions & 16 deletions src/array_like.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,50 @@
use std::marker::PhantomData;
use std::ops::Deref;

use ndarray::{Array1, Dimension, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn};
use pyo3::{intern, sync::GILOnceCell, types::PyDict, FromPyObject, Py, PyAny, PyResult};

use crate::sealed::Sealed;
use crate::{get_array_module, Element, IntoPyArray, PyArray, PyReadonlyArray};

pub trait Coerce: Sealed {
const VAL: bool;
}

/// TODO
#[derive(Debug)]
pub struct TypeMustMatch;

impl Sealed for TypeMustMatch {}

impl Coerce for TypeMustMatch {
const VAL: bool = false;
}

/// TODO
#[derive(Debug)]
pub struct AllowTypeChange;

impl Sealed for AllowTypeChange {}

impl Coerce for AllowTypeChange {
const VAL: bool = true;
}

/// TODO
#[derive(Debug)]
#[repr(transparent)]
pub struct PyArrayLike<'py, T, D, const COERCE: bool = false>(PyReadonlyArray<'py, T, D>)
pub struct PyArrayLike<'py, T, D, C = TypeMustMatch>(PyReadonlyArray<'py, T, D>, PhantomData<C>)
where
T: Element,
D: Dimension;
D: Dimension,
C: Coerce;

impl<'py, T, D, const COERCE: bool> Deref for PyArrayLike<'py, T, D, COERCE>
impl<'py, T, D, C> Deref for PyArrayLike<'py, T, D, C>
where
T: Element,
D: Dimension,
C: Coerce,
{
type Target = PyReadonlyArray<'py, T, D>;

Expand All @@ -25,27 +53,28 @@ where
}
}

impl<'py, T, D, const COERCE: bool> FromPyObject<'py> for PyArrayLike<'py, T, D, COERCE>
impl<'py, T, D, C> FromPyObject<'py> for PyArrayLike<'py, T, D, C>
where
T: Element,
D: Dimension,
C: Coerce,
Vec<T>: FromPyObject<'py>,
{
fn extract(ob: &'py PyAny) -> PyResult<Self> {
if let Ok(array) = ob.downcast::<PyArray<T, D>>() {
return Ok(Self(array.readonly()));
return Ok(Self(array.readonly(), PhantomData));
}

let py = ob.py();

if matches!(D::NDIM, None | Some(1)) {
if let Ok(vec) = ob.extract::<Vec<T>>() {
let array = Array1::from_vec(vec)
let array = Array1::from(vec)
.into_dimensionality()
.expect("D being compatible to Ix1")
.into_pyarray(py)
.readonly();
return Ok(Self(array));
return Ok(Self(array, PhantomData));
}
}

Expand All @@ -57,29 +86,30 @@ where
})?
.as_ref(py);

let kwargs = if COERCE {
let kwargs = if C::VAL {
let kwargs = PyDict::new(py);
kwargs.set_item(intern!(py, "dtype"), T::get_dtype(py))?;
Some(kwargs)
} else {
None
};

as_array.call((ob,), kwargs)?.extract().map(Self)
let array = as_array.call((ob,), kwargs)?.extract()?;
Ok(Self(array, PhantomData))
}
}

/// TODO
pub type PyArrayLike1<'py, T, const COERCE: bool = false> = PyArrayLike<'py, T, Ix1, COERCE>;
pub type PyArrayLike1<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix1, C>;
/// TODO
pub type PyArrayLike2<'py, T, const COERCE: bool = false> = PyArrayLike<'py, T, Ix2, COERCE>;
pub type PyArrayLike2<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix2, C>;
/// TODO
pub type PyArrayLike3<'py, T, const COERCE: bool = false> = PyArrayLike<'py, T, Ix3, COERCE>;
pub type PyArrayLike3<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix3, C>;
/// TODO
pub type PyArrayLike4<'py, T, const COERCE: bool = false> = PyArrayLike<'py, T, Ix4, COERCE>;
pub type PyArrayLike4<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix4, C>;
/// TODO
pub type PyArrayLike5<'py, T, const COERCE: bool = false> = PyArrayLike<'py, T, Ix5, COERCE>;
pub type PyArrayLike5<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix5, C>;
/// TODO
pub type PyArrayLike6<'py, T, const COERCE: bool = false> = PyArrayLike<'py, T, Ix6, COERCE>;
pub type PyArrayLike6<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix6, C>;
/// TODO
pub type PyArrayLikeDyn<'py, T, const COERCE: bool = false> = PyArrayLike<'py, T, IxDyn, COERCE>;
pub type PyArrayLikeDyn<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, IxDyn, C>;
4 changes: 2 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ pub use crate::array::{
PyArray6, PyArrayDyn,
};
pub use crate::array_like::{
PyArrayLike, PyArrayLike1, PyArrayLike2, PyArrayLike3, PyArrayLike4, PyArrayLike5,
PyArrayLike6, PyArrayLikeDyn,
AllowTypeChange, PyArrayLike, PyArrayLike1, PyArrayLike2, PyArrayLike3, PyArrayLike4,
PyArrayLike5, PyArrayLike6, PyArrayLikeDyn, TypeMustMatch,
};
pub use crate::borrow::{
PyReadonlyArray, PyReadonlyArray1, PyReadonlyArray2, PyReadonlyArray3, PyReadonlyArray4,
Expand Down
10 changes: 7 additions & 3 deletions tests/array_like.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use ndarray::array;
use numpy::{get_array_module, PyArrayLike1, PyArrayLike2, PyArrayLikeDyn};
use numpy::{get_array_module, AllowTypeChange, PyArrayLike1, PyArrayLike2, PyArrayLikeDyn};
use pyo3::{
types::{IntoPyDict, PyDict},
Python,
Expand Down Expand Up @@ -36,7 +36,9 @@ fn convert_array_on_extract() {
let py_array = py
.eval("np.array([[1,2],[3,4]], dtype='int32')", Some(locals), None)
.unwrap();
let extracted_array = py_array.extract::<PyArrayLike2<f64, true>>().unwrap();
let extracted_array = py_array
.extract::<PyArrayLike2<f64, AllowTypeChange>>()
.unwrap();

assert_eq!(
array![[1_f64, 2_f64], [3_f64, 4_f64]],
Expand Down Expand Up @@ -130,7 +132,9 @@ fn unsafe_cast_with_coerce_works() {
None,
)
.unwrap();
let extracted_array = py_list.extract::<PyArrayLike1<i32, true>>().unwrap();
let extracted_array = py_list
.extract::<PyArrayLike1<i32, AllowTypeChange>>()
.unwrap();

assert_eq!(array![1, 2, 3, 4], extracted_array.as_array());
});
Expand Down

0 comments on commit d672028

Please sign in to comment.