From 3c2c6528bdcf5725964537adcac2c43bc9733372 Mon Sep 17 00:00:00 2001 From: Mamy Ratsimbazafy Date: Mon, 19 Feb 2024 16:07:03 +0100 Subject: [PATCH] move ZAL to middleware --- halo2_backend/Cargo.toml | 2 +- halo2_backend/src/plonk/keygen.rs | 2 +- halo2_backend/src/plonk/lookup/prover.rs | 2 +- halo2_backend/src/plonk/permutation/keygen.rs | 2 +- halo2_backend/src/plonk/permutation/prover.rs | 2 +- halo2_backend/src/plonk/prover.rs | 2 +- halo2_backend/src/plonk/shuffle/prover.rs | 2 +- halo2_backend/src/plonk/vanishing/prover.rs | 2 +- halo2_backend/src/plonk/verifier.rs | 2 +- halo2_backend/src/plonk/verifier/batch.rs | 3 +- halo2_backend/src/poly/commitment.rs | 6 +- halo2_backend/src/poly/ipa/commitment.rs | 4 +- .../src/poly/ipa/commitment/prover.rs | 2 +- halo2_backend/src/poly/ipa/msm.rs | 4 +- .../src/poly/ipa/multiopen/prover.rs | 2 +- halo2_backend/src/poly/ipa/strategy.rs | 2 +- halo2_backend/src/poly/kzg/commitment.rs | 4 +- halo2_backend/src/poly/kzg/msm.rs | 2 +- .../src/poly/kzg/multiopen/gwc/prover.rs | 3 +- .../src/poly/kzg/multiopen/shplonk/prover.rs | 2 +- halo2_backend/src/poly/kzg/strategy.rs | 2 +- halo2_backend/src/poly/multiopen_test.rs | 2 +- halo2_common/Cargo.toml | 2 +- halo2_frontend/Cargo.toml | 2 +- halo2_middleware/Cargo.toml | 4 +- halo2_middleware/src/lib.rs | 1 + halo2_middleware/src/zal.rs | 233 ++++++++++++++++++ halo2_proofs/Cargo.toml | 2 +- halo2_proofs/benches/arithmetic.rs | 2 +- halo2_proofs/src/plonk/prover.rs | 2 +- halo2_proofs/tests/frontend_backend_split.rs | 2 +- halo2_proofs/tests/plonk_api.rs | 2 +- 32 files changed, 272 insertions(+), 36 deletions(-) create mode 100644 halo2_middleware/src/zal.rs diff --git a/halo2_backend/Cargo.toml b/halo2_backend/Cargo.toml index 5f0476d1ee..807436a131 100644 --- a/halo2_backend/Cargo.toml +++ b/halo2_backend/Cargo.toml @@ -28,7 +28,7 @@ rustdoc-args = ["--cfg", "docsrs", "--html-in-header", "katex-header.html"] backtrace = { version = "0.3", optional = true } ff = "0.13" group = "0.13" -halo2curves = { git = 'https://github.com/taikoxyz/halo2curves', branch = "pr-pse-exec-engine", default-features = false } +halo2curves = { version = "0.6.0", default-features = false } rand_core = { version = "0.6", default-features = false } tracing = "0.1" blake2b_simd = "1" # MSRV 1.66.0 diff --git a/halo2_backend/src/plonk/keygen.rs b/halo2_backend/src/plonk/keygen.rs index 864cbde91a..df1710eb7f 100644 --- a/halo2_backend/src/plonk/keygen.rs +++ b/halo2_backend/src/plonk/keygen.rs @@ -2,7 +2,7 @@ use group::Curve; use halo2_middleware::ff::{Field, FromUniformBytes}; -use halo2curves::zal::H2cEngine; +use halo2_middleware::zal::impls::H2cEngine; use super::{evaluation::Evaluator, permutation, Polynomial, ProvingKey, VerifyingKey}; use crate::{ diff --git a/halo2_backend/src/plonk/lookup/prover.rs b/halo2_backend/src/plonk/lookup/prover.rs index d9fc5a423d..4e6298b342 100644 --- a/halo2_backend/src/plonk/lookup/prover.rs +++ b/halo2_backend/src/plonk/lookup/prover.rs @@ -18,7 +18,7 @@ use halo2_common::plonk::{ }; use halo2_middleware::ff::WithSmallOrderMulGroup; use halo2_middleware::poly::Rotation; -use halo2curves::zal::MsmAccel; +use halo2_middleware::zal::traits::MsmAccel; use rand_core::RngCore; use std::{ collections::BTreeMap, diff --git a/halo2_backend/src/plonk/permutation/keygen.rs b/halo2_backend/src/plonk/permutation/keygen.rs index 00b6e38427..7c295ef445 100644 --- a/halo2_backend/src/plonk/permutation/keygen.rs +++ b/halo2_backend/src/plonk/permutation/keygen.rs @@ -1,6 +1,6 @@ use group::Curve; use halo2_middleware::ff::{Field, PrimeField}; -use halo2curves::zal::H2cEngine; +use halo2_middleware::zal::impls::H2cEngine; use super::{Argument, ProvingKey, VerifyingKey}; use crate::{ diff --git a/halo2_backend/src/plonk/permutation/prover.rs b/halo2_backend/src/plonk/permutation/prover.rs index e95af4be51..79155aae7c 100644 --- a/halo2_backend/src/plonk/permutation/prover.rs +++ b/halo2_backend/src/plonk/permutation/prover.rs @@ -3,7 +3,7 @@ use group::{ Curve, }; use halo2_middleware::ff::PrimeField; -use halo2curves::zal::MsmAccel; +use halo2_middleware::zal::traits::MsmAccel; use rand_core::RngCore; use std::iter::{self, ExactSizeIterator}; diff --git a/halo2_backend/src/plonk/prover.rs b/halo2_backend/src/plonk/prover.rs index d3bb0da765..ffa8fafb76 100644 --- a/halo2_backend/src/plonk/prover.rs +++ b/halo2_backend/src/plonk/prover.rs @@ -1,6 +1,6 @@ use group::Curve; use halo2_middleware::ff::{Field, FromUniformBytes, WithSmallOrderMulGroup}; -use halo2curves::zal::{H2cEngine, MsmAccel}; +use halo2_middleware::zal::{impls::H2cEngine, traits::MsmAccel}; use rand_core::RngCore; use std::collections::{BTreeSet, HashSet}; use std::{collections::HashMap, iter}; diff --git a/halo2_backend/src/plonk/shuffle/prover.rs b/halo2_backend/src/plonk/shuffle/prover.rs index bba149a96d..8ae4527ddb 100644 --- a/halo2_backend/src/plonk/shuffle/prover.rs +++ b/halo2_backend/src/plonk/shuffle/prover.rs @@ -13,7 +13,7 @@ use group::{ff::BatchInvert, Curve}; use halo2_common::plonk::{ChallengeGamma, ChallengeTheta, ChallengeX, Error, Expression}; use halo2_middleware::ff::WithSmallOrderMulGroup; use halo2_middleware::poly::Rotation; -use halo2curves::zal::MsmAccel; +use halo2_middleware::zal::traits::MsmAccel; use rand_core::RngCore; use std::{ iter, diff --git a/halo2_backend/src/plonk/vanishing/prover.rs b/halo2_backend/src/plonk/vanishing/prover.rs index f035637a6f..5de5626553 100644 --- a/halo2_backend/src/plonk/vanishing/prover.rs +++ b/halo2_backend/src/plonk/vanishing/prover.rs @@ -3,7 +3,7 @@ use std::{collections::HashMap, iter}; use group::Curve; use halo2_common::plonk::{ChallengeX, Error}; use halo2_middleware::ff::Field; -use halo2curves::zal::MsmAccel; +use halo2_middleware::zal::traits::MsmAccel; use rand_chacha::ChaCha20Rng; use rand_core::{RngCore, SeedableRng}; diff --git a/halo2_backend/src/plonk/verifier.rs b/halo2_backend/src/plonk/verifier.rs index cd93974d4e..cd10535571 100644 --- a/halo2_backend/src/plonk/verifier.rs +++ b/halo2_backend/src/plonk/verifier.rs @@ -3,7 +3,7 @@ use halo2_common::plonk::{ ChallengeBeta, ChallengeGamma, ChallengeTheta, ChallengeX, ChallengeY, Error, }; use halo2_middleware::ff::{Field, FromUniformBytes, WithSmallOrderMulGroup}; -use halo2curves::zal::H2cEngine; +use halo2_middleware::zal::impls::H2cEngine; use std::iter; use super::{vanishing, VerifyingKey}; diff --git a/halo2_backend/src/plonk/verifier/batch.rs b/halo2_backend/src/plonk/verifier/batch.rs index a5b2d91ece..f7d5ba6341 100644 --- a/halo2_backend/src/plonk/verifier/batch.rs +++ b/halo2_backend/src/plonk/verifier/batch.rs @@ -1,7 +1,8 @@ use group::ff::Field; use halo2_common::plonk::Error; use halo2_middleware::ff::FromUniformBytes; -use halo2curves::{zal::H2cEngine, CurveAffine}; +use halo2_middleware::zal::impls::H2cEngine; +use halo2curves::CurveAffine; use rand_core::OsRng; use super::{verify_proof, VerificationStrategy}; diff --git a/halo2_backend/src/poly/commitment.rs b/halo2_backend/src/poly/commitment.rs index 22345d1ee7..dc2d426f67 100644 --- a/halo2_backend/src/poly/commitment.rs +++ b/halo2_backend/src/poly/commitment.rs @@ -6,10 +6,8 @@ use super::{ use crate::poly::Error; use crate::transcript::{EncodedChallenge, TranscriptRead, TranscriptWrite}; use halo2_middleware::ff::Field; -use halo2curves::{ - zal::{H2cEngine, MsmAccel}, - CurveAffine, -}; +use halo2_middleware::zal::{impls::H2cEngine, traits::MsmAccel}; +use halo2curves::CurveAffine; use rand_core::RngCore; use std::{ fmt::Debug, diff --git a/halo2_backend/src/poly/ipa/commitment.rs b/halo2_backend/src/poly/ipa/commitment.rs index 8608d80336..47ce3dc36a 100644 --- a/halo2_backend/src/poly/ipa/commitment.rs +++ b/halo2_backend/src/poly/ipa/commitment.rs @@ -10,7 +10,7 @@ use crate::poly::ipa::msm::MSMIPA; use crate::poly::{Coeff, LagrangeCoeff, Polynomial}; use group::{Curve, Group}; -use halo2curves::zal::MsmAccel; +use halo2_middleware::zal::traits::MsmAccel; use std::marker::PhantomData; mod prover; @@ -243,7 +243,7 @@ mod test { use group::Curve; use halo2_middleware::ff::Field; - use halo2curves::zal::H2cEngine; + use halo2_middleware::zal::impls::H2cEngine; #[test] fn test_commit_lagrange_epaffine() { diff --git a/halo2_backend/src/poly/ipa/commitment/prover.rs b/halo2_backend/src/poly/ipa/commitment/prover.rs index 797928fbd8..d16dc5574a 100644 --- a/halo2_backend/src/poly/ipa/commitment/prover.rs +++ b/halo2_backend/src/poly/ipa/commitment/prover.rs @@ -1,5 +1,5 @@ use halo2_middleware::ff::Field; -use halo2curves::zal::MsmAccel; +use halo2_middleware::zal::traits::MsmAccel; use rand_core::RngCore; use super::ParamsIPA; diff --git a/halo2_backend/src/poly/ipa/msm.rs b/halo2_backend/src/poly/ipa/msm.rs index bf5d01ea4d..87eb94cbda 100644 --- a/halo2_backend/src/poly/ipa/msm.rs +++ b/halo2_backend/src/poly/ipa/msm.rs @@ -2,7 +2,7 @@ use crate::arithmetic::CurveAffine; use crate::poly::{commitment::MSM, ipa::commitment::ParamsVerifierIPA}; use group::Group; use halo2_middleware::ff::Field; -use halo2curves::zal::MsmAccel; +use halo2_middleware::zal::traits::MsmAccel; use std::collections::BTreeMap; /// A multiscalar multiplication in the polynomial commitment scheme @@ -222,9 +222,9 @@ mod tests { commitment::{ParamsProver, MSM}, ipa::{commitment::ParamsIPA, msm::MSMIPA}, }; + use halo2_middleware::zal::impls::H2cEngine; use halo2curves::{ pasta::{Ep, EpAffine, Fp, Fq}, - zal::H2cEngine, CurveAffine, }; diff --git a/halo2_backend/src/poly/ipa/multiopen/prover.rs b/halo2_backend/src/poly/ipa/multiopen/prover.rs index ee8f09f1d0..c6802c23ab 100644 --- a/halo2_backend/src/poly/ipa/multiopen/prover.rs +++ b/halo2_backend/src/poly/ipa/multiopen/prover.rs @@ -9,7 +9,7 @@ use crate::transcript::{EncodedChallenge, TranscriptWrite}; use group::Curve; use halo2_middleware::ff::Field; -use halo2curves::zal::MsmAccel; +use halo2_middleware::zal::traits::MsmAccel; use rand_core::RngCore; use std::io; use std::marker::PhantomData; diff --git a/halo2_backend/src/poly/ipa/strategy.rs b/halo2_backend/src/poly/ipa/strategy.rs index 93380e455c..eb0e62c380 100644 --- a/halo2_backend/src/poly/ipa/strategy.rs +++ b/halo2_backend/src/poly/ipa/strategy.rs @@ -10,7 +10,7 @@ use crate::{ }; use group::Curve; use halo2_middleware::ff::Field; -use halo2curves::zal::{H2cEngine, MsmAccel}; +use halo2_middleware::zal::{impls::H2cEngine, traits::MsmAccel}; use halo2curves::CurveAffine; use rand_core::OsRng; diff --git a/halo2_backend/src/poly/kzg/commitment.rs b/halo2_backend/src/poly/kzg/commitment.rs index 13ef1bf59f..4aec6521e1 100644 --- a/halo2_backend/src/poly/kzg/commitment.rs +++ b/halo2_backend/src/poly/kzg/commitment.rs @@ -6,8 +6,8 @@ use crate::SerdeFormat; use group::{prime::PrimeCurveAffine, Curve, Group}; use halo2_middleware::ff::{Field, PrimeField}; +use halo2_middleware::zal::traits::MsmAccel; use halo2curves::pairing::Engine; -use halo2curves::zal::MsmAccel; use rand_core::{OsRng, RngCore}; use std::fmt::Debug; use std::marker::PhantomData; @@ -376,7 +376,7 @@ mod test { use crate::poly::commitment::{Blind, Params}; use crate::poly::kzg::commitment::ParamsKZG; use halo2_middleware::ff::Field; - use halo2curves::zal::H2cEngine; + use halo2_middleware::zal::impls::H2cEngine; #[test] fn test_commit_lagrange() { diff --git a/halo2_backend/src/poly/kzg/msm.rs b/halo2_backend/src/poly/kzg/msm.rs index f4f6feb7c0..ad77a82c32 100644 --- a/halo2_backend/src/poly/kzg/msm.rs +++ b/halo2_backend/src/poly/kzg/msm.rs @@ -3,9 +3,9 @@ use std::fmt::Debug; use super::commitment::ParamsKZG; use crate::{arithmetic::parallelize, poly::commitment::MSM}; use group::{Curve, Group}; +use halo2_middleware::zal::traits::MsmAccel; use halo2curves::{ pairing::{Engine, MillerLoopResult, MultiMillerLoop}, - zal::MsmAccel, CurveAffine, CurveExt, }; diff --git a/halo2_backend/src/poly/kzg/multiopen/gwc/prover.rs b/halo2_backend/src/poly/kzg/multiopen/gwc/prover.rs index 183017f598..4b9cda2470 100644 --- a/halo2_backend/src/poly/kzg/multiopen/gwc/prover.rs +++ b/halo2_backend/src/poly/kzg/multiopen/gwc/prover.rs @@ -9,8 +9,9 @@ use crate::poly::{commitment::Blind, Polynomial}; use crate::transcript::{EncodedChallenge, TranscriptWrite}; use group::Curve; +use halo2_middleware::zal::traits::MsmAccel; +use halo2curves::pairing::Engine; use halo2curves::CurveExt; -use halo2curves::{pairing::Engine, zal::MsmAccel}; use rand_core::RngCore; use std::fmt::Debug; use std::io; diff --git a/halo2_backend/src/poly/kzg/multiopen/shplonk/prover.rs b/halo2_backend/src/poly/kzg/multiopen/shplonk/prover.rs index 07dc20264e..194215e6da 100644 --- a/halo2_backend/src/poly/kzg/multiopen/shplonk/prover.rs +++ b/halo2_backend/src/poly/kzg/multiopen/shplonk/prover.rs @@ -15,8 +15,8 @@ use crate::transcript::{EncodedChallenge, TranscriptWrite}; use crate::multicore::{IntoParallelIterator, ParallelIterator}; use group::Curve; use halo2_middleware::ff::Field; +use halo2_middleware::zal::traits::MsmAccel; use halo2curves::pairing::Engine; -use halo2curves::zal::MsmAccel; use halo2curves::CurveExt; use rand_core::RngCore; use std::fmt::Debug; diff --git a/halo2_backend/src/poly/kzg/strategy.rs b/halo2_backend/src/poly/kzg/strategy.rs index e3ea556c87..c9dfc4972a 100644 --- a/halo2_backend/src/poly/kzg/strategy.rs +++ b/halo2_backend/src/poly/kzg/strategy.rs @@ -11,9 +11,9 @@ use crate::{ }, }; use halo2_middleware::ff::Field; +use halo2_middleware::zal::impls::H2cEngine; use halo2curves::{ pairing::{Engine, MultiMillerLoop}, - zal::H2cEngine, CurveAffine, CurveExt, }; use rand_core::OsRng; diff --git a/halo2_backend/src/poly/multiopen_test.rs b/halo2_backend/src/poly/multiopen_test.rs index d6db2a1251..c37b809b7a 100644 --- a/halo2_backend/src/poly/multiopen_test.rs +++ b/halo2_backend/src/poly/multiopen_test.rs @@ -16,7 +16,7 @@ mod test { }; use group::Curve; use halo2_middleware::ff::WithSmallOrderMulGroup; - use halo2curves::zal::{H2cEngine, MsmAccel}; + use halo2_middleware::zal::{impls::H2cEngine, traits::MsmAccel}; use rand_core::OsRng; #[test] diff --git a/halo2_common/Cargo.toml b/halo2_common/Cargo.toml index cf5e26a698..4f60dbfa11 100644 --- a/halo2_common/Cargo.toml +++ b/halo2_common/Cargo.toml @@ -27,7 +27,7 @@ rustdoc-args = ["--cfg", "docsrs", "--html-in-header", "katex-header.html"] [dependencies] backtrace = { version = "0.3", optional = true } group = "0.13" -halo2curves = { git = 'https://github.com/taikoxyz/halo2curves', branch = "pr-pse-exec-engine", default-features = false } +halo2curves = { version = "0.6.0", default-features = false } rand_core = { version = "0.6", default-features = false } blake2b_simd = "1" # MSRV 1.66.0 sha3 = "0.9.1" diff --git a/halo2_frontend/Cargo.toml b/halo2_frontend/Cargo.toml index a8496aedc1..6a35d66652 100644 --- a/halo2_frontend/Cargo.toml +++ b/halo2_frontend/Cargo.toml @@ -28,7 +28,7 @@ rustdoc-args = ["--cfg", "docsrs", "--html-in-header", "katex-header.html"] backtrace = { version = "0.3", optional = true } ff = "0.13" group = "0.13" -halo2curves = { git = 'https://github.com/taikoxyz/halo2curves', branch = "pr-pse-exec-engine", default-features = false } +halo2curves = { version = "0.6.0", default-features = false } tracing = "0.1" blake2b_simd = "1" # MSRV 1.66.0 serde = { version = "1", optional = true, features = ["derive"] } diff --git a/halo2_middleware/Cargo.toml b/halo2_middleware/Cargo.toml index eee59a3458..604897b886 100644 --- a/halo2_middleware/Cargo.toml +++ b/halo2_middleware/Cargo.toml @@ -26,14 +26,16 @@ rustdoc-args = ["--cfg", "docsrs", "--html-in-header", "katex-header.html"] [dependencies] ff = "0.13" +halo2curves = { version = "0.6.0", default-features = false } serde = { version = "1", optional = true, features = ["derive"] } serde_derive = { version = "1", optional = true} rayon = "1.8" [dev-dependencies] +ark-std = { version = "0.3" } proptest = "1" group = "0.13" -halo2curves = { git = 'https://github.com/taikoxyz/halo2curves', branch = "pr-pse-exec-engine", default-features = false } +rand_core = { version = "0.6", default-features = false } [target.'cfg(all(target_arch = "wasm32", target_os = "unknown"))'.dev-dependencies] getrandom = { version = "0.2", features = ["js"] } diff --git a/halo2_middleware/src/lib.rs b/halo2_middleware/src/lib.rs index db9734d819..70ff0c4af9 100644 --- a/halo2_middleware/src/lib.rs +++ b/halo2_middleware/src/lib.rs @@ -4,5 +4,6 @@ pub mod metadata; pub mod permutation; pub mod poly; pub mod shuffle; +pub mod zal; pub use ff; diff --git a/halo2_middleware/src/zal.rs b/halo2_middleware/src/zal.rs new file mode 100644 index 0000000000..475f22524c --- /dev/null +++ b/halo2_middleware/src/zal.rs @@ -0,0 +1,233 @@ +//! This module provides "ZK Acceleration Layer" traits +//! to abstract away the execution engine for performance-critical primitives. +//! +//! Terminology +//! ----------- +//! +//! We use the name Backend+Engine for concrete implementations of ZalEngine. +//! For example H2cEngine for pure Halo2curves implementation. +//! +//! Alternative names considered were Executor or Driver however +//! - executor is already used in Rust (and the name is long) +//! - driver will be confusing as we work quite low-level with GPUs and FPGAs. +//! +//! Unfortunately the "Engine" name is used in bn256 for pairings. +//! Fortunately a ZalEngine is only used in the prover (at least for now) +//! while "pairing engine" is only used in the verifier +//! +//! Initialization design space +//! --------------------------- +//! +//! It is recommended that ZAL backends provide: +//! - an initialization function: +//! - either "fn new() -> ZalEngine" for simple libraries +//! - or a builder pattern for complex initializations +//! - a shutdown function or document when it is not needed (when it's a global threadpool like Rayon for example). +//! +//! Backends might want to add as an option: +//! - The number of threads (CPU) +//! - The device(s) to run on (multi-sockets machines, multi-GPUs machines, ...) +//! - The curve (JIT-compiled backend) +//! +//! Descriptors +//! --------------------------- +//! +//! Descriptors enable providers to configure opaque details on data +//! when doing repeated computations with the same input(s). +//! For example: +//! - Pointer(s) caching to limit data movement between CPU and GPU, FPGAs +//! - Length of data +//! - data in layout: +//! - canonical or Montgomery fields, unsaturated representation, endianness +//! - jacobian or projective coordinates or maybe even Twisted Edwards for faster elliptic curve additions, +//! - FFT: canonical or bit-reversed permuted +//! - data out layout +//! - Device(s) ID +//! +//! They are required to be Plain Old Data (Copy trait), so no custom `Drop` is required. +//! If a specific resource is needed, it can be stored in the engine in a hashmap for example +//! and an integer ID or a pointer can be opaquely given as a descriptor. + +// The ZK Accel Layer API +// --------------------------------------------------- +pub mod traits { + use halo2curves::CurveAffine; + + pub trait MsmAccel { + fn msm(&self, coeffs: &[C::Scalar], base: &[C]) -> C::Curve; + + // Caching API + // ------------------------------------------------- + // From here we propose an extended API + // that allows reusing coeffs and/or the base points + // + // This is inspired by CuDNN API (Nvidia GPU) + // and oneDNN API (CPU, OpenCL) https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnn-ops-infer-so-opaque + // usage of descriptors + // + // https://github.com/oneapi-src/oneDNN/blob/master/doc/programming_model/basic_concepts.md + // + // Descriptors are opaque pointers that hold the input in a format suitable for the accelerator engine. + // They may be: + // - Input moved on accelerator device (only once for repeated calls) + // - Endianess conversion + // - Converting from Montgomery to Canonical form + // - Input changed from Projective to Jacobian coordinates or even to a Twisted Edwards curve. + // - other form of expensive preprocessing + type CoeffsDescriptor<'c>: Copy; + type BaseDescriptor<'b>: Copy; + + fn get_coeffs_descriptor<'c>(&self, coeffs: &'c [C::Scalar]) -> Self::CoeffsDescriptor<'c>; + fn get_base_descriptor<'b>(&self, base: &'b [C]) -> Self::BaseDescriptor<'b>; + + fn msm_with_cached_scalars( + &self, + coeffs: &Self::CoeffsDescriptor<'_>, + base: &[C], + ) -> C::Curve; + + fn msm_with_cached_base( + &self, + coeffs: &[C::Scalar], + base: &Self::BaseDescriptor<'_>, + ) -> C::Curve; + + fn msm_with_cached_inputs( + &self, + coeffs: &Self::CoeffsDescriptor<'_>, + base: &Self::BaseDescriptor<'_>, + ) -> C::Curve; + // Execute MSM according to descriptors + // Unsure of naming, msm_with_cached_inputs, msm_apply, msm_cached, msm_with_descriptors, ... + } +} + +// ZAL using Halo2curves as a backend +// --------------------------------------------------- + +pub mod impls { + use crate::zal::traits::MsmAccel; + use halo2curves::msm::best_multiexp; + use halo2curves::CurveAffine; + pub struct H2cEngine; + + #[derive(Clone, Copy)] + pub struct H2cMsmCoeffsDesc<'c, C: CurveAffine> { + raw: &'c [C::Scalar], + } + + #[derive(Clone, Copy)] + pub struct H2cMsmBaseDesc<'b, C: CurveAffine> { + raw: &'b [C], + } + + impl H2cEngine { + pub fn new() -> Self { + Self {} + } + } + + impl MsmAccel for H2cEngine { + fn msm(&self, coeffs: &[C::Scalar], bases: &[C]) -> C::Curve { + best_multiexp(coeffs, bases) + } + + // Caching API + // ------------------------------------------------- + + type CoeffsDescriptor<'c> = H2cMsmCoeffsDesc<'c, C>; + type BaseDescriptor<'b> = H2cMsmBaseDesc<'b, C>; + + fn get_coeffs_descriptor<'c>(&self, coeffs: &'c [C::Scalar]) -> Self::CoeffsDescriptor<'c> { + // Do expensive device/library specific preprocessing here + Self::CoeffsDescriptor { raw: coeffs } + } + fn get_base_descriptor<'b>(&self, base: &'b [C]) -> Self::BaseDescriptor<'b> { + Self::BaseDescriptor { raw: base } + } + + fn msm_with_cached_scalars( + &self, + coeffs: &Self::CoeffsDescriptor<'_>, + base: &[C], + ) -> C::Curve { + best_multiexp(coeffs.raw, base) + } + + fn msm_with_cached_base( + &self, + coeffs: &[C::Scalar], + base: &Self::BaseDescriptor<'_>, + ) -> C::Curve { + best_multiexp(coeffs, base.raw) + } + + fn msm_with_cached_inputs( + &self, + coeffs: &Self::CoeffsDescriptor<'_>, + base: &Self::BaseDescriptor<'_>, + ) -> C::Curve { + best_multiexp(coeffs.raw, base.raw) + } + } +} + +// Testing +// --------------------------------------------------- + +#[cfg(test)] +mod test { + use crate::zal::impls::H2cEngine; + use crate::zal::traits::MsmAccel; + use halo2curves::bn256::G1Affine; + use halo2curves::msm::best_multiexp; + use halo2curves::CurveAffine; + + use ark_std::{end_timer, start_timer}; + use ff::Field; + use group::{Curve, Group}; + use rand_core::OsRng; + + fn run_msm_zal(min_k: usize, max_k: usize) { + let points = (0..1 << max_k) + .map(|_| C::Curve::random(OsRng)) + .collect::>(); + let mut affine_points = vec![C::identity(); 1 << max_k]; + C::Curve::batch_normalize(&points[..], &mut affine_points[..]); + let points = affine_points; + + let scalars = (0..1 << max_k) + .map(|_| C::Scalar::random(OsRng)) + .collect::>(); + + for k in min_k..=max_k { + let points = &points[..1 << k]; + let scalars = &scalars[..1 << k]; + + let t0 = start_timer!(|| format!("freestanding msm k={}", k)); + let e0 = best_multiexp(scalars, points); + end_timer!(t0); + + let engine = H2cEngine::new(); + let t1 = start_timer!(|| format!("H2cEngine msm k={}", k)); + let e1 = engine.msm(scalars, points); + end_timer!(t1); + + assert_eq!(e0, e1); + + // Caching API + // ----------- + let t2 = start_timer!(|| format!("H2cEngine msm cached base k={}", k)); + let base_descriptor = engine.get_base_descriptor(points); + let e2 = engine.msm_with_cached_base(scalars, &base_descriptor); + end_timer!(t2); + + assert_eq!(e0, e2) + } + } + + #[test] + fn test_msm_zal() { + run_msm_zal::(3, 14); + } +} diff --git a/halo2_proofs/Cargo.toml b/halo2_proofs/Cargo.toml index e6c5ae6170..8bf059790b 100644 --- a/halo2_proofs/Cargo.toml +++ b/halo2_proofs/Cargo.toml @@ -53,7 +53,7 @@ halo2_middleware = { path = "../halo2_middleware" } halo2_common = { path = "../halo2_common" } halo2_backend = { path = "../halo2_backend" } halo2_frontend = { path = "../halo2_frontend" } -halo2curves = { git = 'https://github.com/taikoxyz/halo2curves', branch = "pr-pse-exec-engine", default-features = false } +halo2curves = { version = "0.6.0", default-features = false } rand_core = { version = "0.6", default-features = false, features = ["getrandom"] } plotters = { version = "0.3.0", default-features = false, optional = true } diff --git a/halo2_proofs/benches/arithmetic.rs b/halo2_proofs/benches/arithmetic.rs index 659caa10da..51f769c9e9 100644 --- a/halo2_proofs/benches/arithmetic.rs +++ b/halo2_proofs/benches/arithmetic.rs @@ -2,9 +2,9 @@ extern crate criterion; use group::ff::Field; +use halo2_middleware::zal::{impls::H2cEngine, traits::MsmAccel}; use halo2_proofs::*; use halo2curves::pasta::{EqAffine, Fp}; -use halo2curves::zal::{H2cEngine, MsmAccel}; use halo2_proofs::poly::{commitment::ParamsProver, ipa::commitment::ParamsIPA}; diff --git a/halo2_proofs/src/plonk/prover.rs b/halo2_proofs/src/plonk/prover.rs index 9c831c2694..3cd90a45c8 100644 --- a/halo2_proofs/src/plonk/prover.rs +++ b/halo2_proofs/src/plonk/prover.rs @@ -4,7 +4,7 @@ use halo2_common::plonk::{circuit::Circuit, Error}; use halo2_common::transcript::{EncodedChallenge, TranscriptWrite}; use halo2_frontend::circuit::{compile_circuit, WitnessCalculator}; use halo2_middleware::ff::{FromUniformBytes, WithSmallOrderMulGroup}; -use halo2curves::zal::{H2cEngine, MsmAccel}; +use halo2_middleware::zal::{impls::H2cEngine, traits::MsmAccel}; use rand_core::RngCore; use std::collections::HashMap; diff --git a/halo2_proofs/tests/frontend_backend_split.rs b/halo2_proofs/tests/frontend_backend_split.rs index 3171816bf5..5fd70f3895 100644 --- a/halo2_proofs/tests/frontend_backend_split.rs +++ b/halo2_proofs/tests/frontend_backend_split.rs @@ -548,7 +548,7 @@ fn test_mycircuit_full_legacy() { #[test] fn test_mycircuit_full_split() { - use halo2curves::zal::H2cEngine; + use halo2_middleware::zal::impls::H2cEngine; #[cfg(feature = "heap-profiling")] let _profiler = dhat::Profiler::new_heap(); diff --git a/halo2_proofs/tests/plonk_api.rs b/halo2_proofs/tests/plonk_api.rs index 858a9736ed..501d3e89d0 100644 --- a/halo2_proofs/tests/plonk_api.rs +++ b/halo2_proofs/tests/plonk_api.rs @@ -3,6 +3,7 @@ use assert_matches::assert_matches; use ff::{FromUniformBytes, WithSmallOrderMulGroup}; +use halo2_middleware::zal::{impls::H2cEngine, traits::MsmAccel}; use halo2_proofs::arithmetic::Field; use halo2_proofs::circuit::{Cell, Layouter, SimpleFloorPlanner, Value}; use halo2_proofs::dev::MockProver; @@ -18,7 +19,6 @@ use halo2_proofs::transcript::{ Blake2bRead, Blake2bWrite, Challenge255, EncodedChallenge, TranscriptReadBuffer, TranscriptWriterBuffer, }; -use halo2curves::zal::{H2cEngine, MsmAccel}; use rand_core::{OsRng, RngCore}; use std::marker::PhantomData;