From a7a391adda0af2f170c763567d09344b1f28f9b2 Mon Sep 17 00:00:00 2001 From: Daniel Chew Date: Mon, 4 Nov 2024 23:09:59 +0900 Subject: [PATCH 01/20] add initial contracts --- .../contracts/contracts/pulse/IPulse.sol | 144 ++++++ .../contracts/contracts/pulse/Pulse.sol | 411 ++++++++++++++++++ .../contracts/contracts/pulse/PulseErrors.sol | 10 + .../contracts/contracts/pulse/PulseState.sol | 39 ++ .../contracts/pulse/PulseUpgradeable.sol | 70 +++ .../ethereum/contracts/forge-test/Pulse.t.sol | 401 +++++++++++++++++ 6 files changed, 1075 insertions(+) create mode 100644 target_chains/ethereum/contracts/contracts/pulse/IPulse.sol create mode 100644 target_chains/ethereum/contracts/contracts/pulse/Pulse.sol create mode 100644 target_chains/ethereum/contracts/contracts/pulse/PulseErrors.sol create mode 100644 target_chains/ethereum/contracts/contracts/pulse/PulseState.sol create mode 100644 target_chains/ethereum/contracts/contracts/pulse/PulseUpgradeable.sol create mode 100644 target_chains/ethereum/contracts/forge-test/Pulse.t.sol diff --git a/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol b/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol new file mode 100644 index 0000000000..8bcf6f263d --- /dev/null +++ b/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol @@ -0,0 +1,144 @@ +// SPDX-License-Identifier: Apache 2 + +pragma solidity ^0.8.0; + +import "./PulseState.sol"; + +interface IPulseConsumer { + function pulseCallback( + uint64 sequenceNumber, + address provider, + uint256 publishTime, + bytes32[] calldata priceIds + ) external; +} + +interface IPulse { + // Events + event PriceUpdateRequested( + uint64 indexed sequenceNumber, + address indexed provider, + uint256 publishTime, + bytes32[] priceIds, + address requester + ); + + event PriceUpdateExecuted( + uint64 indexed sequenceNumber, + address indexed provider, + uint256 publishTime, + bytes32[] priceIds + ); + + event ProviderRegistered( + address indexed provider, + uint128 feeInWei, + bytes uri + ); + + event ProviderFeeUpdated( + address indexed provider, + uint128 oldFeeInWei, + uint128 newFeeInWei + ); + + event ProviderWithdrawn( + address indexed provider, + address indexed recipient, + uint128 amount + ); + + event ProviderFeeManagerUpdated( + address indexed provider, + address oldFeeManager, + address newFeeManager + ); + + event ProviderUriUpdated( + address indexed provider, + bytes oldUri, + bytes newUri + ); + + event ProviderMaxNumPricesUpdated( + address indexed provider, + uint32 oldMaxNumPrices, + uint32 maxNumPrices + ); + + event PriceUpdateCallbackFailed( + uint64 indexed sequenceNumber, + address indexed provider, + uint256 publishTime, + bytes32[] priceIds, + address requester, + string reason + ); + + // Core functions + function requestPriceUpdatesWithCallback( + address provider, + uint256 publishTime, + bytes32[] calldata priceIds, + bytes[] calldata updateData, + uint256 callbackGasLimit + ) external payable returns (uint64 sequenceNumber); + + function executeCallback( + uint64 sequenceNumber, + uint256 publishTime, + bytes32[] calldata priceIds, + bytes[] calldata updateData, + uint256 callbackGasLimit + ) external; + + // Provider management + function register(uint128 feeInWei, bytes calldata uri) external; + + function setProviderFee(uint128 newFeeInWei) external; + + function withdraw(uint128 amount) external; + + // Add to interface + function withdrawAsFeeManager(address provider, uint128 amount) external; + + // Add to Provider management section + function setProviderUri(bytes calldata uri) external; + + // Getters + function getFee(address provider) external view returns (uint128 feeAmount); + + function getDefaultProvider() external view returns (address); + + // Add to interface + function setFeeManager(address manager) external; + + // Add to interface + function setProviderFeeAsFeeManager( + address provider, + uint128 newFeeInWei + ) external; + + // Add to Getters section + function getAccruedPythFees() + external + view + returns (uint128 accruedPythFeesInWei); + + // Add to Getters section + function getProviderInfo( + address provider + ) external view returns (PulseState.ProviderInfo memory info); + + function getAdmin() external view returns (address admin); + + function getPythFeeInWei() external view returns (uint128 pythFeeInWei); + + function setMaxNumPrices(uint32 maxNumPrices) external; + + // Add to Getters section + function getRequest( + address provider, + uint64 sequenceNumber + ) external view returns (PulseState.Request memory req); +} diff --git a/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol b/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol new file mode 100644 index 0000000000..e17c4da49f --- /dev/null +++ b/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol @@ -0,0 +1,411 @@ +// SPDX-License-Identifier: Apache 2 + +pragma solidity ^0.8.0; + +import "@openzeppelin/contracts/security/ReentrancyGuard.sol"; +import "@openzeppelin/contracts/utils/math/SafeCast.sol"; +import "./IPulse.sol"; +import "./PulseState.sol"; +import "./PulseErrors.sol"; + +contract Pulse is IPulse, ReentrancyGuard, PulseState { + using SafeCast for uint256; + + function _initialize( + address admin, + uint128 pythFeeInWei, + address defaultProvider, + bool prefillRequestStorage + ) internal { + require(admin != address(0), "admin is zero address"); + require( + defaultProvider != address(0), + "defaultProvider is zero address" + ); + + _state.admin = admin; + _state.pythFeeInWei = pythFeeInWei; + _state.accruedPythFeesInWei = 0; + _state.defaultProvider = defaultProvider; + + if (prefillRequestStorage) { + // Prefill storage slots to make future requests use less gas + for (uint8 i = 0; i < NUM_REQUESTS; i++) { + Request storage req = _state.requests[i]; + req.provider = address(1); + req.sequenceNumber = 0; // Keep it inactive + req.publishTime = 1; + // No need to prefill dynamic arrays (priceIds, updateData) + req.callbackGasLimit = 1; + req.requester = address(1); + } + } + } + + function requestPriceUpdatesWithCallback( + address provider, + uint256 publishTime, + bytes32[] calldata priceIds, + bytes[] calldata updateData, + uint256 callbackGasLimit + ) + external + payable + override + nonReentrant + returns (uint64 requestSequenceNumber) + { + ProviderInfo storage providerInfo = _state.providers[provider]; + if (providerInfo.sequenceNumber == 0) revert NoSuchProvider(); + + if ( + providerInfo.maxNumPrices > 0 && + priceIds.length > providerInfo.maxNumPrices + ) { + revert("Exceeds max number of prices"); + } + + // Assign sequence number and increment + requestSequenceNumber = providerInfo.sequenceNumber++; + + // Verify fee payment + uint128 requiredFee = getFee(provider); + if (msg.value < requiredFee) revert InsufficientFee(); + + // Store request for callback execution + Request storage req = allocRequest(provider, requestSequenceNumber); + req.provider = provider; + req.sequenceNumber = requestSequenceNumber; + req.publishTime = publishTime; + req.priceIds = priceIds; + req.updateData = updateData; + req.callbackGasLimit = callbackGasLimit; + req.requester = msg.sender; + + // Update fee balances + providerInfo.accruedFeesInWei += providerInfo.feeInWei; + _state.accruedPythFeesInWei += (msg.value.toUint128() - + providerInfo.feeInWei); + + emit PriceUpdateRequested( + requestSequenceNumber, + provider, + publishTime, + priceIds, + msg.sender + ); + } + + function executeCallback( + uint64 sequenceNumber, + uint256 publishTime, + bytes32[] calldata priceIds, + bytes[] calldata updateData, + uint256 callbackGasLimit + ) external override nonReentrant { + Request storage req = findActiveRequest(msg.sender, sequenceNumber); + + // Verify request parameters match + require(req.publishTime == publishTime, "Invalid publish time"); + require( + keccak256(abi.encode(req.priceIds)) == + keccak256(abi.encode(priceIds)), + "Invalid price IDs" + ); + require( + keccak256(abi.encode(req.updateData)) == + keccak256(abi.encode(updateData)), + "Invalid update data" + ); + require( + req.callbackGasLimit == callbackGasLimit, + "Invalid callback gas limit" + ); + + // Execute callback but don't revert if it fails + try + IPulseConsumer(req.requester).pulseCallback( + sequenceNumber, + msg.sender, + publishTime, + priceIds + ) + { + // Callback succeeded + emit PriceUpdateExecuted( + sequenceNumber, + msg.sender, + publishTime, + priceIds + ); + } catch Error(string memory reason) { + // Explicit revert/require + emit PriceUpdateCallbackFailed( + sequenceNumber, + msg.sender, + publishTime, + priceIds, + req.requester, + reason + ); + } catch { + // Out of gas or other low-level errors + emit PriceUpdateCallbackFailed( + sequenceNumber, + msg.sender, + publishTime, + priceIds, + req.requester, + "low-level error (possibly out of gas)" + ); + } + + // Clear request regardless of callback success + clearRequest(msg.sender, sequenceNumber); + } + + function register(uint128 feeInWei, bytes calldata uri) public override { + ProviderInfo storage provider = _state.providers[msg.sender]; + + provider.feeInWei = feeInWei; + provider.uri = uri; + + if (provider.sequenceNumber == 0) { + provider.sequenceNumber = 1; + } + + emit ProviderRegistered(msg.sender, feeInWei, uri); + } + + function setProviderFee(uint128 newFeeInWei) external override { + ProviderInfo storage provider = _state.providers[msg.sender]; + if (provider.sequenceNumber == 0) revert NoSuchProvider(); + + uint128 oldFeeInWei = provider.feeInWei; + provider.feeInWei = newFeeInWei; + + emit ProviderFeeUpdated(msg.sender, oldFeeInWei, newFeeInWei); + } + + function getFee( + address provider + ) public view override returns (uint128 feeAmount) { + feeAmount = _state.providers[provider].feeInWei + _state.pythFeeInWei; + } + + function getDefaultProvider() + external + view + override + returns (address defaultProvider) + { + defaultProvider = _state.defaultProvider; + } + + // Internal helper functions + function findActiveRequest( + address provider, + uint64 sequenceNumber + ) internal view returns (Request storage activeRequest) { + activeRequest = findRequest(provider, sequenceNumber); + if ( + !isActive(activeRequest) || + activeRequest.provider != provider || + activeRequest.sequenceNumber != sequenceNumber + ) { + revert NoSuchRequest(); + } + } + + function findRequest( + address provider, + uint64 sequenceNumber + ) internal view returns (Request storage foundRequest) { + (bytes32 key, uint8 shortKey) = requestKey(provider, sequenceNumber); + foundRequest = _state.requests[shortKey]; + + if ( + foundRequest.provider == provider && + foundRequest.sequenceNumber == sequenceNumber + ) { + return foundRequest; + } else { + foundRequest = _state.requestsOverflow[key]; + } + } + + function clearRequest(address provider, uint64 sequenceNumber) internal { + (bytes32 key, uint8 shortKey) = requestKey(provider, sequenceNumber); + Request storage req = _state.requests[shortKey]; + + if (req.provider == provider && req.sequenceNumber == sequenceNumber) { + req.sequenceNumber = 0; + } else { + delete _state.requestsOverflow[key]; + } + } + + function allocRequest( + address provider, + uint64 sequenceNumber + ) internal returns (Request storage newRequest) { + (, uint8 shortKey) = requestKey(provider, sequenceNumber); + newRequest = _state.requests[shortKey]; + + if (isActive(newRequest)) { + (bytes32 reqKey, ) = requestKey( + newRequest.provider, + newRequest.sequenceNumber + ); + _state.requestsOverflow[reqKey] = newRequest; + } + } + + function requestKey( + address provider, + uint64 sequenceNumber + ) internal pure returns (bytes32 hashKey, uint8 shortHashKey) { + hashKey = keccak256(abi.encodePacked(provider, sequenceNumber)); + shortHashKey = uint8(hashKey[0] & NUM_REQUESTS_MASK); + } + + function isActive( + Request storage req + ) internal view returns (bool isRequestActive) { + isRequestActive = req.sequenceNumber != 0; + } + + function withdraw(uint128 amount) public override { + ProviderInfo storage providerInfo = _state.providers[msg.sender]; + + // Use checks-effects-interactions pattern to prevent reentrancy attacks + require( + providerInfo.accruedFeesInWei >= amount, + "Insufficient balance" + ); + providerInfo.accruedFeesInWei -= amount; + + // Interaction with an external contract or token transfer + (bool sent, ) = msg.sender.call{value: amount}(""); + require(sent, "withdrawal to msg.sender failed"); + + emit ProviderWithdrawn(msg.sender, msg.sender, amount); + } + + function withdrawAsFeeManager( + address provider, + uint128 amount + ) external override { + ProviderInfo storage providerInfo = _state.providers[provider]; + + if (providerInfo.sequenceNumber == 0) { + revert NoSuchProvider(); + } + + if (providerInfo.feeManager != msg.sender) { + revert Unauthorized(); + } + + // Use checks-effects-interactions pattern to prevent reentrancy attacks + require( + providerInfo.accruedFeesInWei >= amount, + "Insufficient balance" + ); + providerInfo.accruedFeesInWei -= amount; + + // Interaction with an external contract or token transfer + (bool sent, ) = msg.sender.call{value: amount}(""); + require(sent, "withdrawal to msg.sender failed"); + + emit ProviderWithdrawn(provider, msg.sender, amount); + } + + function setFeeManager(address manager) external override { + ProviderInfo storage provider = _state.providers[msg.sender]; + if (provider.sequenceNumber == 0) revert NoSuchProvider(); + + address oldFeeManager = provider.feeManager; + provider.feeManager = manager; + + emit ProviderFeeManagerUpdated(msg.sender, oldFeeManager, manager); + } + + function setProviderFeeAsFeeManager( + address provider, + uint128 newFeeInWei + ) external override { + ProviderInfo storage providerInfo = _state.providers[provider]; + + if (providerInfo.sequenceNumber == 0) { + revert NoSuchProvider(); + } + + if (providerInfo.feeManager != msg.sender) { + revert Unauthorized(); + } + + uint128 oldFeeInWei = providerInfo.feeInWei; + providerInfo.feeInWei = newFeeInWei; + + emit ProviderFeeUpdated(provider, oldFeeInWei, newFeeInWei); + } + + function getAccruedPythFees() + public + view + override + returns (uint128 accruedPythFeesInWei) + { + accruedPythFeesInWei = _state.accruedPythFeesInWei; + } + + function getProviderInfo( + address provider + ) public view override returns (ProviderInfo memory info) { + info = _state.providers[provider]; + } + + function getAdmin() external view override returns (address adminAddress) { + adminAddress = _state.admin; + } + + function getPythFeeInWei() + external + view + override + returns (uint128 pythFee) + { + pythFee = _state.pythFeeInWei; + } + + function setProviderUri(bytes calldata uri) external override { + ProviderInfo storage provider = _state.providers[msg.sender]; + if (provider.sequenceNumber == 0) revert NoSuchProvider(); + + bytes memory oldUri = provider.uri; + provider.uri = uri; + + emit ProviderUriUpdated(msg.sender, oldUri, uri); + } + + function setMaxNumPrices(uint32 maxNumPrices) external override { + ProviderInfo storage provider = _state.providers[msg.sender]; + if (provider.sequenceNumber == 0) revert NoSuchProvider(); + + uint32 oldMaxNumPrices = provider.maxNumPrices; + provider.maxNumPrices = maxNumPrices; + + emit ProviderMaxNumPricesUpdated( + msg.sender, + oldMaxNumPrices, + maxNumPrices + ); + } + + function getRequest( + address provider, + uint64 sequenceNumber + ) public view override returns (Request memory req) { + req = findRequest(provider, sequenceNumber); + } +} diff --git a/target_chains/ethereum/contracts/contracts/pulse/PulseErrors.sol b/target_chains/ethereum/contracts/contracts/pulse/PulseErrors.sol new file mode 100644 index 0000000000..187ccf00a2 --- /dev/null +++ b/target_chains/ethereum/contracts/contracts/pulse/PulseErrors.sol @@ -0,0 +1,10 @@ +// SPDX-License-Identifier: Apache 2 + +pragma solidity ^0.8.0; + +error NoSuchProvider(); +error NoSuchRequest(); +error InsufficientFee(); +error Unauthorized(); +error InvalidCallbackGas(); +error CallbackFailed(); diff --git a/target_chains/ethereum/contracts/contracts/pulse/PulseState.sol b/target_chains/ethereum/contracts/contracts/pulse/PulseState.sol new file mode 100644 index 0000000000..839a6a0d21 --- /dev/null +++ b/target_chains/ethereum/contracts/contracts/pulse/PulseState.sol @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: Apache 2 + +pragma solidity ^0.8.0; + +contract PulseState { + uint8 public constant NUM_REQUESTS = 32; + bytes1 public constant NUM_REQUESTS_MASK = 0x1f; + + struct Request { + address provider; + uint64 sequenceNumber; + uint256 publishTime; + bytes32[] priceIds; + bytes[] updateData; + uint256 callbackGasLimit; + address requester; + } + + struct ProviderInfo { + uint64 sequenceNumber; + uint128 feeInWei; + uint128 accruedFeesInWei; + bytes uri; + address feeManager; + uint32 maxNumPrices; + } + + struct State { + address admin; + uint128 pythFeeInWei; + uint128 accruedPythFeesInWei; + address defaultProvider; + Request[32] requests; + mapping(bytes32 => Request) requestsOverflow; + mapping(address => ProviderInfo) providers; + } + + State internal _state; +} diff --git a/target_chains/ethereum/contracts/contracts/pulse/PulseUpgradeable.sol b/target_chains/ethereum/contracts/contracts/pulse/PulseUpgradeable.sol new file mode 100644 index 0000000000..191ef15889 --- /dev/null +++ b/target_chains/ethereum/contracts/contracts/pulse/PulseUpgradeable.sol @@ -0,0 +1,70 @@ +// SPDX-License-Identifier: Apache 2 + +pragma solidity ^0.8.0; + +import "@openzeppelin/contracts-upgradeable/proxy/utils/Initializable.sol"; +import "@openzeppelin/contracts-upgradeable/proxy/utils/UUPSUpgradeable.sol"; +import "@openzeppelin/contracts-upgradeable/access/Ownable2StepUpgradeable.sol"; +import "./Pulse.sol"; + +contract PulseUpgradeable is + Initializable, + Ownable2StepUpgradeable, + UUPSUpgradeable, + Pulse +{ + event ContractUpgraded( + address oldImplementation, + address newImplementation + ); + + function initialize( + address owner, + address admin, + uint128 pythFeeInWei, + address defaultProvider, + bool prefillRequestStorage + ) public initializer { + require(owner != address(0), "owner is zero address"); + + __Ownable_init(); + __UUPSUpgradeable_init(); + + Pulse._initialize( + admin, + pythFeeInWei, + defaultProvider, + prefillRequestStorage + ); + + _transferOwnership(owner); + } + + /// @custom:oz-upgrades-unsafe-allow constructor + constructor() initializer {} + + function _authorizeUpgrade(address) internal override onlyOwner {} + + function upgradeTo(address newImplementation) external override onlyProxy { + address oldImplementation = _getImplementation(); + _authorizeUpgrade(newImplementation); + _upgradeToAndCallUUPS(newImplementation, new bytes(0), false); + + emit ContractUpgraded(oldImplementation, _getImplementation()); + } + + function upgradeToAndCall( + address newImplementation, + bytes memory data + ) external payable override onlyProxy { + address oldImplementation = _getImplementation(); + _authorizeUpgrade(newImplementation); + _upgradeToAndCallUUPS(newImplementation, data, true); + + emit ContractUpgraded(oldImplementation, _getImplementation()); + } + + function version() public pure returns (string memory) { + return "1.0.0"; + } +} diff --git a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol new file mode 100644 index 0000000000..615d45eb7f --- /dev/null +++ b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol @@ -0,0 +1,401 @@ +// SPDX-License-Identifier: Apache 2 + +pragma solidity ^0.8.0; + +import "forge-std/Test.sol"; +import "../contracts/pulse/PulseUpgradeable.sol"; +import "../contracts/pulse/IPulse.sol"; +import "../contracts/pulse/PulseState.sol"; + +contract MockPulseConsumer is IPulseConsumer { + uint64 public lastSequenceNumber; + address public lastProvider; + uint256 public lastPublishTime; + bytes32[] public lastPriceIds; + + function pulseCallback( + uint64 sequenceNumber, + address provider, + uint256 publishTime, + bytes32[] calldata priceIds + ) external override { + lastSequenceNumber = sequenceNumber; + lastProvider = provider; + lastPublishTime = publishTime; + lastPriceIds = priceIds; + } +} + +contract PulseTest is Test { + PulseUpgradeable public pulse; + MockPulseConsumer public consumer; + address public owner; + address public admin; + address public provider; + uint128 constant PYTH_FEE = 0.001 ether; + uint128 constant PROVIDER_FEE = 0.002 ether; + + function setUp() public { + owner = address(1); + admin = address(2); + provider = address(3); + + // Deploy contracts + pulse = new PulseUpgradeable(); + pulse.initialize(owner, admin, PYTH_FEE, provider); + consumer = new MockPulseConsumer(); + + // Register provider + vm.prank(provider); + pulse.register(PROVIDER_FEE, "https://provider.com"); + } + + function testRequestPriceUpdate() public { + bytes32[] memory priceIds = new bytes32[](2); + priceIds[0] = bytes32("BTC/USD"); + priceIds[1] = bytes32("ETH/USD"); + + bytes[] memory updateData = new bytes[](2); + updateData[0] = bytes("data1"); + updateData[1] = bytes("data2"); + + uint256 publishTime = block.timestamp; + uint256 callbackGasLimit = 500000; + + vm.prank(address(consumer)); + uint64 sequenceNumber = pulse.requestPriceUpdatesWithCallback{ + value: PYTH_FEE + PROVIDER_FEE + }(provider, publishTime, priceIds, updateData, callbackGasLimit); + + assertEq(sequenceNumber, 1); + } + + function testExecuteCallback() public { + bytes32[] memory priceIds = new bytes32[](2); + priceIds[0] = bytes32("BTC/USD"); + priceIds[1] = bytes32("ETH/USD"); + + bytes[] memory updateData = new bytes[](2); + updateData[0] = bytes("data1"); + updateData[1] = bytes("data2"); + + uint256 publishTime = block.timestamp; + uint256 callbackGasLimit = 500000; + + vm.prank(address(consumer)); + uint64 sequenceNumber = pulse.requestPriceUpdatesWithCallback{ + value: PYTH_FEE + PROVIDER_FEE + }(provider, publishTime, priceIds, updateData, callbackGasLimit); + + vm.prank(provider); + pulse.executeCallback( + sequenceNumber, + publishTime, + priceIds, + updateData, + callbackGasLimit + ); + + assertEq(consumer.lastSequenceNumber(), sequenceNumber); + assertEq(consumer.lastProvider(), provider); + assertEq(consumer.lastPublishTime(), publishTime); + assertEq(consumer.lastPriceIds()[0], priceIds[0]); + assertEq(consumer.lastPriceIds()[1], priceIds[1]); + } + + function testProviderRegistration() public { + address newProvider = address(4); + vm.prank(newProvider); + pulse.register(PROVIDER_FEE, "https://newprovider.com"); + + uint128 fee = pulse.getFee(newProvider); + assertEq(fee, PYTH_FEE + PROVIDER_FEE); + } + + function testUpdateProviderFee() public { + uint128 newFee = 0.003 ether; + vm.prank(provider); + pulse.setProviderFee(newFee); + + uint128 fee = pulse.getFee(provider); + assertEq(fee, PYTH_FEE + newFee); + } + + function testFailInsufficientFee() public { + bytes32[] memory priceIds = new bytes32[](1); + bytes[] memory updateData = new bytes[](1); + + vm.prank(address(consumer)); + pulse.requestPriceUpdatesWithCallback{value: PYTH_FEE}( // Not paying provider fee + provider, + block.timestamp, + priceIds, + updateData, + 500000 + ); + } + + function testFailUnregisteredProvider() public { + bytes32[] memory priceIds = new bytes32[](1); + bytes[] memory updateData = new bytes[](1); + + vm.prank(address(consumer)); + pulse.requestPriceUpdatesWithCallback{value: PYTH_FEE + PROVIDER_FEE}( + address(99), // Unregistered provider + block.timestamp, + priceIds, + updateData, + 500000 + ); + } + + function testGasCostsWithPrefill() public { + // Deploy with prefill + PulseUpgradeable pulseWithPrefill = new PulseUpgradeable(); + pulseWithPrefill.initialize(owner, admin, PYTH_FEE, provider, true); + + // Measure gas for first request + uint256 gasBefore = gasleft(); + makeRequest(address(pulseWithPrefill)); + uint256 gasUsed = gasBefore - gasleft(); + + // Should be lower due to prefill + assertLt(gasUsed, 30000); + } + + function testGasCostsWithoutPrefill() public { + // Deploy without prefill + PulseUpgradeable pulseWithoutPrefill = new PulseUpgradeable(); + pulseWithoutPrefill.initialize(owner, admin, PYTH_FEE, provider, false); + + // Measure gas for first request + uint256 gasBefore = gasleft(); + makeRequest(address(pulseWithoutPrefill)); + uint256 gasUsed = gasBefore - gasleft(); + + // Should be higher without prefill + assertGt(gasUsed, 35000); + } + + function makeRequest(address pulseAddress) internal { + // Helper to make a standard request + bytes32[] memory priceIds = new bytes32[](1); + bytes[] memory updateData = new bytes[](1); + IPulse(pulseAddress).requestPriceUpdatesWithCallback{ + value: PYTH_FEE + PROVIDER_FEE + }(provider, block.timestamp, priceIds, updateData, 500000); + } + + function testWithdraw() public { + // Setup - make a request to accrue some fees + bytes32[] memory priceIds = new bytes32[](1); + bytes[] memory updateData = new bytes[](1); + + vm.prank(address(consumer)); + pulse.requestPriceUpdatesWithCallback{value: PYTH_FEE + PROVIDER_FEE}( + provider, + block.timestamp, + priceIds, + updateData, + 500000 + ); + + // Check provider balance before withdrawal + uint256 providerBalanceBefore = address(provider).balance; + + // Provider withdraws fees + vm.prank(provider); + pulse.withdraw(PROVIDER_FEE); + + // Verify balance increased + assertEq( + address(provider).balance, + providerBalanceBefore + PROVIDER_FEE + ); + } + + function testFailWithdrawTooMuch() public { + vm.prank(provider); + pulse.withdraw(1 ether); // Try to withdraw more than accrued + } + + function testFailWithdrawUnregistered() public { + vm.prank(address(99)); // Unregistered provider + pulse.withdraw(1 ether); + } + + function testWithdrawAsFeeManager() public { + // Setup fee manager + vm.prank(provider); + pulse.setFeeManager(address(99)); + + // Setup fees + bytes32[] memory priceIds = new bytes32[](1); + bytes[] memory updateData = new bytes[](1); + + vm.prank(address(consumer)); + pulse.requestPriceUpdatesWithCallback{value: PYTH_FEE + PROVIDER_FEE}( + provider, + block.timestamp, + priceIds, + updateData, + 500000 + ); + + // Check fee manager balance before withdrawal + uint256 managerBalanceBefore = address(99).balance; + + // Fee manager withdraws + vm.prank(address(99)); + pulse.withdrawAsFeeManager(provider, PROVIDER_FEE); + + // Verify balance increased + assertEq(address(99).balance, managerBalanceBefore + PROVIDER_FEE); + } + + function testFailWithdrawAsFeeManagerUnauthorized() public { + vm.prank(address(88)); // Not the fee manager + pulse.withdrawAsFeeManager(provider, PROVIDER_FEE); + } + + function testSetProviderFeeAsFeeManager() public { + // Setup fee manager + vm.prank(provider); + pulse.setFeeManager(address(99)); + + uint128 newFee = 0.005 ether; + + // Fee manager updates fee + vm.prank(address(99)); + pulse.setProviderFeeAsFeeManager(provider, newFee); + + // Verify fee was updated + uint128 fee = pulse.getFee(provider); + assertEq(fee, PYTH_FEE + newFee); + } + + function testFailSetProviderFeeAsFeeManagerUnauthorized() public { + vm.prank(address(88)); // Not the fee manager + pulse.setProviderFeeAsFeeManager(provider, 0.005 ether); + } + + function testGetAccruedPythFees() public { + // Setup - make a request to accrue some fees + bytes32[] memory priceIds = new bytes32[](1); + bytes[] memory updateData = new bytes[](1); + + vm.prank(address(consumer)); + pulse.requestPriceUpdatesWithCallback{value: PYTH_FEE + PROVIDER_FEE}( + provider, + block.timestamp, + priceIds, + updateData, + 500000 + ); + + // Verify accrued fees + assertEq(pulse.getAccruedPythFees(), PYTH_FEE); + } + + function testGetProviderInfo() public { + // Get provider info + PulseState.ProviderInfo memory info = pulse.getProviderInfo(provider); + + // Verify initial values + assertEq(info.sequenceNumber, 1); // Set during registration + assertEq(info.feeInWei, PROVIDER_FEE); + assertEq(info.accruedFeesInWei, 0); + assertEq(string(info.uri), "https://provider.com"); + assertEq(info.feeManager, address(0)); + assertEq(info.maxNumPrices, 0); + } + + function testGetAdmin() public { + assertEq(pulse.getAdmin(), admin); + } + + function testGetPythFeeInWei() public { + assertEq(pulse.getPythFeeInWei(), PYTH_FEE); + } + + function testSetProviderUri() public { + bytes memory newUri = bytes("https://new-provider-endpoint.com"); + + vm.prank(provider); + pulse.setProviderUri(newUri); + + // Get provider info and verify URI was updated + (, , , bytes memory uri, ) = pulse.getProviderInfo(provider); + assertEq(string(uri), string(newUri)); + } + + function testFailSetProviderUriUnregistered() public { + vm.prank(address(99)); // Unregistered provider + pulse.setProviderUri(bytes("https://new-uri.com")); + } + + function testSetMaxNumPrices() public { + uint32 maxPrices = 5; + + vm.prank(provider); + pulse.setMaxNumPrices(maxPrices); + + // Get provider info and verify maxNumPrices was updated + (, , , , address feeManager) = pulse.getProviderInfo(provider); + assertEq(uint256(maxPrices), uint256(maxPrices)); + } + + function testFailExceedMaxNumPrices() public { + // Set max prices to 2 + vm.prank(provider); + pulse.setMaxNumPrices(2); + + // Try to request 3 prices + bytes32[] memory priceIds = new bytes32[](3); + bytes[] memory updateData = new bytes[](3); + + vm.prank(address(consumer)); + pulse.requestPriceUpdatesWithCallback{value: PYTH_FEE + PROVIDER_FEE}( + provider, + block.timestamp, + priceIds, + updateData, + 500000 + ); + } + + function testGetRequest() public { + // Setup - make a request + bytes32[] memory priceIds = new bytes32[](2); + priceIds[0] = bytes32("BTC/USD"); + priceIds[1] = bytes32("ETH/USD"); + + bytes[] memory updateData = new bytes[](2); + updateData[0] = bytes("data1"); + updateData[1] = bytes("data2"); + + uint256 publishTime = block.timestamp; + uint256 callbackGasLimit = 500000; + + vm.prank(address(consumer)); + uint64 sequenceNumber = pulse.requestPriceUpdatesWithCallback{ + value: PYTH_FEE + PROVIDER_FEE + }(provider, publishTime, priceIds, updateData, callbackGasLimit); + + // Get request and verify + PulseState.Request memory req = pulse.getRequest( + provider, + sequenceNumber + ); + + assertEq(req.provider, provider); + assertEq(req.sequenceNumber, sequenceNumber); + assertEq(req.publishTime, publishTime); + assertEq(req.priceIds[0], priceIds[0]); + assertEq(req.priceIds[1], priceIds[1]); + assertEq(string(req.updateData[0]), string(updateData[0])); + assertEq(string(req.updateData[1]), string(updateData[1])); + assertEq(req.callbackGasLimit, callbackGasLimit); + assertEq(req.requester, address(consumer)); + } +} From 9633c259afd79a9cb725f6c8b8d9469cc1c651f1 Mon Sep 17 00:00:00 2001 From: Daniel Chew Date: Mon, 4 Nov 2024 23:13:44 +0900 Subject: [PATCH 02/20] refactor --- .../contracts/contracts/pulse/IPulse.sol | 26 +++++++------------ 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol b/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol index 8bcf6f263d..4fed40092c 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol @@ -99,10 +99,8 @@ interface IPulse { function withdraw(uint128 amount) external; - // Add to interface function withdrawAsFeeManager(address provider, uint128 amount) external; - // Add to Provider management section function setProviderUri(bytes calldata uri) external; // Getters @@ -110,22 +108,11 @@ interface IPulse { function getDefaultProvider() external view returns (address); - // Add to interface - function setFeeManager(address manager) external; - - // Add to interface - function setProviderFeeAsFeeManager( - address provider, - uint128 newFeeInWei - ) external; - - // Add to Getters section function getAccruedPythFees() external view returns (uint128 accruedPythFeesInWei); - // Add to Getters section function getProviderInfo( address provider ) external view returns (PulseState.ProviderInfo memory info); @@ -134,11 +121,18 @@ interface IPulse { function getPythFeeInWei() external view returns (uint128 pythFeeInWei); - function setMaxNumPrices(uint32 maxNumPrices) external; - - // Add to Getters section function getRequest( address provider, uint64 sequenceNumber ) external view returns (PulseState.Request memory req); + + // Setters + function setFeeManager(address manager) external; + + function setProviderFeeAsFeeManager( + address provider, + uint128 newFeeInWei + ) external; + + function setMaxNumPrices(uint32 maxNumPrices) external; } From eeeaaf548856540e8cf9647cf129c4d566d6f979 Mon Sep 17 00:00:00 2001 From: Daniel Chew Date: Wed, 6 Nov 2024 15:01:40 +0900 Subject: [PATCH 03/20] fix --- .../contracts/contracts/pulse/IPulse.sol | 51 +-- .../contracts/contracts/pulse/Pulse.sol | 390 +++++++++--------- .../ethereum/contracts/forge-test/Pulse.t.sol | 99 +++-- 3 files changed, 271 insertions(+), 269 deletions(-) diff --git a/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol b/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol index 4fed40092c..e7fac6da7c 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol @@ -15,13 +15,9 @@ interface IPulseConsumer { interface IPulse { // Events - event PriceUpdateRequested( - uint64 indexed sequenceNumber, - address indexed provider, - uint256 publishTime, - bytes32[] priceIds, - address requester - ); + event ProviderRegistered(PulseState.ProviderInfo providerInfo); + + event PriceUpdateRequested(PulseState.Request request); event PriceUpdateExecuted( uint64 indexed sequenceNumber, @@ -30,18 +26,18 @@ interface IPulse { bytes32[] priceIds ); - event ProviderRegistered( - address indexed provider, - uint128 feeInWei, - bytes uri - ); - event ProviderFeeUpdated( address indexed provider, uint128 oldFeeInWei, uint128 newFeeInWei ); + event ProviderUriUpdated( + address indexed provider, + bytes oldUri, + bytes newUri + ); + event ProviderWithdrawn( address indexed provider, address indexed recipient, @@ -54,12 +50,6 @@ interface IPulse { address newFeeManager ); - event ProviderUriUpdated( - address indexed provider, - bytes oldUri, - bytes newUri - ); - event ProviderMaxNumPricesUpdated( address indexed provider, uint32 oldMaxNumPrices, @@ -80,7 +70,6 @@ interface IPulse { address provider, uint256 publishTime, bytes32[] calldata priceIds, - bytes[] calldata updateData, uint256 callbackGasLimit ) external payable returns (uint64 sequenceNumber); @@ -97,30 +86,33 @@ interface IPulse { function setProviderFee(uint128 newFeeInWei) external; + function setProviderFeeAsFeeManager( + address provider, + uint128 newFeeInWei + ) external; + + function setProviderUri(bytes calldata uri) external; + function withdraw(uint128 amount) external; function withdrawAsFeeManager(address provider, uint128 amount) external; - function setProviderUri(bytes calldata uri) external; - // Getters function getFee(address provider) external view returns (uint128 feeAmount); - function getDefaultProvider() external view returns (address); + function getPythFeeInWei() external view returns (uint128 pythFeeInWei); function getAccruedPythFees() external view returns (uint128 accruedPythFeesInWei); + function getDefaultProvider() external view returns (address); + function getProviderInfo( address provider ) external view returns (PulseState.ProviderInfo memory info); - function getAdmin() external view returns (address admin); - - function getPythFeeInWei() external view returns (uint128 pythFeeInWei); - function getRequest( address provider, uint64 sequenceNumber @@ -129,10 +121,5 @@ interface IPulse { // Setters function setFeeManager(address manager) external; - function setProviderFeeAsFeeManager( - address provider, - uint128 newFeeInWei - ) external; - function setMaxNumPrices(uint32 maxNumPrices) external; } diff --git a/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol b/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol index e17c4da49f..e6397a2f65 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol @@ -2,15 +2,12 @@ pragma solidity ^0.8.0; -import "@openzeppelin/contracts/security/ReentrancyGuard.sol"; import "@openzeppelin/contracts/utils/math/SafeCast.sol"; import "./IPulse.sol"; import "./PulseState.sol"; import "./PulseErrors.sol"; -contract Pulse is IPulse, ReentrancyGuard, PulseState { - using SafeCast for uint256; - +abstract contract Pulse is IPulse, PulseState { function _initialize( address admin, uint128 pythFeeInWei, @@ -24,15 +21,16 @@ contract Pulse is IPulse, ReentrancyGuard, PulseState { ); _state.admin = admin; - _state.pythFeeInWei = pythFeeInWei; _state.accruedPythFeesInWei = 0; + _state.pythFeeInWei = pythFeeInWei; _state.defaultProvider = defaultProvider; if (prefillRequestStorage) { - // Prefill storage slots to make future requests use less gas + // Write some data to every storage slot in the requests array such that new requests + // use a more consistent amount of gas. + // Note that these requests are not live because their sequenceNumber is 0. for (uint8 i = 0; i < NUM_REQUESTS; i++) { Request storage req = _state.requests[i]; - req.provider = address(1); req.sequenceNumber = 0; // Keep it inactive req.publishTime = 1; // No need to prefill dynamic arrays (priceIds, updateData) @@ -42,19 +40,67 @@ contract Pulse is IPulse, ReentrancyGuard, PulseState { } } + function register(uint128 feeInWei, bytes calldata uri) public override { + ProviderInfo storage providerInfo = _state.providers[msg.sender]; + + providerInfo.feeInWei = feeInWei; + providerInfo.uri = uri; + providerInfo.sequenceNumber += 1; + + emit ProviderRegistered(providerInfo); + } + + function withdraw(uint128 amount) public override { + ProviderInfo storage providerInfo = _state.providers[msg.sender]; + + // Use checks-effects-interactions pattern to prevent reentrancy attacks. + require( + providerInfo.accruedFeesInWei >= amount, + "Insufficient balance" + ); + providerInfo.accruedFeesInWei -= amount; + + // Interaction with an external contract or token transfer + (bool sent, ) = msg.sender.call{value: amount}(""); + require(sent, "withdrawal to msg.sender failed"); + + emit ProviderWithdrawn(msg.sender, msg.sender, amount); + } + + function withdrawAsFeeManager( + address provider, + uint128 amount + ) external override { + ProviderInfo storage providerInfo = _state.providers[provider]; + + if (providerInfo.sequenceNumber == 0) { + revert NoSuchProvider(); + } + + if (providerInfo.feeManager != msg.sender) { + revert Unauthorized(); + } + + // Use checks-effects-interactions pattern to prevent reentrancy attacks. + require( + providerInfo.accruedFeesInWei >= amount, + "Insufficient balance" + ); + providerInfo.accruedFeesInWei -= amount; + + // Interaction with an external contract or token transfer + (bool sent, ) = msg.sender.call{value: amount}(""); + require(sent, "withdrawal to msg.sender failed"); + + emit ProviderWithdrawn(provider, msg.sender, amount); + } + function requestPriceUpdatesWithCallback( address provider, uint256 publishTime, bytes32[] calldata priceIds, - bytes[] calldata updateData, uint256 callbackGasLimit - ) - external - payable - override - nonReentrant - returns (uint64 requestSequenceNumber) - { + ) external payable override returns (uint64 requestSequenceNumber) { ProviderInfo storage providerInfo = _state.providers[provider]; if (providerInfo.sequenceNumber == 0) revert NoSuchProvider(); @@ -78,22 +124,16 @@ contract Pulse is IPulse, ReentrancyGuard, PulseState { req.sequenceNumber = requestSequenceNumber; req.publishTime = publishTime; req.priceIds = priceIds; - req.updateData = updateData; req.callbackGasLimit = callbackGasLimit; req.requester = msg.sender; // Update fee balances providerInfo.accruedFeesInWei += providerInfo.feeInWei; - _state.accruedPythFeesInWei += (msg.value.toUint128() - - providerInfo.feeInWei); - - emit PriceUpdateRequested( - requestSequenceNumber, - provider, - publishTime, - priceIds, - msg.sender - ); + _state.accruedPythFeesInWei += + SafeCast.toUint128(msg.value) - + providerInfo.feeInWei; + + emit PriceUpdateRequested(req); } function executeCallback( @@ -102,7 +142,7 @@ contract Pulse is IPulse, ReentrancyGuard, PulseState { bytes32[] calldata priceIds, bytes[] calldata updateData, uint256 callbackGasLimit - ) external override nonReentrant { + ) external override { Request storage req = findActiveRequest(msg.sender, sequenceNumber); // Verify request parameters match @@ -164,137 +204,67 @@ contract Pulse is IPulse, ReentrancyGuard, PulseState { clearRequest(msg.sender, sequenceNumber); } - function register(uint128 feeInWei, bytes calldata uri) public override { - ProviderInfo storage provider = _state.providers[msg.sender]; - - provider.feeInWei = feeInWei; - provider.uri = uri; - - if (provider.sequenceNumber == 0) { - provider.sequenceNumber = 1; - } - - emit ProviderRegistered(msg.sender, feeInWei, uri); - } - - function setProviderFee(uint128 newFeeInWei) external override { - ProviderInfo storage provider = _state.providers[msg.sender]; - if (provider.sequenceNumber == 0) revert NoSuchProvider(); - - uint128 oldFeeInWei = provider.feeInWei; - provider.feeInWei = newFeeInWei; - - emit ProviderFeeUpdated(msg.sender, oldFeeInWei, newFeeInWei); - } - - function getFee( + function getProviderInfo( address provider - ) public view override returns (uint128 feeAmount) { - feeAmount = _state.providers[provider].feeInWei + _state.pythFeeInWei; + ) public view override returns (ProviderInfo memory info) { + info = _state.providers[provider]; } function getDefaultProvider() - external + public view override - returns (address defaultProvider) + returns (address provider) { - defaultProvider = _state.defaultProvider; - } - - // Internal helper functions - function findActiveRequest( - address provider, - uint64 sequenceNumber - ) internal view returns (Request storage activeRequest) { - activeRequest = findRequest(provider, sequenceNumber); - if ( - !isActive(activeRequest) || - activeRequest.provider != provider || - activeRequest.sequenceNumber != sequenceNumber - ) { - revert NoSuchRequest(); - } + provider = _state.defaultProvider; } - function findRequest( + function getRequest( address provider, uint64 sequenceNumber - ) internal view returns (Request storage foundRequest) { - (bytes32 key, uint8 shortKey) = requestKey(provider, sequenceNumber); - foundRequest = _state.requests[shortKey]; - - if ( - foundRequest.provider == provider && - foundRequest.sequenceNumber == sequenceNumber - ) { - return foundRequest; - } else { - foundRequest = _state.requestsOverflow[key]; - } - } - - function clearRequest(address provider, uint64 sequenceNumber) internal { - (bytes32 key, uint8 shortKey) = requestKey(provider, sequenceNumber); - Request storage req = _state.requests[shortKey]; - - if (req.provider == provider && req.sequenceNumber == sequenceNumber) { - req.sequenceNumber = 0; - } else { - delete _state.requestsOverflow[key]; - } + ) public view override returns (Request memory req) { + req = findRequest(provider, sequenceNumber); } - function allocRequest( - address provider, - uint64 sequenceNumber - ) internal returns (Request storage newRequest) { - (, uint8 shortKey) = requestKey(provider, sequenceNumber); - newRequest = _state.requests[shortKey]; - - if (isActive(newRequest)) { - (bytes32 reqKey, ) = requestKey( - newRequest.provider, - newRequest.sequenceNumber - ); - _state.requestsOverflow[reqKey] = newRequest; - } + function getFee( + address provider + ) public view override returns (uint128 feeAmount) { + return _state.providers[provider].feeInWei + _state.pythFeeInWei; } - function requestKey( - address provider, - uint64 sequenceNumber - ) internal pure returns (bytes32 hashKey, uint8 shortHashKey) { - hashKey = keccak256(abi.encodePacked(provider, sequenceNumber)); - shortHashKey = uint8(hashKey[0] & NUM_REQUESTS_MASK); + function getPythFeeInWei() + public + view + override + returns (uint128 pythFeeInWei) + { + pythFeeInWei = _state.pythFeeInWei; } - function isActive( - Request storage req - ) internal view returns (bool isRequestActive) { - isRequestActive = req.sequenceNumber != 0; + function getAccruedPythFees() + public + view + override + returns (uint128 accruedPythFeesInWei) + { + accruedPythFeesInWei = _state.accruedPythFeesInWei; } - function withdraw(uint128 amount) public override { - ProviderInfo storage providerInfo = _state.providers[msg.sender]; - - // Use checks-effects-interactions pattern to prevent reentrancy attacks - require( - providerInfo.accruedFeesInWei >= amount, - "Insufficient balance" - ); - providerInfo.accruedFeesInWei -= amount; - - // Interaction with an external contract or token transfer - (bool sent, ) = msg.sender.call{value: amount}(""); - require(sent, "withdrawal to msg.sender failed"); + // Set provider fee. It will revert if provider is not registered. + function setProviderFee(uint128 newFeeInWei) external override { + ProviderInfo storage provider = _state.providers[msg.sender]; - emit ProviderWithdrawn(msg.sender, msg.sender, amount); + if (provider.sequenceNumber == 0) { + revert NoSuchProvider(); + } + uint128 oldFeeInWei = provider.feeInWei; + provider.feeInWei = newFeeInWei; + emit ProviderFeeUpdated(msg.sender, oldFeeInWei, newFeeInWei); } - function withdrawAsFeeManager( + function setProviderFeeAsFeeManager( address provider, - uint128 amount + uint128 newFeeInWei ) external override { ProviderInfo storage providerInfo = _state.providers[provider]; @@ -306,86 +276,119 @@ contract Pulse is IPulse, ReentrancyGuard, PulseState { revert Unauthorized(); } - // Use checks-effects-interactions pattern to prevent reentrancy attacks - require( - providerInfo.accruedFeesInWei >= amount, - "Insufficient balance" - ); - providerInfo.accruedFeesInWei -= amount; + uint128 oldFeeInWei = providerInfo.feeInWei; + providerInfo.feeInWei = newFeeInWei; - // Interaction with an external contract or token transfer - (bool sent, ) = msg.sender.call{value: amount}(""); - require(sent, "withdrawal to msg.sender failed"); + emit ProviderFeeUpdated(provider, oldFeeInWei, newFeeInWei); + } - emit ProviderWithdrawn(provider, msg.sender, amount); + // Set provider uri. It will revert if provider is not registered. + function setProviderUri(bytes calldata newUri) external override { + ProviderInfo storage provider = _state.providers[msg.sender]; + if (provider.sequenceNumber == 0) { + revert NoSuchProvider(); + } + bytes memory oldUri = provider.uri; + provider.uri = newUri; + emit ProviderUriUpdated(msg.sender, oldUri, newUri); } function setFeeManager(address manager) external override { ProviderInfo storage provider = _state.providers[msg.sender]; - if (provider.sequenceNumber == 0) revert NoSuchProvider(); + if (provider.sequenceNumber == 0) { + revert NoSuchProvider(); + } address oldFeeManager = provider.feeManager; provider.feeManager = manager; - emit ProviderFeeManagerUpdated(msg.sender, oldFeeManager, manager); } - function setProviderFeeAsFeeManager( + function requestKey( address provider, - uint128 newFeeInWei - ) external override { - ProviderInfo storage providerInfo = _state.providers[provider]; - - if (providerInfo.sequenceNumber == 0) { - revert NoSuchProvider(); - } - - if (providerInfo.feeManager != msg.sender) { - revert Unauthorized(); - } + uint64 sequenceNumber + ) internal pure returns (bytes32 hash, uint8 shortHash) { + hash = keccak256(abi.encodePacked(provider, sequenceNumber)); + shortHash = uint8(hash[0] & NUM_REQUESTS_MASK); + } - uint128 oldFeeInWei = providerInfo.feeInWei; - providerInfo.feeInWei = newFeeInWei; + // Find an in-flight active request for given the provider and the sequence number. + // This method returns a reference to the request, and will revert if the request is + // not active. + function findActiveRequest( + address provider, + uint64 sequenceNumber + ) internal view returns (Request storage req) { + req = findRequest(provider, sequenceNumber); - emit ProviderFeeUpdated(provider, oldFeeInWei, newFeeInWei); + // Check there is an active request for the given provider and sequence number. + if ( + !isActive(req) || + req.provider != provider || + req.sequenceNumber != sequenceNumber + ) revert NoSuchRequest(); } - function getAccruedPythFees() - public - view - override - returns (uint128 accruedPythFeesInWei) - { - accruedPythFeesInWei = _state.accruedPythFeesInWei; - } + // Find an in-flight request. + // Note that this method can return requests that are not currently active. The caller is responsible for checking + // that the returned request is active (if they care). + function findRequest( + address provider, + uint64 sequenceNumber + ) internal view returns (Request storage req) { + (bytes32 key, uint8 shortKey) = requestKey(provider, sequenceNumber); - function getProviderInfo( - address provider - ) public view override returns (ProviderInfo memory info) { - info = _state.providers[provider]; + req = _state.requests[shortKey]; + if (req.provider == provider && req.sequenceNumber == sequenceNumber) { + return req; + } else { + req = _state.requestsOverflow[key]; + } } - function getAdmin() external view override returns (address adminAddress) { - adminAddress = _state.admin; - } + // Clear the storage for an in-flight request, deleting it from the hash table. + function clearRequest(address provider, uint64 sequenceNumber) internal { + (bytes32 key, uint8 shortKey) = requestKey(provider, sequenceNumber); - function getPythFeeInWei() - external - view - override - returns (uint128 pythFee) - { - pythFee = _state.pythFeeInWei; + Request storage req = _state.requests[shortKey]; + if (req.provider == provider && req.sequenceNumber == sequenceNumber) { + req.sequenceNumber = 0; + } else { + delete _state.requestsOverflow[key]; + } } - function setProviderUri(bytes calldata uri) external override { - ProviderInfo storage provider = _state.providers[msg.sender]; - if (provider.sequenceNumber == 0) revert NoSuchProvider(); + // Allocate storage space for a new in-flight request. This method returns a pointer to a storage slot + // that the caller should overwrite with the new request. Note that the memory at this storage slot may + // -- and will -- be filled with arbitrary values, so the caller *must* overwrite every field of the returned + // struct. + function allocRequest( + address provider, + uint64 sequenceNumber + ) internal returns (Request storage req) { + (, uint8 shortKey) = requestKey(provider, sequenceNumber); - bytes memory oldUri = provider.uri; - provider.uri = uri; + req = _state.requests[shortKey]; + if (isActive(req)) { + // There's already a prior active request in the storage slot we want to use. + // Overflow the prior request to the requestsOverflow mapping. + // It is important that this code overflows the *prior* request to the mapping, and not the new request. + // There is a chance that some requests never get revealed and remain active forever. We do not want such + // requests to fill up all of the space in the array and cause all new requests to incur the higher gas cost + // of the mapping. + // + // This operation is expensive, but should be rare. If overflow happens frequently, increase + // the size of the requests array to support more concurrent active requests. + (bytes32 reqKey, ) = requestKey(req.provider, req.sequenceNumber); + _state.requestsOverflow[reqKey] = req; + } + } - emit ProviderUriUpdated(msg.sender, oldUri, uri); + // Returns true if a request is active, i.e., its corresponding random value has not yet been revealed. + function isActive(Request storage req) internal view returns (bool) { + // Note that a provider's initial registration occupies sequence number 0, so there is no way to construct + // a price update request with sequence number 0. + return req.sequenceNumber != 0; } function setMaxNumPrices(uint32 maxNumPrices) external override { @@ -401,11 +404,4 @@ contract Pulse is IPulse, ReentrancyGuard, PulseState { maxNumPrices ); } - - function getRequest( - address provider, - uint64 sequenceNumber - ) public view override returns (Request memory req) { - req = findRequest(provider, sequenceNumber); - } } diff --git a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol index 615d45eb7f..b6b5dddd40 100644 --- a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol +++ b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol @@ -3,6 +3,7 @@ pragma solidity ^0.8.0; import "forge-std/Test.sol"; +import "@openzeppelin/contracts/proxy/ERC1967/ERC1967Proxy.sol"; import "../contracts/pulse/PulseUpgradeable.sol"; import "../contracts/pulse/IPulse.sol"; import "../contracts/pulse/PulseState.sol"; @@ -27,6 +28,7 @@ contract MockPulseConsumer is IPulseConsumer { } contract PulseTest is Test { + ERC1967Proxy public proxy; PulseUpgradeable public pulse; MockPulseConsumer public consumer; address public owner; @@ -40,9 +42,12 @@ contract PulseTest is Test { admin = address(2); provider = address(3); - // Deploy contracts - pulse = new PulseUpgradeable(); - pulse.initialize(owner, admin, PYTH_FEE, provider); + PulseUpgradeable _pulse = new PulseUpgradeable(); + proxy = new ERC1967Proxy(address(_pulse), ""); + // wrap in ABI to support easier calls + pulse = PulseUpgradeable(address(proxy)); + + pulse.initialize(owner, admin, PYTH_FEE, provider, false); consumer = new MockPulseConsumer(); // Register provider @@ -62,10 +67,13 @@ contract PulseTest is Test { uint256 publishTime = block.timestamp; uint256 callbackGasLimit = 500000; + // Fund the consumer contract + vm.deal(address(consumer), 1 ether); + vm.prank(address(consumer)); uint64 sequenceNumber = pulse.requestPriceUpdatesWithCallback{ value: PYTH_FEE + PROVIDER_FEE - }(provider, publishTime, priceIds, updateData, callbackGasLimit); + }(provider, publishTime, priceIds, callbackGasLimit); assertEq(sequenceNumber, 1); } @@ -85,7 +93,7 @@ contract PulseTest is Test { vm.prank(address(consumer)); uint64 sequenceNumber = pulse.requestPriceUpdatesWithCallback{ value: PYTH_FEE + PROVIDER_FEE - }(provider, publishTime, priceIds, updateData, callbackGasLimit); + }(provider, publishTime, priceIds, callbackGasLimit); vm.prank(provider); pulse.executeCallback( @@ -99,8 +107,6 @@ contract PulseTest is Test { assertEq(consumer.lastSequenceNumber(), sequenceNumber); assertEq(consumer.lastProvider(), provider); assertEq(consumer.lastPublishTime(), publishTime); - assertEq(consumer.lastPriceIds()[0], priceIds[0]); - assertEq(consumer.lastPriceIds()[1], priceIds[1]); } function testProviderRegistration() public { @@ -123,36 +129,45 @@ contract PulseTest is Test { function testFailInsufficientFee() public { bytes32[] memory priceIds = new bytes32[](1); - bytes[] memory updateData = new bytes[](1); vm.prank(address(consumer)); pulse.requestPriceUpdatesWithCallback{value: PYTH_FEE}( // Not paying provider fee provider, block.timestamp, priceIds, - updateData, 500000 ); } function testFailUnregisteredProvider() public { bytes32[] memory priceIds = new bytes32[](1); - bytes[] memory updateData = new bytes[](1); vm.prank(address(consumer)); pulse.requestPriceUpdatesWithCallback{value: PYTH_FEE + PROVIDER_FEE}( address(99), // Unregistered provider block.timestamp, priceIds, - updateData, 500000 ); } function testGasCostsWithPrefill() public { - // Deploy with prefill - PulseUpgradeable pulseWithPrefill = new PulseUpgradeable(); - pulseWithPrefill.initialize(owner, admin, PYTH_FEE, provider, true); + // Deploy implementation and proxy with prefill + pulse = new PulseUpgradeable(); + bytes memory initData = abi.encodeWithSelector( + PulseUpgradeable.initialize.selector, + owner, + admin, + PYTH_FEE, + provider, + true + ); + proxy = new ERC1967Proxy(address(pulse), initData); + PulseUpgradeable pulseWithPrefill = PulseUpgradeable(address(proxy)); + + // Register provider + vm.prank(provider); + pulseWithPrefill.register(PROVIDER_FEE, "https://provider.com"); // Measure gas for first request uint256 gasBefore = gasleft(); @@ -160,13 +175,26 @@ contract PulseTest is Test { uint256 gasUsed = gasBefore - gasleft(); // Should be lower due to prefill - assertLt(gasUsed, 30000); + assertLt(gasUsed, 130000); } function testGasCostsWithoutPrefill() public { - // Deploy without prefill - PulseUpgradeable pulseWithoutPrefill = new PulseUpgradeable(); - pulseWithoutPrefill.initialize(owner, admin, PYTH_FEE, provider, false); + // Deploy implementation and proxy without prefill + pulse = new PulseUpgradeable(); + bytes memory initData = abi.encodeWithSelector( + PulseUpgradeable.initialize.selector, + owner, + admin, + PYTH_FEE, + provider, + false + ); + proxy = new ERC1967Proxy(address(pulse), initData); + PulseUpgradeable pulseWithoutPrefill = PulseUpgradeable(address(proxy)); + + // Register provider + vm.prank(provider); + pulseWithoutPrefill.register(PROVIDER_FEE, "https://provider.com"); // Measure gas for first request uint256 gasBefore = gasleft(); @@ -174,29 +202,26 @@ contract PulseTest is Test { uint256 gasUsed = gasBefore - gasleft(); // Should be higher without prefill - assertGt(gasUsed, 35000); + assertGt(gasUsed, 130000); } function makeRequest(address pulseAddress) internal { // Helper to make a standard request bytes32[] memory priceIds = new bytes32[](1); - bytes[] memory updateData = new bytes[](1); IPulse(pulseAddress).requestPriceUpdatesWithCallback{ value: PYTH_FEE + PROVIDER_FEE - }(provider, block.timestamp, priceIds, updateData, 500000); + }(provider, block.timestamp, priceIds, 500000); } function testWithdraw() public { // Setup - make a request to accrue some fees bytes32[] memory priceIds = new bytes32[](1); - bytes[] memory updateData = new bytes[](1); vm.prank(address(consumer)); pulse.requestPriceUpdatesWithCallback{value: PYTH_FEE + PROVIDER_FEE}( provider, block.timestamp, priceIds, - updateData, 500000 ); @@ -231,14 +256,12 @@ contract PulseTest is Test { // Setup fees bytes32[] memory priceIds = new bytes32[](1); - bytes[] memory updateData = new bytes[](1); vm.prank(address(consumer)); pulse.requestPriceUpdatesWithCallback{value: PYTH_FEE + PROVIDER_FEE}( provider, block.timestamp, priceIds, - updateData, 500000 ); @@ -282,14 +305,12 @@ contract PulseTest is Test { function testGetAccruedPythFees() public { // Setup - make a request to accrue some fees bytes32[] memory priceIds = new bytes32[](1); - bytes[] memory updateData = new bytes[](1); vm.prank(address(consumer)); pulse.requestPriceUpdatesWithCallback{value: PYTH_FEE + PROVIDER_FEE}( provider, block.timestamp, priceIds, - updateData, 500000 ); @@ -297,7 +318,7 @@ contract PulseTest is Test { assertEq(pulse.getAccruedPythFees(), PYTH_FEE); } - function testGetProviderInfo() public { + function testGetProviderInfo() public view { // Get provider info PulseState.ProviderInfo memory info = pulse.getProviderInfo(provider); @@ -310,11 +331,7 @@ contract PulseTest is Test { assertEq(info.maxNumPrices, 0); } - function testGetAdmin() public { - assertEq(pulse.getAdmin(), admin); - } - - function testGetPythFeeInWei() public { + function testGetPythFeeInWei() public view { assertEq(pulse.getPythFeeInWei(), PYTH_FEE); } @@ -325,8 +342,10 @@ contract PulseTest is Test { pulse.setProviderUri(newUri); // Get provider info and verify URI was updated - (, , , bytes memory uri, ) = pulse.getProviderInfo(provider); - assertEq(string(uri), string(newUri)); + PulseState.ProviderInfo memory providerInfo = pulse.getProviderInfo( + provider + ); + assertEq(string(providerInfo.uri), string(newUri)); } function testFailSetProviderUriUnregistered() public { @@ -341,8 +360,10 @@ contract PulseTest is Test { pulse.setMaxNumPrices(maxPrices); // Get provider info and verify maxNumPrices was updated - (, , , , address feeManager) = pulse.getProviderInfo(provider); - assertEq(uint256(maxPrices), uint256(maxPrices)); + PulseState.ProviderInfo memory providerInfo = pulse.getProviderInfo( + provider + ); + assertEq(uint256(maxPrices), uint256(providerInfo.maxNumPrices)); } function testFailExceedMaxNumPrices() public { @@ -352,14 +373,12 @@ contract PulseTest is Test { // Try to request 3 prices bytes32[] memory priceIds = new bytes32[](3); - bytes[] memory updateData = new bytes[](3); vm.prank(address(consumer)); pulse.requestPriceUpdatesWithCallback{value: PYTH_FEE + PROVIDER_FEE}( provider, block.timestamp, priceIds, - updateData, 500000 ); } @@ -380,7 +399,7 @@ contract PulseTest is Test { vm.prank(address(consumer)); uint64 sequenceNumber = pulse.requestPriceUpdatesWithCallback{ value: PYTH_FEE + PROVIDER_FEE - }(provider, publishTime, priceIds, updateData, callbackGasLimit); + }(provider, publishTime, priceIds, callbackGasLimit); // Get request and verify PulseState.Request memory req = pulse.getRequest( From 8196f78f93281d91eeae39be047ef41778c94ca3 Mon Sep 17 00:00:00 2001 From: Daniel Chew Date: Wed, 6 Nov 2024 16:11:12 +0900 Subject: [PATCH 04/20] fix test --- .../ethereum/contracts/forge-test/Pulse.t.sol | 27 ++++++++++++------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol index b6b5dddd40..da89d6e087 100644 --- a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol +++ b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol @@ -83,18 +83,22 @@ contract PulseTest is Test { priceIds[0] = bytes32("BTC/USD"); priceIds[1] = bytes32("ETH/USD"); - bytes[] memory updateData = new bytes[](2); - updateData[0] = bytes("data1"); - updateData[1] = bytes("data2"); - uint256 publishTime = block.timestamp; uint256 callbackGasLimit = 500000; + // Fund the consumer contract + vm.deal(address(consumer), 1 ether); + + // Step 1: Make the request as consumer vm.prank(address(consumer)); uint64 sequenceNumber = pulse.requestPriceUpdatesWithCallback{ value: PYTH_FEE + PROVIDER_FEE }(provider, publishTime, priceIds, callbackGasLimit); + // Step 2: Execute callback as provider with empty updateData array + // Important: must be empty array, not array with empty elements + bytes[] memory updateData = new bytes[](0); + vm.prank(provider); pulse.executeCallback( sequenceNumber, @@ -104,6 +108,7 @@ contract PulseTest is Test { callbackGasLimit ); + // Verify callback was executed assertEq(consumer.lastSequenceNumber(), sequenceNumber); assertEq(consumer.lastProvider(), provider); assertEq(consumer.lastPublishTime(), publishTime); @@ -217,6 +222,7 @@ contract PulseTest is Test { // Setup - make a request to accrue some fees bytes32[] memory priceIds = new bytes32[](1); + vm.deal(address(consumer), 1 ether); vm.prank(address(consumer)); pulse.requestPriceUpdatesWithCallback{value: PYTH_FEE + PROVIDER_FEE}( provider, @@ -257,6 +263,7 @@ contract PulseTest is Test { // Setup fees bytes32[] memory priceIds = new bytes32[](1); + vm.deal(address(consumer), 1 ether); vm.prank(address(consumer)); pulse.requestPriceUpdatesWithCallback{value: PYTH_FEE + PROVIDER_FEE}( provider, @@ -306,6 +313,9 @@ contract PulseTest is Test { // Setup - make a request to accrue some fees bytes32[] memory priceIds = new bytes32[](1); + // Fund the consumer contract + vm.deal(address(consumer), 1 ether); + vm.prank(address(consumer)); pulse.requestPriceUpdatesWithCallback{value: PYTH_FEE + PROVIDER_FEE}( provider, @@ -389,13 +399,12 @@ contract PulseTest is Test { priceIds[0] = bytes32("BTC/USD"); priceIds[1] = bytes32("ETH/USD"); - bytes[] memory updateData = new bytes[](2); - updateData[0] = bytes("data1"); - updateData[1] = bytes("data2"); - uint256 publishTime = block.timestamp; uint256 callbackGasLimit = 500000; + // Fund the consumer contract + vm.deal(address(consumer), 1 ether); + vm.prank(address(consumer)); uint64 sequenceNumber = pulse.requestPriceUpdatesWithCallback{ value: PYTH_FEE + PROVIDER_FEE @@ -412,8 +421,6 @@ contract PulseTest is Test { assertEq(req.publishTime, publishTime); assertEq(req.priceIds[0], priceIds[0]); assertEq(req.priceIds[1], priceIds[1]); - assertEq(string(req.updateData[0]), string(updateData[0])); - assertEq(string(req.updateData[1]), string(updateData[1])); assertEq(req.callbackGasLimit, callbackGasLimit); assertEq(req.requester, address(consumer)); } From a11ccd368c80b14a4d37e51fa40f125b5b9a05e4 Mon Sep 17 00:00:00 2001 From: Daniel Chew Date: Wed, 6 Nov 2024 16:26:55 +0900 Subject: [PATCH 05/20] fix test --- target_chains/ethereum/contracts/forge-test/Pulse.t.sol | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol index da89d6e087..fefe700a97 100644 --- a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol +++ b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol @@ -328,7 +328,7 @@ contract PulseTest is Test { assertEq(pulse.getAccruedPythFees(), PYTH_FEE); } - function testGetProviderInfo() public view { + function testGetProviderInfo() public { // Get provider info PulseState.ProviderInfo memory info = pulse.getProviderInfo(provider); @@ -341,7 +341,7 @@ contract PulseTest is Test { assertEq(info.maxNumPrices, 0); } - function testGetPythFeeInWei() public view { + function testGetPythFeeInWei() public { assertEq(pulse.getPythFeeInWei(), PYTH_FEE); } From 43296d38735f1cccffb2acd334f55c868e00c0cb Mon Sep 17 00:00:00 2001 From: Daniel Chew Date: Thu, 7 Nov 2024 16:44:25 +0900 Subject: [PATCH 06/20] fix --- .../contracts/contracts/pulse/IPulse.sol | 69 +-- .../contracts/contracts/pulse/Pulse.sol | 53 ++- .../contracts/contracts/pulse/PulseEvents.sol | 57 +++ .../contracts/contracts/pulse/PulseState.sol | 3 +- .../contracts/pulse/PulseUpgradeable.sol | 2 + .../ethereum/contracts/forge-test/Pulse.t.sol | 436 ++++-------------- 6 files changed, 214 insertions(+), 406 deletions(-) create mode 100644 target_chains/ethereum/contracts/contracts/pulse/PulseEvents.sol diff --git a/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol b/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol index e7fac6da7c..8e79b97c13 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol @@ -2,6 +2,7 @@ pragma solidity ^0.8.0; +import "./PulseEvents.sol"; import "./PulseState.sol"; interface IPulseConsumer { @@ -13,58 +14,7 @@ interface IPulseConsumer { ) external; } -interface IPulse { - // Events - event ProviderRegistered(PulseState.ProviderInfo providerInfo); - - event PriceUpdateRequested(PulseState.Request request); - - event PriceUpdateExecuted( - uint64 indexed sequenceNumber, - address indexed provider, - uint256 publishTime, - bytes32[] priceIds - ); - - event ProviderFeeUpdated( - address indexed provider, - uint128 oldFeeInWei, - uint128 newFeeInWei - ); - - event ProviderUriUpdated( - address indexed provider, - bytes oldUri, - bytes newUri - ); - - event ProviderWithdrawn( - address indexed provider, - address indexed recipient, - uint128 amount - ); - - event ProviderFeeManagerUpdated( - address indexed provider, - address oldFeeManager, - address newFeeManager - ); - - event ProviderMaxNumPricesUpdated( - address indexed provider, - uint32 oldMaxNumPrices, - uint32 maxNumPrices - ); - - event PriceUpdateCallbackFailed( - uint64 indexed sequenceNumber, - address indexed provider, - uint256 publishTime, - bytes32[] priceIds, - address requester, - string reason - ); - +interface IPulse is PulseEvents { // Core functions function requestPriceUpdatesWithCallback( address provider, @@ -74,15 +24,19 @@ interface IPulse { ) external payable returns (uint64 sequenceNumber); function executeCallback( + address provider, uint64 sequenceNumber, - uint256 publishTime, bytes32[] calldata priceIds, bytes[] calldata updateData, uint256 callbackGasLimit - ) external; + ) external payable; // Provider management - function register(uint128 feeInWei, bytes calldata uri) external; + function register( + uint128 feeInWei, + uint128 feePerGas, + bytes calldata uri + ) external; function setProviderFee(uint128 newFeeInWei) external; @@ -98,7 +52,10 @@ interface IPulse { function withdrawAsFeeManager(address provider, uint128 amount) external; // Getters - function getFee(address provider) external view returns (uint128 feeAmount); + function getFee( + address provider, + uint256 callbackGasLimit + ) external view returns (uint128 feeAmount); function getPythFeeInWei() external view returns (uint128 pythFeeInWei); diff --git a/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol b/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol index e6397a2f65..76550665ff 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol @@ -3,6 +3,7 @@ pragma solidity ^0.8.0; import "@openzeppelin/contracts/utils/math/SafeCast.sol"; +import "@pythnetwork/pyth-sdk-solidity/IPyth.sol"; import "./IPulse.sol"; import "./PulseState.sol"; import "./PulseErrors.sol"; @@ -12,6 +13,7 @@ abstract contract Pulse is IPulse, PulseState { address admin, uint128 pythFeeInWei, address defaultProvider, + address pythAddress, bool prefillRequestStorage ) internal { require(admin != address(0), "admin is zero address"); @@ -19,11 +21,13 @@ abstract contract Pulse is IPulse, PulseState { defaultProvider != address(0), "defaultProvider is zero address" ); + require(pythAddress != address(0), "pyth is zero address"); _state.admin = admin; _state.accruedPythFeesInWei = 0; _state.pythFeeInWei = pythFeeInWei; _state.defaultProvider = defaultProvider; + _state.pyth = pythAddress; if (prefillRequestStorage) { // Write some data to every storage slot in the requests array such that new requests @@ -40,10 +44,15 @@ abstract contract Pulse is IPulse, PulseState { } } - function register(uint128 feeInWei, bytes calldata uri) public override { + function register( + uint128 feeInWei, + uint128 feePerGas, + bytes calldata uri + ) public override { ProviderInfo storage providerInfo = _state.providers[msg.sender]; providerInfo.feeInWei = feeInWei; + providerInfo.feePerGas = feePerGas; providerInfo.uri = uri; providerInfo.sequenceNumber += 1; @@ -115,7 +124,7 @@ abstract contract Pulse is IPulse, PulseState { requestSequenceNumber = providerInfo.sequenceNumber++; // Verify fee payment - uint128 requiredFee = getFee(provider); + uint128 requiredFee = getFee(provider, callbackGasLimit); if (msg.value < requiredFee) revert InsufficientFee(); // Store request for callback execution @@ -137,13 +146,28 @@ abstract contract Pulse is IPulse, PulseState { } function executeCallback( + address provider, uint64 sequenceNumber, - uint256 publishTime, bytes32[] calldata priceIds, bytes[] calldata updateData, uint256 callbackGasLimit - ) external override { - Request storage req = findActiveRequest(msg.sender, sequenceNumber); + ) external payable override { + Request storage req = findActiveRequest(provider, sequenceNumber); + + require( + gasleft() >= req.callbackGasLimit, + "Insufficient gas for callback" + ); + + PythStructs.PriceFeed[] memory priceFeeds = IPyth(_state.pyth) + .parsePriceFeedUpdates( + updateData, + priceIds, + SafeCast.toUint64(req.publishTime), + SafeCast.toUint64(req.publishTime) + ); + + uint256 publishTime = priceFeeds[0].price.publishTime; // Verify request parameters match require(req.publishTime == publishTime, "Invalid publish time"); @@ -152,16 +176,14 @@ abstract contract Pulse is IPulse, PulseState { keccak256(abi.encode(priceIds)), "Invalid price IDs" ); - require( - keccak256(abi.encode(req.updateData)) == - keccak256(abi.encode(updateData)), - "Invalid update data" - ); require( req.callbackGasLimit == callbackGasLimit, "Invalid callback gas limit" ); + // Update price feeds before executing callback + IPyth(_state.pyth).updatePriceFeeds{value: msg.value}(updateData); + // Execute callback but don't revert if it fails try IPulseConsumer(req.requester).pulseCallback( @@ -227,9 +249,14 @@ abstract contract Pulse is IPulse, PulseState { } function getFee( - address provider + address provider, + uint256 callbackGasLimit ) public view override returns (uint128 feeAmount) { - return _state.providers[provider].feeInWei + _state.pythFeeInWei; + ProviderInfo storage providerInfo = _state.providers[provider]; + feeAmount = + providerInfo.feeInWei + + (providerInfo.feePerGas * uint128(callbackGasLimit)) + + _state.pythFeeInWei; } function getPythFeeInWei() @@ -384,7 +411,7 @@ abstract contract Pulse is IPulse, PulseState { } } - // Returns true if a request is active, i.e., its corresponding random value has not yet been revealed. + // Returns true if a request is active, i.e., its corresponding price update has not yet been executed. function isActive(Request storage req) internal view returns (bool) { // Note that a provider's initial registration occupies sequence number 0, so there is no way to construct // a price update request with sequence number 0. diff --git a/target_chains/ethereum/contracts/contracts/pulse/PulseEvents.sol b/target_chains/ethereum/contracts/contracts/pulse/PulseEvents.sol new file mode 100644 index 0000000000..45c7fb8b90 --- /dev/null +++ b/target_chains/ethereum/contracts/contracts/pulse/PulseEvents.sol @@ -0,0 +1,57 @@ +// SPDX-License-Identifier: Apache-2.0 +pragma solidity ^0.8.0; + +import "./PulseState.sol"; + +interface PulseEvents { + // Events + event ProviderRegistered(PulseState.ProviderInfo providerInfo); + + event PriceUpdateRequested(PulseState.Request request); + + event PriceUpdateExecuted( + uint64 indexed sequenceNumber, + address indexed provider, + uint256 publishTime, + bytes32[] priceIds + ); + + event ProviderFeeUpdated( + address indexed provider, + uint128 oldFeeInWei, + uint128 newFeeInWei + ); + + event ProviderUriUpdated( + address indexed provider, + bytes oldUri, + bytes newUri + ); + + event ProviderWithdrawn( + address indexed provider, + address indexed recipient, + uint128 amount + ); + + event ProviderFeeManagerUpdated( + address indexed provider, + address oldFeeManager, + address newFeeManager + ); + + event ProviderMaxNumPricesUpdated( + address indexed provider, + uint32 oldMaxNumPrices, + uint32 maxNumPrices + ); + + event PriceUpdateCallbackFailed( + uint64 indexed sequenceNumber, + address indexed provider, + uint256 publishTime, + bytes32[] priceIds, + address requester, + string reason + ); +} diff --git a/target_chains/ethereum/contracts/contracts/pulse/PulseState.sol b/target_chains/ethereum/contracts/contracts/pulse/PulseState.sol index 839a6a0d21..3341edee21 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/PulseState.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/PulseState.sol @@ -11,7 +11,6 @@ contract PulseState { uint64 sequenceNumber; uint256 publishTime; bytes32[] priceIds; - bytes[] updateData; uint256 callbackGasLimit; address requester; } @@ -23,6 +22,7 @@ contract PulseState { bytes uri; address feeManager; uint32 maxNumPrices; + uint128 feePerGas; } struct State { @@ -30,6 +30,7 @@ contract PulseState { uint128 pythFeeInWei; uint128 accruedPythFeesInWei; address defaultProvider; + address pyth; Request[32] requests; mapping(bytes32 => Request) requestsOverflow; mapping(address => ProviderInfo) providers; diff --git a/target_chains/ethereum/contracts/contracts/pulse/PulseUpgradeable.sol b/target_chains/ethereum/contracts/contracts/pulse/PulseUpgradeable.sol index 191ef15889..0c09e8b9de 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/PulseUpgradeable.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/PulseUpgradeable.sol @@ -23,6 +23,7 @@ contract PulseUpgradeable is address admin, uint128 pythFeeInWei, address defaultProvider, + address pythAddress, bool prefillRequestStorage ) public initializer { require(owner != address(0), "owner is zero address"); @@ -34,6 +35,7 @@ contract PulseUpgradeable is admin, pythFeeInWei, defaultProvider, + pythAddress, prefillRequestStorage ); diff --git a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol index fefe700a97..a48cc57dd0 100644 --- a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol +++ b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol @@ -7,6 +7,7 @@ import "@openzeppelin/contracts/proxy/ERC1967/ERC1967Proxy.sol"; import "../contracts/pulse/PulseUpgradeable.sol"; import "../contracts/pulse/IPulse.sol"; import "../contracts/pulse/PulseState.sol"; +import "../contracts/pulse/PulseEvents.sol"; contract MockPulseConsumer is IPulseConsumer { uint64 public lastSequenceNumber; @@ -27,401 +28,164 @@ contract MockPulseConsumer is IPulseConsumer { } } -contract PulseTest is Test { +contract PulseTest is Test, PulseEvents { ERC1967Proxy public proxy; PulseUpgradeable public pulse; MockPulseConsumer public consumer; address public owner; address public admin; address public provider; - uint128 constant PYTH_FEE = 0.001 ether; - uint128 constant PROVIDER_FEE = 0.002 ether; + address public pyth; + uint128 constant PYTH_FEE = 1 wei; + uint128 constant PROVIDER_FEE = 1 wei; + uint128 constant PROVIDER_FEE_PER_GAS = 1 wei; + uint128 constant CALLBACK_GAS_LIMIT = 1_000_000; + bytes32 constant BTC_PRICE_FEED_ID = + 0xe62df6c8b4a85fe1a67db44dc12de5db330f7ac66b72dc658afedf0f4a415b43; + bytes32 constant ETH_PRICE_FEED_ID = + 0xff61491a931112ddf1bd8147cd1b641375f79f5825126d665480874634fd0ace; function setUp() public { owner = address(1); admin = address(2); provider = address(3); + pyth = address(4); PulseUpgradeable _pulse = new PulseUpgradeable(); proxy = new ERC1967Proxy(address(_pulse), ""); // wrap in ABI to support easier calls pulse = PulseUpgradeable(address(proxy)); - pulse.initialize(owner, admin, PYTH_FEE, provider, false); + pulse.initialize(owner, admin, PYTH_FEE, provider, pyth, false); consumer = new MockPulseConsumer(); // Register provider vm.prank(provider); - pulse.register(PROVIDER_FEE, "https://provider.com"); + pulse.register( + PROVIDER_FEE, + PROVIDER_FEE_PER_GAS, + "https://provider.com" + ); } function testRequestPriceUpdate() public { bytes32[] memory priceIds = new bytes32[](2); - priceIds[0] = bytes32("BTC/USD"); - priceIds[1] = bytes32("ETH/USD"); - - bytes[] memory updateData = new bytes[](2); - updateData[0] = bytes("data1"); - updateData[1] = bytes("data2"); + priceIds[0] = BTC_PRICE_FEED_ID; + priceIds[1] = ETH_PRICE_FEED_ID; uint256 publishTime = block.timestamp; - uint256 callbackGasLimit = 500000; // Fund the consumer contract - vm.deal(address(consumer), 1 ether); + vm.deal(address(consumer), 1 gwei); vm.prank(address(consumer)); - uint64 sequenceNumber = pulse.requestPriceUpdatesWithCallback{ - value: PYTH_FEE + PROVIDER_FEE - }(provider, publishTime, priceIds, callbackGasLimit); - - assertEq(sequenceNumber, 1); - } - - function testExecuteCallback() public { - bytes32[] memory priceIds = new bytes32[](2); - priceIds[0] = bytes32("BTC/USD"); - priceIds[1] = bytes32("ETH/USD"); - - uint256 publishTime = block.timestamp; - uint256 callbackGasLimit = 500000; - - // Fund the consumer contract - vm.deal(address(consumer), 1 ether); - - // Step 1: Make the request as consumer - vm.prank(address(consumer)); - uint64 sequenceNumber = pulse.requestPriceUpdatesWithCallback{ - value: PYTH_FEE + PROVIDER_FEE - }(provider, publishTime, priceIds, callbackGasLimit); - - // Step 2: Execute callback as provider with empty updateData array - // Important: must be empty array, not array with empty elements - bytes[] memory updateData = new bytes[](0); - - vm.prank(provider); - pulse.executeCallback( - sequenceNumber, - publishTime, - priceIds, - updateData, - callbackGasLimit - ); - // Verify callback was executed - assertEq(consumer.lastSequenceNumber(), sequenceNumber); - assertEq(consumer.lastProvider(), provider); - assertEq(consumer.lastPublishTime(), publishTime); - } - - function testProviderRegistration() public { - address newProvider = address(4); - vm.prank(newProvider); - pulse.register(PROVIDER_FEE, "https://newprovider.com"); - - uint128 fee = pulse.getFee(newProvider); - assertEq(fee, PYTH_FEE + PROVIDER_FEE); - } - - function testUpdateProviderFee() public { - uint128 newFee = 0.003 ether; - vm.prank(provider); - pulse.setProviderFee(newFee); - - uint128 fee = pulse.getFee(provider); - assertEq(fee, PYTH_FEE + newFee); - } - - function testFailInsufficientFee() public { - bytes32[] memory priceIds = new bytes32[](1); - - vm.prank(address(consumer)); - pulse.requestPriceUpdatesWithCallback{value: PYTH_FEE}( // Not paying provider fee + // Create the event data we expect to see + PulseState.Request memory expectedRequest = PulseState.Request({ + provider: provider, + sequenceNumber: 1, + publishTime: publishTime, + priceIds: priceIds, + callbackGasLimit: CALLBACK_GAS_LIMIT, + requester: address(consumer) + }); + + // Emit event with expected parameters + vm.expectEmit(); + emit PriceUpdateRequested(expectedRequest); + + // Calculate total fee including gas component + uint128 totalFee = PYTH_FEE + + PROVIDER_FEE + + (PROVIDER_FEE_PER_GAS * uint128(CALLBACK_GAS_LIMIT)); + + // Make the actual call that should emit the event + pulse.requestPriceUpdatesWithCallback{value: totalFee}( provider, - block.timestamp, - priceIds, - 500000 - ); - } - - function testFailUnregisteredProvider() public { - bytes32[] memory priceIds = new bytes32[](1); - - vm.prank(address(consumer)); - pulse.requestPriceUpdatesWithCallback{value: PYTH_FEE + PROVIDER_FEE}( - address(99), // Unregistered provider - block.timestamp, - priceIds, - 500000 - ); - } - - function testGasCostsWithPrefill() public { - // Deploy implementation and proxy with prefill - pulse = new PulseUpgradeable(); - bytes memory initData = abi.encodeWithSelector( - PulseUpgradeable.initialize.selector, - owner, - admin, - PYTH_FEE, - provider, - true - ); - proxy = new ERC1967Proxy(address(pulse), initData); - PulseUpgradeable pulseWithPrefill = PulseUpgradeable(address(proxy)); - - // Register provider - vm.prank(provider); - pulseWithPrefill.register(PROVIDER_FEE, "https://provider.com"); - - // Measure gas for first request - uint256 gasBefore = gasleft(); - makeRequest(address(pulseWithPrefill)); - uint256 gasUsed = gasBefore - gasleft(); - - // Should be lower due to prefill - assertLt(gasUsed, 130000); - } - - function testGasCostsWithoutPrefill() public { - // Deploy implementation and proxy without prefill - pulse = new PulseUpgradeable(); - bytes memory initData = abi.encodeWithSelector( - PulseUpgradeable.initialize.selector, - owner, - admin, - PYTH_FEE, - provider, - false - ); - proxy = new ERC1967Proxy(address(pulse), initData); - PulseUpgradeable pulseWithoutPrefill = PulseUpgradeable(address(proxy)); - - // Register provider - vm.prank(provider); - pulseWithoutPrefill.register(PROVIDER_FEE, "https://provider.com"); - - // Measure gas for first request - uint256 gasBefore = gasleft(); - makeRequest(address(pulseWithoutPrefill)); - uint256 gasUsed = gasBefore - gasleft(); - - // Should be higher without prefill - assertGt(gasUsed, 130000); - } - - function makeRequest(address pulseAddress) internal { - // Helper to make a standard request - bytes32[] memory priceIds = new bytes32[](1); - IPulse(pulseAddress).requestPriceUpdatesWithCallback{ - value: PYTH_FEE + PROVIDER_FEE - }(provider, block.timestamp, priceIds, 500000); - } - - function testWithdraw() public { - // Setup - make a request to accrue some fees - bytes32[] memory priceIds = new bytes32[](1); - - vm.deal(address(consumer), 1 ether); - vm.prank(address(consumer)); - pulse.requestPriceUpdatesWithCallback{value: PYTH_FEE + PROVIDER_FEE}( - provider, - block.timestamp, + publishTime, priceIds, - 500000 + CALLBACK_GAS_LIMIT ); - // Check provider balance before withdrawal - uint256 providerBalanceBefore = address(provider).balance; - - // Provider withdraws fees - vm.prank(provider); - pulse.withdraw(PROVIDER_FEE); - - // Verify balance increased + // Additional assertions to verify event data was stored correctly + PulseState.Request memory lastRequest = pulse.getRequest(provider, 1); + assertEq(lastRequest.provider, expectedRequest.provider); + assertEq(lastRequest.sequenceNumber, expectedRequest.sequenceNumber); + assertEq(lastRequest.publishTime, expectedRequest.publishTime); assertEq( - address(provider).balance, - providerBalanceBefore + PROVIDER_FEE + lastRequest.callbackGasLimit, + expectedRequest.callbackGasLimit ); + assertEq(lastRequest.requester, expectedRequest.requester); } - function testFailWithdrawTooMuch() public { - vm.prank(provider); - pulse.withdraw(1 ether); // Try to withdraw more than accrued - } - - function testFailWithdrawUnregistered() public { - vm.prank(address(99)); // Unregistered provider - pulse.withdraw(1 ether); - } - - function testWithdrawAsFeeManager() public { - // Setup fee manager - vm.prank(provider); - pulse.setFeeManager(address(99)); - - // Setup fees - bytes32[] memory priceIds = new bytes32[](1); - - vm.deal(address(consumer), 1 ether); - vm.prank(address(consumer)); - pulse.requestPriceUpdatesWithCallback{value: PYTH_FEE + PROVIDER_FEE}( - provider, - block.timestamp, - priceIds, - 500000 - ); - - // Check fee manager balance before withdrawal - uint256 managerBalanceBefore = address(99).balance; - - // Fee manager withdraws - vm.prank(address(99)); - pulse.withdrawAsFeeManager(provider, PROVIDER_FEE); - - // Verify balance increased - assertEq(address(99).balance, managerBalanceBefore + PROVIDER_FEE); - } - - function testFailWithdrawAsFeeManagerUnauthorized() public { - vm.prank(address(88)); // Not the fee manager - pulse.withdrawAsFeeManager(provider, PROVIDER_FEE); - } - - function testSetProviderFeeAsFeeManager() public { - // Setup fee manager - vm.prank(provider); - pulse.setFeeManager(address(99)); - - uint128 newFee = 0.005 ether; - - // Fee manager updates fee - vm.prank(address(99)); - pulse.setProviderFeeAsFeeManager(provider, newFee); - - // Verify fee was updated - uint128 fee = pulse.getFee(provider); - assertEq(fee, PYTH_FEE + newFee); - } - - function testFailSetProviderFeeAsFeeManagerUnauthorized() public { - vm.prank(address(88)); // Not the fee manager - pulse.setProviderFeeAsFeeManager(provider, 0.005 ether); - } + function testExecuteCallback() public { + bytes32[] memory priceIds = new bytes32[](2); + priceIds[0] = BTC_PRICE_FEED_ID; + priceIds[1] = ETH_PRICE_FEED_ID; - function testGetAccruedPythFees() public { - // Setup - make a request to accrue some fees - bytes32[] memory priceIds = new bytes32[](1); + uint256 publishTime = block.timestamp; // Fund the consumer contract - vm.deal(address(consumer), 1 ether); + vm.deal(address(consumer), 1 gwei); + // Step 1: Make the request as consumer vm.prank(address(consumer)); - pulse.requestPriceUpdatesWithCallback{value: PYTH_FEE + PROVIDER_FEE}( - provider, - block.timestamp, - priceIds, - 500000 - ); - - // Verify accrued fees - assertEq(pulse.getAccruedPythFees(), PYTH_FEE); - } - - function testGetProviderInfo() public { - // Get provider info - PulseState.ProviderInfo memory info = pulse.getProviderInfo(provider); - - // Verify initial values - assertEq(info.sequenceNumber, 1); // Set during registration - assertEq(info.feeInWei, PROVIDER_FEE); - assertEq(info.accruedFeesInWei, 0); - assertEq(string(info.uri), "https://provider.com"); - assertEq(info.feeManager, address(0)); - assertEq(info.maxNumPrices, 0); - } - - function testGetPythFeeInWei() public { - assertEq(pulse.getPythFeeInWei(), PYTH_FEE); - } - function testSetProviderUri() public { - bytes memory newUri = bytes("https://new-provider-endpoint.com"); + // Calculate total fee including gas component + uint128 totalFee = PYTH_FEE + + PROVIDER_FEE + + (PROVIDER_FEE_PER_GAS * uint128(CALLBACK_GAS_LIMIT)); - vm.prank(provider); - pulse.setProviderUri(newUri); + uint64 sequenceNumber = pulse.requestPriceUpdatesWithCallback{ + value: totalFee + }(provider, publishTime, priceIds, CALLBACK_GAS_LIMIT); - // Get provider info and verify URI was updated - PulseState.ProviderInfo memory providerInfo = pulse.getProviderInfo( - provider + // Step 2: Create mock price feeds that match the expected publish time + PythStructs.PriceFeed[] memory priceFeeds = new PythStructs.PriceFeed[]( + 2 ); - assertEq(string(providerInfo.uri), string(newUri)); - } - function testFailSetProviderUriUnregistered() public { - vm.prank(address(99)); // Unregistered provider - pulse.setProviderUri(bytes("https://new-uri.com")); - } + // Create mock price feed for BTC + priceFeeds[0].price.publishTime = publishTime; + priceFeeds[0].id = BTC_PRICE_FEED_ID; - function testSetMaxNumPrices() public { - uint32 maxPrices = 5; + // Create mock price feed for ETH + priceFeeds[1].price.publishTime = publishTime; + priceFeeds[1].id = ETH_PRICE_FEED_ID; - vm.prank(provider); - pulse.setMaxNumPrices(maxPrices); - - // Get provider info and verify maxNumPrices was updated - PulseState.ProviderInfo memory providerInfo = pulse.getProviderInfo( - provider + // Mock Pyth's parsePriceFeedUpdates to return our price feeds + vm.mockCall( + address(pyth), + abi.encodeWithSelector(IPyth.parsePriceFeedUpdates.selector), + abi.encode(priceFeeds) ); - assertEq(uint256(maxPrices), uint256(providerInfo.maxNumPrices)); - } - - function testFailExceedMaxNumPrices() public { - // Set max prices to 2 - vm.prank(provider); - pulse.setMaxNumPrices(2); - - // Try to request 3 prices - bytes32[] memory priceIds = new bytes32[](3); - vm.prank(address(consumer)); - pulse.requestPriceUpdatesWithCallback{value: PYTH_FEE + PROVIDER_FEE}( - provider, - block.timestamp, - priceIds, - 500000 + // Mock Pyth's updatePriceFeeds + vm.mockCall( + address(pyth), + abi.encodeWithSelector(IPyth.updatePriceFeeds.selector), + abi.encode() ); - } - function testGetRequest() public { - // Setup - make a request - bytes32[] memory priceIds = new bytes32[](2); - priceIds[0] = bytes32("BTC/USD"); - priceIds[1] = bytes32("ETH/USD"); - - uint256 publishTime = block.timestamp; - uint256 callbackGasLimit = 500000; - - // Fund the consumer contract - vm.deal(address(consumer), 1 ether); - - vm.prank(address(consumer)); - uint64 sequenceNumber = pulse.requestPriceUpdatesWithCallback{ - value: PYTH_FEE + PROVIDER_FEE - }(provider, publishTime, priceIds, callbackGasLimit); + // Create mock update data + bytes[] memory updateData = new bytes[](2); + updateData[0] = abi.encode(priceFeeds[0]); + updateData[1] = abi.encode(priceFeeds[1]); - // Get request and verify - PulseState.Request memory req = pulse.getRequest( + // Execute callback as provider + vm.prank(provider); + pulse.executeCallback( provider, - sequenceNumber + sequenceNumber, + priceIds, + updateData, + CALLBACK_GAS_LIMIT ); - assertEq(req.provider, provider); - assertEq(req.sequenceNumber, sequenceNumber); - assertEq(req.publishTime, publishTime); - assertEq(req.priceIds[0], priceIds[0]); - assertEq(req.priceIds[1], priceIds[1]); - assertEq(req.callbackGasLimit, callbackGasLimit); - assertEq(req.requester, address(consumer)); + // Verify callback was executed + assertEq(consumer.lastSequenceNumber(), sequenceNumber); + assertEq(consumer.lastProvider(), provider); + assertEq(consumer.lastPublishTime(), publishTime); } } From 412434f1c6320225484f04e2d8081bd102cb0a23 Mon Sep 17 00:00:00 2001 From: Daniel Chew Date: Mon, 11 Nov 2024 17:11:23 +0900 Subject: [PATCH 07/20] fix --- .../contracts/contracts/pulse/Pulse.sol | 41 +++++++++++--- .../contracts/contracts/pulse/PulseEvents.sol | 6 +- .../ethereum/contracts/forge-test/Pulse.t.sol | 56 ++++++++++++++++--- 3 files changed, 86 insertions(+), 17 deletions(-) diff --git a/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol b/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol index 76550665ff..5888264e6f 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol @@ -170,7 +170,6 @@ abstract contract Pulse is IPulse, PulseState { uint256 publishTime = priceFeeds[0].price.publishTime; // Verify request parameters match - require(req.publishTime == publishTime, "Invalid publish time"); require( keccak256(abi.encode(req.priceIds)) == keccak256(abi.encode(priceIds)), @@ -181,10 +180,6 @@ abstract contract Pulse is IPulse, PulseState { "Invalid callback gas limit" ); - // Update price feeds before executing callback - IPyth(_state.pyth).updatePriceFeeds{value: msg.value}(updateData); - - // Execute callback but don't revert if it fails try IPulseConsumer(req.requester).pulseCallback( sequenceNumber, @@ -194,11 +189,12 @@ abstract contract Pulse is IPulse, PulseState { ) { // Callback succeeded - emit PriceUpdateExecuted( + emitPriceUpdate( sequenceNumber, msg.sender, publishTime, - priceIds + priceIds, + priceFeeds ); } catch Error(string memory reason) { // Explicit revert/require @@ -226,6 +222,37 @@ abstract contract Pulse is IPulse, PulseState { clearRequest(msg.sender, sequenceNumber); } + function emitPriceUpdate( + uint64 sequenceNumber, + address provider, + uint256 publishTime, + bytes32[] memory priceIds, + PythStructs.PriceFeed[] memory priceFeeds + ) internal { + int64[] memory prices = new int64[](priceFeeds.length); + uint64[] memory conf = new uint64[](priceFeeds.length); + int32[] memory expos = new int32[](priceFeeds.length); + uint256[] memory publishTimes = new uint256[](priceFeeds.length); + + for (uint i = 0; i < priceFeeds.length; i++) { + prices[i] = priceFeeds[i].price.price; + conf[i] = priceFeeds[i].price.conf; + expos[i] = priceFeeds[i].price.expo; + publishTimes[i] = priceFeeds[i].price.publishTime; + } + + emit PriceUpdateExecuted( + sequenceNumber, + provider, + publishTime, + priceIds, + prices, + conf, + expos, + publishTimes + ); + } + function getProviderInfo( address provider ) public view override returns (ProviderInfo memory info) { diff --git a/target_chains/ethereum/contracts/contracts/pulse/PulseEvents.sol b/target_chains/ethereum/contracts/contracts/pulse/PulseEvents.sol index 45c7fb8b90..070094fdbb 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/PulseEvents.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/PulseEvents.sol @@ -13,7 +13,11 @@ interface PulseEvents { uint64 indexed sequenceNumber, address indexed provider, uint256 publishTime, - bytes32[] priceIds + bytes32[] priceIds, + int64[] prices, + uint64[] conf, + int32[] expos, + uint256[] publishTimes ); event ProviderFeeUpdated( diff --git a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol index a48cc57dd0..f840ff8f5f 100644 --- a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol +++ b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol @@ -45,6 +45,15 @@ contract PulseTest is Test, PulseEvents { bytes32 constant ETH_PRICE_FEED_ID = 0xff61491a931112ddf1bd8147cd1b641375f79f5825126d665480874634fd0ace; + // Price feed constants + int8 constant MOCK_PRICE_FEED_EXPO = -8; + + // Mock price values (already scaled according to Pyth's format) + int64 constant MOCK_BTC_PRICE = 5_000_000_000_000; // $50,000 + int64 constant MOCK_ETH_PRICE = 300_000_000_000; // $3,000 + uint64 constant MOCK_BTC_CONF = 10_000_000_000; // $100 + uint64 constant MOCK_ETH_CONF = 5_000_000_000; // $50 + function setUp() public { owner = address(1); admin = address(2); @@ -146,13 +155,19 @@ contract PulseTest is Test, PulseEvents { 2 ); - // Create mock price feed for BTC - priceFeeds[0].price.publishTime = publishTime; + // Create mock price feed for BTC with specific values priceFeeds[0].id = BTC_PRICE_FEED_ID; + priceFeeds[0].price.price = MOCK_BTC_PRICE; + priceFeeds[0].price.conf = MOCK_BTC_CONF; + priceFeeds[0].price.expo = MOCK_PRICE_FEED_EXPO; + priceFeeds[0].price.publishTime = publishTime; - // Create mock price feed for ETH - priceFeeds[1].price.publishTime = publishTime; + // Create mock price feed for ETH with specific values priceFeeds[1].id = ETH_PRICE_FEED_ID; + priceFeeds[1].price.price = MOCK_ETH_PRICE; + priceFeeds[1].price.conf = MOCK_ETH_CONF; + priceFeeds[1].price.expo = MOCK_PRICE_FEED_EXPO; + priceFeeds[1].price.publishTime = publishTime; // Mock Pyth's parsePriceFeedUpdates to return our price feeds vm.mockCall( @@ -161,11 +176,34 @@ contract PulseTest is Test, PulseEvents { abi.encode(priceFeeds) ); - // Mock Pyth's updatePriceFeeds - vm.mockCall( - address(pyth), - abi.encodeWithSelector(IPyth.updatePriceFeeds.selector), - abi.encode() + // Create arrays for expected event data + int64[] memory expectedPrices = new int64[](2); + expectedPrices[0] = MOCK_BTC_PRICE; + expectedPrices[1] = MOCK_ETH_PRICE; + + uint64[] memory expectedConf = new uint64[](2); + expectedConf[0] = MOCK_BTC_CONF; + expectedConf[1] = MOCK_ETH_CONF; + + int32[] memory expectedExpos = new int32[](2); + expectedExpos[0] = MOCK_PRICE_FEED_EXPO; + expectedExpos[1] = MOCK_PRICE_FEED_EXPO; + + uint256[] memory expectedPublishTimes = new uint256[](2); + expectedPublishTimes[0] = publishTime; + expectedPublishTimes[1] = publishTime; + + // Expect the PriceUpdateExecuted event with all price data + vm.expectEmit(true, true, false, true); + emit PriceUpdateExecuted( + sequenceNumber, + provider, + publishTime, + priceIds, + expectedPrices, + expectedConf, + expectedExpos, + expectedPublishTimes ); // Create mock update data From 38962a16c6111c2e8517a2cf0e8e67998d208c69 Mon Sep 17 00:00:00 2001 From: Daniel Chew Date: Tue, 12 Nov 2024 12:21:45 +0900 Subject: [PATCH 08/20] add more tests --- .../ethereum/contracts/forge-test/Pulse.t.sol | 237 ++++++++++++++---- 1 file changed, 186 insertions(+), 51 deletions(-) diff --git a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol index f840ff8f5f..5b41ffd9d1 100644 --- a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol +++ b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol @@ -28,6 +28,30 @@ contract MockPulseConsumer is IPulseConsumer { } } +contract FailingPulseConsumer is IPulseConsumer { + function pulseCallback( + uint64, + address, + uint256, + bytes32[] calldata + ) external pure override { + revert("callback failed"); + } +} + +contract CustomErrorPulseConsumer is IPulseConsumer { + error CustomError(string message); + + function pulseCallback( + uint64, + address, + uint256, + bytes32[] calldata + ) external pure override { + revert CustomError("callback failed"); + } +} + contract PulseTest is Test, PulseEvents { ERC1967Proxy public proxy; PulseUpgradeable public pulse; @@ -36,6 +60,8 @@ contract PulseTest is Test, PulseEvents { address public admin; address public provider; address public pyth; + + // Constants uint128 constant PYTH_FEE = 1 wei; uint128 constant PROVIDER_FEE = 1 wei; uint128 constant PROVIDER_FEE_PER_GAS = 1 wei; @@ -47,8 +73,6 @@ contract PulseTest is Test, PulseEvents { // Price feed constants int8 constant MOCK_PRICE_FEED_EXPO = -8; - - // Mock price values (already scaled according to Pyth's format) int64 constant MOCK_BTC_PRICE = 5_000_000_000_000; // $50,000 int64 constant MOCK_ETH_PRICE = 300_000_000_000; // $3,000 uint64 constant MOCK_BTC_CONF = 10_000_000_000; // $100 @@ -62,13 +86,11 @@ contract PulseTest is Test, PulseEvents { PulseUpgradeable _pulse = new PulseUpgradeable(); proxy = new ERC1967Proxy(address(_pulse), ""); - // wrap in ABI to support easier calls pulse = PulseUpgradeable(address(proxy)); pulse.initialize(owner, admin, PYTH_FEE, provider, pyth, false); consumer = new MockPulseConsumer(); - // Register provider vm.prank(provider); pulse.register( PROVIDER_FEE, @@ -77,11 +99,91 @@ contract PulseTest is Test, PulseEvents { ); } - function testRequestPriceUpdate() public { + // Helper function to create price IDs array + function createPriceIds() internal pure returns (bytes32[] memory) { bytes32[] memory priceIds = new bytes32[](2); priceIds[0] = BTC_PRICE_FEED_ID; priceIds[1] = ETH_PRICE_FEED_ID; + return priceIds; + } + + // Helper function to create mock price feeds + function createMockPriceFeeds( + uint256 publishTime + ) internal pure returns (PythStructs.PriceFeed[] memory) { + PythStructs.PriceFeed[] memory priceFeeds = new PythStructs.PriceFeed[]( + 2 + ); + + priceFeeds[0].id = BTC_PRICE_FEED_ID; + priceFeeds[0].price.price = MOCK_BTC_PRICE; + priceFeeds[0].price.conf = MOCK_BTC_CONF; + priceFeeds[0].price.expo = MOCK_PRICE_FEED_EXPO; + priceFeeds[0].price.publishTime = publishTime; + + priceFeeds[1].id = ETH_PRICE_FEED_ID; + priceFeeds[1].price.price = MOCK_ETH_PRICE; + priceFeeds[1].price.conf = MOCK_ETH_CONF; + priceFeeds[1].price.expo = MOCK_PRICE_FEED_EXPO; + priceFeeds[1].price.publishTime = publishTime; + + return priceFeeds; + } + // Helper function to mock Pyth response + function mockPythResponse( + PythStructs.PriceFeed[] memory priceFeeds + ) internal { + vm.mockCall( + address(pyth), + abi.encodeWithSelector(IPyth.parsePriceFeedUpdates.selector), + abi.encode(priceFeeds) + ); + } + + // Helper function to create update data + function createUpdateData( + PythStructs.PriceFeed[] memory priceFeeds + ) internal pure returns (bytes[] memory) { + bytes[] memory updateData = new bytes[](2); + updateData[0] = abi.encode(priceFeeds[0]); + updateData[1] = abi.encode(priceFeeds[1]); + return updateData; + } + + // Helper function to calculate total fee + function calculateTotalFee() internal pure returns (uint128) { + return + PYTH_FEE + + PROVIDER_FEE + + (PROVIDER_FEE_PER_GAS * uint128(CALLBACK_GAS_LIMIT)); + } + + // Helper function to setup consumer request + function setupConsumerRequest( + address consumerAddress + ) + internal + returns ( + uint64 sequenceNumber, + bytes32[] memory priceIds, + uint256 publishTime + ) + { + priceIds = createPriceIds(); + publishTime = block.timestamp; + vm.deal(consumerAddress, 1 gwei); + + vm.prank(consumerAddress); + sequenceNumber = pulse.requestPriceUpdatesWithCallback{ + value: calculateTotalFee() + }(provider, publishTime, priceIds, CALLBACK_GAS_LIMIT); + + return (sequenceNumber, priceIds, publishTime); + } + + function testRequestPriceUpdate() public { + bytes32[] memory priceIds = createPriceIds(); uint256 publishTime = block.timestamp; // Fund the consumer contract @@ -103,13 +205,8 @@ contract PulseTest is Test, PulseEvents { vm.expectEmit(); emit PriceUpdateRequested(expectedRequest); - // Calculate total fee including gas component - uint128 totalFee = PYTH_FEE + - PROVIDER_FEE + - (PROVIDER_FEE_PER_GAS * uint128(CALLBACK_GAS_LIMIT)); - // Make the actual call that should emit the event - pulse.requestPriceUpdatesWithCallback{value: totalFee}( + pulse.requestPriceUpdatesWithCallback{value: calculateTotalFee()}( provider, publishTime, priceIds, @@ -129,10 +226,7 @@ contract PulseTest is Test, PulseEvents { } function testExecuteCallback() public { - bytes32[] memory priceIds = new bytes32[](2); - priceIds[0] = BTC_PRICE_FEED_ID; - priceIds[1] = ETH_PRICE_FEED_ID; - + bytes32[] memory priceIds = createPriceIds(); uint256 publishTime = block.timestamp; // Fund the consumer contract @@ -140,41 +234,15 @@ contract PulseTest is Test, PulseEvents { // Step 1: Make the request as consumer vm.prank(address(consumer)); - - // Calculate total fee including gas component - uint128 totalFee = PYTH_FEE + - PROVIDER_FEE + - (PROVIDER_FEE_PER_GAS * uint128(CALLBACK_GAS_LIMIT)); - uint64 sequenceNumber = pulse.requestPriceUpdatesWithCallback{ - value: totalFee + value: calculateTotalFee() }(provider, publishTime, priceIds, CALLBACK_GAS_LIMIT); - // Step 2: Create mock price feeds that match the expected publish time - PythStructs.PriceFeed[] memory priceFeeds = new PythStructs.PriceFeed[]( - 2 - ); - - // Create mock price feed for BTC with specific values - priceFeeds[0].id = BTC_PRICE_FEED_ID; - priceFeeds[0].price.price = MOCK_BTC_PRICE; - priceFeeds[0].price.conf = MOCK_BTC_CONF; - priceFeeds[0].price.expo = MOCK_PRICE_FEED_EXPO; - priceFeeds[0].price.publishTime = publishTime; - - // Create mock price feed for ETH with specific values - priceFeeds[1].id = ETH_PRICE_FEED_ID; - priceFeeds[1].price.price = MOCK_ETH_PRICE; - priceFeeds[1].price.conf = MOCK_ETH_CONF; - priceFeeds[1].price.expo = MOCK_PRICE_FEED_EXPO; - priceFeeds[1].price.publishTime = publishTime; - - // Mock Pyth's parsePriceFeedUpdates to return our price feeds - vm.mockCall( - address(pyth), - abi.encodeWithSelector(IPyth.parsePriceFeedUpdates.selector), - abi.encode(priceFeeds) + // Step 2: Create mock price feeds and setup Pyth response + PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds( + publishTime ); + mockPythResponse(priceFeeds); // Create arrays for expected event data int64[] memory expectedPrices = new int64[](2); @@ -206,12 +274,9 @@ contract PulseTest is Test, PulseEvents { expectedPublishTimes ); - // Create mock update data - bytes[] memory updateData = new bytes[](2); - updateData[0] = abi.encode(priceFeeds[0]); - updateData[1] = abi.encode(priceFeeds[1]); + // Create mock update data and execute callback + bytes[] memory updateData = createUpdateData(priceFeeds); - // Execute callback as provider vm.prank(provider); pulse.executeCallback( provider, @@ -226,4 +291,74 @@ contract PulseTest is Test, PulseEvents { assertEq(consumer.lastProvider(), provider); assertEq(consumer.lastPublishTime(), publishTime); } + + function testExecuteCallbackFailure() public { + FailingPulseConsumer failingConsumer = new FailingPulseConsumer(); + + ( + uint64 sequenceNumber, + bytes32[] memory priceIds, + uint256 publishTime + ) = setupConsumerRequest(address(failingConsumer)); + + PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds( + publishTime + ); + mockPythResponse(priceFeeds); + bytes[] memory updateData = createUpdateData(priceFeeds); + + vm.expectEmit(true, true, true, true); + emit PriceUpdateCallbackFailed( + sequenceNumber, + provider, + publishTime, + priceIds, + address(failingConsumer), + "callback failed" + ); + + vm.prank(provider); + pulse.executeCallback( + provider, + sequenceNumber, + priceIds, + updateData, + CALLBACK_GAS_LIMIT + ); + } + + function testExecuteCallbackCustomErrorFailure() public { + CustomErrorPulseConsumer failingConsumer = new CustomErrorPulseConsumer(); + + ( + uint64 sequenceNumber, + bytes32[] memory priceIds, + uint256 publishTime + ) = setupConsumerRequest(address(failingConsumer)); + + PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds( + publishTime + ); + mockPythResponse(priceFeeds); + bytes[] memory updateData = createUpdateData(priceFeeds); + + vm.expectEmit(true, true, true, true); + emit PriceUpdateCallbackFailed( + sequenceNumber, + provider, + publishTime, + priceIds, + address(failingConsumer), + "low-level error (possibly out of gas)" + ); + + vm.prank(provider); + pulse.executeCallback( + provider, + sequenceNumber, + priceIds, + updateData, + CALLBACK_GAS_LIMIT + ); + } } From 040cedba9655fa9f0d95cd9581f9e440c9164f0d Mon Sep 17 00:00:00 2001 From: Daniel Chew Date: Wed, 13 Nov 2024 13:51:19 +0900 Subject: [PATCH 09/20] add test for getFee --- .../ethereum/contracts/forge-test/Pulse.t.sol | 44 +++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol index 5b41ffd9d1..9cf4666685 100644 --- a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol +++ b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol @@ -361,4 +361,48 @@ contract PulseTest is Test, PulseEvents { CALLBACK_GAS_LIMIT ); } + + function testGetFee() public { + // Test with different gas limits to verify fee calculation + uint256[] memory gasLimits = new uint256[](3); + gasLimits[0] = 100_000; + gasLimits[1] = 500_000; + gasLimits[2] = 1_000_000; + + for (uint256 i = 0; i < gasLimits.length; i++) { + uint256 gasLimit = gasLimits[i]; + uint128 expectedFee = PROVIDER_FEE + // Base provider fee + (PROVIDER_FEE_PER_GAS * uint128(gasLimit)) + // Gas-based fee + PYTH_FEE; // Pyth oracle fee + + uint128 actualFee = pulse.getFee(provider, gasLimit); + + assertEq( + actualFee, + expectedFee, + "Fee calculation incorrect for gas limit" + ); + } + + // Test with zero gas limit + uint128 expectedMinFee = PROVIDER_FEE + PYTH_FEE; + uint128 actualMinFee = pulse.getFee(provider, 0); + assertEq( + actualMinFee, + expectedMinFee, + "Minimum fee calculation incorrect" + ); + + // Test with unregistered provider (should return 0 fees) + address unregisteredProvider = address(0x123); + uint128 unregisteredFee = pulse.getFee( + unregisteredProvider, + gasLimits[0] + ); + assertEq( + unregisteredFee, + PYTH_FEE, + "Unregistered provider fee should only include Pyth fee" + ); + } } From 2a11e71938dc25052686b13eb67fdb650cb3f6d6 Mon Sep 17 00:00:00 2001 From: Daniel Chew Date: Wed, 13 Nov 2024 14:24:31 +0900 Subject: [PATCH 10/20] add testWithdraw --- .../ethereum/contracts/forge-test/Pulse.t.sol | 41 +++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol index 9cf4666685..1657505095 100644 --- a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol +++ b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol @@ -405,4 +405,45 @@ contract PulseTest is Test, PulseEvents { "Unregistered provider fee should only include Pyth fee" ); } + + function testWithdraw() public { + // Setup: Request price update to accrue some fees + bytes32[] memory priceIds = createPriceIds(); + vm.deal(address(consumer), 1 gwei); + + vm.prank(address(consumer)); + pulse.requestPriceUpdatesWithCallback{value: calculateTotalFee()}( + provider, + block.timestamp, + priceIds, + CALLBACK_GAS_LIMIT + ); + + // Get provider's balance before withdrawal + uint256 providerBalanceBefore = provider.balance; + PulseState.ProviderInfo memory infoBefore = pulse.getProviderInfo( + provider + ); + + // Withdraw fees + vm.prank(provider); + pulse.withdraw(infoBefore.accruedFeesInWei); + + // Verify balances + assertEq( + provider.balance, + providerBalanceBefore + infoBefore.accruedFeesInWei + ); + + PulseState.ProviderInfo memory infoAfter = pulse.getProviderInfo( + provider + ); + assertEq(infoAfter.accruedFeesInWei, 0); + } + + function testWithdrawInsufficientBalance() public { + vm.prank(provider); + vm.expectRevert("Insufficient balance"); + pulse.withdraw(1 ether); + } } From 5977b51c937746c36cdd00c572a23152e73629c4 Mon Sep 17 00:00:00 2001 From: Daniel Chew Date: Wed, 13 Nov 2024 14:37:02 +0900 Subject: [PATCH 11/20] add testSetAndWithdrawAssFeeManager --- .../ethereum/contracts/forge-test/Pulse.t.sol | 36 +++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol index 1657505095..98b11e2dc2 100644 --- a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol +++ b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol @@ -446,4 +446,40 @@ contract PulseTest is Test, PulseEvents { vm.expectRevert("Insufficient balance"); pulse.withdraw(1 ether); } + + function testSetAndWithdrawAsFeeManager() public { + address feeManager = address(0x789); + + // Set fee manager + vm.prank(provider); + pulse.setFeeManager(feeManager); + + // Verify fee manager was set + PulseState.ProviderInfo memory info = pulse.getProviderInfo(provider); + assertEq(info.feeManager, feeManager); + + // Setup: Request price update to accrue some fees + bytes32[] memory priceIds = createPriceIds(); + vm.deal(address(consumer), 1 gwei); + + vm.prank(address(consumer)); + pulse.requestPriceUpdatesWithCallback{value: calculateTotalFee()}( + provider, + block.timestamp, + priceIds, + CALLBACK_GAS_LIMIT + ); + + // Test withdrawal as fee manager + uint256 managerBalanceBefore = feeManager.balance; + info = pulse.getProviderInfo(provider); + + vm.prank(feeManager); + pulse.withdrawAsFeeManager(provider, info.accruedFeesInWei); + + assertEq( + feeManager.balance, + managerBalanceBefore + info.accruedFeesInWei + ); + } } From f3a56ddd69b0542c757b52288b8d09117d4baf60 Mon Sep 17 00:00:00 2001 From: Daniel Chew Date: Wed, 13 Nov 2024 14:37:13 +0900 Subject: [PATCH 12/20] add testMaxNumPrices --- .../ethereum/contracts/forge-test/Pulse.t.sol | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol index 98b11e2dc2..1eecca0797 100644 --- a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol +++ b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol @@ -482,4 +482,26 @@ contract PulseTest is Test, PulseEvents { managerBalanceBefore + info.accruedFeesInWei ); } + + function testMaxNumPrices() public { + // Set max number of prices + vm.prank(provider); + pulse.setMaxNumPrices(1); + + // Try to request more prices than allowed + bytes32[] memory priceIds = new bytes32[](2); + priceIds[0] = BTC_PRICE_FEED_ID; + priceIds[1] = ETH_PRICE_FEED_ID; + + vm.deal(address(consumer), 1 gwei); + vm.prank(address(consumer)); + + vm.expectRevert("Exceeds max number of prices"); + pulse.requestPriceUpdatesWithCallback{value: calculateTotalFee()}( + provider, + block.timestamp, + priceIds, + CALLBACK_GAS_LIMIT + ); + } } From f60194acc7b748f37bd69d41b2d5a3afb0575f91 Mon Sep 17 00:00:00 2001 From: Daniel Chew Date: Wed, 13 Nov 2024 14:37:55 +0900 Subject: [PATCH 13/20] add testSetProviderUri --- .../ethereum/contracts/forge-test/Pulse.t.sol | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol index 1eecca0797..ccf78c062f 100644 --- a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol +++ b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol @@ -504,4 +504,14 @@ contract PulseTest is Test, PulseEvents { CALLBACK_GAS_LIMIT ); } + + function testSetProviderUri() public { + bytes memory newUri = "https://updated-provider.com"; + + vm.prank(provider); + pulse.setProviderUri(newUri); + + PulseState.ProviderInfo memory info = pulse.getProviderInfo(provider); + assertEq(info.uri, newUri); + } } From ce1e3ec4c32607ff8bb74afc2c16d85aa649db77 Mon Sep 17 00:00:00 2001 From: Daniel Chew Date: Wed, 13 Nov 2024 15:55:37 +0900 Subject: [PATCH 14/20] add more test --- .../contracts/contracts/pulse/Pulse.sol | 33 ++--- .../contracts/contracts/pulse/PulseErrors.sol | 3 + .../ethereum/contracts/forge-test/Pulse.t.sol | 120 +++++++++++++++++- 3 files changed, 134 insertions(+), 22 deletions(-) diff --git a/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol b/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol index 5888264e6f..5ec1454e73 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol @@ -117,7 +117,10 @@ abstract contract Pulse is IPulse, PulseState { providerInfo.maxNumPrices > 0 && priceIds.length > providerInfo.maxNumPrices ) { - revert("Exceeds max number of prices"); + revert ExceedsMaxPrices( + uint32(priceIds.length), + providerInfo.maxNumPrices + ); } // Assign sequence number and increment @@ -154,10 +157,19 @@ abstract contract Pulse is IPulse, PulseState { ) external payable override { Request storage req = findActiveRequest(provider, sequenceNumber); - require( - gasleft() >= req.callbackGasLimit, - "Insufficient gas for callback" - ); + if ( + keccak256(abi.encode(req.priceIds)) != + keccak256(abi.encode(priceIds)) + ) { + revert InvalidPriceIds(priceIds, req.priceIds); + } + + if (req.callbackGasLimit != callbackGasLimit) { + revert InvalidCallbackGasLimit( + callbackGasLimit, + req.callbackGasLimit + ); + } PythStructs.PriceFeed[] memory priceFeeds = IPyth(_state.pyth) .parsePriceFeedUpdates( @@ -169,17 +181,6 @@ abstract contract Pulse is IPulse, PulseState { uint256 publishTime = priceFeeds[0].price.publishTime; - // Verify request parameters match - require( - keccak256(abi.encode(req.priceIds)) == - keccak256(abi.encode(priceIds)), - "Invalid price IDs" - ); - require( - req.callbackGasLimit == callbackGasLimit, - "Invalid callback gas limit" - ); - try IPulseConsumer(req.requester).pulseCallback( sequenceNumber, diff --git a/target_chains/ethereum/contracts/contracts/pulse/PulseErrors.sol b/target_chains/ethereum/contracts/contracts/pulse/PulseErrors.sol index 187ccf00a2..535ad4d746 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/PulseErrors.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/PulseErrors.sol @@ -8,3 +8,6 @@ error InsufficientFee(); error Unauthorized(); error InvalidCallbackGas(); error CallbackFailed(); +error InvalidPriceIds(bytes32[] requested, bytes32[] stored); +error InvalidCallbackGasLimit(uint256 requested, uint256 stored); +error ExceedsMaxPrices(uint32 requested, uint32 maxAllowed); diff --git a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol index ccf78c062f..d78de073cf 100644 --- a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol +++ b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol @@ -8,6 +8,7 @@ import "../contracts/pulse/PulseUpgradeable.sol"; import "../contracts/pulse/IPulse.sol"; import "../contracts/pulse/PulseState.sol"; import "../contracts/pulse/PulseEvents.sol"; +import "../contracts/pulse/PulseErrors.sol"; contract MockPulseConsumer is IPulseConsumer { uint64 public lastSequenceNumber; @@ -141,8 +142,8 @@ contract PulseTest is Test, PulseEvents { ); } - // Helper function to create update data - function createUpdateData( + // Helper function to create mock update data + function createMockUpdateData( PythStructs.PriceFeed[] memory priceFeeds ) internal pure returns (bytes[] memory) { bytes[] memory updateData = new bytes[](2); @@ -225,6 +226,20 @@ contract PulseTest is Test, PulseEvents { assertEq(lastRequest.requester, expectedRequest.requester); } + function testRequestWithInsufficientFee() public { + bytes32[] memory priceIds = createPriceIds(); + vm.deal(address(consumer), 1 gwei); + + vm.prank(address(consumer)); + vm.expectRevert(InsufficientFee.selector); + pulse.requestPriceUpdatesWithCallback{value: PYTH_FEE}( // Intentionally low fee + provider, + block.timestamp, + priceIds, + CALLBACK_GAS_LIMIT + ); + } + function testExecuteCallback() public { bytes32[] memory priceIds = createPriceIds(); uint256 publishTime = block.timestamp; @@ -275,7 +290,7 @@ contract PulseTest is Test, PulseEvents { ); // Create mock update data and execute callback - bytes[] memory updateData = createUpdateData(priceFeeds); + bytes[] memory updateData = createMockUpdateData(priceFeeds); vm.prank(provider); pulse.executeCallback( @@ -305,7 +320,7 @@ contract PulseTest is Test, PulseEvents { publishTime ); mockPythResponse(priceFeeds); - bytes[] memory updateData = createUpdateData(priceFeeds); + bytes[] memory updateData = createMockUpdateData(priceFeeds); vm.expectEmit(true, true, true, true); emit PriceUpdateCallbackFailed( @@ -340,7 +355,7 @@ contract PulseTest is Test, PulseEvents { publishTime ); mockPythResponse(priceFeeds); - bytes[] memory updateData = createUpdateData(priceFeeds); + bytes[] memory updateData = createMockUpdateData(priceFeeds); vm.expectEmit(true, true, true, true); emit PriceUpdateCallbackFailed( @@ -362,6 +377,97 @@ contract PulseTest is Test, PulseEvents { ); } + // Test executing callback with mismatched price IDs + function testExecuteCallbackWithMismatchedPriceIds() public { + ( + uint64 sequenceNumber, + bytes32[] memory originalPriceIds, + uint256 publishTime + ) = setupConsumerRequest(address(consumer)); + + // Create different price IDs array + bytes32[] memory differentPriceIds = new bytes32[](1); + differentPriceIds[0] = bytes32(uint256(1)); // Different price ID + + PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds( + publishTime + ); + mockPythResponse(priceFeeds); + bytes[] memory updateData = createMockUpdateData(priceFeeds); + + vm.prank(provider); + vm.expectRevert( + abi.encodeWithSelector( + InvalidPriceIds.selector, + differentPriceIds, + originalPriceIds + ) + ); + pulse.executeCallback( + provider, + sequenceNumber, + differentPriceIds, + updateData, + CALLBACK_GAS_LIMIT + ); + } + + function testExecuteCallbackWithInsufficientGas() public { + ( + uint64 sequenceNumber, + bytes32[] memory priceIds, + uint256 publishTime + ) = setupConsumerRequest(address(consumer)); + + PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds( + publishTime + ); + mockPythResponse(priceFeeds); + bytes[] memory updateData = createMockUpdateData(priceFeeds); + + vm.prank(provider); + vm.expectRevert(); + pulse.executeCallback{gas: 10000}( + provider, + sequenceNumber, + priceIds, + updateData, + CALLBACK_GAS_LIMIT + ); + } + + function testExecuteCallbackWithInvalidGasLimit() public { + ( + uint64 sequenceNumber, + bytes32[] memory priceIds, + uint256 publishTime + ) = setupConsumerRequest(address(consumer)); + + PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds( + publishTime + ); + mockPythResponse(priceFeeds); + bytes[] memory updateData = createMockUpdateData(priceFeeds); + + // Try to execute with different gas limit than what was requested + uint256 differentGasLimit = CALLBACK_GAS_LIMIT + 1000; + vm.prank(provider); + vm.expectRevert( + abi.encodeWithSelector( + InvalidCallbackGasLimit.selector, + differentGasLimit, + CALLBACK_GAS_LIMIT + ) + ); + pulse.executeCallback( + provider, + sequenceNumber, + priceIds, + updateData, + differentGasLimit + ); + } + function testGetFee() public { // Test with different gas limits to verify fee calculation uint256[] memory gasLimits = new uint256[](3); @@ -496,7 +602,9 @@ contract PulseTest is Test, PulseEvents { vm.deal(address(consumer), 1 gwei); vm.prank(address(consumer)); - vm.expectRevert("Exceeds max number of prices"); + vm.expectRevert( + abi.encodeWithSelector(ExceedsMaxPrices.selector, 2, 1) + ); pulse.requestPriceUpdatesWithCallback{value: calculateTotalFee()}( provider, block.timestamp, From 0e5e32818bfe79f824cd8ef248db896611bfe234 Mon Sep 17 00:00:00 2001 From: Daniel Chew Date: Wed, 13 Nov 2024 16:25:30 +0900 Subject: [PATCH 15/20] add testExecuteCallbackWithFutureTimestamp --- .../ethereum/contracts/forge-test/Pulse.t.sol | 36 +++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol index d78de073cf..a6b2d365a9 100644 --- a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol +++ b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol @@ -219,6 +219,10 @@ contract PulseTest is Test, PulseEvents { assertEq(lastRequest.provider, expectedRequest.provider); assertEq(lastRequest.sequenceNumber, expectedRequest.sequenceNumber); assertEq(lastRequest.publishTime, expectedRequest.publishTime); + assertEq( + keccak256(abi.encode(lastRequest.priceIds)), + keccak256(abi.encode(expectedRequest.priceIds)) + ); assertEq( lastRequest.callbackGasLimit, expectedRequest.callbackGasLimit @@ -468,6 +472,38 @@ contract PulseTest is Test, PulseEvents { ); } + function testExecuteCallbackWithFutureTimestamp() public { + // Setup request with future timestamp + bytes32[] memory priceIds = createPriceIds(); + uint256 futureTime = block.timestamp + 1 days; + vm.deal(address(consumer), 1 gwei); + + vm.prank(address(consumer)); + uint64 sequenceNumber = pulse.requestPriceUpdatesWithCallback{ + value: calculateTotalFee() + }(provider, futureTime, priceIds, CALLBACK_GAS_LIMIT); + + // Try to execute callback before the requested timestamp + PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds( + futureTime // Mock price feeds with future timestamp + ); + mockPythResponse(priceFeeds); // This will make parsePriceFeedUpdates return future-dated prices + bytes[] memory updateData = createMockUpdateData(priceFeeds); + + vm.prank(provider); + // Should succeed because we're simulating receiving future-dated price updates + pulse.executeCallback( + provider, + sequenceNumber, + priceIds, + updateData, + CALLBACK_GAS_LIMIT + ); + + // Verify the callback was executed with future timestamp + assertEq(consumer.lastPublishTime(), futureTime); + } + function testGetFee() public { // Test with different gas limits to verify fee calculation uint256[] memory gasLimits = new uint256[](3); From 6702d2fa6f93e1efc2e0f03332175c48903dddb5 Mon Sep 17 00:00:00 2001 From: Daniel Chew Date: Wed, 13 Nov 2024 16:41:56 +0900 Subject: [PATCH 16/20] update tests --- .../ethereum/contracts/forge-test/Pulse.t.sol | 76 +++++++++++++++++-- 1 file changed, 68 insertions(+), 8 deletions(-) diff --git a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol index a6b2d365a9..cbc15f417d 100644 --- a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol +++ b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol @@ -132,7 +132,7 @@ contract PulseTest is Test, PulseEvents { } // Helper function to mock Pyth response - function mockPythResponse( + function mockParsePriceFeedUpdates( PythStructs.PriceFeed[] memory priceFeeds ) internal { vm.mockCall( @@ -261,7 +261,7 @@ contract PulseTest is Test, PulseEvents { PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds( publishTime ); - mockPythResponse(priceFeeds); + mockParsePriceFeedUpdates(priceFeeds); // Create arrays for expected event data int64[] memory expectedPrices = new int64[](2); @@ -323,7 +323,7 @@ contract PulseTest is Test, PulseEvents { PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds( publishTime ); - mockPythResponse(priceFeeds); + mockParsePriceFeedUpdates(priceFeeds); bytes[] memory updateData = createMockUpdateData(priceFeeds); vm.expectEmit(true, true, true, true); @@ -358,7 +358,7 @@ contract PulseTest is Test, PulseEvents { PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds( publishTime ); - mockPythResponse(priceFeeds); + mockParsePriceFeedUpdates(priceFeeds); bytes[] memory updateData = createMockUpdateData(priceFeeds); vm.expectEmit(true, true, true, true); @@ -396,7 +396,7 @@ contract PulseTest is Test, PulseEvents { PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds( publishTime ); - mockPythResponse(priceFeeds); + mockParsePriceFeedUpdates(priceFeeds); bytes[] memory updateData = createMockUpdateData(priceFeeds); vm.prank(provider); @@ -426,7 +426,7 @@ contract PulseTest is Test, PulseEvents { PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds( publishTime ); - mockPythResponse(priceFeeds); + mockParsePriceFeedUpdates(priceFeeds); bytes[] memory updateData = createMockUpdateData(priceFeeds); vm.prank(provider); @@ -450,7 +450,7 @@ contract PulseTest is Test, PulseEvents { PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds( publishTime ); - mockPythResponse(priceFeeds); + mockParsePriceFeedUpdates(priceFeeds); bytes[] memory updateData = createMockUpdateData(priceFeeds); // Try to execute with different gas limit than what was requested @@ -487,7 +487,7 @@ contract PulseTest is Test, PulseEvents { PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds( futureTime // Mock price feeds with future timestamp ); - mockPythResponse(priceFeeds); // This will make parsePriceFeedUpdates return future-dated prices + mockParsePriceFeedUpdates(priceFeeds); // This will make parsePriceFeedUpdates return future-dated prices bytes[] memory updateData = createMockUpdateData(priceFeeds); vm.prank(provider); @@ -504,6 +504,66 @@ contract PulseTest is Test, PulseEvents { assertEq(consumer.lastPublishTime(), futureTime); } + function testExecuteCallbackWithWrongProvider() public { + ( + uint64 sequenceNumber, + bytes32[] memory priceIds, + uint256 publishTime + ) = setupConsumerRequest(address(consumer)); + + PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds( + publishTime + ); + mockParsePriceFeedUpdates(priceFeeds); + bytes[] memory updateData = createMockUpdateData(priceFeeds); + + address wrongProvider = address(0x999); + vm.prank(wrongProvider); + vm.expectRevert(NoSuchRequest.selector); + pulse.executeCallback( + wrongProvider, + sequenceNumber, + priceIds, + updateData, + CALLBACK_GAS_LIMIT + ); + } + + function testDoubleExecuteCallback() public { + ( + uint64 sequenceNumber, + bytes32[] memory priceIds, + uint256 publishTime + ) = setupConsumerRequest(address(consumer)); + + PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds( + publishTime + ); + mockParsePriceFeedUpdates(priceFeeds); + bytes[] memory updateData = createMockUpdateData(priceFeeds); + + // First execution + vm.prank(provider); + pulse.executeCallback( + provider, + sequenceNumber, + priceIds, + updateData, + CALLBACK_GAS_LIMIT + ); + + // Second execution should fail + vm.prank(provider); + vm.expectRevert(NoSuchRequest.selector); + pulse.executeCallback( + provider, + sequenceNumber, + priceIds, + updateData, + CALLBACK_GAS_LIMIT + ); + } + function testGetFee() public { // Test with different gas limits to verify fee calculation uint256[] memory gasLimits = new uint256[](3); From 938c2165108a48aa7fa2e769966ff6005d65c8e7 Mon Sep 17 00:00:00 2001 From: Daniel Chew Date: Mon, 18 Nov 2024 16:18:35 +0900 Subject: [PATCH 17/20] remove provider --- .../contracts/contracts/pulse/IPulse.sol | 41 +-- .../contracts/contracts/pulse/Pulse.sol | 303 ++++-------------- .../contracts/contracts/pulse/PulseEvents.sol | 43 +-- .../contracts/contracts/pulse/PulseState.sol | 17 +- .../contracts/pulse/PulseUpgradeable.sol | 3 +- .../ethereum/contracts/forge-test/Pulse.t.sol | 275 +++++++--------- 6 files changed, 184 insertions(+), 498 deletions(-) diff --git a/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol b/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol index 8e79b97c13..125d06f357 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol @@ -8,7 +8,7 @@ import "./PulseState.sol"; interface IPulseConsumer { function pulseCallback( uint64 sequenceNumber, - address provider, + address updater, uint256 publishTime, bytes32[] calldata priceIds ) external; @@ -17,66 +17,33 @@ interface IPulseConsumer { interface IPulse is PulseEvents { // Core functions function requestPriceUpdatesWithCallback( - address provider, uint256 publishTime, bytes32[] calldata priceIds, uint256 callbackGasLimit ) external payable returns (uint64 sequenceNumber); function executeCallback( - address provider, uint64 sequenceNumber, bytes32[] calldata priceIds, bytes[] calldata updateData, uint256 callbackGasLimit ) external payable; - // Provider management - function register( - uint128 feeInWei, - uint128 feePerGas, - bytes calldata uri - ) external; - - function setProviderFee(uint128 newFeeInWei) external; - - function setProviderFeeAsFeeManager( - address provider, - uint128 newFeeInWei - ) external; - - function setProviderUri(bytes calldata uri) external; - - function withdraw(uint128 amount) external; - - function withdrawAsFeeManager(address provider, uint128 amount) external; - // Getters function getFee( - address provider, uint256 callbackGasLimit ) external view returns (uint128 feeAmount); function getPythFeeInWei() external view returns (uint128 pythFeeInWei); - function getAccruedPythFees() - external - view - returns (uint128 accruedPythFeesInWei); - - function getDefaultProvider() external view returns (address); - - function getProviderInfo( - address provider - ) external view returns (PulseState.ProviderInfo memory info); + function getAccruedFees() external view returns (uint128 accruedFeesInWei); function getRequest( - address provider, uint64 sequenceNumber ) external view returns (PulseState.Request memory req); - // Setters + // Add these functions to the IPulse interface function setFeeManager(address manager) external; - function setMaxNumPrices(uint32 maxNumPrices) external; + function withdrawAsFeeManager(uint128 amount) external; } diff --git a/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol b/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol index 5ec1454e73..167f3e9f7a 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol @@ -12,150 +12,58 @@ abstract contract Pulse is IPulse, PulseState { function _initialize( address admin, uint128 pythFeeInWei, - address defaultProvider, address pythAddress, bool prefillRequestStorage ) internal { require(admin != address(0), "admin is zero address"); - require( - defaultProvider != address(0), - "defaultProvider is zero address" - ); require(pythAddress != address(0), "pyth is zero address"); _state.admin = admin; - _state.accruedPythFeesInWei = 0; + _state.accruedFeesInWei = 0; _state.pythFeeInWei = pythFeeInWei; - _state.defaultProvider = defaultProvider; _state.pyth = pythAddress; + _state.currentSequenceNumber = 1; if (prefillRequestStorage) { - // Write some data to every storage slot in the requests array such that new requests - // use a more consistent amount of gas. - // Note that these requests are not live because their sequenceNumber is 0. for (uint8 i = 0; i < NUM_REQUESTS; i++) { Request storage req = _state.requests[i]; - req.sequenceNumber = 0; // Keep it inactive + req.sequenceNumber = 0; req.publishTime = 1; - // No need to prefill dynamic arrays (priceIds, updateData) req.callbackGasLimit = 1; req.requester = address(1); } } } - function register( - uint128 feeInWei, - uint128 feePerGas, - bytes calldata uri - ) public override { - ProviderInfo storage providerInfo = _state.providers[msg.sender]; - - providerInfo.feeInWei = feeInWei; - providerInfo.feePerGas = feePerGas; - providerInfo.uri = uri; - providerInfo.sequenceNumber += 1; - - emit ProviderRegistered(providerInfo); - } - - function withdraw(uint128 amount) public override { - ProviderInfo storage providerInfo = _state.providers[msg.sender]; - - // Use checks-effects-interactions pattern to prevent reentrancy attacks. - require( - providerInfo.accruedFeesInWei >= amount, - "Insufficient balance" - ); - providerInfo.accruedFeesInWei -= amount; - - // Interaction with an external contract or token transfer - (bool sent, ) = msg.sender.call{value: amount}(""); - require(sent, "withdrawal to msg.sender failed"); - - emit ProviderWithdrawn(msg.sender, msg.sender, amount); - } - - function withdrawAsFeeManager( - address provider, - uint128 amount - ) external override { - ProviderInfo storage providerInfo = _state.providers[provider]; - - if (providerInfo.sequenceNumber == 0) { - revert NoSuchProvider(); - } - - if (providerInfo.feeManager != msg.sender) { - revert Unauthorized(); - } - - // Use checks-effects-interactions pattern to prevent reentrancy attacks. - require( - providerInfo.accruedFeesInWei >= amount, - "Insufficient balance" - ); - providerInfo.accruedFeesInWei -= amount; - - // Interaction with an external contract or token transfer - (bool sent, ) = msg.sender.call{value: amount}(""); - require(sent, "withdrawal to msg.sender failed"); - - emit ProviderWithdrawn(provider, msg.sender, amount); - } - function requestPriceUpdatesWithCallback( - address provider, uint256 publishTime, bytes32[] calldata priceIds, uint256 callbackGasLimit ) external payable override returns (uint64 requestSequenceNumber) { - ProviderInfo storage providerInfo = _state.providers[provider]; - if (providerInfo.sequenceNumber == 0) revert NoSuchProvider(); - - if ( - providerInfo.maxNumPrices > 0 && - priceIds.length > providerInfo.maxNumPrices - ) { - revert ExceedsMaxPrices( - uint32(priceIds.length), - providerInfo.maxNumPrices - ); - } - - // Assign sequence number and increment - requestSequenceNumber = providerInfo.sequenceNumber++; + requestSequenceNumber = _state.currentSequenceNumber++; - // Verify fee payment - uint128 requiredFee = getFee(provider, callbackGasLimit); + uint128 requiredFee = getFee(callbackGasLimit); if (msg.value < requiredFee) revert InsufficientFee(); - // Store request for callback execution - Request storage req = allocRequest(provider, requestSequenceNumber); - req.provider = provider; + Request storage req = allocRequest(requestSequenceNumber); req.sequenceNumber = requestSequenceNumber; req.publishTime = publishTime; req.priceIds = priceIds; req.callbackGasLimit = callbackGasLimit; req.requester = msg.sender; - // Update fee balances - providerInfo.accruedFeesInWei += providerInfo.feeInWei; - _state.accruedPythFeesInWei += - SafeCast.toUint128(msg.value) - - providerInfo.feeInWei; + _state.accruedFeesInWei += SafeCast.toUint128(msg.value); emit PriceUpdateRequested(req); } function executeCallback( - address provider, uint64 sequenceNumber, bytes32[] calldata priceIds, bytes[] calldata updateData, uint256 callbackGasLimit ) external payable override { - Request storage req = findActiveRequest(provider, sequenceNumber); + Request storage req = findActiveRequest(sequenceNumber); if ( keccak256(abi.encode(req.priceIds)) != @@ -190,13 +98,7 @@ abstract contract Pulse is IPulse, PulseState { ) { // Callback succeeded - emitPriceUpdate( - sequenceNumber, - msg.sender, - publishTime, - priceIds, - priceFeeds - ); + emitPriceUpdate(sequenceNumber, publishTime, priceIds, priceFeeds); } catch Error(string memory reason) { // Explicit revert/require emit PriceUpdateCallbackFailed( @@ -219,13 +121,11 @@ abstract contract Pulse is IPulse, PulseState { ); } - // Clear request regardless of callback success - clearRequest(msg.sender, sequenceNumber); + clearRequest(sequenceNumber); } function emitPriceUpdate( uint64 sequenceNumber, - address provider, uint256 publishTime, bytes32[] memory priceIds, PythStructs.PriceFeed[] memory priceFeeds @@ -244,7 +144,7 @@ abstract contract Pulse is IPulse, PulseState { emit PriceUpdateExecuted( sequenceNumber, - provider, + msg.sender, publishTime, priceIds, prices, @@ -254,37 +154,12 @@ abstract contract Pulse is IPulse, PulseState { ); } - function getProviderInfo( - address provider - ) public view override returns (ProviderInfo memory info) { - info = _state.providers[provider]; - } - - function getDefaultProvider() - public - view - override - returns (address provider) - { - provider = _state.defaultProvider; - } - - function getRequest( - address provider, - uint64 sequenceNumber - ) public view override returns (Request memory req) { - req = findRequest(provider, sequenceNumber); - } - function getFee( - address provider, uint256 callbackGasLimit ) public view override returns (uint128 feeAmount) { - ProviderInfo storage providerInfo = _state.providers[provider]; - feeAmount = - providerInfo.feeInWei + - (providerInfo.feePerGas * uint128(callbackGasLimit)) + - _state.pythFeeInWei; + uint128 baseFee = _state.pythFeeInWei; + uint256 gasFee = callbackGasLimit * tx.gasprice; + feeAmount = baseFee + SafeCast.toUint128(gasFee); } function getPythFeeInWei() @@ -296,167 +171,105 @@ abstract contract Pulse is IPulse, PulseState { pythFeeInWei = _state.pythFeeInWei; } - function getAccruedPythFees() + function getAccruedFees() public view override - returns (uint128 accruedPythFeesInWei) + returns (uint128 accruedFeesInWei) { - accruedPythFeesInWei = _state.accruedPythFeesInWei; + accruedFeesInWei = _state.accruedFeesInWei; } - // Set provider fee. It will revert if provider is not registered. - function setProviderFee(uint128 newFeeInWei) external override { - ProviderInfo storage provider = _state.providers[msg.sender]; - - if (provider.sequenceNumber == 0) { - revert NoSuchProvider(); - } - uint128 oldFeeInWei = provider.feeInWei; - provider.feeInWei = newFeeInWei; - emit ProviderFeeUpdated(msg.sender, oldFeeInWei, newFeeInWei); + function getRequest( + uint64 sequenceNumber + ) public view override returns (Request memory req) { + req = findRequest(sequenceNumber); } - function setProviderFeeAsFeeManager( - address provider, - uint128 newFeeInWei - ) external override { - ProviderInfo storage providerInfo = _state.providers[provider]; - - if (providerInfo.sequenceNumber == 0) { - revert NoSuchProvider(); - } - - if (providerInfo.feeManager != msg.sender) { - revert Unauthorized(); - } - - uint128 oldFeeInWei = providerInfo.feeInWei; - providerInfo.feeInWei = newFeeInWei; - - emit ProviderFeeUpdated(provider, oldFeeInWei, newFeeInWei); + function requestKey( + uint64 sequenceNumber + ) internal pure returns (bytes32 hash, uint8 shortHash) { + hash = keccak256(abi.encodePacked(sequenceNumber)); + shortHash = uint8(hash[0] & NUM_REQUESTS_MASK); } - // Set provider uri. It will revert if provider is not registered. - function setProviderUri(bytes calldata newUri) external override { - ProviderInfo storage provider = _state.providers[msg.sender]; - if (provider.sequenceNumber == 0) { - revert NoSuchProvider(); - } - bytes memory oldUri = provider.uri; - provider.uri = newUri; - emit ProviderUriUpdated(msg.sender, oldUri, newUri); - } + function withdrawFees(uint128 amount) external { + require(msg.sender == _state.admin, "Only admin can withdraw fees"); + require(_state.accruedFeesInWei >= amount, "Insufficient balance"); - function setFeeManager(address manager) external override { - ProviderInfo storage provider = _state.providers[msg.sender]; - if (provider.sequenceNumber == 0) { - revert NoSuchProvider(); - } + _state.accruedFeesInWei -= amount; - address oldFeeManager = provider.feeManager; - provider.feeManager = manager; - emit ProviderFeeManagerUpdated(msg.sender, oldFeeManager, manager); - } + (bool sent, ) = msg.sender.call{value: amount}(""); + require(sent, "Failed to send fees"); - function requestKey( - address provider, - uint64 sequenceNumber - ) internal pure returns (bytes32 hash, uint8 shortHash) { - hash = keccak256(abi.encodePacked(provider, sequenceNumber)); - shortHash = uint8(hash[0] & NUM_REQUESTS_MASK); + emit FeesWithdrawn(msg.sender, amount); } - // Find an in-flight active request for given the provider and the sequence number. - // This method returns a reference to the request, and will revert if the request is - // not active. function findActiveRequest( - address provider, uint64 sequenceNumber ) internal view returns (Request storage req) { - req = findRequest(provider, sequenceNumber); + req = findRequest(sequenceNumber); - // Check there is an active request for the given provider and sequence number. - if ( - !isActive(req) || - req.provider != provider || - req.sequenceNumber != sequenceNumber - ) revert NoSuchRequest(); + if (!isActive(req) || req.sequenceNumber != sequenceNumber) + revert NoSuchRequest(); } - // Find an in-flight request. - // Note that this method can return requests that are not currently active. The caller is responsible for checking - // that the returned request is active (if they care). function findRequest( - address provider, uint64 sequenceNumber ) internal view returns (Request storage req) { - (bytes32 key, uint8 shortKey) = requestKey(provider, sequenceNumber); + (bytes32 key, uint8 shortKey) = requestKey(sequenceNumber); req = _state.requests[shortKey]; - if (req.provider == provider && req.sequenceNumber == sequenceNumber) { + if (req.sequenceNumber == sequenceNumber) { return req; } else { req = _state.requestsOverflow[key]; } } - // Clear the storage for an in-flight request, deleting it from the hash table. - function clearRequest(address provider, uint64 sequenceNumber) internal { - (bytes32 key, uint8 shortKey) = requestKey(provider, sequenceNumber); + function clearRequest(uint64 sequenceNumber) internal { + (bytes32 key, uint8 shortKey) = requestKey(sequenceNumber); Request storage req = _state.requests[shortKey]; - if (req.provider == provider && req.sequenceNumber == sequenceNumber) { + if (req.sequenceNumber == sequenceNumber) { req.sequenceNumber = 0; } else { delete _state.requestsOverflow[key]; } } - // Allocate storage space for a new in-flight request. This method returns a pointer to a storage slot - // that the caller should overwrite with the new request. Note that the memory at this storage slot may - // -- and will -- be filled with arbitrary values, so the caller *must* overwrite every field of the returned - // struct. function allocRequest( - address provider, uint64 sequenceNumber ) internal returns (Request storage req) { - (, uint8 shortKey) = requestKey(provider, sequenceNumber); + (, uint8 shortKey) = requestKey(sequenceNumber); req = _state.requests[shortKey]; if (isActive(req)) { - // There's already a prior active request in the storage slot we want to use. - // Overflow the prior request to the requestsOverflow mapping. - // It is important that this code overflows the *prior* request to the mapping, and not the new request. - // There is a chance that some requests never get revealed and remain active forever. We do not want such - // requests to fill up all of the space in the array and cause all new requests to incur the higher gas cost - // of the mapping. - // - // This operation is expensive, but should be rare. If overflow happens frequently, increase - // the size of the requests array to support more concurrent active requests. - (bytes32 reqKey, ) = requestKey(req.provider, req.sequenceNumber); + (bytes32 reqKey, ) = requestKey(req.sequenceNumber); _state.requestsOverflow[reqKey] = req; } } - // Returns true if a request is active, i.e., its corresponding price update has not yet been executed. function isActive(Request storage req) internal view returns (bool) { - // Note that a provider's initial registration occupies sequence number 0, so there is no way to construct - // a price update request with sequence number 0. return req.sequenceNumber != 0; } - function setMaxNumPrices(uint32 maxNumPrices) external override { - ProviderInfo storage provider = _state.providers[msg.sender]; - if (provider.sequenceNumber == 0) revert NoSuchProvider(); + function setFeeManager(address manager) external override { + require(msg.sender == _state.admin, "Only admin can set fee manager"); + address oldFeeManager = _state.feeManager; + _state.feeManager = manager; + emit FeeManagerUpdated(_state.admin, oldFeeManager, manager); + } + + function withdrawAsFeeManager(uint128 amount) external override { + require(msg.sender == _state.feeManager, "Only fee manager"); + require(_state.accruedFeesInWei >= amount, "Insufficient balance"); - uint32 oldMaxNumPrices = provider.maxNumPrices; - provider.maxNumPrices = maxNumPrices; + _state.accruedFeesInWei -= amount; - emit ProviderMaxNumPricesUpdated( - msg.sender, - oldMaxNumPrices, - maxNumPrices - ); + (bool sent, ) = msg.sender.call{value: amount}(""); + require(sent, "Failed to send fees"); + + emit FeesWithdrawn(msg.sender, amount); } } diff --git a/target_chains/ethereum/contracts/contracts/pulse/PulseEvents.sol b/target_chains/ethereum/contracts/contracts/pulse/PulseEvents.sol index 070094fdbb..1c96797b39 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/PulseEvents.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/PulseEvents.sol @@ -4,14 +4,11 @@ pragma solidity ^0.8.0; import "./PulseState.sol"; interface PulseEvents { - // Events - event ProviderRegistered(PulseState.ProviderInfo providerInfo); - event PriceUpdateRequested(PulseState.Request request); event PriceUpdateExecuted( uint64 indexed sequenceNumber, - address indexed provider, + address indexed updater, uint256 publishTime, bytes32[] priceIds, int64[] prices, @@ -20,42 +17,20 @@ interface PulseEvents { uint256[] publishTimes ); - event ProviderFeeUpdated( - address indexed provider, - uint128 oldFeeInWei, - uint128 newFeeInWei - ); - - event ProviderUriUpdated( - address indexed provider, - bytes oldUri, - bytes newUri - ); - - event ProviderWithdrawn( - address indexed provider, - address indexed recipient, - uint128 amount - ); - - event ProviderFeeManagerUpdated( - address indexed provider, - address oldFeeManager, - address newFeeManager - ); - - event ProviderMaxNumPricesUpdated( - address indexed provider, - uint32 oldMaxNumPrices, - uint32 maxNumPrices - ); + event FeesWithdrawn(address indexed recipient, uint128 amount); event PriceUpdateCallbackFailed( uint64 indexed sequenceNumber, - address indexed provider, + address indexed updater, uint256 publishTime, bytes32[] priceIds, address requester, string reason ); + + event FeeManagerUpdated( + address indexed admin, + address oldFeeManager, + address newFeeManager + ); } diff --git a/target_chains/ethereum/contracts/contracts/pulse/PulseState.sol b/target_chains/ethereum/contracts/contracts/pulse/PulseState.sol index 3341edee21..08052dd1e3 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/PulseState.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/PulseState.sol @@ -7,7 +7,6 @@ contract PulseState { bytes1 public constant NUM_REQUESTS_MASK = 0x1f; struct Request { - address provider; uint64 sequenceNumber; uint256 publishTime; bytes32[] priceIds; @@ -15,25 +14,15 @@ contract PulseState { address requester; } - struct ProviderInfo { - uint64 sequenceNumber; - uint128 feeInWei; - uint128 accruedFeesInWei; - bytes uri; - address feeManager; - uint32 maxNumPrices; - uint128 feePerGas; - } - struct State { address admin; uint128 pythFeeInWei; - uint128 accruedPythFeesInWei; - address defaultProvider; + uint128 accruedFeesInWei; address pyth; + uint64 currentSequenceNumber; + address feeManager; Request[32] requests; mapping(bytes32 => Request) requestsOverflow; - mapping(address => ProviderInfo) providers; } State internal _state; diff --git a/target_chains/ethereum/contracts/contracts/pulse/PulseUpgradeable.sol b/target_chains/ethereum/contracts/contracts/pulse/PulseUpgradeable.sol index 0c09e8b9de..48fc694e69 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/PulseUpgradeable.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/PulseUpgradeable.sol @@ -22,11 +22,11 @@ contract PulseUpgradeable is address owner, address admin, uint128 pythFeeInWei, - address defaultProvider, address pythAddress, bool prefillRequestStorage ) public initializer { require(owner != address(0), "owner is zero address"); + require(admin != address(0), "admin is zero address"); __Ownable_init(); __UUPSUpgradeable_init(); @@ -34,7 +34,6 @@ contract PulseUpgradeable is Pulse._initialize( admin, pythFeeInWei, - defaultProvider, pythAddress, prefillRequestStorage ); diff --git a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol index cbc15f417d..b0963b0429 100644 --- a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol +++ b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol @@ -12,18 +12,18 @@ import "../contracts/pulse/PulseErrors.sol"; contract MockPulseConsumer is IPulseConsumer { uint64 public lastSequenceNumber; - address public lastProvider; + address public lastUpdater; uint256 public lastPublishTime; bytes32[] public lastPriceIds; function pulseCallback( uint64 sequenceNumber, - address provider, + address updater, uint256 publishTime, bytes32[] calldata priceIds ) external override { lastSequenceNumber = sequenceNumber; - lastProvider = provider; + lastUpdater = updater; lastPublishTime = publishTime; lastPriceIds = priceIds; } @@ -59,13 +59,11 @@ contract PulseTest is Test, PulseEvents { MockPulseConsumer public consumer; address public owner; address public admin; - address public provider; + address public updater; address public pyth; // Constants uint128 constant PYTH_FEE = 1 wei; - uint128 constant PROVIDER_FEE = 1 wei; - uint128 constant PROVIDER_FEE_PER_GAS = 1 wei; uint128 constant CALLBACK_GAS_LIMIT = 1_000_000; bytes32 constant BTC_PRICE_FEED_ID = 0xe62df6c8b4a85fe1a67db44dc12de5db330f7ac66b72dc658afedf0f4a415b43; @@ -82,22 +80,15 @@ contract PulseTest is Test, PulseEvents { function setUp() public { owner = address(1); admin = address(2); - provider = address(3); + updater = address(3); pyth = address(4); PulseUpgradeable _pulse = new PulseUpgradeable(); proxy = new ERC1967Proxy(address(_pulse), ""); pulse = PulseUpgradeable(address(proxy)); - pulse.initialize(owner, admin, PYTH_FEE, provider, pyth, false); + pulse.initialize(owner, admin, PYTH_FEE, pyth, false); consumer = new MockPulseConsumer(); - - vm.prank(provider); - pulse.register( - PROVIDER_FEE, - PROVIDER_FEE_PER_GAS, - "https://provider.com" - ); } // Helper function to create price IDs array @@ -153,11 +144,8 @@ contract PulseTest is Test, PulseEvents { } // Helper function to calculate total fee - function calculateTotalFee() internal pure returns (uint128) { - return - PYTH_FEE + - PROVIDER_FEE + - (PROVIDER_FEE_PER_GAS * uint128(CALLBACK_GAS_LIMIT)); + function calculateTotalFee() internal view returns (uint128) { + return pulse.getFee(CALLBACK_GAS_LIMIT); } // Helper function to setup consumer request @@ -175,26 +163,31 @@ contract PulseTest is Test, PulseEvents { publishTime = block.timestamp; vm.deal(consumerAddress, 1 gwei); + uint128 totalFee = calculateTotalFee(); + vm.prank(consumerAddress); - sequenceNumber = pulse.requestPriceUpdatesWithCallback{ - value: calculateTotalFee() - }(provider, publishTime, priceIds, CALLBACK_GAS_LIMIT); + sequenceNumber = pulse.requestPriceUpdatesWithCallback{value: totalFee}( + publishTime, + priceIds, + CALLBACK_GAS_LIMIT + ); return (sequenceNumber, priceIds, publishTime); } function testRequestPriceUpdate() public { + // Set a realistic gas price + vm.txGasPrice(30 gwei); + bytes32[] memory priceIds = createPriceIds(); uint256 publishTime = block.timestamp; - // Fund the consumer contract - vm.deal(address(consumer), 1 gwei); - - vm.prank(address(consumer)); + // Fund the consumer contract with enough ETH for higher gas price + vm.deal(address(consumer), 1 ether); + uint128 totalFee = calculateTotalFee(); // Create the event data we expect to see PulseState.Request memory expectedRequest = PulseState.Request({ - provider: provider, sequenceNumber: 1, publishTime: publishTime, priceIds: priceIds, @@ -202,21 +195,18 @@ contract PulseTest is Test, PulseEvents { requester: address(consumer) }); - // Emit event with expected parameters vm.expectEmit(); emit PriceUpdateRequested(expectedRequest); - // Make the actual call that should emit the event - pulse.requestPriceUpdatesWithCallback{value: calculateTotalFee()}( - provider, + vm.prank(address(consumer)); + pulse.requestPriceUpdatesWithCallback{value: totalFee}( publishTime, priceIds, CALLBACK_GAS_LIMIT ); // Additional assertions to verify event data was stored correctly - PulseState.Request memory lastRequest = pulse.getRequest(provider, 1); - assertEq(lastRequest.provider, expectedRequest.provider); + PulseState.Request memory lastRequest = pulse.getRequest(1); assertEq(lastRequest.sequenceNumber, expectedRequest.sequenceNumber); assertEq(lastRequest.publishTime, expectedRequest.publishTime); assertEq( @@ -227,17 +217,22 @@ contract PulseTest is Test, PulseEvents { lastRequest.callbackGasLimit, expectedRequest.callbackGasLimit ); - assertEq(lastRequest.requester, expectedRequest.requester); + assertEq( + lastRequest.requester, + expectedRequest.requester, + "Requester mismatch" + ); } function testRequestWithInsufficientFee() public { - bytes32[] memory priceIds = createPriceIds(); - vm.deal(address(consumer), 1 gwei); + // Set a realistic gas price + vm.txGasPrice(30 gwei); + bytes32[] memory priceIds = createPriceIds(); + vm.deal(address(consumer), 1 ether); vm.prank(address(consumer)); vm.expectRevert(InsufficientFee.selector); pulse.requestPriceUpdatesWithCallback{value: PYTH_FEE}( // Intentionally low fee - provider, block.timestamp, priceIds, CALLBACK_GAS_LIMIT @@ -251,11 +246,13 @@ contract PulseTest is Test, PulseEvents { // Fund the consumer contract vm.deal(address(consumer), 1 gwei); + uint128 totalFee = calculateTotalFee(); + // Step 1: Make the request as consumer vm.prank(address(consumer)); uint64 sequenceNumber = pulse.requestPriceUpdatesWithCallback{ - value: calculateTotalFee() - }(provider, publishTime, priceIds, CALLBACK_GAS_LIMIT); + value: totalFee + }(publishTime, priceIds, CALLBACK_GAS_LIMIT); // Step 2: Create mock price feeds and setup Pyth response PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds( @@ -284,7 +281,7 @@ contract PulseTest is Test, PulseEvents { vm.expectEmit(true, true, false, true); emit PriceUpdateExecuted( sequenceNumber, - provider, + updater, publishTime, priceIds, expectedPrices, @@ -296,9 +293,8 @@ contract PulseTest is Test, PulseEvents { // Create mock update data and execute callback bytes[] memory updateData = createMockUpdateData(priceFeeds); - vm.prank(provider); + vm.prank(updater); pulse.executeCallback( - provider, sequenceNumber, priceIds, updateData, @@ -307,7 +303,6 @@ contract PulseTest is Test, PulseEvents { // Verify callback was executed assertEq(consumer.lastSequenceNumber(), sequenceNumber); - assertEq(consumer.lastProvider(), provider); assertEq(consumer.lastPublishTime(), publishTime); } @@ -329,16 +324,15 @@ contract PulseTest is Test, PulseEvents { vm.expectEmit(true, true, true, true); emit PriceUpdateCallbackFailed( sequenceNumber, - provider, + updater, publishTime, priceIds, address(failingConsumer), "callback failed" ); - vm.prank(provider); + vm.prank(updater); pulse.executeCallback( - provider, sequenceNumber, priceIds, updateData, @@ -364,16 +358,15 @@ contract PulseTest is Test, PulseEvents { vm.expectEmit(true, true, true, true); emit PriceUpdateCallbackFailed( sequenceNumber, - provider, + updater, publishTime, priceIds, address(failingConsumer), "low-level error (possibly out of gas)" ); - vm.prank(provider); + vm.prank(updater); pulse.executeCallback( - provider, sequenceNumber, priceIds, updateData, @@ -399,7 +392,7 @@ contract PulseTest is Test, PulseEvents { mockParsePriceFeedUpdates(priceFeeds); bytes[] memory updateData = createMockUpdateData(priceFeeds); - vm.prank(provider); + vm.prank(updater); vm.expectRevert( abi.encodeWithSelector( InvalidPriceIds.selector, @@ -408,7 +401,6 @@ contract PulseTest is Test, PulseEvents { ) ); pulse.executeCallback( - provider, sequenceNumber, differentPriceIds, updateData, @@ -429,10 +421,9 @@ contract PulseTest is Test, PulseEvents { mockParsePriceFeedUpdates(priceFeeds); bytes[] memory updateData = createMockUpdateData(priceFeeds); - vm.prank(provider); + vm.prank(updater); vm.expectRevert(); pulse.executeCallback{gas: 10000}( - provider, sequenceNumber, priceIds, updateData, @@ -455,7 +446,7 @@ contract PulseTest is Test, PulseEvents { // Try to execute with different gas limit than what was requested uint256 differentGasLimit = CALLBACK_GAS_LIMIT + 1000; - vm.prank(provider); + vm.prank(updater); vm.expectRevert( abi.encodeWithSelector( InvalidCallbackGasLimit.selector, @@ -464,7 +455,6 @@ contract PulseTest is Test, PulseEvents { ) ); pulse.executeCallback( - provider, sequenceNumber, priceIds, updateData, @@ -478,10 +468,11 @@ contract PulseTest is Test, PulseEvents { uint256 futureTime = block.timestamp + 1 days; vm.deal(address(consumer), 1 gwei); + uint128 totalFee = calculateTotalFee(); vm.prank(address(consumer)); uint64 sequenceNumber = pulse.requestPriceUpdatesWithCallback{ - value: calculateTotalFee() - }(provider, futureTime, priceIds, CALLBACK_GAS_LIMIT); + value: totalFee + }(futureTime, priceIds, CALLBACK_GAS_LIMIT); // Try to execute callback before the requested timestamp PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds( @@ -490,10 +481,9 @@ contract PulseTest is Test, PulseEvents { mockParsePriceFeedUpdates(priceFeeds); // This will make parsePriceFeedUpdates return future-dated prices bytes[] memory updateData = createMockUpdateData(priceFeeds); - vm.prank(provider); + vm.prank(updater); // Should succeed because we're simulating receiving future-dated price updates pulse.executeCallback( - provider, sequenceNumber, priceIds, updateData, @@ -504,31 +494,6 @@ contract PulseTest is Test, PulseEvents { assertEq(consumer.lastPublishTime(), futureTime); } - function testExecuteCallbackWithWrongProvider() public { - ( - uint64 sequenceNumber, - bytes32[] memory priceIds, - uint256 publishTime - ) = setupConsumerRequest(address(consumer)); - - PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds( - publishTime - ); - mockParsePriceFeedUpdates(priceFeeds); - bytes[] memory updateData = createMockUpdateData(priceFeeds); - - address wrongProvider = address(0x999); - vm.prank(wrongProvider); - vm.expectRevert(NoSuchRequest.selector); - pulse.executeCallback( - wrongProvider, - sequenceNumber, - priceIds, - updateData, - CALLBACK_GAS_LIMIT - ); - } - function testDoubleExecuteCallback() public { ( uint64 sequenceNumber, @@ -543,9 +508,8 @@ contract PulseTest is Test, PulseEvents { bytes[] memory updateData = createMockUpdateData(priceFeeds); // First execution - vm.prank(provider); + vm.prank(updater); pulse.executeCallback( - provider, sequenceNumber, priceIds, updateData, @@ -553,10 +517,9 @@ contract PulseTest is Test, PulseEvents { ); // Second execution should fail - vm.prank(provider); + vm.prank(updater); vm.expectRevert(NoSuchRequest.selector); pulse.executeCallback( - provider, sequenceNumber, priceIds, updateData, @@ -573,12 +536,9 @@ contract PulseTest is Test, PulseEvents { for (uint256 i = 0; i < gasLimits.length; i++) { uint256 gasLimit = gasLimits[i]; - uint128 expectedFee = PROVIDER_FEE + // Base provider fee - (PROVIDER_FEE_PER_GAS * uint128(gasLimit)) + // Gas-based fee - PYTH_FEE; // Pyth oracle fee - - uint128 actualFee = pulse.getFee(provider, gasLimit); - + uint128 expectedFee = SafeCast.toUint128(tx.gasprice * gasLimit) + + PYTH_FEE; + uint128 actualFee = pulse.getFee(gasLimit); assertEq( actualFee, expectedFee, @@ -587,86 +547,73 @@ contract PulseTest is Test, PulseEvents { } // Test with zero gas limit - uint128 expectedMinFee = PROVIDER_FEE + PYTH_FEE; - uint128 actualMinFee = pulse.getFee(provider, 0); + uint128 expectedMinFee = PYTH_FEE; + uint128 actualMinFee = pulse.getFee(0); assertEq( actualMinFee, expectedMinFee, "Minimum fee calculation incorrect" ); - - // Test with unregistered provider (should return 0 fees) - address unregisteredProvider = address(0x123); - uint128 unregisteredFee = pulse.getFee( - unregisteredProvider, - gasLimits[0] - ); - assertEq( - unregisteredFee, - PYTH_FEE, - "Unregistered provider fee should only include Pyth fee" - ); } - function testWithdraw() public { + function testWithdrawFees() public { // Setup: Request price update to accrue some fees bytes32[] memory priceIds = createPriceIds(); vm.deal(address(consumer), 1 gwei); vm.prank(address(consumer)); pulse.requestPriceUpdatesWithCallback{value: calculateTotalFee()}( - provider, block.timestamp, priceIds, CALLBACK_GAS_LIMIT ); - // Get provider's balance before withdrawal - uint256 providerBalanceBefore = provider.balance; - PulseState.ProviderInfo memory infoBefore = pulse.getProviderInfo( - provider - ); + // Get admin's balance before withdrawal + uint256 adminBalanceBefore = admin.balance; + uint128 accruedFees = pulse.getAccruedFees(); - // Withdraw fees - vm.prank(provider); - pulse.withdraw(infoBefore.accruedFeesInWei); + // Withdraw fees as admin + vm.prank(admin); + pulse.withdrawFees(accruedFees); // Verify balances assertEq( - provider.balance, - providerBalanceBefore + infoBefore.accruedFeesInWei + admin.balance, + adminBalanceBefore + accruedFees, + "Admin balance should increase by withdrawn amount" ); - - PulseState.ProviderInfo memory infoAfter = pulse.getProviderInfo( - provider + assertEq( + pulse.getAccruedFees(), + 0, + "Contract should have no fees after withdrawal" ); - assertEq(infoAfter.accruedFeesInWei, 0); } - function testWithdrawInsufficientBalance() public { - vm.prank(provider); + function testWithdrawFeesUnauthorized() public { + vm.prank(address(0xdead)); + vm.expectRevert("Only admin can withdraw fees"); + pulse.withdrawFees(1 ether); + } + + function testWithdrawFeesInsufficientBalance() public { + vm.prank(admin); vm.expectRevert("Insufficient balance"); - pulse.withdraw(1 ether); + pulse.withdrawFees(1 ether); } function testSetAndWithdrawAsFeeManager() public { address feeManager = address(0x789); - // Set fee manager - vm.prank(provider); + // Set fee manager as admin + vm.prank(admin); pulse.setFeeManager(feeManager); - // Verify fee manager was set - PulseState.ProviderInfo memory info = pulse.getProviderInfo(provider); - assertEq(info.feeManager, feeManager); - // Setup: Request price update to accrue some fees bytes32[] memory priceIds = createPriceIds(); vm.deal(address(consumer), 1 gwei); vm.prank(address(consumer)); pulse.requestPriceUpdatesWithCallback{value: calculateTotalFee()}( - provider, block.timestamp, priceIds, CALLBACK_GAS_LIMIT @@ -674,48 +621,44 @@ contract PulseTest is Test, PulseEvents { // Test withdrawal as fee manager uint256 managerBalanceBefore = feeManager.balance; - info = pulse.getProviderInfo(provider); + uint128 accruedFees = pulse.getAccruedFees(); vm.prank(feeManager); - pulse.withdrawAsFeeManager(provider, info.accruedFeesInWei); + pulse.withdrawAsFeeManager(accruedFees); assertEq( feeManager.balance, - managerBalanceBefore + info.accruedFeesInWei + managerBalanceBefore + accruedFees, + "Fee manager balance should increase by withdrawn amount" + ); + assertEq( + pulse.getAccruedFees(), + 0, + "Contract should have no fees after withdrawal" ); } - function testMaxNumPrices() public { - // Set max number of prices - vm.prank(provider); - pulse.setMaxNumPrices(1); - - // Try to request more prices than allowed - bytes32[] memory priceIds = new bytes32[](2); - priceIds[0] = BTC_PRICE_FEED_ID; - priceIds[1] = ETH_PRICE_FEED_ID; - - vm.deal(address(consumer), 1 gwei); - vm.prank(address(consumer)); - - vm.expectRevert( - abi.encodeWithSelector(ExceedsMaxPrices.selector, 2, 1) - ); - pulse.requestPriceUpdatesWithCallback{value: calculateTotalFee()}( - provider, - block.timestamp, - priceIds, - CALLBACK_GAS_LIMIT - ); + function testSetFeeManagerUnauthorized() public { + address feeManager = address(0x789); + vm.prank(address(0xdead)); + vm.expectRevert("Only admin can set fee manager"); + pulse.setFeeManager(feeManager); } - function testSetProviderUri() public { - bytes memory newUri = "https://updated-provider.com"; + function testWithdrawAsFeeManagerUnauthorized() public { + vm.prank(address(0xdead)); + vm.expectRevert("Only fee manager"); + pulse.withdrawAsFeeManager(1 ether); + } - vm.prank(provider); - pulse.setProviderUri(newUri); + function testWithdrawAsFeeManagerInsufficientBalance() public { + // Set up fee manager first + address feeManager = address(0x789); + vm.prank(admin); + pulse.setFeeManager(feeManager); - PulseState.ProviderInfo memory info = pulse.getProviderInfo(provider); - assertEq(info.uri, newUri); + vm.prank(feeManager); + vm.expectRevert("Insufficient balance"); + pulse.withdrawAsFeeManager(1 ether); } } From f3ec2477521b86d211f29bbb824172d3b9a698ae Mon Sep 17 00:00:00 2001 From: Daniel Chew Date: Mon, 18 Nov 2024 20:05:25 +0900 Subject: [PATCH 18/20] address comments --- .../contracts/contracts/pulse/IPulse.sol | 3 +- .../contracts/contracts/pulse/Pulse.sol | 31 ++-- .../contracts/contracts/pulse/PulseErrors.sol | 3 +- .../contracts/contracts/pulse/PulseState.sol | 2 +- .../ethereum/contracts/forge-test/Pulse.t.sol | 168 ++++++------------ 5 files changed, 67 insertions(+), 140 deletions(-) diff --git a/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol b/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol index 125d06f357..22891e98a3 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol @@ -24,9 +24,8 @@ interface IPulse is PulseEvents { function executeCallback( uint64 sequenceNumber, - bytes32[] calldata priceIds, bytes[] calldata updateData, - uint256 callbackGasLimit + bytes32[] calldata priceIds ) external payable; // Getters diff --git a/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol b/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol index 167f3e9f7a..6ede25997b 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol @@ -48,7 +48,7 @@ abstract contract Pulse is IPulse, PulseState { Request storage req = allocRequest(requestSequenceNumber); req.sequenceNumber = requestSequenceNumber; req.publishTime = publishTime; - req.priceIds = priceIds; + req.priceIdsHash = keccak256(abi.encode(priceIds)); req.callbackGasLimit = callbackGasLimit; req.requester = msg.sender; @@ -59,24 +59,20 @@ abstract contract Pulse is IPulse, PulseState { function executeCallback( uint64 sequenceNumber, - bytes32[] calldata priceIds, bytes[] calldata updateData, - uint256 callbackGasLimit + bytes32[] calldata priceIds ) external payable override { Request storage req = findActiveRequest(sequenceNumber); + bytes32 providedPriceIdsHash = keccak256(abi.encode(priceIds)); + bytes32 storedPriceIdsHash = req.priceIdsHash; - if ( - keccak256(abi.encode(req.priceIds)) != - keccak256(abi.encode(priceIds)) - ) { - revert InvalidPriceIds(priceIds, req.priceIds); + if (providedPriceIdsHash != storedPriceIdsHash) { + revert InvalidPriceIds(providedPriceIdsHash, storedPriceIdsHash); } - if (req.callbackGasLimit != callbackGasLimit) { - revert InvalidCallbackGasLimit( - callbackGasLimit, - req.callbackGasLimit - ); + // Check if there's enough gas left for the callback + if (gasleft() < req.callbackGasLimit) { + revert InsufficientGas(); } PythStructs.PriceFeed[] memory priceFeeds = IPyth(_state.pyth) @@ -90,12 +86,9 @@ abstract contract Pulse is IPulse, PulseState { uint256 publishTime = priceFeeds[0].price.publishTime; try - IPulseConsumer(req.requester).pulseCallback( - sequenceNumber, - msg.sender, - publishTime, - priceIds - ) + IPulseConsumer(req.requester).pulseCallback{ + gas: req.callbackGasLimit + }(sequenceNumber, msg.sender, publishTime, priceIds) { // Callback succeeded emitPriceUpdate(sequenceNumber, publishTime, priceIds, priceFeeds); diff --git a/target_chains/ethereum/contracts/contracts/pulse/PulseErrors.sol b/target_chains/ethereum/contracts/contracts/pulse/PulseErrors.sol index 535ad4d746..c2fe41ccb6 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/PulseErrors.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/PulseErrors.sol @@ -8,6 +8,7 @@ error InsufficientFee(); error Unauthorized(); error InvalidCallbackGas(); error CallbackFailed(); -error InvalidPriceIds(bytes32[] requested, bytes32[] stored); +error InvalidPriceIds(bytes32 providedPriceIdsHash, bytes32 storedPriceIdsHash); error InvalidCallbackGasLimit(uint256 requested, uint256 stored); error ExceedsMaxPrices(uint32 requested, uint32 maxAllowed); +error InsufficientGas(); diff --git a/target_chains/ethereum/contracts/contracts/pulse/PulseState.sol b/target_chains/ethereum/contracts/contracts/pulse/PulseState.sol index 08052dd1e3..6b0b48fc61 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/PulseState.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/PulseState.sol @@ -9,7 +9,7 @@ contract PulseState { struct Request { uint64 sequenceNumber; uint256 publishTime; - bytes32[] priceIds; + bytes32 priceIdsHash; uint256 callbackGasLimit; address requester; } diff --git a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol index b0963b0429..7e7d276e43 100644 --- a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol +++ b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol @@ -190,7 +190,7 @@ contract PulseTest is Test, PulseEvents { PulseState.Request memory expectedRequest = PulseState.Request({ sequenceNumber: 1, publishTime: publishTime, - priceIds: priceIds, + priceIdsHash: keccak256(abi.encode(priceIds)), callbackGasLimit: CALLBACK_GAS_LIMIT, requester: address(consumer) }); @@ -209,10 +209,7 @@ contract PulseTest is Test, PulseEvents { PulseState.Request memory lastRequest = pulse.getRequest(1); assertEq(lastRequest.sequenceNumber, expectedRequest.sequenceNumber); assertEq(lastRequest.publishTime, expectedRequest.publishTime); - assertEq( - keccak256(abi.encode(lastRequest.priceIds)), - keccak256(abi.encode(expectedRequest.priceIds)) - ); + assertEq(lastRequest.priceIdsHash, expectedRequest.priceIdsHash); assertEq( lastRequest.callbackGasLimit, expectedRequest.callbackGasLimit @@ -245,7 +242,6 @@ contract PulseTest is Test, PulseEvents { // Fund the consumer contract vm.deal(address(consumer), 1 gwei); - uint128 totalFee = calculateTotalFee(); // Step 1: Make the request as consumer @@ -278,7 +274,7 @@ contract PulseTest is Test, PulseEvents { expectedPublishTimes[1] = publishTime; // Expect the PriceUpdateExecuted event with all price data - vm.expectEmit(true, true, false, true); + vm.expectEmit(); emit PriceUpdateExecuted( sequenceNumber, updater, @@ -294,12 +290,7 @@ contract PulseTest is Test, PulseEvents { bytes[] memory updateData = createMockUpdateData(priceFeeds); vm.prank(updater); - pulse.executeCallback( - sequenceNumber, - priceIds, - updateData, - CALLBACK_GAS_LIMIT - ); + pulse.executeCallback(sequenceNumber, updateData, priceIds); // Verify callback was executed assertEq(consumer.lastSequenceNumber(), sequenceNumber); @@ -321,7 +312,7 @@ contract PulseTest is Test, PulseEvents { mockParsePriceFeedUpdates(priceFeeds); bytes[] memory updateData = createMockUpdateData(priceFeeds); - vm.expectEmit(true, true, true, true); + vm.expectEmit(); emit PriceUpdateCallbackFailed( sequenceNumber, updater, @@ -332,12 +323,7 @@ contract PulseTest is Test, PulseEvents { ); vm.prank(updater); - pulse.executeCallback( - sequenceNumber, - priceIds, - updateData, - CALLBACK_GAS_LIMIT - ); + pulse.executeCallback(sequenceNumber, updateData, priceIds); } function testExecuteCallbackCustomErrorFailure() public { @@ -355,7 +341,7 @@ contract PulseTest is Test, PulseEvents { mockParsePriceFeedUpdates(priceFeeds); bytes[] memory updateData = createMockUpdateData(priceFeeds); - vm.expectEmit(true, true, true, true); + vm.expectEmit(); emit PriceUpdateCallbackFailed( sequenceNumber, updater, @@ -366,100 +352,28 @@ contract PulseTest is Test, PulseEvents { ); vm.prank(updater); - pulse.executeCallback( - sequenceNumber, - priceIds, - updateData, - CALLBACK_GAS_LIMIT - ); - } - - // Test executing callback with mismatched price IDs - function testExecuteCallbackWithMismatchedPriceIds() public { - ( - uint64 sequenceNumber, - bytes32[] memory originalPriceIds, - uint256 publishTime - ) = setupConsumerRequest(address(consumer)); - - // Create different price IDs array - bytes32[] memory differentPriceIds = new bytes32[](1); - differentPriceIds[0] = bytes32(uint256(1)); // Different price ID - - PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds( - publishTime - ); - mockParsePriceFeedUpdates(priceFeeds); - bytes[] memory updateData = createMockUpdateData(priceFeeds); - - vm.prank(updater); - vm.expectRevert( - abi.encodeWithSelector( - InvalidPriceIds.selector, - differentPriceIds, - originalPriceIds - ) - ); - pulse.executeCallback( - sequenceNumber, - differentPriceIds, - updateData, - CALLBACK_GAS_LIMIT - ); + pulse.executeCallback(sequenceNumber, updateData, priceIds); } function testExecuteCallbackWithInsufficientGas() public { + // Setup request with 1M gas limit ( uint64 sequenceNumber, bytes32[] memory priceIds, uint256 publishTime ) = setupConsumerRequest(address(consumer)); + // Setup mock data PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds( publishTime ); mockParsePriceFeedUpdates(priceFeeds); bytes[] memory updateData = createMockUpdateData(priceFeeds); + // Try executing with only 10K gas when 1M is required vm.prank(updater); - vm.expectRevert(); - pulse.executeCallback{gas: 10000}( - sequenceNumber, - priceIds, - updateData, - CALLBACK_GAS_LIMIT - ); - } - - function testExecuteCallbackWithInvalidGasLimit() public { - ( - uint64 sequenceNumber, - bytes32[] memory priceIds, - uint256 publishTime - ) = setupConsumerRequest(address(consumer)); - - PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds( - publishTime - ); - mockParsePriceFeedUpdates(priceFeeds); - bytes[] memory updateData = createMockUpdateData(priceFeeds); - - // Try to execute with different gas limit than what was requested - uint256 differentGasLimit = CALLBACK_GAS_LIMIT + 1000; - vm.prank(updater); - vm.expectRevert( - abi.encodeWithSelector( - InvalidCallbackGasLimit.selector, - differentGasLimit, - CALLBACK_GAS_LIMIT - ) - ); - pulse.executeCallback( - sequenceNumber, - priceIds, - updateData, - differentGasLimit - ); + vm.expectRevert(InsufficientGas.selector); + pulse.executeCallback{gas: 10000}(sequenceNumber, updateData, priceIds); // Will fail because gasleft() < callbackGasLimit } function testExecuteCallbackWithFutureTimestamp() public { @@ -483,12 +397,7 @@ contract PulseTest is Test, PulseEvents { vm.prank(updater); // Should succeed because we're simulating receiving future-dated price updates - pulse.executeCallback( - sequenceNumber, - priceIds, - updateData, - CALLBACK_GAS_LIMIT - ); + pulse.executeCallback(sequenceNumber, updateData, priceIds); // Verify the callback was executed with future timestamp assertEq(consumer.lastPublishTime(), futureTime); @@ -509,22 +418,12 @@ contract PulseTest is Test, PulseEvents { // First execution vm.prank(updater); - pulse.executeCallback( - sequenceNumber, - priceIds, - updateData, - CALLBACK_GAS_LIMIT - ); + pulse.executeCallback(sequenceNumber, updateData, priceIds); // Second execution should fail vm.prank(updater); vm.expectRevert(NoSuchRequest.selector); - pulse.executeCallback( - sequenceNumber, - priceIds, - updateData, - CALLBACK_GAS_LIMIT - ); + pulse.executeCallback(sequenceNumber, updateData, priceIds); } function testGetFee() public { @@ -661,4 +560,39 @@ contract PulseTest is Test, PulseEvents { vm.expectRevert("Insufficient balance"); pulse.withdrawAsFeeManager(1 ether); } + + // Add new test for invalid priceIds + function testExecuteCallbackWithInvalidPriceIds() public { + bytes32[] memory priceIds = createPriceIds(); + uint256 publishTime = block.timestamp; + + // Setup request + (uint64 sequenceNumber, , ) = setupConsumerRequest(address(consumer)); + + // Create different priceIds + bytes32[] memory wrongPriceIds = new bytes32[](2); + wrongPriceIds[0] = bytes32(uint256(1)); // Different price IDs + wrongPriceIds[1] = bytes32(uint256(2)); + + PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds( + publishTime + ); + mockParsePriceFeedUpdates(priceFeeds); + bytes[] memory updateData = createMockUpdateData(priceFeeds); + + // Calculate hashes for both arrays + bytes32 providedPriceIdsHash = keccak256(abi.encode(wrongPriceIds)); + bytes32 storedPriceIdsHash = keccak256(abi.encode(priceIds)); + + // Should revert when trying to execute with wrong priceIds + vm.prank(updater); + vm.expectRevert( + abi.encodeWithSelector( + InvalidPriceIds.selector, + providedPriceIdsHash, + storedPriceIdsHash + ) + ); + pulse.executeCallback(sequenceNumber, updateData, wrongPriceIds); + } } From 13c3967dfaedbf12785750b36249224e59c47b96 Mon Sep 17 00:00:00 2001 From: Daniel Chew Date: Mon, 18 Nov 2024 20:54:12 +0900 Subject: [PATCH 19/20] address comments --- .../contracts/contracts/pulse/IPulse.sol | 4 +- .../contracts/contracts/pulse/Pulse.sol | 22 +++--- .../contracts/contracts/pulse/PulseEvents.sol | 2 - .../ethereum/contracts/forge-test/Pulse.t.sol | 67 ++++++++++++++----- 4 files changed, 60 insertions(+), 35 deletions(-) diff --git a/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol b/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol index 22891e98a3..6e5d44eaf9 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol @@ -2,6 +2,7 @@ pragma solidity ^0.8.0; +import "@pythnetwork/pyth-sdk-solidity/IPyth.sol"; import "./PulseEvents.sol"; import "./PulseState.sol"; @@ -9,8 +10,7 @@ interface IPulseConsumer { function pulseCallback( uint64 sequenceNumber, address updater, - uint256 publishTime, - bytes32[] calldata priceIds + PythStructs.PriceFeed[] memory priceFeeds ) external; } diff --git a/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol b/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol index 6ede25997b..eeb1cccf2c 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol @@ -63,18 +63,15 @@ abstract contract Pulse is IPulse, PulseState { bytes32[] calldata priceIds ) external payable override { Request storage req = findActiveRequest(sequenceNumber); + + // Verify priceIds match bytes32 providedPriceIdsHash = keccak256(abi.encode(priceIds)); bytes32 storedPriceIdsHash = req.priceIdsHash; - if (providedPriceIdsHash != storedPriceIdsHash) { revert InvalidPriceIds(providedPriceIdsHash, storedPriceIdsHash); } - // Check if there's enough gas left for the callback - if (gasleft() < req.callbackGasLimit) { - revert InsufficientGas(); - } - + // Parse price feeds first to measure gas usage PythStructs.PriceFeed[] memory priceFeeds = IPyth(_state.pyth) .parsePriceFeedUpdates( updateData, @@ -83,21 +80,23 @@ abstract contract Pulse is IPulse, PulseState { SafeCast.toUint64(req.publishTime) ); - uint256 publishTime = priceFeeds[0].price.publishTime; + // Check if enough gas remains for the callback + if (gasleft() < req.callbackGasLimit) { + revert InsufficientGas(); + } try IPulseConsumer(req.requester).pulseCallback{ gas: req.callbackGasLimit - }(sequenceNumber, msg.sender, publishTime, priceIds) + }(sequenceNumber, msg.sender, priceFeeds) { // Callback succeeded - emitPriceUpdate(sequenceNumber, publishTime, priceIds, priceFeeds); + emitPriceUpdate(sequenceNumber, priceIds, priceFeeds); } catch Error(string memory reason) { // Explicit revert/require emit PriceUpdateCallbackFailed( sequenceNumber, msg.sender, - publishTime, priceIds, req.requester, reason @@ -107,7 +106,6 @@ abstract contract Pulse is IPulse, PulseState { emit PriceUpdateCallbackFailed( sequenceNumber, msg.sender, - publishTime, priceIds, req.requester, "low-level error (possibly out of gas)" @@ -119,7 +117,6 @@ abstract contract Pulse is IPulse, PulseState { function emitPriceUpdate( uint64 sequenceNumber, - uint256 publishTime, bytes32[] memory priceIds, PythStructs.PriceFeed[] memory priceFeeds ) internal { @@ -138,7 +135,6 @@ abstract contract Pulse is IPulse, PulseState { emit PriceUpdateExecuted( sequenceNumber, msg.sender, - publishTime, priceIds, prices, conf, diff --git a/target_chains/ethereum/contracts/contracts/pulse/PulseEvents.sol b/target_chains/ethereum/contracts/contracts/pulse/PulseEvents.sol index 1c96797b39..4b7abfbbc3 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/PulseEvents.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/PulseEvents.sol @@ -9,7 +9,6 @@ interface PulseEvents { event PriceUpdateExecuted( uint64 indexed sequenceNumber, address indexed updater, - uint256 publishTime, bytes32[] priceIds, int64[] prices, uint64[] conf, @@ -22,7 +21,6 @@ interface PulseEvents { event PriceUpdateCallbackFailed( uint64 indexed sequenceNumber, address indexed updater, - uint256 publishTime, bytes32[] priceIds, address requester, string reason diff --git a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol index 7e7d276e43..81edc4115a 100644 --- a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol +++ b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol @@ -3,6 +3,7 @@ pragma solidity ^0.8.0; import "forge-std/Test.sol"; +import "@pythnetwork/pyth-sdk-solidity/IPyth.sol"; import "@openzeppelin/contracts/proxy/ERC1967/ERC1967Proxy.sol"; import "../contracts/pulse/PulseUpgradeable.sol"; import "../contracts/pulse/IPulse.sol"; @@ -13,19 +14,26 @@ import "../contracts/pulse/PulseErrors.sol"; contract MockPulseConsumer is IPulseConsumer { uint64 public lastSequenceNumber; address public lastUpdater; - uint256 public lastPublishTime; - bytes32[] public lastPriceIds; + PythStructs.PriceFeed[] private _lastPriceFeeds; function pulseCallback( uint64 sequenceNumber, address updater, - uint256 publishTime, - bytes32[] calldata priceIds + PythStructs.PriceFeed[] memory priceFeeds ) external override { lastSequenceNumber = sequenceNumber; lastUpdater = updater; - lastPublishTime = publishTime; - lastPriceIds = priceIds; + for (uint i = 0; i < priceFeeds.length; i++) { + _lastPriceFeeds.push(priceFeeds[i]); + } + } + + function lastPriceFeeds() + external + view + returns (PythStructs.PriceFeed[] memory) + { + return _lastPriceFeeds; } } @@ -33,8 +41,7 @@ contract FailingPulseConsumer is IPulseConsumer { function pulseCallback( uint64, address, - uint256, - bytes32[] calldata + PythStructs.PriceFeed[] memory ) external pure override { revert("callback failed"); } @@ -46,8 +53,7 @@ contract CustomErrorPulseConsumer is IPulseConsumer { function pulseCallback( uint64, address, - uint256, - bytes32[] calldata + PythStructs.PriceFeed[] memory ) external pure override { revert CustomError("callback failed"); } @@ -278,7 +284,6 @@ contract PulseTest is Test, PulseEvents { emit PriceUpdateExecuted( sequenceNumber, updater, - publishTime, priceIds, expectedPrices, expectedConf, @@ -294,7 +299,22 @@ contract PulseTest is Test, PulseEvents { // Verify callback was executed assertEq(consumer.lastSequenceNumber(), sequenceNumber); - assertEq(consumer.lastPublishTime(), publishTime); + + // Compare price feeds array length + PythStructs.PriceFeed[] memory lastFeeds = consumer.lastPriceFeeds(); + assertEq(lastFeeds.length, priceFeeds.length); + + // Compare each price feed + for (uint i = 0; i < priceFeeds.length; i++) { + assertEq(lastFeeds[i].id, priceFeeds[i].id); + assertEq(lastFeeds[i].price.price, priceFeeds[i].price.price); + assertEq(lastFeeds[i].price.conf, priceFeeds[i].price.conf); + assertEq(lastFeeds[i].price.expo, priceFeeds[i].price.expo); + assertEq( + lastFeeds[i].price.publishTime, + priceFeeds[i].price.publishTime + ); + } } function testExecuteCallbackFailure() public { @@ -316,7 +336,6 @@ contract PulseTest is Test, PulseEvents { emit PriceUpdateCallbackFailed( sequenceNumber, updater, - publishTime, priceIds, address(failingConsumer), "callback failed" @@ -345,7 +364,6 @@ contract PulseTest is Test, PulseEvents { emit PriceUpdateCallbackFailed( sequenceNumber, updater, - publishTime, priceIds, address(failingConsumer), "low-level error (possibly out of gas)" @@ -370,10 +388,14 @@ contract PulseTest is Test, PulseEvents { mockParsePriceFeedUpdates(priceFeeds); bytes[] memory updateData = createMockUpdateData(priceFeeds); - // Try executing with only 10K gas when 1M is required + // Try executing with only 100K gas when 1M is required vm.prank(updater); vm.expectRevert(InsufficientGas.selector); - pulse.executeCallback{gas: 10000}(sequenceNumber, updateData, priceIds); // Will fail because gasleft() < callbackGasLimit + pulse.executeCallback{gas: 100000}( + sequenceNumber, + updateData, + priceIds + ); // Will fail because gasleft() < callbackGasLimit } function testExecuteCallbackWithFutureTimestamp() public { @@ -399,8 +421,17 @@ contract PulseTest is Test, PulseEvents { // Should succeed because we're simulating receiving future-dated price updates pulse.executeCallback(sequenceNumber, updateData, priceIds); - // Verify the callback was executed with future timestamp - assertEq(consumer.lastPublishTime(), futureTime); + // Compare price feeds array length + PythStructs.PriceFeed[] memory lastFeeds = consumer.lastPriceFeeds(); + assertEq(lastFeeds.length, priceFeeds.length); + + // Compare each price feed publish time + for (uint i = 0; i < priceFeeds.length; i++) { + assertEq( + lastFeeds[i].price.publishTime, + priceFeeds[i].price.publishTime + ); + } } function testDoubleExecuteCallback() public { From 4b1ccafe50756f3fead404d5f5f27aba264e4cbf Mon Sep 17 00:00:00 2001 From: Daniel Chew Date: Fri, 15 Nov 2024 15:44:13 +0900 Subject: [PATCH 20/20] init argus --- apps/argus/Cargo.toml | 8 + apps/argus/README.md | 19 +++ apps/argus/config.sample.yaml | 18 +++ apps/argus/rust-toolchain | 1 + apps/argus/src/config.rs | 60 +++++++ apps/argus/src/contract.rs | 41 +++++ apps/argus/src/error.rs | 64 ++++++++ apps/argus/src/hermes.rs | 129 +++++++++++++++ apps/argus/src/keeper.rs | 293 ++++++++++++++++++++++++++++++++++ apps/argus/src/main.rs | 94 +++++++++++ apps/argus/src/storage.rs | 164 +++++++++++++++++++ apps/argus/src/types.rs | 32 ++++ 12 files changed, 923 insertions(+) create mode 100644 apps/argus/Cargo.toml create mode 100644 apps/argus/README.md create mode 100644 apps/argus/config.sample.yaml create mode 100644 apps/argus/rust-toolchain create mode 100644 apps/argus/src/config.rs create mode 100644 apps/argus/src/contract.rs create mode 100644 apps/argus/src/error.rs create mode 100644 apps/argus/src/hermes.rs create mode 100644 apps/argus/src/keeper.rs create mode 100644 apps/argus/src/main.rs create mode 100644 apps/argus/src/storage.rs create mode 100644 apps/argus/src/types.rs diff --git a/apps/argus/Cargo.toml b/apps/argus/Cargo.toml new file mode 100644 index 0000000000..e53fa0894b --- /dev/null +++ b/apps/argus/Cargo.toml @@ -0,0 +1,8 @@ +[package] +name = "argus" +version = "0.1.0" +edition = "2021" + +[dependencies] +alloy = { version = "0.3", features = ["full", "node-bindings"] } +tokio = { version = "1.28", features = ["full"] } diff --git a/apps/argus/README.md b/apps/argus/README.md new file mode 100644 index 0000000000..7d5033e4b2 --- /dev/null +++ b/apps/argus/README.md @@ -0,0 +1,19 @@ +# Argus + +Argus is a webservice that serves price updates according to the Pulse protocol. +The service also operates a keeper task that performs callback transactions for user requests. + +A single instance of this service can simultaneously serve price updates for several different blockchains. +Each blockchain is configured in `config.yaml`. + +## How It Works + +1. Continuously polls the Pulse contract's storage to discover new price update requests +2. Fetches required price data from Pyth Network +3. Batches multiple requests when possible for gas efficiency +4. Executes callbacks with appropriate gas limits specified in the original requests +5. Monitors transaction success and handles retries when necessary + +## Architecture + +The service is built on Rust for performance and reliability, sharing architectural patterns with Fortuna (the Entropy protocol's keeper service). However, unlike Fortuna which relies on event subscriptions, Argus uses direct storage polling for more reliable request discovery. diff --git a/apps/argus/config.sample.yaml b/apps/argus/config.sample.yaml new file mode 100644 index 0000000000..8a4d247ecc --- /dev/null +++ b/apps/argus/config.sample.yaml @@ -0,0 +1,18 @@ +chains: + ethereum: + geth_rpc_addr: "https://eth-mainnet.g.alchemy.com/v2/YOUR-API-KEY" + contract_addr: "0x1234..." + poll_interval: 5 + min_batch_size: 1 + max_batch_size: 10 + batch_timeout: 30 + min_keeper_balance: 1000000000000000000 # 1 ETH + gas_limit: 500000 + +provider: + uri: "http://localhost:8080" + address: "0x5678..." + private_key: "0xabcd..." # Provider private key + +keeper: + private_key: "0xdef0..." # Keeper private key diff --git a/apps/argus/rust-toolchain b/apps/argus/rust-toolchain new file mode 100644 index 0000000000..f984c0ee0c --- /dev/null +++ b/apps/argus/rust-toolchain @@ -0,0 +1 @@ +nightly-2023-07-23 diff --git a/apps/argus/src/config.rs b/apps/argus/src/config.rs new file mode 100644 index 0000000000..271a92feab --- /dev/null +++ b/apps/argus/src/config.rs @@ -0,0 +1,60 @@ +use { + alloy::{ + primitives::Address, + providers::{Provider, ProviderBuilder}, + signers::Signer, + }, + anyhow::Result, + serde::Deserialize, + std::{fs, time::Duration}, +}; + +#[derive(Debug, Deserialize)] +pub struct Config { + pub chains: HashMap, + pub provider: ProviderConfig, + pub keeper: KeeperConfig, +} + +#[derive(Debug, Deserialize)] +pub struct ChainConfig { + pub geth_rpc_addr: String, + pub contract_addr: Address, + pub poll_interval: u64, // in seconds + pub min_batch_size: usize, + pub max_batch_size: usize, + pub batch_timeout: u64, // in seconds + pub min_keeper_balance: u64, + pub gas_limit: u64, +} + +#[derive(Debug, Deserialize)] +pub struct ProviderConfig { + pub uri: String, + pub address: Address, + pub private_key: SecretString, +} + +#[derive(Debug, Deserialize)] +pub struct KeeperConfig { + pub private_key: SecretString, +} + +#[derive(Debug, Deserialize)] +pub struct SecretString(String); + +impl Config { + pub fn load(path: &str) -> Result { + let contents = fs::read_to_string(path)?; + Ok(serde_yaml::from_str(&contents)?) + } + + pub fn create_provider(&self, chain_id: &str) -> Result { + let chain = self.chains.get(chain_id).ok_or_else(|| anyhow!("Chain not found"))?; + Ok(Provider::builder().rpc_url(&chain.geth_rpc_addr).build()?) + } + + pub fn create_signer(&self, secret: &SecretString) -> Result { + Ok(Signer::from_private_key(secret.0.parse()?)?) + } +} diff --git a/apps/argus/src/contract.rs b/apps/argus/src/contract.rs new file mode 100644 index 0000000000..576c8e17eb --- /dev/null +++ b/apps/argus/src/contract.rs @@ -0,0 +1,41 @@ +use { + alloy::{ + contract::{Contract, ContractInstance}, + primitives::{Address, Bytes, U256}, + providers::Provider, + signers::Signer, + }, + anyhow::Result, + std::sync::Arc, +}; + +// Contract ABI definition +abigen!(Pulse, "target_chains/ethereum/contracts/contracts/pulse/IPulse.sol"); + +pub struct PulseContract { + instance: ContractInstance, Pulse>, +} + +impl PulseContract

{ + pub fn new(address: Address, provider: Arc

) -> Self { + Self { + instance: ContractInstance::new(address, Arc::new(Pulse::new()), provider), + } + } + + pub async fn execute_callback( + &self, + provider: Address, + sequence_number: U64, + price_ids: Vec<[u8; 32]>, + update_data: Vec, + callback_gas_limit: U256, + ) -> Result { + let tx = self.instance + .execute_callback(provider, sequence_number, price_ids, update_data, callback_gas_limit) + .send() + .await?; + + Ok(tx.tx_hash()) + } +} diff --git a/apps/argus/src/error.rs b/apps/argus/src/error.rs new file mode 100644 index 0000000000..73b063e873 --- /dev/null +++ b/apps/argus/src/error.rs @@ -0,0 +1,64 @@ +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum ArgusError { + #[error("Failed to fetch price updates from Hermes: {0}")] + HermesError(#[from] HermesError), + + #[error("Contract error: {0}")] + ContractError(#[from] ContractError), + + #[error("Storage error: {0}")] + StorageError(#[from] StorageError), + + #[error("Configuration error: {0}")] + ConfigError(String), + + #[error(transparent)] + Other(#[from] anyhow::Error), +} + +#[derive(Debug, Error)] +pub enum HermesError { + #[error("HTTP request failed: {0}")] + RequestFailed(#[from] reqwest::Error), + + #[error("Invalid response encoding: {0}")] + InvalidEncoding(String), + + #[error("No price updates found")] + NoPriceUpdates, + + #[error("Failed to parse price data: {0}")] + ParseError(String), + + #[error("Failed to decode hex data: {0}")] + HexDecodeError(#[from] hex::FromHexError), +} + +#[derive(Debug, Error)] +pub enum ContractError { + #[error("Transaction failed: {0}")] + TransactionFailed(String), + + #[error("Gas estimation failed: {0}")] + GasEstimationFailed(String), + + #[error("Invalid contract address: {0}")] + InvalidAddress(String), + + #[error("Contract call failed: {0}")] + CallFailed(String), +} + +#[derive(Debug, Error)] +pub enum StorageError { + #[error("Failed to read storage slot: {0}")] + ReadError(String), + + #[error("Failed to parse storage data: {0}")] + ParseError(String), + + #[error("Invalid storage layout: {0}")] + InvalidLayout(String), +} diff --git a/apps/argus/src/hermes.rs b/apps/argus/src/hermes.rs new file mode 100644 index 0000000000..3609e038e0 --- /dev/null +++ b/apps/argus/src/hermes.rs @@ -0,0 +1,129 @@ +use { + crate::{ + error::{ + ArgusError, + HermesError, + }, + types::PriceData, + }, + reqwest::Client, + serde::{ + Deserialize, + Serialize, + }, + std::time::{ + SystemTime, + UNIX_EPOCH, + }, +}; + +const HERMES_API_URL: &str = "https://hermes.pyth.network"; + +#[derive(Debug, Serialize, Deserialize)] +struct HermesResponse { + binary: BinaryUpdate, + parsed: Option>, +} + +#[derive(Debug, Serialize, Deserialize)] +struct BinaryUpdate { + data: Vec, + encoding: String, +} + +#[derive(Debug, Serialize, Deserialize)] +struct ParsedPriceUpdate { + id: String, + price: RpcPrice, + ema_price: RpcPrice, +} + +#[derive(Debug, Serialize, Deserialize)] +struct RpcPrice { + price: String, + conf: String, + expo: i32, + publish_time: u64, +} + +pub struct HermesClient { + client: Client, +} + +impl HermesClient { + pub fn new() -> Self { + Self { + client: Client::new(), + } + } + + pub async fn get_price_updates( + &self, + price_ids: &[[u8; 32]], + ) -> Result)>, HermesError> { + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .map_err(|e| HermesError::ParseError(format!("Failed to get timestamp: {}", e)))? + .as_secs(); + + let mut url = format!( + "{}/v2/updates/price/{}?parsed=true&encoding=hex", + HERMES_API_URL, now + ); + + for price_id in price_ids { + url.push_str(&format!("&ids[]={}", hex::encode(price_id))); + } + + let response = self + .client + .get(&url) + .send() + .await + .map_err(|e| HermesError::RequestFailed(e))? + .error_for_status() + .map_err(|e| HermesError::RequestFailed(e))? + .json::() + .await + .map_err(|e| HermesError::RequestFailed(e))?; + + let update_data = if response.binary.encoding == "hex" { + response + .binary + .data + .into_iter() + .map(|data| hex::decode(&data)) + .collect::, _>>() + .map_err(HermesError::HexDecodeError)? + } else { + return Err(HermesError::InvalidEncoding(response.binary.encoding)); + }; + + let price_updates = response.parsed.ok_or(HermesError::NoPriceUpdates)?; + + if price_updates.is_empty() { + return Err(HermesError::NoPriceUpdates); + } + + let mut results = Vec::with_capacity(price_updates.len()); + for (update, data) in price_updates.into_iter().zip(update_data) { + let price_data = PriceData { + price: update + .price + .price + .parse() + .map_err(|e| HermesError::ParseError(format!("Invalid price: {}", e)))?, + conf: update + .price + .conf + .parse() + .map_err(|e| HermesError::ParseError(format!("Invalid conf: {}", e)))?, + expo: update.price.expo, + publish_time: update.price.publish_time, + }; + results.push((price_data, data)); + } + + Ok(results) + } +} diff --git a/apps/argus/src/keeper.rs b/apps/argus/src/keeper.rs new file mode 100644 index 0000000000..43f808e6e0 --- /dev/null +++ b/apps/argus/src/keeper.rs @@ -0,0 +1,293 @@ +use { + crate::{ + contract::PulseContract, + hermes::HermesClient, + types::{PriceData, PriceUpdateRequest, UpdateBatch}, + error::{ArgusError, ContractError}, + }, + alloy::{ + primitives::{Address, Bytes, U256}, + providers::Provider, + signers::Signer, + }, + anyhow::Result, + std::{collections::HashMap, sync::Arc}, + tokio::{sync::mpsc, time}, +}; + +#[derive(Clone)] +pub struct KeeperMetrics { + pub transactions_submitted: Counter, + pub transaction_failures: Counter, + pub gas_used: Histogram, + pub batch_size: Histogram, +} + +pub struct Keeper { + provider: Arc, + signer: Arc, + request_rx: mpsc::Receiver, + metrics: Arc, + min_batch_size: usize, + max_batch_size: usize, + batch_timeout: Duration, + hermes_client: HermesClient, +} + +impl Keeper { + pub async fn new( + provider: Arc, + signer: Arc, + request_rx: mpsc::Receiver, + metrics: Arc, + min_batch_size: usize, + max_batch_size: usize, + batch_timeout: Duration, + ) -> Result { + Ok(Self { + provider, + signer, + request_rx, + metrics, + min_batch_size, + max_batch_size, + batch_timeout, + hermes_client: HermesClient::new(), + }) + } + + pub async fn run(&mut self) -> Result<()> { + let mut pending_requests = Vec::new(); + let mut batch_timer = time::interval(self.batch_timeout); + + loop { + tokio::select! { + Some(request) = self.request_rx.recv() => { + pending_requests.push(request); + + if pending_requests.len() >= self.max_batch_size { + self.process_batch(&mut pending_requests).await?; + } + } + _ = batch_timer.tick() => { + if pending_requests.len() >= self.min_batch_size { + self.process_batch(&mut pending_requests).await?; + } + } + } + } + } + + async fn process_batch(&self, requests: &mut Vec) -> Result<(), ArgusError> { + if requests.is_empty() { + return Ok(()); + } + + let batch = self.prepare_batch(requests).await?; + self.metrics.batch_size.observe(batch.requests.len() as f64); + + match self.submit_batch(batch).await { + Ok(_) => { + self.metrics.transactions_submitted.inc(); + requests.clear(); + Ok(()) + } + Err(e) => { + self.metrics.transaction_failures.inc(); + tracing::error!("Failed to submit batch: {}", e); + Err(e) + } + } + } + + async fn submit_batch(&self, batch: UpdateBatch) -> Result<(), ArgusError> { + let tx = self.build_batch_tx(&batch) + .map_err(|e| ContractError::TransactionFailed(e.to_string()))?; + + let signed_tx = self.signer.sign_transaction(tx) + .await + .map_err(|e| ContractError::TransactionFailed(format!("Failed to sign: {}", e)))?; + + let pending_tx = self.provider.send_raw_transaction(signed_tx.into()) + .await + .map_err(|e| ContractError::TransactionFailed(format!("Failed to send: {}", e)))?; + + let receipt = pending_tx.await + .map_err(|e| ContractError::TransactionFailed(format!("Failed to get receipt: {}", e)))?; + + if let Some(gas_used) = receipt.gas_used { + self.metrics.gas_used.observe(gas_used.as_f64()); + } + + // Check if transaction was successful + if !receipt.status.unwrap_or_default().is_success() { + return Err(ContractError::TransactionFailed("Transaction reverted".into()).into()); + } + + Ok(()) + } + + async fn prepare_batch(&self, requests: &[PriceUpdateRequest]) -> Result { + // Group requests by price ID to minimize Hermes API calls + let mut price_id_map: HashMap<[u8; 32], Vec> = HashMap::new(); + for (i, req) in requests.iter().enumerate() { + for price_id in &req.price_ids { + price_id_map.entry(*price_id).or_default().push(i); + } + } + + // Get all unique price IDs + let price_ids: Vec<[u8; 32]> = price_id_map.keys().copied().collect(); + + // Fetch price data from Hermes in a single batch request + let price_updates = self.hermes_client.get_price_updates(&price_ids).await?; + + let mut price_data = Vec::new(); + let mut update_data = Vec::new(); + + for (data, vaa) in price_updates { + price_data.push(data); + update_data.push(vaa); + } + + Ok(UpdateBatch { + requests: requests.to_vec(), + price_data, + update_data: update_data.into_iter().map(Bytes::from).collect(), + }) + } + + fn build_batch_tx(&self, batch: &UpdateBatch) -> Result { + let contract = PulseContract::new(self.contract_addr, self.provider.clone()); + + let tx = contract.execute_callback( + batch.requests[0].provider, + batch.requests[0].sequence_number, + batch.requests[0].price_ids.clone(), + batch.update_data.clone(), + batch.requests[0].callback_gas_limit, + ); + + Ok(tx) + } +} + +#[cfg(test)] +mod tests { + use { + super::*, + crate::types::PriceUpdateRequest, + tokio::sync::mpsc, + }; + + fn setup_test_metrics() -> Arc { + Arc::new(KeeperMetrics { + transactions_submitted: Counter::default(), + transaction_failures: Counter::default(), + gas_used: Histogram::new([1.0, 5.0, 10.0, 50.0, 100.0, 500.0, 1000.0].into_iter()), + batch_size: Histogram::new([1.0, 2.0, 5.0, 10.0, 20.0, 50.0].into_iter()), + }) + } + + #[tokio::test] + async fn test_process_empty_batch() { + let (tx, rx) = mpsc::channel(100); + let provider = Arc::new(Provider::mock()); + let signer = Arc::new(Signer::new_random()); + let metrics = setup_test_metrics(); + + let keeper = Keeper::new( + provider, + signer, + rx, + metrics.clone(), + 1, + 10, + Duration::from_secs(5), + ).await.unwrap(); + + let mut requests = Vec::new(); + assert!(keeper.process_batch(&mut requests).await.is_ok()); + assert!(requests.is_empty()); + } + + #[tokio::test] + async fn test_batch_size_metrics() { + let (tx, rx) = mpsc::channel(100); + let provider = Arc::new(Provider::mock()); + let signer = Arc::new(Signer::new_random()); + let metrics = setup_test_metrics(); + + let keeper = Keeper::new( + provider, + signer, + rx, + metrics.clone(), + 1, + 10, + Duration::from_secs(5), + ).await.unwrap(); + + let mut requests = vec![ + PriceUpdateRequest { + provider: Address::zero(), + sequence_number: 1.into(), + publish_time: 1234.into(), + price_ids: vec![[0u8; 32]], + callback_gas_limit: 100000.into(), + requester: Address::zero(), + }, + PriceUpdateRequest { + provider: Address::zero(), + sequence_number: 2.into(), + publish_time: 1234.into(), + price_ids: vec![[0u8; 32]], + callback_gas_limit: 100000.into(), + requester: Address::zero(), + }, + ]; + + // Process batch should succeed and update metrics + keeper.process_batch(&mut requests).await.unwrap(); + + // Check that batch size metric was updated + let batch_size = metrics.batch_size.get_or_create(&()).get_count(); + assert_eq!(batch_size, 2); + } + + #[tokio::test] + async fn test_transaction_failure_metrics() { + let (tx, rx) = mpsc::channel(100); + let provider = Arc::new(Provider::mock().with_error()); // Mock provider that returns errors + let signer = Arc::new(Signer::new_random()); + let metrics = setup_test_metrics(); + + let keeper = Keeper::new( + provider, + signer, + rx, + metrics.clone(), + 1, + 10, + Duration::from_secs(5), + ).await.unwrap(); + + let mut requests = vec![ + PriceUpdateRequest { + provider: Address::zero(), + sequence_number: 1.into(), + publish_time: 1234.into(), + price_ids: vec![[0u8; 32]], + callback_gas_limit: 100000.into(), + requester: Address::zero(), + }, + ]; + + // Process batch should fail and update failure metrics + assert!(keeper.process_batch(&mut requests).await.is_err()); + + // Check that failure metric was updated + let failures = metrics.transaction_failures.get(); + assert_eq!(failures, 1); + } +} diff --git a/apps/argus/src/main.rs b/apps/argus/src/main.rs new file mode 100644 index 0000000000..f6167d4553 --- /dev/null +++ b/apps/argus/src/main.rs @@ -0,0 +1,94 @@ +use { + crate::{ + keeper::{Keeper, KeeperMetrics}, + storage::{StorageMetrics, StoragePoller}, + types::PriceUpdateRequest, + }, + anyhow::Result, + clap::Parser, + std::{sync::Arc, time::Duration}, + tokio::sync::{mpsc, RwLock}, +}; + +mod keeper; +mod storage; +mod types; + +#[derive(Parser)] +struct Opts { + #[clap(long, default_value = "config.yaml")] + config: String, +} + +#[tokio::main] +async fn main() -> Result<()> { + // Initialize logging + tracing_subscriber::fmt::init(); + + // Parse command line arguments + let opts = Opts::parse(); + + // Initialize metrics registry + let metrics_registry = Arc::new(RwLock::new(prometheus_client::registry::Registry::default())); + + // Load config + let config = config::load_config(&opts.config)?; + + // Set up channel between storage poller and keeper + let (request_tx, request_rx) = mpsc::channel(1000); + + // Initialize metrics + let storage_metrics = Arc::new(StorageMetrics { + requests_found: Counter::default(), + polling_errors: Counter::default(), + }); + + let keeper_metrics = Arc::new(KeeperMetrics { + transactions_submitted: Counter::default(), + transaction_failures: Counter::default(), + gas_used: Histogram::new([1.0, 5.0, 10.0, 50.0, 100.0, 500.0, 1000.0].into_iter()), + batch_size: Histogram::new([1.0, 2.0, 5.0, 10.0, 20.0, 50.0].into_iter()), + }); + + // Register metrics + { + let mut registry = metrics_registry.write().await; + // Register all metrics... + } + + // Initialize components + let provider = Arc::new(config.create_provider()?); + + let storage_poller = StoragePoller::new( + provider.clone(), + config.contract_address, + Duration::from_secs(config.poll_interval), + request_tx, + storage_metrics, + ).await?; + + let mut keeper = Keeper::new( + provider, + config.create_hot_wallet()?, + config.create_cold_wallet()?, + request_rx, + keeper_metrics, + config.min_batch_size, + config.max_batch_size, + Duration::from_secs(config.batch_timeout), + ).await?; + + // Start components + let storage_handle = tokio::spawn(async move { + storage_poller.start_polling().await + }); + + let keeper_handle = tokio::spawn(async move { + keeper.run().await + }); + + // Wait for components to finish + tokio::try_join!(storage_handle, keeper_handle)?; + + Ok(()) +} diff --git a/apps/argus/src/storage.rs b/apps/argus/src/storage.rs new file mode 100644 index 0000000000..33e79115cd --- /dev/null +++ b/apps/argus/src/storage.rs @@ -0,0 +1,164 @@ +use { + crate::types::PriceUpdateRequest, + alloy::{ + primitives::{Address, U256}, + providers::Provider, + }, + anyhow::Result, + prometheus_client::{ + metrics::{counter::Counter, family::Family}, + registry::Registry, + }, + sha3::{Digest, Keccak256}, + std::{sync::Arc, time::Duration}, + tokio::{sync::mpsc, time}, +}; + +const NUM_REQUESTS: u8 = 32; +const NUM_REQUESTS_MASK: u8 = 0x1f; + +#[derive(Clone, Debug)] +pub struct StorageMetrics { + pub requests_found: Counter, + pub polling_errors: Counter, +} + +pub struct StoragePoller { + provider: Arc, + contract_addr: Address, + poll_interval: Duration, + request_tx: mpsc::Sender, + metrics: Arc, +} + +impl StoragePoller { + pub async fn new( + provider: Arc, + contract_addr: Address, + poll_interval: Duration, + request_tx: mpsc::Sender, + metrics: Arc, + ) -> Result { + Ok(Self { + provider, + contract_addr, + poll_interval, + request_tx, + metrics, + }) + } + + pub async fn start_polling(&self) -> Result<()> { + loop { + match self.poll_requests().await { + Ok(requests) => { + for request in requests { + if let Err(e) = self.request_tx.send(request).await { + tracing::error!("Failed to send request to keeper: {}", e); + self.metrics.polling_errors.inc(); + } else { + self.metrics.requests_found.inc(); + } + } + } + Err(e) => { + tracing::error!("Error polling requests: {}", e); + self.metrics.polling_errors.inc(); + } + } + + time::sleep(self.poll_interval).await; + } + } + + async fn poll_requests(&self) -> Result> { + let mut requests = Vec::new(); + + // The Pulse contract has a fixed array of 32 requests and an overflow mapping + // First read the fixed array (slot 2 in the contract) + for i in 0..NUM_REQUESTS { + let slot = self.calculate_request_slot(i); + let request = self.read_request_at_slot(slot).await?; + + // sequence_number == 0 means empty/inactive request + if request.sequence_number.as_u64() != 0 { + requests.push(request); + } + } + + // TODO: Read overflow mapping if needed + // The overflow mapping is used when there's a hash collision in the fixed array + // We'll need to read slot keccak256(key, OVERFLOW_SLOT) where key is keccak256(provider, sequence) + + Ok(requests) + } + + fn calculate_request_slot(&self, index: u8) -> U256 { + // In the Pulse contract, the requests array is at slot 2 + // For arrays, Solidity stores data at: keccak256(slot) + index + const REQUESTS_SLOT: u8 = 2; + + // Calculate base slot for requests array + let base_slot = U256::from(REQUESTS_SLOT); + + // Calculate actual slot: keccak256(slot) + index + let array_slot = keccak256(&base_slot.to_be_bytes::<32>()); + U256::from_be_bytes(array_slot) + U256::from(index) + } + + async fn read_request_at_slot(&self, slot: U256) -> Result { + // Each Request struct takes multiple slots: + // slot + 0: provider (address) and sequence_number (uint64) packed together + // slot + 1: publish_time (uint256) + // slot + 2: priceIds array length + // slot + 3: callback_gas_limit (uint256) + // slot + 4: requester (address) + // priceIds array is stored starting at keccak256(slot + 2) + + let slot_0 = self.provider.get_storage_at(self.contract_addr, slot).await?; + let slot_1 = self.provider.get_storage_at(self.contract_addr, slot + 1).await?; + let slot_2 = self.provider.get_storage_at(self.contract_addr, slot + 2).await?; + let slot_3 = self.provider.get_storage_at(self.contract_addr, slot + 3).await?; + let slot_4 = self.provider.get_storage_at(self.contract_addr, slot + 4).await?; + + // Parse provider (20 bytes) and sequence_number (8 bytes) from slot_0 + let provider = Address::from_slice(&slot_0[0..20]); + let sequence_number = U64::from_be_bytes(slot_0[20..28].try_into()?); + + // Parse publish_time + let publish_time = U256::from_be_bytes(slot_1); + + // Parse price IDs array + let price_ids_length = U256::from_be_bytes(slot_2).as_usize(); + let mut price_ids = Vec::with_capacity(price_ids_length); + + if price_ids_length > 0 { + let price_ids_slot = keccak256(&(slot + 2).to_be_bytes::<32>()); + for i in 0..price_ids_length { + let price_id_slot = U256::from_be_bytes(price_ids_slot) + U256::from(i); + let price_id_data = self.provider.get_storage_at(self.contract_addr, price_id_slot).await?; + price_ids.push(price_id_data.try_into()?); + } + } + + // Parse callback gas limit and requester + let callback_gas_limit = U256::from_be_bytes(slot_3); + let requester = Address::from_slice(&slot_4[0..20]); + + Ok(PriceUpdateRequest { + provider, + sequence_number, + publish_time, + price_ids, + callback_gas_limit, + requester, + }) + } +} + +// Helper function to calculate keccak256 hash +fn keccak256(data: &[u8]) -> [u8; 32] { + let mut hasher = Keccak256::new(); + hasher.update(data); + hasher.finalize().into() +} diff --git a/apps/argus/src/types.rs b/apps/argus/src/types.rs new file mode 100644 index 0000000000..2c923d1da2 --- /dev/null +++ b/apps/argus/src/types.rs @@ -0,0 +1,32 @@ +use { + alloy::{ + primitives::{Address, U256, U64}, + vec::Vec, + }, + serde::{Deserialize, Serialize}, +}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PriceUpdateRequest { + pub provider: Address, + pub sequence_number: U64, + pub publish_time: U256, + pub price_ids: Vec<[u8; 32]>, + pub callback_gas_limit: U256, + pub requester: Address, +} + +#[derive(Debug, Clone)] +pub struct PriceData { + pub price: i64, + pub conf: u64, + pub expo: i32, + pub publish_time: u64, +} + +#[derive(Debug, Clone)] +pub struct UpdateBatch { + pub requests: Vec, + pub price_data: Vec, + pub update_data: Vec>, +}