diff --git a/contracts/interfaces/IBaseHook.sol b/contracts/interfaces/IBaseHook.sol new file mode 100644 index 00000000..7f404cf2 --- /dev/null +++ b/contracts/interfaces/IBaseHook.sol @@ -0,0 +1,9 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.24; + +import {Hooks} from "@uniswap/v4-core/src/libraries/Hooks.sol"; +import {IHooks} from "@uniswap/v4-core/src/interfaces/IHooks.sol"; + +interface IBaseHook is IHooks { + function getHookPermissions() external pure returns (Hooks.Permissions memory); +} diff --git a/contracts/interfaces/IMiddlewareFactory.sol b/contracts/interfaces/IMiddlewareFactory.sol new file mode 100644 index 00000000..4af554fe --- /dev/null +++ b/contracts/interfaces/IMiddlewareFactory.sol @@ -0,0 +1,17 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.19; + +interface IMiddlewareFactory { + event MiddlewareCreated(address implementation, address middleware); + + /// @notice Returns the implementation address for a given middleware + /// @param middleware The middleware address + /// @return implementation The implementation address + function getImplementation(address middleware) external view returns (address implementation); + + /// @notice Creates a middleware for the given implementation + /// @param implementation The implementation address + /// @param salt The salt to use to deploy the middleware + /// @return middleware The address of the newly created middleware + function createMiddleware(address implementation, bytes32 salt) external returns (address middleware); +} diff --git a/contracts/middleware/BaseMiddleware.sol b/contracts/middleware/BaseMiddleware.sol new file mode 100644 index 00000000..ebe5c609 --- /dev/null +++ b/contracts/middleware/BaseMiddleware.sol @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.24; + +import {Hooks} from "@uniswap/v4-core/src/libraries/Hooks.sol"; +import {IPoolManager} from "@uniswap/v4-core/src/interfaces/IPoolManager.sol"; +import {IHooks} from "@uniswap/v4-core/src/interfaces/IHooks.sol"; +import {BalanceDelta} from "@uniswap/v4-core/src/types/BalanceDelta.sol"; +import {PoolKey} from "@uniswap/v4-core/src/types/PoolKey.sol"; +import {BeforeSwapDelta} from "@uniswap/v4-core/src/types/BeforeSwapDelta.sol"; +import {Proxy} from "@openzeppelin/contracts/proxy/Proxy.sol"; + +contract BaseMiddleware is Proxy { + /// @notice The address of the pool manager + IPoolManager public immutable poolManager; + address public immutable implementation; + + constructor(IPoolManager _poolManager, address _impl) { + poolManager = _poolManager; + implementation = _impl; + } + + function _implementation() internal view override returns (address) { + return implementation; + } + + // yo i wanna delete this function but how do i remove this warning + receive() external payable { + _delegate(_implementation()); + } +} diff --git a/contracts/middleware/BaseMiddlewareFactory.sol b/contracts/middleware/BaseMiddlewareFactory.sol new file mode 100644 index 00000000..4b06fd6d --- /dev/null +++ b/contracts/middleware/BaseMiddlewareFactory.sol @@ -0,0 +1,34 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.19; + +import {IMiddlewareFactory} from "../interfaces/IMiddlewareFactory.sol"; +import {BaseMiddleware} from "./BaseMiddleware.sol"; +import {IPoolManager} from "@uniswap/v4-core/src/interfaces/IPoolManager.sol"; +import {IHooks} from "@uniswap/v4-core/src/interfaces/IHooks.sol"; +import {Hooks} from "@uniswap/v4-core/src/libraries/Hooks.sol"; +import {IBaseHook} from "../interfaces/IBaseHook.sol"; + +contract BaseMiddlewareFactory is IMiddlewareFactory { + mapping(address => address) private _implementations; + + IPoolManager public immutable poolManager; + + constructor(IPoolManager _poolManager) { + poolManager = _poolManager; + } + + function getImplementation(address middleware) external view override returns (address implementation) { + return _implementations[middleware]; + } + + function createMiddleware(address implementation, bytes32 salt) external override returns (address middleware) { + middleware = _deployMiddleware(implementation, salt); + Hooks.validateHookPermissions(IHooks(middleware), IBaseHook(implementation).getHookPermissions()); + _implementations[middleware] = implementation; + emit MiddlewareCreated(implementation, middleware); + } + + function _deployMiddleware(address implementation, bytes32 salt) internal virtual returns (address middleware) { + return address(new BaseMiddleware{salt: salt}(poolManager, implementation)); + } +} diff --git a/test/BaseMiddlewareFactory.t.sol b/test/BaseMiddlewareFactory.t.sol new file mode 100644 index 00000000..6ff3c166 --- /dev/null +++ b/test/BaseMiddlewareFactory.t.sol @@ -0,0 +1,124 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.19; + +import {Test} from "forge-std/Test.sol"; +import {Hooks} from "@uniswap/v4-core/src/libraries/Hooks.sol"; +import {Deployers} from "@uniswap/v4-core/test/utils/Deployers.sol"; +import {TestERC20} from "@uniswap/v4-core/src/test/TestERC20.sol"; +import {CurrencyLibrary, Currency} from "@uniswap/v4-core/src/types/Currency.sol"; +import {PoolId, PoolIdLibrary} from "@uniswap/v4-core/src/types/PoolId.sol"; +import {PoolKey} from "@uniswap/v4-core/src/types/PoolKey.sol"; +import {HookEnabledSwapRouter} from "./utils/HookEnabledSwapRouter.sol"; +import {BalanceDelta} from "@uniswap/v4-core/src/types/BalanceDelta.sol"; +import {IHooks} from "@uniswap/v4-core/src/interfaces/IHooks.sol"; +import {console} from "../../../lib/forge-std/src/console.sol"; +import {BaseMiddleware} from "../contracts/middleware/BaseMiddleware.sol"; +import {BaseMiddlewareFactory} from "./../contracts/middleware/BaseMiddlewareFactory.sol"; +import {HookMiner} from "./utils/HookMiner.sol"; +import {Counter} from "./middleware/Counter.sol"; +import {SafeCallback} from "./../contracts/base/SafeCallback.sol"; + +contract BaseMiddlewareFactoryTest is Test, Deployers { + HookEnabledSwapRouter router; + TestERC20 token0; + TestERC20 token1; + + BaseMiddlewareFactory factory; + Counter counter; + + address middleware; + + function setUp() public { + deployFreshManagerAndRouters(); + (currency0, currency1) = deployMintAndApprove2Currencies(); + + router = new HookEnabledSwapRouter(manager); + token0 = TestERC20(Currency.unwrap(currency0)); + token1 = TestERC20(Currency.unwrap(currency1)); + + factory = new BaseMiddlewareFactory(manager); + counter = new Counter(manager); + + token0.approve(address(router), type(uint256).max); + token1.approve(address(router), type(uint256).max); + + uint160 flags = uint160( + Hooks.BEFORE_INITIALIZE_FLAG | Hooks.AFTER_INITIALIZE_FLAG | Hooks.BEFORE_SWAP_FLAG | Hooks.AFTER_SWAP_FLAG + | Hooks.BEFORE_ADD_LIQUIDITY_FLAG | Hooks.AFTER_ADD_LIQUIDITY_FLAG | Hooks.BEFORE_REMOVE_LIQUIDITY_FLAG + | Hooks.AFTER_REMOVE_LIQUIDITY_FLAG | Hooks.BEFORE_DONATE_FLAG | Hooks.AFTER_DONATE_FLAG + ); + + (address hookAddress, bytes32 salt) = HookMiner.find( + address(factory), flags, type(BaseMiddleware).creationCode, abi.encode(address(manager), address(counter)) + ); + middleware = factory.createMiddleware(address(counter), salt); + assertEq(hookAddress, middleware); + } + + function testRevertOnSameDeployment() public { + uint160 flags = uint160( + Hooks.BEFORE_INITIALIZE_FLAG | Hooks.AFTER_INITIALIZE_FLAG | Hooks.BEFORE_SWAP_FLAG | Hooks.AFTER_SWAP_FLAG + | Hooks.BEFORE_ADD_LIQUIDITY_FLAG | Hooks.AFTER_ADD_LIQUIDITY_FLAG | Hooks.BEFORE_REMOVE_LIQUIDITY_FLAG + | Hooks.AFTER_REMOVE_LIQUIDITY_FLAG | Hooks.BEFORE_DONATE_FLAG | Hooks.AFTER_DONATE_FLAG + ); + (address hookAddress, bytes32 salt) = HookMiner.find( + address(factory), flags, type(BaseMiddleware).creationCode, abi.encode(address(manager), address(counter)) + ); + factory.createMiddleware(address(counter), salt); + // second deployment should revert + vm.expectRevert(bytes("")); + factory.createMiddleware(address(counter), salt); + } + + function testRevertOnIncorrectFlags() public { + Counter counter2 = new Counter(manager); + uint160 flags = uint160(Hooks.BEFORE_INITIALIZE_FLAG); + + (address hookAddress, bytes32 salt) = HookMiner.find( + address(factory), flags, type(BaseMiddleware).creationCode, abi.encode(address(manager), address(counter2)) + ); + address implementation = address(counter2); + vm.expectRevert(abi.encodePacked(bytes16(Hooks.HookAddressNotValid.selector), hookAddress)); + factory.createMiddleware(implementation, salt); + } + + function testRevertOnIncorrectFlagsMined() public { + Counter counter2 = new Counter(manager); + address implementation = address(counter2); + vm.expectRevert(); // HookAddressNotValid + factory.createMiddleware(implementation, bytes32("who needs to mine a salt?")); + } + + function testRevertOnIncorrectCaller() public { + vm.expectRevert(SafeCallback.NotManager.selector); + counter.afterDonate(address(this), key, 0, 0, ZERO_BYTES); + } + + function testCounters() public { + (PoolKey memory key, PoolId id) = + initPoolAndAddLiquidity(currency0, currency1, IHooks(middleware), 3000, SQRT_PRICE_1_1, ZERO_BYTES); + + Counter counterProxy = Counter(middleware); + assertEq(counterProxy.beforeInitializeCount(id), 1); + assertEq(counterProxy.afterInitializeCount(id), 1); + assertEq(counterProxy.beforeSwapCount(id), 0); + assertEq(counterProxy.afterSwapCount(id), 0); + assertEq(counterProxy.beforeAddLiquidityCount(id), 1); + assertEq(counterProxy.afterAddLiquidityCount(id), 1); + assertEq(counterProxy.beforeRemoveLiquidityCount(id), 0); + assertEq(counterProxy.afterRemoveLiquidityCount(id), 0); + assertEq(counterProxy.beforeDonateCount(id), 0); + assertEq(counterProxy.afterDonateCount(id), 0); + + assertEq(counterProxy.lastHookData(), ZERO_BYTES); + swap(key, true, 1, bytes("hi")); + assertEq(counterProxy.lastHookData(), bytes("hi")); + assertEq(counterProxy.beforeSwapCount(id), 1); + assertEq(counterProxy.afterSwapCount(id), 1); + + // counter does not store data itself + assertEq(counter.lastHookData(), bytes("")); + assertEq(counter.beforeSwapCount(id), 0); + assertEq(counter.afterSwapCount(id), 0); + } +} diff --git a/test/middleware/Counter.sol b/test/middleware/Counter.sol new file mode 100644 index 00000000..14b787af --- /dev/null +++ b/test/middleware/Counter.sol @@ -0,0 +1,167 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.24; + +import {BaseHook} from "./../../contracts/BaseHook.sol"; +import {Hooks} from "@uniswap/v4-core/src/libraries/Hooks.sol"; +import {IPoolManager} from "@uniswap/v4-core/src/interfaces/IPoolManager.sol"; +import {PoolKey} from "@uniswap/v4-core/src/types/PoolKey.sol"; +import {BeforeSwapDelta, BeforeSwapDeltaLibrary} from "@uniswap/v4-core/src/types/BeforeSwapDelta.sol"; +import {BalanceDelta, BalanceDeltaLibrary} from "@uniswap/v4-core/src/types/BalanceDelta.sol"; +import {PoolId, PoolIdLibrary} from "@uniswap/v4-core/src/types/PoolId.sol"; + +contract Counter is BaseHook { + using PoolIdLibrary for PoolKey; + + mapping(PoolId => uint256) public beforeInitializeCount; + mapping(PoolId => uint256) public afterInitializeCount; + + mapping(PoolId => uint256) public beforeSwapCount; + mapping(PoolId => uint256) public afterSwapCount; + + mapping(PoolId => uint256) public beforeAddLiquidityCount; + mapping(PoolId => uint256) public afterAddLiquidityCount; + mapping(PoolId => uint256) public beforeRemoveLiquidityCount; + mapping(PoolId => uint256) public afterRemoveLiquidityCount; + + mapping(PoolId => uint256) public beforeDonateCount; + mapping(PoolId => uint256) public afterDonateCount; + + bytes public lastHookData; + + constructor(IPoolManager _manager) BaseHook(_manager) {} + + // middleware implementations do not need to be mined + function validateHookAddress(BaseHook _this) internal pure override {} + + function getHookPermissions() public pure override returns (Hooks.Permissions memory) { + return Hooks.Permissions({ + beforeInitialize: true, + afterInitialize: true, + beforeAddLiquidity: true, + afterAddLiquidity: true, + beforeRemoveLiquidity: true, + afterRemoveLiquidity: true, + beforeSwap: true, + afterSwap: true, + beforeDonate: true, + afterDonate: true, + beforeSwapReturnDelta: false, + afterSwapReturnDelta: false, + afterAddLiquidityReturnDelta: false, + afterRemoveLiquidityReturnDelta: false + }); + } + + function beforeInitialize(address, PoolKey calldata key, uint160, bytes calldata hookData) + external + override + onlyByManager + returns (bytes4) + { + beforeInitializeCount[key.toId()]++; + lastHookData = hookData; + return BaseHook.beforeInitialize.selector; + } + + function afterInitialize(address, PoolKey calldata key, uint160, int24, bytes calldata hookData) + external + override + onlyByManager + returns (bytes4) + { + afterInitializeCount[key.toId()]++; + lastHookData = hookData; + return BaseHook.afterInitialize.selector; + } + + function beforeAddLiquidity( + address, + PoolKey calldata key, + IPoolManager.ModifyLiquidityParams calldata, + bytes calldata hookData + ) external override onlyByManager returns (bytes4) { + beforeAddLiquidityCount[key.toId()]++; + lastHookData = hookData; + return BaseHook.beforeAddLiquidity.selector; + } + + function afterAddLiquidity( + address, + PoolKey calldata key, + IPoolManager.ModifyLiquidityParams calldata, + BalanceDelta, + bytes calldata hookData + ) external override onlyByManager returns (bytes4, BalanceDelta) { + afterAddLiquidityCount[key.toId()]++; + lastHookData = hookData; + return (BaseHook.afterAddLiquidity.selector, BalanceDeltaLibrary.ZERO_DELTA); + } + + function beforeRemoveLiquidity( + address, + PoolKey calldata key, + IPoolManager.ModifyLiquidityParams calldata, + bytes calldata hookData + ) external override onlyByManager returns (bytes4) { + beforeRemoveLiquidityCount[key.toId()]++; + lastHookData = hookData; + return BaseHook.beforeRemoveLiquidity.selector; + } + + function afterRemoveLiquidity( + address, + PoolKey calldata key, + IPoolManager.ModifyLiquidityParams calldata, + BalanceDelta, + bytes calldata hookData + ) external override onlyByManager returns (bytes4, BalanceDelta) { + afterRemoveLiquidityCount[key.toId()]++; + lastHookData = hookData; + return (BaseHook.afterRemoveLiquidity.selector, BalanceDeltaLibrary.ZERO_DELTA); + } + + function beforeSwap(address, PoolKey calldata key, IPoolManager.SwapParams calldata, bytes calldata hookData) + external + override + onlyByManager + returns (bytes4, BeforeSwapDelta, uint24) + { + beforeSwapCount[key.toId()]++; + lastHookData = hookData; + return (BaseHook.beforeSwap.selector, BeforeSwapDeltaLibrary.ZERO_DELTA, 0); + } + + function afterSwap( + address, + PoolKey calldata key, + IPoolManager.SwapParams calldata, + BalanceDelta, + bytes calldata hookData + ) external override onlyByManager returns (bytes4, int128) { + afterSwapCount[key.toId()]++; + lastHookData = hookData; + return (BaseHook.afterSwap.selector, 0); + } + + function beforeDonate(address, PoolKey calldata key, uint256, uint256, bytes calldata hookData) + external + override + onlyByManager + returns (bytes4) + { + beforeDonateCount[key.toId()]++; + lastHookData = hookData; + return BaseHook.beforeDonate.selector; + } + + function afterDonate(address, PoolKey calldata key, uint256, uint256, bytes calldata hookData) + external + override + onlyByManager + returns (bytes4) + { + afterDonateCount[key.toId()]++; + lastHookData = hookData; + return BaseHook.afterDonate.selector; + } +} diff --git a/test/utils/HookMiner.sol b/test/utils/HookMiner.sol new file mode 100644 index 00000000..d6b30c40 --- /dev/null +++ b/test/utils/HookMiner.sol @@ -0,0 +1,52 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +pragma solidity ^0.8.21; + +/// @title HookMiner - a library for mining hook addresses +/// @dev This library is intended for `forge test` environments. There may be gotchas when using salts in `forge script` or `forge create` +library HookMiner { + // mask to slice out the bottom 14 bit of the address + uint160 constant FLAG_MASK = 0x3FFF; + + // Maximum number of iterations to find a salt, avoid infinite loops + uint256 constant MAX_LOOP = 100_000; + + /// @notice Find a salt that produces a hook address with the desired `flags` + /// @param deployer The address that will deploy the hook. In `forge test`, this will be the test contract `address(this)` or the pranking address + /// In `forge script`, this should be `0x4e59b44847b379578588920cA78FbF26c0B4956C` (CREATE2 Deployer Proxy) + /// @param flags The desired flags for the hook address + /// @param creationCode The creation code of a hook contract. Example: `type(Counter).creationCode` + /// @param constructorArgs The encoded constructor arguments of a hook contract. Example: `abi.encode(address(manager))` + /// @return hookAddress salt and corresponding address that was found. The salt can be used in `new Hook{salt: salt}()` + function find(address deployer, uint160 flags, bytes memory creationCode, bytes memory constructorArgs) + internal + view + returns (address, bytes32) + { + address hookAddress; + bytes memory creationCodeWithArgs = abi.encodePacked(creationCode, constructorArgs); + + uint256 salt; + for (salt; salt < MAX_LOOP; salt++) { + hookAddress = computeAddress(deployer, salt, creationCodeWithArgs); + if (uint160(hookAddress) & FLAG_MASK == flags && hookAddress.code.length == 0) { + return (hookAddress, bytes32(salt)); + } + } + revert("HookMiner: could not find salt"); + } + + /// @notice Precompute a contract address deployed via CREATE2 + /// @param deployer The address that will deploy the hook. In `forge test`, this will be the test contract `address(this)` or the pranking address + /// In `forge script`, this should be `0x4e59b44847b379578588920cA78FbF26c0B4956C` (CREATE2 Deployer Proxy) + /// @param salt The salt used to deploy the hook + /// @param creationCode The creation code of a hook contract + function computeAddress(address deployer, uint256 salt, bytes memory creationCode) + internal + pure + returns (address hookAddress) + { + return address( + uint160(uint256(keccak256(abi.encodePacked(bytes1(0xFF), deployer, salt, keccak256(creationCode))))) + ); + } +}