From d6720284ff8ac2c680b8a192c243f55fcbf1b79b Mon Sep 17 00:00:00 2001 From: Adam Reichold Date: Sat, 8 Jul 2023 23:31:48 +0200 Subject: [PATCH] Work around MSRV problems with const generics using sealed trait. --- src/array_like.rs | 62 +++++++++++++++++++++++++++++++++------------ src/lib.rs | 4 +-- tests/array_like.rs | 10 +++++--- 3 files changed, 55 insertions(+), 21 deletions(-) diff --git a/src/array_like.rs b/src/array_like.rs index 18c424883..46e302ac3 100644 --- a/src/array_like.rs +++ b/src/array_like.rs @@ -1,22 +1,50 @@ +use std::marker::PhantomData; use std::ops::Deref; use ndarray::{Array1, Dimension, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn}; use pyo3::{intern, sync::GILOnceCell, types::PyDict, FromPyObject, Py, PyAny, PyResult}; +use crate::sealed::Sealed; use crate::{get_array_module, Element, IntoPyArray, PyArray, PyReadonlyArray}; +pub trait Coerce: Sealed { + const VAL: bool; +} + +/// TODO +#[derive(Debug)] +pub struct TypeMustMatch; + +impl Sealed for TypeMustMatch {} + +impl Coerce for TypeMustMatch { + const VAL: bool = false; +} + +/// TODO +#[derive(Debug)] +pub struct AllowTypeChange; + +impl Sealed for AllowTypeChange {} + +impl Coerce for AllowTypeChange { + const VAL: bool = true; +} + /// TODO #[derive(Debug)] #[repr(transparent)] -pub struct PyArrayLike<'py, T, D, const COERCE: bool = false>(PyReadonlyArray<'py, T, D>) +pub struct PyArrayLike<'py, T, D, C = TypeMustMatch>(PyReadonlyArray<'py, T, D>, PhantomData) where T: Element, - D: Dimension; + D: Dimension, + C: Coerce; -impl<'py, T, D, const COERCE: bool> Deref for PyArrayLike<'py, T, D, COERCE> +impl<'py, T, D, C> Deref for PyArrayLike<'py, T, D, C> where T: Element, D: Dimension, + C: Coerce, { type Target = PyReadonlyArray<'py, T, D>; @@ -25,27 +53,28 @@ where } } -impl<'py, T, D, const COERCE: bool> FromPyObject<'py> for PyArrayLike<'py, T, D, COERCE> +impl<'py, T, D, C> FromPyObject<'py> for PyArrayLike<'py, T, D, C> where T: Element, D: Dimension, + C: Coerce, Vec: FromPyObject<'py>, { fn extract(ob: &'py PyAny) -> PyResult { if let Ok(array) = ob.downcast::>() { - return Ok(Self(array.readonly())); + return Ok(Self(array.readonly(), PhantomData)); } let py = ob.py(); if matches!(D::NDIM, None | Some(1)) { if let Ok(vec) = ob.extract::>() { - let array = Array1::from_vec(vec) + let array = Array1::from(vec) .into_dimensionality() .expect("D being compatible to Ix1") .into_pyarray(py) .readonly(); - return Ok(Self(array)); + return Ok(Self(array, PhantomData)); } } @@ -57,7 +86,7 @@ where })? .as_ref(py); - let kwargs = if COERCE { + let kwargs = if C::VAL { let kwargs = PyDict::new(py); kwargs.set_item(intern!(py, "dtype"), T::get_dtype(py))?; Some(kwargs) @@ -65,21 +94,22 @@ where None }; - as_array.call((ob,), kwargs)?.extract().map(Self) + let array = as_array.call((ob,), kwargs)?.extract()?; + Ok(Self(array, PhantomData)) } } /// TODO -pub type PyArrayLike1<'py, T, const COERCE: bool = false> = PyArrayLike<'py, T, Ix1, COERCE>; +pub type PyArrayLike1<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix1, C>; /// TODO -pub type PyArrayLike2<'py, T, const COERCE: bool = false> = PyArrayLike<'py, T, Ix2, COERCE>; +pub type PyArrayLike2<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix2, C>; /// TODO -pub type PyArrayLike3<'py, T, const COERCE: bool = false> = PyArrayLike<'py, T, Ix3, COERCE>; +pub type PyArrayLike3<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix3, C>; /// TODO -pub type PyArrayLike4<'py, T, const COERCE: bool = false> = PyArrayLike<'py, T, Ix4, COERCE>; +pub type PyArrayLike4<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix4, C>; /// TODO -pub type PyArrayLike5<'py, T, const COERCE: bool = false> = PyArrayLike<'py, T, Ix5, COERCE>; +pub type PyArrayLike5<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix5, C>; /// TODO -pub type PyArrayLike6<'py, T, const COERCE: bool = false> = PyArrayLike<'py, T, Ix6, COERCE>; +pub type PyArrayLike6<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix6, C>; /// TODO -pub type PyArrayLikeDyn<'py, T, const COERCE: bool = false> = PyArrayLike<'py, T, IxDyn, COERCE>; +pub type PyArrayLikeDyn<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, IxDyn, C>; diff --git a/src/lib.rs b/src/lib.rs index 342d0ab36..b28260178 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -98,8 +98,8 @@ pub use crate::array::{ PyArray6, PyArrayDyn, }; pub use crate::array_like::{ - PyArrayLike, PyArrayLike1, PyArrayLike2, PyArrayLike3, PyArrayLike4, PyArrayLike5, - PyArrayLike6, PyArrayLikeDyn, + AllowTypeChange, PyArrayLike, PyArrayLike1, PyArrayLike2, PyArrayLike3, PyArrayLike4, + PyArrayLike5, PyArrayLike6, PyArrayLikeDyn, TypeMustMatch, }; pub use crate::borrow::{ PyReadonlyArray, PyReadonlyArray1, PyReadonlyArray2, PyReadonlyArray3, PyReadonlyArray4, diff --git a/tests/array_like.rs b/tests/array_like.rs index e0c43b591..1874cab0e 100644 --- a/tests/array_like.rs +++ b/tests/array_like.rs @@ -1,5 +1,5 @@ use ndarray::array; -use numpy::{get_array_module, PyArrayLike1, PyArrayLike2, PyArrayLikeDyn}; +use numpy::{get_array_module, AllowTypeChange, PyArrayLike1, PyArrayLike2, PyArrayLikeDyn}; use pyo3::{ types::{IntoPyDict, PyDict}, Python, @@ -36,7 +36,9 @@ fn convert_array_on_extract() { let py_array = py .eval("np.array([[1,2],[3,4]], dtype='int32')", Some(locals), None) .unwrap(); - let extracted_array = py_array.extract::>().unwrap(); + let extracted_array = py_array + .extract::>() + .unwrap(); assert_eq!( array![[1_f64, 2_f64], [3_f64, 4_f64]], @@ -130,7 +132,9 @@ fn unsafe_cast_with_coerce_works() { None, ) .unwrap(); - let extracted_array = py_list.extract::>().unwrap(); + let extracted_array = py_list + .extract::>() + .unwrap(); assert_eq!(array![1, 2, 3, 4], extracted_array.as_array()); });