Skip to content

Commit 0dbeaf3

Browse files
committed
Modify the docs and visibility of broadcast_with
1 parent 6c40d61 commit 0dbeaf3

File tree

2 files changed

+3
-42
lines changed

2 files changed

+3
-42
lines changed

src/impl_methods.rs

+3-12
Original file line numberDiff line numberDiff line change
@@ -1707,20 +1707,11 @@ where
17071707
unsafe { Some(ArrayView::new(self.ptr, dim, broadcast_strides)) }
17081708
}
17091709

1710-
/// Calculate the views of two ArrayBases after broadcasting each other, if possible.
1710+
/// For two arrays or views, find their common shape if possible and
1711+
/// broadcast them as array views into that shape.
17111712
///
17121713
/// Return `ShapeError` if their shapes can not be broadcast together.
1713-
///
1714-
/// ```
1715-
/// use ndarray::{arr1, arr2};
1716-
///
1717-
/// let a = arr2(&[[2], [3], [4]]);
1718-
/// let b = arr1(&[5, 6, 7]);
1719-
/// let (a1, b1) = a.broadcast_with(&b).unwrap();
1720-
/// assert_eq!(a1, arr2(&[[2, 2, 2], [3, 3, 3], [4, 4, 4]]));
1721-
/// assert_eq!(b1, arr2(&[[5, 6, 7], [5, 6, 7], [5, 6, 7]]));
1722-
/// ```
1723-
pub fn broadcast_with<'a, 'b, B, S2, E>(&'a self, other: &'b ArrayBase<S2, E>) ->
1714+
pub(crate) fn broadcast_with<'a, 'b, B, S2, E>(&'a self, other: &'b ArrayBase<S2, E>) ->
17241715
Result<(ArrayView<'a, A, <D as BroadcastShape<E>>::Output>, ArrayView<'b, B, <D as BroadcastShape<E>>::Output>), ShapeError>
17251716
where
17261717
S: Data<Elem=A>,

tests/broadcast.rs

-30
Original file line numberDiff line numberDiff line change
@@ -82,33 +82,3 @@ fn test_broadcast_1d() {
8282
println!("b2=\n{:?}", b2);
8383
assert_eq!(b0, b2);
8484
}
85-
86-
#[test]
87-
fn test_broadcast_with() {
88-
let a = arr2(&[[1., 2.], [3., 4.]]);
89-
let b = aview0(&1.);
90-
let (a1, b1) = a.broadcast_with(&b).unwrap();
91-
assert_eq!(a1, arr2(&[[1.0, 2.0], [3.0, 4.0]]));
92-
assert_eq!(b1, arr2(&[[1.0, 1.0], [1.0, 1.0]]));
93-
94-
let a = arr2(&[[2], [3], [4]]);
95-
let b = arr1(&[5, 6, 7]);
96-
let (a1, b1) = a.broadcast_with(&b).unwrap();
97-
assert_eq!(a1, arr2(&[[2, 2, 2], [3, 3, 3], [4, 4, 4]]));
98-
assert_eq!(b1, arr2(&[[5, 6, 7], [5, 6, 7], [5, 6, 7]]));
99-
100-
// Negative strides and non-contiguous memory
101-
let s = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12];
102-
let s = Array3::from_shape_vec((2, 3, 2).strides((1, 4, 2)), s.to_vec()).unwrap();
103-
let a = s.slice(s![..;-1,..;2,..]);
104-
let b = s.slice(s![..2, -1, ..]);
105-
let (a1, b1) = a.broadcast_with(&b).unwrap();
106-
assert_eq!(a1, arr3(&[[[2, 4], [10, 12]], [[1, 3], [9, 11]]]));
107-
assert_eq!(b1, arr3(&[[[9, 11], [10, 12]], [[9, 11], [10, 12]]]));
108-
109-
// ShapeError
110-
let a = arr2(&[[2, 2], [3, 3], [4, 4]]);
111-
let b = arr1(&[5, 6, 7]);
112-
let e = a.broadcast_with(&b);
113-
assert_eq!(e, Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)));
114-
}

0 commit comments

Comments
 (0)