From 6c23c6963527dcfc07f68684844256c73ce52794 Mon Sep 17 00:00:00 2001 From: Liuhaai Date: Fri, 15 Sep 2023 10:59:06 -0700 Subject: [PATCH] decouple config creation from VerifyingKey::read --- halo2_proofs/examples/serialization.rs | 8 +--- halo2_proofs/src/plonk.rs | 51 ++++---------------------- halo2_proofs/src/plonk/keygen.rs | 39 +++++++++++--------- 3 files changed, 31 insertions(+), 67 deletions(-) diff --git a/halo2_proofs/examples/serialization.rs b/halo2_proofs/examples/serialization.rs index 39b6b1192f..adc29a5478 100644 --- a/halo2_proofs/examples/serialization.rs +++ b/halo2_proofs/examples/serialization.rs @@ -143,13 +143,7 @@ fn main() { let f = File::open("serialization-test.pk").unwrap(); let mut reader = BufReader::new(f); #[allow(clippy::unit_arg)] - let pk = ProvingKey::::read::<_, StandardPlonk>( - &mut reader, - SerdeFormat::RawBytes, - #[cfg(feature = "circuit-params")] - circuit.params(), - ) - .unwrap(); + let pk = ProvingKey::::read::<_>(&mut reader, SerdeFormat::RawBytes).unwrap(); std::fs::remove_file("serialization-test.pk").unwrap(); diff --git a/halo2_proofs/src/plonk.rs b/halo2_proofs/src/plonk.rs index d05e7a4005..94a30c8b26 100644 --- a/halo2_proofs/src/plonk.rs +++ b/halo2_proofs/src/plonk.rs @@ -99,19 +99,11 @@ where /// Checks that field elements are less than modulus, and then checks that the point is on the curve. /// - `RawBytesUnchecked`: Reads an uncompressed curve element with coordinates in Montgomery form; /// does not perform any checks - pub fn read>( - reader: &mut R, - format: SerdeFormat, - #[cfg(feature = "circuit-params")] params: ConcreteCircuit::Params, - ) -> io::Result { + pub fn read(reader: &mut R, format: SerdeFormat) -> io::Result { let mut k = [0u8; 4]; reader.read_exact(&mut k)?; let k = u32::from_be_bytes(k); - let (domain, cs, _) = keygen::create_domain::( - k, - #[cfg(feature = "circuit-params")] - params, - ); + let (domain, cs) = keygen::create_domain::(k); let mut num_fixed_columns = [0u8; 4]; reader.read_exact(&mut num_fixed_columns)?; let num_fixed_columns = u32::from_be_bytes(num_fixed_columns); @@ -153,17 +145,8 @@ where } /// Reads a verification key from a slice of bytes using [`Self::read`]. - pub fn from_bytes>( - mut bytes: &[u8], - format: SerdeFormat, - #[cfg(feature = "circuit-params")] params: ConcreteCircuit::Params, - ) -> io::Result { - Self::read::<_, ConcreteCircuit>( - &mut bytes, - format, - #[cfg(feature = "circuit-params")] - params, - ) + pub fn from_bytes(mut bytes: &[u8], format: SerdeFormat) -> io::Result { + Self::read::<_>(&mut bytes, format) } } @@ -349,17 +332,8 @@ where /// Checks that field elements are less than modulus, and then checks that the point is on the curve. /// - `RawBytesUnchecked`: Reads an uncompressed curve element with coordinates in Montgomery form; /// does not perform any checks - pub fn read>( - reader: &mut R, - format: SerdeFormat, - #[cfg(feature = "circuit-params")] params: ConcreteCircuit::Params, - ) -> io::Result { - let vk = VerifyingKey::::read::( - reader, - format, - #[cfg(feature = "circuit-params")] - params, - )?; + pub fn read(reader: &mut R, format: SerdeFormat) -> io::Result { + let vk = VerifyingKey::::read::(reader, format)?; let l0 = Polynomial::read(reader, format)?; let l_last = Polynomial::read(reader, format)?; let l_active_row = Polynomial::read(reader, format)?; @@ -389,17 +363,8 @@ where } /// Reads a proving key from a slice of bytes using [`Self::read`]. - pub fn from_bytes>( - mut bytes: &[u8], - format: SerdeFormat, - #[cfg(feature = "circuit-params")] params: ConcreteCircuit::Params, - ) -> io::Result { - Self::read::<_, ConcreteCircuit>( - &mut bytes, - format, - #[cfg(feature = "circuit-params")] - params, - ) + pub fn from_bytes(mut bytes: &[u8], format: SerdeFormat) -> io::Result { + Self::read::<_>(&mut bytes, format) } } diff --git a/halo2_proofs/src/plonk/keygen.rs b/halo2_proofs/src/plonk/keygen.rs index bd48b7c96a..08458b65a0 100644 --- a/halo2_proofs/src/plonk/keygen.rs +++ b/halo2_proofs/src/plonk/keygen.rs @@ -23,29 +23,33 @@ use crate::{ }, }; -pub(crate) fn create_domain( - k: u32, - #[cfg(feature = "circuit-params")] params: ConcreteCircuit::Params, -) -> ( - EvaluationDomain, - ConstraintSystem, - ConcreteCircuit::Config, -) +pub(crate) fn create_domain(k: u32) -> (EvaluationDomain, ConstraintSystem) where C: CurveAffine, - ConcreteCircuit: Circuit, { - let mut cs = ConstraintSystem::default(); - #[cfg(feature = "circuit-params")] - let config = ConcreteCircuit::configure_with_params(&mut cs, params); - #[cfg(not(feature = "circuit-params"))] - let config = ConcreteCircuit::configure(&mut cs); + let cs = ConstraintSystem::default(); let degree = cs.degree(); let domain = EvaluationDomain::new(degree as u32, k); - (domain, cs, config) + (domain, cs) +} + +pub(crate) fn create_config( + cs: &mut ConstraintSystem, + #[cfg(feature = "circuit-params")] params: ConcreteCircuit::Params, +) -> ConcreteCircuit::Config +where + C: CurveAffine, + ConcreteCircuit: Circuit, +{ + #[cfg(feature = "circuit-params")] + let config = ConcreteCircuit::configure_with_params(cs, params); + #[cfg(not(feature = "circuit-params"))] + let config = ConcreteCircuit::configure(cs); + + config } /// Assembly to be used in circuit synthesis. @@ -213,8 +217,9 @@ where ConcreteCircuit: Circuit, C::Scalar: FromUniformBytes<64>, { - let (domain, cs, config) = create_domain::( - params.k(), + let (domain, mut cs) = create_domain::(params.k()); + let config = create_config::( + &mut cs, #[cfg(feature = "circuit-params")] circuit.params(), );