@@ -38,6 +38,7 @@ pub struct RowMajorMatrix<T: Sized + Sync + Clone + Send + Copy> {
3838 inner : p3:: matrix:: dense:: RowMajorMatrix < T > ,
3939 // num_row is the real instance BEFORE padding
4040 num_rows : usize ,
41+ log2_num_rotation : usize ,
4142 is_padded : bool ,
4243 padding_strategy : InstancePaddingStrategy ,
4344}
@@ -53,13 +54,15 @@ impl<T: Sized + Sync + Clone + Send + Copy + Default + FieldAlgebra> RowMajorMat
5354 inner : p3:: matrix:: dense:: RowMajorMatrix :: rand ( rng, num_row_padded, cols) ,
5455 num_rows : rows,
5556 is_padded : true ,
57+ log2_num_rotation : 0 ,
5658 padding_strategy : InstancePaddingStrategy :: Default ,
5759 }
5860 }
5961 pub fn empty ( ) -> Self {
6062 Self {
6163 inner : p3:: matrix:: dense:: RowMajorMatrix :: new ( vec ! [ ] , 0 ) ,
6264 num_rows : 0 ,
65+ log2_num_rotation : 0 ,
6366 is_padded : true ,
6467 padding_strategy : InstancePaddingStrategy :: Default ,
6568 }
@@ -72,7 +75,7 @@ impl<T: Sized + Sync + Clone + Send + Copy + Default + FieldAlgebra> RowMajorMat
7275 mut self ,
7376 blowup_factor : Option < usize > ,
7477 ) -> p3:: matrix:: dense:: RowMajorMatrix < T > {
75- let padded_height = next_pow2_instance_padding ( self . num_instances ( ) ) ;
78+ let padded_height = next_pow2_instance_padding ( self . num_instances ( ) ) * Self :: num_rotation ( self . log2_num_rotation ) ;
7679 if let Some ( blowup_factor) = blowup_factor {
7780 if blowup_factor != CAPACITY_RESERVED_FACTOR {
7881 tracing:: warn!(
@@ -98,7 +101,20 @@ impl<T: Sized + Sync + Clone + Send + Copy + Default + FieldAlgebra> RowMajorMat
98101 num_cols : usize ,
99102 padding_strategy : InstancePaddingStrategy ,
100103 ) -> Self {
101- let num_row_padded = next_pow2_instance_padding ( num_rows) ;
104+ Self :: new_by_rotation ( num_rows, 0 , num_cols, padding_strategy)
105+ }
106+
107+ /// rotation controls how many physical rows each logical row spans.
108+ ///
109+ /// for example, if `log2_num_rotation = 2`, then each logical row
110+ /// is expanded into `2^2 = 4` physical rows.
111+ pub fn new_by_rotation (
112+ num_rows : usize ,
113+ log2_num_rotation : usize ,
114+ num_cols : usize ,
115+ padding_strategy : InstancePaddingStrategy ,
116+ ) -> Self {
117+ let num_row_padded = next_pow2_instance_padding ( num_rows) * Self :: num_rotation ( log2_num_rotation) ;
102118
103119 let mut value = Vec :: with_capacity ( CAPACITY_RESERVED_FACTOR * num_row_padded * num_cols) ;
104120 value. par_extend (
@@ -109,6 +125,7 @@ impl<T: Sized + Sync + Clone + Send + Copy + Default + FieldAlgebra> RowMajorMat
109125 RowMajorMatrix {
110126 inner : p3:: matrix:: dense:: RowMajorMatrix :: new ( value, num_cols) ,
111127 num_rows,
128+ log2_num_rotation,
112129 is_padded : matches ! ( padding_strategy, InstancePaddingStrategy :: Default ) ,
113130 padding_strategy,
114131 }
@@ -126,6 +143,7 @@ impl<T: Sized + Sync + Clone + Send + Copy + Default + FieldAlgebra> RowMajorMat
126143 RowMajorMatrix {
127144 inner : m,
128145 num_rows,
146+ log2_num_rotation : 0 ,
129147 is_padded : matches ! ( padding_strategy, InstancePaddingStrategy :: Default ) ,
130148 padding_strategy,
131149 }
@@ -146,32 +164,41 @@ impl<T: Sized + Sync + Clone + Send + Copy + Default + FieldAlgebra> RowMajorMat
146164 next_pow2_instance_padding ( self . num_instances ( ) ) - self . num_instances ( )
147165 }
148166
167+ // return raw num_instances without rotation
149168 pub fn num_instances ( & self ) -> usize {
150169 self . num_rows
151170 }
152171
172+ fn num_rotation ( log2_num_rotation : usize ) -> usize {
173+ 1 << log2_num_rotation
174+ }
175+
153176 pub fn iter_rows ( & self ) -> Chunks < ' _ , T > {
154- self . inner . values [ ..self . num_instances ( ) * self . n_col ( ) ] . chunks ( self . inner . width )
177+ let num_rotation = Self :: num_rotation ( self . log2_num_rotation ) ;
178+ self . inner . values [ ..self . num_instances ( ) * num_rotation * self . n_col ( ) ] . chunks ( num_rotation * self . inner . width )
155179 }
156180
157181 pub fn iter_mut ( & mut self ) -> ChunksMut < ' _ , T > {
158- let max_range = self . num_instances ( ) * self . n_col ( ) ;
159- self . inner . values [ ..max_range] . chunks_mut ( self . inner . width )
182+ let num_rotation = Self :: num_rotation ( self . log2_num_rotation ) ;
183+ let max_range = self . num_instances ( ) * num_rotation * self . n_col ( ) ;
184+ self . inner . values [ ..max_range] . chunks_mut ( num_rotation * self . inner . width )
160185 }
161186
162187 pub fn par_batch_iter_mut ( & mut self , num_rows : usize ) -> rayon:: slice:: ChunksMut < ' _ , T > {
163- let max_range = self . num_instances ( ) * self . n_col ( ) ;
164- self . inner . values [ ..max_range] . par_chunks_mut ( num_rows * self . inner . width )
188+ let num_rotation = Self :: num_rotation ( self . log2_num_rotation ) ;
189+ let max_range = self . num_instances ( ) * self . n_col ( ) * num_rotation;
190+ self . inner . values [ ..max_range] . par_chunks_mut ( num_rows * num_rotation * self . inner . width )
165191 }
166192
167193 pub fn padding_by_strategy ( & mut self ) {
168- let start_index = self . num_instances ( ) * self . n_col ( ) ;
194+ let num_rotation = Self :: num_rotation ( self . log2_num_rotation ) ;
195+ let start_index = self . num_instances ( ) * num_rotation * self . n_col ( ) ;
169196
170197 match & self . padding_strategy {
171198 InstancePaddingStrategy :: Default => ( ) ,
172199 InstancePaddingStrategy :: Custom ( fun) => {
173200 self . inner . values [ start_index..]
174- . par_chunks_mut ( self . inner . width )
201+ . par_chunks_mut ( num_rotation * self . inner . width )
175202 . enumerate ( )
176203 . for_each ( |( i, instance) | {
177204 instance. iter_mut ( ) . enumerate ( ) . for_each ( |( j, v) | {
@@ -277,8 +304,9 @@ impl<F: Sync + Send + Copy + FieldAlgebra> Index<usize> for RowMajorMatrix<F> {
277304 type Output = [ F ] ;
278305
279306 fn index ( & self , idx : usize ) -> & Self :: Output {
307+ let num_rotation = Self :: num_rotation ( self . log2_num_rotation ) ;
280308 let num_col = self . n_col ( ) ;
281- & self . inner . values [ num_col * idx..] [ ..num_col]
309+ & self . inner . values [ num_rotation * num_col * idx..] [ ..num_rotation * num_col]
282310 }
283311}
284312
0 commit comments