diff --git a/.gitmodules b/.gitmodules index 690924b..04ea64a 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,6 +1,9 @@ [submodule "lib/forge-std"] path = lib/forge-std url = https://github.com/foundry-rs/forge-std -[submodule "lib/openzeppelin-contracts"] - path = lib/openzeppelin-contracts - url = https://github.com/OpenZeppelin/openzeppelin-contracts +[submodule "lib/modular-contracts"] + path = lib/modular-contracts + url = https://github.com/thirdweb-dev/modular-contracts +[submodule "lib/solady"] + path = lib/solady + url = https://github.com/vectorized/solady diff --git a/README.md b/README.md index ebfb9e9..4297b13 100644 --- a/README.md +++ b/README.md @@ -1,24 +1,20 @@ -## **Thirdweb Gateway Contract** +## **Thirdweb PayGateway Contract** -Thirdweb Gateway Contract is used as the entrypoint to thirdweb Pay for swaps and bridges. +Thirdweb PayGateway Contract is used as the entrypoint to thirdweb Pay for swaps and bridges. -This is a forwarder contract that forwards the swap providers transaction (LiFi, Decent, etc) to their contract. Thirdweb Gateway Contract has the following responsibilities: +This is a forwarder contract that forwards the swap providers transaction (LiFi, Decent, etc) to their contract. Thirdweb PayGateway Contract has the following responsibilities: - Data Logging - this is essential for attribution and linking on-chain and off-chain data - Fee Splitting - this allows us to split the fees in-flight and flexibility to change fees on a per client basis - Data validation - this provides high-security as only thirdweb originated swaps with untampered data can use this contract -- exit point for contract calls - for LiFi, they can only guarantee toAmount for contract calls. This allows use to add a contract call to transferEnd that forwards the end funds to the user - Stateless - this will be deployed on many different chains. We don’t want to have to call addClient, changeFee, addSwapProvider, etc on every single chain for every change. Therefore, this should not rely on data held in the state of the contract, but rather data passed in -[Gateway Reference](img/gateway.png) - -[Gateway With Transfer End](img/gateway-transfer-end.png) +[PayGateway Reference](img/gateway.png) ## Features - Event Logging - - TransferStart logs the necessary events attribution and link off-chain and on-chain through clientId and transactionId. We use bytes32 instead of string for clientId and transactionId (uuid in database) because this allows recovering indexed pre-image - - TransferEnd logs the transfer end in case of a contract call and can be used for indexing bridge transactions by just listening to our Thirdweb Gateway deployments + - TokenPurchaseInitiated logs the necessary events attribution and link off-chain and on-chain through clientId and transactionId. We use bytes32 instead of string for clientId and transactionId (uuid in database) because this allows recovering indexed pre-image - FeePayout logs the fees distributed among the payees - Fee Splitting - supports many parties for fee payouts (we only expect us and client). It also allows for flexible fees on a per client basis diff --git a/foundry.toml b/foundry.toml index bbc1b8d..e853876 100644 --- a/foundry.toml +++ b/foundry.toml @@ -13,7 +13,6 @@ test = 'test' out = 'artifacts_forge' libs = ["lib"] remappings = [ - '@openzeppelin/contracts=lib/openzeppelin-contracts/contracts', '@ds-test=lib/ds-test/src/', '@std=lib/forge-std/src/', ] diff --git a/gasreport.txt b/gasreport.txt index 35000cb..7e3735c 100644 --- a/gasreport.txt +++ b/gasreport.txt @@ -1,26 +1,15 @@ No files changed, compilation skipped -Running 2 tests for test/benchmarks/BenchmarkPaymentsGatewaySplit.t.sol:BenchmarkPaymentsGatewaySplitTest -[PASS] test_startTransfer_erc20() (gas: 131301) -[PASS] test_startTransfer_nativeToken() (gas: 141479) -Test result: ok. 2 passed; 0 failed; 0 skipped; finished in 1.69ms +Ran 2 tests for test/benchmarks/BenchmarkModularPaymentsGateway.t.sol:BenchmarkModularPaymentsGatewayTest +[PASS] test_initiateTokenPurchase_erc20() (gas: 181350) +[PASS] test_initiateTokenPurchase_nativeToken() (gas: 246176) +Suite result: ok. 2 passed; 0 failed; 0 skipped; finished in 1.83ms (777.63µs CPU time) -Running 2 tests for test/benchmarks/BenchmarkPaymentsGateway.t.sol:BenchmarkPaymentsGatewayTest -[PASS] test_startTransfer_erc20() (gas: 158244) -[PASS] test_startTransfer_nativeToken() (gas: 183194) -Test result: ok. 2 passed; 0 failed; 0 skipped; finished in 1.94ms +Ran 2 tests for test/benchmarks/BenchmarkPaymentsGateway.t.sol:BenchmarkPaymentsGatewayTest +[PASS] test_initiateTokenPurchase_erc20() (gas: 150635) +[PASS] test_initiateTokenPurchase_nativeToken() (gas: 209936) +Suite result: ok. 2 passed; 0 failed; 0 skipped; finished in 1.83ms (1.01ms CPU time) -Running 2 tests for test/benchmarks/BenchmarkPaymentsGatewayPull.t.sol:BenchmarkPaymentsGatewayPullTest -[PASS] test_startTransfer_erc20() (gas: 184188) -[PASS] test_startTransfer_nativeToken() (gas: 181962) -Test result: ok. 2 passed; 0 failed; 0 skipped; finished in 1.81ms - -Ran 3 test suites: 6 tests passed, 0 failed, 0 skipped (6 total tests) -test_startTransfer_erc20() (gas: 0 (0.000%)) -test_startTransfer_nativeToken() (gas: 0 (0.000%)) -test_startTransfer_erc20() (gas: 0 (0.000%)) -test_startTransfer_nativeToken() (gas: 0 (0.000%)) -test_startTransfer_erc20() (gas: 0 (0.000%)) -test_startTransfer_nativeToken() (gas: 0 (0.000%)) -Overall gas change: 0 (0.000%) +Ran 2 test suites in 2.70ms (3.65ms CPU time): 4 tests passed, 0 failed, 0 skipped (4 total tests) +Overall gas change: 0 (NaN%) diff --git a/img/gateway-transfer-end.png b/img/gateway-transfer-end.png deleted file mode 100644 index ed55fd1..0000000 Binary files a/img/gateway-transfer-end.png and /dev/null differ diff --git a/lib/modular-contracts b/lib/modular-contracts new file mode 160000 index 0000000..99e6e5d --- /dev/null +++ b/lib/modular-contracts @@ -0,0 +1 @@ +Subproject commit 99e6e5ddc3a46b9b08acef3c3cdbab0c707e71ac diff --git a/lib/openzeppelin-contracts b/lib/openzeppelin-contracts deleted file mode 160000 index 01ef448..0000000 --- a/lib/openzeppelin-contracts +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 01ef448981be9d20ca85f2faf6ebdf591ce409f3 diff --git a/lib/solady b/lib/solady new file mode 160000 index 0000000..a1f9be9 --- /dev/null +++ b/lib/solady @@ -0,0 +1 @@ +Subproject commit a1f9be988d3c12655692cb8cdfc6864cc393cff6 diff --git a/src/PayGateway.sol b/src/PayGateway.sol new file mode 100644 index 0000000..5cf8564 --- /dev/null +++ b/src/PayGateway.sol @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: Apache-2.0 +pragma solidity ^0.8.22; + +import { ModularCore } from "lib/modular-contracts/src/ModularCore.sol"; + +contract PayGateway is ModularCore { + constructor(address _owner, address[] memory _modules, bytes[] memory _moduleInstallData) { + _initializeOwner(_owner); + + // Install and initialize modules + require(_modules.length == _moduleInstallData.length); + for (uint256 i = 0; i < _modules.length; i++) { + _installModule(_modules[i], _moduleInstallData[i]); + } + } + + function getSupportedCallbackFunctions() + public + pure + override + returns (SupportedCallbackFunction[] memory supportedCallbackFunctions) + {} +} diff --git a/src/PaymentsGateway.sol b/src/PayGatewayModule.sol similarity index 53% rename from src/PaymentsGateway.sol rename to src/PayGatewayModule.sol index 7536cae..252426e 100644 --- a/src/PaymentsGateway.sol +++ b/src/PayGatewayModule.sol @@ -3,43 +3,59 @@ pragma solidity ^0.8.22; /// @author thirdweb -import "@openzeppelin/contracts/access/Ownable.sol"; -import "@openzeppelin/contracts/token/ERC20/IERC20.sol"; -import "@openzeppelin/contracts/utils/ReentrancyGuard.sol"; -import { EIP712 } from "./utils/EIP712.sol"; +import { EIP712 } from "lib/solady/src/utils/EIP712.sol"; +import { SafeTransferLib } from "lib/solady/src/utils/SafeTransferLib.sol"; +import { ReentrancyGuard } from "lib/solady/src/utils/ReentrancyGuard.sol"; +import { ECDSA } from "lib/solady/src/utils/ECDSA.sol"; +import { ModularModule } from "lib/modular-contracts/src/ModularModule.sol"; +import { Ownable } from "lib/solady/src/auth/Ownable.sol"; + +library PayGatewayModuleStorage { + /// @custom:storage-location erc7201:pay.gateway.module + bytes32 public constant PAY_GATEWAY_EXTENSION_STORAGE_POSITION = + keccak256(abi.encode(uint256(keccak256("pay.gateway.module")) - 1)) & ~bytes32(uint256(0xff)); + + struct Data { + /// @dev Mapping from pay request UID => whether the pay request is processed. + mapping(bytes32 => bool) processed; + } -import { SafeTransferLib } from "./lib/SafeTransferLib.sol"; -import { ECDSA } from "./lib/ECDSA.sol"; + function data() internal pure returns (Data storage data_) { + bytes32 position = PAY_GATEWAY_EXTENSION_STORAGE_POSITION; + assembly { + data_.slot := position + } + } +} -contract PaymentsGateway is EIP712, Ownable, ReentrancyGuard { +contract PayGatewayModule is EIP712, ModularModule, ReentrancyGuard { using ECDSA for bytes32; /*/////////////////////////////////////////////////////////////// State, constants, structs //////////////////////////////////////////////////////////////*/ + uint256 private constant _ADMIN_ROLE = 1 << 2; + bytes32 private constant PAYOUTINFO_TYPEHASH = - keccak256("PayoutInfo(bytes32 clientId,address payoutAddress,uint256 feeBPS)"); + keccak256("PayoutInfo(bytes32 clientId,address payoutAddress,uint256 feeAmount)"); bytes32 private constant REQUEST_TYPEHASH = keccak256( - "PayRequest(bytes32 clientId,bytes32 transactionId,address tokenAddress,uint256 tokenAmount,uint256 expirationTimestamp,PayoutInfo[] payouts,address forwardAddress,bytes data)PayoutInfo(bytes32 clientId,address payoutAddress,uint256 feeBPS)" + "PayRequest(bytes32 clientId,bytes32 transactionId,address tokenAddress,uint256 tokenAmount,uint256 expirationTimestamp,PayoutInfo[] payouts,address forwardAddress,bool directTransfer,bytes data)PayoutInfo(bytes32 clientId,address payoutAddress,uint256 feeAmount)" ); address private constant NATIVE_TOKEN_ADDRESS = 0xEeeeeEeeeEeEeeEeEeEeeEEEeeeeEeeeeeeeEEeE; - /// @dev Mapping from pay request UID => whether the pay request is processed. - mapping(bytes32 => bool) private processed; - /** * @notice Info of fee payout recipients. * * @param clientId ClientId of fee recipient * @param payoutAddress Recipient address - * @param feeBPS The fee basis points to be charged. Max = 10000 (10000 = 100%, 1000 = 10%) + * @param feeAmount The fee amount to be paid to each recipient */ struct PayoutInfo { bytes32 clientId; address payable payoutAddress; - uint256 feeBPS; + uint256 feeAmount; } /** @@ -52,6 +68,7 @@ contract PaymentsGateway is EIP712, Ownable, ReentrancyGuard { * @param expirationTimestamp The unix timestamp at which the request expires * @param payouts Array of Payout struct - containing fee recipients' info * @param forwardAddress Address of swap provider contract + * @param directTransfer Whether the payment is a direct transfer to another address * @param data Calldata for swap provider */ struct PayRequest { @@ -62,6 +79,7 @@ contract PaymentsGateway is EIP712, Ownable, ReentrancyGuard { uint256 expirationTimestamp; PayoutInfo[] payouts; address payable forwardAddress; + bool directTransfer; bytes data; } @@ -77,57 +95,68 @@ contract PaymentsGateway is EIP712, Ownable, ReentrancyGuard { uint256 tokenAmount ); - event TokenPurchaseCompleted( - bytes32 indexed clientId, - address indexed receiver, - bytes32 transactionId, - address tokenAddress, - uint256 tokenAmount - ); - event FeePayout( bytes32 indexed clientId, address indexed sender, address payoutAddress, address tokenAddress, - uint256 feeAmount, - uint256 feeBPS + uint256 feeAmount ); /*/////////////////////////////////////////////////////////////// Errors //////////////////////////////////////////////////////////////*/ - error PaymentsGatewayMismatchedValue(uint256 expected, uint256 actual); - error PaymentsGatewayInvalidAmount(uint256 amount); - error PaymentsGatewayVerificationFailed(); - error PaymentsGatewayFailedToForward(); - error PaymentsGatewayRequestExpired(uint256 expirationTimestamp); + error PayGatewayMismatchedValue(uint256 expected, uint256 actual); + error PayGatewayInvalidAmount(uint256 amount); + error PayGatewayVerificationFailed(); + error PayGatewayFailedToForward(); + error PayGatewayRequestExpired(uint256 expirationTimestamp); + error PayGatewayMsgValueNotZero(); - /*/////////////////////////////////////////////////////////////// - Constructor + /*////////////////////////////////////////////////////////////// + EXTENSION CONFIG //////////////////////////////////////////////////////////////*/ - constructor(address contractOwner) Ownable(contractOwner) {} + /// @notice Returns all implemented callback and fallback functions. + function getModuleConfig() external pure override returns (ModuleConfig memory config) { + config.fallbackFunctions = new FallbackFunction[](5); + + config.fallbackFunctions[0] = FallbackFunction({ + selector: this.withdrawTo.selector, + permissionBits: _ADMIN_ROLE + }); + config.fallbackFunctions[1] = FallbackFunction({ + selector: this.withdraw.selector, + permissionBits: _ADMIN_ROLE + }); + config.fallbackFunctions[2] = FallbackFunction({ + selector: this.initiateTokenPurchase.selector, + permissionBits: 0 + }); + config.fallbackFunctions[3] = FallbackFunction({ selector: this.eip712Domain.selector, permissionBits: 0 }); + config.fallbackFunctions[4] = FallbackFunction({ selector: this.isProcessed.selector, permissionBits: 0 }); + } /*/////////////////////////////////////////////////////////////// External / public functions //////////////////////////////////////////////////////////////*/ + /// @notice check if transaction id has been used / processed + function isProcessed(bytes32 transactionId) external view returns (bool) { + return PayGatewayModuleStorage.data().processed[transactionId]; + } + /// @notice some bridges may refund need a way to get funds back to user - function withdrawTo( - address tokenAddress, - uint256 tokenAmount, - address payable receiver - ) public onlyOwner nonReentrant { + function withdrawTo(address tokenAddress, uint256 tokenAmount, address payable receiver) public nonReentrant { if (_isTokenERC20(tokenAddress)) { - SafeTransferLib.safeTransferFrom(tokenAddress, address(this), receiver, tokenAmount); + SafeTransferLib.safeTransfer(tokenAddress, receiver, tokenAmount); } else { SafeTransferLib.safeTransferETH(receiver, tokenAmount); } } - function withdraw(address tokenAddress, uint256 tokenAmount) external onlyOwner nonReentrant { + function withdraw(address tokenAddress, uint256 tokenAmount) external nonReentrant { withdrawTo(tokenAddress, tokenAmount, payable(msg.sender)); } @@ -146,30 +175,24 @@ contract PaymentsGateway is EIP712, Ownable, ReentrancyGuard { function initiateTokenPurchase(PayRequest calldata req, bytes calldata signature) external payable nonReentrant { // verify amount if (req.tokenAmount == 0) { - revert PaymentsGatewayInvalidAmount(req.tokenAmount); + revert PayGatewayInvalidAmount(req.tokenAmount); } // verify expiration timestamp if (req.expirationTimestamp < block.timestamp) { - revert PaymentsGatewayRequestExpired(req.expirationTimestamp); + revert PayGatewayRequestExpired(req.expirationTimestamp); } // verify data if (!_verifyTransferStart(req, signature)) { - revert PaymentsGatewayVerificationFailed(); - } - - if (_isTokenNative(req.tokenAddress)) { - if (msg.value < req.tokenAmount) { - revert PaymentsGatewayMismatchedValue(req.tokenAmount, msg.value); - } + revert PayGatewayVerificationFailed(); } // mark the pay request as processed - processed[req.transactionId] = true; + PayGatewayModuleStorage.data().processed[req.transactionId] = true; // distribute fees - uint256 totalFeeAmount = _distributeFees(req.tokenAddress, req.tokenAmount, req.payouts); + uint256 totalFeeAmount = _distributeFees(req.tokenAddress, req.payouts); // determine native value to send uint256 sendValue = msg.value; // includes bridge fee etc. (if any) @@ -177,70 +200,56 @@ contract PaymentsGateway is EIP712, Ownable, ReentrancyGuard { sendValue = msg.value - totalFeeAmount; if (sendValue < req.tokenAmount) { - revert PaymentsGatewayMismatchedValue(sendValue, req.tokenAmount); + revert PayGatewayMismatchedValue(req.tokenAmount, sendValue); } } - if (_isTokenERC20(req.tokenAddress)) { - // pull user funds - SafeTransferLib.safeTransferFrom(req.tokenAddress, msg.sender, address(this), req.tokenAmount); - SafeTransferLib.safeApprove(req.tokenAddress, req.forwardAddress, req.tokenAmount); - } - - { - (bool success, bytes memory response) = req.forwardAddress.call{ value: sendValue }(req.data); - if (!success) { - // If there is return data, the delegate call reverted with a reason or a custom error, which we bubble up. - if (response.length > 0) { - assembly { - let returndata_size := mload(response) - revert(add(32, response), returndata_size) + if (req.directTransfer) { + if (_isTokenNative(req.tokenAddress)) { + (bool success, bytes memory response) = req.forwardAddress.call{ value: sendValue }(""); + + if (!success) { + // If there is return data, the delegate call reverted with a reason or a custom error, which we bubble up. + if (response.length > 0) { + assembly { + let returndata_size := mload(response) + revert(add(32, response), returndata_size) + } + } else { + revert PayGatewayFailedToForward(); } - } else { - revert PaymentsGatewayFailedToForward(); } - } - } - - emit TokenPurchaseInitiated(req.clientId, msg.sender, req.transactionId, req.tokenAddress, req.tokenAmount); - } - - /** - @notice - The purpose of completeTokenPurchase is to provide a forwarding contract call - on the destination chain. For some swap providers, they can only guarantee the toAmount - if we use a contract call. This allows us to call the endTransfer function and forward the - funds to the end user. - - Requirements: - 1. Log the transfer end - 2. forward the user funds - */ - function completeTokenPurchase( - bytes32 clientId, - bytes32 transactionId, - address tokenAddress, - uint256 tokenAmount, - address payable receiverAddress - ) external payable nonReentrant { - if (tokenAmount == 0) { - revert PaymentsGatewayInvalidAmount(tokenAmount); - } + } else { + if (msg.value != 0) { + revert PayGatewayMsgValueNotZero(); + } - if (_isTokenNative(tokenAddress)) { - if (msg.value < tokenAmount) { - revert PaymentsGatewayMismatchedValue(tokenAmount, msg.value); + SafeTransferLib.safeTransferFrom(req.tokenAddress, msg.sender, req.forwardAddress, req.tokenAmount); } - } - - // pull user funds - if (_isTokenERC20(tokenAddress)) { - SafeTransferLib.safeTransferFrom(tokenAddress, msg.sender, receiverAddress, tokenAmount); } else { - SafeTransferLib.safeTransferETH(receiverAddress, tokenAmount); + if (_isTokenERC20(req.tokenAddress)) { + // pull user funds + SafeTransferLib.safeTransferFrom(req.tokenAddress, msg.sender, address(this), req.tokenAmount); + SafeTransferLib.safeApprove(req.tokenAddress, req.forwardAddress, req.tokenAmount); + } + + { + (bool success, bytes memory response) = req.forwardAddress.call{ value: sendValue }(req.data); + if (!success) { + // If there is return data, the delegate call reverted with a reason or a custom error, which we bubble up. + if (response.length > 0) { + assembly { + let returndata_size := mload(response) + revert(add(32, response), returndata_size) + } + } else { + revert PayGatewayFailedToForward(); + } + } + } } - emit TokenPurchaseCompleted(clientId, receiverAddress, transactionId, tokenAddress, tokenAmount); + emit TokenPurchaseInitiated(req.clientId, msg.sender, req.transactionId, req.tokenAddress, req.tokenAmount); } /*/////////////////////////////////////////////////////////////// @@ -248,7 +257,7 @@ contract PaymentsGateway is EIP712, Ownable, ReentrancyGuard { //////////////////////////////////////////////////////////////*/ function _domainNameAndVersion() internal pure override returns (string memory name, string memory version) { - name = "PaymentsGateway"; + name = "PayGateway"; version = "1"; } @@ -256,45 +265,43 @@ contract PaymentsGateway is EIP712, Ownable, ReentrancyGuard { bytes32[] memory payoutsHashes = new bytes32[](payouts.length); for (uint i = 0; i < payouts.length; i++) { payoutsHashes[i] = keccak256( - abi.encode(PAYOUTINFO_TYPEHASH, payouts[i].clientId, payouts[i].payoutAddress, payouts[i].feeBPS) + abi.encode(PAYOUTINFO_TYPEHASH, payouts[i].clientId, payouts[i].payoutAddress, payouts[i].feeAmount) ); } return keccak256(abi.encodePacked(payoutsHashes)); } - function _distributeFees( - address tokenAddress, - uint256 tokenAmount, - PayoutInfo[] calldata payouts - ) private returns (uint256) { + function _distributeFees(address tokenAddress, PayoutInfo[] calldata payouts) private returns (uint256) { uint256 totalFeeAmount = 0; for (uint32 payeeIdx = 0; payeeIdx < payouts.length; payeeIdx++) { - uint256 feeAmount = _calculateFee(tokenAmount, payouts[payeeIdx].feeBPS); - totalFeeAmount += feeAmount; + totalFeeAmount += payouts[payeeIdx].feeAmount; emit FeePayout( payouts[payeeIdx].clientId, msg.sender, payouts[payeeIdx].payoutAddress, tokenAddress, - feeAmount, - payouts[payeeIdx].feeBPS + payouts[payeeIdx].feeAmount ); if (_isTokenNative(tokenAddress)) { - SafeTransferLib.safeTransferETH(payouts[payeeIdx].payoutAddress, feeAmount); + SafeTransferLib.safeTransferETH(payouts[payeeIdx].payoutAddress, payouts[payeeIdx].feeAmount); } else { - SafeTransferLib.safeTransferFrom(tokenAddress, msg.sender, payouts[payeeIdx].payoutAddress, feeAmount); + SafeTransferLib.safeTransferFrom( + tokenAddress, + msg.sender, + payouts[payeeIdx].payoutAddress, + payouts[payeeIdx].feeAmount + ); } } - if (totalFeeAmount > tokenAmount) { - revert PaymentsGatewayMismatchedValue(totalFeeAmount, tokenAmount); - } return totalFeeAmount; } function _verifyTransferStart(PayRequest calldata req, bytes calldata signature) private view returns (bool) { + bool processed = PayGatewayModuleStorage.data().processed[req.transactionId]; + bytes32 payoutsHash = _hashPayoutInfo(req.payouts); bytes32 structHash = keccak256( abi.encode( @@ -306,13 +313,14 @@ contract PaymentsGateway is EIP712, Ownable, ReentrancyGuard { req.expirationTimestamp, payoutsHash, req.forwardAddress, + req.directTransfer, keccak256(req.data) ) ); bytes32 digest = _hashTypedData(structHash); address recovered = digest.recover(signature); - bool valid = recovered == owner() && !processed[req.transactionId]; + bool valid = recovered == Ownable(address(this)).owner() && !processed; return valid; } @@ -324,9 +332,4 @@ contract PaymentsGateway is EIP712, Ownable, ReentrancyGuard { function _isTokenNative(address tokenAddress) private pure returns (bool) { return tokenAddress == NATIVE_TOKEN_ADDRESS; } - - function _calculateFee(uint256 amount, uint256 feeBPS) private pure returns (uint256) { - uint256 feeAmount = (amount * feeBPS) / 10_000; - return feeAmount; - } } diff --git a/test/PaymentsGateway.t.sol b/test/PayGateway.t.sol similarity index 69% rename from test/PaymentsGateway.t.sol rename to test/PayGateway.t.sol index 068cfd7..cb45aa3 100644 --- a/test/PaymentsGateway.t.sol +++ b/test/PayGateway.t.sol @@ -1,12 +1,17 @@ -// SPDX-License-Identifier: UNLICENSED -pragma solidity ^0.8.13; +// SPDX-License-Identifier: Apache 2.0 +pragma solidity ^0.8.0; -import { Test, console, console2 } from "forge-std/Test.sol"; -import { PaymentsGateway } from "src/PaymentsGateway.sol"; +import { Test, console } from "forge-std/Test.sol"; + +import { PayGateway } from "src/PayGateway.sol"; +import { PayGatewayModule } from "src/PayGatewayModule.sol"; +import { IModuleConfig } from "lib/modular-contracts/src/interface/IModuleConfig.sol"; +import { IModularCore } from "lib/modular-contracts/src/interface/IModularCore.sol"; +import { LibClone } from "lib/solady/src/utils/LibClone.sol"; import { MockERC20 } from "./utils/MockERC20.sol"; import { MockTarget } from "./utils/MockTarget.sol"; -contract PaymentsGatewayTest is Test { +contract PayGatewayTest is Test { event TokenPurchaseInitiated( bytes32 indexed clientId, address indexed sender, @@ -15,14 +20,6 @@ contract PaymentsGatewayTest is Test { uint256 tokenAmount ); - event TokenPurchaseCompleted( - bytes32 indexed clientId, - address indexed receiver, - bytes32 transactionId, - address tokenAddress, - uint256 tokenAmount - ); - event FeePayout( bytes32 indexed clientId, address indexed sender, @@ -34,7 +31,7 @@ contract PaymentsGatewayTest is Test { event OperatorChanged(address indexed previousOperator, address indexed newOperator); - PaymentsGateway internal gateway; + PayGatewayModule internal gateway; MockERC20 internal mockERC20; MockTarget internal mockTarget; @@ -47,11 +44,11 @@ contract PaymentsGatewayTest is Test { bytes32 internal ownerClientId; bytes32 internal clientId; - uint256 internal ownerFeeBps; - uint256 internal clientFeeBps; - uint256 internal totalFeeBps; + uint256 internal ownerFeeAmount; + uint256 internal clientFeeAmount; + uint256 internal totalFeeAmount; - PaymentsGateway.PayoutInfo[] internal payouts; + PayGatewayModule.PayoutInfo[] internal payouts; bytes32 internal typehashPayRequest; bytes32 internal typehashPayoutInfo; @@ -70,10 +67,19 @@ contract PaymentsGatewayTest is Test { ownerClientId = keccak256("owner"); clientId = keccak256("client"); - ownerFeeBps = 200; - clientFeeBps = 100; + ownerFeeAmount = 20; + clientFeeAmount = 10; + + // deploy and install module + address module = address(new PayGatewayModule()); + + address[] memory modules = new address[](1); + bytes[] memory moduleData = new bytes[](1); + modules[0] = address(module); + moduleData[0] = ""; + + gateway = PayGatewayModule(address(new PayGateway(operator, modules, moduleData))); - gateway = new PaymentsGateway(operator); mockERC20 = new MockERC20("Token", "TKN"); mockTarget = new MockTarget(); @@ -83,24 +89,22 @@ contract PaymentsGatewayTest is Test { // build payout info payouts.push( - PaymentsGateway.PayoutInfo({ clientId: ownerClientId, payoutAddress: owner, feeBPS: ownerFeeBps }) + PayGatewayModule.PayoutInfo({ clientId: ownerClientId, payoutAddress: owner, feeAmount: ownerFeeAmount }) + ); + payouts.push( + PayGatewayModule.PayoutInfo({ clientId: clientId, payoutAddress: client, feeAmount: clientFeeAmount }) ); - payouts.push(PaymentsGateway.PayoutInfo({ clientId: clientId, payoutAddress: client, feeBPS: clientFeeBps })); - // console.logBytes32(clientId); - // console.log(client); - // console.log(clientFeeBps); - console.log(address(gateway)); for (uint256 i = 0; i < payouts.length; i++) { - totalFeeBps += payouts[i].feeBPS; + totalFeeAmount += payouts[i].feeAmount; } // EIP712 - typehashPayoutInfo = keccak256("PayoutInfo(bytes32 clientId,address payoutAddress,uint256 feeBPS)"); + typehashPayoutInfo = keccak256("PayoutInfo(bytes32 clientId,address payoutAddress,uint256 feeAmount)"); typehashPayRequest = keccak256( - "PayRequest(bytes32 clientId,bytes32 transactionId,address tokenAddress,uint256 tokenAmount,uint256 expirationTimestamp,PayoutInfo[] payouts,address forwardAddress,bytes data)PayoutInfo(bytes32 clientId,address payoutAddress,uint256 feeBPS)" + "PayRequest(bytes32 clientId,bytes32 transactionId,address tokenAddress,uint256 tokenAmount,uint256 expirationTimestamp,PayoutInfo[] payouts,address forwardAddress,bool directTransfer,bytes data)PayoutInfo(bytes32 clientId,address payoutAddress,uint256 feeAmount)" ); - nameHash = keccak256(bytes("PaymentsGateway")); + nameHash = keccak256(bytes("PayGateway")); versionHash = keccak256(bytes("1")); typehashEip712 = keccak256( "EIP712Domain(string name,string version,uint256 chainId,address verifyingContract)" @@ -122,13 +126,13 @@ contract PaymentsGatewayTest is Test { data = abi.encode(_sender, _receiver, _token, _sendValue, _message); } - function _hashPayoutInfo(PaymentsGateway.PayoutInfo[] memory _payouts) private view returns (bytes32) { + function _hashPayoutInfo(PayGatewayModule.PayoutInfo[] memory _payouts) private view returns (bytes32) { bytes32 payoutHash = typehashPayoutInfo; bytes32[] memory payoutsHashes = new bytes32[](_payouts.length); for (uint i = 0; i < payouts.length; i++) { payoutsHashes[i] = keccak256( - abi.encode(payoutHash, _payouts[i].clientId, _payouts[i].payoutAddress, _payouts[i].feeBPS) + abi.encode(payoutHash, _payouts[i].clientId, _payouts[i].payoutAddress, _payouts[i].feeAmount) ); } return keccak256(abi.encodePacked(payoutsHashes)); @@ -136,7 +140,7 @@ contract PaymentsGatewayTest is Test { function _prepareAndSignData( uint256 _operatorPrivateKey, - PaymentsGateway.PayRequest memory req + PayGatewayModule.PayRequest memory req ) internal view returns (bytes memory signature) { bytes memory dataToHash; { @@ -150,6 +154,7 @@ contract PaymentsGatewayTest is Test { req.expirationTimestamp, _payoutsHash, req.forwardAddress, + req.directTransfer, keccak256(req.data) ); } @@ -169,7 +174,7 @@ contract PaymentsGatewayTest is Test { function test_initiateTokenPurchase_erc20() public { uint256 sendValue = 1 ether; - uint256 sendValueWithFees = sendValue + (sendValue * totalFeeBps) / 10_000; + uint256 sendValueWithFees = sendValue + totalFeeAmount; bytes memory targetCalldata = _buildMockTargetCalldata(sender, receiver, address(mockERC20), sendValue, ""); // approve amount to gateway contract @@ -177,7 +182,7 @@ contract PaymentsGatewayTest is Test { mockERC20.approve(address(gateway), sendValueWithFees); // create pay request - PaymentsGateway.PayRequest memory req; + PayGatewayModule.PayRequest memory req; bytes32 _transactionId = keccak256("transaction ID"); req.clientId = clientId; @@ -206,15 +211,60 @@ contract PaymentsGatewayTest is Test { gateway.initiateTokenPurchase(req, _signature); // check balances after transaction - assertEq(mockERC20.balanceOf(owner), ownerBalanceBefore + (sendValue * ownerFeeBps) / 10_000); - assertEq(mockERC20.balanceOf(client), clientBalanceBefore + (sendValue * clientFeeBps) / 10_000); + assertEq(mockERC20.balanceOf(owner), ownerBalanceBefore + ownerFeeAmount); + assertEq(mockERC20.balanceOf(client), clientBalanceBefore + clientFeeAmount); + assertEq(mockERC20.balanceOf(sender), senderBalanceBefore - sendValueWithFees); + assertEq(mockERC20.balanceOf(receiver), receiverBalanceBefore + sendValue); + } + + function test_initiateTokenPurchase_erc20_directTransfer() public { + uint256 sendValue = 1 ether; + uint256 sendValueWithFees = sendValue + totalFeeAmount; + bytes memory targetCalldata = abi.encodeWithSignature("transfer(address,uint256)", receiver, sendValue); + + // approve amount to gateway contract + vm.prank(sender); + mockERC20.approve(address(gateway), sendValueWithFees); + + // create pay request + PayGatewayModule.PayRequest memory req; + bytes32 _transactionId = keccak256("transaction ID"); + + req.clientId = clientId; + req.transactionId = _transactionId; + req.tokenAddress = address(mockERC20); + req.tokenAmount = sendValue; + req.forwardAddress = payable(address(mockERC20)); + req.expirationTimestamp = 1000; + req.data = targetCalldata; + req.payouts = payouts; + + // generate signature + bytes memory _signature = _prepareAndSignData( + 2, // sign with operator private key, i.e. 2 + req + ); + + // state/balances before sending transaction + uint256 ownerBalanceBefore = mockERC20.balanceOf(owner); + uint256 clientBalanceBefore = mockERC20.balanceOf(client); + uint256 senderBalanceBefore = mockERC20.balanceOf(sender); + uint256 receiverBalanceBefore = mockERC20.balanceOf(receiver); + + // send transaction + vm.prank(sender); + gateway.initiateTokenPurchase(req, _signature); + + // check balances after transaction + assertEq(mockERC20.balanceOf(owner), ownerBalanceBefore + ownerFeeAmount); + assertEq(mockERC20.balanceOf(client), clientBalanceBefore + clientFeeAmount); assertEq(mockERC20.balanceOf(sender), senderBalanceBefore - sendValueWithFees); assertEq(mockERC20.balanceOf(receiver), receiverBalanceBefore + sendValue); } function test_initiateTokenPurchase_nativeToken() public { uint256 sendValue = 1 ether; - uint256 sendValueWithFees = sendValue + (sendValue * totalFeeBps) / 10_000; + uint256 sendValueWithFees = sendValue + totalFeeAmount; bytes memory targetCalldata = _buildMockTargetCalldata( sender, receiver, @@ -224,7 +274,7 @@ contract PaymentsGatewayTest is Test { ); // create pay request - PaymentsGateway.PayRequest memory req; + PayGatewayModule.PayRequest memory req; bytes32 _transactionId = keccak256("transaction ID"); req.clientId = clientId; @@ -248,8 +298,52 @@ contract PaymentsGatewayTest is Test { req ); - console.logBytes(_signature); - console.log(address(uint160(gateway._cachedThis()))); + // state/balances before sending transaction + uint256 ownerBalanceBefore = owner.balance; + uint256 clientBalanceBefore = client.balance; + uint256 senderBalanceBefore = sender.balance; + uint256 receiverBalanceBefore = receiver.balance; + + // send transaction + vm.prank(sender); + gateway.initiateTokenPurchase{ value: sendValueWithFees }(req, _signature); + + // check balances after transaction + assertEq(owner.balance, ownerBalanceBefore + ownerFeeAmount); + assertEq(client.balance, clientBalanceBefore + clientFeeAmount); + assertEq(sender.balance, senderBalanceBefore - sendValueWithFees); + assertEq(receiver.balance, receiverBalanceBefore + sendValue); + } + + function test_initiateTokenPurchase_nativeToken_directTransfer() public { + uint256 sendValue = 1 ether; + uint256 sendValueWithFees = sendValue + totalFeeAmount; + bytes memory targetCalldata = ""; + + // create pay request + PayGatewayModule.PayRequest memory req; + bytes32 _transactionId = keccak256("transaction ID"); + + req.clientId = clientId; + req.transactionId = _transactionId; + req.tokenAddress = address(0xEeeeeEeeeEeEeeEeEeEeeEEEeeeeEeeeeeeeEEeE); + req.tokenAmount = sendValue; + req.forwardAddress = payable(address(receiver)); + req.expirationTimestamp = 1000; + req.data = targetCalldata; + req.payouts = payouts; + + console.logBytes32(clientId); + console.logBytes32(_transactionId); + console.log(sendValue); + console.log(address(mockTarget)); + console.logBytes(targetCalldata); + + // generate signature + bytes memory _signature = _prepareAndSignData( + 2, // sign with operator private key, i.e. 2 + req + ); // state/balances before sending transaction uint256 ownerBalanceBefore = owner.balance; @@ -262,15 +356,15 @@ contract PaymentsGatewayTest is Test { gateway.initiateTokenPurchase{ value: sendValueWithFees }(req, _signature); // check balances after transaction - assertEq(owner.balance, ownerBalanceBefore + (sendValue * ownerFeeBps) / 10_000); - assertEq(client.balance, clientBalanceBefore + (sendValue * clientFeeBps) / 10_000); + assertEq(owner.balance, ownerBalanceBefore + ownerFeeAmount); + assertEq(client.balance, clientBalanceBefore + clientFeeAmount); assertEq(sender.balance, senderBalanceBefore - sendValueWithFees); assertEq(receiver.balance, receiverBalanceBefore + sendValue); } function test_initiateTokenPurchase_events() public { uint256 sendValue = 1 ether; - uint256 sendValueWithFees = sendValue + (sendValue * totalFeeBps) / 10_000; + uint256 sendValueWithFees = sendValue + totalFeeAmount; bytes memory targetCalldata = _buildMockTargetCalldata(sender, receiver, address(mockERC20), sendValue, ""); // approve amount to gateway contract @@ -278,7 +372,7 @@ contract PaymentsGatewayTest is Test { mockERC20.approve(address(gateway), sendValueWithFees); // create pay request - PaymentsGateway.PayRequest memory req; + PayGatewayModule.PayRequest memory req; bytes32 _transactionId = keccak256("transaction ID"); req.clientId = clientId; @@ -305,7 +399,7 @@ contract PaymentsGatewayTest is Test { function test_revert_initiateTokenPurchase_invalidSignature() public { uint256 sendValue = 1 ether; - uint256 sendValueWithFees = sendValue + (sendValue * totalFeeBps) / 10_000; + uint256 sendValueWithFees = sendValue + totalFeeAmount; bytes memory targetCalldata = _buildMockTargetCalldata(sender, receiver, address(mockERC20), sendValue, ""); // approve amount to gateway contract @@ -313,7 +407,7 @@ contract PaymentsGatewayTest is Test { mockERC20.approve(address(gateway), sendValueWithFees); // create pay request - PaymentsGateway.PayRequest memory req; + PayGatewayModule.PayRequest memory req; bytes32 _transactionId = keccak256("transaction ID"); req.clientId = clientId; @@ -333,7 +427,7 @@ contract PaymentsGatewayTest is Test { // send transaction vm.prank(sender); - vm.expectRevert(abi.encodeWithSelector(PaymentsGateway.PaymentsGatewayVerificationFailed.selector)); + vm.expectRevert(abi.encodeWithSelector(PayGatewayModule.PayGatewayVerificationFailed.selector)); gateway.initiateTokenPurchase(req, _signature); } @@ -342,7 +436,7 @@ contract PaymentsGatewayTest is Test { bytes memory targetCalldata = ""; // create pay request - PaymentsGateway.PayRequest memory req; + PayGatewayModule.PayRequest memory req; bytes32 _transactionId = keccak256("transaction ID"); req.clientId = clientId; @@ -361,75 +455,8 @@ contract PaymentsGatewayTest is Test { // send transaction vm.prank(sender); vm.expectRevert( - abi.encodeWithSelector(PaymentsGateway.PaymentsGatewayRequestExpired.selector, req.expirationTimestamp) + abi.encodeWithSelector(PayGatewayModule.PayGatewayRequestExpired.selector, req.expirationTimestamp) ); gateway.initiateTokenPurchase(req, _signature); } - - // /*/////////////////////////////////////////////////////////////// - // Test `completeTokenPurchase` - // //////////////////////////////////////////////////////////////*/ - - function test_completeTokenPurchase_erc20() public { - uint256 sendValue = 1 ether; - - // approve amount to gateway contract - vm.prank(sender); - mockERC20.approve(address(gateway), sendValue); - - // state/balances before sending transaction - uint256 ownerBalanceBefore = mockERC20.balanceOf(owner); - uint256 senderBalanceBefore = mockERC20.balanceOf(sender); - uint256 receiverBalanceBefore = mockERC20.balanceOf(receiver); - - // send transaction - bytes32 _transactionId = keccak256("transaction ID"); - vm.prank(sender); - gateway.completeTokenPurchase(clientId, _transactionId, address(mockERC20), sendValue, receiver); - - // check balances after transaction - assertEq(mockERC20.balanceOf(owner), ownerBalanceBefore); - assertEq(mockERC20.balanceOf(sender), senderBalanceBefore - sendValue); - assertEq(mockERC20.balanceOf(receiver), receiverBalanceBefore + sendValue); - } - - function test_completeTokenPurchase_nativeToken() public { - uint256 sendValue = 1 ether; - - // state/balances before sending transaction - uint256 ownerBalanceBefore = owner.balance; - uint256 senderBalanceBefore = sender.balance; - uint256 receiverBalanceBefore = receiver.balance; - - // send transaction - bytes32 _transactionId = keccak256("transaction ID"); - vm.prank(sender); - gateway.completeTokenPurchase{ value: sendValue }( - clientId, - _transactionId, - address(0xEeeeeEeeeEeEeeEeEeEeeEEEeeeeEeeeeeeeEEeE), - sendValue, - receiver - ); - - // check balances after transaction - assertEq(owner.balance, ownerBalanceBefore); - assertEq(sender.balance, senderBalanceBefore - sendValue); - assertEq(receiver.balance, receiverBalanceBefore + sendValue); - } - - function test_completeTokenPurchase_events() public { - uint256 sendValue = 1 ether; - - // approve amount to gateway contract - vm.prank(sender); - mockERC20.approve(address(gateway), sendValue); - - // send transaction - bytes32 _transactionId = keccak256("transaction ID"); - vm.prank(sender); - vm.expectEmit(true, true, false, true); - emit TokenPurchaseCompleted(clientId, receiver, _transactionId, address(mockERC20), sendValue); - gateway.completeTokenPurchase(clientId, _transactionId, address(mockERC20), sendValue, receiver); - } } diff --git a/test/benchmarks/BenchmarkPaymentsGateway.t.sol b/test/benchmarks/BenchmarkPayGateway.t.sol similarity index 76% rename from test/benchmarks/BenchmarkPaymentsGateway.t.sol rename to test/benchmarks/BenchmarkPayGateway.t.sol index f5d45ce..68696cf 100644 --- a/test/benchmarks/BenchmarkPaymentsGateway.t.sol +++ b/test/benchmarks/BenchmarkPayGateway.t.sol @@ -1,13 +1,15 @@ // SPDX-License-Identifier: UNLICENSED pragma solidity ^0.8.13; -import { Test, console, console2 } from "forge-std/Test.sol"; -import { PaymentsGateway } from "src/PaymentsGateway.sol"; +import { Test, console } from "forge-std/Test.sol"; +import { PayGateway } from "src/PayGateway.sol"; +import { PayGatewayModule } from "src/PayGatewayModule.sol"; +import { LibClone } from "lib/solady/src/utils/LibClone.sol"; import { MockERC20 } from "../utils/MockERC20.sol"; import { MockTarget } from "../utils/MockTarget.sol"; -contract BenchmarkPaymentsGatewayTest is Test { - PaymentsGateway internal gateway; +contract BenchmarkPayGatewayTest is Test { + PayGatewayModule internal gateway; MockERC20 internal mockERC20; MockTarget internal mockTarget; @@ -20,11 +22,11 @@ contract BenchmarkPaymentsGatewayTest is Test { bytes32 internal ownerClientId; bytes32 internal clientId; - uint256 internal ownerFeeBps; - uint256 internal clientFeeBps; - uint256 internal totalFeeBps; + uint256 internal ownerFeeAmount; + uint256 internal clientFeeAmount; + uint256 internal totalFeeAmount; - PaymentsGateway.PayoutInfo[] internal payouts; + PayGatewayModule.PayoutInfo[] internal payouts; bytes32 internal typehashPayRequest; bytes32 internal typehashPayoutInfo; @@ -43,10 +45,19 @@ contract BenchmarkPaymentsGatewayTest is Test { ownerClientId = keccak256("owner"); clientId = keccak256("client"); - ownerFeeBps = 200; - clientFeeBps = 100; + ownerFeeAmount = 20; + clientFeeAmount = 10; + + // deploy and install module + address module = address(new PayGatewayModule()); + + address[] memory modules = new address[](1); + bytes[] memory moduleData = new bytes[](1); + modules[0] = address(module); + moduleData[0] = ""; + + gateway = PayGatewayModule(address(new PayGateway(operator, modules, moduleData))); - gateway = new PaymentsGateway(operator); mockERC20 = new MockERC20("Token", "TKN"); mockTarget = new MockTarget(); @@ -56,19 +67,22 @@ contract BenchmarkPaymentsGatewayTest is Test { // build payout info payouts.push( - PaymentsGateway.PayoutInfo({ clientId: ownerClientId, payoutAddress: owner, feeBPS: ownerFeeBps }) + PayGatewayModule.PayoutInfo({ clientId: ownerClientId, payoutAddress: owner, feeAmount: ownerFeeAmount }) ); - payouts.push(PaymentsGateway.PayoutInfo({ clientId: clientId, payoutAddress: client, feeBPS: clientFeeBps })); + payouts.push( + PayGatewayModule.PayoutInfo({ clientId: clientId, payoutAddress: client, feeAmount: clientFeeAmount }) + ); + for (uint256 i = 0; i < payouts.length; i++) { - totalFeeBps += payouts[i].feeBPS; + totalFeeAmount += payouts[i].feeAmount; } // EIP712 - typehashPayoutInfo = keccak256("PayoutInfo(bytes32 clientId,address payoutAddress,uint256 feeBPS)"); + typehashPayoutInfo = keccak256("PayoutInfo(bytes32 clientId,address payoutAddress,uint256 feeAmount)"); typehashPayRequest = keccak256( - "PayRequest(bytes32 clientId,bytes32 transactionId,address tokenAddress,uint256 tokenAmount,uint256 expirationTimestamp,PayoutInfo[] payouts,address forwardAddress,bytes data)PayoutInfo(bytes32 clientId,address payoutAddress,uint256 feeBPS)" + "PayRequest(bytes32 clientId,bytes32 transactionId,address tokenAddress,uint256 tokenAmount,uint256 expirationTimestamp,PayoutInfo[] payouts,address forwardAddress,bool directTransfer,bytes data)PayoutInfo(bytes32 clientId,address payoutAddress,uint256 feeAmount)" ); - nameHash = keccak256(bytes("PaymentsGateway")); + nameHash = keccak256(bytes("PayGateway")); versionHash = keccak256(bytes("1")); typehashEip712 = keccak256( "EIP712Domain(string name,string version,uint256 chainId,address verifyingContract)" @@ -90,13 +104,13 @@ contract BenchmarkPaymentsGatewayTest is Test { data = abi.encode(_sender, _receiver, _token, _sendValue, _message); } - function _hashPayoutInfo(PaymentsGateway.PayoutInfo[] memory _payouts) private view returns (bytes32) { + function _hashPayoutInfo(PayGatewayModule.PayoutInfo[] memory _payouts) private view returns (bytes32) { bytes32 payoutHash = typehashPayoutInfo; bytes32[] memory payoutsHashes = new bytes32[](_payouts.length); for (uint i = 0; i < payouts.length; i++) { payoutsHashes[i] = keccak256( - abi.encode(payoutHash, _payouts[i].clientId, _payouts[i].payoutAddress, _payouts[i].feeBPS) + abi.encode(payoutHash, _payouts[i].clientId, _payouts[i].payoutAddress, _payouts[i].feeAmount) ); } return keccak256(abi.encodePacked(payoutsHashes)); @@ -104,7 +118,7 @@ contract BenchmarkPaymentsGatewayTest is Test { function _prepareAndSignData( uint256 _operatorPrivateKey, - PaymentsGateway.PayRequest memory req + PayGatewayModule.PayRequest memory req ) internal view returns (bytes memory signature) { bytes memory dataToHash; { @@ -118,6 +132,7 @@ contract BenchmarkPaymentsGatewayTest is Test { req.expirationTimestamp, _payoutsHash, req.forwardAddress, + req.directTransfer, keccak256(req.data) ); } @@ -138,7 +153,7 @@ contract BenchmarkPaymentsGatewayTest is Test { function test_initiateTokenPurchase_erc20() public { vm.pauseGasMetering(); uint256 sendValue = 1 ether; - uint256 sendValueWithFees = sendValue + (sendValue * totalFeeBps) / 10_000; + uint256 sendValueWithFees = sendValue + totalFeeAmount; bytes memory targetCalldata = _buildMockTargetCalldata(sender, receiver, address(mockERC20), sendValue, ""); // approve amount to gateway contract @@ -146,7 +161,7 @@ contract BenchmarkPaymentsGatewayTest is Test { mockERC20.approve(address(gateway), sendValueWithFees); // create pay request - PaymentsGateway.PayRequest memory req; + PayGatewayModule.PayRequest memory req; bytes32 _transactionId = keccak256("transaction ID"); req.clientId = clientId; @@ -173,7 +188,7 @@ contract BenchmarkPaymentsGatewayTest is Test { function test_initiateTokenPurchase_nativeToken() public { vm.pauseGasMetering(); uint256 sendValue = 1 ether; - uint256 sendValueWithFees = sendValue + (sendValue * totalFeeBps) / 10_000; + uint256 sendValueWithFees = sendValue + totalFeeAmount; bytes memory targetCalldata = _buildMockTargetCalldata( sender, receiver, @@ -183,7 +198,7 @@ contract BenchmarkPaymentsGatewayTest is Test { ); // create pay request - PaymentsGateway.PayRequest memory req; + PayGatewayModule.PayRequest memory req; bytes32 _transactionId = keccak256("transaction ID"); req.clientId = clientId; diff --git a/test/utils/MockERC20.sol b/test/utils/MockERC20.sol index b71b22c..8dabd48 100644 --- a/test/utils/MockERC20.sol +++ b/test/utils/MockERC20.sol @@ -1,10 +1,24 @@ // SPDX-License-Identifier: MIT pragma solidity ^0.8.22; -import "@openzeppelin/contracts/token/ERC20/ERC20.sol"; +import "lib/solady/src/tokens/ERC20.sol"; contract MockERC20 is ERC20 { - constructor(string memory name, string memory symbol) ERC20(name, symbol) {} + string private _name; + string private _symbol; + + constructor(string memory name, string memory symbol) { + _name = name; + _symbol = symbol; + } + + function name() public view override returns (string memory) { + return _name; + } + + function symbol() public view override returns (string memory) { + return _symbol; + } function mint(address to, uint256 amount) public { _mint(to, amount); diff --git a/test/utils/MockTarget.sol b/test/utils/MockTarget.sol index 2d121bb..fccce10 100644 --- a/test/utils/MockTarget.sol +++ b/test/utils/MockTarget.sol @@ -1,7 +1,7 @@ // SPDX-License-Identifier: MIT pragma solidity ^0.8.0; -import "@openzeppelin/contracts/token/ERC20/IERC20.sol"; +import "lib/solady/src/tokens/ERC20.sol"; import "lib/forge-std/src/console.sol"; contract MockTarget { @@ -23,7 +23,7 @@ contract MockTarget { emit TargetLog(sender, receiver, tokenAddress, tokenAmount, message); console.log("Transferring %s erc20 tokens from %s to %s", tokenAmount, sender, receiver); - require(IERC20(tokenAddress).transferFrom(msg.sender, receiver, tokenAmount), "Token transfer failed"); + require(ERC20(tokenAddress).transferFrom(msg.sender, receiver, tokenAmount), "Token transfer failed"); } function performNativeTokenAction(