diff --git a/src/arithmetics/mod.rs b/src/arithmetics/mod.rs index 8548ccc..776aeb0 100644 --- a/src/arithmetics/mod.rs +++ b/src/arithmetics/mod.rs @@ -7,10 +7,10 @@ use ff_ext::{BabyBearExt4, SmallField}; use openvm_native_compiler::prelude::*; use openvm_native_compiler_derive::iter_zip; use openvm_native_recursion::challenger::ChallengerVariable; -use p3_field::{FieldAlgebra, FieldExtensionAlgebra}; use openvm_native_recursion::challenger::{ duplex::DuplexChallengerVariable, CanObserveVariable, FeltChallenger, }; +use p3_field::{FieldAlgebra, FieldExtensionAlgebra}; type E = BabyBearExt4; const HASH_RATE: usize = 8; @@ -35,8 +35,14 @@ pub fn _print_usize_arr(builder: &mut Builder, arr: &Array(builder: &mut Builder, exts: &Array>) -> Array> { - assert!(matches!(exts, Array::Dyn(_, _)), "Expected dynamic array of Exts"); +pub unsafe fn exts_to_felts( + builder: &mut Builder, + exts: &Array>, +) -> Array> { + assert!( + matches!(exts, Array::Dyn(_, _)), + "Expected dynamic array of Exts" + ); let f_len: Usize = builder.eval(exts.len() * Usize::from(C::EF::D)); let f_arr: Array> = Array::Dyn(exts.ptr(), f_len); f_arr @@ -45,15 +51,22 @@ pub unsafe fn exts_to_felts(builder: &mut Builder, exts: &Array( builder: &mut Builder, challenger: &mut DuplexChallengerVariable, - arr: &Array> + arr: &Array>, ) { - let next_input_ptr = builder.poseidon2_multi_observe(&challenger.sponge_state, challenger.input_ptr, &arr); - builder.assign(&challenger.input_ptr, challenger.io_empty_ptr + next_input_ptr.clone()); - builder.if_ne(next_input_ptr, Usize::from(0)).then_or_else(|builder| { - builder.assign(&challenger.output_ptr, challenger.io_empty_ptr); - }, |builder| { - builder.assign(&challenger.output_ptr, challenger.io_full_ptr); - }); + let next_input_ptr = + builder.poseidon2_multi_observe(&challenger.sponge_state, challenger.input_ptr, &arr); + builder.assign( + &challenger.input_ptr, + challenger.io_empty_ptr + next_input_ptr.clone(), + ); + builder.if_ne(next_input_ptr, Usize::from(0)).then_or_else( + |builder| { + builder.assign(&challenger.output_ptr, challenger.io_empty_ptr); + }, + |builder| { + builder.assign(&challenger.output_ptr, challenger.io_full_ptr); + }, + ); } pub fn is_smaller_than( @@ -685,7 +698,7 @@ pub fn max_usize_arr( } pub struct UniPolyExtrapolator { - constants: [Ext; 12], // 0, 1, 2, 3, 4, -1, 1/2, -1/2, 1/6, -1/6, 1/4, 1/24 + constants: [Ext; 12], // 0, 1, 2, 3, 4, -1, 1/2, -1/2, 1/6, -1/6, 1/4, 1/24 } impl UniPolyExtrapolator { @@ -719,39 +732,62 @@ impl UniPolyExtrapolator { neg_six_inverse, four_inverse, twenty_four_inverse, - ] + ], } } - pub fn extrapolate_uni_poly(&mut self, builder: &mut Builder, p_i: &Array>, eval_at: Ext) -> Ext { + pub fn extrapolate_uni_poly( + &mut self, + builder: &mut Builder, + p_i: &Array>, + eval_at: Ext, + ) -> Ext { let res: Ext = builder.eval(self.constants[0] + self.constants[0]); - builder.if_eq(p_i.len(), Usize::from(4)).then_or_else(|builder| { - let ext = self.extrapolate_uni_poly_deg_3(builder, p_i, eval_at); - builder.assign(&res, ext); - }, |builder| { - builder.if_eq(p_i.len(), Usize::from(3)).then_or_else(|builder| { - let ext = self.extrapolate_uni_poly_deg_2(builder, p_i, eval_at); + builder.if_eq(p_i.len(), Usize::from(4)).then_or_else( + |builder| { + let ext = self.extrapolate_uni_poly_deg_3(builder, p_i, eval_at); builder.assign(&res, ext); - }, |builder| { - builder.if_eq(p_i.len(), Usize::from(2)).then_or_else(|builder| { - let ext = self.extrapolate_uni_poly_deg_1(builder, p_i, eval_at); - builder.assign(&res, ext); - }, |builder| { - builder.if_eq(p_i.len(), Usize::from(5)).then_or_else(|builder| { - let ext = self.extrapolate_uni_poly_deg_4(builder, p_i, eval_at); + }, + |builder| { + builder.if_eq(p_i.len(), Usize::from(3)).then_or_else( + |builder| { + let ext = self.extrapolate_uni_poly_deg_2(builder, p_i, eval_at); builder.assign(&res, ext); - }, |builder| { - builder.error(); - }); - }); - }); - }); + }, + |builder| { + builder.if_eq(p_i.len(), Usize::from(2)).then_or_else( + |builder| { + let ext = self.extrapolate_uni_poly_deg_1(builder, p_i, eval_at); + builder.assign(&res, ext); + }, + |builder| { + builder.if_eq(p_i.len(), Usize::from(5)).then_or_else( + |builder| { + let ext = + self.extrapolate_uni_poly_deg_4(builder, p_i, eval_at); + builder.assign(&res, ext); + }, + |builder| { + builder.error(); + }, + ); + }, + ); + }, + ); + }, + ); res } - fn extrapolate_uni_poly_deg_1(&self, builder: &mut Builder, p_i: &Array>, eval_at: Ext) -> Ext { + fn extrapolate_uni_poly_deg_1( + &self, + builder: &mut Builder, + p_i: &Array>, + eval_at: Ext, + ) -> Ext { // w0 = 1 / (0−1) = -1 // w1 = 1 / (1−0) = 1 let d0: Ext = builder.eval(eval_at - self.constants[0]); @@ -767,7 +803,12 @@ impl UniPolyExtrapolator { builder.eval(l * (t0 + t1)) } - fn extrapolate_uni_poly_deg_2(&self, builder: &mut Builder, p_i: &Array>, eval_at: Ext) -> Ext { + fn extrapolate_uni_poly_deg_2( + &self, + builder: &mut Builder, + p_i: &Array>, + eval_at: Ext, + ) -> Ext { // w0 = 1 / ((0−1)(0−2)) = 1/2 // w1 = 1 / ((1−0)(1−2)) = -1 // w2 = 1 / ((2−0)(2−1)) = 1/2 @@ -788,7 +829,12 @@ impl UniPolyExtrapolator { builder.eval(l * (t0 + t1 + t2)) } - fn extrapolate_uni_poly_deg_3(&self, builder: &mut Builder, p_i: &Array>, eval_at: Ext) -> Ext { + fn extrapolate_uni_poly_deg_3( + &self, + builder: &mut Builder, + p_i: &Array>, + eval_at: Ext, + ) -> Ext { // w0 = 1 / ((0−1)(0−2)(0−3)) = -1/6 // w1 = 1 / ((1−0)(1−2)(1−3)) = 1/2 // w2 = 1 / ((2−0)(2−1)(2−3)) = -1/2 @@ -813,7 +859,12 @@ impl UniPolyExtrapolator { builder.eval(l * (t0 + t1 + t2 + t3)) } - fn extrapolate_uni_poly_deg_4(&self, builder: &mut Builder, p_i: &Array>, eval_at: Ext) -> Ext { + fn extrapolate_uni_poly_deg_4( + &self, + builder: &mut Builder, + p_i: &Array>, + eval_at: Ext, + ) -> Ext { // w0 = 1 / ((0−1)(0−2)(0−3)(0−4)) = 1/24 // w1 = 1 / ((1−0)(1−2)(1−3)(1−4)) = -1/6 // w2 = 1 / ((2−0)(2−1)(2−3)(2−4)) = 1/4 diff --git a/src/e2e/mod.rs b/src/e2e/mod.rs index d54ecb2..101eacc 100644 --- a/src/e2e/mod.rs +++ b/src/e2e/mod.rs @@ -11,9 +11,9 @@ use mpcs::{Basefold, BasefoldRSParams}; use openvm_circuit::arch::{instructions::program::Program, SystemConfig, VmExecutor}; use openvm_native_circuit::{Native, NativeConfig}; use openvm_native_compiler::{ - asm::AsmBuilder, - conversion::{CompilerOptions, convert_program}, - prelude::AsmCompiler + asm::AsmBuilder, + conversion::{convert_program, CompilerOptions}, + prelude::AsmCompiler, }; use openvm_native_recursion::hints::Hintable; use openvm_stark_backend::config::StarkGenericConfig; @@ -458,13 +458,15 @@ pub fn test_zkvm_proof_verifier_from_bincode_exports() { let asm_code = compiler.code(); // _debug: print out assembly - /* + /* println!("=> AssemblyCode:"); println!("{asm_code}"); return (); */ - let program: Program,> = convert_program(asm_code, options); + let program: Program< + p3_monty_31::MontyField31, + > = convert_program(asm_code, options); let mut system_config = SystemConfig::default() .with_public_values(4) .with_max_segment_len((1 << 25) - 100); @@ -473,14 +475,9 @@ pub fn test_zkvm_proof_verifier_from_bincode_exports() { let executor = VmExecutor::::new(config); - let res = executor.execute_and_then( - program, - witness_stream, - |_, seg| { - Ok(seg) - }, - |err| err, - ).unwrap(); + let res = executor + .execute_and_then(program, witness_stream, |_, seg| Ok(seg), |err| err) + .unwrap(); for (i, seg) in res.iter().enumerate() { println!("=> segment {:?} metrics: {:?}", i, seg.metrics); diff --git a/src/extensions/mod.rs b/src/extensions/mod.rs index 77984e2..9a37a6c 100644 --- a/src/extensions/mod.rs +++ b/src/extensions/mod.rs @@ -1,30 +1,13 @@ use crate::arithmetics::{challenger_multi_observe, exts_to_felts, print_felt_arr}; +use crate::e2e::SubcircuitParams; use crate::tower_verifier::binding::IOPProverMessage; +use crate::tower_verifier::program::verify_tower_proof; +use crate::transcript::transcript_observe_label; use crate::zkvm_verifier::binding::ZKVMProofInput; use crate::zkvm_verifier::binding::{ TowerProofInput, ZKVMOpcodeProofInput, ZKVMTableProofInput, E, F, }; use crate::zkvm_verifier::verifier::verify_zkvm_proof; -use ff_ext::BabyBearExt4; -use itertools::Itertools; -use mpcs::BasefoldCommitment; -use mpcs::{Basefold, BasefoldRSParams}; -use openvm_circuit::arch::{instructions::program::Program, SystemConfig, VmExecutor}; -use openvm_native_circuit::{Native, NativeConfig}; -use openvm_native_compiler::{asm::AsmBuilder, conversion::CompilerOptions}; -use openvm_native_recursion::challenger::{self, CanSampleVariable}; -use openvm_native_recursion::hints::Hintable; -use openvm_stark_backend::config::StarkGenericConfig; -use openvm_stark_sdk::{ - config::baby_bear_poseidon2::BabyBearPoseidon2Config, p3_baby_bear::BabyBear, -}; -use openvm_native_compiler::conversion::convert_program; -use std::collections::HashMap; -use std::fs::File; -use std::marker::PhantomData; -use crate::e2e::SubcircuitParams; -use crate::tower_verifier::program::verify_tower_proof; -use crate::transcript::transcript_observe_label; use crate::{ arithmetics::{ build_eq_x_r_vec_sequential, ceil_log2, concat, dot_product as ext_dot_product, @@ -39,14 +22,31 @@ use crate::{ }; use ceno_zkvm::circuit_builder::SetTableSpec; use ceno_zkvm::{expression::StructuralWitIn, scheme::verifier::ZKVMVerifier}; +use ff_ext::BabyBearExt4; use itertools::interleave; use itertools::max; +use itertools::Itertools; +use mpcs::BasefoldCommitment; +use mpcs::{Basefold, BasefoldRSParams}; +use openvm_circuit::arch::{instructions::program::Program, SystemConfig, VmExecutor}; +use openvm_native_circuit::{Native, NativeConfig}; +use openvm_native_compiler::conversion::convert_program; use openvm_native_compiler::prelude::*; +use openvm_native_compiler::{asm::AsmBuilder, conversion::CompilerOptions}; use openvm_native_compiler_derive::iter_zip; +use openvm_native_recursion::challenger::{self, CanSampleVariable}; use openvm_native_recursion::challenger::{ duplex::DuplexChallengerVariable, CanObserveVariable, FeltChallenger, }; +use openvm_native_recursion::hints::Hintable; +use openvm_stark_backend::config::StarkGenericConfig; +use openvm_stark_sdk::{ + config::baby_bear_poseidon2::BabyBearPoseidon2Config, p3_baby_bear::BabyBear, +}; use p3_field::{Field, FieldAlgebra, FieldExtensionAlgebra}; +use std::collections::HashMap; +use std::fs::File; +use std::marker::PhantomData; type Pcs = Basefold; const NUM_FANIN: usize = 2; @@ -84,27 +84,20 @@ pub fn test_native_multi_observe() { let config = NativeConfig::new(system_config, Native); let executor = VmExecutor::::new(config); - + // Alternative execution // executor.execute(program, witness_stream).unwrap(); - let res = executor.execute_and_then( - program, - witness_stream, - |_, seg| { - Ok(seg) - }, - |err| err, - ).unwrap(); + let res = executor + .execute_and_then(program, witness_stream, |_, seg| Ok(seg), |err| err) + .unwrap(); for (i, seg) in res.iter().enumerate() { println!("=> segment {:?} metrics: {:?}", i, seg.metrics); } } -fn vm_program( - builder: &mut Builder, -) { +fn vm_program(builder: &mut Builder) { let e1: Ext = builder.constant(C::EF::GENERATOR.exp_power_of_2(16)); let e2: Ext = builder.constant(C::EF::GENERATOR.exp_power_of_2(32)); let e3: Ext = builder.constant(C::EF::GENERATOR.exp_power_of_2(64)); @@ -123,10 +116,10 @@ fn vm_program( let mut c1 = DuplexChallengerVariable::new(builder); let mut c2 = DuplexChallengerVariable::new(builder); - let f_arr1 = exts_to_felts(builder, &e_arr); + let f_arr1 = exts_to_felts(builder, &e_arr); let f_arr2 = f_arr1.clone(); - challenger_multi_observe(builder, &mut c1, &f_arr1); + challenger_multi_observe(builder, &mut c1, &f_arr1); let test_e1 = c1.sample(builder); c2.observe_slice(builder, f_arr2); diff --git a/src/transcript/mod.rs b/src/transcript/mod.rs index 05fd418..9a61b80 100644 --- a/src/transcript/mod.rs +++ b/src/transcript/mod.rs @@ -1,10 +1,10 @@ use ff_ext::{BabyBearExt4, ExtensionField as CenoExtensionField, SmallField}; use openvm_native_compiler::prelude::*; use openvm_native_recursion::challenger::ChallengerVariable; -use p3_field::FieldAlgebra; use openvm_native_recursion::challenger::{ duplex::DuplexChallengerVariable, CanObserveVariable, FeltChallenger, }; +use p3_field::FieldAlgebra; pub fn transcript_observe_label( builder: &mut Builder,