Skip to content

Commit 23342b1

Browse files
committed
rebase and use map_collect_owned in impl_ops.rs
1 parent 1055ebf commit 23342b1

File tree

4 files changed

+33
-68
lines changed

4 files changed

+33
-68
lines changed

src/data_traits.rs

+2-19
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
1111
use rawpointer::PointerExt;
1212

13-
use std::mem::{self, size_of};use std::mem::MaybeUninit;
13+
use std::mem::{self, size_of};
14+
use std::mem::MaybeUninit;
1415
use std::ptr::NonNull;
1516
use alloc::sync::Arc;
1617
use alloc::vec::Vec;
@@ -620,21 +621,3 @@ impl<'a, A: 'a, B: 'a> RawDataSubst<B> for ViewRepr<&'a mut A> {
620621
}
621622
}
622623

623-
/// Array representation trait.
624-
///
625-
/// The MaybeUninitSubst trait maps the MaybeUninit type of element, while
626-
/// mapping the MaybeUninit type back to origin element type.
627-
///
628-
/// For example, `MaybeUninitSubst` can map the type `OwnedRepr<A>` to `OwnedRepr<MaybeUninit<A>>`,
629-
/// and use `Output as RawDataSubst` to map `OwnedRepr<MaybeUninit<A>>` back to `OwnedRepr<A>`.
630-
pub trait MaybeUninitSubst<A>: DataOwned<Elem = A> {
631-
type Output: DataOwned<Elem = MaybeUninit<A>> + RawDataSubst<A, Output=Self, Elem = MaybeUninit<A>>;
632-
}
633-
634-
impl<A> MaybeUninitSubst<A> for OwnedRepr<A> {
635-
type Output = OwnedRepr<MaybeUninit<A>>;
636-
}
637-
638-
impl<A> MaybeUninitSubst<A> for OwnedArcRepr<A> {
639-
type Output = OwnedArcRepr<MaybeUninit<A>>;
640-
}

src/impl_methods.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -1961,7 +1961,7 @@ where
19611961
self.unordered_foreach_mut(move |elt| *elt = x.clone());
19621962
}
19631963

1964-
fn zip_mut_with_same_shape<B, S2, E, F>(&mut self, rhs: &ArrayBase<S2, E>, mut f: F)
1964+
pub(crate) fn zip_mut_with_same_shape<B, S2, E, F>(&mut self, rhs: &ArrayBase<S2, E>, mut f: F)
19651965
where
19661966
S: DataMut,
19671967
S2: Data<Elem = B>,

src/impl_ops.rs

+29-47
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
// except according to those terms.
88

99
use crate::dimension::BroadcastShape;
10-
use crate::data_traits::MaybeUninitSubst;
1110
use crate::Zip;
1211
use num_complex::Complex;
1312

@@ -68,8 +67,8 @@ impl<A, B, S, S2, D, E> $trt<ArrayBase<S2, E>> for ArrayBase<S, D>
6867
where
6968
A: Clone + $trt<B, Output=A>,
7069
B: Clone,
71-
S: DataOwned<Elem=A> + DataMut + MaybeUninitSubst<A>,
72-
<S as MaybeUninitSubst<A>>::Output: DataMut,
70+
S: DataOwned<Elem=A> + DataMut,
71+
S::MaybeUninit: DataMut,
7372
S2: Data<Elem=B>,
7473
D: Dimension + BroadcastShape<E>,
7574
E: Dimension,
@@ -96,38 +95,24 @@ impl<'a, A, B, S, S2, D, E> $trt<&'a ArrayBase<S2, E>> for ArrayBase<S, D>
9695
where
9796
A: Clone + $trt<B, Output=A>,
9897
B: Clone,
99-
S: DataOwned<Elem=A> + DataMut + MaybeUninitSubst<A>,
100-
<S as MaybeUninitSubst<A>>::Output: DataMut,
98+
S: DataOwned<Elem=A> + DataMut,
99+
S::MaybeUninit: DataMut,
101100
S2: Data<Elem=B>,
102101
D: Dimension + BroadcastShape<E>,
103102
E: Dimension,
104103
{
105104
type Output = ArrayBase<S, <D as BroadcastShape<E>>::Output>;
106105
fn $mth(self, rhs: &ArrayBase<S2, E>) -> Self::Output
107106
{
108-
let shape = self.dim.broadcast_shape(&rhs.dim).unwrap();
109-
if shape.slice() == self.dim.slice() {
107+
if self.ndim() == rhs.ndim() && self.shape() == rhs.shape() {
110108
let mut out = self.into_dimensionality::<<D as BroadcastShape<E>>::Output>().unwrap();
111-
out.zip_mut_with(rhs, |x, y| {
112-
*x = x.clone() $operator y.clone();
113-
});
109+
out.zip_mut_with_same_shape(rhs, clone_iopf(A::$mth));
114110
out
115111
} else {
112+
let shape = self.dim.broadcast_shape(&rhs.dim).unwrap();
116113
let lhs = self.broadcast(shape.clone()).unwrap();
117-
let rhs = rhs.broadcast(shape.clone()).unwrap();
118-
// SAFETY: Overwrite all the elements in the array after
119-
// it is created via `raw_view_mut`.
120-
unsafe {
121-
let mut out =ArrayBase::<<S as MaybeUninitSubst<A>>::Output, <D as BroadcastShape<E>>::Output>::maybe_uninit(shape.into_pattern());
122-
let output_view = out.raw_view_mut().cast::<A>();
123-
Zip::from(&lhs).and(&rhs)
124-
.and(output_view)
125-
.collect_with_partial(|x, y| {
126-
x.clone() $operator y.clone()
127-
})
128-
.release_ownership();
129-
out.assume_init()
130-
}
114+
let rhs = rhs.broadcast(shape).unwrap();
115+
Zip::from(&lhs).and(&rhs).map_collect_owned(clone_opf(A::$mth))
131116
}
132117
}
133118
}
@@ -148,38 +133,24 @@ where
148133
A: Clone + $trt<B, Output=B>,
149134
B: Clone,
150135
S: Data<Elem=A>,
151-
S2: DataOwned<Elem=B> + DataMut + MaybeUninitSubst<B>,
152-
<S2 as MaybeUninitSubst<B>>::Output: DataMut,
136+
S2: DataOwned<Elem=B> + DataMut,
137+
S2::MaybeUninit: DataMut,
153138
D: Dimension,
154139
E: Dimension + BroadcastShape<D>,
155140
{
156141
type Output = ArrayBase<S2, <E as BroadcastShape<D>>::Output>;
157142
fn $mth(self, rhs: ArrayBase<S2, E>) -> Self::Output
158143
where
159144
{
160-
let shape = rhs.dim.broadcast_shape(&self.dim).unwrap();
161-
if shape.slice() == rhs.dim.slice() {
145+
if self.ndim() == rhs.ndim() && self.shape() == rhs.shape() {
162146
let mut out = rhs.into_dimensionality::<<E as BroadcastShape<D>>::Output>().unwrap();
163-
out.zip_mut_with(self, |x, y| {
164-
*x = y.clone() $operator x.clone();
165-
});
147+
out.zip_mut_with_same_shape(self, clone_iopf_rev(A::$mth));
166148
out
167149
} else {
150+
let shape = rhs.dim.broadcast_shape(&self.dim).unwrap();
168151
let lhs = self.broadcast(shape.clone()).unwrap();
169-
let rhs = rhs.broadcast(shape.clone()).unwrap();
170-
// SAFETY: Overwrite all the elements in the array after
171-
// it is created via `raw_view_mut`.
172-
unsafe {
173-
let mut out =ArrayBase::<<S2 as MaybeUninitSubst<B>>::Output, <E as BroadcastShape<D>>::Output>::maybe_uninit(shape.into_pattern());
174-
let output_view = out.raw_view_mut().cast::<B>();
175-
Zip::from(&lhs).and(&rhs)
176-
.and(output_view)
177-
.collect_with_partial(|x, y| {
178-
x.clone() $operator y.clone()
179-
})
180-
.release_ownership();
181-
out.assume_init()
182-
}
152+
let rhs = rhs.broadcast(shape).unwrap();
153+
Zip::from(&lhs).and(&rhs).map_collect_owned(clone_opf(A::$mth))
183154
}
184155
}
185156
}
@@ -207,8 +178,7 @@ where
207178
let shape = self.dim.broadcast_shape(&rhs.dim).unwrap();
208179
let lhs = self.broadcast(shape.clone()).unwrap();
209180
let rhs = rhs.broadcast(shape).unwrap();
210-
let out = Zip::from(&lhs).and(&rhs).map_collect(|x, y| x.clone() $operator y.clone());
211-
out
181+
Zip::from(&lhs).and(&rhs).map_collect(clone_opf(A::$mth))
212182
}
213183
}
214184

@@ -313,6 +283,18 @@ mod arithmetic_ops {
313283
use num_complex::Complex;
314284
use std::ops::*;
315285

286+
fn clone_opf<A: Clone, B: Clone, C>(f: impl Fn(A, B) -> C) -> impl FnMut(&A, &B) -> C {
287+
move |x, y| f(x.clone(), y.clone())
288+
}
289+
290+
fn clone_iopf<A: Clone, B: Clone>(f: impl Fn(A, B) -> A) -> impl FnMut(&mut A, &B) {
291+
move |x, y| *x = f(x.clone(), y.clone())
292+
}
293+
294+
fn clone_iopf_rev<A: Clone, B: Clone>(f: impl Fn(A, B) -> B) -> impl FnMut(&mut B, &A) {
295+
move |x, y| *x = f(y.clone(), x.clone())
296+
}
297+
316298
impl_binary_op!(Add, +, add, +=, "addition");
317299
impl_binary_op!(Sub, -, sub, -=, "subtraction");
318300
impl_binary_op!(Mul, *, mul, *=, "multiplication");

src/lib.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ pub use crate::aliases::*;
179179

180180
pub use crate::data_traits::{
181181
Data, DataMut, DataOwned, DataShared, RawData, RawDataClone, RawDataMut,
182-
RawDataSubst, MaybeUninitSubst,
182+
RawDataSubst,
183183
};
184184

185185
mod free_functions;

0 commit comments

Comments
 (0)