diff --git a/crates/witness/src/lib.rs b/crates/witness/src/lib.rs index f7a40d3..a7619c1 100644 --- a/crates/witness/src/lib.rs +++ b/crates/witness/src/lib.rs @@ -38,6 +38,7 @@ pub struct RowMajorMatrix { inner: p3::matrix::dense::RowMajorMatrix, // num_row is the real instance BEFORE padding num_rows: usize, + log2_num_rotation: usize, is_padded: bool, padding_strategy: InstancePaddingStrategy, } @@ -53,6 +54,7 @@ impl RowMajorMat inner: p3::matrix::dense::RowMajorMatrix::rand(rng, num_row_padded, cols), num_rows: rows, is_padded: true, + log2_num_rotation: 0, padding_strategy: InstancePaddingStrategy::Default, } } @@ -60,6 +62,7 @@ impl RowMajorMat Self { inner: p3::matrix::dense::RowMajorMatrix::new(vec![], 0), num_rows: 0, + log2_num_rotation: 0, is_padded: true, padding_strategy: InstancePaddingStrategy::Default, } @@ -72,7 +75,8 @@ impl RowMajorMat mut self, blowup_factor: Option, ) -> p3::matrix::dense::RowMajorMatrix { - let padded_height = next_pow2_instance_padding(self.num_instances()); + let padded_height = next_pow2_instance_padding(self.num_instances()) + * Self::num_rotation(self.log2_num_rotation); if let Some(blowup_factor) = blowup_factor { if blowup_factor != CAPACITY_RESERVED_FACTOR { tracing::warn!( @@ -98,7 +102,21 @@ impl RowMajorMat num_cols: usize, padding_strategy: InstancePaddingStrategy, ) -> Self { - let num_row_padded = next_pow2_instance_padding(num_rows); + Self::new_by_rotation(num_rows, 0, num_cols, padding_strategy) + } + + /// rotation controls how many physical rows each logical row spans. + /// + /// for example, if `log2_num_rotation = 2`, then each logical row + /// is expanded into `2^2 = 4` physical rows. + pub fn new_by_rotation( + num_rows: usize, + log2_num_rotation: usize, + num_cols: usize, + padding_strategy: InstancePaddingStrategy, + ) -> Self { + let num_row_padded = + next_pow2_instance_padding(num_rows) * Self::num_rotation(log2_num_rotation); let mut value = Vec::with_capacity(CAPACITY_RESERVED_FACTOR * num_row_padded * num_cols); value.par_extend( @@ -109,6 +127,7 @@ impl RowMajorMat RowMajorMatrix { inner: p3::matrix::dense::RowMajorMatrix::new(value, num_cols), num_rows, + log2_num_rotation, is_padded: matches!(padding_strategy, InstancePaddingStrategy::Default), padding_strategy, } @@ -126,6 +145,7 @@ impl RowMajorMat RowMajorMatrix { inner: m, num_rows, + log2_num_rotation: 0, is_padded: matches!(padding_strategy, InstancePaddingStrategy::Default), padding_strategy, } @@ -146,32 +166,42 @@ impl RowMajorMat next_pow2_instance_padding(self.num_instances()) - self.num_instances() } + // return raw num_instances without rotation pub fn num_instances(&self) -> usize { self.num_rows } + fn num_rotation(log2_num_rotation: usize) -> usize { + 1 << log2_num_rotation + } + pub fn iter_rows(&self) -> Chunks<'_, T> { - self.inner.values[..self.num_instances() * self.n_col()].chunks(self.inner.width) + let num_rotation = Self::num_rotation(self.log2_num_rotation); + self.inner.values[..self.num_instances() * num_rotation * self.n_col()] + .chunks(num_rotation * self.inner.width) } pub fn iter_mut(&mut self) -> ChunksMut<'_, T> { - let max_range = self.num_instances() * self.n_col(); - self.inner.values[..max_range].chunks_mut(self.inner.width) + let num_rotation = Self::num_rotation(self.log2_num_rotation); + let max_range = self.num_instances() * num_rotation * self.n_col(); + self.inner.values[..max_range].chunks_mut(num_rotation * self.inner.width) } pub fn par_batch_iter_mut(&mut self, num_rows: usize) -> rayon::slice::ChunksMut<'_, T> { - let max_range = self.num_instances() * self.n_col(); - self.inner.values[..max_range].par_chunks_mut(num_rows * self.inner.width) + let num_rotation = Self::num_rotation(self.log2_num_rotation); + let max_range = self.num_instances() * self.n_col() * num_rotation; + self.inner.values[..max_range].par_chunks_mut(num_rows * num_rotation * self.inner.width) } pub fn padding_by_strategy(&mut self) { - let start_index = self.num_instances() * self.n_col(); + let num_rotation = Self::num_rotation(self.log2_num_rotation); + let start_index = self.num_instances() * num_rotation * self.n_col(); match &self.padding_strategy { InstancePaddingStrategy::Default => (), InstancePaddingStrategy::Custom(fun) => { self.inner.values[start_index..] - .par_chunks_mut(self.inner.width) + .par_chunks_mut(num_rotation * self.inner.width) .enumerate() .for_each(|(i, instance)| { instance.iter_mut().enumerate().for_each(|(j, v)| { @@ -277,8 +307,9 @@ impl Index for RowMajorMatrix { type Output = [F]; fn index(&self, idx: usize) -> &Self::Output { + let num_rotation = Self::num_rotation(self.log2_num_rotation); let num_col = self.n_col(); - &self.inner.values[num_col * idx..][..num_col] + &self.inner.values[num_rotation * num_col * idx..][..num_rotation * num_col] } }