Skip to content

Commit

Permalink
feat: memory
Browse files Browse the repository at this point in the history
  • Loading branch information
enitrat committed Dec 30, 2024
1 parent 7fa1396 commit f930ddf
Show file tree
Hide file tree
Showing 6 changed files with 322 additions and 35 deletions.
188 changes: 188 additions & 0 deletions cairo/ethereum/cancun/vm/memory.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
// SPDX-License-Identifier: MIT

from starkware.cairo.common.alloc import alloc
from starkware.cairo.common.bool import FALSE, TRUE
from starkware.cairo.common.default_dict import default_dict_new, default_dict_finalize
from starkware.cairo.common.dict import DictAccess, dict_read, dict_write
from starkware.cairo.common.memset import memset
from starkware.cairo.common.memcpy import memcpy
from starkware.cairo.lang.compiler.lib.registers import get_fp_and_pc
from starkware.cairo.common.math import assert_le, assert_lt
from starkware.cairo.common.math_cmp import is_le, is_not_zero

from ethereum_types.bytes import Bytes, BytesStruct, Bytearray, BytearrayStruct, Bytes1DictAccess
from ethereum_types.numeric import U256
from ethereum.utils.numeric import max

// @notice Write bytes to memory at a given position.
// @param memory The pointer to the bytearray.
// @param start_position Starting position to write at.
// @param value Bytes to write.
func memory_write{range_check_ptr, memory: Bytearray}(start_position: U256, value: Bytes) {
alloc_locals;
let bytes_len = value.value.len;
let start_position_felt = start_position.value.low;
with_attr error_message("memory_write: start_position > 2**128 || value.len > 2**128") {
assert start_position.value.high = 0;
}

let bytes_data = value.value.data;
let dict_ptr = cast(memory.value.dict_ptr, DictAccess*);
with dict_ptr {
Internals._write_bytes(start_position_felt, bytes_data, bytes_len);
}
let new_dict_ptr = cast(dict_ptr, Bytes1DictAccess*);

let len = max(memory.value.len, start_position.value.low + value.value.len);
tempvar memory = Bytearray(new BytearrayStruct(memory.value.dict_ptr_start, new_dict_ptr, len));
return ();
}

// @notice Read bytes from memory.
// @param memory The pointer to the bytearray.
// @param start_position Starting position to read from.
// @param size Number of bytes to read.
// @return The bytes read from memory.
func memory_read_bytes{memory: Bytearray}(start_position: U256, size: U256) -> Bytes {
alloc_locals;

with_attr error_message("memory_read_bytes: start_position > 2**128 || size > 2**128") {
assert start_position.value.high = 0;
assert size.value.high = 0;
}

let (local output: felt*) = alloc();
let dict_ptr = cast(memory.value.dict_ptr, DictAccess*);
let start_position_felt = start_position.value.low;
let size_felt = size.value.low;

with dict_ptr {
Internals._read_bytes(start_position_felt, size_felt, output);
}
let new_dict_ptr = cast(dict_ptr, Bytes1DictAccess*);

tempvar memory = Bytearray(
new BytearrayStruct(memory.value.dict_ptr_start, new_dict_ptr, memory.value.len)
);
tempvar result = Bytes(new BytesStruct(output, size_felt));
return result;
}

// @notice Read bytes from a buffer with zero padding.
// @dev assumption: start_position < 2**128
// @dev assumption: size < 2**128
// @param buffer Source bytes to read from.
// @param start_position Starting position to read from.
// @param size Number of bytes to read.
// @return The bytes read from the buffer.
func buffer_read{range_check_ptr}(buffer: Bytes, start_position: U256, size: U256) -> Bytes {
alloc_locals;
let (local output: felt*) = alloc();
let buffer_len = buffer.value.len;
let buffer_data = buffer.value.data;
let start_position_felt = start_position.value.low;
let size_felt = size.value.low;
with_attr error_message("buffer_read: start_position > 2**128 || size > 2**128") {
assert start_position.value.high = 0;
assert size.value.high = 0;
}

Internals._buffer_read(buffer_len, buffer_data, start_position_felt, size_felt, output);
tempvar result = Bytes(new BytesStruct(output, size_felt));
return result;
}

namespace Internals {
// @notice Internal function to write bytes to memory.
// @param start_position Starting position to write at.
// @param data Pointer to the bytes data.
// @param len Length of bytes to write.
func _write_bytes{dict_ptr: DictAccess*}(start_position: felt, data: felt*, len: felt) {
if (len == 0) {
return ();
}

tempvar index = len;
tempvar dict_ptr = dict_ptr;

body:
let index = [ap - 2] - 1;
let dict_ptr = cast([ap - 1], DictAccess*);
let start_position = [fp - 5];
let data = cast([fp - 4], felt*);

dict_write(start_position + index, data[index]);

tempvar index = index;
tempvar dict_ptr = dict_ptr;
jmp body if index != 0;

end:
return ();
}

// @notice Internal function to read bytes from memory.
// @param start_position Starting position to read from.
// @param size Number of bytes to read.
// @param output Pointer to write output bytes to.
func _read_bytes{dict_ptr: DictAccess*}(start_position: felt, size: felt, output: felt*) {
alloc_locals;
if (size == 0) {
return ();
}

tempvar dict_index = start_position + size;
tempvar dict_ptr = dict_ptr;

body:
let dict_index = [ap - 2] - 1;
let dict_ptr = cast([ap - 1], DictAccess*);
let output = cast([fp - 3], felt*);
let start_position = [fp - 5];
tempvar output_index = dict_index - start_position;

let (value) = dict_read(dict_index);
assert output[output_index] = value;

tempvar dict_index = dict_index;
tempvar dict_ptr = dict_ptr;
jmp body if output_index != 0;

return ();
}

// @notice Internal function to read bytes from a buffer with zero padding.
// @param data_len Length of the buffer.
// @param data Pointer to the buffer data.
// @param start_position Starting position to read from.
// @param size Number of bytes to read.
// @param output Pointer to write output bytes to.
func _buffer_read{range_check_ptr}(
data_len: felt, data: felt*, start_position: felt, size: felt, output: felt*
) {
alloc_locals;
if (size == 0) {
return ();
}

// Check if start position is beyond buffer length
let start_oob = is_le(data_len, start_position);
if (start_oob == TRUE) {
memset(output, 0, size);
return ();
}

// Check if read extends past end of buffer
let end_oob = is_le(data_len, start_position + size);
if (end_oob == TRUE) {
let available_size = data_len - start_position;
memcpy(output, data + start_position, available_size);
let remaining_size = size - available_size;
memset(output + available_size, 0, remaining_size);
} else {
memcpy(output, data + start_position, size);
}
return ();
}
}
20 changes: 20 additions & 0 deletions cairo/ethereum_types/bytes.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ from ethereum_types.numeric import U128
struct Bytes0 {
value: felt,
}
struct Bytes1 {
value: felt,
}
struct Bytes8 {
value: felt,
}
Expand Down Expand Up @@ -119,3 +122,20 @@ struct TupleBytes32Struct {
struct TupleBytes32 {
value: TupleBytes32Struct*,
}

// mutable bytearray collection type
struct BytearrayStruct {
dict_ptr_start: Bytes1DictAccess*,
dict_ptr: Bytes1DictAccess*,
len: felt,
}

struct Bytearray {
value: BytearrayStruct*,
}

struct Bytes1DictAccess {
key: felt,
prev_value: Bytes1,
new_value: Bytes1,
}
70 changes: 70 additions & 0 deletions cairo/tests/ethereum/cancun/vm/test_memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from ethereum_types.bytes import Bytes
from ethereum_types.numeric import U256
from hypothesis import given
from hypothesis import strategies as st
from hypothesis.strategies import composite

from ethereum.cancun.vm.memory import buffer_read, memory_read_bytes, memory_write

# NOTE: The testing strategy always assume that memory accesses are within bounds.
# Because the memory is always extended to the proper size _before_ being accessed.


@composite
def memory_write_strategy(draw):
# Higher than 2**10 will cause a HealthCheck too large error.
memory_size = draw(st.integers(min_value=0, max_value=2**10))
memory = draw(st.binary(min_size=memory_size, max_size=memory_size).map(bytearray))

# Generate a start position in bounds with existing memory
start_position = draw(st.integers(min_value=0, max_value=memory_size).map(U256))

# Generate value with size that won't overflow memory
max_value_size = memory_size - int(start_position)
value = draw(st.binary(min_size=0, max_size=max_value_size))

return memory, start_position, value


@composite
def memory_read_strategy(draw):
memory_size = draw(st.integers(min_value=0, max_value=2**10))
memory = draw(st.binary(min_size=memory_size, max_size=memory_size).map(bytearray))

start_position = draw(st.integers(min_value=0, max_value=memory_size).map(U256))
size = draw(
st.integers(min_value=0, max_value=memory_size - int(start_position)).map(U256)
)

return memory, start_position, size


class TestMemory:
@given(memory_write_strategy())
def test_memory_write(self, cairo_run, params):
memory, start_position, value = params
cairo_memory = cairo_run("memory_write", memory, start_position, Bytes(value))
memory_write(memory, start_position, Bytes(value))
assert cairo_memory == memory

@given(memory_read_strategy())
def test_memory_read(self, cairo_run, params):
memory, start_position, size = params
(cairo_memory, cairo_value) = cairo_run(
"memory_read_bytes", memory, start_position, size
)
python_value = memory_read_bytes(memory, start_position, size)
assert cairo_memory == memory
assert cairo_value == python_value

@given(
buffer=st.binary(min_size=0, max_size=2**10).map(Bytes),
start_position=st.integers(min_value=0, max_value=2**128 - 1).map(U256),
size=st.integers(min_value=0, max_value=2**10).map(U256),
)
def test_buffer_read(
self, cairo_run, buffer: Bytes, start_position: U256, size: U256
):
assert buffer_read(buffer, start_position, size) == cairo_run(
"buffer_read", buffer, start_position, size
)
21 changes: 16 additions & 5 deletions cairo/tests/utils/args_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,15 @@
get_origin,
)

from ethereum_types.bytes import Bytes, Bytes0, Bytes8, Bytes20, Bytes32, Bytes256
from ethereum_types.bytes import (
Bytes,
Bytes0,
Bytes1,
Bytes8,
Bytes20,
Bytes32,
Bytes256,
)
from ethereum_types.numeric import U64, U256, Uint
from starkware.cairo.common.dict import DictManager, DictTracker
from starkware.cairo.lang.compiler.ast.cairo_types import (
Expand Down Expand Up @@ -114,6 +122,7 @@
("ethereum_types", "numeric", "SetUint"): Set[Uint],
("ethereum_types", "numeric", "UnionUintU256"): Union[Uint, U256],
("ethereum_types", "bytes", "Bytes0"): Bytes0,
("ethereum_types", "bytes", "Bytes1"): Bytes1,
("ethereum_types", "bytes", "Bytes8"): Bytes8,
("ethereum_types", "bytes", "Bytes20"): Bytes20,
("ethereum_types", "bytes", "Bytes32"): Bytes32,
Expand Down Expand Up @@ -184,6 +193,7 @@
Address, Account
],
("ethereum", "exceptions", "EthereumException"): EthereumException,
("ethereum_types", "bytes", "Bytearray"): bytearray,
("ethereum", "cancun", "vm", "stack", "Stack"): List[U256],
(
"ethereum",
Expand Down Expand Up @@ -287,9 +297,10 @@ def _gen_arg(
segments.load_data(struct_ptr, data)
return struct_ptr

if arg_type_origin is list:
# A `list` is represented as a Dict[felt, V] along with a length field.
value_type = get_args(arg_type)[0] # Get the concrete type parameter
if arg_type_origin in (list, bytearray):
# Collection types are represented as a Dict[felt, V] along with a length field.
# Get the concrete type parameter. For bytearray, the value type is int.
value_type = next(iter(get_args(arg_type)), int)
data = defaultdict(int, {k: v for k, v in enumerate(arg)})
base = _gen_arg(dict_manager, segments, Dict[Uint, value_type], data)
segments.load_data(base + 2, [len(arg)])
Expand Down Expand Up @@ -383,7 +394,7 @@ def _gen_arg(
)
return base

if arg_type in (Bytes, bytes, bytearray, str):
if arg_type in (Bytes, bytes, str):
if arg is None:
return 0
if isinstance(arg, str):
Expand Down
Loading

0 comments on commit f930ddf

Please sign in to comment.