diff --git a/.gitignore b/.gitignore index be04130ad..72eab5c86 100644 --- a/.gitignore +++ b/.gitignore @@ -26,3 +26,4 @@ resources/* .hypothesis *.pb.gz +profile*.png diff --git a/blockchain-tests-skip.yml b/blockchain-tests-skip.yml index 196e482bb..5ccac05f7 100644 --- a/blockchain-tests-skip.yml +++ b/blockchain-tests-skip.yml @@ -2452,7 +2452,6 @@ testname: - stackOverflowSWAP_d0g0v0_Shanghai - stacksanitySWAP_d0g0v0_Shanghai stStaticCall: - - RevertOpcodeCalls_d0g1v0_Shanghai - StaticcallToPrecompileFromCalledContract_d0g0v0_Shanghai - StaticcallToPrecompileFromContractInitialization_d0g0v0_Shanghai - StaticcallToPrecompileFromTransaction_d0g0v0_Shanghai @@ -2538,6 +2537,7 @@ testname: - static_callWithHighValueAndGasOOG_d1g0v0_Shanghai - static_call_value_inherit_d0g0v0_Shanghai - static_call_value_inherit_from_call_d0g0v0_Shanghai + - static_RevertOpcodeCalls_d0g1v0_Shanghai stStaticFlagEnabled: - CallWithNOTZeroValueToPrecompileFromCalledContract_d0g0v0_Shanghai - CallWithNOTZeroValueToPrecompileFromCalledContract_d1g0v0_Shanghai diff --git a/src/kakarot/accounts/contract/library.cairo b/src/kakarot/accounts/contract/library.cairo index b2ab3c76c..d8eb13576 100644 --- a/src/kakarot/accounts/contract/library.cairo +++ b/src/kakarot/accounts/contract/library.cairo @@ -1,26 +1,25 @@ -// SPDX-License-Identifier: MIT - %lang starknet -// Starkware dependencies from openzeppelin.access.ownable.library import Ownable from starkware.cairo.common.alloc import alloc from starkware.cairo.common.bool import FALSE from starkware.cairo.common.cairo_builtins import HashBuiltin, BitwiseBuiltin -from starkware.cairo.common.math import unsigned_div_rem +from starkware.cairo.common.math import unsigned_div_rem, split_int from starkware.cairo.common.registers import get_label_location from starkware.cairo.common.uint256 import Uint256, uint256_not -from starkware.starknet.common.syscalls import storage_read, storage_write +from starkware.starknet.common.syscalls import ( + StorageRead, + StorageWrite, + STORAGE_READ_SELECTOR, + STORAGE_WRITE_SELECTOR, + storage_read, + storage_write, + StorageReadRequest, +) from starkware.cairo.common.memset import memset from kakarot.interfaces.interfaces import IERC20, IKakarot -// Storage - -@storage_var -func bytecode_(index: felt) -> (res: felt) { -} - @storage_var func bytecode_len_() -> (res: felt) { } @@ -41,13 +40,12 @@ func evm_address() -> (evm_address: felt) { func nonce() -> (nonce: felt) { } +// Define the number of bytes per felt +const BYTES_PER_FELT = 31; + // @title ContractAccount main library file. // @notice This file contains the EVM smart contract account representation logic. namespace ContractAccount { - // Define the number of bytes per felt. Above 16, the following code won't work as it uses unsigned_div_rem - // which is bounded by RC_BOUND = 2 ** 128 ~ uint128 ~ bytes16 - const BYTES_PER_FELT = 16; - // @notice This function is used to initialize the smart contract account. func initialize{ syscall_ptr: felt*, @@ -85,17 +83,12 @@ namespace ContractAccount { range_check_ptr, bitwise_ptr: BitwiseBuiltin*, }(bytecode_len: felt, bytecode: felt*) { + alloc_locals; // Access control check. Ownable.assert_only_owner(); // Recursively store the bytecode. bytecode_len_.write(bytecode_len); - internal.write_bytecode( - index=0, - bytecode_len=bytecode_len, - bytecode=bytecode, - current_felt=0, - remaining_shift=BYTES_PER_FELT, - ); + internal.write_bytecode(bytecode_len=bytecode_len, bytecode=bytecode); return (); } @@ -117,17 +110,8 @@ namespace ContractAccount { bitwise_ptr: BitwiseBuiltin*, }() -> (bytecode_len: felt, bytecode: felt*) { alloc_locals; - // Read bytecode length from storage. let (bytecode_len) = bytecode_len_.read(); - // Recursively load bytecode into specified memory location. - let bytecode_: felt* = alloc(); - internal.load_bytecode( - index=0, - bytecode_len=bytecode_len, - bytecode=bytecode_, - current_felt=0, - remaining_shift=0, - ); + let (bytecode_) = internal.load_bytecode(bytecode_len); return (bytecode_len, bytecode_); } @@ -224,107 +208,169 @@ namespace ContractAccount { } namespace internal { - // Use a precomputed 2 ** n array to save on resources usage. - // Array starts with a 0 to be shifted and have pow[i] = bit shift for byte i with - // i as a counter, ie i \in (0, BYTES_PER_FELT] - pow_: - dw 0; - dw 1; - dw 2 ** 8; - dw 2 ** 16; - dw 2 ** 24; - dw 2 ** 32; - dw 2 ** 40; - dw 2 ** 48; - dw 2 ** 56; - dw 2 ** 64; - dw 2 ** 72; - dw 2 ** 80; - dw 2 ** 88; - dw 2 ** 96; - dw 2 ** 104; - dw 2 ** 112; - dw 2 ** 120; - // @notice Store the bytecode of the contract. // @param index The current free index in the bytecode_ storage. // @param bytecode_len The length of the bytecode. // @param bytecode The bytecode of the contract. - func write_bytecode{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}( - index: felt, bytecode_len: felt, bytecode: felt*, current_felt: felt, remaining_shift: felt - ) { + func write_bytecode{syscall_ptr: felt*}(bytecode_len: felt, bytecode: felt*) { alloc_locals; if (bytecode_len == 0) { - // end of bytecode case, break loop storing latest "pending" packed felt - bytecode_.write(index, current_felt); return (); } - if (remaining_shift == 0) { - // end of packed felt case, store current "pending" felt - // continue loop with a new current_felt and increment index in bytecode_ storage - bytecode_.write(index, current_felt); - return write_bytecode( - index + 1, bytecode_len, bytecode, 0, ContractAccount.BYTES_PER_FELT - ); - } - - // retrieve the precomputed pow array - let (pow_address) = get_label_location(pow_); - let pow = cast(pow_address, felt*); - - // shift the current byte and add it to the current felt - // bytes are stored big endian, ie that 3 bytes ends up being stored as a felt whose representation is 0xabcdef000...000 - // for a given remaining_shift: - // current_felt = 0x 12 34 00 00...00 00 - // bytecode = 0x 56 - // pow[remaining_shift] * bytecode = 0x 00 00 56 00...00 00 - // resulting in 0x 12 34 56 00...00 00 - let current_felt = pow[remaining_shift] * [bytecode] + current_felt; - - return write_bytecode( - index, bytecode_len - 1, bytecode + 1, current_felt, remaining_shift - 1 + tempvar value = 0; + tempvar address = 0; + tempvar syscall_ptr = syscall_ptr; + tempvar bytecode_len = bytecode_len; + tempvar count = BYTES_PER_FELT; + + body: + let value = [ap - 5]; + let address = [ap - 4]; + let syscall_ptr = cast([ap - 3], felt*); + let bytecode_len = [ap - 2]; + let count = [ap - 1]; + let initial_bytecode_len = [fp - 4]; + let bytecode = cast([fp - 3], felt*); + + tempvar value = value * 256 + bytecode[initial_bytecode_len - bytecode_len]; + tempvar address = address; + tempvar syscall_ptr = syscall_ptr; + tempvar bytecode_len = bytecode_len - 1; + tempvar count = count - 1; + + jmp cond if bytecode_len != 0; + jmp store; + + cond: + jmp body if count != 0; + + store: + assert [cast(syscall_ptr, StorageWrite*)] = StorageWrite( + selector=STORAGE_WRITE_SELECTOR, address=address, value=value ); + %{ syscall_handler.storage_write(segments=segments, syscall_ptr=ids.syscall_ptr) %} + tempvar value = 0; + tempvar address = address + 1; + tempvar syscall_ptr = syscall_ptr + StorageWrite.SIZE; + tempvar bytecode_len = bytecode_len; + tempvar count = BYTES_PER_FELT; + + jmp body if bytecode_len != 0; + + return (); } // @notice Load the bytecode of the contract in the specified array. // @param index The index in the bytecode. // @param bytecode_len The length of the bytecode. // @param bytecode The bytecode of the contract. - func load_bytecode{ - syscall_ptr: felt*, - pedersen_ptr: HashBuiltin*, - range_check_ptr, - bitwise_ptr: BitwiseBuiltin*, - }(index: felt, bytecode_len: felt, bytecode: felt*, current_felt: felt, remaining_shift: felt) { + func load_bytecode{syscall_ptr: felt*, range_check_ptr}(bytecode_len: felt) -> ( + bytecode: felt* + ) { alloc_locals; - if (bytecode_len == 0) { - // end of loop - return (); - } + let (local bytecode: felt*) = alloc(); + local bound = 256; + local base = 256; - if (remaining_shift == 0) { - // end of current packed felt, loading next stored felt and increase storage index - let (current_felt) = bytecode_.read(index); - return load_bytecode( - index + 1, bytecode_len, bytecode, current_felt, ContractAccount.BYTES_PER_FELT - ); + if (bytecode_len == 0) { + return (bytecode=bytecode); } - // retrieve the precomputed pow array - let (pow_address) = get_label_location(pow_); - let pow = cast(pow_address, felt*); - - // get the leading (big endian) byte of the current_felt - // reassign current_felt to the be remainder - let (current_byte, current_felt) = unsigned_div_rem(current_felt, pow[remaining_shift]); - // add byte to returned array - assert [bytecode] = current_byte; - - return load_bytecode( - index, bytecode_len - 1, bytecode + 1, current_felt, remaining_shift - 1 + let (local chunk_counts, local remainder) = unsigned_div_rem(bytecode_len, BYTES_PER_FELT); + + tempvar remaining_bytes = bytecode_len; + tempvar range_check_ptr = range_check_ptr; + tempvar address = 0; + tempvar syscall_ptr = syscall_ptr; + tempvar value = 0; + tempvar count = 0; + + read: + let remaining_bytes = [ap - 6]; + let range_check_ptr = [ap - 5]; + let address = [ap - 4]; + let syscall_ptr = cast([ap - 3], felt*); + let value = [ap - 2]; + let count = [ap - 1]; + + let syscall = [cast(syscall_ptr, StorageRead*)]; + assert syscall.request = StorageReadRequest( + selector=STORAGE_READ_SELECTOR, address=address ); + %{ syscall_handler.storage_read(segments=segments, syscall_ptr=ids.syscall_ptr) %} + let response = syscall.response; + + let remainder = [fp + 4]; + let chunk_counts = [fp + 3]; + tempvar remaining_chunk = chunk_counts - address; + jmp full_chunk if remaining_chunk != 0; + tempvar count = remainder; + jmp next; + + full_chunk: + tempvar count = BYTES_PER_FELT; + + next: + tempvar remaining_bytes = remaining_bytes; + tempvar range_check_ptr = range_check_ptr; + tempvar address = address + 1; + tempvar syscall_ptr = syscall_ptr + StorageRead.SIZE; + tempvar value = response.value; + tempvar count = count; + + body: + let remaining_bytes = [ap - 6]; + let range_check_ptr = [ap - 5]; + let address = [ap - 4]; + let syscall_ptr = cast([ap - 3], felt*); + let value = [ap - 2]; + let count = [ap - 1]; + + let base = [fp + 1]; + let bound = [fp + 2]; + let bytecode = cast([fp], felt*); + tempvar offset = (address - 1) * BYTES_PER_FELT + count - 1; + let output = bytecode + offset; + + // Put byte in output and assert that 0 <= byte < bound + // See math.split_int + %{ + memory[ids.output] = res = (int(ids.value) % PRIME) % ids.base + assert res < ids.bound, f'split_int(): Limb {res} is out of range.' + %} + tempvar a = [output]; + %{ + from starkware.cairo.common.math_utils import assert_integer + assert_integer(ids.a) + assert 0 <= ids.a % PRIME < range_check_builtin.bound, f'a = {ids.a} is out of range.' + %} + assert a = [range_check_ptr]; + tempvar a = bound - 1 - a; + %{ + from starkware.cairo.common.math_utils import assert_integer + assert_integer(ids.a) + assert 0 <= ids.a % PRIME < range_check_builtin.bound, f'a = {ids.a} is out of range.' + %} + assert a = [range_check_ptr + 1]; + + tempvar value = (value - [output]) / base; + tempvar remaining_bytes = remaining_bytes - 1; + tempvar range_check_ptr = range_check_ptr + 2; + tempvar address = address; + tempvar syscall_ptr = syscall_ptr; + tempvar value = value; + tempvar count = count - 1; + + jmp cond if remaining_bytes != 0; + + let bytecode = cast([fp], felt*); + return (bytecode=bytecode); + + cond: + jmp body if count != 0; + jmp read; } } diff --git a/tests/src/kakarot/accounts/test_contract_account.cairo b/tests/src/kakarot/accounts/test_contract_account.cairo index 25b86c953..af7d92c86 100644 --- a/tests/src/kakarot/accounts/test_contract_account.cairo +++ b/tests/src/kakarot/accounts/test_contract_account.cairo @@ -2,6 +2,7 @@ from starkware.cairo.common.cairo_builtins import HashBuiltin, BitwiseBuiltin from starkware.cairo.common.alloc import alloc +from starkware.cairo.common.memcpy import memcpy from kakarot.accounts.contract.library import ContractAccount @@ -51,3 +52,12 @@ func test__write_bytecode{ return (); } + +func test__read_bytecode{ + syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr, bitwise_ptr: BitwiseBuiltin* +}(output_ptr: felt*) { + alloc_locals; + let (bytecode_len, bytecode) = ContractAccount.bytecode(); + memcpy(output_ptr, bytecode, bytecode_len); + return (); +} diff --git a/tests/src/kakarot/accounts/test_contract_account.py b/tests/src/kakarot/accounts/test_contract_account.py index 383f1ab95..6992608c3 100644 --- a/tests/src/kakarot/accounts/test_contract_account.py +++ b/tests/src/kakarot/accounts/test_contract_account.py @@ -1,3 +1,8 @@ +import random +from textwrap import wrap +from unittest.mock import call, patch + +import pytest from starkware.starknet.public.abi import ( get_selector_from_name, get_storage_var_address, @@ -9,6 +14,11 @@ class TestContractAccount: + @pytest.fixture(params=[0, 10, 100, 1000, 10000]) + def bytecode(self, request): + random.seed(0) + return random.randbytes(request.param) + class TestInitialize: @SyscallHandler.patch("IKakarot.get_native_token", lambda addr, data: [0xDEAD]) @SyscallHandler.patch("IERC20.approve", lambda addr, data: [1]) @@ -73,16 +83,38 @@ def test_should_assert_only_owner(self, cairo_run): cairo_run("test__write_bytecode", bytecode=[]) @SyscallHandler.patch("Ownable_owner", SyscallHandler.caller_address) - def test_should_write_bytecode(self, cairo_run): - cairo_run("test__write_bytecode", bytecode=list(range(32))) - SyscallHandler.mock_storage.assert_any_call( - address=get_storage_var_address("bytecode_len_"), value=32 - ) - SyscallHandler.mock_storage.assert_any_call( - address=get_storage_var_address("bytecode_", 0), - value=int.from_bytes(bytes(list(range(16))), "big"), - ) + def test_should_write_bytecode(self, cairo_run, bytecode): + cairo_run("test__write_bytecode", bytecode=list(bytecode)) SyscallHandler.mock_storage.assert_any_call( - address=get_storage_var_address("bytecode_", 1), - value=int.from_bytes(bytes(list(range(16, 32))), "big"), + address=get_storage_var_address("bytecode_len_"), value=len(bytecode) ) + calls = [ + call(address=i, value=int(value, 16)) + for i, value in enumerate(wrap(bytecode.hex(), 2 * 31)) + ] + SyscallHandler.mock_storage.assert_has_calls(calls) + + class TestBytecode: + @pytest.fixture + def storage(self, bytecode): + chunks = wrap(bytecode.hex(), 2 * 31) + + def _storage(address): + return ( + int(chunks[address], 16) + if address != get_storage_var_address("bytecode_len_") + else len(bytecode) + ) + + return _storage + + def test_should_read_bytecode(self, cairo_run, bytecode, storage): + with patch.object( + SyscallHandler, "mock_storage", side_effect=storage + ) as mock_storage: + output = cairo_run("test__read_bytecode") + chunk_counts, remainder = divmod(len(bytecode), 31) + addresses = list(range(chunk_counts + (remainder > 0))) + calls = [call(address=address) for address in addresses] + mock_storage.assert_has_calls(calls) + assert output == list(bytecode) diff --git a/tests/utils/reporting.py b/tests/utils/reporting.py index d87b95d23..807dea9df 100644 --- a/tests/utils/reporting.py +++ b/tests/utils/reporting.py @@ -20,7 +20,11 @@ _time_report: List[dict] = [] _resources_report: List[dict] = [] - +# A mapping to fix the mismatch between the debug_info and the identifiers. +_label_scope = { + "kakarot.constants.opcodes_label": "kakarot.constants", + "kakarot.accounts.contract.library.internal.pow_": "kakarot.accounts.contract.library.internal", +} T = TypeVar("T", bound=Callable[..., Any]) @@ -300,9 +304,7 @@ def profile_from_tracer_data(tracer_data): if not isinstance(ident, LabelDefinition): continue builder.function_id( - name="kakarot.constants" - if str(name) == "kakarot.constants.opcodes_label" - else str(name), + name=_label_scope.get(str(name), str(name)), inst_location=tracer_data.program.debug_info.instruction_locations[ ident.pc ], diff --git a/tests/utils/syscall_handler.py b/tests/utils/syscall_handler.py index 20e2e31bc..2669802f3 100644 --- a/tests/utils/syscall_handler.py +++ b/tests/utils/syscall_handler.py @@ -162,6 +162,7 @@ def storage_read(self, segments, syscall_ptr): """ Return a constant value for the storage read system call. We use the patches dict to store the storage values; returned value is 0 if the address is not found as in Starknet. + Value can also be set by patching the underling mock_storage object. Syscall structure is: @@ -180,7 +181,11 @@ def storage_read(self, segments, syscall_ptr): } """ address = segments.memory[syscall_ptr + 1] - value = self.patches.get(address, 0) + mock = self.mock_storage(address=address) + patched = self.patches.get(address) + value = ( + patched if patched is not None else (mock if isinstance(mock, int) else 0) + ) segments.write_arg(syscall_ptr + 2, [value]) def storage_write(self, segments, syscall_ptr):