Skip to content

Commit 1c4fd07

Browse files
separate zipped_rw macro
1 parent 9a0f0a5 commit 1c4fd07

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

57 files changed

+410
-460
lines changed

src/col/colmut.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use crate::{
44
iter,
55
iter::chunks::ChunkPolicy,
66
row::{RowMut, RowRef},
7-
unzipped, zipped, Idx, IdxInc, Unbind,
7+
unzipped, zipped_rw, Idx, IdxInc, Unbind,
88
};
99

1010
/// Mutable view over a column vector, similar to a mutable reference to a strided [prim@slice].
@@ -542,7 +542,7 @@ impl<'a, E: Entity, R: Shape> ColMut<'a, E, R> {
542542
this: ColMut<'_, E, R>,
543543
other: ColRef<'_, ViewE, R>,
544544
) {
545-
zipped!(__rw, this, other)
545+
zipped_rw!(this, other)
546546
.for_each(|unzipped!(mut dst, src)| dst.write(src.read().canonicalize()));
547547
}
548548
implementation(self.rb_mut(), other.as_col_ref())
@@ -554,7 +554,7 @@ impl<'a, E: Entity, R: Shape> ColMut<'a, E, R> {
554554
where
555555
E: ComplexField,
556556
{
557-
zipped!(__rw, self.rb_mut()).for_each(
557+
zipped_rw!(self.rb_mut()).for_each(
558558
#[inline(always)]
559559
|unzipped!(mut x)| x.write(E::faer_zero()),
560560
);
@@ -563,7 +563,7 @@ impl<'a, E: Entity, R: Shape> ColMut<'a, E, R> {
563563
/// Fills the elements of `self` with copies of `constant`.
564564
#[track_caller]
565565
pub fn fill(&mut self, constant: E) {
566-
zipped!(__rw, (*self).rb_mut()).for_each(
566+
zipped_rw!((*self).rb_mut()).for_each(
567567
#[inline(always)]
568568
|unzipped!(mut x)| x.write(constant),
569569
);

src/col/colown.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1163,7 +1163,7 @@ impl<E: Entity, R: Shape> Clone for Col<E, R> {
11631163
}
11641164
fn clone_from(&mut self, other: &Self) {
11651165
if self.nrows() == other.nrows() {
1166-
crate::zipped!(__rw, self, other)
1166+
crate::zipped_rw!(self, other)
11671167
.for_each(|crate::unzipped!(mut dst, src)| dst.write(src.read()));
11681168
} else {
11691169
if !R::IS_BOUND {

src/lib.rs

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,7 @@ impl Conj {
364364
///
365365
/// # Example
366366
/// ```
367-
/// use faer::{mat, unzipped, zipped, Mat};
367+
/// use faer::{mat, unzipped, zipped_rw, Mat};
368368
///
369369
/// let nrows = 2;
370370
/// let ncols = 3;
@@ -385,28 +385,34 @@ impl Conj {
385385
/// ```
386386
#[macro_export]
387387
macro_rules! zipped {
388-
(__rw, $head: expr $(,)?) => {
389-
$crate::linalg::zip::LastEq($crate::linalg::zip::ViewMut::view_mut(&mut { $head }))
388+
($head: expr $(,)?) => {
389+
$crate::linalg::zip::LastEq($crate::linalg::zip::RefWrapper($crate::linalg::zip::ViewMut::view_mut(&mut { $head })))
390390
};
391391

392-
(__rw, $head: expr, $($tail: expr),* $(,)?) => {
393-
$crate::linalg::zip::ZipEq::new($crate::linalg::zip::ViewMut::view_mut(&mut { $head }), $crate::zipped!(__rw, $($tail,)*))
392+
($head: expr, $($tail: expr),* $(,)?) => {
393+
$crate::linalg::zip::ZipEq::new($crate::linalg::zip::RefWrapper($crate::linalg::zip::ViewMut::view_mut(&mut { $head })), $crate::zipped!( $($tail,)*))
394394
};
395395

396+
}
397+
398+
/// Like the [`zipped!`] macro, but is compatible with potentially uninit values by not forming
399+
/// references.
400+
#[macro_export]
401+
macro_rules! zipped_rw {
396402
($head: expr $(,)?) => {
397-
$crate::linalg::zip::LastEq($crate::linalg::zip::RefWrapper($crate::linalg::zip::ViewMut::view_mut(&mut { $head })))
403+
$crate::linalg::zip::LastEq($crate::linalg::zip::ViewMut::view_mut(&mut { $head }))
398404
};
399405

400406
($head: expr, $($tail: expr),* $(,)?) => {
401-
$crate::linalg::zip::ZipEq::new($crate::linalg::zip::RefWrapper($crate::linalg::zip::ViewMut::view_mut(&mut { $head })), $crate::zipped!($($tail,)*))
407+
$crate::linalg::zip::ZipEq::new($crate::linalg::zip::ViewMut::view_mut(&mut { $head }), $crate::zipped_rw!($($tail,)*))
402408
};
403409
}
404410

405-
/// Used to undo the zipping by the [`zipped!`] macro.
411+
/// Used to undo the zipping by the [`zipped_rw!`] macro.
406412
///
407413
/// # Example
408414
/// ```
409-
/// use faer::{mat, unzipped, zipped, Mat};
415+
/// use faer::{mat, unzipped, zipped_rw, Mat};
410416
///
411417
/// let nrows = 2;
412418
/// let ncols = 3;
@@ -415,7 +421,7 @@ macro_rules! zipped {
415421
/// let b = mat![[7.0, 9.0, 11.0], [8.0, 10.0, 12.0]];
416422
/// let mut sum = Mat::<f64>::zeros(nrows, ncols);
417423
///
418-
/// zipped!(sum.as_mut(), a.as_ref(), b.as_ref()).for_each(|unzipped!(mut sum, a, b)| {
424+
/// zipped_rw!(sum.as_mut(), a.as_ref(), b.as_ref()).for_each(|unzipped!(mut sum, a, b)| {
419425
/// *sum = a + b;
420426
/// });
421427
///
@@ -985,7 +991,8 @@ pub mod prelude {
985991
pub use crate::{
986992
col,
987993
complex_native::{c32, c64},
988-
mat, row, unzipped, zipped, Col, ColMut, ColRef, Mat, MatMut, MatRef, Row, RowMut, RowRef,
994+
mat, row, unzipped, zipped_rw, Col, ColMut, ColRef, Mat, MatMut, MatRef, Row, RowMut,
995+
RowRef,
989996
};
990997

991998
pub use crate::linalg::solvers::{

src/linalg/cholesky/bunch_kaufman/mod.rs

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ use crate::{
1313
},
1414
},
1515
perm::{permute_rows, swap_cols_idx as swap_cols, swap_rows_idx as swap_rows, PermRef},
16-
unzipped, zipped, ColMut, ColRef, Conj, Index, MatMut, MatRef, Parallelism, SignedIndex,
16+
unzipped, zipped_rw, ColMut, ColRef, Conj, Index, MatMut, MatRef, Parallelism, SignedIndex,
1717
};
1818
use dyn_stack::{PodStack, SizeOverflow, StackReq};
1919
use faer_entity::{ComplexField, Entity, RealField};
@@ -235,8 +235,7 @@ pub mod compute {
235235
if abs_akk >= colmax.faer_mul(alpha) {
236236
kp = k;
237237
} else {
238-
zipped!(
239-
__rw,
238+
zipped_rw!(
240239
w.rb_mut().subrows_mut(k, imax - k).col_mut(k + 1),
241240
a.rb().row(imax).subcols(k, imax - k).transpose(),
242241
)
@@ -326,9 +325,9 @@ pub mod compute {
326325
let d11 = d11.faer_inv();
327326

328327
let x = a.rb_mut().subrows_mut(k + 1, n - k - 1).col_mut(k);
329-
zipped!(__rw, x)
328+
zipped_rw!(x)
330329
.for_each(|unzipped!(mut x)| x.write(x.read().faer_scale_real(d11)));
331-
zipped!(__rw, w.rb_mut().subrows_mut(k + 1, n - k - 1).col_mut(k))
330+
zipped_rw!(w.rb_mut().subrows_mut(k + 1, n - k - 1).col_mut(k))
332331
.for_each(|unzipped!(mut x)| x.write(x.read().faer_conj()));
333332
} else {
334333
let dd = w.read(k + 1, k).faer_abs();
@@ -406,13 +405,10 @@ pub mod compute {
406405
a.write(j, k + 1, wkp1);
407406
}
408407

409-
zipped!(__rw, w.rb_mut().subrows_mut(k + 1, n - k - 1).col_mut(k))
408+
zipped_rw!(w.rb_mut().subrows_mut(k + 1, n - k - 1).col_mut(k))
409+
.for_each(|unzipped!(mut x)| x.write(x.read().faer_conj()));
410+
zipped_rw!(w.rb_mut().subrows_mut(k + 2, n - k - 2).col_mut(k + 1))
410411
.for_each(|unzipped!(mut x)| x.write(x.read().faer_conj()));
411-
zipped!(
412-
__rw,
413-
w.rb_mut().subrows_mut(k + 2, n - k - 2).col_mut(k + 1)
414-
)
415-
.for_each(|unzipped!(mut x)| x.write(x.read().faer_conj()));
416412
}
417413
}
418414

@@ -439,7 +435,7 @@ pub mod compute {
439435
parallelism,
440436
);
441437

442-
zipped!(__rw, a_right.diagonal_mut().column_vector_mut())
438+
zipped_rw!(a_right.diagonal_mut().column_vector_mut())
443439
.for_each(|unzipped!(mut x)| x.write(E::faer_from_real(x.read().faer_real())));
444440

445441
let mut j = k - 1;
@@ -596,7 +592,7 @@ pub mod compute {
596592
}
597593
make_real(trailing.rb_mut(), j, j);
598594
}
599-
zipped!(__rw, x)
595+
zipped_rw!(x)
600596
.for_each(|unzipped!(mut x)| x.write(x.read().faer_scale_real(d11)));
601597
} else {
602598
let d21 = a.read(k + 1, k).faer_abs();
@@ -1042,7 +1038,7 @@ mod tests {
10421038

10431039
let err = &a * &x - &rhs;
10441040
let mut max = 0.0;
1045-
zipped!(__rw, err.as_ref()).for_each(|unzipped!(err)| {
1041+
zipped_rw!(err.as_ref()).for_each(|unzipped!(err)| {
10461042
let err = err.read().abs();
10471043
if err > max {
10481044
max = err
@@ -1099,7 +1095,7 @@ mod tests {
10991095

11001096
let err = a.conjugate() * &x - &rhs;
11011097
let mut max = 0.0;
1102-
zipped!(__rw, err.as_ref()).for_each(|unzipped!(err)| {
1098+
zipped_rw!(err.as_ref()).for_each(|unzipped!(err)| {
11031099
let err = err.read().abs();
11041100
if err > max {
11051101
max = err

src/linalg/cholesky/ldlt_diagonal/compute.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use crate::{
1010
},
1111
unzipped,
1212
utils::{simd::*, slice::*, DivCeil},
13-
zipped, ComplexField, MatMut, MatRef, Parallelism,
13+
zipped_rw, ComplexField, MatMut, MatRef, Parallelism,
1414
};
1515
use core::{convert::Infallible, marker::PhantomData};
1616
use dyn_stack::{PodStack, SizeOverflow, StackReq};
@@ -1426,7 +1426,7 @@ fn cholesky_in_place_impl<E: ComplexField>(
14261426
let a10_col = a10.rb_mut().col_mut(j);
14271427
let d0_elem = d0.read(j).faer_real().faer_inv();
14281428

1429-
zipped!(__rw, l10xd0_col, a10_col).for_each(
1429+
zipped_rw!(l10xd0_col, a10_col).for_each(
14301430
|unzipped!(mut l10xd0_elem, mut a10_elem)| {
14311431
let a10_elem_read = a10_elem.read();
14321432
a10_elem.write(a10_elem_read.faer_scale_real(d0_elem));

src/linalg/cholesky/ldlt_diagonal/solve.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use crate::{
2-
assert, linalg::triangular_solve as solve, unzipped, zipped, ComplexField, Conj, Entity,
2+
assert, linalg::triangular_solve as solve, unzipped, zipped_rw, ComplexField, Conj, Entity,
33
MatMut, MatRef, Parallelism,
44
};
55
use dyn_stack::{PodStack, SizeOverflow, StackReq};
@@ -184,7 +184,7 @@ pub fn solve_transpose_with_conj<E: ComplexField>(
184184
stack: &mut PodStack,
185185
) {
186186
let mut dst = dst;
187-
zipped!(__rw, dst.rb_mut(), rhs).for_each(|unzipped!(mut dst, src)| dst.write(src.read()));
187+
zipped_rw!(dst.rb_mut(), rhs).for_each(|unzipped!(mut dst, src)| dst.write(src.read()));
188188
solve_transpose_in_place_with_conj(cholesky_factors, conj_lhs, dst, parallelism, stack)
189189
}
190190

@@ -216,6 +216,6 @@ pub fn solve_with_conj<E: ComplexField>(
216216
stack: &mut PodStack,
217217
) {
218218
let mut dst = dst;
219-
zipped!(__rw, dst.rb_mut(), rhs).for_each(|unzipped!(mut dst, src)| dst.write(src.read()));
219+
zipped_rw!(dst.rb_mut(), rhs).for_each(|unzipped!(mut dst, src)| dst.write(src.read()));
220220
solve_in_place_with_conj(cholesky_factors, conj_lhs, dst, parallelism, stack)
221221
}

src/linalg/cholesky/ldlt_diagonal/update.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use crate::{
77
},
88
unzipped,
99
utils::{simd::*, slice::*},
10-
zipped, ColMut, MatMut, Parallelism,
10+
zipped_rw, ColMut, MatMut, Parallelism,
1111
};
1212
use core::iter::zip;
1313
use dyn_stack::{PodStack, SizeOverflow, StackReq};
@@ -422,7 +422,7 @@ fn rank_update_step_impl4<E: ComplexField>(
422422
let [p0, p1, p2, p3] = p_array;
423423
let [beta0, beta1, beta2, beta3] = beta_array;
424424

425-
zipped!(__rw, l_col, w0, w1, w2, w3).for_each(
425+
zipped_rw!(l_col, w0, w1, w2, w3).for_each(
426426
|unzipped!(mut l, mut w0, mut w1, mut w2, mut w3)| {
427427
let mut local_l = l.read();
428428
let mut local_w0 = w0.read();
@@ -482,7 +482,7 @@ fn rank_update_step_impl3<E: ComplexField>(
482482
let [p0, p1, p2] = p_array;
483483
let [beta0, beta1, beta2] = beta_array;
484484

485-
zipped!(__rw, l_col, w0, w1, w2).for_each(|unzipped!(mut l, mut w0, mut w1, mut w2)| {
485+
zipped_rw!(l_col, w0, w1, w2).for_each(|unzipped!(mut l, mut w0, mut w1, mut w2)| {
486486
let mut local_l = l.read();
487487
let mut local_w0 = w0.read();
488488
let mut local_w1 = w1.read();
@@ -532,7 +532,7 @@ fn rank_update_step_impl2<E: ComplexField>(
532532
let [p0, p1] = p_array;
533533
let [beta0, beta1] = beta_array;
534534

535-
zipped!(__rw, l_col, w0, w1).for_each(|unzipped!(mut l, mut w0, mut w1)| {
535+
zipped_rw!(l_col, w0, w1).for_each(|unzipped!(mut l, mut w0, mut w1)| {
536536
let mut local_l = l.read();
537537
let mut local_w0 = w0.read();
538538
let mut local_w1 = w1.read();
@@ -574,7 +574,7 @@ fn rank_update_step_impl1<E: ComplexField>(
574574
let [p0] = p_array;
575575
let [beta0] = beta_array;
576576

577-
zipped!(__rw, l_col, w0).for_each(|unzipped!(mut l, mut w0)| {
577+
zipped_rw!(l_col, w0).for_each(|unzipped!(mut l, mut w0)| {
578578
let mut local_l = l.read();
579579
let mut local_w0 = w0.read();
580580

src/linalg/cholesky/llt/reconstruct.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use crate::{
55
temp_mat_req, temp_mat_uninit,
66
zip::Diag,
77
},
8-
unzipped, zipped, ComplexField, Entity, MatMut, MatRef, Parallelism,
8+
unzipped, zipped_rw, ComplexField, Entity, MatMut, MatRef, Parallelism,
99
};
1010
use dyn_stack::{PodStack, SizeOverflow, StackReq};
1111
use reborrow::*;
@@ -78,7 +78,7 @@ pub fn reconstruct_lower_in_place<E: ComplexField>(
7878
let (mut tmp, stack) = temp_mat_uninit::<E>(n, n, stack);
7979
let mut tmp = tmp.as_mut();
8080
reconstruct_lower(tmp.rb_mut(), cholesky_factor.rb(), parallelism, stack);
81-
zipped!(__rw, cholesky_factor, tmp.rb())
81+
zipped_rw!(cholesky_factor, tmp.rb())
8282
.for_each_triangular_lower(Diag::Include, |unzipped!(mut dst, src)| {
8383
dst.write(src.read())
8484
});

src/linalg/cholesky/llt/solve.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use crate::{
2-
assert, linalg::triangular_solve as solve, unzipped, zipped, ComplexField, Conj, Entity,
2+
assert, linalg::triangular_solve as solve, unzipped, zipped_rw, ComplexField, Conj, Entity,
33
MatMut, MatRef, Parallelism,
44
};
55
use dyn_stack::{PodStack, SizeOverflow, StackReq};
@@ -135,7 +135,7 @@ pub fn solve_with_conj<E: ComplexField>(
135135
stack: &mut PodStack,
136136
) {
137137
let mut dst = dst;
138-
zipped!(__rw, dst.rb_mut(), rhs).for_each(|unzipped!(mut dst, src)| dst.write(src.read()));
138+
zipped_rw!(dst.rb_mut(), rhs).for_each(|unzipped!(mut dst, src)| dst.write(src.read()));
139139
solve_in_place_with_conj(cholesky_factor, conj_lhs, dst, parallelism, stack)
140140
}
141141

@@ -202,6 +202,6 @@ pub fn solve_transpose_with_conj<E: ComplexField>(
202202
stack: &mut PodStack,
203203
) {
204204
let mut dst = dst;
205-
zipped!(__rw, dst.rb_mut(), rhs).for_each(|unzipped!(mut dst, src)| dst.write(src.read()));
205+
zipped_rw!(dst.rb_mut(), rhs).for_each(|unzipped!(mut dst, src)| dst.write(src.read()));
206206
solve_transpose_in_place_with_conj(cholesky_factor, conj_lhs, dst, parallelism, stack)
207207
}

src/linalg/cholesky/llt/update.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ use crate::{
1212
},
1313
unzipped,
1414
utils::{simd::*, slice::*},
15-
zipped, ColMut, MatMut, Parallelism,
15+
zipped_rw, ColMut, MatMut, Parallelism,
1616
};
1717
use core::iter::zip;
1818
use dyn_stack::{PodStack, SizeOverflow, StackReq};
@@ -533,7 +533,7 @@ fn rank_update_step_impl4<E: ComplexField>(
533533
let [alpha_wj_over_nljj0, alpha_wj_over_nljj1, alpha_wj_over_nljj2, alpha_wj_over_nljj3] =
534534
alpha_wj_over_nljj_array;
535535

536-
zipped!(__rw, l_col, w0, w1, w2, w3,).for_each(
536+
zipped_rw!(l_col, w0, w1, w2, w3,).for_each(
537537
|unzipped!(mut l, mut w0, mut w1, mut w2, mut w3)| {
538538
let mut local_l = l.read();
539539
let mut local_w0 = w0.read();
@@ -607,7 +607,7 @@ fn rank_update_step_impl3<E: ComplexField>(
607607
let [alpha_wj_over_nljj0, alpha_wj_over_nljj1, alpha_wj_over_nljj2] =
608608
alpha_wj_over_nljj_array;
609609

610-
zipped!(__rw, l_col, w0, w1, w2).for_each(|unzipped!(mut l, mut w0, mut w1, mut w2)| {
610+
zipped_rw!(l_col, w0, w1, w2).for_each(|unzipped!(mut l, mut w0, mut w1, mut w2)| {
611611
let mut local_l = l.read();
612612
let mut local_w0 = w0.read();
613613
let mut local_w1 = w1.read();
@@ -668,7 +668,7 @@ fn rank_update_step_impl2<E: ComplexField>(
668668
let [nljj_over_ljj0, nljj_over_ljj1] = nljj_over_ljj_array;
669669
let [alpha_wj_over_nljj0, alpha_wj_over_nljj1] = alpha_wj_over_nljj_array;
670670

671-
zipped!(__rw, l_col, w0, w1).for_each(|unzipped!(mut l, mut w0, mut w1)| {
671+
zipped_rw!(l_col, w0, w1).for_each(|unzipped!(mut l, mut w0, mut w1)| {
672672
let mut local_l = l.read();
673673
let mut local_w0 = w0.read();
674674
let mut local_w1 = w1.read();
@@ -719,7 +719,7 @@ fn rank_update_step_impl1<E: ComplexField>(
719719
let [nljj_over_ljj0] = nljj_over_ljj_array;
720720
let [alpha_wj_over_nljj0] = alpha_wj_over_nljj_array;
721721

722-
zipped!(__rw, l_col, w0).for_each(|unzipped!(mut l, mut w0)| {
722+
zipped_rw!(l_col, w0).for_each(|unzipped!(mut l, mut w0)| {
723723
let mut local_l = l.read();
724724
let mut local_w0 = w0.read();
725725

0 commit comments

Comments
 (0)