Skip to content

Commit

Permalink
fix: issues with CurrencyComparableValue appearing `ContractLog.eve…
Browse files Browse the repository at this point in the history
…nt_arguments` & on Pydantic models (ApeWorX#2221)
  • Loading branch information
antazoey authored Aug 22, 2024
1 parent a0726d9 commit fa2ba27
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 12 deletions.
48 changes: 47 additions & 1 deletion src/ape/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,14 @@
)
from ethpm_types.abi import EventABI
from ethpm_types.source import Closure
from pydantic import BaseModel, BeforeValidator, field_validator, model_validator
from pydantic import BaseModel, BeforeValidator, field_serializer, field_validator, model_validator
from pydantic_core.core_schema import (
CoreSchema,
ValidationInfo,
int_schema,
no_info_plain_validator_function,
plain_serializer_function_ser_schema,
)
from typing_extensions import TypeAlias
from web3.types import FilterParams

Expand Down Expand Up @@ -251,6 +258,15 @@ def __eq__(self, other: Any) -> bool:

return True

@field_serializer("event_arguments")
def _serialize_event_arguments(self, event_arguments, info):
"""
Because of an issue with BigInt in Pydantic,
(https://github.com/pydantic/pydantic/issues/10152)
we have to ensure these are regular ints.
"""
return {k: int(v) if isinstance(v, int) else v for k, v in event_arguments.items()}


class ContractLog(ExtraAttributesMixin, BaseContractLog):
"""
Expand Down Expand Up @@ -484,6 +500,36 @@ def __eq__(self, other: Any) -> bool:
# Try from the other end, if hasn't already.
return NotImplemented

@classmethod
def __get_pydantic_core_schema__(cls, value, handler=None) -> CoreSchema:
return no_info_plain_validator_function(
cls._validate,
serialization=plain_serializer_function_ser_schema(
cls._serialize,
info_arg=False,
return_schema=int_schema(),
),
)

@staticmethod
def _validate(value: Any, info: Optional[ValidationInfo] = None) -> "CurrencyValueComparable":
# NOTE: For some reason, for this to work, it has to happen
# in an "after" validator, or else it always only `int` type on the model.
if value is None:
# Will fail if not optional.
# Type ignore because this is an hacky and unlikely situation.
return None # type: ignore

elif isinstance(value, str) and " " in value:
return ManagerAccessMixin.conversion_manager.convert(value, int)

# For models annotating with this type, we validate all integers into it.
return CurrencyValueComparable(value)

@staticmethod
def _serialize(value):
return int(value)


CurrencyValueComparable.__name__ = int.__name__

Expand Down
3 changes: 3 additions & 0 deletions tests/functional/geth/test_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def test_uri_when_configured(geth_provider, project, ethereum):
assert actual_mainnet_uri == expected


@geth_process_test
def test_uri_non_dev_and_not_configured(mocker, ethereum):
"""
If the URI was not configured and we are not using a dev
Expand Down Expand Up @@ -547,6 +548,7 @@ def test_make_request_not_exists(geth_provider):
geth_provider.make_request("ape_thisDoesNotExist")


@geth_process_test
def test_geth_bin_not_found():
bin_name = "__NOT_A_REAL_EXECUTABLE_HOPEFULLY__"
with pytest.raises(NodeSoftwareNotInstalledError):
Expand Down Expand Up @@ -677,6 +679,7 @@ def test_trace_approach_config(project):
assert provider.call_trace_approach is TraceApproach.GETH_STRUCT_LOG_PARSE


@geth_process_test
def test_start(mocker, convert, project, geth_provider):
amount = convert("100_000 ETH", int)
spy = mocker.spy(GethDevProcess, "from_uri")
Expand Down
21 changes: 20 additions & 1 deletion tests/functional/test_contract_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from ape.api import ReceiptAPI
from ape.exceptions import ProviderError
from ape.types import ContractLog
from ape.types import ContractLog, CurrencyValueComparable


@pytest.fixture
Expand Down Expand Up @@ -363,3 +363,22 @@ def test_info(solidity_contract_instance):
{spec}
""".strip()
assert actual == expected


def test_model_dump(solidity_contract_container, owner):
# NOTE: deploying a new contract with a new number to lessen x-dist conflicts.
contract = owner.deploy(solidity_contract_container, 29620000000003)

# First, get an event (a normal way).
number = int(10e18)
tx = contract.setNumber(number, sender=owner)
event = tx.events[0]

# Next, invoke `.model_dump()` to get the serialized version.
log = event.model_dump()
actual = log["event_arguments"]
assert actual["newNum"] == number

# This next assertion is important because of this Pydantic bug:
# https://github.com/pydantic/pydantic/issues/10152
assert not isinstance(actual["newNum"], CurrencyValueComparable)
76 changes: 66 additions & 10 deletions tests/functional/test_types.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from typing import Optional
from typing import Any, Optional

import pytest
from eth_utils import to_hex
from ethpm_types.abi import EventABI
from hexbytes import HexBytes
from pydantic import BaseModel, Field

from ape.types import AddressType, ContractLog, HexInt, LogFilter
from ape.types import AddressType, ContractLog, CurrencyValueComparable, HexInt, LogFilter
from ape.utils import ZERO_ADDRESS

TXN_HASH = "0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa222222222222222222222222"
Expand Down Expand Up @@ -131,11 +131,67 @@ class MyModel(BaseModel):


class TestHexInt:
class MyModel(BaseModel):
ual: HexInt = 0
ual_optional: Optional[HexInt] = Field(default=None, validate_default=True)

act = MyModel.model_validate({"ual": "0x123"})
expected = 291 # Base-10 form of 0x123.
assert act.ual == expected
assert act.ual_optional is None
def test_model(self):
class MyModel(BaseModel):
ual: HexInt = 0
ual_optional: Optional[HexInt] = Field(default=None, validate_default=True)

act = MyModel.model_validate({"ual": "0x123"})
expected = 291 # Base-10 form of 0x123.
assert act.ual == expected
assert act.ual_optional is None


class TestCurrencyValueComparable:
def test_use_for_int_in_pydantic_model(self):
value = 100000000000000000000000000000000000000000000

class MyBasicModel(BaseModel):
val: int

model = MyBasicModel.model_validate({"val": CurrencyValueComparable(value)})
assert model.val == value

# Ensure serializes.
dumped = model.model_dump()
assert dumped["val"] == value

@pytest.mark.parametrize("mode", ("json", "python"))
def test_use_in_model_annotation(self, mode):
value = 100000000000000000000000000000000000000000000

class MyAnnotatedModel(BaseModel):
val: CurrencyValueComparable
val_optional: Optional[CurrencyValueComparable]

model = MyAnnotatedModel.model_validate({"val": value, "val_optional": value})
assert isinstance(model.val, CurrencyValueComparable)
assert model.val == value

# Show can use currency-comparable
expected_currency_value = "100000000000000000000000000 ETH"
assert model.val == expected_currency_value
assert model.val_optional == expected_currency_value

# Ensure serializes.
dumped = model.model_dump(mode=mode)
assert dumped["val"] == value
assert dumped["val_optional"] == value

def test_validate_from_currency_value(self):
class MyAnnotatedModel(BaseModel):
val: CurrencyValueComparable
val_optional: Optional[CurrencyValueComparable]
val_in_dict: dict[str, Any]

value = "100000000000000000000000000 ETH"
expected = 100000000000000000000000000000000000000000000
data = {
"val": value,
"val_optional": value,
"val_in_dict": {"value": CurrencyValueComparable(expected)},
}
model = MyAnnotatedModel.model_validate(data)
for actual in (model.val, model.val_optional, model.val_in_dict["value"]):
for ex in (value, expected):
assert actual == ex

0 comments on commit fa2ba27

Please sign in to comment.