diff --git a/.gitmodules b/.gitmodules index 0f3aff0..870f249 100644 --- a/.gitmodules +++ b/.gitmodules @@ -4,3 +4,6 @@ [submodule "lib/amm-core"] path = lib/amm-core url = https://github.com/reservoir-labs/amm-core +[submodule "lib/solady"] + path = lib/solady + url = https://github.com/Vectorized/solady diff --git a/foundry.toml b/foundry.toml index df54b68..3da5378 100644 --- a/foundry.toml +++ b/foundry.toml @@ -1,5 +1,5 @@ [profile.default] -solc = "0.8.23" +solc = "0.8.26" #via_ir = true bytecode_hash = "ipfs" optimizer_runs = 1_000_000 @@ -8,10 +8,10 @@ remappings = [ "amm-core/=lib/amm-core/src/" ] match_path = "test/unit/*.sol" +block_base_fee_per_gas = 10_000_000 # set to arbitrum's base fee of 0.01 gwei verbosity = 3 fs_permissions = [ - { access = "write", path = "./script/optimized-deployer-meta" }, - { access = "write", path = "./script/unoptimized-deployer-meta" }, + { access = "read", path = "./out" } ] ignored_error_codes = [] diff --git a/lib/solady b/lib/solady new file mode 160000 index 0000000..183a5c9 --- /dev/null +++ b/lib/solady @@ -0,0 +1 @@ +Subproject commit 183a5c9cc3ca982492ae5fcca9e7ed6668ddb6ac diff --git a/src/Enums.sol b/src/Enums.sol index f6fc387..e726fe4 100644 --- a/src/Enums.sol +++ b/src/Enums.sol @@ -1,18 +1,17 @@ -// SPDX-License-Identifier: UNLICENSED +// SPDX-License-Identifier: GPL-3.0-or-later pragma solidity ^0.8.0; // The two values that can be queried: // -// - PAIR_PRICE: the price of the tokens in the Pool, expressed as the price of the second token in units of the +// - RAW_PRICE: the price of the tokens in the Pool, expressed as the price of the second token in units of the // first token. For example, if token A is worth $2, and token B is worth $4, the pair price will be 2.0. // Note that the price is computed *including* the tokens decimals. This means that the pair price of a Pool with // DAI and USDC will be close to 1.0, despite DAI having 18 decimals and USDC 6. // -// - BPT_PRICE: the price of the Pool share token (BPT), in units of the first token. -// Note that the price is computed *including* the tokens decimals. This means that the BPT price of a Pool with -// USDC in which BPT is worth $5 will be 5.0, despite the BPT having 18 decimals and USDC 6. -// -// - INVARIANT: the value of the Pool's invariant, which serves as a measure of its liquidity. +// - CLAMPED_PRICE: the clamped price of the tokens in the Pool, in units of the first token. Clamping is necessary as +// as a countermeasure to oracle manipulation attempts. +// Refer to `maxChangeRate` and `maxChangePerTrade` in `ReservoirPair` and the `Observation` struct +// Note that the price is computed *including* the tokens decimals, just like the raw price. enum Variable { RAW_PRICE, CLAMPED_PRICE diff --git a/src/Errors.sol b/src/Errors.sol deleted file mode 100644 index b2e394c..0000000 --- a/src/Errors.sol +++ /dev/null @@ -1,9 +0,0 @@ -// SPDX-License-Identifier: UNLICENSED -pragma solidity ^0.8.0; - -// oracle related errors -error BadVariableRequest(); -error OracleNotInitialized(); -error InvalidSeconds(); -error QueryTooOld(); -error BadSecs(); diff --git a/src/ReservoirPriceCache.sol b/src/ReservoirPriceCache.sol deleted file mode 100644 index eed8c66..0000000 --- a/src/ReservoirPriceCache.sol +++ /dev/null @@ -1,174 +0,0 @@ -// SPDX-License-Identifier: UNLICENSED -pragma solidity ^0.8.0; - -import { Owned } from "lib/amm-core/lib/solmate/src/auth/Owned.sol"; -import { ReentrancyGuard } from "lib/amm-core/lib/solmate/src/utils/ReentrancyGuard.sol"; -import { FixedPointMathLib } from "lib/amm-core/lib/solmate/src/utils/FixedPointMathLib.sol"; - -import { ReservoirPair } from "amm-core/ReservoirPair.sol"; - -import { - IReservoirPriceOracle, - OracleAverageQuery, - OracleAccumulatorQuery, - Variable -} from "src/interfaces/IReservoirPriceOracle.sol"; - -contract ReservoirPriceCache is Owned(msg.sender), ReentrancyGuard { - using FixedPointMathLib for uint256; - - /////////////////////////////////////////////////////////////////////////////////////////////// - // CONSTANTS // - /////////////////////////////////////////////////////////////////////////////////////////////// - - uint256 MAX_DEVIATION_THRESHOLD = 0.2e18; // 20% - uint256 MAX_TWAP_PERIOD = 1 hours; - - /////////////////////////////////////////////////////////////////////////////////////////////// - // EVENTS // - /////////////////////////////////////////////////////////////////////////////////////////////// - - event Oracle(address newOracle); - event TwapPeriod(uint256 newPeriod); - event PriceDeviationThreshold(uint256 newThreshold); - event RewardMultiplier(uint256 newMultiplier); - event Price(address indexed pair, uint256 price); - - /////////////////////////////////////////////////////////////////////////////////////////////// - // ERRORS // - /////////////////////////////////////////////////////////////////////////////////////////////// - - error RPC_THRESHOLD_TOO_HIGH(); - error RPC_TWAP_PERIOD_TOO_HIGH(); - - /////////////////////////////////////////////////////////////////////////////////////////////// - // STORAGE // - /////////////////////////////////////////////////////////////////////////////////////////////// - - IReservoirPriceOracle public oracle; - - /// @notice percentage change greater than which, a price update with the oracles would succeed - /// 1e18 == 100% - uint64 public priceDeviationThreshold; - - /// @notice percentage of gas fee the contract rewards the caller for updating the price - /// 1e18 == 100% - uint64 public rewardMultiplier; - - /// @notice TWAP period for querying the oracle - uint64 public twapPeriod; - - // for a certain pair, regardless of the curve id, the latest cached price of token1/token0 - // calculate reciprocal to for price of token0/token1 - mapping(address => uint256) public priceCache; - - /////////////////////////////////////////////////////////////////////////////////////////////// - // CONSTRUCTOR, FALLBACKS // - /////////////////////////////////////////////////////////////////////////////////////////////// - - constructor(IReservoirPriceOracle aOracle, uint64 aThreshold, uint64 aTwapPeriod, uint64 aMultiplier) { - updatePriceDeviationThreshold(aThreshold); - updateTwapPeriod(aTwapPeriod); - updateRewardMultiplier(aMultiplier); - } - - /// @dev contract will hold native tokens to be distributed as gas bounty for updating the prices - /// anyone can contribute native tokens to this contract - receive() external payable { } - - /////////////////////////////////////////////////////////////////////////////////////////////// - // PUBLIC FUNCTIONS // - /////////////////////////////////////////////////////////////////////////////////////////////// - - // admin functions - - function updateOracle(address aOracle) public onlyOwner { - oracle = IReservoirPriceOracle(aOracle); - } - - function updatePriceDeviationThreshold(uint64 aNewThreshold) public onlyOwner { - if (aNewThreshold > MAX_DEVIATION_THRESHOLD) { - revert RPC_THRESHOLD_TOO_HIGH(); - } - - priceDeviationThreshold = aNewThreshold; - emit PriceDeviationThreshold(aNewThreshold); - } - - function updateTwapPeriod(uint64 aNewPeriod) public onlyOwner { - if (aNewPeriod > MAX_TWAP_PERIOD) { - revert RPC_TWAP_PERIOD_TOO_HIGH(); - } - twapPeriod = aNewPeriod; - emit TwapPeriod(aNewPeriod); - } - - function updateRewardMultiplier(uint64 aNewMultiplier) public onlyOwner { - rewardMultiplier = aNewMultiplier; - emit RewardMultiplier(aNewMultiplier); - } - - // oracle price functions - - function getPriceForPair(address aPair) external view returns (uint256) { - return priceCache[aPair]; - } - - // price update related functions - - function isPriceUpdateIncentivized() external view returns (bool) { - return address(this).balance > 0; - } - - function gasBountyAvailable() external view returns (uint256) { - return address(this).balance; - } - - /// @dev we do not allow specifying which address gets the reward, to save on calldata gas - function updatePrice(address aPair) external nonReentrant { - ReservoirPair lPair = ReservoirPair(aPair); - - OracleAverageQuery[] memory lQueries; - lQueries[0] = OracleAverageQuery( - Variable.RAW_PRICE, - address(lPair.token0()), - address(lPair.token1()), - twapPeriod, - 0 // now - ); - - // reads new price from pair - uint256 lNewPrice = oracle.getTimeWeightedAverage(lQueries)[0]; - - // determine if price has moved beyond the threshold - // reward caller if so - if (_calcPercentageDiff(lNewPrice, priceCache[aPair]) >= priceDeviationThreshold) { - _rewardUpdater(msg.sender); - } - - priceCache[aPair] = lNewPrice; - emit Price(aPair, lNewPrice); - } - - /////////////////////////////////////////////////////////////////////////////////////////////// - // INTERNAL FUNCTIONS // - /////////////////////////////////////////////////////////////////////////////////////////////// - - // TODO: replace this with safe, audited lib function - function _calcPercentageDiff(uint256 aOriginal, uint256 aNew) internal pure returns (uint256) { - if (aOriginal > aNew) { - return (aOriginal - aNew) * 1e18 / aOriginal; - } else { - return (aNew - aOriginal) * 1e18 / aOriginal; - } - } - - function _rewardUpdater(address lRecipient) internal { - // TODO: make sure this works on L1 as well as L2s - uint256 lPayoutAmt = block.basefee.mulWadDown(rewardMultiplier); - - if (lPayoutAmt <= address(this).balance) { - payable(lRecipient).transfer(lPayoutAmt); - } else { } // do nothing if lPayoutAmt is greater than the balance - } -} diff --git a/src/ReservoirPriceOracle.sol b/src/ReservoirPriceOracle.sol index dbb843d..fac5f61 100644 --- a/src/ReservoirPriceOracle.sol +++ b/src/ReservoirPriceOracle.sol @@ -1,6 +1,9 @@ -// SPDX-License-Identifier: UNLICENSED +// SPDX-License-Identifier: GPL-3.0-or-later pragma solidity ^0.8.0; +import { IERC20 } from "forge-std/interfaces/IERC20.sol"; + +import { OracleErrors } from "src/libraries/OracleErrors.sol"; import { IReservoirPriceOracle, OracleAverageQuery, @@ -8,25 +11,541 @@ import { OracleAccumulatorQuery, Variable } from "src/interfaces/IReservoirPriceOracle.sol"; +import { IPriceOracle } from "src/interfaces/IPriceOracle.sol"; +import { QueryProcessor, ReservoirPair, Buffer } from "src/libraries/QueryProcessor.sol"; +import { Utils } from "src/libraries/Utils.sol"; +import { Owned } from "lib/amm-core/lib/solmate/src/auth/Owned.sol"; +import { ReentrancyGuard } from "lib/amm-core/lib/solmate/src/utils/ReentrancyGuard.sol"; +import { FixedPointMathLib } from "lib/amm-core/lib/solady/src/utils/FixedPointMathLib.sol"; +import { LibSort } from "lib/solady/src/utils/LibSort.sol"; +import { Constants } from "src/libraries/Constants.sol"; +import { FlagsLib } from "src/libraries/FlagsLib.sol"; -contract ReservoirPriceOracle is IReservoirPriceOracle { - function getTimeWeightedAverage(OracleAverageQuery[] memory aQueries) +contract ReservoirPriceOracle is IPriceOracle, IReservoirPriceOracle, Owned(msg.sender), ReentrancyGuard { + using FixedPointMathLib for uint256; + using LibSort for address[]; + using FlagsLib for *; + using QueryProcessor for ReservoirPair; + using Utils for *; + + /////////////////////////////////////////////////////////////////////////////////////////////// + // EVENTS // + /////////////////////////////////////////////////////////////////////////////////////////////// + + event DesignatePair(address token0, address token1, ReservoirPair pair); + event PriceDeviationThreshold(uint256 newThreshold); + event RewardGasAmount(uint256 newAmount); + event Route(address token0, address token1, address[] route); + event Price(address token0, address token1, uint256 price); + event TwapPeriod(uint256 newPeriod); + + /////////////////////////////////////////////////////////////////////////////////////////////// + // STORAGE // + /////////////////////////////////////////////////////////////////////////////////////////////// + + /// @notice percentage change greater than which, a price update may result in a reward payout of native tokens, + /// subject to availability of rewards. + /// 1e18 == 100% + uint64 public priceDeviationThreshold; + + /// @notice multiples of the base fee the contract rewards the caller for updating the price when it goes + /// beyond the `priceDeviationThreshold` + uint64 public rewardGasAmount; + + /// @notice TWAP period (in seconds) for querying the oracle + uint64 public twapPeriod; + + /// @notice Designated pairs to serve as price feed for a certain token0 and token1 + mapping(address token0 => mapping(address token1 => ReservoirPair pair)) public pairs; + + /////////////////////////////////////////////////////////////////////////////////////////////// + // CONSTRUCTOR, FALLBACKS // + /////////////////////////////////////////////////////////////////////////////////////////////// + + constructor(uint64 aThreshold, uint64 aTwapPeriod, uint64 aMultiplier) { + updatePriceDeviationThreshold(aThreshold); + updateTwapPeriod(aTwapPeriod); + updateRewardGasAmount(aMultiplier); + } + + /// @dev contract will hold native tokens to be distributed as gas bounty for updating the prices + /// anyone can contribute native tokens to this contract + receive() external payable { } + + /////////////////////////////////////////////////////////////////////////////////////////////// + // PUBLIC FUNCTIONS // + /////////////////////////////////////////////////////////////////////////////////////////////// + + // IPriceOracle + + function name() external pure returns (string memory) { + return "RESERVOIR PRICE ORACLE"; + } + + /// @inheritdoc IPriceOracle + function getQuote(uint256 aAmount, address aBase, address aQuote) external view returns (uint256 rOut) { + rOut = _getQuote(aAmount, aBase, aQuote); + } + + /// @inheritdoc IPriceOracle + function getQuotes(uint256 aAmount, address aBase, address aQuote) external view + returns (uint256 rBidOut, uint256 rAskOut) + { + uint256 lResult = _getQuote(aAmount, aBase, aQuote); + (rBidOut, rAskOut) = (lResult, lResult); + } + + // price update related functions + + function gasBountyAvailable() external view returns (uint256) { + return address(this).balance; + } + + function route(address aToken0, address aToken1) external view returns (address[] memory rRoute) { + (rRoute,,) = _getRouteDecimalDifferencePrice(aToken0, aToken1); + } + + /// @notice The latest cached geometric TWAP of token0/token1. + /// Stored in the form of a 18 decimal fixed point number. + /// Supported price range: 1wei to `Constants.MAX_SUPPORTED_PRICE`. + /// Only stores prices of simple routes. Does not store prices of composite routes. + /// @param aToken0 Address of the lower token. + /// @param aToken1 Address of the higher token. + /// @return rPrice The cached price of aToken0/aToken1 for simple routes. Returns 0 for prices of composite routes. + /// @return rDecimalDiff The difference in decimals as defined by aToken1.decimals() - aToken0.decimals(). Only valid for simple routes. + function priceCache(address aToken0, address aToken1) external view returns (uint256 rPrice, int256 rDecimalDiff) { + (rPrice, rDecimalDiff) = _priceCache(aToken0, aToken1); + } + + /// @notice Updates the TWAP price for all simple routes between `aTokenA` and `aTokenB`. Will also update intermediate routes if the route defined between + /// `aTokenA` and `aTokenB` is longer than 1 hop + /// However, if the route between aTokenA and aTokenB is composite route (more than 1 hop), no cache entry is written + /// for priceCache[aTokenA][aTokenB] but instead the prices of its constituent simple routes will be written. + /// Reverts if price is 0 or greater than `Constants.MAX_SUPPORTED_PRICE`. + /// @param aTokenA Address of one of the tokens for the price update. Does not have to be less than address of aTokenB + /// @param aTokenB Address of one of the tokens for the price update. Does not have to be greater than address of aTokenA + /// @param aRewardRecipient The beneficiary of the reward. Must be able to receive ether. Set to address(0) if not seeking a reward + function updatePrice(address aTokenA, address aTokenB, address aRewardRecipient) public nonReentrant { + (address lToken0, address lToken1) = aTokenA.sortTokens(aTokenB); + + (address[] memory lRoute,,) = _getRouteDecimalDifferencePrice(lToken0, lToken1); + if (lRoute.length == 0) revert OracleErrors.NoPath(); + + OracleAverageQuery[] memory lQueries = new OracleAverageQuery[](lRoute.length - 1); + + for (uint256 i = 0; i < lRoute.length - 1; ++i) { + (lToken0, lToken1) = lRoute[i].sortTokens(lRoute[i + 1]); + + lQueries[i] = OracleAverageQuery( + Variable.RAW_PRICE, + lToken0, + lToken1, + twapPeriod, + 0 // now + ); + } + + uint256[] memory lNewPrices = getTimeWeightedAverage(lQueries); + + for (uint256 i = 0; i < lNewPrices.length; ++i) { + address lBase = lQueries[i].base; + address lQuote = lQueries[i].quote; + uint256 lNewPrice = lNewPrices[i]; + + // assumed to be simple routes and therefore lPrevPrice would only be 0 for the first update + // consider an optimization here for simple routes: no need to read the price cache again + // as it has been returned by _getRouteDecimalDifferencePrice in the beginning of the function + (uint256 lPrevPrice,) = _priceCache(lBase, lQuote); + + // determine if price has moved beyond the threshold, and pay out reward if so + if (_calcPercentageDiff(lPrevPrice, lNewPrice) >= priceDeviationThreshold) { + _rewardUpdater(aRewardRecipient); + } + + _writePriceCache(lBase, lQuote, lNewPrice); + } + } + + // IReservoirPriceOracle + + /// @inheritdoc IReservoirPriceOracle + function getTimeWeightedAverage(OracleAverageQuery[] memory aQueries) + public + view returns (uint256[] memory rResults) - { } + { + rResults = new uint256[](aQueries.length); - function getLatest(OracleLatestQuery calldata aQuery) external view returns (uint256) { - return 0; + OracleAverageQuery memory lQuery; + for (uint256 i = 0; i < aQueries.length; ++i) { + lQuery = aQueries[i]; + ReservoirPair lPair = pairs[lQuery.base][lQuery.quote]; + _validatePair(lPair); + + (,,, uint16 lIndex) = lPair.getReserves(); + uint256 lResult = lPair.getTimeWeightedAverage(lQuery.variable, lQuery.secs, lQuery.ago, lIndex); + rResults[i] = lResult; + } } - function getLargestSafeQueryWindow() external view returns (uint256) { - return 0; + /// @inheritdoc IReservoirPriceOracle + function getLatest(OracleLatestQuery calldata aQuery) external view returns (uint256) { + ReservoirPair lPair = pairs[aQuery.base][aQuery.quote]; + _validatePair(lPair); + + (,,, uint256 lIndex) = lPair.getReserves(); + uint256 lResult = lPair.getInstantValue(aQuery.variable, lIndex); + return lResult; } + /// @inheritdoc IReservoirPriceOracle function getPastAccumulators(OracleAccumulatorQuery[] memory aQueries) external view returns (int256[] memory rResults) - { } + { + rResults = new int256[](aQueries.length); + + OracleAccumulatorQuery memory lQuery; + for (uint256 i = 0; i < aQueries.length; ++i) { + lQuery = aQueries[i]; + ReservoirPair lPair = pairs[lQuery.base][lQuery.quote]; + _validatePair(lPair); + + (,,, uint16 lIndex) = lPair.getReserves(); + int256 lAcc = lPair.getPastAccumulator(lQuery.variable, lIndex, lQuery.ago); + rResults[i] = lAcc; + } + } + + /// @inheritdoc IReservoirPriceOracle + function getLargestSafeQueryWindow() external pure returns (uint256) { + return Buffer.SIZE; + } + + /////////////////////////////////////////////////////////////////////////////////////////////// + // INTERNAL FUNCTIONS // + /////////////////////////////////////////////////////////////////////////////////////////////// + + function _validatePair(ReservoirPair aPair) internal pure { + if (address(aPair) == address(0)) revert OracleErrors.NoDesignatedPair(); + } + + // TODO: replace this with safe, audited lib function + function _calcPercentageDiff(uint256 aOriginal, uint256 aNew) internal pure returns (uint256) { + if (aOriginal == 0) return 0; + + if (aOriginal > aNew) { + return (aOriginal - aNew) * 1e18 / aOriginal; + } else { + return (aNew - aOriginal) * 1e18 / aOriginal; + } + } + + function _rewardUpdater(address aRecipient) internal { + if (aRecipient == address(0)) return; + + // N.B. Revisit this whenever deployment on a new chain is needed + // we use `block.basefee` instead of `ArbGasInfo::getMinimumGasPrice()` on ARB because the latter will always return + // the demand insensitive base fee, while the former can return real higher fees during times of congestion + // safety: this mul will not overflow even in extreme cases of `block.basefee` + uint256 lPayoutAmt = block.basefee * rewardGasAmount; + + if (lPayoutAmt <= address(this).balance) { + payable(aRecipient).transfer(lPayoutAmt); + } else { } // do nothing if lPayoutAmt is greater than the balance + } + + /// @return rRoute The route to determine the price between aToken0 and aToken1 + /// @return rDecimalDiff The result of token1.decimals() - token0.decimals() if it's a simple route. 0 otherwise + /// @return rPrice The price of aToken0/aToken1 if it's a simple route (i.e. rRoute.length == 2). 0 otherwise + function _getRouteDecimalDifferencePrice(address aToken0, address aToken1) + internal + view + returns (address[] memory rRoute, int256 rDecimalDiff, uint256 rPrice) + { + address[] memory lResults = new address[](Constants.MAX_ROUTE_LENGTH); + bytes32 lSlot = aToken0.calculateSlot(aToken1); + + bytes32 lFirstWord; + uint256 lRouteLength; + assembly { + lFirstWord := sload(lSlot) + } + + // simple route + if (lFirstWord.isSimplePrice()) { + lResults[0] = aToken0; + lResults[1] = aToken1; + lRouteLength = 2; + rDecimalDiff = lFirstWord.getDecimalDifference(); + rPrice = lFirstWord.getPrice(); + } + // composite route + else if (lFirstWord.isCompositeRoute()) { + address lSecondToken = lFirstWord.getTokenFirstWord(); + + lResults[0] = aToken0; + lResults[1] = lSecondToken; + + if (lFirstWord.is3HopRoute()) { + bytes32 lSecondWord; + assembly { + lSecondWord := sload(add(lSlot, 1)) + } + address lThirdToken = lFirstWord.getThirdToken(lSecondWord); + + lResults[2] = lThirdToken; + lResults[3] = aToken1; + lRouteLength = 4; + } else { + lResults[2] = aToken1; + lRouteLength = 3; + } + } + // no route + else if (lFirstWord.isUninitialized()) { } + + rRoute = new address[](lRouteLength); + for (uint256 i = 0; i < lRouteLength; ++i) { + rRoute[i] = lResults[i]; + } + } + + /// Calculate the storage slot for this intermediate segment and read it to see if there is an existing + /// route. If there isn't an existing route, we write it as well. + /// @dev assumed that aToken0 and aToken1 are not necessarily sorted + function _checkAndPopulateIntermediateRoute(address aToken0, address aToken1) internal { + (address lLowerToken, address lHigherToken) = aToken0.sortTokens(aToken1); + + bytes32 lSlot = lLowerToken.calculateSlot(lHigherToken); + bytes32 lData; + assembly { + lData := sload(lSlot) + } + if (lData == bytes32(0)) { + address[] memory lIntermediateRoute = new address[](2); + lIntermediateRoute[0] = lLowerToken; + lIntermediateRoute[1] = lHigherToken; + setRoute(lLowerToken, lHigherToken, lIntermediateRoute); + } + } + + // performs an SLOAD to load 1 word which contains the simple price and decimal difference + function _priceCache(address aToken0, address aToken1) + internal + view + returns (uint256 rPrice, int256 rDecimalDiff) + { + bytes32 lSlot = aToken0.calculateSlot(aToken1); + + bytes32 lData; + assembly { + lData := sload(lSlot) + } + if (lData.isSimplePrice()) { + rPrice = lData.getPrice(); + rDecimalDiff = lData.getDecimalDifference(); + } + } + + function _writePriceCache(address aToken0, address aToken1, uint256 aNewPrice) internal { + if (aNewPrice == 0 || aNewPrice > Constants.MAX_SUPPORTED_PRICE) revert OracleErrors.PriceOutOfRange(aNewPrice); + + bytes32 lSlot = aToken0.calculateSlot(aToken1); + bytes32 lData; + assembly { + lData := sload(lSlot) + } + if (!lData.isSimplePrice()) revert OracleErrors.WriteToNonSimpleRoute(); + + int256 lDiff = lData.getDecimalDifference(); + + lData = lDiff.packSimplePrice(aNewPrice); + assembly { + sstore(lSlot, lData) + } + } + + function _getQuote(uint256 aAmount, address aBase, address aQuote) internal view returns (uint256 rOut) { + if (aBase == aQuote) return aAmount; + if (aAmount > Constants.MAX_AMOUNT_IN) revert OracleErrors.AmountInTooLarge(); + + (address lToken0, address lToken1) = aBase.sortTokens(aQuote); + (address[] memory lRoute, int256 lDecimalDiff, uint256 lPrice) = + _getRouteDecimalDifferencePrice(lToken0, lToken1); + + if (lRoute.length == 0) { + revert OracleErrors.NoPath(); + } else if (lRoute.length == 2) { + if (lPrice == 0) revert OracleErrors.PriceZero(); + rOut = _calcAmtOut(aAmount, lPrice, lDecimalDiff, lRoute[0] != aBase); + } + // for composite route, read simple prices to derive composite price + else { + uint256 lIntermediateAmount = aAmount; + + // reverse the route so we always perform calculations starting from index 0 + if (lRoute[0] != aBase) lRoute.reverse(); + assert(lRoute[0] == aBase); + + for (uint256 i = 0; i < lRoute.length - 1; ++i) { + (address lLowerToken, address lHigherToken) = lRoute[i].sortTokens(lRoute[i + 1]); + // it is assumed that intermediate routes defined here are simple routes and not composite routes + (lPrice, lDecimalDiff) = _priceCache(lLowerToken, lHigherToken); + + if (lPrice == 0) revert OracleErrors.PriceZero(); + lIntermediateAmount = _calcAmtOut(lIntermediateAmount, lPrice, lDecimalDiff, lRoute[i] != lLowerToken); + } + rOut = lIntermediateAmount; + } + } + + /// @dev aPrice assumed to be > 0, as checked by _getQuote + function _calcAmtOut(uint256 aAmountIn, uint256 aPrice, int256 aDecimalDiff, bool aInverse) + internal + pure + returns (uint256 rOut) + { + // formula: baseAmountOut = quoteAmountIn * Constants.WAD * baseDecimalScale / baseQuotePrice / quoteDecimalScale + if (aInverse) { + if (aDecimalDiff > 0) { + rOut = aAmountIn.fullMulDiv(Constants.WAD, aPrice) / 10 ** uint256(aDecimalDiff); + } else if (aDecimalDiff < 0) { + rOut = aAmountIn.fullMulDiv(Constants.WAD * 10 ** uint256(-aDecimalDiff), aPrice); + } + // equal decimals + else { + rOut = aAmountIn.fullMulDiv(Constants.WAD, aPrice); + } + } else { + // formula: quoteAmountOut = baseAmountIn * baseQuotePrice * quoteDecimalScale / baseDecimalScale / Constants.WAD + if (aDecimalDiff > 0) { + rOut = aAmountIn.fullMulDiv(aPrice * 10 ** uint256(aDecimalDiff), Constants.WAD); + } else if (aDecimalDiff < 0) { + rOut = aAmountIn.fullMulDiv(aPrice, 10 ** uint256(-aDecimalDiff) * Constants.WAD); + } else { + rOut = aAmountIn.fullMulDiv(aPrice, Constants.WAD); + } + } + } + + /////////////////////////////////////////////////////////////////////////////////////////////// + // ADMIN FUNCTIONS // + /////////////////////////////////////////////////////////////////////////////////////////////// + + function updatePriceDeviationThreshold(uint64 aNewThreshold) public onlyOwner { + if (aNewThreshold > Constants.MAX_DEVIATION_THRESHOLD) { + revert OracleErrors.PriceDeviationThresholdTooHigh(); + } + + priceDeviationThreshold = aNewThreshold; + emit PriceDeviationThreshold(aNewThreshold); + } + + function updateTwapPeriod(uint64 aNewPeriod) public onlyOwner { + if (aNewPeriod == 0 || aNewPeriod > Constants.MAX_TWAP_PERIOD) { + revert OracleErrors.InvalidTwapPeriod(); + } + twapPeriod = aNewPeriod; + emit TwapPeriod(aNewPeriod); + } + + function updateRewardGasAmount(uint64 aNewMultiplier) public onlyOwner { + rewardGasAmount = aNewMultiplier; + emit RewardGasAmount(aNewMultiplier); + } + + // sets a specific pair to serve as price feed for a certain route + function designatePair(address aToken0, address aToken1, ReservoirPair aPair) external onlyOwner { + (aToken0, aToken1) = aToken0.sortTokens(aToken1); + assert(aToken0 == address(aPair.token0()) && aToken1 == address(aPair.token1())); + + pairs[aToken0][aToken1] = aPair; + emit DesignatePair(aToken0, aToken1, aPair); + } + + function undesignatePair(address aToken0, address aToken1) external onlyOwner { + (aToken0, aToken1) = aToken0.sortTokens(aToken1); + + delete pairs[aToken0][aToken1]; + emit DesignatePair(aToken0, aToken1, ReservoirPair(address(0))); + } + + /// @notice Sets the price route between aToken0 and aToken1, and also intermediate routes if previously undefined + /// @param aToken0 Address of the lower token + /// @param aToken1 Address of the higher token + /// @param aRoute Path with which the price between aToken0 and aToken1 should be derived + function setRoute(address aToken0, address aToken1, address[] memory aRoute) public onlyOwner { + uint256 lRouteLength = aRoute.length; + + if (aToken0 == aToken1) revert OracleErrors.SameToken(); + if (aToken1 < aToken0) revert OracleErrors.TokensUnsorted(); + if (lRouteLength > Constants.MAX_ROUTE_LENGTH || lRouteLength < 2) revert OracleErrors.InvalidRouteLength(); + if (aRoute[0] != aToken0 || aRoute[lRouteLength - 1] != aToken1) revert OracleErrors.InvalidRoute(); + + bytes32 lSlot = aToken0.calculateSlot(aToken1); + + // simple route + if (lRouteLength == 2) { + uint256 lToken0Decimals = IERC20(aToken0).decimals(); + uint256 lToken1Decimals = IERC20(aToken1).decimals(); + if (lToken0Decimals > 18 || lToken1Decimals > 18) revert OracleErrors.UnsupportedTokenDecimals(); + + int256 lDiff = int256(lToken1Decimals) - int256(lToken0Decimals); + + bytes32 lData = lDiff.packSimplePrice(0); + assembly { + // Write data to storage. + sstore(lSlot, lData) + } + } + // composite route + else { + address lSecondToken = aRoute[1]; + address lThirdToken = aRoute[2]; + + if (lRouteLength == 3) { + bytes32 lData = lSecondToken.pack2HopRoute(); + assembly { + sstore(lSlot, lData) + } + } else if (lRouteLength == 4) { + (bytes32 lFirstWord, bytes32 lSecondWord) = lSecondToken.pack3HopRoute(lThirdToken); + + // Write two words to storage. + assembly { + sstore(lSlot, lFirstWord) + sstore(add(lSlot, 1), lSecondWord) + } + _checkAndPopulateIntermediateRoute(lThirdToken, aToken1); + } + _checkAndPopulateIntermediateRoute(aToken0, lSecondToken); + _checkAndPopulateIntermediateRoute(lSecondToken, lThirdToken); + } + emit Route(aToken0, aToken1, aRoute); + } + + function clearRoute(address aToken0, address aToken1) external onlyOwner { + if (aToken0 == aToken1) revert OracleErrors.SameToken(); + if (aToken1 < aToken0) revert OracleErrors.TokensUnsorted(); + + (address[] memory lRoute,,) = _getRouteDecimalDifferencePrice(aToken0, aToken1); + + bytes32 lSlot = aToken0.calculateSlot(aToken1); + + // clear all storage slots that the route has written to previously + assembly { + sstore(lSlot, 0) + } + // routes with length 4 use two words of storage + if (lRoute.length == 4) { + assembly { + sstore(add(lSlot, 1), 0) + } + } + emit Route(aToken0, aToken1, new address[](0)); + } } diff --git a/src/Structs.sol b/src/Structs.sol index fb03eb3..3323789 100644 --- a/src/Structs.sol +++ b/src/Structs.sol @@ -1,4 +1,4 @@ -// SPDX-License-Identifier: UNLICENSED +// SPDX-License-Identifier: GPL-3.0-or-later pragma solidity ^0.8.0; import { Variable } from "src/Enums.sol"; @@ -9,6 +9,7 @@ import { Variable } from "src/Enums.sol"; * Each query computes the average over a window of duration `secs` seconds that ended `ago` seconds ago. For * example, the average over the past 30 minutes is computed by settings secs to 1800 and ago to 0. If secs is 1800 * and ago is 1800 as well, the average between 60 and 30 minutes ago is computed instead. + * The address of `base` is strictly less than the address of `quote` */ struct OracleAverageQuery { Variable variable; @@ -21,7 +22,8 @@ struct OracleAverageQuery { /** * @dev Information for a query for the latest variable * - * TODO: fill this in + * Each query computes the latest instantaneous variable. + * The address of `base` is strictly less than the address of `quote` */ struct OracleLatestQuery { Variable variable; @@ -33,6 +35,7 @@ struct OracleLatestQuery { * @dev Information for an Accumulator query. * * Each query estimates the accumulator at a time `ago` seconds ago. + * The address of `base` is strictly less than the address of `quote` */ struct OracleAccumulatorQuery { Variable variable; diff --git a/src/interfaces/IPriceOracle.sol b/src/interfaces/IPriceOracle.sol new file mode 100644 index 0000000..adad550 --- /dev/null +++ b/src/interfaces/IPriceOracle.sol @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +pragma solidity ^0.8.0; + +interface IPriceOracle { + function name() external view returns (string memory); + + /// @notice Returns the quote for a given amount of base asset in quote asset. + /// @param amount The amount of base asset. + /// @param base The address of the base asset. + /// @param quote The address of the quote asset. + /// @return out The quote amount in quote asset. + function getQuote(uint256 amount, address base, address quote) external view returns (uint256 out); + + /// @notice Returns the bid and ask quotes for a given amount of base asset in quote asset. + /// @param amount The amount of base asset. + /// @param base The address of the base asset. + /// @param quote The address of the quote asset. + /// @return bidOut The bid quote amount in quote asset. + /// @return askOut The ask quote amount in quote asset. + function getQuotes(uint256 amount, address base, address quote) + external + view + returns (uint256 bidOut, uint256 askOut); +} diff --git a/src/interfaces/IReservoirPriceOracle.sol b/src/interfaces/IReservoirPriceOracle.sol index d9d79ef..2c59bf7 100644 --- a/src/interfaces/IReservoirPriceOracle.sol +++ b/src/interfaces/IReservoirPriceOracle.sol @@ -48,6 +48,9 @@ interface IReservoirPriceOracle { * * If a query has a non-zero `ago` value, then `secs + ago` (the oldest point in time) must be smaller than this * value for 'safe' queries. + * + * Since ReservoirPair's oracle writes every second, the largest safe query window is the number of seconds + * same as the size of the buffer. */ function getLargestSafeQueryWindow() external view returns (uint256); diff --git a/src/libraries/Constants.sol b/src/libraries/Constants.sol new file mode 100644 index 0000000..e69032b --- /dev/null +++ b/src/libraries/Constants.sol @@ -0,0 +1,15 @@ +// SPDX-License-Identifier: GPL-3.0-or-later +pragma solidity ^0.8.0; + +library Constants { + /////////////////////////////////////////////////////////////////////////////////////////////// + // CONSTANTS // + /////////////////////////////////////////////////////////////////////////////////////////////// + + uint256 public constant MAX_DEVIATION_THRESHOLD = 0.1e18; // 10% + uint256 public constant MAX_TWAP_PERIOD = 1 hours; + uint256 public constant MAX_ROUTE_LENGTH = 4; + uint256 public constant WAD = 1e18; + uint256 public constant MAX_SUPPORTED_PRICE = type(uint128).max; + uint256 public constant MAX_AMOUNT_IN = type(uint128).max; +} diff --git a/src/libraries/FlagsLib.sol b/src/libraries/FlagsLib.sol new file mode 100644 index 0000000..a955e85 --- /dev/null +++ b/src/libraries/FlagsLib.sol @@ -0,0 +1,72 @@ +// SPDX-License-Identifier: GPL-3.0-or-later +pragma solidity ^0.8.0; + +library FlagsLib { + bytes32 public constant FLAG_UNINITIALIZED = bytes32(hex"00"); + bytes32 public constant FLAG_SIMPLE_PRICE = bytes32(hex"01"); + bytes32 public constant FLAG_2_HOP_ROUTE = bytes32(hex"02"); + bytes32 public constant FLAG_3_HOP_ROUTE = bytes32(hex"03"); + + function isUninitialized(bytes32 aData) internal pure returns (bool) { + return aData[0] == FLAG_UNINITIALIZED; + } + + function isSimplePrice(bytes32 aData) internal pure returns (bool) { + return aData[0] == FLAG_SIMPLE_PRICE; + } + + function isCompositeRoute(bytes32 aData) internal pure returns (bool) { + return aData[0] & hex"02" > 0; + } + + function is3HopRoute(bytes32 aData) internal pure returns (bool) { + return aData[0] == FLAG_3_HOP_ROUTE; + } + + // Positive value indicates that token1 has a greater number of decimals compared to token2 + // while a negative value indicates otherwise. + // range of values between -18 and 18 + function getDecimalDifference(bytes32 aData) internal pure returns (int256 rDiff) { + rDiff = int8(uint8(aData[1])); + } + + // Assumes that aDecimalDifference is between -18 and 18 + // Assumes that aPrice is between 1 and 1e36 + function packSimplePrice(int256 aDecimalDifference, uint256 aPrice) internal pure returns (bytes32 rPacked) { + bytes32 lDecimalDifferenceRaw = bytes1(uint8(int8(aDecimalDifference))); + rPacked = FLAG_SIMPLE_PRICE | lDecimalDifferenceRaw >> 8 | bytes32(aPrice); + } + + function pack2HopRoute(address aSecondToken) internal pure returns (bytes32 rPacked) { + // Move aSecondToken to start on the 2nd byte. + rPacked = FLAG_2_HOP_ROUTE | bytes32(bytes20(aSecondToken)) >> 8; + } + + function pack3HopRoute(address aSecondToken, address aThirdToken) + internal + pure + returns (bytes32 rFirstWord, bytes32 rSecondWord) + { + bytes32 lThirdTokenTop10Bytes = bytes32(bytes20(aThirdToken)) >> 176; + // Trim away the first 10 bytes since we only want the last 10 bytes. + bytes32 lThirdTokenBottom10Bytes = bytes32(bytes20(aThirdToken) << 80); + + rFirstWord = FLAG_3_HOP_ROUTE | bytes32(bytes20(aSecondToken)) >> 8 | lThirdTokenTop10Bytes; + rSecondWord = lThirdTokenBottom10Bytes; + } + + function getPrice(bytes32 aData) internal pure returns (uint256 rPrice) { + rPrice = uint256(aData & 0x0000ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff); + } + + function getTokenFirstWord(bytes32 aData) internal pure returns (address rToken) { + rToken = + address(uint160(uint256(aData & 0x00ffffffffffffffffffffffffffffffffffffffff0000000000000000000000) >> 88)); + } + + function getThirdToken(bytes32 aFirstWord, bytes32 aSecondWord) internal pure returns (address rToken) { + bytes32 lTop10Bytes = (aFirstWord & 0x00000000000000000000000000000000000000000000ffffffffffffffffffff) << 80; + bytes32 lBottom10Bytes = aSecondWord >> 176; + rToken = address(uint160(uint256(lTop10Bytes | lBottom10Bytes))); + } +} diff --git a/src/libraries/OracleErrors.sol b/src/libraries/OracleErrors.sol new file mode 100644 index 0000000..3d1fa97 --- /dev/null +++ b/src/libraries/OracleErrors.sol @@ -0,0 +1,29 @@ +// SPDX-License-Identifier: GPL-3.0-or-later +pragma solidity ^0.8.0; + +/// @dev Collection of all oracle related errors. +library OracleErrors { + // config errors + error InvalidRoute(); + error InvalidRouteLength(); + error InvalidTwapPeriod(); + error NoDesignatedPair(); + error PriceDeviationThresholdTooHigh(); + error SameToken(); + error TokensUnsorted(); + error UnsupportedTokenDecimals(); + + // query errors + error AmountInTooLarge(); + error BadSecs(); + error BadVariableRequest(); + error InvalidSeconds(); + error NoPath(); + error OracleNotInitialized(); + error PriceZero(); + error QueryTooOld(); + + // price update and calculation errors + error PriceOutOfRange(uint256 aPrice); + error WriteToNonSimpleRoute(); +} diff --git a/src/libraries/QueryProcessor.sol b/src/libraries/QueryProcessor.sol index d537894..a4d77fa 100644 --- a/src/libraries/QueryProcessor.sol +++ b/src/libraries/QueryProcessor.sol @@ -18,7 +18,7 @@ import { Buffer } from "amm-core/libraries/Buffer.sol"; import { ReservoirPair, Observation } from "amm-core/ReservoirPair.sol"; import { Variable, OracleAverageQuery, OracleAccumulatorQuery } from "src/interfaces/IReservoirPriceOracle.sol"; -import { BadVariableRequest, OracleNotInitialized, InvalidSeconds, QueryTooOld, BadSecs } from "src/Errors.sol"; +import { OracleErrors } from "src/libraries/OracleErrors.sol"; import { Samples } from "src/libraries/Samples.sol"; /** @@ -35,27 +35,37 @@ library QueryProcessor { /** * @dev Returns the value for `variable` at the indexed sample. */ - function getInstantValue(ReservoirPair pair, Variable variable, uint256 index) external view returns (uint256) { + function getInstantValue(ReservoirPair pair, Variable variable, uint256 index) internal view returns (uint256) { Observation memory sample = pair.observation(index); - if (sample.timestamp == 0) revert OracleNotInitialized(); + if (sample.timestamp == 0) revert OracleErrors.OracleNotInitialized(); int256 rawInstantValue = sample.instant(variable); return LogCompression.fromLowResLog(rawInstantValue); } /** - * @dev Returns the time average weighted price corresponding to `query`. + * @dev Returns the time average weighted price */ - function getTimeWeightedAverage(ReservoirPair pair, OracleAverageQuery memory query, uint16 latestIndex) - external - view - returns (uint256) - { - if (query.secs == 0) revert BadSecs(); - - int256 beginAccumulator = getPastAccumulator(pair, query.variable, latestIndex, query.ago + query.secs); - int256 endAccumulator = getPastAccumulator(pair, query.variable, latestIndex, query.ago); - return LogCompression.fromLowResLog((endAccumulator - beginAccumulator) / int256(query.secs)); + function getTimeWeightedAverage( + ReservoirPair pair, + Variable variable, + uint256 secs, + uint256 ago, + uint16 latestIndex + ) internal view returns (uint256) { + if (secs == 0) revert OracleErrors.BadSecs(); + + // SAFETY: + // + // `getPastAccumulator` reverts for any `ago`` greater than 32 bits anyway (i.e. greater than the current block.timestamp till year 2106) + // So if either `ago` or `ago + secs` is larger than 32 bits, it will revert + // `endAccumulator` and `beginAccumulators` themselves will not overflow/underflow until at least after year 2106. So the subtraction will not underflow as well. + // Therefore it is safe to use unchecked here + unchecked { + int256 beginAccumulator = getPastAccumulator(pair, variable, latestIndex, ago + secs); + int256 endAccumulator = getPastAccumulator(pair, variable, latestIndex, ago); + return LogCompression.fromLowResLog((endAccumulator - beginAccumulator) / int256(secs)); + } } /** @@ -66,39 +76,49 @@ library QueryProcessor { * - if the buffer is empty. * - if querying past information and the buffer has not been fully initialized. * - if querying older information than available in the buffer. Note that a full buffer guarantees queries for the - * past 34 hours will not revert. + * past largest safe query window will not revert. * * If requesting information for a timestamp later than the latest one, it is extrapolated using the latest * available data. * * When no exact information is available for the requested past timestamp (as usually happens, since at most one * timestamp is stored every two minutes), it is estimated by performing linear interpolation using the closest - * values. This process is guaranteed to complete performing at most 10 storage reads. + * values. This process is guaranteed to complete performing at most 11 storage reads. */ function getPastAccumulator(ReservoirPair pair, Variable variable, uint16 latestIndex, uint256 ago) - public + internal view returns (int256) { // solhint-disable not-rely-on-time // `ago` must not be before the epoch. - if (block.timestamp < ago) revert InvalidSeconds(); - uint256 lookUpTime = block.timestamp - ago; + if (block.timestamp < ago) revert OracleErrors.InvalidSeconds(); + uint256 lookUpTime; + // SAFETY: + // + // `ago` is guaranteed to be equal to or less than `block.timestamp` as checked above, so subtraction will not underflow. + unchecked { + lookUpTime = block.timestamp - ago; + } Observation memory latestSample = pair.observation(latestIndex); uint256 latestTimestamp = latestSample.timestamp; // The latest sample only has a non-zero timestamp if no data was ever processed and stored in the buffer. - if (latestTimestamp == 0) revert OracleNotInitialized(); + if (latestTimestamp == 0) revert OracleErrors.OracleNotInitialized(); if (latestTimestamp <= lookUpTime) { // The accumulator at times ahead of the latest one are computed by extrapolating the latest data. This is // equivalent to the instant value not changing between the last timestamp and the look up time. - // We can use unchecked arithmetic since the accumulator can be represented in 53 bits, timestamps in 31 - // bits, and the instant value in 22 bits. - uint256 elapsed = lookUpTime - latestTimestamp; - return latestSample.accumulator(variable) + (latestSample.instant(variable) * int256(elapsed)); + // SAFETY: + // + // `latestTimestamp` is guaranteed to be equal or less than `lookUpTime` as checked above. So this subtraction will not underflow. + // The accumulator can be represented in 53 bits, timestamps in 31bits, and the instant value in 22 bits. So this addition will not overflow. + unchecked { + uint256 elapsed = lookUpTime - latestTimestamp; + return latestSample.accumulator(variable) + (latestSample.instant(variable) * int256(elapsed)); + } } else { // The look up time is before the latest sample, but we need to make sure that it is not before the oldest // sample as well. @@ -124,25 +144,33 @@ library QueryProcessor { } // Finally check that the look up time is not previous to the oldest timestamp. - if (oldestTimestamp > lookUpTime) revert QueryTooOld(); + if (oldestTimestamp > lookUpTime) revert OracleErrors.QueryTooOld(); } // Perform binary search to find nearest samples to the desired timestamp. (Observation memory prev, Observation memory next) = findNearestSample(pair, lookUpTime, oldestIndex, bufferLength); - // `next`'s timestamp is guaranteed to be larger than `prev`'s, so we can skip checked arithmetic. - uint256 samplesTimeDiff = next.timestamp - prev.timestamp; - + // SAFETY: + // + // `next.timestamp` is guaranteed to be larger than `prev.timestamp`, so subtraction will not underflow. + uint256 samplesTimeDiff; + unchecked { + samplesTimeDiff = next.timestamp - prev.timestamp; + } if (samplesTimeDiff > 0) { // We estimate the accumulator at the requested look up time by interpolating linearly between the // previous and next accumulators. - // We can use unchecked arithmetic since the accumulators can be represented in 53 bits, and timestamps - // in 31 bits. - int256 samplesAccDiff = next.accumulator(variable) - prev.accumulator(variable); - uint256 elapsed = lookUpTime - prev.timestamp; - return prev.accumulator(variable) + ((samplesAccDiff * int256(elapsed)) / int256(samplesTimeDiff)); + // SAFETY: + // + // The accumulators can be represented in 53 bits, and timestamps are in 31 bits. So the addition and subtraction will not under/overflow. + // `lookupTime` is greater than `latestTimestamp` and is thus also greater than `prev.timestamp` so subtraction will not underflow. + unchecked { + int256 samplesAccDiff = next.accumulator(variable) - prev.accumulator(variable); + uint256 elapsed = lookUpTime - prev.timestamp; + return prev.accumulator(variable) + ((samplesAccDiff * int256(elapsed)) / int256(samplesTimeDiff)); + } } else { // Rarely, one of the samples will have the exact requested look up time, which is indicated by `prev` // and `next` being the same. In this case, we simply return the accumulator at that point in time. @@ -157,57 +185,64 @@ library QueryProcessor { * of the samples list. * * Assumes `lookUpDate` is greater or equal than the timestamp of the oldest sample, and less or equal than the - * timestamp of the latest sample. + * timestamp of the latest sample. Assumes that `length` is at least 1. */ function findNearestSample(ReservoirPair pair, uint256 lookUpDate, uint16 offset, uint16 length) - public + internal view returns (Observation memory prev, Observation memory next) { - // We're going to perform a binary search in the circular buffer, which requires it to be sorted. To achieve - // this, we offset all buffer accesses by `offset`, making the first element the oldest one. - - // Auxiliary variables in a typical binary search: we will look at some value `mid` between `low` and `high`, - // periodically increasing `low` or decreasing `high` until we either find a match or determine the element is - // not in the array. - uint16 low = 0; - uint16 high = length - 1; - uint16 mid; - - // If the search fails and no sample has a timestamp of `lookUpDate` (as is the most common scenario), `sample` - // will be either the sample with the largest timestamp smaller than `lookUpDate`, or the one with the smallest - // timestamp larger than `lookUpDate`. - Observation memory sample; - uint256 sampleTimestamp; - - while (low <= high) { - // Mid is the floor of the average. - uint16 midWithoutOffset = (high + low) / 2; - - // Recall that the buffer is not actually sorted: we need to apply the offset to access it in a sorted way. - mid = midWithoutOffset.add(offset); - sample = pair.observation(mid); - sampleTimestamp = sample.timestamp; - - if (sampleTimestamp < lookUpDate) { - // If the mid sample is bellow the look up date, then increase the low index to start from there. - low = midWithoutOffset + 1; - } else if (sampleTimestamp > lookUpDate) { - // If the mid sample is above the look up date, then decrease the high index to start from there. - - // We can skip checked arithmetic: it is impossible for `high` to ever be 0, as a scenario where `low` - // equals 0 and `high` equals 1 would result in `low` increasing to 1 in the previous `if` clause. - high = midWithoutOffset - 1; - } else { - // sampleTimestamp == lookUpDate - // If we have an exact match, return the sample as both `prev` and `next`. - return (sample, sample); + // SAFETY: + // + // As `length` is at least 1, subtractions will not underflow + // Additions will also not overflow as the max length is `Buffer.SIZE` + unchecked { + // We're going to perform a binary search in the circular buffer, which requires it to be sorted. To achieve + // this, we offset all buffer accesses by `offset`, making the first element the oldest one. + + // Auxiliary variables in a typical binary search: we will look at some value `mid` between `low` and `high`, + // periodically increasing `low` or decreasing `high` until we either find a match or determine the element is + // not in the array. + uint16 low = 0; + uint16 high = length - 1; + uint16 mid; + + // If the search fails and no sample has a timestamp of `lookUpDate` (as is the most common scenario), `sample` + // will be either the sample with the largest timestamp smaller than `lookUpDate`, or the one with the smallest + // timestamp larger than `lookUpDate`. + Observation memory sample; + uint256 sampleTimestamp; + + while (low <= high) { + // Mid is the floor of the average. + // Additions does not overflow as they are Buffer.SIZE max + uint16 midWithoutOffset = (high + low) / 2; + + // Recall that the buffer is not actually sorted: we need to apply the offset to access it in a sorted way. + mid = midWithoutOffset.add(offset); + sample = pair.observation(mid); + sampleTimestamp = sample.timestamp; + + if (sampleTimestamp < lookUpDate) { + // If the mid sample is bellow the look up date, then increase the low index to start from there. + low = midWithoutOffset + 1; + } else if (sampleTimestamp > lookUpDate) { + // If the mid sample is above the look up date, then decrease the high index to start from there. + + // We can skip checked arithmetic: it is impossible for `high` to ever be 0, as a scenario where `low` + // equals 0 and `high` equals 1 would result in `low` increasing to 1 in the previous `if` clause. + high = midWithoutOffset - 1; + } else { + // sampleTimestamp == lookUpDate + // If we have an exact match, return the sample as both `prev` and `next`. + return (sample, sample); + } } - } - // In case we reach here, it means we didn't find exactly the sample we where looking for. - return sampleTimestamp < lookUpDate - ? (sample, pair.observation(mid.next())) - : (pair.observation(mid.prev()), sample); + // In case we reach here, it means we didn't find exactly the sample we where looking for. + return sampleTimestamp < lookUpDate + ? (sample, pair.observation(mid.next())) + : (pair.observation(mid.prev()), sample); + } } } diff --git a/src/libraries/Samples.sol b/src/libraries/Samples.sol index 7408979..03cd638 100644 --- a/src/libraries/Samples.sol +++ b/src/libraries/Samples.sol @@ -16,7 +16,7 @@ pragma solidity ^0.8.0; import { Observation } from "amm-core/ReservoirPair.sol"; import { Variable } from "src/Enums.sol"; -import { BadVariableRequest } from "src/Errors.sol"; +import { OracleErrors } from "src/libraries/OracleErrors.sol"; library Samples { /** @@ -28,7 +28,7 @@ library Samples { } else if (variable == Variable.CLAMPED_PRICE) { return sample.logInstantClampedPrice; } else { - revert BadVariableRequest(); + revert OracleErrors.BadVariableRequest(); } } @@ -41,7 +41,7 @@ library Samples { } else if (variable == Variable.CLAMPED_PRICE) { return sample.logAccClampedPrice; } else { - revert BadVariableRequest(); + revert OracleErrors.BadVariableRequest(); } } } diff --git a/src/libraries/Utils.sol b/src/libraries/Utils.sol new file mode 100644 index 0000000..c67d4c9 --- /dev/null +++ b/src/libraries/Utils.sol @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: GPL-3.0-or-later +pragma solidity ^0.8.0; + +import { OracleErrors } from "src/libraries/OracleErrors.sol"; + +library Utils { + /// @dev Square of 1e18 (WAD) + uint256 internal constant WAD_SQUARED = 1e36; + + error OutOfRange(uint256 value); + + // returns the lower address followed by the higher address + function sortTokens(address tokenA, address tokenB) internal pure returns (address, address) { + return tokenA < tokenB ? (tokenA, tokenB) : (tokenB, tokenA); + } + + /// @dev aToken0 has to be strictly less than aToken1 + function calculateSlot(address aToken0, address aToken1) internal pure returns (bytes32) { + return keccak256(abi.encode(aToken0, aToken1)); + } + + function invertWad(uint256 x) internal pure returns (uint256) { + if (x == 0 || x > WAD_SQUARED) revert OutOfRange(x); + + return WAD_SQUARED / x; + } +} diff --git a/test/__fixtures/BaseTest.t.sol b/test/__fixtures/BaseTest.t.sol new file mode 100644 index 0000000..f335695 --- /dev/null +++ b/test/__fixtures/BaseTest.t.sol @@ -0,0 +1,57 @@ +// SPDX-License-Identifier: GPL-3.0-or-later +pragma solidity ^0.8.0; + +import { Test, console2, stdError } from "forge-std/Test.sol"; + +import { GenericFactory, IERC20 } from "amm-core/GenericFactory.sol"; +import { ReservoirPair } from "amm-core/ReservoirPair.sol"; +import { ConstantProductPair } from "amm-core/curve/constant-product/ConstantProductPair.sol"; +import { StablePair } from "amm-core/curve/stable/StablePair.sol"; +import { Constants } from "amm-core/Constants.sol"; +import { FactoryStoreLib } from "amm-core/libraries/FactoryStore.sol"; +import { MintableERC20 } from "lib/amm-core/test/__fixtures/MintableERC20.sol"; + +import { ReservoirPriceOracle, IReservoirPriceOracle, IPriceOracle } from "src/ReservoirPriceOracle.sol"; + +contract BaseTest is Test { + using FactoryStoreLib for GenericFactory; + + GenericFactory internal _factory = new GenericFactory(); + ReservoirPair internal _pair; + + ReservoirPriceOracle internal _oracle = new ReservoirPriceOracle(0.02e18, 15 minutes, 500_000); + + MintableERC20 internal _tokenA = MintableERC20(address(0x100)); + MintableERC20 internal _tokenB = MintableERC20(address(0x200)); + MintableERC20 internal _tokenC = MintableERC20(address(0x300)); + MintableERC20 internal _tokenD = MintableERC20(address(0x400)); + + constructor() { + // we do this to have certainty that these token addresses are in ascending order, for easy testing + deployCodeTo("MintableERC20.sol", abi.encode("TokenA", "TA", uint8(6)), address(_tokenA)); + deployCodeTo("MintableERC20.sol", abi.encode("TokenB", "TB", uint8(18)), address(_tokenB)); + deployCodeTo("MintableERC20.sol", abi.encode("TokenC", "TC", uint8(10)), address(_tokenC)); + deployCodeTo("MintableERC20.sol", abi.encode("TokenD", "TD", uint8(6)), address(_tokenD)); + + _factory.addCurve(type(ConstantProductPair).creationCode); + _factory.addCurve(type(StablePair).creationCode); + + _factory.write("CP::swapFee", Constants.DEFAULT_SWAP_FEE_CP); + _factory.write("SP::swapFee", Constants.DEFAULT_SWAP_FEE_SP); + _factory.write("SP::amplificationCoefficient", Constants.DEFAULT_AMP_COEFF); + _factory.write("Shared::platformFee", Constants.DEFAULT_PLATFORM_FEE); + _factory.write("Shared::platformFeeTo", address(this)); + _factory.write("Shared::recoverer", address(this)); + _factory.write("Shared::maxChangeRate", Constants.DEFAULT_MAX_CHANGE_RATE); + _factory.write("Shared::oracleCaller", address(_oracle)); + + _pair = ReservoirPair(_createPair(address(_tokenA), address(_tokenB), 0)); + _tokenA.mint(address(_pair), 103e6); + _tokenB.mint(address(_pair), 10_189e18); + _pair.mint(address(this)); + } + + function _createPair(address aTokenA, address aTokenB, uint256 aCurveId) internal returns (address rPair) { + rPair = _factory.createPair(IERC20(aTokenA), IERC20(aTokenB), aCurveId); + } +} diff --git a/test/integration/ReservoirPriceOracle.t.sol b/test/integration/ReservoirPriceOracle.t.sol new file mode 100644 index 0000000..f59222c --- /dev/null +++ b/test/integration/ReservoirPriceOracle.t.sol @@ -0,0 +1,16 @@ +// SPDX-License-Identifier: GPL-3.0-or-later +pragma solidity ^0.8.0; + +import { Test, console2 } from "forge-std/Test.sol"; + +contract ReservoirPriceOracleIntegrationTest is Test { + function setUp() external { + uint256 lForkId = vm.createFork(getChain("arbitrum_one").rpcUrl); + vm.selectFork(lForkId); + } + + function testBlockBaseFee() external view { + // assert + assertEq(block.basefee, 0.01 gwei); + } +} diff --git a/test/unit/ReservoirPriceOracle.t.sol b/test/unit/ReservoirPriceOracle.t.sol new file mode 100644 index 0000000..9e5967c --- /dev/null +++ b/test/unit/ReservoirPriceOracle.t.sol @@ -0,0 +1,1129 @@ +// SPDX-License-Identifier: GPL-3.0-or-later +pragma solidity ^0.8.0; + +import { BaseTest, console2, ReservoirPair, MintableERC20 } from "test/__fixtures/BaseTest.t.sol"; + +import { Utils } from "src/libraries/Utils.sol"; +import { + Buffer, + Variable, + OracleErrors, + OracleLatestQuery, + OracleAccumulatorQuery, + OracleAverageQuery, + ReservoirPriceOracle, + IERC20, + IPriceOracle, + FlagsLib +} from "src/ReservoirPriceOracle.sol"; +import { Bytes32Lib } from "amm-core/libraries/Bytes32.sol"; +import { EnumerableSetLib } from "lib/solady/src/utils/EnumerableSetLib.sol"; +import { Constants } from "src/libraries/Constants.sol"; + +contract ReservoirPriceOracleTest is BaseTest { + using Utils for *; + using FlagsLib for *; + using Bytes32Lib for *; + using EnumerableSetLib for EnumerableSetLib.AddressSet; + + event DesignatePair(address token0, address token1, ReservoirPair pair); + event Oracle(address newOracle); + event RewardGasAmount(uint256 newAmount); + event Route(address token0, address token1, address[] route); + + uint256 private constant WAD = 1e18; + + // to keep track of addresses to ensure no clash for fuzz tests + EnumerableSetLib.AddressSet internal _addressSet; + + address internal constant ADDRESS_THRESHOLD = address(0x1000); + + // writes the cached prices, for easy testing + function _writePriceCache(address aToken0, address aToken1, uint256 aPrice) internal { + require(aToken0 < aToken1, "tokens unsorted"); + require(bytes32(aPrice) & bytes2(0xffff) == 0, "PRICE WILL OVERLAP FLAG"); + + vm.record(); + _oracle.priceCache(aToken0, aToken1); + (bytes32[] memory lAccesses,) = vm.accesses(address(_oracle)); + require(lAccesses.length == 1, "incorrect number of accesses"); + + int256 lDecimalDiff = int256(uint256(IERC20(aToken1).decimals())) - int256(uint256(IERC20(aToken0).decimals())); + bytes32 lData = lDecimalDiff.packSimplePrice(aPrice); + require(lData.getDecimalDifference() == lDecimalDiff, "decimal diff incorrect"); + require(lData.isSimplePrice(), "flag incorrect"); + vm.store(address(_oracle), lAccesses[0], lData); + } + + constructor() { + // sanity - ensure that base fee is correct, for testing reward payout + assertEq(block.basefee, 0.01 gwei); + + // make sure ether balance of test contract is 0 + deal(address(this), 0); + + _addressSet.add(address(_tokenA)); + _addressSet.add(address(_tokenB)); + _addressSet.add(address(_tokenC)); + _addressSet.add(address(_tokenD)); + } + + receive() external payable { } // required to receive reward payout from priceCache + + function setUp() external { + // define route + address[] memory lRoute = new address[](2); + lRoute[0] = address(_tokenA); + lRoute[1] = address(_tokenB); + + _oracle.designatePair(address(_tokenB), address(_tokenA), _pair); + _oracle.setRoute(address(_tokenA), address(_tokenB), lRoute); + } + + function testWritePriceCache(uint256 aPrice) external { + // arrange + uint256 lPrice = bound(aPrice, 1, 1e36); + + // act + _writePriceCache(address(_tokenB), address(_tokenC), lPrice); + + // assert + (uint256 lQueriedPrice,) = _oracle.priceCache(address(_tokenB), address(_tokenC)); + assertEq(lQueriedPrice, lPrice); + } + + function testGasBountyAvailable(uint256 aBountyAmount) external { + // assume + uint256 lBounty = bound(aBountyAmount, 1, type(uint256).max); + + // arrange + deal(address(_oracle), lBounty); + + // act & assert + assertEq(_oracle.gasBountyAvailable(), lBounty); + } + + function testGasBountyAvailable_Zero() external view { + // sanity + assertEq(address(_oracle).balance, 0); + + // act & assert + assertEq(_oracle.gasBountyAvailable(), 0); + } + + function testGetQuote(uint256 aPrice, uint256 aAmountIn) public { + // assume + uint256 lPrice = bound(aPrice, 1, 1e36); + uint256 lAmountIn = bound(aAmountIn, 100, 10_000_000e6); + + // arrange + _writePriceCache(address(_tokenA), address(_tokenB), lPrice); + + // act + uint256 lAmountOut = _oracle.getQuote(lAmountIn, address(_tokenA), address(_tokenB)); + + // assert + assertEq(lAmountOut, lAmountIn * lPrice * 10 ** _tokenB.decimals() / 10 ** _tokenA.decimals() / 1e18); + } + + function testGetQuote_Inverse(uint256 aPrice, uint256 aAmountIn) external { + // assume + uint256 lPrice = bound(aPrice, 1, 1e36); + uint256 lAmountIn = bound(aAmountIn, 100, 100_000_000_000e18); + + // arrange + _writePriceCache(address(_tokenA), address(_tokenB), lPrice); + (uint256 lQueriedPrice,) = _oracle.priceCache(address(_tokenA), address(_tokenB)); + assertEq(lQueriedPrice, lPrice); + + // act + uint256 lAmountOut = _oracle.getQuote(lAmountIn, address(_tokenB), address(_tokenA)); + + // assert + assertEq(lAmountOut, lAmountIn * WAD * (10 ** _tokenA.decimals()) / lPrice / (10 ** _tokenB.decimals())); + } + + function testGetQuote_MultipleHops() public { + // assume + uint256 lPriceAB = 1e18; + uint256 lPriceBC = 2e18; + uint256 lPriceCD = 4e18; + + // arrange + _writePriceCache(address(_tokenA), address(_tokenB), lPriceAB); + _writePriceCache(address(_tokenB), address(_tokenC), lPriceBC); + _writePriceCache(address(_tokenC), address(_tokenD), lPriceCD); + + address[] memory lRoute = new address[](4); + lRoute[0] = address(_tokenA); + lRoute[1] = address(_tokenB); + lRoute[2] = address(_tokenC); + lRoute[3] = address(_tokenD); + _oracle.setRoute(address(_tokenA), address(_tokenD), lRoute); + + uint256 lAmountIn = 789e6; + + // act + uint256 lAmountOut = _oracle.getQuote(lAmountIn, address(_tokenA), address(_tokenD)); + + // assert + assertEq(lAmountOut, 6312e6); + } + + function testGetQuote_MultipleHops_Inverse() external { + // assume + uint256 lPriceAB = 1e18; + uint256 lPriceBC = 2e18; + uint256 lPriceCD = 4e18; + + // arrange + _writePriceCache(address(_tokenA), address(_tokenB), lPriceAB); + _writePriceCache(address(_tokenB), address(_tokenC), lPriceBC); + _writePriceCache(address(_tokenC), address(_tokenD), lPriceCD); + + address[] memory lRoute = new address[](4); + lRoute[0] = address(_tokenA); + lRoute[1] = address(_tokenB); + lRoute[2] = address(_tokenC); + lRoute[3] = address(_tokenD); + _oracle.setRoute(address(_tokenA), address(_tokenD), lRoute); + + uint256 lAmountIn = 789e6; + + // act + uint256 lAmountOut = _oracle.getQuote(lAmountIn, address(_tokenD), address(_tokenA)); + + // assert + assertEq(lAmountOut, 98.625e6); + } + + function testGetQuote_ComplicatedDecimals() external { + // arrange + MintableERC20 lTokenA = MintableERC20(address(0x1111)); + MintableERC20 lTokenB = MintableERC20(address(0x3333)); + MintableERC20 lTokenC = MintableERC20(address(0x2222)); + uint8 lTokenADecimals = 6; + uint8 lTokenBDecimals = 8; + uint8 lTokenCDecimals = 11; + + deployCodeTo("MintableERC20.sol", abi.encode("T", "T", lTokenADecimals), address(lTokenA)); + deployCodeTo("MintableERC20.sol", abi.encode("T", "T", lTokenBDecimals), address(lTokenB)); + deployCodeTo("MintableERC20.sol", abi.encode("T", "T", lTokenCDecimals), address(lTokenC)); + + ReservoirPair lPairAB = + ReservoirPair(_factory.createPair(IERC20(address(lTokenA)), IERC20(address(lTokenB)), 0)); + ReservoirPair lPairBC = + ReservoirPair(_factory.createPair(IERC20(address(lTokenB)), IERC20(address(lTokenC)), 0)); + _oracle.designatePair(address(lTokenA), address(lTokenB), lPairAB); + _oracle.designatePair(address(lTokenB), address(lTokenC), lPairBC); + + address[] memory lRoute = new address[](3); + lRoute[0] = address(lTokenA); + lRoute[1] = address(lTokenB); + lRoute[2] = address(lTokenC); + + _oracle.setRoute(address(lTokenA), address(lTokenC), lRoute); + _writePriceCache(address(lTokenA), address(lTokenB), 1e18); + _writePriceCache(address(lTokenC), address(lTokenB), 1e18); + + // act + uint256 lAmtCOut = _oracle.getQuote(10 ** lTokenADecimals, address(lTokenA), address(lTokenC)); + uint256 lAmtBOut = _oracle.getQuote(10 ** lTokenADecimals, address(lTokenA), address(lTokenB)); + uint256 lAmtAOut = _oracle.getQuote(10 ** lTokenCDecimals, address(lTokenC), address(lTokenA)); + + // assert + assertEq(lAmtCOut, 10 ** lTokenCDecimals); + assertEq(lAmtBOut, 10 ** lTokenBDecimals); + assertEq(lAmtAOut, 10 ** lTokenADecimals); + } + + function testGetQuote_RandomizeAllParam_1HopRoute( + uint256 aPrice, + uint256 aAmtIn, + address aTokenAAddress, + address aTokenBAddress, + uint8 aTokenADecimal, + uint8 aTokenBDecimal + ) external { + // assume + vm.assume(aTokenAAddress > ADDRESS_THRESHOLD && aTokenBAddress > ADDRESS_THRESHOLD); // avoid precompile addresses + vm.assume(_addressSet.add(aTokenAAddress) && _addressSet.add(aTokenBAddress)); + uint256 lPrice = bound(aPrice, 1, 1e36); + uint256 lAmtIn = bound(aAmtIn, 0, 1_000_000_000); + uint256 lTokenADecimal = bound(aTokenADecimal, 0, 18); + uint256 lTokenBDecimal = bound(aTokenBDecimal, 0, 18); + + // arrange + MintableERC20 lTokenA = MintableERC20(aTokenAAddress); + MintableERC20 lTokenB = MintableERC20(aTokenBAddress); + deployCodeTo("MintableERC20.sol", abi.encode("T", "T", uint8(lTokenADecimal)), address(lTokenA)); + deployCodeTo("MintableERC20.sol", abi.encode("T", "T", uint8(lTokenBDecimal)), address(lTokenB)); + + ReservoirPair lPair = ReservoirPair(_factory.createPair(IERC20(address(lTokenA)), IERC20(address(lTokenB)), 0)); + _oracle.designatePair(address(lTokenA), address(lTokenB), lPair); + + address[] memory lRoute = new address[](2); + (lRoute[0], lRoute[1]) = + lTokenA < lTokenB ? (address(lTokenA), address(lTokenB)) : (address(lTokenB), address(lTokenA)); + _oracle.setRoute(lRoute[0], lRoute[1], lRoute); + _writePriceCache(lRoute[0], lRoute[1], lPrice); // price written could be tokenB/tokenA or tokenA/tokenB depending on the fuzz addresses + + // act + uint256 lAmtBOut = _oracle.getQuote(lAmtIn * 10 ** lTokenADecimal, address(lTokenA), address(lTokenB)); + + // assert + uint256 lExpectedAmt = lTokenA < lTokenB + ? lAmtIn * 10 ** lTokenADecimal * lPrice * 10 ** lTokenBDecimal / 10 ** lTokenADecimal / WAD + : lAmtIn * 10 ** lTokenADecimal * WAD * 10 ** lTokenBDecimal / lPrice / 10 ** lTokenADecimal; + + assertEq(lAmtBOut, lExpectedAmt); + } + + function testGetQuote_RandomizeAllParam_2HopRoute( + uint256 aPrice1, + uint256 aPrice2, + uint256 aAmtIn, + address aTokenAAddress, + address aTokenBAddress, + address aTokenCAddress, + uint8 aTokenADecimal, + uint8 aTokenBDecimal, + uint8 aTokenCDecimal + ) external { + // assume + vm.assume( + aTokenAAddress > ADDRESS_THRESHOLD && aTokenBAddress > ADDRESS_THRESHOLD + && aTokenCAddress > ADDRESS_THRESHOLD + ); + vm.assume(_addressSet.add(aTokenAAddress) && _addressSet.add(aTokenBAddress) && _addressSet.add(aTokenCAddress)); + uint256 lPrice1 = bound(aPrice1, 1e9, 1e25); // need to bound price within this range as a price below this will go to zero as during the mul and div of prices + uint256 lPrice2 = bound(aPrice2, 1e9, 1e25); + uint256 lAmtIn = bound(aAmtIn, 0, 1_000_000_000); + uint256 lTokenADecimal = bound(aTokenADecimal, 0, 18); + uint256 lTokenBDecimal = bound(aTokenBDecimal, 0, 18); + uint256 lTokenCDecimal = bound(aTokenCDecimal, 0, 18); + + // arrange + MintableERC20 lTokenA = MintableERC20(aTokenAAddress); + MintableERC20 lTokenB = MintableERC20(aTokenBAddress); + MintableERC20 lTokenC = MintableERC20(aTokenCAddress); + deployCodeTo("MintableERC20.sol", abi.encode("T", "T", uint8(lTokenADecimal)), address(lTokenA)); + deployCodeTo("MintableERC20.sol", abi.encode("T", "T", uint8(lTokenBDecimal)), address(lTokenB)); + deployCodeTo("MintableERC20.sol", abi.encode("T", "T", uint8(lTokenCDecimal)), address(lTokenC)); + + ReservoirPair lPair1 = ReservoirPair(_factory.createPair(IERC20(address(lTokenA)), IERC20(address(lTokenB)), 0)); + ReservoirPair lPair2 = ReservoirPair(_factory.createPair(IERC20(address(lTokenB)), IERC20(address(lTokenC)), 0)); + + _oracle.designatePair(address(lTokenA), address(lTokenB), lPair1); + _oracle.designatePair(address(lTokenB), address(lTokenC), lPair2); + { + // to avoid stack too deep error + address[] memory lRoute = new address[](3); + (lRoute[0], lRoute[2]) = + lTokenA < lTokenC ? (address(lTokenA), address(lTokenC)) : (address(lTokenC), address(lTokenA)); + lRoute[1] = address(lTokenB); + + _oracle.setRoute(lRoute[0], lRoute[2], lRoute); + _writePriceCache( + address(lTokenA) < address(lTokenB) ? address(lTokenA) : address(lTokenB), + address(lTokenA) < address(lTokenB) ? address(lTokenB) : address(lTokenA), + lPrice1 + ); + _writePriceCache( + address(lTokenB) < address(lTokenC) ? address(lTokenB) : address(lTokenC), + address(lTokenB) < address(lTokenC) ? address(lTokenC) : address(lTokenB), + lPrice2 + ); + } + // act + uint256 lAmtCOut = _oracle.getQuote(lAmtIn * 10 ** lTokenADecimal, address(lTokenA), address(lTokenC)); + + // assert + uint256 lExpectedAmtBOut = lTokenA < lTokenB + ? lAmtIn * 10 ** lTokenADecimal * lPrice1 * 10 ** lTokenBDecimal / 10 ** lTokenADecimal / WAD + : lAmtIn * 10 ** lTokenADecimal * WAD * 10 ** lTokenBDecimal / lPrice1 / 10 ** lTokenADecimal; + uint256 lExpectedAmtCOut = lTokenB < lTokenC + ? lExpectedAmtBOut * lPrice2 * 10 ** lTokenCDecimal / 10 ** lTokenBDecimal / WAD + : lExpectedAmtBOut * WAD * 10 ** lTokenCDecimal / lPrice2 / 10 ** lTokenBDecimal; + + assertEq(lAmtCOut, lExpectedAmtCOut); + } + + // function testGetQuote_RandomizeAllParam_3HopRoute( + // uint256 aPrice1, + // uint256 aPrice2, + // uint256 aPrice3, + // uint256 aAmtIn, + // address aTokenAAddress, + // address aTokenBAddress, + // address aTokenCAddress, + // address aTokenDAddress, + // uint8 aTokenADecimal, + // uint8 aTokenBDecimal, + // uint8 aTokenCDecimal, + // uint8 aTokenDDecimal + // ) external { + // // assume + // vm.assume( + // aTokenAAddress > ADDRESS_THRESHOLD && aTokenBAddress > ADDRESS_THRESHOLD + // && aTokenCAddress > ADDRESS_THRESHOLD && aTokenDAddress > ADDRESS_THRESHOLD + // ); + // vm.assume( + // _addressSet.add(aTokenAAddress) && _addressSet.add(aTokenBAddress) && _addressSet.add(aTokenCAddress) + // && _addressSet.add(aTokenDAddress) + // ); + // uint256 lPrice1 = bound(aPrice1, 1e12, 1e24); // need to bound price within this range as a price below this will go to zero as during the mul and div of prices + // uint256 lPrice2 = bound(aPrice2, 1e12, 1e24); + // uint256 lPrice3 = bound(aPrice3, 1e12, 1e24); + // uint256 lAmtIn = bound(aAmtIn, 0, 1_000_000_000); + // uint256 lTokenADecimal = bound(aTokenADecimal, 0, 18); + // uint256 lTokenBDecimal = bound(aTokenBDecimal, 0, 18); + // uint256 lTokenCDecimal = bound(aTokenCDecimal, 0, 18); + // uint256 lTokenDDecimal = bound(aTokenDDecimal, 0, 18); + // + // // arrange + // MintableERC20 lTokenA = MintableERC20(aTokenAAddress); + // MintableERC20 lTokenB = MintableERC20(aTokenBAddress); + // MintableERC20 lTokenC = MintableERC20(aTokenCAddress); + // MintableERC20 lTokenD = MintableERC20(aTokenDAddress); + // deployCodeTo("MintableERC20.sol", abi.encode("T", "T", uint8(lTokenADecimal)), address(lTokenA)); + // deployCodeTo("MintableERC20.sol", abi.encode("T", "T", uint8(lTokenBDecimal)), address(lTokenB)); + // deployCodeTo("MintableERC20.sol", abi.encode("T", "T", uint8(lTokenCDecimal)), address(lTokenC)); + // deployCodeTo("MintableERC20.sol", abi.encode("T", "T", uint8(lTokenDDecimal)), address(lTokenD)); + // + // ReservoirPair lPair1 = ReservoirPair(_factory.createPair(IERC20(address(lTokenA)), IERC20(address(lTokenB)), 0)); + // ReservoirPair lPair2 = ReservoirPair(_factory.createPair(IERC20(address(lTokenB)), IERC20(address(lTokenC)), 0)); + // ReservoirPair lPair3 = ReservoirPair(_factory.createPair(IERC20(address(lTokenC)), IERC20(address(lTokenD)), 0)); + // + // _oracle.designatePair(address(lTokenA), address(lTokenB), lPair1); + // _oracle.designatePair(address(lTokenB), address(lTokenC), lPair2); + // _oracle.designatePair(address(lTokenC), address(lTokenD), lPair3); + // + // address[] memory lRoute = new address[](4); + // (lRoute[0], lRoute[3]) = + // lTokenA < lTokenD ? (address(lTokenA), address(lTokenD)) : (address(lTokenD), address(lTokenA)); + // lRoute[1] = address(lTokenB); + // lRoute[2] = address(lTokenC); + // + // _oracle.setRoute(lRoute[0], lRoute[3], lRoute); + // _writePriceCache( + // lRoute[0] < lRoute[1] ? lRoute[0] : lRoute[1], lRoute[0] < lRoute[1] ? lRoute[1] : lRoute[0], lPrice1 + // ); + // _writePriceCache( + // address(lTokenB) < address(lTokenC) ? address(lTokenB) : address(lTokenC), + // address(lTokenB) < address(lTokenC) ? address(lTokenC) : address(lTokenB), + // lPrice2 + // ); + // _writePriceCache( + // lRoute[2] < lRoute[3] ? lRoute[2] : lRoute[3], lRoute[2] < lRoute[3] ? lRoute[3] : lRoute[2], lPrice3 + // ); + // + // // act + // uint256 lAmtDOut = _oracle.getQuote(lAmtIn * 10 ** lTokenADecimal, address(lTokenA), address(lTokenD)); + // + // // assert + // uint256 lPriceStartEnd = (lRoute[0] < lRoute[1] ? lPrice1 : lPrice1.invertWad()) + // * (lRoute[1] < lRoute[2] ? lPrice2 : lPrice2.invertWad()) / WAD + // * (lRoute[2] < lRoute[3] ? lPrice3 : lPrice3.invertWad()) / WAD; + // assertEq( + // lAmtDOut, + // lAmtIn * (lRoute[0] == address(lTokenA) ? lPriceStartEnd : lPriceStartEnd.invertWad()) + // * (10 ** lTokenDDecimal) / WAD + // ); + // } + + function testGetQuotes(uint256 aPrice, uint256 aAmountIn) external { + // assume + uint256 lPrice = bound(aPrice, 1, 1e36); + uint256 lAmountIn = bound(aAmountIn, 100, 10_000_000e6); + + // arrange + _writePriceCache(address(_tokenA), address(_tokenB), lPrice); + + // act + (uint256 lBidOut, uint256 lAskOut) = _oracle.getQuotes(lAmountIn, address(_tokenA), address(_tokenB)); + + // assert + assertEq(lBidOut, lAskOut); + } + + function testGetQuote_ZeroIn() external { + // arrange + testGetQuote(1e18, 1_000_000e6); + + // act + uint256 lAmountOut = _oracle.getQuote(0, address(_tokenA), address(_tokenB)); + + // assert + assertEq(lAmountOut, 0); + } + + function testGetQuote_SameBaseQuote(uint256 aAmtIn, address aToken) external view { + // act + uint256 lAmtOut = _oracle.getQuote(aAmtIn, aToken, aToken); + + // assert + assertEq(lAmtOut, aAmtIn); + } + + function testUpdatePriceDeviationThreshold(uint256 aNewThreshold) external { + // assume + uint64 lNewThreshold = uint64(bound(aNewThreshold, 0, 0.1e18)); + + // act + _oracle.updatePriceDeviationThreshold(lNewThreshold); + + // assert + assertEq(_oracle.priceDeviationThreshold(), lNewThreshold); + } + + function testUpdateTwapPeriod(uint256 aNewPeriod) external { + // assume + uint64 lNewPeriod = uint64(bound(aNewPeriod, 1, 1 hours)); + + // act + _oracle.updateTwapPeriod(lNewPeriod); + + // assert + assertEq(_oracle.twapPeriod(), lNewPeriod); + } + + function testUpdateRewardGasAmount() external { + // arrange + uint64 lNewRewardMultiplier = 50; + + // act + vm.expectEmit(false, false, false, false); + emit RewardGasAmount(lNewRewardMultiplier); + _oracle.updateRewardGasAmount(lNewRewardMultiplier); + + // assert + assertEq(_oracle.rewardGasAmount(), lNewRewardMultiplier); + } + + function testUpdatePrice_FirstUpdate() external { + // sanity + (uint256 lPrice,) = _oracle.priceCache(address(_tokenA), address(_tokenB)); + assertEq(lPrice, 0); + + // arrange + deal(address(_oracle), 1 ether); + + skip(1); + _pair.sync(); + skip(_oracle.twapPeriod() * 2); + _tokenA.mint(address(_pair), 2e18); + _pair.swap(2e18, true, address(this), ""); + + // act + _oracle.updatePrice(address(_tokenB), address(_tokenA), address(this)); + + // assert + (lPrice,) = _oracle.priceCache(address(_tokenA), address(_tokenB)); + assertEq(lPrice, 98_918_868_099_219_913_512); + (lPrice,) = _oracle.priceCache(address(_tokenB), address(_tokenA)); + assertEq(lPrice, 0); + assertEq(address(this).balance, 0); // there should be no reward for the first price update + } + + function testUpdatePrice_WithinThreshold() external { + // arrange + _writePriceCache(address(_tokenA), address(_tokenB), 98.9223e18); + deal(address(_oracle), 1 ether); + + skip(1); + _pair.sync(); + skip(_oracle.twapPeriod() * 2); + _tokenA.mint(address(_pair), 2e18); + _pair.swap(2e18, true, address(this), ""); + + // act + _oracle.updatePrice(address(_tokenB), address(_tokenA), address(this)); + + // assert + (uint256 lPrice,) = _oracle.priceCache(address(_tokenA), address(_tokenB)); + assertEq(lPrice, 98_918_868_099_219_913_512); + (lPrice,) = _oracle.priceCache(address(_tokenB), address(_tokenA)); + assertEq(lPrice, 0); + assertEq(address(this).balance, 0); // no reward since price is within threshold + } + + function testUpdatePrice_BeyondThreshold() external { + // arrange + _writePriceCache(address(_tokenA), address(_tokenB), 5e18); + deal(address(_oracle), 1 ether); + + skip(1); + _pair.sync(); + skip(_oracle.twapPeriod() * 2); + _tokenA.mint(address(_pair), 2e18); + _pair.swap(2e18, true, address(this), ""); + + // act + _oracle.updatePrice(address(_tokenB), address(_tokenA), address(this)); + + // assert + (uint256 lPrice,) = _oracle.priceCache(address(_tokenA), address(_tokenB)); + assertEq(lPrice, 98_918_868_099_219_913_512); + (lPrice,) = _oracle.priceCache(address(_tokenB), address(_tokenA)); + assertEq(lPrice, 0); + assertEq(address(this).balance, block.basefee * _oracle.rewardGasAmount()); + assertEq(address(_oracle).balance, 1 ether - block.basefee * _oracle.rewardGasAmount()); + } + + function testUpdatePrice_BeyondThreshold_InsufficientReward(uint256 aRewardAvailable) external { + // assume + uint256 lRewardAvailable = bound(aRewardAvailable, 1, block.basefee * _oracle.rewardGasAmount() - 1); + + // arrange + deal(address(_oracle), lRewardAvailable); + _writePriceCache(address(_tokenA), address(_tokenB), 5e18); + + skip(1); + _pair.sync(); + skip(_oracle.twapPeriod() * 2); + _tokenA.mint(address(_pair), 2e18); + _pair.swap(2e18, true, address(this), ""); + + // act + _oracle.updatePrice(address(_tokenA), address(_tokenB), address(this)); + + // assert + assertEq(address(this).balance, 0); // no reward as there's insufficient ether in the contract + } + + function testUpdatePrice_BeyondThreshold_ZeroRecipient() external { + // arrange + uint256 lBalance = 10 ether; + deal(address(_oracle), lBalance); + _writePriceCache(address(_tokenA), address(_tokenB), 5e18); + + skip(1); + _pair.sync(); + skip(_oracle.twapPeriod() * 2); + _tokenA.mint(address(_pair), 2e18); + _pair.swap(2e18, true, address(this), ""); + + // act + _oracle.updatePrice(address(_tokenA), address(_tokenB), address(0)); + + // assert - no change to balance + assertEq(address(_oracle).balance, lBalance); + } + + function testUpdatePrice_IntermediateRoutes() external { + // arrange + address lStart = address(_tokenA); + address lIntermediate1 = address(_tokenC); + address lIntermediate2 = address(_tokenD); + address lEnd = address(_tokenB); + address[] memory lRoute = new address[](4); + lRoute[0] = lStart; + lRoute[1] = lIntermediate1; + lRoute[2] = lIntermediate2; + lRoute[3] = lEnd; + _oracle.setRoute(lStart, lEnd, lRoute); + + ReservoirPair lAC = ReservoirPair(_createPair(address(_tokenA), address(_tokenC), 0)); + ReservoirPair lCD = ReservoirPair(_createPair(address(_tokenC), address(_tokenD), 0)); + ReservoirPair lBD = ReservoirPair(_createPair(address(_tokenB), address(_tokenD), 0)); + + _tokenA.mint(address(lAC), 200 * 10 ** _tokenA.decimals()); + _tokenC.mint(address(lAC), 100 * 10 ** _tokenC.decimals()); + lAC.mint(address(this)); + + _tokenC.mint(address(lCD), 100 * 10 ** _tokenC.decimals()); + _tokenD.mint(address(lCD), 200 * 10 ** _tokenD.decimals()); + lCD.mint(address(this)); + + _tokenB.mint(address(lBD), 100 * 10 ** _tokenB.decimals()); + _tokenD.mint(address(lBD), 200 * 10 ** _tokenD.decimals()); + lBD.mint(address(this)); + + _oracle.designatePair(lStart, lIntermediate1, lAC); + _oracle.designatePair(lIntermediate2, lIntermediate1, lCD); + _oracle.designatePair(lIntermediate2, lEnd, lBD); + + skip(1); + _pair.sync(); + lAC.sync(); + lCD.sync(); + lBD.sync(); + skip(_oracle.twapPeriod() * 2); + + // act + _oracle.updatePrice(address(_tokenA), address(_tokenB), address(this)); + + // assert + (uint256 lPriceAC,) = _oracle.priceCache(lStart, lIntermediate1); + (uint256 lPriceCD,) = _oracle.priceCache(lIntermediate1, lIntermediate2); + (uint256 lPriceBD,) = _oracle.priceCache(lEnd, lIntermediate2); + (uint256 lPriceAB,) = _oracle.priceCache(lStart, lEnd); + assertApproxEqRel(lPriceAC, 0.5e18, 0.0001e18); + assertApproxEqRel(lPriceCD, 2e18, 0.0001e18); + assertApproxEqRel(lPriceBD, 2e18, 0.0001e18); + assertEq(lPriceAB, 0); // composite price is not stored in the cache + } + + function testSetRoute() public { + // arrange + address lToken0 = address(_tokenB); + address lToken1 = address(_tokenC); + address[] memory lRoute = new address[](2); + lRoute[0] = lToken0; + lRoute[1] = lToken1; + + // act + vm.expectEmit(false, false, false, false); + emit Route(lToken0, lToken1, lRoute); + _oracle.setRoute(lToken0, lToken1, lRoute); + + // assert + address[] memory lQueriedRoute = _oracle.route(lToken0, lToken1); + assertEq(lQueriedRoute, lRoute); + (, int256 lDecimalDiff) = _oracle.priceCache(lToken0, lToken1); + int256 lActualDiff = int256(uint256(IERC20(lToken1).decimals())) - int256(uint256(IERC20(lToken0).decimals())); + assertEq(lDecimalDiff, lActualDiff); + } + + function testSetRoute_OverwriteExisting() external { + // arrange + testSetRoute(); + address lToken0 = address(_tokenB); + address lToken1 = address(_tokenC); + address[] memory lRoute = new address[](4); + lRoute[0] = lToken0; + lRoute[1] = address(_tokenA); + lRoute[2] = address(_tokenD); + lRoute[3] = lToken1; + + // act + _oracle.setRoute(lToken0, lToken1, lRoute); + + // assert + address[] memory lQueriedRoute = _oracle.route(lToken0, lToken1); + assertEq(lQueriedRoute, lRoute); + assertEq(lQueriedRoute.length, 4); + } + + function testSetRoute_MultipleHops() external { + // arrange + address lStart = address(_tokenA); + address lIntermediate1 = address(_tokenC); + address lIntermediate2 = address(_tokenB); + address lEnd = address(_tokenD); + address[] memory lRoute = new address[](4); + lRoute[0] = lStart; + lRoute[1] = lIntermediate1; + lRoute[2] = lIntermediate2; + lRoute[3] = lEnd; + + address[] memory lIntermediateRoute1 = new address[](2); + lIntermediateRoute1[0] = lStart; + lIntermediateRoute1[1] = lIntermediate1; + + // note that the seq should be reversed cuz lIntermediate2 < lIntermediate1 + address[] memory lIntermediateRoute2 = new address[](2); + lIntermediateRoute2[0] = lIntermediate2; + lIntermediateRoute2[1] = lIntermediate1; + + address[] memory lIntermediateRoute3 = new address[](2); + lIntermediateRoute3[0] = lIntermediate2; + lIntermediateRoute3[1] = lEnd; + + // act + vm.expectEmit(false, false, false, true); + emit Route(lIntermediate2, lEnd, lIntermediateRoute3); + vm.expectEmit(false, false, false, true); + emit Route(lStart, lIntermediate1, lIntermediateRoute1); + vm.expectEmit(true, true, true, true); + // note the reverse seq here as well + emit Route(lIntermediate2, lIntermediate1, lIntermediateRoute2); + vm.expectEmit(false, false, false, true); + emit Route(lStart, lEnd, lRoute); + _oracle.setRoute(lStart, lEnd, lRoute); + + // assert + assertEq(_oracle.route(lStart, lEnd), lRoute); + assertEq(_oracle.route(lStart, lIntermediate1), lIntermediateRoute1); + assertEq(_oracle.route(lIntermediate2, lIntermediate1), lIntermediateRoute2); + assertEq(_oracle.route(lIntermediate2, lEnd), lIntermediateRoute3); + } + + function testClearRoute() external { + // arrange + address lToken0 = address(_tokenB); + address lToken1 = address(_tokenC); + address[] memory lRoute = new address[](2); + lRoute[0] = lToken0; + lRoute[1] = lToken1; + _oracle.setRoute(lToken0, lToken1, lRoute); + address[] memory lQueriedRoute = _oracle.route(lToken0, lToken1); + assertEq(lQueriedRoute, lRoute); + _writePriceCache(lToken0, lToken1, 1e18); + + // act + vm.expectEmit(false, false, false, true); + emit Route(lToken0, lToken1, new address[](0)); + _oracle.clearRoute(lToken0, lToken1); + + // assert + lQueriedRoute = _oracle.route(lToken0, lToken1); + assertEq(lQueriedRoute, new address[](0)); + (uint256 lPrice,) = _oracle.priceCache(lToken0, lToken1); + assertEq(lPrice, 0); + } + + function testClearRoute_AllWordsCleared() external { + // arrange + address[] memory lRoute = new address[](4); + lRoute[0] = address(_tokenA); + lRoute[1] = address(_tokenC); + lRoute[2] = address(_tokenB); + lRoute[3] = address(_tokenD); + _oracle.setRoute(address(_tokenA), address(_tokenD), lRoute); + address[] memory lQueriedRoute = _oracle.route(address(_tokenA), address(_tokenD)); + assertEq(lQueriedRoute, lRoute); + bytes32 lSlot1 = address(_tokenA).calculateSlot(address(_tokenD)); + bytes32 lSlot2 = bytes32(uint256(lSlot1) + 1); + bytes32 lData = vm.load(address(_oracle), lSlot2); + assertNotEq(lData, 0); + + // act + vm.expectEmit(false, false, false, true); + emit Route(address(_tokenA), address(_tokenD), new address[](0)); + _oracle.clearRoute(address(_tokenA), address(_tokenD)); + + // assert + lQueriedRoute = _oracle.route(address(_tokenA), address(_tokenD)); + assertEq(lQueriedRoute, new address[](0)); + // intermediate routes should still remain + lQueriedRoute = _oracle.route(address(_tokenB), address(_tokenC)); + address[] memory lIntermediate1 = new address[](2); + lIntermediate1[0] = address(_tokenB); + lIntermediate1[1] = address(_tokenC); + assertEq(lQueriedRoute, lIntermediate1); + lQueriedRoute = _oracle.route(address(_tokenB), address(_tokenD)); + address[] memory lIntermediate2 = new address[](2); + lIntermediate2[0] = address(_tokenB); + lIntermediate2[1] = address(_tokenD); + assertEq(lQueriedRoute, lIntermediate2); + + // all used slots should be cleared + lData = vm.load(address(_oracle), lSlot1); + assertEq(lData, 0); + lData = vm.load(address(_oracle), lSlot2); + assertEq(lData, 0); + } + + function testGetTimeWeightedAverage() external { + // arrange + skip(60); + _pair.sync(); + skip(60); + _pair.sync(); + _oracle.designatePair(address(_tokenA), address(_tokenB), _pair); + OracleAverageQuery[] memory lQueries = new OracleAverageQuery[](1); + lQueries[0] = OracleAverageQuery(Variable.RAW_PRICE, address(_tokenA), address(_tokenB), 10, 0); + + // act + uint256[] memory lResults = _oracle.getTimeWeightedAverage(lQueries); + + // assert + assertEq(lResults[0], 98_918_868_099_219_913_512); + } + + function testGetLatest(uint32 aFastForward) public { + // assume - latest price should always be the same no matter how much time has elapsed + uint32 lFastForward = uint32(bound(aFastForward, 1, 2 ** 31 - 2)); + + // arrange + skip(lFastForward); + _pair.sync(); + _oracle.designatePair(address(_tokenA), address(_tokenB), _pair); + + // act + uint256 lLatestPrice = + _oracle.getLatest(OracleLatestQuery(Variable.RAW_PRICE, address(_tokenA), address(_tokenB))); + + // assert + assertEq(lLatestPrice, 98_918_868_099_219_913_512); + } + + function testGetPastAccumulators() external { + // arrange + skip(1 hours); + _pair.sync(); + skip(1 hours); + _pair.sync(); + skip(1 hours); + _pair.sync(); + _oracle.designatePair(address(_tokenA), address(_tokenB), _pair); + OracleAccumulatorQuery[] memory lQueries = new OracleAccumulatorQuery[](3); + lQueries[0] = OracleAccumulatorQuery(Variable.RAW_PRICE, address(_tokenA), address(_tokenB), 0); + lQueries[1] = OracleAccumulatorQuery(Variable.RAW_PRICE, address(_tokenA), address(_tokenB), 1 hours); + lQueries[2] = OracleAccumulatorQuery(Variable.RAW_PRICE, address(_tokenA), address(_tokenB), 2 hours); + + // act + int256[] memory lResults = _oracle.getPastAccumulators(lQueries); + + // assert + assertEq(lResults.length, lQueries.length); + vm.startPrank(address(_oracle)); + assertEq(lResults[0], _pair.observation(2).logAccRawPrice); + assertEq(lResults[1], _pair.observation(1).logAccRawPrice); + assertEq(lResults[2], _pair.observation(0).logAccRawPrice); + vm.stopPrank(); + } + + function testGetLargestSafeQueryWindow() external view { + // assert + assertEq(_oracle.getLargestSafeQueryWindow(), Buffer.SIZE); + } + + function testDesignatePair() external { + // act + vm.expectEmit(false, false, false, true); + emit DesignatePair(address(_tokenA), address(_tokenB), _pair); + _oracle.designatePair(address(_tokenA), address(_tokenB), _pair); + + // assert + assertEq(address(_oracle.pairs(address(_tokenA), address(_tokenB))), address(_pair)); + } + + function testDesignatePair_TokenOrderReversed() external { + // act + _oracle.designatePair(address(_tokenB), address(_tokenA), _pair); + + // assert + assertEq(address(_oracle.pairs(address(_tokenA), address(_tokenB))), address(_pair)); + assertEq(address(_oracle.pairs(address(_tokenB), address(_tokenA))), address(0)); + } + + function testUndesignatePair() external { + // arrange + _oracle.designatePair(address(_tokenA), address(_tokenB), _pair); + + // act + vm.expectEmit(false, false, false, true); + emit DesignatePair(address(_tokenA), address(_tokenB), ReservoirPair(address(0))); + _oracle.undesignatePair(address(_tokenA), address(_tokenB)); + + // assert + assertEq(address(_oracle.pairs(address(_tokenA), address(_tokenB))), address(0)); + } + + /////////////////////////////////////////////////////////////////////////////////////////////// + // ERROR CONDITIONS // + /////////////////////////////////////////////////////////////////////////////////////////////// + + function testGetLatest_Inverted() external { + // arrange + testGetLatest(5); + + // act & assert + vm.expectRevert(OracleErrors.NoDesignatedPair.selector); + _oracle.getLatest(OracleLatestQuery(Variable.RAW_PRICE, address(_tokenB), address(_tokenA))); + } + + function testGetPastAccumulators_Inverted() external { + // arrange + skip(1 hours); + _pair.sync(); + skip(1 hours); + _pair.sync(); + skip(1 hours); + _pair.sync(); + _oracle.designatePair(address(_tokenA), address(_tokenB), _pair); + OracleAccumulatorQuery[] memory lQueries = new OracleAccumulatorQuery[](3); + lQueries[0] = OracleAccumulatorQuery(Variable.RAW_PRICE, address(_tokenB), address(_tokenA), 0); + lQueries[1] = OracleAccumulatorQuery(Variable.RAW_PRICE, address(_tokenB), address(_tokenA), 1 hours); + lQueries[2] = OracleAccumulatorQuery(Variable.RAW_PRICE, address(_tokenB), address(_tokenA), 2 hours); + + // act & assert + vm.expectRevert(OracleErrors.NoDesignatedPair.selector); + _oracle.getPastAccumulators(lQueries); + } + + function testGetTimeWeightedAverage_Inverted() external { + // arrange + skip(60); + _pair.sync(); + skip(60); + _pair.sync(); + _oracle.designatePair(address(_tokenB), address(_tokenA), _pair); + OracleAverageQuery[] memory lQueries = new OracleAverageQuery[](1); + lQueries[0] = OracleAverageQuery(Variable.RAW_PRICE, address(_tokenB), address(_tokenA), 10, 0); + + // act & assert + vm.expectRevert(OracleErrors.NoDesignatedPair.selector); + _oracle.getTimeWeightedAverage(lQueries); + } + + function testDesignatePair_IncorrectPair() external { + // act & assert + vm.expectRevert(); + _oracle.designatePair(address(_tokenA), address(_tokenC), _pair); + } + + function testDesignatePair_NotOwner() external { + // act & assert + vm.prank(address(123)); + vm.expectRevert("UNAUTHORIZED"); + _oracle.designatePair(address(_tokenA), address(_tokenB), _pair); + } + + function testUndesignatePair_NotOwner() external { + // act & assert + vm.prank(address(123)); + vm.expectRevert("UNAUTHORIZED"); + _oracle.undesignatePair(address(_tokenA), address(_tokenB)); + } + + function testUpdateTwapPeriod_InvalidTwapPeriod(uint256 aNewPeriod) external { + // assume + uint64 lNewPeriod = uint64(bound(aNewPeriod, 1 hours + 1, type(uint64).max)); + + // act & assert + vm.expectRevert(OracleErrors.InvalidTwapPeriod.selector); + _oracle.updateTwapPeriod(lNewPeriod); + vm.expectRevert(OracleErrors.InvalidTwapPeriod.selector); + _oracle.updateTwapPeriod(0); + } + + function testUpdatePrice_PriceOutOfRange() external { + // arrange + ReservoirPair lPair = ReservoirPair(_factory.createPair(IERC20(address(_tokenB)), IERC20(address(_tokenC)), 0)); + _tokenB.mint(address(lPair), 1); + _tokenC.mint(address(lPair), type(uint104).max); + lPair.mint(address(this)); + + skip(10); + lPair.sync(); + skip(_oracle.twapPeriod() * 2); + lPair.sync(); + + address[] memory lRoute = new address[](2); + lRoute[0] = address(_tokenB); + lRoute[1] = address(_tokenC); + + _oracle.designatePair(address(_tokenB), address(_tokenC), lPair); + _oracle.setRoute(address(_tokenB), address(_tokenC), lRoute); + + // act & assert + vm.expectRevert( + abi.encodeWithSelector( + OracleErrors.PriceOutOfRange.selector, + 2_028_266_268_535_138_201_503_457_042_228_640_366_328_194_935_292_146_200_000 + ) + ); + _oracle.updatePrice(address(_tokenB), address(_tokenC), address(0)); + } + + function testSetRoute_SameToken() external { + // arrange + address lToken0 = address(0x1); + address lToken1 = address(0x1); + address[] memory lRoute = new address[](2); + lRoute[0] = lToken0; + lRoute[1] = lToken1; + + // act & assert + vm.expectRevert(OracleErrors.SameToken.selector); + _oracle.setRoute(lToken0, lToken1, lRoute); + } + + function testSetRoute_NotSorted() external { + // arrange + address lToken0 = address(0x21); + address lToken1 = address(0x2); + address[] memory lRoute = new address[](2); + lRoute[0] = lToken0; + lRoute[1] = lToken1; + + // act & assert + vm.expectRevert(OracleErrors.TokensUnsorted.selector); + _oracle.setRoute(lToken0, lToken1, lRoute); + } + + function testSetRoute_InvalidRouteLength() external { + // arrange + address lToken0 = address(0x1); + address lToken1 = address(0x2); + address[] memory lTooLong = new address[](5); + lTooLong[0] = lToken0; + lTooLong[1] = address(0); + lTooLong[2] = address(0); + lTooLong[3] = address(0); + lTooLong[4] = lToken1; + address[] memory lTooShort = new address[](1); + lTooShort[0] = lToken0; + + // act & assert + vm.expectRevert(OracleErrors.InvalidRouteLength.selector); + _oracle.setRoute(lToken0, lToken1, lTooLong); + + // act & assert + vm.expectRevert(OracleErrors.InvalidRouteLength.selector); + _oracle.setRoute(lToken0, lToken1, lTooShort); + } + + function testSetRoute_InvalidRoute() external { + // arrange + address lToken0 = address(0x1); + address lToken1 = address(0x2); + address[] memory lInvalidRoute1 = new address[](3); + lInvalidRoute1[0] = lToken0; + lInvalidRoute1[1] = lToken1; + lInvalidRoute1[2] = address(0); + + address[] memory lInvalidRoute2 = new address[](3); + lInvalidRoute2[0] = address(0); + lInvalidRoute2[1] = address(54); + lInvalidRoute2[2] = lToken1; + + // act & assert + vm.expectRevert(OracleErrors.InvalidRoute.selector); + _oracle.setRoute(lToken0, lToken1, lInvalidRoute1); + vm.expectRevert(OracleErrors.InvalidRoute.selector); + _oracle.setRoute(lToken0, lToken1, lInvalidRoute2); + } + + function testUpdateRewardGasAmount_NotOwner() external { + // act & assert + vm.prank(address(123)); + vm.expectRevert("UNAUTHORIZED"); + _oracle.updateRewardGasAmount(111); + } + + function testGetQuote_NoPath() external { + // act & assert + vm.expectRevert(OracleErrors.NoPath.selector); + _oracle.getQuote(123, address(123), address(456)); + } + + function testGetQuote_PriceZero() external { + // act & assert + vm.expectRevert(OracleErrors.PriceZero.selector); + _oracle.getQuote(32_111, address(_tokenA), address(_tokenB)); + } + + function testGetQuote_MultipleHops_PriceZero() external { + // arrange + testGetQuote_MultipleHops(); + _writePriceCache(address(_tokenB), address(_tokenC), 0); + + // act & assert + vm.expectRevert(OracleErrors.PriceZero.selector); + _oracle.getQuote(321_321, address(_tokenA), address(_tokenD)); + } + + function testGetQuote_AmountInTooLarge() external { + // arrange + uint256 lAmtIn = Constants.MAX_AMOUNT_IN + 1; + + // act & assert + vm.expectRevert(OracleErrors.AmountInTooLarge.selector); + _oracle.getQuote(lAmtIn, address(_tokenA), address(_tokenB)); + } +} diff --git a/test/unit/libraries/FlagsLib.t.sol b/test/unit/libraries/FlagsLib.t.sol new file mode 100644 index 0000000..e2f1021 --- /dev/null +++ b/test/unit/libraries/FlagsLib.t.sol @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: GPL-3.0-or-later +pragma solidity ^0.8.0; + +import { Test, console2, stdError } from "forge-std/Test.sol"; + +import { FlagsLib } from "src/libraries/FlagsLib.sol"; + +contract FlagsLibTest is Test { + using FlagsLib for bytes32; + using FlagsLib for int256; + + function testIsCompositeRoute() external pure { + // arrange + bytes32 lUninitialized = FlagsLib.FLAG_UNINITIALIZED; + bytes32 l1HopRoute = FlagsLib.FLAG_SIMPLE_PRICE; + bytes32 l2HopRoute = FlagsLib.FLAG_2_HOP_ROUTE; + bytes32 l3HopRoute = FlagsLib.FLAG_3_HOP_ROUTE; + + // act & assert + assertTrue(l2HopRoute.isCompositeRoute()); + assertTrue(l3HopRoute.isCompositeRoute()); + assertFalse(lUninitialized.isCompositeRoute()); + assertFalse(l1HopRoute.isCompositeRoute()); + } + + function testGetDecimalDifference() external pure { + // arrange + bytes32 lPositive = hex"0012"; + bytes32 lNegative = hex"00ee"; + bytes32 lZero = hex"0000"; + + // act & assert + assertEq(lPositive.getDecimalDifference(), 18); + assertEq(lNegative.getDecimalDifference(), -18); + assertEq(lZero.getDecimalDifference(), 0); + } + + function testPackSimplePrice(int8 aDiff, uint256 aPrice) external pure { + // assume + uint256 lPrice = bound(aPrice, 1, 1e36); + + // act + bytes32 lResult = int256(aDiff).packSimplePrice(lPrice); + + // assert + assertEq(lResult[0], FlagsLib.FLAG_SIMPLE_PRICE); + assertEq(lResult[1], bytes1(uint8(aDiff))); + assertEq(lResult.getPrice(), lPrice); + } +} diff --git a/test/unit/libraries/QueryProcessor.t.sol b/test/unit/libraries/QueryProcessor.t.sol new file mode 100644 index 0000000..e9ec4f6 --- /dev/null +++ b/test/unit/libraries/QueryProcessor.t.sol @@ -0,0 +1,459 @@ +// SPDX-License-Identifier: GPL-3.0-or-later +pragma solidity ^0.8.0; + +import { BaseTest, FactoryStoreLib, GenericFactory } from "test/__fixtures/BaseTest.t.sol"; + +import { Buffer, OracleErrors } from "src/libraries/QueryProcessor.sol"; +import { QueryProcessorWrapper, ReservoirPair, Observation, Variable } from "test/wrapper/QueryProcessorWrapper.sol"; + +contract QueryProcessorTest is BaseTest { + using FactoryStoreLib for GenericFactory; + using Buffer for uint16; + + QueryProcessorWrapper internal _queryProcessor = new QueryProcessorWrapper(); + + constructor() { + _factory.write("Shared::oracleCaller", address(_queryProcessor)); + _pair.updateOracleCaller(); + } + + // TODO: test both negative and positive acc values + // i.e. accumulator value keeps getting more negative and more positive + + function _fillBuffer(uint256 aBlockTime, uint256 aObservationsToWrite) internal { + for (uint256 i = 0; i < aObservationsToWrite; ++i) { + skip(aBlockTime); + _pair.sync(); + } + } + + function _writeObservation( + ReservoirPair aPair, + uint256 aIndex, + int24 aLogInstantRawPrice, + int24 aLogInstantClampedPrice, + int88 aLogAccRawPrice, + int88 aLogAccClampedPrice, + uint32 aTime + ) internal { + require(aTime < 2 ** 31, "TIMESTAMP TOO BIG"); + bytes32 lEncoded = bytes32( + bytes.concat( + bytes4(aTime), + bytes11(uint88(aLogAccClampedPrice)), + bytes11(uint88(aLogAccRawPrice)), + bytes3(uint24(aLogInstantClampedPrice)), + bytes3(uint24(aLogInstantRawPrice)) + ) + ); + + vm.prank(address(_queryProcessor)); + vm.record(); + aPair.observation(aIndex); + (bytes32[] memory lAccesses,) = vm.accesses(address(aPair)); + require(lAccesses.length == 2, "invalid number of accesses"); + + vm.store(address(aPair), lAccesses[1], lEncoded); + } + + modifier setAccumulatorPositive(bool aIsPositive) { + _; + } + + modifier randomizeStartTime(uint32 aNewStartTime) { + vm.assume(aNewStartTime > 1 && aNewStartTime < 2 ** 31 / 2); + + vm.warp(aNewStartTime); + _; + } + + function testGetInstantValue() external { + // arrange + skip(123); + _tokenB.mint(address(_pair), 105e18); + _pair.swap(-105e18, true, address(this), ""); + + // act + uint256 lInstantRawPrice = _queryProcessor.getInstantValue(_pair, Variable.RAW_PRICE, 0); + uint256 lInstantClampedPrice = _queryProcessor.getInstantValue(_pair, Variable.CLAMPED_PRICE, 0); + + // assert - instant price should be the new price after swap, not the price before swap + assertApproxEqRel(lInstantRawPrice, 100e18, 0.01e18); + assertApproxEqRel(lInstantClampedPrice, 100e18, 0.01e18); + } + + function testGetTimeWeightedAverage( + uint32 aStartTime, + uint256 aBlockTime, + uint256 aObservationsToWrite, + uint256 aSecs, + uint256 aAgo + ) external randomizeStartTime(aStartTime) { + // assume + uint256 lBlockTime = bound(aBlockTime, 1, 60); + uint16 lObservationsToWrite = uint16(bound(aObservationsToWrite, 3, Buffer.SIZE * 3)); + uint256 lSecs = bound(aSecs, 1, 1 hours); + uint256 lAgo = bound(aAgo, 0, 1 hours); + + // ensure that the query window is within what is still available in the buffer + // the fact that we potentially go around the buffer more than one means that maybe the query window's + // samples have been overwritten. Thus the need for the modulus. + vm.assume(lSecs + lAgo <= (lBlockTime * (lObservationsToWrite - 1)) % (lBlockTime * Buffer.SIZE)); + + // arrange - perform some swaps + uint256 lSwapAmt = 1e6; + for (uint256 i = 0; i < lObservationsToWrite; ++i) { + skip(lBlockTime); + _tokenA.mint(address(_pair), lSwapAmt); + _pair.swap(int256(lSwapAmt), true, address(this), ""); + } + + // act + (,,, uint16 lLatestIndex) = _pair.getReserves(); + uint256 lAveragePrice = + _queryProcessor.getTimeWeightedAverage(_pair, Variable.RAW_PRICE, lSecs, lAgo, lLatestIndex); + + // assert + // as it is hard to calc the exact average price given so many fuzz parameters, we just assert that the price should be within a range + uint256 lStartingPrice = 98.9223e18; + uint256 lEndingPrice = _queryProcessor.getInstantValue(_pair, Variable.RAW_PRICE, lLatestIndex); + assertLt(lAveragePrice, lStartingPrice); + assertGt(lAveragePrice, lEndingPrice); + } + + function testGetTimeWeightedAverage_AccumulatorOverflow() external { + // arrange + uint256 lSecs = 600; // 10 minutes + skip(10); + _pair.sync(); // write first observation at index 0 + uint32 lNow = uint32(block.timestamp); + _writeObservation(_pair, 0, 0, 0, type(int88).max, type(int88).max, lNow); + + skip(lSecs); + _pair.sync(); // write second observation at index 1 + (,,, uint16 lLatestIndex) = _pair.getReserves(); + // sanity + assertEq(lLatestIndex, 1); + lNow = uint32(block.timestamp); + _writeObservation(_pair, 1, 0, 0, type(int88).min + 3000, type(int88).min + 3000, lNow); + + // act + uint256 lAgo = 0; + uint256 lResult = _queryProcessor.getTimeWeightedAverage(_pair, Variable.RAW_PRICE, lSecs, lAgo, lLatestIndex); + + // assert6 + assertEq(lResult, 0); + } + + function testGetPastAccumulator_ExactMatch( + uint32 aStartTime, + uint256 aBlockTime, + uint256 aObservationsToWrite, + uint16 aBlocksAgo + ) external randomizeStartTime(aStartTime) { + // assume + uint256 lBlockTime = bound(aBlockTime, 1, 60); + uint16 lObservationsToWrite = uint16(bound(aObservationsToWrite, 3, Buffer.SIZE * 3)); // go around it 3 times maximum + uint16 lBlocksAgo = uint16(bound(aBlocksAgo, 0, lObservationsToWrite.sub(1))); + + // arrange + _fillBuffer(lBlockTime, lObservationsToWrite); + (,,, uint16 lIndex) = _pair.getReserves(); + + // act + uint256 lAgo = lBlockTime * lBlocksAgo; + int256 lAcc = _queryProcessor.getPastAccumulator(_pair, Variable.RAW_PRICE, lIndex, lAgo); + + // assert + uint256 lDesiredIndex = lIndex.sub(lBlocksAgo); + vm.prank(address(_queryProcessor)); + Observation memory lObs = _pair.observation(lDesiredIndex); + assertEq(lAcc, lObs.logAccRawPrice); + } + + function testGetPastAccumulator_ExactMatch_LatestAccumulator( + uint32 aStartTime, + uint256 aBlockTime, + uint256 aObservationsToWrite + ) external randomizeStartTime(aStartTime) { + // assume + uint256 lBlockTime = bound(aBlockTime, 1, 60); + uint16 lObservationsToWrite = uint16(bound(aObservationsToWrite, 3, Buffer.SIZE * 3)); + + // arrange + _fillBuffer(lBlockTime, lObservationsToWrite); + (,,, uint16 lIndex) = _pair.getReserves(); + + // act + int256 lAcc = _queryProcessor.getPastAccumulator(_pair, Variable.RAW_PRICE, lIndex, 0); + + // assert + vm.prank(address(_queryProcessor)); + Observation memory lObs = _pair.observation(lIndex); + assertEq(lAcc, lObs.logAccRawPrice); + } + + function testGetPastAccumulator_ExactMatch_OldestAccumulator( + uint32 aStartTime, + uint256 aBlockTime, + uint256 aObservationsToWrite + ) external randomizeStartTime(aStartTime) { + // assume + uint256 lBlockTime = bound(aBlockTime, 1, 30); + uint16 lObservationsToWrite = uint16(bound(aObservationsToWrite, 3, Buffer.SIZE * 3)); // go around it 3 times maximum + + // arrange + uint256 lStartTime = block.timestamp; + _fillBuffer(lBlockTime, lObservationsToWrite); + (,,, uint16 lIndex) = _pair.getReserves(); + + // act + vm.startPrank(address(_queryProcessor)); + uint256 lAgo = lObservationsToWrite > Buffer.SIZE + ? block.timestamp - _pair.observation(lIndex.next()).timestamp + : block.timestamp - (lStartTime + lBlockTime); + int256 lAcc = _queryProcessor.getPastAccumulator(_pair, Variable.RAW_PRICE, lIndex, lAgo); + + // assert + Observation memory lObs = _pair.observation(lObservationsToWrite > Buffer.SIZE ? lIndex.next() : 0); + assertEq(lAcc, lObs.logAccRawPrice); + vm.stopPrank(); + } + + function testGetPastAccumulator_InterpolatesBetweenPastAccumulators( + uint32 aStartTime, + uint256 aBlockTime, + uint256 aObservationsToWrite, + uint256 aRandomSlot + ) external randomizeStartTime(aStartTime) { + // assume + uint256 lBlockTime = bound(aBlockTime, 3, 60); + uint16 lObservationsToWrite = uint16(bound(aObservationsToWrite, 3, Buffer.SIZE * 3)); + uint16 lRandomSlot = uint16(bound(aRandomSlot, 0, lObservationsToWrite.sub(2))); + + // arrange + _fillBuffer(lBlockTime, lObservationsToWrite); + (,,, uint16 lIndex) = _pair.getReserves(); + + // act + vm.startPrank(address(_queryProcessor)); + Observation memory lPrevObs = _pair.observation(lRandomSlot); + uint256 lWantedTimestamp = lPrevObs.timestamp + lBlockTime / 2; + uint256 lAgo = block.timestamp - lWantedTimestamp; + int256 lAcc = _queryProcessor.getPastAccumulator(_pair, Variable.RAW_PRICE, lIndex, lAgo); + + // assert + Observation memory lNextObs = _pair.observation(lRandomSlot.next()); + int256 lAccDiff = lNextObs.logAccRawPrice - lPrevObs.logAccRawPrice; + assertGt(lNextObs.timestamp, lWantedTimestamp); + assertLt(lPrevObs.timestamp, lWantedTimestamp); + assertEq(lAcc, lPrevObs.logAccRawPrice + lAccDiff * int256(lBlockTime / 2) / int256(lBlockTime)); + vm.stopPrank(); + } + + function testGetPastAccumulator_ExtrapolatesBeyondLatest( + uint32 aStartTime, + uint256 aBlockTime, + uint256 aObservationsToWrite, + uint256 aTimeBeyondLatest + ) external randomizeStartTime(aStartTime) { + // assume + uint256 lBlockTime = bound(aBlockTime, 1, 30); + uint16 lObservationsToWrite = uint16(bound(aObservationsToWrite, 3, Buffer.SIZE * 3)); // go around it 3 times maximum + uint256 lTimeBeyondLatest = bound(aTimeBeyondLatest, 1, 90 days); + + // arrange + _fillBuffer(lBlockTime, lObservationsToWrite); + skip(lTimeBeyondLatest); + (,,, uint16 lIndex) = _pair.getReserves(); + + // act + int256 lAcc = _queryProcessor.getPastAccumulator(_pair, Variable.RAW_PRICE, lIndex, 0); + + // assert + vm.prank(address(_queryProcessor)); + Observation memory lObs = _pair.observation(lIndex); + if (lAcc > 0) { + assertGt(lAcc, lObs.logAccRawPrice); + } else { + assertLt(lAcc, lObs.logAccRawPrice); + } + assertEq(lAcc, lObs.logAccRawPrice + int256(lTimeBeyondLatest) * lObs.logInstantRawPrice); + } + + function testFindNearestSample_CanFindExactValue( + uint32 aStartTime, + uint256 aBlockTime, + uint256 aObservationsToWrite, + uint256 aRandomSlot + ) external randomizeStartTime(aStartTime) { + // assume + uint256 lBlockTime = bound(aBlockTime, 1, 30); + uint16 lObservationsToWrite = uint16(bound(aObservationsToWrite, 2, Buffer.SIZE * 3)); // go around it 3 times maximum + uint256 lRandomSlot = bound(aRandomSlot, 0, lObservationsToWrite.sub(1)); + + // arrange + _fillBuffer(lBlockTime, lObservationsToWrite); + + // act + vm.prank(address(_queryProcessor)); + uint256 lLookupTime = _pair.observation(lRandomSlot).timestamp; + uint16 lOffset = lObservationsToWrite > Buffer.SIZE ? lObservationsToWrite % Buffer.SIZE : 0; + uint16 lBufferLength = lObservationsToWrite > Buffer.SIZE ? Buffer.SIZE : lObservationsToWrite; + (Observation memory prev, Observation memory next) = + _queryProcessor.findNearestSample(_pair, lLookupTime, lOffset, lBufferLength); + + // assert + assertEq(prev.timestamp, next.timestamp, "prev.timestamp != next.timestamp"); + assertEq(prev.timestamp, lLookupTime); + } + + function testFindNearestSample_CanFindIntermediateValue( + uint32 aStartTime, + uint256 aBlockTime, + uint256 aObservationsToWrite, + uint256 aRandomSlot + ) external randomizeStartTime(aStartTime) { + // assume + uint256 lBlockTime = bound(aBlockTime, 3, 60); + uint16 lObservationsToWrite = uint16(bound(aObservationsToWrite, 3, Buffer.SIZE * 3)); // go around it 3 times maximum + uint256 lRandomSlot = bound(aRandomSlot, 0, lObservationsToWrite.sub(2)); // can't be the latest one as lookupTime will go beyond + + // arrange + _fillBuffer(lBlockTime, lObservationsToWrite); + + // act + vm.prank(address(_queryProcessor)); + uint256 lLookupTime = _pair.observation(lRandomSlot).timestamp + lBlockTime / 2; + uint16 lOffset = lObservationsToWrite > Buffer.SIZE ? lObservationsToWrite % Buffer.SIZE : 0; + uint16 lBufferLength = lObservationsToWrite > Buffer.SIZE ? Buffer.SIZE : lObservationsToWrite; + (Observation memory prev, Observation memory next) = + _queryProcessor.findNearestSample(_pair, lLookupTime, lOffset, lBufferLength); + + // assert + assertEq(prev.timestamp + lBlockTime, next.timestamp, "next is not prev + blocktime"); + assertNotEq(prev.timestamp, lLookupTime, "prev eq lookup"); + assertNotEq(next.timestamp, lLookupTime, "next eq lookup"); + assertGt(lLookupTime, prev.timestamp); + assertLt(lLookupTime, next.timestamp); + } + + function testFindNearestSample_OneSample(uint256 aBlockTime) external { + // assume + uint256 lBlockTime = bound(aBlockTime, 1, 60); + + // arrange + _fillBuffer(lBlockTime, 1); + + // act + (Observation memory prev, Observation memory next) = + _queryProcessor.findNearestSample(_pair, block.timestamp, 0, 1); + + // assert + assertEq(prev.timestamp, next.timestamp); + assertGt(prev.logAccRawPrice, 0); + assertGt(prev.timestamp, 0); + } + + /////////////////////////////////////////////////////////////////////////////////////////////// + // ERROR CONDITIONS // + /////////////////////////////////////////////////////////////////////////////////////////////// + + function testGetInstantValue_NotInitialized(uint256 aIndex) external { + // act & assert + vm.expectRevert(OracleErrors.OracleNotInitialized.selector); + _queryProcessor.getInstantValue(_pair, Variable.RAW_PRICE, aIndex); + } + + function testGetInstantValue_NotInitialized_BeyondBufferSize(uint8 aVariable, uint16 aIndex) external { + // assume + Variable lVar = Variable(bound(aVariable, 0, 1)); + uint16 lIndex = uint16(bound(aIndex, Buffer.SIZE, type(uint16).max)); + + // arrange - fill up buffer size + _fillBuffer(5, Buffer.SIZE); + + // act & assert - should revert for all indexes that are beyond the bounds of buffer + vm.expectRevert(OracleErrors.OracleNotInitialized.selector); + _queryProcessor.getInstantValue(_pair, lVar, lIndex); + } + + function testGetPastAccumulator_BufferEmpty(uint8 aVariable) external { + // assume + Variable lVar = Variable(bound(aVariable, 0, 1)); + + // arrange + (,,, uint16 lIndex) = _pair.getReserves(); + + // act & assert + vm.expectRevert(OracleErrors.OracleNotInitialized.selector); + _queryProcessor.getPastAccumulator(_pair, lVar, lIndex, 0); + } + + function testGetPastAccumulator_InvalidAgo( + uint32 aStartTime, + uint256 aBlockTime, + uint256 aObservationsToWrite, + uint256 aAgo + ) external randomizeStartTime(aStartTime) { + // assume + uint256 lBlockTime = bound(aBlockTime, 3, 60); + uint16 lObservationsToWrite = uint16(bound(aObservationsToWrite, 3, Buffer.SIZE * 3)); + uint256 lAgo = bound(aAgo, aStartTime + lBlockTime * lObservationsToWrite + 1, type(uint256).max); + + // arrange + _fillBuffer(lBlockTime, lObservationsToWrite); + (,,, uint16 lIndex) = _pair.getReserves(); + + // act & assert + vm.expectRevert(OracleErrors.InvalidSeconds.selector); + _queryProcessor.getPastAccumulator(_pair, Variable.RAW_PRICE, lIndex, lAgo); + } + + function testGetPastAccumulator_QueryTooOld( + uint32 aStartTime, + uint256 aBlockTime, + uint256 aObservationsToWrite, + uint256 aAgo + ) external randomizeStartTime(aStartTime) { + // assume + uint256 lBlockTime = bound(aBlockTime, 3, 60); + uint16 lObservationsToWrite = uint16(bound(aObservationsToWrite, 3, Buffer.SIZE * 3)); + + // arrange + _fillBuffer(lBlockTime, lObservationsToWrite); + (,,, uint16 lIndex) = _pair.getReserves(); + uint256 lOldestSample = lObservationsToWrite > Buffer.SIZE ? lIndex.next() : 0; + vm.prank(address(_queryProcessor)); + uint256 lAgo = bound( + aAgo, + block.timestamp - _pair.observation(lOldestSample).timestamp + 1, + aStartTime + lBlockTime * lObservationsToWrite + ); + + // act & assert + vm.expectRevert(OracleErrors.QueryTooOld.selector); + _queryProcessor.getPastAccumulator(_pair, Variable.RAW_PRICE, lIndex, lAgo); + } + + // technically this should never happen in production as `getPastAccumulator` would have reverted with the + // `OracleNotInitialized` error if the oracle is not initialized + // so the expected revert is one of running out of gas, having done too many iterations as the subtraction has underflown + // due to supplying a buffer length of 0 + function testFindNearestSample_NotInitialized() external { + // arrange + uint256 lLookupTime = 123; + uint16 lOffset = 0; + uint16 lBufferLength = 0; + + // act & assert + vm.expectRevert(); + _queryProcessor.findNearestSample(_pair, lLookupTime, lOffset, lBufferLength); + } + + function testGetTimeWeightedAverage_BadSecs() external { + // act & assert + vm.expectRevert(OracleErrors.BadSecs.selector); + _queryProcessor.getTimeWeightedAverage(_pair, Variable.RAW_PRICE, 0, 0, 0); + } +} diff --git a/test/unit/libraries/Samples.t.sol b/test/unit/libraries/Samples.t.sol new file mode 100644 index 0000000..692e45d --- /dev/null +++ b/test/unit/libraries/Samples.t.sol @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: GPL-3.0-or-later +pragma solidity ^0.8.0; + +import { Test, console2, stdError } from "forge-std/Test.sol"; + +import { Samples, Observation, Variable } from "src/libraries/Samples.sol"; + +contract SamplesTest is Test { + using Samples for Observation; + + function testInstant() external pure { + // arrange + Observation memory lObs = Observation(-123, -456, 3, 4, 5); + + // act + int256 lInstantRawPrice = lObs.instant(Variable.RAW_PRICE); + int256 lInstantClampedPrice = lObs.instant(Variable.CLAMPED_PRICE); + + // assert + assertEq(lInstantRawPrice, -123); + assertEq(lInstantClampedPrice, -456); + } + + function testInstant_BadVariableRequest() external { + // would like to test the revert behavior when passing an invalid enum + // but solidity has a check to prevent casting a uint that is out of range of the enum + vm.expectRevert(stdError.enumConversionError); + Variable(uint256(5)); + } + + function testAccumulator() external pure { + // arrange + Observation memory lObs = Observation(-789, -569, -401, -1238, 5); + + // act + int256 lAccRawPrice = lObs.accumulator(Variable.RAW_PRICE); + int256 lAccClampedPrice = lObs.accumulator(Variable.CLAMPED_PRICE); + + // assert + assertEq(lAccRawPrice, -401); + assertEq(lAccClampedPrice, -1238); + } + + function testAccumulator_BadVariableRequest() external { + // would like to test the revert behavior when passing an invalid enum + // but solidity has a check to prevent casting a uint that is out of range of the enum + vm.expectRevert(stdError.enumConversionError); + Variable(uint256(5)); + } +} diff --git a/test/wrapper/QueryProcessorWrapper.sol b/test/wrapper/QueryProcessorWrapper.sol new file mode 100644 index 0000000..2644fc9 --- /dev/null +++ b/test/wrapper/QueryProcessorWrapper.sol @@ -0,0 +1,36 @@ +// SPDX-License-Identifier: GPL-3.0-or-later +pragma solidity ^0.8.0; + +import { QueryProcessor, ReservoirPair, Variable, Observation } from "src/libraries/QueryProcessor.sol"; + +contract QueryProcessorWrapper { + function getInstantValue(ReservoirPair pair, Variable variable, uint256 index) external view returns (uint256) { + return QueryProcessor.getInstantValue(pair, variable, index); + } + + function getTimeWeightedAverage( + ReservoirPair pair, + Variable variable, + uint256 secs, + uint256 ago, + uint16 latestIndex + ) external view returns (uint256) { + return QueryProcessor.getTimeWeightedAverage(pair, variable, secs, ago, latestIndex); + } + + function getPastAccumulator(ReservoirPair pair, Variable variable, uint16 latestIndex, uint256 ago) + external + view + returns (int256) + { + return QueryProcessor.getPastAccumulator(pair, variable, latestIndex, ago); + } + + function findNearestSample(ReservoirPair pair, uint256 lookUpDate, uint16 offset, uint16 length) + external + view + returns (Observation memory prev, Observation memory next) + { + return QueryProcessor.findNearestSample(pair, lookUpDate, offset, length); + } +}