Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
cctdaniel committed Jan 2, 2025
1 parent d4caa71 commit 62dd571
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 35 deletions.
4 changes: 2 additions & 2 deletions target_chains/ethereum/contracts/contracts/pulse/IPulse.sol
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@

pragma solidity ^0.8.0;

import "@pythnetwork/pyth-sdk-solidity/IPyth.sol";
import "./PulseEvents.sol";
import "./PulseState.sol";

interface IPulseConsumer {
function pulseCallback(
uint64 sequenceNumber,
address updater,
uint256 publishTime,
bytes32[] calldata priceIds
PythStructs.PriceFeed[] memory priceFeeds
) external;
}

Expand Down
22 changes: 9 additions & 13 deletions target_chains/ethereum/contracts/contracts/pulse/Pulse.sol
Original file line number Diff line number Diff line change
Expand Up @@ -63,18 +63,15 @@ abstract contract Pulse is IPulse, PulseState {
bytes32[] calldata priceIds
) external payable override {
Request storage req = findActiveRequest(sequenceNumber);

// Verify priceIds match
bytes32 providedPriceIdsHash = keccak256(abi.encode(priceIds));
bytes32 storedPriceIdsHash = req.priceIdsHash;

if (providedPriceIdsHash != storedPriceIdsHash) {
revert InvalidPriceIds(providedPriceIdsHash, storedPriceIdsHash);
}

// Check if there's enough gas left for the callback
if (gasleft() < req.callbackGasLimit) {
revert InsufficientGas();
}

// Parse price feeds first to measure gas usage
PythStructs.PriceFeed[] memory priceFeeds = IPyth(_state.pyth)
.parsePriceFeedUpdates(
updateData,
Expand All @@ -83,21 +80,23 @@ abstract contract Pulse is IPulse, PulseState {
SafeCast.toUint64(req.publishTime)
);

uint256 publishTime = priceFeeds[0].price.publishTime;
// Check if enough gas remains for the callback
if (gasleft() < req.callbackGasLimit) {
revert InsufficientGas();
}

try
IPulseConsumer(req.requester).pulseCallback{
gas: req.callbackGasLimit
}(sequenceNumber, msg.sender, publishTime, priceIds)
}(sequenceNumber, msg.sender, priceFeeds)
{
// Callback succeeded
emitPriceUpdate(sequenceNumber, publishTime, priceIds, priceFeeds);
emitPriceUpdate(sequenceNumber, priceIds, priceFeeds);
} catch Error(string memory reason) {
// Explicit revert/require
emit PriceUpdateCallbackFailed(
sequenceNumber,
msg.sender,
publishTime,
priceIds,
req.requester,
reason
Expand All @@ -107,7 +106,6 @@ abstract contract Pulse is IPulse, PulseState {
emit PriceUpdateCallbackFailed(
sequenceNumber,
msg.sender,
publishTime,
priceIds,
req.requester,
"low-level error (possibly out of gas)"
Expand All @@ -119,7 +117,6 @@ abstract contract Pulse is IPulse, PulseState {

function emitPriceUpdate(
uint64 sequenceNumber,
uint256 publishTime,
bytes32[] memory priceIds,
PythStructs.PriceFeed[] memory priceFeeds
) internal {
Expand All @@ -138,7 +135,6 @@ abstract contract Pulse is IPulse, PulseState {
emit PriceUpdateExecuted(
sequenceNumber,
msg.sender,
publishTime,
priceIds,
prices,
conf,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ interface PulseEvents {
event PriceUpdateExecuted(
uint64 indexed sequenceNumber,
address indexed updater,
uint256 publishTime,
bytes32[] priceIds,
int64[] prices,
uint64[] conf,
Expand All @@ -22,7 +21,6 @@ interface PulseEvents {
event PriceUpdateCallbackFailed(
uint64 indexed sequenceNumber,
address indexed updater,
uint256 publishTime,
bytes32[] priceIds,
address requester,
string reason
Expand Down
67 changes: 49 additions & 18 deletions target_chains/ethereum/contracts/forge-test/Pulse.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
pragma solidity ^0.8.0;

import "forge-std/Test.sol";
import "@pythnetwork/pyth-sdk-solidity/IPyth.sol";
import "@openzeppelin/contracts/proxy/ERC1967/ERC1967Proxy.sol";
import "../contracts/pulse/PulseUpgradeable.sol";
import "../contracts/pulse/IPulse.sol";
Expand All @@ -13,28 +14,34 @@ import "../contracts/pulse/PulseErrors.sol";
contract MockPulseConsumer is IPulseConsumer {
uint64 public lastSequenceNumber;
address public lastUpdater;
uint256 public lastPublishTime;
bytes32[] public lastPriceIds;
PythStructs.PriceFeed[] private _lastPriceFeeds;

function pulseCallback(
uint64 sequenceNumber,
address updater,
uint256 publishTime,
bytes32[] calldata priceIds
PythStructs.PriceFeed[] memory priceFeeds
) external override {
lastSequenceNumber = sequenceNumber;
lastUpdater = updater;
lastPublishTime = publishTime;
lastPriceIds = priceIds;
for (uint i = 0; i < priceFeeds.length; i++) {
_lastPriceFeeds.push(priceFeeds[i]);
}
}

function lastPriceFeeds()
external
view
returns (PythStructs.PriceFeed[] memory)
{
return _lastPriceFeeds;
}
}

contract FailingPulseConsumer is IPulseConsumer {
function pulseCallback(
uint64,
address,
uint256,
bytes32[] calldata
PythStructs.PriceFeed[] memory
) external pure override {
revert("callback failed");
}
Expand All @@ -46,8 +53,7 @@ contract CustomErrorPulseConsumer is IPulseConsumer {
function pulseCallback(
uint64,
address,
uint256,
bytes32[] calldata
PythStructs.PriceFeed[] memory
) external pure override {
revert CustomError("callback failed");
}
Expand Down Expand Up @@ -278,7 +284,6 @@ contract PulseTest is Test, PulseEvents {
emit PriceUpdateExecuted(
sequenceNumber,
updater,
publishTime,
priceIds,
expectedPrices,
expectedConf,
Expand All @@ -294,7 +299,22 @@ contract PulseTest is Test, PulseEvents {

// Verify callback was executed
assertEq(consumer.lastSequenceNumber(), sequenceNumber);
assertEq(consumer.lastPublishTime(), publishTime);

// Compare price feeds array length
PythStructs.PriceFeed[] memory lastFeeds = consumer.lastPriceFeeds();
assertEq(lastFeeds.length, priceFeeds.length);

// Compare each price feed
for (uint i = 0; i < priceFeeds.length; i++) {
assertEq(lastFeeds[i].id, priceFeeds[i].id);
assertEq(lastFeeds[i].price.price, priceFeeds[i].price.price);
assertEq(lastFeeds[i].price.conf, priceFeeds[i].price.conf);
assertEq(lastFeeds[i].price.expo, priceFeeds[i].price.expo);
assertEq(
lastFeeds[i].price.publishTime,
priceFeeds[i].price.publishTime
);
}
}

function testExecuteCallbackFailure() public {
Expand All @@ -316,7 +336,6 @@ contract PulseTest is Test, PulseEvents {
emit PriceUpdateCallbackFailed(
sequenceNumber,
updater,
publishTime,
priceIds,
address(failingConsumer),
"callback failed"
Expand Down Expand Up @@ -345,7 +364,6 @@ contract PulseTest is Test, PulseEvents {
emit PriceUpdateCallbackFailed(
sequenceNumber,
updater,
publishTime,
priceIds,
address(failingConsumer),
"low-level error (possibly out of gas)"
Expand All @@ -370,10 +388,14 @@ contract PulseTest is Test, PulseEvents {
mockParsePriceFeedUpdates(priceFeeds);
bytes[] memory updateData = createMockUpdateData(priceFeeds);

// Try executing with only 10K gas when 1M is required
// Try executing with only 100K gas when 1M is required
vm.prank(updater);
vm.expectRevert(InsufficientGas.selector);
pulse.executeCallback{gas: 10000}(sequenceNumber, updateData, priceIds); // Will fail because gasleft() < callbackGasLimit
pulse.executeCallback{gas: 100000}(
sequenceNumber,
updateData,
priceIds
); // Will fail because gasleft() < callbackGasLimit
}

function testExecuteCallbackWithFutureTimestamp() public {
Expand All @@ -399,8 +421,17 @@ contract PulseTest is Test, PulseEvents {
// Should succeed because we're simulating receiving future-dated price updates
pulse.executeCallback(sequenceNumber, updateData, priceIds);

// Verify the callback was executed with future timestamp
assertEq(consumer.lastPublishTime(), futureTime);
// Compare price feeds array length
PythStructs.PriceFeed[] memory lastFeeds = consumer.lastPriceFeeds();
assertEq(lastFeeds.length, priceFeeds.length);

// Compare each price feed publish time
for (uint i = 0; i < priceFeeds.length; i++) {
assertEq(
lastFeeds[i].price.publishTime,
priceFeeds[i].price.publishTime
);
}
}

function testDoubleExecuteCallback() public {
Expand Down

0 comments on commit 62dd571

Please sign in to comment.