Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 41 additions & 10 deletions crates/witness/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ pub struct RowMajorMatrix<T: Sized + Sync + Clone + Send + Copy> {
inner: p3::matrix::dense::RowMajorMatrix<T>,
// num_row is the real instance BEFORE padding
num_rows: usize,
log2_num_rotation: usize,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This design naturally skip num_rotation > 0 check, assure default rotation is 1

is_padded: bool,
padding_strategy: InstancePaddingStrategy,
}
Expand All @@ -53,13 +54,15 @@ impl<T: Sized + Sync + Clone + Send + Copy + Default + FieldAlgebra> RowMajorMat
inner: p3::matrix::dense::RowMajorMatrix::rand(rng, num_row_padded, cols),
num_rows: rows,
is_padded: true,
log2_num_rotation: 0,
padding_strategy: InstancePaddingStrategy::Default,
}
}
pub fn empty() -> Self {
Self {
inner: p3::matrix::dense::RowMajorMatrix::new(vec![], 0),
num_rows: 0,
log2_num_rotation: 0,
is_padded: true,
padding_strategy: InstancePaddingStrategy::Default,
}
Expand All @@ -72,7 +75,8 @@ impl<T: Sized + Sync + Clone + Send + Copy + Default + FieldAlgebra> RowMajorMat
mut self,
blowup_factor: Option<usize>,
) -> p3::matrix::dense::RowMajorMatrix<T> {
let padded_height = next_pow2_instance_padding(self.num_instances());
let padded_height = next_pow2_instance_padding(self.num_instances())
* Self::num_rotation(self.log2_num_rotation);
if let Some(blowup_factor) = blowup_factor {
if blowup_factor != CAPACITY_RESERVED_FACTOR {
tracing::warn!(
Expand All @@ -98,7 +102,21 @@ impl<T: Sized + Sync + Clone + Send + Copy + Default + FieldAlgebra> RowMajorMat
num_cols: usize,
padding_strategy: InstancePaddingStrategy,
) -> Self {
let num_row_padded = next_pow2_instance_padding(num_rows);
Self::new_by_rotation(num_rows, 0, num_cols, padding_strategy)
}

/// rotation controls how many physical rows each logical row spans.
///
/// for example, if `log2_num_rotation = 2`, then each logical row
/// is expanded into `2^2 = 4` physical rows.
pub fn new_by_rotation(
num_rows: usize,
log2_num_rotation: usize,
num_cols: usize,
padding_strategy: InstancePaddingStrategy,
) -> Self {
let num_row_padded =
next_pow2_instance_padding(num_rows) * Self::num_rotation(log2_num_rotation);

let mut value = Vec::with_capacity(CAPACITY_RESERVED_FACTOR * num_row_padded * num_cols);
value.par_extend(
Expand All @@ -109,6 +127,7 @@ impl<T: Sized + Sync + Clone + Send + Copy + Default + FieldAlgebra> RowMajorMat
RowMajorMatrix {
inner: p3::matrix::dense::RowMajorMatrix::new(value, num_cols),
num_rows,
log2_num_rotation,
is_padded: matches!(padding_strategy, InstancePaddingStrategy::Default),
padding_strategy,
}
Expand All @@ -126,6 +145,7 @@ impl<T: Sized + Sync + Clone + Send + Copy + Default + FieldAlgebra> RowMajorMat
RowMajorMatrix {
inner: m,
num_rows,
log2_num_rotation: 0,
is_padded: matches!(padding_strategy, InstancePaddingStrategy::Default),
padding_strategy,
}
Expand All @@ -146,32 +166,42 @@ impl<T: Sized + Sync + Clone + Send + Copy + Default + FieldAlgebra> RowMajorMat
next_pow2_instance_padding(self.num_instances()) - self.num_instances()
}

// return raw num_instances without rotation
pub fn num_instances(&self) -> usize {
self.num_rows
}

fn num_rotation(log2_num_rotation: usize) -> usize {
1 << log2_num_rotation
}

pub fn iter_rows(&self) -> Chunks<'_, T> {
self.inner.values[..self.num_instances() * self.n_col()].chunks(self.inner.width)
let num_rotation = Self::num_rotation(self.log2_num_rotation);
self.inner.values[..self.num_instances() * num_rotation * self.n_col()]
.chunks(num_rotation * self.inner.width)
}

pub fn iter_mut(&mut self) -> ChunksMut<'_, T> {
let max_range = self.num_instances() * self.n_col();
self.inner.values[..max_range].chunks_mut(self.inner.width)
let num_rotation = Self::num_rotation(self.log2_num_rotation);
let max_range = self.num_instances() * num_rotation * self.n_col();
self.inner.values[..max_range].chunks_mut(num_rotation * self.inner.width)
}

pub fn par_batch_iter_mut(&mut self, num_rows: usize) -> rayon::slice::ChunksMut<'_, T> {
let max_range = self.num_instances() * self.n_col();
self.inner.values[..max_range].par_chunks_mut(num_rows * self.inner.width)
let num_rotation = Self::num_rotation(self.log2_num_rotation);
let max_range = self.num_instances() * self.n_col() * num_rotation;
self.inner.values[..max_range].par_chunks_mut(num_rows * num_rotation * self.inner.width)
}

pub fn padding_by_strategy(&mut self) {
let start_index = self.num_instances() * self.n_col();
let num_rotation = Self::num_rotation(self.log2_num_rotation);
let start_index = self.num_instances() * num_rotation * self.n_col();

match &self.padding_strategy {
InstancePaddingStrategy::Default => (),
InstancePaddingStrategy::Custom(fun) => {
self.inner.values[start_index..]
.par_chunks_mut(self.inner.width)
.par_chunks_mut(num_rotation * self.inner.width)
.enumerate()
.for_each(|(i, instance)| {
instance.iter_mut().enumerate().for_each(|(j, v)| {
Expand Down Expand Up @@ -277,8 +307,9 @@ impl<F: Sync + Send + Copy + FieldAlgebra> Index<usize> for RowMajorMatrix<F> {
type Output = [F];

fn index(&self, idx: usize) -> &Self::Output {
let num_rotation = Self::num_rotation(self.log2_num_rotation);
let num_col = self.n_col();
&self.inner.values[num_col * idx..][..num_col]
&self.inner.values[num_rotation * num_col * idx..][..num_rotation * num_col]
}
}

Expand Down