Skip to content

Commit 7fe6cfa

Browse files
committed
rotation support
1 parent 08462f5 commit 7fe6cfa

File tree

1 file changed

+38
-10
lines changed

1 file changed

+38
-10
lines changed

crates/witness/src/lib.rs

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ pub struct RowMajorMatrix<T: Sized + Sync + Clone + Send + Copy> {
3838
inner: p3::matrix::dense::RowMajorMatrix<T>,
3939
// num_row is the real instance BEFORE padding
4040
num_rows: usize,
41+
log2_num_rotation: usize,
4142
is_padded: bool,
4243
padding_strategy: InstancePaddingStrategy,
4344
}
@@ -53,13 +54,15 @@ impl<T: Sized + Sync + Clone + Send + Copy + Default + FieldAlgebra> RowMajorMat
5354
inner: p3::matrix::dense::RowMajorMatrix::rand(rng, num_row_padded, cols),
5455
num_rows: rows,
5556
is_padded: true,
57+
log2_num_rotation: 0,
5658
padding_strategy: InstancePaddingStrategy::Default,
5759
}
5860
}
5961
pub fn empty() -> Self {
6062
Self {
6163
inner: p3::matrix::dense::RowMajorMatrix::new(vec![], 0),
6264
num_rows: 0,
65+
log2_num_rotation: 0,
6366
is_padded: true,
6467
padding_strategy: InstancePaddingStrategy::Default,
6568
}
@@ -72,7 +75,7 @@ impl<T: Sized + Sync + Clone + Send + Copy + Default + FieldAlgebra> RowMajorMat
7275
mut self,
7376
blowup_factor: Option<usize>,
7477
) -> p3::matrix::dense::RowMajorMatrix<T> {
75-
let padded_height = next_pow2_instance_padding(self.num_instances());
78+
let padded_height = next_pow2_instance_padding(self.num_instances()) * Self::num_rotation(self.log2_num_rotation);
7679
if let Some(blowup_factor) = blowup_factor {
7780
if blowup_factor != CAPACITY_RESERVED_FACTOR {
7881
tracing::warn!(
@@ -98,7 +101,20 @@ impl<T: Sized + Sync + Clone + Send + Copy + Default + FieldAlgebra> RowMajorMat
98101
num_cols: usize,
99102
padding_strategy: InstancePaddingStrategy,
100103
) -> Self {
101-
let num_row_padded = next_pow2_instance_padding(num_rows);
104+
Self::new_by_rotation(num_rows, 0,num_cols, padding_strategy)
105+
}
106+
107+
/// rotation controls how many physical rows each logical row spans.
108+
///
109+
/// for example, if `log2_num_rotation = 2`, then each logical row
110+
/// is expanded into `2^2 = 4` physical rows.
111+
pub fn new_by_rotation(
112+
num_rows: usize,
113+
log2_num_rotation: usize,
114+
num_cols: usize,
115+
padding_strategy: InstancePaddingStrategy,
116+
) -> Self {
117+
let num_row_padded = next_pow2_instance_padding(num_rows) * Self::num_rotation(log2_num_rotation);
102118

103119
let mut value = Vec::with_capacity(CAPACITY_RESERVED_FACTOR * num_row_padded * num_cols);
104120
value.par_extend(
@@ -109,6 +125,7 @@ impl<T: Sized + Sync + Clone + Send + Copy + Default + FieldAlgebra> RowMajorMat
109125
RowMajorMatrix {
110126
inner: p3::matrix::dense::RowMajorMatrix::new(value, num_cols),
111127
num_rows,
128+
log2_num_rotation,
112129
is_padded: matches!(padding_strategy, InstancePaddingStrategy::Default),
113130
padding_strategy,
114131
}
@@ -126,6 +143,7 @@ impl<T: Sized + Sync + Clone + Send + Copy + Default + FieldAlgebra> RowMajorMat
126143
RowMajorMatrix {
127144
inner: m,
128145
num_rows,
146+
log2_num_rotation: 0,
129147
is_padded: matches!(padding_strategy, InstancePaddingStrategy::Default),
130148
padding_strategy,
131149
}
@@ -146,32 +164,41 @@ impl<T: Sized + Sync + Clone + Send + Copy + Default + FieldAlgebra> RowMajorMat
146164
next_pow2_instance_padding(self.num_instances()) - self.num_instances()
147165
}
148166

167+
// return raw num_instances without rotation
149168
pub fn num_instances(&self) -> usize {
150169
self.num_rows
151170
}
152171

172+
fn num_rotation(log2_num_rotation: usize) -> usize {
173+
1 << log2_num_rotation
174+
}
175+
153176
pub fn iter_rows(&self) -> Chunks<'_, T> {
154-
self.inner.values[..self.num_instances() * self.n_col()].chunks(self.inner.width)
177+
let num_rotation = Self::num_rotation(self.log2_num_rotation);
178+
self.inner.values[..self.num_instances() * num_rotation * self.n_col()].chunks(num_rotation * self.inner.width)
155179
}
156180

157181
pub fn iter_mut(&mut self) -> ChunksMut<'_, T> {
158-
let max_range = self.num_instances() * self.n_col();
159-
self.inner.values[..max_range].chunks_mut(self.inner.width)
182+
let num_rotation = Self::num_rotation(self.log2_num_rotation);
183+
let max_range = self.num_instances() *num_rotation * self.n_col();
184+
self.inner.values[..max_range].chunks_mut(num_rotation * self.inner.width)
160185
}
161186

162187
pub fn par_batch_iter_mut(&mut self, num_rows: usize) -> rayon::slice::ChunksMut<'_, T> {
163-
let max_range = self.num_instances() * self.n_col();
164-
self.inner.values[..max_range].par_chunks_mut(num_rows * self.inner.width)
188+
let num_rotation = Self::num_rotation(self.log2_num_rotation);
189+
let max_range = self.num_instances() * self.n_col() * num_rotation;
190+
self.inner.values[..max_range].par_chunks_mut(num_rows * num_rotation * self.inner.width)
165191
}
166192

167193
pub fn padding_by_strategy(&mut self) {
168-
let start_index = self.num_instances() * self.n_col();
194+
let num_rotation = Self::num_rotation(self.log2_num_rotation);
195+
let start_index = self.num_instances() * num_rotation * self.n_col();
169196

170197
match &self.padding_strategy {
171198
InstancePaddingStrategy::Default => (),
172199
InstancePaddingStrategy::Custom(fun) => {
173200
self.inner.values[start_index..]
174-
.par_chunks_mut(self.inner.width)
201+
.par_chunks_mut(num_rotation * self.inner.width)
175202
.enumerate()
176203
.for_each(|(i, instance)| {
177204
instance.iter_mut().enumerate().for_each(|(j, v)| {
@@ -277,8 +304,9 @@ impl<F: Sync + Send + Copy + FieldAlgebra> Index<usize> for RowMajorMatrix<F> {
277304
type Output = [F];
278305

279306
fn index(&self, idx: usize) -> &Self::Output {
307+
let num_rotation = Self::num_rotation(self.log2_num_rotation);
280308
let num_col = self.n_col();
281-
&self.inner.values[num_col * idx..][..num_col]
309+
&self.inner.values[num_rotation * num_col * idx..][..num_rotation * num_col]
282310
}
283311
}
284312

0 commit comments

Comments
 (0)