diff --git a/halo2_proofs/src/helpers.rs b/halo2_proofs/src/helpers.rs index df5259e6c..ee91b8d52 100644 --- a/halo2_proofs/src/helpers.rs +++ b/halo2_proofs/src/helpers.rs @@ -199,6 +199,14 @@ pub trait SerdeCurveAffine: CurveAffine + SerdeObject { _ => self.write_raw(writer), } } + + /// Byte length of an affine curve element according to `format`. + fn byte_length(format: SerdeFormat) -> usize { + match format { + SerdeFormat::Processed => Self::default().to_bytes().as_ref().len(), + _ => Self::Repr::default().as_ref().len() * 2, + } + } } impl SerdeCurveAffine for C {} diff --git a/halo2_proofs/src/plonk.rs b/halo2_proofs/src/plonk.rs index aa58eaaeb..875f09436 100644 --- a/halo2_proofs/src/plonk.rs +++ b/halo2_proofs/src/plonk.rs @@ -58,6 +58,8 @@ pub struct VerifyingKey { /// The representative of this `VerifyingKey` in transcripts. transcript_repr: C::Scalar, selectors: Vec>, + /// Whether selector compression is turned on or not. + compress_selectors: bool, } impl VerifyingKey @@ -74,13 +76,19 @@ where /// Writes a field element into raw bytes in its internal Montgomery representation, /// WITHOUT performing the expensive Montgomery reduction. pub fn write(&self, writer: &mut W, format: SerdeFormat) -> io::Result<()> { - writer.write_all(&self.domain.k().to_be_bytes())?; - writer.write_all(&(self.fixed_commitments.len() as u32).to_be_bytes())?; + // Version byte that will be checked on read. + writer.write_all(&[0x02])?; + writer.write_all(&self.domain.k().to_le_bytes())?; + writer.write_all(&[self.compress_selectors as u8])?; + writer.write_all(&(self.fixed_commitments.len() as u32).to_le_bytes())?; for commitment in &self.fixed_commitments { commitment.write(writer, format)?; } self.permutation.write(writer, format)?; + if !self.compress_selectors { + assert!(self.selectors.is_empty()); + } // write self.selectors for selector in &self.selectors { // since `selector` is filled with `bool`, we pack them 8 at a time into bytes and then write @@ -106,9 +114,26 @@ where format: SerdeFormat, #[cfg(feature = "circuit-params")] params: ConcreteCircuit::Params, ) -> io::Result { + let mut version_byte = [0u8; 1]; + reader.read_exact(&mut version_byte)?; + if 0x02 != version_byte[0] { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "unexpected version byte", + )); + } let mut k = [0u8; 4]; reader.read_exact(&mut k)?; - let k = u32::from_be_bytes(k); + let k = u32::from_le_bytes(k); + let mut compress_selectors = [0u8; 1]; + reader.read_exact(&mut compress_selectors)?; + if compress_selectors[0] != 0 && compress_selectors[0] != 1 { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "unexpected compress_selectors not boolean", + )); + } + let compress_selectors = compress_selectors[0] == 1; let (domain, cs, _) = keygen::create_domain::( k, #[cfg(feature = "circuit-params")] @@ -116,7 +141,7 @@ where ); let mut num_fixed_columns = [0u8; 4]; reader.read_exact(&mut num_fixed_columns)?; - let num_fixed_columns = u32::from_be_bytes(num_fixed_columns); + let num_fixed_columns = u32::from_le_bytes(num_fixed_columns); let fixed_commitments: Vec<_> = (0..num_fixed_columns) .map(|_| C::read(reader, format)) @@ -124,19 +149,27 @@ where let permutation = permutation::VerifyingKey::read(reader, &cs.permutation, format)?; - // read selectors - let selectors: Vec> = vec![vec![false; 1 << k]; cs.num_selectors] - .into_iter() - .map(|mut selector| { - let mut selector_bytes = vec![0u8; (selector.len() + 7) / 8]; - reader.read_exact(&mut selector_bytes)?; - for (bits, byte) in selector.chunks_mut(8).zip(selector_bytes) { - crate::helpers::unpack(byte, bits); - } - Ok(selector) - }) - .collect::>()?; - let (cs, _) = cs.compress_selectors(selectors.clone()); + let (cs, selectors) = if compress_selectors { + // read selectors + let selectors: Vec> = vec![vec![false; 1 << k]; cs.num_selectors] + .into_iter() + .map(|mut selector| { + let mut selector_bytes = vec![0u8; (selector.len() + 7) / 8]; + reader.read_exact(&mut selector_bytes)?; + for (bits, byte) in selector.chunks_mut(8).zip(selector_bytes) { + crate::helpers::unpack(byte, bits); + } + Ok(selector) + }) + .collect::>()?; + let (cs, _) = cs.compress_selectors(selectors.clone()); + (cs, selectors) + } else { + // we still need to replace selectors with fixed Expressions in `cs` + let fake_selectors = vec![vec![]; cs.num_selectors]; + let (cs, _) = cs.directly_convert_selectors_to_fixed(fake_selectors); + (cs, vec![]) + }; Ok(Self::from_parts( domain, @@ -144,12 +177,13 @@ where permutation, cs, selectors, + compress_selectors, )) } /// Writes a verifying key to a vector of bytes using [`Self::write`]. pub fn to_bytes(&self, format: SerdeFormat) -> Vec { - let mut bytes = Vec::::with_capacity(self.bytes_length()); + let mut bytes = Vec::::with_capacity(self.bytes_length(format)); Self::write(self, &mut bytes, format).expect("Writing to vector should not fail"); bytes } @@ -169,13 +203,13 @@ where } } -impl VerifyingKey -where - C::ScalarExt: FromUniformBytes<64>, -{ - fn bytes_length(&self) -> usize { - 8 + (self.fixed_commitments.len() * C::default().to_bytes().as_ref().len()) - + self.permutation.bytes_length() +impl VerifyingKey { + fn bytes_length(&self, format: SerdeFormat) -> usize + where + C: SerdeCurveAffine, + { + 10 + (self.fixed_commitments.len() * C::byte_length(format)) + + self.permutation.bytes_length(format) + self.selectors.len() * (self .selectors @@ -190,7 +224,11 @@ where permutation: permutation::VerifyingKey, cs: ConstraintSystem, selectors: Vec>, - ) -> Self { + compress_selectors: bool, + ) -> Self + where + C::ScalarExt: FromUniformBytes<64>, + { // Compute cached values. let cs_degree = cs.degree(); @@ -203,6 +241,7 @@ where // Temporary, this is not pinned. transcript_repr: C::Scalar::ZERO, selectors, + compress_selectors, }; let mut hasher = Blake2bParams::new() @@ -303,9 +342,12 @@ where } /// Gets the total number of bytes in the serialization of `self` - fn bytes_length(&self) -> usize { + fn bytes_length(&self, format: SerdeFormat) -> usize + where + C: SerdeCurveAffine, + { let scalar_len = C::Scalar::default().to_repr().as_ref().len(); - self.vk.bytes_length() + self.vk.bytes_length(format) + 12 + scalar_len * (self.l0.len() + self.l_last.len() + self.l_active_row.len()) + polynomial_slice_byte_length(&self.fixed_values) @@ -386,7 +428,7 @@ where /// Writes a proving key to a vector of bytes using [`Self::write`]. pub fn to_bytes(&self, format: SerdeFormat) -> Vec { - let mut bytes = Vec::::with_capacity(self.bytes_length()); + let mut bytes = Vec::::with_capacity(self.bytes_length(format)); Self::write(self, &mut bytes, format).expect("Writing to vector should not fail"); bytes } diff --git a/halo2_proofs/src/plonk/circuit.rs b/halo2_proofs/src/plonk/circuit.rs index f0d3a83d1..575375df1 100644 --- a/halo2_proofs/src/plonk/circuit.rs +++ b/halo2_proofs/src/plonk/circuit.rs @@ -1986,7 +1986,45 @@ impl ConstraintSystem { .into_iter() .map(|a| a.unwrap()) .collect::>(); + self.replace_selectors_with_fixed(&selector_replacements); + (self, polys) + } + + /// Does not combine selectors and directly replaces them everywhere with fixed columns. + pub fn directly_convert_selectors_to_fixed( + mut self, + selectors: Vec>, + ) -> (Self, Vec>) { + // The number of provided selector assignments must be the number we + // counted for this constraint system. + assert_eq!(selectors.len(), self.num_selectors); + + let (polys, selector_replacements): (Vec<_>, Vec<_>) = selectors + .into_iter() + .map(|selector| { + let poly = selector + .iter() + .map(|b| if *b { F::ONE } else { F::ZERO }) + .collect::>(); + let column = self.fixed_column(); + let rotation = Rotation::cur(); + let expr = Expression::Fixed(FixedQuery { + index: Some(self.query_fixed_index(column, rotation)), + column_index: column.index, + rotation, + }); + (poly, expr) + }) + .unzip(); + + self.replace_selectors_with_fixed(&selector_replacements); + self.num_selectors = 0; + + (self, polys) + } + + fn replace_selectors_with_fixed(&mut self, selector_replacements: &[Expression]) { fn replace_selectors( expr: &mut Expression, selector_replacements: &[Expression], @@ -2017,7 +2055,7 @@ impl ConstraintSystem { // Substitute selectors for the real fixed columns in all gates for expr in self.gates.iter_mut().flat_map(|gate| gate.polys.iter_mut()) { - replace_selectors(expr, &selector_replacements, false); + replace_selectors(expr, selector_replacements, false); } // Substitute non-simple selectors for the real fixed columns in all @@ -2028,10 +2066,8 @@ impl ConstraintSystem { .iter_mut() .chain(lookup.table_expressions.iter_mut()) }) { - replace_selectors(expr, &selector_replacements, true); + replace_selectors(expr, selector_replacements, true); } - - (self, polys) } /// Allocate a new (simple) selector. Simple selectors cannot be added to diff --git a/halo2_proofs/src/plonk/keygen.rs b/halo2_proofs/src/plonk/keygen.rs index c8c53aa90..fc70574b6 100644 --- a/halo2_proofs/src/plonk/keygen.rs +++ b/halo2_proofs/src/plonk/keygen.rs @@ -196,10 +196,28 @@ impl Assignment for Assembly { } /// Generate a `VerifyingKey` from an instance of `Circuit`. +/// By default, selector compression is turned **off**. pub fn keygen_vk<'params, C, P, ConcreteCircuit>( params: &P, circuit: &ConcreteCircuit, ) -> Result, Error> +where + C: CurveAffine, + P: Params<'params, C>, + ConcreteCircuit: Circuit, + C::Scalar: FromUniformBytes<64>, +{ + keygen_vk_custom(params, circuit, true) +} + +/// Generate a `VerifyingKey` from an instance of `Circuit`. +/// +/// The selector compression optimization is turned on only if `compress_selectors` is `true`. +pub fn keygen_vk_custom<'params, C, P, ConcreteCircuit>( + params: &P, + circuit: &ConcreteCircuit, + compress_selectors: bool, +) -> Result, Error> where C: CurveAffine, P: Params<'params, C>, @@ -234,7 +252,13 @@ where )?; let mut fixed = batch_invert_assigned(assembly.fixed); - let (cs, selector_polys) = cs.compress_selectors(assembly.selectors.clone()); + let (cs, selector_polys) = if compress_selectors { + cs.compress_selectors(assembly.selectors.clone()) + } else { + // After this, the ConstraintSystem should not have any selectors: `verify` does not need them, and `keygen_pk` regenerates `cs` from scratch anyways. + let selectors = std::mem::take(&mut assembly.selectors); + cs.directly_convert_selectors_to_fixed(selectors) + }; fixed.extend( selector_polys .into_iter() @@ -256,6 +280,7 @@ where permutation_vk, cs, assembly.selectors, + compress_selectors, )) } @@ -300,7 +325,11 @@ where )?; let mut fixed = batch_invert_assigned(assembly.fixed); - let (cs, selector_polys) = cs.compress_selectors(assembly.selectors); + let (cs, selector_polys) = if vk.compress_selectors { + cs.compress_selectors(assembly.selectors) + } else { + cs.directly_convert_selectors_to_fixed(assembly.selectors) + }; fixed.extend( selector_polys .into_iter() diff --git a/halo2_proofs/src/plonk/permutation.rs b/halo2_proofs/src/plonk/permutation.rs index 1da8fb1e2..29a4a323a 100644 --- a/halo2_proofs/src/plonk/permutation.rs +++ b/halo2_proofs/src/plonk/permutation.rs @@ -118,8 +118,11 @@ impl VerifyingKey { Ok(VerifyingKey { commitments }) } - pub(crate) fn bytes_length(&self) -> usize { - self.commitments.len() * C::default().to_bytes().as_ref().len() + pub(crate) fn bytes_length(&self, format: SerdeFormat) -> usize + where + C: SerdeCurveAffine, + { + self.commitments.len() * C::byte_length(format) } }