diff --git a/cairo/ethereum/cancun/vm/memory.cairo b/cairo/ethereum/cancun/vm/memory.cairo new file mode 100644 index 00000000..2bc45633 --- /dev/null +++ b/cairo/ethereum/cancun/vm/memory.cairo @@ -0,0 +1,212 @@ +// SPDX-License-Identifier: MIT + +from starkware.cairo.common.alloc import alloc +from starkware.cairo.common.bool import FALSE, TRUE +from starkware.cairo.common.default_dict import default_dict_new, default_dict_finalize +from starkware.cairo.common.dict import DictAccess, dict_read, dict_write +from starkware.cairo.common.memset import memset +from starkware.cairo.common.memcpy import memcpy +from starkware.cairo.lang.compiler.lib.registers import get_fp_and_pc +from starkware.cairo.common.math import assert_le, assert_lt +from starkware.cairo.common.math_cmp import is_le, is_not_zero + +from ethereum_types.bytes import Bytes, BytesStruct, Bytes1 +from ethereum_types.numeric import U256 + +// @title Memory related functions. +// @notice Implements EVM memory operations using a mutable bytearray. +struct BytearrayStruct { + dict_ptr_start: Bytes1DictAccess*, + dict_ptr: Bytes1DictAccess*, + len: felt, +} + +struct Bytearray { + value: BytearrayStruct*, +} + +struct Bytes1DictAccess { + key: felt, + prev_value: Bytes1, + new_value: Bytes1, +} + +// @notice Write bytes to memory at a given position. +// @dev assumption: start_position < 2**128 +// @dev assumption: value.len < 2**128 +// @param memory The pointer to the bytearray. +// @param start_position Starting position to write at. +// @param value Bytes to write. +func memory_write{range_check_ptr, memory: Bytearray}(start_position: U256, value: Bytes) { + alloc_locals; + let bytes_len = value.value.len; + let bytes_data = value.value.data; + let start_position_felt = start_position.value.low; + let new_len = start_position_felt + bytes_len; + + let dict_ptr = cast(memory.value.dict_ptr, DictAccess*); + with dict_ptr { + Internals._write_bytes(start_position_felt, bytes_data, bytes_len); + } + let new_dict_ptr = cast(dict_ptr, Bytes1DictAccess*); + + // Update length if we wrote beyond current length + tempvar current_len = memory.value.len; + let is_new_le_current = is_le(new_len, current_len); + if (is_new_le_current != TRUE) { + tempvar final_len = new_len; + } else { + tempvar final_len = current_len; + } + + tempvar memory = Bytearray( + new BytearrayStruct(memory.value.dict_ptr_start, new_dict_ptr, final_len) + ); + return (); +} + +// @notice Read bytes from memory. +// @dev assumption: start_position < 2**128 +// @dev assumption: size < 2**128 +// @param memory The pointer to the bytearray. +// @param start_position Starting position to read from. +// @param size Number of bytes to read. +// @return The bytes read from memory. +func memory_read_bytes{memory: Bytearray}(start_position: U256, size: U256) -> Bytes { + alloc_locals; + let (local output: felt*) = alloc(); + + let start_position_felt = start_position.value.low; + let size_felt = size.value.low; + let dict_ptr = cast(memory.value.dict_ptr, DictAccess*); + with dict_ptr { + Internals._read_bytes(start_position_felt, size_felt, output); + } + let new_dict_ptr = cast(dict_ptr, Bytes1DictAccess*); + + tempvar memory = Bytearray( + new BytearrayStruct(memory.value.dict_ptr_start, new_dict_ptr, memory.value.len) + ); + tempvar result = Bytes(new BytesStruct(output, size_felt)); + return result; +} + +// @notice Read bytes from a buffer with zero padding. +// @dev assumption: start_position < 2**128 +// @dev assumption: size < 2**128 +// @param buffer Source bytes to read from. +// @param start_position Starting position to read from. +// @param size Number of bytes to read. +// @return The bytes read from the buffer. +func buffer_read{range_check_ptr}(buffer: Bytes, start_position: U256, size: U256) -> Bytes { + alloc_locals; + let (local output: felt*) = alloc(); + let buffer_len = buffer.value.len; + let buffer_data = buffer.value.data; + let start_position_felt = start_position.value.low; + let size_felt = size.value.low; + + Internals._buffer_read_internal( + buffer_len, buffer_data, start_position_felt, size_felt, output + ); + tempvar result = Bytes(new BytesStruct(output, size_felt)); + return result; +} + +namespace Internals { + // @notice Internal function to write bytes to memory. + // @param start_position Starting position to write at. + // @param data Pointer to the bytes data. + // @param len Length of bytes to write. + func _write_bytes{dict_ptr: DictAccess*}(start_position: felt, data: felt*, len: felt) { + alloc_locals; + if (len == 0) { + return (); + } + + tempvar start_position = start_position; + tempvar data = data; + tempvar len = len; + tempvar dict_ptr = dict_ptr; + + body: + let start_position = [ap - 4]; + let data = cast([ap - 3], felt*); + let len = [ap - 2]; + let dict_ptr = cast([ap - 1], DictAccess*); + dict_write(start_position, [data]); + + tempvar start_position = start_position + 1; + tempvar data = data + 1; + tempvar len = len - 1; + tempvar dict_ptr = dict_ptr; + jmp body if len != 0; + return (); + } + + // @notice Internal function to read bytes from memory. + // @param start_position Starting position to read from. + // @param size Number of bytes to read. + // @param output Pointer to write output bytes to. + func _read_bytes{dict_ptr: DictAccess*}(start_position: felt, size: felt, output: felt*) { + if (size == 0) { + return (); + } + + tempvar start_position = start_position; + tempvar size = size; + tempvar output = output; + tempvar dict_ptr = dict_ptr; + + body: + let start_position = [ap - 4]; + let size = [ap - 3]; + let output = cast([ap - 2], felt*); + let dict_ptr = cast([ap - 1], DictAccess*); + + let (value) = dict_read(start_position); + assert [output] = value; + + tempvar start_position = start_position + 1; + tempvar size = size - 1; + tempvar output = output + 1; + tempvar dict_ptr = dict_ptr; + jmp body if size != 0; + return (); + } + + // @notice Internal function to read bytes from a buffer with zero padding. + // @param data_len Length of the buffer. + // @param data Pointer to the buffer data. + // @param start_position Starting position to read from. + // @param size Number of bytes to read. + // @param output Pointer to write output bytes to. + func _buffer_read_internal{range_check_ptr}( + data_len: felt, data: felt*, start_position: felt, size: felt, output: felt* + ) { + alloc_locals; + if (size == 0) { + return (); + } + + // Check if start position is beyond buffer length + let start_oob = is_le(data_len, start_position); + if (start_oob == TRUE) { + memset(output, 0, size); + return (); + } + + // Check if read extends past end of buffer + let end_oob = is_le(data_len, start_position + size); + if (end_oob == TRUE) { + let available_size = data_len - start_position; + memcpy(output, data + start_position, available_size); + + let remaining_size = size - available_size; + memset(output + available_size, 0, remaining_size); + } else { + memcpy(output, data + start_position, size); + } + return (); + } +} diff --git a/cairo/ethereum_types/bytes.cairo b/cairo/ethereum_types/bytes.cairo index 78a60d02..e900aca7 100644 --- a/cairo/ethereum_types/bytes.cairo +++ b/cairo/ethereum_types/bytes.cairo @@ -29,6 +29,9 @@ from ethereum_types.numeric import U128 struct Bytes0 { value: felt, } +struct Bytes1 { + value: felt, +} struct Bytes8 { value: felt, } diff --git a/cairo/tests/ethereum/cancun/vm/test_memory.py b/cairo/tests/ethereum/cancun/vm/test_memory.py new file mode 100644 index 00000000..6b19372b --- /dev/null +++ b/cairo/tests/ethereum/cancun/vm/test_memory.py @@ -0,0 +1,71 @@ +from ethereum_types.bytes import Bytes +from ethereum_types.numeric import U256 +from hypothesis import given +from hypothesis import strategies as st +from hypothesis.strategies import composite + +from ethereum.cancun.vm.memory import buffer_read, memory_read_bytes, memory_write +from tests.utils.strategies import smallbytes + +# NOTE: The testing strategy always assume that memory accesses are within bounds. +# Because the memory is always extended to the proper size _before_ being accessed. + + +@composite +def memory_write_strategy(draw): + # Higher than 2**10 will cause a HealthCheck too large error. + memory_size = draw(st.integers(min_value=0, max_value=2**10)) + memory = draw(st.binary(min_size=memory_size, max_size=memory_size).map(bytearray)) + + # Generate a start position in bounds with existing memory + start_position = draw(st.integers(min_value=0, max_value=memory_size).map(U256)) + + # Generate value with size that won't overflow memory + max_value_size = memory_size - int(start_position) + value = draw(st.binary(min_size=0, max_size=max_value_size)) + + return memory, start_position, value + + +@composite +def memory_read_strategy(draw): + memory_size = draw(st.integers(min_value=0, max_value=2**10)) + memory = draw(st.binary(min_size=memory_size, max_size=memory_size).map(bytearray)) + + start_position = draw(st.integers(min_value=0, max_value=memory_size).map(U256)) + size = draw( + st.integers(min_value=0, max_value=memory_size - int(start_position)).map(U256) + ) + + return memory, start_position, size + + +class TestMemory: + @given(memory_write_strategy()) + def test_memory_write(self, cairo_run, params): + memory, start_position, value = params + cairo_memory = cairo_run("memory_write", memory, start_position, Bytes(value)) + memory_write(memory, start_position, Bytes(value)) + assert cairo_memory == memory + + @given(memory_read_strategy()) + def test_memory_read(self, cairo_run, params): + memory, start_position, size = params + (cairo_memory, cairo_value) = cairo_run( + "memory_read_bytes", memory, start_position, size + ) + python_value = memory_read_bytes(memory, start_position, size) + assert cairo_memory == memory + assert cairo_value == python_value + + @given( + buffer=smallbytes, + start_position=..., + size=st.integers(min_value=0, max_value=2**10).map(U256), + ) + def test_buffer_read( + self, cairo_run, buffer: Bytes, start_position: U256, size: U256 + ): + assert buffer_read(buffer, start_position, size) == cairo_run( + "buffer_read", buffer, start_position, size + ) diff --git a/cairo/tests/utils/args_gen.py b/cairo/tests/utils/args_gen.py index fdd065a2..a4bee1ba 100644 --- a/cairo/tests/utils/args_gen.py +++ b/cairo/tests/utils/args_gen.py @@ -67,7 +67,15 @@ get_origin, ) -from ethereum_types.bytes import Bytes, Bytes0, Bytes8, Bytes20, Bytes32, Bytes256 +from ethereum_types.bytes import ( + Bytes, + Bytes0, + Bytes1, + Bytes8, + Bytes20, + Bytes32, + Bytes256, +) from ethereum_types.numeric import U64, U256, Uint from starkware.cairo.common.dict import DictManager, DictTracker from starkware.cairo.lang.compiler.ast.cairo_types import ( @@ -109,6 +117,7 @@ ("ethereum_types", "numeric", "U256"): U256, ("ethereum_types", "numeric", "SetUint"): Set[Uint], ("ethereum_types", "bytes", "Bytes0"): Bytes0, + ("ethereum_types", "bytes", "Bytes1"): Bytes1, ("ethereum_types", "bytes", "Bytes8"): Bytes8, ("ethereum_types", "bytes", "Bytes20"): Bytes20, ("ethereum_types", "bytes", "Bytes32"): Bytes32, @@ -172,6 +181,7 @@ Address, Account ], ("ethereum", "exceptions", "EthereumException"): EthereumException, + ("ethereum", "cancun", "vm", "memory", "Bytearray"): bytearray, } # In the EELS, some functions are annotated with Sequence while it's actually just Bytes. @@ -271,13 +281,19 @@ def _gen_arg( segments.load_data(struct_ptr, [instances_ptr, len(arg)]) return struct_ptr - if arg_type_origin in (dict, ChainMap, abc.Mapping, set): + if arg_type_origin in (dict, ChainMap, abc.Mapping, set) or arg_type is bytearray: dict_ptr = segments.add() assert dict_ptr.segment_index not in dict_manager.trackers if arg_type_origin is set: arg = {k: True for k in arg} arg_type = Mapping[type(next(iter(arg))), bool] + elif arg_type is bytearray: + # Create a dict with one byte per value and include length + data = defaultdict(int, {k: v for k, v in enumerate(arg)}) + base = _gen_arg(dict_manager, segments, Mapping[int, int], data) + segments.load_data(base + 2, [len(arg)]) # Store length after dict pointers + return base data = { _gen_arg(dict_manager, segments, get_args(arg_type)[0], k): _gen_arg( @@ -323,7 +339,7 @@ def _gen_arg( ) return base - if arg_type in (Bytes, bytes, bytearray, str): + if arg_type in (Bytes, bytes, str): if arg is None: return 0 if isinstance(arg, str): diff --git a/cairo/tests/utils/serde.py b/cairo/tests/utils/serde.py index 620bc474..39f7edf6 100644 --- a/cairo/tests/utils/serde.py +++ b/cairo/tests/utils/serde.py @@ -35,7 +35,15 @@ ) from eth_utils.address import to_checksum_address -from ethereum_types.bytes import Bytes, Bytes0, Bytes8, Bytes20, Bytes32, Bytes256 +from ethereum_types.bytes import ( + Bytes, + Bytes0, + Bytes1, + Bytes8, + Bytes20, + Bytes32, + Bytes256, +) from ethereum_types.numeric import U256 from starkware.cairo.lang.compiler.ast.cairo_types import ( CairoType, @@ -185,7 +193,10 @@ def serialize_type(self, path: Tuple[str, ...], ptr) -> Any: ] ) - if get_origin(python_cls) in (Mapping, abc.Mapping, set): + if ( + get_origin(python_cls) in (Mapping, abc.Mapping, set) + or python_cls is bytearray + ): mapping_struct_ptr = self.serialize_pointers(path, ptr)["value"] mapping_struct_path = ( get_struct_definition(self.program, path) @@ -212,6 +223,19 @@ def serialize_type(self, path: Tuple[str, ...], ptr) -> Any: for i in range(0, segment_size, 3) } + if python_cls is bytearray: + # For bytearray, we reconstruct it from the dictionary values up to length + d = { + self._serialize(key_type, dict_ptr + i): self._serialize( + value_type, dict_ptr + i + 2 + ) + for i in range(0, segment_size, 3) + } + length = pointers["len"] + return bytearray( + [int.from_bytes(d[i], "little") for i in range(length)] + ) + return { self._serialize(key_type, dict_ptr + i): self._serialize( value_type, dict_ptr + i + 2 @@ -219,7 +243,7 @@ def serialize_type(self, path: Tuple[str, ...], ptr) -> Any: for i in range(0, segment_size, 3) } - if python_cls in (bytes, bytearray, Bytes, str): + if python_cls in (bytes, Bytes, str): tuple_struct_ptr = self.serialize_pointers(path, ptr)["value"] struct_name = path[-1] + "Struct" path = (*path[:-1], struct_name) @@ -269,7 +293,7 @@ def serialize_type(self, path: Tuple[str, ...], ptr) -> Any: return U256(value) return python_cls(value.to_bytes(32, "little")) - if python_cls in (Bytes0, Bytes8, Bytes20): + if python_cls in (Bytes0, Bytes1, Bytes8, Bytes20): return python_cls(kwargs["value"].to_bytes(python_cls.LENGTH, "little")) # Because some types are wrapped in a value field, e.g. Account{ value: AccountStruct } diff --git a/cairo/tests/utils/strategies.py b/cairo/tests/utils/strategies.py index 6a3d9d70..6d38741c 100644 --- a/cairo/tests/utils/strategies.py +++ b/cairo/tests/utils/strategies.py @@ -5,7 +5,7 @@ from unittest.mock import patch from eth_keys.datatypes import PrivateKey -from ethereum_types.bytes import Bytes0, Bytes8, Bytes20, Bytes32, Bytes256 +from ethereum_types.bytes import Bytes, Bytes0, Bytes8, Bytes20, Bytes32, Bytes256 from ethereum_types.numeric import U64, U256, FixedUnsigned, Uint from hypothesis import strategies as st from starkware.cairo.lang.cairo_constants import DEFAULT_PRIME @@ -52,6 +52,8 @@ bytes256 = st.binary(min_size=256, max_size=256).map(Bytes256) bloom = bytes256.map(Bloom) +smallbytes = st.binary(min_size=0, max_size=2**10).map(Bytes) + # See ethereum.rlp.Simple and ethereum.rlp.Extended for the definition of Simple and Extended simple = st.recursive(st.one_of(st.binary()), st.lists) extended = st.recursive(