Skip to content

Commit

Permalink
Fix cache memory leaks
Browse files Browse the repository at this point in the history
- `@cache` is stored in the Class, not in the Object. So when multiple `Safe` objects were created, and `get_contract()` was called, cache only increased (also if the same Safe was created multiple times)
- Use now `@cached_property`, which is binded to the object, so when the object is destroyed, so is the cached value
- Closes #322
- Related to #325
  • Loading branch information
Uxio0 committed Aug 25, 2022
1 parent 0147c94 commit 4f1e009
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 96 deletions.
90 changes: 35 additions & 55 deletions gnosis/safe/safe.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,9 @@
from .safe_tx import SafeTx

try:
from functools import cache
from functools import cached_property
except ImportError:
from functools import lru_cache

cache = lru_cache(maxsize=None)
from cached_property import cached_property


logger = getLogger(__name__)
Expand Down Expand Up @@ -502,6 +500,15 @@ def build_safe_create2_tx(

return safe_creation_tx

@cached_property
def contract(self) -> Contract:
v_1_3_0_contract = get_safe_V1_3_0_contract(self.w3, address=self.address)
version = v_1_3_0_contract.functions.VERSION().call()
if version == "1.3.0":
return v_1_3_0_contract
else:
return get_safe_V1_1_1_contract(self.w3, address=self.address)

def check_funds_for_tx_gas(
self, safe_tx_gas: int, base_gas: int, gas_price: int, gas_token: str
) -> bool:
Expand Down Expand Up @@ -542,7 +549,7 @@ def estimate_tx_base_gas(
:return:
"""
data = data or b""
safe_contract = self.get_contract()
safe_contract = self.contract
threshold = self.retrieve_threshold()
nonce = self.retrieve_nonce()

Expand Down Expand Up @@ -651,16 +658,14 @@ def parse_revert_data(result: bytes) -> int:

return int(gas_estimation.hex(), 16)

tx = (
self.get_contract()
.functions.requiredTxGas(to, value, data, operation)
.build_transaction(
{
"from": safe_address,
"gas": 0, # Don't call estimate
"gasPrice": 0, # Don't get gas price
}
)
tx = self.contract.functions.requiredTxGas(
to, value, data, operation
).build_transaction(
{
"from": safe_address,
"gas": 0, # Don't call estimate
"gasPrice": 0, # Don't get gas price
}
)

tx_params = {
Expand Down Expand Up @@ -822,15 +827,6 @@ def estimate_tx_operational_gas(self, data_bytes_length: int):
threshold = self.retrieve_threshold()
return 15000 + data_bytes_length // 32 * 100 + 5000 * threshold

@cache
def get_contract(self) -> Contract:
v_1_3_0_contract = get_safe_V1_3_0_contract(self.w3, address=self.address)
version = v_1_3_0_contract.functions.VERSION().call()
if version == "1.3.0":
return v_1_3_0_contract
else:
return get_safe_V1_1_1_contract(self.w3, address=self.address)

def retrieve_all_info(
self, block_identifier: Optional[BlockIdentifier] = "latest"
) -> SafeInfo:
Expand All @@ -842,7 +838,7 @@ def retrieve_all_info(
:raises: CannotRetrieveSafeInfoException
"""
try:
contract = self.get_contract()
contract = self.contract
master_copy = self.retrieve_master_copy_address()
fallback_handler = self.retrieve_fallback_handler()
guard = self.retrieve_guard()
Expand Down Expand Up @@ -938,7 +934,7 @@ def retrieve_modules(
except BadFunctionCallOutput:
pass

contract = self.get_contract()
contract = self.contract
address = SENTINEL_ADDRESS
all_modules: List[str] = []
while True:
Expand All @@ -960,9 +956,9 @@ def retrieve_is_hash_approved(
block_identifier: Optional[BlockIdentifier] = "latest",
) -> bool:
return (
self.get_contract()
.functions.approvedHashes(owner, safe_hash)
.call(block_identifier=block_identifier)
self.contract.functions.approvedHashes(owner, safe_hash).call(
block_identifier=block_identifier
)
== 1
)

Expand All @@ -971,56 +967,40 @@ def retrieve_is_message_signed(
message_hash: bytes,
block_identifier: Optional[BlockIdentifier] = "latest",
) -> bool:
return (
self.get_contract()
.functions.signedMessages(message_hash)
.call(block_identifier=block_identifier)
return self.contract.functions.signedMessages(message_hash).call(
block_identifier=block_identifier
)

def retrieve_is_owner(
self, owner: str, block_identifier: Optional[BlockIdentifier] = "latest"
) -> bool:
return (
self.get_contract()
.functions.isOwner(owner)
.call(block_identifier=block_identifier)
return self.contract.functions.isOwner(owner).call(
block_identifier=block_identifier
)

def retrieve_nonce(
self, block_identifier: Optional[BlockIdentifier] = "latest"
) -> int:
return (
self.get_contract()
.functions.nonce()
.call(block_identifier=block_identifier)
)
return self.contract.functions.nonce().call(block_identifier=block_identifier)

def retrieve_owners(
self, block_identifier: Optional[BlockIdentifier] = "latest"
) -> List[str]:
return (
self.get_contract()
.functions.getOwners()
.call(block_identifier=block_identifier)
return self.contract.functions.getOwners().call(
block_identifier=block_identifier
)

def retrieve_threshold(
self, block_identifier: Optional[BlockIdentifier] = "latest"
) -> int:
return (
self.get_contract()
.functions.getThreshold()
.call(block_identifier=block_identifier)
return self.contract.functions.getThreshold().call(
block_identifier=block_identifier
)

def retrieve_version(
self, block_identifier: Optional[BlockIdentifier] = "latest"
) -> str:
return (
self.get_contract()
.functions.VERSION()
.call(block_identifier=block_identifier)
)
return self.contract.functions.VERSION().call(block_identifier=block_identifier)

def build_multisig_tx(
self,
Expand Down
12 changes: 6 additions & 6 deletions gnosis/safe/tests/test_safe.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,9 +572,9 @@ def test_retrieve_guard(self):

guard_address = Account.create().address
set_guard_data = HexBytes(
safe.get_contract()
.functions.setGuard(guard_address)
.build_transaction({"gas": 1, "gasPrice": 1})["data"]
safe.contract.functions.setGuard(guard_address).build_transaction(
{"gas": 1, "gasPrice": 1}
)["data"]
)
set_guard_tx = safe.build_multisig_tx(safe.address, 0, set_guard_data)
set_guard_tx.sign(owner_account.key)
Expand Down Expand Up @@ -629,7 +629,7 @@ def test_retrieve_all_info(self):

def test_retrieve_modules(self):
safe = self.deploy_test_safe(owners=[self.ethereum_test_account.address])
safe_contract = safe.get_contract()
safe_contract = safe.contract
module_address = Account.create().address
self.assertEqual(safe.retrieve_modules(), [])

Expand Down Expand Up @@ -663,7 +663,7 @@ def test_retrieve_modules(self):

def test_retrieve_is_hash_approved(self):
safe = self.deploy_test_safe(owners=[self.ethereum_test_account.address])
safe_contract = safe.get_contract()
safe_contract = safe.contract
fake_tx_hash = Web3.keccak(text="Knopfler")
another_tx_hash = Web3.keccak(text="Marc")
tx = safe_contract.functions.approveHash(fake_tx_hash).build_transaction(
Expand All @@ -686,7 +686,7 @@ def test_retrieve_is_hash_approved(self):

def test_retrieve_is_message_signed(self):
safe = self.deploy_test_safe_v1_1_1(owners=[self.ethereum_test_account.address])
safe_contract = safe.get_contract()
safe_contract = safe.contract
message = b"12345"
message_hash = safe_contract.functions.getMessageHash(message).call()
sign_message_data = HexBytes(
Expand Down
4 changes: 2 additions & 2 deletions gnosis/safe/tests/test_safe_signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def test_contract_signature(self):
safe = self.deploy_test_safe_v1_1_1(
owners=[owner_1.address], initial_funding_wei=Web3.toWei(0.01, "ether")
)
safe_contract = safe.get_contract()
safe_contract = safe.contract
safe_tx_hash = Web3.keccak(text="test")
signature_r = HexBytes(safe.address.replace("0x", "").rjust(64, "0"))
signature_s = HexBytes(
Expand Down Expand Up @@ -241,7 +241,7 @@ def test_contract_multiple_signatures(self):
safe = self.deploy_test_safe_v1_1_1(
owners=[owner_1.address], initial_funding_wei=Web3.toWei(0.01, "ether")
)
safe_contract = safe.get_contract()
safe_contract = safe.contract
safe_tx_hash = Web3.keccak(text="test")

tx = safe_contract.functions.signMessage(safe_tx_hash).build_transaction(
Expand Down
58 changes: 25 additions & 33 deletions gnosis/safe/tests/test_safe_tx.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def test_multi_send_safe_tx(self):
threshold=threshold,
initial_funding_wei=self.w3.toWei(0.1, "ether"),
)
safe_contract = safe.get_contract()
safe_contract = safe.contract
to = self.multi_send_contract.address
value = 0
safe_tx_gas = 600000
Expand Down Expand Up @@ -273,22 +273,18 @@ def test_hash_safe_multisig_tx(self):

safe = self.deploy_test_safe_v1_1_1()
# Expected hash must be the same calculated by `getTransactionHash` of the contract
expected_hash = (
safe.get_contract()
.functions.getTransactionHash(
"0x5AC255889882aaB35A2aa939679E3F3d4Cea221E",
5212459,
HexBytes(0x00),
1,
123456,
122,
12345,
"0x" + "2" * 40,
"0x" + "2" * 40,
10789,
)
.call()
)
expected_hash = safe.contract.functions.getTransactionHash(
"0x5AC255889882aaB35A2aa939679E3F3d4Cea221E",
5212459,
HexBytes(0x00),
1,
123456,
122,
12345,
"0x" + "2" * 40,
"0x" + "2" * 40,
10789,
).call()
safe_tx_hash = SafeTx(
self.ethereum_client,
safe.address,
Expand All @@ -309,22 +305,18 @@ def test_hash_safe_multisig_tx(self):
# Safe v1.3.0
safe = self.deploy_test_safe()
# Expected hash must be the same calculated by `getTransactionHash` of the contract
expected_hash = (
safe.get_contract()
.functions.getTransactionHash(
"0x5AC255889882aaB35A2aa939679E3F3d4Cea221E",
5212459,
HexBytes(0x00),
1,
123456,
122,
12345,
"0x" + "2" * 40,
"0x" + "2" * 40,
10789,
)
.call()
)
expected_hash = safe.contract.functions.getTransactionHash(
"0x5AC255889882aaB35A2aa939679E3F3d4Cea221E",
5212459,
HexBytes(0x00),
1,
123456,
122,
12345,
"0x" + "2" * 40,
"0x" + "2" * 40,
10789,
).call()
safe_tx_hash = SafeTx(
self.ethereum_client,
safe.address,
Expand Down

0 comments on commit 4f1e009

Please sign in to comment.