diff --git a/plonk-napi/src/circuit.rs b/plonk-napi/src/circuit.rs index 6a887e37c7..299a8954d2 100644 --- a/plonk-napi/src/circuit.rs +++ b/plonk-napi/src/circuit.rs @@ -1,4 +1,4 @@ -use crate::{build_info::report_native_call, pasta_fp_plonk_index::WasmPastaFpPlonkIndex}; +use crate::{build_info::report_native_call, pasta_fp_plonk_index::NapiPastaFpPlonkIndex}; use ark_ff::PrimeField; use kimchi::circuits::{constraints::ConstraintSystem, gate::CircuitGate}; use mina_curves::pasta::Fp; @@ -29,7 +29,7 @@ where } #[napi(js_name = "prover_to_json")] -pub fn prover_to_json(prover_index: &External) -> String { +pub fn prover_to_json(prover_index: &External) -> String { report_native_call(); let circuit: Circuit = prover_index.0.cs.as_ref().into(); diff --git a/plonk-napi/src/gate_vector.rs b/plonk-napi/src/gate_vector.rs index 7bbad67295..a85c82c246 100644 --- a/plonk-napi/src/gate_vector.rs +++ b/plonk-napi/src/gate_vector.rs @@ -1,12 +1,14 @@ +use ark_ff::PrimeField; use kimchi::circuits::{ gate::{Circuit, CircuitGate, GateType}, - wires::{GateWires, Wire as KimchiWire}, + wires::Wire, }; use mina_curves::pasta::{Fp, Fq}; use napi::bindgen_prelude::*; use napi_derive::napi; use o1_utils::hasher::CryptoDigest; use paste::paste; +use std::ops::Deref; use wasm_types::{FlatVector as WasmFlatVector, FlatVectorElem}; use crate::wrappers::{ @@ -14,8 +16,198 @@ use crate::wrappers::{ wires::NapiWire, }; +pub mod shared { + use super::*; + + /// Number of wires stored per gate. + pub const WIRE_COUNT: usize = 7; + + #[derive(Clone, Copy, Debug, PartialEq, Eq)] + pub struct GateWires(pub [Wire; WIRE_COUNT]); + + impl GateWires { + pub fn new(wires: [Wire; WIRE_COUNT]) -> Self { + Self(wires) + } + + pub fn as_array(&self) -> &[Wire; WIRE_COUNT] { + &self.0 + } + + pub fn into_array(self) -> [Wire; WIRE_COUNT] { + self.0 + } + } + + impl From<[Wire; WIRE_COUNT]> for GateWires { + fn from(wires: [Wire; WIRE_COUNT]) -> Self { + GateWires::new(wires) + } + } + + impl From for [Wire; WIRE_COUNT] { + fn from(gw: GateWires) -> Self { + gw.into_array() + } + } + + #[derive(Clone, Debug)] + pub struct Gate { + pub typ: GateType, + pub wires: GateWires, + pub coeffs: Vec, + } + + impl From> for Gate + where + F: PrimeField, + { + fn from(cg: CircuitGate) -> Self { + Gate { + typ: cg.typ, + wires: GateWires::new([ + cg.wires[0], + cg.wires[1], + cg.wires[2], + cg.wires[3], + cg.wires[4], + cg.wires[5], + cg.wires[6], + ]), + coeffs: cg.coeffs, + } + } + } + + impl From<&CircuitGate> for Gate + where + F: PrimeField, + { + fn from(cg: &CircuitGate) -> Self { + Gate { + typ: cg.typ, + wires: GateWires::new([ + cg.wires[0], + cg.wires[1], + cg.wires[2], + cg.wires[3], + cg.wires[4], + cg.wires[5], + cg.wires[6], + ]), + coeffs: cg.coeffs.clone(), + } + } + } + + impl From> for CircuitGate + where + F: PrimeField, + { + fn from(gate: Gate) -> Self { + CircuitGate { + typ: gate.typ, + wires: gate.wires.into_array(), + coeffs: gate.coeffs, + } + } + } + + #[derive(Clone, Debug, Default)] + pub struct GateVector { + gates: Vec>, + } + + impl GateVector + where + F: PrimeField, + { + pub fn new() -> Self { + Self { gates: Vec::new() } + } + + pub fn from_vec(gates: Vec>) -> Self { + Self { gates } + } + + pub fn into_inner(self) -> Vec> { + self.gates + } + + pub fn as_slice(&self) -> &[CircuitGate] { + &self.gates + } + + pub fn iter(&self) -> core::slice::Iter<'_, CircuitGate> { + self.gates.iter() + } + + pub fn iter_mut(&mut self) -> core::slice::IterMut<'_, CircuitGate> { + self.gates.iter_mut() + } + + pub fn push_gate(&mut self, gate: CircuitGate) { + self.gates.push(gate); + } + + pub fn len(&self) -> usize { + self.gates.len() + } + + pub fn get_gate(&self, index: usize) -> Option> { + self.gates.get(index).map(Gate::from) + } + + pub fn wrap_wire(&mut self, target: Wire, replacement: Wire) { + if let Some(gate) = self.gates.get_mut(target.row) { + if target.col < gate.wires.len() { + gate.wires[target.col] = replacement; + } + } + } + + pub fn digest(&self, public_input_size: usize) -> Vec { + Circuit::new(public_input_size, self.as_slice()) + .digest() + .to_vec() + } + + pub fn serialize( + &self, + public_input_size: usize, + ) -> std::result::Result { + let circuit = Circuit::new(public_input_size, self.as_slice()); + serde_json::to_string(&circuit) + } + } + + impl From>> for GateVector + where + F: PrimeField, + { + fn from(gates: Vec>) -> Self { + GateVector::from_vec(gates) + } + } + + impl From> for Vec> + where + F: PrimeField, + { + fn from(vec: GateVector) -> Self { + vec.into_inner() + } + } +} + +pub use self::shared::{GateVector as CoreGateVector, GateWires as CoreGateWires}; + +fn gate_vector_error(context: &str, err: impl std::fmt::Display) -> Error { + Error::new(Status::GenericFailure, format!("{}: {}", context, err)) +} + #[napi(object)] -#[derive(Clone, Debug, Default)] +#[derive(Clone, Copy, Debug, Default)] pub struct NapiGateWires { pub w0: NapiWire, pub w1: NapiWire, @@ -26,31 +218,38 @@ pub struct NapiGateWires { pub w6: NapiWire, } -impl NapiGateWires { - fn into_inner(self) -> GateWires { - [ - KimchiWire::from(self.w0), - KimchiWire::from(self.w1), - KimchiWire::from(self.w2), - KimchiWire::from(self.w3), - KimchiWire::from(self.w4), - KimchiWire::from(self.w5), - KimchiWire::from(self.w6), - ] +impl From for NapiGateWires { + fn from(wires: CoreGateWires) -> Self { + let array = wires.into_array(); + NapiGateWires { + w0: array[0].into(), + w1: array[1].into(), + w2: array[2].into(), + w3: array[3].into(), + w4: array[4].into(), + w5: array[5].into(), + w6: array[6].into(), + } } } -impl From<&GateWires> for NapiGateWires { - fn from(value: &GateWires) -> Self { - Self { - w0: value[0].into(), - w1: value[1].into(), - w2: value[2].into(), - w3: value[3].into(), - w4: value[4].into(), - w5: value[5].into(), - w6: value[6].into(), - } +impl From for CoreGateWires { + fn from(wires: NapiGateWires) -> Self { + CoreGateWires::new(wires.into_inner()) + } +} + +impl NapiGateWires { + fn into_inner(self) -> [Wire; shared::WIRE_COUNT] { + [ + self.w0.into(), + self.w1.into(), + self.w2.into(), + self.w3.into(), + self.w4.into(), + self.w5.into(), + self.w6.into(), + ] } } @@ -132,24 +331,91 @@ macro_rules! impl_gate_support { .flat_map(|elem| elem.flatten()) .collect(); + let wires = CoreGateWires::new([ + value.wires[0], + value.wires[1], + value.wires[2], + value.wires[3], + value.wires[4], + value.wires[5], + value.wires[6], + ]); + Self { typ: gate_type_to_i32(value.typ), - wires: (&value.wires).into(), + wires: wires.into(), coeffs, } } } #[napi] - #[derive(Clone, Default, Debug)] + #[derive(Clone, Debug, Default)] pub struct []( - #[napi(skip)] pub Vec>, + #[napi(skip)] pub CoreGateVector<$F>, ); + impl Deref for [] { + type Target = CoreGateVector<$F>; + + fn deref(&self) -> &Self::Target { + &self.0 + } + } + + impl From> for [] { + fn from(inner: CoreGateVector<$F>) -> Self { + Self(inner) + } + } + + impl From<[]> for CoreGateVector<$F> { + fn from(vector: []) -> Self { + vector.0 + } + } + + #[napi] + impl [] { + #[napi(constructor)] + pub fn new() -> Self { + CoreGateVector::new().into() + } + + #[napi] + pub fn serialize(&self) -> Result { + let bytes = rmp_serde::to_vec(self.0.as_slice()) + .map_err(|e| gate_vector_error("gate vector serialize failed", e))?; + Ok(Uint8Array::from(bytes)) + } + + #[napi(factory)] + pub fn deserialize(bytes: Uint8Array) -> Result { + let gates: Vec> = rmp_serde::from_slice(bytes.as_ref()) + .map_err(|e| gate_vector_error("gate vector deserialize failed", e))?; + Ok(CoreGateVector::from_vec(gates).into()) + } + + pub(crate) fn inner(&self) -> &CoreGateVector<$F> { + &self.0 + } + + pub(crate) fn inner_mut(&mut self) -> &mut CoreGateVector<$F> { + &mut self.0 + } + + pub(crate) fn as_slice(&self) -> &[CircuitGate<$F>] { + self.0.as_slice() + } + + pub(crate) fn to_vec(&self) -> Vec> { + self.0.as_slice().to_vec() + } + } + #[napi] pub fn []() -> [] { - println!("from native rust creating gate vector"); - [](Vec::new()) + []::new() } #[napi] @@ -157,8 +423,8 @@ macro_rules! impl_gate_support { vector: &mut [], gate: [], ) -> Result<()> { - println!("from native rust adding gate to vector"); - vector.0.push(gate.into_inner()?); + let gate = gate.into_inner()?; + vector.inner_mut().push_gate(gate); Ok(()) } @@ -167,16 +433,18 @@ macro_rules! impl_gate_support { vector: &[], index: i32, ) -> [] { - println!("from native rust getting gate from vector"); - []::from_inner(&vector.0[index as usize]) + let gate = vector + .as_slice() + .get(index as usize) + .expect("index out of bounds"); + []::from_inner(gate) } #[napi] pub fn []( vector: &[], ) -> i32 { - println!("from native rust getting gate vector length"); - vector.0.len() as i32 + vector.as_slice().len() as i32 } #[napi] @@ -185,19 +453,17 @@ macro_rules! impl_gate_support { target: NapiWire, head: NapiWire, ) { - println!("from native rust wrapping wire in gate vector"); - vector.0[target.row as usize].wires[target.col as usize] = KimchiWire::from(head); - } + let target: Wire = target.into(); + let head: Wire = head.into(); + vector.inner_mut().wrap_wire(target, head); + } #[napi] pub fn []( public_input_size: i32, vector: &[], ) -> Uint8Array { - println!("from native rust computing gate vector digest"); - let bytes = Circuit::new(public_input_size as usize, &vector.0) - .digest() - .to_vec(); + let bytes = vector.inner().digest(public_input_size as usize); Uint8Array::from(bytes) } @@ -206,14 +472,29 @@ macro_rules! impl_gate_support { public_input_size: i32, vector: &[], ) -> Result { - println!("from native rust serializing gate vector to json"); - let circuit = Circuit::new(public_input_size as usize, &vector.0); - serde_json::to_string(&circuit).map_err(|err| { - Error::new( - Status::GenericFailure, - format!("couldn't serialize constraints: {}", err), - ) - }) + vector + .inner() + .serialize(public_input_size as usize) + .map_err(|err| { + Error::new( + Status::GenericFailure, + format!("couldn't serialize constraints: {}", err), + ) + }) + } + + #[napi] + pub fn []( + vector: &[], + ) -> Result { + vector.serialize() + } + + #[napi] + pub fn []( + bytes: Uint8Array, + ) -> Result<[]> { + []::deserialize(bytes) } } }; diff --git a/plonk-napi/src/lib.rs b/plonk-napi/src/lib.rs index 21f54e985b..757489bf60 100644 --- a/plonk-napi/src/lib.rs +++ b/plonk-napi/src/lib.rs @@ -1,7 +1,9 @@ mod build_info; mod circuit; +mod gate_vector; mod pasta_fp_plonk_index; mod pasta_fq_plonk_index; +mod plonk_verifier_index; mod poly_comm; mod poseidon; mod srs; @@ -9,11 +11,27 @@ mod tables; mod vector; mod wrappers; +pub use circuit::prover_to_json; +pub use gate_vector::{ + caml_pasta_fp_plonk_circuit_serialize, caml_pasta_fp_plonk_gate_vector_add, + caml_pasta_fp_plonk_gate_vector_create, caml_pasta_fp_plonk_gate_vector_digest, + caml_pasta_fp_plonk_gate_vector_from_bytes, caml_pasta_fp_plonk_gate_vector_get, + caml_pasta_fp_plonk_gate_vector_len, caml_pasta_fp_plonk_gate_vector_to_bytes, + caml_pasta_fp_plonk_gate_vector_wrap, caml_pasta_fq_plonk_circuit_serialize, + caml_pasta_fq_plonk_gate_vector_add, caml_pasta_fq_plonk_gate_vector_create, + caml_pasta_fq_plonk_gate_vector_digest, caml_pasta_fq_plonk_gate_vector_from_bytes, + caml_pasta_fq_plonk_gate_vector_get, caml_pasta_fq_plonk_gate_vector_len, + caml_pasta_fq_plonk_gate_vector_to_bytes, caml_pasta_fq_plonk_gate_vector_wrap, + NapiFpGate as WasmFpGate, NapiFpGateVector as WasmFpGateVector, NapiFqGate as WasmFqGate, + NapiFqGateVector as WasmFqGateVector, +}; pub use pasta_fp_plonk_index::{ - prover_index_fp_from_bytes, prover_index_fp_to_bytes, WasmPastaFpPlonkIndex, + prover_index_fp_from_bytes, prover_index_fp_to_bytes, + NapiPastaFpPlonkIndex as WasmPastaFpPlonkIndex, }; pub use pasta_fq_plonk_index::{ - prover_index_fq_from_bytes, prover_index_fq_to_bytes, WasmPastaFqPlonkIndex, + prover_index_fq_from_bytes, prover_index_fq_to_bytes, + NapiPastaFqPlonkIndex as WasmPastaFqPlonkIndex, }; pub use plonk_verifier_index::{ caml_pasta_fp_plonk_verifier_index_shifts, caml_pasta_fq_plonk_verifier_index_shifts, @@ -22,7 +40,10 @@ pub use poly_comm::{ pallas::NapiFqPolyComm as WasmFqPolyComm, vesta::NapiFpPolyComm as WasmFpPolyComm, }; pub use poseidon::{caml_pasta_fp_poseidon_block_cipher, caml_pasta_fq_poseidon_block_cipher}; -pub use srs::{caml_fp_srs_from_bytes, caml_fp_srs_to_bytes, caml_fq_srs_from_bytes}; +pub use srs::{ + caml_fp_srs_from_bytes, caml_fp_srs_to_bytes, caml_fq_srs_from_bytes, + fp::NapiFpSrs as WasmFpSrs, fq::NapiFqSrs as WasmFqSrs, +}; pub use tables::{JsLookupTableFp, JsLookupTableFq, JsRuntimeTableCfgFp, JsRuntimeTableCfgFq}; pub use vector::{fp::NapiVecVecFp as WasmVecVecFp, fq::NapiVecVecFq as WasmVecVecFq}; pub use wrappers::{ diff --git a/plonk-napi/src/pasta_fp_plonk_index.rs b/plonk-napi/src/pasta_fp_plonk_index.rs index 26ab12c041..d42e24d110 100644 --- a/plonk-napi/src/pasta_fp_plonk_index.rs +++ b/plonk-napi/src/pasta_fp_plonk_index.rs @@ -1,3 +1,11 @@ +use crate::{ + build_info::report_native_call, + gate_vector::NapiFpGateVector, + srs::fp::NapiFpSrs, + tables::{ + lookup_table_fp_from_js, runtime_table_cfg_fp_from_js, JsLookupTableFp, JsRuntimeTableCfgFp, + }, +}; use ark_poly::EvaluationDomain; use kimchi::{ circuits::{ @@ -11,7 +19,6 @@ use mina_curves::pasta::{Fp, Pallas as GAffineOther, Vesta as GAffine, VestaPara use mina_poseidon::{constants::PlonkSpongeConstantsKimchi, sponge::DefaultFqSponge}; use napi::bindgen_prelude::{Error, External, Status, Uint8Array}; use napi_derive::napi; -use plonk_wasm::gate_vector::shared::GateVector; use poly_commitment::{ ipa::{OpeningProof, SRS as IPA_SRS}, SRS, @@ -22,12 +29,7 @@ use std::{ io::{BufReader, BufWriter, Cursor, Seek, SeekFrom::Start}, sync::Arc, }; - -use crate::tables::{ - lookup_table_fp_from_js, runtime_table_cfg_fp_from_js, JsLookupTableFp, JsRuntimeTableCfgFp, -}; -use plonk_wasm::srs::fp::WasmFpSrs as WasmSrs; -pub struct WasmPastaFpPlonkIndex(pub Box>>); +pub struct NapiPastaFpPlonkIndex(pub Box>>); #[derive(Serialize, Deserialize)] struct SerializedProverIndex { @@ -35,7 +37,7 @@ struct SerializedProverIndex { srs: Vec, } -impl WasmPastaFpPlonkIndex { +impl NapiPastaFpPlonkIndex { fn serialize_inner(&self) -> Result, String> { let prover_index = rmp_serde::to_vec(self.0.as_ref()).map_err(|e| e.to_string())?; @@ -75,23 +77,29 @@ impl WasmPastaFpPlonkIndex { DefaultFqSponge, >(); - Ok(WasmPastaFpPlonkIndex(Box::new(index))) + Ok(NapiPastaFpPlonkIndex(Box::new(index))) } } -#[napi] +// TOOD: remove incl all dependencies when no longer needed and we only pass napi objects around +#[napi(js_name = "prover_index_fp_from_bytes")] pub fn prover_index_fp_from_bytes( bytes: Uint8Array, -) -> napi::bindgen_prelude::Result> { - let index = WasmPastaFpPlonkIndex::deserialize_inner(bytes.as_ref()) +) -> napi::bindgen_prelude::Result> { + report_native_call(); + + let index = NapiPastaFpPlonkIndex::deserialize_inner(bytes.as_ref()) .map_err(|e| Error::new(Status::InvalidArg, e))?; Ok(External::new(index)) } -#[napi] +// TOOD: remove incl all dependencies when no longer needed and we only pass napi objects around +#[napi(js_name = "prover_index_fp_to_bytes")] pub fn prover_index_fp_to_bytes( - index: External, + index: &External, ) -> napi::bindgen_prelude::Result { + report_native_call(); + let bytes = index .serialize_inner() .map_err(|e| Error::new(Status::GenericFailure, e))?; @@ -99,41 +107,41 @@ pub fn prover_index_fp_to_bytes( } #[napi] -pub fn caml_pasta_fp_plonk_index_max_degree(index: External) -> i32 { +pub fn caml_pasta_fp_plonk_index_max_degree(index: &External) -> i32 { index.0.srs.max_poly_size() as i32 } #[napi] -pub fn caml_pasta_fp_plonk_index_public_inputs(index: External) -> i32 { +pub fn caml_pasta_fp_plonk_index_public_inputs(index: &External) -> i32 { index.0.cs.public as i32 } #[napi] -pub fn caml_pasta_fp_plonk_index_domain_d1_size(index: External) -> i32 { +pub fn caml_pasta_fp_plonk_index_domain_d1_size(index: &External) -> i32 { index.0.cs.domain.d1.size() as i32 } #[napi] -pub fn caml_pasta_fp_plonk_index_domain_d4_size(index: External) -> i32 { +pub fn caml_pasta_fp_plonk_index_domain_d4_size(index: &External) -> i32 { index.0.cs.domain.d4.size() as i32 } #[napi] -pub fn caml_pasta_fp_plonk_index_domain_d8_size(index: External) -> i32 { +pub fn caml_pasta_fp_plonk_index_domain_d8_size(index: &External) -> i32 { index.0.cs.domain.d8.size() as i32 } #[napi] pub fn caml_pasta_fp_plonk_index_create( - gates: External>, + gates: &NapiFpGateVector, public_: i32, lookup_tables: Vec, runtime_table_cfgs: Vec, prev_challenges: i32, - srs: External, + srs: &External, lazy_mode: bool, -) -> Result, Error> { - let gates: Vec<_> = gates.as_ref().as_slice().to_vec(); +) -> Result, Error> { + let gates: Vec<_> = gates.to_vec(); let runtime_cfgs: Vec> = runtime_table_cfgs .into_iter() @@ -178,14 +186,14 @@ pub fn caml_pasta_fp_plonk_index_create( ); index.compute_verifier_index_digest::>(); - Ok(External::new(WasmPastaFpPlonkIndex(Box::new(index)))) + Ok(External::new(NapiPastaFpPlonkIndex(Box::new(index)))) } #[napi] pub fn caml_pasta_fp_plonk_index_decode( bytes: &[u8], - srs: External, -) -> Result, Error> { + srs: &External, +) -> Result, Error> { let mut deserializer = rmp_serde::Deserializer::new(bytes); let mut index = ProverIndex::>::deserialize(&mut deserializer) .map_err(|e| { @@ -200,12 +208,12 @@ pub fn caml_pasta_fp_plonk_index_decode( index.linearization = linearization; index.powers_of_alpha = powers_of_alpha; - Ok(External::new(WasmPastaFpPlonkIndex(Box::new(index)))) + Ok(External::new(NapiPastaFpPlonkIndex(Box::new(index)))) } #[napi] pub fn caml_pasta_fp_plonk_index_encode( - index: External, + index: &External, ) -> Result, Error> { let mut buffer = Vec::new(); let mut serializer = rmp_serde::Serializer::new(&mut buffer); @@ -221,7 +229,7 @@ pub fn caml_pasta_fp_plonk_index_encode( #[napi] pub fn caml_pasta_fp_plonk_index_write( append: Option, - index: External, + index: &External, path: String, ) -> Result<(), Error> { let file = OpenOptions::new() @@ -243,9 +251,9 @@ pub fn caml_pasta_fp_plonk_index_write( #[napi] pub fn caml_pasta_fp_plonk_index_read( offset: Option, - srs: External, + srs: &External, path: String, -) -> Result, Error> { +) -> Result, Error> { // read from file let file = match File::open(path) { Err(_) => { @@ -284,5 +292,5 @@ pub fn caml_pasta_fp_plonk_index_read( t.powers_of_alpha = powers_of_alpha; // - Ok(External::new(WasmPastaFpPlonkIndex(Box::new(t)))) + Ok(External::new(NapiPastaFpPlonkIndex(Box::new(t)))) } diff --git a/plonk-napi/src/pasta_fq_plonk_index.rs b/plonk-napi/src/pasta_fq_plonk_index.rs index 4203bd4505..eae617fc16 100644 --- a/plonk-napi/src/pasta_fq_plonk_index.rs +++ b/plonk-napi/src/pasta_fq_plonk_index.rs @@ -1,3 +1,11 @@ +use crate::{ + build_info::report_native_call, + gate_vector::NapiFqGateVector, + srs::fq::NapiFqSrs, + tables::{ + lookup_table_fq_from_js, runtime_table_cfg_fq_from_js, JsLookupTableFq, JsRuntimeTableCfgFq, + }, +}; use ark_poly::EvaluationDomain; use kimchi::{ circuits::constraints::ConstraintSystem, linearization::expr_linearization, @@ -7,7 +15,6 @@ use mina_curves::pasta::{Fq, Pallas as GAffine, PallasParameters, Vesta as GAffi use mina_poseidon::{constants::PlonkSpongeConstantsKimchi, sponge::DefaultFqSponge}; use napi::bindgen_prelude::{Error, External, Status, Uint8Array}; use napi_derive::napi; -use plonk_wasm::gate_vector::shared::GateVector; use poly_commitment::{ ipa::{OpeningProof, SRS as IPA_SRS}, SRS, @@ -18,12 +25,7 @@ use std::{ io::{BufReader, BufWriter, Cursor, Seek, SeekFrom::Start}, sync::Arc, }; - -use crate::tables::{ - lookup_table_fq_from_js, runtime_table_cfg_fq_from_js, JsLookupTableFq, JsRuntimeTableCfgFq, -}; -use plonk_wasm::srs::fq::WasmFqSrs as WasmSrs; -pub struct WasmPastaFqPlonkIndex(pub Box>>); +pub struct NapiPastaFqPlonkIndex(pub Box>>); #[derive(Serialize, Deserialize)] struct SerializedProverIndex { @@ -31,7 +33,7 @@ struct SerializedProverIndex { srs: Vec, } -impl WasmPastaFqPlonkIndex { +impl NapiPastaFqPlonkIndex { fn serialize_inner(&self) -> Result, String> { let prover_index = rmp_serde::to_vec(self.0.as_ref()).map_err(|e| e.to_string())?; @@ -71,23 +73,29 @@ impl WasmPastaFqPlonkIndex { DefaultFqSponge, >(); - Ok(WasmPastaFqPlonkIndex(Box::new(index))) + Ok(NapiPastaFqPlonkIndex(Box::new(index))) } } -#[napi] +// TOOD: remove incl all dependencies when no longer needed and we only pass napi objects around +#[napi(js_name = "prover_index_fq_from_bytes")] pub fn prover_index_fq_from_bytes( bytes: Uint8Array, -) -> napi::bindgen_prelude::Result> { - let index = WasmPastaFqPlonkIndex::deserialize_inner(bytes.as_ref()) +) -> napi::bindgen_prelude::Result> { + report_native_call(); + + let index = NapiPastaFqPlonkIndex::deserialize_inner(bytes.as_ref()) .map_err(|e| Error::new(Status::InvalidArg, e))?; Ok(External::new(index)) } -#[napi] +// TOOD: remove incl all dependencies when no longer needed and we only pass napi objects around +#[napi(js_name = "prover_index_fq_to_bytes")] pub fn prover_index_fq_to_bytes( - index: External, + index: &External, ) -> napi::bindgen_prelude::Result { + report_native_call(); + let bytes = index .serialize_inner() .map_err(|e| Error::new(Status::GenericFailure, e))?; @@ -95,43 +103,43 @@ pub fn prover_index_fq_to_bytes( } #[napi] -pub fn caml_pasta_fq_plonk_index_max_degree(index: External) -> i32 { +pub fn caml_pasta_fq_plonk_index_max_degree(index: &External) -> i32 { index.0.srs.max_poly_size() as i32 } #[napi] -pub fn caml_pasta_fq_plonk_index_public_inputs(index: External) -> i32 { +pub fn caml_pasta_fq_plonk_index_public_inputs(index: &External) -> i32 { index.0.cs.public as i32 } #[napi] -pub fn caml_pasta_fq_plonk_index_domain_d1_size(index: External) -> i32 { +pub fn caml_pasta_fq_plonk_index_domain_d1_size(index: &External) -> i32 { index.0.cs.domain.d1.size() as i32 } #[napi] -pub fn caml_pasta_fq_plonk_index_domain_d4_size(index: External) -> i32 { +pub fn caml_pasta_fq_plonk_index_domain_d4_size(index: &External) -> i32 { index.0.cs.domain.d4.size() as i32 } #[napi] -pub fn caml_pasta_fq_plonk_index_domain_d8_size(index: External) -> i32 { +pub fn caml_pasta_fq_plonk_index_domain_d8_size(index: &External) -> i32 { index.0.cs.domain.d8.size() as i32 } #[napi] pub fn caml_pasta_fq_plonk_index_create( - gates: External>, + gates: &NapiFqGateVector, public_: i32, lookup_tables: Vec, runtime_table_cfgs: Vec, prev_challenges: i32, - srs: External, + srs: &External, lazy_mode: bool, -) -> Result, Error> { +) -> Result, Error> { // TODO: check if and how we run rayon threads automatically in napi - let gates: Vec<_> = gates.as_ref().as_slice().to_vec(); + let gates: Vec<_> = gates.to_vec(); let runtime_cfgs = runtime_table_cfgs .into_iter() @@ -176,14 +184,14 @@ pub fn caml_pasta_fq_plonk_index_create( ); index.compute_verifier_index_digest::>(); - Ok(External::new(WasmPastaFqPlonkIndex(Box::new(index)))) + Ok(External::new(NapiPastaFqPlonkIndex(Box::new(index)))) } #[napi] pub fn caml_pasta_fq_plonk_index_decode( bytes: &[u8], - srs: External, -) -> Result, Error> { + srs: &External, +) -> Result, Error> { let mut deserializer = rmp_serde::Deserializer::new(bytes); let mut index = ProverIndex::>::deserialize(&mut deserializer) .map_err(|e| { @@ -197,12 +205,12 @@ pub fn caml_pasta_fq_plonk_index_decode( index.linearization = linearization; index.powers_of_alpha = powers_of_alpha; - Ok(External::new(WasmPastaFqPlonkIndex(Box::new(index)))) + Ok(External::new(NapiPastaFqPlonkIndex(Box::new(index)))) } #[napi] pub fn caml_pasta_fq_plonk_index_encode( - index: External, + index: &External, ) -> Result, Error> { let mut buffer = Vec::new(); let mut serializer = rmp_serde::Serializer::new(&mut buffer); @@ -218,7 +226,7 @@ pub fn caml_pasta_fq_plonk_index_encode( #[napi] pub fn caml_pasta_fq_plonk_index_write( append: Option, - index: External, + index: &External, path: String, ) -> Result<(), Error> { let file = OpenOptions::new() @@ -240,9 +248,9 @@ pub fn caml_pasta_fq_plonk_index_write( #[napi] pub fn caml_pasta_fq_plonk_index_read( offset: Option, - srs: External, + srs: &External, path: String, -) -> Result, Error> { +) -> Result, Error> { // read from file let file = match File::open(path) { Err(_) => { @@ -281,5 +289,5 @@ pub fn caml_pasta_fq_plonk_index_read( t.powers_of_alpha = powers_of_alpha; // - Ok(External::new(WasmPastaFqPlonkIndex(Box::new(t)))) + Ok(External::new(NapiPastaFqPlonkIndex(Box::new(t)))) } diff --git a/plonk-napi/src/prover_index.rs b/plonk-napi/src/prover_index.rs deleted file mode 100644 index 106ebffe55..0000000000 --- a/plonk-napi/src/prover_index.rs +++ /dev/null @@ -1,25 +0,0 @@ -use napi::bindgen_prelude::*; -use napi_derive::napi; - -use crate::{build_info::report_native_call, types::WasmPastaFpPlonkIndex}; - -// TOOD: remove incl all dependencies when no longer needed and we only pass napi objects around -#[napi(js_name = "prover_index_from_bytes")] -pub fn prover_index_from_bytes(bytes: Uint8Array) -> Result> { - report_native_call(); - - let index = WasmPastaFpPlonkIndex::deserialize_inner(bytes.as_ref()) - .map_err(|e| Error::new(Status::InvalidArg, e))?; - Ok(External::new(index)) -} - -// TOOD: remove incl all dependencies when no longer needed and we only pass napi objects around -#[napi(js_name = "prover_index_to_bytes")] -pub fn prover_index_to_bytes(index: &External) -> Result { - report_native_call(); - - let bytes = index - .serialize_inner() - .map_err(|e| Error::new(Status::GenericFailure, e))?; - Ok(Uint8Array::from(bytes)) -} diff --git a/plonk-napi/src/srs.rs b/plonk-napi/src/srs.rs index 3f66c82ddd..f996bbb9c8 100644 --- a/plonk-napi/src/srs.rs +++ b/plonk-napi/src/srs.rs @@ -1,30 +1,365 @@ -use std::sync::Arc; - -use mina_curves::pasta::{Pallas as GAffineOther, Vesta as GAffine}; -use napi::bindgen_prelude::{Error, External, Result as NapiResult, Status, Uint8Array}; +use crate::vector::NapiVector; +use ark_poly::{univariate::DensePolynomial, DenseUVPolynomial, EvaluationDomain, Evaluations}; +use core::ops::Deref; +use napi::bindgen_prelude::{Error, Result, Status, Uint8Array}; use napi_derive::napi; -use plonk_wasm::srs::fp::WasmFpSrs as WasmSrsFp; -use plonk_wasm::srs::fq::WasmFqSrs as WasmSrsFq; +use paste::paste; +use poly_commitment::{ + commitment::b_poly_coefficients, hash_map_cache::HashMapCache, ipa::SRS, SRS as ISRS, +}; +use serde::{Deserialize, Serialize}; +use std::{ + fs::{File, OpenOptions}, + io::{BufReader, BufWriter, Seek, SeekFrom}, + sync::Arc, +}; +use wasm_types::FlatVector as WasmFlatVector; + +macro_rules! impl_srs { + ( + $name:ident, + $NapiF:ty, + $NapiG:ty, + $F:ty, + $G:ty, + $NapiPolyComm:ty, + ) => { + paste! { + + #[napi(js_name = [<"Wasm" $name:camel "Srs">])] + #[derive(Clone)] + pub struct [] ( + #[napi(skip)] pub Arc> + ); + + impl Deref for [] { + type Target = Arc>; + + fn deref(&self) -> &Self::Target { &self.0 } + } + + impl From>> for [] { + fn from(x: Arc>) -> Self { + [](x) + } + } + + impl From<&Arc>> for [] { + fn from(x: &Arc>) -> Self { + [](x.clone()) + } + } + + impl From<[]> for Arc> { + fn from(x: []) -> Self { + x.0 + } + } + + impl From<&[]> for Arc> { + fn from(x: &[]) -> Self { + x.0.clone() + } + } + + impl<'a> From<&'a []> for &'a Arc> { + fn from(x: &'a []) -> Self { + &x.0 + } + } + + fn invalid_domain_error() -> Error { + Error::new(Status::InvalidArg, "invalid domain size") + } + + fn map_error(context: &str, err: impl std::fmt::Display) -> Error { + Error::new(Status::GenericFailure, format!("{}: {}", context, err)) + } + + #[napi] + impl [] { + + #[napi] + pub fn serialize(&self) -> Result { + let mut buf = Vec::new(); + self.0 + .serialize(&mut rmp_serde::Serializer::new(&mut buf)) + .map_err(|e| map_error("srs_serialize", e))?; + Ok(Uint8Array::from(buf)) + } + + #[napi] + pub fn deserialize(bytes: Uint8Array) -> Result { + let srs: SRS<$G> = rmp_serde::from_slice(bytes.as_ref()) + .map_err(|e| map_error("srs_deserialize", e))?; + Ok(Arc::new(srs).into()) + } + + #[napi(factory, js_name = [<"caml_" $name:snake "_srs_create">])] + pub fn [](depth: i32) -> Result { + println!("Creating SRS with napi"); + Ok(Arc::new(SRS::<$G>::create(depth as usize)).into()) + } + + #[napi(factory, js_name = [<"caml_" $name:snake "_srs_create_parallel">])] + pub fn [](depth: i32) -> Result { + println!("Creating SRS in parallel with napi"); + Ok(Arc::new(SRS::<$G>::create_parallel( + depth as usize, + )).into()) + } + + #[napi(js_name = [<"caml_" $name:snake "_srs_add_lagrange_basis">])] + pub fn [](srs: &[], log2_size: i32) -> Result<()> { + println!("Adding lagrange basis with napi"); + let size = 1usize << (log2_size as usize); + let domain = EvaluationDomain::<$F>::new(size).ok_or_else(invalid_domain_error)?; + srs.get_lagrange_basis(domain); + Ok(()) + } + + #[napi(js_name = [<"caml_" $name:snake "_srs_write">])] + pub fn [](append: Option, srs: &[], path: String) -> Result<()> { + println!("Writing SRS to file with napi"); + let function_name = format!("caml_{0}_srs_write", stringify!($name).to_lowercase()); + let file = OpenOptions::new() + .append(append.unwrap_or(true)) + .open(&path) + .map_err(|err| map_error(&function_name, err))?; + let file = BufWriter::new(file); + srs.0.serialize(&mut rmp_serde::Serializer::new(file)) + .map_err(|err| map_error(&function_name, err)) + } + + #[napi(js_name = [<"caml_" $name:snake "_srs_read">])] + pub fn [](offset: Option, path: String) -> Result> { + println!("Reading SRS from file with napi"); + let function_name = format!("caml_{0}_srs_read", stringify!($name).to_lowercase()); + let file = match File::open(&path) { + Ok(file) => file, + Err(err) => return Err(map_error(&function_name, err)), + }; + let mut reader = BufReader::new(file); + + if let Some(off) = offset { + reader + .seek(SeekFrom::Start(off as u64)) + .map_err(|err| map_error(&function_name, err))?; + } -use poly_commitment::ipa::SRS; + match SRS::<$G>::deserialize(&mut rmp_serde::Deserializer::new(reader)) { + Ok(srs) => Ok(Some(Arc::new(srs).into())), + Err(_) => Ok(None), + } + } + + #[napi(js_name = [<"caml_" $name:snake "_srs_get">])] + pub fn [](srs: &[]) -> Vec<$NapiG> { + println!("Getting SRS with napi"); + let mut h_and_gs: Vec<$NapiG> = vec![srs.0.h.into()]; + h_and_gs.extend(srs.0.g.iter().cloned().map(Into::into)); + h_and_gs + } + + #[napi(js_name = [<"caml_" $name:snake "_srs_set">])] + pub fn [](h_and_gs: Vec<$NapiG>) -> Result { + println!("Setting SRS with napi"); + let mut h_and_gs: Vec<$G> = h_and_gs.into_iter().map(Into::into).collect(); + if h_and_gs.is_empty() { + return Err(Error::new( + Status::InvalidArg, + "expected at least one element for SRS", + )); + } + let h = h_and_gs.remove(0); + let g = h_and_gs; + let srs = SRS::<$G> { h, g, lagrange_bases: HashMapCache::new() }; + Ok(Arc::new(srs).into()) + } + + #[napi(js_name = [<"caml_" $name:snake "_srs_maybe_lagrange_commitment">])] + pub fn []( + srs: &[], + domain_size: i32, + i: i32, + ) -> Option<$NapiPolyComm> { + println!("Getting maybe lagrange commitment with napi"); + if !srs + .0 + .lagrange_bases + .contains_key(&(domain_size as usize)) + { + return None; + } + let basis = srs + .get_lagrange_basis_from_domain_size(domain_size as usize); + Some(basis[i as usize].clone().into()) + } + + #[napi(js_name = [<"caml_" $name:snake "_srs_set_lagrange_basis">])] + pub fn [](srs: &[], + domain_size: i32, + input_bases: NapiVector<$NapiPolyComm>, + ) { + println!("Setting lagrange basis with napi"); + srs.0.lagrange_bases + .get_or_generate(domain_size as usize, || { input_bases.into_iter().map(Into::into).collect()}); + } + + #[napi(js_name = [<"caml_" $name:snake "_srs_get_lagrange_basis">])] + pub fn [](srs: &[], + domain_size: i32, + ) -> Result> { + println!("Getting lagrange basis with napi"); + let domain = EvaluationDomain::<$F>::new(domain_size as usize) + .ok_or_else(invalid_domain_error)?; + let basis = srs.0.get_lagrange_basis(domain); + Ok(basis.iter().cloned().map(Into::into).collect()) + } + + #[napi(js_name = [<"caml_" $name:snake "_srs_to_bytes">])] + pub fn [](srs: &[]) -> Result { + srs.serialize() + } + + #[napi(js_name = [<"caml_" $name:snake "_srs_from_bytes">])] + pub fn [](bytes: Uint8Array) -> Result { + Self::deserialize(bytes) + } + + #[napi(js_name = [<"caml_" $name:snake "_srs_commit_evaluations">])] + pub fn [](srs: &[], + domain_size: i32, + evals: Uint8Array, + ) -> Result<$NapiPolyComm> { + println!("Committing evaluations with napi"); + let elems: Vec<$F> = WasmFlatVector::<$NapiF>::from_bytes( + evals.as_ref().to_vec(), + ) + .into_iter() + .map(Into::into) + .collect(); + let x_domain = EvaluationDomain::<$F>::new(domain_size as usize) + .ok_or_else(invalid_domain_error)?; + let evals = elems.into_iter().map(Into::into).collect(); + let p = Evaluations::<$F>::from_vec_and_domain(evals, x_domain).interpolate(); + Ok(srs.commit_non_hiding(&p, 1).into()) + } + + #[napi(js_name = [<"caml_" $name:snake "_srs_b_poly_commitment">])] + pub fn [](srs: &[], chals: Uint8Array) -> Result<$NapiPolyComm> { + println!("Computing b poly commitment with napi"); + let elements: Vec<$F> = WasmFlatVector::<$NapiF>::from_bytes( + chals.as_ref().to_vec(), + ) + .into_iter() + .map(Into::into) + .collect(); + let coeffs = b_poly_coefficients(&elements); + let p = DensePolynomial::<$F>::from_coefficients_vec(coeffs); + Ok(srs.commit_non_hiding(&p, 1).into()) + } + + #[napi(js_name = [<"caml_" $name:snake "_srs_batch_accumulator_check">])] + pub fn []( + srs: &[], + comms: NapiVector<$NapiG>, + chals: Uint8Array, + ) -> Result { + println!("Performing batch accumulator check with napi"); + let comms: Vec<$G> = comms.into_iter().map(Into::into).collect(); + let chals: Vec<$F> = WasmFlatVector::<$NapiF>::from_bytes( + chals.as_ref().to_vec(), + ) + .into_iter() + .map(Into::into) + .collect(); + Ok(poly_commitment::utils::batch_dlog_accumulator_check( + &srs, + &comms, + &chals, + )) + } + + #[napi(js_name = [<"caml_" $name:snake "_srs_batch_accumulator_generate">])] + pub fn []( + srs: &[], + comms: i32, + chals: Uint8Array, + ) -> Result> { + println!("Generating batch accumulator with napi"); + let chals: Vec<$F> = WasmFlatVector::<$NapiF>::from_bytes( + chals.as_ref().to_vec(), + ) + .into_iter() + .map(Into::into) + .collect(); + let points = poly_commitment::utils::batch_dlog_accumulator_generate::<$G>( + &srs, + comms as usize, + &chals, + ); + Ok(points.into_iter().map(Into::into).collect()) + } + + #[napi(js_name = [<"caml_" $name:snake "_srs_get_h">])] + pub fn h(srs: &[]) -> $NapiG { + println!("Getting h point with napi"); + srs.h.into() + } + } + } + } +} + +#[napi] +pub fn caml_fp_srs_to_bytes(srs: &fp::NapiFpSrs) -> Result { + srs.serialize() +} #[napi] -pub fn caml_fp_srs_to_bytes(srs: External) -> NapiResult { - let buffer = rmp_serde::to_vec(srs.as_ref().0.as_ref()) - .map_err(|e| Error::new(Status::GenericFailure, e.to_string()))?; - Ok(Uint8Array::from(buffer)) +pub fn caml_fp_srs_from_bytes(bytes: Uint8Array) -> Result { + fp::NapiFpSrs::deserialize(bytes) } #[napi] -pub fn caml_fp_srs_from_bytes(bytes: Uint8Array) -> NapiResult> { - let srs: SRS = rmp_serde::from_slice(bytes.as_ref()) - .map_err(|e| Error::new(Status::InvalidArg, e.to_string()))?; - Ok(External::new(Arc::new(srs).into())) +pub fn caml_fq_srs_to_bytes(srs: &fq::NapiFqSrs) -> Result { + srs.serialize() } #[napi] -pub fn caml_fq_srs_from_bytes(bytes: Uint8Array) -> NapiResult> { - let srs: SRS = rmp_serde::from_slice(bytes.as_ref()) - .map_err(|e: rmp_serde::decode::Error| Error::new(Status::InvalidArg, e.to_string()))?; - Ok(External::new(Arc::new(srs).into())) +pub fn caml_fq_srs_from_bytes(bytes: Uint8Array) -> Result { + fq::NapiFqSrs::deserialize(bytes) +} + +pub mod fp { + use super::*; + use crate::{ + poly_comm::vesta::NapiFpPolyComm, + wrappers::{field::NapiPastaFp, group::NapiGVesta}, + }; + impl_srs!( + fp, // field name + NapiPastaFp, // Napi field wrapper + NapiGVesta, // Napi group wrapper + mina_curves::pasta::Fp, // Actual Kimchi field + mina_curves::pasta::Vesta, // Actual kimchi group + NapiFpPolyComm, // Napi poly commitment type + ); +} + +pub mod fq { + use super::*; + use crate::{ + poly_comm::pallas::NapiFqPolyComm, + wrappers::{field::NapiPastaFq, group::NapiGPallas}, + }; + impl_srs!( + fq, // Field name + NapiPastaFq, // Napi field wrapper + NapiGPallas, // Napi group wrapper + mina_curves::pasta::Fq, // Actual Kimchi field + mina_curves::pasta::Pallas, // Actual kimchi group + NapiFqPolyComm, // Napi poly commitment type + ); } diff --git a/plonk-napi/src/tables.rs b/plonk-napi/src/tables.rs index fd00cb94ae..cd9f4f718b 100644 --- a/plonk-napi/src/tables.rs +++ b/plonk-napi/src/tables.rs @@ -1,8 +1,5 @@ use arkworks::{WasmPastaFp, WasmPastaFq}; -use kimchi::circuits::lookup::{ - runtime_tables::RuntimeTableCfg, - tables::LookupTable, -}; +use kimchi::circuits::lookup::{runtime_tables::RuntimeTableCfg, tables::LookupTable}; use mina_curves::pasta::{Fp, Fq}; use napi::bindgen_prelude::Uint8Array; use napi_derive::napi; @@ -66,18 +63,14 @@ pub fn lookup_table_fq_from_js(js: JsLookupTableFq) -> napi::Result napi::Result> { +pub fn runtime_table_cfg_fp_from_js(js: JsRuntimeTableCfgFp) -> napi::Result> { Ok(RuntimeTableCfg { id: js.id, first_column: bytes_to_fp_vec(typed_array_to_vec(&js.first_column)), }) } -pub fn runtime_table_cfg_fq_from_js( - js: JsRuntimeTableCfgFq, -) -> napi::Result> { +pub fn runtime_table_cfg_fq_from_js(js: JsRuntimeTableCfgFq) -> napi::Result> { Ok(RuntimeTableCfg { id: js.id, first_column: bytes_to_fq_vec(typed_array_to_vec(&js.first_column)), diff --git a/plonk-wasm/src/gate_vector.rs b/plonk-wasm/src/gate_vector.rs index 7928845af1..d78f9d82a7 100644 --- a/plonk-wasm/src/gate_vector.rs +++ b/plonk-wasm/src/gate_vector.rs @@ -60,7 +60,12 @@ pub mod shared { Gate { typ: cg.typ, wires: GateWires::new([ - cg.wires[0], cg.wires[1], cg.wires[2], cg.wires[3], cg.wires[4], cg.wires[5], + cg.wires[0], + cg.wires[1], + cg.wires[2], + cg.wires[3], + cg.wires[4], + cg.wires[5], cg.wires[6], ]), coeffs: cg.coeffs, @@ -76,7 +81,12 @@ pub mod shared { Gate { typ: cg.typ, wires: GateWires::new([ - cg.wires[0], cg.wires[1], cg.wires[2], cg.wires[3], cg.wires[4], cg.wires[5], + cg.wires[0], + cg.wires[1], + cg.wires[2], + cg.wires[3], + cg.wires[4], + cg.wires[5], cg.wires[6], ]), coeffs: cg.coeffs.clone(), @@ -181,7 +191,9 @@ pub mod shared { } } -pub use self::shared::{Gate as CoreGate, GateVector as CoreGateVector, GateWires as CoreGateWires}; +pub use self::shared::{ + Gate as CoreGate, GateVector as CoreGateVector, GateWires as CoreGateWires, +}; #[wasm_bindgen] #[derive(Clone, Copy, Debug)]