Skip to content

Commit

Permalink
Migrate to new ZAL API
Browse files Browse the repository at this point in the history
Deprecate pre-ZAL API

Insert patch in `Cargo.toml` for `../halo2curves`

move ZAL to middleware

ZAL: introduce modular MSM/FFT per prover accelerators

ZAL: address clippy

halo2_common: best_fft and best_multiexp have been moved to halo2curves

zal: store engine at the prover level

zal: fix clippy

halo2_common::arithmetic -> halo2_backend::arithmetic sync #281

Remove Copy trait requirement from descriptors

run fmt

zal: remove Option from PlonkEngineConfig

zal: add default engine test
  • Loading branch information
mratsim committed Mar 29, 2024
1 parent 6d1dad3 commit 8278d44
Show file tree
Hide file tree
Showing 31 changed files with 821 additions and 388 deletions.
205 changes: 2 additions & 203 deletions halo2_backend/src/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,202 +3,13 @@

use group::{
ff::{BatchInvert, PrimeField},
Curve, Group, GroupOpsOwned, ScalarMulOwned,
Curve,
};
use halo2_common::multicore;
pub use halo2_middleware::ff::Field;

use halo2curves::msm::multiexp_serial;
pub use halo2curves::{fft::best_fft, msm::best_multiexp};
pub use halo2curves::{CurveAffine, CurveExt};

/// This represents an element of a group with basic operations that can be
/// performed. This allows an FFT implementation (for example) to operate
/// generically over either a field or elliptic curve group.
pub trait FftGroup<Scalar: Field>:
Copy + Send + Sync + 'static + GroupOpsOwned + ScalarMulOwned<Scalar>
{
}

impl<T, Scalar> FftGroup<Scalar> for T
where
Scalar: Field,
T: Copy + Send + Sync + 'static + GroupOpsOwned + ScalarMulOwned<Scalar>,
{
}

/// Performs a small multi-exponentiation operation.
/// Uses the double-and-add algorithm with doublings shared across points.
pub fn small_multiexp<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C]) -> C::Curve {
let coeffs: Vec<_> = coeffs.iter().map(|a| a.to_repr()).collect();
let mut acc = C::Curve::identity();

// for byte idx
for byte_idx in (0..((C::Scalar::NUM_BITS as usize + 7) / 8)).rev() {
// for bit idx
for bit_idx in (0..8).rev() {
acc = acc.double();
// for each coeff
for coeff_idx in 0..coeffs.len() {
let byte = coeffs[coeff_idx].as_ref()[byte_idx];
if ((byte >> bit_idx) & 1) != 0 {
acc += bases[coeff_idx];
}
}
}
}

acc
}

/// Performs a multi-exponentiation operation.
///
/// This function will panic if coeffs and bases have a different length.
///
/// This will use multithreading if beneficial.
pub fn best_multiexp<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C]) -> C::Curve {
assert_eq!(coeffs.len(), bases.len());

let num_threads = multicore::current_num_threads();
if coeffs.len() > num_threads {
let chunk = coeffs.len() / num_threads;
let num_chunks = coeffs.chunks(chunk).len();
let mut results = vec![C::Curve::identity(); num_chunks];
multicore::scope(|scope| {
let chunk = coeffs.len() / num_threads;

for ((coeffs, bases), acc) in coeffs
.chunks(chunk)
.zip(bases.chunks(chunk))
.zip(results.iter_mut())
{
scope.spawn(move |_| {
multiexp_serial(coeffs, bases, acc);
});
}
});
results.iter().fold(C::Curve::identity(), |a, b| a + b)
} else {
let mut acc = C::Curve::identity();
multiexp_serial(coeffs, bases, &mut acc);
acc
}
}

/// Performs a radix-$2$ Fast-Fourier Transformation (FFT) on a vector of size
/// $n = 2^k$, when provided `log_n` = $k$ and an element of multiplicative
/// order $n$ called `omega` ($\omega$). The result is that the vector `a`, when
/// interpreted as the coefficients of a polynomial of degree $n - 1$, is
/// transformed into the evaluations of this polynomial at each of the $n$
/// distinct powers of $\omega$. This transformation is invertible by providing
/// $\omega^{-1}$ in place of $\omega$ and dividing each resulting field element
/// by $n$.
///
/// This will use multithreading if beneficial.
pub fn best_fft<Scalar: Field, G: FftGroup<Scalar>>(a: &mut [G], omega: Scalar, log_n: u32) {
fn bitreverse(mut n: usize, l: usize) -> usize {
let mut r = 0;
for _ in 0..l {
r = (r << 1) | (n & 1);
n >>= 1;
}
r
}

let threads = multicore::current_num_threads();
let log_threads = log2_floor(threads);
let n = a.len();
assert_eq!(n, 1 << log_n);

for k in 0..n {
let rk = bitreverse(k, log_n as usize);
if k < rk {
a.swap(rk, k);
}
}

// precompute twiddle factors
let twiddles: Vec<_> = (0..(n / 2))
.scan(Scalar::ONE, |w, _| {
let tw = *w;
*w *= &omega;
Some(tw)
})
.collect();

if log_n <= log_threads {
let mut chunk = 2_usize;
let mut twiddle_chunk = n / 2;
for _ in 0..log_n {
a.chunks_mut(chunk).for_each(|coeffs| {
let (left, right) = coeffs.split_at_mut(chunk / 2);

// case when twiddle factor is one
let (a, left) = left.split_at_mut(1);
let (b, right) = right.split_at_mut(1);
let t = b[0];
b[0] = a[0];
a[0] += &t;
b[0] -= &t;

left.iter_mut()
.zip(right.iter_mut())
.enumerate()
.for_each(|(i, (a, b))| {
let mut t = *b;
t *= &twiddles[(i + 1) * twiddle_chunk];
*b = *a;
*a += &t;
*b -= &t;
});
});
chunk *= 2;
twiddle_chunk /= 2;
}
} else {
recursive_butterfly_arithmetic(a, n, 1, &twiddles)
}
}

/// This perform recursive butterfly arithmetic
pub fn recursive_butterfly_arithmetic<Scalar: Field, G: FftGroup<Scalar>>(
a: &mut [G],
n: usize,
twiddle_chunk: usize,
twiddles: &[Scalar],
) {
if n == 2 {
let t = a[1];
a[1] = a[0];
a[0] += &t;
a[1] -= &t;
} else {
let (left, right) = a.split_at_mut(n / 2);
multicore::join(
|| recursive_butterfly_arithmetic(left, n / 2, twiddle_chunk * 2, twiddles),
|| recursive_butterfly_arithmetic(right, n / 2, twiddle_chunk * 2, twiddles),
);

// case when twiddle factor is one
let (a, left) = left.split_at_mut(1);
let (b, right) = right.split_at_mut(1);
let t = b[0];
b[0] = a[0];
a[0] += &t;
b[0] -= &t;

left.iter_mut()
.zip(right.iter_mut())
.enumerate()
.for_each(|(i, (a, b))| {
let mut t = *b;
t *= &twiddles[(i + 1) * twiddle_chunk];
*b = *a;
*a += &t;
*b -= &t;
});
}
}

/// Convert coefficient bases group elements to lagrange basis by inverse FFT.
pub fn g_to_lagrange<C: CurveAffine>(g_projective: Vec<C::Curve>, k: u32) -> Vec<C> {
let n_inv = C::Scalar::TWO_INV.pow_vartime([k as u64, 0, 0, 0]);
Expand Down Expand Up @@ -344,18 +155,6 @@ pub fn parallelize<T: Send, F: Fn(&mut [T], usize) + Send + Sync + Clone>(v: &mu
});
}

fn log2_floor(num: usize) -> u32 {
assert!(num > 0);

let mut pow = 0;

while (1 << (pow + 1)) <= num {
pow += 1;
}

pow
}

/// Returns coefficients of an n - 1 degree polynomial given a set of n points
/// and their evaluations. This function will panic if two values in `points`
/// are the same.
Expand Down
2 changes: 2 additions & 0 deletions halo2_backend/src/plonk/keygen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

use group::Curve;
use halo2_middleware::ff::{Field, FromUniformBytes};
use halo2_middleware::zal::impls::H2cEngine;

use super::{evaluation::Evaluator, permutation, Polynomial, ProvingKey, VerifyingKey};
use crate::{
Expand Down Expand Up @@ -72,6 +73,7 @@ where
.map(|poly| {
params
.commit_lagrange(
&H2cEngine::new(),
&Polynomial::new_lagrange_from_vec(poly.clone()),
Blind::default(),
)
Expand Down
14 changes: 12 additions & 2 deletions halo2_backend/src/plonk/lookup/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use halo2_common::plonk::{
};
use halo2_middleware::ff::WithSmallOrderMulGroup;
use halo2_middleware::poly::Rotation;
use halo2_middleware::zal::{impls::PlonkEngine, traits::MsmAccel};
use rand_core::RngCore;
use std::{
collections::BTreeMap,
Expand Down Expand Up @@ -70,7 +71,9 @@ pub(in crate::plonk) fn lookup_commit_permuted<
E: EncodedChallenge<C>,
R: RngCore,
T: TranscriptWrite<C, E>,
M: MsmAccel<C>,
>(
engine: &PlonkEngine<C, M>,
arg: &Argument<F>,
pk: &ProvingKey<C>,
params: &P,
Expand Down Expand Up @@ -128,7 +131,9 @@ where
let mut commit_values = |values: &Polynomial<C::Scalar, LagrangeCoeff>| {
let poly = pk.vk.domain.lagrange_to_coeff(values.clone());
let blind = Blind(C::Scalar::random(&mut rng));
let commitment = params.commit_lagrange(values, blind).to_affine();
let commitment = params
.commit_lagrange(&engine.msm_backend, values, blind)
.to_affine();
(poly, blind, commitment)
};

Expand Down Expand Up @@ -164,14 +169,17 @@ impl<C: CurveAffine> Permuted<C> {
/// grand product polynomial over the lookup. The grand product polynomial
/// is used to populate the Product<C> struct. The Product<C> struct is
/// added to the Lookup and finally returned by the method.
#[allow(clippy::too_many_arguments)]
pub(in crate::plonk) fn commit_product<
'params,
P: Params<'params, C>,
E: EncodedChallenge<C>,
R: RngCore,
T: TranscriptWrite<C, E>,
M: MsmAccel<C>,
>(
self,
engine: &PlonkEngine<C, M>,
pk: &ProvingKey<C>,
params: &P,
beta: ChallengeBeta<C>,
Expand Down Expand Up @@ -288,7 +296,9 @@ impl<C: CurveAffine> Permuted<C> {
}

let product_blind = Blind(C::Scalar::random(rng));
let product_commitment = params.commit_lagrange(&z, product_blind).to_affine();
let product_commitment = params
.commit_lagrange(&engine.msm_backend, &z, product_blind)
.to_affine();
let z = pk.vk.domain.lagrange_to_coeff(z);

// Hash product commitment
Expand Down
3 changes: 2 additions & 1 deletion halo2_backend/src/plonk/permutation/keygen.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use group::Curve;
use halo2_middleware::ff::{Field, PrimeField};
use halo2_middleware::zal::impls::H2cEngine;

use super::{Argument, ProvingKey, VerifyingKey};
use crate::{
Expand Down Expand Up @@ -477,7 +478,7 @@ pub(crate) fn build_vk<'params, C: CurveAffine, P: Params<'params, C>>(
// Compute commitment to permutation polynomial
commitments.push(
params
.commit_lagrange(permutation, Blind::default())
.commit_lagrange(&H2cEngine::new(), permutation, Blind::default())
.to_affine(),
);
}
Expand Down
8 changes: 6 additions & 2 deletions halo2_backend/src/plonk/permutation/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ use group::{
ff::{BatchInvert, Field},
Curve,
};
use halo2_middleware::ff::PrimeField;
use halo2_middleware::zal::traits::MsmAccel;
use halo2_middleware::{ff::PrimeField, zal::impls::PlonkEngine};
use rand_core::RngCore;
use std::iter::{self, ExactSizeIterator};

Expand Down Expand Up @@ -54,7 +55,9 @@ pub(in crate::plonk) fn permutation_commit<
E: EncodedChallenge<C>,
R: RngCore,
T: TranscriptWrite<C, E>,
M: MsmAccel<C>,
>(
engine: &PlonkEngine<C, M>,
arg: &Argument,
params: &P,
pk: &plonk::ProvingKey<C>,
Expand Down Expand Up @@ -172,7 +175,8 @@ pub(in crate::plonk) fn permutation_commit<

let blind = Blind(C::Scalar::random(&mut rng));

let permutation_product_commitment_projective = params.commit_lagrange(&z, blind);
let permutation_product_commitment_projective =
params.commit_lagrange(&engine.msm_backend, &z, blind);
let permutation_product_blind = blind;
let z = domain.lagrange_to_coeff(z);
let permutation_product_poly = z.clone();
Expand Down
Loading

0 comments on commit 8278d44

Please sign in to comment.