Skip to content

Commit

Permalink
feat: use a_sync.SmartProcessingQueue instead of semaphores (#183)
Browse files Browse the repository at this point in the history
  • Loading branch information
BobTheBuidler authored Apr 24, 2024
1 parent 320af2a commit 6ff6e4a
Show file tree
Hide file tree
Showing 12 changed files with 117 additions and 53 deletions.
2 changes: 1 addition & 1 deletion dank_mids/ENVIRONMENT_VARIABLES.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@

# type: ignore [attr-defined]
import logging

import a_sync
Expand Down
8 changes: 4 additions & 4 deletions dank_mids/_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
ExceedsMaxBatchSize, PayloadTooLarge,
ResponseNotReady, internal_err_types)
from dank_mids._uid import _AlertingRLock
from dank_mids.helpers import _decode, _session
from dank_mids.helpers import _codec, _session
from dank_mids.helpers._helpers import set_done
from dank_mids.types import (BatchId, BlockId, JSONRPCBatchResponse,
JsonrpcParams, PartialRequest, PartialResponse,
Expand Down Expand Up @@ -633,12 +633,12 @@ def data(self) -> bytes:
if not self.calls:
raise EmptyBatch(f"batch {self.uid} is empty and should not be processed.")
try:
return msgspec.json.encode([call.request for call in self.calls])
return _codec.encode([call.request for call in self.calls])
except TypeError:
# If we can't encode one of the calls, lets figure out which one and pass some useful info downstream
for call in self.calls:
try:
msgspec.json.encode(call.request)
_codec.encode(call.request)
except TypeError as e:
raise TypeError(e, call.request) from None
raise
Expand Down Expand Up @@ -723,7 +723,7 @@ async def get_response(self) -> None:
async def post(self) -> List[RawResponse]:
"this function raises `BadResponse` if a successful 'error' response was received from the rpc"
try:
response: JSONRPCBatchResponse = await _session.post(self.controller.endpoint, data=self.data, loads=_decode.jsonrpc_batch)
response: JSONRPCBatchResponse = await _session.post(self.controller.endpoint, data=self.data, loads=_codec.decode_jsonrpc_batch)
except ClientResponseError as e:
if e.message == "Payload Too Large":
logger.warning("Payload Too Large")
Expand Down
10 changes: 1 addition & 9 deletions dank_mids/brownie_patch/_abi.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import Any

from brownie.convert.utils import build_function_selector, build_function_signature
from web3.datastructures import AttributeDict


@functools.lru_cache(maxsize=None)
Expand All @@ -13,11 +12,4 @@ class FunctionABI:
def __init__(self, **abi: Any):
self.abi = abi
self.input_sig = build_function_signature(abi)
self.signature = build_function_selector(abi)

def _make_hashable(obj: Any) -> Any:
if isinstance(obj, (list, tuple)):
return tuple((_make_hashable(o) for o in obj))
elif isinstance(obj, dict):
return AttributeDict({k: _make_hashable(v) for k, v in obj.items()})
return obj
self.signature = build_function_selector(abi)
3 changes: 2 additions & 1 deletion dank_mids/brownie_patch/_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from web3 import Web3

from dank_mids import ENVIRONMENT_VARIABLES as ENVS
from dank_mids.brownie_patch._abi import FunctionABI, _make_hashable
from dank_mids.brownie_patch._abi import FunctionABI
from dank_mids.helpers._helpers import _make_hashable

_EVMType = TypeVar("_EVMType")

Expand Down
40 changes: 28 additions & 12 deletions dank_mids/controller.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@

import logging
from collections import defaultdict
from contextlib import suppress
from functools import lru_cache
from importlib.metadata import version
from typing import Any, DefaultDict, List, Literal, Optional
Expand All @@ -22,8 +23,8 @@
from dank_mids._exceptions import DankMidsInternalError
from dank_mids._requests import JSONRPCBatch, Multicall, RPCRequest, eth_call
from dank_mids._uid import UIDGenerator, _AlertingRLock
from dank_mids.helpers import _decode, _session
from dank_mids.semaphores import _MethodSemaphores
from dank_mids.helpers import _codec, _helpers, _session
from dank_mids.semaphores import _MethodQueues, _MethodSemaphores
from dank_mids.types import (BlockId, ChainId, PartialRequest, RawResponse,
Request)

Expand Down Expand Up @@ -108,6 +109,8 @@ def __init__(self, w3: Web3) -> None:
self.no_multicall.add(self.mc3.address)

self.method_semaphores = _MethodSemaphores(self)
# semaphores soon to be deprecated for smart queue
self.method_queues = _MethodQueues(self)
self.batcher = NotSoBrightBatcher()
self.batcher.step = ENVS.MAX_MULTICALL_SIZE

Expand All @@ -124,17 +127,30 @@ def __repr__(self) -> str:
return f"<DankMiddlewareController instance={self._instance} chain={self.chain_id} endpoint={self.endpoint}>"

async def __call__(self, method: RPCEndpoint, params: Any) -> RPCResponse:
call_semaphore = self.method_semaphores[method][params[1]] if method == "eth_call" else self.method_semaphores[method]
async with call_semaphore:
logger.debug(f'making {self.request_type.__name__} {method} with params {params}')
call = eth_call(self, params) if method == "eth_call" and params[0]["to"] not in self.no_multicall else RPCRequest(self, method, params)
return await call
with suppress(KeyError):
# some methods go thru a SmartProcessingQueue, we try this first
try:
queue = self.method_queues[method]
return await queue(self, method, params)
except TypeError as e:
if "unhashable type" in str(e):
return await queue(self, method, _helpers._make_hashable(params))
raise e

# eth_call go thru a specialized Semaphore and other methods pass thru unblocked
logger.debug(f'making {self.request_type.__name__} {method} with params {params}')
if method != "eth_call":
return await RPCRequest(self, method, params)
async with self.method_semaphores[method][params[1]]:
if params[0]["to"] not in self.no_multicall:
return await eth_call(self, params)
return await RPCRequest(self, method, params)

@eth_retry.auto_retry
async def make_request(self, method: str, params: List[Any], request_id: Optional[int] = None) -> RawResponse:
request = self.request_type(method=method, params=params, id=request_id or self.call_uid.next)
try:
return await _session.post(self.endpoint, data=request, loads=_decode.raw)
return await _session.post(self.endpoint, data=request, loads=_codec.decode_raw)
except Exception as e:
if ENVS.DEBUG:
_debugging.failures.record(self.chain_id, e, "eth_call" if method == "eth_call" else "RPCRequest", "unknown", "unknown", request.data)
Expand All @@ -155,8 +171,8 @@ async def execute_batch(self) -> None:
@property
def queue_is_full(self) -> bool:
with self.pools_closed_lock:
if ENVS.OPERATION_MODE.infura:
return sum(len(call) for call in self.pending_rpc_calls) >= ENVS.MAX_JSONRPC_BATCH_SIZE
if ENVS.OPERATION_MODE.infura: # type: ignore [attr-defined]
return sum(len(call) for call in self.pending_rpc_calls) >= ENVS.MAX_JSONRPC_BATCH_SIZE # type: ignore [attr-defined]
eth_calls = sum(len(calls) for calls in self.pending_eth_calls.values())
other_calls = sum(len(call) for call in self.pending_rpc_calls)
return eth_calls + other_calls >= self.batcher.step
Expand Down Expand Up @@ -199,9 +215,9 @@ def set_multicall_size_limit(self, new_limit: int) -> None:
logger.info("new multicall size limit %s is not lower than existing limit %s", new_limit, existing_limit)

def set_batch_size_limit(self, new_limit: int) -> None:
existing_limit = ENVS.MAX_JSONRPC_BATCH_SIZE
existing_limit = ENVS.MAX_JSONRPC_BATCH_SIZE # type: ignore [attr-defined]
if new_limit < existing_limit:
ENVS.MAX_JSONRPC_BATCH_SIZE = new_limit
ENVS.MAX_JSONRPC_BATCH_SIZE = new_limit # type: ignore [attr-defined]
logger.warning("jsonrpc batch size limit reduced from %s to %s", existing_limit, new_limit)
else:
logger.info("new jsonrpc batch size limit %s is not lower than existing limit %s", new_limit, int(existing_limit))
Expand Down
24 changes: 24 additions & 0 deletions dank_mids/helpers/_codec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@

from msgspec import Raw, json

from dank_mids.types import (Any, Callable, JSONRPCBatchResponseRaw, List, PartialResponse,
RawResponse, Union, nested_dict_of_stuff, _encode_hook)


decode_raw = lambda data: RawResponse(json.decode(data, type=Raw))
"""Decode json-encoded bytes into a `msgspec.Raw` object."""

decode_nested_dict = lambda data: json.decode(data, type=nested_dict_of_stuff)
"""Decode json-encoded bytes into a nested dictionary."""

def decode_jsonrpc_batch(data: bytes) -> Union[PartialResponse, List[RawResponse]]:
"""Decode json-encoded bytes into a list of response structs, or a single error response struct if applicable."""
decoded = json.decode(data, type=JSONRPCBatchResponseRaw)
return decoded if isinstance(decoded, PartialResponse) else _map_raw(decoded)

_map_raw: Callable[[List[Raw]], List[RawResponse]] = lambda decoded: list(map(RawResponse, decoded))
"""Converts a list of `msgspec.Raw` objects into a list of `RawResponse` objects."""

def encode(obj: Any) -> bytes:
"""Encode an object to json-encoded bytes."""
return json.encode(obj, enc_hook=_encode_hook)
11 changes: 0 additions & 11 deletions dank_mids/helpers/_decode.py

This file was deleted.

9 changes: 9 additions & 0 deletions dank_mids/helpers/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from multicall.utils import get_async_w3
from typing_extensions import ParamSpec
from web3 import Web3
from web3.datastructures import AttributeDict
from web3._utils.rpc_abi import RPC
from web3.eth import AsyncEth
from web3.providers.async_base import AsyncBaseProvider
Expand Down Expand Up @@ -172,3 +173,11 @@ def _format_response(
return _format_response("error", error_formatters[method])
else:
return response


def _make_hashable(obj: Any) -> Any:
if isinstance(obj, (list, tuple)):
return tuple((_make_hashable(o) for o in obj))
elif isinstance(obj, dict):
return AttributeDict({k: _make_hashable(v) for k, v in obj.items()})
return obj
4 changes: 2 additions & 2 deletions dank_mids/helpers/_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from async_lru import alru_cache

from dank_mids import ENVIRONMENT_VARIABLES
from dank_mids.helpers import _decode
from dank_mids.helpers import _codec
from dank_mids.types import JSONRPCBatchResponse, PartialRequest, RawResponse

logger = logging.getLogger("dank_mids.session")
Expand Down Expand Up @@ -91,7 +91,7 @@ async def post(self, endpoint: str, *args, loads: Optional[JSONDecoder] = None,
# Process input arguments.
if isinstance(kwargs.get('data'), PartialRequest):
logger.debug("making request for %s", kwargs['data'])
kwargs['data'] = msgspec.json.encode(kwargs['data'])
kwargs['data'] = _codec.encode(kwargs['data'])
logger.debug("making request with (args, kwargs): (%s %s)", tuple(endpoint, *args), kwargs)

# Try the request until success or 5 failures.
Expand Down
23 changes: 22 additions & 1 deletion dank_mids/semaphores.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@

import functools
from decimal import Decimal
from typing import TYPE_CHECKING, Literal, Optional, Union

import a_sync
from a_sync.primitives import DummySemaphore, ThreadsafeSemaphore
from a_sync.primitives.locks.prio_semaphore import (
_AbstractPrioritySemaphore, _PrioritySemaphoreContextManager)
Expand Down Expand Up @@ -33,6 +35,7 @@ def __getitem__(self, block: Union[int, str, Literal["latest", None]]) -> "_Bloc

class _MethodSemaphores:
def __init__(self, controller: "DankMiddlewareController") -> None:
# TODO: refactor this out, just use BlockSemaphore for eth_call and SmartProcessingQueue to limit other methods
from dank_mids import ENVIRONMENT_VARIABLES
self.controller = controller
self.method_semaphores = {
Expand All @@ -41,6 +44,24 @@ def __init__(self, controller: "DankMiddlewareController") -> None:
}
self.keys = self.method_semaphores.keys()
self.dummy = DummySemaphore()

@functools.lru_cache(maxsize=None)
def __getitem__(self, method: RPCEndpoint) -> Union[ThreadsafeSemaphore, DummySemaphore]:
return next((self.method_semaphores[key] for key in self.keys if key in method), self.dummy)

class _MethodQueues:
def __init__(self, controller: "DankMiddlewareController") -> None:
from dank_mids import ENVIRONMENT_VARIABLES
from dank_mids._requests import RPCRequest
self.controller = controller
self.method_queues = {
method: a_sync.SmartProcessingQueue(RPCRequest, num_workers=sem._value, name=f"{method} {controller}")
for method, sem in ENVIRONMENT_VARIABLES.method_semaphores.items()
if method != "eth_call"
}
self.keys = self.method_queues.keys()
@functools.lru_cache(maxsize=None)
def __getitem__(self, method: RPCEndpoint) -> a_sync.SmartProcessingQueue:
try:
return next(self.method_queues[key] for key in self.keys if key in method)
except StopIteration:
raise KeyError(method) from None
34 changes: 23 additions & 11 deletions dank_mids/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import re
from time import time
from typing import (TYPE_CHECKING, Any, Callable, Coroutine, DefaultDict, Dict,
List, Literal, NewType, Optional, TypedDict, TypeVar,
List, Literal, NewType, Optional, Set, TypedDict, TypeVar,
Union, overload)

import msgspec
Expand Down Expand Up @@ -37,6 +37,7 @@ class _DictStruct(msgspec.Struct):
def __getitem__(self, attr: str) -> Any:
return getattr(self, attr)
def to_dict(self) -> Dict[str, Any]:
"""Returns a complete dictionary representation of this ``Struct``'s attributes and values."""
data = {}
for field in self.__struct_fields__:
attr = getattr(self, field)
Expand Down Expand Up @@ -64,8 +65,8 @@ class Error(_DictStruct):
data: Optional[Any] = ''

# some devving tools that will go away eventually
_dict_responses = set()
_str_responses = set()
_dict_responses: Set[str] = set()
_str_responses: Set[str] = set()

# TODO: use the types from snek
Log = Dict[str, Union[bool, str, None, List[str]]]
Expand All @@ -91,24 +92,28 @@ class Error(_DictStruct):

class PartialResponse(_DictStruct):
result: msgspec.Raw = None # type: ignore
"If the rpc response contains a 'result' field, it is set here"
error: Optional[Error] = None
"If the rpc response contains an 'error' field, it is set here"

@property
def exception(self) -> BadResponse:
def exception(self) -> Exception:
"If the rpc response contains an 'error' field, returns a specialized exception for the specified rpc error."
if self.error is None:
raise AttributeError(f"{self} did not error.")
return (
PayloadTooLarge(self) if self.payload_too_large
else ExceedsMaxBatchSize(self) if re.search(r'batch limit (\d+) exceeded', self.error.message)
else TypeError(self.error.message, "You're probably passing what should be an integer type as a string type. The usual culprit is a block number.") if self.error.message == 'invalid argument 1: hex string without 0x prefix'
else TypeError(self.error.message, "DANKMIDS NOTE: You're probably passing what should be an integer type as a string type. The usual culprit is a block number.") if self.error.message == 'invalid argument 1: hex string without 0x prefix'
else BadResponse(self)
)

@property
def payload_too_large(self) -> bool:
return any(err in self.error.message for err in constants.TOO_MUCH_DATA_ERRS)
return any(err in self.error.message for err in constants.TOO_MUCH_DATA_ERRS) # type: ignore [union-attr]

def to_dict(self, method: Optional[str] = None) -> Dict[str, Any]:
def to_dict(self, method: Optional[RPCEndpoint] = None) -> Dict[str, Any]:
"""Returns a complete dictionary representation of this response ``Struct``."""
data = {}
for field in self.__struct_fields__:
attr = getattr(self, field)
Expand All @@ -121,7 +126,7 @@ def to_dict(self, method: Optional[str] = None) -> Dict[str, Any]:
data[field] = AttributeDict(attr) if isinstance(attr, dict) and field != "error" else attr
return data

def decode_result(self, method: Optional[str] = None, _caller = None) -> Any:
def decode_result(self, method: Optional[RPCEndpoint] = None, _caller = None) -> Any:
# NOTE: These must be added to the `RETURN_TYPES` constant above manually
if method and (typ := RETURN_TYPES.get(method)):
if method in ["eth_call", "eth_blockNumber", "eth_getCode", "eth_getBlockByNumber", "eth_getTransactionReceipt", "eth_getTransactionCount", "eth_getBalance", "eth_chainId", "erigon_getHeaderByNumber"]:
Expand Down Expand Up @@ -174,15 +179,22 @@ class RawResponse:
"""
def __init__(self, raw: msgspec.Raw) -> None:
self._raw = raw
"""The `msgspec.Raw` object wrapped by this wrapper."""
@overload
def decode(self, partial = True) -> PartialResponse:...
def decode(self, partial: Literal[True]) -> PartialResponse:...
@overload
def decode(self, partial = False) -> Response:...
def decode(self, partial: Literal[False]) -> Response:...
def decode(self, partial: bool = False) -> Union[Response, PartialResponse]:
"""Decode the wrapped `msgspec.Raw` object into a `Response` or a `PartialResponse`."""
return msgspec.json.decode(self._raw, type=PartialResponse if partial else Response)

JSONRPCBatchRequest = List[Request]
# NOTE: A PartialResponse result implies a failure response from the rpc.
JSONRPCBatchResponse = Union[List[RawResponse], PartialResponse]
# We need this for proper decoding.
_JSONRPCBatchResponse = Union[List[msgspec.Raw], PartialResponse]
JSONRPCBatchResponseRaw = Union[List[msgspec.Raw], PartialResponse]

def _encode_hook(obj: Any) -> Any:
if isinstance(obj, AttributeDict):
return dict(obj)
raise NotImplementedError(type(obj))
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
aiofiles
eth_retry>=0.1.15,<0.2
ez-a-sync>=0.19.4
ez-a-sync>=0.20.6,<0.22
msgspec
multicall>=0.6.2,<1
typed-envs>=0.0.2
Expand Down

0 comments on commit 6ff6e4a

Please sign in to comment.