diff --git a/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol b/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol new file mode 100644 index 0000000000..6e5d44eaf9 --- /dev/null +++ b/target_chains/ethereum/contracts/contracts/pulse/IPulse.sol @@ -0,0 +1,48 @@ +// SPDX-License-Identifier: Apache 2 + +pragma solidity ^0.8.0; + +import "@pythnetwork/pyth-sdk-solidity/IPyth.sol"; +import "./PulseEvents.sol"; +import "./PulseState.sol"; + +interface IPulseConsumer { + function pulseCallback( + uint64 sequenceNumber, + address updater, + PythStructs.PriceFeed[] memory priceFeeds + ) external; +} + +interface IPulse is PulseEvents { + // Core functions + function requestPriceUpdatesWithCallback( + uint256 publishTime, + bytes32[] calldata priceIds, + uint256 callbackGasLimit + ) external payable returns (uint64 sequenceNumber); + + function executeCallback( + uint64 sequenceNumber, + bytes[] calldata updateData, + bytes32[] calldata priceIds + ) external payable; + + // Getters + function getFee( + uint256 callbackGasLimit + ) external view returns (uint128 feeAmount); + + function getPythFeeInWei() external view returns (uint128 pythFeeInWei); + + function getAccruedFees() external view returns (uint128 accruedFeesInWei); + + function getRequest( + uint64 sequenceNumber + ) external view returns (PulseState.Request memory req); + + // Add these functions to the IPulse interface + function setFeeManager(address manager) external; + + function withdrawAsFeeManager(uint128 amount) external; +} diff --git a/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol b/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol new file mode 100644 index 0000000000..eeb1cccf2c --- /dev/null +++ b/target_chains/ethereum/contracts/contracts/pulse/Pulse.sol @@ -0,0 +1,264 @@ +// SPDX-License-Identifier: Apache 2 + +pragma solidity ^0.8.0; + +import "@openzeppelin/contracts/utils/math/SafeCast.sol"; +import "@pythnetwork/pyth-sdk-solidity/IPyth.sol"; +import "./IPulse.sol"; +import "./PulseState.sol"; +import "./PulseErrors.sol"; + +abstract contract Pulse is IPulse, PulseState { + function _initialize( + address admin, + uint128 pythFeeInWei, + address pythAddress, + bool prefillRequestStorage + ) internal { + require(admin != address(0), "admin is zero address"); + require(pythAddress != address(0), "pyth is zero address"); + + _state.admin = admin; + _state.accruedFeesInWei = 0; + _state.pythFeeInWei = pythFeeInWei; + _state.pyth = pythAddress; + _state.currentSequenceNumber = 1; + + if (prefillRequestStorage) { + for (uint8 i = 0; i < NUM_REQUESTS; i++) { + Request storage req = _state.requests[i]; + req.sequenceNumber = 0; + req.publishTime = 1; + req.callbackGasLimit = 1; + req.requester = address(1); + } + } + } + + function requestPriceUpdatesWithCallback( + uint256 publishTime, + bytes32[] calldata priceIds, + uint256 callbackGasLimit + ) external payable override returns (uint64 requestSequenceNumber) { + requestSequenceNumber = _state.currentSequenceNumber++; + + uint128 requiredFee = getFee(callbackGasLimit); + if (msg.value < requiredFee) revert InsufficientFee(); + + Request storage req = allocRequest(requestSequenceNumber); + req.sequenceNumber = requestSequenceNumber; + req.publishTime = publishTime; + req.priceIdsHash = keccak256(abi.encode(priceIds)); + req.callbackGasLimit = callbackGasLimit; + req.requester = msg.sender; + + _state.accruedFeesInWei += SafeCast.toUint128(msg.value); + + emit PriceUpdateRequested(req); + } + + function executeCallback( + uint64 sequenceNumber, + bytes[] calldata updateData, + bytes32[] calldata priceIds + ) external payable override { + Request storage req = findActiveRequest(sequenceNumber); + + // Verify priceIds match + bytes32 providedPriceIdsHash = keccak256(abi.encode(priceIds)); + bytes32 storedPriceIdsHash = req.priceIdsHash; + if (providedPriceIdsHash != storedPriceIdsHash) { + revert InvalidPriceIds(providedPriceIdsHash, storedPriceIdsHash); + } + + // Parse price feeds first to measure gas usage + PythStructs.PriceFeed[] memory priceFeeds = IPyth(_state.pyth) + .parsePriceFeedUpdates( + updateData, + priceIds, + SafeCast.toUint64(req.publishTime), + SafeCast.toUint64(req.publishTime) + ); + + // Check if enough gas remains for the callback + if (gasleft() < req.callbackGasLimit) { + revert InsufficientGas(); + } + + try + IPulseConsumer(req.requester).pulseCallback{ + gas: req.callbackGasLimit + }(sequenceNumber, msg.sender, priceFeeds) + { + // Callback succeeded + emitPriceUpdate(sequenceNumber, priceIds, priceFeeds); + } catch Error(string memory reason) { + // Explicit revert/require + emit PriceUpdateCallbackFailed( + sequenceNumber, + msg.sender, + priceIds, + req.requester, + reason + ); + } catch { + // Out of gas or other low-level errors + emit PriceUpdateCallbackFailed( + sequenceNumber, + msg.sender, + priceIds, + req.requester, + "low-level error (possibly out of gas)" + ); + } + + clearRequest(sequenceNumber); + } + + function emitPriceUpdate( + uint64 sequenceNumber, + bytes32[] memory priceIds, + PythStructs.PriceFeed[] memory priceFeeds + ) internal { + int64[] memory prices = new int64[](priceFeeds.length); + uint64[] memory conf = new uint64[](priceFeeds.length); + int32[] memory expos = new int32[](priceFeeds.length); + uint256[] memory publishTimes = new uint256[](priceFeeds.length); + + for (uint i = 0; i < priceFeeds.length; i++) { + prices[i] = priceFeeds[i].price.price; + conf[i] = priceFeeds[i].price.conf; + expos[i] = priceFeeds[i].price.expo; + publishTimes[i] = priceFeeds[i].price.publishTime; + } + + emit PriceUpdateExecuted( + sequenceNumber, + msg.sender, + priceIds, + prices, + conf, + expos, + publishTimes + ); + } + + function getFee( + uint256 callbackGasLimit + ) public view override returns (uint128 feeAmount) { + uint128 baseFee = _state.pythFeeInWei; + uint256 gasFee = callbackGasLimit * tx.gasprice; + feeAmount = baseFee + SafeCast.toUint128(gasFee); + } + + function getPythFeeInWei() + public + view + override + returns (uint128 pythFeeInWei) + { + pythFeeInWei = _state.pythFeeInWei; + } + + function getAccruedFees() + public + view + override + returns (uint128 accruedFeesInWei) + { + accruedFeesInWei = _state.accruedFeesInWei; + } + + function getRequest( + uint64 sequenceNumber + ) public view override returns (Request memory req) { + req = findRequest(sequenceNumber); + } + + function requestKey( + uint64 sequenceNumber + ) internal pure returns (bytes32 hash, uint8 shortHash) { + hash = keccak256(abi.encodePacked(sequenceNumber)); + shortHash = uint8(hash[0] & NUM_REQUESTS_MASK); + } + + function withdrawFees(uint128 amount) external { + require(msg.sender == _state.admin, "Only admin can withdraw fees"); + require(_state.accruedFeesInWei >= amount, "Insufficient balance"); + + _state.accruedFeesInWei -= amount; + + (bool sent, ) = msg.sender.call{value: amount}(""); + require(sent, "Failed to send fees"); + + emit FeesWithdrawn(msg.sender, amount); + } + + function findActiveRequest( + uint64 sequenceNumber + ) internal view returns (Request storage req) { + req = findRequest(sequenceNumber); + + if (!isActive(req) || req.sequenceNumber != sequenceNumber) + revert NoSuchRequest(); + } + + function findRequest( + uint64 sequenceNumber + ) internal view returns (Request storage req) { + (bytes32 key, uint8 shortKey) = requestKey(sequenceNumber); + + req = _state.requests[shortKey]; + if (req.sequenceNumber == sequenceNumber) { + return req; + } else { + req = _state.requestsOverflow[key]; + } + } + + function clearRequest(uint64 sequenceNumber) internal { + (bytes32 key, uint8 shortKey) = requestKey(sequenceNumber); + + Request storage req = _state.requests[shortKey]; + if (req.sequenceNumber == sequenceNumber) { + req.sequenceNumber = 0; + } else { + delete _state.requestsOverflow[key]; + } + } + + function allocRequest( + uint64 sequenceNumber + ) internal returns (Request storage req) { + (, uint8 shortKey) = requestKey(sequenceNumber); + + req = _state.requests[shortKey]; + if (isActive(req)) { + (bytes32 reqKey, ) = requestKey(req.sequenceNumber); + _state.requestsOverflow[reqKey] = req; + } + } + + function isActive(Request storage req) internal view returns (bool) { + return req.sequenceNumber != 0; + } + + function setFeeManager(address manager) external override { + require(msg.sender == _state.admin, "Only admin can set fee manager"); + address oldFeeManager = _state.feeManager; + _state.feeManager = manager; + emit FeeManagerUpdated(_state.admin, oldFeeManager, manager); + } + + function withdrawAsFeeManager(uint128 amount) external override { + require(msg.sender == _state.feeManager, "Only fee manager"); + require(_state.accruedFeesInWei >= amount, "Insufficient balance"); + + _state.accruedFeesInWei -= amount; + + (bool sent, ) = msg.sender.call{value: amount}(""); + require(sent, "Failed to send fees"); + + emit FeesWithdrawn(msg.sender, amount); + } +} diff --git a/target_chains/ethereum/contracts/contracts/pulse/PulseErrors.sol b/target_chains/ethereum/contracts/contracts/pulse/PulseErrors.sol new file mode 100644 index 0000000000..c2fe41ccb6 --- /dev/null +++ b/target_chains/ethereum/contracts/contracts/pulse/PulseErrors.sol @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: Apache 2 + +pragma solidity ^0.8.0; + +error NoSuchProvider(); +error NoSuchRequest(); +error InsufficientFee(); +error Unauthorized(); +error InvalidCallbackGas(); +error CallbackFailed(); +error InvalidPriceIds(bytes32 providedPriceIdsHash, bytes32 storedPriceIdsHash); +error InvalidCallbackGasLimit(uint256 requested, uint256 stored); +error ExceedsMaxPrices(uint32 requested, uint32 maxAllowed); +error InsufficientGas(); diff --git a/target_chains/ethereum/contracts/contracts/pulse/PulseEvents.sol b/target_chains/ethereum/contracts/contracts/pulse/PulseEvents.sol new file mode 100644 index 0000000000..4b7abfbbc3 --- /dev/null +++ b/target_chains/ethereum/contracts/contracts/pulse/PulseEvents.sol @@ -0,0 +1,34 @@ +// SPDX-License-Identifier: Apache-2.0 +pragma solidity ^0.8.0; + +import "./PulseState.sol"; + +interface PulseEvents { + event PriceUpdateRequested(PulseState.Request request); + + event PriceUpdateExecuted( + uint64 indexed sequenceNumber, + address indexed updater, + bytes32[] priceIds, + int64[] prices, + uint64[] conf, + int32[] expos, + uint256[] publishTimes + ); + + event FeesWithdrawn(address indexed recipient, uint128 amount); + + event PriceUpdateCallbackFailed( + uint64 indexed sequenceNumber, + address indexed updater, + bytes32[] priceIds, + address requester, + string reason + ); + + event FeeManagerUpdated( + address indexed admin, + address oldFeeManager, + address newFeeManager + ); +} diff --git a/target_chains/ethereum/contracts/contracts/pulse/PulseState.sol b/target_chains/ethereum/contracts/contracts/pulse/PulseState.sol new file mode 100644 index 0000000000..6b0b48fc61 --- /dev/null +++ b/target_chains/ethereum/contracts/contracts/pulse/PulseState.sol @@ -0,0 +1,29 @@ +// SPDX-License-Identifier: Apache 2 + +pragma solidity ^0.8.0; + +contract PulseState { + uint8 public constant NUM_REQUESTS = 32; + bytes1 public constant NUM_REQUESTS_MASK = 0x1f; + + struct Request { + uint64 sequenceNumber; + uint256 publishTime; + bytes32 priceIdsHash; + uint256 callbackGasLimit; + address requester; + } + + struct State { + address admin; + uint128 pythFeeInWei; + uint128 accruedFeesInWei; + address pyth; + uint64 currentSequenceNumber; + address feeManager; + Request[32] requests; + mapping(bytes32 => Request) requestsOverflow; + } + + State internal _state; +} diff --git a/target_chains/ethereum/contracts/contracts/pulse/PulseUpgradeable.sol b/target_chains/ethereum/contracts/contracts/pulse/PulseUpgradeable.sol new file mode 100644 index 0000000000..48fc694e69 --- /dev/null +++ b/target_chains/ethereum/contracts/contracts/pulse/PulseUpgradeable.sol @@ -0,0 +1,71 @@ +// SPDX-License-Identifier: Apache 2 + +pragma solidity ^0.8.0; + +import "@openzeppelin/contracts-upgradeable/proxy/utils/Initializable.sol"; +import "@openzeppelin/contracts-upgradeable/proxy/utils/UUPSUpgradeable.sol"; +import "@openzeppelin/contracts-upgradeable/access/Ownable2StepUpgradeable.sol"; +import "./Pulse.sol"; + +contract PulseUpgradeable is + Initializable, + Ownable2StepUpgradeable, + UUPSUpgradeable, + Pulse +{ + event ContractUpgraded( + address oldImplementation, + address newImplementation + ); + + function initialize( + address owner, + address admin, + uint128 pythFeeInWei, + address pythAddress, + bool prefillRequestStorage + ) public initializer { + require(owner != address(0), "owner is zero address"); + require(admin != address(0), "admin is zero address"); + + __Ownable_init(); + __UUPSUpgradeable_init(); + + Pulse._initialize( + admin, + pythFeeInWei, + pythAddress, + prefillRequestStorage + ); + + _transferOwnership(owner); + } + + /// @custom:oz-upgrades-unsafe-allow constructor + constructor() initializer {} + + function _authorizeUpgrade(address) internal override onlyOwner {} + + function upgradeTo(address newImplementation) external override onlyProxy { + address oldImplementation = _getImplementation(); + _authorizeUpgrade(newImplementation); + _upgradeToAndCallUUPS(newImplementation, new bytes(0), false); + + emit ContractUpgraded(oldImplementation, _getImplementation()); + } + + function upgradeToAndCall( + address newImplementation, + bytes memory data + ) external payable override onlyProxy { + address oldImplementation = _getImplementation(); + _authorizeUpgrade(newImplementation); + _upgradeToAndCallUUPS(newImplementation, data, true); + + emit ContractUpgraded(oldImplementation, _getImplementation()); + } + + function version() public pure returns (string memory) { + return "1.0.0"; + } +} diff --git a/target_chains/ethereum/contracts/forge-test/Pulse.t.sol b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol new file mode 100644 index 0000000000..81edc4115a --- /dev/null +++ b/target_chains/ethereum/contracts/forge-test/Pulse.t.sol @@ -0,0 +1,629 @@ +// SPDX-License-Identifier: Apache 2 + +pragma solidity ^0.8.0; + +import "forge-std/Test.sol"; +import "@pythnetwork/pyth-sdk-solidity/IPyth.sol"; +import "@openzeppelin/contracts/proxy/ERC1967/ERC1967Proxy.sol"; +import "../contracts/pulse/PulseUpgradeable.sol"; +import "../contracts/pulse/IPulse.sol"; +import "../contracts/pulse/PulseState.sol"; +import "../contracts/pulse/PulseEvents.sol"; +import "../contracts/pulse/PulseErrors.sol"; + +contract MockPulseConsumer is IPulseConsumer { + uint64 public lastSequenceNumber; + address public lastUpdater; + PythStructs.PriceFeed[] private _lastPriceFeeds; + + function pulseCallback( + uint64 sequenceNumber, + address updater, + PythStructs.PriceFeed[] memory priceFeeds + ) external override { + lastSequenceNumber = sequenceNumber; + lastUpdater = updater; + for (uint i = 0; i < priceFeeds.length; i++) { + _lastPriceFeeds.push(priceFeeds[i]); + } + } + + function lastPriceFeeds() + external + view + returns (PythStructs.PriceFeed[] memory) + { + return _lastPriceFeeds; + } +} + +contract FailingPulseConsumer is IPulseConsumer { + function pulseCallback( + uint64, + address, + PythStructs.PriceFeed[] memory + ) external pure override { + revert("callback failed"); + } +} + +contract CustomErrorPulseConsumer is IPulseConsumer { + error CustomError(string message); + + function pulseCallback( + uint64, + address, + PythStructs.PriceFeed[] memory + ) external pure override { + revert CustomError("callback failed"); + } +} + +contract PulseTest is Test, PulseEvents { + ERC1967Proxy public proxy; + PulseUpgradeable public pulse; + MockPulseConsumer public consumer; + address public owner; + address public admin; + address public updater; + address public pyth; + + // Constants + uint128 constant PYTH_FEE = 1 wei; + uint128 constant CALLBACK_GAS_LIMIT = 1_000_000; + bytes32 constant BTC_PRICE_FEED_ID = + 0xe62df6c8b4a85fe1a67db44dc12de5db330f7ac66b72dc658afedf0f4a415b43; + bytes32 constant ETH_PRICE_FEED_ID = + 0xff61491a931112ddf1bd8147cd1b641375f79f5825126d665480874634fd0ace; + + // Price feed constants + int8 constant MOCK_PRICE_FEED_EXPO = -8; + int64 constant MOCK_BTC_PRICE = 5_000_000_000_000; // $50,000 + int64 constant MOCK_ETH_PRICE = 300_000_000_000; // $3,000 + uint64 constant MOCK_BTC_CONF = 10_000_000_000; // $100 + uint64 constant MOCK_ETH_CONF = 5_000_000_000; // $50 + + function setUp() public { + owner = address(1); + admin = address(2); + updater = address(3); + pyth = address(4); + + PulseUpgradeable _pulse = new PulseUpgradeable(); + proxy = new ERC1967Proxy(address(_pulse), ""); + pulse = PulseUpgradeable(address(proxy)); + + pulse.initialize(owner, admin, PYTH_FEE, pyth, false); + consumer = new MockPulseConsumer(); + } + + // Helper function to create price IDs array + function createPriceIds() internal pure returns (bytes32[] memory) { + bytes32[] memory priceIds = new bytes32[](2); + priceIds[0] = BTC_PRICE_FEED_ID; + priceIds[1] = ETH_PRICE_FEED_ID; + return priceIds; + } + + // Helper function to create mock price feeds + function createMockPriceFeeds( + uint256 publishTime + ) internal pure returns (PythStructs.PriceFeed[] memory) { + PythStructs.PriceFeed[] memory priceFeeds = new PythStructs.PriceFeed[]( + 2 + ); + + priceFeeds[0].id = BTC_PRICE_FEED_ID; + priceFeeds[0].price.price = MOCK_BTC_PRICE; + priceFeeds[0].price.conf = MOCK_BTC_CONF; + priceFeeds[0].price.expo = MOCK_PRICE_FEED_EXPO; + priceFeeds[0].price.publishTime = publishTime; + + priceFeeds[1].id = ETH_PRICE_FEED_ID; + priceFeeds[1].price.price = MOCK_ETH_PRICE; + priceFeeds[1].price.conf = MOCK_ETH_CONF; + priceFeeds[1].price.expo = MOCK_PRICE_FEED_EXPO; + priceFeeds[1].price.publishTime = publishTime; + + return priceFeeds; + } + + // Helper function to mock Pyth response + function mockParsePriceFeedUpdates( + PythStructs.PriceFeed[] memory priceFeeds + ) internal { + vm.mockCall( + address(pyth), + abi.encodeWithSelector(IPyth.parsePriceFeedUpdates.selector), + abi.encode(priceFeeds) + ); + } + + // Helper function to create mock update data + function createMockUpdateData( + PythStructs.PriceFeed[] memory priceFeeds + ) internal pure returns (bytes[] memory) { + bytes[] memory updateData = new bytes[](2); + updateData[0] = abi.encode(priceFeeds[0]); + updateData[1] = abi.encode(priceFeeds[1]); + return updateData; + } + + // Helper function to calculate total fee + function calculateTotalFee() internal view returns (uint128) { + return pulse.getFee(CALLBACK_GAS_LIMIT); + } + + // Helper function to setup consumer request + function setupConsumerRequest( + address consumerAddress + ) + internal + returns ( + uint64 sequenceNumber, + bytes32[] memory priceIds, + uint256 publishTime + ) + { + priceIds = createPriceIds(); + publishTime = block.timestamp; + vm.deal(consumerAddress, 1 gwei); + + uint128 totalFee = calculateTotalFee(); + + vm.prank(consumerAddress); + sequenceNumber = pulse.requestPriceUpdatesWithCallback{value: totalFee}( + publishTime, + priceIds, + CALLBACK_GAS_LIMIT + ); + + return (sequenceNumber, priceIds, publishTime); + } + + function testRequestPriceUpdate() public { + // Set a realistic gas price + vm.txGasPrice(30 gwei); + + bytes32[] memory priceIds = createPriceIds(); + uint256 publishTime = block.timestamp; + + // Fund the consumer contract with enough ETH for higher gas price + vm.deal(address(consumer), 1 ether); + uint128 totalFee = calculateTotalFee(); + + // Create the event data we expect to see + PulseState.Request memory expectedRequest = PulseState.Request({ + sequenceNumber: 1, + publishTime: publishTime, + priceIdsHash: keccak256(abi.encode(priceIds)), + callbackGasLimit: CALLBACK_GAS_LIMIT, + requester: address(consumer) + }); + + vm.expectEmit(); + emit PriceUpdateRequested(expectedRequest); + + vm.prank(address(consumer)); + pulse.requestPriceUpdatesWithCallback{value: totalFee}( + publishTime, + priceIds, + CALLBACK_GAS_LIMIT + ); + + // Additional assertions to verify event data was stored correctly + PulseState.Request memory lastRequest = pulse.getRequest(1); + assertEq(lastRequest.sequenceNumber, expectedRequest.sequenceNumber); + assertEq(lastRequest.publishTime, expectedRequest.publishTime); + assertEq(lastRequest.priceIdsHash, expectedRequest.priceIdsHash); + assertEq( + lastRequest.callbackGasLimit, + expectedRequest.callbackGasLimit + ); + assertEq( + lastRequest.requester, + expectedRequest.requester, + "Requester mismatch" + ); + } + + function testRequestWithInsufficientFee() public { + // Set a realistic gas price + vm.txGasPrice(30 gwei); + + bytes32[] memory priceIds = createPriceIds(); + vm.deal(address(consumer), 1 ether); + vm.prank(address(consumer)); + vm.expectRevert(InsufficientFee.selector); + pulse.requestPriceUpdatesWithCallback{value: PYTH_FEE}( // Intentionally low fee + block.timestamp, + priceIds, + CALLBACK_GAS_LIMIT + ); + } + + function testExecuteCallback() public { + bytes32[] memory priceIds = createPriceIds(); + uint256 publishTime = block.timestamp; + + // Fund the consumer contract + vm.deal(address(consumer), 1 gwei); + uint128 totalFee = calculateTotalFee(); + + // Step 1: Make the request as consumer + vm.prank(address(consumer)); + uint64 sequenceNumber = pulse.requestPriceUpdatesWithCallback{ + value: totalFee + }(publishTime, priceIds, CALLBACK_GAS_LIMIT); + + // Step 2: Create mock price feeds and setup Pyth response + PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds( + publishTime + ); + mockParsePriceFeedUpdates(priceFeeds); + + // Create arrays for expected event data + int64[] memory expectedPrices = new int64[](2); + expectedPrices[0] = MOCK_BTC_PRICE; + expectedPrices[1] = MOCK_ETH_PRICE; + + uint64[] memory expectedConf = new uint64[](2); + expectedConf[0] = MOCK_BTC_CONF; + expectedConf[1] = MOCK_ETH_CONF; + + int32[] memory expectedExpos = new int32[](2); + expectedExpos[0] = MOCK_PRICE_FEED_EXPO; + expectedExpos[1] = MOCK_PRICE_FEED_EXPO; + + uint256[] memory expectedPublishTimes = new uint256[](2); + expectedPublishTimes[0] = publishTime; + expectedPublishTimes[1] = publishTime; + + // Expect the PriceUpdateExecuted event with all price data + vm.expectEmit(); + emit PriceUpdateExecuted( + sequenceNumber, + updater, + priceIds, + expectedPrices, + expectedConf, + expectedExpos, + expectedPublishTimes + ); + + // Create mock update data and execute callback + bytes[] memory updateData = createMockUpdateData(priceFeeds); + + vm.prank(updater); + pulse.executeCallback(sequenceNumber, updateData, priceIds); + + // Verify callback was executed + assertEq(consumer.lastSequenceNumber(), sequenceNumber); + + // Compare price feeds array length + PythStructs.PriceFeed[] memory lastFeeds = consumer.lastPriceFeeds(); + assertEq(lastFeeds.length, priceFeeds.length); + + // Compare each price feed + for (uint i = 0; i < priceFeeds.length; i++) { + assertEq(lastFeeds[i].id, priceFeeds[i].id); + assertEq(lastFeeds[i].price.price, priceFeeds[i].price.price); + assertEq(lastFeeds[i].price.conf, priceFeeds[i].price.conf); + assertEq(lastFeeds[i].price.expo, priceFeeds[i].price.expo); + assertEq( + lastFeeds[i].price.publishTime, + priceFeeds[i].price.publishTime + ); + } + } + + function testExecuteCallbackFailure() public { + FailingPulseConsumer failingConsumer = new FailingPulseConsumer(); + + ( + uint64 sequenceNumber, + bytes32[] memory priceIds, + uint256 publishTime + ) = setupConsumerRequest(address(failingConsumer)); + + PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds( + publishTime + ); + mockParsePriceFeedUpdates(priceFeeds); + bytes[] memory updateData = createMockUpdateData(priceFeeds); + + vm.expectEmit(); + emit PriceUpdateCallbackFailed( + sequenceNumber, + updater, + priceIds, + address(failingConsumer), + "callback failed" + ); + + vm.prank(updater); + pulse.executeCallback(sequenceNumber, updateData, priceIds); + } + + function testExecuteCallbackCustomErrorFailure() public { + CustomErrorPulseConsumer failingConsumer = new CustomErrorPulseConsumer(); + + ( + uint64 sequenceNumber, + bytes32[] memory priceIds, + uint256 publishTime + ) = setupConsumerRequest(address(failingConsumer)); + + PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds( + publishTime + ); + mockParsePriceFeedUpdates(priceFeeds); + bytes[] memory updateData = createMockUpdateData(priceFeeds); + + vm.expectEmit(); + emit PriceUpdateCallbackFailed( + sequenceNumber, + updater, + priceIds, + address(failingConsumer), + "low-level error (possibly out of gas)" + ); + + vm.prank(updater); + pulse.executeCallback(sequenceNumber, updateData, priceIds); + } + + function testExecuteCallbackWithInsufficientGas() public { + // Setup request with 1M gas limit + ( + uint64 sequenceNumber, + bytes32[] memory priceIds, + uint256 publishTime + ) = setupConsumerRequest(address(consumer)); + + // Setup mock data + PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds( + publishTime + ); + mockParsePriceFeedUpdates(priceFeeds); + bytes[] memory updateData = createMockUpdateData(priceFeeds); + + // Try executing with only 100K gas when 1M is required + vm.prank(updater); + vm.expectRevert(InsufficientGas.selector); + pulse.executeCallback{gas: 100000}( + sequenceNumber, + updateData, + priceIds + ); // Will fail because gasleft() < callbackGasLimit + } + + function testExecuteCallbackWithFutureTimestamp() public { + // Setup request with future timestamp + bytes32[] memory priceIds = createPriceIds(); + uint256 futureTime = block.timestamp + 1 days; + vm.deal(address(consumer), 1 gwei); + + uint128 totalFee = calculateTotalFee(); + vm.prank(address(consumer)); + uint64 sequenceNumber = pulse.requestPriceUpdatesWithCallback{ + value: totalFee + }(futureTime, priceIds, CALLBACK_GAS_LIMIT); + + // Try to execute callback before the requested timestamp + PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds( + futureTime // Mock price feeds with future timestamp + ); + mockParsePriceFeedUpdates(priceFeeds); // This will make parsePriceFeedUpdates return future-dated prices + bytes[] memory updateData = createMockUpdateData(priceFeeds); + + vm.prank(updater); + // Should succeed because we're simulating receiving future-dated price updates + pulse.executeCallback(sequenceNumber, updateData, priceIds); + + // Compare price feeds array length + PythStructs.PriceFeed[] memory lastFeeds = consumer.lastPriceFeeds(); + assertEq(lastFeeds.length, priceFeeds.length); + + // Compare each price feed publish time + for (uint i = 0; i < priceFeeds.length; i++) { + assertEq( + lastFeeds[i].price.publishTime, + priceFeeds[i].price.publishTime + ); + } + } + + function testDoubleExecuteCallback() public { + ( + uint64 sequenceNumber, + bytes32[] memory priceIds, + uint256 publishTime + ) = setupConsumerRequest(address(consumer)); + + PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds( + publishTime + ); + mockParsePriceFeedUpdates(priceFeeds); + bytes[] memory updateData = createMockUpdateData(priceFeeds); + + // First execution + vm.prank(updater); + pulse.executeCallback(sequenceNumber, updateData, priceIds); + + // Second execution should fail + vm.prank(updater); + vm.expectRevert(NoSuchRequest.selector); + pulse.executeCallback(sequenceNumber, updateData, priceIds); + } + + function testGetFee() public { + // Test with different gas limits to verify fee calculation + uint256[] memory gasLimits = new uint256[](3); + gasLimits[0] = 100_000; + gasLimits[1] = 500_000; + gasLimits[2] = 1_000_000; + + for (uint256 i = 0; i < gasLimits.length; i++) { + uint256 gasLimit = gasLimits[i]; + uint128 expectedFee = SafeCast.toUint128(tx.gasprice * gasLimit) + + PYTH_FEE; + uint128 actualFee = pulse.getFee(gasLimit); + assertEq( + actualFee, + expectedFee, + "Fee calculation incorrect for gas limit" + ); + } + + // Test with zero gas limit + uint128 expectedMinFee = PYTH_FEE; + uint128 actualMinFee = pulse.getFee(0); + assertEq( + actualMinFee, + expectedMinFee, + "Minimum fee calculation incorrect" + ); + } + + function testWithdrawFees() public { + // Setup: Request price update to accrue some fees + bytes32[] memory priceIds = createPriceIds(); + vm.deal(address(consumer), 1 gwei); + + vm.prank(address(consumer)); + pulse.requestPriceUpdatesWithCallback{value: calculateTotalFee()}( + block.timestamp, + priceIds, + CALLBACK_GAS_LIMIT + ); + + // Get admin's balance before withdrawal + uint256 adminBalanceBefore = admin.balance; + uint128 accruedFees = pulse.getAccruedFees(); + + // Withdraw fees as admin + vm.prank(admin); + pulse.withdrawFees(accruedFees); + + // Verify balances + assertEq( + admin.balance, + adminBalanceBefore + accruedFees, + "Admin balance should increase by withdrawn amount" + ); + assertEq( + pulse.getAccruedFees(), + 0, + "Contract should have no fees after withdrawal" + ); + } + + function testWithdrawFeesUnauthorized() public { + vm.prank(address(0xdead)); + vm.expectRevert("Only admin can withdraw fees"); + pulse.withdrawFees(1 ether); + } + + function testWithdrawFeesInsufficientBalance() public { + vm.prank(admin); + vm.expectRevert("Insufficient balance"); + pulse.withdrawFees(1 ether); + } + + function testSetAndWithdrawAsFeeManager() public { + address feeManager = address(0x789); + + // Set fee manager as admin + vm.prank(admin); + pulse.setFeeManager(feeManager); + + // Setup: Request price update to accrue some fees + bytes32[] memory priceIds = createPriceIds(); + vm.deal(address(consumer), 1 gwei); + + vm.prank(address(consumer)); + pulse.requestPriceUpdatesWithCallback{value: calculateTotalFee()}( + block.timestamp, + priceIds, + CALLBACK_GAS_LIMIT + ); + + // Test withdrawal as fee manager + uint256 managerBalanceBefore = feeManager.balance; + uint128 accruedFees = pulse.getAccruedFees(); + + vm.prank(feeManager); + pulse.withdrawAsFeeManager(accruedFees); + + assertEq( + feeManager.balance, + managerBalanceBefore + accruedFees, + "Fee manager balance should increase by withdrawn amount" + ); + assertEq( + pulse.getAccruedFees(), + 0, + "Contract should have no fees after withdrawal" + ); + } + + function testSetFeeManagerUnauthorized() public { + address feeManager = address(0x789); + vm.prank(address(0xdead)); + vm.expectRevert("Only admin can set fee manager"); + pulse.setFeeManager(feeManager); + } + + function testWithdrawAsFeeManagerUnauthorized() public { + vm.prank(address(0xdead)); + vm.expectRevert("Only fee manager"); + pulse.withdrawAsFeeManager(1 ether); + } + + function testWithdrawAsFeeManagerInsufficientBalance() public { + // Set up fee manager first + address feeManager = address(0x789); + vm.prank(admin); + pulse.setFeeManager(feeManager); + + vm.prank(feeManager); + vm.expectRevert("Insufficient balance"); + pulse.withdrawAsFeeManager(1 ether); + } + + // Add new test for invalid priceIds + function testExecuteCallbackWithInvalidPriceIds() public { + bytes32[] memory priceIds = createPriceIds(); + uint256 publishTime = block.timestamp; + + // Setup request + (uint64 sequenceNumber, , ) = setupConsumerRequest(address(consumer)); + + // Create different priceIds + bytes32[] memory wrongPriceIds = new bytes32[](2); + wrongPriceIds[0] = bytes32(uint256(1)); // Different price IDs + wrongPriceIds[1] = bytes32(uint256(2)); + + PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds( + publishTime + ); + mockParsePriceFeedUpdates(priceFeeds); + bytes[] memory updateData = createMockUpdateData(priceFeeds); + + // Calculate hashes for both arrays + bytes32 providedPriceIdsHash = keccak256(abi.encode(wrongPriceIds)); + bytes32 storedPriceIdsHash = keccak256(abi.encode(priceIds)); + + // Should revert when trying to execute with wrong priceIds + vm.prank(updater); + vm.expectRevert( + abi.encodeWithSelector( + InvalidPriceIds.selector, + providedPriceIdsHash, + storedPriceIdsHash + ) + ); + pulse.executeCallback(sequenceNumber, updateData, wrongPriceIds); + } +}