Skip to content

Commit

Permalink
perf: avoid enriching entire trace when only requesting return_value (
Browse files Browse the repository at this point in the history
ApeWorX#2208)

Co-authored-by: El De-dog-lo <[email protected]>
  • Loading branch information
antazoey and fubuloubu authored Aug 7, 2024
1 parent e2839a0 commit f854ed5
Show file tree
Hide file tree
Showing 10 changed files with 197 additions and 84 deletions.
3 changes: 3 additions & 0 deletions src/ape/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,9 @@ def __eq__(self, other: Any) -> bool:
return NotImplemented


CurrencyValueComparable.__name__ = int.__name__


CurrencyValue: TypeAlias = CurrencyValueComparable
"""
An alias to :class:`~ape.types.CurrencyValueComparable` for
Expand Down
122 changes: 65 additions & 57 deletions src/ape_ethereum/ecosystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
is_hex,
is_hex_address,
keccak,
to_bytes,
to_checksum_address,
to_hex,
)
from ethpm_types import ContractType
from ethpm_types.abi import ABIType, ConstructorABI, EventABI, MethodABI
Expand Down Expand Up @@ -57,7 +59,6 @@
StructParser,
is_array,
returns_array,
to_int,
)
from ape.utils.basemodel import _assert_not_ipython_check, only_raise_attribute_error
from ape.utils.misc import DEFAULT_MAX_RETRIES_TX, DEFAULT_TRANSACTION_TYPE
Expand Down Expand Up @@ -161,7 +162,7 @@ def validate_gas_limit(cls, value):
return int(value)

elif isinstance(value, str) and is_hex(value) and is_0x_prefixed(value):
return to_int(HexBytes(value))
return int(value, 16)

elif is_hex(value):
raise ValueError("Gas limit hex str must include '0x' prefix.")
Expand Down Expand Up @@ -400,7 +401,7 @@ def decode_address(cls, raw_address: RawAddress) -> AddressType:

@classmethod
def encode_address(cls, address: AddressType) -> RawAddress:
return str(address)
return f"{address}"

def decode_transaction_type(self, transaction_type_id: Any) -> type[TransactionAPI]:
if isinstance(transaction_type_id, TransactionType):
Expand Down Expand Up @@ -721,15 +722,16 @@ def decode_returndata(self, abi: MethodABI, raw_data: bytes) -> tuple[Any, ...]:
):
# Array of structs or tuples: don't convert to list
# Array of anything else: convert to single list
return (
(
[
output_values[0],
],
)
if issubclass(type(output_values[0]), Struct)
else ([o for o in output_values[0]],) # type: ignore[union-attr]
)

if issubclass(type(output_values[0]), Struct):
return ([output_values[0]],)

else:
try:
return ([o for o in output_values[0]],) # type: ignore[union-attr]
except Exception:
# On-chains transaction data errors.
return (output_values,)

elif returns_array(abi):
# Tuple with single item as the array.
Expand All @@ -747,7 +749,7 @@ def _enrich_value(self, value: Any, **kwargs) -> Any:
if len(value) > 24:
return humanize_hash(cast(Hash32, value))

hex_str = HexBytes(value).hex()
hex_str = to_hex(value)
if is_hex_address(hex_str):
return self._enrich_value(hex_str, **kwargs)

Expand Down Expand Up @@ -1028,6 +1030,7 @@ def get_abi(_topic: HexStr) -> Optional[LogInputABICollection]:
def enrich_trace(self, trace: TraceAPI, **kwargs) -> TraceAPI:
kwargs["trace"] = trace
if not isinstance(trace, Trace):
# Can only enrich `ape_ethereum.trace.Trace` (or subclass) implementations.
return trace

elif trace._enriched_calltree is not None:
Expand All @@ -1047,11 +1050,8 @@ def enrich_trace(self, trace: TraceAPI, **kwargs) -> TraceAPI:
# Return value was discovered already.
kwargs["return_value"] = return_value

enriched_calltree = self._enrich_calltree(data, **kwargs)

# Cache the result back on the trace.
trace._enriched_calltree = enriched_calltree

trace._enriched_calltree = self._enrich_calltree(data, **kwargs)
return trace

def _enrich_calltree(self, call: dict, **kwargs) -> dict:
Expand Down Expand Up @@ -1080,10 +1080,10 @@ def _enrich_calltree(self, call: dict, **kwargs) -> dict:
call["calls"] = [self._enrich_calltree(c, **kwargs) for c in subcalls]

# Figure out the contract.
address = call.pop("address", "")
address: AddressType = call.pop("address", "")
try:
call["contract_id"] = address = kwargs["contract_address"] = str(
self.decode_address(address)
call["contract_id"] = address = kwargs["contract_address"] = self.decode_address(
address
)
except Exception:
# Tx was made with a weird address.
Expand All @@ -1104,25 +1104,25 @@ def _enrich_calltree(self, call: dict, **kwargs) -> dict:
else:
# Collapse pre-compile address calls
if 1 <= address_int <= 9:
if len(call.get("calls", [])) == 1:
return call["calls"][0]

return {"contract_id": f"{address_int}", "calls": call["calls"]}
return (
call["calls"][0]
if len(call.get("calls", [])) == 1
else {"contract_id": f"{address_int}", "calls": call["calls"]}
)

depth = call.get("depth", 0)
if depth == 0 and address in self.account_manager:
call["contract_id"] = f"__{self.fee_token_symbol}_transfer__"
else:
call["contract_id"] = self._enrich_contract_id(call["contract_id"], **kwargs)

if not (contract_type := self.chain_manager.contracts.get(address)):
# Without a contract, we can enrich no further.
if not (contract_type := self._get_contract_type_for_enrichment(address, **kwargs)):
# Without a contract type, we can enrich no further.
return call

kwargs["contract_type"] = contract_type
if events := call.get("events"):
call["events"] = self._enrich_trace_events(
events, address=address, contract_type=contract_type
)
call["events"] = self._enrich_trace_events(events, address=address, **kwargs)

method_abi: Optional[Union[MethodABI, ConstructorABI]] = None
if is_create:
Expand All @@ -1131,24 +1131,26 @@ def _enrich_calltree(self, call: dict, **kwargs) -> dict:

elif call["method_id"] != "0x":
method_id_bytes = HexBytes(call["method_id"])
if method_id_bytes in contract_type.methods:

# perf: use try/except instead of __contains__ check.
try:
method_abi = contract_type.methods[method_id_bytes]
except KeyError:
name = call["method_id"]
else:
assert isinstance(method_abi, MethodABI) # For mypy

# Check if method name duplicated. If that is the case, use selector.
times = len([x for x in contract_type.methods if x.name == method_abi.name])
name = (method_abi.name if times == 1 else method_abi.selector) or call["method_id"]
call = self._enrich_calldata(call, method_abi, contract_type, **kwargs)

else:
name = call["method_id"]
call = self._enrich_calldata(call, method_abi, **kwargs)
else:
name = call.get("method_id") or "0x"

call["method_id"] = name

if method_abi:
call = self._enrich_calldata(call, method_abi, contract_type, **kwargs)
call = self._enrich_calldata(call, method_abi, **kwargs)

if kwargs.get("return_value"):
# Return value was separately enriched.
Expand All @@ -1172,10 +1174,12 @@ def _enrich_contract_id(self, address: AddressType, **kwargs) -> str:
elif address == ZERO_ADDRESS:
return "ZERO_ADDRESS"

if not (contract_type := self.chain_manager.contracts.get(address)):
elif not (contract_type := self._get_contract_type_for_enrichment(address, **kwargs)):
# Without a contract type, we can enrich no further.
return address

elif kwargs.get("use_symbol_for_tokens") and "symbol" in contract_type.view_methods:
kwargs["contract_type"] = contract_type
if kwargs.get("use_symbol_for_tokens") and "symbol" in contract_type.view_methods:
# Use token symbol as name
contract = self.chain_manager.contracts.instance_at(
address, contract_type=contract_type
Expand Down Expand Up @@ -1203,17 +1207,18 @@ def _enrich_calldata(
self,
call: dict,
method_abi: Union[MethodABI, ConstructorABI],
contract_type: ContractType,
**kwargs,
) -> dict:
calldata = call["calldata"]
if isinstance(calldata, (str, bytes, int)):
calldata_arg = HexBytes(calldata)
if isinstance(calldata, str):
calldata_arg = to_bytes(hexstr=calldata)
elif isinstance(calldata, bytes):
calldata_arg = calldata
else:
# Not sure if we can get here.
# Mostly for mypy's sake.
# Already enriched.
return call

contract_type = kwargs["contract_type"]
if call.get("call_type") and "CREATE" in call.get("call_type", ""):
# Strip off bytecode
bytecode = (
Expand Down Expand Up @@ -1316,18 +1321,15 @@ def _enrich_trace_events(
self,
events: list[dict],
address: Optional[AddressType] = None,
contract_type: Optional[ContractType] = None,
**kwargs,
) -> list[dict]:
return [
self._enrich_trace_event(e, address=address, contract_type=contract_type)
for e in events
]
return [self._enrich_trace_event(e, address=address, **kwargs) for e in events]

def _enrich_trace_event(
self,
event: dict,
address: Optional[AddressType] = None,
contract_type: Optional[ContractType] = None,
**kwargs,
) -> dict:
if "topics" not in event or len(event["topics"]) < 1:
# Already enriched or wrong.
Expand All @@ -1339,16 +1341,11 @@ def _enrich_trace_event(
# Cannot enrich further w/o an address.
return event

if not contract_type:
try:
contract_type = self.chain_manager.contracts.get(address)
except Exception as err:
logger.debug(f"Error getting contract type during event enrichment: {err}")
return event
if not (contract_type := self._get_contract_type_for_enrichment(address, **kwargs)):
# Without a contract type, we can enrich no further.
return event

if not contract_type:
# Cannot enrich further w/o an contract type.
return event
kwargs["contract_type"] = contract_type

# The selector is always the first topic.
selector = event["topics"][0]
Expand Down Expand Up @@ -1393,6 +1390,17 @@ def _enrich_revert_message(self, call: dict) -> dict:

return call

def _get_contract_type_for_enrichment(
self, address: AddressType, **kwargs
) -> Optional[ContractType]:
if not (contract_type := kwargs.get("contract_type")):
try:
contract_type = self.chain_manager.contracts.get(address)
except Exception as err:
logger.debug(f"Error getting contract type during event enrichment: {err}")

return contract_type

def get_python_types(self, abi_type: ABIType) -> Union[type, Sequence]:
return self._python_type_for_abi_type(abi_type)

Expand Down
9 changes: 8 additions & 1 deletion src/ape_ethereum/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ class Web3Provider(ProviderAPI, ABC):

_supports_debug_trace_call: Optional[bool] = None

_transaction_trace_cache: dict[str, TransactionTrace] = {}

def __new__(cls, *args, **kwargs):
assert_web3_provider_uri_env_var_not_set()

Expand Down Expand Up @@ -440,10 +442,15 @@ def get_storage(
raise # Raise original error

def get_transaction_trace(self, transaction_hash: str, **kwargs) -> TraceAPI:
if transaction_hash in self._transaction_trace_cache:
return self._transaction_trace_cache[transaction_hash]

if "call_trace_approach" not in kwargs:
kwargs["call_trace_approach"] = self.call_trace_approach

return TransactionTrace(transaction_hash=transaction_hash, **kwargs)
trace = TransactionTrace(transaction_hash=transaction_hash, **kwargs)
self._transaction_trace_cache[transaction_hash] = trace
return trace

def send_call(
self,
Expand Down
Loading

0 comments on commit f854ed5

Please sign in to comment.