From b8e458e8af4146e058fbe3dafa73973b23d403bd Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Tue, 10 Jan 2023 03:03:47 -0800 Subject: [PATCH] Parallelize SHPLONK multi-open prover (#114) * feat: parallelize (cpu) shplonk prover * shplonk: improve `construct_intermediate_sets` using `BTreeSet` and `BTreeMap` more aggressively * shplonk: add `Send` and `Sync` to `Query` trait for more parallelization * fix: ensure the order of the collection of rotation sets is independent of the values of the opening points Co-authored-by: Jonathan Wang --- halo2_proofs/src/poly/commitment.rs | 2 +- .../src/poly/kzg/multiopen/shplonk.rs | 84 ++++++++----------- .../src/poly/kzg/multiopen/shplonk/prover.rs | 51 ++++++----- halo2_proofs/src/poly/query.rs | 4 +- 4 files changed, 69 insertions(+), 72 deletions(-) diff --git a/halo2_proofs/src/poly/commitment.rs b/halo2_proofs/src/poly/commitment.rs index 91fa244280..e82515cfff 100644 --- a/halo2_proofs/src/poly/commitment.rs +++ b/halo2_proofs/src/poly/commitment.rs @@ -99,7 +99,7 @@ pub trait ParamsProver<'params, C: CurveAffine>: Params<'params, C> { pub trait ParamsVerifier<'params, C: CurveAffine>: Params<'params, C> {} /// Multi scalar multiplication engine -pub trait MSM: Clone + Debug { +pub trait MSM: Clone + Debug + Send + Sync { /// Add arbitrary term (the scalar and the point) fn append_term(&mut self, scalar: C::Scalar, point: C::CurveExt); diff --git a/halo2_proofs/src/poly/kzg/multiopen/shplonk.rs b/halo2_proofs/src/poly/kzg/multiopen/shplonk.rs index 125936229e..0b1a2492b1 100644 --- a/halo2_proofs/src/poly/kzg/multiopen/shplonk.rs +++ b/halo2_proofs/src/poly/kzg/multiopen/shplonk.rs @@ -9,10 +9,11 @@ use crate::{ poly::{query::Query, Coeff, Polynomial}, transcript::ChallengeScalar, }; - +use rayon::prelude::*; use std::{ - collections::{btree_map::Entry, BTreeMap, BTreeSet}, + collections::{btree_map::Entry, BTreeMap, BTreeSet, HashMap, HashSet}, marker::PhantomData, + sync::Arc, }; #[derive(Clone, Copy, Debug)] @@ -49,7 +50,7 @@ struct RotationSet { #[derive(Debug, PartialEq)] struct IntermediateSets> { rotation_sets: Vec>, - super_point_set: Vec, + super_point_set: BTreeSet, } fn construct_intermediate_sets>( @@ -69,18 +70,8 @@ where .get_eval() }; - // Order points according to their rotation - let mut rotation_point_map = BTreeMap::new(); - for query in queries.clone() { - let point = rotation_point_map - .entry(query.get_point()) - .or_insert_with(|| query.get_point()); - - // Assert rotation point matching consistency - assert_eq!(*point, query.get_point()); - } - // All points appear in queries - let super_point_set: Vec = rotation_point_map.values().cloned().collect(); + // All points that appear in queries + let mut super_point_set = BTreeSet::new(); // Collect rotation sets for each commitment // Example elements in the vector: @@ -89,19 +80,21 @@ where // (C_2, {r_2, r_3, r_4}), // (C_3, {r_2, r_3, r_4}), // ... - let mut commitment_rotation_set_map: Vec<(Q::Commitment, Vec)> = vec![]; - for query in queries.clone() { + let mut commitment_rotation_set_map: Vec<(Q::Commitment, BTreeSet)> = vec![]; + for query in queries.iter() { let rotation = query.get_point(); - if let Some(pos) = commitment_rotation_set_map - .iter() - .position(|(commitment, _)| *commitment == query.get_commitment()) + super_point_set.insert(rotation); + if let Some(commitment_rotation_set) = commitment_rotation_set_map + .iter_mut() + .find(|(commitment, _)| *commitment == query.get_commitment()) { - let (_, rotation_set) = &mut commitment_rotation_set_map[pos]; - if !rotation_set.contains(&rotation) { - rotation_set.push(rotation); - } + let (_, rotation_set) = commitment_rotation_set; + rotation_set.insert(rotation); } else { - commitment_rotation_set_map.push((query.get_commitment(), vec![rotation])); + commitment_rotation_set_map.push(( + query.get_commitment(), + BTreeSet::from_iter(std::iter::once(rotation)), + )); }; } @@ -111,41 +104,38 @@ where // {r_1, r_2, r_3} : [C_1] // {r_2, r_3, r_4} : [C_2, C_3], // ... - let mut rotation_set_commitment_map = Vec::<(Vec<_>, Vec)>::new(); - for (commitment, rotation_set) in commitment_rotation_set_map.iter() { - if let Some(pos) = rotation_set_commitment_map.iter().position(|(set, _)| { - BTreeSet::::from_iter(set.iter().cloned()) - == BTreeSet::::from_iter(rotation_set.iter().cloned()) - }) { - let (_, commitments) = &mut rotation_set_commitment_map[pos]; - if !commitments.contains(commitment) { - commitments.push(*commitment); - } + // NOTE: we want to make the order of the collection of rotation sets independent of the opening points, to ease the verifier computation + let mut rotation_set_commitment_map: Vec<(BTreeSet, Vec)> = vec![]; + for (commitment, rotation_set) in commitment_rotation_set_map.into_iter() { + if let Some(rotation_set_commitment) = rotation_set_commitment_map + .iter_mut() + .find(|(set, _)| set == &rotation_set) + { + let (_, commitments) = rotation_set_commitment; + commitments.push(commitment); } else { - rotation_set_commitment_map.push((rotation_set.clone(), vec![*commitment])) - } + rotation_set_commitment_map.push((rotation_set, vec![commitment])); + }; } let rotation_sets = rotation_set_commitment_map - .into_iter() + .into_par_iter() .map(|(rotations, commitments)| { + let rotations_vec = rotations.iter().collect::>(); let commitments: Vec> = commitments - .iter() + .into_par_iter() .map(|commitment| { - let evals: Vec = rotations - .iter() - .map(|rotation| get_eval(*commitment, *rotation)) + let evals: Vec = rotations_vec + .par_iter() + .map(|&&rotation| get_eval(commitment, rotation)) .collect(); - Commitment((*commitment, evals)) + Commitment((commitment, evals)) }) .collect(); RotationSet { commitments, - points: rotations - .iter() - .map(|rotation| *rotation_point_map.get(rotation).unwrap()) - .collect(), + points: rotations.into_iter().collect(), } }) .collect::>>(); diff --git a/halo2_proofs/src/poly/kzg/multiopen/shplonk/prover.rs b/halo2_proofs/src/poly/kzg/multiopen/shplonk/prover.rs index 2585d9ab69..b916228857 100644 --- a/halo2_proofs/src/poly/kzg/multiopen/shplonk/prover.rs +++ b/halo2_proofs/src/poly/kzg/multiopen/shplonk/prover.rs @@ -17,6 +17,7 @@ use ff::Field; use group::Curve; use halo2curves::pairing::Engine; use rand_core::RngCore; +use rayon::prelude::*; use std::fmt::Debug; use std::io::{self, Write}; use std::marker::PhantomData; @@ -36,8 +37,8 @@ struct CommitmentExtension<'a, C: CurveAffine> { } impl<'a, C: CurveAffine> Commitment> { - fn extend(&self, points: Vec) -> CommitmentExtension<'a, C> { - let poly = lagrange_interpolate(&points[..], &self.evals()[..]); + fn extend(&self, points: &[C::Scalar]) -> CommitmentExtension<'a, C> { + let poly = lagrange_interpolate(points, &self.evals()[..]); let low_degree_equivalent = Polynomial { values: poly, @@ -79,10 +80,10 @@ struct RotationSetExtension<'a, C: CurveAffine> { } impl<'a, C: CurveAffine> RotationSet> { - fn extend(&self, commitments: Vec>) -> RotationSetExtension<'a, C> { + fn extend(self, commitments: Vec>) -> RotationSetExtension<'a, C> { RotationSetExtension { commitments, - points: self.points.clone(), + points: self.points, } } } @@ -136,8 +137,9 @@ impl<'params, E: Engine + Debug> Prover<'params, KZGCommitmentScheme> // [P_i_0(X) - R_i_0(X), P_i_1(X) - R_i_1(X), ... ] let numerators = rotation_set .commitments - .iter() - .map(|commitment| commitment.quotient_contribution()); + .par_iter() + .map(|commitment| commitment.quotient_contribution()) + .collect::>(); // define numerator polynomial as // N_i_j(X) = (P_i_j(X) - R_i_j(X)) @@ -145,6 +147,7 @@ impl<'params, E: Engine + Debug> Prover<'params, KZGCommitmentScheme> // N_i(X) = linear_combinination(y, N_i_j(X)) // where y is random scalar to combine numerator polynomials let n_x = numerators + .into_iter() .zip(powers(*y)) .map(|(numerator, power_of_y)| numerator * power_of_y) .reduce(|acc, numerator| acc + &numerator) @@ -171,12 +174,12 @@ impl<'params, E: Engine + Debug> Prover<'params, KZGCommitmentScheme> ); let rotation_sets: Vec> = rotation_sets - .iter() + .into_par_iter() .map(|rotation_set| { let commitments: Vec> = rotation_set .commitments - .iter() - .map(|commitment_data| commitment_data.extend(rotation_set.points.clone())) + .par_iter() + .map(|commitment_data| commitment_data.extend(&rotation_set.points)) .collect(); rotation_set.extend(commitments) }) @@ -184,9 +187,13 @@ impl<'params, E: Engine + Debug> Prover<'params, KZGCommitmentScheme> let v: ChallengeV<_> = transcript.squeeze_challenge_scalar(); - let quotient_polynomials = rotation_sets.iter().map(quotient_contribution); + let quotient_polynomials = rotation_sets + .par_iter() + .map(quotient_contribution) + .collect::>(); let h_x: Polynomial = quotient_polynomials + .into_iter() .zip(powers(*v)) .map(|(poly, power_of_v)| poly * power_of_v) .reduce(|acc, poly| acc + &poly) @@ -196,18 +203,15 @@ impl<'params, E: Engine + Debug> Prover<'params, KZGCommitmentScheme> transcript.write_point(h)?; let u: ChallengeU<_> = transcript.squeeze_challenge_scalar(); - let zt_eval = evaluate_vanishing_polynomial(&super_point_set[..], *u); - let linearisation_contribution = |rotation_set: RotationSetExtension| -> (Polynomial, E::Scalar) { - let diffs: Vec = super_point_set - .iter() - .filter(|point| !rotation_set.points.contains(point)) - .copied() - .collect(); + let mut diffs = super_point_set.clone(); + for point in rotation_set.points.iter() { + diffs.remove(point); + } + let diffs = diffs.into_iter().collect::>(); // calculate difference vanishing polynomial evaluation - let z_i = evaluate_vanishing_polynomial(&diffs[..], *u); // inner linearisation contibutions are @@ -216,15 +220,15 @@ impl<'params, E: Engine + Debug> Prover<'params, KZGCommitmentScheme> // where u is random evaluation point let inner_contributions = rotation_set .commitments - .iter() - .map(|commitment| commitment.linearisation_contribution(*u)); + .par_iter() + .map(|commitment| commitment.linearisation_contribution(*u)).collect::>(); // define inner contributor polynomial as // L_i_j(X) = (P_i_j(X) - r_i_j) // and combine polynomials with same evaluation point set // L_i(X) = linear_combinination(y, L_i_j(X)) // where y is random scalar to combine inner contibutors - let l_x: Polynomial = inner_contributions.zip(powers(*y)).map(|(poly, power_of_y)| poly * power_of_y).reduce(|acc, poly| acc + &poly).unwrap(); + let l_x: Polynomial = inner_contributions.into_iter().zip(powers(*y)).map(|(poly, power_of_y)| poly * power_of_y).reduce(|acc, poly| acc + &poly).unwrap(); // finally scale l_x by difference vanishing polynomial evaluation z_i (l_x * z_i, z_i) @@ -235,7 +239,7 @@ impl<'params, E: Engine + Debug> Prover<'params, KZGCommitmentScheme> Vec>, Vec, ) = rotation_sets - .into_iter() + .into_par_iter() .map(linearisation_contribution) .unzip(); @@ -246,9 +250,12 @@ impl<'params, E: Engine + Debug> Prover<'params, KZGCommitmentScheme> .reduce(|acc, poly| acc + &poly) .unwrap(); + let super_point_set = super_point_set.into_iter().collect::>(); + let zt_eval = evaluate_vanishing_polynomial(&super_point_set[..], *u); let l_x = l_x - &(h_x * zt_eval); // sanity check + #[cfg(debug_assertions)] { let must_be_zero = eval_polynomial(&l_x.values[..], *u); assert_eq!(must_be_zero, E::Scalar::zero()); diff --git a/halo2_proofs/src/poly/query.rs b/halo2_proofs/src/poly/query.rs index f13cc25a89..c596e6a71c 100644 --- a/halo2_proofs/src/poly/query.rs +++ b/halo2_proofs/src/poly/query.rs @@ -8,8 +8,8 @@ use crate::{ use ff::Field; use halo2curves::CurveAffine; -pub trait Query: Sized + Clone { - type Commitment: PartialEq + Copy; +pub trait Query: Sized + Clone + Send + Sync { + type Commitment: PartialEq + Copy + Send + Sync; type Eval: Clone + Default + Debug; fn get_point(&self) -> F;