Skip to content

Commit

Permalink
perf: avoid duplicate account generation (ApeWorX#2192)
Browse files Browse the repository at this point in the history
  • Loading branch information
antazoey authored Jul 30, 2024
1 parent 34bee0a commit 7caaf4d
Show file tree
Hide file tree
Showing 9 changed files with 216 additions and 121 deletions.
33 changes: 24 additions & 9 deletions src/ape/api/accounts.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from collections.abc import Iterator
from functools import cached_property
from pathlib import Path
from typing import TYPE_CHECKING, Any, Optional, Union

Expand All @@ -23,7 +24,7 @@
)
from ape.logging import logger
from ape.types import AddressType, MessageSignature, SignableMessage
from ape.utils import BaseInterfaceModel, abstractmethod
from ape.utils import BaseInterfaceModel, abstractmethod, raises_not_implemented

if TYPE_CHECKING:
from ape.contracts import ContractContainer, ContractInstance
Expand Down Expand Up @@ -443,11 +444,11 @@ def accounts(self) -> Iterator[AccountAPI]:
Iterator[:class:`~ape.api.accounts.AccountAPI`]
"""

@property
@cached_property
def data_folder(self) -> Path:
"""
The path to the account data files.
Defaults to ``$HOME/.ape/<plugin_name>`` unless overriden.
Defaults to ``$HOME/.ape/<plugin_name>`` unless overridden.
"""
path = self.config_manager.DATA_FOLDER / self.name
path.mkdir(parents=True, exist_ok=True)
Expand Down Expand Up @@ -573,25 +574,39 @@ class TestAccountContainerAPI(AccountContainerAPI):
``AccountContainerAPI`` directly. Then, they show up in the ``accounts`` test fixture.
"""

@property
@cached_property
def data_folder(self) -> Path:
"""
**NOTE**: Test account containers do not touch
persistant data. By default and unless overriden,
persistent data. By default and unless overridden,
this property returns the path to ``/dev/null`` and
it is not used for anything.
"""
if os.name == "posix":
return Path("/dev/null")
return Path("/dev/null" if os.name == "posix" else "NUL")

@raises_not_implemented
def get_test_account(self, index: int) -> "TestAccountAPI": # type: ignore[empty-body]
"""
Get the test account at the given index.
return Path("NUL")
Args:
index (int): The index of the test account.
Returns:
:class:`~ape.api.accounts.TestAccountAPI`
"""

@abstractmethod
def generate_account(self) -> "TestAccountAPI":
def generate_account(self, index: Optional[int] = None) -> "TestAccountAPI":
"""
Generate a new test account.
"""

def reset(self):
"""
Reset the account container to an original state.
"""


class TestAccountAPI(AccountAPI):
"""
Expand Down
17 changes: 16 additions & 1 deletion src/ape/api/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from pathlib import Path
from signal import SIGINT, SIGTERM, signal
from subprocess import DEVNULL, PIPE, Popen
from typing import Any, Optional, Union, cast
from typing import TYPE_CHECKING, Any, Optional, Union, cast

from eth_pydantic_types import HexBytes
from ethpm_types.abi import EventABI
Expand Down Expand Up @@ -42,6 +42,9 @@
raises_not_implemented,
)

if TYPE_CHECKING:
from ape.api.accounts import TestAccountAPI


class BlockAPI(BaseInterfaceModel):
"""
Expand Down Expand Up @@ -659,6 +662,18 @@ def set_balance(self, address: AddressType, amount: int):
amount (int): The balance to set in the address.
"""

@raises_not_implemented
def get_test_account(self, index: int) -> "TestAccountAPI": # type: ignore[empty-body]
"""
Retrieve one of the provider-generated test accounts.
Args:
index (int): The index of the test account in the HD-Path.
Returns:
:class:`~ape.api.accounts.TestAccountAPI`
"""

@log_instead_of_fail(default="<ProviderAPI>")
def __repr__(self) -> str:
return f"<{self.name.capitalize()} chain_id={self.chain_id}>"
Expand Down
54 changes: 37 additions & 17 deletions src/ape/managers/accounts.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class TestAccountManager(list, ManagerAccessMixin):
__test__ = False

_impersonated_accounts: dict[AddressType, ImpersonatedAccount] = {}
_accounts_by_index: dict[int, AccountAPI] = {}

@log_instead_of_fail(default="<TestAccountManager>")
def __repr__(self) -> str:
Expand All @@ -43,14 +44,13 @@ def __repr__(self) -> str:

@cached_property
def containers(self) -> dict[str, TestAccountContainerAPI]:
containers = {}
account_types = [
t for t in self.plugin_manager.account_types if issubclass(t[1][1], TestAccountAPI)
]
for plugin_name, (container_type, account_type) in account_types:
containers[plugin_name] = container_type(name=plugin_name, account_type=account_type)

return containers
account_types = filter(
lambda t: issubclass(t[1][1], TestAccountAPI), self.plugin_manager.account_types
)
return {
plugin_name: container_type(name=plugin_name, account_type=account_type)
for plugin_name, (container_type, account_type) in account_types
}

@property
def accounts(self) -> Iterator[AccountAPI]:
Expand All @@ -63,7 +63,7 @@ def aliases(self) -> Iterator[str]:
yield account.alias

def __len__(self) -> int:
return len(list(self.accounts))
return sum(len(c) for c in self.containers.values())

def __iter__(self) -> Iterator[AccountAPI]:
yield from self.accounts
Expand All @@ -74,13 +74,16 @@ def __getitem__(self, account_id):

@__getitem__.register
def __getitem_int(self, account_id: int):
if account_id in self._accounts_by_index:
return self._accounts_by_index[account_id]

original_account_id = account_id
if account_id < 0:
account_id = len(self) + account_id
for idx, account in enumerate(self.accounts):
if account_id == idx:
return account

raise IndexError(f"No account at index '{account_id}'.")
account = self.containers["test"].get_test_account(account_id)
self._accounts_by_index[original_account_id] = account
return account

@__getitem__.register
def __getitem_slice(self, account_id: slice):
Expand Down Expand Up @@ -136,6 +139,19 @@ def use_sender(self, account_id: Union[TestAccountAPI, AddressType, int]) -> Con
account = account_id if isinstance(account_id, TestAccountAPI) else self[account_id]
return _use_sender(account)

def init_test_account(
self, index: int, address: AddressType, private_key: str
) -> "TestAccountAPI":
container = self.containers["test"]
return container.init_test_account( # type: ignore[attr-defined]
index, address, private_key
)

def reset(self):
self._accounts_by_index = {}
for container in self.containers.values():
container.reset()


class AccountManager(BaseManager):
"""
Expand Down Expand Up @@ -168,7 +184,6 @@ def containers(self) -> dict[str, AccountContainerAPI]:
Returns:
dict[str, :class:`~ape.api.accounts.AccountContainerAPI`]
"""

containers = {}
data_folder = self.config_manager.DATA_FOLDER
data_folder.mkdir(exist_ok=True)
Expand Down Expand Up @@ -217,7 +232,6 @@ def __len__(self) -> int:
Returns:
int
"""

return sum(len(container) for container in self.containers.values())

def __iter__(self) -> Iterator[AccountAPI]:
Expand Down Expand Up @@ -291,7 +305,6 @@ def __getitem_int(self, account_id: int) -> AccountAPI:
Returns:
:class:`~ape.api.accounts.AccountAPI`
"""

if account_id < 0:
account_id = len(self) + account_id
for idx, account in enumerate(self):
Expand Down Expand Up @@ -366,7 +379,6 @@ def __contains__(self, address: AddressType) -> bool:
Returns:
bool: ``True`` when the given address is found.
"""

return (
any(address in container for container in self.containers.values())
or address in self.test_accounts
Expand All @@ -381,6 +393,14 @@ def use_sender(
account = self[account_id]
elif isinstance(account_id, str): # alias
account = self.load(account_id)
else:
raise TypeError(account_id)
else:
account = account_id

return _use_sender(account)

def init_test_account(
self, index: int, address: AddressType, private_key: str
) -> "TestAccountAPI":
return self.test_accounts.init_test_account(index, address, private_key)
2 changes: 2 additions & 0 deletions src/ape/managers/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -1887,6 +1887,8 @@ def reconfigure(self, **overrides):
self._config_override = overrides
_ = self.config

self.account_manager.test_accounts.reset()

def extract_manifest(self) -> PackageManifest:
return self.manifest

Expand Down
26 changes: 14 additions & 12 deletions src/ape/utils/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,17 +48,19 @@ def generate_dev_accounts(
list[:class:`~ape.utils.GeneratedDevAccount`]: List of development accounts.
"""
seed = Mnemonic.to_seed(mnemonic)
accounts = []
hd_path_format = (
hd_path if "{}" in hd_path or "{0}" in hd_path else f"{hd_path.rstrip('/')}/{{}}"
)
return [
_generate_dev_account(hd_path_format, i, seed)
for i in range(start_index, start_index + number_of_accounts)
]

if "{}" in hd_path or "{0}" in hd_path:
hd_path_format = hd_path
else:
hd_path_format = f"{hd_path.rstrip('/')}/{{}}"

for i in range(start_index, start_index + number_of_accounts):
hd_path_obj = HDPath(hd_path_format.format(i))
private_key = HexBytes(hd_path_obj.derive(seed)).hex()
address = Account.from_key(private_key).address
accounts.append(GeneratedDevAccount(address, private_key))

return accounts
def _generate_dev_account(hd_path, index: int, seed: bytes) -> GeneratedDevAccount:
return GeneratedDevAccount(
address=Account.from_key(
private_key := HexBytes(HDPath(hd_path.format(index)).derive(seed)).hex()
).address,
private_key=private_key,
)
17 changes: 14 additions & 3 deletions src/ape_node/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from web3.middleware import geth_poa_middleware
from yarl import URL

from ape.api import PluginConfig, SubprocessProvider, TestProviderAPI
from ape.api import PluginConfig, SubprocessProvider, TestAccountAPI, TestProviderAPI
from ape.logging import LogLevel, logger
from ape.types import SnapshotID
from ape.utils.misc import ZERO_ADDRESS, log_instead_of_fail, raises_not_implemented
Expand Down Expand Up @@ -130,10 +130,10 @@ def __init__(

geth_kwargs["dev_mode"] = True
hd_path = hd_path or DEFAULT_TEST_HD_PATH
accounts = generate_dev_accounts(
self._dev_accounts = generate_dev_accounts(
mnemonic, number_of_accounts=number_of_accounts, hd_path=hd_path
)
addresses = [a.address for a in accounts]
addresses = [a.address for a in self._dev_accounts]
addresses.extend(extra_funded_accounts or [])
bal_dict = {"balance": str(initial_balance)}
alloc = {a: bal_dict for a in addresses}
Expand Down Expand Up @@ -418,6 +418,17 @@ def mine(self, num_blocks: int = 1):
def build_command(self) -> list[str]:
return self._process.command if self._process else []

def get_test_account(self, index: int) -> "TestAccountAPI":
if self._process is None:
# Not managing the process. Use default approach.
test_container = self.account_manager.test_accounts.containers["test"]
return test_container.generate_account(index)

# perf: we avoid having to generate account keys twice by utilizing
# the accounts generated for geth's genesis.json.
account = self._process._dev_accounts[index]
return self.account_manager.init_test_account(index, account.address, account.private_key)


# NOTE: The default behavior of EthereumNodeBehavior assumes geth.
class Node(EthereumNodeProvider):
Expand Down
Loading

0 comments on commit 7caaf4d

Please sign in to comment.