@@ -38,6 +38,7 @@ pub struct RowMajorMatrix<T: Sized + Sync + Clone + Send + Copy> {
3838 inner : p3:: matrix:: dense:: RowMajorMatrix < T > ,
3939 // num_row is the real instance BEFORE padding
4040 num_rows : usize ,
41+ log2_num_rotation : usize ,
4142 is_padded : bool ,
4243 padding_strategy : InstancePaddingStrategy ,
4344}
@@ -53,13 +54,15 @@ impl<T: Sized + Sync + Clone + Send + Copy + Default + FieldAlgebra> RowMajorMat
5354 inner : p3:: matrix:: dense:: RowMajorMatrix :: rand ( rng, num_row_padded, cols) ,
5455 num_rows : rows,
5556 is_padded : true ,
57+ log2_num_rotation : 0 ,
5658 padding_strategy : InstancePaddingStrategy :: Default ,
5759 }
5860 }
5961 pub fn empty ( ) -> Self {
6062 Self {
6163 inner : p3:: matrix:: dense:: RowMajorMatrix :: new ( vec ! [ ] , 0 ) ,
6264 num_rows : 0 ,
65+ log2_num_rotation : 0 ,
6366 is_padded : true ,
6467 padding_strategy : InstancePaddingStrategy :: Default ,
6568 }
@@ -72,7 +75,8 @@ impl<T: Sized + Sync + Clone + Send + Copy + Default + FieldAlgebra> RowMajorMat
7275 mut self ,
7376 blowup_factor : Option < usize > ,
7477 ) -> p3:: matrix:: dense:: RowMajorMatrix < T > {
75- let padded_height = next_pow2_instance_padding ( self . num_instances ( ) ) ;
78+ let padded_height = next_pow2_instance_padding ( self . num_instances ( ) )
79+ * Self :: num_rotation ( self . log2_num_rotation ) ;
7680 if let Some ( blowup_factor) = blowup_factor {
7781 if blowup_factor != CAPACITY_RESERVED_FACTOR {
7882 tracing:: warn!(
@@ -98,7 +102,21 @@ impl<T: Sized + Sync + Clone + Send + Copy + Default + FieldAlgebra> RowMajorMat
98102 num_cols : usize ,
99103 padding_strategy : InstancePaddingStrategy ,
100104 ) -> Self {
101- let num_row_padded = next_pow2_instance_padding ( num_rows) ;
105+ Self :: new_by_rotation ( num_rows, 0 , num_cols, padding_strategy)
106+ }
107+
108+ /// rotation controls how many physical rows each logical row spans.
109+ ///
110+ /// for example, if `log2_num_rotation = 2`, then each logical row
111+ /// is expanded into `2^2 = 4` physical rows.
112+ pub fn new_by_rotation (
113+ num_rows : usize ,
114+ log2_num_rotation : usize ,
115+ num_cols : usize ,
116+ padding_strategy : InstancePaddingStrategy ,
117+ ) -> Self {
118+ let num_row_padded =
119+ next_pow2_instance_padding ( num_rows) * Self :: num_rotation ( log2_num_rotation) ;
102120
103121 let mut value = Vec :: with_capacity ( CAPACITY_RESERVED_FACTOR * num_row_padded * num_cols) ;
104122 value. par_extend (
@@ -109,6 +127,7 @@ impl<T: Sized + Sync + Clone + Send + Copy + Default + FieldAlgebra> RowMajorMat
109127 RowMajorMatrix {
110128 inner : p3:: matrix:: dense:: RowMajorMatrix :: new ( value, num_cols) ,
111129 num_rows,
130+ log2_num_rotation,
112131 is_padded : matches ! ( padding_strategy, InstancePaddingStrategy :: Default ) ,
113132 padding_strategy,
114133 }
@@ -126,6 +145,7 @@ impl<T: Sized + Sync + Clone + Send + Copy + Default + FieldAlgebra> RowMajorMat
126145 RowMajorMatrix {
127146 inner : m,
128147 num_rows,
148+ log2_num_rotation : 0 ,
129149 is_padded : matches ! ( padding_strategy, InstancePaddingStrategy :: Default ) ,
130150 padding_strategy,
131151 }
@@ -146,32 +166,42 @@ impl<T: Sized + Sync + Clone + Send + Copy + Default + FieldAlgebra> RowMajorMat
146166 next_pow2_instance_padding ( self . num_instances ( ) ) - self . num_instances ( )
147167 }
148168
169+ // return raw num_instances without rotation
149170 pub fn num_instances ( & self ) -> usize {
150171 self . num_rows
151172 }
152173
174+ fn num_rotation ( log2_num_rotation : usize ) -> usize {
175+ 1 << log2_num_rotation
176+ }
177+
153178 pub fn iter_rows ( & self ) -> Chunks < ' _ , T > {
154- self . inner . values [ ..self . num_instances ( ) * self . n_col ( ) ] . chunks ( self . inner . width )
179+ let num_rotation = Self :: num_rotation ( self . log2_num_rotation ) ;
180+ self . inner . values [ ..self . num_instances ( ) * num_rotation * self . n_col ( ) ]
181+ . chunks ( num_rotation * self . inner . width )
155182 }
156183
157184 pub fn iter_mut ( & mut self ) -> ChunksMut < ' _ , T > {
158- let max_range = self . num_instances ( ) * self . n_col ( ) ;
159- self . inner . values [ ..max_range] . chunks_mut ( self . inner . width )
185+ let num_rotation = Self :: num_rotation ( self . log2_num_rotation ) ;
186+ let max_range = self . num_instances ( ) * num_rotation * self . n_col ( ) ;
187+ self . inner . values [ ..max_range] . chunks_mut ( num_rotation * self . inner . width )
160188 }
161189
162190 pub fn par_batch_iter_mut ( & mut self , num_rows : usize ) -> rayon:: slice:: ChunksMut < ' _ , T > {
163- let max_range = self . num_instances ( ) * self . n_col ( ) ;
164- self . inner . values [ ..max_range] . par_chunks_mut ( num_rows * self . inner . width )
191+ let num_rotation = Self :: num_rotation ( self . log2_num_rotation ) ;
192+ let max_range = self . num_instances ( ) * self . n_col ( ) * num_rotation;
193+ self . inner . values [ ..max_range] . par_chunks_mut ( num_rows * num_rotation * self . inner . width )
165194 }
166195
167196 pub fn padding_by_strategy ( & mut self ) {
168- let start_index = self . num_instances ( ) * self . n_col ( ) ;
197+ let num_rotation = Self :: num_rotation ( self . log2_num_rotation ) ;
198+ let start_index = self . num_instances ( ) * num_rotation * self . n_col ( ) ;
169199
170200 match & self . padding_strategy {
171201 InstancePaddingStrategy :: Default => ( ) ,
172202 InstancePaddingStrategy :: Custom ( fun) => {
173203 self . inner . values [ start_index..]
174- . par_chunks_mut ( self . inner . width )
204+ . par_chunks_mut ( num_rotation * self . inner . width )
175205 . enumerate ( )
176206 . for_each ( |( i, instance) | {
177207 instance. iter_mut ( ) . enumerate ( ) . for_each ( |( j, v) | {
@@ -277,8 +307,9 @@ impl<F: Sync + Send + Copy + FieldAlgebra> Index<usize> for RowMajorMatrix<F> {
277307 type Output = [ F ] ;
278308
279309 fn index ( & self , idx : usize ) -> & Self :: Output {
310+ let num_rotation = Self :: num_rotation ( self . log2_num_rotation ) ;
280311 let num_col = self . n_col ( ) ;
281- & self . inner . values [ num_col * idx..] [ ..num_col]
312+ & self . inner . values [ num_rotation * num_col * idx..] [ ..num_rotation * num_col]
282313 }
283314}
284315
0 commit comments