From aac64f96638af28f4d63882935ba949650bdbd73 Mon Sep 17 00:00:00 2001 From: Oba Date: Tue, 3 Sep 2024 11:57:33 +0200 Subject: [PATCH] fix: cairo precompile revert for child contexts --- .../CairoPrecompiles/SubContextPrecompile.sol | 36 +++++++++++++++++++ .../instructions/system_operations.cairo | 30 ++++++++++++++-- .../test_cairo_precompiles.py | 31 ++++++++++++++-- 3 files changed, 92 insertions(+), 5 deletions(-) create mode 100644 solidity_contracts/src/CairoPrecompiles/SubContextPrecompile.sol diff --git a/solidity_contracts/src/CairoPrecompiles/SubContextPrecompile.sol b/solidity_contracts/src/CairoPrecompiles/SubContextPrecompile.sol new file mode 100644 index 000000000..33934cfca --- /dev/null +++ b/solidity_contracts/src/CairoPrecompiles/SubContextPrecompile.sol @@ -0,0 +1,36 @@ +// SPDX-License-Identifier: MIT +pragma solidity >=0.7.0 <0.9.0; + +import {CairoCounterCaller} from "./CairoCounterCaller.sol"; + +contract SubContextPrecompile { + RevertingSubContext immutable revertingSubContext; + + constructor(address _cairo_counter_caller) { + revertingSubContext = new RevertingSubContext(_cairo_counter_caller); + } + + function exploitLowLevelCall() public { + (bool success,) = address(revertingSubContext).call(abi.encodeWithSignature("reverting()")); + } + + function exploitChildContext() public { + revertingSubContext.reverting(); + } +} + +contract RevertingSubContext { + CairoCounterCaller immutable cairo_counter_caller; + uint256 dummyCounter; + + constructor(address _cairo_counter_caller) { + cairo_counter_caller = CairoCounterCaller(_cairo_counter_caller); + } + + function reverting() public { + dummyCounter = 1; + cairo_counter_caller.incrementCairoCounter(); + // force a revert after a call to a cairo precompile in a subcontext + require(false); + } +} diff --git a/src/kakarot/instructions/system_operations.cairo b/src/kakarot/instructions/system_operations.cairo index 13d9dcfca..db4512e77 100644 --- a/src/kakarot/instructions/system_operations.cairo +++ b/src/kakarot/instructions/system_operations.cairo @@ -947,6 +947,10 @@ namespace CallHelper { code_account, evm.message.valid_jumpdests_start, evm.message.valid_jumpdests ); State.update_account(code_account); + + let cairo_precompile_called = evm.message.cairo_precompile_called + + evm.message.parent.evm.message.cairo_precompile_called; + tempvar message = new model.Message( bytecode=evm.message.parent.evm.message.bytecode, bytecode_len=evm.message.parent.evm.message.bytecode_len, @@ -963,8 +967,17 @@ namespace CallHelper { is_create=evm.message.parent.evm.message.is_create, depth=evm.message.parent.evm.message.depth, env=evm.message.parent.evm.message.env, - cairo_precompile_called=evm.message.cairo_precompile_called, + cairo_precompile_called=cairo_precompile_called, ); + + if (evm.reverted != FALSE) { + // If a call to a cairo precompile has been made, the tx should be reverted + with_attr error_message( + "EVM tx reverted, reverting SN tx because of previous calls to cairo precompiles") { + assert cairo_precompile_called = FALSE; + } + } + if (evm.reverted == Errors.EXCEPTIONAL_HALT) { // If the call has halted exceptionnaly, the return_data is empty // and nothing is copied to memory, and the gas is not returned; @@ -1165,6 +1178,9 @@ namespace CreateHelper { }(evm: model.EVM*) -> model.EVM* { alloc_locals; + let cairo_precompile_called = evm.message.cairo_precompile_called + + evm.message.parent.evm.message.cairo_precompile_called; + tempvar message = new model.Message( bytecode=evm.message.parent.evm.message.bytecode, bytecode_len=evm.message.parent.evm.message.bytecode_len, @@ -1181,10 +1197,14 @@ namespace CreateHelper { is_create=evm.message.parent.evm.message.is_create, depth=evm.message.parent.evm.message.depth, env=evm.message.parent.evm.message.env, - cairo_precompile_called=evm.message.cairo_precompile_called, + cairo_precompile_called=cairo_precompile_called, ); // Reverted during execution - either REVERT or exceptional if (evm.reverted != FALSE) { + with_attr error_message( + "EVM tx reverted, reverting SN tx because of previous calls to cairo precompiles") { + assert cairo_precompile_called = FALSE; + } let is_exceptional_revert = is_not_zero(Errors.REVERT - evm.reverted); let return_data_len = (1 - is_exceptional_revert) * evm.return_data_len; let gas_left = evm.message.parent.evm.gas_left + (1 - is_exceptional_revert) * @@ -1231,6 +1251,10 @@ namespace CreateHelper { if (success == FALSE) { tempvar state = evm.message.parent.state; + with_attr error_message( + "EVM tx reverted, reverting SN tx because of previous calls to cairo precompiles") { + assert cairo_precompile_called = FALSE; + } tempvar evm = new model.EVM( message=message, @@ -1253,7 +1277,7 @@ namespace CreateHelper { State.update_account(account); tempvar evm = new model.EVM( - message=evm.message.parent.evm.message, + message=message, return_data_len=0, return_data=evm.return_data, program_counter=evm.message.parent.evm.program_counter + 1, diff --git a/tests/end_to_end/CairoPrecompiles/test_cairo_precompiles.py b/tests/end_to_end/CairoPrecompiles/test_cairo_precompiles.py index e9a73299b..390098b56 100644 --- a/tests/end_to_end/CairoPrecompiles/test_cairo_precompiles.py +++ b/tests/end_to_end/CairoPrecompiles/test_cairo_precompiles.py @@ -6,7 +6,7 @@ from tests.utils.errors import cairo_error -@pytest_asyncio.fixture() +@pytest_asyncio.fixture(scope="module") async def cairo_counter(max_fee, deployer): cairo_counter = get_contract("Counter", provider=deployer) @@ -16,7 +16,7 @@ async def cairo_counter(max_fee, deployer): await wait_for_transaction(tx.hash) -@pytest_asyncio.fixture() +@pytest_asyncio.fixture(scope="module") async def cairo_counter_caller(owner, cairo_counter): caller_contract = await deploy( "CairoPrecompiles", @@ -34,6 +34,17 @@ async def cairo_counter_caller(owner, cairo_counter): return caller_contract +@pytest_asyncio.fixture(scope="module") +async def sub_context_precompile(owner, cairo_counter_caller): + sub_context_precompile = await deploy( + "CairoPrecompiles", + "SubContextPrecompile", + cairo_counter_caller.address, + caller_eoa=owner.starknet_contract, + ) + return sub_context_precompile + + @pytest.mark.asyncio(scope="module") @pytest.mark.CairoPrecompiles class TestCairoPrecompiles: @@ -79,3 +90,19 @@ async def test_last_caller_address_should_be_eoa(self, cairo_counter_caller): await cairo_counter_caller.incrementCairoCounter(caller_eoa=eoa) last_caller_address = await cairo_counter_caller.getLastCaller() assert last_caller_address == eoa.address + + async def test_should_fail_when_precompiles_called_and_low_level_call_fails( + self, sub_context_precompile + ): + with cairo_error( + "EVM tx reverted, reverting SN tx because of previous calls to cairo precompiles" + ): + await sub_context_precompile.exploitLowLevelCall() + + async def test_should_fail_when_precompiles_called_and_child_context_fails( + self, sub_context_precompile + ): + with cairo_error( + "EVM tx reverted, reverting SN tx because of previous calls to cairo precompiles" + ): + await sub_context_precompile.exploitChildContext()