Skip to content

Commit

Permalink
fix(mypy): fix most mypy errs in the repo (#194)
Browse files Browse the repository at this point in the history
  • Loading branch information
BobTheBuidler committed Apr 30, 2024
1 parent 59960bd commit 58f7f14
Show file tree
Hide file tree
Showing 18 changed files with 177 additions and 154 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/mypy.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:
- name: Install MyPy
run: |
python -m pip install --upgrade pip
pip install mypy types-requests
pip install mypy types-requests types-aiofiles
- name: Run MyPy
run: mypy ./dank_mids --pretty --ignore-missing-imports --show-error-codes --show-error-context --no-warn-no-return
5 changes: 3 additions & 2 deletions dank_mids/ENVIRONMENT_VARIABLES.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# type: ignore [attr-defined]
# mypy: disable-error-code="attr-defined,dict-item"
import logging
from typing import Dict

import a_sync
import typed_envs
Expand Down Expand Up @@ -69,7 +70,7 @@
STUCK_CALL_TIMEOUT = _envs.create_env("STUCK_CALL_TIMEOUT", int, default=60*60*2)

# Method-specific Semaphores
method_semaphores = {
method_semaphores: Dict[str, a_sync.Semaphore] = {
"eth_call": _envs.create_env("ETH_CALL_SEMAPHORE", BlockSemaphore, default=BROWNIE_CALL_SEMAPHORE._value, string_converter=int, verbose=False),
"eth_getBlock": _envs.create_env("ETH_GETBLOCK_SEMAPHORE", a_sync.Semaphore, default=1_000, string_converter=int, verbose=False),
"eth_getLogs": _envs.create_env("ETH_GETLOGS_SEMAPHORE", a_sync.Semaphore, default=64, string_converter=int, verbose=False),
Expand Down
12 changes: 6 additions & 6 deletions dank_mids/_batch.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@

import asyncio
import logging
from typing import TYPE_CHECKING, Any, Generator, List
from typing import TYPE_CHECKING, Any, Awaitable, Generator, List, Union

from dank_mids._exceptions import DankMidsInternalError
from dank_mids._requests import JSONRPCBatch, RPCRequest, _Batch
from dank_mids.types import Multicalls
from dank_mids._requests import _Batch, JSONRPCBatch, Multicall, RPCRequest
from dank_mids.types import Multicalls, RawResponse

if TYPE_CHECKING:
from dank_mids.controller import DankMiddlewareController
Expand All @@ -18,7 +18,7 @@
class DankBatch:
__slots__ = 'controller', 'multicalls', 'rpc_calls', '_started'
""" A batch of jsonrpc batches. This is pretty much deprecated and needs to be refactored away."""
def __init__(self, controller: "DankMiddlewareController", multicalls: Multicalls, rpc_calls: List[RPCRequest]):
def __init__(self, controller: "DankMiddlewareController", multicalls: Multicalls, rpc_calls: List[Union[Multicall, RPCRequest]]):
self.controller = controller
self.multicalls = multicalls
self.rpc_calls = rpc_calls
Expand All @@ -44,7 +44,7 @@ def start(self) -> None:
self._started = True

@property
def coroutines(self) -> Generator["_Batch", None, None]:
def coroutines(self) -> Generator[Union["_Batch", Awaitable[RawResponse]], None, None]:
# Combine multicalls into one or more jsonrpc batches

# Create empty batch
Expand Down Expand Up @@ -73,6 +73,6 @@ def coroutines(self) -> Generator["_Batch", None, None]:
if working_batch.is_single_multicall:
yield working_batch[0] # type: ignore [misc]
elif len(working_batch) == 1:
yield working_batch[0].make_request()
yield working_batch[0].make_request() # type: ignore [union-attr]
else:
yield working_batch
2 changes: 1 addition & 1 deletion dank_mids/_demo_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ class DummyLogger:
def info(self, *args: Any, **kwargs: Any) -> None:
...

demo_logger = logging.getLogger("dank_mids.demo") if ENVIRONMENT_VARIABLES.DEMO_MODE else DummyLogger() # type: ignore [attr-defined]
demo_logger: logging.Logger = logging.getLogger("dank_mids.demo") if ENVIRONMENT_VARIABLES.DEMO_MODE else DummyLogger() # type: ignore [attr-defined, assignment]
2 changes: 1 addition & 1 deletion dank_mids/_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(
internal_err_types = Union[AttributeError, TypeError, UnboundLocalError, NotImplementedError, RuntimeError, SyntaxError]

class DankMidsInternalError(Exception):
def __init__(self, e: internal_err_types) -> None:
def __init__(self, e: Union[ValueError, internal_err_types]) -> None:
logger.warning(f"unhandled exception inside dank mids internals: {e}", exc_info=True)
self._original_exception = e
super().__init__(e.__repr__())
118 changes: 62 additions & 56 deletions dank_mids/_requests.py

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions dank_mids/_uid.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ def next(self) -> int:
self._value = new
return new

class _AlertingRLock(threading._RLock):
class _AlertingRLock(threading._RLock): # type: ignore [misc]
def __init__(self, name: str) -> None:
super().__init__()
self.name = name
def acquire(self, blocking: bool = True, timeout: int = -1) -> bool:
def acquire(self, blocking: bool = True, timeout: int = -1) -> bool: # type: ignore [override]
acquired = super().acquire(blocking=False, timeout=5)
if not acquired:
logger.warning("wtf?! %s with name %s is locked!", self, self.name)
Expand Down
22 changes: 15 additions & 7 deletions dank_mids/brownie_patch/_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
from brownie.typing import AccountsType
from brownie.convert.datatypes import EthAddress
from eth_abi.exceptions import InsufficientDataBytes
from web3 import Web3
from hexbytes.main import BytesLike

from dank_mids import ENVIRONMENT_VARIABLES as ENVS
from dank_mids.brownie_patch._abi import FunctionABI
from dank_mids.helpers._helpers import _make_hashable
from dank_mids.helpers._helpers import DankWeb3, _make_hashable

_EVMType = TypeVar("_EVMType")

Expand All @@ -36,6 +36,14 @@ def abi(self) -> dict:
@property
def signature(self) -> str:
return self._abi.signature
async def coroutine( # type: ignore [empty-body]
self,
*args: Any,
block_identifier: Optional[int] = None,
decimals: Optional[int] = None,
override: Optional[Dict[str, str]] = None,
) -> _EVMType:
raise NotImplementedError
@property
def _input_sig(self) -> str:
return self._abi.input_sig
Expand All @@ -47,13 +55,13 @@ def _skip_decoder_proc_pool(self) -> bool:
from dank_mids.brownie_patch.call import _skip_proc_pool
return self._address in _skip_proc_pool
@functools.cached_property
def _web3(cls) -> Web3:
def _web3(cls) -> DankWeb3:
from dank_mids import web3
return web3
@functools.cached_property
def _prep_request_data(self) -> Callable[..., Awaitable[bytes]]:
def _prep_request_data(self) -> Callable[..., Awaitable[BytesLike]]:
from dank_mids.brownie_patch import call
if ENVS.OPERATION_MODE.application or self._len_inputs:
if ENVS.OPERATION_MODE.application or self._len_inputs: # type: ignore [attr-defined]
return call.encode
else:
return call._request_data_no_args
Expand Down Expand Up @@ -97,9 +105,9 @@ async def coroutine( # type: ignore [empty-body]
"""
if override:
raise ValueError("Cannot use state override with `coroutine`.")
async with ENVS.BROWNIE_ENCODER_SEMAPHORE[block_identifier]: # type: ignore [attr-defined]
async with ENVS.BROWNIE_ENCODER_SEMAPHORE[block_identifier]: # type: ignore [attr-defined,index]
data = await self._encode_input(self, self._len_inputs, self._prep_request_data, *args)
async with ENVS.BROWNIE_CALL_SEMAPHORE[block_identifier]: # type: ignore [attr-defined]
async with ENVS.BROWNIE_CALL_SEMAPHORE[block_identifier]: # type: ignore [attr-defined,index]
output = await self._web3.eth.call({"to": self._address, "data": data}, block_identifier)
try:
decoded = await self._decode_output(self, output)
Expand Down
32 changes: 17 additions & 15 deletions dank_mids/brownie_patch/call.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from typing import Any, Dict, Optional, Tuple, Union

import eth_abi
import eth_retry
from a_sync import AsyncProcessPoolExecutor
from brownie import chain
from brownie.convert.normalize import format_input, format_output
Expand All @@ -17,43 +16,46 @@
from brownie.network.contract import Contract, ContractCall
from brownie.project.compiler.solidity import SOLIDITY_ERROR_CODES
from eth_abi.exceptions import InsufficientDataBytes
from eth_typing import HexStr
from eth_utils import to_checksum_address
from hexbytes import HexBytes
from hexbytes.main import BytesLike
from multicall.constants import MULTICALL2_ADDRESSES
from web3 import Web3
from web3.types import BlockIdentifier

from dank_mids import ENVIRONMENT_VARIABLES as ENVS
from dank_mids.brownie_patch.types import ContractMethod
from dank_mids.exceptions import Revert
from dank_mids.helpers._helpers import DankWeb3

logger = logging.getLogger(__name__)
encode = lambda self, *args: ENVS.BROWNIE_ENCODER_PROCESSES.run(__encode_input, self.abi, self.signature, *args) # type: ignore [attr-defined]
decode = lambda self, data: ENVS.BROWNIE_DECODER_PROCESSES.run(__decode_output, data, self.abi) # type: ignore [attr-defined]

def _patch_call(call: ContractCall, w3: Web3) -> None:
def _patch_call(call: ContractCall, w3: DankWeb3) -> None:
call._skip_decoder_proc_pool = call._address in _skip_proc_pool
call.coroutine = MethodType(_get_coroutine_fn(w3, len(call.abi['inputs'])), call)
call.__await__ = MethodType(_call_no_args, call)

@functools.lru_cache
def _get_coroutine_fn(w3: Web3, len_inputs: int):
def _get_coroutine_fn(w3: DankWeb3, len_inputs: int):
if ENVS.OPERATION_MODE.application or len_inputs: # type: ignore [attr-defined]
get_request_data = encode
else:
get_request_data = _request_data_no_args
get_request_data = _request_data_no_args # type: ignore [assignment]

async def coroutine(
self: ContractCall,
*args: Tuple[Any,...],
block_identifier: Optional[Union[int, str, bytes]] = None,
block_identifier: Optional[BlockIdentifier] = None,
decimals: Optional[int] = None,
override: Optional[Dict[str, str]] = None
) -> Any:
if override:
raise ValueError("Cannot use state override with `coroutine`.")
async with ENVS.BROWNIE_ENCODER_SEMAPHORE[block_identifier]: # type: ignore [attr-defined]
async with ENVS.BROWNIE_ENCODER_SEMAPHORE[block_identifier]: # type: ignore [attr-defined, index]
data = await encode_input(self, len_inputs, get_request_data, *args)
async with ENVS.BROWNIE_CALL_SEMAPHORE[block_identifier]: # type: ignore [attr-defined]
async with ENVS.BROWNIE_CALL_SEMAPHORE[block_identifier]: # type: ignore [attr-defined, index]
output = await w3.eth.call({"to": self._address, "data": data}, block_identifier)
try:
decoded = await decode_output(self, output)
Expand All @@ -68,7 +70,7 @@ def _call_no_args(self: ContractMethod):
"""Asynchronously call `self` with no arguments at the latest block."""
return self.coroutine().__await__()

async def encode_input(call: ContractCall, len_inputs, get_request_data, *args) -> bytes:
async def encode_input(call: ContractCall, len_inputs, get_request_data, *args) -> HexStr:
if any(isinstance(arg, Contract) for arg in args) or any(hasattr(arg, "__contains__") for arg in args): # We will just assume containers contain a Contract object until we have a better way to handle this
# We can't unpickle these because of the added `coroutine` method.
data = __encode_input(call.abi, call.signature, *args)
Expand All @@ -82,7 +84,7 @@ async def encode_input(call: ContractCall, len_inputs, get_request_data, *args)
except BrokenProcessPool:
logger.critical("Oh fuck, you broke the %s while decoding %s with abi %s", ENVS.BROWNIE_ENCODER_PROCESSES, data, call.abi) # type: ignore [attr-defined]
# Let's fix that right up
ENVS.BROWNIE_ENCODER_PROCESSES = AsyncProcessPoolExecutor(ENVS.BROWNIE_ENCODER_PROCESSES._max_workers) # type: ignore [attr-defined]
ENVS.BROWNIE_ENCODER_PROCESSES = AsyncProcessPoolExecutor(ENVS.BROWNIE_ENCODER_PROCESSES._max_workers) # type: ignore [attr-defined,assignment]
data = __encode_input(call.abi, call.signature, *args) if len_inputs else call.signature
except PicklingError: # But if that fails, don't worry. I got you.
data = __encode_input(call.abi, call.signature, *args) if len_inputs else call.signature
Expand All @@ -105,7 +107,7 @@ async def decode_output(call: ContractCall, data: bytes) -> Any:
except BrokenProcessPool:
# Let's fix that right up
logger.critical("Oh fuck, you broke the %s while decoding %s with abi %s", ENVS.BROWNIE_DECODER_PROCESSES, data, call.abi) # type: ignore [attr-defined]
ENVS.BROWNIE_DECODER_PROCESSES = AsyncProcessPoolExecutor(ENVS.BROWNIE_DECODER_PROCESSES._max_workers) # type: ignore [attr-defined]
ENVS.BROWNIE_DECODER_PROCESSES = AsyncProcessPoolExecutor(ENVS.BROWNIE_DECODER_PROCESSES._max_workers) # type: ignore [attr-defined, assignment]
decoded = __decode_output(data, call.abi)
# We have to do it like this so we don't break the process pool.
if isinstance(decoded, Exception):
Expand All @@ -119,14 +121,14 @@ async def decode_output(call: ContractCall, data: bytes) -> Any:
call._skip_decoder_proc_pool = call._address in _skip_proc_pool
return await decode_output(call, data)

async def _request_data_no_args(call: ContractCall) -> str:
async def _request_data_no_args(call: ContractCall) -> HexStr:
return call.signature

# These methods were renamed in eth-abi 4.0.0
__eth_abi_encode = eth_abi.encode if hasattr(eth_abi, 'encode') else eth_abi.encode_abi
__eth_abi_decode = eth_abi.decode if hasattr(eth_abi, 'decode') else eth_abi.decode_abi

def __encode_input(abi: Dict[str, Any], signature: str, *args: Tuple[Any,...]) -> str:
def __encode_input(abi: Dict[str, Any], signature: str, *args: Tuple[Any,...]) -> Union[HexStr, Exception]:
try:
data = format_input(abi, args)
types_list = get_type_strings(abi["inputs"])
Expand All @@ -146,7 +148,7 @@ def __encode_input(abi: Dict[str, Any], signature: str, *args: Tuple[Any,...]) -
if multicall2 := MULTICALL2_ADDRESSES.get(chainid, None):
_skip_proc_pool.add(to_checksum_address(multicall2))

def __decode_output(hexstr: str, abi: Dict[str, Any]) -> Any:
def __decode_output(hexstr: BytesLike, abi: Dict[str, Any]) -> Any:
try:
types_list = get_type_strings(abi["outputs"])
result = __eth_abi_decode(types_list, HexBytes(hexstr))
Expand All @@ -157,7 +159,7 @@ def __decode_output(hexstr: str, abi: Dict[str, Any]) -> Any:
except Exception as e:
return e

def __validate_output(abi: Dict[str, Any], hexstr: str):
def __validate_output(abi: Dict[str, Any], hexstr: BytesLike):
try:
selector = HexBytes(hexstr)[:4].hex()
if selector == "0x08c379a0":
Expand Down
11 changes: 6 additions & 5 deletions dank_mids/brownie_patch/contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from dank_mids.brownie_patch.call import _patch_call
from dank_mids.brownie_patch.overloaded import _patch_overloaded_method
from dank_mids.brownie_patch.types import ContractMethod, DankContractMethod, DankOverloadedMethod, _get_method_object

from dank_mids.helpers._helpers import DankWeb3


EventName = NewType("EventName", str)
LogTopic = NewType("LogTopic", str)
Expand Down Expand Up @@ -97,10 +98,10 @@ def __get_method_object__(self, name: str) -> DankContractMethod:


@overload
def patch_contract(contract: Contract, w3: Optional[Web3] = None) -> Contract:...
def patch_contract(contract: Contract, w3: Optional[DankWeb3] = None) -> Contract:...
@overload
def patch_contract(contract: Union[brownie.Contract, str], w3: Optional[Web3] = None) -> brownie.Contract:...
def patch_contract(contract: Union[Contract, brownie.Contract, str], w3: Optional[Web3] = None) -> Union[Contract, brownie.Contract]:
def patch_contract(contract: Union[brownie.Contract, str], w3: Optional[DankWeb3] = None) -> brownie.Contract:...
def patch_contract(contract: Union[Contract, brownie.Contract, str], w3: Optional[DankWeb3] = None) -> Union[Contract, brownie.Contract]:
"""returns a patched version of `contract` with async and call batchings functionalities"""
if not isinstance(contract, brownie.Contract):
contract = brownie.Contract(contract)
Expand All @@ -112,7 +113,7 @@ def patch_contract(contract: Union[Contract, brownie.Contract, str], w3: Optiona
_patch_if_method(v, w3)
return contract

def _patch_if_method(method: ContractMethod, w3: Web3) -> None:
def _patch_if_method(method: ContractMethod, w3: DankWeb3) -> None:
if isinstance(method, (ContractCall, ContractTx)):
_patch_call(method, w3)
elif isinstance(method, OverloadedMethod):
Expand Down
4 changes: 2 additions & 2 deletions dank_mids/brownie_patch/overloaded.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
from brownie.network.contract import ContractCall, ContractTx, OverloadedMethod
from dank_mids.brownie_patch.call import _get_coroutine_fn, _skip_proc_pool
from dank_mids.brownie_patch.types import ContractMethod
from web3 import Web3
from dank_mids.helpers._helpers import DankWeb3


def _patch_overloaded_method(call: OverloadedMethod, w3: Web3) -> None:
def _patch_overloaded_method(call: OverloadedMethod, w3: DankWeb3) -> None:
# sourcery skip: avoid-builtin-shadow
@functools.wraps(call)
async def coroutine(
Expand Down
8 changes: 5 additions & 3 deletions dank_mids/constants.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@

# mypy: disable-error-code="attr-defined, dict-item"
from typing import Dict

import multicall
from eth_typing import BlockNumber
from multicall.constants import Network

TOO_MUCH_DATA_ERRS = ["Payload Too Large", "content length too large", "request entity too large", "batch limit exceeded"]
Expand All @@ -13,14 +15,14 @@
except AttributeError:
MULTICALL3_OVERRIDE_CODE = multicall.constants.MULTICALL2_BYTECODE

MULTICALL2_DEPLOY_BLOCKS = {
MULTICALL2_DEPLOY_BLOCKS: Dict[Network, BlockNumber] = {
Network.Mainnet: 12336033,
Network.Fantom: 16572242,
Network.Arbitrum: 821923,
Network.Optimism: 722566,
}

MULTICALL3_DEPLOY_BLOCKS = {
MULTICALL3_DEPLOY_BLOCKS: Dict[Network, BlockNumber] = {
Network.Mainnet: 14353601,
Network.Fantom: 33001987,
Network.Arbitrum: 7654707,
Expand Down
Loading

0 comments on commit 58f7f14

Please sign in to comment.