Skip to content

Commit

Permalink
feat: allow setting chain ID in tester [APE-1253] (ApeWorX#1578)
Browse files Browse the repository at this point in the history
  • Loading branch information
antazoey authored Aug 2, 2023
1 parent fab0855 commit 04a6361
Show file tree
Hide file tree
Showing 14 changed files with 152 additions and 113 deletions.
13 changes: 9 additions & 4 deletions src/ape/api/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from evm_trace import TraceFrame as EvmTraceFrame
from pydantic import Field, root_validator, validator
from web3 import Web3
from web3.exceptions import BlockNotFound
from web3.exceptions import ContractLogicError as Web3ContractLogicError
from web3.exceptions import MethodUnavailable, TimeExhausted, TransactionNotFound
from web3.types import RPCEndpoint, TxParams
Expand Down Expand Up @@ -877,7 +876,7 @@ def get_block(self, block_id: BlockID) -> BlockAPI:

try:
block_data = dict(self.web3.eth.get_block(block_id))
except BlockNotFound as err:
except Exception as err:
raise BlockNotFoundError(block_id) from err

# Some nodes (like anvil) will not have a base fee if set to 0.
Expand Down Expand Up @@ -1590,8 +1589,14 @@ def connect(self):
if self.is_connected:
raise ProviderError("Cannot connect twice. Call disconnect before connecting again.")

# Register atexit handler to make sure disconnect is called for normal object lifecycle.
atexit.register(self.disconnect)
# Always disconnect after,
# unless running tests with `disconnect_providers_after: false`.
disconnect_after = (
self._test_runner is None
or self.config_manager.get_config("test").disconnect_provider_after
)
if disconnect_after:
atexit.register(self.disconnect)

# Register handlers to ensure atexit handlers are called when Python dies.
def _signal_handler(signum, frame):
Expand Down
16 changes: 12 additions & 4 deletions src/ape/pytest/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest

from ape.api import ReceiptAPI, TestAccountAPI
from ape.exceptions import ChainError
from ape.exceptions import BlockNotFoundError, ChainError
from ape.logging import logger
from ape.managers.chain import ChainManager
from ape.managers.networks import NetworkManager
Expand Down Expand Up @@ -88,15 +88,23 @@ def _isolation(self) -> Iterator[None]:
When tracing support is available, will also assist in capturing receipts.
"""

snapshot_id = self._snapshot()
try:
snapshot_id = self._snapshot()
except BlockNotFoundError:
snapshot_id = None

if self._track_transactions:
with self.receipt_capture:
try:
with self.receipt_capture:
yield

except BlockNotFoundError:
yield

else:
yield

if snapshot_id:
if snapshot_id is not None:
self._restore(snapshot_id)

# isolation fixtures
Expand Down
2 changes: 2 additions & 0 deletions src/ape/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from ape.utils.testing import (
DEFAULT_HD_PATH,
DEFAULT_NUMBER_OF_TEST_ACCOUNTS,
DEFAULT_TEST_CHAIN_ID,
DEFAULT_TEST_MNEMONIC,
GeneratedDevAccount,
generate_dev_accounts,
Expand All @@ -65,6 +66,7 @@
"cached_property",
"DEFAULT_LOCAL_TRANSACTION_ACCEPTANCE_TIMEOUT",
"DEFAULT_NUMBER_OF_TEST_ACCOUNTS",
"DEFAULT_TEST_CHAIN_ID",
"DEFAULT_TEST_MNEMONIC",
"DEFAULT_HD_PATH",
"DEFAULT_TRANSACTION_ACCEPTANCE_TIMEOUT",
Expand Down
1 change: 1 addition & 0 deletions src/ape/utils/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
DEFAULT_NUMBER_OF_TEST_ACCOUNTS = 10
DEFAULT_TEST_MNEMONIC = "test test test test test test test test test test test junk"
DEFAULT_HD_PATH = "m/44'/60'/0'/{}"
DEFAULT_TEST_CHAIN_ID = 1337
GeneratedDevAccount = namedtuple("GeneratedDevAccount", ("address", "private_key"))
"""
An account key-pair generated from the test mnemonic. Set the test mnemonic
Expand Down
71 changes: 37 additions & 34 deletions src/ape_geth/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
get_calltree_from_geth_trace,
get_calltree_from_parity_trace,
)
from geth import LoggingMixin # type: ignore
from geth.accounts import ensure_account_exists # type: ignore
from geth.chain import initialize_chain # type: ignore
from geth.process import BaseGethProcess # type: ignore
Expand Down Expand Up @@ -49,18 +48,20 @@
from ape.types import CallTreeNode, SnapshotID, SourceTraceback, TraceFrame
from ape.utils import (
DEFAULT_NUMBER_OF_TEST_ACCOUNTS,
DEFAULT_TEST_CHAIN_ID,
DEFAULT_TEST_MNEMONIC,
JoinableQueue,
generate_dev_accounts,
raises_not_implemented,
spawn,
)

DEFAULT_PORT = 8545
DEFAULT_HOSTNAME = "localhost"
DEFAULT_SETTINGS = {"uri": f"http://{DEFAULT_HOSTNAME}:{DEFAULT_PORT}"}
GETH_DEV_CHAIN_ID = 1337


class GethDevProcess(LoggingMixin, BaseGethProcess):
class GethDevProcess(BaseGethProcess):
"""
A developer-configured geth that only exists until disconnected.
"""
Expand All @@ -72,18 +73,21 @@ def __init__(
port: int = DEFAULT_PORT,
mnemonic: str = DEFAULT_TEST_MNEMONIC,
number_of_accounts: int = DEFAULT_NUMBER_OF_TEST_ACCOUNTS,
chain_id: int = GETH_DEV_CHAIN_ID,
chain_id: int = DEFAULT_TEST_CHAIN_ID,
initial_balance: Union[str, int] = to_wei(10000, "ether"),
executable: Optional[str] = None,
auto_disconnect: bool = True,
):
if not shutil.which("geth"):
executable = executable or "geth"
if not shutil.which(executable):
raise GethNotInstalledError()

self.data_dir = data_dir
self._hostname = hostname
self._port = port
self.data_dir.mkdir(exist_ok=True, parents=True)
self.is_running = False
self._auto_disconnect = auto_disconnect

geth_kwargs = construct_test_chain_kwargs(
data_dir=self.data_dir,
Expand Down Expand Up @@ -129,23 +133,9 @@ def __init__(
"alloc": {a.address: {"balance": str(initial_balance)} for a in accounts},
}

def make_logs_paths(stream_name: str):
path = data_dir / "geth-logs" / f"{stream_name}_{self._port}"
path.parent.mkdir(exist_ok=True, parents=True)
return path

initialize_chain(genesis_data, **geth_kwargs)
self.proc: Optional[Popen] = None
super().__init__(
geth_kwargs,
stdout_logfile_path=make_logs_paths("stdout"),
stderr_logfile_path=make_logs_paths("stderr"),
)

if logger.level <= LogLevel.DEBUG:
# Show process output.
self.register_stdout_callback(lambda x: logger.debug)
self.register_stderr_callback(lambda x: logger.debug)
super().__init__(geth_kwargs)

@classmethod
def from_uri(cls, uri: str, data_folder: Path, **kwargs):
Expand All @@ -164,17 +154,19 @@ def from_uri(cls, uri: str, data_folder: Path, **kwargs):
mnemonic=mnemonic,
number_of_accounts=number_of_accounts,
executable=kwargs.get("executable"),
auto_disconnect=kwargs.get("auto_disconnect", True),
)

def connect(self):
def connect(self, timeout: int = 60):
home = str(Path.home())
ipc_path = self.ipc_path.replace(home, "$HOME")
logger.info(f"Starting geth (HTTP='{self._hostname}:{self._port}', IPC={ipc_path}).")
self.start()
self.wait_for_rpc(timeout=60)
self.wait_for_rpc(timeout=timeout)

# Register atexit handler to make sure disconnect is called for normal object lifecycle.
atexit.register(self.disconnect)
if self._auto_disconnect:
atexit.register(self.disconnect)

def start(self):
if self.is_running:
Expand Down Expand Up @@ -213,7 +205,7 @@ class GethNetworkConfig(PluginConfig):
goerli: dict = DEFAULT_SETTINGS.copy()
sepolia: dict = DEFAULT_SETTINGS.copy()
# Make sure to run via `geth --dev` (or similar)
local: dict = DEFAULT_SETTINGS.copy()
local: dict = {**DEFAULT_SETTINGS.copy(), "chain_id": DEFAULT_TEST_CHAIN_ID}


class GethConfig(PluginConfig):
Expand Down Expand Up @@ -376,7 +368,7 @@ def get_call_tree(self, txn_hash: str) -> CallTreeNode:
return self._get_geth_call_tree(txn_hash)

# Parity style works.
self.can_use_parity_traces = True
self._can_use_parity_traces = True
return tree

def _get_parity_call_tree(self, txn_hash: str) -> CallTreeNode:
Expand All @@ -394,7 +386,7 @@ def _get_geth_call_tree(self, txn_hash: str) -> CallTreeNode:
return self._create_call_tree_node(evm_call, txn_hash=txn_hash)

def _log_connection(self, client_name: str):
msg = f"Connecting to existing {client_name} node at "
msg = f"Connecting to existing {client_name.strip()} node at"
suffix = (
self.ipc_path.as_posix().replace(Path.home().as_posix(), "$HOME")
if self.ipc_path.exists()
Expand Down Expand Up @@ -438,37 +430,39 @@ def process_name(self) -> str:

@property
def chain_id(self) -> int:
return GETH_DEV_CHAIN_ID
return self.geth_config.ethereum.local.get("chain_id", DEFAULT_TEST_CHAIN_ID)

@property
def data_dir(self) -> Path:
# Overridden from BaseGeth class for placing debug logs in ape data folder.
return self.geth_config.data_dir or self.data_folder / self.name

def __repr__(self):
if self._process is None:
# Exclude chain ID when not connected
try:
return f"<geth chain_id={self.chain_id}>"
except Exception:
return "<geth>"

return super().__repr__()

def connect(self):
self._set_web3()
if not self.is_connected:
self._start_geth()
self.start()
else:
self._complete_connect()

def _start_geth(self):
def start(self, timeout: int = 20):
test_config = self.config_manager.get_config("test").dict()

# Allow configuring a custom executable besides your $PATH geth.
if self.geth_config.executable is not None:
test_config["executable"] = self.geth_config.executable

test_config["ipc_path"] = self.ipc_path
test_config["auto_disconnect"] = self._test_runner is None or test_config.get(
"disconnect_providers_after", True
)
process = GethDevProcess.from_uri(self.uri, self.data_dir, **test_config)
process.connect()
process.connect(timeout=timeout)
if not self.web3.is_connected():
process.disconnect()
raise ConnectionError("Unable to connect to locally running geth.")
Expand All @@ -479,8 +473,17 @@ def _start_geth(self):

# For subprocess-provider
if self._process is not None and (process := self._process.proc):
self.stderr_queue = JoinableQueue()
self.stdout_queue = JoinableQueue()

self.process = process

# Start listening to output.
spawn(self.produce_stdout_queue)
spawn(self.produce_stderr_queue)
spawn(self.consume_stdout_queue)
spawn(self.consume_stderr_queue)

def disconnect(self):
# Must disconnect process first.
if self._process is not None:
Expand Down
7 changes: 6 additions & 1 deletion src/ape_test/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ape.utils import DEFAULT_HD_PATH, DEFAULT_NUMBER_OF_TEST_ACCOUNTS, DEFAULT_TEST_MNEMONIC

from .accounts import TestAccount, TestAccountContainer
from .provider import LocalProvider
from .provider import EthTesterProviderConfig, LocalProvider


class GasExclusion(PluginConfig):
Expand Down Expand Up @@ -125,6 +125,11 @@ class Config(PluginConfig):
The hd_path to use when generating the test accounts.
"""

provider: EthTesterProviderConfig = EthTesterProviderConfig()
"""
Settings for the provider.
"""


@plugins.register(plugins.Config)
def config_class():
Expand Down
39 changes: 24 additions & 15 deletions src/ape_test/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,12 @@
from eth_utils import is_0x_prefixed
from eth_utils.exceptions import ValidationError
from ethpm_types import HexBytes
from lazyasd import LazyObject # type: ignore
from web3 import EthereumTesterProvider, Web3
from web3.exceptions import ContractPanicError
from web3.providers.eth_tester.defaults import API_ENDPOINTS
from web3.providers.eth_tester.defaults import API_ENDPOINTS, static_return
from web3.types import TxParams

from ape.api import ReceiptAPI, TestProviderAPI, TransactionAPI, Web3Provider
from ape.api import PluginConfig, ReceiptAPI, TestProviderAPI, TransactionAPI, Web3Provider
from ape.exceptions import (
ContractLogicError,
ProviderError,
Expand All @@ -24,9 +23,11 @@
VirtualMachineError,
)
from ape.types import SnapshotID
from ape.utils import gas_estimation_error_message
from ape.utils import DEFAULT_TEST_CHAIN_ID, gas_estimation_error_message

CHAIN_ID = LazyObject(lambda: API_ENDPOINTS["eth"]["chainId"](), globals(), "CHAIN_ID")

class EthTesterProviderConfig(PluginConfig):
chain_id: int = DEFAULT_TEST_CHAIN_ID


class LocalProvider(TestProviderAPI, Web3Provider):
Expand All @@ -44,14 +45,21 @@ def evm_backend(self) -> PyEVMBackend:
return self._evm_backend

def connect(self):
chain_id = self.provider_settings.get("chain_id", self.config.provider.chain_id)
if self._web3 is not None:
return
connected_chain_id = self.chain_id
if connected_chain_id == chain_id:
# Is already connected and settings have not changed.
return

self._evm_backend = PyEVMBackend.from_mnemonic(
mnemonic=self.config["mnemonic"],
num_accounts=self.config["number_of_accounts"],
)
self._web3 = Web3(EthereumTesterProvider(ethereum_tester=self._evm_backend))
endpoints = {**API_ENDPOINTS}
endpoints["eth"]["chainId"] = static_return(chain_id)
tester = EthereumTesterProvider(ethereum_tester=self._evm_backend, api_endpoints=endpoints)
self._web3 = Web3(tester)

def disconnect(self):
self.cached_chain_id = None
Expand Down Expand Up @@ -101,15 +109,16 @@ def estimate_gas_cost(self, txn: TransactionAPI, **kwargs) -> int:

@property
def chain_id(self) -> int:
if self.cached_chain_id is not None:
return self.cached_chain_id
elif hasattr(self.web3, "eth"):
chain_id = self.web3.eth.chain_id
else:
chain_id = CHAIN_ID # type: ignore
try:
if self.cached_chain_id:
return self.cached_chain_id

result = self._make_request("eth_chainId", [])
self.cached_chain_id = result
return result

self.cached_chain_id = chain_id
return chain_id
except ProviderNotConnectedError:
return self.provider_settings.get("chain_id", self.config.provider.chain_id)

@property
def gas_price(self) -> int:
Expand Down
Loading

0 comments on commit 04a6361

Please sign in to comment.