Skip to content

Commit f7dc023

Browse files
committed
rotation support
1 parent 08462f5 commit f7dc023

File tree

1 file changed

+41
-10
lines changed

1 file changed

+41
-10
lines changed

crates/witness/src/lib.rs

Lines changed: 41 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ pub struct RowMajorMatrix<T: Sized + Sync + Clone + Send + Copy> {
3838
inner: p3::matrix::dense::RowMajorMatrix<T>,
3939
// num_row is the real instance BEFORE padding
4040
num_rows: usize,
41+
log2_num_rotation: usize,
4142
is_padded: bool,
4243
padding_strategy: InstancePaddingStrategy,
4344
}
@@ -53,13 +54,15 @@ impl<T: Sized + Sync + Clone + Send + Copy + Default + FieldAlgebra> RowMajorMat
5354
inner: p3::matrix::dense::RowMajorMatrix::rand(rng, num_row_padded, cols),
5455
num_rows: rows,
5556
is_padded: true,
57+
log2_num_rotation: 0,
5658
padding_strategy: InstancePaddingStrategy::Default,
5759
}
5860
}
5961
pub fn empty() -> Self {
6062
Self {
6163
inner: p3::matrix::dense::RowMajorMatrix::new(vec![], 0),
6264
num_rows: 0,
65+
log2_num_rotation: 0,
6366
is_padded: true,
6467
padding_strategy: InstancePaddingStrategy::Default,
6568
}
@@ -72,7 +75,8 @@ impl<T: Sized + Sync + Clone + Send + Copy + Default + FieldAlgebra> RowMajorMat
7275
mut self,
7376
blowup_factor: Option<usize>,
7477
) -> p3::matrix::dense::RowMajorMatrix<T> {
75-
let padded_height = next_pow2_instance_padding(self.num_instances());
78+
let padded_height = next_pow2_instance_padding(self.num_instances())
79+
* Self::num_rotation(self.log2_num_rotation);
7680
if let Some(blowup_factor) = blowup_factor {
7781
if blowup_factor != CAPACITY_RESERVED_FACTOR {
7882
tracing::warn!(
@@ -98,7 +102,21 @@ impl<T: Sized + Sync + Clone + Send + Copy + Default + FieldAlgebra> RowMajorMat
98102
num_cols: usize,
99103
padding_strategy: InstancePaddingStrategy,
100104
) -> Self {
101-
let num_row_padded = next_pow2_instance_padding(num_rows);
105+
Self::new_by_rotation(num_rows, 0, num_cols, padding_strategy)
106+
}
107+
108+
/// rotation controls how many physical rows each logical row spans.
109+
///
110+
/// for example, if `log2_num_rotation = 2`, then each logical row
111+
/// is expanded into `2^2 = 4` physical rows.
112+
pub fn new_by_rotation(
113+
num_rows: usize,
114+
log2_num_rotation: usize,
115+
num_cols: usize,
116+
padding_strategy: InstancePaddingStrategy,
117+
) -> Self {
118+
let num_row_padded =
119+
next_pow2_instance_padding(num_rows) * Self::num_rotation(log2_num_rotation);
102120

103121
let mut value = Vec::with_capacity(CAPACITY_RESERVED_FACTOR * num_row_padded * num_cols);
104122
value.par_extend(
@@ -109,6 +127,7 @@ impl<T: Sized + Sync + Clone + Send + Copy + Default + FieldAlgebra> RowMajorMat
109127
RowMajorMatrix {
110128
inner: p3::matrix::dense::RowMajorMatrix::new(value, num_cols),
111129
num_rows,
130+
log2_num_rotation,
112131
is_padded: matches!(padding_strategy, InstancePaddingStrategy::Default),
113132
padding_strategy,
114133
}
@@ -126,6 +145,7 @@ impl<T: Sized + Sync + Clone + Send + Copy + Default + FieldAlgebra> RowMajorMat
126145
RowMajorMatrix {
127146
inner: m,
128147
num_rows,
148+
log2_num_rotation: 0,
129149
is_padded: matches!(padding_strategy, InstancePaddingStrategy::Default),
130150
padding_strategy,
131151
}
@@ -146,32 +166,42 @@ impl<T: Sized + Sync + Clone + Send + Copy + Default + FieldAlgebra> RowMajorMat
146166
next_pow2_instance_padding(self.num_instances()) - self.num_instances()
147167
}
148168

169+
// return raw num_instances without rotation
149170
pub fn num_instances(&self) -> usize {
150171
self.num_rows
151172
}
152173

174+
fn num_rotation(log2_num_rotation: usize) -> usize {
175+
1 << log2_num_rotation
176+
}
177+
153178
pub fn iter_rows(&self) -> Chunks<'_, T> {
154-
self.inner.values[..self.num_instances() * self.n_col()].chunks(self.inner.width)
179+
let num_rotation = Self::num_rotation(self.log2_num_rotation);
180+
self.inner.values[..self.num_instances() * num_rotation * self.n_col()]
181+
.chunks(num_rotation * self.inner.width)
155182
}
156183

157184
pub fn iter_mut(&mut self) -> ChunksMut<'_, T> {
158-
let max_range = self.num_instances() * self.n_col();
159-
self.inner.values[..max_range].chunks_mut(self.inner.width)
185+
let num_rotation = Self::num_rotation(self.log2_num_rotation);
186+
let max_range = self.num_instances() * num_rotation * self.n_col();
187+
self.inner.values[..max_range].chunks_mut(num_rotation * self.inner.width)
160188
}
161189

162190
pub fn par_batch_iter_mut(&mut self, num_rows: usize) -> rayon::slice::ChunksMut<'_, T> {
163-
let max_range = self.num_instances() * self.n_col();
164-
self.inner.values[..max_range].par_chunks_mut(num_rows * self.inner.width)
191+
let num_rotation = Self::num_rotation(self.log2_num_rotation);
192+
let max_range = self.num_instances() * self.n_col() * num_rotation;
193+
self.inner.values[..max_range].par_chunks_mut(num_rows * num_rotation * self.inner.width)
165194
}
166195

167196
pub fn padding_by_strategy(&mut self) {
168-
let start_index = self.num_instances() * self.n_col();
197+
let num_rotation = Self::num_rotation(self.log2_num_rotation);
198+
let start_index = self.num_instances() * num_rotation * self.n_col();
169199

170200
match &self.padding_strategy {
171201
InstancePaddingStrategy::Default => (),
172202
InstancePaddingStrategy::Custom(fun) => {
173203
self.inner.values[start_index..]
174-
.par_chunks_mut(self.inner.width)
204+
.par_chunks_mut(num_rotation * self.inner.width)
175205
.enumerate()
176206
.for_each(|(i, instance)| {
177207
instance.iter_mut().enumerate().for_each(|(j, v)| {
@@ -277,8 +307,9 @@ impl<F: Sync + Send + Copy + FieldAlgebra> Index<usize> for RowMajorMatrix<F> {
277307
type Output = [F];
278308

279309
fn index(&self, idx: usize) -> &Self::Output {
310+
let num_rotation = Self::num_rotation(self.log2_num_rotation);
280311
let num_col = self.n_col();
281-
&self.inner.values[num_col * idx..][..num_col]
312+
&self.inner.values[num_rotation * num_col * idx..][..num_rotation * num_col]
282313
}
283314
}
284315

0 commit comments

Comments
 (0)