From 814ef62415dbefa4e2b4c072fb0ed7b2bd8b48be Mon Sep 17 00:00:00 2001 From: Uxio Fuentefria Date: Wed, 13 Jul 2022 13:48:45 +0200 Subject: [PATCH] Limit batch requests - Add new environment variable ETHEREUM_RPC_BATCH_REQUEST_MAX_SIZE to set the limit - Refactor batch requests code --- gnosis/eth/ethereum_client.py | 191 ++++++++++++----------- gnosis/eth/tests/test_ethereum_client.py | 96 ++++++++++-- gnosis/util/__init__.py | 4 + gnosis/util/util.py | 11 ++ 4 files changed, 198 insertions(+), 104 deletions(-) create mode 100644 gnosis/util/__init__.py create mode 100644 gnosis/util/util.py diff --git a/gnosis/eth/ethereum_client.py b/gnosis/eth/ethereum_client.py index 1c87b853e..9dfe10388 100644 --- a/gnosis/eth/ethereum_client.py +++ b/gnosis/eth/ethereum_client.py @@ -59,6 +59,7 @@ fast_to_checksum_address, mk_contract_address, ) +from gnosis.util import chunks from .constants import ( ERC20_721_TRANSFER_TOPIC, @@ -156,28 +157,6 @@ def with_exception_handling(*args, **kwargs): return with_exception_handling -def parse_rpc_result_or_raise( - result: Dict[str, Any], eth_fn: str, arguments: Any -) -> Any: - """ - Responses from RPC should return a dictionary with `result` key and they are always 200 OK - even if errored. If not, raise an error - - :param result: - :param eth_fn: RPC function called (for more info if exception is raised) - :param arguments: Arguments for the RPC function (for more info if exception is raised) - :return: `result["result"]` if key exists - :raises: ValueError - """ - - if "result" not in result: - message = f"Problem calling `{eth_fn}` on {arguments}, result={result}" - logger.error(message) - raise ValueError(message) - - return result["result"] - - class EthereumTxSent(NamedTuple): tx_hash: bytes tx: TxParams @@ -222,6 +201,9 @@ def __new__(cls): os.environ.get("ETHEREUM_RPC_SLOW_TIMEOUT", 60) ), retry_count=int(os.environ.get("ETHEREUM_RPC_RETRY_COUNT", 60)), + batch_request_max_size=int( + os.environ.get("ETHEREUM_RPC_BATCH_REQUEST_MAX_SIZE", 500) + ), ) return cls.instance @@ -243,9 +225,10 @@ def batch_call_custom( payloads: Iterable[Dict[str, Any]], raise_exception: bool = True, block_identifier: Optional[BlockIdentifier] = "latest", + batch_size: Optional[int] = None, ) -> List[Optional[Any]]: """ - Do batch requests of multiple contract calls + Do batch requests of multiple contract calls (`eth_call`) :param payloads: Iterable of Dictionaries with at least {'data': '', 'output_type': , 'to': ''}. `from` can also be provided and if @@ -253,6 +236,8 @@ def batch_call_custom( :param raise_exception: If False, exception will not be raised if there's any problem and instead `None` will be returned as the value :param block_identifier: `latest` by default + :param batch_size: If `payload` length is bigger than size, it will be split into smaller chunks before + sending to the server :return: List with the ABI decoded return values :raises: ValueError if raise_exception=True """ @@ -283,18 +268,34 @@ def batch_call_custom( } ) - response = self.http_session.post( - self.ethereum_node_url, json=queries, timeout=self.slow_timeout - ) - if not response.ok: - raise ConnectionError( - f"Error connecting to {self.ethereum_node_url}: {response.text}" + batch_size = batch_size or self.ethereum_client.batch_request_max_size + all_results = [] + for chunk in chunks(queries, batch_size): + response = self.http_session.post( + self.ethereum_node_url, json=chunk, timeout=self.slow_timeout ) + if not response.ok: + raise ConnectionError( + f"Error connecting to {self.ethereum_node_url}: {response.text}" + ) + + results = response.json() + + # If there's an error some nodes return a json instead of a list + if isinstance(results, dict) and "error" in results: + logger.error( + "Batch call custom problem with payload=%s, result=%s)", + chunk, + results, + ) + raise ValueError(f"Batch request error: {results}") + + all_results.extend(results) return_values: List[Optional[Any]] = [] errors = [] for payload, result in zip( - payloads, sorted(response.json(), key=lambda x: x["id"]) + payloads, sorted(all_results, key=lambda x: x["id"]) ): if "error" in result: fn_name = payload.get("fn_name", HexBytes(payload["data"]).hex()) @@ -1106,20 +1107,10 @@ def trace_blocks( } for i, block_identifier in enumerate(block_identifiers) ] - response = self.http_session.post( - self.ethereum_node_url, json=payload, timeout=self.slow_timeout - ) - if not response.ok: - message = ( - f"Problem calling batch `trace_block` on blocks={block_identifiers} " - f"status_code={response.status_code} result={response.content}" - ) - logger.error(message) - raise ValueError(message) - results = sorted(response.json(), key=lambda x: x["id"]) + + results = self.ethereum_client.raw_batch_request(payload) traces = [] - for block_identifier, result in zip(block_identifiers, results): - raw_tx = parse_rpc_result_or_raise(result, "trace_block", block_identifier) + for raw_tx in results: if raw_tx: try: decoded_traces = self._decode_traces(raw_tx) @@ -1160,20 +1151,9 @@ def trace_transactions( } for i, tx_hash in enumerate(tx_hashes) ] - response = self.http_session.post( - self.ethereum_node_url, json=payload, timeout=self.slow_timeout - ) - if not response.ok: - message = ( - f"Problem calling batch `trace_transaction` on tx_hashes={tx_hashes} " - f"status_code={response.status_code} result={response.content}" - ) - logger.error(message) - raise ValueError(message) - results = sorted(response.json(), key=lambda x: x["id"]) + results = self.ethereum_client.raw_batch_request(payload) traces = [] - for tx_hash, result in zip(tx_hashes, results): - raw_tx = parse_rpc_result_or_raise(result, "trace_transaction", tx_hash) + for raw_tx in results: if raw_tx: try: decoded_traces = self._decode_traces(raw_tx) @@ -1306,6 +1286,7 @@ def __init__( slow_provider_timeout: int = 60, retry_count: int = 3, use_caching_middleware: bool = True, + batch_request_max_size: int = 500, ): """ :param ethereum_node_url: Ethereum RPC uri @@ -1313,6 +1294,7 @@ def __init__( :param slow_provider_timeout: Timeout for slow (tracing, logs...) and custom RPC queries :param retry_count: Retry count for failed requests :param use_caching_middleware: Use web3 simple cache middleware: https://web3py.readthedocs.io/en/stable/middleware.html#web3.middleware.construct_simple_cache_middleware + :param batch_request_max_size: Max size for JSON RPC Batch requests. Some providers have a limitation on 500 """ self.http_session = self._prepare_http_session(retry_count) self.ethereum_node_url: str = ethereum_node_url @@ -1341,10 +1323,13 @@ def __init__( except (IOError, FileNotFoundError): self.w3.middleware_onion.inject(geth_poa_middleware, layer=0) - if use_caching_middleware: + self.use_caching_middleware = use_caching_middleware + if self.use_caching_middleware: self.w3.middleware_onion.add(simple_cache_middleware) self.slow_w3.middleware_onion.add(simple_cache_middleware) + self.batch_request_max_size = batch_request_max_size + def __str__(self): return f"EthereumClient for url={self.ethereum_node_url}" @@ -1366,6 +1351,56 @@ def _prepare_http_session(self, retry_count: int) -> requests.Session: session.mount("https://", adapter) return session + def raw_batch_request( + self, payload: List[Dict[str, Any]], batch_size: Optional[int] = None + ) -> Iterable[Optional[Dict[str, Any]]]: + """ + Perform a raw batch JSON RPC call + + :param payload: Batch request payload. Make sure all provided `ids` inside the payload are different + :param batch_size: If `payload` length is bigger than size, it will be split into smaller chunks before + sending to the server + :return: + :raises: ValueError + """ + + batch_size = batch_size or self.batch_request_max_size + + all_results = [] + for chunk in chunks(payload, batch_size): + response = self.http_session.post( + self.ethereum_node_url, json=chunk, timeout=self.slow_timeout + ) + + if not response.ok: + logger.error( + "Problem doing raw batch request with payload=%s status_code=%d result=%s", + chunk, + response.status_code, + response.content, + ) + raise ValueError(f"Batch request error: {response.content}") + + results = response.json() + + # If there's an error some nodes return a json instead of a list + if isinstance(results, dict) and "error" in results: + logger.error( + "Batch request problem with payload=%s, result=%s)", chunk, results + ) + raise ValueError(f"Batch request error: {results}") + + all_results.extend(results) + + # Nodes like Erigon send back results out of order + for query, result in zip(payload, sorted(all_results, key=lambda x: x["id"])): + if "result" not in result: + message = f"Problem with payload=`{query}` result={result}" + logger.error(message) + raise ValueError(message) + + yield result["result"] + @property def current_block_number(self): return self.w3.eth.block_number @@ -1679,19 +1714,11 @@ def get_transactions(self, tx_hashes: List[EthereumHash]) -> List[Optional[TxDat } for i, tx_hash in enumerate(tx_hashes) ] - results = self.http_session.post( - self.ethereum_node_url, json=payload, timeout=self.slow_timeout - ).json() - txs = [] - for tx_hash, result in zip(tx_hashes, sorted(results, key=lambda x: x["id"])): - raw_tx = parse_rpc_result_or_raise( - result, "eth_getTransactionByHash", tx_hash - ) - if raw_tx: - txs.append(transaction_result_formatter(raw_tx)) - else: - txs.append(None) - return txs + results = self.raw_batch_request(payload) + return [ + transaction_result_formatter(raw_tx) if raw_tx else None + for raw_tx in results + ] def get_transaction_receipt( self, tx_hash: EthereumHash, timeout=None @@ -1730,14 +1757,9 @@ def get_transaction_receipts( } for i, tx_hash in enumerate(tx_hashes) ] - results = self.http_session.post( - self.ethereum_node_url, json=payload, timeout=self.slow_timeout - ).json() + results = self.raw_batch_request(payload) receipts = [] - for tx_hash, result in zip(tx_hashes, sorted(results, key=lambda x: x["id"])): - tx_receipt = parse_rpc_result_or_raise( - result, "eth_getTransactionReceipt", tx_hash - ) + for tx_receipt in results: # Parity returns tx_receipt even is tx is still pending, so we check `blockNumber` is not None if tx_receipt and tx_receipt["blockNumber"] is not None: receipts.append(receipt_formatter(tx_receipt)) @@ -1784,20 +1806,9 @@ def get_blocks( } for i, block_identifier in enumerate(block_identifiers) ] - results = self.http_session.post( - self.ethereum_node_url, json=payload, timeout=self.slow_timeout - ).json() + results = self.raw_batch_request(payload) blocks = [] - for block_identifier, result in zip( - block_identifiers, sorted(results, key=lambda x: x["id"]) - ): - raw_block = parse_rpc_result_or_raise( - result, - "eth_getBlockByNumber" - if isinstance(block_identifier, int) - else "eth_getBlockByHash", - block_identifier, - ) + for raw_block in results: if raw_block: if "extraData" in raw_block: del raw_block[ diff --git a/gnosis/eth/tests/test_ethereum_client.py b/gnosis/eth/tests/test_ethereum_client.py index d86b87c7f..16f4ed750 100644 --- a/gnosis/eth/tests/test_ethereum_client.py +++ b/gnosis/eth/tests/test_ethereum_client.py @@ -5,6 +5,7 @@ from django.test import TestCase import pytest +import requests from eth_account import Account from hexbytes import HexBytes from web3.eth import Eth @@ -21,7 +22,6 @@ InvalidNonce, ParityManager, SenderAccountNotFoundInNode, - parse_rpc_result_or_raise, ) from ..exceptions import BatchCallException, ChainIdIsRequired, InvalidERC20Info from ..utils import fast_to_checksum_address, get_eth_address_with_key @@ -764,6 +764,87 @@ def test_trace_filter(self): from_address=Account.create().address ) + @mock.patch.object(requests.Response, "json") + def test_raw_batch_request(self, session_post_mock: MagicMock): + # Ankr + session_post_mock.return_value = { + "jsonrpc": "2.0", + "error": { + "code": 0, + "message": "you can't send more than 1000 requests in a batch", + }, + "id": None, + } + payload = [ + { + "id": 0, + "jsonrpc": "2.0", + "method": "eth_getTransactionByHash", + "params": "0x5afea3f32970a22f4e63a815c174fa989e3b659826e5f52496662bb256baf3b2", + }, + { + "id": 1, + "jsonrpc": "2.0", + "method": "eth_getTransactionByHash", + "params": "0x12ab96991ddd4ac55c28ace4e7b59bc64c514b55747e1b0ea3f5b269fbb39f6b", + }, + ] + with self.assertRaisesMessage( + ValueError, + "Batch request error: {'jsonrpc': '2.0', 'error': {'code': 0, 'message': \"you can't send more than 1000 requests in a batch\"}, 'id': None}", + ): + list(self.ethereum_client.raw_batch_request(payload)) + + # Nodereal + session_post_mock.return_value = [ + { + "jsonrpc": "2.0", + "id": None, + "error": { + "code": -32000, + "message": "batch length does not support more than 500", + }, + } + ] + + with self.assertRaisesMessage( + ValueError, + "Problem with payload=`{'id': 0, 'jsonrpc': '2.0', 'method': 'eth_getTransactionByHash', 'params': '0x5afea3f32970a22f4e63a815c174fa989e3b659826e5f52496662bb256baf3b2'}` result={'jsonrpc': '2.0', 'id': None, 'error': {'code': -32000, 'message': 'batch length does not support more than 500'}}", + ): + list(self.ethereum_client.raw_batch_request(payload)) + + # Test batching chunks + session_post_mock.return_value = [ + { + "jsonrpc": "2.0", + "id": 0, + "result": { + "blockHash": "0x13e9e3262d9cf1c4d07d7324d95e6bddf27f07d7bddbdcc7df4e4ffb42a2e921", + "blockNumber": "0xa81a59", + "from": "0x136ec956eb32364f5016f3f84f56dbff59c6ead5", + "gas": "0x493e0", + "gasPrice": "0x3b9aca0e", + "maxPriorityFeePerGas": "0x3b9aca00", + "maxFeePerGas": "0x3b9aca1e", + "hash": "0x92898917d7bd7a51d40a903f4c55ae988cbac7c661c3e271c54bbda21415501b", + "input": "0x8ea59e1de547ab59caab9379b4b307450a29a0137c7dbbfc7b18c3cd6179d927efbab9ee", + "nonce": "0x1242f", + "to": "0x7e22c795325e76306920293f62a02f353536280b", + "transactionIndex": "0x1e", + "value": "0x0", + "type": "0x2", + "accessList": [], + "chainId": "0x4", + "v": "0x1", + "r": "0x5aaaa2a32326ca4add9a602ffba968c3d991219fde93a2531eb7a82fc61919ed", + "s": "0x1c4bff2abcc671ad2a1dd09f92a9720ac595138c666e59153711056811c1c95c", + }, + } + ] + + results = list(self.ethereum_client.raw_batch_request(payload, batch_size=1)) + self.assertEqual(len(results), 2) + class TestEthereumNetwork(EthereumTestCaseMixin, TestCase): def test_unknown_ethereum_network_name(self): @@ -777,19 +858,6 @@ def test_rinkeby_ethereum_network_name(self): class TestEthereumClient(EthereumTestCaseMixin, TestCase): - def test_parse_rpc_result_or_raise(self): - self.assertEqual(parse_rpc_result_or_raise({"result": "test"}, "", ""), "test") - - with self.assertRaisesMessage( - ValueError, - "Problem calling `trace_transaction` on 0x230b7f018951818c2a4545654d43a086ed2a3ed7c5b7c03990f4ac22ffae3840, result={'error': 'Something bad happened'}", - ): - parse_rpc_result_or_raise( - {"error": "Something bad happened"}, - "trace_transaction", - "0x230b7f018951818c2a4545654d43a086ed2a3ed7c5b7c03990f4ac22ffae3840", - ) - def test_ethereum_client_str(self): self.assertTrue(str(self.ethereum_client)) diff --git a/gnosis/util/__init__.py b/gnosis/util/__init__.py new file mode 100644 index 000000000..80507c51c --- /dev/null +++ b/gnosis/util/__init__.py @@ -0,0 +1,4 @@ +# flake8: noqa F401 +from .util import chunks + +__all__ = ["chunks"] diff --git a/gnosis/util/util.py b/gnosis/util/util.py new file mode 100644 index 000000000..063dbdd13 --- /dev/null +++ b/gnosis/util/util.py @@ -0,0 +1,11 @@ +from typing import Any, Iterable, List + + +def chunks(elements: List[Any], n: int) -> Iterable[Any]: + """ + :param elements: List + :param n: Number of elements per chunk + :return: Yield successive n-sized chunks from l + """ + for i in range(0, len(elements), n): + yield elements[i : i + n]