Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 7 additions & 33 deletions jolt-core/src/zkvm/bytecode/mod.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
use crate::poly::opening_proof::{
OpeningAccumulator, ProverOpeningAccumulator, SumcheckId, VerifierOpeningAccumulator,
};
use crate::poly::opening_proof::{OpeningAccumulator, ProverOpeningAccumulator, SumcheckId};
use crate::subprotocols::sumcheck_prover::SumcheckInstanceProver;
use crate::subprotocols::sumcheck_verifier::SumcheckInstanceVerifier;
use crate::subprotocols::{
BooleanitySumcheckParams, BooleanitySumcheckProver, BooleanitySumcheckVerifier,
HammingWeightSumcheckParams, HammingWeightSumcheckProver, HammingWeightSumcheckVerifier,
};
use crate::utils::math::Math;
#[cfg(feature = "allocative")]
use crate::utils::profiling::print_data_structure_heap_usage;
use crate::zkvm::bytecode::read_raf_checking::{ReadRafSumcheckProver, ReadRafSumcheckVerifier};
use crate::zkvm::dag::stage::{SumcheckStagesProver, SumcheckStagesVerifier};
use crate::zkvm::bytecode::read_raf_checking::ReadRafSumcheckProver;
use crate::zkvm::dag::stage::SumcheckStagesProver;
use crate::zkvm::dag::state_manager::StateManager;
use crate::zkvm::witness::{
compute_d_parameter, CommittedPolynomial, VirtualPolynomial, DTH_ROOT_OF_K,
Expand Down Expand Up @@ -168,27 +165,6 @@ impl<F: JoltField, PCS: CommitmentScheme<Field = F>, T: Transcript> SumcheckStag
}
}

pub struct BytecodeDagVerifier;

impl<F: JoltField, PCS: CommitmentScheme<Field = F>, T: Transcript>
SumcheckStagesVerifier<F, T, PCS> for BytecodeDagVerifier
{
fn stage6_instances(
&mut self,
sm: &mut StateManager<'_, F, PCS>,
opening_accumulator: &mut VerifierOpeningAccumulator<F>,
transcript: &mut T,
) -> Vec<Box<dyn SumcheckInstanceVerifier<F, T>>> {
let read_checking = ReadRafSumcheckVerifier::gen(sm, opening_accumulator, transcript);
let (hamming_weight, booleanity) = new_ra_one_hot_verifiers(sm, transcript);
vec![
Box::new(read_checking),
Box::new(hamming_weight),
Box::new(booleanity),
]
}
}

fn gen_ra_one_hot_provers<F: JoltField>(
state_manager: &mut StateManager<'_, F, impl CommitmentScheme<Field = F>>,
opening_accumulator: &ProverOpeningAccumulator<F>,
Expand Down Expand Up @@ -247,17 +223,15 @@ fn gen_ra_one_hot_provers<F: JoltField>(
)
}

fn new_ra_one_hot_verifiers<F: JoltField>(
state_manager: &mut StateManager<'_, F, impl CommitmentScheme<Field = F>>,
pub fn new_ra_one_hot_verifiers<F: JoltField>(
bytecode_preprocessing: &BytecodePreprocessing,
n_cycle_vars: usize,
transcript: &mut impl Transcript,
) -> (
HammingWeightSumcheckVerifier<F>,
BooleanitySumcheckVerifier<F>,
) {
let (preprocessing, _, T_val) = state_manager.get_verifier_data();
let bytecode_preprocessing = &preprocessing.shared.bytecode;
let d = bytecode_preprocessing.d;
let log_t = T_val.log_2();
let polynomial_types: Vec<CommittedPolynomial> =
(0..d).map(CommittedPolynomial::BytecodeRa).collect();
let hamming_weight_gamma_powers = transcript.challenge_scalar_powers(d);
Expand All @@ -279,7 +253,7 @@ fn new_ra_one_hot_verifiers<F: JoltField>(
let booleanity_params = BooleanitySumcheckParams {
d,
log_k_chunk: DTH_ROOT_OF_K.log_2(),
log_t,
log_t: n_cycle_vars,
gammas: booleanity_gammas,
r_address,
r_cycle: Vec::new(),
Expand Down
86 changes: 45 additions & 41 deletions jolt-core/src/zkvm/bytecode/read_raf_checking.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ use crate::{
transcripts::Transcript,
utils::{math::Math, small_scalar::SmallScalar, thread::unsafe_allocate_zero_vec},
zkvm::{
bytecode::BytecodePreprocessing,
dag::state_manager::StateManager,
instruction::{
CircuitFlags, Flags, InstructionFlags, InstructionLookup, InterleavedBitsMarker,
Expand All @@ -42,7 +43,7 @@ use common::constants::{REGISTER_COUNT, XLEN};
use itertools::{chain, zip_eq, Itertools};
use rayon::prelude::*;
use strum::{EnumCount, IntoEnumIterator};
use tracer::instruction::NormalizedInstruction;
use tracer::instruction::{Instruction, NormalizedInstruction};

/// Number of batched read-checking sumchecks bespokely
const N_STAGES: usize = 5;
Expand Down Expand Up @@ -127,7 +128,14 @@ impl<F: JoltField> ReadRafSumcheckProver<F> {
opening_accumulator: &ProverOpeningAccumulator<F>,
transcript: &mut impl Transcript,
) -> Self {
let params = ReadRafSumcheckParams::gen(state_manager, opening_accumulator, transcript);
let (preprocessing, _, trace, _, _) = state_manager.get_prover_data();

let params = ReadRafSumcheckParams::gen(
&preprocessing.shared.bytecode,
trace.len().log_2(),
opening_accumulator,
transcript,
);

let claim_per_stage = [
params.rv_claims[0] + params.gamma_powers[5] * params.raf_claim,
Expand All @@ -146,7 +154,6 @@ impl<F: JoltField> ReadRafSumcheckProver<F> {
.each_ref()
.map(|r_cycle| EqPolynomial::evals(&r_cycle[..prefix_n_vars]));

let (preprocessing, _, trace, _, _) = state_manager.get_prover_data();
let F = trace
.par_chunks(chunk_size)
.enumerate()
Expand Down Expand Up @@ -521,12 +528,18 @@ pub struct ReadRafSumcheckVerifier<F: JoltField> {

impl<F: JoltField> ReadRafSumcheckVerifier<F> {
pub fn gen(
state_manager: &mut StateManager<'_, F, impl CommitmentScheme<Field = F>>,
bytecode_preprocessing: &BytecodePreprocessing,
n_cycle_vars: usize,
opening_accumulator: &VerifierOpeningAccumulator<F>,
transcript: &mut impl Transcript,
) -> Self {
Self {
params: ReadRafSumcheckParams::gen(state_manager, opening_accumulator, transcript),
params: ReadRafSumcheckParams::gen(
bytecode_preprocessing,
n_cycle_vars,
opening_accumulator,
transcript,
),
}
}
}
Expand Down Expand Up @@ -643,43 +656,43 @@ struct ReadRafSumcheckParams<F: JoltField> {

impl<F: JoltField> ReadRafSumcheckParams<F> {
fn gen(
state_manager: &mut StateManager<'_, F, impl CommitmentScheme<Field = F>>,
bytecode_preprocessing: &BytecodePreprocessing,
n_cycle_vars: usize,
opening_accumulator: &dyn OpeningAccumulator<F>,
transcript: &mut impl Transcript,
) -> Self {
let preprocessing = state_manager.get_shared_preprocessing();
let K = preprocessing.bytecode.code_size;
let K = bytecode_preprocessing.code_size;
let log_K = K.log_2();
let d = preprocessing.bytecode.d;
let log_T = state_manager.get_trace_len().log_2();
let d = bytecode_preprocessing.d;
let gamma_powers = transcript.challenge_scalar_powers(7);

let bytecode = &bytecode_preprocessing.bytecode;
let (val_1, rv_claim_1) = Self::compute_val_rv(
state_manager,
bytecode,
opening_accumulator,
ReadCheckingValType::Stage1,
transcript,
);
let (val_2, rv_claim_2) = Self::compute_val_rv(
state_manager,
bytecode,
opening_accumulator,
ReadCheckingValType::Stage2,
transcript,
);
let (val_3, rv_claim_3) = Self::compute_val_rv(
state_manager,
bytecode,
opening_accumulator,
ReadCheckingValType::Stage3,
transcript,
);
let (val_4, rv_claim_4) = Self::compute_val_rv(
state_manager,
bytecode,
opening_accumulator,
ReadCheckingValType::Stage4,
transcript,
);
let (val_5, rv_claim_5) = Self::compute_val_rv(
state_manager,
bytecode,
opening_accumulator,
ReadCheckingValType::Stage5,
transcript,
Expand Down Expand Up @@ -747,7 +760,7 @@ impl<F: JoltField> ReadRafSumcheckParams<F> {
K,
log_K,
d,
log_T,
log_T: n_cycle_vars,
val_polys,
rv_claims,
raf_claim,
Expand All @@ -758,7 +771,7 @@ impl<F: JoltField> ReadRafSumcheckParams<F> {
}

fn compute_val_rv(
sm: &mut StateManager<F, impl CommitmentScheme<Field = F>>,
bytecode: &[Instruction],
opening_accumulator: &dyn OpeningAccumulator<F>,
val_type: ReadCheckingValType,
transcript: &mut impl Transcript,
Expand All @@ -767,35 +780,35 @@ impl<F: JoltField> ReadRafSumcheckParams<F> {
ReadCheckingValType::Stage1 => {
let gamma_powers = transcript.challenge_scalar_powers(2 + NUM_CIRCUIT_FLAGS);
(
Self::compute_val_1(sm, &gamma_powers),
Self::compute_val_1(bytecode, &gamma_powers),
Self::compute_rv_claim_1(opening_accumulator, &gamma_powers),
)
}
ReadCheckingValType::Stage2 => {
let gamma_powers = transcript.challenge_scalar_powers(4);
(
Self::compute_val_2(sm, &gamma_powers),
Self::compute_val_2(bytecode, &gamma_powers),
Self::compute_rv_claim_2(opening_accumulator, &gamma_powers),
)
}
ReadCheckingValType::Stage3 => {
let gamma_powers = transcript.challenge_scalar_powers(9);
(
Self::compute_val_3(sm, &gamma_powers),
Self::compute_val_3(bytecode, &gamma_powers),
Self::compute_rv_claim_3(opening_accumulator, &gamma_powers),
)
}
ReadCheckingValType::Stage4 => {
let gamma_powers = transcript.challenge_scalar_powers(3);
(
Self::compute_val_4(sm, opening_accumulator, &gamma_powers),
Self::compute_val_4(bytecode, opening_accumulator, &gamma_powers),
Self::compute_rv_claim_4(opening_accumulator, &gamma_powers),
)
}
ReadCheckingValType::Stage5 => {
let gamma_powers = transcript.challenge_scalar_powers(2 + NUM_LOOKUP_TABLES);
(
Self::compute_val_5(sm, opening_accumulator, &gamma_powers),
Self::compute_val_5(bytecode, opening_accumulator, &gamma_powers),
Self::compute_rv_claim_5(opening_accumulator, &gamma_powers),
)
}
Expand All @@ -806,11 +819,8 @@ impl<F: JoltField> ReadRafSumcheckParams<F> {
/// Val(k) = unexpanded_pc(k) + gamma * imm(k)
/// + gamma^2 * circuit_flags[0](k) + gamma^3 * circuit_flags[1](k) + ...
/// This particular Val virtualizes claims output by Spartan's "outer" sumcheck
fn compute_val_1(
sm: &mut StateManager<F, impl CommitmentScheme<Field = F>>,
gamma_powers: &[F],
) -> Vec<F> {
sm.get_bytecode()
fn compute_val_1(bytecode: &[Instruction], gamma_powers: &[F]) -> Vec<F> {
bytecode
.par_iter()
.map(|instruction| {
let NormalizedInstruction {
Expand Down Expand Up @@ -879,11 +889,8 @@ impl<F: JoltField> ReadRafSumcheckParams<F> {
/// rd_addr(k) is the rd address for instruction k;
/// write_lookup_output_to_rd_flag(k) = 1 if instruction k writes lookup output to rd, 0 otherwise.
/// This Val matches the fused product sumcheck.
fn compute_val_2(
sm: &mut StateManager<F, impl CommitmentScheme<Field = F>>,
gamma_powers: &[F],
) -> Vec<F> {
sm.get_bytecode()
fn compute_val_2(bytecode: &[Instruction], gamma_powers: &[F]) -> Vec<F> {
bytecode
.par_iter()
.map(|instruction| {
let flags = instruction.circuit_flags();
Expand Down Expand Up @@ -947,11 +954,8 @@ impl<F: JoltField> ReadRafSumcheckParams<F> {
/// + gamma^2 * left_operand_is_rs1_value(k)
/// + gamma^3 * left_operand_is_pc(k) + ...
/// This particular Val virtualizes claims output by the ShiftSumcheck.
fn compute_val_3(
sm: &mut StateManager<F, impl CommitmentScheme<Field = F>>,
gamma_powers: &[F],
) -> Vec<F> {
sm.get_bytecode()
fn compute_val_3(bytecode: &[Instruction], gamma_powers: &[F]) -> Vec<F> {
bytecode
.par_iter()
.map(|instruction| {
let instr = instruction.normalize();
Expand Down Expand Up @@ -1067,7 +1071,7 @@ impl<F: JoltField> ReadRafSumcheckParams<F> {
/// and analogously for rs1(k, k') and rs2(k, k').
/// This particular Val virtualizes claims output by the registers read/write checking sumcheck.
fn compute_val_4(
sm: &mut StateManager<F, impl CommitmentScheme<Field = F>>,
bytecode: &[Instruction],
opening_accumulator: &dyn OpeningAccumulator<F>,
gamma_powers: &[F],
) -> Vec<F> {
Expand All @@ -1082,7 +1086,7 @@ impl<F: JoltField> ReadRafSumcheckParams<F> {
let eq_r_register = EqPolynomial::<F>::evals(r_register);
debug_assert_eq!(eq_r_register.len(), REGISTER_COUNT as usize);

sm.get_bytecode()
bytecode
.par_iter()
.map(|instruction| {
let instr = instruction.normalize();
Expand Down Expand Up @@ -1125,7 +1129,7 @@ impl<F: JoltField> ReadRafSumcheckParams<F> {
/// This particular Val virtualizes the claim output by the registers val-evaluation sumcheck
/// and the instruction lookups sumcheck.
fn compute_val_5(
sm: &mut StateManager<F, impl CommitmentScheme<Field = F>>,
bytecode: &[Instruction],
opening_accumulator: &dyn OpeningAccumulator<F>,
gamma_powers: &[F],
) -> Vec<F> {
Expand All @@ -1140,7 +1144,7 @@ impl<F: JoltField> ReadRafSumcheckParams<F> {
let eq_r_register = EqPolynomial::evals(&r_register);
debug_assert_eq!(eq_r_register.len(), REGISTER_COUNT as usize);

sm.get_bytecode()
bytecode
.par_iter()
.map(|instruction| {
let instr = instruction.normalize();
Expand Down
Loading