From b40c1c018dc1e6378415df00e9d240286dd42447 Mon Sep 17 00:00:00 2001 From: antazoey Date: Wed, 28 Aug 2024 18:39:43 -0500 Subject: [PATCH] feat: support `request_header` driven from configs (#2252) --- docs/userguides/config.md | 33 ++++++++++ docs/userguides/networks.md | 37 +++++++++++ src/ape/api/config.py | 5 ++ src/ape/api/networks.py | 43 ++++++++++-- src/ape/api/providers.py | 15 ++++- src/ape/managers/__init__.py | 2 +- src/ape/managers/config.py | 9 +++ src/ape/managers/networks.py | 23 +++++-- src/ape/pytest/fixtures.py | 2 +- src/ape/utils/__init__.py | 5 +- src/ape/utils/_github.py | 3 +- src/ape/utils/misc.py | 64 +----------------- src/ape/utils/rpc.py | 91 ++++++++++++++++++++++++++ src/ape_ethereum/ecosystem.py | 13 +++- src/ape_ethereum/provider.py | 27 ++++++-- src/ape_node/provider.py | 5 ++ tests/functional/geth/conftest.py | 1 - tests/functional/geth/test_provider.py | 59 ++++++++++++++++- tests/functional/test_ecosystem.py | 18 ++++- tests/functional/test_provider.py | 3 +- tests/functional/utils/test_rpc.py | 45 +++++++++++++ 21 files changed, 412 insertions(+), 91 deletions(-) create mode 100644 src/ape/utils/rpc.py create mode 100644 tests/functional/utils/test_rpc.py diff --git a/docs/userguides/config.md b/docs/userguides/config.md index 0c84ff3f04..447349c824 100644 --- a/docs/userguides/config.md +++ b/docs/userguides/config.md @@ -179,6 +179,39 @@ Install these plugins by running command: ape plugins install . ``` +## Request Headers + +For Ape's HTTP usage, such as requests made via `web3.py`, optionally specify extra request headers. + +```yaml +request_headers: + # NOTE: Only using Content-Type as an example; can be any header key/value. + Content-Type: application/json +``` + +You can also specify request headers at the ecosystem, network, and provider levels: + +```yaml +# NOTE: All the headers are the same only for demo purposes. +# You can use headers you want for any of these config locations. +ethereum: + # Apply to all requests made to ethereum networks. + request_headers: + Content-Type: application/json + + mainnet: + # Apply to all requests made to ethereum:mainnet (using any provider) + request_headers: + Content-Type: application/json + +node: + # Apply to any request using the `node` provider. + request_headers: + Content-Type: application/json +``` + +To learn more about how request headers work in Ape, see [this section of the Networking guide](./networks.html#request-headers). + ## Testing Configure your test accounts: diff --git a/docs/userguides/networks.md b/docs/userguides/networks.md index 232573090d..ae6237b0fb 100644 --- a/docs/userguides/networks.md +++ b/docs/userguides/networks.md @@ -340,6 +340,43 @@ You may use one of: For the local network configuration, the default is `"max"`. Otherwise, it is `"auto"`. +## Request Headers + +There are several layers of request-header configuration. +They get merged into each-other in this order, with the exception being `User-Agent`, which has an append-behavior. + +- Default Ape headers (includes `User-Agent`) +- Top-level configuration for headers (using `request_headers:` key) +- Per-ecosystem configuration +- Per-network configuration +- Per-provider configuration + +Use the top-level `request_headers:` config to specify headers for every request. +Use ecosystem-level specification for only requests made when connected to that ecosystem. +Network and provider configurations work similarly; they are only used when connecting to that network or provider. + +Here is an example using each layer: + +```yaml +request_headers: + Top-Level: "UseThisOnEveryRequest" + +ethereum: + request_headers: + Ecosystem-Level: "UseThisOnEveryEthereumRequest" + + mainnet: + request_headers: + Network-Level: "UseThisOnAllRequestsToEthereumMainnet" + +node: + request_headers: + Provider-Level: "UseThisOnAllRequestsUsingNodeProvider" +``` + +When using `User-Agent`, it will not override Ape's default `User-Agent` nor will each layer override each-other's. +Instead, they are carefully appended to each other, allowing you to have a very customizable `User-Agent`. + ## Local Network The default network in Ape is the local network (keyword `"local"`). diff --git a/src/ape/api/config.py b/src/ape/api/config.py index 7df4a4b669..05f67897f8 100644 --- a/src/ape/api/config.py +++ b/src/ape/api/config.py @@ -333,6 +333,11 @@ def __init__(self, *args, **kwargs): The name of the project. """ + request_headers: dict = {} + """ + Extra request headers for all HTTP requests. + """ + version: str = "" """ The version of the project. diff --git a/src/ape/api/networks.py b/src/ape/api/networks.py index 52bde2ff03..798e60e391 100644 --- a/src/ape/api/networks.py +++ b/src/ape/api/networks.py @@ -13,6 +13,7 @@ from eth_utils import keccak, to_int from ethpm_types import BaseModel, ContractType from ethpm_types.abi import ABIType, ConstructorABI, EventABI, MethodABI +from pydantic import model_validator from ape.exceptions import ( CustomError, @@ -31,6 +32,7 @@ ExtraAttributesMixin, ExtraModelAttributes, ManagerAccessMixin, + RPCHeaders, abstractmethod, cached_property, log_instead_of_fail, @@ -68,7 +70,9 @@ class EcosystemAPI(ExtraAttributesMixin, BaseInterfaceModel): The name of the ecosystem. This should be set the same name as the plugin. """ - request_header: dict + # TODO: In 0.9, make @property that returns value from config, + # and use REQUEST_HEADER as plugin-defined constants. + request_header: dict = {} """A shareable HTTP header for network requests.""" fee_token_symbol: str @@ -80,6 +84,14 @@ class EcosystemAPI(ExtraAttributesMixin, BaseInterfaceModel): _default_network: Optional[str] = None """The default network of the ecosystem, such as ``local``.""" + @model_validator(mode="after") + @classmethod + def _validate_ecosystem(cls, model): + headers = RPCHeaders(**model.request_header) + headers["User-Agent"] = f"ape-{model.name}" + model.request_header = dict(**headers) + return model + @log_instead_of_fail(default="") def __repr__(self) -> str: return f"<{self.name}>" @@ -289,9 +301,7 @@ def networks(self) -> dict[str, "NetworkAPI"]: @cached_property def _networks_from_plugins(self) -> dict[str, "NetworkAPI"]: return { - network_name: network_class( - name=network_name, ecosystem=self, request_header=self.request_header - ) + network_name: network_class(name=network_name, ecosystem=self) for _, (ecosystem_name, network_name, network_class) in self.plugin_manager.networks if ecosystem_name == self.name } @@ -647,6 +657,16 @@ def decode_custom_error( Optional[CustomError]: If it able to decode one, else ``None``. """ + def _get_request_headers(self) -> RPCHeaders: + # Internal helper method called by NetworkManager + headers = RPCHeaders(**self.request_header) + # Have to do it this way to avoid "multiple-keys" error. + configured_headers: dict = self.config.get("request_headers", {}) + for key, value in configured_headers.items(): + headers[key] = value + + return headers + class ProviderContextManager(ManagerAccessMixin): """ @@ -809,7 +829,9 @@ class NetworkAPI(BaseInterfaceModel): ecosystem: EcosystemAPI """The ecosystem of the network.""" - request_header: dict + # TODO: In 0.9, make @property that returns value from config, + # and use REQUEST_HEADER as plugin-defined constants. + request_header: dict = {} """A shareable network HTTP header.""" # See ``.default_provider`` which is the proper field. @@ -1043,7 +1065,6 @@ def providers(self): # -> dict[str, Partial[ProviderAPI]] provider_class, name=provider_name, network=self, - request_header=self.request_header, ) return providers @@ -1285,6 +1306,16 @@ def verify_chain_id(self, chain_id: int): if self.name not in ("custom", LOCAL_NETWORK_NAME) and self.chain_id != chain_id: raise NetworkMismatchError(chain_id, self) + def _get_request_headers(self) -> RPCHeaders: + # Internal helper method called by NetworkManager + headers = RPCHeaders(**self.request_header) + # Have to do it this way to avoid multiple-keys error. + configured_headers: dict = self.config.get("request_headers", {}) + for key, value in configured_headers.items(): + headers[key] = value + + return headers + class ForkedNetworkAPI(NetworkAPI): @property diff --git a/src/ape/api/providers.py b/src/ape/api/providers.py index c64d4637b0..abde585077 100644 --- a/src/ape/api/providers.py +++ b/src/ape/api/providers.py @@ -41,6 +41,7 @@ log_instead_of_fail, raises_not_implemented, ) +from ape.utils.rpc import RPCHeaders if TYPE_CHECKING: from ape.api.accounts import TestAccountAPI @@ -177,7 +178,9 @@ class ProviderAPI(BaseInterfaceModel): provider_settings: dict = {} """The settings for the provider, as overrides to the configuration.""" - request_header: dict + # TODO: In 0.9, make @property that returns value from config, + # and use REQUEST_HEADER as plugin-defined constants. + request_header: dict = {} """A header to set on HTTP/RPC requests.""" block_page_size: int = 100 @@ -845,6 +848,16 @@ def get_virtual_machine_error(self, exception: Exception, **kwargs) -> VirtualMa """ return VirtualMachineError(base_err=exception, **kwargs) + def _get_request_headers(self) -> RPCHeaders: + # Internal helper method called by NetworkManager + headers = RPCHeaders(**self.request_header) + # Have to do it this way to avoid "multiple-keys" error. + configured_headers: dict = self.config.get("request_headers", {}) + for key, value in configured_headers.items(): + headers[key] = value + + return headers + class TestProviderAPI(ProviderAPI): """ diff --git a/src/ape/managers/__init__.py b/src/ape/managers/__init__.py index 1b72168f61..98ead4e335 100644 --- a/src/ape/managers/__init__.py +++ b/src/ape/managers/__init__.py @@ -14,7 +14,7 @@ ManagerAccessMixin.plugin_manager = PluginManager() ManagerAccessMixin.config_manager = ConfigManager( - request_header={"User-Agent": USER_AGENT}, + request_header={"User-Agent": USER_AGENT, "Content-Type": "application/json"}, ) ManagerAccessMixin.compiler_manager = CompilerManager() ManagerAccessMixin.network_manager = NetworkManager() diff --git a/src/ape/managers/config.py b/src/ape/managers/config.py index ff1a182287..d6a5a357a0 100644 --- a/src/ape/managers/config.py +++ b/src/ape/managers/config.py @@ -17,6 +17,7 @@ get_item_with_extras, only_raise_attribute_error, ) +from ape.utils.rpc import RPCHeaders CONFIG_FILE_NAME = "ape-config.yaml" @@ -125,6 +126,14 @@ def isolate_data_folder(self) -> Iterator[Path]: finally: self.DATA_FOLDER = original_data_folder + def _get_request_headers(self) -> RPCHeaders: + # Avoid multiple keys error by not initializing with both dicts. + headers = RPCHeaders(**self.REQUEST_HEADER) + for key, value in self.request_headers.items(): + headers[key] = value + + return headers + def merge_configs(*cfgs: dict) -> dict: if len(cfgs) == 0: diff --git a/src/ape/managers/networks.py b/src/ape/managers/networks.py index 68d33287e1..e247d1528a 100644 --- a/src/ape/managers/networks.py +++ b/src/ape/managers/networks.py @@ -6,6 +6,7 @@ from ape.api.networks import NetworkAPI from ape.exceptions import EcosystemNotFoundError, NetworkError, NetworkNotFoundError from ape.managers.base import BaseManager +from ape.utils import RPCHeaders from ape.utils.basemodel import ( ExtraAttributesMixin, ExtraModelAttributes, @@ -85,6 +86,22 @@ def ecosystem(self) -> EcosystemAPI: """ return self.network.ecosystem + def get_request_headers( + self, ecosystem_name: str, network_name: str, provider_name: str + ) -> RPCHeaders: + """ + All request headers to be used when connecting to this network. + """ + ecosystem = self.get_ecosystem(ecosystem_name) + network = ecosystem.get_network(network_name) + provider = network.get_provider(provider_name) + headers = self.config_manager._get_request_headers() + for obj in (ecosystem, network, provider): + for key, value in obj._get_request_headers().items(): + headers[key] = value + + return headers + def fork( self, provider_name: Optional[str] = None, @@ -225,12 +242,9 @@ def ecosystems(self) -> dict[str, EcosystemAPI]: @cached_property def _plugin_ecosystems(self) -> dict[str, EcosystemAPI]: - def to_kwargs(name: str) -> dict: - return {"name": name, "request_header": self.config_manager.REQUEST_HEADER} - # Load plugins. plugins = self.plugin_manager.ecosystems - return {n: cls(**to_kwargs(n)) for n, cls in plugins} # type: ignore[operator] + return {n: cls(name=n) for n, cls in plugins} # type: ignore[operator] def create_custom_provider( self, @@ -287,7 +301,6 @@ def create_custom_provider( network=network, provider_settings=provider_settings, data_folder=self.ethereum.data_folder / name, - request_header=network.request_header, ) def __iter__(self) -> Iterator[str]: diff --git a/src/ape/pytest/fixtures.py b/src/ape/pytest/fixtures.py index dd9ce8f4aa..9f4e5587ea 100644 --- a/src/ape/pytest/fixtures.py +++ b/src/ape/pytest/fixtures.py @@ -16,7 +16,7 @@ from ape.pytest.config import ConfigWrapper from ape.types import SnapshotID from ape.utils.basemodel import ManagerAccessMixin -from ape.utils.misc import allow_disconnected +from ape.utils.rpc import allow_disconnected class PytestApeFixtures(ManagerAccessMixin): diff --git a/src/ape/utils/__init__.py b/src/ape/utils/__init__.py index 822776c4e4..ac4733b542 100644 --- a/src/ape/utils/__init__.py +++ b/src/ape/utils/__init__.py @@ -25,10 +25,8 @@ DEFAULT_TRANSACTION_ACCEPTANCE_TIMEOUT, EMPTY_BYTES32, SOURCE_EXCLUDE_PATTERNS, - USER_AGENT, ZERO_ADDRESS, add_padding_to_strings, - allow_disconnected, as_our_module, cached_property, extract_nested_value, @@ -44,7 +42,6 @@ raises_not_implemented, run_until_complete, singledispatchmethod, - stream_response, to_int, ) from ape.utils.os import ( @@ -61,6 +58,7 @@ use_temp_sys_path, ) from ape.utils.process import JoinableQueue, spawn +from ape.utils.rpc import USER_AGENT, RPCHeaders, allow_disconnected, stream_response from ape.utils.testing import ( DEFAULT_NUMBER_OF_TEST_ACCOUNTS, DEFAULT_TEST_ACCOUNT_BALANCE, @@ -125,6 +123,7 @@ "path_match", "raises_not_implemented", "returns_array", + "RPCHeaders", "run_in_tempdir", "run_until_complete", "singledispatchmethod", diff --git a/src/ape/utils/_github.py b/src/ape/utils/_github.py index 63cfe15b45..b5da772cc4 100644 --- a/src/ape/utils/_github.py +++ b/src/ape/utils/_github.py @@ -14,7 +14,8 @@ from ape.exceptions import CompilerError, ProjectError, UnknownVersionError from ape.logging import logger -from ape.utils.misc import USER_AGENT, cached_property, stream_response +from ape.utils.misc import cached_property +from ape.utils.rpc import USER_AGENT, stream_response class GitProcessWrapper: diff --git a/src/ape/utils/misc.py b/src/ape/utils/misc.py index 08eefd8cf2..81e8b33c56 100644 --- a/src/ape/utils/misc.py +++ b/src/ape/utils/misc.py @@ -4,7 +4,7 @@ import json import sys from asyncio import gather -from collections.abc import Callable, Coroutine, Mapping +from collections.abc import Coroutine, Mapping from datetime import datetime, timezone from functools import cached_property, lru_cache, singledispatchmethod, wraps from importlib.metadata import PackageNotFoundError, distributions @@ -12,14 +12,12 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, Optional, TypeVar, cast -import requests import yaml from eth_pydantic_types import HexBytes from eth_utils import is_0x_prefixed from packaging.specifiers import SpecifierSet -from tqdm.auto import tqdm # type: ignore -from ape.exceptions import APINotImplementedError, ProviderNotConnectedError +from ape.exceptions import APINotImplementedError from ape.logging import logger from ape.utils.os import expand_environment_variables @@ -187,7 +185,6 @@ def get_package_version(obj: Any) -> str: __version__ = get_package_version(__name__) -USER_AGENT = f"Ape/{__version__} (Python/{_python_version})" def load_config(path: Path, expand_envars=True, must_exist=False) -> dict: @@ -308,33 +305,6 @@ def add_padding_to_strings( return spaced_items -def stream_response(download_url: str, progress_bar_description: str = "Downloading") -> bytes: - """ - Download HTTP content by streaming and returning the bytes. - Progress bar will be displayed in the CLI. - - Args: - download_url (str): String to get files to download. - progress_bar_description (str): Downloading word. - - Returns: - bytes: Content in bytes to show the progress. - """ - response = requests.get(download_url, stream=True) - response.raise_for_status() - - total_size = int(response.headers.get("content-length", 0)) - progress_bar = tqdm(total=total_size, unit="iB", unit_scale=True, leave=False) - progress_bar.set_description(progress_bar_description) - content = b"" - for data in response.iter_content(1024, decode_unicode=True): - progress_bar.update(len(data)) - content += data - - progress_bar.close() - return content - - def raises_not_implemented(fn): """ Decorator for raising helpful not implemented error. @@ -400,33 +370,6 @@ def run_until_complete(*item: Any) -> Any: return result -def allow_disconnected(fn: Callable): - """ - A decorator that instead of raising :class:`~ape.exceptions.ProviderNotConnectedError` - warns and returns ``None``. - - Usage example:: - - from typing import Optional - from ape.types import SnapshotID - from ape.utils import return_none_when_disconnected - - @allow_disconnected - def try_snapshot(self) -> Optional[SnapshotID]: - return self.chain.snapshot() - - """ - - def inner(*args, **kwargs): - try: - return fn(*args, **kwargs) - except ProviderNotConnectedError: - logger.warning("Provider is not connected.") - return None - - return inner - - def nonreentrant(key_fn): def inner(f): locks = set() @@ -568,7 +511,6 @@ def as_our_module(cls_or_def: _MOD_T, doc_str: Optional[str] = None) -> _MOD_T: __all__ = [ - "allow_disconnected", "cached_property", "_dict_overlay", "extract_nested_value", @@ -584,7 +526,5 @@ def as_our_module(cls_or_def: _MOD_T, doc_str: Optional[str] = None) -> _MOD_T: "raises_not_implemented", "run_until_complete", "singledispatchmethod", - "stream_response", "to_int", - "USER_AGENT", ] diff --git a/src/ape/utils/rpc.py b/src/ape/utils/rpc.py new file mode 100644 index 0000000000..3cfa7b54e2 --- /dev/null +++ b/src/ape/utils/rpc.py @@ -0,0 +1,91 @@ +from collections.abc import Callable + +import requests +from requests.models import CaseInsensitiveDict +from tqdm import tqdm # type: ignore + +from ape.exceptions import ProviderNotConnectedError +from ape.logging import logger +from ape.utils.misc import __version__, _python_version + +USER_AGENT = f"Ape/{__version__} (Python/{_python_version})" + + +def allow_disconnected(fn: Callable): + """ + A decorator that instead of raising :class:`~ape.exceptions.ProviderNotConnectedError` + warns and returns ``None``. + + Usage example:: + + from typing import Optional + from ape.types import SnapshotID + from ape.utils import return_none_when_disconnected + + @allow_disconnected + def try_snapshot(self) -> Optional[SnapshotID]: + return self.chain.snapshot() + + """ + + def inner(*args, **kwargs): + try: + return fn(*args, **kwargs) + except ProviderNotConnectedError: + logger.warning("Provider is not connected.") + return None + + return inner + + +def stream_response(download_url: str, progress_bar_description: str = "Downloading") -> bytes: + """ + Download HTTP content by streaming and returning the bytes. + Progress bar will be displayed in the CLI. + + Args: + download_url (str): String to get files to download. + progress_bar_description (str): Downloading word. + + Returns: + bytes: Content in bytes to show the progress. + """ + response = requests.get(download_url, stream=True) + response.raise_for_status() + + total_size = int(response.headers.get("content-length", 0)) + progress_bar = tqdm(total=total_size, unit="iB", unit_scale=True, leave=False) + progress_bar.set_description(progress_bar_description) + content = b"" + for data in response.iter_content(1024, decode_unicode=True): + progress_bar.update(len(data)) + content += data + + progress_bar.close() + return content + + +class RPCHeaders(CaseInsensitiveDict): + """ + A dict-like data-structure for HTTP-headers. + It is case-insensitive and appends user-agent strings + rather than overrides. + """ + + def __setitem__(self, key, value): + if key.lower() != "user-agent" or not self.__contains__("user-agent"): + return super().__setitem__(key, value) + + # Handle appending the user-agent (without replacing). + existing_user_agent = self.__getitem__("user-agent") + parts = [a.strip() for a in value.split(" ")] + new_parts = [] + for part in parts: + if part in existing_user_agent: + # Already added. + continue + else: + new_parts.append(part) + + if new_user_agent := " ".join(new_parts): + super().__setitem__(key, f"{existing_user_agent} {new_user_agent}") diff --git a/src/ape_ethereum/ecosystem.py b/src/ape_ethereum/ecosystem.py index 738d0fc942..dd62e42e4b 100644 --- a/src/ape_ethereum/ecosystem.py +++ b/src/ape_ethereum/ecosystem.py @@ -147,6 +147,9 @@ class NetworkConfig(PluginConfig): base_fee_multiplier: float = 1.0 """A multiplier to apply to a transaction base fee.""" + request_headers: dict = {} + """Optionally config extra request headers whenever using this network.""" + @field_validator("gas_limit", mode="before") @classmethod def validate_gas_limit(cls, value): @@ -220,6 +223,9 @@ class BaseEthereumConfig(PluginConfig): _forked_configs: dict[str, ForkedNetworkConfig] = {} _custom_networks: dict[str, NetworkConfig] = {} + # NOTE: This gets appended to Ape's root User-Agent string. + request_headers: dict = {} + model_config = SettingsConfigDict(extra="allow") @model_validator(mode="before") @@ -243,7 +249,12 @@ def load_network_configs(cls, values): data = merge_configs(default_fork_model, obj) cfg_forks[key] = ForkedNetworkConfig.model_validate(data) - elif key != LOCAL_NETWORK_NAME and key not in cls.NETWORKS and isinstance(obj, dict): + elif ( + key != LOCAL_NETWORK_NAME + and key not in cls.NETWORKS + and isinstance(obj, dict) + and key not in ("request_headers",) + ): # Custom network. default_network_model = create_network_config( default_transaction_type=cls.DEFAULT_TRANSACTION_TYPE diff --git a/src/ape_ethereum/provider.py b/src/ape_ethereum/provider.py index a53c1bcd9f..e488db70bb 100644 --- a/src/ape_ethereum/provider.py +++ b/src/ape_ethereum/provider.py @@ -22,6 +22,7 @@ from requests import HTTPError from web3 import HTTPProvider, IPCProvider, Web3 from web3 import WebsocketProvider as WebSocketProvider +from web3._utils.http import construct_user_agent from web3.exceptions import ContractLogicError as Web3ContractLogicError from web3.exceptions import ( ExtraDataLengthError, @@ -1297,6 +1298,11 @@ class EthereumNodeProvider(Web3Provider, ABC): name: str = "node" + # NOTE: Appends user-agent to base User-Agent string. + request_header: dict = { + "User-Agent": construct_user_agent(str(HTTPProvider)), + } + @property def uri(self) -> str: if "url" in self.provider_settings: @@ -1444,8 +1450,14 @@ def _ots_api_level(self) -> Optional[int]: def _set_web3(self): # Clear cached version when connecting to another URI. self._client_version = None + headers = self.network_manager.get_request_headers( + self.network.ecosystem.name, self.network.name, self.name + ) self._web3 = _create_web3( - http_uri=self.http_uri, ipc_path=self.ipc_path, ws_uri=self.ws_uri + http_uri=self.http_uri, + ipc_path=self.ipc_path, + ws_uri=self.ws_uri, + request_kwargs={"headers": headers}, ) def _complete_connect(self): @@ -1544,7 +1556,10 @@ def connect(self): def _create_web3( - http_uri: Optional[str] = None, ipc_path: Optional[Path] = None, ws_uri: Optional[str] = None + http_uri: Optional[str] = None, + ipc_path: Optional[Path] = None, + ws_uri: Optional[str] = None, + request_kwargs: Optional[dict] = None, ): # NOTE: This list is ordered by try-attempt. # Try ENV, then IPC, and then HTTP last. @@ -1552,9 +1567,11 @@ def _create_web3( if ipc := ipc_path: providers.append(lambda: IPCProvider(ipc_path=ipc)) if http := http_uri: - providers.append( - lambda: HTTPProvider(endpoint_uri=http, request_kwargs={"timeout": 30 * 60}) - ) + request_kwargs = request_kwargs or {} + if "timeout" not in request_kwargs: + request_kwargs["timeout"] = 30 * 60 + + providers.append(lambda: HTTPProvider(endpoint_uri=http, request_kwargs=request_kwargs)) if ws := ws_uri: providers.append(lambda: WebSocketProvider(endpoint_uri=ws)) diff --git a/src/ape_node/provider.py b/src/ape_node/provider.py index 857a516646..44581963b3 100644 --- a/src/ape_node/provider.py +++ b/src/ape_node/provider.py @@ -262,6 +262,11 @@ class EthereumNodeConfig(PluginConfig): based on your node's client-version and available RPCs. """ + request_headers: dict = {} + """ + Optionally specify request headers to use whenever using this provider. + """ + model_config = SettingsConfigDict(extra="allow") @field_validator("call_trace_approach", mode="before") diff --git a/tests/functional/geth/conftest.py b/tests/functional/geth/conftest.py index e903284534..97bdffb6e5 100644 --- a/tests/functional/geth/conftest.py +++ b/tests/functional/geth/conftest.py @@ -57,7 +57,6 @@ def mock_geth(geth_provider, mock_web3): network=geth_provider.network, provider_settings={}, data_folder=Path("."), - request_header={}, ) original_web3 = provider._web3 provider._web3 = mock_web3 diff --git a/tests/functional/geth/test_provider.py b/tests/functional/geth/test_provider.py index dc5e76a97e..fb16a03b94 100644 --- a/tests/functional/geth/test_provider.py +++ b/tests/functional/geth/test_provider.py @@ -7,9 +7,11 @@ from eth_utils import keccak, to_hex from evmchains import PUBLIC_CHAIN_META from hexbytes import HexBytes +from web3 import AutoProvider from web3.exceptions import ContractLogicError as Web3ContractLogicError from web3.exceptions import ExtraDataLengthError from web3.middleware import geth_poa_middleware as ExtraDataToPOAMiddleware +from web3.providers import HTTPProvider from ape.exceptions import ( APINotImplementedError, @@ -109,7 +111,7 @@ def test_uri_non_dev_and_not_configured(mocker, ethereum): network.name = "gorillanet" network.ecosystem.name = "gorillas" - provider = Node.model_construct(network=network, request_header={}) + provider = Node.model_construct(network=network) with pytest.raises(ProviderError): _ = provider.uri @@ -241,6 +243,61 @@ def test_connect_using_only_ipc_for_uri(project, networks, geth_provider): assert node.uri == f"{ipc_path}" +@geth_process_test +def test_connect_request_headers(project, geth_provider, networks): + http_provider = None + config = { + "request_headers": {"h0": 0, "User-Agent": "myapp/2.0"}, + "ethereum": { + "request_headers": {"h1": 1, "User-Agent": "ETH/1.0"}, + "local": { + "request_headers": {"h2": 2, "user-agent": "MyPrivateNetwork/0.0.1"}, + }, + }, + "node": {"request_headers": {"h3": 3, "USER-AGENT": "custom-geth-client/v100"}}, + } + with project.temp_config(**config): + with networks.ethereum.local.use_provider("node") as geth: + w3_provider = geth.web3.provider + if isinstance(w3_provider, AutoProvider): + for pot_provider_fn in w3_provider._potential_providers: + pot_provider = pot_provider_fn() + if not isinstance(pot_provider, HTTPProvider): + continue + else: + http_provider = pot_provider + + elif isinstance(w3_provider, HTTPProvider): + http_provider = w3_provider + + else: + pytest.fail("Not using HTTP. Please adjust test.") + + assert http_provider is not None, "Setup failed - HTTP Provider still None." + + assert isinstance(http_provider._request_kwargs, dict) + actual = http_provider._request_kwargs["headers"] + assert actual["h0"] == 0 # top-level + assert actual["h1"] == 1 # ecosystem + assert actual["h2"] == 2 # network + assert actual["h3"] == 3 # provider + + # Also, assert Ape's default user-agent strings. + assert actual["User-Agent"].startswith("Ape/") + assert "Python" in actual["User-Agent"] + assert "ape-ethereum" in actual["User-Agent"] + assert "web3.py/" in actual["User-Agent"] + + # Show other default headers. + assert actual["Content-Type"] == "application/json" + + # Show appended user-agents strings. + assert "myapp/2.0" in actual["User-Agent"] + assert "ETH/1.0" in actual["User-Agent"] + assert "MyPrivateNetwork/0.0.1" in actual["User-Agent"] + assert "custom-geth-client/v100" in actual["User-Agent"] + + @geth_process_test @pytest.mark.parametrize("block_id", (0, "0", "0x0", HexStr("0x0"))) def test_get_block(geth_provider, block_id): diff --git a/tests/functional/test_ecosystem.py b/tests/functional/test_ecosystem.py index 1aa2c1584c..86e78a3799 100644 --- a/tests/functional/test_ecosystem.py +++ b/tests/functional/test_ecosystem.py @@ -14,7 +14,7 @@ from ape.types import AddressType, CurrencyValueComparable from ape.utils import DEFAULT_LOCAL_TRANSACTION_ACCEPTANCE_TIMEOUT from ape_ethereum import TransactionTrace -from ape_ethereum.ecosystem import BLUEPRINT_HEADER, BaseEthereumConfig, Block +from ape_ethereum.ecosystem import BLUEPRINT_HEADER, BaseEthereumConfig, Block, Ethereum from ape_ethereum.transactions import ( DynamicFeeTransaction, Receipt, @@ -80,6 +80,22 @@ def test_name(ethereum): assert ethereum.name == "ethereum" +def test_request_header(ethereum): + actual = ethereum.request_header + expected = {"User-Agent": "ape-ethereum"} + assert actual == expected + + +def test_request_header_subclass(): + class L2(Ethereum): + name: str = "l2" + + l2 = L2() + actual = l2.request_header + expected = {"User-Agent": "ape-l2"} + assert actual == expected + + def test_name_when_custom(configured_custom_ecosystem, networks): ecosystem = networks.get_ecosystem(CUSTOM_ECOSYSTEM_NAME) actual = ecosystem.name diff --git a/tests/functional/test_provider.py b/tests/functional/test_provider.py index a054df6488..e30ae84bbd 100644 --- a/tests/functional/test_provider.py +++ b/tests/functional/test_provider.py @@ -478,7 +478,7 @@ def disconnect(self): try: with pytest.raises(ProviderError, match=expected): - _ = MyProvider(data_folder=None, name=None, network=None, request_header=None) + _ = MyProvider(data_folder=None, name=None, network=None) finally: if WEB3_PROVIDER_URI_ENV_VAR_NAME in os.environ: @@ -494,7 +494,6 @@ def test_account_balance_state(project, eth_tester_provider, owner): provider = LocalProvider( name="test", network=eth_tester_provider.network, - request_header=eth_tester_provider.request_header, ) provider.connect() bal = provider.get_balance(owner.address) diff --git a/tests/functional/utils/test_rpc.py b/tests/functional/utils/test_rpc.py new file mode 100644 index 0000000000..d2caf36bbd --- /dev/null +++ b/tests/functional/utils/test_rpc.py @@ -0,0 +1,45 @@ +import pytest + +from ape.utils.rpc import RPCHeaders + + +class TestRPCHeaders: + @pytest.fixture + def headers(self): + return RPCHeaders() + + @pytest.mark.parametrize("key", ("Content-Type", "CONTENT-TYPE")) + def test_setitem_key_case_insensitive(self, key, headers): + headers[key] = "application/javascript" + headers[key.lower()] = "application/json" + assert headers[key] == "application/json" + assert headers[key.lower()] == "application/json" + + def test_setitem_user_agent_does_not_add_twice(self, headers): + expected = "test-user-agent/1.0" + headers["User-Agent"] = expected + # Add again. It should not add twice. + headers["User-Agent"] = expected + assert headers["User-Agent"] == expected + + def test_setitem_user_agent_appends(self, headers): + headers["User-Agent"] = "test0/1.0" + headers["User-Agent"] = "test1/2.0" + assert headers["User-Agent"] == "test0/1.0 test1/2.0" + + def test_setitem_user_agent_parts_exist(self, headers): + """ + Tests the case when user-agents share a sub-set + of each other, that it does not duplicate. + """ + headers["User-Agent"] = "test0/1.0" + # The beginning of the user-agent is already present. + # It shouldn't add the full thing. + headers["User-Agent"] = "test0/1.0 test1/2.0" + assert headers["User-Agent"] == "test0/1.0 test1/2.0" + # unexpected = "test0/1.0 test0/1.0 test1/2.0" + + @pytest.mark.parametrize("key", ("user-agent", "User-Agent", "USER-AGENT")) + def test_contains_user_agent(self, key, headers): + headers["User-Agent"] = "test0/1.0" + assert key in headers