Skip to content

Commit 7ae478c

Browse files
committed
Add range implementations of linspace and logspace
1 parent dc833d8 commit 7ae478c

File tree

5 files changed

+86
-78
lines changed

5 files changed

+86
-78
lines changed

benches/construct.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ fn zeros_f64(bench: &mut Bencher)
2121
#[bench]
2222
fn map_regular(bench: &mut test::Bencher)
2323
{
24-
let a = Array::linspace(0., 127., 128)
24+
let a = Array::linspace(0.0..=127.0, 128)
2525
.into_shape_with_order((8, 16))
2626
.unwrap();
2727
bench.iter(|| a.map(|&x| 2. * x));
@@ -31,7 +31,7 @@ fn map_regular(bench: &mut test::Bencher)
3131
#[bench]
3232
fn map_stride(bench: &mut test::Bencher)
3333
{
34-
let a = Array::linspace(0., 127., 256)
34+
let a = Array::linspace(0.0..=127.0, 256)
3535
.into_shape_with_order((8, 32))
3636
.unwrap();
3737
let av = a.slice(s![.., ..;2]);

src/impl_constructors.rs

Lines changed: 29 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ use rawpointer::PointerExt;
4444
///
4545
/// ## Constructor methods for one-dimensional arrays.
4646
impl<S, A> ArrayBase<S, Ix1>
47-
where S: DataOwned<Elem = A>
47+
where
48+
S: DataOwned<Elem = A>,
4849
{
4950
/// Create a one-dimensional array from a vector (no copying needed).
5051
///
@@ -55,13 +56,9 @@ where S: DataOwned<Elem = A>
5556
///
5657
/// let array = Array::from_vec(vec![1., 2., 3., 4.]);
5758
/// ```
58-
pub fn from_vec(v: Vec<A>) -> Self
59-
{
59+
pub fn from_vec(v: Vec<A>) -> Self {
6060
if mem::size_of::<A>() == 0 {
61-
assert!(
62-
v.len() <= isize::MAX as usize,
63-
"Length must fit in `isize`.",
64-
);
61+
assert!(v.len() <= isize::MAX as usize, "Length must fit in `isize`.",);
6562
}
6663
unsafe { Self::from_shape_vec_unchecked(v.len() as Ix, v) }
6764
}
@@ -76,8 +73,7 @@ where S: DataOwned<Elem = A>
7673
/// let array = Array::from_iter(0..10);
7774
/// ```
7875
#[allow(clippy::should_implement_trait)]
79-
pub fn from_iter<I: IntoIterator<Item = A>>(iterable: I) -> Self
80-
{
76+
pub fn from_iter<I: IntoIterator<Item = A>>(iterable: I) -> Self {
8177
Self::from_vec(iterable.into_iter().collect())
8278
}
8379

@@ -99,10 +95,12 @@ where S: DataOwned<Elem = A>
9995
/// assert!(array == arr1(&[0.0, 0.25, 0.5, 0.75, 1.0]))
10096
/// ```
10197
#[cfg(feature = "std")]
102-
pub fn linspace(start: A, end: A, n: usize) -> Self
103-
where A: Float
98+
pub fn linspace<R>(range: R, n: usize) -> Self
99+
where
100+
R: std::ops::RangeBounds<A>,
101+
A: Float,
104102
{
105-
Self::from(to_vec(linspace::linspace(start, end, n)))
103+
Self::from(to_vec(linspace::linspace(range, n)))
106104
}
107105

108106
/// Create a one-dimensional array with elements from `start` to `end`
@@ -118,7 +116,8 @@ where S: DataOwned<Elem = A>
118116
/// ```
119117
#[cfg(feature = "std")]
120118
pub fn range(start: A, end: A, step: A) -> Self
121-
where A: Float
119+
where
120+
A: Float,
122121
{
123122
Self::from(to_vec(linspace::range(start, end, step)))
124123
}
@@ -145,10 +144,12 @@ where S: DataOwned<Elem = A>
145144
/// # }
146145
/// ```
147146
#[cfg(feature = "std")]
148-
pub fn logspace(base: A, start: A, end: A, n: usize) -> Self
149-
where A: Float
147+
pub fn logspace<R>(base: A, range: R, n: usize) -> Self
148+
where
149+
R: std::ops::RangeBounds<A>,
150+
A: Float,
150151
{
151-
Self::from(to_vec(logspace::logspace(base, start, end, n)))
152+
Self::from(to_vec(logspace::logspace(base, range, n)))
152153
}
153154

154155
/// Create a one-dimensional array with `n` geometrically spaced elements
@@ -180,15 +181,17 @@ where S: DataOwned<Elem = A>
180181
/// ```
181182
#[cfg(feature = "std")]
182183
pub fn geomspace(start: A, end: A, n: usize) -> Option<Self>
183-
where A: Float
184+
where
185+
A: Float,
184186
{
185187
Some(Self::from(to_vec(geomspace::geomspace(start, end, n)?)))
186188
}
187189
}
188190

189191
/// ## Constructor methods for two-dimensional arrays.
190192
impl<S, A> ArrayBase<S, Ix2>
191-
where S: DataOwned<Elem = A>
193+
where
194+
S: DataOwned<Elem = A>,
192195
{
193196
/// Create an identity matrix of size `n` (square 2D array).
194197
///
@@ -470,14 +473,14 @@ where
470473
/// );
471474
/// ```
472475
pub fn from_shape_vec<Sh>(shape: Sh, v: Vec<A>) -> Result<Self, ShapeError>
473-
where Sh: Into<StrideShape<D>>
476+
where
477+
Sh: Into<StrideShape<D>>,
474478
{
475479
// eliminate the type parameter Sh as soon as possible
476480
Self::from_shape_vec_impl(shape.into(), v)
477481
}
478482

479-
fn from_shape_vec_impl(shape: StrideShape<D>, v: Vec<A>) -> Result<Self, ShapeError>
480-
{
483+
fn from_shape_vec_impl(shape: StrideShape<D>, v: Vec<A>) -> Result<Self, ShapeError> {
481484
let dim = shape.dim;
482485
let is_custom = shape.strides.is_custom();
483486
dimension::can_index_slice_with_strides(&v, &dim, &shape.strides, dimension::CanIndexCheckMode::OwnedMutable)?;
@@ -513,16 +516,16 @@ where
513516
/// 5. The strides must not allow any element to be referenced by two different
514517
/// indices.
515518
pub unsafe fn from_shape_vec_unchecked<Sh>(shape: Sh, v: Vec<A>) -> Self
516-
where Sh: Into<StrideShape<D>>
519+
where
520+
Sh: Into<StrideShape<D>>,
517521
{
518522
let shape = shape.into();
519523
let dim = shape.dim;
520524
let strides = shape.strides.strides_for_dim(&dim);
521525
Self::from_vec_dim_stride_unchecked(dim, strides, v)
522526
}
523527

524-
unsafe fn from_vec_dim_stride_unchecked(dim: D, strides: D, mut v: Vec<A>) -> Self
525-
{
528+
unsafe fn from_vec_dim_stride_unchecked(dim: D, strides: D, mut v: Vec<A>) -> Self {
526529
// debug check for issues that indicates wrong use of this constructor
527530
debug_assert!(dimension::can_index_slice(&v, &dim, &strides, CanIndexCheckMode::OwnedMutable).is_ok());
528531

@@ -595,7 +598,8 @@ where
595598
/// # let _ = shift_by_two;
596599
/// ```
597600
pub fn uninit<Sh>(shape: Sh) -> ArrayBase<S::MaybeUninit, D>
598-
where Sh: ShapeBuilder<Dim = D>
601+
where
602+
Sh: ShapeBuilder<Dim = D>,
599603
{
600604
unsafe {
601605
let shape = shape.into_shape_with_order();

src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ mod layout;
201201
mod linalg_traits;
202202
mod linspace;
203203
#[cfg(feature = "std")]
204-
pub use crate::linspace::{linspace, linspace_exclusive, range, Linspace};
204+
pub use crate::linspace::{linspace, range, Linspace};
205205
mod logspace;
206206
#[cfg(feature = "std")]
207207
pub use crate::logspace::{logspace, Logspace};

src/linspace.rs

Lines changed: 30 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -6,27 +6,29 @@
66
// option. This file may not be copied, modified, or distributed
77
// except according to those terms.
88
#![cfg(feature = "std")]
9+
10+
use std::ops::{Bound, RangeBounds};
11+
912
use num_traits::Float;
1013

1114
/// An iterator of a sequence of evenly spaced floats.
1215
///
1316
/// Iterator element type is `F`.
14-
pub struct Linspace<F>
15-
{
17+
pub struct Linspace<F> {
1618
start: F,
1719
step: F,
1820
index: usize,
1921
len: usize,
2022
}
2123

2224
impl<F> Iterator for Linspace<F>
23-
where F: Float
25+
where
26+
F: Float,
2427
{
2528
type Item = F;
2629

2730
#[inline]
28-
fn next(&mut self) -> Option<F>
29-
{
31+
fn next(&mut self) -> Option<F> {
3032
if self.index >= self.len {
3133
None
3234
} else {
@@ -38,19 +40,18 @@ where F: Float
3840
}
3941

4042
#[inline]
41-
fn size_hint(&self) -> (usize, Option<usize>)
42-
{
43+
fn size_hint(&self) -> (usize, Option<usize>) {
4344
let n = self.len - self.index;
4445
(n, Some(n))
4546
}
4647
}
4748

4849
impl<F> DoubleEndedIterator for Linspace<F>
49-
where F: Float
50+
where
51+
F: Float,
5052
{
5153
#[inline]
52-
fn next_back(&mut self) -> Option<F>
53-
{
54+
fn next_back(&mut self) -> Option<F> {
5455
if self.index >= self.len {
5556
None
5657
} else {
@@ -71,43 +72,31 @@ impl<F> ExactSizeIterator for Linspace<F> where Linspace<F>: Iterator {}
7172
/// The iterator element type is `F`, where `F` must implement [`Float`], e.g.
7273
/// [`f32`] or [`f64`].
7374
///
74-
/// **Panics** if converting `n - 1` to type `F` fails.
75+
/// ## Panics
76+
/// - If called with a range type other than `a..b` or `a..=b`.
77+
/// - If converting `n` to type `F` fails.
7578
#[inline]
76-
pub fn linspace<F>(a: F, b: F, n: usize) -> Linspace<F>
77-
where F: Float
79+
pub fn linspace<R, F>(range: R, n: usize) -> Linspace<F>
80+
where
81+
R: RangeBounds<F>,
82+
F: Float,
7883
{
79-
let step = if n > 1 {
80-
let num_steps = F::from(n - 1).expect("Converting number of steps to `A` must not fail.");
81-
(b - a) / num_steps
82-
} else {
83-
F::zero()
84+
let (a, b, num_steps) = match (range.start_bound(), range.end_bound()) {
85+
(Bound::Included(a), Bound::Included(b)) => {
86+
(*a, *b, F::from(n - 1).expect("Converting number of steps to `A` must not fail."))
87+
}
88+
(Bound::Included(a), Bound::Excluded(b)) => {
89+
(*a, *b, F::from(n).expect("Converting number of steps to `A` must not fail."))
90+
}
91+
_ => panic!("Only a..b and a..=b ranges are supported."),
8492
};
85-
Linspace {
86-
start: a,
87-
step,
88-
index: 0,
89-
len: n,
90-
}
91-
}
9293

93-
/// Return an iterator of evenly spaced floats.
94-
///
95-
/// The `Linspace` has `n` elements from `a` to `b` (exclusive).
96-
///
97-
/// The iterator element type is `F`, where `F` must implement [`Float`], e.g.
98-
/// [`f32`] or [`f64`].
99-
///
100-
/// **Panics** if converting `n` to type `F` fails.
101-
#[inline]
102-
pub fn linspace_exclusive<F>(a: F, b: F, n: usize) -> Linspace<F>
103-
where F: Float
104-
{
105-
let step = if n > 1 {
106-
let num_steps = F::from(n).expect("Converting number of steps to `A` must not fail.");
94+
let step = if num_steps > F::zero() {
10795
(b - a) / num_steps
10896
} else {
10997
F::zero()
11098
};
99+
111100
Linspace {
112101
start: a,
113102
step,
@@ -127,7 +116,8 @@ where F: Float
127116
/// **Panics** if converting `((b - a) / step).ceil()` to type `F` fails.
128117
#[inline]
129118
pub fn range<F>(a: F, b: F, step: F) -> Linspace<F>
130-
where F: Float
119+
where
120+
F: Float,
131121
{
132122
let len = b - a;
133123
let steps = F::ceil(len / step);

src/logspace.rs

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
// option. This file may not be copied, modified, or distributed
77
// except according to those terms.
88
#![cfg(feature = "std")]
9+
10+
use std::ops::{Bound, RangeBounds};
911
use num_traits::Float;
1012

1113
/// An iterator of a sequence of logarithmically spaced number.
@@ -79,15 +81,27 @@ impl<F> ExactSizeIterator for Logspace<F> where Logspace<F>: Iterator {}
7981
///
8082
/// **Panics** if converting `n - 1` to type `F` fails.
8183
#[inline]
82-
pub fn logspace<F>(base: F, a: F, b: F, n: usize) -> Logspace<F>
83-
where F: Float
84+
pub fn logspace<R, F>(base: F, range: R, n: usize) -> Logspace<F>
85+
where
86+
R: RangeBounds<F>,
87+
F: Float,
8488
{
85-
let step = if n > 1 {
86-
let num_steps = F::from(n - 1).expect("Converting number of steps to `A` must not fail.");
89+
let (a, b, num_steps) = match (range.start_bound(), range.end_bound()) {
90+
(Bound::Included(a), Bound::Included(b)) => {
91+
(*a, *b, F::from(n - 1).expect("Converting number of steps to `A` must not fail."))
92+
}
93+
(Bound::Included(a), Bound::Excluded(b)) => {
94+
(*a, *b, F::from(n).expect("Converting number of steps to `A` must not fail."))
95+
}
96+
_ => panic!("Only a..b and a..=b ranges are supported."),
97+
};
98+
99+
let step = if num_steps > F::zero() {
87100
(b - a) / num_steps
88101
} else {
89102
F::zero()
90103
};
104+
91105
Logspace {
92106
sign: base.signum(),
93107
base: base.abs(),
@@ -110,23 +124,23 @@ mod tests
110124
use crate::{arr1, Array1};
111125
use approx::assert_abs_diff_eq;
112126

113-
let array: Array1<_> = logspace(10.0, 0.0, 3.0, 4).collect();
127+
let array: Array1<_> = logspace(10.0, 0.0..=3.0, 4).collect();
114128
assert_abs_diff_eq!(array, arr1(&[1e0, 1e1, 1e2, 1e3]), epsilon = 1e-12);
115129

116-
let array: Array1<_> = logspace(10.0, 3.0, 0.0, 4).collect();
130+
let array: Array1<_> = logspace(10.0, 3.0..=0.0, 4).collect();
117131
assert_abs_diff_eq!(array, arr1(&[1e3, 1e2, 1e1, 1e0]), epsilon = 1e-12);
118132

119-
let array: Array1<_> = logspace(-10.0, 3.0, 0.0, 4).collect();
133+
let array: Array1<_> = logspace(-10.0, 3.0..=0.0, 4).collect();
120134
assert_abs_diff_eq!(array, arr1(&[-1e3, -1e2, -1e1, -1e0]), epsilon = 1e-12);
121135

122-
let array: Array1<_> = logspace(-10.0, 0.0, 3.0, 4).collect();
136+
let array: Array1<_> = logspace(-10.0, 0.0..=3.0, 4).collect();
123137
assert_abs_diff_eq!(array, arr1(&[-1e0, -1e1, -1e2, -1e3]), epsilon = 1e-12);
124138
}
125139

126140
#[test]
127141
fn iter_forward()
128142
{
129-
let mut iter = logspace(10.0f64, 0.0, 3.0, 4);
143+
let mut iter = logspace(10.0f64, 0.0..=3.0, 4);
130144

131145
assert!(iter.size_hint() == (4, Some(4)));
132146

@@ -142,7 +156,7 @@ mod tests
142156
#[test]
143157
fn iter_backward()
144158
{
145-
let mut iter = logspace(10.0f64, 0.0, 3.0, 4);
159+
let mut iter = logspace(10.0f64, 0.0..=3.0, 4);
146160

147161
assert!(iter.size_hint() == (4, Some(4)));
148162

0 commit comments

Comments
 (0)