Skip to content

Commit

Permalink
refactor: update MerkleProof
Browse files Browse the repository at this point in the history
  • Loading branch information
enricobottazzi committed Nov 21, 2023
1 parent 68ab601 commit eda4ad2
Show file tree
Hide file tree
Showing 8 changed files with 22 additions and 70 deletions.
13 changes: 3 additions & 10 deletions zk_prover/benches/full_solvency_flow.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
#![feature(generic_const_exprs)]
use criterion::{criterion_group, criterion_main, Criterion};
use halo2_proofs::{
halo2curves::bn256::Fr as Fp,
plonk::{keygen_pk, keygen_vk},
};
use halo2_proofs::plonk::{keygen_pk, keygen_vk};
use snark_verifier_sdk::CircuitExt;
use summa_solvency::{
circuits::merkle_sum_tree::MstInclusionCircuit,
Expand Down Expand Up @@ -112,10 +109,8 @@ fn generate_zk_proof_mst_inclusion_circuit(_c: &mut Criterion) {
let user_index = 0;

let merkle_proof = merkle_sum_tree.generate_proof(user_index).unwrap();
let user_entry = merkle_sum_tree.get_entry(user_index);

let circuit =
MstInclusionCircuit::<LEVELS, N_ASSETS, N_BYTES>::init(merkle_proof, user_entry.clone());
let circuit = MstInclusionCircuit::<LEVELS, N_ASSETS, N_BYTES>::init(merkle_proof);

let bench_name = format!(
"generate zk proof - tree of 2 power of {} entries with {} assets mst inclusion circuit",
Expand Down Expand Up @@ -147,10 +142,8 @@ fn verify_zk_proof_mst_inclusion_circuit(_c: &mut Criterion) {
let user_index = 0;

let merkle_proof = merkle_sum_tree.generate_proof(user_index).unwrap();
let user_entry = merkle_sum_tree.get_entry(user_index);

let circuit =
MstInclusionCircuit::<LEVELS, N_ASSETS, N_BYTES>::init(merkle_proof, user_entry.clone());
let circuit = MstInclusionCircuit::<LEVELS, N_ASSETS, N_BYTES>::init(merkle_proof);

let proof = full_prover(&params, &pk, circuit.clone(), circuit.instances());

Expand Down
4 changes: 1 addition & 3 deletions zk_prover/examples/gen_inclusion_verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,9 @@ fn main() {
let user_index = 0;

let merkle_proof = merkle_sum_tree.generate_proof(user_index).unwrap();
let user_entry = merkle_sum_tree.get_entry(user_index);

// Generate the circuit with the actual inputs
let circuit =
MstInclusionCircuit::<LEVELS, N_ASSETS, N_BYTES>::init(merkle_proof, user_entry.clone());
let circuit = MstInclusionCircuit::<LEVELS, N_ASSETS, N_BYTES>::init(merkle_proof);

let instances = circuit.instances();

Expand Down
Binary file modified zk_prover/prints/mst-inclusion-layout.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
8 changes: 2 additions & 6 deletions zk_prover/src/circuits/merkle_sum_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ where
}

/// Initializes the circuit with the merkle proof and the entry of the user of which the inclusion is to be verified.
pub fn init(merkle_proof: MerkleProof<N_ASSETS, N_BYTES>, entry: Entry<N_ASSETS>) -> Self
pub fn init(merkle_proof: MerkleProof<N_ASSETS, N_BYTES>) -> Self
where
[usize; N_ASSETS + 1]: Sized,
[usize; N_ASSETS + 2]: Sized,
Expand All @@ -92,12 +92,8 @@ where
merkle_proof.sibling_middle_node_hash_preimages.len(),
LEVELS - 1
);

// assert that the entry leaf hash matches the leaf hash in the merkle proof
assert_eq!(merkle_proof.leaf.hash, entry.compute_leaf().hash);

Self {
entry,
entry: merkle_proof.entry,
path_indices: merkle_proof.path_indices,
sibling_leaf_node_hash_preimage: merkle_proof.sibling_leaf_node_hash_preimage,
sibling_middle_node_hash_preimages: merkle_proof.sibling_middle_node_hash_preimages,
Expand Down
53 changes: 9 additions & 44 deletions zk_prover/src/circuits/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,8 @@ mod test {
for user_index in 0..16 {
// get proof for entry ˆuser_indexˆ
let merkle_proof = merkle_sum_tree.generate_proof(user_index).unwrap();
let user_entry = merkle_sum_tree.get_entry(user_index);

let circuit = MstInclusionCircuit::<LEVELS, N_ASSETS, N_BYTES>::init(
merkle_proof,
user_entry.clone(),
);
let circuit = MstInclusionCircuit::<LEVELS, N_ASSETS, N_BYTES>::init(merkle_proof);

let valid_prover = MockProver::run(K, &circuit, circuit.instances()).unwrap();

Expand Down Expand Up @@ -69,10 +65,7 @@ mod test {
let user_entry = merkle_sum_tree.get_entry(user_index);

// Only now we can instantiate the circuit with the actual inputs
let circuit = MstInclusionCircuit::<LEVELS, N_ASSETS, N_BYTES>::init(
merkle_proof,
user_entry.clone(),
);
let circuit = MstInclusionCircuit::<LEVELS, N_ASSETS, N_BYTES>::init(merkle_proof);

// Generate the proof
let proof = full_prover(&params, &pk, circuit.clone(), circuit.instances());
Expand Down Expand Up @@ -105,12 +98,8 @@ mod test {
let user_index = 0;

let merkle_proof = merkle_sum_tree.generate_proof(user_index).unwrap();
let user_entry = merkle_sum_tree.get_entry(user_index);

let circuit = MstInclusionCircuit::<LEVELS, N_ASSETS, N_BYTES>::init(
merkle_proof,
user_entry.clone(),
);
let circuit = MstInclusionCircuit::<LEVELS, N_ASSETS, N_BYTES>::init(merkle_proof);

let mut instances = circuit.instances();
let invalid_root_hash = Fp::from(1000u64);
Expand Down Expand Up @@ -150,13 +139,9 @@ mod test {
let user_index = 0;

let merkle_proof = merkle_sum_tree.generate_proof(user_index).unwrap();
let user_entry = merkle_sum_tree.get_entry(user_index);

// Only now we can instantiate the circuit with the actual inputs
let circuit = MstInclusionCircuit::<LEVELS, N_ASSETS, N_BYTES>::init(
merkle_proof,
user_entry.clone(),
);
let circuit = MstInclusionCircuit::<LEVELS, N_ASSETS, N_BYTES>::init(merkle_proof);

let invalid_root_hash = Fp::from(1000u64);

Expand All @@ -183,13 +168,9 @@ mod test {
let user_index = 0;

let merkle_proof = merkle_sum_tree.generate_proof(user_index).unwrap();
let user_entry = merkle_sum_tree.get_entry(user_index);

// Only now we can instantiate the circuit with the actual inputs
let mut circuit = MstInclusionCircuit::<LEVELS, N_ASSETS, N_BYTES>::init(
merkle_proof,
user_entry.clone(),
);
let mut circuit = MstInclusionCircuit::<LEVELS, N_ASSETS, N_BYTES>::init(merkle_proof);

let instances = circuit.instances();

Expand Down Expand Up @@ -263,13 +244,9 @@ mod test {
let user_index = 0;

let merkle_proof = merkle_sum_tree.generate_proof(user_index).unwrap();
let user_entry = merkle_sum_tree.get_entry(user_index);

// Only now we can instantiate the circuit with the actual inputs
let circuit = MstInclusionCircuit::<LEVELS, N_ASSETS, N_BYTES>::init(
merkle_proof,
user_entry.clone(),
);
let circuit = MstInclusionCircuit::<LEVELS, N_ASSETS, N_BYTES>::init(merkle_proof);

let mut instances = circuit.instances();
let invalid_leaf_hash = Fp::from(1000u64);
Expand Down Expand Up @@ -305,13 +282,9 @@ mod test {
let user_index = 0;

let merkle_proof = merkle_sum_tree.generate_proof(user_index).unwrap();
let user_entry = merkle_sum_tree.get_entry(user_index);

// Only now we can instantiate the circuit with the actual inputs
let mut circuit = MstInclusionCircuit::<LEVELS, N_ASSETS, N_BYTES>::init(
merkle_proof,
user_entry.clone(),
);
let mut circuit = MstInclusionCircuit::<LEVELS, N_ASSETS, N_BYTES>::init(merkle_proof);

let instances = circuit.instances();

Expand Down Expand Up @@ -430,13 +403,9 @@ mod test {
let user_index = 0;

let merkle_proof = merkle_sum_tree.generate_proof(user_index).unwrap();
let user_entry = merkle_sum_tree.get_entry(user_index);

// Only now we can instantiate the circuit with the actual inputs
let mut circuit = MstInclusionCircuit::<LEVELS, N_ASSETS, N_BYTES>::init(
merkle_proof,
user_entry.clone(),
);
let mut circuit = MstInclusionCircuit::<LEVELS, N_ASSETS, N_BYTES>::init(merkle_proof);

let instances = circuit.instances();

Expand Down Expand Up @@ -475,12 +444,8 @@ mod test {
let user_index = 0;

let merkle_proof = merkle_sum_tree.generate_proof(user_index).unwrap();
let user_entry = merkle_sum_tree.get_entry(user_index);

let circuit = MstInclusionCircuit::<LEVELS, N_ASSETS, N_BYTES>::init(
merkle_proof,
user_entry.clone(),
);
let circuit = MstInclusionCircuit::<LEVELS, N_ASSETS, N_BYTES>::init(merkle_proof);

let root = BitMapBackend::new("prints/mst-inclusion-layout.png", (2048, 32768))
.into_drawing_area();
Expand Down
2 changes: 1 addition & 1 deletion zk_prover/src/merkle_sum_tree/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ where
[usize; N_ASSETS + 1]: Sized,
[usize; N_ASSETS + 2]: Sized,
{
pub leaf: Node<N_ASSETS>,
pub entry: Entry<N_ASSETS>,
pub root: Node<N_ASSETS>,
pub sibling_leaf_node_hash_preimage: [Fp; N_ASSETS + 1],
pub sibling_middle_node_hash_preimages: Vec<[Fp; N_ASSETS + 2]>,
Expand Down
4 changes: 2 additions & 2 deletions zk_prover/src/merkle_sum_tree/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ mod test {
[35479.to_biguint().unwrap(), 35479.to_biguint().unwrap()],
)
.unwrap();
let invalid_leaf = invalid_entry.compute_leaf();
let invalid_entry = invalid_entry;
let mut proof_invalid_1 = proof.clone();
proof_invalid_1.leaf = invalid_leaf;
proof_invalid_1.entry = invalid_entry;
assert!(!merkle_tree.verify_proof(&proof_invalid_1));

// shouldn't verify a proof with a wrong root hash
Expand Down
8 changes: 4 additions & 4 deletions zk_prover/src/merkle_sum_tree/tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,6 @@ pub trait Tree<const N_ASSETS: usize, const N_BYTES: usize> {
let mut path_indices = vec![Fp::zero(); depth];
let mut current_index = index;

let leaf = &nodes[0][index];

for level in 0..depth {
let position = current_index % 2;
let sibling_index = current_index - position + (1 - position);
Expand All @@ -129,8 +127,10 @@ pub trait Tree<const N_ASSETS: usize, const N_BYTES: usize> {
current_index /= 2;
}

let entry = self.get_entry(index).clone();

Ok(MerkleProof {
leaf: leaf.clone(),
entry,
root: root.clone(),
sibling_leaf_node_hash_preimage,
sibling_middle_node_hash_preimages,
Expand All @@ -144,7 +144,7 @@ pub trait Tree<const N_ASSETS: usize, const N_BYTES: usize> {
[usize; N_ASSETS + 1]: Sized,
[usize; N_ASSETS + 2]: Sized,
{
let mut node = proof.leaf.clone();
let mut node = proof.entry.compute_leaf();

let sibling_leaf_node_balances = proof.sibling_leaf_node_hash_preimage[1..]
.try_into()
Expand Down

0 comments on commit eda4ad2

Please sign in to comment.