diff --git a/contracts/middleware/BaseImplementation.sol b/contracts/middleware/BaseImplementation.sol index 1dfee1a8..3f5b034f 100644 --- a/contracts/middleware/BaseImplementation.sol +++ b/contracts/middleware/BaseImplementation.sol @@ -47,6 +47,11 @@ abstract contract BaseImplementation is IHooks, SafeCallback { middleware = _middleware; } + function updateDynamicFee(PoolKey calldata key, uint24 fee) external { + if (msg.sender != middlewareFactory) revert NotMiddlewareFactory(); + manager.updateDynamicLPFee(key, fee); + } + function getHookPermissions() public pure virtual returns (Hooks.Permissions memory); function _unlockCallback(bytes calldata data) internal virtual override returns (bytes memory) { diff --git a/contracts/middleware/BaseMiddleware.sol b/contracts/middleware/BaseMiddleware.sol index 8074aace..1dc064f4 100644 --- a/contracts/middleware/BaseMiddleware.sol +++ b/contracts/middleware/BaseMiddleware.sol @@ -11,6 +11,9 @@ import {PoolKey} from "@uniswap/v4-core/src/types/PoolKey.sol"; import {BeforeSwapDelta, BeforeSwapDeltaLibrary} from "@uniswap/v4-core/src/types/BeforeSwapDelta.sol"; contract BaseMiddleware is IHooks { + using Hooks for BaseMiddleware; + using BeforeSwapDeltaLibrary for BeforeSwapDelta; + error NotManager(); IPoolManager public immutable manager; @@ -102,11 +105,15 @@ contract BaseMiddleware is IHooks { PoolKey calldata key, IPoolManager.SwapParams calldata params, bytes calldata hookData - ) external virtual onlyByManager returns (bytes4, BeforeSwapDelta, uint24) { + ) external virtual onlyByManager returns (bytes4 selector, BeforeSwapDelta beforeSwapDelta, uint24 lpFeeOverride) { if (msg.sender == address(implementation)) { return (BaseHook.beforeSwap.selector, BeforeSwapDeltaLibrary.ZERO_DELTA, 0); } - return implementation.beforeSwap(sender, key, params, hookData); + (selector, beforeSwapDelta, lpFeeOverride) = implementation.beforeSwap(sender, key, params, hookData); + if (this.hasPermission(Hooks.BEFORE_SWAP_RETURNS_DELTA_FLAG)) { + manager.take(key.currency0, sender, uint256(uint128(beforeSwapDelta.getSpecifiedDelta()))); + manager.take(key.currency1, sender, uint256(uint128(beforeSwapDelta.getUnspecifiedDelta()))); + } } function afterSwap( diff --git a/test/BaseMiddlewareFactory.t.sol b/test/BaseMiddlewareFactory.t.sol index bef9110b..3652f84c 100644 --- a/test/BaseMiddlewareFactory.t.sol +++ b/test/BaseMiddlewareFactory.t.sol @@ -16,16 +16,18 @@ import {BaseMiddleware} from "../contracts/middleware/BaseMiddleware.sol"; import {BaseMiddlewareFactory} from "./../contracts/middleware/BaseMiddlewareFactory.sol"; import {HookMiner} from "./utils/HookMiner.sol"; import {HooksCounter} from "./middleware-implementations/HooksCounter.sol"; +import {BaseImplementation} from "./../contracts/middleware/BaseImplementation.sol"; contract BaseMiddlewareFactoryTest is Test, Deployers { HookEnabledSwapRouter router; TestERC20 token0; TestERC20 token1; - PoolId id; BaseMiddlewareFactory factory; HooksCounter hooksCounter; + address middleware; + function setUp() public { deployFreshManagerAndRouters(); (currency0, currency1) = deployMintAndApprove2Currencies(); @@ -39,25 +41,85 @@ contract BaseMiddlewareFactoryTest is Test, Deployers { token0.approve(address(router), type(uint256).max); token1.approve(address(router), type(uint256).max); + + uint160 flags = uint160( + Hooks.BEFORE_INITIALIZE_FLAG | Hooks.AFTER_INITIALIZE_FLAG | Hooks.BEFORE_SWAP_FLAG | Hooks.AFTER_SWAP_FLAG + | Hooks.BEFORE_ADD_LIQUIDITY_FLAG | Hooks.AFTER_ADD_LIQUIDITY_FLAG | Hooks.BEFORE_REMOVE_LIQUIDITY_FLAG + | Hooks.AFTER_REMOVE_LIQUIDITY_FLAG | Hooks.BEFORE_DONATE_FLAG | Hooks.AFTER_DONATE_FLAG + ); + + (address hookAddress, bytes32 salt) = HookMiner.find( + address(factory), + flags, + type(BaseMiddleware).creationCode, + abi.encode(address(manager), address(hooksCounter)) + ); + middleware = factory.createMiddleware(address(hooksCounter), salt); + assertEq(hookAddress, middleware); + } + + function testRevertOnAlreadyInitialized() public { + uint160 flags = uint160( + Hooks.BEFORE_INITIALIZE_FLAG | Hooks.AFTER_INITIALIZE_FLAG | Hooks.BEFORE_SWAP_FLAG | Hooks.AFTER_SWAP_FLAG + | Hooks.BEFORE_ADD_LIQUIDITY_FLAG | Hooks.AFTER_ADD_LIQUIDITY_FLAG | Hooks.BEFORE_REMOVE_LIQUIDITY_FLAG + | Hooks.AFTER_REMOVE_LIQUIDITY_FLAG | Hooks.BEFORE_DONATE_FLAG | Hooks.AFTER_DONATE_FLAG + ); + (address hookAddress, bytes32 salt) = HookMiner.find( + address(factory), + flags, + type(BaseMiddleware).creationCode, + abi.encode(address(manager), address(hooksCounter)) + ); + vm.expectRevert(BaseMiddlewareFactory.AlreadyInitialized.selector); + factory.createMiddleware(address(hooksCounter), salt); } function testRevertOnIncorrectFlags() public { + HooksCounter hooksCounter2 = new HooksCounter(manager, address(factory)); uint160 flags = uint160(Hooks.BEFORE_INITIALIZE_FLAG); (address hookAddress, bytes32 salt) = HookMiner.find( address(factory), flags, type(BaseMiddleware).creationCode, - abi.encode(address(manager), address(hooksCounter)) + abi.encode(address(manager), address(hooksCounter2)) ); - address implementation = address(hooksCounter); + address implementation = address(hooksCounter2); vm.expectRevert(abi.encodePacked(bytes16(Hooks.HookAddressNotValid.selector), hookAddress)); factory.createMiddleware(implementation, salt); } function testRevertOnIncorrectFlagsMined() public { - address implementation = address(hooksCounter); + HooksCounter hooksCounter2 = new HooksCounter(manager, address(factory)); + address implementation = address(hooksCounter2); vm.expectRevert(); // HookAddressNotValid factory.createMiddleware(implementation, bytes32("who needs to mine a salt?")); } + + function testRevertOnIncorrectCaller() public { + vm.expectRevert(BaseImplementation.NotMiddleware.selector); + hooksCounter.afterDonate(address(this), key, 0, 0, ZERO_BYTES); + } + + function testCounters() public { + (PoolKey memory key, PoolId id) = + initPoolAndAddLiquidity(currency0, currency1, IHooks(middleware), 3000, SQRT_PRICE_1_1, ZERO_BYTES); + + assertEq(hooksCounter.beforeInitializeCount(id), 1); + assertEq(hooksCounter.afterInitializeCount(id), 1); + assertEq(hooksCounter.beforeSwapCount(id), 0); + assertEq(hooksCounter.afterSwapCount(id), 0); + assertEq(hooksCounter.beforeAddLiquidityCount(id), 1); + assertEq(hooksCounter.afterAddLiquidityCount(id), 1); + assertEq(hooksCounter.beforeRemoveLiquidityCount(id), 0); + assertEq(hooksCounter.afterRemoveLiquidityCount(id), 0); + assertEq(hooksCounter.beforeDonateCount(id), 0); + assertEq(hooksCounter.afterDonateCount(id), 0); + + assertEq(hooksCounter.lastHookData(), ZERO_BYTES); + swap(key, true, 1, bytes("hi")); + assertEq(hooksCounter.lastHookData(), bytes("hi")); + assertEq(hooksCounter.beforeSwapCount(id), 1); + assertEq(hooksCounter.afterSwapCount(id), 1); + } } diff --git a/test/middleware-implementations/HooksCounter.sol b/test/middleware-implementations/HooksCounter.sol index fa6d5e5b..49aeb50a 100644 --- a/test/middleware-implementations/HooksCounter.sol +++ b/test/middleware-implementations/HooksCounter.sol @@ -8,8 +8,27 @@ import {PoolKey} from "@uniswap/v4-core/src/types/PoolKey.sol"; import {BeforeSwapDelta, BeforeSwapDeltaLibrary} from "@uniswap/v4-core/src/types/BeforeSwapDelta.sol"; import {BaseImplementation} from "./../../contracts/middleware/BaseImplementation.sol"; import {BalanceDelta, BalanceDeltaLibrary} from "@uniswap/v4-core/src/types/BalanceDelta.sol"; +import {PoolId, PoolIdLibrary} from "@uniswap/v4-core/src/types/PoolId.sol"; contract HooksCounter is BaseImplementation { + using PoolIdLibrary for PoolKey; + + mapping(PoolId => uint256) public beforeInitializeCount; + mapping(PoolId => uint256) public afterInitializeCount; + + mapping(PoolId => uint256) public beforeSwapCount; + mapping(PoolId => uint256) public afterSwapCount; + + mapping(PoolId => uint256) public beforeAddLiquidityCount; + mapping(PoolId => uint256) public afterAddLiquidityCount; + mapping(PoolId => uint256) public beforeRemoveLiquidityCount; + mapping(PoolId => uint256) public afterRemoveLiquidityCount; + + mapping(PoolId => uint256) public beforeDonateCount; + mapping(PoolId => uint256) public afterDonateCount; + + bytes public lastHookData; + constructor(IPoolManager _manager, address _middlewareFactory) BaseImplementation(_manager, _middlewareFactory) {} function getHookPermissions() public pure override returns (Hooks.Permissions memory) { @@ -31,95 +50,116 @@ contract HooksCounter is BaseImplementation { }); } - function beforeInitialize(address, PoolKey calldata, uint160, bytes calldata) + function beforeInitialize(address, PoolKey calldata key, uint160, bytes calldata hookData) external - pure override + onlyByMiddleware returns (bytes4) { + beforeInitializeCount[key.toId()]++; + lastHookData = hookData; return BaseHook.beforeInitialize.selector; } - function afterInitialize(address, PoolKey calldata, uint160, int24, bytes calldata) + function afterInitialize(address, PoolKey calldata key, uint160, int24, bytes calldata hookData) external - pure override + onlyByMiddleware returns (bytes4) { + afterInitializeCount[key.toId()]++; + lastHookData = hookData; return BaseHook.afterInitialize.selector; } - function beforeAddLiquidity(address, PoolKey calldata, IPoolManager.ModifyLiquidityParams calldata, bytes calldata) - external - pure - override - returns (bytes4) - { - return BaseHook.beforeAddLiquidity.selector; - } - - function beforeRemoveLiquidity( + function beforeAddLiquidity( address, - PoolKey calldata, + PoolKey calldata key, IPoolManager.ModifyLiquidityParams calldata, - bytes calldata - ) external pure override returns (bytes4) { - return BaseHook.beforeRemoveLiquidity.selector; + bytes calldata hookData + ) external override onlyByMiddleware returns (bytes4) { + beforeAddLiquidityCount[key.toId()]++; + lastHookData = hookData; + return BaseHook.beforeAddLiquidity.selector; } function afterAddLiquidity( address, - PoolKey calldata, + PoolKey calldata key, IPoolManager.ModifyLiquidityParams calldata, BalanceDelta, - bytes calldata - ) external pure override returns (bytes4, BalanceDelta) { + bytes calldata hookData + ) external override onlyByMiddleware returns (bytes4, BalanceDelta) { + afterAddLiquidityCount[key.toId()]++; + lastHookData = hookData; return (BaseHook.afterAddLiquidity.selector, BalanceDeltaLibrary.ZERO_DELTA); } + function beforeRemoveLiquidity( + address, + PoolKey calldata key, + IPoolManager.ModifyLiquidityParams calldata, + bytes calldata hookData + ) external override onlyByMiddleware returns (bytes4) { + beforeRemoveLiquidityCount[key.toId()]++; + lastHookData = hookData; + return BaseHook.beforeRemoveLiquidity.selector; + } + function afterRemoveLiquidity( address, - PoolKey calldata, + PoolKey calldata key, IPoolManager.ModifyLiquidityParams calldata, BalanceDelta, - bytes calldata - ) external pure override returns (bytes4, BalanceDelta) { + bytes calldata hookData + ) external override onlyByMiddleware returns (bytes4, BalanceDelta) { + afterRemoveLiquidityCount[key.toId()]++; + lastHookData = hookData; return (BaseHook.afterRemoveLiquidity.selector, BalanceDeltaLibrary.ZERO_DELTA); } - function beforeSwap(address, PoolKey calldata, IPoolManager.SwapParams calldata, bytes calldata) + function beforeSwap(address, PoolKey calldata key, IPoolManager.SwapParams calldata, bytes calldata hookData) external - pure override + onlyByMiddleware returns (bytes4, BeforeSwapDelta, uint24) { + beforeSwapCount[key.toId()]++; + lastHookData = hookData; return (BaseHook.beforeSwap.selector, BeforeSwapDeltaLibrary.ZERO_DELTA, 0); } - function afterSwap(address, PoolKey calldata, IPoolManager.SwapParams calldata, BalanceDelta, bytes calldata) - external - pure - override - returns (bytes4, int128) - { + function afterSwap( + address, + PoolKey calldata key, + IPoolManager.SwapParams calldata, + BalanceDelta, + bytes calldata hookData + ) external override onlyByMiddleware returns (bytes4, int128) { + afterSwapCount[key.toId()]++; + lastHookData = hookData; return (BaseHook.afterSwap.selector, 0); } - function beforeDonate(address, PoolKey calldata, uint256, uint256, bytes calldata) + function beforeDonate(address, PoolKey calldata key, uint256, uint256, bytes calldata hookData) external - pure override + onlyByMiddleware returns (bytes4) { + beforeDonateCount[key.toId()]++; + lastHookData = hookData; return BaseHook.beforeDonate.selector; } - function afterDonate(address, PoolKey calldata, uint256, uint256, bytes calldata) + function afterDonate(address, PoolKey calldata key, uint256, uint256, bytes calldata hookData) external - pure override + onlyByMiddleware returns (bytes4) { + afterDonateCount[key.toId()]++; + lastHookData = hookData; return BaseHook.afterDonate.selector; } }