Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: BRP and G1 MSM #38

Merged
merged 1 commit into from
Feb 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions src/blob.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use crate::{
bls::{FiniteFieldError, Fr, P1},
kzg::{Commitment, Polynomial, Proof, Setup},
math::BitReversalPermutation,
};

pub enum Error {
Expand Down Expand Up @@ -37,8 +36,7 @@ impl<const N: usize> Blob<N> {
}

pub(crate) fn commitment<const G2: usize>(&self, setup: &Setup<N, G2>) -> Commitment {
let g1_lagrange = BitReversalPermutation::new(setup.g1_lagrange.as_slice());
let lincomb = P1::lincomb(g1_lagrange.iter().zip(self.elements.iter()));
let lincomb = P1::lincomb(setup.g1_lagrange_brp.as_slice(), self.elements.as_slice());

Commitment::from(lincomb)
}
Expand Down
19 changes: 5 additions & 14 deletions src/bls.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::{
cmp,
mem::MaybeUninit,
ops::{Add, Div, Mul, Neg, Shl, ShlAssign, Shr, ShrAssign, Sub},
};
Expand Down Expand Up @@ -374,21 +375,11 @@ impl P1 {
}

// TODO: optimize w/ pippenger
pub fn lincomb<'a>(terms: impl Iterator<Item = (&'a Self, &'a Fr)>) -> Self {
pub fn lincomb(points: impl AsRef<[Self]>, scalars: impl AsRef<[Fr]>) -> Self {
let n = cmp::min(points.as_ref().len(), scalars.as_ref().len());
let mut lincomb = Self::INF;
for (point, scalar) in terms {
lincomb = lincomb + (point * scalar);
}

lincomb
}

// TODO: optimize w/ pippenger
// TODO: unify with `P1::lincomb`
pub fn lincomb_owned(terms: impl Iterator<Item = (Self, Fr)>) -> Self {
let mut lincomb = Self::INF;
for (point, scalar) in terms {
lincomb = lincomb + (point * scalar);
for i in 0..n {
lincomb = lincomb + (points.as_ref()[i] * scalars.as_ref()[i]);
}

lincomb
Expand Down
12 changes: 4 additions & 8 deletions src/kzg/poly.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
use crate::{
bls::{Fr, P1},
math::BitReversalPermutation,
};
use crate::bls::{Fr, P1};

use super::{setup::Setup, Proof};

Expand All @@ -11,7 +8,7 @@ pub(crate) struct Polynomial<'a, const N: usize>(pub(crate) &'a [Fr; N]);
impl<'a, const N: usize> Polynomial<'a, N> {
/// evaluates the polynomial at `point`.
pub(crate) fn evaluate<const G2: usize>(&self, point: Fr, setup: &Setup<N, G2>) -> Fr {
let roots = BitReversalPermutation::new(setup.roots_of_unity.as_slice());
let roots = &setup.roots_of_unity_brp;

// if `point` is a root of a unity, then we have the evaluation available
for i in 0..N {
Expand All @@ -37,7 +34,7 @@ impl<'a, const N: usize> Polynomial<'a, N> {

/// returns a `Proof` for the evaluation of the polynomial at `point`.
pub(crate) fn prove<const G2: usize>(&self, point: Fr, setup: &Setup<N, G2>) -> (Fr, Proof) {
let roots = BitReversalPermutation::new(setup.roots_of_unity.as_slice());
let roots = &setup.roots_of_unity_brp;

let eval = self.evaluate(point, setup);

Expand Down Expand Up @@ -68,8 +65,7 @@ impl<'a, const N: usize> Polynomial<'a, N> {
quotient_poly.push(quotient);
}

let g1_lagrange = BitReversalPermutation::new(setup.g1_lagrange.as_slice());
let lincomb = P1::lincomb(g1_lagrange.iter().zip(quotient_poly.iter()));
let lincomb = P1::lincomb(setup.g1_lagrange_brp.as_slice(), quotient_poly);

(eval, lincomb)
}
Expand Down
49 changes: 24 additions & 25 deletions src/kzg/setup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ struct SetupUnchecked {

#[derive(Clone, Debug)]
pub struct Setup<const G1: usize, const G2: usize> {
pub(crate) g1_lagrange: Box<[P1; G1]>,
pub(crate) g1_lagrange_brp: Box<[P1; G1]>,
pub(crate) g2_monomial: Box<[P2; G2]>,
pub(crate) roots_of_unity: Box<[Fr; G1]>,
pub(crate) roots_of_unity_brp: Box<[Fr; G1]>,
}

impl<const G1: usize, const G2: usize> Setup<G1, G2> {
Expand Down Expand Up @@ -63,6 +63,7 @@ impl<const G1: usize, const G2: usize> Setup<G1, G2> {
P1::deserialize(point).map_err(|err| LoadSetupError::Bls(BlsError::from(err)))?;
g1_lagrange[i] = point;
}
let g1_lagrange_brp = math::bit_reversal_permutation_boxed_array(g1_lagrange.as_slice());

let mut g2_monomial: Box<[P2; G2]> = Box::new([P2::default(); G2]);
for (i, point) in setup.g2_monomial.iter().enumerate() {
Expand All @@ -78,13 +79,13 @@ impl<const G1: usize, const G2: usize> Setup<G1, G2> {
g2_monomial[i] = point;
}

let roots_of_unity = math::roots_of_unity();
let roots_of_unity = Box::new(roots_of_unity);
let roots_of_unity: [Fr; G1] = math::roots_of_unity();
let roots_of_unity_brp = math::bit_reversal_permutation_boxed_array(roots_of_unity);

Ok(Setup {
g1_lagrange,
g1_lagrange_brp,
g2_monomial,
roots_of_unity,
roots_of_unity_brp,
})
}

Expand All @@ -110,39 +111,37 @@ impl<const G1: usize, const G2: usize> Setup<G1, G2> {
assert_eq!(proofs.as_ref().len(), commitments.as_ref().len());
assert_eq!(commitments.as_ref().len(), points.as_ref().len());
assert_eq!(points.as_ref().len(), evals.as_ref().len());
let n = proofs.as_ref().len();

const DOMAIN: &[u8; 16] = b"RCKZGBATCH___V1_";
let degree = (G1 as u128).to_be_bytes();
let len = (proofs.as_ref().len() as u128).to_be_bytes();
let len = (n as u128).to_be_bytes();

let mut data = [0; 48];
data[..16].copy_from_slice(DOMAIN.as_slice());
data[16..32].copy_from_slice(&degree);
data[32..].copy_from_slice(&len);

let r = Fr::hash_to(data);
let mut rpowers = Vec::with_capacity(proofs.as_ref().len());
let mut rpowers = Vec::with_capacity(n);
let mut points_mul_rpowers = Vec::with_capacity(n);
let mut comms_minus_evals = Vec::with_capacity(n);
for i in 0..proofs.as_ref().len() {
rpowers.push(r.pow(&Fr::from(i as u64)));
let rpower = r.pow(&Fr::from(i as u64));
rpowers.push(rpower);

let point = points.as_ref()[i];
points_mul_rpowers.push(point * rpower);

let commitment = commitments.as_ref()[i];
let eval = evals.as_ref()[i];
comms_minus_evals.push(commitment + (P1::neg_generator() * eval));
}

let proof_lincomb = P1::lincomb(proofs.as_ref().iter().zip(rpowers.iter()));
let proof_z_lincomb = P1::lincomb_owned(
proofs.as_ref().iter().copied().zip(
points
.as_ref()
.iter()
.zip(rpowers.iter())
.map(|(point, pow)| point * pow),
),
);
let proof_lincomb = P1::lincomb(&proofs, &rpowers);
let proof_z_lincomb = P1::lincomb(proofs, points_mul_rpowers);

let comm_minus_eval = commitments
.as_ref()
.iter()
.zip(evals.as_ref().iter())
.map(|(comm, eval)| *comm + (P1::neg_generator() * eval));
let comm_minus_eval_lincomb = P1::lincomb_owned(comm_minus_eval.zip(rpowers));
let comm_minus_eval_lincomb = P1::lincomb(comms_minus_evals, rpowers);

bls::verify_pairings(
(proof_lincomb, self.g2_monomial[1]),
Expand Down
93 changes: 36 additions & 57 deletions src/math.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use core::{marker::PhantomData, ops::Index};
use std::fmt::Debug;

use crate::bls::Fr;

Expand Down Expand Up @@ -28,66 +28,45 @@ pub fn roots_of_unity<const ORDER: usize>() -> [Fr; ORDER] {
roots
}

pub struct BitReversalPermutation<T, S> {
elements: S,
phantom: PhantomData<T>,
}

impl<T, S> BitReversalPermutation<T, S>
/// # Panics
///
/// This function will panic if the length of `elements` is not a power of 2.
pub(crate) fn bit_reversal_permutation<T>(elements: impl AsRef<[T]>) -> Vec<T>
where
S: AsRef<[T]>,
T: Copy,
{
/// # Panics
///
/// This function will panic if the length of `elements` is not a power of 2.
pub fn new(elements: S) -> Self {
assert!(elements.as_ref().len().is_power_of_two());
Self {
elements,
phantom: PhantomData,
}
}

pub(crate) fn iter(&self) -> BitReversalPermutationIter<T> {
BitReversalPermutationIter {
inner: self.elements.as_ref(),
index: 0,
}
let n = elements.as_ref().len();
assert!(n.is_power_of_two());
let mut brp = Vec::with_capacity(n);
for i in 0..n {
let index = bit_reversal_permutation_index(i, n);
brp.push(elements.as_ref()[index]);
}
brp
}

impl<T, S> Index<usize> for BitReversalPermutation<T, S>
/// # Panics
///
/// This function will panic if the length of `elements` is not equal to `N`.
///
/// This function will panic if the length of `elements` is not a power of 2.
pub(crate) fn bit_reversal_permutation_boxed_array<T, const N: usize>(
elements: impl AsRef<[T]>,
) -> Box<[T; N]>
where
S: AsRef<[T]>,
T: Copy + Debug,
{
type Output = T;

fn index(&self, index: usize) -> &Self::Output {
let index = bit_reversal_permutation_index(index, self.elements.as_ref().len());
&self.elements.as_ref()[index]
}
}

pub struct BitReversalPermutationIter<'a, T> {
inner: &'a [T],
index: usize,
}

impl<'a, T> Iterator for BitReversalPermutationIter<'a, T> {
type Item = &'a T;
assert_eq!(elements.as_ref().len(), N);
assert!(N.is_power_of_two());

fn next(&mut self) -> Option<Self::Item> {
if self.index == self.inner.len() {
return None;
}

let index = bit_reversal_permutation_index(self.index, self.inner.len());
let next = &self.inner[index];
let brp = bit_reversal_permutation(elements);

self.index += 1;
// TODO: make sure the conversion does not cause a new allocation
let brp: Box<[T; N]> = brp
.try_into()
.expect("infallible conversion to equal len boxed array");

Some(next)
}
brp
}

fn bit_reversal_permutation_index(index: usize, len: usize) -> usize {
Expand All @@ -107,12 +86,12 @@ mod tests {
}

// since the permutation is an involution, the double application should be equal to the identity function
let permutation = BitReversalPermutation::new(elements.clone());
let permutation = super::bit_reversal_permutation(&elements);
let mut permuted = Vec::with_capacity(N);
for i in 0..N {
permuted.push(permutation[i]);
for element in permutation.iter().take(N) {
permuted.push(*element);
}
let double_permutation = BitReversalPermutation::new(permuted);
let double_permutation = super::bit_reversal_permutation(permuted);
for i in 0..N {
assert_eq!(double_permutation[i], elements[i]);
}
Expand All @@ -122,8 +101,8 @@ mod tests {
#[should_panic]
fn bit_reversal_permutation_non_power_of_two() {
const N: usize = (1 << 12) - 1;
let mut elements = vec![0u16; N];
BitReversalPermutation::new(&mut elements);
let elements = vec![0u16; N];
super::bit_reversal_permutation(elements);
}

#[test]
Expand Down
Loading