Skip to content

Commit

Permalink
feat: memory
Browse files Browse the repository at this point in the history
  • Loading branch information
enitrat committed Dec 19, 2024
1 parent 61c2fa8 commit 03a851b
Show file tree
Hide file tree
Showing 6 changed files with 336 additions and 8 deletions.
212 changes: 212 additions & 0 deletions cairo/ethereum/cancun/vm/memory.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
// SPDX-License-Identifier: MIT

from starkware.cairo.common.alloc import alloc
from starkware.cairo.common.bool import FALSE, TRUE
from starkware.cairo.common.default_dict import default_dict_new, default_dict_finalize
from starkware.cairo.common.dict import DictAccess, dict_read, dict_write
from starkware.cairo.common.memset import memset
from starkware.cairo.common.memcpy import memcpy
from starkware.cairo.lang.compiler.lib.registers import get_fp_and_pc
from starkware.cairo.common.math import assert_le, assert_lt
from starkware.cairo.common.math_cmp import is_le, is_not_zero

from ethereum_types.bytes import Bytes, BytesStruct, Bytes1
from ethereum_types.numeric import U256

// @title Memory related functions.
// @notice Implements EVM memory operations using a mutable bytearray.
struct BytearrayStruct {
dict_ptr_start: Bytes1DictAccess*,
dict_ptr: Bytes1DictAccess*,
len: felt,
}

struct Bytearray {
value: BytearrayStruct*,
}

struct Bytes1DictAccess {
key: felt,
prev_value: Bytes1,
new_value: Bytes1,
}

// @notice Write bytes to memory at a given position.
// @dev assumption: start_position < 2**128
// @dev assumption: value.len < 2**128
// @param memory The pointer to the bytearray.
// @param start_position Starting position to write at.
// @param value Bytes to write.
func memory_write{range_check_ptr, memory: Bytearray}(start_position: U256, value: Bytes) {
alloc_locals;
let bytes_len = value.value.len;
let bytes_data = value.value.data;
let start_position_felt = start_position.value.low;
let new_len = start_position_felt + bytes_len;

let dict_ptr = cast(memory.value.dict_ptr, DictAccess*);
with dict_ptr {
Internals._write_bytes(start_position_felt, bytes_data, bytes_len);
}
let new_dict_ptr = cast(dict_ptr, Bytes1DictAccess*);

// Update length if we wrote beyond current length
tempvar current_len = memory.value.len;
let is_new_le_current = is_le(new_len, current_len);
if (is_new_le_current != TRUE) {
tempvar final_len = new_len;
} else {
tempvar final_len = current_len;
}

tempvar memory = Bytearray(
new BytearrayStruct(memory.value.dict_ptr_start, new_dict_ptr, final_len)
);
return ();
}

// @notice Read bytes from memory.
// @dev assumption: start_position < 2**128
// @dev assumption: size < 2**128
// @param memory The pointer to the bytearray.
// @param start_position Starting position to read from.
// @param size Number of bytes to read.
// @return The bytes read from memory.
func memory_read_bytes{memory: Bytearray}(start_position: U256, size: U256) -> Bytes {
alloc_locals;
let (local output: felt*) = alloc();

let start_position_felt = start_position.value.low;
let size_felt = size.value.low;
let dict_ptr = cast(memory.value.dict_ptr, DictAccess*);
with dict_ptr {
Internals._read_bytes(start_position_felt, size_felt, output);
}
let new_dict_ptr = cast(dict_ptr, Bytes1DictAccess*);

tempvar memory = Bytearray(
new BytearrayStruct(memory.value.dict_ptr_start, new_dict_ptr, memory.value.len)
);
tempvar result = Bytes(new BytesStruct(output, size_felt));
return result;
}

// @notice Read bytes from a buffer with zero padding.
// @dev assumption: start_position < 2**128
// @dev assumption: size < 2**128
// @param buffer Source bytes to read from.
// @param start_position Starting position to read from.
// @param size Number of bytes to read.
// @return The bytes read from the buffer.
func buffer_read{range_check_ptr}(buffer: Bytes, start_position: U256, size: U256) -> Bytes {
alloc_locals;
let (local output: felt*) = alloc();
let buffer_len = buffer.value.len;
let buffer_data = buffer.value.data;
let start_position_felt = start_position.value.low;
let size_felt = size.value.low;

Internals._buffer_read_internal(
buffer_len, buffer_data, start_position_felt, size_felt, output
);
tempvar result = Bytes(new BytesStruct(output, size_felt));
return result;
}

namespace Internals {
// @notice Internal function to write bytes to memory.
// @param start_position Starting position to write at.
// @param data Pointer to the bytes data.
// @param len Length of bytes to write.
func _write_bytes{dict_ptr: DictAccess*}(start_position: felt, data: felt*, len: felt) {
alloc_locals;
if (len == 0) {
return ();
}

tempvar start_position = start_position;
tempvar data = data;
tempvar len = len;
tempvar dict_ptr = dict_ptr;

body:
let start_position = [ap - 4];
let data = cast([ap - 3], felt*);
let len = [ap - 2];
let dict_ptr = cast([ap - 1], DictAccess*);
dict_write(start_position, [data]);

tempvar start_position = start_position + 1;
tempvar data = data + 1;
tempvar len = len - 1;
tempvar dict_ptr = dict_ptr;
jmp body if len != 0;
return ();
}

// @notice Internal function to read bytes from memory.
// @param start_position Starting position to read from.
// @param size Number of bytes to read.
// @param output Pointer to write output bytes to.
func _read_bytes{dict_ptr: DictAccess*}(start_position: felt, size: felt, output: felt*) {
if (size == 0) {
return ();
}

tempvar start_position = start_position;
tempvar size = size;
tempvar output = output;
tempvar dict_ptr = dict_ptr;

body:
let start_position = [ap - 4];
let size = [ap - 3];
let output = cast([ap - 2], felt*);
let dict_ptr = cast([ap - 1], DictAccess*);

let (value) = dict_read(start_position);
assert [output] = value;

tempvar start_position = start_position + 1;
tempvar size = size - 1;
tempvar output = output + 1;
tempvar dict_ptr = dict_ptr;
jmp body if size != 0;
return ();
}

// @notice Internal function to read bytes from a buffer with zero padding.
// @param data_len Length of the buffer.
// @param data Pointer to the buffer data.
// @param start_position Starting position to read from.
// @param size Number of bytes to read.
// @param output Pointer to write output bytes to.
func _buffer_read_internal{range_check_ptr}(
data_len: felt, data: felt*, start_position: felt, size: felt, output: felt*
) {
alloc_locals;
if (size == 0) {
return ();
}

// Check if start position is beyond buffer length
let start_oob = is_le(data_len, start_position);
if (start_oob == TRUE) {
memset(output, 0, size);
return ();
}

// Check if read extends past end of buffer
let end_oob = is_le(data_len, start_position + size);
if (end_oob == TRUE) {
let available_size = data_len - start_position;
memcpy(output, data + start_position, available_size);
let remaining_size = size - available_size;
memset(output + available_size, 0, remaining_size);
} else {
memcpy(output, data + start_position, size);
}
return ();
}
}
3 changes: 3 additions & 0 deletions cairo/ethereum_types/bytes.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ from ethereum_types.numeric import U128
struct Bytes0 {
value: felt,
}
struct Bytes1 {
value: felt,
}
struct Bytes8 {
value: felt,
}
Expand Down
71 changes: 71 additions & 0 deletions cairo/tests/ethereum/cancun/vm/test_memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from ethereum_types.bytes import Bytes
from ethereum_types.numeric import U256
from hypothesis import given
from hypothesis import strategies as st
from hypothesis.strategies import composite

from ethereum.cancun.vm.memory import buffer_read, memory_read_bytes, memory_write
from tests.utils.strategies import smallbytes

# NOTE: The testing strategy always assume that memory accesses are within bounds.
# Because the memory is always extended to the proper size _before_ being accessed.


@composite
def memory_write_strategy(draw):
# Higher than 2**10 will cause a HealthCheck too large error.
memory_size = draw(st.integers(min_value=0, max_value=2**10))
memory = draw(st.binary(min_size=memory_size, max_size=memory_size).map(bytearray))

# Generate a start position in bounds with existing memory
start_position = draw(st.integers(min_value=0, max_value=memory_size).map(U256))

# Generate value with size that won't overflow memory
max_value_size = memory_size - int(start_position)
value = draw(st.binary(min_size=0, max_size=max_value_size))

return memory, start_position, value


@composite
def memory_read_strategy(draw):
memory_size = draw(st.integers(min_value=0, max_value=2**10))
memory = draw(st.binary(min_size=memory_size, max_size=memory_size).map(bytearray))

start_position = draw(st.integers(min_value=0, max_value=memory_size).map(U256))
size = draw(
st.integers(min_value=0, max_value=memory_size - int(start_position)).map(U256)
)

return memory, start_position, size


class TestMemory:
@given(memory_write_strategy())
def test_memory_write(self, cairo_run, params):
memory, start_position, value = params
cairo_memory = cairo_run("memory_write", memory, start_position, Bytes(value))
memory_write(memory, start_position, Bytes(value))
assert cairo_memory == memory

@given(memory_read_strategy())
def test_memory_read(self, cairo_run, params):
memory, start_position, size = params
(cairo_memory, cairo_value) = cairo_run(
"memory_read_bytes", memory, start_position, size
)
python_value = memory_read_bytes(memory, start_position, size)
assert cairo_memory == memory
assert cairo_value == python_value

@given(
buffer=smallbytes,
start_position=...,
size=st.integers(min_value=0, max_value=2**10).map(U256),
)
def test_buffer_read(
self, cairo_run, buffer: Bytes, start_position: U256, size: U256
):
assert buffer_read(buffer, start_position, size) == cairo_run(
"buffer_read", buffer, start_position, size
)
22 changes: 19 additions & 3 deletions cairo/tests/utils/args_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,15 @@
get_origin,
)

from ethereum_types.bytes import Bytes, Bytes0, Bytes8, Bytes20, Bytes32, Bytes256
from ethereum_types.bytes import (
Bytes,
Bytes0,
Bytes1,
Bytes8,
Bytes20,
Bytes32,
Bytes256,
)
from ethereum_types.numeric import U64, U256, Uint
from starkware.cairo.common.dict import DictManager, DictTracker
from starkware.cairo.lang.compiler.ast.cairo_types import (
Expand Down Expand Up @@ -109,6 +117,7 @@
("ethereum_types", "numeric", "U256"): U256,
("ethereum_types", "numeric", "SetUint"): Set[Uint],
("ethereum_types", "bytes", "Bytes0"): Bytes0,
("ethereum_types", "bytes", "Bytes1"): Bytes1,
("ethereum_types", "bytes", "Bytes8"): Bytes8,
("ethereum_types", "bytes", "Bytes20"): Bytes20,
("ethereum_types", "bytes", "Bytes32"): Bytes32,
Expand Down Expand Up @@ -172,6 +181,7 @@
Address, Account
],
("ethereum", "exceptions", "EthereumException"): EthereumException,
("ethereum", "cancun", "vm", "memory", "Bytearray"): bytearray,
}

# In the EELS, some functions are annotated with Sequence while it's actually just Bytes.
Expand Down Expand Up @@ -271,13 +281,19 @@ def _gen_arg(
segments.load_data(struct_ptr, [instances_ptr, len(arg)])
return struct_ptr

if arg_type_origin in (dict, ChainMap, abc.Mapping, set):
if arg_type_origin in (dict, ChainMap, abc.Mapping, set) or arg_type is bytearray:
dict_ptr = segments.add()
assert dict_ptr.segment_index not in dict_manager.trackers

if arg_type_origin is set:
arg = {k: True for k in arg}
arg_type = Mapping[type(next(iter(arg))), bool]
elif arg_type is bytearray:
# Create a dict with one byte per value and include length
data = defaultdict(int, {k: v for k, v in enumerate(arg)})
base = _gen_arg(dict_manager, segments, Mapping[int, int], data)
segments.load_data(base + 2, [len(arg)]) # Store length after dict pointers
return base

data = {
_gen_arg(dict_manager, segments, get_args(arg_type)[0], k): _gen_arg(
Expand Down Expand Up @@ -323,7 +339,7 @@ def _gen_arg(
)
return base

if arg_type in (Bytes, bytes, bytearray, str):
if arg_type in (Bytes, bytes, str):
if arg is None:
return 0
if isinstance(arg, str):
Expand Down
Loading

0 comments on commit 03a851b

Please sign in to comment.