7
7
// except according to those terms.
8
8
9
9
use crate :: dimension:: BroadcastShape ;
10
- use crate :: data_traits:: MaybeUninitSubst ;
11
10
use crate :: Zip ;
12
11
use num_complex:: Complex ;
13
12
@@ -68,8 +67,8 @@ impl<A, B, S, S2, D, E> $trt<ArrayBase<S2, E>> for ArrayBase<S, D>
68
67
where
69
68
A : Clone + $trt<B , Output =A >,
70
69
B : Clone ,
71
- S : DataOwned <Elem =A > + DataMut + MaybeUninitSubst < A > ,
72
- < S as MaybeUninitSubst < A >> :: Output : DataMut ,
70
+ S : DataOwned <Elem =A > + DataMut ,
71
+ S :: MaybeUninit : DataMut ,
73
72
S2 : Data <Elem =B >,
74
73
D : Dimension + BroadcastShape <E >,
75
74
E : Dimension ,
@@ -96,38 +95,24 @@ impl<'a, A, B, S, S2, D, E> $trt<&'a ArrayBase<S2, E>> for ArrayBase<S, D>
96
95
where
97
96
A : Clone + $trt<B , Output =A >,
98
97
B : Clone ,
99
- S : DataOwned <Elem =A > + DataMut + MaybeUninitSubst < A > ,
100
- < S as MaybeUninitSubst < A >> :: Output : DataMut ,
98
+ S : DataOwned <Elem =A > + DataMut ,
99
+ S :: MaybeUninit : DataMut ,
101
100
S2 : Data <Elem =B >,
102
101
D : Dimension + BroadcastShape <E >,
103
102
E : Dimension ,
104
103
{
105
104
type Output = ArrayBase <S , <D as BroadcastShape <E >>:: Output >;
106
105
fn $mth( self , rhs: & ArrayBase <S2 , E >) -> Self :: Output
107
106
{
108
- let shape = self . dim. broadcast_shape( & rhs. dim) . unwrap( ) ;
109
- if shape. slice( ) == self . dim. slice( ) {
107
+ if self . ndim( ) == rhs. ndim( ) && self . shape( ) == rhs. shape( ) {
110
108
let mut out = self . into_dimensionality:: <<D as BroadcastShape <E >>:: Output >( ) . unwrap( ) ;
111
- out. zip_mut_with( rhs, |x, y| {
112
- * x = x. clone( ) $operator y. clone( ) ;
113
- } ) ;
109
+ out. zip_mut_with_same_shape( rhs, clone_iopf( A :: $mth) ) ;
114
110
out
115
111
} else {
112
+ let shape = self . dim. broadcast_shape( & rhs. dim) . unwrap( ) ;
116
113
let lhs = self . broadcast( shape. clone( ) ) . unwrap( ) ;
117
- let rhs = rhs. broadcast( shape. clone( ) ) . unwrap( ) ;
118
- // SAFETY: Overwrite all the elements in the array after
119
- // it is created via `raw_view_mut`.
120
- unsafe {
121
- let mut out =ArrayBase :: <<S as MaybeUninitSubst <A >>:: Output , <D as BroadcastShape <E >>:: Output >:: maybe_uninit( shape. into_pattern( ) ) ;
122
- let output_view = out. raw_view_mut( ) . cast:: <A >( ) ;
123
- Zip :: from( & lhs) . and( & rhs)
124
- . and( output_view)
125
- . collect_with_partial( |x, y| {
126
- x. clone( ) $operator y. clone( )
127
- } )
128
- . release_ownership( ) ;
129
- out. assume_init( )
130
- }
114
+ let rhs = rhs. broadcast( shape) . unwrap( ) ;
115
+ Zip :: from( & lhs) . and( & rhs) . map_collect_owned( clone_opf( A :: $mth) )
131
116
}
132
117
}
133
118
}
@@ -148,38 +133,24 @@ where
148
133
A : Clone + $trt<B , Output =B >,
149
134
B : Clone ,
150
135
S : Data <Elem =A >,
151
- S2 : DataOwned <Elem =B > + DataMut + MaybeUninitSubst < B > ,
152
- < S2 as MaybeUninitSubst < B >> :: Output : DataMut ,
136
+ S2 : DataOwned <Elem =B > + DataMut ,
137
+ S2 :: MaybeUninit : DataMut ,
153
138
D : Dimension ,
154
139
E : Dimension + BroadcastShape <D >,
155
140
{
156
141
type Output = ArrayBase <S2 , <E as BroadcastShape <D >>:: Output >;
157
142
fn $mth( self , rhs: ArrayBase <S2 , E >) -> Self :: Output
158
143
where
159
144
{
160
- let shape = rhs. dim. broadcast_shape( & self . dim) . unwrap( ) ;
161
- if shape. slice( ) == rhs. dim. slice( ) {
145
+ if self . ndim( ) == rhs. ndim( ) && self . shape( ) == rhs. shape( ) {
162
146
let mut out = rhs. into_dimensionality:: <<E as BroadcastShape <D >>:: Output >( ) . unwrap( ) ;
163
- out. zip_mut_with( self , |x, y| {
164
- * x = y. clone( ) $operator x. clone( ) ;
165
- } ) ;
147
+ out. zip_mut_with_same_shape( self , clone_iopf_rev( A :: $mth) ) ;
166
148
out
167
149
} else {
150
+ let shape = rhs. dim. broadcast_shape( & self . dim) . unwrap( ) ;
168
151
let lhs = self . broadcast( shape. clone( ) ) . unwrap( ) ;
169
- let rhs = rhs. broadcast( shape. clone( ) ) . unwrap( ) ;
170
- // SAFETY: Overwrite all the elements in the array after
171
- // it is created via `raw_view_mut`.
172
- unsafe {
173
- let mut out =ArrayBase :: <<S2 as MaybeUninitSubst <B >>:: Output , <E as BroadcastShape <D >>:: Output >:: maybe_uninit( shape. into_pattern( ) ) ;
174
- let output_view = out. raw_view_mut( ) . cast:: <B >( ) ;
175
- Zip :: from( & lhs) . and( & rhs)
176
- . and( output_view)
177
- . collect_with_partial( |x, y| {
178
- x. clone( ) $operator y. clone( )
179
- } )
180
- . release_ownership( ) ;
181
- out. assume_init( )
182
- }
152
+ let rhs = rhs. broadcast( shape) . unwrap( ) ;
153
+ Zip :: from( & lhs) . and( & rhs) . map_collect_owned( clone_opf( A :: $mth) )
183
154
}
184
155
}
185
156
}
@@ -207,8 +178,7 @@ where
207
178
let shape = self . dim. broadcast_shape( & rhs. dim) . unwrap( ) ;
208
179
let lhs = self . broadcast( shape. clone( ) ) . unwrap( ) ;
209
180
let rhs = rhs. broadcast( shape) . unwrap( ) ;
210
- let out = Zip :: from( & lhs) . and( & rhs) . map_collect( |x, y| x. clone( ) $operator y. clone( ) ) ;
211
- out
181
+ Zip :: from( & lhs) . and( & rhs) . map_collect( clone_opf( A :: $mth) )
212
182
}
213
183
}
214
184
@@ -313,6 +283,18 @@ mod arithmetic_ops {
313
283
use num_complex:: Complex ;
314
284
use std:: ops:: * ;
315
285
286
+ fn clone_opf < A : Clone , B : Clone , C > ( f : impl Fn ( A , B ) -> C ) -> impl FnMut ( & A , & B ) -> C {
287
+ move |x, y| f ( x. clone ( ) , y. clone ( ) )
288
+ }
289
+
290
+ fn clone_iopf < A : Clone , B : Clone > ( f : impl Fn ( A , B ) -> A ) -> impl FnMut ( & mut A , & B ) {
291
+ move |x, y| * x = f ( x. clone ( ) , y. clone ( ) )
292
+ }
293
+
294
+ fn clone_iopf_rev < A : Clone , B : Clone > ( f : impl Fn ( A , B ) -> B ) -> impl FnMut ( & mut B , & A ) {
295
+ move |x, y| * x = f ( y. clone ( ) , x. clone ( ) )
296
+ }
297
+
316
298
impl_binary_op ! ( Add , +, add, +=, "addition" ) ;
317
299
impl_binary_op ! ( Sub , -, sub, -=, "subtraction" ) ;
318
300
impl_binary_op ! ( Mul , * , mul, *=, "multiplication" ) ;
0 commit comments