From afc466f49b0a334fa9ca32fe5cd363a06630bc56 Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Wed, 26 Jul 2023 18:06:45 -0600 Subject: [PATCH 1/5] feat: derive Serialize, Deserialize for ProvingKey, VerifyingKey --- halo2_proofs/Cargo.toml | 1 + halo2_proofs/src/dev.rs | 28 +++++++-------- halo2_proofs/src/dev/failure.rs | 4 +-- halo2_proofs/src/dev/gates.rs | 8 ++--- halo2_proofs/src/dev/metadata.rs | 27 ++++++++++----- halo2_proofs/src/plonk.rs | 13 +++++-- halo2_proofs/src/plonk/circuit.rs | 49 ++++++++++++++------------- halo2_proofs/src/plonk/evaluation.rs | 19 ++++++++--- halo2_proofs/src/plonk/lookup.rs | 7 ++-- halo2_proofs/src/plonk/permutation.rs | 11 ++++-- halo2_proofs/src/poly.rs | 5 +-- halo2_proofs/src/poly/domain.rs | 7 ++-- primitives/poseidon/src/poseidon.rs | 4 +++ primitives/poseidon/src/spec.rs | 7 ++-- 14 files changed, 116 insertions(+), 74 deletions(-) diff --git a/halo2_proofs/Cargo.toml b/halo2_proofs/Cargo.toml index c3b4219ccf..47021b265f 100644 --- a/halo2_proofs/Cargo.toml +++ b/halo2_proofs/Cargo.toml @@ -61,6 +61,7 @@ blake2b_simd = "1" rustc-hash = "1.1.0" sha3 = "0.9.1" ark-std = { version = "0.3.0", features = ["print-trace"], optional = true } +serde = { version = "1.0", default-features = false, features = ["derive"] } # Developer tooling dependencies plotters = { version = "0.3.0", optional = true } diff --git a/halo2_proofs/src/dev.rs b/halo2_proofs/src/dev.rs index f9eb16af17..17b6732b41 100644 --- a/halo2_proofs/src/dev.rs +++ b/halo2_proofs/src/dev.rs @@ -834,9 +834,9 @@ impl MockProver { Value::Real(x) if x.is_zero_vartime() => None, Value::Real(_) => Some(VerifyFailure::ConstraintNotSatisfied { constraint: ( - (gate_index, gate.name()).into(), + (gate_index, gate.name().to_string()).into(), poly_index, - gate.constraint_name(poly_index), + gate.constraint_name(poly_index).to_string(), ) .into(), location: FailureLocation::find_expressions( @@ -860,9 +860,9 @@ impl MockProver { }), Value::Poison => Some(VerifyFailure::ConstraintPoisoned { constraint: ( - (gate_index, gate.name()).into(), + (gate_index, gate.name().to_string()).into(), poly_index, - gate.constraint_name(poly_index), + gate.constraint_name(poly_index).to_string(), ) .into(), }), @@ -995,7 +995,7 @@ impl MockProver { assert!(table.binary_search(input).is_err()); Some(VerifyFailure::Lookup { - name: lookup.name, + name: lookup.name.clone(), lookup_index, location: FailureLocation::find_expressions( &self.cs, @@ -1146,7 +1146,7 @@ impl MockProver { None } else { Some(VerifyFailure::CellNotAssigned { - gate: (gate_index, gate.name()).into(), + gate: (gate_index, gate.name().to_string()).into(), region: ( r_i, r.name.clone(), @@ -1225,9 +1225,9 @@ impl MockProver { Value::Real(x) if x.is_zero_vartime() => None, Value::Real(_) => Some(VerifyFailure::ConstraintNotSatisfied { constraint: ( - (gate_index, gate.name()).into(), + (gate_index, gate.name().to_string()).into(), poly_index, - gate.constraint_name(poly_index), + gate.constraint_name(poly_index).to_string(), ) .into(), location: FailureLocation::find_expressions( @@ -1251,9 +1251,9 @@ impl MockProver { }), Value::Poison => Some(VerifyFailure::ConstraintPoisoned { constraint: ( - (gate_index, gate.name()).into(), + (gate_index, gate.name().to_string()).into(), poly_index, - gate.constraint_name(poly_index), + gate.constraint_name(poly_index).to_string(), ) .into(), }), @@ -1374,7 +1374,7 @@ impl MockProver { .filter_map(move |(input, input_row)| { if table.binary_search(input).is_err() { Some(VerifyFailure::Lookup { - name: lookup.name, + name: lookup.name.clone(), lookup_index, location: FailureLocation::find_expressions( &self.cs, @@ -1781,7 +1781,7 @@ mod tests { assert_eq!( prover.verify(), Err(vec![VerifyFailure::Lookup { - name: "lookup", + name: "lookup".to_owned(), lookup_index: 0, location: FailureLocation::InRegion { region: (1, "Faulty synthesis").into(), @@ -1913,7 +1913,7 @@ mod tests { assert_eq!( prover.verify(), Err(vec![VerifyFailure::Lookup { - name: "lookup", + name: "lookup".to_owned(), lookup_index: 0, location: FailureLocation::InRegion { region: (2, "Faulty synthesis").into(), @@ -2030,7 +2030,7 @@ mod tests { assert_eq!( prover.verify(), Err(vec![VerifyFailure::ConstraintNotSatisfied { - constraint: ((0, "Equality check").into(), 0, "").into(), + constraint: ((0, "Equality check".to_owned()).into(), 0, "".to_owned()).into(), location: FailureLocation::InRegion { region: (1, "Wrong synthesis").into(), offset: 0, diff --git a/halo2_proofs/src/dev/failure.rs b/halo2_proofs/src/dev/failure.rs index 3d08a5812d..49e608f37a 100644 --- a/halo2_proofs/src/dev/failure.rs +++ b/halo2_proofs/src/dev/failure.rs @@ -162,7 +162,7 @@ pub enum VerifyFailure { /// A lookup input did not exist in its corresponding table. Lookup { /// The name of the lookup that is not satisfied. - name: &'static str, + name: String, /// The index of the lookup that is not satisfied. These indices are assigned in /// the order in which `ConstraintSystem::lookup` is called during /// `Circuit::configure`. @@ -280,7 +280,7 @@ impl Debug for VerifyFailure { }; let debug = ConstraintCaseDebug { - constraint: *constraint, + constraint: constraint.clone(), location: location.clone(), cell_values: cell_values .iter() diff --git a/halo2_proofs/src/dev/gates.rs b/halo2_proofs/src/dev/gates.rs index cfc71c021e..ef09fbf3c7 100644 --- a/halo2_proofs/src/dev/gates.rs +++ b/halo2_proofs/src/dev/gates.rs @@ -15,14 +15,14 @@ use crate::{ #[derive(Debug)] struct Constraint { - name: &'static str, + name: String, expression: String, queries: BTreeSet, } #[derive(Debug)] struct Gate { - name: &'static str, + name: String, constraints: Vec, } @@ -112,13 +112,13 @@ impl CircuitGates { .gates .iter() .map(|gate| Gate { - name: gate.name(), + name: gate.name().to_owned(), constraints: gate .polynomials() .iter() .enumerate() .map(|(i, constraint)| Constraint { - name: gate.constraint_name(i), + name: gate.constraint_name(i).to_owned(), expression: constraint.evaluate( &util::format_value, &|selector| format!("S{}", selector.0), diff --git a/halo2_proofs/src/dev/metadata.rs b/halo2_proofs/src/dev/metadata.rs index 5fd0835bad..8c626bd8e2 100644 --- a/halo2_proofs/src/dev/metadata.rs +++ b/halo2_proofs/src/dev/metadata.rs @@ -1,5 +1,7 @@ //! Metadata about circuits. +use serde::{Deserialize, Serialize}; + use super::metadata::Column as ColumnMetadata; use crate::plonk::{self, Any}; use std::{ @@ -7,7 +9,7 @@ use std::{ fmt::{self, Debug}, }; /// Metadata about a column within a circuit. -#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)] pub struct Column { /// The type of the column. pub(super) column_type: Any, @@ -149,14 +151,14 @@ impl fmt::Display for DebugVirtualCell { } /// Metadata about a configured gate within a circuit. -#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[derive(Debug, PartialEq, Eq, Clone)] pub struct Gate { /// The index of the active gate. These indices are assigned in the order in which /// `ConstraintSystem::create_gate` is called during `Circuit::configure`. pub(super) index: usize, /// The name of the active gate. These are specified by the gate creator (such as /// a chip implementation), and is not enforced to be unique. - pub(super) name: &'static str, + pub(super) name: String, } impl fmt::Display for Gate { @@ -165,14 +167,23 @@ impl fmt::Display for Gate { } } +impl From<(usize, String)> for Gate { + fn from((index, name): (usize, String)) -> Self { + Gate { index, name } + } +} + impl From<(usize, &'static str)> for Gate { fn from((index, name): (usize, &'static str)) -> Self { - Gate { index, name } + Gate { + index, + name: name.to_owned(), + } } } /// Metadata about a configured constraint within a circuit. -#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[derive(Debug, PartialEq, Eq, Clone)] pub struct Constraint { /// The gate containing the constraint. pub(super) gate: Gate, @@ -182,7 +193,7 @@ pub struct Constraint { pub(super) index: usize, /// The name of the constraint. This is specified by the gate creator (such as a chip /// implementation), and is not enforced to be unique. - pub(super) name: &'static str, + pub(super) name: String, } impl fmt::Display for Constraint { @@ -202,8 +213,8 @@ impl fmt::Display for Constraint { } } -impl From<(Gate, usize, &'static str)> for Constraint { - fn from((gate, index, name): (Gate, usize, &'static str)) -> Self { +impl From<(Gate, usize, String)> for Constraint { + fn from((gate, index, name): (Gate, usize, String)) -> Self { Constraint { gate, index, name } } } diff --git a/halo2_proofs/src/plonk.rs b/halo2_proofs/src/plonk.rs index e3bffc000e..344f98f6a5 100644 --- a/halo2_proofs/src/plonk.rs +++ b/halo2_proofs/src/plonk.rs @@ -8,6 +8,7 @@ use blake2b_simd::Params as Blake2bParams; use ff::PrimeField; use group::ff::Field; +use serde::{Deserialize, Serialize}; use crate::arithmetic::{CurveAffine, FieldExt}; use crate::helpers::{ @@ -150,7 +151,11 @@ pub fn log_info(msg: String) { /// This is a verifying key which allows for the verification of proofs for a /// particular circuit. -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(bound( + serialize = "C::Scalar: Serialize, C: Serialize", + deserialize = "C::Scalar: Deserialize<'de>, C: Deserialize<'de>" +))] pub struct VerifyingKey { domain: EvaluationDomain, fixed_commitments: Vec, @@ -365,7 +370,11 @@ pub struct PinnedVerificationKey<'a, C: CurveAffine> { } /// This is a proving key which allows for the creation of proofs for a /// particular circuit. -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(bound( + serialize = "C::Scalar: Serialize, C: Serialize", + deserialize = "C::Scalar: Deserialize<'de>, C: Deserialize<'de>" +))] pub struct ProvingKey { vk: VerifyingKey, l0: Polynomial, diff --git a/halo2_proofs/src/plonk/circuit.rs b/halo2_proofs/src/plonk/circuit.rs index 1a08107e5c..fd1f72900c 100644 --- a/halo2_proofs/src/plonk/circuit.rs +++ b/halo2_proofs/src/plonk/circuit.rs @@ -1,6 +1,7 @@ use core::cmp::max; use core::ops::{Add, Mul}; use ff::Field; +use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::env::var; use std::{ @@ -25,7 +26,7 @@ pub trait ColumnType: } /// A column with an index and type -#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)] +#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)] pub struct Column { index: usize, column_type: C, @@ -69,8 +70,10 @@ impl PartialOrd for Column { pub(crate) mod sealed { use std::ops::Add; + use serde::{Deserialize, Serialize}; + /// Phase of advice column - #[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)] + #[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize)] pub struct Phase(pub(super) u8); impl Phase { @@ -134,7 +137,7 @@ impl SealedPhase for super::ThirdPhase { } /// An advice column -#[derive(Clone, Copy, Eq, PartialEq, Hash)] +#[derive(Clone, Copy, Eq, PartialEq, Hash, Serialize, Deserialize)] pub struct Advice { pub(crate) phase: sealed::Phase, } @@ -173,15 +176,15 @@ impl std::fmt::Debug for Advice { } /// A fixed column -#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)] +#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)] pub struct Fixed; /// An instance column -#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)] +#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)] pub struct Instance; /// An enum over the Advice, Fixed, Instance structs -#[derive(Clone, Copy, Eq, PartialEq, Hash)] +#[derive(Clone, Copy, Eq, PartialEq, Hash, Serialize, Deserialize)] pub enum Any { /// An Advice variant Advice(Advice), @@ -388,7 +391,7 @@ impl TryFrom> for Column { /// Ok(()) /// } /// ``` -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] pub struct Selector(pub(crate) usize, bool); impl Selector { @@ -405,7 +408,7 @@ impl Selector { } /// Query of fixed column at a certain relative location -#[derive(Copy, Clone, Debug)] +#[derive(Copy, Clone, Debug, Serialize, Deserialize)] pub struct FixedQuery { /// Query index pub(crate) index: usize, @@ -428,7 +431,7 @@ impl FixedQuery { } /// Query of advice column at a certain relative location -#[derive(Copy, Clone, Debug)] +#[derive(Copy, Clone, Debug, Serialize, Deserialize)] pub struct AdviceQuery { /// Query index pub(crate) index: usize, @@ -458,7 +461,7 @@ impl AdviceQuery { } /// Query of instance column at a certain relative location -#[derive(Copy, Clone, Debug)] +#[derive(Copy, Clone, Debug, Serialize, Deserialize)] pub struct InstanceQuery { /// Query index pub(crate) index: usize, @@ -509,7 +512,7 @@ impl TableColumn { } /// A challenge squeezed from transcript after advice columns at the phase have been committed. -#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)] +#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)] pub struct Challenge { index: usize, pub(crate) phase: sealed::Phase, @@ -678,7 +681,7 @@ pub trait Circuit { } /// Low-degree expression representing an identity that must hold over the committed columns. -#[derive(Clone)] +#[derive(Clone, Serialize, Deserialize)] pub enum Expression { /// This is a constant polynomial Constant(F), @@ -1190,7 +1193,7 @@ pub(crate) struct PointIndex(pub usize); /// A "virtual cell" is a PLONK cell that has been queried at a particular relative offset /// within a custom gate. -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub struct VirtualCell { pub(crate) column: Column, pub(crate) rotation: Rotation, @@ -1316,10 +1319,10 @@ impl>, Iter: IntoIterator> IntoIterato } /// Gate -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub struct Gate { - name: &'static str, - constraint_names: Vec<&'static str>, + name: String, + constraint_names: Vec, polys: Vec>, /// We track queried selectors separately from other cells, so that we can use them to /// trigger debug checks on gates. @@ -1328,12 +1331,12 @@ pub struct Gate { } impl Gate { - pub(crate) fn name(&self) -> &'static str { - self.name + pub(crate) fn name(&self) -> &str { + &self.name } - pub(crate) fn constraint_name(&self, constraint_index: usize) -> &'static str { - self.constraint_names[constraint_index] + pub(crate) fn constraint_name(&self, constraint_index: usize) -> &str { + &self.constraint_names[constraint_index] } /// Returns constraints of this gate @@ -1352,7 +1355,7 @@ impl Gate { /// This is a description of the circuit environment, such as the gate, column and /// permutation arrangements. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct ConstraintSystem { pub(crate) num_fixed_columns: usize, pub(crate) num_advice_columns: usize, @@ -1699,7 +1702,7 @@ impl ConstraintSystem { let (constraint_names, polys): (_, Vec<_>) = constraints .into_iter() .map(|c| c.into()) - .map(|c| (c.name, c.poly)) + .map(|c| (c.name.to_string(), c.poly)) .unzip(); assert!( @@ -1708,7 +1711,7 @@ impl ConstraintSystem { ); self.gates.push(Gate { - name, + name: name.to_string(), constraint_names, polys, queried_selectors, diff --git a/halo2_proofs/src/plonk/evaluation.rs b/halo2_proofs/src/plonk/evaluation.rs index b3357c4684..ee6ef79447 100644 --- a/halo2_proofs/src/plonk/evaluation.rs +++ b/halo2_proofs/src/plonk/evaluation.rs @@ -16,6 +16,7 @@ use group::{ ff::{BatchInvert, Field}, Curve, }; +use serde::{Deserialize, Serialize}; use std::any::TypeId; use std::convert::TryInto; use std::num::ParseIntError; @@ -34,7 +35,7 @@ fn get_rotation_idx(idx: usize, rot: i32, rot_scale: i32, isize: i32) -> usize { } /// Value used in a calculation -#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Serialize, Deserialize)] pub enum ValueSource { /// This is a constant value Constant(usize), @@ -106,7 +107,7 @@ impl ValueSource { } /// Calculation -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub enum Calculation { /// This is an addition Add(ValueSource, ValueSource), @@ -180,7 +181,11 @@ impl Calculation { } /// Evaluator -#[derive(Clone, Default, Debug)] +#[derive(Clone, Default, Debug, Serialize, Deserialize)] +#[serde(bound( + serialize = "C::Scalar: Serialize", + deserialize = "C::Scalar: Deserialize<'de>" +))] pub struct Evaluator { /// Custom gates evalution pub custom_gates: GraphEvaluator, @@ -189,7 +194,11 @@ pub struct Evaluator { } /// GraphEvaluator -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(bound( + serialize = "C::Scalar: Serialize", + deserialize = "C::Scalar: Deserialize<'de>" +))] pub struct GraphEvaluator { /// Constants pub constants: Vec, @@ -211,7 +220,7 @@ pub struct EvaluationData { } /// CaluclationInfo -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub struct CalculationInfo { /// Calculation pub calculation: Calculation, diff --git a/halo2_proofs/src/plonk/lookup.rs b/halo2_proofs/src/plonk/lookup.rs index 68cda75d37..c211db6e90 100644 --- a/halo2_proofs/src/plonk/lookup.rs +++ b/halo2_proofs/src/plonk/lookup.rs @@ -1,13 +1,14 @@ use super::circuit::Expression; use ff::Field; +use serde::{Deserialize, Serialize}; use std::fmt::{self, Debug}; pub(crate) mod prover; pub(crate) mod verifier; -#[derive(Clone)] +#[derive(Clone, Serialize, Deserialize)] pub struct Argument { - pub(crate) name: &'static str, + pub(crate) name: String, pub(crate) input_expressions: Vec>, pub(crate) table_expressions: Vec>, } @@ -28,7 +29,7 @@ impl Argument { pub fn new(name: &'static str, table_map: Vec<(Expression, Expression)>) -> Self { let (input_expressions, table_expressions) = table_map.into_iter().unzip(); Argument { - name, + name: name.to_string(), input_expressions, table_expressions, } diff --git a/halo2_proofs/src/plonk/permutation.rs b/halo2_proofs/src/plonk/permutation.rs index 0eba912093..255c12a956 100644 --- a/halo2_proofs/src/plonk/permutation.rs +++ b/halo2_proofs/src/plonk/permutation.rs @@ -17,11 +17,12 @@ pub(crate) mod prover; pub(crate) mod verifier; pub use keygen::Assembly; +use serde::{Deserialize, Serialize}; use std::io; /// A permutation argument. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct Argument { /// A sequence of columns involved in the argument. pub(super) columns: Vec>, @@ -83,7 +84,7 @@ impl Argument { } /// The verifying key for a single permutation argument. -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub struct VerifyingKey { commitments: Vec, } @@ -123,7 +124,11 @@ impl VerifyingKey { } /// The proving key for a single permutation argument. -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(bound( + serialize = "C::Scalar: Serialize", + deserialize = "C::Scalar: Deserialize<'de>" +))] pub(crate) struct ProvingKey { permutations: Vec>, polys: Vec>, diff --git a/halo2_proofs/src/poly.rs b/halo2_proofs/src/poly.rs index d7fd0fb3e8..268d023b24 100644 --- a/halo2_proofs/src/poly.rs +++ b/halo2_proofs/src/poly.rs @@ -10,6 +10,7 @@ use crate::SerdeFormat; use ff::PrimeField; use group::ff::{BatchInvert, Field}; use halo2curves::FieldExt; +use serde::{Deserialize, Serialize}; use std::fmt::Debug; use std::io; use std::marker::PhantomData; @@ -65,7 +66,7 @@ impl Basis for ExtendedLagrangeCoeff {} /// Represents a univariate polynomial defined over a field and a particular /// basis. -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub struct Polynomial { pub(crate) values: Vec, _marker: PhantomData, @@ -353,7 +354,7 @@ impl<'a, F: Field, B: Basis> Sub for &'a Polynomial { /// Describes the relative rotation of a vector. Negative numbers represent /// reverse (leftmost) rotations and positive numbers represent forward (rightmost) /// rotations. Zero represents no rotation. -#[derive(Copy, Clone, Debug, PartialEq)] +#[derive(Copy, Clone, Debug, PartialEq, Serialize, Deserialize)] pub struct Rotation(pub i32); impl Rotation { diff --git a/halo2_proofs/src/poly/domain.rs b/halo2_proofs/src/poly/domain.rs index 260e788900..c19ed920b7 100644 --- a/halo2_proofs/src/poly/domain.rs +++ b/halo2_proofs/src/poly/domain.rs @@ -10,6 +10,7 @@ use crate::{ use super::{Coeff, ExtendedLagrangeCoeff, LagrangeCoeff, Polynomial, Rotation}; use group::ff::{BatchInvert, Field, PrimeField}; +use serde::{Deserialize, Serialize}; use std::{env::var, marker::PhantomData}; @@ -24,7 +25,7 @@ fn get_fft_mode() -> usize { } /// FFTStage -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub struct FFTStage { radix: usize, length: usize, @@ -65,7 +66,7 @@ pub fn get_stages(size: usize, radixes: Vec) -> Vec { } /// FFTData -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Serialize, Deserialize)] struct FFTData { n: usize, @@ -409,7 +410,7 @@ fn recursive_fft(data: &FFTData, data_in: &mut Vec, inverse: /// This structure contains precomputed constants and other details needed for /// performing operations on an evaluation domain of size $2^k$ and an extended /// domain of size $2^{k} * j$ with $j \neq 0$. -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub struct EvaluationDomain { n: u64, k: u32, diff --git a/primitives/poseidon/src/poseidon.rs b/primitives/poseidon/src/poseidon.rs index 90023e3da6..a46391f5ce 100644 --- a/primitives/poseidon/src/poseidon.rs +++ b/primitives/poseidon/src/poseidon.rs @@ -25,6 +25,7 @@ impl Poseidon { pub fn update(&mut self, elements: &[F]) { let mut input_elements = self.absorbing.clone(); input_elements.extend_from_slice(elements); + dbg!(&input_elements); for chunk in input_elements.chunks(RATE) { if chunk.len() < RATE { @@ -32,12 +33,14 @@ impl Poseidon { // absorbation line self.absorbing = chunk.to_vec(); } else { + dbg!(&self.state); // Add new chunk of inputs for the next permutation cycle. for (input_element, state) in chunk.iter().zip(self.state.0.iter_mut().skip(1)) { state.add_assign(input_element); } // Perform intermediate permutation self.spec.permute(&mut self.state); + dbg!(&self.state); // Flush the absorption line self.absorbing.clear(); } @@ -62,6 +65,7 @@ impl Poseidon { // Perform final permutation self.spec.permute(&mut self.state); + dbg!(&self.state); // Flush the absorption line self.absorbing.clear(); // Returns the challenge while preserving internal state diff --git a/primitives/poseidon/src/spec.rs b/primitives/poseidon/src/spec.rs index 6c9a9382ce..c96db7af23 100644 --- a/primitives/poseidon/src/spec.rs +++ b/primitives/poseidon/src/spec.rs @@ -12,7 +12,7 @@ impl Default for State { /// The capacity value is 2**64 + (o − 1) where o the output length. fn default() -> Self { let mut state = [F::zero(); T]; - state[0] = F::from_u128(1 << 64); + state[0] = F::from_u128(1u128 << 64); State(state) } } @@ -349,10 +349,7 @@ impl Spec { *optimized = tmp[0]; tmp[0] = F::zero(); - for ((acc, tmp), constant) in acc - .iter_mut() - .zip(tmp.into_iter()) - .zip(constants.iter()) + for ((acc, tmp), constant) in acc.iter_mut().zip(tmp.into_iter()).zip(constants.iter()) { *acc = tmp + constant } From 7d7527d8e5690a2187fc70ccf3ad7e1eae5fd102 Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Wed, 26 Jul 2023 18:21:04 -0600 Subject: [PATCH 2/5] chore: clippy fix --- halo2_proofs/Cargo.toml | 1 + halo2_proofs/examples/shuffle.rs | 2 +- halo2_proofs/src/plonk/circuit.rs | 1 + halo2_proofs/src/plonk/evaluation.rs | 10 +++++++-- halo2_proofs/src/plonk/permutation/keygen.rs | 4 ++-- halo2_proofs/src/plonk/prover.rs | 2 +- halo2_proofs/src/poly/domain.rs | 23 ++++++++++++-------- halo2_proofs/src/transcript.rs | 20 ++++++++--------- halo2_proofs/tests/plonk_api.rs | 4 +++- 9 files changed, 41 insertions(+), 26 deletions(-) diff --git a/halo2_proofs/Cargo.toml b/halo2_proofs/Cargo.toml index 47021b265f..059ae37331 100644 --- a/halo2_proofs/Cargo.toml +++ b/halo2_proofs/Cargo.toml @@ -73,6 +73,7 @@ criterion = "0.3" gumdrop = "0.8" proptest = "1" rand_core = { version = "0.6", features = ["getrandom"] } +rand_chacha = "0.3.1" [target.'cfg(all(target_arch = "wasm32", target_os = "unknown"))'.dev-dependencies] getrandom = { version = "0.2", features = ["js"] } diff --git a/halo2_proofs/examples/shuffle.rs b/halo2_proofs/examples/shuffle.rs index 3075caea3a..34ca5c0d12 100644 --- a/halo2_proofs/examples/shuffle.rs +++ b/halo2_proofs/examples/shuffle.rs @@ -338,7 +338,7 @@ fn main() { K, circuit.clone(), Err(vec![( - ((1, "z should end with 1").into(), 0, "").into(), + ((1, "z should end with 1").into(), 0, "".to_owned()).into(), FailureLocation::InRegion { region: (0, "Shuffle original into shuffled").into(), offset: 32, diff --git a/halo2_proofs/src/plonk/circuit.rs b/halo2_proofs/src/plonk/circuit.rs index fd1f72900c..976bea0785 100644 --- a/halo2_proofs/src/plonk/circuit.rs +++ b/halo2_proofs/src/plonk/circuit.rs @@ -84,6 +84,7 @@ pub(crate) mod sealed { assert!(self.0 < 2, "The API only supports three phases"); Phase(self.0 + 1) } + #[allow(clippy::wrong_self_convention)] pub fn to_u8(&self) -> u8 { self.0 } diff --git a/halo2_proofs/src/plonk/evaluation.rs b/halo2_proofs/src/plonk/evaluation.rs index ee6ef79447..c5d2d6ae22 100644 --- a/halo2_proofs/src/plonk/evaluation.rs +++ b/halo2_proofs/src/plonk/evaluation.rs @@ -312,6 +312,7 @@ impl Evaluator { let p = &pk.vk.cs.permutation; // Calculate the advice and instance cosets + #[cfg(feature = "profile")] let start = start_measure("cosets", false); let advice: Vec>> = advice_polys .iter() @@ -331,7 +332,7 @@ impl Evaluator { .collect() }) .collect(); - stop_measure(start); + // stop_measure(start); let mut values = domain.empty_extended(); @@ -344,7 +345,7 @@ impl Evaluator { .zip(permutations.iter()) { // Custom gates - let start = start_measure("custom gates", false); + // let start = start_measure("custom gates", false); multicore::scope(|scope| { let chunk_size = (size + num_threads - 1) / num_threads; for (thread_idx, values) in values.chunks_mut(chunk_size).enumerate() { @@ -372,9 +373,11 @@ impl Evaluator { }); } }); + #[cfg(feature = "profile")] stop_measure(start); // Permutations + #[cfg(feature = "profile")] let start = start_measure("permutations", false); let sets = &permutation.sets; if !sets.is_empty() { @@ -456,9 +459,11 @@ impl Evaluator { } }); } + #[cfg(feature = "profile")] stop_measure(start); // Lookups + #[cfg(feature = "profile")] let start = start_measure("lookups", false); for (n, lookup) in lookups.iter().enumerate() { // Polynomials required for this lookup. @@ -529,6 +534,7 @@ impl Evaluator { } }); } + #[cfg(feature = "profile")] stop_measure(start); } values diff --git a/halo2_proofs/src/plonk/permutation/keygen.rs b/halo2_proofs/src/plonk/permutation/keygen.rs index c15f3130cc..a47b178854 100644 --- a/halo2_proofs/src/plonk/permutation/keygen.rs +++ b/halo2_proofs/src/plonk/permutation/keygen.rs @@ -175,7 +175,7 @@ impl Assembly { { let omega = domain.get_omega(); parallelize(&mut omega_powers, |o, start| { - let mut cur = omega.pow_vartime(&[start as u64]); + let mut cur = omega.pow_vartime([start as u64]); for v in o.iter_mut() { *v = cur; cur *= ω @@ -187,7 +187,7 @@ impl Assembly { let mut deltaomega = vec![omega_powers; p.columns.len()]; { parallelize(&mut deltaomega, |o, start| { - let mut cur = C::Scalar::DELTA.pow_vartime(&[start as u64]); + let mut cur = C::Scalar::DELTA.pow_vartime([start as u64]); for omega_powers in o.iter_mut() { for v in omega_powers { *v *= &cur; diff --git a/halo2_proofs/src/plonk/prover.rs b/halo2_proofs/src/plonk/prover.rs index ebc64016bc..b12743e4d4 100644 --- a/halo2_proofs/src/plonk/prover.rs +++ b/halo2_proofs/src/plonk/prover.rs @@ -11,7 +11,6 @@ use std::marker::PhantomData; use std::ops::RangeTo; use std::rc::Rc; use std::sync::atomic::AtomicUsize; -use std::time::Instant; use std::{collections::HashMap, iter, mem, sync::atomic::Ordering}; use super::{ @@ -753,6 +752,7 @@ where #[cfg(feature = "profile")] let multiopen_time = start_timer!(|| "Phase 5: multiopen"); let prover = P::new(params); + #[allow(clippy::let_and_return)] let multiopen_res = prover .create_proof(&mut rng, transcript, instances) .map_err(|_| Error::ConstraintSystemFailure); diff --git a/halo2_proofs/src/poly/domain.rs b/halo2_proofs/src/poly/domain.rs index c19ed920b7..6c5676fe54 100644 --- a/halo2_proofs/src/poly/domain.rs +++ b/halo2_proofs/src/poly/domain.rs @@ -80,7 +80,7 @@ struct FFTData { impl FFTData { /// Create FFT data pub fn new(n: usize, omega: F, omega_inv: F) -> Self { - let stages = get_stages(n as usize, vec![]); + let stages = get_stages(n, vec![]); let mut f_twiddles = vec![]; let mut inv_twiddles = vec![]; let mut scratch = vec![F::zero(); n]; @@ -100,7 +100,7 @@ impl FFTData { // Twiddles parallelize(twiddles, |twiddles, start| { let w_m = o; - let mut w = o.pow_vartime(&[start as u64, 0, 0, 0]); + let mut w = o.pow_vartime([start as u64]); for value in twiddles.iter_mut() { *value = w; w *= w_m; @@ -169,9 +169,9 @@ fn butterfly_2_parallel( num_threads: usize, ) { let n = out.len(); - let mut chunk = (n as usize) / num_threads; + let mut chunk = n / num_threads; if chunk < num_threads { - chunk = n as usize; + chunk = n; } multicore::scope(|scope| { @@ -258,9 +258,9 @@ pub fn butterfly_4_parallel( let j = twiddles[twiddles.len() - 1]; let n = out.len(); - let mut chunk = (n as usize) / num_threads; + let mut chunk = n / num_threads; if chunk < num_threads { - chunk = n as usize; + chunk = n; } multicore::scope(|scope| { //let mut parts: Vec<&mut [F]> = out.chunks_mut(4).collect(); @@ -1017,7 +1017,7 @@ fn test_fft() { } let k = get_degree() as u32; - let mut domain = EvaluationDomain::::new(1, k); + let domain = EvaluationDomain::::new(1, k); let n = domain.n as usize; let input = vec![Scalar::random(OsRng); n]; @@ -1026,19 +1026,24 @@ fn test_fft() { input[i] = Scalar::random(OsRng); }*/ + #[cfg(feature = "profile")] let num_threads = multicore::current_num_threads(); let mut a = input.clone(); + #[cfg(feature = "profile")] let start = start_measure(format!("best fft {} ({})", a.len(), num_threads), false); best_fft(&mut a, domain.omega, k); + #[cfg(feature = "profile")] stop_measure(start); - let mut b = input.clone(); + let mut b = input; + #[cfg(feature = "profile")] let start = start_measure( format!("recursive fft {} ({})", a.len(), num_threads), false, ); - recursive_fft(&mut domain.fft_data, &mut b, false); + recursive_fft(&domain.fft_data, &mut b, false); + #[cfg(feature = "profile")] stop_measure(start); for i in 0..n { diff --git a/halo2_proofs/src/transcript.rs b/halo2_proofs/src/transcript.rs index 45c08df95b..14c41768f5 100644 --- a/halo2_proofs/src/transcript.rs +++ b/halo2_proofs/src/transcript.rs @@ -239,12 +239,12 @@ impl Transcript> for Keccak256Read> { fn squeeze_challenge(&mut self) -> Challenge255 { - self.state.update(&[KECCAK256_PREFIX_CHALLENGE]); + self.state.update([KECCAK256_PREFIX_CHALLENGE]); let mut state_lo = self.state.clone(); let mut state_hi = self.state.clone(); - state_lo.update(&[KECCAK256_PREFIX_CHALLENGE_LO]); - state_hi.update(&[KECCAK256_PREFIX_CHALLENGE_HI]); + state_lo.update([KECCAK256_PREFIX_CHALLENGE_LO]); + state_hi.update([KECCAK256_PREFIX_CHALLENGE_HI]); let result_lo: [u8; 32] = state_lo.finalize().as_slice().try_into().unwrap(); let result_hi: [u8; 32] = state_hi.finalize().as_slice().try_into().unwrap(); @@ -256,7 +256,7 @@ impl Transcript> } fn common_point(&mut self, point: C) -> io::Result<()> { - self.state.update(&[KECCAK256_PREFIX_POINT]); + self.state.update([KECCAK256_PREFIX_POINT]); let coords: Coordinates = Option::from(point.coordinates()).ok_or_else(|| { io::Error::new( io::ErrorKind::Other, @@ -270,7 +270,7 @@ impl Transcript> } fn common_scalar(&mut self, scalar: C::Scalar) -> io::Result<()> { - self.state.update(&[KECCAK256_PREFIX_SCALAR]); + self.state.update([KECCAK256_PREFIX_SCALAR]); self.state.update(scalar.to_repr().as_ref()); Ok(()) @@ -401,12 +401,12 @@ impl Transcript> for Keccak256Write> { fn squeeze_challenge(&mut self) -> Challenge255 { - self.state.update(&[KECCAK256_PREFIX_CHALLENGE]); + self.state.update([KECCAK256_PREFIX_CHALLENGE]); let mut state_lo = self.state.clone(); let mut state_hi = self.state.clone(); - state_lo.update(&[KECCAK256_PREFIX_CHALLENGE_LO]); - state_hi.update(&[KECCAK256_PREFIX_CHALLENGE_HI]); + state_lo.update([KECCAK256_PREFIX_CHALLENGE_LO]); + state_hi.update([KECCAK256_PREFIX_CHALLENGE_HI]); let result_lo: [u8; 32] = state_lo.finalize().as_slice().try_into().unwrap(); let result_hi: [u8; 32] = state_hi.finalize().as_slice().try_into().unwrap(); @@ -418,7 +418,7 @@ impl Transcript> } fn common_point(&mut self, point: C) -> io::Result<()> { - self.state.update(&[KECCAK256_PREFIX_POINT]); + self.state.update([KECCAK256_PREFIX_POINT]); let coords: Coordinates = Option::from(point.coordinates()).ok_or_else(|| { io::Error::new( io::ErrorKind::Other, @@ -432,7 +432,7 @@ impl Transcript> } fn common_scalar(&mut self, scalar: C::Scalar) -> io::Result<()> { - self.state.update(&[KECCAK256_PREFIX_SCALAR]); + self.state.update([KECCAK256_PREFIX_SCALAR]); self.state.update(scalar.to_repr().as_ref()); Ok(()) diff --git a/halo2_proofs/tests/plonk_api.rs b/halo2_proofs/tests/plonk_api.rs index 511b6c78a9..5f0919f086 100644 --- a/halo2_proofs/tests/plonk_api.rs +++ b/halo2_proofs/tests/plonk_api.rs @@ -1,7 +1,9 @@ #![allow(clippy::many_single_char_names)] #![allow(clippy::op_ref)] +#![allow(unused_macros)] +#![allow(dead_code)] -use assert_matches::assert_matches; +// use assert_matches::assert_matches; use halo2_proofs::arithmetic::{Field, FieldExt}; use halo2_proofs::circuit::{Cell, Layouter, SimpleFloorPlanner, Value}; use halo2_proofs::dev::MockProver; From 2fac78bc5c0a3b0f74be2a1cf801e9813279c25b Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Wed, 26 Jul 2023 19:53:29 -0600 Subject: [PATCH 3/5] feat: update serialization example to compare bincode --- halo2_proofs/Cargo.toml | 1 + halo2_proofs/examples/serialization.rs | 134 +++++++++++++++---------- 2 files changed, 84 insertions(+), 51 deletions(-) diff --git a/halo2_proofs/Cargo.toml b/halo2_proofs/Cargo.toml index 059ae37331..a1982bf538 100644 --- a/halo2_proofs/Cargo.toml +++ b/halo2_proofs/Cargo.toml @@ -62,6 +62,7 @@ rustc-hash = "1.1.0" sha3 = "0.9.1" ark-std = { version = "0.3.0", features = ["print-trace"], optional = true } serde = { version = "1.0", default-features = false, features = ["derive"] } +bincode = "1.3.3" # Developer tooling dependencies plotters = { version = "0.3.0", optional = true } diff --git a/halo2_proofs/examples/serialization.rs b/halo2_proofs/examples/serialization.rs index 8b925b7adf..052c492c78 100644 --- a/halo2_proofs/examples/serialization.rs +++ b/halo2_proofs/examples/serialization.rs @@ -1,5 +1,5 @@ use std::{ - fs::File, + fs::{self, File}, io::{BufReader, BufWriter, Write}, }; @@ -126,59 +126,91 @@ impl Circuit for StandardPlonk { } } -fn main() { - let k = 4; +fn main() -> std::io::Result<()> { + let k = 22; let circuit = StandardPlonk(Fr::random(OsRng)); let params = ParamsKZG::::setup(k, OsRng); let vk = keygen_vk(¶ms, &circuit).expect("vk should not fail"); let pk = keygen_pk(¶ms, vk, &circuit).expect("pk should not fail"); - let f = File::create("serialization-test.pk").unwrap(); - let mut writer = BufWriter::new(f); - pk.write(&mut writer, SerdeFormat::RawBytes).unwrap(); - writer.flush().unwrap(); - - let f = File::open("serialization-test.pk").unwrap(); - let mut reader = BufReader::new(f); - let pk = ProvingKey::::read::<_, StandardPlonk>(&mut reader, SerdeFormat::RawBytes) - .unwrap(); - - std::fs::remove_file("serialization-test.pk").unwrap(); - - let instances: &[&[Fr]] = &[&[circuit.0]]; - let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); - create_proof::< - KZGCommitmentScheme, - ProverGWC<'_, Bn256>, - Challenge255, - _, - Blake2bWrite, G1Affine, Challenge255<_>>, - _, - >( - ¶ms, - &pk, - &[circuit], - &[instances], - OsRng, - &mut transcript, - ) - .expect("prover should not fail"); - let proof = transcript.finalize(); - - let strategy = SingleStrategy::new(¶ms); - let mut transcript = Blake2bRead::<_, _, Challenge255<_>>::init(&proof[..]); - assert!(verify_proof::< - KZGCommitmentScheme, - VerifierGWC<'_, Bn256>, - Challenge255, - Blake2bRead<&[u8], G1Affine, Challenge255>, - SingleStrategy<'_, Bn256>, - >( - ¶ms, - pk.get_vk(), - strategy, - &[instances], - &mut transcript - ) - .is_ok()); + for buf_size in [1024, 1024 * 1024, 1024 * 1024 * 1024] { + println!("buf_size: {buf_size}"); + // Using halo2_proofs serde implementation + let f = File::create("serialization-test.pk")?; + let mut writer = BufWriter::with_capacity(buf_size, f); + let start = std::time::Instant::now(); + pk.write(&mut writer, SerdeFormat::RawBytes)?; + writer.flush().unwrap(); + println!("SerdeFormat::RawBytes pk write time: {:?}", start.elapsed()); + + let f = File::open("serialization-test.pk")?; + let mut reader = BufReader::with_capacity(buf_size, f); + let start = std::time::Instant::now(); + let pk = + ProvingKey::::read::<_, StandardPlonk>(&mut reader, SerdeFormat::RawBytes) + .unwrap(); + println!("SerdeFormat::RawBytes pk read time: {:?}", start.elapsed()); + + let metadata = fs::metadata("serialization-test.pk")?; + let file_size = metadata.len(); + println!("The size of the file is {} bytes", file_size); + std::fs::remove_file("serialization-test.pk")?; + + // Using bincode + let f = File::create("serialization-test.pk")?; + let mut writer = BufWriter::with_capacity(buf_size, f); + let start = std::time::Instant::now(); + bincode::serialize_into(&mut writer, &pk).unwrap(); + writer.flush().unwrap(); + println!("bincode pk write time: {:?}", start.elapsed()); + + let f = File::open("serialization-test.pk").unwrap(); + let mut reader = BufReader::with_capacity(buf_size, f); + let start = std::time::Instant::now(); + let pk: ProvingKey = bincode::deserialize_from(&mut reader).unwrap(); + println!("bincode pk read time: {:?}", start.elapsed()); + + let metadata = fs::metadata("serialization-test.pk")?; + let file_size = metadata.len(); + println!("The size of the file is {} bytes", file_size); + std::fs::remove_file("serialization-test.pk").unwrap(); + + let instances: &[&[Fr]] = &[&[circuit.clone().0]]; + let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); + create_proof::< + KZGCommitmentScheme, + ProverGWC<'_, Bn256>, + Challenge255, + _, + Blake2bWrite, G1Affine, Challenge255<_>>, + _, + >( + ¶ms, + &pk, + &[circuit.clone()], + &[instances], + OsRng, + &mut transcript, + ) + .expect("prover should not fail"); + let proof = transcript.finalize(); + + let strategy = SingleStrategy::new(¶ms); + let mut transcript = Blake2bRead::<_, _, Challenge255<_>>::init(&proof[..]); + assert!(verify_proof::< + KZGCommitmentScheme, + VerifierGWC<'_, Bn256>, + Challenge255, + Blake2bRead<&[u8], G1Affine, Challenge255>, + SingleStrategy<'_, Bn256>, + >( + ¶ms, + pk.get_vk(), + strategy, + &[instances], + &mut transcript + ) + .is_ok()); + } + Ok(()) } From 782df4a857cd904487f8016330da5ce0bf7aad48 Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Thu, 27 Jul 2023 00:49:59 -0600 Subject: [PATCH 4/5] chore: update example --- halo2_proofs/examples/serialization.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/halo2_proofs/examples/serialization.rs b/halo2_proofs/examples/serialization.rs index 052c492c78..fb10d8246e 100644 --- a/halo2_proofs/examples/serialization.rs +++ b/halo2_proofs/examples/serialization.rs @@ -133,7 +133,7 @@ fn main() -> std::io::Result<()> { let vk = keygen_vk(¶ms, &circuit).expect("vk should not fail"); let pk = keygen_pk(¶ms, vk, &circuit).expect("pk should not fail"); - for buf_size in [1024, 1024 * 1024, 1024 * 1024 * 1024] { + for buf_size in [1024, 8 * 1024, 1024 * 1024, 1024 * 1024 * 1024] { println!("buf_size: {buf_size}"); // Using halo2_proofs serde implementation let f = File::create("serialization-test.pk")?; From 4beeb03d8f1cecb4f0f6caff99a86343038af1df Mon Sep 17 00:00:00 2001 From: Xinding Wei Date: Thu, 26 Oct 2023 19:42:49 +0000 Subject: [PATCH 5/5] Add multi-thread pk read/write --- halo2_proofs/Cargo.toml | 17 ++- halo2_proofs/examples/serialization.rs | 200 ++++++++++++++----------- halo2_proofs/src/helpers.rs | 69 ++++++++- halo2_proofs/src/plonk.rs | 141 ++++++++++++++++- halo2_proofs/src/plonk/permutation.rs | 68 ++++++++- halo2_proofs/tests/plonk_api.rs | 3 + 6 files changed, 402 insertions(+), 96 deletions(-) diff --git a/halo2_proofs/Cargo.toml b/halo2_proofs/Cargo.toml index a1982bf538..fa752f61c6 100644 --- a/halo2_proofs/Cargo.toml +++ b/halo2_proofs/Cargo.toml @@ -1,12 +1,7 @@ [package] name = "halo2_proofs" version = "0.2.0" -authors = [ - "Sean Bowe ", - "Ying Tong Lai ", - "Daira Hopwood ", - "Jack Grigg ", -] +authors = ["Sean Bowe ", "Ying Tong Lai ", "Daira Hopwood ", "Jack Grigg "] edition = "2021" rust-version = "1.56.1" description = """ @@ -55,7 +50,7 @@ ff = "0.12" group = "0.12" halo2curves = { path = "../arithmetic/curves" } rand = "0.8" -rand_core = { version = "0.6", default-features = false} +rand_core = { version = "0.6", default-features = false } tracing = "0.1" blake2b_simd = "1" rustc-hash = "1.1.0" @@ -63,6 +58,10 @@ sha3 = "0.9.1" ark-std = { version = "0.3.0", features = ["print-trace"], optional = true } serde = { version = "1.0", default-features = false, features = ["derive"] } bincode = "1.3.3" +maybe-rayon = { version = "0.1.0", default-features = false } +itertools = "0.10" +tokio = { version = "1.33", features = ["full"] } +serde_json = { version = "1.0", default-features = false } # Developer tooling dependencies plotters = { version = "0.3.0", optional = true } @@ -80,8 +79,10 @@ rand_chacha = "0.3.1" getrandom = { version = "0.2", features = ["js"] } [features] -default = ["batch"] +default = ["batch", "multicore"] +multicore = ["maybe-rayon/threads"] dev-graph = ["plotters", "tabbycat"] +test-dev-graph = ["dev-graph", "plotters/bitmap_backend", "plotters/bitmap_encoder", "plotters/ttf"] gadget-traces = ["backtrace"] sanity-checks = [] batch = ["rand/getrandom"] diff --git a/halo2_proofs/examples/serialization.rs b/halo2_proofs/examples/serialization.rs index fb10d8246e..721301c515 100644 --- a/halo2_proofs/examples/serialization.rs +++ b/halo2_proofs/examples/serialization.rs @@ -126,91 +126,121 @@ impl Circuit for StandardPlonk { } } -fn main() -> std::io::Result<()> { +#[tokio::main(flavor = "multi_thread", worker_threads = 24)] +async fn main() -> std::io::Result<()> { let k = 22; - let circuit = StandardPlonk(Fr::random(OsRng)); - let params = ParamsKZG::::setup(k, OsRng); - let vk = keygen_vk(¶ms, &circuit).expect("vk should not fail"); - let pk = keygen_pk(¶ms, vk, &circuit).expect("pk should not fail"); - - for buf_size in [1024, 8 * 1024, 1024 * 1024, 1024 * 1024 * 1024] { - println!("buf_size: {buf_size}"); - // Using halo2_proofs serde implementation - let f = File::create("serialization-test.pk")?; - let mut writer = BufWriter::with_capacity(buf_size, f); - let start = std::time::Instant::now(); - pk.write(&mut writer, SerdeFormat::RawBytes)?; - writer.flush().unwrap(); - println!("SerdeFormat::RawBytes pk write time: {:?}", start.elapsed()); - - let f = File::open("serialization-test.pk")?; - let mut reader = BufReader::with_capacity(buf_size, f); - let start = std::time::Instant::now(); - let pk = - ProvingKey::::read::<_, StandardPlonk>(&mut reader, SerdeFormat::RawBytes) - .unwrap(); - println!("SerdeFormat::RawBytes pk read time: {:?}", start.elapsed()); - - let metadata = fs::metadata("serialization-test.pk")?; - let file_size = metadata.len(); - println!("The size of the file is {} bytes", file_size); - std::fs::remove_file("serialization-test.pk")?; - - // Using bincode - let f = File::create("serialization-test.pk")?; - let mut writer = BufWriter::with_capacity(buf_size, f); - let start = std::time::Instant::now(); - bincode::serialize_into(&mut writer, &pk).unwrap(); - writer.flush().unwrap(); - println!("bincode pk write time: {:?}", start.elapsed()); - - let f = File::open("serialization-test.pk").unwrap(); - let mut reader = BufReader::with_capacity(buf_size, f); - let start = std::time::Instant::now(); - let pk: ProvingKey = bincode::deserialize_from(&mut reader).unwrap(); - println!("bincode pk read time: {:?}", start.elapsed()); - - let metadata = fs::metadata("serialization-test.pk")?; - let file_size = metadata.len(); - println!("The size of the file is {} bytes", file_size); - std::fs::remove_file("serialization-test.pk").unwrap(); - - let instances: &[&[Fr]] = &[&[circuit.clone().0]]; - let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); - create_proof::< - KZGCommitmentScheme, - ProverGWC<'_, Bn256>, - Challenge255, - _, - Blake2bWrite, G1Affine, Challenge255<_>>, - _, - >( - ¶ms, - &pk, - &[circuit.clone()], - &[instances], - OsRng, - &mut transcript, - ) - .expect("prover should not fail"); - let proof = transcript.finalize(); - - let strategy = SingleStrategy::new(¶ms); - let mut transcript = Blake2bRead::<_, _, Challenge255<_>>::init(&proof[..]); - assert!(verify_proof::< - KZGCommitmentScheme, - VerifierGWC<'_, Bn256>, - Challenge255, - Blake2bRead<&[u8], G1Affine, Challenge255>, - SingleStrategy<'_, Bn256>, - >( - ¶ms, - pk.get_vk(), - strategy, - &[instances], - &mut transcript - ) - .is_ok()); - } + + let buf_size = 1024 * 1024; + + // let pk_path = "/home/ubuntu/playground/serialization-test.pk"; + // let f = File::open(pk_path)?; + // let mut reader = BufReader::with_capacity(buf_size, f); + // let start = std::time::Instant::now(); + // let pk = ProvingKey::::read::<_, StandardPlonk>(&mut reader, SerdeFormat::RawBytes) + // .unwrap(); + // println!("SerdeFormat::RawBytes pk read time: {:?}", start.elapsed()); + + // let pk_folder = "/home/ubuntu/playground/serialization-test/"; + let pk_folder = "/mnt/ramdisk/serialization-test"; + // let start = std::time::Instant::now(); + // pk.multi_thread_write(pk_folder, SerdeFormat::RawBytes)?; + // println!( + // "SerdeFormat::RawBytes pk multi thread write time: {:?}", + // start.elapsed() + // ); + + let start = std::time::Instant::now(); + ProvingKey::::multi_thread_read::(pk_folder, SerdeFormat::RawBytes) + .await?; + println!( + "SerdeFormat::RawBytes pk multi thread read time: {:?}", + start.elapsed() + ); + Ok(()) + // let circuit = StandardPlonk(Fr::random(OsRng)); + // let params = ParamsKZG::::setup(k, OsRng); + // let vk = keygen_vk(¶ms, &circuit).expect("vk should not fail"); + // let pk = keygen_pk(¶ms, vk, &circuit).expect("pk should not fail"); + + // for buf_size in [1024, 8 * 1024, 1024 * 1024, 1024 * 1024 * 1024] { + // println!("buf_size: {buf_size}"); + // // Using halo2_proofs serde implementation + // let f = File::create("serialization-test.pk")?; + // let mut writer = BufWriter::with_capacity(buf_size, f); + // let start = std::time::Instant::now(); + // pk.write(&mut writer, SerdeFormat::RawBytes)?; + // writer.flush().unwrap(); + // println!("SerdeFormat::RawBytes pk write time: {:?}", start.elapsed()); + + // let f = File::open("serialization-test.pk")?; + // let mut reader = BufReader::with_capacity(buf_size, f); + // let start = std::time::Instant::now(); + // let pk = + // ProvingKey::::read::<_, StandardPlonk>(&mut reader, SerdeFormat::RawBytes) + // .unwrap(); + // println!("SerdeFormat::RawBytes pk read time: {:?}", start.elapsed()); + + // let metadata = fs::metadata("serialization-test.pk")?; + // let file_size = metadata.len(); + // println!("The size of the file is {} bytes", file_size); + // std::fs::remove_file("serialization-test.pk")?; + + // // Using bincode + // let f = File::create("serialization-test.pk")?; + // let mut writer = BufWriter::with_capacity(buf_size, f); + // let start = std::time::Instant::now(); + // bincode::serialize_into(&mut writer, &pk).unwrap(); + // writer.flush().unwrap(); + // println!("bincode pk write time: {:?}", start.elapsed()); + + // let f = File::open("serialization-test.pk").unwrap(); + // let mut reader = BufReader::with_capacity(buf_size, f); + // let start = std::time::Instant::now(); + // let pk: ProvingKey = bincode::deserialize_from(&mut reader).unwrap(); + // println!("bincode pk read time: {:?}", start.elapsed()); + + // let metadata = fs::metadata("serialization-test.pk")?; + // let file_size = metadata.len(); + // println!("The size of the file is {} bytes", file_size); + // std::fs::remove_file("serialization-test.pk").unwrap(); + + // let instances: &[&[Fr]] = &[&[circuit.clone().0]]; + // let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); + // create_proof::< + // KZGCommitmentScheme, + // ProverGWC<'_, Bn256>, + // Challenge255, + // _, + // Blake2bWrite, G1Affine, Challenge255<_>>, + // _, + // >( + // ¶ms, + // &pk, + // &[circuit.clone()], + // &[instances], + // OsRng, + // &mut transcript, + // ) + // .expect("prover should not fail"); + // let proof = transcript.finalize(); + + // let strategy = SingleStrategy::new(¶ms); + // let mut transcript = Blake2bRead::<_, _, Challenge255<_>>::init(&proof[..]); + // assert!(verify_proof::< + // KZGCommitmentScheme, + // VerifierGWC<'_, Bn256>, + // Challenge255, + // Blake2bRead<&[u8], G1Affine, Challenge255>, + // SingleStrategy<'_, Bn256>, + // >( + // ¶ms, + // pk.get_vk(), + // strategy, + // &[instances], + // &mut transcript + // ) + // .is_ok()); + // } + // Ok(()) } diff --git a/halo2_proofs/src/helpers.rs b/halo2_proofs/src/helpers.rs index 8eba57359b..37484e0ad1 100644 --- a/halo2_proofs/src/helpers.rs +++ b/halo2_proofs/src/helpers.rs @@ -1,7 +1,15 @@ use crate::poly::Polynomial; use ff::PrimeField; use halo2curves::{pairing::Engine, serde::SerdeObject, CurveAffine}; -use std::io; +use itertools::Itertools; +use maybe_rayon::prelude::{ + IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator, ParallelIterator, +}; +use std::{ + fs::File, + io::{self, BufReader, BufWriter}, + path::Path, +}; /// This enum specifies how various types are serialized and deserialized. #[derive(Clone, Copy, Debug)] @@ -125,6 +133,35 @@ pub(crate) fn read_polynomial_vec( .collect() } +/// Reads a vector of polynomials from buffer +pub(crate) async fn multi_thread_read_polynomial_vec< + F: SerdePrimeField, + B: Send + Sync + 'static, +>( + pk_prefix_path: impl AsRef, + format: SerdeFormat, + n: usize, +) -> Vec> { + const BUFFER_SIZE: usize = 1024 * 1024; + let join_handles = (0..n) + .map(|i| { + let mut poly_path = pk_prefix_path + .as_ref() + .clone() + .to_path_buf() + .into_os_string(); + poly_path.push(format!("_{i}")); + let mut reader = BufReader::with_capacity(BUFFER_SIZE, File::open(poly_path).unwrap()); + tokio::spawn(async move { Polynomial::::read(&mut reader, format) }) + }) + .collect_vec(); + let mut ret = Vec::with_capacity(join_handles.len()); + for join_handle in join_handles { + ret.push(join_handle.await.unwrap()); + } + ret +} + /// Writes a slice of polynomials to buffer pub(crate) fn write_polynomial_slice( slice: &[Polynomial], @@ -139,6 +176,36 @@ pub(crate) fn write_polynomial_slice( } } +/// Writes a slice of polynomials to buffer +pub(crate) fn multi_thread_write_polynomial_slice( + slice: &[Polynomial], + pk_prefix_path: impl AsRef, + format: SerdeFormat, +) { + const BUFFER_SIZE: usize = 1024 * 1024; + let poly_path = slice + .iter() + .enumerate() + .map(|(i, _)| { + let mut poly_path = pk_prefix_path + .as_ref() + .clone() + .to_path_buf() + .into_os_string(); + poly_path.push(format!("_{i}")); + poly_path + }) + .collect_vec(); + slice + .par_iter() + .zip_eq(poly_path.par_iter()) + .for_each(|(poly, poly_path)| { + let mut writer = + BufWriter::with_capacity(BUFFER_SIZE, File::create(poly_path).unwrap()); + poly.write(&mut writer, format); + }); +} + /// Gets the total number of bytes of a slice of polynomials, assuming all polynomials are the same length pub(crate) fn polynomial_slice_byte_length(slice: &[Polynomial]) -> usize { let field_len = F::default().to_repr().as_ref().len(); diff --git a/halo2_proofs/src/plonk.rs b/halo2_proofs/src/plonk.rs index 344f98f6a5..e681170998 100644 --- a/halo2_proofs/src/plonk.rs +++ b/halo2_proofs/src/plonk.rs @@ -8,10 +8,12 @@ use blake2b_simd::Params as Blake2bParams; use ff::PrimeField; use group::ff::Field; +use rayon::prelude::{IntoParallelRefIterator, IntoParallelRefMutIterator, ParallelIterator}; use serde::{Deserialize, Serialize}; use crate::arithmetic::{CurveAffine, FieldExt}; use crate::helpers::{ + multi_thread_read_polynomial_vec, multi_thread_write_polynomial_slice, polynomial_slice_byte_length, read_polynomial_vec, write_polynomial_slice, SerdeCurveAffine, SerdePrimeField, }; @@ -45,7 +47,7 @@ pub use verifier::*; use evaluation::Evaluator; use std::env::var; -use std::io; +use std::io::{self, BufReader}; use std::time::Instant; /// Temp @@ -149,6 +151,10 @@ pub fn log_info(msg: String) { } } +use std::fs::File; +use std::io::BufWriter; +use std::path::{Path, PathBuf}; + /// This is a verifying key which allows for the verification of proofs for a /// particular circuit. #[derive(Clone, Debug, Serialize, Deserialize)] @@ -387,6 +393,13 @@ pub struct ProvingKey { ev: Evaluator, } +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct ProvingKeyMetadata { + pub num_fixed_values: usize, + pub num_fixed_polys: usize, + pub num_fixed_cosets: usize, +} + impl ProvingKey { /// Get the underlying [`VerifyingKey`]. pub fn get_vk(&self) -> &VerifyingKey { @@ -432,6 +445,56 @@ where Ok(()) } + pub fn multi_thread_write( + &self, + path: impl AsRef, + format: SerdeFormat, + ) -> io::Result<()> { + const BUFFER_SIZE: usize = 1024 * 1024; + let create_path = |filename: &str| -> PathBuf { + let mut p = path.as_ref().to_path_buf(); + p.push(filename); + p + }; + let create_writer = |filename: &str| -> BufWriter { + let p = create_path(filename); + BufWriter::with_capacity(BUFFER_SIZE, File::create(p).unwrap()) + }; + let metadata = ProvingKeyMetadata { + num_fixed_values: self.fixed_values.len(), + num_fixed_polys: self.fixed_polys.len(), + num_fixed_cosets: self.fixed_cosets.len(), + }; + serde_json::to_writer(create_writer("metadata"), &metadata)?; + + let mut vk_writer = create_writer("vk"); + self.vk.write(&mut vk_writer, format)?; + + let l0_writer = create_writer("l0"); + let l_last_writer = create_writer("l_last"); + let l_active_writer = create_writer("l_active"); + + let mut v = [ + (l0_writer, &self.l0), + (l_last_writer, &self.l_last), + (l_active_writer, &self.l_last), + ]; + v.par_iter_mut().for_each(|(writer, p)| { + p.write(writer, format); + }); + + let values_path_prefix = create_path("values"); + multi_thread_write_polynomial_slice(&self.fixed_values, values_path_prefix, format); + let polys_path_prefix = create_path("polys"); + multi_thread_write_polynomial_slice(&self.fixed_polys, polys_path_prefix, format); + let cosets_path_prefix = create_path("cosets"); + multi_thread_write_polynomial_slice(&self.fixed_cosets, cosets_path_prefix, format); + + let perm_path = create_path("perm"); + self.permutation.multi_thread_write(perm_path, format); + Ok(()) + } + /// Reads a proving key from a buffer. /// Does so by reading verification key first, and then deserializing the rest of the file into the remaining proving key data. /// @@ -469,6 +532,82 @@ where }) } + pub async fn multi_thread_read>( + path: impl AsRef, + format: SerdeFormat, + ) -> io::Result { + const BUFFER_SIZE: usize = 1024 * 1024; + let create_path = |filename: &str| -> PathBuf { + let mut p = path.as_ref().to_path_buf(); + p.push(filename); + p + }; + let create_reader = |filename: &str| -> BufReader { + let p = create_path(filename); + BufReader::with_capacity(BUFFER_SIZE, File::open(p).unwrap()) + }; + let metadata: ProvingKeyMetadata = serde_json::from_reader(create_reader("metadata"))?; + + let vk_path = create_path("vk"); + let mut reader = BufReader::with_capacity(BUFFER_SIZE, File::open(vk_path).unwrap()); + let vk_promise = tokio::spawn(async move { + VerifyingKey::::read::<_, ConcreteCircuit>(&mut reader, format).unwrap() + }); + + let l0_path = create_path("l0"); + let mut reader = BufReader::with_capacity(BUFFER_SIZE, File::open(l0_path).unwrap()); + let l0_promise = tokio::spawn(async move { Polynomial::read(&mut reader, format) }); + + let l_last_path = create_path("l_last"); + let mut reader = BufReader::with_capacity(BUFFER_SIZE, File::open(l_last_path).unwrap()); + let l_last_promise = tokio::spawn(async move { Polynomial::read(&mut reader, format) }); + + let l_active_row_path = create_path("l_active"); + let mut reader = + BufReader::with_capacity(BUFFER_SIZE, File::open(l_active_row_path).unwrap()); + let l_active_row_promise = + tokio::spawn(async move { Polynomial::read(&mut reader, format) }); + + let values_path = create_path("values"); + let values_n = metadata.num_fixed_values; + let values_promise = tokio::spawn(multi_thread_read_polynomial_vec( + values_path, + format, + values_n, + )); + + let polys_path = create_path("polys"); + let polys_n = metadata.num_fixed_polys; + let polys_promise = tokio::spawn(multi_thread_read_polynomial_vec( + polys_path, format, polys_n, + )); + + let cosets_path = create_path("cosets"); + let cosets_n = metadata.num_fixed_cosets; + let cosets_promise = tokio::spawn(multi_thread_read_polynomial_vec( + cosets_path, + format, + cosets_n, + )); + + let perm_path = create_path("perm"); + let permutation_promise = + tokio::spawn(permutation::ProvingKey::async_read(perm_path, format)); + let vk = vk_promise.await?; + let ev = Evaluator::new(vk.cs()); + Ok(Self { + vk, + l0: l0_promise.await?, + l_last: l_last_promise.await?, + l_active_row: l_active_row_promise.await?, + fixed_values: values_promise.await?, + fixed_polys: polys_promise.await?, + fixed_cosets: cosets_promise.await?, + permutation: permutation_promise.await?, + ev, + }) + } + /// Writes a proving key to a vector of bytes using [`Self::write`]. pub fn to_bytes(&self, format: SerdeFormat) -> Vec { let mut bytes = Vec::::with_capacity(self.bytes_length()); diff --git a/halo2_proofs/src/plonk/permutation.rs b/halo2_proofs/src/plonk/permutation.rs index 255c12a956..fb549d0ca5 100644 --- a/halo2_proofs/src/plonk/permutation.rs +++ b/halo2_proofs/src/plonk/permutation.rs @@ -4,6 +4,7 @@ use super::circuit::{Any, Column}; use crate::{ arithmetic::CurveAffine, helpers::{ + multi_thread_read_polynomial_vec, multi_thread_write_polynomial_slice, polynomial_slice_byte_length, read_polynomial_vec, write_polynomial_slice, SerdeCurveAffine, SerdePrimeField, }, @@ -19,7 +20,7 @@ pub(crate) mod verifier; pub use keygen::Assembly; use serde::{Deserialize, Serialize}; -use std::io; +use std::{fs::File, io, path::Path}; /// A permutation argument. #[derive(Debug, Clone, Serialize, Deserialize)] @@ -135,6 +136,13 @@ pub(crate) struct ProvingKey { pub(super) cosets: Vec>, } +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct ProvingKeyMetadata { + pub num_perms: usize, + pub num_polys: usize, + pub num_cosets: usize, +} + impl ProvingKey where C::Scalar: SerdePrimeField, @@ -150,6 +158,41 @@ where cosets, } } + /// Async version of read + pub(super) async fn async_read(prefix: impl AsRef, format: SerdeFormat) -> Self { + let os_str = prefix.as_ref().clone().to_path_buf().into_os_string(); + let mut metadata_os_str = os_str.clone(); + metadata_os_str.push("_metadata"); + let metadata: ProvingKeyMetadata = + serde_json::from_reader(File::open(metadata_os_str).unwrap()).unwrap(); + + let mut perms_os_str = os_str.clone(); + perms_os_str.push("_perm"); + let permutations_promise = tokio::spawn(multi_thread_read_polynomial_vec( + perms_os_str, + format, + metadata.num_perms, + )); + let mut polys_os_str = os_str.clone(); + polys_os_str.push("_polys"); + let polys_promise = tokio::spawn(multi_thread_read_polynomial_vec( + polys_os_str, + format, + metadata.num_polys, + )); + let mut cosets_os_str = os_str.clone(); + cosets_os_str.push("_cosets"); + let cosets_promise = tokio::spawn(multi_thread_read_polynomial_vec( + cosets_os_str, + format, + metadata.num_cosets, + )); + ProvingKey { + permutations: permutations_promise.await.unwrap(), + polys: polys_promise.await.unwrap(), + cosets: cosets_promise.await.unwrap(), + } + } /// Writes proving key for a single permutation argument to buffer using `Polynomial::write`. pub(super) fn write(&self, writer: &mut W, format: SerdeFormat) { @@ -157,6 +200,29 @@ where write_polynomial_slice(&self.polys, writer, format); write_polynomial_slice(&self.cosets, writer, format); } + /// Multi thread write + pub(super) fn multi_thread_write(&self, prefix: impl AsRef, format: SerdeFormat) { + let os_str = prefix.as_ref().clone().to_path_buf().into_os_string(); + let metadata = ProvingKeyMetadata { + num_perms: self.permutations.len(), + num_polys: self.polys.len(), + num_cosets: self.cosets.len(), + }; + let mut metadata_os_str = os_str.clone(); + metadata_os_str.push("_metadata"); + serde_json::to_writer(File::create(metadata_os_str).unwrap(), &metadata).unwrap(); + + let permutations = &self.permutations; + let mut perm_os_str = os_str.clone(); + perm_os_str.push("_perm"); + multi_thread_write_polynomial_slice(permutations, perm_os_str, format); + let mut polys_os_str = os_str.clone(); + polys_os_str.push("_polys"); + multi_thread_write_polynomial_slice(&self.polys, polys_os_str, format); + let mut cosets_os_str = os_str.clone(); + cosets_os_str.push("_cosets"); + multi_thread_write_polynomial_slice(&self.cosets, cosets_os_str, format); + } } impl ProvingKey { diff --git a/halo2_proofs/tests/plonk_api.rs b/halo2_proofs/tests/plonk_api.rs index 5f0919f086..679163c7e7 100644 --- a/halo2_proofs/tests/plonk_api.rs +++ b/halo2_proofs/tests/plonk_api.rs @@ -19,8 +19,11 @@ use halo2_proofs::transcript::{ Blake2bRead, Blake2bWrite, Challenge255, EncodedChallenge, TranscriptReadBuffer, TranscriptWriterBuffer, }; +use halo2_proofs::SerdeFormat; use rand_core::{OsRng, RngCore}; +use std::fs::File; use std::hash::Hash; +use std::io; use std::marker::PhantomData; #[test]