Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: multiexp_serial skips doubling when all bits are zero #202

Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 36 additions & 5 deletions halo2_proofs/src/arithmetic.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
//! This module provides common utilities, traits and structures for group,
//! field and polynomial arithmetic.

use std::cmp;

use super::multicore;
pub use ff::Field;
use group::{
Expand All @@ -25,6 +27,7 @@ where
{
}

// ASSUMES C::Scalar::Repr is little endian
fn multiexp_serial<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C], acc: &mut C::Curve) {
let coeffs: Vec<_> = coeffs.iter().map(|a| a.to_repr()).collect();

Expand All @@ -36,6 +39,7 @@ fn multiexp_serial<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C], acc: &mut
(f64::from(bases.len() as u32)).ln().ceil() as usize
};

// Group `bytes` into bits and take the `segment`th chunk of `c` bits
fn get_at<F: PrimeField>(segment: usize, c: usize, bytes: &F::Repr) -> usize {
let skip_bits = segment * c;
let skip_bytes = skip_bits / 8;
Expand All @@ -56,9 +60,35 @@ fn multiexp_serial<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C], acc: &mut
tmp as usize
}

let segments = (256 / c) + 1;
// Ideally `segments` should be calculated from the max number of bits among all scalars. But this requires a scan of all scalars, so we don't implement it for now.
let segments = (C::Scalar::NUM_BITS as usize + c - 1) / c;

// this can be optimized
let mut coeffs_in_segments = Vec::with_capacity(segments);
// track what is the last segment where we actually have nonzero bits, so we completely skip buckets where the scalar bits for all coeffs are 0
let mut max_nonzero_segment = None;
for current_segment in 0..segments {
let coeff_segments: Vec<_> = coeffs
.iter()
.map(|coeff| {
let c_bits = get_at::<C::Scalar>(current_segment, c, coeff);
if c_bits != 0 {
max_nonzero_segment = Some(current_segment);
}
c_bits
})
.collect();
coeffs_in_segments.push(coeff_segments);
}

for current_segment in (0..segments).rev() {
if max_nonzero_segment.is_none() {
return;
}
for coeffs_seg in coeffs_in_segments
.into_iter()
.take(max_nonzero_segment.unwrap() + 1)
.rev()
{
for _ in 0..c {
*acc = acc.double();
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be faster and easier to maintain to add a boolean guard if is_init here
that defaults to false and is set to true on first initialization of the accumulator.

It would avoid an extra pass over the data.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes, nice suggestion.

But now I am tempted to just do a full scan ahead of time to do this bit column optimization...

}
Expand Down Expand Up @@ -96,9 +126,10 @@ fn multiexp_serial<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C], acc: &mut

let mut buckets: Vec<Bucket<C>> = vec![Bucket::None; (1 << c) - 1];

for (coeff, base) in coeffs.iter().zip(bases.iter()) {
let coeff = get_at::<C::Scalar>(current_segment, c, coeff);
let mut max_bits = 0;
for (coeff, base) in coeffs_seg.into_iter().zip(bases.iter()) {
if coeff != 0 {
max_bits = cmp::max(max_bits, coeff);
buckets[coeff - 1].add_assign(base);
}
}
Expand All @@ -108,7 +139,7 @@ fn multiexp_serial<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C], acc: &mut
// (a) + b +
// ((a) + b) + c
let mut running_sum = C::Curve::identity();
for exp in buckets.into_iter().rev() {
for exp in buckets.into_iter().take(max_bits).rev() {
running_sum = exp.add(running_sum);
*acc = *acc + &running_sum;
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll need a local accumulator initialized with the first non-zero running sum because we have no info on the global accumulator here.

}
Expand Down
Loading