Skip to content

Commit a880bbf

Browse files
committed
refactor: use balances for a single user instead of multi-user
1 parent 825d932 commit a880bbf

File tree

1 file changed

+31
-19
lines changed

1 file changed

+31
-19
lines changed

src/vault/TradingVault.sol

Lines changed: 31 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ contract TradingVault is ITradingVault, IAuthorizableV2, ReentrancyGuard, Ownabl
1616
/// that may not return a boolean value on success.
1717
using SafeERC20 for IERC20;
1818

19-
mapping(address user => mapping(address token => uint256 balance)) internal _balances;
19+
mapping(address token => uint256 balance) internal _balances;
20+
mapping(address token => uint256 balance) internal _broker_balances;
2021
uint256 internal _nonce;
2122

2223
address public broker;
@@ -34,14 +35,14 @@ contract TradingVault is ITradingVault, IAuthorizableV2, ReentrancyGuard, Ownabl
3435

3536
// ---------- View functions ----------
3637

37-
function balanceOf(address user, address token) external view returns (uint256) {
38-
return _balances[user][token];
38+
function balanceOf(address token) external view returns (uint256) {
39+
return _balances[token];
3940
}
4041

41-
function balancesOfTokens(address user, address[] calldata tokens) external view returns (uint256[] memory) {
42+
function balancesOfTokens(address[] calldata tokens) external view returns (uint256[] memory) {
4243
uint256[] memory balances = new uint256[](tokens.length);
4344
for (uint256 i = 0; i < tokens.length; i++) {
44-
balances[i] = _balances[user][tokens[i]];
45+
balances[i] = _balances[tokens[i]];
4546
}
4647
return balances;
4748
}
@@ -65,7 +66,6 @@ contract TradingVault is ITradingVault, IAuthorizableV2, ReentrancyGuard, Ownabl
6566

6667
function deposit(ITradingStructs.Intent calldata intent) external payable notZeroAddress(intent.trader) {
6768
address sender = msg.sender;
68-
address recipient = intent.trader;
6969
require(_nonce == intent.nonce, NonceMismatch(_nonce, intent.nonce));
7070

7171
_nonce++;
@@ -77,22 +77,21 @@ contract TradingVault is ITradingVault, IAuthorizableV2, ReentrancyGuard, Ownabl
7777

7878
if (token == address(0)) {
7979
require(msg.value == amount, IncorrectValue());
80-
_balances[recipient][address(0)] += amount;
80+
_balances[address(0)] += amount;
8181
} else {
8282
require(msg.value == 0, IncorrectValue());
83-
_balances[recipient][token] += amount;
83+
_balances[token] += amount;
8484
IERC20(token).safeTransferFrom(sender, address(this), amount);
8585
}
8686

87-
emit Deposited(recipient, token, amount);
87+
emit Deposited(intent.trader, token, amount);
8888
}
8989
}
9090

9191
function withdraw(
9292
ITradingStructs.Intent calldata intent,
9393
bytes calldata additionalAuthData
9494
) external notZeroAddress(intent.trader) {
95-
address account = intent.trader;
9695
require(_nonce == intent.nonce, NonceMismatch(_nonce, intent.nonce));
9796

9897
bytes memory authData = abi.encodePacked(
@@ -115,11 +114,12 @@ contract TradingVault is ITradingVault, IAuthorizableV2, ReentrancyGuard, Ownabl
115114
for (uint256 i = 0; i < intent.allocations.length; i++) {
116115
address asset = intent.allocations[i].asset;
117116
uint256 amount = intent.allocations[i].amount;
118-
uint256 currentBalance = _balances[account][asset];
117+
uint256 currentBalance = _balances[asset];
119118
require(currentBalance >= amount, InsufficientBalance(asset, amount, currentBalance));
120119

121-
_balances[account][asset] -= amount;
120+
_balances[asset] -= amount;
122121

122+
address account = intent.trader;
123123
if (asset == address(0)) {
124124
/// @dev using `call` instead of `transfer` to overcome 2300 gas ceiling that could make it revert with some AA wallets
125125
(bool success, ) = account.call{value: amount}("");
@@ -174,28 +174,38 @@ contract TradingVault is ITradingVault, IAuthorizableV2, ReentrancyGuard, Ownabl
174174

175175
// ---------- Internal functions ----------
176176

177+
function _getBalanceMapping(address user) private view returns (mapping(address => uint256) storage) {
178+
return user == broker ? _broker_balances : _balances;
179+
}
180+
177181
function _checkAndVaultSwap(
178182
address sender,
179183
address receiver,
180184
ITradingStructs.Allocation memory alloc
181185
) internal virtual {
182-
uint256 balance = _balances[sender][alloc.asset];
186+
mapping(address => uint256) storage senderBalances = _getBalanceMapping(sender);
187+
mapping(address => uint256) storage receiverBalances = _getBalanceMapping(receiver);
188+
189+
uint256 balance = senderBalances[alloc.asset];
183190
require(balance >= alloc.amount, InsufficientBalance(alloc.asset, alloc.amount, balance));
184191

185-
_balances[sender][alloc.asset] -= alloc.amount;
186-
_balances[receiver][alloc.asset] += alloc.amount;
192+
senderBalances[alloc.asset] -= alloc.amount;
193+
receiverBalances[alloc.asset] += alloc.amount;
187194
}
188195

189196
function _checkAndVaultSendAccount(
190197
address sender,
191198
address receiver,
192199
ITradingStructs.Allocation memory alloc
193200
) internal virtual {
194-
uint256 balance = _balances[sender][alloc.asset];
201+
mapping(address => uint256) storage senderBalances = _getBalanceMapping(sender);
202+
mapping(address => uint256) storage receiverBalances = _getBalanceMapping(receiver);
203+
204+
uint256 balance = senderBalances[alloc.asset];
195205
require(balance >= alloc.amount, InsufficientBalance(alloc.asset, alloc.amount, balance));
196206

197-
_balances[sender][alloc.asset] -= alloc.amount;
198-
_balances[receiver][alloc.asset] += alloc.amount;
207+
senderBalances[alloc.asset] -= alloc.amount;
208+
receiverBalances[alloc.asset] += alloc.amount;
199209
}
200210

201211
function _accountSendVault(
@@ -204,7 +214,9 @@ contract TradingVault is ITradingVault, IAuthorizableV2, ReentrancyGuard, Ownabl
204214
ITradingStructs.Allocation memory alloc
205215
) internal virtual {
206216
IERC20(alloc.asset).safeTransferFrom(sender, address(this), alloc.amount);
207-
_balances[receiver][alloc.asset] += alloc.amount;
217+
218+
mapping(address => uint256) storage receiverBalances = _getBalanceMapping(receiver);
219+
receiverBalances[alloc.asset] += alloc.amount;
208220
}
209221

210222
function _accountSwap(address sender, address receiver, ITradingStructs.Allocation memory alloc) internal virtual {

0 commit comments

Comments
 (0)