Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DO NOT MERGE] fix(cheatcode): expect revert only for calls with greater depth than test #9537

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions crates/cheatcodes/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ pub struct CheatsConfig {
pub assertions_revert: bool,
/// Optional seed for the RNG algorithm.
pub seed: Option<U256>,
/// Whether to allow `expectRevert` to work for internal calls.
pub internal_expect_revert: bool,
}

impl CheatsConfig {
Expand Down Expand Up @@ -98,6 +100,7 @@ impl CheatsConfig {
running_version,
assertions_revert: config.assertions_revert,
seed: config.fuzz.seed,
internal_expect_revert: config.allow_internal_expect_revert,
}
}

Expand Down Expand Up @@ -239,6 +242,7 @@ impl Default for CheatsConfig {
running_version: Default::default(),
assertions_revert: true,
seed: None,
internal_expect_revert: false,
}
}
}
Expand Down
34 changes: 19 additions & 15 deletions crates/cheatcodes/src/inspector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ use revm::{
};
use serde_json::Value;
use std::{
cmp::max,
collections::{BTreeMap, VecDeque},
fs::File,
io::BufReader,
Expand Down Expand Up @@ -758,6 +759,7 @@ where {
let handler_result = expect::handle_expect_revert(
false,
true,
self.config.internal_expect_revert,
&mut expected_revert,
outcome.result.result,
outcome.result.output.clone(),
Expand Down Expand Up @@ -1176,6 +1178,11 @@ impl Inspector<&mut dyn DatabaseExt> for Cheatcodes {
if self.gas_metering.paused {
self.gas_metering.paused_frames.push(interpreter.gas);
}

// `expectRevert`: track the max call depth during `expectRevert`
if let Some(ref mut expected) = self.expected_revert {
expected.max_depth = max(ecx.journaled_state.depth(), expected.max_depth);
}
}

#[inline]
Expand Down Expand Up @@ -1302,21 +1309,17 @@ impl Inspector<&mut dyn DatabaseExt> for Cheatcodes {

// Handle expected reverts.
if let Some(expected_revert) = &mut self.expected_revert {
// Record current reverter address before processing the expect revert if call reverted,
// expect revert is set with expected reverter address and no actual reverter set yet.
if outcome.result.is_revert() &&
expected_revert.reverter.is_some() &&
expected_revert.reverted_by.is_none()
{
expected_revert.reverted_by = Some(call.target_address);
} else if outcome.result.is_revert() &&
expected_revert.reverter.is_some() &&
expected_revert.reverted_by.is_some() &&
expected_revert.count > 1
{
// If we're expecting more than one revert, we need to reset the reverted_by address
// to latest reverter.
expected_revert.reverted_by = Some(call.target_address);
// Record current reverter address and call scheme before processing the expect revert
// if call reverted.
if outcome.result.is_revert() {
// Record current reverter address if expect revert is set with expected reverter
// address and no actual reverter was set yet or if we're expecting more than one
// revert.
if expected_revert.reverter.is_some() &&
(expected_revert.reverted_by.is_none() || expected_revert.count > 1)
{
expected_revert.reverted_by = Some(call.target_address);
}
}

if ecx.journaled_state.depth() <= expected_revert.depth {
Expand All @@ -1337,6 +1340,7 @@ impl Inspector<&mut dyn DatabaseExt> for Cheatcodes {
let handler_result = expect::handle_expect_revert(
cheatcode_call,
false,
self.config.internal_expect_revert,
&mut expected_revert,
outcome.result.result,
outcome.result.output.clone(),
Expand Down
12 changes: 12 additions & 0 deletions crates/cheatcodes/src/test/expect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ pub struct ExpectedRevert {
pub reverter: Option<Address>,
/// Actual reverter of the call.
pub reverted_by: Option<Address>,
/// Max call depth reached during next call execution.
pub max_depth: u64,
/// Number of times this revert is expected.
pub count: u64,
/// Actual number of times this revert has been seen.
Expand Down Expand Up @@ -774,6 +776,7 @@ fn expect_revert(
partial_match,
reverter,
reverted_by: None,
max_depth: depth,
count,
actual_count: 0,
});
Expand All @@ -783,6 +786,7 @@ fn expect_revert(
pub(crate) fn handle_expect_revert(
is_cheatcode: bool,
is_create: bool,
internal_expect_revert: bool,
expected_revert: &mut ExpectedRevert,
status: InstructionResult,
retdata: Bytes,
Expand All @@ -806,6 +810,14 @@ pub(crate) fn handle_expect_revert(
hex::encode_prefixed(data)
};

// Check depths if it's not an expect cheatcode call and if internal expect reverts not enabled.
if !is_cheatcode && !internal_expect_revert {
ensure!(
expected_revert.max_depth > expected_revert.depth,
"call didn't revert at a lower depth than cheatcode call depth"
);
}

if expected_revert.count == 0 {
if expected_revert.reverter.is_none() && expected_revert.reason.is_none() {
ensure!(
Expand Down
3 changes: 3 additions & 0 deletions crates/config/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,8 @@ pub struct Config {
pub invariant: InvariantConfig,
/// Whether to allow ffi cheatcodes in test
pub ffi: bool,
/// Whether to allow `expectRevert` for internal functions.
pub allow_internal_expect_revert: bool,
/// Use the create 2 factory in all cases including tests and non-broadcasting scripts.
pub always_use_create_2_factory: bool,
/// Sets a timeout in seconds for vm.prompt cheatcodes
Expand Down Expand Up @@ -2310,6 +2312,7 @@ impl Default for Config {
invariant: InvariantConfig::new("cache/invariant".into()),
always_use_create_2_factory: false,
ffi: false,
allow_internal_expect_revert: false,
prompt_timeout: 120,
sender: Self::DEFAULT_SENDER,
tx_origin: Self::DEFAULT_SENDER,
Expand Down
1 change: 1 addition & 0 deletions crates/forge/tests/cli/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ forgetest!(can_extract_config_values, |prj, cmd| {
..Default::default()
},
ffi: true,
allow_internal_expect_revert: false,
always_use_create_2_factory: false,
prompt_timeout: 0,
sender: "00a329c0648769A73afAc7F9381D08FB43dBEA72".parse().unwrap(),
Expand Down
3 changes: 3 additions & 0 deletions crates/forge/tests/it/repros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -389,3 +389,6 @@ test_repro!(8971; |config| {

// https://github.com/foundry-rs/foundry/issues/8639
test_repro!(8639);

// https://github.com/foundry-rs/foundry/issues/7238
test_repro!(7238);
1 change: 1 addition & 0 deletions crates/test-utils/src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ impl ExtTester {
test_cmd.env("FOUNDRY_FORK_BLOCK_NUMBER", fork_block.to_string());
}
test_cmd.env("FOUNDRY_INVARIANT_DEPTH", "15");
test_cmd.env("FOUNDRY_ALLOW_INTERNAL_EXPECT_REVERT", "true");

test_cmd.assert_success();
}
Expand Down
3 changes: 2 additions & 1 deletion testdata/default/cheats/AttachDelegation.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ contract AttachDelegationTest is DSTest {
assertEq(token.balanceOf(bob), 200);
}

/// forge-config: default.allow_internal_expect_revert = true
function testAttachDelegationRevertInvalidSignature() public {
Vm.SignedDelegation memory signedDelegation = vm.signDelegation(address(implementation), alice_pk);
// change v from 1 to 0
Expand All @@ -109,7 +110,7 @@ contract AttachDelegationTest is DSTest {
// send tx to increment alice's nonce
token.mint(1, bob);

vm.expectRevert("vm.attachDelegation: invalid nonce");
vm._expectCheatcodeRevert("vm.attachDelegation: invalid nonce");
vm.attachDelegation(signedDelegation);
}

Expand Down
4 changes: 2 additions & 2 deletions testdata/default/cheats/BroadcastRawTransaction.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,8 @@ contract BroadcastRawTransactionTest is DSTest {
assertEq(address(from).balance, balance - (gasPrice * 21_000) - amountSent);
assertEq(address(to).balance, amountSent);

vm.expectRevert();
assert(3 == 4);
vm._expectCheatcodeRevert();
vm.assertFalse(true);
}

function test_execute_multiple_signed_tx() public {
Expand Down
2 changes: 2 additions & 0 deletions testdata/default/cheats/MemSafety.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,7 @@ contract MemSafetyTest is DSTest {

/// @dev Tests that expanding memory outside of the range given to `expectSafeMemory`
/// will cause the test to fail while using the `MLOAD` opcode.
/// forge-config: default.allow_internal_expect_revert = true
function testExpectSafeMemory_MLOAD_REVERT() public {
vm.expectSafeMemory(0x80, 0x100);

Expand Down Expand Up @@ -504,6 +505,7 @@ contract MemSafetyTest is DSTest {

/// @dev Tests that expanding memory outside of the range given to `expectSafeMemory`
/// will cause the test to fail while using the `LOG0` opcode.
/// forge-config: default.allow_internal_expect_revert = true
function testExpectSafeMemory_LOG0_REVERT() public {
vm.expectSafeMemory(0x80, 0x100);
vm.expectRevert();
Expand Down
56 changes: 40 additions & 16 deletions testdata/default/cheats/MockCall.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,11 @@ contract MockCallRevertTest is DSTest {

// post-mock
assertEq(target.numberA(), 1);
vm.expectRevert();
target.numberB();
try target.numberB() {
revert();
} catch (bytes memory err) {
require(keccak256(err) == keccak256(ERROR_MESSAGE));
}
}

function testMockRevertWithCustomError() public {
Expand All @@ -216,8 +219,11 @@ contract MockCallRevertTest is DSTest {
vm.mockCallRevert(address(target), abi.encodeWithSelector(target.numberB.selector), customError);

assertEq(target.numberA(), 1);
vm.expectRevert(customError);
target.numberB();
try target.numberB() {
revert();
} catch (bytes memory err) {
require(keccak256(err) == keccak256(customError));
}
}

function testMockNestedRevert() public {
Expand All @@ -228,8 +234,11 @@ contract MockCallRevertTest is DSTest {

vm.mockCallRevert(address(inner), abi.encodeWithSelector(inner.numberB.selector), ERROR_MESSAGE);

vm.expectRevert(ERROR_MESSAGE);
target.sum();
try target.sum() {
revert();
} catch (bytes memory err) {
require(keccak256(err) == keccak256(ERROR_MESSAGE));
}
}

function testMockCalldataRevert() public {
Expand All @@ -241,8 +250,11 @@ contract MockCallRevertTest is DSTest {

assertEq(target.add(6, 4), 10);

vm.expectRevert(ERROR_MESSAGE);
target.add(5, 5);
try target.add(5, 5) {
revert();
} catch (bytes memory err) {
require(keccak256(err) == keccak256(ERROR_MESSAGE));
}
}

function testClearMockRevertedCalls() public {
Expand All @@ -263,8 +275,11 @@ contract MockCallRevertTest is DSTest {

assertEq(mock.add(1, 2), 3);

vm.expectRevert(ERROR_MESSAGE);
mock.add(2, 3);
try mock.add(2, 3) {
revert();
} catch (bytes memory err) {
require(keccak256(err) == keccak256(ERROR_MESSAGE));
}
}

function testMockCallRevertWithValue() public {
Expand All @@ -275,8 +290,11 @@ contract MockCallRevertTest is DSTest {
assertEq(mock.pay(1), 1);
assertEq(mock.pay(2), 2);

vm.expectRevert(ERROR_MESSAGE);
mock.pay{value: 10}(1);
try mock.pay{value: 10}(1) {
revert();
} catch (bytes memory err) {
require(keccak256(err) == keccak256(ERROR_MESSAGE));
}
}

function testMockCallResetsMockCallRevert() public {
Expand All @@ -296,8 +314,11 @@ contract MockCallRevertTest is DSTest {

vm.mockCallRevert(address(mock), abi.encodeWithSelector(mock.add.selector), ERROR_MESSAGE);

vm.expectRevert(ERROR_MESSAGE);
mock.add(2, 3);
try mock.add(2, 3) {
revert();
} catch (bytes memory err) {
require(keccak256(err) == keccak256(ERROR_MESSAGE));
}
}

function testMockCallRevertWithCall() public {
Expand All @@ -317,7 +338,10 @@ contract MockCallRevertTest is DSTest {

vm.mockCallRevert(address(mock), abi.encodeWithSelector(mock.add.selector), ERROR_MESSAGE);

vm.expectRevert(ERROR_MESSAGE);
mock.add(1, 2);
try mock.add(2, 3) {
revert();
} catch (bytes memory err) {
require(keccak256(err) == keccak256(ERROR_MESSAGE));
}
}
}
6 changes: 3 additions & 3 deletions testdata/default/cheats/RandomCheatcodes.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ contract RandomCheatcodesTest is DSTest {
int128 constant max = 170141183460469231731687303715884105727;

function test_int128() public {
vm.expectRevert("vm.randomInt: number of bits cannot exceed 256");
vm._expectCheatcodeRevert("vm.randomInt: number of bits cannot exceed 256");
int256 val = vm.randomInt(type(uint256).max);

val = vm.randomInt(128);
Expand All @@ -31,7 +31,7 @@ contract RandomCheatcodesTest is DSTest {
}

function test_randomUintLimit() public {
vm.expectRevert("vm.randomUint: number of bits cannot exceed 256");
vm._expectCheatcodeRevert("vm.randomUint: number of bits cannot exceed 256");
uint256 val = vm.randomUint(type(uint256).max);
}

Expand Down Expand Up @@ -67,7 +67,7 @@ contract RandomBytesTest is DSTest {
}

function test_symbolic_bytes_revert() public {
vm.expectRevert();
vm._expectCheatcodeRevert();
bytes memory val = vm.randomBytes(type(uint256).max);
}

Expand Down
Loading
Loading