Skip to content

Commit

Permalink
Limit batch requests
Browse files Browse the repository at this point in the history
- Add new environment variable ETHEREUM_RPC_BATCH_REQUEST_MAX_SIZE to set the limit
- Refactor batch requests code
  • Loading branch information
Uxio0 committed Jul 13, 2022
1 parent 9ca399f commit 814ef62
Show file tree
Hide file tree
Showing 4 changed files with 198 additions and 104 deletions.
191 changes: 101 additions & 90 deletions gnosis/eth/ethereum_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
fast_to_checksum_address,
mk_contract_address,
)
from gnosis.util import chunks

from .constants import (
ERC20_721_TRANSFER_TOPIC,
Expand Down Expand Up @@ -156,28 +157,6 @@ def with_exception_handling(*args, **kwargs):
return with_exception_handling


def parse_rpc_result_or_raise(
result: Dict[str, Any], eth_fn: str, arguments: Any
) -> Any:
"""
Responses from RPC should return a dictionary with `result` key and they are always 200 OK
even if errored. If not, raise an error
:param result:
:param eth_fn: RPC function called (for more info if exception is raised)
:param arguments: Arguments for the RPC function (for more info if exception is raised)
:return: `result["result"]` if key exists
:raises: ValueError
"""

if "result" not in result:
message = f"Problem calling `{eth_fn}` on {arguments}, result={result}"
logger.error(message)
raise ValueError(message)

return result["result"]


class EthereumTxSent(NamedTuple):
tx_hash: bytes
tx: TxParams
Expand Down Expand Up @@ -222,6 +201,9 @@ def __new__(cls):
os.environ.get("ETHEREUM_RPC_SLOW_TIMEOUT", 60)
),
retry_count=int(os.environ.get("ETHEREUM_RPC_RETRY_COUNT", 60)),
batch_request_max_size=int(
os.environ.get("ETHEREUM_RPC_BATCH_REQUEST_MAX_SIZE", 500)
),
)
return cls.instance

Expand All @@ -243,16 +225,19 @@ def batch_call_custom(
payloads: Iterable[Dict[str, Any]],
raise_exception: bool = True,
block_identifier: Optional[BlockIdentifier] = "latest",
batch_size: Optional[int] = None,
) -> List[Optional[Any]]:
"""
Do batch requests of multiple contract calls
Do batch requests of multiple contract calls (`eth_call`)
:param payloads: Iterable of Dictionaries with at least {'data': '<hex-string>',
'output_type': <solidity-output-type>, 'to': '<checksummed-address>'}. `from` can also be provided and if
`fn_name` is provided it will be used for debugging purposes
:param raise_exception: If False, exception will not be raised if there's any problem and instead `None` will
be returned as the value
:param block_identifier: `latest` by default
:param batch_size: If `payload` length is bigger than size, it will be split into smaller chunks before
sending to the server
:return: List with the ABI decoded return values
:raises: ValueError if raise_exception=True
"""
Expand Down Expand Up @@ -283,18 +268,34 @@ def batch_call_custom(
}
)

response = self.http_session.post(
self.ethereum_node_url, json=queries, timeout=self.slow_timeout
)
if not response.ok:
raise ConnectionError(
f"Error connecting to {self.ethereum_node_url}: {response.text}"
batch_size = batch_size or self.ethereum_client.batch_request_max_size
all_results = []
for chunk in chunks(queries, batch_size):
response = self.http_session.post(
self.ethereum_node_url, json=chunk, timeout=self.slow_timeout
)
if not response.ok:
raise ConnectionError(
f"Error connecting to {self.ethereum_node_url}: {response.text}"
)

results = response.json()

# If there's an error some nodes return a json instead of a list
if isinstance(results, dict) and "error" in results:
logger.error(
"Batch call custom problem with payload=%s, result=%s)",
chunk,
results,
)
raise ValueError(f"Batch request error: {results}")

all_results.extend(results)

return_values: List[Optional[Any]] = []
errors = []
for payload, result in zip(
payloads, sorted(response.json(), key=lambda x: x["id"])
payloads, sorted(all_results, key=lambda x: x["id"])
):
if "error" in result:
fn_name = payload.get("fn_name", HexBytes(payload["data"]).hex())
Expand Down Expand Up @@ -1106,20 +1107,10 @@ def trace_blocks(
}
for i, block_identifier in enumerate(block_identifiers)
]
response = self.http_session.post(
self.ethereum_node_url, json=payload, timeout=self.slow_timeout
)
if not response.ok:
message = (
f"Problem calling batch `trace_block` on blocks={block_identifiers} "
f"status_code={response.status_code} result={response.content}"
)
logger.error(message)
raise ValueError(message)
results = sorted(response.json(), key=lambda x: x["id"])

results = self.ethereum_client.raw_batch_request(payload)
traces = []
for block_identifier, result in zip(block_identifiers, results):
raw_tx = parse_rpc_result_or_raise(result, "trace_block", block_identifier)
for raw_tx in results:
if raw_tx:
try:
decoded_traces = self._decode_traces(raw_tx)
Expand Down Expand Up @@ -1160,20 +1151,9 @@ def trace_transactions(
}
for i, tx_hash in enumerate(tx_hashes)
]
response = self.http_session.post(
self.ethereum_node_url, json=payload, timeout=self.slow_timeout
)
if not response.ok:
message = (
f"Problem calling batch `trace_transaction` on tx_hashes={tx_hashes} "
f"status_code={response.status_code} result={response.content}"
)
logger.error(message)
raise ValueError(message)
results = sorted(response.json(), key=lambda x: x["id"])
results = self.ethereum_client.raw_batch_request(payload)
traces = []
for tx_hash, result in zip(tx_hashes, results):
raw_tx = parse_rpc_result_or_raise(result, "trace_transaction", tx_hash)
for raw_tx in results:
if raw_tx:
try:
decoded_traces = self._decode_traces(raw_tx)
Expand Down Expand Up @@ -1306,13 +1286,15 @@ def __init__(
slow_provider_timeout: int = 60,
retry_count: int = 3,
use_caching_middleware: bool = True,
batch_request_max_size: int = 500,
):
"""
:param ethereum_node_url: Ethereum RPC uri
:param provider_timeout: Timeout for regular RPC queries
:param slow_provider_timeout: Timeout for slow (tracing, logs...) and custom RPC queries
:param retry_count: Retry count for failed requests
:param use_caching_middleware: Use web3 simple cache middleware: https://web3py.readthedocs.io/en/stable/middleware.html#web3.middleware.construct_simple_cache_middleware
:param batch_request_max_size: Max size for JSON RPC Batch requests. Some providers have a limitation on 500
"""
self.http_session = self._prepare_http_session(retry_count)
self.ethereum_node_url: str = ethereum_node_url
Expand Down Expand Up @@ -1341,10 +1323,13 @@ def __init__(
except (IOError, FileNotFoundError):
self.w3.middleware_onion.inject(geth_poa_middleware, layer=0)

if use_caching_middleware:
self.use_caching_middleware = use_caching_middleware
if self.use_caching_middleware:
self.w3.middleware_onion.add(simple_cache_middleware)
self.slow_w3.middleware_onion.add(simple_cache_middleware)

self.batch_request_max_size = batch_request_max_size

def __str__(self):
return f"EthereumClient for url={self.ethereum_node_url}"

Expand All @@ -1366,6 +1351,56 @@ def _prepare_http_session(self, retry_count: int) -> requests.Session:
session.mount("https://", adapter)
return session

def raw_batch_request(
self, payload: List[Dict[str, Any]], batch_size: Optional[int] = None
) -> Iterable[Optional[Dict[str, Any]]]:
"""
Perform a raw batch JSON RPC call
:param payload: Batch request payload. Make sure all provided `ids` inside the payload are different
:param batch_size: If `payload` length is bigger than size, it will be split into smaller chunks before
sending to the server
:return:
:raises: ValueError
"""

batch_size = batch_size or self.batch_request_max_size

all_results = []
for chunk in chunks(payload, batch_size):
response = self.http_session.post(
self.ethereum_node_url, json=chunk, timeout=self.slow_timeout
)

if not response.ok:
logger.error(
"Problem doing raw batch request with payload=%s status_code=%d result=%s",
chunk,
response.status_code,
response.content,
)
raise ValueError(f"Batch request error: {response.content}")

results = response.json()

# If there's an error some nodes return a json instead of a list
if isinstance(results, dict) and "error" in results:
logger.error(
"Batch request problem with payload=%s, result=%s)", chunk, results
)
raise ValueError(f"Batch request error: {results}")

all_results.extend(results)

# Nodes like Erigon send back results out of order
for query, result in zip(payload, sorted(all_results, key=lambda x: x["id"])):
if "result" not in result:
message = f"Problem with payload=`{query}` result={result}"
logger.error(message)
raise ValueError(message)

yield result["result"]

@property
def current_block_number(self):
return self.w3.eth.block_number
Expand Down Expand Up @@ -1679,19 +1714,11 @@ def get_transactions(self, tx_hashes: List[EthereumHash]) -> List[Optional[TxDat
}
for i, tx_hash in enumerate(tx_hashes)
]
results = self.http_session.post(
self.ethereum_node_url, json=payload, timeout=self.slow_timeout
).json()
txs = []
for tx_hash, result in zip(tx_hashes, sorted(results, key=lambda x: x["id"])):
raw_tx = parse_rpc_result_or_raise(
result, "eth_getTransactionByHash", tx_hash
)
if raw_tx:
txs.append(transaction_result_formatter(raw_tx))
else:
txs.append(None)
return txs
results = self.raw_batch_request(payload)
return [
transaction_result_formatter(raw_tx) if raw_tx else None
for raw_tx in results
]

def get_transaction_receipt(
self, tx_hash: EthereumHash, timeout=None
Expand Down Expand Up @@ -1730,14 +1757,9 @@ def get_transaction_receipts(
}
for i, tx_hash in enumerate(tx_hashes)
]
results = self.http_session.post(
self.ethereum_node_url, json=payload, timeout=self.slow_timeout
).json()
results = self.raw_batch_request(payload)
receipts = []
for tx_hash, result in zip(tx_hashes, sorted(results, key=lambda x: x["id"])):
tx_receipt = parse_rpc_result_or_raise(
result, "eth_getTransactionReceipt", tx_hash
)
for tx_receipt in results:
# Parity returns tx_receipt even is tx is still pending, so we check `blockNumber` is not None
if tx_receipt and tx_receipt["blockNumber"] is not None:
receipts.append(receipt_formatter(tx_receipt))
Expand Down Expand Up @@ -1784,20 +1806,9 @@ def get_blocks(
}
for i, block_identifier in enumerate(block_identifiers)
]
results = self.http_session.post(
self.ethereum_node_url, json=payload, timeout=self.slow_timeout
).json()
results = self.raw_batch_request(payload)
blocks = []
for block_identifier, result in zip(
block_identifiers, sorted(results, key=lambda x: x["id"])
):
raw_block = parse_rpc_result_or_raise(
result,
"eth_getBlockByNumber"
if isinstance(block_identifier, int)
else "eth_getBlockByHash",
block_identifier,
)
for raw_block in results:
if raw_block:
if "extraData" in raw_block:
del raw_block[
Expand Down
Loading

0 comments on commit 814ef62

Please sign in to comment.