Skip to content

Commit

Permalink
Count exact number of leading zeroes in assert_subset_in_key.
Browse files Browse the repository at this point in the history
  • Loading branch information
feltroidprime committed May 6, 2024
1 parent 4b4def1 commit f37ebf8
Showing 1 changed file with 185 additions and 57 deletions.
242 changes: 185 additions & 57 deletions lib/rlp_little.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ from lib.utils import (
get_uint256_bit_length,
get_felt_n_nibbles,
n_bits_to_n_nibbles,
count_trailing_zeroes_128,
)

func n_nibbles_in_key{range_check_ptr}(key: Uint256, pow2_array: felt*) -> (res: felt) {
Expand All @@ -28,6 +29,127 @@ func n_nibbles_in_key{range_check_ptr}(key: Uint256, pow2_array: felt*) -> (res:
return (res=num_nibbles_in_key);
}

// From a Uint256 number in little endian bytes representation,
// predict the number of leading zeroes nibbles before the number is converted to BE representation.
// Parameters:
// - x: the liitle endian representation of the number.
// - n_nibbles_after_reversion: the fixed # of nibbles in the number after reversion. This is known from RLP decoding.
// - cut_nibble: if 1, takes into account that the leftmost nibble in BE representation will be cut.
// - pow2_array: array of powers of 2.
// Example 1:
// LE input : 0x ab 0d 0f 00 : cut_nibble = 0
// BE reverted : 0x 00 0f 0d ab -> 3 leading zeroes.
// Example 2:
// LE input : 0x ab 0d 0f e0 : cut_nibble = 1
// BE reverted: 0x e0 0d 0f ab
// BE reverted + cutted ("e" removed) : 0x 00 d0 fa b -> 2 leading zeroes.
func count_leading_zeroes_from_uint256_le_before_reversion{bitwise_ptr: BitwiseBuiltin*}(
x: Uint256, n_nibbles_after_reversion: felt, cut_nibble: felt, pow2_array: felt*
) -> (res: felt) {
alloc_locals;
%{
from tools.py.utils import parse_int_to_bytes, count_leading_zero_nibbles_from_hex
input_ = ids.x.low + 2**128*ids.x.high
input_bytes = parse_int_to_bytes(input_)
#print(f"input hex {input_bytes.hex()}")
reversed_bytes = input_bytes[::-1]
#print("reversed bytes", reversed_bytes)
reversed_hex = reversed_bytes.hex()
#print("reversed hex", reversed_hex)
if ids.cut_nibble == 1:
reversed_hex = reversed_hex[1:]
#print(f"Reversed hex final : {reversed_hex}")
expected_leading_zeroes = count_leading_zero_nibbles_from_hex(reversed_hex)
#print(f"Expected leading zeroes {expected_leading_zeroes}")
%}
local x_f: Uint256;
local first_nibble_is_zero;
assert x_f.high = x.high;
if (cut_nibble != 0) {
assert bitwise_ptr.x = x.low;
assert bitwise_ptr.y = 0xffffffffffffffffffffffffffffff0f;
assert bitwise_ptr[1].x = x.low;
assert bitwise_ptr[1].y = 0xf;
let xf_l = bitwise_ptr.x_and_y;
assert x_f.low = xf_l;
let first_nibble = bitwise_ptr[1].x_and_y;
if (first_nibble == 0) {
assert first_nibble_is_zero = 1;
tempvar bitwise_ptr = bitwise_ptr + 2 * BitwiseBuiltin.SIZE;
} else {
assert first_nibble_is_zero = 0;
tempvar bitwise_ptr = bitwise_ptr + 2 * BitwiseBuiltin.SIZE;
}
} else {
assert x_f.low = x.low;
assert bitwise_ptr.x = x.low;
assert bitwise_ptr.y = 0xf0;
let first_nibble = bitwise_ptr.x_and_y / 2 ** 4;
if (first_nibble == 0) {
assert first_nibble_is_zero = 1;
tempvar bitwise_ptr = bitwise_ptr + BitwiseBuiltin.SIZE;
} else {
assert first_nibble_is_zero = 0;
tempvar bitwise_ptr = bitwise_ptr + BitwiseBuiltin.SIZE;
}
}
let (trailing_zeroes_low) = count_trailing_zeroes_128(x_f.low, pow2_array);
if (trailing_zeroes_low == 16) {
// The low part if full of zeroes bytes.
// Need to analyze the high part.
let (trailing_zeroes_high) = count_trailing_zeroes_128(x_f.high, pow2_array);
if (trailing_zeroes_high == 16) {
// The high part is also full of zeroes bytes.
// The number of leading zeroes is then precisely the number of nibbles after reversion.
return (res=n_nibbles_after_reversion);
} else {
// Need to analyse the first nibble after reversion.
let first_non_zero_byte = extract_byte_at_pos(
x_f.high, trailing_zeroes_high, pow2_array
);
let (first_nibble_after_reversion, _) = bitwise_divmod(first_non_zero_byte, 2 ** 4);
if (first_nibble_after_reversion == 0) {
let res = 32 + 2 * trailing_zeroes_high - cut_nibble + 1;
%{ assert ids.res == expected_leading_zeroes, f"Expected {expected_leading_zeroes} but got {ids.res}" %}
return (res=res);
} else {
let res = 32 + 2 * trailing_zeroes_high - cut_nibble;
%{ assert ids.res == expected_leading_zeroes, f"Expected {expected_leading_zeroes} but got {ids.res}" %}

return (res=res);
}
}
} else {
// Trailing zeroes bytes between [0, 15].
if (trailing_zeroes_low == 0) {
let res = first_nibble_is_zero;
%{ assert ids.res == expected_leading_zeroes, f"Expected {expected_leading_zeroes} but got {ids.res}" %}
return (res=res);
} else {
// Trailing zeroes bytes between [1, 15].

// Need to check the first nibble after reversion.
let first_non_zero_byte = extract_byte_at_pos(x_f.low, trailing_zeroes_low, pow2_array);
// %{ print(f"{hex(ids.first_non_zero_byte)=}") %}
local first_nibble_after_reversion;

let (first_nibble_after_reversion, _) = bitwise_divmod(first_non_zero_byte, 2 ** 4);
// %{ print(f"{hex(ids.first_nibble_after_reversion)=}") %}
if (first_nibble_after_reversion == 0) {
let res = 2 * trailing_zeroes_low - cut_nibble + 1;
%{ assert ids.res == expected_leading_zeroes, f"Expected {expected_leading_zeroes} but got {ids.res}" %}

return (res=res);
} else {
let res = 2 * trailing_zeroes_low - cut_nibble;
%{ assert ids.res == expected_leading_zeroes, f"Expected {expected_leading_zeroes} but got {ids.res}" %}

return (res=res);
}
}
}
}

// Takes a 64 bit word in little endian, returns the byte at a given position as it would be in big endian.
// Ie: word = b7 b6 b5 b4 b3 b2 b1 b0
// returns bi such that i = byte_position
Expand Down Expand Up @@ -124,6 +246,7 @@ func assert_subset_in_key_be{range_check_ptr, bitwise_ptr: BitwiseBuiltin*}(
pow2_array: felt*,
) {
alloc_locals;

// Get the little endian 256 bit number from the extracted 64 bit le words array :
let key_subset_256_le = key_subset_to_uint256(key_subset, key_subset_len);
%{
Expand Down Expand Up @@ -165,84 +288,91 @@ func assert_subset_in_key_be{range_check_ptr, bitwise_ptr: BitwiseBuiltin*}(
assert key_subset_be.high = key_subset_be_tmp.high;
assert bitwise_ptr_f = bitwise_ptr;
}
// Right pad with 0's if nibble lens don't match.
let bitwise_ptr = bitwise_ptr_f;
// %{
// key_subset_cut = hex(ids.key_subset_be.low + ids.key_subset_be.high*2**128)[2:]
// print(f"Key subset cut: {key_subset_cut}, n_nibbles={len(key_subset_cut)}")
// %}
local bitwise_ptr_f: BitwiseBuiltin*;
local range_check_ptr_f;
local key_subset_be_final: Uint256;
let (key_subset_bits) = get_uint256_bit_length(key_subset_be, pow2_array);
let (key_subset_nibbles) = n_bits_to_n_nibbles(key_subset_bits);

// Remove n_nibbles_already_checked nibbles from the left part of the key
// %{ print(f"Remove {ids.n_nibbles_already_checked} nibbles from the left part of the key") %}
let (u256_power) = uint256_pow2(
Uint256((key_be_nibbles + key_be_leading_zeroes_nibbles - n_nibbles_already_checked) * 4, 0)
);
let (_, key_shifted) = uint256_unsigned_div_rem(key_be, u256_power);
// %{ print(f"Key shifted: {hex(ids.key_shifted.low + ids.key_shifted.high*2**128)}") %}

// Remove rightmost part of the key, keep only key_subset_nibble_len nibbles on the left
// %{
// print(f"Remove rightmost part of the key, keep only {ids.key_subset_nibble_len} nibbles on the left")
// power = ids.key_be_nibbles + ids.key_be_leading_zeroes_nibbles - ids.n_nibbles_already_checked - ids.key_subset_nibble_len
// print(f"Computing 2**({power}) = {power/4} nibbles = {power/8} bytes")
// %}
let (u256_power) = uint256_pow2(
Uint256(
4 * (
key_be_nibbles +
key_be_leading_zeroes_nibbles -
n_nibbles_already_checked -
key_subset_nibble_len
),
0,
),
);
let (key_shifted, _) = uint256_unsigned_div_rem(key_shifted, u256_power);
// %{ print(f"Key shifted final: {hex(ids.key_shifted.low + ids.key_shifted.high*2**128)}") %}

if (key_subset_nibbles != key_subset_nibble_len) {
// Nibbles lens don't match.
%{ print(f"Nibbles lens don't match: {ids.key_subset_nibbles=} != {ids.key_subset_nibble_len=}") %}
// This either come from the first byte being of the form 0x0n (Nothing to do).
// Or from the last nibbles being 0. Need to right pad with 0's until the expected nibble len.
let (_, first_byte) = felt_divmod(key_subset[0], 2 ** 8);
%{ print(f"First word {hex(memory[ids.key_subset])}, first_byte={hex(ids.first_byte)}") %}
local first_nibble;
let (q, r) = felt_divmod(first_byte, 2 ** 4);
if (cut_nibble != 0) {
// If nibble needed to be cut, the first nibble is actually the second one
assert first_nibble = r;
} else {
// If nibble not needed to be cut, the first nibble is the first one
assert first_nibble = q;
}
%{ print(f"First nibble: {hex(ids.first_nibble)}") %}
if (first_nibble != 0) {
// Right pad with 0's
%{ print(f"Right pad with 0's") %}
// %{ print(f"Nibbles lens don't match: {ids.key_subset_nibbles=} != {ids.key_subset_nibble_len=}") %}
// This either come from :
// 1. the leftmost nibbles of the BE key (right most nibbles of the LE key) being 0's
// 2. the the rightmost nibbles of the BE key being 0's.
// Handle 1. : count leftfmost nibbles of the BE key from the rightmost nibbles of the LE key:
let (n_leading_zeroes_nibbles) = count_leading_zeroes_from_uint256_le_before_reversion(
key_subset_256_le, key_subset_nibble_len, cut_nibble, pow2_array
);
// %{ print(f"n_leading_zeroes_nibbles: {ids.n_leading_zeroes_nibbles}") %}
if (key_subset_nibble_len - (key_subset_nibbles + n_leading_zeroes_nibbles) != 0) {
// Handle 2. : Right pad the BE key with 0's until the expected length.
// %{ print(f"Right pad with {ids.key_subset_nibble_len - (ids.key_subset_nibbles + ids.n_leading_zeroes_nibbles)} 0's") %}
let (u256_pow) = uint256_pow2(
Uint256((key_subset_nibble_len - key_subset_nibbles) * 4, 0)
Uint256(
(key_subset_nibble_len - (key_subset_nibbles + n_leading_zeroes_nibbles)) * 4, 0
),
);
let (res_tmp, _) = uint256_mul(key_subset_be, u256_pow);
assert key_subset_be_final.low = res_tmp.low;
assert key_subset_be_final.high = res_tmp.high;
assert range_check_ptr_f = range_check_ptr;
assert bitwise_ptr_f = bitwise_ptr;
} else {
%{ print(f"Do nothing") %}
// If the very first nibble is 0, do nothing. Assertions will pass.
// %{ print(f"Do nothing. Nibble lens including leading zeroes match") %}
// Handle 1. Nothing to do. Nibble lens including leading zeroes match.
assert key_subset_be_final.low = key_subset_be.low;
assert key_subset_be_final.high = key_subset_be.high;
assert range_check_ptr_f = range_check_ptr;
assert bitwise_ptr_f = bitwise_ptr;
}
} else {
%{ print(f"Do nothing. Nibble lens match") %}
// Do nothing if the nibble lens match. Assertions will pass.
// %{ print(f"Do nothing. Nibble lens match") %}
// Do nothing if the nibble lens already match. Assertions will pass.
assert key_subset_be_final.low = key_subset_be.low;
assert key_subset_be_final.high = key_subset_be.high;
assert range_check_ptr_f = range_check_ptr;
assert bitwise_ptr_f = bitwise_ptr;
}
%{ print(f"Key subset final: {hex(ids.key_subset_be_final.low + ids.key_subset_be_final.high*2**128)}") %}
let bitwise_ptr = bitwise_ptr_f;
let range_check_ptr = range_check_ptr_f;
// Remove n_nibbles_already_checked nibbles from the left part of the key
%{ print(f"Remove {ids.n_nibbles_already_checked} nibbles from the left part of the key") %}
let (u256_power) = uint256_pow2(
Uint256((key_be_nibbles + key_be_leading_zeroes_nibbles - n_nibbles_already_checked) * 4, 0)
);
let (_, key_shifted) = uint256_unsigned_div_rem(key_be, u256_power);
%{ print(f"Key shifted: {hex(ids.key_shifted.low + ids.key_shifted.high*2**128)}") %}
// %{ print(f"Key subset final: {hex(ids.key_subset_be_final.low + ids.key_subset_be_final.high*2**128)}") %}

// Remove rightmost part of the key, keep only key_subset_nibble_len nibbles on the left
%{
print(f"Remove rightmost part of the key, keep only {ids.key_subset_nibble_len} nibbles on the left")
power = ids.key_be_nibbles + ids.key_be_leading_zeroes_nibbles - ids.n_nibbles_already_checked - ids.key_subset_nibble_len
print(f"Computing 2**({power}) = {power/4} nibbles = {power/8} bytes")
%}
let (u256_power) = uint256_pow2(
Uint256(
4 * (
key_be_nibbles +
key_be_leading_zeroes_nibbles -
n_nibbles_already_checked -
key_subset_nibble_len
),
0,
),
);
let (key_shifted, _) = uint256_unsigned_div_rem(key_shifted, u256_power);
%{ print(f"Key shifted final: {hex(ids.key_shifted.low + ids.key_shifted.high*2**128)}") %}
%{ print(f"key subset expect: {hex(ids.key_subset_be_final.low + ids.key_subset_be_final.high*2**128)}") %}
// %{ print(f"key subset expect: {hex(ids.key_shifted.low + ids.key_shifted.high*2**128)}") %}
assert key_subset_be_final.low = key_shifted.low;
assert key_subset_be_final.high = key_shifted.high;
return ();
Expand Down Expand Up @@ -347,13 +477,13 @@ func extract_nibble_from_key_be{range_check_ptr, bitwise_ptr: BitwiseBuiltin*}(
func extract_n_bytes_at_pos{bitwise_ptr: BitwiseBuiltin*}(
word_64_little: felt, pos: felt, n: felt, pow2_array: felt*
) -> felt {
%{ print(f"extracting {ids.n} bytes at pos {ids.pos} from {hex(ids.word_64_little)}") %}
// %{ print(f"extracting {ids.n} bytes at pos {ids.pos} from {hex(ids.word_64_little)}") %}
let x_mask = get_0xff_mask(n);
%{ print(f"x_mask for len {ids.n}: {hex(ids.x_mask)}") %}
// %{ print(f"x_mask for len {ids.n}: {hex(ids.x_mask)}") %}
assert bitwise_ptr[0].x = word_64_little;
assert bitwise_ptr[0].y = x_mask * pow2_array[8 * (pos)];
tempvar res = bitwise_ptr[0].x_and_y;
%{ print(f"tmp : {hex(ids.res)}") %}
// %{ print(f"tmp : {hex(ids.res)}") %}
tempvar extracted_bytes = bitwise_ptr[0].x_and_y / pow2_array[8 * pos];
tempvar bitwise_ptr = bitwise_ptr + BitwiseBuiltin.SIZE;
return extracted_bytes;
Expand Down Expand Up @@ -529,8 +659,6 @@ func extract_n_bytes_from_le_64_chunks_array{range_check_ptr}(
return (res, n_words_handled);
}

// func extract_n_bytes_from_le_64_chunks_array_inner(array:felt*, current_word:felt, n_words_handled:felt,

func array_copy(src: felt*, dst: felt*, n: felt, index: felt) {
if (index == n) {
return ();
Expand Down

0 comments on commit f37ebf8

Please sign in to comment.