diff --git a/packages/contracts/test/src/multi/Frequencies.s.sol b/packages/contracts/test/src/multi/Frequencies.s.sol new file mode 100644 index 000000000..a75e05d78 --- /dev/null +++ b/packages/contracts/test/src/multi/Frequencies.s.sol @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.23; + +library Frequencies { + error InvalidFrequency(Frequency frequency); + + enum Frequency { + ANNUAL, + MONTHLY, + QUARTERLY, + DAYS_360, + DAYS_365 + } + + // Helper function to convert enum value to corresponding uint256 frequency + function toValue(Frequency frequency) external pure returns (uint256) { + if (frequency == Frequency.ANNUAL) return 1; + if (frequency == Frequency.MONTHLY) return 12; + if (frequency == Frequency.QUARTERLY) return 4; + if (frequency == Frequency.DAYS_360) return 360; + if (frequency == Frequency.DAYS_365) return 365; + + revert InvalidFrequency(frequency); + } +} diff --git a/packages/contracts/test/src/multi/IERC4626Interest.s.sol b/packages/contracts/test/src/multi/IERC4626Interest.s.sol new file mode 100644 index 000000000..1fe70524c --- /dev/null +++ b/packages/contracts/test/src/multi/IERC4626Interest.s.sol @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// OpenZeppelin Contracts (last updated v5.0.0) (interfaces/IERC4626.sol) + +pragma solidity ^0.8.20; + +import { IERC4626 } from "@openzeppelin/contracts/interfaces/IERC4626.sol"; +import { ISimpleInterest } from "./ISimpleInterest.s.sol"; + +/** + * @dev Extension to Interface Vault Standard + */ +interface IERC4626Interest is IERC4626, ISimpleInterest { + function convertToSharesAtPeriod(uint256 assets, uint256 numTimePeriodsElapsed) + external + view + returns (uint256 shares); + + function convertToAssetsAtPeriod(uint256 shares, uint256 numTimePeriodsElapsed) + external + view + returns (uint256 assets); + + // TODO - confirm if required on interface + function getCurrentTimePeriodsElapsed() external pure returns (uint256 currentTimePeriodsElapsed); + + // TODO - confirm if required on interface + function setCurrentTimePeriodsElapsed(uint256 currentTimePeriodsElapsed) external; + + // TODO - confirm if required on interface + function getTenor() external view returns (uint256 tenor); +} diff --git a/packages/contracts/test/src/multi/ISimpleInterest.s.sol b/packages/contracts/test/src/multi/ISimpleInterest.s.sol new file mode 100644 index 000000000..696c4298c --- /dev/null +++ b/packages/contracts/test/src/multi/ISimpleInterest.s.sol @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: MIT + +pragma solidity ^0.8.20; + +/** + * @title Simple Interest Interface + * @dev This interface provides functions to calculate interest and principal amounts over time. + * + * @notice The `calcPrincipalFromDiscounted` and `calcDiscounted` functions are designed to be mathematical inverses of each other. + * This means that applying `calcPrincipalFromDiscounted` to the output of `calcDiscounted` will return the original principal amount. + * + * For example: + * ``` + * uint256 originalPrincipal = 1000; + * uint256 discountedValue = calcDiscounted(originalPrincipal); + * uint256 recoveredPrincipal = calcPrincipalFromDiscounted(discountedValue); + * assert(recoveredPrincipal == originalPrincipal); + * ``` + * + * This property ensures that no information is lost when discounting and then recovering the principal, + * making the system consistent and predictable. + */ +interface ISimpleInterest { + function calcInterest(uint256 principal, uint256 numTimePeriodsElapsed) external view returns (uint256 interest); + + function calcDiscounted(uint256 principal, uint256 numTimePeriodsElapsed) + external + view + returns (uint256 discounted); + + function calcPrincipalFromDiscounted(uint256 discounted, uint256 numTimePeriodsElapsed) + external + view + returns (uint256 principal); + + function getFrequency() external view returns (uint256 frequency); + + function getInterestInPercentage() external view returns (uint256 interestRateInPercentage); +} diff --git a/packages/contracts/test/src/multi/InterestTest.t.sol b/packages/contracts/test/src/multi/InterestTest.t.sol new file mode 100644 index 000000000..a0a236a1b --- /dev/null +++ b/packages/contracts/test/src/multi/InterestTest.t.sol @@ -0,0 +1,88 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.23; + +import { ISimpleInterest } from "./ISimpleInterest.s.sol"; + +import { IERC20 } from "@openzeppelin/contracts/interfaces/IERC20.sol"; +import { Math } from "@openzeppelin/contracts/utils/math/Math.sol"; + +import { Test } from "forge-std/Test.sol"; +import { console2 } from "forge-std/console2.sol"; + +abstract contract InterestTest is Test { + uint256 public constant TOLERANCE = 500; // with 18 decimals, means allowed difference of 5E+16 + uint256 public constant NUM_CYCLES_TO_TEST = 2; // number of cycles in test (e.g. 2 years, 24 months, 720 days) + + uint256 public constant SCALE = 1 * 10 ** 18; // number of cycles in test (e.g. 2 years, 24 months, 720 days) + + using Math for uint256; + + function testInterestToMaxPeriods(uint256 principal, ISimpleInterest simpleInterest) internal { + uint256 maxNumPeriods = simpleInterest.getFrequency() * NUM_CYCLES_TO_TEST; // e.g. 2 years, 24 months, 720 days + + // due to small fractional numbers, principal needs to be SCALED to calculate correctly + assertGe(principal, SCALE, "principal not in SCALE"); + + // check all periods for 24 months + for (uint256 numTimePeriods = 0; numTimePeriods <= maxNumPeriods; numTimePeriods++) { + testInterestAtPeriod(principal, simpleInterest, numTimePeriods); + } + } + + function testInterestAtPeriod(uint256 principal, ISimpleInterest simpleInterest, uint256 numTimePeriods) + internal + virtual + { + console2.log("---------------------- simpleInterestTestHarness ----------------------"); + + // The `calcPrincipalFromDiscounted` and `calcDiscounted` functions are designed to be mathematical inverses of each other. + // This means that applying `calcPrincipalFromDiscounted` to the output of `calcDiscounted` will return the original principal amount. + + uint256 discounted = simpleInterest.calcDiscounted(principal, numTimePeriods); + uint256 principalFromDiscounted = simpleInterest.calcPrincipalFromDiscounted(discounted, numTimePeriods); + + assertApproxEqAbs( + principal, + principalFromDiscounted, + TOLERANCE, + assertMsg("principalFromDiscountW not inverse of principalInWei", simpleInterest, numTimePeriods) + ); + + // discountedFactor = principal - interest, therefore interest = principal - discountedFactor + assertApproxEqAbs( + principal - discounted, + simpleInterest.calcInterest(principal, numTimePeriods), + 10, // even smaller tolerance here + assertMsg("calcInterest incorrect for ", simpleInterest, numTimePeriods) + ); + } + + function assertMsg(string memory prefix, ISimpleInterest simpleInterest, uint256 numTimePeriods) + internal + view + returns (string memory) + { + return string.concat(prefix, toString(simpleInterest), " timePeriod= ", vm.toString(numTimePeriods)); + } + + function toString(ISimpleInterest simpleInterest) internal view returns (string memory) { + return string.concat( + " ISimpleInterest [ ", + " IR = ", + vm.toString(simpleInterest.getInterestInPercentage()), + " Freq = ", + vm.toString(simpleInterest.getFrequency()), + " ] " + ); + } + + function transferAndAssert(IERC20 _token, address fromAddress, address toAddress, uint256 amount) internal { + uint256 beforeBalance = _token.balanceOf(toAddress); + + vm.startPrank(fromAddress); + _token.transfer(toAddress, amount); + vm.stopPrank(); + + assertEq(beforeBalance + amount, _token.balanceOf(toAddress)); + } +} diff --git a/packages/contracts/test/src/multi/SimpleInterest.s.sol b/packages/contracts/test/src/multi/SimpleInterest.s.sol new file mode 100644 index 000000000..c667a5f80 --- /dev/null +++ b/packages/contracts/test/src/multi/SimpleInterest.s.sol @@ -0,0 +1,165 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.23; + +import { Math } from "@openzeppelin/contracts/utils/math/Math.sol"; +import { Strings } from "@openzeppelin/contracts/utils/Strings.sol"; + +import { console2 } from "forge-std/console2.sol"; +import { ISimpleInterest } from "./ISimpleInterest.s.sol"; + +/** + * https://en.wikipedia.org/wiki/Interest + * + * Simple interest is calculated only on the principal amount, or on that portion of the principal amount that remains. + * It excludes the effect of compounding. Simple interest can be applied over a time period other than a year, for example, every month. + * + * Simple interest is calculated according to the following formula: (IR * P * m) / f + * - IR is the simple annual interest rate + * - P is the Principal (aka initial amount) + * - m is the number of time periods elapsed + * - f is the frequency of applying interest (how many interest periods in a year) + * + * + * @notice The `calcPrincipalFromDiscounted` and `calcDiscounted` functions are designed to be mathematical inverses of each other. + * This means that applying `calcPrincipalFromDiscounted` to the output of `calcDiscounted` will return the original principal amount. + * + * For example: + * ``` + * uint256 originalPrincipal = 1000; + * uint256 discountedValue = calcDiscounted(originalPrincipal); + * uint256 recoveredPrincipal = calcPrincipalFromDiscounted(discountedValue); + * assert(recoveredPrincipal == originalPrincipal); + * ``` + * + * This property ensures that no information is lost when discounting and then recovering the principal, + * making the system consistent and predictable. + * + */ +contract SimpleInterest is ISimpleInterest { + using Math for uint256; + + uint256 public immutable INTEREST_RATE_PERCENTAGE; + uint256 public immutable FREQUENCY; + + uint256 public constant DECIMALS = 18; + uint256 public constant SCALE = 10 ** DECIMALS; + + uint256 public immutable PAR = 1; + + Math.Rounding public constant ROUNDING = Math.Rounding.Floor; + + error PrincipalLessThanScale(uint256 principal, uint256 scale); + + constructor(uint256 interestRatePercentage, uint256 frequency) { + INTEREST_RATE_PERCENTAGE = interestRatePercentage; + FREQUENCY = frequency; + } + + function calcInterest(uint256 principal, uint256 numTimePeriodsElapsed) + public + view + virtual + returns (uint256 interest) + { + if (principal < SCALE) { + revert PrincipalLessThanScale(principal, SCALE); + } + + uint256 interestScaled = + principal.mulDiv(INTEREST_RATE_PERCENTAGE * numTimePeriodsElapsed * SCALE, FREQUENCY * 100, ROUNDING); + + console2.log( + string.concat( + "Interest = (IR * P * m) / f = ", + Strings.toString(INTEREST_RATE_PERCENTAGE), + "% * ", + Strings.toString(principal), + " * ", + Strings.toString(numTimePeriodsElapsed), + " / ", + Strings.toString(FREQUENCY), + " = ", + Strings.toString(interestScaled) + ) + ); + + return unscale(interestScaled); + } + + function _calcInterestWithScale(uint256 principal, uint256 numTimePeriodsElapsed) + internal + view + returns (uint256 _interestScaled) + { + uint256 interestScaled = + principal.mulDiv(INTEREST_RATE_PERCENTAGE * numTimePeriodsElapsed * SCALE, FREQUENCY * 100, ROUNDING); + + console2.log( + string.concat( + "Interest = (IR * P * m) / f = ", + Strings.toString(INTEREST_RATE_PERCENTAGE), + "% * ", + Strings.toString(principal), + " * ", + Strings.toString(numTimePeriodsElapsed), + " / ", + Strings.toString(FREQUENCY), + " = ", + Strings.toString(interestScaled) + ) + ); + + return interestScaled; + } + + function calcDiscounted(uint256 principal, uint256 numTimePeriodsElapsed) public view returns (uint256) { + if (principal < SCALE) { + revert PrincipalLessThanScale(principal, SCALE); + } + + uint256 discountedScaled = principal * SCALE - _calcInterestWithScale(principal, numTimePeriodsElapsed); + + return unscale(discountedScaled); + } + + function calcPrincipalFromDiscounted(uint256 discounted, uint256 numTimePeriodsElapsed) + public + view + virtual + returns (uint256) + { + uint256 interestFactor = + INTEREST_RATE_PERCENTAGE.mulDiv(numTimePeriodsElapsed * SCALE, FREQUENCY * 100, ROUNDING); + + uint256 scaledPrincipal = discounted.mulDiv(SCALE * SCALE, SCALE - interestFactor, ROUNDING); + + console2.log( + string.concat( + "Principal = Discounted / (1 - ((IR * m) / f)) = ", + Strings.toString(discounted), + " / (1 - ((", + Strings.toString(INTEREST_RATE_PERCENTAGE), + " * ", + Strings.toString(numTimePeriodsElapsed), + " ) / ", + Strings.toString(FREQUENCY), + " = ", + Strings.toString(scaledPrincipal) + ) + ); + + return unscale(scaledPrincipal); + } + + function unscale(uint256 amount) internal pure returns (uint256) { + return amount / SCALE; + } + + function getFrequency() public view returns (uint256 frequency) { + return FREQUENCY; + } + + function getInterestInPercentage() public view returns (uint256 interestRateInPercentage) { + return INTEREST_RATE_PERCENTAGE; + } +} diff --git a/packages/contracts/test/src/multi/SimpleInterestTest.t.sol b/packages/contracts/test/src/multi/SimpleInterestTest.t.sol new file mode 100644 index 000000000..73173acd0 --- /dev/null +++ b/packages/contracts/test/src/multi/SimpleInterestTest.t.sol @@ -0,0 +1,44 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.23; + +import { Math } from "@openzeppelin/contracts/utils/math/Math.sol"; +import { SimpleInterest } from "./SimpleInterest.s.sol"; +import { Frequencies } from "./Frequencies.s.sol"; + +import { ISimpleInterest } from "./ISimpleInterest.s.sol"; +import { InterestTest } from "./InterestTest.t.sol"; + +contract SimpleInterestTest is InterestTest { + using Math for uint256; + + function test__SimpleInterestTest__CheckScale() public { + uint256 apy = 10; // APY in percentage + + ISimpleInterest simpleInterest = new SimpleInterest(apy, Frequencies.toValue(Frequencies.Frequency.DAYS_360)); + + uint256 scaleMinus1 = SCALE - 1; + + // expect revert when principal not scaled + vm.expectRevert(); + simpleInterest.calcInterest(scaleMinus1, 0); + + vm.expectRevert(); + simpleInterest.calcDiscounted(scaleMinus1, 0); + } + + function test__SimpleInterestTest__Monthly() public { + uint256 apy = 12; // APY in percentage + + ISimpleInterest simpleInterest = new SimpleInterest(apy, Frequencies.toValue(Frequencies.Frequency.MONTHLY)); + + testInterestToMaxPeriods(200 * SCALE, simpleInterest); + } + + function test__SimpleInterestTest__Daily360() public { + uint256 apy = 10; // APY in percentage + + ISimpleInterest simpleInterest = new SimpleInterest(apy, Frequencies.toValue(Frequencies.Frequency.DAYS_360)); + + testInterestToMaxPeriods(200 * SCALE, simpleInterest); + } +} diff --git a/packages/contracts/test/src/multi/SimpleInterestVault.s.sol b/packages/contracts/test/src/multi/SimpleInterestVault.s.sol new file mode 100644 index 000000000..d46e2efe8 --- /dev/null +++ b/packages/contracts/test/src/multi/SimpleInterestVault.s.sol @@ -0,0 +1,114 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.23; + +import { SimpleInterest } from "./SimpleInterest.s.sol"; +import { IERC4626Interest } from "./IERC4626Interest.s.sol"; +import { TimelockVault } from "./TimelockVault.s.sol"; + +import { Math } from "openzeppelin-contracts/contracts/utils/math/Math.sol"; + +import { IERC4626 } from "@openzeppelin/contracts/interfaces/IERC4626.sol"; +import { ERC4626 } from "@openzeppelin/contracts/token/ERC20/extensions/ERC4626.sol"; + +import { IERC20 } from "@openzeppelin/contracts/interfaces/IERC20.sol"; + +// Vault that uses SimpleInterest to calculate Shares per Asset +// - At the start, 1 asset gives 1 share +// - At numPeriod of N, 1 asset gives as discounted amount of "1 - N * interest" +contract SimpleInterestVault is IERC4626Interest, SimpleInterest, TimelockVault { + using Math for uint256; + + uint256 public currentTimePeriodsElapsed = 0; // the current interest frequency + + // how many time periods for vault redeem + // should use the same time unit (day / month or years) as the interest frequency + uint256 public immutable TENOR; + + constructor(IERC20 asset, uint256 interestRatePercentage, uint256 frequency, uint256 tenor) + SimpleInterest(interestRatePercentage, frequency) + TimelockVault(asset, "Simple Interest Rate Claim", "cSIR", 0) + { + TENOR = tenor; + } + + // =============== Deposit =============== + + function convertToSharesAtPeriod(uint256 assetsInWei, uint256 numTimePeriodsElapsed) + public + view + returns (uint256 sharesInWei) + { + if (assetsInWei < SCALE) return 0; // no shares for fractional assets + + return calcDiscounted(assetsInWei, numTimePeriodsElapsed); + } + + function previewDeposit(uint256 assetsInWei) + public + view + override(ERC4626, IERC4626) + returns (uint256 sharesInWei) + { + return convertToShares(assetsInWei); + } + + function convertToShares(uint256 assetsInWei) + public + view + override(ERC4626, IERC4626) + returns (uint256 sharesInWei) + { + return convertToSharesAtPeriod(assetsInWei, currentTimePeriodsElapsed); + } + + // =============== Redeem =============== + + // asset that would be exchanged for the amount of shares + // for a given numberOfTimePeriodsElapsed + // assets = principal + interest + function convertToAssetsAtPeriod(uint256 sharesInWei, uint256 numTimePeriodsElapsed) + public + view + returns (uint256 assetsInWei) + { + if (sharesInWei < SCALE) return 0; // no assets for fractional shares + + // trying to redeem before TENOR - just give back the Discounted Amount + // this is a slash of Principal (and no Interest) + // NB - according to spec, this function should not revert + if (numTimePeriodsElapsed < TENOR) return sharesInWei; + + uint256 impliedNumTimePeriodsAtDeposit = (numTimePeriodsElapsed - TENOR); + + uint256 principal = calcPrincipalFromDiscounted(sharesInWei, impliedNumTimePeriodsAtDeposit); + + return principal + calcInterest(principal, TENOR); + } + + function previewRedeem(uint256 sharesInWei) public view override(ERC4626, IERC4626) returns (uint256 assetsInWei) { + return convertToAssets(sharesInWei); + } + + function convertToAssets(uint256 sharesInWei) + public + view + override(ERC4626, IERC4626) + returns (uint256 assetsInWei) + { + return convertToAssetsAtPeriod(sharesInWei, currentTimePeriodsElapsed); + } + + // =============== Utility =============== + + function getCurrentTimePeriodsElapsed() public pure returns (uint256 currentTimePeriodElapsed) { + return currentTimePeriodElapsed; + } + + function setCurrentTimePeriodsElapsed(uint256 _currentTimePeriodsElapsed) public { + currentTimePeriodsElapsed = _currentTimePeriodsElapsed; + } + + function getTenor() public view returns (uint256 tenor) { + return TENOR; + } +} diff --git a/packages/contracts/test/src/multi/SimpleInterestVaultTest.t.sol b/packages/contracts/test/src/multi/SimpleInterestVaultTest.t.sol new file mode 100644 index 000000000..a3a25f7ba --- /dev/null +++ b/packages/contracts/test/src/multi/SimpleInterestVaultTest.t.sol @@ -0,0 +1,119 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.23; + +import { IERC4626Interest } from "./IERC4626Interest.s.sol"; +import { SimpleInterestVault } from "./SimpleInterestVault.s.sol"; +import { Frequencies } from "./Frequencies.s.sol"; + +import { InterestTest } from "./InterestTest.t.sol"; +import { ISimpleInterest } from "./ISimpleInterest.s.sol"; + +import { SimpleToken } from "@test/test/token/SimpleToken.t.sol"; + +import { Math } from "openzeppelin-contracts/contracts/utils/math/Math.sol"; +import { IERC20 } from "@openzeppelin/contracts/interfaces/IERC20.sol"; + +contract SimpleInterestVaultTest is InterestTest { + using Math for uint256; + + IERC20 private asset; + + address private owner = makeAddr("owner"); + address private alice = makeAddr("alice"); + address private bob = makeAddr("bob"); + address private charlie = makeAddr("charlie"); + + function setUp() public { + uint256 tokenSupply = 100000 ether; + + vm.startPrank(owner); + asset = new SimpleToken(tokenSupply); + vm.stopPrank(); + + uint256 userTokenAmount = 1000 ether; + + assertEq(asset.balanceOf(owner), tokenSupply, "owner should start with total supply"); + transferAndAssert(asset, owner, alice, userTokenAmount); + transferAndAssert(asset, owner, bob, userTokenAmount); + transferAndAssert(asset, owner, charlie, userTokenAmount); + } + + function test__SimpleInterestVaultTest__CheckScale() public { + uint256 apy = 10; // APY in percentage + uint256 frequencyValue = Frequencies.toValue(Frequencies.Frequency.DAYS_360); + uint256 tenor = 90; + + IERC4626Interest vault = new SimpleInterestVault(asset, apy, frequencyValue, tenor); + + uint256 scaleMinus1 = SCALE - 1; + + assertEq(0, vault.convertToAssets(scaleMinus1), "convert to assets not scaled"); + + assertEq(0, vault.convertToShares(scaleMinus1), "convert to shares not scaled"); + } + + function test__SimpleInterestVaultTest__Monthly() public { + uint256 apy = 12; // APY in percentage + uint256 frequencyValue = Frequencies.toValue(Frequencies.Frequency.MONTHLY); + uint256 tenor = 3; + + IERC4626Interest vault = new SimpleInterestVault(asset, apy, frequencyValue, tenor); + + testInterestToMaxPeriods(200 * SCALE, vault); + } + + function test__SimpleInterestVaultTest__Daily360() public { + uint256 apy = 10; // APY in percentage + uint256 frequencyValue = Frequencies.toValue(Frequencies.Frequency.DAYS_360); + uint256 tenor = 90; + + IERC4626Interest vault = new SimpleInterestVault(asset, apy, frequencyValue, tenor); + + testInterestToMaxPeriods(200 * SCALE, vault); + } + + function testInterestAtPeriod(uint256 principal, ISimpleInterest simpleInterest, uint256 numTimePeriods) + internal + override + { + // test against the simple interest harness + super.testInterestAtPeriod(principal, simpleInterest, numTimePeriods); + + // test the vault related + IERC4626Interest vault = (IERC4626Interest)(address(simpleInterest)); + + uint256 expectedYield = principal + vault.calcInterest(principal, vault.getTenor()); + + // check convertAtSharesAtPeriod and convertToAssetsAtPeriod + + // yieldAt(Periods+Tenor) = principalAtDeposit + interestForTenor - similar to how we test the interest. + uint256 sharesInWeiAtPeriod = vault.convertToSharesAtPeriod(principal, numTimePeriods); + uint256 assetsInWeiAtPeriod = + vault.convertToAssetsAtPeriod(sharesInWeiAtPeriod, numTimePeriods + vault.getTenor()); + + assertApproxEqAbs( + expectedYield, + assetsInWeiAtPeriod, + TOLERANCE, + assertMsg("yield does not equal principal + interest", simpleInterest, numTimePeriods) + ); + + // check convertAtShares and convertToAssets -- simulates the passage of time (e.g. block times) + uint256 prevVaultTimePeriodsElapsed = vault.getCurrentTimePeriodsElapsed(); + + vault.setCurrentTimePeriodsElapsed(numTimePeriods); // set deposit numTimePeriods + uint256 sharesInWei = vault.convertToShares(principal); // now deposit + + vault.setCurrentTimePeriodsElapsed(numTimePeriods + vault.getTenor()); // set redeem numTimePeriods + uint256 assetsInWei = vault.convertToAssets(sharesInWei); // now redeem + + assertApproxEqAbs( + principal + vault.calcInterest(principal, vault.getTenor()), + assetsInWei, + TOLERANCE, + assertMsg("yield does not equal principal + interest", simpleInterest, numTimePeriods) + ); + + vault.setCurrentTimePeriodsElapsed(prevVaultTimePeriodsElapsed); // restore the vault to previous state + } +} diff --git a/packages/contracts/test/src/multi/TimelockVault.s.sol b/packages/contracts/test/src/multi/TimelockVault.s.sol new file mode 100644 index 000000000..2ed46ba98 --- /dev/null +++ b/packages/contracts/test/src/multi/TimelockVault.s.sol @@ -0,0 +1,65 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.20; + +import { ERC4626 } from "@openzeppelin/contracts/token/ERC20/extensions/ERC4626.sol"; + +import { IERC20 } from "@openzeppelin/contracts/token/ERC20/IERC20.sol"; +import { ERC20 } from "@openzeppelin/contracts/token/ERC20/ERC20.sol"; + +contract TimelockVault is ERC4626 { + struct LockInfo { + uint256 amount; + uint256 releaseTime; + } + + mapping(address => LockInfo) private _locks; + uint256 public lockDuration; + + error SharesLocked(uint256 releaseTime); + error TransferNotSupported(); + + constructor(IERC20 asset, string memory name, string memory symbol, uint256 initialLockDuration) + ERC4626(asset) + ERC20(name, symbol) + { + lockDuration = initialLockDuration; + } + + function setLockDuration(uint256 newLockDuration) external { + lockDuration = newLockDuration; + } + + function deposit(uint256 assets, address receiver) public override returns (uint256) { + uint256 shares = super.deposit(assets, receiver); + _locks[receiver] = LockInfo(shares, block.timestamp + lockDuration); + return shares; + } + + function redeem(uint256 shares, address receiver, address owner) public override returns (uint256) { + if (block.timestamp < _locks[owner].releaseTime) { + revert SharesLocked(_locks[owner].releaseTime); + } + return super.redeem(shares, receiver, owner); + } + + function transfer(address, uint256) public pure override(ERC20, IERC20) returns (bool) { + revert TransferNotSupported(); + } + + function transferFrom(address, address, uint256) public pure override(ERC20, IERC20) returns (bool) { + revert TransferNotSupported(); + } + + function getLockInfo(address account) external view returns (uint256 lockedAmount, uint256 releaseTime) { + LockInfo memory lock = _locks[account]; + return (lock.amount, lock.releaseTime); + } + + function getLockTimeLeft(address account) external view returns (uint256) { + if (block.timestamp >= _locks[account].releaseTime) { + return 0; + } else { + return _locks[account].releaseTime - block.timestamp; + } + } +} diff --git a/packages/contracts/test/src/multi/TimelockVaultTest.t.sol b/packages/contracts/test/src/multi/TimelockVaultTest.t.sol new file mode 100644 index 000000000..1af3d529b --- /dev/null +++ b/packages/contracts/test/src/multi/TimelockVaultTest.t.sol @@ -0,0 +1,89 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.20; + +import { TimelockVault } from "./TimelockVault.s.sol"; +import { SimpleToken } from "@test/test/token/SimpleToken.t.sol"; + +import { Test } from "forge-std/Test.sol"; +import { ERC20 } from "@openzeppelin/contracts/token/ERC20/ERC20.sol"; + +contract TimelockVaultTest is Test { + TimelockVault private vault; + ERC20 private underlyingAsset; + address private owner = makeAddr("owner"); + address private alice = makeAddr("alice"); + uint256 private lockDuration = 30 days; + uint256 private initialSupply = 1000000; + + function setUp() public { + // Setup the underlying asset token and mint some to the owner + vm.startPrank(owner); + underlyingAsset = new SimpleToken(initialSupply); + vm.stopPrank(); + + vault = new TimelockVault(underlyingAsset, "VaultToken", "VT", lockDuration); + } + + function test__TimelockVault__DepositAndLock() public { + uint256 depositAmount = 1000; + + // Transfer underlying assets to Alice + vm.startPrank(owner); + underlyingAsset.transfer(alice, depositAmount); + vm.stopPrank(); + + vm.startPrank(alice); + underlyingAsset.approve(address(vault), depositAmount); + vault.deposit(depositAmount, alice); + vm.stopPrank(); + + (uint256 lockedAmount, uint256 releaseTime) = vault.getLockInfo(alice); + + assertEq(lockedAmount, depositAmount, "Incorrect locked amount"); + assertEq(releaseTime, block.timestamp + lockDuration, "Incorrect release time"); + + // Ensure redemption fails before the lock period is over + vm.startPrank(alice); + vm.expectRevert(abi.encodeWithSelector(TimelockVault.SharesLocked.selector, releaseTime)); + vault.redeem(depositAmount, alice, alice); + vm.stopPrank(); + } + + function test__TimelockVault__RedeemAfterUnlock() public { + uint256 depositAmount = 1000; + + // Transfer underlying assets to Alice + vm.startPrank(owner); + underlyingAsset.transfer(alice, depositAmount); + vm.stopPrank(); + + vm.startPrank(alice); + underlyingAsset.approve(address(vault), depositAmount); + vault.deposit(depositAmount, alice); + vm.stopPrank(); + + // Fast forward time to after the lock period + vm.warp(block.timestamp + lockDuration + 1); + + // Redemption should succeed after the lock period + vm.startPrank(alice); + uint256 assets = vault.redeem(depositAmount, alice, alice); + vm.stopPrank(); + + assertEq(assets, depositAmount, "Redemption after unlock failed"); + assertEq(underlyingAsset.balanceOf(alice), depositAmount, "Alice should have the redeemed assets"); + assertEq(vault.balanceOf(alice), 0, "Alice should have no vault shares left"); + } + + function test__TimelockVault__TransferNotSupported() public { + vm.startPrank(alice); + vm.expectRevert(TimelockVault.TransferNotSupported.selector); + vault.transfer(owner, 1); + vm.stopPrank(); + + vm.startPrank(alice); + vm.expectRevert(TimelockVault.TransferNotSupported.selector); + vault.transferFrom(alice, owner, 1); + vm.stopPrank(); + } +}