Skip to content

Commit

Permalink
fix: improve test case inputs in selector table fuzz (#3625)
Browse files Browse the repository at this point in the history
this commit improves the fuzz examples for the selector table. the
nested `@given` tests too many "dumb" examples (ex. 0, 1, max_value)
when `max_examples` is not large enough. the nested `@given` strategy
can find falsifying inputs, but it requires the inner `max_examples` to
be set much higher, and the shrinking takes much longer. this setting
of `max_examples=125` with a single `@given` using the `@composite`
strategy in this commit finds the selector table bug (that was fixed in
823675a) after an average of 3 runs.
  • Loading branch information
charles-cooper authored Sep 28, 2023
1 parent aecd911 commit 4281780
Showing 1 changed file with 52 additions and 46 deletions.
98 changes: 52 additions & 46 deletions tests/parser/test_selector_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,66 +478,72 @@ def test_dense_jumptable_bucket_size(n_methods, seed):
assert n_buckets / n < 0.4 or n < 10


@st.composite
def generate_methods(draw, max_calldata_bytes):
max_default_args = draw(st.integers(min_value=0, max_value=4))
default_fn_mutability = draw(st.sampled_from(["", "@pure", "@view", "@nonpayable", "@payable"]))

return (
max_default_args,
default_fn_mutability,
draw(
st.lists(
st.tuples(
# function id:
st.integers(min_value=0),
# mutability:
st.sampled_from(["@pure", "@view", "@nonpayable", "@payable"]),
# n calldata words:
st.integers(min_value=0, max_value=max_calldata_bytes // 32),
# n bytes to strip from calldata
st.integers(min_value=1, max_value=4),
# n default args
st.integers(min_value=0, max_value=max_default_args),
),
unique_by=lambda x: x[0],
min_size=1,
max_size=100,
)
),
)


@pytest.mark.parametrize("opt_level", list(OptimizationLevel))
# dense selector table packing boundaries at 256 and 65336
@pytest.mark.parametrize("max_calldata_bytes", [255, 256, 65336])
@settings(max_examples=5, deadline=None)
@given(
seed=st.integers(min_value=0, max_value=2**64 - 1),
max_default_args=st.integers(min_value=0, max_value=4),
default_fn_mutability=st.sampled_from(["", "@pure", "@view", "@nonpayable", "@payable"]),
)
@pytest.mark.fuzzing
def test_selector_table_fuzz(
max_calldata_bytes,
seed,
max_default_args,
opt_level,
default_fn_mutability,
w3,
get_contract,
assert_tx_failed,
get_logs,
max_calldata_bytes, opt_level, w3, get_contract, assert_tx_failed, get_logs
):
def abi_sig(calldata_words, i, n_default_args):
args = [] if not calldata_words else [f"uint256[{calldata_words}]"]
args.extend(["uint256"] * n_default_args)
argstr = ",".join(args)
return f"foo{seed + i}({argstr})"
def abi_sig(func_id, calldata_words, n_default_args):
params = [] if not calldata_words else [f"uint256[{calldata_words}]"]
params.extend(["uint256"] * n_default_args)
paramstr = ",".join(params)
return f"foo{func_id}({paramstr})"

def generate_func_def(mutability, calldata_words, i, n_default_args):
def generate_func_def(func_id, mutability, calldata_words, n_default_args):
arglist = [] if not calldata_words else [f"x: uint256[{calldata_words}]"]
for j in range(n_default_args):
arglist.append(f"x{j}: uint256 = 0")
args = ", ".join(arglist)
_log_return = f"log _Return({i})" if mutability == "@payable" else ""
_log_return = f"log _Return({func_id})" if mutability == "@payable" else ""

return f"""
@external
{mutability}
def foo{seed + i}({args}) -> uint256:
def foo{func_id}({args}) -> uint256:
{_log_return}
return {i}
return {func_id}
"""

@given(
methods=st.lists(
st.tuples(
st.sampled_from(["@pure", "@view", "@nonpayable", "@payable"]),
st.integers(min_value=0, max_value=max_calldata_bytes // 32),
# n bytes to strip from calldata
st.integers(min_value=1, max_value=4),
# n default args
st.integers(min_value=0, max_value=max_default_args),
),
min_size=1,
max_size=100,
)
)
@settings(max_examples=25)
def _test(methods):
@given(_input=generate_methods(max_calldata_bytes))
@settings(max_examples=125, deadline=None)
def _test(_input):
max_default_args, default_fn_mutability, methods = _input

func_defs = "\n".join(
generate_func_def(m, s, i, d) for i, (m, s, _, d) in enumerate(methods)
generate_func_def(func_id, mutability, calldata_words, n_default_args)
for (func_id, mutability, calldata_words, _, n_default_args) in (methods)
)

if default_fn_mutability == "":
Expand Down Expand Up @@ -571,26 +577,26 @@ def __default__():

c = get_contract(code, override_opt_level=opt_level)

for i, (mutability, n_calldata_words, n_strip_bytes, n_default_args) in enumerate(methods):
funcname = f"foo{seed + i}"
for func_id, mutability, n_calldata_words, n_strip_bytes, n_default_args in methods:
funcname = f"foo{func_id}"
func = getattr(c, funcname)

for j in range(n_default_args + 1):
args = [[1] * n_calldata_words] if n_calldata_words else []
args.extend([1] * j)

# check the function returns as expected
assert func(*args) == i
assert func(*args) == func_id

method_id = utils.method_id(abi_sig(n_calldata_words, i, j))
method_id = utils.method_id(abi_sig(func_id, n_calldata_words, j))

argsdata = b"\x00" * (n_calldata_words * 32 + j * 32)

# do payable check
if mutability == "@payable":
tx = func(*args, transact={"value": 1})
(event,) = get_logs(tx, c, "_Return")
assert event.args.val == i
assert event.args.val == func_id
else:
hexstr = (method_id + argsdata).hex()
txdata = {"to": c.address, "data": hexstr, "value": 1}
Expand Down

0 comments on commit 4281780

Please sign in to comment.