Skip to content

Commit

Permalink
Add shortcut state property to env
Browse files Browse the repository at this point in the history
  • Loading branch information
DanielSchiavini committed Aug 13, 2024
1 parent 0b8860f commit 673d5e4
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 30 deletions.
4 changes: 4 additions & 0 deletions boa/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,10 @@ def _hook_trace_computation(self, computation, contract=None):
child_contract = self._lookup_contract_fast(child.msg.code_address)
self._hook_trace_computation(child, child_contract)

@property
def state(self):
return self.evm.state

def get_code(self, address: _AddressType) -> bytes:
return self.evm.get_code(Address(address))

Expand Down
60 changes: 32 additions & 28 deletions boa/vm/py_evm.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,82 +368,86 @@ def _init_vm(self, account_db_class=AccountDB):

c: Type[titanoboa_computation] = type(
"TitanoboaComputation",
(titanoboa_computation, self.vm.state.computation_class),
(titanoboa_computation, self.state.computation_class),
{"env": self.env},
)

if self._fast_mode_enabled:
patch_pyevm_state_object(self.vm.state)
patch_pyevm_state_object(self.state)

self.vm.state.computation_class = c
self.state.computation_class = c

# patch in tracing opcodes
c.opcodes[0x20] = Sha3PreimageTracer(c.opcodes[0x20], self.env)
c.opcodes[0x55] = SstoreTracer(c.opcodes[0x55], self.env)

def enable_fast_mode(self, flag: bool = True):
if flag:
patch_pyevm_state_object(self.vm.state)
patch_pyevm_state_object(self.state)
else:
unpatch_pyevm_state_object(self.vm.state)
unpatch_pyevm_state_object(self.state)

def fork_rpc(self, rpc: RPC, block_identifier: str, **kwargs):
account_db_class = AccountDBFork.class_from_rpc(rpc, block_identifier, **kwargs)
self._init_vm(account_db_class)
block_info = self.vm.state._account_db._block_info
block_info = self.state._account_db._block_info

self.patch.timestamp = int(block_info["timestamp"], 16)
self.patch.block_number = int(block_info["number"], 16)
self.patch.chain_id = int(rpc.fetch("eth_chainId", []), 16)

self.vm.state._account_db._rpc._init_db()
self.state._account_db._rpc._init_db()

@property
def is_forked(self):
return issubclass(
self.vm.__class__._state_class.account_db_class, AccountDBFork
)

@property
def state(self):
return self.vm.state

def get_gas_meter_class(self):
return self.vm.state.computation_class._gas_meter_class
return self.state.computation_class._gas_meter_class

def set_gas_meter_class(self, cls: type):
self.vm.state.computation_class._gas_meter_class = cls
self.state.computation_class._gas_meter_class = cls

def get_balance(self, address: Address):
return self.vm.state.get_balance(address.canonical_address)
return self.state.get_balance(address.canonical_address)

def set_balance(self, address: Address, value):
self.vm.state.set_balance(address.canonical_address, value)
self.state.set_balance(address.canonical_address, value)

def get_code(self, address: Address) -> bytes:
return self.vm.state.get_code(address.canonical_address)
return self.state.get_code(address.canonical_address)

def set_code(self, address: Address, code: bytes) -> None:
self.vm.state.set_code(address.canonical_address, code)
self.state.set_code(address.canonical_address, code)

def get_storage(self, address: Address, slot: int) -> int:
return self.vm.state.get_storage(address.canonical_address, slot)
return self.state.get_storage(address.canonical_address, slot)

def set_storage(self, address: Address, slot: int, value: int) -> None:
self.vm.state.set_storage(address.canonical_address, slot, value)
self.state.set_storage(address.canonical_address, slot, value)

def get_gas_limit(self):
return self.vm.state.gas_limit
return self.state.gas_limit

# advanced: reset warm/cold counters for addresses and storage
def reset_access_counters(self):
self.vm.state._account_db._reset_access_counters()
self.state._account_db._reset_access_counters()

def snapshot(self) -> Any:
return self.vm.state.snapshot()
return self.state.snapshot()

def revert(self, snapshot_id: Any) -> None:
self.vm.state.revert(snapshot_id)
self.state.revert(snapshot_id)

def generate_create_address(self, sender: Address):
nonce = self.vm.state.get_nonce(sender.canonical_address)
self.vm.state.increment_nonce(sender.canonical_address)
nonce = self.state.get_nonce(sender.canonical_address)
self.state.increment_nonce(sender.canonical_address)
return Address(generate_contract_address(sender.canonical_address, nonce))

def deploy_code(
Expand All @@ -457,7 +461,7 @@ def deploy_code(
bytecode: bytes,
):
if gas is None:
gas = self.vm.state.gas_limit
gas = self.state.gas_limit

msg = Message(
to=constants.CREATE_CONTRACT_ADDRESS, # i.e., b""
Expand All @@ -470,13 +474,13 @@ def deploy_code(
)

if self.is_forked and self._fork_try_prefetch_state:
self.vm.state._account_db.try_prefetch_state(msg)
self.state._account_db.try_prefetch_state(msg)

tx_ctx = BaseTransactionContext(
origin=origin.canonical_address, gas_price=gas_price
)
return self.vm.state.computation_class.apply_create_message(
self.vm.state, msg, tx_ctx
return self.state.computation_class.apply_create_message(
self.state, msg, tx_ctx
)

def execute_code(
Expand Down Expand Up @@ -509,14 +513,14 @@ def execute_code(
)

if self.is_forked and self._fork_try_prefetch_state:
self.vm.state._account_db.try_prefetch_state(msg)
self.state._account_db.try_prefetch_state(msg)

origin = sender.canonical_address # XXX: consider making this parameterizable
tx_ctx = BaseTransactionContext(origin=origin, gas_price=gas_price)
return self.vm.state.computation_class.apply_message(self.vm.state, msg, tx_ctx)
return self.state.computation_class.apply_message(self.state, msg, tx_ctx)

def get_storage_slot(self, address: Address, slot: int) -> bytes:
data = self.vm.state._account_db.get_storage(address.canonical_address, slot)
data = self.state._account_db.get_storage(address.canonical_address, slot)
return data.to_bytes(32, "big")


Expand Down
4 changes: 2 additions & 2 deletions tests/integration/fork/test_from_etherscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ def test_prefetch_state(rpc_url, fresh_env, crvusd_contract):
code=crvusd_contract._bytecode,
data=crvusd_contract.burn.prepare_calldata(0),
)
state = env.evm.vm.state
db = state._account_db

db = env.state._account_db
db.try_prefetch_state(msg)

# patch the RPC, so we make sure to use the cache
Expand Down

0 comments on commit 673d5e4

Please sign in to comment.