diff --git a/jolt-core/src/zkvm/bytecode/mod.rs b/jolt-core/src/zkvm/bytecode/mod.rs index 84fa31db7..c54e7d51a 100644 --- a/jolt-core/src/zkvm/bytecode/mod.rs +++ b/jolt-core/src/zkvm/bytecode/mod.rs @@ -1,8 +1,5 @@ -use crate::poly::opening_proof::{ - OpeningAccumulator, ProverOpeningAccumulator, SumcheckId, VerifierOpeningAccumulator, -}; +use crate::poly::opening_proof::{OpeningAccumulator, ProverOpeningAccumulator, SumcheckId}; use crate::subprotocols::sumcheck_prover::SumcheckInstanceProver; -use crate::subprotocols::sumcheck_verifier::SumcheckInstanceVerifier; use crate::subprotocols::{ BooleanitySumcheckParams, BooleanitySumcheckProver, BooleanitySumcheckVerifier, HammingWeightSumcheckParams, HammingWeightSumcheckProver, HammingWeightSumcheckVerifier, @@ -10,8 +7,8 @@ use crate::subprotocols::{ use crate::utils::math::Math; #[cfg(feature = "allocative")] use crate::utils::profiling::print_data_structure_heap_usage; -use crate::zkvm::bytecode::read_raf_checking::{ReadRafSumcheckProver, ReadRafSumcheckVerifier}; -use crate::zkvm::dag::stage::{SumcheckStagesProver, SumcheckStagesVerifier}; +use crate::zkvm::bytecode::read_raf_checking::ReadRafSumcheckProver; +use crate::zkvm::dag::stage::SumcheckStagesProver; use crate::zkvm::dag::state_manager::StateManager; use crate::zkvm::witness::{ compute_d_parameter, CommittedPolynomial, VirtualPolynomial, DTH_ROOT_OF_K, @@ -168,27 +165,6 @@ impl, T: Transcript> SumcheckStag } } -pub struct BytecodeDagVerifier; - -impl, T: Transcript> - SumcheckStagesVerifier for BytecodeDagVerifier -{ - fn stage6_instances( - &mut self, - sm: &mut StateManager<'_, F, PCS>, - opening_accumulator: &mut VerifierOpeningAccumulator, - transcript: &mut T, - ) -> Vec>> { - let read_checking = ReadRafSumcheckVerifier::gen(sm, opening_accumulator, transcript); - let (hamming_weight, booleanity) = new_ra_one_hot_verifiers(sm, transcript); - vec![ - Box::new(read_checking), - Box::new(hamming_weight), - Box::new(booleanity), - ] - } -} - fn gen_ra_one_hot_provers( state_manager: &mut StateManager<'_, F, impl CommitmentScheme>, opening_accumulator: &ProverOpeningAccumulator, @@ -247,17 +223,15 @@ fn gen_ra_one_hot_provers( ) } -fn new_ra_one_hot_verifiers( - state_manager: &mut StateManager<'_, F, impl CommitmentScheme>, +pub fn new_ra_one_hot_verifiers( + bytecode_preprocessing: &BytecodePreprocessing, + n_cycle_vars: usize, transcript: &mut impl Transcript, ) -> ( HammingWeightSumcheckVerifier, BooleanitySumcheckVerifier, ) { - let (preprocessing, _, T_val) = state_manager.get_verifier_data(); - let bytecode_preprocessing = &preprocessing.shared.bytecode; let d = bytecode_preprocessing.d; - let log_t = T_val.log_2(); let polynomial_types: Vec = (0..d).map(CommittedPolynomial::BytecodeRa).collect(); let hamming_weight_gamma_powers = transcript.challenge_scalar_powers(d); @@ -279,7 +253,7 @@ fn new_ra_one_hot_verifiers( let booleanity_params = BooleanitySumcheckParams { d, log_k_chunk: DTH_ROOT_OF_K.log_2(), - log_t, + log_t: n_cycle_vars, gammas: booleanity_gammas, r_address, r_cycle: Vec::new(), diff --git a/jolt-core/src/zkvm/bytecode/read_raf_checking.rs b/jolt-core/src/zkvm/bytecode/read_raf_checking.rs index f6f81de54..2cf948a5f 100644 --- a/jolt-core/src/zkvm/bytecode/read_raf_checking.rs +++ b/jolt-core/src/zkvm/bytecode/read_raf_checking.rs @@ -26,6 +26,7 @@ use crate::{ transcripts::Transcript, utils::{math::Math, small_scalar::SmallScalar, thread::unsafe_allocate_zero_vec}, zkvm::{ + bytecode::BytecodePreprocessing, dag::state_manager::StateManager, instruction::{ CircuitFlags, Flags, InstructionFlags, InstructionLookup, InterleavedBitsMarker, @@ -42,7 +43,7 @@ use common::constants::{REGISTER_COUNT, XLEN}; use itertools::{chain, zip_eq, Itertools}; use rayon::prelude::*; use strum::{EnumCount, IntoEnumIterator}; -use tracer::instruction::NormalizedInstruction; +use tracer::instruction::{Instruction, NormalizedInstruction}; /// Number of batched read-checking sumchecks bespokely const N_STAGES: usize = 5; @@ -127,7 +128,14 @@ impl ReadRafSumcheckProver { opening_accumulator: &ProverOpeningAccumulator, transcript: &mut impl Transcript, ) -> Self { - let params = ReadRafSumcheckParams::gen(state_manager, opening_accumulator, transcript); + let (preprocessing, _, trace, _, _) = state_manager.get_prover_data(); + + let params = ReadRafSumcheckParams::gen( + &preprocessing.shared.bytecode, + trace.len().log_2(), + opening_accumulator, + transcript, + ); let claim_per_stage = [ params.rv_claims[0] + params.gamma_powers[5] * params.raf_claim, @@ -146,7 +154,6 @@ impl ReadRafSumcheckProver { .each_ref() .map(|r_cycle| EqPolynomial::evals(&r_cycle[..prefix_n_vars])); - let (preprocessing, _, trace, _, _) = state_manager.get_prover_data(); let F = trace .par_chunks(chunk_size) .enumerate() @@ -521,12 +528,18 @@ pub struct ReadRafSumcheckVerifier { impl ReadRafSumcheckVerifier { pub fn gen( - state_manager: &mut StateManager<'_, F, impl CommitmentScheme>, + bytecode_preprocessing: &BytecodePreprocessing, + n_cycle_vars: usize, opening_accumulator: &VerifierOpeningAccumulator, transcript: &mut impl Transcript, ) -> Self { Self { - params: ReadRafSumcheckParams::gen(state_manager, opening_accumulator, transcript), + params: ReadRafSumcheckParams::gen( + bytecode_preprocessing, + n_cycle_vars, + opening_accumulator, + transcript, + ), } } } @@ -643,43 +656,43 @@ struct ReadRafSumcheckParams { impl ReadRafSumcheckParams { fn gen( - state_manager: &mut StateManager<'_, F, impl CommitmentScheme>, + bytecode_preprocessing: &BytecodePreprocessing, + n_cycle_vars: usize, opening_accumulator: &dyn OpeningAccumulator, transcript: &mut impl Transcript, ) -> Self { - let preprocessing = state_manager.get_shared_preprocessing(); - let K = preprocessing.bytecode.code_size; + let K = bytecode_preprocessing.code_size; let log_K = K.log_2(); - let d = preprocessing.bytecode.d; - let log_T = state_manager.get_trace_len().log_2(); + let d = bytecode_preprocessing.d; let gamma_powers = transcript.challenge_scalar_powers(7); + let bytecode = &bytecode_preprocessing.bytecode; let (val_1, rv_claim_1) = Self::compute_val_rv( - state_manager, + bytecode, opening_accumulator, ReadCheckingValType::Stage1, transcript, ); let (val_2, rv_claim_2) = Self::compute_val_rv( - state_manager, + bytecode, opening_accumulator, ReadCheckingValType::Stage2, transcript, ); let (val_3, rv_claim_3) = Self::compute_val_rv( - state_manager, + bytecode, opening_accumulator, ReadCheckingValType::Stage3, transcript, ); let (val_4, rv_claim_4) = Self::compute_val_rv( - state_manager, + bytecode, opening_accumulator, ReadCheckingValType::Stage4, transcript, ); let (val_5, rv_claim_5) = Self::compute_val_rv( - state_manager, + bytecode, opening_accumulator, ReadCheckingValType::Stage5, transcript, @@ -747,7 +760,7 @@ impl ReadRafSumcheckParams { K, log_K, d, - log_T, + log_T: n_cycle_vars, val_polys, rv_claims, raf_claim, @@ -758,7 +771,7 @@ impl ReadRafSumcheckParams { } fn compute_val_rv( - sm: &mut StateManager>, + bytecode: &[Instruction], opening_accumulator: &dyn OpeningAccumulator, val_type: ReadCheckingValType, transcript: &mut impl Transcript, @@ -767,35 +780,35 @@ impl ReadRafSumcheckParams { ReadCheckingValType::Stage1 => { let gamma_powers = transcript.challenge_scalar_powers(2 + NUM_CIRCUIT_FLAGS); ( - Self::compute_val_1(sm, &gamma_powers), + Self::compute_val_1(bytecode, &gamma_powers), Self::compute_rv_claim_1(opening_accumulator, &gamma_powers), ) } ReadCheckingValType::Stage2 => { let gamma_powers = transcript.challenge_scalar_powers(4); ( - Self::compute_val_2(sm, &gamma_powers), + Self::compute_val_2(bytecode, &gamma_powers), Self::compute_rv_claim_2(opening_accumulator, &gamma_powers), ) } ReadCheckingValType::Stage3 => { let gamma_powers = transcript.challenge_scalar_powers(9); ( - Self::compute_val_3(sm, &gamma_powers), + Self::compute_val_3(bytecode, &gamma_powers), Self::compute_rv_claim_3(opening_accumulator, &gamma_powers), ) } ReadCheckingValType::Stage4 => { let gamma_powers = transcript.challenge_scalar_powers(3); ( - Self::compute_val_4(sm, opening_accumulator, &gamma_powers), + Self::compute_val_4(bytecode, opening_accumulator, &gamma_powers), Self::compute_rv_claim_4(opening_accumulator, &gamma_powers), ) } ReadCheckingValType::Stage5 => { let gamma_powers = transcript.challenge_scalar_powers(2 + NUM_LOOKUP_TABLES); ( - Self::compute_val_5(sm, opening_accumulator, &gamma_powers), + Self::compute_val_5(bytecode, opening_accumulator, &gamma_powers), Self::compute_rv_claim_5(opening_accumulator, &gamma_powers), ) } @@ -806,11 +819,8 @@ impl ReadRafSumcheckParams { /// Val(k) = unexpanded_pc(k) + gamma * imm(k) /// + gamma^2 * circuit_flags[0](k) + gamma^3 * circuit_flags[1](k) + ... /// This particular Val virtualizes claims output by Spartan's "outer" sumcheck - fn compute_val_1( - sm: &mut StateManager>, - gamma_powers: &[F], - ) -> Vec { - sm.get_bytecode() + fn compute_val_1(bytecode: &[Instruction], gamma_powers: &[F]) -> Vec { + bytecode .par_iter() .map(|instruction| { let NormalizedInstruction { @@ -879,11 +889,8 @@ impl ReadRafSumcheckParams { /// rd_addr(k) is the rd address for instruction k; /// write_lookup_output_to_rd_flag(k) = 1 if instruction k writes lookup output to rd, 0 otherwise. /// This Val matches the fused product sumcheck. - fn compute_val_2( - sm: &mut StateManager>, - gamma_powers: &[F], - ) -> Vec { - sm.get_bytecode() + fn compute_val_2(bytecode: &[Instruction], gamma_powers: &[F]) -> Vec { + bytecode .par_iter() .map(|instruction| { let flags = instruction.circuit_flags(); @@ -947,11 +954,8 @@ impl ReadRafSumcheckParams { /// + gamma^2 * left_operand_is_rs1_value(k) /// + gamma^3 * left_operand_is_pc(k) + ... /// This particular Val virtualizes claims output by the ShiftSumcheck. - fn compute_val_3( - sm: &mut StateManager>, - gamma_powers: &[F], - ) -> Vec { - sm.get_bytecode() + fn compute_val_3(bytecode: &[Instruction], gamma_powers: &[F]) -> Vec { + bytecode .par_iter() .map(|instruction| { let instr = instruction.normalize(); @@ -1067,7 +1071,7 @@ impl ReadRafSumcheckParams { /// and analogously for rs1(k, k') and rs2(k, k'). /// This particular Val virtualizes claims output by the registers read/write checking sumcheck. fn compute_val_4( - sm: &mut StateManager>, + bytecode: &[Instruction], opening_accumulator: &dyn OpeningAccumulator, gamma_powers: &[F], ) -> Vec { @@ -1082,7 +1086,7 @@ impl ReadRafSumcheckParams { let eq_r_register = EqPolynomial::::evals(r_register); debug_assert_eq!(eq_r_register.len(), REGISTER_COUNT as usize); - sm.get_bytecode() + bytecode .par_iter() .map(|instruction| { let instr = instruction.normalize(); @@ -1125,7 +1129,7 @@ impl ReadRafSumcheckParams { /// This particular Val virtualizes the claim output by the registers val-evaluation sumcheck /// and the instruction lookups sumcheck. fn compute_val_5( - sm: &mut StateManager>, + bytecode: &[Instruction], opening_accumulator: &dyn OpeningAccumulator, gamma_powers: &[F], ) -> Vec { @@ -1140,7 +1144,7 @@ impl ReadRafSumcheckParams { let eq_r_register = EqPolynomial::evals(&r_register); debug_assert_eq!(eq_r_register.len(), REGISTER_COUNT as usize); - sm.get_bytecode() + bytecode .par_iter() .map(|instruction| { let instr = instruction.normalize(); diff --git a/jolt-core/src/zkvm/dag/jolt_dag.rs b/jolt-core/src/zkvm/dag/jolt_dag.rs index d8f896043..4204d2f9b 100644 --- a/jolt-core/src/zkvm/dag/jolt_dag.rs +++ b/jolt-core/src/zkvm/dag/jolt_dag.rs @@ -8,7 +8,9 @@ use crate::poly::multilinear_polynomial::MultilinearPolynomial; use crate::poly::opening_proof::ProverOpeningAccumulator; use crate::poly::opening_proof::VerifierOpeningAccumulator; use crate::subprotocols::sumcheck::BatchedSumcheck; +use crate::subprotocols::sumcheck_verifier::SumcheckInstanceVerifier; use crate::transcripts::Transcript; +use crate::utils::math::Math; #[cfg(not(target_arch = "wasm32"))] use crate::utils::profiling::print_current_memory_usage; #[cfg(feature = "allocative")] @@ -17,22 +19,42 @@ use crate::utils::profiling::print_data_structure_heap_usage; use crate::utils::profiling::write_flamegraph_svg; use crate::utils::thread::drop_in_background_thread; use crate::utils::transpose; +use crate::zkvm::bytecode; +use crate::zkvm::bytecode::read_raf_checking::ReadRafSumcheckVerifier as BytecodeReadRafSumcheckVerifier; use crate::zkvm::bytecode::BytecodeDagProver; -use crate::zkvm::bytecode::BytecodeDagVerifier; use crate::zkvm::dag::proof_serialization::Claims; use crate::zkvm::dag::proof_serialization::JoltProof; use crate::zkvm::dag::stage::SumcheckStagesProver; -use crate::zkvm::dag::stage::SumcheckStagesVerifier; use crate::zkvm::dag::state_manager::StateManager; +use crate::zkvm::instruction_lookups; +use crate::zkvm::instruction_lookups::ra_virtual::RaSumcheckVerifier as LookupsRaSumcheckVerifier; +use crate::zkvm::instruction_lookups::read_raf_checking::ReadRafSumcheckVerifier as LookupsReadRafSumcheckVerifier; use crate::zkvm::instruction_lookups::LookupsDagProver; -use crate::zkvm::instruction_lookups::LookupsDagVerifier; +use crate::zkvm::r1cs::key::UniformSpartanKey; +use crate::zkvm::ram; +use crate::zkvm::ram::hamming_booleanity::HammingBooleanitySumcheckVerifier; +use crate::zkvm::ram::output_check::OutputSumcheckVerifier; +use crate::zkvm::ram::output_check::ValFinalSumcheckVerifier; +use crate::zkvm::ram::ra_virtual::RaSumcheckVerifier as RamRaSumcheckVerifier; +use crate::zkvm::ram::raf_evaluation::RafEvaluationSumcheckVerifier as RamRafEvaluationSumcheckVerifier; +use crate::zkvm::ram::read_write_checking::RamReadWriteCheckingVerifier; +use crate::zkvm::ram::val_evaluation::ValEvaluationSumcheckVerifier as RamValEvaluationSumcheckVerifier; +use crate::zkvm::ram::verifier_accumulate_advice; use crate::zkvm::ram::RamDagProver; -use crate::zkvm::ram::RamDagVerifier; +use crate::zkvm::registers::read_write_checking::RegistersReadWriteCheckingVerifier; +use crate::zkvm::registers::val_evaluation::ValEvaluationSumcheckVerifier as RegistersValEvaluationSumcheckVerifier; use crate::zkvm::registers::RegistersDagProver; -use crate::zkvm::registers::RegistersDagVerifier; +use crate::zkvm::spartan::inner::InnerSumcheckVerifier; +use crate::zkvm::spartan::instruction_input::InstructionInputSumcheckVerifier; +use crate::zkvm::spartan::outer::OuterRemainingSumcheckVerifier; +use crate::zkvm::spartan::product::ProductVirtualInnerVerifier; +use crate::zkvm::spartan::product::ProductVirtualRemainderVerifier; +use crate::zkvm::spartan::shift::ShiftSumcheckVerifier; +use crate::zkvm::spartan::verify_stage1_uni_skip; +use crate::zkvm::spartan::verify_stage2_uni_skip; use crate::zkvm::spartan::SpartanDagProver; -use crate::zkvm::spartan::SpartanDagVerifier; use crate::zkvm::witness::{AllCommittedPolynomials, CommittedPolynomial, DTH_ROOT_OF_K}; +use crate::zkvm::JoltVerifierPreprocessing; use crate::zkvm::ProverDebugInfo; #[cfg(feature = "allocative")] use allocative::FlameGraphBuilder; @@ -509,261 +531,356 @@ pub fn prove_jolt_dag< Ok((proof, debug_info)) } -#[tracing::instrument(skip_all)] -pub fn verify_jolt_dag< +pub struct DagVerifier< 'a, + 'b, + 'c, F: JoltField, ProofTranscript: Transcript, PCS: CommitmentScheme, ->( - proof: &JoltProof, - mut state_manager: StateManager<'a, F, PCS>, - mut opening_accumulator: VerifierOpeningAccumulator, - transcript: &mut ProofTranscript, -) -> Result<(), anyhow::Error> { - state_manager.fiat_shamir_preamble(transcript); +> { + pub state_manager: StateManager<'a, F, PCS>, + pub proof: JoltProof, + pub opening_accumulator: VerifierOpeningAccumulator, + pub transcript: &'b mut ProofTranscript, + pub preprocessing: &'c JoltVerifierPreprocessing, +} - let ram_K = state_manager.ram_K; - let bytecode_d = state_manager.get_verifier_data().0.shared.bytecode.d; - let _guard = AllCommittedPolynomials::initialize(ram_K, bytecode_d); +impl<'a, 'b, 'c, F: JoltField, ProofTranscript: Transcript, PCS: CommitmentScheme> + DagVerifier<'a, 'b, 'c, F, ProofTranscript, PCS> +{ + #[tracing::instrument(skip_all)] + pub fn verify(mut self) -> Result<(), anyhow::Error> { + self.state_manager.fiat_shamir_preamble(self.transcript); - // Append commitments to transcript - for commitment in &proof.commitments { - transcript.append_serializable(commitment); - } + let ram_K = self.proof.ram_K; + let bytecode_d = self.state_manager.get_verifier_data().0.shared.bytecode.d; + let _guard = AllCommittedPolynomials::initialize(ram_K, bytecode_d); - // Append untrusted advice commitment to transcript - if let Some(ref untrusted_advice_commitment) = state_manager.untrusted_advice_commitment { - transcript.append_serializable(untrusted_advice_commitment); - } - // Append trusted advice commitment to transcript - if let Some(ref trusted_advice_commitment) = state_manager.trusted_advice_commitment { - transcript.append_serializable(trusted_advice_commitment); - } - - // Initialize Dags - let (preprocessing, _, trace_length) = state_manager.get_verifier_data(); - let padded_trace_length = trace_length.next_power_of_two(); - let mut spartan_dag = SpartanDagVerifier::new(padded_trace_length); - let mut lookups_dag = LookupsDagVerifier; - let mut registers_dag = RegistersDagVerifier; - let mut ram_dag = RamDagVerifier::new(&state_manager); - let mut bytecode_dag = BytecodeDagVerifier; + // Append commitments to transcript + for commitment in &self.proof.commitments { + self.transcript.append_serializable(commitment); + } + // Append untrusted advice commitment to transcript + if let Some(ref untrusted_advice_commitment) = + self.state_manager.untrusted_advice_commitment + { + self.transcript + .append_serializable(untrusted_advice_commitment); + } + // Append trusted advice commitment to transcript + if let Some(ref trusted_advice_commitment) = self.state_manager.trusted_advice_commitment { + self.transcript + .append_serializable(trusted_advice_commitment); + } - // Stage 1: - spartan_dag - .stage1_uni_skip(&proof.stage1_uni_skip_first_round_proof, transcript) - .context("Stage 1 univariate skip first round")?; + let spartan_key = UniformSpartanKey::new(self.proof.trace_length.next_power_of_two()); - let stage1_remainder_instances: Vec<_> = spartan_dag - .stage1_instances(&mut state_manager, &mut opening_accumulator, transcript) - .into_iter() - .collect(); - let stage1_remainder_instances_ref = stage1_remainder_instances - .iter() - .map(|instance| &**instance as _) - .collect(); + self.verify_stage1(&spartan_key)?; + self.verify_stage2(&spartan_key)?; + self.verify_stage3()?; + self.verify_stage4()?; + self.verify_stage5()?; + self.verify_stage6()?; + self.verify_trusted_advice_opening_proofs()?; + self.verify_untrusted_advice_opening_proofs()?; + self.verify_stage7()?; - let _r_stage1 = BatchedSumcheck::verify( - &proof.stage1_sumcheck_proof, - stage1_remainder_instances_ref, - &mut opening_accumulator, - transcript, - ) - .context("Stage 1 remainder")?; + Ok(()) + } - // Stage 2: - // Stage 2a: Verify univariate-skip first round for product virtualization - spartan_dag - .stage2_uni_skip( - &proof.stage2_uni_skip_first_round_proof, - &mut opening_accumulator, - transcript, + fn verify_stage1(&mut self, spartan_key: &UniformSpartanKey) -> Result<(), anyhow::Error> { + let spartan_outer_uni_skip_state = verify_stage1_uni_skip( + &self.proof.stage1_uni_skip_first_round_proof, + spartan_key, + self.transcript, ) - .context("Stage 2 univariate skip first round")?; + .context("Stage 1 univariate skip first round")?; - let stage2_instances: Vec<_> = std::iter::empty() - .chain(spartan_dag.stage2_instances( - &mut state_manager, - &mut opening_accumulator, - transcript, - )) - .chain(registers_dag.stage2_instances( - &mut state_manager, - &mut opening_accumulator, - transcript, - )) - .chain(ram_dag.stage2_instances(&mut state_manager, &mut opening_accumulator, transcript)) - .chain(lookups_dag.stage2_instances( - &mut state_manager, - &mut opening_accumulator, - transcript, - )) - .chain(bytecode_dag.stage2_instances( - &mut state_manager, - &mut opening_accumulator, - transcript, - )) - .collect(); - let stage2_instances_ref = stage2_instances.iter().map(|inst| &**inst as _).collect(); + let n_cycle_vars = self.proof.trace_length.log_2(); + let spartan_outer_remaining = + OuterRemainingSumcheckVerifier::new(n_cycle_vars, &spartan_outer_uni_skip_state); - let _r_stage2 = BatchedSumcheck::verify( - &proof.stage2_sumcheck_proof, - stage2_instances_ref, - &mut opening_accumulator, - transcript, - ) - .context("Stage 2")?; + let _r_stage1 = BatchedSumcheck::verify( + &self.proof.stage1_sumcheck_proof, + vec![&spartan_outer_remaining], + &mut self.opening_accumulator, + self.transcript, + ) + .context("Stage 1")?; - // Stage 3: - let stage3_instances: Vec<_> = std::iter::empty() - .chain(spartan_dag.stage3_instances( - &mut state_manager, - &mut opening_accumulator, - transcript, - )) - .chain(lookups_dag.stage3_instances( - &mut state_manager, - &mut opening_accumulator, - transcript, - )) - .chain(ram_dag.stage3_instances(&mut state_manager, &mut opening_accumulator, transcript)) - .collect(); - let stage3_instances_ref = stage3_instances.iter().map(|inst| &**inst as _).collect(); + Ok(()) + } - let _r_stage3 = BatchedSumcheck::verify( - &proof.stage3_sumcheck_proof, - stage3_instances_ref, - &mut opening_accumulator, - transcript, - ) - .context("Stage 3")?; + fn verify_stage2(&mut self, spartan_key: &UniformSpartanKey) -> Result<(), anyhow::Error> { + let product_virtual_uni_skip_state = verify_stage2_uni_skip( + &self.proof.stage2_uni_skip_first_round_proof, + spartan_key, + &mut self.opening_accumulator, + self.transcript, + ) + .context("Stage 2 univariate skip first round")?; - // Stage 4: - let stage4_instances: Vec<_> = std::iter::empty() - .chain(registers_dag.stage4_instances( - &mut state_manager, - &mut opening_accumulator, - transcript, - )) - .chain(ram_dag.stage4_instances(&mut state_manager, &mut opening_accumulator, transcript)) - .collect(); - let stage4_instances_ref = stage4_instances - .iter() - .map(|instance| &**instance as _) - .collect(); + let spartan_inner = InnerSumcheckVerifier::new(spartan_key, self.transcript); + let spartan_product_virtual_remainder = ProductVirtualRemainderVerifier::new( + self.proof.trace_length.log_2(), + &product_virtual_uni_skip_state, + ); + let ram_raf_evaluation = RamRafEvaluationSumcheckVerifier::new( + &self.state_manager.program_io, + self.proof.ram_K, + &self.opening_accumulator, + ); + let ram_read_write_checking = RamReadWriteCheckingVerifier::new( + self.proof.ram_K, + self.proof.trace_length, + self.proof.twist_sumcheck_switch_index, + &self.opening_accumulator, + self.transcript, + ); + let ram_output_check = OutputSumcheckVerifier::new( + self.proof.ram_K, + &self.state_manager.program_io, + self.transcript, + ); + + let _r_stage2 = BatchedSumcheck::verify( + &self.proof.stage2_sumcheck_proof, + vec![ + &spartan_inner as &dyn SumcheckInstanceVerifier, + &spartan_product_virtual_remainder, + &ram_raf_evaluation, + &ram_read_write_checking, + &ram_output_check, + ], + &mut self.opening_accumulator, + self.transcript, + ) + .context("Stage 2")?; - let _r_stage4 = BatchedSumcheck::verify( - &proof.stage4_sumcheck_proof, - stage4_instances_ref, - &mut opening_accumulator, - transcript, - ) - .context("Stage 4")?; + Ok(()) + } - // Stage 5: - let stage5_instances: Vec<_> = std::iter::empty() - .chain(registers_dag.stage5_instances( - &mut state_manager, - &mut opening_accumulator, - transcript, - )) - .chain(ram_dag.stage5_instances(&mut state_manager, &mut opening_accumulator, transcript)) - .chain(lookups_dag.stage5_instances( - &mut state_manager, - &mut opening_accumulator, - transcript, - )) - .collect(); - let stage5_instances_ref = stage5_instances.iter().map(|inst| &**inst as _).collect(); + fn verify_stage3(&mut self) -> Result<(), anyhow::Error> { + let spartan_shift_sumcheck = ShiftSumcheckVerifier::new( + self.proof.trace_length.log_2(), + &self.opening_accumulator, + self.transcript, + ); + let spartan_instruction_input = + InstructionInputSumcheckVerifier::new(&self.opening_accumulator, self.transcript); + let spartan_product_virtual_claim_check = + ProductVirtualInnerVerifier::new(&self.opening_accumulator, self.transcript); + let lookups_ra_hamming_weight = + instruction_lookups::new_ra_hamming_weight_verifier(self.transcript); + + let _r_stage3 = BatchedSumcheck::verify( + &self.proof.stage3_sumcheck_proof, + vec![ + &spartan_shift_sumcheck as &dyn SumcheckInstanceVerifier, + &spartan_instruction_input, + &spartan_product_virtual_claim_check, + &lookups_ra_hamming_weight, + ], + &mut self.opening_accumulator, + self.transcript, + ) + .context("Stage 3")?; - let _r_stage5 = BatchedSumcheck::verify( - &proof.stage5_sumcheck_proof, - stage5_instances_ref, - &mut opening_accumulator, - transcript, - ) - .context("Stage 5")?; + Ok(()) + } - // Stage 6: - let stage6_instances: Vec<_> = std::iter::empty() - .chain(bytecode_dag.stage6_instances( - &mut state_manager, - &mut opening_accumulator, - transcript, - )) - .chain(ram_dag.stage6_instances(&mut state_manager, &mut opening_accumulator, transcript)) - .chain(lookups_dag.stage6_instances( - &mut state_manager, - &mut opening_accumulator, - transcript, - )) - .collect(); - let stage6_instances_ref = stage6_instances.iter().map(|inst| &**inst as _).collect(); + fn verify_stage4(&mut self) -> Result<(), anyhow::Error> { + let registers_read_write_checking = RegistersReadWriteCheckingVerifier::new( + self.proof.twist_sumcheck_switch_index, + self.proof.trace_length.log_2(), + &self.opening_accumulator, + self.transcript, + ); + verifier_accumulate_advice::( + self.proof.ram_K, + &self.state_manager.program_io, + self.state_manager.untrusted_advice_commitment.is_some(), + self.state_manager.trusted_advice_commitment.is_some(), + &mut self.opening_accumulator, + self.transcript, + ); + let ram_ra_booleanity = ram::new_ra_booleanity_verifier( + self.proof.ram_K, + self.proof.trace_length.log_2(), + self.transcript, + ); + let initial_ram_state = ram::gen_ram_initial_memory_state::( + self.proof.ram_K, + &self.preprocessing.shared.ram, + &self.state_manager.program_io, + ); + let ram_val_evaluation = RamValEvaluationSumcheckVerifier::new( + &initial_ram_state, + &self.state_manager.program_io, + self.proof.trace_length, + self.proof.ram_K, + &self.opening_accumulator, + ); + let ram_val_final = ValFinalSumcheckVerifier::new( + &initial_ram_state, + &self.state_manager.program_io, + self.proof.trace_length, + self.proof.ram_K, + &self.opening_accumulator, + ); + + let _r_stage4 = BatchedSumcheck::verify( + &self.proof.stage4_sumcheck_proof, + vec![ + ®isters_read_write_checking as &dyn SumcheckInstanceVerifier, + &ram_ra_booleanity, + &ram_val_evaluation, + &ram_val_final, + ], + &mut self.opening_accumulator, + self.transcript, + ) + .context("Stage 4")?; - let _r_stage6 = BatchedSumcheck::verify( - &proof.stage6_sumcheck_proof, - stage6_instances_ref, - &mut opening_accumulator, - transcript, - ) - .context("Stage 6")?; + Ok(()) + } - // Verify trusted_advice opening proofs - if let Some(ref commitment) = state_manager.trusted_advice_commitment { - let Some(ref proof) = proof.trusted_advice_proof else { - return Err(anyhow::anyhow!("Trusted advice proof not found")); - }; - let Some((point, eval)) = opening_accumulator.get_trusted_advice_opening() else { - return Err(anyhow::anyhow!("Trusted advice opening not found")); - }; - PCS::verify( - proof, - &preprocessing.generators, - transcript, - &point.r, - &eval, - commitment, + fn verify_stage5(&mut self) -> Result<(), anyhow::Error> { + let n_cycle_vars = self.proof.trace_length.log_2(); + let registers_val_evaluation = RegistersValEvaluationSumcheckVerifier::new(n_cycle_vars); + let ram_hamming_booleanity = HammingBooleanitySumcheckVerifier::new(n_cycle_vars); + let ram_ra_virtual = RamRaSumcheckVerifier::new( + self.proof.trace_length, + self.proof.ram_K, + &self.opening_accumulator, + self.transcript, + ); + let lookups_read_raf = LookupsReadRafSumcheckVerifier::new(n_cycle_vars, self.transcript); + + let _r_stage5 = BatchedSumcheck::verify( + &self.proof.stage5_sumcheck_proof, + vec![ + ®isters_val_evaluation as &dyn SumcheckInstanceVerifier, + &ram_hamming_booleanity, + &ram_ra_virtual, + &lookups_read_raf, + ], + &mut self.opening_accumulator, + self.transcript, ) - .map_err(|e| anyhow::anyhow!("Trusted advice opening proof verification failed: {e:?}"))?; + .context("Stage 5")?; + + Ok(()) } - // Verify untrusted_advice opening proofs - if let Some(ref commitment) = state_manager.untrusted_advice_commitment { - let Some(ref proof) = proof.untrusted_advice_proof else { - return Err(anyhow::anyhow!("Untrusted advice proof not found")); - }; - let Some((point, eval)) = opening_accumulator.get_untrusted_advice_opening() else { - return Err(anyhow::anyhow!("Untrusted advice opening not found")); - }; - PCS::verify( - proof, - &preprocessing.generators, - transcript, - &point.r, - &eval, - commitment, + fn verify_stage6(&mut self) -> Result<(), anyhow::Error> { + let n_cycle_vars = self.proof.trace_length.log_2(); + let bytecode_read_raf = BytecodeReadRafSumcheckVerifier::gen( + &self.preprocessing.shared.bytecode, + n_cycle_vars, + &self.opening_accumulator, + self.transcript, + ); + let (bytecode_hamming_weight, bytecode_booleanity) = bytecode::new_ra_one_hot_verifiers( + &self.preprocessing.shared.bytecode, + n_cycle_vars, + self.transcript, + ); + let ram_hamming_weight = + ram::new_ra_hamming_weight_verifier(self.proof.ram_K, self.transcript); + let lookups_ra_virtual = LookupsRaSumcheckVerifier::new(&self.opening_accumulator); + let lookups_booleanity = + instruction_lookups::new_ra_booleanity_verifier(n_cycle_vars, self.transcript); + + let _r_stage6 = BatchedSumcheck::verify( + &self.proof.stage6_sumcheck_proof, + vec![ + &bytecode_read_raf as &dyn SumcheckInstanceVerifier, + &bytecode_hamming_weight, + &bytecode_booleanity, + &ram_hamming_weight, + &lookups_ra_virtual, + &lookups_booleanity, + ], + &mut self.opening_accumulator, + self.transcript, ) - .map_err(|e| { - anyhow::anyhow!("Untrusted advice opening proof verification failed: {e:?}") - })?; + .context("Stage 6")?; + + Ok(()) } - // Batch-prove all openings (Stage 7) - let mut commitments_map = HashMap::new(); - for (polynomial, commitment) in AllCommittedPolynomials::iter().zip_eq(&proof.commitments) { - commitments_map.insert(*polynomial, commitment.clone()); + fn verify_trusted_advice_opening_proofs(&mut self) -> Result<(), anyhow::Error> { + if let Some(ref commitment) = self.state_manager.trusted_advice_commitment { + let Some(ref proof) = self.proof.trusted_advice_proof else { + return Err(anyhow::anyhow!("Trusted advice proof not found")); + }; + let Some((point, eval)) = self.opening_accumulator.get_trusted_advice_opening() else { + return Err(anyhow::anyhow!("Trusted advice opening not found")); + }; + PCS::verify( + proof, + &self.preprocessing.generators, + self.transcript, + &point.r, + &eval, + commitment, + ) + .map_err(|e| { + anyhow::anyhow!("Trusted advice opening proof verification failed: {e:?}") + })?; + } + + Ok(()) } - opening_accumulator - .reduce_and_verify( - &preprocessing.generators, - &mut commitments_map, - &proof.reduced_opening_proof, - transcript, - ) - .context("Stage 7")?; - Ok(()) + fn verify_untrusted_advice_opening_proofs(&mut self) -> Result<(), anyhow::Error> { + // Verify untrusted_advice opening proofs + if let Some(ref commitment) = self.state_manager.untrusted_advice_commitment { + let Some(ref proof) = self.proof.untrusted_advice_proof else { + return Err(anyhow::anyhow!("Untrusted advice proof not found")); + }; + let Some((point, eval)) = self.opening_accumulator.get_untrusted_advice_opening() + else { + return Err(anyhow::anyhow!("Untrusted advice opening not found")); + }; + PCS::verify( + proof, + &self.preprocessing.generators, + self.transcript, + &point.r, + &eval, + commitment, + ) + .map_err(|e| { + anyhow::anyhow!("Untrusted advice opening proof verification failed: {e:?}") + })?; + } + + Ok(()) + } + + fn verify_stage7(&mut self) -> Result<(), anyhow::Error> { + // Batch-prove all openings (Stage 7) + let mut commitments_map = HashMap::new(); + for (polynomial, commitment) in + AllCommittedPolynomials::iter().zip_eq(&self.proof.commitments) + { + commitments_map.insert(*polynomial, commitment.clone()); + } + + self.opening_accumulator + .reduce_and_verify( + &self.preprocessing.generators, + &mut commitments_map, + &self.proof.reduced_opening_proof, + self.transcript, + ) + .context("Stage 7")?; + + Ok(()) + } } // Prover utility to commit to all the polynomials for the PCS diff --git a/jolt-core/src/zkvm/instruction_lookups/mod.rs b/jolt-core/src/zkvm/instruction_lookups/mod.rs index eb49fdcf2..61699e157 100644 --- a/jolt-core/src/zkvm/instruction_lookups/mod.rs +++ b/jolt-core/src/zkvm/instruction_lookups/mod.rs @@ -5,26 +5,20 @@ use crate::{ poly::{ commitment::commitment_scheme::CommitmentScheme, eq_poly::EqPolynomial, - opening_proof::{ - OpeningAccumulator, ProverOpeningAccumulator, SumcheckId, VerifierOpeningAccumulator, - }, + opening_proof::{OpeningAccumulator, ProverOpeningAccumulator, SumcheckId}, }, subprotocols::{ - sumcheck_prover::SumcheckInstanceProver, sumcheck_verifier::SumcheckInstanceVerifier, - BooleanitySumcheckParams, BooleanitySumcheckProver, BooleanitySumcheckVerifier, - HammingWeightSumcheckParams, HammingWeightSumcheckProver, HammingWeightSumcheckVerifier, + sumcheck_prover::SumcheckInstanceProver, BooleanitySumcheckParams, + BooleanitySumcheckProver, BooleanitySumcheckVerifier, HammingWeightSumcheckParams, + HammingWeightSumcheckProver, HammingWeightSumcheckVerifier, }, transcripts::Transcript, utils::{math::Math, thread::unsafe_allocate_zero_vec}, zkvm::{ - dag::{ - stage::{SumcheckStagesProver, SumcheckStagesVerifier}, - state_manager::StateManager, - }, + dag::{stage::SumcheckStagesProver, state_manager::StateManager}, instruction::LookupQuery, instruction_lookups::{ - ra_virtual::{RaSumcheckProver, RaSumcheckVerifier}, - read_raf_checking::{ReadRafSumcheckProver, ReadRafSumcheckVerifier}, + ra_virtual::RaSumcheckProver, read_raf_checking::ReadRafSumcheckProver, }, witness::{CommittedPolynomial, VirtualPolynomial}, }, @@ -132,43 +126,6 @@ impl, T: Transcript> SumcheckStag } } -pub struct LookupsDagVerifier; - -impl, T: Transcript> - SumcheckStagesVerifier for LookupsDagVerifier -{ - fn stage3_instances( - &mut self, - _sm: &mut StateManager<'_, F, PCS>, - _opening_accumulator: &mut VerifierOpeningAccumulator, - transcript: &mut T, - ) -> Vec>> { - let hamming_weight = new_ra_hamming_weight_verifier(transcript); - vec![Box::new(hamming_weight)] - } - - fn stage5_instances( - &mut self, - sm: &mut StateManager<'_, F, PCS>, - _opening_accumulator: &mut VerifierOpeningAccumulator, - transcript: &mut T, - ) -> Vec>> { - let read_raf = ReadRafSumcheckVerifier::new(sm, transcript); - vec![Box::new(read_raf)] - } - - fn stage6_instances( - &mut self, - sm: &mut StateManager<'_, F, PCS>, - opening_accumulator: &mut VerifierOpeningAccumulator, - transcript: &mut T, - ) -> Vec>> { - let ra_virtual = RaSumcheckVerifier::new(opening_accumulator); - let booleanity = new_ra_booleanity_verifier(sm, transcript); - vec![Box::new(ra_virtual), Box::new(booleanity)] - } -} - fn gen_ra_booleanity_prover( state_manager: &mut StateManager<'_, F, impl CommitmentScheme>, opening_accumulator: &ProverOpeningAccumulator, @@ -225,26 +182,19 @@ fn gen_ra_hamming_weight_prover( HammingWeightSumcheckProver::gen(params, ra_evals.to_vec()) } -fn new_ra_booleanity_verifier( - state_manager: &mut StateManager<'_, F, impl CommitmentScheme>, +pub fn new_ra_booleanity_verifier( + n_cycle_vars: usize, transcript: &mut impl Transcript, ) -> BooleanitySumcheckVerifier { - let (_, _, T_val) = state_manager.get_verifier_data(); - let log_t = T_val.log_2(); - - let gammas: Vec = transcript.challenge_vector_optimized::(D); - - let r_address: Vec = transcript.challenge_vector_optimized::(LOG_K_CHUNK); - + let gammas = transcript.challenge_vector_optimized::(D); + let r_address = transcript.challenge_vector_optimized::(LOG_K_CHUNK); let r_cycle = Vec::new(); - let polynomial_types: Vec = (0..D).map(CommittedPolynomial::InstructionRa).collect(); - let params = BooleanitySumcheckParams { d: D, log_k_chunk: LOG_K_CHUNK, - log_t, + log_t: n_cycle_vars, gammas, r_address, r_cycle, @@ -256,7 +206,7 @@ fn new_ra_booleanity_verifier( BooleanitySumcheckVerifier::new(params) } -fn new_ra_hamming_weight_verifier( +pub fn new_ra_hamming_weight_verifier( transcript: &mut T, ) -> HammingWeightSumcheckVerifier { let gamma_powers = transcript.challenge_scalar_powers(D); diff --git a/jolt-core/src/zkvm/instruction_lookups/read_raf_checking.rs b/jolt-core/src/zkvm/instruction_lookups/read_raf_checking.rs index a927ed0cc..f97c9e84b 100644 --- a/jolt-core/src/zkvm/instruction_lookups/read_raf_checking.rs +++ b/jolt-core/src/zkvm/instruction_lookups/read_raf_checking.rs @@ -185,8 +185,9 @@ impl<'a, F: JoltField> ReadRafSumcheckProver { opening_accumulator: &ProverOpeningAccumulator, transcript: &mut impl Transcript, ) -> Self { - let params = ReadRafSumcheckParams::new(sm, transcript); let trace = sm.get_prover_data().2; + let log_T = trace.len().log_2(); + let params = ReadRafSumcheckParams::new(log_T, transcript); let (r_branch, _) = opening_accumulator.get_virtual_polynomial_opening( VirtualPolynomial::LookupOutput, SumcheckId::ProductVirtualization, @@ -1008,11 +1009,8 @@ pub struct ReadRafSumcheckVerifier { } impl ReadRafSumcheckVerifier { - pub fn new( - state_manager: &mut StateManager<'_, F, impl CommitmentScheme>, - transcript: &mut impl Transcript, - ) -> Self { - let params = ReadRafSumcheckParams::new(state_manager, transcript); + pub fn new(n_cycle_vars: usize, transcript: &mut impl Transcript) -> Self { + let params = ReadRafSumcheckParams::new(n_cycle_vars, transcript); Self { params } } } @@ -1153,17 +1151,13 @@ struct ReadRafSumcheckParams { } impl ReadRafSumcheckParams { - fn new( - state_manager: &mut StateManager<'_, F, impl CommitmentScheme>, - transcript: &mut impl Transcript, - ) -> Self { + fn new(n_cycle_vars: usize, transcript: &mut impl Transcript) -> Self { let gamma = transcript.challenge_scalar::(); let gamma_sqr = gamma.square(); - let log_T = state_manager.get_trace_len().log_2(); Self { gamma, gamma_sqr, - log_T, + log_T: n_cycle_vars, } } @@ -1212,7 +1206,7 @@ mod tests { poly::commitment::mock::MockCommitScheme, zkvm::{ bytecode::BytecodePreprocessing, ram::RAMPreprocessing, JoltProverPreprocessing, - JoltSharedPreprocessing, JoltVerifierPreprocessing, + JoltSharedPreprocessing, }, }; use ark_bn254::Fr; @@ -1319,11 +1313,6 @@ mod tests { shared: shared_preprocessing.clone(), }; - let verifier_preprocessing: JoltVerifierPreprocessing> = - JoltVerifierPreprocessing { - generators: (), - shared: shared_preprocessing, - }; let program_io = JoltDevice { memory_layout, untrusted_advice: vec![], @@ -1341,19 +1330,12 @@ mod tests { &prover_preprocessing, lazy_trace, trace.clone(), - program_io.clone(), + program_io, None, final_memory_state, ); let verifier_transcript = &mut Blake2bTranscript::new(&[]); let mut verifier_opening_accumulator = VerifierOpeningAccumulator::new(trace.len().log_2()); - let mut verifier_sm = StateManager::<'_, Fr, _>::new_verifier( - &verifier_preprocessing, - program_io, - trace.len(), - 1 << 8, - prover_sm.twist_sumcheck_switch_index, - ); let r_cycle: Vec<::Challenge> = prover_transcript.challenge_vector_optimized::(LOG_T); @@ -1466,7 +1448,7 @@ mod tests { ); let mut verifier_sumcheck = - ReadRafSumcheckVerifier::new(&mut verifier_sm, verifier_transcript); + ReadRafSumcheckVerifier::new(trace.len().log_2(), verifier_transcript); let r_sumcheck_verif = BatchedSumcheck::verify( &proof, diff --git a/jolt-core/src/zkvm/mod.rs b/jolt-core/src/zkvm/mod.rs index 768a7da6f..c59726b2f 100644 --- a/jolt-core/src/zkvm/mod.rs +++ b/jolt-core/src/zkvm/mod.rs @@ -14,9 +14,7 @@ use crate::{ utils::{errors::ProofVerifyError, math::Math}, zkvm::{ bytecode::BytecodePreprocessing, - dag::{ - jolt_dag::verify_jolt_dag, proof_serialization::JoltProof, state_manager::StateManager, - }, + dag::{jolt_dag::DagVerifier, proof_serialization::JoltProof, state_manager::StateManager}, ram::RAMPreprocessing, witness::DTH_ROOT_OF_K, }, @@ -435,8 +433,15 @@ pub trait Jolt, FS: Tran }), }; - verify_jolt_dag(&proof, state_manager, opening_accumulator, transcript) - .expect("Verification failed"); + DagVerifier { + state_manager, + proof, + opening_accumulator, + transcript, + preprocessing, + } + .verify() + .expect("Verification failed"); Ok(()) } diff --git a/jolt-core/src/zkvm/ram/hamming_booleanity.rs b/jolt-core/src/zkvm/ram/hamming_booleanity.rs index 5dd7e8e25..35191dfe3 100644 --- a/jolt-core/src/zkvm/ram/hamming_booleanity.rs +++ b/jolt-core/src/zkvm/ram/hamming_booleanity.rs @@ -198,9 +198,10 @@ pub struct HammingBooleanitySumcheckVerifier { } impl HammingBooleanitySumcheckVerifier { - pub fn new(state_manager: &mut StateManager>) -> Self { + pub fn new(n_cycle_vars: usize) -> Self { Self { - log_T: state_manager.get_trace_len().log_2(), + // TODO: Make the name for this consistent across the codebase. + log_T: n_cycle_vars, _phantom: PhantomData, } } diff --git a/jolt-core/src/zkvm/ram/mod.rs b/jolt-core/src/zkvm/ram/mod.rs index a9f95e9e0..b7a5bcafd 100644 --- a/jolt-core/src/zkvm/ram/mod.rs +++ b/jolt-core/src/zkvm/ram/mod.rs @@ -7,8 +7,7 @@ use crate::subprotocols::{ }; #[cfg(feature = "allocative")] use crate::utils::profiling::print_data_structure_heap_usage; -use crate::zkvm::dag::stage::{SumcheckStagesProver, SumcheckStagesVerifier}; -use crate::zkvm::ram::hamming_booleanity::HammingBooleanitySumcheckVerifier; +use crate::zkvm::dag::stage::SumcheckStagesProver; use crate::{ field::{self, JoltField}, poly::{ @@ -17,23 +16,18 @@ use crate::{ multilinear_polynomial::{MultilinearPolynomial, PolynomialEvaluation}, opening_proof::{OpeningAccumulator, OpeningPoint, SumcheckId, BIG_ENDIAN}, }, - subprotocols::{ - sumcheck_prover::SumcheckInstanceProver, sumcheck_verifier::SumcheckInstanceVerifier, - }, + subprotocols::sumcheck_prover::SumcheckInstanceProver, transcripts::Transcript, utils::{math::Math, thread::unsafe_allocate_zero_vec}, zkvm::{ dag::state_manager::StateManager, ram::{ hamming_booleanity::HammingBooleanitySumcheckProver, - output_check::{ - OutputSumcheckProver, OutputSumcheckVerifier, ValFinalSumcheckProver, - ValFinalSumcheckVerifier, - }, - ra_virtual::{RaSumcheckProver, RaSumcheckVerifier}, - raf_evaluation::{RafEvaluationSumcheckProver, RafEvaluationSumcheckVerifier}, - read_write_checking::{RamReadWriteCheckingProver, RamReadWriteCheckingVerifier}, - val_evaluation::{ValEvaluationSumcheckProver, ValEvaluationSumcheckVerifier}, + output_check::{OutputSumcheckProver, ValFinalSumcheckProver}, + ra_virtual::RaSumcheckProver, + raf_evaluation::RafEvaluationSumcheckProver, + read_write_checking::RamReadWriteCheckingProver, + val_evaluation::ValEvaluationSumcheckProver, }, witness::{compute_d_parameter, CommittedPolynomial, VirtualPolynomial, DTH_ROOT_OF_K}, }, @@ -47,6 +41,7 @@ use common::{ }; use rayon::prelude::*; use tracer::instruction::Cycle; +use tracer::JoltDevice; pub mod hamming_booleanity; pub mod output_check; @@ -328,21 +323,21 @@ pub fn prover_accumulate_advice( } /// Accumulates advice commitments into the verifier's accumulator. -pub fn verifier_accumulate_advice( - state_manager: &mut StateManager<'_, F, PCS>, +pub fn verifier_accumulate_advice( + ram_K: usize, + program_io: &JoltDevice, + has_untrusted_advice_commitment: bool, + has_trusted_advice_commitment: bool, opening_accumulator: &mut VerifierOpeningAccumulator, transcript: &mut impl Transcript, -) where - F: JoltField, - PCS: CommitmentScheme, -{ +) { let get_advice_point = |opening_accumulator: &VerifierOpeningAccumulator, max_advice_size: usize| { let (r, _) = opening_accumulator.get_virtual_polynomial_opening( VirtualPolynomial::RamVal, SumcheckId::RamReadWriteChecking, ); - let (r_address, _) = r.split_at(state_manager.ram_K.log_2()); + let (r_address, _) = r.split_at(ram_K.log_2()); let total_vars = r_address.r.len(); let advice_variables = (max_advice_size / 8).next_power_of_two().log_2(); @@ -352,24 +347,18 @@ pub fn verifier_accumulate_advice( advice_point }; - if state_manager.untrusted_advice_commitment.is_some() { + if has_untrusted_advice_commitment { let point = get_advice_point( opening_accumulator, - state_manager - .program_io - .memory_layout - .max_untrusted_advice_size as usize, + program_io.memory_layout.max_untrusted_advice_size as usize, ); opening_accumulator.append_untrusted_advice(transcript, point); } - if state_manager.trusted_advice_commitment.is_some() { + if has_trusted_advice_commitment { let point = get_advice_point( opening_accumulator, - state_manager - .program_io - .memory_layout - .max_trusted_advice_size as usize, + program_io.memory_layout.max_trusted_advice_size as usize, ); opening_accumulator.append_trusted_advice(transcript, point); } @@ -570,126 +559,41 @@ where } } -pub struct RamDagVerifier { - initial_memory_state: Vec, -} - -impl RamDagVerifier { - pub fn new( - state_manager: &StateManager<'_, F, impl CommitmentScheme>, - ) -> Self { - let (preprocessing, program_io, _) = state_manager.get_verifier_data(); - let ram_preprocessing = &preprocessing.shared.ram; - - let K = state_manager.ram_K; - - let mut initial_memory_state = vec![0; K]; - // Copy bytecode - let mut index = remap_address( - ram_preprocessing.min_bytecode_address, - &program_io.memory_layout, - ) - .unwrap() as usize; - for word in &ram_preprocessing.bytecode_words { - initial_memory_state[index] = *word; - index += 1; - } - - index = remap_address( - program_io.memory_layout.input_start, - &program_io.memory_layout, - ) - .unwrap() as usize; - // Convert input bytes into words and populate - // `initial_memory_state` and `final_memory_state` - for chunk in program_io.inputs.chunks(8) { - let mut word = [0u8; 8]; - for (i, byte) in chunk.iter().enumerate() { - word[i] = *byte; - } - let word = u64::from_le_bytes(word); - initial_memory_state[index] = word; - index += 1; - } - - Self { - initial_memory_state, - } - } -} - -impl SumcheckStagesVerifier for RamDagVerifier -where - F: JoltField, - ProofTranscript: Transcript, - PCS: CommitmentScheme, -{ - fn stage2_instances( - &mut self, - state_manager: &mut StateManager<'_, F, PCS>, - opening_accumulator: &mut VerifierOpeningAccumulator, - transcript: &mut ProofTranscript, - ) -> Vec>> { - let raf_evaluation = RafEvaluationSumcheckVerifier::new(state_manager, opening_accumulator); - let read_write_checking = - RamReadWriteCheckingVerifier::new(state_manager, opening_accumulator, transcript); - let output_check = OutputSumcheckVerifier::new(state_manager, transcript); - - vec![ - Box::new(raf_evaluation), - Box::new(read_write_checking), - Box::new(output_check), - ] +pub fn gen_ram_initial_memory_state( + ram_K: usize, + ram_preprocessing: &RAMPreprocessing, + program_io: &JoltDevice, +) -> Vec { + let mut initial_memory_state = vec![0; ram_K]; + // Copy bytecode + let mut index = remap_address( + ram_preprocessing.min_bytecode_address, + &program_io.memory_layout, + ) + .unwrap() as usize; + for word in &ram_preprocessing.bytecode_words { + initial_memory_state[index] = *word; + index += 1; } - fn stage4_instances( - &mut self, - state_manager: &mut StateManager<'_, F, PCS>, - opening_accumulator: &mut VerifierOpeningAccumulator, - transcript: &mut ProofTranscript, - ) -> Vec>> { - // Accumulate advice commitments if present - verifier_accumulate_advice(state_manager, opening_accumulator, transcript); - let booleanity = new_ra_booleanity_verifier(state_manager, transcript); - - let val_evaluation = ValEvaluationSumcheckVerifier::new( - &self.initial_memory_state, - state_manager, - opening_accumulator, - ); - let val_final_evaluation = ValFinalSumcheckVerifier::new( - &self.initial_memory_state, - state_manager, - opening_accumulator, - ); - - vec![ - Box::new(booleanity), - Box::new(val_evaluation), - Box::new(val_final_evaluation), - ] - } - - fn stage5_instances( - &mut self, - state_manager: &mut StateManager<'_, F, PCS>, - opening_accumulator: &mut VerifierOpeningAccumulator, - transcript: &mut ProofTranscript, - ) -> Vec>> { - let hamming_booleanity = HammingBooleanitySumcheckVerifier::new(state_manager); - let ra_virtual = RaSumcheckVerifier::new(state_manager, opening_accumulator, transcript); - vec![Box::new(hamming_booleanity), Box::new(ra_virtual)] + index = remap_address( + program_io.memory_layout.input_start, + &program_io.memory_layout, + ) + .unwrap() as usize; + // Convert input bytes into words and populate + // `initial_memory_state` and `final_memory_state` + for chunk in program_io.inputs.chunks(8) { + let mut word = [0u8; 8]; + for (i, byte) in chunk.iter().enumerate() { + word[i] = *byte; + } + let word = u64::from_le_bytes(word); + initial_memory_state[index] = word; + index += 1; } - fn stage6_instances( - &mut self, - state_manager: &mut StateManager<'_, F, PCS>, - _opening_accumulator: &mut VerifierOpeningAccumulator, - transcript: &mut ProofTranscript, - ) -> Vec>> { - let hamming_weight = new_ra_hamming_weight_verifier(state_manager, transcript); - vec![Box::new(hamming_weight)] - } + initial_memory_state } fn gen_ra_booleanity_prover( @@ -768,17 +672,15 @@ fn gen_ra_hamming_weight_prover( HammingWeightSumcheckProver::gen(params, G) } -fn new_ra_booleanity_verifier( - state_manager: &mut StateManager<'_, F, impl CommitmentScheme>, +pub fn new_ra_booleanity_verifier( + ram_K: usize, + n_cycle_vars: usize, transcript: &mut impl Transcript, ) -> BooleanitySumcheckVerifier { - let (_, _, T) = state_manager.get_verifier_data(); - let K = state_manager.ram_K; - let d = compute_d_parameter(K); + let d = compute_d_parameter(ram_K); let log_k_chunk = DTH_ROOT_OF_K.log_2(); - let log_t = T.log_2(); - let r_cycle = transcript.challenge_vector_optimized::(log_t); + let r_cycle = transcript.challenge_vector_optimized::(n_cycle_vars); let r_address = transcript.challenge_vector_optimized::(log_k_chunk); let gammas = transcript.challenge_vector_optimized::(d); @@ -789,7 +691,7 @@ fn new_ra_booleanity_verifier( let params = BooleanitySumcheckParams { d, log_k_chunk, - log_t, + log_t: n_cycle_vars, gammas, r_address, r_cycle, @@ -801,11 +703,11 @@ fn new_ra_booleanity_verifier( BooleanitySumcheckVerifier::new(params) } -fn new_ra_hamming_weight_verifier( - state_manager: &mut StateManager<'_, F, impl CommitmentScheme>, +pub fn new_ra_hamming_weight_verifier( + ram_K: usize, transcript: &mut impl Transcript, ) -> HammingWeightSumcheckVerifier { - let d = compute_d_parameter(state_manager.ram_K); + let d = compute_d_parameter(ram_K); let num_rounds = DTH_ROOT_OF_K.log_2(); let gamma_powers = transcript.challenge_scalar_powers(d); diff --git a/jolt-core/src/zkvm/ram/output_check.rs b/jolt-core/src/zkvm/ram/output_check.rs index 761bad787..ecad6db98 100644 --- a/jolt-core/src/zkvm/ram/output_check.rs +++ b/jolt-core/src/zkvm/ram/output_check.rs @@ -91,7 +91,8 @@ impl OutputSumcheckProver { state_manager: &mut StateManager<'_, F, impl CommitmentScheme>, transcript: &mut impl Transcript, ) -> Self { - let params = OutputSumcheckParams::new(state_manager, transcript); + let (_, _, _, program_io, _) = state_manager.get_prover_data(); + let params = OutputSumcheckParams::new(state_manager.ram_K, program_io, transcript); let K = final_ram_state.len(); debug_assert_eq!(initial_ram_state.len(), final_ram_state.len()); @@ -263,11 +264,8 @@ pub struct OutputSumcheckVerifier { } impl OutputSumcheckVerifier { - pub fn new( - state_manager: &mut StateManager<'_, F, impl CommitmentScheme>, - transcript: &mut impl Transcript, - ) -> Self { - let params = OutputSumcheckParams::new(state_manager, transcript); + pub fn new(ram_K: usize, program_io: &JoltDevice, transcript: &mut impl Transcript) -> Self { + let params = OutputSumcheckParams::new(ram_K, program_io, transcript); Self { params } } } @@ -364,17 +362,12 @@ struct OutputSumcheckParams { } impl OutputSumcheckParams { - pub fn new( - state_manager: &mut StateManager<'_, F, impl CommitmentScheme>, - transcript: &mut impl Transcript, - ) -> Self { - let program_io = state_manager.program_io.clone(); - let K = state_manager.ram_K; - let r_address = transcript.challenge_vector_optimized::(K.log_2()); + pub fn new(ram_K: usize, program_io: &JoltDevice, transcript: &mut impl Transcript) -> Self { + let r_address = transcript.challenge_vector_optimized::(ram_K.log_2()); Self { - K, + K: ram_K, r_address, - program_io, + program_io: program_io.clone(), } } @@ -568,11 +561,11 @@ pub struct ValFinalSumcheckVerifier { impl ValFinalSumcheckVerifier { pub fn new( initial_ram_state: &[u64], - state_manager: &mut StateManager<'_, F, impl CommitmentScheme>, + program_io: &JoltDevice, + trace_len: usize, + ram_K: usize, opening_accumulator: &VerifierOpeningAccumulator, ) -> Self { - let (_, program_io, T) = state_manager.get_verifier_data(); - let r_address = opening_accumulator .get_virtual_polynomial_opening( VirtualPolynomial::RamValFinal, @@ -581,6 +574,8 @@ impl ValFinalSumcheckVerifier { .0 .r; + let n_memory_vars = ram_K.log_2(); + { // Verify that val_evaluation and output_check use the same opening point for initial_ram_state. // This allows us to reuse a single untrusted_advice opening instead of providing two. @@ -588,12 +583,10 @@ impl ValFinalSumcheckVerifier { VirtualPolynomial::RamVal, SumcheckId::RamReadWriteChecking, ); - let (r_address_val_evaluation, _) = r.split_at(state_manager.ram_K.log_2()); + let (r_address_val_evaluation, _) = r.split_at(n_memory_vars); assert_eq!(r_address_val_evaluation.r, r_address); } - let total_memory_vars = state_manager.ram_K.log_2(); - let untrusted_advice_contribution = super::calculate_advice_memory_evaluation( opening_accumulator.get_untrusted_advice_opening(), (program_io.memory_layout.max_untrusted_advice_size as usize / 8) @@ -602,7 +595,7 @@ impl ValFinalSumcheckVerifier { program_io.memory_layout.untrusted_advice_start, &program_io.memory_layout, &r_address, - total_memory_vars, + n_memory_vars, ); let trusted_advice_contribution = super::calculate_advice_memory_evaluation( @@ -613,7 +606,7 @@ impl ValFinalSumcheckVerifier { program_io.memory_layout.trusted_advice_start, &program_io.memory_layout, &r_address, - total_memory_vars, + n_memory_vars, ); // Compute the public part of val_init evaluation @@ -625,7 +618,10 @@ impl ValFinalSumcheckVerifier { + trusted_advice_contribution + val_init_public.evaluate(&r_address); - let params = ValFinalSumcheckParams { T, val_init_eval }; + let params = ValFinalSumcheckParams { + T: trace_len, + val_init_eval, + }; Self { params } } diff --git a/jolt-core/src/zkvm/ram/ra_virtual.rs b/jolt-core/src/zkvm/ram/ra_virtual.rs index 435c07234..f979e1463 100644 --- a/jolt-core/src/zkvm/ram/ra_virtual.rs +++ b/jolt-core/src/zkvm/ram/ra_virtual.rs @@ -60,7 +60,14 @@ impl RaSumcheckProver { opening_accumulator: &ProverOpeningAccumulator, transcript: &mut impl Transcript, ) -> Self { - let params = RaSumcheckParams::new(state_manager, opening_accumulator, transcript); + let (preprocessing, _, trace, _, _) = state_manager.get_prover_data(); + + let params = RaSumcheckParams::new( + trace.len(), + state_manager.ram_K, + opening_accumulator, + transcript, + ); // Precompute EQ tables for each chunk let eq_tables: Vec> = params @@ -78,7 +85,6 @@ impl RaSumcheckProver { DensePolynomial::linear_combination(&eq_polys.each_ref(), ¶ms.gamma_powers).Z, ); - let (preprocessing, _, trace, _, _) = state_manager.get_prover_data(); let ra_i_polys: Vec> = (0..params.d) .into_par_iter() .zip(eq_tables.into_par_iter()) @@ -211,11 +217,12 @@ pub struct RaSumcheckVerifier { impl RaSumcheckVerifier { pub fn new( - state_manager: &mut StateManager<'_, F, impl CommitmentScheme>, + trace_len: usize, + ram_K: usize, opening_accumulator: &VerifierOpeningAccumulator, transcript: &mut impl Transcript, ) -> Self { - let params = RaSumcheckParams::new(state_manager, opening_accumulator, transcript); + let params = RaSumcheckParams::new(trace_len, ram_K, opening_accumulator, transcript); Self { params } } } @@ -289,15 +296,14 @@ struct RaSumcheckParams { impl RaSumcheckParams { fn new( - state_manager: &mut StateManager<'_, F, impl CommitmentScheme>, + trace_len: usize, + ram_K: usize, opening_accumulator: &dyn OpeningAccumulator, transcript: &mut impl Transcript, ) -> Self { // Calculate d dynamically such that 2^8 = K^(1/D) - let d = compute_d_parameter(state_manager.ram_K); - let log_K = state_manager.ram_K.log_2(); - - let T = state_manager.get_trace_len(); + let d = compute_d_parameter(ram_K); + let log_K = ram_K.log_2(); // These two sumchecks have the same binding order and number of rounds, // and they're run in parallel, so the openings are the same. @@ -360,7 +366,7 @@ impl RaSumcheckParams { Self { gamma_powers, - T, + T: trace_len, d, r_cycle, r_address_chunks, diff --git a/jolt-core/src/zkvm/ram/raf_evaluation.rs b/jolt-core/src/zkvm/ram/raf_evaluation.rs index eb596cfe4..31b8d3a11 100644 --- a/jolt-core/src/zkvm/ram/raf_evaluation.rs +++ b/jolt-core/src/zkvm/ram/raf_evaluation.rs @@ -4,6 +4,7 @@ use allocative::Allocative; #[cfg(feature = "allocative")] use allocative::FlameGraphBuilder; use rayon::prelude::*; +use tracer::JoltDevice; use crate::{ field::JoltField, @@ -56,13 +57,13 @@ impl RafEvaluationSumcheckProver { state_manager: &mut StateManager<'_, F, impl CommitmentScheme>, opening_accumulator: &ProverOpeningAccumulator, ) -> Self { - let params = RafEvaluationSumcheckParams::new(state_manager, opening_accumulator); - let (_, _, trace, program_io, _) = state_manager.get_prover_data(); let memory_layout = &program_io.memory_layout; let K = state_manager.ram_K; let T = trace.len(); + let params = RafEvaluationSumcheckParams::new(program_io, K, opening_accumulator); + let num_chunks = rayon::current_num_threads().next_power_of_two().min(T); let chunk_size = (T / num_chunks).max(1); @@ -187,10 +188,11 @@ pub struct RafEvaluationSumcheckVerifier { impl RafEvaluationSumcheckVerifier { pub fn new( - state_manager: &mut StateManager<'_, F, impl CommitmentScheme>, + program_io: &JoltDevice, + ram_K: usize, opening_accumulator: &VerifierOpeningAccumulator, ) -> Self { - let params = RafEvaluationSumcheckParams::new(state_manager, opening_accumulator); + let params = RafEvaluationSumcheckParams::new(program_io, ram_K, opening_accumulator); Self { params } } } @@ -255,11 +257,12 @@ pub struct RafEvaluationSumcheckParams { impl RafEvaluationSumcheckParams { pub fn new( - state_manager: &mut StateManager<'_, F, impl CommitmentScheme>, + program_io: &JoltDevice, + ram_K: usize, opening_accumulator: &dyn OpeningAccumulator, ) -> Self { - let start_address = state_manager.program_io.memory_layout.get_lowest_address(); - let log_K = state_manager.ram_K.log_2(); + let start_address = program_io.memory_layout.get_lowest_address(); + let log_K = ram_K.log_2(); let (r_cycle, _) = opening_accumulator.get_virtual_polynomial_opening( VirtualPolynomial::RamAddress, SumcheckId::SpartanOuter, diff --git a/jolt-core/src/zkvm/ram/read_write_checking.rs b/jolt-core/src/zkvm/ram/read_write_checking.rs index f92af08ec..8a7f6655a 100644 --- a/jolt-core/src/zkvm/ram/read_write_checking.rs +++ b/jolt-core/src/zkvm/ram/read_write_checking.rs @@ -107,10 +107,16 @@ impl RamReadWriteCheckingProver { opening_accumulator: &ProverOpeningAccumulator, transcript: &mut impl Transcript, ) -> Self { - let params = ReadWriteCheckingParams::new(state_manager, opening_accumulator, transcript); - let (preprocessing, _, trace, program_io, _) = state_manager.get_prover_data(); + let params = ReadWriteCheckingParams::new( + state_manager.ram_K, + trace.len(), + state_manager.twist_sumcheck_switch_index, + opening_accumulator, + transcript, + ); + let r_prime = opening_accumulator .get_virtual_polynomial_opening( VirtualPolynomial::RamReadValue, @@ -1002,12 +1008,20 @@ pub struct RamReadWriteCheckingVerifier { impl RamReadWriteCheckingVerifier { pub fn new( - state_manager: &mut StateManager<'_, F, impl CommitmentScheme>, + ram_K: usize, + trace_len: usize, + twist_sumcheck_switch_index: usize, opening_accumulator: &dyn OpeningAccumulator, transcript: &mut impl Transcript, ) -> Self { Self { - params: ReadWriteCheckingParams::new(state_manager, opening_accumulator, transcript), + params: ReadWriteCheckingParams::new( + ram_K, + trace_len, + twist_sumcheck_switch_index, + opening_accumulator, + transcript, + ), } } } @@ -1087,29 +1101,28 @@ struct ReadWriteCheckingParams { T: usize, gamma: F, r_cycle_stage_1: OpeningPoint, - sumcheck_switch_index: usize, + twist_sumcheck_switch_index: usize, } impl ReadWriteCheckingParams { pub fn new( - state_manager: &mut StateManager<'_, F, impl CommitmentScheme>, + ram_K: usize, + trace_len: usize, + twist_sumcheck_switch_index: usize, opening_accumulator: &dyn OpeningAccumulator, transcript: &mut impl Transcript, ) -> Self { - let K = state_manager.ram_K; - let T = state_manager.get_trace_len(); - let sumcheck_switch_index = state_manager.twist_sumcheck_switch_index; let gamma = transcript.challenge_scalar(); let (r_cycle_stage_1, _) = opening_accumulator.get_virtual_polynomial_opening( VirtualPolynomial::RamReadValue, SumcheckId::SpartanOuter, ); Self { - K, - T, + K: ram_K, + T: trace_len, gamma, r_cycle_stage_1, - sumcheck_switch_index, + twist_sumcheck_switch_index, } } @@ -1133,7 +1146,7 @@ impl ReadWriteCheckingParams { &self, sumcheck_challenges: &[F::Challenge], ) -> OpeningPoint { - let sumcheck_switch_index = self.sumcheck_switch_index; + let sumcheck_switch_index = self.twist_sumcheck_switch_index; // The high-order cycle variables are bound after the switch let mut r_cycle = sumcheck_challenges[sumcheck_switch_index..self.T.log_2()].to_vec(); // First `sumcheck_switch_index` rounds bind cycle variables from low to high diff --git a/jolt-core/src/zkvm/ram/val_evaluation.rs b/jolt-core/src/zkvm/ram/val_evaluation.rs index 95fa14b83..0e4fb517a 100644 --- a/jolt-core/src/zkvm/ram/val_evaluation.rs +++ b/jolt-core/src/zkvm/ram/val_evaluation.rs @@ -1,6 +1,7 @@ use itertools::chain; use num_traits::Zero; use std::{array, iter::zip, sync::Arc}; +use tracer::JoltDevice; use crate::{ field::JoltField, @@ -226,19 +227,18 @@ pub struct ValEvaluationSumcheckVerifier { impl ValEvaluationSumcheckVerifier { pub fn new( initial_ram_state: &[u64], - state_manager: &mut StateManager<'_, F, impl CommitmentScheme>, + program_io: &JoltDevice, + trace_len: usize, + ram_K: usize, opening_accumulator: &VerifierOpeningAccumulator, ) -> Self { - let (_, program_io, T) = state_manager.get_verifier_data(); - let K = state_manager.ram_K; - let (r, _) = opening_accumulator.get_virtual_polynomial_opening( VirtualPolynomial::RamVal, SumcheckId::RamReadWriteChecking, ); - let (r_address, _) = r.split_at(K.log_2()); + let (r_address, _) = r.split_at(ram_K.log_2()); - let total_memory_vars = K.log_2(); + let n_memory_vars = ram_K.log_2(); // Calculate untrusted advice contribution let untrusted_contribution = super::calculate_advice_memory_evaluation( @@ -249,7 +249,7 @@ impl ValEvaluationSumcheckVerifier { program_io.memory_layout.untrusted_advice_start, &program_io.memory_layout, &r_address.r, - total_memory_vars, + n_memory_vars, ); // Calculate trusted advice contribution @@ -261,7 +261,7 @@ impl ValEvaluationSumcheckVerifier { program_io.memory_layout.trusted_advice_start, &program_io.memory_layout, &r_address.r, - total_memory_vars, + n_memory_vars, ); // Compute the public part of val_init evaluation @@ -272,7 +272,11 @@ impl ValEvaluationSumcheckVerifier { let init_eval = untrusted_contribution + trusted_contribution + val_init_public.evaluate(&r_address.r); - let params = ValEvaluationSumcheckParams { init_eval, T, K }; + let params = ValEvaluationSumcheckParams { + init_eval, + T: trace_len, + K: ram_K, + }; Self { params } } diff --git a/jolt-core/src/zkvm/registers/mod.rs b/jolt-core/src/zkvm/registers/mod.rs index a52f714ba..b1e9b0117 100644 --- a/jolt-core/src/zkvm/registers/mod.rs +++ b/jolt-core/src/zkvm/registers/mod.rs @@ -3,23 +3,15 @@ use crate::utils::profiling::print_data_structure_heap_usage; use crate::{ field::JoltField, poly::{ - commitment::commitment_scheme::CommitmentScheme, - opening_proof::{ProverOpeningAccumulator, VerifierOpeningAccumulator}, - }, - subprotocols::{ - sumcheck_prover::SumcheckInstanceProver, sumcheck_verifier::SumcheckInstanceVerifier, + commitment::commitment_scheme::CommitmentScheme, opening_proof::ProverOpeningAccumulator, }, + subprotocols::sumcheck_prover::SumcheckInstanceProver, transcripts::Transcript, zkvm::{ - dag::{ - stage::{SumcheckStagesProver, SumcheckStagesVerifier}, - state_manager::StateManager, - }, + dag::{stage::SumcheckStagesProver, state_manager::StateManager}, registers::{ - read_write_checking::{ - RegistersReadWriteCheckingProver, RegistersReadWriteCheckingVerifier, - }, - val_evaluation::{ValEvaluationSumcheckProver, ValEvaluationSumcheckVerifier}, + read_write_checking::RegistersReadWriteCheckingProver, + val_evaluation::ValEvaluationSumcheckProver, }, }, }; @@ -60,30 +52,3 @@ impl vec![Box::new(val_evaluation)] } } - -pub struct RegistersDagVerifier; - -impl> - SumcheckStagesVerifier for RegistersDagVerifier -{ - fn stage4_instances( - &mut self, - state_manager: &mut StateManager<'_, F, PCS>, - opening_accumulator: &mut VerifierOpeningAccumulator, - transcript: &mut ProofTranscript, - ) -> Vec>> { - let read_write_checking = - RegistersReadWriteCheckingVerifier::new(state_manager, opening_accumulator, transcript); - vec![Box::new(read_write_checking)] - } - - fn stage5_instances( - &mut self, - state_manager: &mut StateManager<'_, F, PCS>, - _opening_accumulator: &mut VerifierOpeningAccumulator, - _transcript: &mut ProofTranscript, - ) -> Vec>> { - let val_evaluation = ValEvaluationSumcheckVerifier::new(state_manager); - vec![Box::new(val_evaluation)] - } -} diff --git a/jolt-core/src/zkvm/registers/read_write_checking.rs b/jolt-core/src/zkvm/registers/read_write_checking.rs index 290762d74..378021569 100644 --- a/jolt-core/src/zkvm/registers/read_write_checking.rs +++ b/jolt-core/src/zkvm/registers/read_write_checking.rs @@ -118,11 +118,15 @@ impl RegistersReadWriteCheckingProver { opening_accumulator: &ProverOpeningAccumulator, transcript: &mut impl Transcript, ) -> Self { - let params = - RegistersReadWriteCheckingParams::new(state_manager, opening_accumulator, transcript); - let (preprocessing, _, trace, _, _) = state_manager.get_prover_data(); + let params = RegistersReadWriteCheckingParams::new( + state_manager.twist_sumcheck_switch_index, + trace.len().log_2(), + opening_accumulator, + transcript, + ); + let T = trace.len(); let num_chunks = rayon::current_num_threads().next_power_of_two().min(T); let chunk_size = T / num_chunks; @@ -1388,12 +1392,17 @@ pub struct RegistersReadWriteCheckingVerifier { impl RegistersReadWriteCheckingVerifier { pub fn new( - state_manager: &mut StateManager<'_, F, impl CommitmentScheme>, + twist_sumcheck_switch_index: usize, + n_cycle_vars: usize, opening_accumulator: &VerifierOpeningAccumulator, transcript: &mut impl Transcript, ) -> Self { - let params = - RegistersReadWriteCheckingParams::new(state_manager, opening_accumulator, transcript); + let params = RegistersReadWriteCheckingParams::new( + twist_sumcheck_switch_index, + n_cycle_vars, + opening_accumulator, + transcript, + ); Self { params } } } @@ -1502,7 +1511,7 @@ struct RegistersReadWriteCheckingParams { gamma: F, /// Equals `gamma^3`. gamma_cub: F, - sumcheck_switch_index: usize, + twist_sumcheck_switch_index: usize, n_cycle_vars: usize, // = log(T) r_cycle_stage_1: OpeningPoint, r_cycle_stage_3: OpeningPoint, @@ -1510,26 +1519,23 @@ struct RegistersReadWriteCheckingParams { impl RegistersReadWriteCheckingParams { pub fn new( - state_manager: &mut StateManager<'_, F, impl CommitmentScheme>, + twist_sumcheck_switch_index: usize, + n_cycle_vars: usize, opening_accumulator: &dyn OpeningAccumulator, transcript: &mut impl Transcript, ) -> Self { let gamma = transcript.challenge_scalar::(); let gamma_cub = gamma.square() * gamma; - let sumcheck_switch_index = state_manager.twist_sumcheck_switch_index; - let n_cycle_vars = state_manager.get_trace_len().log_2(); - let (r_cycle_stage_1, _) = opening_accumulator .get_virtual_polynomial_opening(VirtualPolynomial::Rs1Value, SumcheckId::SpartanOuter); let (r_cycle_stage_3, _) = opening_accumulator.get_virtual_polynomial_opening( VirtualPolynomial::Rs1Value, SumcheckId::InstructionInputVirtualization, ); - Self { gamma, gamma_cub, - sumcheck_switch_index, + twist_sumcheck_switch_index, n_cycle_vars, r_cycle_stage_1, r_cycle_stage_3, @@ -1569,7 +1575,7 @@ impl RegistersReadWriteCheckingParams { &self, sumcheck_challenges: &[F::Challenge], ) -> OpeningPoint { - let sumcheck_switch_index = self.sumcheck_switch_index; + let sumcheck_switch_index = self.twist_sumcheck_switch_index; let n_cycle_vars = self.n_cycle_vars; // The high-order cycle variables are bound after the switch let mut r_cycle = sumcheck_challenges[sumcheck_switch_index..n_cycle_vars].to_vec(); diff --git a/jolt-core/src/zkvm/registers/val_evaluation.rs b/jolt-core/src/zkvm/registers/val_evaluation.rs index 63bfe3d60..705519606 100644 --- a/jolt-core/src/zkvm/registers/val_evaluation.rs +++ b/jolt-core/src/zkvm/registers/val_evaluation.rs @@ -66,8 +66,6 @@ impl ValEvaluationSumcheckProver { state_manager: &mut StateManager<'_, F, PCS>, opening_accumulator: &ProverOpeningAccumulator, ) -> Self { - let params = ValEvaluationSumcheckParams::new(state_manager); - // The opening point is r_address || r_cycle let registers_val_input_sample = opening_accumulator.get_virtual_polynomial_opening( VirtualPolynomial::RegistersVal, @@ -76,6 +74,7 @@ impl ValEvaluationSumcheckProver { let (r_address, r_cycle) = registers_val_input_sample.0.split_at(LOG_K); let (preprocessing, _, trace, _, _) = state_manager.get_prover_data(); + let params = ValEvaluationSumcheckParams::new(trace.len().log_2()); let inc = CommittedPolynomial::RdInc.generate_witness(preprocessing, trace, state_manager.ram_d); @@ -203,8 +202,8 @@ pub struct ValEvaluationSumcheckVerifier { } impl ValEvaluationSumcheckVerifier { - pub fn new(state_manager: &mut StateManager<'_, F, impl CommitmentScheme>) -> Self { - let params = ValEvaluationSumcheckParams::new(state_manager); + pub fn new(n_cycle_vars: usize) -> Self { + let params = ValEvaluationSumcheckParams::new(n_cycle_vars); Self { params } } } @@ -295,9 +294,9 @@ struct ValEvaluationSumcheckParams { } impl ValEvaluationSumcheckParams { - pub fn new(state_manager: &mut StateManager<'_, F, impl CommitmentScheme>) -> Self { + pub fn new(n_cycle_vars: usize) -> Self { Self { - n_cycle_vars: state_manager.get_trace_len().log_2(), + n_cycle_vars, _phantom: PhantomData, } } diff --git a/jolt-core/src/zkvm/spartan/inner.rs b/jolt-core/src/zkvm/spartan/inner.rs index 3073c54df..e91dd0652 100644 --- a/jolt-core/src/zkvm/spartan/inner.rs +++ b/jolt-core/src/zkvm/spartan/inner.rs @@ -182,19 +182,21 @@ impl SumcheckInstanceProver for InnerSumcheck /// where `segment_evals` are the cached witness openings at `r_cycle` from the outer stage. /// /// Then `expected = (eval_a + γ·eval_b) · eval_z`, and accept iff output_claim == expected. -pub struct InnerSumcheckVerifier { +pub struct InnerSumcheckVerifier<'a, F: JoltField> { params: InnerSumcheckParams, - key: Arc>, + key: &'a UniformSpartanKey, } -impl InnerSumcheckVerifier { - pub fn new(key: Arc>, transcript: &mut impl Transcript) -> Self { +impl<'a, F: JoltField> InnerSumcheckVerifier<'a, F> { + pub fn new(key: &'a UniformSpartanKey, transcript: &mut impl Transcript) -> Self { let params = InnerSumcheckParams::new(transcript); Self { params, key } } } -impl SumcheckInstanceVerifier for InnerSumcheckVerifier { +impl<'a, F: JoltField, T: Transcript> SumcheckInstanceVerifier + for InnerSumcheckVerifier<'a, F> +{ fn degree(&self) -> usize { DEGREE_BOUND } diff --git a/jolt-core/src/zkvm/spartan/mod.rs b/jolt-core/src/zkvm/spartan/mod.rs index 6ceb2e51f..817c89e1f 100644 --- a/jolt-core/src/zkvm/spartan/mod.rs +++ b/jolt-core/src/zkvm/spartan/mod.rs @@ -7,29 +7,23 @@ use crate::poly::opening_proof::{ }; use crate::subprotocols::sumcheck::UniSkipFirstRoundProof; use crate::subprotocols::sumcheck_prover::SumcheckInstanceProver; -use crate::subprotocols::sumcheck_verifier::SumcheckInstanceVerifier; use crate::subprotocols::univariate_skip::{prove_uniskip_round, UniSkipState}; use crate::transcripts::Transcript; #[cfg(feature = "allocative")] use crate::utils::profiling::print_data_structure_heap_usage; -use crate::zkvm::dag::stage::{SumcheckStagesProver, SumcheckStagesVerifier}; +use crate::zkvm::dag::stage::SumcheckStagesProver; use crate::zkvm::dag::state_manager::StateManager; use crate::zkvm::r1cs::constraints::{ OUTER_FIRST_ROUND_POLY_NUM_COEFFS, OUTER_UNIVARIATE_SKIP_DOMAIN_SIZE, }; use crate::zkvm::r1cs::key::UniformSpartanKey; -use crate::zkvm::spartan::inner::{InnerSumcheckProver, InnerSumcheckVerifier}; -use crate::zkvm::spartan::instruction_input::{ - InstructionInputSumcheckProver, InstructionInputSumcheckVerifier, -}; -use crate::zkvm::spartan::outer::{ - OuterRemainingSumcheckProver, OuterRemainingSumcheckVerifier, OuterUniSkipInstanceProver, -}; +use crate::zkvm::spartan::inner::InnerSumcheckProver; +use crate::zkvm::spartan::instruction_input::InstructionInputSumcheckProver; +use crate::zkvm::spartan::outer::{OuterRemainingSumcheckProver, OuterUniSkipInstanceProver}; use crate::zkvm::spartan::product::{ - ProductVirtualInnerProver, ProductVirtualInnerVerifier, ProductVirtualRemainderProver, - ProductVirtualRemainderVerifier, ProductVirtualUniSkipInstanceParams, + ProductVirtualInnerProver, ProductVirtualRemainderProver, ProductVirtualUniSkipInstanceParams, }; -use crate::zkvm::spartan::shift::{ShiftSumcheckProver, ShiftSumcheckVerifier}; +use crate::zkvm::spartan::shift::ShiftSumcheckProver; use crate::zkvm::witness::VirtualPolynomial; use product::{ @@ -44,13 +38,20 @@ pub mod product; pub mod shift; pub struct SpartanDagProver { - state: SpartanDagState, + /// Cached key to avoid recomputation across stages + key: Arc>, + /// Handoff state from univariate skip first round (shared by prover and verifier) + /// Consists of the `tau` vector for Lagrange / eq evals, the claim from univariate skip round, + /// and the challenge r0 from the univariate skip round + /// This is first used in stage 1 and then reused in stage 2 + uni_skip_state: Option>, } impl SpartanDagProver { pub fn new(padded_trace_length: usize) -> Self { Self { - state: SpartanDagState::new(padded_trace_length), + key: Arc::new(UniformSpartanKey::new(padded_trace_length)), + uni_skip_state: None, } } @@ -61,7 +62,7 @@ impl SpartanDagProver { _opening_accumulator: &mut ProverOpeningAccumulator, transcript: &mut T, ) -> UniSkipFirstRoundProof { - let num_rounds_x: usize = self.state.key.num_rows_bits(); + let num_rounds_x: usize = self.key.num_rows_bits(); // Transcript and tau let tau = transcript.challenge_vector_optimized::(num_rounds_x); @@ -71,7 +72,7 @@ impl SpartanDagProver { let (first_round_proof, r0, claim_after_first) = prove_uniskip_round(&mut uniskip_instance, transcript); - self.state.uni_skip_state = Some(UniSkipState { + self.uni_skip_state = Some(UniSkipState { claim_after_first, r0, tau, @@ -87,7 +88,7 @@ impl SpartanDagProver { opening_accumulator: &mut ProverOpeningAccumulator, transcript: &mut T, ) -> UniSkipFirstRoundProof { - let num_cycle_vars: usize = self.state.key.num_cycle_vars(); + let num_cycle_vars: usize = self.key.num_cycle_vars(); // Reuse r_cycle from Stage 1 (outer) for τ_low, and sample τ_high let r_cycle = opening_accumulator @@ -104,7 +105,7 @@ impl SpartanDagProver { let (first_round_proof, r0, claim_after_first) = prove_uniskip_round(&mut uniskip_instance, transcript); - self.state.uni_skip_state = Some(UniSkipState { + self.uni_skip_state = Some(UniSkipState { claim_after_first, r0, tau, @@ -127,9 +128,10 @@ where ) -> Vec>> { // Stage 1 remainder: outer-remaining let mut instances: Vec>> = Vec::new(); - if let Some(st) = self.state.uni_skip_state.take() { - let n_cycles = self.state.key.num_cycle_vars(); - let outer_remaining = OuterRemainingSumcheckProver::gen(state_manager, n_cycles, &st); + if let Some(st) = self.uni_skip_state.take() { + let n_cycle_vars = self.key.num_cycle_vars(); + let outer_remaining = + OuterRemainingSumcheckProver::gen(state_manager, n_cycle_vars, &st); instances.push(Box::new(outer_remaining)); } instances @@ -142,15 +144,14 @@ where transcript: &mut ProofTranscript, ) -> Vec>> { // Stage 2 remainder: inner + product remainder - let key = self.state.key.clone(); + let key = self.key.clone(); let inner_sumcheck = InnerSumcheckProver::gen(opening_accumulator, key, transcript); let st = self - .state .uni_skip_state .take() .expect("stage2_prover_uni_skip must run before stage2_prover_instances"); - let n_cycle_vars = self.state.key.num_cycle_vars(); + let n_cycle_vars = self.key.num_cycle_vars(); let product_virtual_remainder = ProductVirtualRemainderProver::gen(state_manager, n_cycle_vars, &st); @@ -174,9 +175,8 @@ where 1. NextUnexpandedPC(r_cycle) = \sum_t UnexpandedPC(t) * eq_plus_one(r_cycle, t) 2. NextPC(r_cycle) = \sum_t PC(t) * eq_plus_one(r_cycle, t) */ - let key = self.state.key.clone(); let shift_sumcheck = - ShiftSumcheckProver::gen(state_manager, opening_accumulator, key, transcript); + ShiftSumcheckProver::gen(state_manager, opening_accumulator, transcript); let instruction_input_sumcheck = InstructionInputSumcheckProver::gen(state_manager, opening_accumulator, transcript); let product_virtual_claim_check = @@ -199,168 +199,63 @@ where } } -pub struct SpartanDagVerifier { - state: SpartanDagState, -} - -impl SpartanDagVerifier { - pub fn new(padded_trace_length: usize) -> Self { - Self { - state: SpartanDagState::new(padded_trace_length), - } - } - - /// Stage 1a: Verify first round of Spartan outer sum-check with univariate skip - pub fn stage1_uni_skip( - &mut self, - proof: &UniSkipFirstRoundProof, - transcript: &mut T, - ) -> Result<(), anyhow::Error> { - let key = self.state.key.clone(); - let num_rounds_x = key.num_rows_bits(); - - let tau = transcript.challenge_vector_optimized::(num_rounds_x); - - let input_claim = F::zero(); - let (r0, claim_after_first) = proof - .verify::( - OUTER_FIRST_ROUND_POLY_NUM_COEFFS - 1, - input_claim, - transcript, - ) - .map_err(|_| anyhow::anyhow!("UniSkip first-round verification failed"))?; - - self.state.uni_skip_state = Some(UniSkipState { - claim_after_first, - r0, - tau, - }); - - Ok(()) - } - - pub fn stage2_uni_skip( - &mut self, - proof: &UniSkipFirstRoundProof, - opening_accumulator: &mut VerifierOpeningAccumulator, - transcript: &mut T, - ) -> Result<(), anyhow::Error> { - let num_cycle_vars: usize = self.state.key.num_cycle_vars(); - - // Reuse r_cycle from Stage 1 (outer) for τ_low, and sample τ_high - let r_cycle = opening_accumulator - .get_virtual_polynomial_opening(VirtualPolynomial::Product, SumcheckId::SpartanOuter) - .0 - .r; - debug_assert_eq!(r_cycle.len(), num_cycle_vars); - let tau_high: F::Challenge = transcript.challenge_scalar_optimized::(); - let mut tau: Vec = r_cycle; - tau.push(tau_high); - - let uniskip_params = ProductVirtualUniSkipInstanceParams::new(opening_accumulator, &tau); - let input_claim = uniskip_params.input_claim(); - let (r0, claim_after_first) = proof - .verify::( - PRODUCT_VIRTUAL_FIRST_ROUND_POLY_NUM_COEFFS - 1, - input_claim, - transcript, - ) - .map_err(|_| anyhow::anyhow!("ProductVirtual uni-skip first-round verification failed"))?; - - self.state.uni_skip_state = Some(UniSkipState { - claim_after_first, - r0, - tau, - }); - Ok(()) - } -} - -impl SumcheckStagesVerifier - for SpartanDagVerifier -where - F: JoltField, - ProofTranscript: Transcript, - PCS: CommitmentScheme, -{ - fn stage1_instances( - &mut self, - _state_manager: &mut StateManager<'_, F, PCS>, - _opening_accumulator: &mut VerifierOpeningAccumulator, - _transcript: &mut ProofTranscript, - ) -> Vec>> { - // Stage 1 remainder: outer-remaining (verifier) - let mut instances: Vec>> = Vec::new(); - if let Some(st) = self.state.uni_skip_state.take() { - let num_cycles_bits = self.state.key.num_steps.ilog2() as usize; - let outer_remaining = OuterRemainingSumcheckVerifier::new(num_cycles_bits, &st); - instances.push(Box::new(outer_remaining)); - } - instances - } - - fn stage2_instances( - &mut self, - _state_manager: &mut StateManager<'_, F, PCS>, - _opening_accumulator: &mut VerifierOpeningAccumulator, - transcript: &mut ProofTranscript, - ) -> Vec>> { - // Stage 2 remainder (verifier side) - let num_cycle_vars = self.state.key.num_cycle_vars(); - let inner_sumcheck = InnerSumcheckVerifier::new(self.state.key.clone(), transcript); - - let st = self - .state - .uni_skip_state - .take() - .expect("stage2_uni_skip must run before stage2_verifier_instances"); - let product_virtual_remainder = ProductVirtualRemainderVerifier::new(num_cycle_vars, &st); - - vec![ - Box::new(inner_sumcheck), - Box::new(product_virtual_remainder), - ] - } - - fn stage3_instances( - &mut self, - _state_manager: &mut StateManager<'_, F, PCS>, - opening_accumulator: &mut VerifierOpeningAccumulator, - transcript: &mut ProofTranscript, - ) -> Vec>> { - /* Sumcheck 3: Batched sumcheck for NextUnexpandedPC and NextPC verification - Verifies the batched constraint for both NextUnexpandedPC and NextPC - */ - let key = self.state.key.clone(); - let shift_sumcheck = ShiftSumcheckVerifier::new(opening_accumulator, key, transcript); - let instruction_input_sumcheck = - InstructionInputSumcheckVerifier::new(opening_accumulator, transcript); - let product_virtual_claim_check = - ProductVirtualInnerVerifier::new(opening_accumulator, transcript); - vec![ - Box::new(shift_sumcheck), - Box::new(instruction_input_sumcheck), - Box::new(product_virtual_claim_check), - ] - } -} - -struct SpartanDagState { - /// Cached key to avoid recomputation across stages - key: Arc>, - /// Handoff state from univariate skip first round (shared by prover and verifier) - /// Consists of the `tau` vector for Lagrange / eq evals, the claim from univariate skip round, - /// and the challenge r0 from the univariate skip round - /// This is first used in stage 1 and then reused in stage 2 - uni_skip_state: Option>, +/// Stage 1a: Verify first round of Spartan outer sum-check with univariate skip +pub fn verify_stage1_uni_skip( + proof: &UniSkipFirstRoundProof, + key: &UniformSpartanKey, + transcript: &mut T, +) -> Result, anyhow::Error> { + let num_rounds_x = key.num_rows_bits(); + + let tau = transcript.challenge_vector_optimized::(num_rounds_x); + + let input_claim = F::zero(); + let (r0, claim_after_first) = proof + .verify::( + OUTER_FIRST_ROUND_POLY_NUM_COEFFS - 1, + input_claim, + transcript, + ) + .map_err(|_| anyhow::anyhow!("UniSkip first-round verification failed"))?; + + Ok(UniSkipState { + claim_after_first, + r0, + tau, + }) } -impl SpartanDagState { - pub fn new(padded_trace_length: usize) -> Self { - let key = Arc::new(UniformSpartanKey::new(padded_trace_length)); - Self { - key, - uni_skip_state: None, - } - } +pub fn verify_stage2_uni_skip( + proof: &UniSkipFirstRoundProof, + key: &UniformSpartanKey, + opening_accumulator: &mut VerifierOpeningAccumulator, + transcript: &mut T, +) -> Result, anyhow::Error> { + let num_cycle_vars: usize = key.num_cycle_vars(); + + // Reuse r_cycle from Stage 1 (outer) for τ_low, and sample τ_high + let r_cycle = opening_accumulator + .get_virtual_polynomial_opening(VirtualPolynomial::Product, SumcheckId::SpartanOuter) + .0 + .r; + debug_assert_eq!(r_cycle.len(), num_cycle_vars); + let tau_high: F::Challenge = transcript.challenge_scalar_optimized::(); + let mut tau: Vec = r_cycle; + tau.push(tau_high); + + let uniskip_params = ProductVirtualUniSkipInstanceParams::new(opening_accumulator, &tau); + let input_claim = uniskip_params.input_claim(); + let (r0, claim_after_first) = proof + .verify::( + PRODUCT_VIRTUAL_FIRST_ROUND_POLY_NUM_COEFFS - 1, + input_claim, + transcript, + ) + .map_err(|_| anyhow::anyhow!("ProductVirtual uni-skip first-round verification failed"))?; + + Ok(UniSkipState { + claim_after_first, + r0, + tau, + }) } diff --git a/jolt-core/src/zkvm/spartan/shift.rs b/jolt-core/src/zkvm/spartan/shift.rs index 56a9319a9..fc03283b4 100644 --- a/jolt-core/src/zkvm/spartan/shift.rs +++ b/jolt-core/src/zkvm/spartan/shift.rs @@ -21,7 +21,6 @@ use crate::transcripts::Transcript; use crate::zkvm::bytecode::BytecodePreprocessing; use crate::zkvm::dag::state_manager::StateManager; use crate::zkvm::instruction::{CircuitFlags, Flags, InstructionFlags}; -use crate::zkvm::r1cs::key::UniformSpartanKey; use crate::zkvm::witness::VirtualPolynomial; use rayon::prelude::*; @@ -56,11 +55,11 @@ impl ShiftSumcheckProver { pub fn gen( state_manager: &mut StateManager<'_, F, impl CommitmentScheme>, opening_accumulator: &ProverOpeningAccumulator, - key: Arc>, transcript: &mut impl Transcript, ) -> Self { - let (preprocessing, _, _, _, _) = state_manager.get_prover_data(); - let params = ShiftSumcheckParams::new(opening_accumulator, key, transcript); + let (preprocessing, _, trace, _, _) = state_manager.get_prover_data(); + let n_cycle_vars = trace.len().ilog2() as usize; + let params = ShiftSumcheckParams::new(n_cycle_vars, opening_accumulator, transcript); let trace = state_manager.get_trace_arc(); let bytecode_preprocessing = preprocessing.shared.bytecode.clone(); Self::Phase1(Phase1Prover::gen(trace, bytecode_preprocessing, params)) @@ -168,11 +167,11 @@ pub struct ShiftSumcheckVerifier { impl ShiftSumcheckVerifier { pub fn new( + n_cycle_vars: usize, opening_accumulator: &VerifierOpeningAccumulator, - key: Arc>, transcript: &mut impl Transcript, ) -> Self { - let params = ShiftSumcheckParams::new(opening_accumulator, key, transcript); + let params = ShiftSumcheckParams::new(n_cycle_vars, opening_accumulator, transcript); Self { params } } } @@ -287,13 +286,12 @@ struct ShiftSumcheckParams { impl ShiftSumcheckParams { fn new( + n_cycle_vars: usize, opening_accumulator: &dyn OpeningAccumulator, - key: Arc>, transcript: &mut impl Transcript, ) -> Self { let gamma_powers = transcript.challenge_scalar_powers(5).try_into().unwrap(); - let n_cycle_vars = key.num_steps.ilog2() as usize; let (outer_sumcheck_r, _) = opening_accumulator .get_virtual_polynomial_opening(VirtualPolynomial::NextPC, SumcheckId::SpartanOuter); let (r_cycle, _rx_var) = outer_sumcheck_r.split_at(n_cycle_vars);