diff --git a/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol b/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol index 22891e98a3..6e5d44eaf9 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol @@ -2,6 +2,7 @@ pragma solidity ^0.8.0; +import "@pythnetwork/pyth-sdk-solidity/IPyth.sol"; import "./PulseEvents.sol"; import "./PulseState.sol"; @@ -9,8 +10,7 @@ interface IPulseConsumer { function pulseCallback( uint64 sequenceNumber, address updater, - uint256 publishTime, - bytes32[] calldata priceIds + PythStructs.PriceFeed[] memory priceFeeds ) external; } diff --git a/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol b/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol index 6ede25997b..eeb1cccf2c 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol @@ -63,18 +63,15 @@ abstract contract Pulse is IPulse, PulseState { bytes32[] calldata priceIds ) external payable override { Request storage req = findActiveRequest(sequenceNumber); + + // Verify priceIds match bytes32 providedPriceIdsHash = keccak256(abi.encode(priceIds)); bytes32 storedPriceIdsHash = req.priceIdsHash; - if (providedPriceIdsHash != storedPriceIdsHash) { revert InvalidPriceIds(providedPriceIdsHash, storedPriceIdsHash); } - // Check if there's enough gas left for the callback - if (gasleft() < req.callbackGasLimit) { - revert InsufficientGas(); - } - + // Parse price feeds first to measure gas usage PythStructs.PriceFeed[] memory priceFeeds = IPyth(_state.pyth) .parsePriceFeedUpdates( updateData, @@ -83,21 +80,23 @@ abstract contract Pulse is IPulse, PulseState { SafeCast.toUint64(req.publishTime) ); - uint256 publishTime = priceFeeds[0].price.publishTime; + // Check if enough gas remains for the callback + if (gasleft() < req.callbackGasLimit) { + revert InsufficientGas(); + } try IPulseConsumer(req.requester).pulseCallback{ gas: req.callbackGasLimit - }(sequenceNumber, msg.sender, publishTime, priceIds) + }(sequenceNumber, msg.sender, priceFeeds) { // Callback succeeded - emitPriceUpdate(sequenceNumber, publishTime, priceIds, priceFeeds); + emitPriceUpdate(sequenceNumber, priceIds, priceFeeds); } catch Error(string memory reason) { // Explicit revert/require emit PriceUpdateCallbackFailed( sequenceNumber, msg.sender, - publishTime, priceIds, req.requester, reason @@ -107,7 +106,6 @@ abstract contract Pulse is IPulse, PulseState { emit PriceUpdateCallbackFailed( sequenceNumber, msg.sender, - publishTime, priceIds, req.requester, "low-level error (possibly out of gas)" @@ -119,7 +117,6 @@ abstract contract Pulse is IPulse, PulseState { function emitPriceUpdate( uint64 sequenceNumber, - uint256 publishTime, bytes32[] memory priceIds, PythStructs.PriceFeed[] memory priceFeeds ) internal { @@ -138,7 +135,6 @@ abstract contract Pulse is IPulse, PulseState { emit PriceUpdateExecuted( sequenceNumber, msg.sender, - publishTime, priceIds, prices, conf, diff --git a/target_chains/ethereum/contracts/contracts/pulse/PulseEvents.sol b/target_chains/ethereum/contracts/contracts/pulse/PulseEvents.sol index 1c96797b39..4b7abfbbc3 100644 --- a/target_chains/ethereum/contracts/contracts/pulse/PulseEvents.sol +++ b/target_chains/ethereum/contracts/contracts/pulse/PulseEvents.sol @@ -9,7 +9,6 @@ interface PulseEvents { event PriceUpdateExecuted( uint64 indexed sequenceNumber, address indexed updater, - uint256 publishTime, bytes32[] priceIds, int64[] prices, uint64[] conf, @@ -22,7 +21,6 @@ interface PulseEvents { event PriceUpdateCallbackFailed( uint64 indexed sequenceNumber, address indexed updater, - uint256 publishTime, bytes32[] priceIds, address requester, string reason diff --git a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol index 7e7d276e43..81edc4115a 100644 --- a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol +++ b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol @@ -3,6 +3,7 @@ pragma solidity ^0.8.0; import "forge-std/Test.sol"; +import "@pythnetwork/pyth-sdk-solidity/IPyth.sol"; import "@openzeppelin/contracts/proxy/ERC1967/ERC1967Proxy.sol"; import "../contracts/pulse/PulseUpgradeable.sol"; import "../contracts/pulse/IPulse.sol"; @@ -13,19 +14,26 @@ import "../contracts/pulse/PulseErrors.sol"; contract MockPulseConsumer is IPulseConsumer { uint64 public lastSequenceNumber; address public lastUpdater; - uint256 public lastPublishTime; - bytes32[] public lastPriceIds; + PythStructs.PriceFeed[] private _lastPriceFeeds; function pulseCallback( uint64 sequenceNumber, address updater, - uint256 publishTime, - bytes32[] calldata priceIds + PythStructs.PriceFeed[] memory priceFeeds ) external override { lastSequenceNumber = sequenceNumber; lastUpdater = updater; - lastPublishTime = publishTime; - lastPriceIds = priceIds; + for (uint i = 0; i < priceFeeds.length; i++) { + _lastPriceFeeds.push(priceFeeds[i]); + } + } + + function lastPriceFeeds() + external + view + returns (PythStructs.PriceFeed[] memory) + { + return _lastPriceFeeds; } } @@ -33,8 +41,7 @@ contract FailingPulseConsumer is IPulseConsumer { function pulseCallback( uint64, address, - uint256, - bytes32[] calldata + PythStructs.PriceFeed[] memory ) external pure override { revert("callback failed"); } @@ -46,8 +53,7 @@ contract CustomErrorPulseConsumer is IPulseConsumer { function pulseCallback( uint64, address, - uint256, - bytes32[] calldata + PythStructs.PriceFeed[] memory ) external pure override { revert CustomError("callback failed"); } @@ -278,7 +284,6 @@ contract PulseTest is Test, PulseEvents { emit PriceUpdateExecuted( sequenceNumber, updater, - publishTime, priceIds, expectedPrices, expectedConf, @@ -294,7 +299,22 @@ contract PulseTest is Test, PulseEvents { // Verify callback was executed assertEq(consumer.lastSequenceNumber(), sequenceNumber); - assertEq(consumer.lastPublishTime(), publishTime); + + // Compare price feeds array length + PythStructs.PriceFeed[] memory lastFeeds = consumer.lastPriceFeeds(); + assertEq(lastFeeds.length, priceFeeds.length); + + // Compare each price feed + for (uint i = 0; i < priceFeeds.length; i++) { + assertEq(lastFeeds[i].id, priceFeeds[i].id); + assertEq(lastFeeds[i].price.price, priceFeeds[i].price.price); + assertEq(lastFeeds[i].price.conf, priceFeeds[i].price.conf); + assertEq(lastFeeds[i].price.expo, priceFeeds[i].price.expo); + assertEq( + lastFeeds[i].price.publishTime, + priceFeeds[i].price.publishTime + ); + } } function testExecuteCallbackFailure() public { @@ -316,7 +336,6 @@ contract PulseTest is Test, PulseEvents { emit PriceUpdateCallbackFailed( sequenceNumber, updater, - publishTime, priceIds, address(failingConsumer), "callback failed" @@ -345,7 +364,6 @@ contract PulseTest is Test, PulseEvents { emit PriceUpdateCallbackFailed( sequenceNumber, updater, - publishTime, priceIds, address(failingConsumer), "low-level error (possibly out of gas)" @@ -370,10 +388,14 @@ contract PulseTest is Test, PulseEvents { mockParsePriceFeedUpdates(priceFeeds); bytes[] memory updateData = createMockUpdateData(priceFeeds); - // Try executing with only 10K gas when 1M is required + // Try executing with only 100K gas when 1M is required vm.prank(updater); vm.expectRevert(InsufficientGas.selector); - pulse.executeCallback{gas: 10000}(sequenceNumber, updateData, priceIds); // Will fail because gasleft() < callbackGasLimit + pulse.executeCallback{gas: 100000}( + sequenceNumber, + updateData, + priceIds + ); // Will fail because gasleft() < callbackGasLimit } function testExecuteCallbackWithFutureTimestamp() public { @@ -399,8 +421,17 @@ contract PulseTest is Test, PulseEvents { // Should succeed because we're simulating receiving future-dated price updates pulse.executeCallback(sequenceNumber, updateData, priceIds); - // Verify the callback was executed with future timestamp - assertEq(consumer.lastPublishTime(), futureTime); + // Compare price feeds array length + PythStructs.PriceFeed[] memory lastFeeds = consumer.lastPriceFeeds(); + assertEq(lastFeeds.length, priceFeeds.length); + + // Compare each price feed publish time + for (uint i = 0; i < priceFeeds.length; i++) { + assertEq( + lastFeeds[i].price.publishTime, + priceFeeds[i].price.publishTime + ); + } } function testDoubleExecuteCallback() public {