Skip to content

Commit

Permalink
Merge pull request a16z#415 from a16z/sragss/multi-offset-eq
Browse files Browse the repository at this point in the history
Multi-Offset-Eq Constraints
  • Loading branch information
sragss authored Jul 23, 2024
2 parents 1f9ca2d + d3b30ca commit d5147f8
Show file tree
Hide file tree
Showing 4 changed files with 253 additions and 220 deletions.
234 changes: 120 additions & 114 deletions jolt-core/src/r1cs/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use rayon::prelude::*;
use std::{collections::HashMap, fmt::Debug};

use super::{
key::{NonUniformR1CS, SparseEqualityItem},
key::{NonUniformR1CS, NonUniformR1CSConstraint, SparseEqualityItem},
ops::{ConstraintInput, Term, Variable, LC},
special_polys::SparsePolynomial,
};
Expand Down Expand Up @@ -552,7 +552,7 @@ pub struct CombinedUniformBuilder<F: JoltField, I: ConstraintInput> {
/// Padded to the nearest power of 2
uniform_repeat: usize,

offset_equality_constraint: OffsetEqConstraint<I>,
offset_equality_constraints: Vec<OffsetEqConstraint<I>>,
}

#[tracing::instrument(skip_all, name = "batch_inputs")]
Expand All @@ -574,13 +574,13 @@ impl<F: JoltField, I: ConstraintInput> CombinedUniformBuilder<F, I> {
pub fn construct(
uniform_builder: R1CSBuilder<F, I>,
uniform_repeat: usize,
offset_equality_constraint: OffsetEqConstraint<I>,
offset_equality_constraints: Vec<OffsetEqConstraint<I>>,
) -> Self {
assert!(uniform_repeat.is_power_of_two());
Self {
uniform_builder,
uniform_repeat,
offset_equality_constraint,
offset_equality_constraints,
}
}

Expand Down Expand Up @@ -667,7 +667,7 @@ impl<F: JoltField, I: ConstraintInput> CombinedUniformBuilder<F, I> {
}

pub(super) fn offset_eq_constraint_rows(&self) -> usize {
self.uniform_repeat
self.uniform_repeat * self.offset_equality_constraints.len()
}

/// Total number of rows used across all repeated constraints. Not padded to nearest power of two.
Expand All @@ -684,71 +684,75 @@ impl<F: JoltField, I: ConstraintInput> CombinedUniformBuilder<F, I> {
self.uniform_builder.materialize()
}

/// Converts builder::OffsetEqConstraints into key::NonUniformR1CSConstraint
pub fn materialize_offset_eq(&self) -> NonUniformR1CS<F> {
// (a - b) * condition == 0
// A: a - b
// B: condition
// C: 0

let mut eq = SparseEqualityItem::<F>::empty();
let mut condition = SparseEqualityItem::<F>::empty();
let mut constraints = Vec::with_capacity(self.offset_equality_constraints.len());
for constraint in &self.offset_equality_constraints {
let mut eq = SparseEqualityItem::<F>::empty();
let mut condition = SparseEqualityItem::<F>::empty();

let constraint = &self.offset_equality_constraint;
constraint
.cond
.1
.terms()
.iter()
.filter(|term| matches!(term.0, Variable::Input(_) | Variable::Auxiliary(_)))
.for_each(|term| {
condition.offset_vars.push((
self.uniform_builder.variable_to_column(term.0),
constraint.cond.0,
F::from_i64(term.1),
))
});
if let Some(term) = constraint.cond.1.constant_term() {
condition.constant = F::from_i64(term.1);
}

constraint
.cond
.1
.terms()
.iter()
.filter(|term| matches!(term.0, Variable::Input(_) | Variable::Auxiliary(_)))
.for_each(|term| {
condition.offset_vars.push((
self.uniform_builder.variable_to_column(term.0),
constraint.cond.0,
F::from_i64(term.1),
))
});
if let Some(term) = constraint.cond.1.constant_term() {
condition.constant = F::from_i64(term.1);
}
// Can't simply combine like terms because of the offset
let lhs = constraint.a.1.clone();
let rhs = -constraint.b.1.clone();

// Can't simply combine like terms because of the offset
let lhs = constraint.a.1.clone();
let rhs = -constraint.b.1.clone();
lhs.terms()
.iter()
.filter(|term| matches!(term.0, Variable::Input(_) | Variable::Auxiliary(_)))
.for_each(|term| {
eq.offset_vars.push((
self.uniform_builder.variable_to_column(term.0),
constraint.a.0,
F::from_i64(term.1),
))
});
rhs.terms()
.iter()
.filter(|term| matches!(term.0, Variable::Input(_) | Variable::Auxiliary(_)))
.for_each(|term| {
eq.offset_vars.push((
self.uniform_builder.variable_to_column(term.0),
constraint.b.0,
F::from_i64(term.1),
))
});

lhs.terms()
.iter()
.filter(|term| matches!(term.0, Variable::Input(_) | Variable::Auxiliary(_)))
.for_each(|term| {
eq.offset_vars.push((
self.uniform_builder.variable_to_column(term.0),
constraint.a.0,
F::from_i64(term.1),
))
});
rhs.terms()
.iter()
.filter(|term| matches!(term.0, Variable::Input(_) | Variable::Auxiliary(_)))
.for_each(|term| {
eq.offset_vars.push((
self.uniform_builder.variable_to_column(term.0),
constraint.b.0,
F::from_i64(term.1),
))
// Handle constants
lhs.terms().iter().for_each(|term| {
assert!(
!matches!(term.0, Variable::Constant),
"Constants only supported in RHS"
)
});
if let Some(term) = rhs.constant_term() {
eq.constant = F::from_i64(term.1);
}

// Handle constants
lhs.terms().iter().for_each(|term| {
assert!(
!matches!(term.0, Variable::Constant),
"Constants only supported in RHS"
)
});
if let Some(term) = rhs.constant_term() {
eq.constant = F::from_i64(term.1);
constraints.push(NonUniformR1CSConstraint::new(eq, condition));
}

NonUniformR1CS::new(eq, condition)
NonUniformR1CS { constraints }
}

/// inputs should be of the format [[I::0, I::0, ...], [I::1, I::1, ...], ... [I::N, I::N]]
Expand Down Expand Up @@ -818,7 +822,7 @@ impl<F: JoltField, I: ConstraintInput> CombinedUniformBuilder<F, I> {
let (mut az_sparse, mut bz_sparse, cz_sparse) = par_flatten_triple(
uni_constraint_evals,
unsafe_allocate_sparse_zero_vec,
self.uniform_repeat, // Capacity overhead for offset_eq constraints.
self.offset_eq_constraint_rows(),
);

// offset_equality_constraints: Xz[uniform_constraint_rows..uniform_constraint_rows + 1]
Expand All @@ -827,48 +831,50 @@ impl<F: JoltField, I: ConstraintInput> CombinedUniformBuilder<F, I> {
let span = tracing::span!(tracing::Level::DEBUG, "offset_eq");
let _enter = span.enter();

let constr = &self.offset_equality_constraint;
let condition_evals = constr
.cond
.1
.evaluate_batch(&batch_inputs(&constr.cond.1), self.uniform_repeat);
let eq_a_evals = constr
.a
.1
.evaluate_batch(&batch_inputs(&constr.a.1), self.uniform_repeat);
let eq_b_evals = constr
.b
.1
.evaluate_batch(&batch_inputs(&constr.b.1), self.uniform_repeat);

(0..self.uniform_repeat).for_each(|step_index| {
// Write corresponding values, if outside the step range, only include the constant.
let a_step = step_index + constr.a.0 as usize;
let b_step = step_index + constr.b.0 as usize;
let a = eq_a_evals
.get(a_step)
.cloned()
.unwrap_or(constr.a.1.constant_term_field());
let b = eq_b_evals
.get(b_step)
.cloned()
.unwrap_or(constr.b.1.constant_term_field());
let az = a - b;

let global_index = uniform_constraint_rows + step_index;
if !az.is_zero() {
az_sparse.push((az, global_index));
}
for (constr_i, constr) in self.offset_equality_constraints.iter().enumerate() {
let condition_evals = constr
.cond
.1
.evaluate_batch(&batch_inputs(&constr.cond.1), self.uniform_repeat);
let eq_a_evals = constr
.a
.1
.evaluate_batch(&batch_inputs(&constr.a.1), self.uniform_repeat);
let eq_b_evals = constr
.b
.1
.evaluate_batch(&batch_inputs(&constr.b.1), self.uniform_repeat);

(0..self.uniform_repeat).for_each(|step_index| {
// Write corresponding values, if outside the step range, only include the constant.
let a_step = step_index + constr.a.0 as usize;
let b_step = step_index + constr.b.0 as usize;
let a = eq_a_evals
.get(a_step)
.cloned()
.unwrap_or(constr.a.1.constant_term_field());
let b = eq_b_evals
.get(b_step)
.cloned()
.unwrap_or(constr.b.1.constant_term_field());
let az = a - b;

let global_index =
uniform_constraint_rows + self.uniform_repeat * constr_i + step_index;
if !az.is_zero() {
az_sparse.push((az, global_index));
}

let condition_step = step_index + constr.cond.0 as usize;
let bz = condition_evals
.get(condition_step)
.cloned()
.unwrap_or(constr.cond.1.constant_term_field());
if !bz.is_zero() {
bz_sparse.push((bz, global_index));
}
});
let condition_step = step_index + constr.cond.0 as usize;
let bz = condition_evals
.get(condition_step)
.cloned()
.unwrap_or(constr.cond.1.constant_term_field());
if !bz.is_zero() {
bz_sparse.push((bz, global_index));
}
});
}
drop(_enter);

let num_vars = self.constraint_rows().next_power_of_two().log_2();
Expand Down Expand Up @@ -1296,11 +1302,8 @@ mod tests {
assert_eq!(uniform_builder.constraints.len(), 1);
assert_eq!(uniform_builder.next_aux, 1);
let num_steps = 2;
let combined_builder = CombinedUniformBuilder::construct(
uniform_builder,
num_steps,
OffsetEqConstraint::empty(),
);
let combined_builder =
CombinedUniformBuilder::construct(uniform_builder, num_steps, vec![]);

let mut inputs = vec![vec![Fr::zero(); num_steps]; TestInputs::COUNT];
inputs[TestInputs::OpFlags0 as usize][0] = Fr::from(5);
Expand Down Expand Up @@ -1340,11 +1343,8 @@ mod tests {
assert_eq!(uniform_builder.next_aux, 2);

let num_steps = 2;
let combined_builder = CombinedUniformBuilder::construct(
uniform_builder,
num_steps,
OffsetEqConstraint::empty(),
);
let combined_builder =
CombinedUniformBuilder::construct(uniform_builder, num_steps, vec![]);

let mut inputs = vec![vec![Fr::zero(); num_steps]; TestInputs::COUNT];
inputs[TestInputs::OpFlags0 as usize][0] = Fr::from(5);
Expand Down Expand Up @@ -1401,8 +1401,11 @@ mod tests {
(TestInputs::OpFlags0, false),
(TestInputs::OpFlags0, true),
);
let combined_builder =
CombinedUniformBuilder::construct(uniform_builder, num_steps, non_uniform_constraint);
let combined_builder = CombinedUniformBuilder::construct(
uniform_builder,
num_steps,
vec![non_uniform_constraint],
);

let mut inputs = vec![vec![Fr::zero(); num_steps]; TestInputs::COUNT];
inputs[TestInputs::OpFlags0 as usize][0] = Fr::from(5);
Expand Down Expand Up @@ -1445,8 +1448,11 @@ mod tests {
(TestInputs::OpFlags0, false),
(TestInputs::OpFlags0, true),
);
let combined_builder =
CombinedUniformBuilder::construct(uniform_builder, num_steps, non_uniform_constraint);
let combined_builder = CombinedUniformBuilder::construct(
uniform_builder,
num_steps,
vec![non_uniform_constraint],
);

let offset_eq = combined_builder.materialize_offset_eq();
let mut expected_condition = SparseEqualityItem::<Fr>::empty();
Expand All @@ -1458,8 +1464,8 @@ mod tests {
(TestInputs::OpFlags0 as usize, true, Fr::from_i64(-1)),
];

assert_eq!(offset_eq.condition, expected_condition);
assert_eq!(offset_eq.eq, expected_eq);
assert_eq!(offset_eq.constraints[0].condition, expected_condition);
assert_eq!(offset_eq.constraints[0].eq, expected_eq);
}

#[test]
Expand Down
18 changes: 8 additions & 10 deletions jolt-core/src/r1cs/jolt_constraints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@ pub fn construct_jolt_constraints<F: JoltField>(
(4 * JoltIn::PcIn + PC_START_ADDRESS, true),
);

CombinedUniformBuilder::construct(uniform_builder, padded_trace_length, non_uniform_constraint)
CombinedUniformBuilder::construct(
uniform_builder,
padded_trace_length,
vec![non_uniform_constraint],
)
}

// TODO(#377): Dedupe OpFlags / CircuitFlags
Expand Down Expand Up @@ -283,10 +287,7 @@ impl<F: JoltField> R1CSConstraintBuilder<F> for UniformJoltConstraints {
mod tests {
use super::*;

use crate::{
jolt::vm::rv32i_vm::RV32I,
r1cs::builder::{CombinedUniformBuilder, OffsetEqConstraint},
};
use crate::{jolt::vm::rv32i_vm::RV32I, r1cs::builder::CombinedUniformBuilder};

use ark_bn254::Fr;
use ark_std::Zero;
Expand All @@ -308,11 +309,8 @@ mod tests {
jolt_constraints.build_constraints(&mut uniform_builder);

let num_steps = 1;
let combined_builder = CombinedUniformBuilder::construct(
uniform_builder,
num_steps,
OffsetEqConstraint::empty(),
);
let combined_builder =
CombinedUniformBuilder::construct(uniform_builder, num_steps, vec![]);
let mut inputs = vec![vec![Fr::zero(); num_steps]; JoltIn::COUNT];

// ADD instruction
Expand Down
Loading

0 comments on commit d5147f8

Please sign in to comment.