Skip to content

Commit

Permalink
fix: cairo precompile revert for child contexts (#1378)
Browse files Browse the repository at this point in the history
<!--- Please provide a general summary of your changes in the title
above -->

<!-- Give an estimate of the time you spent on this PR in terms of work
days.
Did you spend 0.5 days on this PR or rather 2 days?  -->

Time spent on this PR:

## Pull request type

<!-- Please try to limit your pull request to one type,
submit multiple pull requests if needed. -->

Please check the type of change your PR introduces:

- [x] Bugfix
- [ ] Feature
- [ ] Code style update (formatting, renaming)
- [ ] Refactoring (no functional changes, no api changes)
- [ ] Build related changes
- [ ] Documentation content changes
- [ ] Other (please describe):

## What is the current behavior?

<!-- Please describe the current behavior that you are modifying,
or link to a relevant issue. -->

Resolves #1377

## What is the new behavior?

<!-- Please describe the behavior or changes that are being added by
this PR. -->

- Revert cairo tx when evm revert
- cairo_precompile_called ORed when propagated to parent

<!-- Reviewable:start -->
- - -
This change is [<img src="https://reviewable.io/review_button.svg"
height="34" align="absmiddle"
alt="Reviewable"/>](https://reviewable.io/reviews/kkrt-labs/kakarot/1378)
<!-- Reviewable:end -->

---------

Co-authored-by: Clément Walter <[email protected]>
  • Loading branch information
obatirou and ClementWalter authored Sep 3, 2024
1 parent b2f8e18 commit e9b3f95
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 8 deletions.
36 changes: 36 additions & 0 deletions solidity_contracts/src/CairoPrecompiles/SubContextPrecompile.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// SPDX-License-Identifier: MIT
pragma solidity >=0.7.0 <0.9.0;

import {CairoCounterCaller} from "./CairoCounterCaller.sol";

contract SubContextPrecompile {
RevertingSubContext immutable revertingSubContext;

constructor(address _cairo_counter_caller) {
revertingSubContext = new RevertingSubContext(_cairo_counter_caller);
}

function exploitLowLevelCall() public {
(bool success,) = address(revertingSubContext).call(abi.encodeWithSignature("reverting()"));
}

function exploitChildContext() public {
revertingSubContext.reverting();
}
}

contract RevertingSubContext {
CairoCounterCaller immutable cairo_counter_caller;
uint256 dummyCounter;

constructor(address _cairo_counter_caller) {
cairo_counter_caller = CairoCounterCaller(_cairo_counter_caller);
}

function reverting() public {
dummyCounter = 1;
cairo_counter_caller.incrementCairoCounter();
// force a revert after a call to a cairo precompile in a subcontext
require(false);
}
}
30 changes: 27 additions & 3 deletions src/kakarot/instructions/system_operations.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -947,6 +947,10 @@ namespace CallHelper {
code_account, evm.message.valid_jumpdests_start, evm.message.valid_jumpdests
);
State.update_account(code_account);

let cairo_precompile_called = evm.message.cairo_precompile_called +
evm.message.parent.evm.message.cairo_precompile_called;

tempvar message = new model.Message(
bytecode=evm.message.parent.evm.message.bytecode,
bytecode_len=evm.message.parent.evm.message.bytecode_len,
Expand All @@ -963,8 +967,17 @@ namespace CallHelper {
is_create=evm.message.parent.evm.message.is_create,
depth=evm.message.parent.evm.message.depth,
env=evm.message.parent.evm.message.env,
cairo_precompile_called=evm.message.cairo_precompile_called,
cairo_precompile_called=cairo_precompile_called,
);

if (evm.reverted != FALSE) {
// If a call to a cairo precompile has been made, the tx should be reverted
with_attr error_message(
"EVM tx reverted, reverting SN tx because of previous calls to cairo precompiles") {
assert cairo_precompile_called = FALSE;
}
}

if (evm.reverted == Errors.EXCEPTIONAL_HALT) {
// If the call has halted exceptionnaly, the return_data is empty
// and nothing is copied to memory, and the gas is not returned;
Expand Down Expand Up @@ -1165,6 +1178,9 @@ namespace CreateHelper {
}(evm: model.EVM*) -> model.EVM* {
alloc_locals;

let cairo_precompile_called = evm.message.cairo_precompile_called +
evm.message.parent.evm.message.cairo_precompile_called;

tempvar message = new model.Message(
bytecode=evm.message.parent.evm.message.bytecode,
bytecode_len=evm.message.parent.evm.message.bytecode_len,
Expand All @@ -1181,10 +1197,14 @@ namespace CreateHelper {
is_create=evm.message.parent.evm.message.is_create,
depth=evm.message.parent.evm.message.depth,
env=evm.message.parent.evm.message.env,
cairo_precompile_called=evm.message.cairo_precompile_called,
cairo_precompile_called=cairo_precompile_called,
);
// Reverted during execution - either REVERT or exceptional
if (evm.reverted != FALSE) {
with_attr error_message(
"EVM tx reverted, reverting SN tx because of previous calls to cairo precompiles") {
assert cairo_precompile_called = FALSE;
}
let is_exceptional_revert = is_not_zero(Errors.REVERT - evm.reverted);
let return_data_len = (1 - is_exceptional_revert) * evm.return_data_len;
let gas_left = evm.message.parent.evm.gas_left + (1 - is_exceptional_revert) *
Expand Down Expand Up @@ -1231,6 +1251,10 @@ namespace CreateHelper {

if (success == FALSE) {
tempvar state = evm.message.parent.state;
with_attr error_message(
"EVM tx reverted, reverting SN tx because of previous calls to cairo precompiles") {
assert cairo_precompile_called = FALSE;
}

tempvar evm = new model.EVM(
message=message,
Expand All @@ -1253,7 +1277,7 @@ namespace CreateHelper {
State.update_account(account);

tempvar evm = new model.EVM(
message=evm.message.parent.evm.message,
message=message,
return_data_len=0,
return_data=evm.return_data,
program_counter=evm.message.parent.evm.program_counter + 1,
Expand Down
31 changes: 29 additions & 2 deletions tests/end_to_end/CairoPrecompiles/test_cairo_precompiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from tests.utils.errors import cairo_error


@pytest_asyncio.fixture()
@pytest_asyncio.fixture(scope="module")
async def cairo_counter(max_fee, deployer):
cairo_counter = get_contract("Counter", provider=deployer)

Expand All @@ -16,7 +16,7 @@ async def cairo_counter(max_fee, deployer):
await wait_for_transaction(tx.hash)


@pytest_asyncio.fixture()
@pytest_asyncio.fixture(scope="module")
async def cairo_counter_caller(owner, cairo_counter):
caller_contract = await deploy(
"CairoPrecompiles",
Expand All @@ -34,6 +34,17 @@ async def cairo_counter_caller(owner, cairo_counter):
return caller_contract


@pytest_asyncio.fixture(scope="module")
async def sub_context_precompile(owner, cairo_counter_caller):
sub_context_precompile = await deploy(
"CairoPrecompiles",
"SubContextPrecompile",
cairo_counter_caller.address,
caller_eoa=owner.starknet_contract,
)
return sub_context_precompile


@pytest.mark.asyncio(scope="module")
@pytest.mark.CairoPrecompiles
class TestCairoPrecompiles:
Expand Down Expand Up @@ -79,3 +90,19 @@ async def test_last_caller_address_should_be_eoa(self, cairo_counter_caller):
await cairo_counter_caller.incrementCairoCounter(caller_eoa=eoa)
last_caller_address = await cairo_counter_caller.getLastCaller()
assert last_caller_address == eoa.address

async def test_should_fail_when_precompiles_called_and_low_level_call_fails(
self, sub_context_precompile
):
with cairo_error(
"EVM tx reverted, reverting SN tx because of previous calls to cairo precompiles"
):
await sub_context_precompile.exploitLowLevelCall()

async def test_should_fail_when_precompiles_called_and_child_context_fails(
self, sub_context_precompile
):
with cairo_error(
"EVM tx reverted, reverting SN tx because of previous calls to cairo precompiles"
):
await sub_context_precompile.exploitChildContext()
6 changes: 3 additions & 3 deletions tests/end_to_end/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,12 @@ async def eth(deployer) -> Contract:
return await get_eth_contract(provider=deployer)


@pytest_asyncio.fixture(scope="session")
async def cairo_counter(deployer) -> Contract:
@pytest.fixture(scope="session")
def cairo_counter(deployer) -> Contract:
"""
Return a cached version of the cairo_counter contract.
"""
return await get_contract("Counter", provider=deployer)
return get_contract("Counter", provider=deployer)


@pytest.fixture(scope="session")
Expand Down

0 comments on commit e9b3f95

Please sign in to comment.