Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mpt fix #1

Merged
merged 4 commits into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 68 additions & 33 deletions lib/mpt.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ from lib.rlp_little import (
assert_subset_in_key,
extract_nibble_from_key,
)
from lib.utils import felt_divmod, felt_divmod_8, word_reverse_endian_64, get_felt_bitlength
from lib.utils import felt_divmod, felt_divmod_8, word_reverse_endian_64, get_felt_bitlength_128

// Verify a Merkle Patricia Tree proof.
// params:
Expand All @@ -37,11 +37,12 @@ func verify_mpt_proof{range_check_ptr, bitwise_ptr: BitwiseBuiltin*, keccak_ptr:
pow2_array: felt*,
) -> (value: felt*, value_len: felt) {
alloc_locals;
%{ print(f"\n\nNode index {ids.node_index+1}/{ids.mpt_proof_len}") %}
%{ print(f"\n\nNode index {ids.node_index+1}/{ids.mpt_proof_len} \n \t {ids.n_nibbles_already_checked=}") %}
if (node_index == mpt_proof_len - 1) {
// Last node : item of interest is the value.
// Check that the hash of the last node is the expected one.
// Check that the final accumulated key is the expected one.
// Check the number of bytes in the key is equal to the number of bytes checked in the key.
let (node_hash: Uint256) = keccak(mpt_proof[node_index], mpt_proof_bytes_len[node_index]);
%{ print(f"node_hash : {hex(ids.node_hash.low + 2**128*ids.node_hash.high)}") %}
%{ print(f"hash_to_assert : {hex(ids.hash_to_assert.low + 2**128*ids.hash_to_assert.high)}") %}
Expand All @@ -56,7 +57,37 @@ func verify_mpt_proof{range_check_ptr, bitwise_ptr: BitwiseBuiltin*, keccak_ptr:
key_little=key_little,
n_nibbles_already_checked=n_nibbles_already_checked,
);
local key_bits;
with pow2_array {
if (key_little.high != 0) {
let key_bit_high = get_felt_bitlength_128(key_little.high);
assert key_bits = 128 + key_bit_high;
} else {
let key_bit_low = get_felt_bitlength_128(key_little.low);
assert key_bits = key_bit_low;
}
}
local n_bytes_in_key;
let (n_bytes_in_key_tmp, rem) = felt_divmod_8(key_bits);

if (n_bytes_in_key_tmp == 0) {
assert n_bytes_in_key = 1;
} else {
if (rem != 0) {
assert n_bytes_in_key = n_bytes_in_key_tmp + 1;
} else {
assert n_bytes_in_key = n_bytes_in_key_tmp;
}
}

local n_bytes_checked;
let (n_bytes_checked_tmp, rem) = felt_divmod(n_nibbles_checked, 2);
if (rem != 0) {
assert n_bytes_checked = n_bytes_checked_tmp + 1;
} else {
assert n_bytes_checked = n_bytes_checked_tmp;
}
assert n_bytes_in_key = n_bytes_checked;
return (item_of_interest, item_of_interest_len);
} else {
// Not last node : item of interest is the hash of the next node.
Expand Down Expand Up @@ -284,11 +315,11 @@ func decode_node_list_lazy{range_check_ptr, bitwise_ptr: BitwiseBuiltin*}(
// Ensure first_item_type is either 0 or 1.
assert (first_item_type - 1) * (first_item_type) = 0;

let first_item_prefix = extract_nibble_at_byte_pos(
let first_item_key_prefix = extract_nibble_at_byte_pos(
rlp[0], first_item_start_offset + first_item_type, 0, pow2_array
);
%{
prefix = ids.first_item_prefix
prefix = ids.first_item_key_prefix
if prefix == 0:
print("First item is an extension node, even number of nibbles")
elif prefix == 1:
Expand All @@ -301,10 +332,10 @@ func decode_node_list_lazy{range_check_ptr, bitwise_ptr: BitwiseBuiltin*}(
raise Exception(f"Unknown prefix {prefix} for MPT node with 2 items")
%}
local odd: felt;
if (first_item_prefix == 0) {
if (first_item_key_prefix == 0) {
assert odd = 0;
} else {
if (first_item_prefix == 2) {
if (first_item_key_prefix == 2) {
assert odd = 0;
} else {
// 1 & 3 case.
Expand All @@ -328,9 +359,13 @@ func decode_node_list_lazy{range_check_ptr, bitwise_ptr: BitwiseBuiltin*}(
);
%{ print(f"nibbles already checked: {ids.n_nibbles_already_checked}") %}

local range_check_ptr_f;
local bitwise_ptr_f: BitwiseBuiltin*;
local n_nibbles_already_checked_f;
local pow2_array_f: felt*;
if (first_item_type != 0) {
// If the first item is not a single byte, verify subset in key.
assert_subset_in_key(
let (n_nibbles_asserted) = assert_subset_in_key(
key_subset=extracted_key_subset,
key_subset_len=extracted_key_subset_len,
key_subset_nibble_len=n_nibbles_in_first_item,
Expand All @@ -339,36 +374,36 @@ func decode_node_list_lazy{range_check_ptr, bitwise_ptr: BitwiseBuiltin*}(
cut_nibble=odd,
pow2_array=pow2_array,
);
tempvar range_check_ptr = range_check_ptr;
tempvar bitwise_ptr = bitwise_ptr;
tempvar pow2_array = pow2_array;
assert range_check_ptr_f = range_check_ptr;
assert bitwise_ptr_f = bitwise_ptr;
assert n_nibbles_already_checked_f = n_nibbles_already_checked + n_nibbles_asserted;
assert pow2_array_f = pow2_array;
} else {
// if the first item is a single byte, skip subset verification and assert n_nibbles_already_checked == n_nibbles_in_key
local key_bits;
with pow2_array {
if (key_little.high != 0) {
let key_bit_high = get_felt_bitlength(key_little.high);
assert key_bits = 128 + key_bit_high;
} else {
let key_bit_low = get_felt_bitlength(key_little.low);
assert key_bits = key_bit_low;
}
}
local n_nibbles_in_key: felt; // <=> ceil(key_bits/4)
let (n_nibbles_in_key_tmp, remainder) = felt_divmod(key_bits, 4);
if (remainder != 0) {
assert n_nibbles_in_key = n_nibbles_in_key_tmp + 1;
// if the first item is a single byte

if (odd != 0) {
// If the first item has an odd number of nibbles, since there are two nibbles in one byte, the second nibble needs to be checked
let key_nibble = extract_nibble_from_key(
key_little, n_nibbles_already_checked, pow2_array
);
let (_, first_item_nibble) = felt_divmod(first_item_prefix, 2 ** 4);
assert key_nibble = first_item_nibble;
assert range_check_ptr_f = range_check_ptr;
assert bitwise_ptr_f = bitwise_ptr;
assert n_nibbles_already_checked_f = n_nibbles_already_checked + 1;
assert pow2_array_f = pow2_array;
} else {
assert n_nibbles_in_key = n_nibbles_in_key_tmp;
// If the first item has en even number of nibbles, since there are two nibbles in one byte, there is nothing to check.
assert range_check_ptr_f = range_check_ptr;
assert bitwise_ptr_f = bitwise_ptr;
assert n_nibbles_already_checked_f = n_nibbles_already_checked;
assert pow2_array_f = pow2_array;
}
assert n_nibbles_in_key = n_nibbles_already_checked;
tempvar range_check_ptr = range_check_ptr;
tempvar bitwise_ptr = bitwise_ptr;
tempvar pow2_array = pow2_array;
}
let range_check_ptr = range_check_ptr;
let bitwise_ptr = bitwise_ptr;
let pow2_array = pow2_array;
let range_check_ptr = range_check_ptr_f;
let bitwise_ptr = bitwise_ptr_f;
let pow2_array = pow2_array_f;
let n_nibbles_already_checked = n_nibbles_already_checked_f;

// Extract the hash or value.

Expand Down
33 changes: 26 additions & 7 deletions lib/rlp_little.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ from lib.utils import (
get_0xff_mask,
word_reverse_endian_64,
bitwise_divmod,
get_felt_bitlength_128,
)

// Takes a 64 bit word in little endian, returns the byte at a given position as it would be in big endian.
Expand Down Expand Up @@ -79,10 +80,11 @@ func key_subset_to_uint256(key_subset: felt*, key_subset_len: felt) -> Uint256 {
// params:
// key_subset : array of 64 bit words with little endian bytes, representing a subset of the key
// key_subset_len : length of the subset in number of 64 bit words
// key_subset_bytes_len : length of the subset in number of bytes
// key_subset_bytes_len : length of the subset in number of nibbles
// key subset is of the form [b7 b6 b5 b4 b3 b2 b1 b0, b15 b14 b13 b12 b11 b10 b9 b8, ...]
// key_little : 256 bit key in little endian
// key_little is of the form high = [b63, ..., b32] , low = [b31, ..., b0]
// returns the actual number of nibbles checked from key_subset within the actual key_little
func assert_subset_in_key{range_check_ptr, bitwise_ptr: BitwiseBuiltin*}(
key_subset: felt*,
key_subset_len: felt,
Expand All @@ -91,7 +93,7 @@ func assert_subset_in_key{range_check_ptr, bitwise_ptr: BitwiseBuiltin*}(
n_nibbles_already_checked: felt,
cut_nibble: felt,
pow2_array: felt*,
) -> () {
) -> (n_nibbles_checked: felt) {
alloc_locals;
let key_subset_256t = key_subset_to_uint256(key_subset, key_subset_len);
%{ print(f"key_susbet_uncut={hex(ids.key_subset_256t.low + ids.key_subset_256t.high*2**128)}") %}
Expand Down Expand Up @@ -122,8 +124,8 @@ func assert_subset_in_key{range_check_ptr, bitwise_ptr: BitwiseBuiltin*}(
local key_shifted: Uint256;
local key_shifted_last_nibble: felt;
if (odd_checked_nibbles != 0) {
let (upow) = uint256_pow2(Uint256((n_nibbles_already_checked + 1) * 4, 0));
let (key_shiftedt, rem) = uint256_unsigned_div_rem(key_little, upow);
let (upow) = uint256_pow2(Uint256((n_nibbles_already_checked + 1) * 4, 0)); // p = 2**(n_nib_checked+1)
let (key_shiftedt, rem) = uint256_unsigned_div_rem(key_little, upow); //
let (upow_) = uint256_pow2(Uint256((n_nibbles_already_checked - 1) * 4, 0));
let (byte_u256, _) = uint256_unsigned_div_rem(rem, upow_);
let (_, nibble) = felt_divmod(byte_u256.low, 2 ** 4);
Expand Down Expand Up @@ -156,17 +158,34 @@ func assert_subset_in_key{range_check_ptr, bitwise_ptr: BitwiseBuiltin*}(
print(f"\t final key high : {hex(ids.key_high)}")
print(f"\t key subset high : {hex(ids.key_subset_256.high)}")
%}

local key_subset_nibbles;
let key_subset_bits = get_felt_bitlength_128{pow2_array=pow2_array}(key_subset_256.high);
let (key_subset_nibbles_tmp, remainder) = felt_divmod(128 + key_subset_bits, 4);
if (remainder != 0) {
assert key_subset_nibbles = key_subset_nibbles_tmp + 1;
} else {
assert key_subset_nibbles = key_subset_nibbles_tmp;
}
assert key_subset_256.low = key_shifted.low;
assert key_subset_256.high = key_high;
assert key_subset_last_nibble = key_shifted_last_nibble;
return ();
return (n_nibbles_checked=key_subset_nibbles + cut_nibble);
} else {
let (_, key_low) = felt_divmod(key_shifted.low, pow2_array[4 * key_subset_nibble_len]);
assert key_subset_256.low = key_low;
assert key_subset_256.high = 0;
assert key_subset_last_nibble = key_shifted_last_nibble;
return ();
local key_subset_nibbles;
let key_subset_bits = get_felt_bitlength_128{pow2_array=pow2_array}(key_subset_256.low);
let (key_subset_nibbles_tmp, remainder) = felt_divmod(key_subset_bits, 4);

if (remainder != 0) {
assert key_subset_nibbles = key_subset_nibbles_tmp + 1;
} else {
assert key_subset_nibbles = key_subset_nibbles_tmp;
}

return (n_nibbles_checked=key_subset_nibbles + cut_nibble);
}
}

Expand Down
34 changes: 34 additions & 0 deletions lib/utils.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,40 @@ func get_felt_bitlength{range_check_ptr, pow2_array: felt*}(x: felt) -> felt {
return bit_length;
}

// Returns the number of bits in x.
// Implicits arguments:
// - pow2_array: felt* - A pointer such that pow2_array[i] = 2^i for i in [0, 128].
// Params:
// - x: felt - Input value.
// Assumptions for the caller:
// - 1 <= x < 2^128
// Returns:
// - bit_length: felt - Number of bits in x.
func get_felt_bitlength_128{range_check_ptr, pow2_array: felt*}(x: felt) -> felt {
alloc_locals;
local bit_length;
%{
x = ids.x
ids.bit_length = x.bit_length()
%}
if (bit_length == 128) {
assert [range_check_ptr] = x - 2 ** 127;
tempvar range_check_ptr = range_check_ptr + 1;
return bit_length;
} else {
// Computes N=2^bit_length and n=2^(bit_length-1)
// x is supposed to verify n = 2^(b-1) <= x < N = 2^bit_length <=> x has bit_length bits
tempvar N = pow2_array[bit_length];
tempvar n = pow2_array[bit_length - 1];
assert [range_check_ptr] = bit_length;
assert [range_check_ptr + 1] = 128 - bit_length;
assert [range_check_ptr + 2] = N - x - 1;
assert [range_check_ptr + 3] = x - n;
tempvar range_check_ptr = range_check_ptr + 4;
return bit_length;
}
}

// Computes x//y and x%y.
// Assumption: y must be a power of 2
// params:
Expand Down
Loading