Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] Make selector compression optional #212

8 changes: 8 additions & 0 deletions halo2_proofs/src/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,14 @@ pub trait SerdeCurveAffine: CurveAffine + SerdeObject {
_ => self.write_raw(writer),
}
}

/// Byte length of an affine curve element according to `format`.
fn byte_length(format: SerdeFormat) -> usize {
match format {
SerdeFormat::Processed => Self::default().to_bytes().as_ref().len(),
_ => Self::Repr::default().as_ref().len() * 2,
}
}
}
impl<C: CurveAffine + SerdeObject> SerdeCurveAffine for C {}

Expand Down
90 changes: 66 additions & 24 deletions halo2_proofs/src/plonk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ pub struct VerifyingKey<C: CurveAffine> {
/// The representative of this `VerifyingKey` in transcripts.
transcript_repr: C::Scalar,
selectors: Vec<Vec<bool>>,
/// Whether selector compression is turned on or not.
compress_selectors: bool,
}

impl<C: SerdeCurveAffine> VerifyingKey<C>
Expand All @@ -72,13 +74,19 @@ where
/// Writes a field element into raw bytes in its internal Montgomery representation,
/// WITHOUT performing the expensive Montgomery reduction.
pub fn write<W: io::Write>(&self, writer: &mut W, format: SerdeFormat) -> io::Result<()> {
writer.write_all(&self.domain.k().to_be_bytes())?;
writer.write_all(&(self.fixed_commitments.len() as u32).to_be_bytes())?;
// Version byte that will be checked on read.
writer.write_all(&[0x02])?;
writer.write_all(&self.domain.k().to_le_bytes())?;
writer.write_all(&[self.compress_selectors as u8])?;
writer.write_all(&(self.fixed_commitments.len() as u32).to_le_bytes())?;
for commitment in &self.fixed_commitments {
commitment.write(writer, format)?;
}
self.permutation.write(writer, format)?;

if !self.compress_selectors {
assert!(self.selectors.is_empty());
}
// write self.selectors
for selector in &self.selectors {
// since `selector` is filled with `bool`, we pack them 8 at a time into bytes and then write
Expand All @@ -104,50 +112,76 @@ where
format: SerdeFormat,
#[cfg(feature = "circuit-params")] params: ConcreteCircuit::Params,
) -> io::Result<Self> {
let mut version_byte = [0u8; 1];
reader.read_exact(&mut version_byte)?;
if 0x02 != version_byte[0] {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"unexpected version byte",
));
}
let mut k = [0u8; 4];
reader.read_exact(&mut k)?;
let k = u32::from_be_bytes(k);
let k = u32::from_le_bytes(k);
let mut compress_selectors = [0u8; 1];
reader.read_exact(&mut compress_selectors)?;
if compress_selectors[0] != 0 && compress_selectors[0] != 1 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"unexpected compress_selectors not boolean",
));
}
let compress_selectors = compress_selectors[0] == 1;
let (domain, cs, _) = keygen::create_domain::<C, ConcreteCircuit>(
k,
#[cfg(feature = "circuit-params")]
params,
);
let mut num_fixed_columns = [0u8; 4];
reader.read_exact(&mut num_fixed_columns)?;
let num_fixed_columns = u32::from_be_bytes(num_fixed_columns);
let num_fixed_columns = u32::from_le_bytes(num_fixed_columns);

let fixed_commitments: Vec<_> = (0..num_fixed_columns)
.map(|_| C::read(reader, format))
.collect::<Result<_, _>>()?;

let permutation = permutation::VerifyingKey::read(reader, &cs.permutation, format)?;

// read selectors
let selectors: Vec<Vec<bool>> = vec![vec![false; 1 << k]; cs.num_selectors]
.into_iter()
.map(|mut selector| {
let mut selector_bytes = vec![0u8; (selector.len() + 7) / 8];
reader.read_exact(&mut selector_bytes)?;
for (bits, byte) in selector.chunks_mut(8).zip(selector_bytes) {
crate::helpers::unpack(byte, bits);
}
Ok(selector)
})
.collect::<io::Result<_>>()?;
let (cs, _) = cs.compress_selectors(selectors.clone());
let (cs, selectors) = if compress_selectors {
// read selectors
let selectors: Vec<Vec<bool>> = vec![vec![false; 1 << k]; cs.num_selectors]
.into_iter()
.map(|mut selector| {
let mut selector_bytes = vec![0u8; (selector.len() + 7) / 8];
reader.read_exact(&mut selector_bytes)?;
for (bits, byte) in selector.chunks_mut(8).zip(selector_bytes) {
crate::helpers::unpack(byte, bits);
}
Ok(selector)
})
.collect::<io::Result<_>>()?;
let (cs, _) = cs.compress_selectors(selectors.clone());
(cs, selectors)
} else {
// we still need to replace selectors with fixed Expressions in `cs`
let fake_selectors = vec![vec![]; cs.num_selectors];
let (cs, _) = cs.directly_convert_selectors_to_fixed(fake_selectors);
(cs, vec![])
};

Ok(Self::from_parts(
domain,
fixed_commitments,
permutation,
cs,
selectors,
compress_selectors,
))
}

/// Writes a verifying key to a vector of bytes using [`Self::write`].
pub fn to_bytes(&self, format: SerdeFormat) -> Vec<u8> {
let mut bytes = Vec::<u8>::with_capacity(self.bytes_length());
let mut bytes = Vec::<u8>::with_capacity(self.bytes_length(format));
Self::write(self, &mut bytes, format).expect("Writing to vector should not fail");
bytes
}
Expand All @@ -168,9 +202,12 @@ where
}

impl<C: CurveAffine> VerifyingKey<C> {
fn bytes_length(&self) -> usize {
8 + (self.fixed_commitments.len() * C::default().to_bytes().as_ref().len())
+ self.permutation.bytes_length()
fn bytes_length(&self, format: SerdeFormat) -> usize
where
C: SerdeCurveAffine,
{
10 + (self.fixed_commitments.len() * C::byte_length(format))
+ self.permutation.bytes_length(format)
+ self.selectors.len()
* (self
.selectors
Expand All @@ -185,6 +222,7 @@ impl<C: CurveAffine> VerifyingKey<C> {
permutation: permutation::VerifyingKey<C>,
cs: ConstraintSystem<C::Scalar>,
selectors: Vec<Vec<bool>>,
compress_selectors: bool,
) -> Self
where
C::ScalarExt: FromUniformBytes<64>,
Expand All @@ -201,6 +239,7 @@ impl<C: CurveAffine> VerifyingKey<C> {
// Temporary, this is not pinned.
transcript_repr: C::Scalar::ZERO,
selectors,
compress_selectors,
};

let mut hasher = Blake2bParams::new()
Expand Down Expand Up @@ -300,9 +339,12 @@ where
}

/// Gets the total number of bytes in the serialization of `self`
fn bytes_length(&self) -> usize {
fn bytes_length(&self, format: SerdeFormat) -> usize
where
C: SerdeCurveAffine,
{
let scalar_len = C::Scalar::default().to_repr().as_ref().len();
self.vk.bytes_length()
self.vk.bytes_length(format)
+ 12
+ scalar_len * (self.l0.len() + self.l_last.len() + self.l_active_row.len())
+ polynomial_slice_byte_length(&self.fixed_values)
Expand Down Expand Up @@ -383,7 +425,7 @@ where

/// Writes a proving key to a vector of bytes using [`Self::write`].
pub fn to_bytes(&self, format: SerdeFormat) -> Vec<u8> {
let mut bytes = Vec::<u8>::with_capacity(self.bytes_length());
let mut bytes = Vec::<u8>::with_capacity(self.bytes_length(format));
Self::write(self, &mut bytes, format).expect("Writing to vector should not fail");
bytes
}
Expand Down
46 changes: 41 additions & 5 deletions halo2_proofs/src/plonk/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2004,7 +2004,45 @@ impl<F: Field> ConstraintSystem<F> {
.into_iter()
.map(|a| a.unwrap())
.collect::<Vec<_>>();
self.replace_selectors_with_fixed(&selector_replacements);

(self, polys)
}

/// Does not combine selectors and directly replaces them everywhere with fixed columns.
pub fn directly_convert_selectors_to_fixed(
mut self,
selectors: Vec<Vec<bool>>,
) -> (Self, Vec<Vec<F>>) {
// The number of provided selector assignments must be the number we
// counted for this constraint system.
assert_eq!(selectors.len(), self.num_selectors);

let (polys, selector_replacements): (Vec<_>, Vec<_>) = selectors
.into_iter()
.map(|selector| {
let poly = selector
.iter()
.map(|b| if *b { F::ONE } else { F::ZERO })
.collect::<Vec<_>>();
let column = self.fixed_column();
let rotation = Rotation::cur();
let expr = Expression::Fixed(FixedQuery {
index: Some(self.query_fixed_index(column, rotation)),
column_index: column.index,
rotation,
});
(poly, expr)
})
.unzip();

self.replace_selectors_with_fixed(&selector_replacements);
self.num_selectors = 0;

(self, polys)
}

fn replace_selectors_with_fixed(&mut self, selector_replacements: &[Expression<F>]) {
fn replace_selectors<F: Field>(
expr: &mut Expression<F>,
selector_replacements: &[Expression<F>],
Expand Down Expand Up @@ -2035,7 +2073,7 @@ impl<F: Field> ConstraintSystem<F> {

// Substitute selectors for the real fixed columns in all gates
for expr in self.gates.iter_mut().flat_map(|gate| gate.polys.iter_mut()) {
replace_selectors(expr, &selector_replacements, false);
replace_selectors(expr, selector_replacements, false);
}

// Substitute non-simple selectors for the real fixed columns in all
Expand All @@ -2046,7 +2084,7 @@ impl<F: Field> ConstraintSystem<F> {
.iter_mut()
.chain(lookup.table_expressions.iter_mut())
}) {
replace_selectors(expr, &selector_replacements, true);
replace_selectors(expr, selector_replacements, true);
}

for expr in self.shuffles.iter_mut().flat_map(|shuffle| {
Expand All @@ -2055,10 +2093,8 @@ impl<F: Field> ConstraintSystem<F> {
.iter_mut()
.chain(shuffle.shuffle_expressions.iter_mut())
}) {
replace_selectors(expr, &selector_replacements, true);
replace_selectors(expr, selector_replacements, true);
}

(self, polys)
}

/// Allocate a new (simple) selector. Simple selectors cannot be added to
Expand Down
33 changes: 31 additions & 2 deletions halo2_proofs/src/plonk/keygen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -203,10 +203,28 @@ impl<F: Field> Assignment<F> for Assembly<F> {
}

/// Generate a `VerifyingKey` from an instance of `Circuit`.
/// By default, selector compression is turned **off**.
pub fn keygen_vk<'params, C, P, ConcreteCircuit>(
params: &P,
circuit: &ConcreteCircuit,
) -> Result<VerifyingKey<C>, Error>
where
C: CurveAffine,
P: Params<'params, C>,
ConcreteCircuit: Circuit<C::Scalar>,
C::Scalar: FromUniformBytes<64>,
{
keygen_vk_custom(params, circuit, true)
}

/// Generate a `VerifyingKey` from an instance of `Circuit`.
///
/// The selector compression optimization is turned on only if `compress_selectors` is `true`.
pub fn keygen_vk_custom<'params, C, P, ConcreteCircuit>(
params: &P,
circuit: &ConcreteCircuit,
compress_selectors: bool,
) -> Result<VerifyingKey<C>, Error>
where
C: CurveAffine,
P: Params<'params, C>,
Expand Down Expand Up @@ -241,7 +259,13 @@ where
)?;

let mut fixed = batch_invert_assigned(assembly.fixed);
let (cs, selector_polys) = cs.compress_selectors(assembly.selectors.clone());
let (cs, selector_polys) = if compress_selectors {
cs.compress_selectors(assembly.selectors.clone())
} else {
// After this, the ConstraintSystem should not have any selectors: `verify` does not need them, and `keygen_pk` regenerates `cs` from scratch anyways.
let selectors = std::mem::take(&mut assembly.selectors);
cs.directly_convert_selectors_to_fixed(selectors)
};
fixed.extend(
selector_polys
.into_iter()
Expand All @@ -263,6 +287,7 @@ where
permutation_vk,
cs,
assembly.selectors,
compress_selectors,
))
}

Expand Down Expand Up @@ -307,7 +332,11 @@ where
)?;

let mut fixed = batch_invert_assigned(assembly.fixed);
let (cs, selector_polys) = cs.compress_selectors(assembly.selectors);
let (cs, selector_polys) = if vk.compress_selectors {
cs.compress_selectors(assembly.selectors)
} else {
cs.directly_convert_selectors_to_fixed(assembly.selectors)
};
fixed.extend(
selector_polys
.into_iter()
Expand Down
7 changes: 5 additions & 2 deletions halo2_proofs/src/plonk/permutation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,11 @@ impl<C: CurveAffine> VerifyingKey<C> {
Ok(VerifyingKey { commitments })
}

pub(crate) fn bytes_length(&self) -> usize {
self.commitments.len() * C::default().to_bytes().as_ref().len()
pub(crate) fn bytes_length(&self, format: SerdeFormat) -> usize
where
C: SerdeCurveAffine,
{
self.commitments.len() * C::byte_length(format)
}
}

Expand Down
Loading