Skip to content

Commit

Permalink
feat: EELS memory (#307)
Browse files Browse the repository at this point in the history
Closes #296
  • Loading branch information
enitrat authored Dec 31, 2024
1 parent c2feefc commit adf8deb
Show file tree
Hide file tree
Showing 7 changed files with 334 additions and 30 deletions.
196 changes: 196 additions & 0 deletions cairo/ethereum/cancun/vm/memory.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
// SPDX-License-Identifier: MIT

from starkware.cairo.common.alloc import alloc
from starkware.cairo.common.bool import FALSE, TRUE
from starkware.cairo.common.default_dict import default_dict_new, default_dict_finalize
from starkware.cairo.common.dict import DictAccess, dict_read, dict_write
from starkware.cairo.common.memset import memset
from starkware.cairo.common.memcpy import memcpy
from starkware.cairo.lang.compiler.lib.registers import get_fp_and_pc
from starkware.cairo.common.math import assert_le, assert_lt
from starkware.cairo.common.math_cmp import is_le, is_not_zero

from ethereum_types.bytes import Bytes, BytesStruct, Bytes1DictAccess
from ethereum_types.numeric import U256
from ethereum.utils.numeric import max

struct MemoryStruct {
dict_ptr_start: Bytes1DictAccess*,
dict_ptr: Bytes1DictAccess*,
len: felt,
}

struct Memory {
value: MemoryStruct*,
}

// @notice Write bytes to memory at a given position.
// @param memory The pointer to the bytearray.
// @param start_position Starting position to write at.
// @param value Bytes to write.
func memory_write{range_check_ptr, memory: Memory}(start_position: U256, value: Bytes) {
alloc_locals;
let bytes_len = value.value.len;
let start_position_felt = start_position.value.low;
with_attr error_message("memory_write: start_position > 2**128 || value.len > 2**128") {
assert start_position.value.high = 0;
}

let bytes_data = value.value.data;
let dict_ptr = cast(memory.value.dict_ptr, DictAccess*);
with dict_ptr {
_write_bytes(start_position_felt, bytes_data, bytes_len);
}
let new_dict_ptr = cast(dict_ptr, Bytes1DictAccess*);

let len = max(memory.value.len, start_position.value.low + value.value.len);
tempvar memory = Memory(new MemoryStruct(memory.value.dict_ptr_start, new_dict_ptr, len));
return ();
}

// @notice Read bytes from memory.
// @param memory The pointer to the bytearray.
// @param start_position Starting position to read from.
// @param size Number of bytes to read.
// @return The bytes read from memory.
func memory_read_bytes{memory: Memory}(start_position: U256, size: U256) -> Bytes {
alloc_locals;

with_attr error_message("memory_read_bytes: start_position > 2**128 || size > 2**128") {
assert start_position.value.high = 0;
assert size.value.high = 0;
}

let (local output: felt*) = alloc();
let dict_ptr = cast(memory.value.dict_ptr, DictAccess*);
let start_position_felt = start_position.value.low;
let size_felt = size.value.low;

with dict_ptr {
_read_bytes(start_position_felt, size_felt, output);
}
let new_dict_ptr = cast(dict_ptr, Bytes1DictAccess*);

tempvar memory = Memory(
new MemoryStruct(memory.value.dict_ptr_start, new_dict_ptr, memory.value.len)
);
tempvar result = Bytes(new BytesStruct(output, size_felt));
return result;
}

// @notice Read bytes from a buffer with zero padding.
// @dev assumption: start_position < 2**128
// @dev assumption: size < 2**128
// @param buffer Source bytes to read from.
// @param start_position Starting position to read from.
// @param size Number of bytes to read.
// @return The bytes read from the buffer.
func buffer_read{range_check_ptr}(buffer: Bytes, start_position: U256, size: U256) -> Bytes {
alloc_locals;
let (local output: felt*) = alloc();
let buffer_len = buffer.value.len;
let buffer_data = buffer.value.data;
let start_position_felt = start_position.value.low;
let size_felt = size.value.low;
with_attr error_message("buffer_read: start_position > 2**128 || size > 2**128") {
assert start_position.value.high = 0;
assert size.value.high = 0;
}

_buffer_read(buffer_len, buffer_data, start_position_felt, size_felt, output);
tempvar result = Bytes(new BytesStruct(output, size_felt));
return result;
}

// @notice Internal function to write bytes to memory.
// @param start_position Starting position to write at.
// @param data Pointer to the bytes data.
// @param len Length of bytes to write.
func _write_bytes{dict_ptr: DictAccess*}(start_position: felt, data: felt*, len: felt) {
if (len == 0) {
return ();
}

tempvar index = len;
tempvar dict_ptr = dict_ptr;

body:
let index = [ap - 2] - 1;
let dict_ptr = cast([ap - 1], DictAccess*);
let start_position = [fp - 5];
let data = cast([fp - 4], felt*);

dict_write(start_position + index, data[index]);

tempvar index = index;
tempvar dict_ptr = dict_ptr;
jmp body if index != 0;

end:
return ();
}

// @notice Internal function to read bytes from memory.
// @param start_position Starting position to read from.
// @param size Number of bytes to read.
// @param output Pointer to write output bytes to.
func _read_bytes{dict_ptr: DictAccess*}(start_position: felt, size: felt, output: felt*) {
alloc_locals;
if (size == 0) {
return ();
}

tempvar dict_index = start_position + size;
tempvar dict_ptr = dict_ptr;

body:
let dict_index = [ap - 2] - 1;
let dict_ptr = cast([ap - 1], DictAccess*);
let output = cast([fp - 3], felt*);
let start_position = [fp - 5];
tempvar output_index = dict_index - start_position;

let (value) = dict_read(dict_index);
assert output[output_index] = value;

tempvar dict_index = dict_index;
tempvar dict_ptr = dict_ptr;
jmp body if output_index != 0;

return ();
}

// @notice Internal function to read bytes from a buffer with zero padding.
// @param data_len Length of the buffer.
// @param data Pointer to the buffer data.
// @param start_position Starting position to read from.
// @param size Number of bytes to read.
// @param output Pointer to write output bytes to.
func _buffer_read{range_check_ptr}(
data_len: felt, data: felt*, start_position: felt, size: felt, output: felt*
) {
alloc_locals;
if (size == 0) {
return ();
}

// Check if start position is beyond buffer length
let start_oob = is_le(data_len, start_position);
if (start_oob == TRUE) {
memset(output, 0, size);
return ();
}

// Check if read extends past end of buffer
let end_oob = is_le(data_len, start_position + size);
if (end_oob == TRUE) {
let available_size = data_len - start_position;
memcpy(output, data + start_position, available_size);
let remaining_size = size - available_size;
memset(output + available_size, 0, remaining_size);
} else {
memcpy(output, data + start_position, size);
}
return ();
}
9 changes: 9 additions & 0 deletions cairo/ethereum_types/bytes.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ from ethereum_types.numeric import U128
struct Bytes0 {
value: felt,
}
struct Bytes1 {
value: felt,
}
struct Bytes8 {
value: felt,
}
Expand Down Expand Up @@ -119,3 +122,9 @@ struct TupleBytes32Struct {
struct TupleBytes32 {
value: TupleBytes32Struct*,
}

struct Bytes1DictAccess {
key: felt,
prev_value: Bytes1,
new_value: Bytes1,
}
4 changes: 1 addition & 3 deletions cairo/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ dev-dependencies = [
"pytest-xdist>=3.6.1",
"pytest>=8.3.3",
"pydantic>=2.9.1",
"polars>=1.18.0",
]

[tool.isort]
Expand All @@ -159,8 +160,5 @@ ethereum = { git = "https://github.com/kkrt-labs/execution-specs.git", rev = "b2
requires = ["hatchling"]
build-backend = "hatchling.build"

[dependency-groups]
dev = ["polars>=1.17.1"]

[tool.hatch.build.targets.wheel]
packages = ["src"]
70 changes: 70 additions & 0 deletions cairo/tests/ethereum/cancun/vm/test_memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from ethereum_types.bytes import Bytes
from ethereum_types.numeric import U256
from hypothesis import given
from hypothesis import strategies as st
from hypothesis.strategies import composite

from ethereum.cancun.vm.memory import buffer_read, memory_read_bytes, memory_write

# NOTE: The testing strategy always assume that memory accesses are within bounds.
# Because the memory is always extended to the proper size _before_ being accessed.


@composite
def memory_write_strategy(draw):
# Higher than 2**10 will cause a HealthCheck too large error.
memory_size = draw(st.integers(min_value=0, max_value=2**10))
memory = draw(st.binary(min_size=memory_size, max_size=memory_size).map(bytearray))

# Generate a start position in bounds with existing memory
start_position = draw(st.integers(min_value=0, max_value=memory_size).map(U256))

# Generate value with size that won't overflow memory
max_value_size = memory_size - int(start_position)
value = draw(st.binary(min_size=0, max_size=max_value_size))

return memory, start_position, value


@composite
def memory_read_strategy(draw):
memory_size = draw(st.integers(min_value=0, max_value=2**10))
memory = draw(st.binary(min_size=memory_size, max_size=memory_size).map(bytearray))

start_position = draw(st.integers(min_value=0, max_value=memory_size).map(U256))
size = draw(
st.integers(min_value=0, max_value=memory_size - int(start_position)).map(U256)
)

return memory, start_position, size


class TestMemory:
@given(memory_write_strategy())
def test_memory_write(self, cairo_run, params):
memory, start_position, value = params
cairo_memory = cairo_run("memory_write", memory, start_position, Bytes(value))
memory_write(memory, start_position, Bytes(value))
assert cairo_memory == memory

@given(memory_read_strategy())
def test_memory_read(self, cairo_run, params):
memory, start_position, size = params
(cairo_memory, cairo_value) = cairo_run(
"memory_read_bytes", memory, start_position, size
)
python_value = memory_read_bytes(memory, start_position, size)
assert cairo_memory == memory
assert cairo_value == python_value

@given(
buffer=st.binary(min_size=0, max_size=2**10).map(Bytes),
start_position=st.integers(min_value=0, max_value=2**128 - 1).map(U256),
size=st.integers(min_value=0, max_value=2**10).map(U256),
)
def test_buffer_read(
self, cairo_run, buffer: Bytes, start_position: U256, size: U256
):
assert buffer_read(buffer, start_position, size) == cairo_run(
"buffer_read", buffer, start_position, size
)
24 changes: 20 additions & 4 deletions cairo/tests/utils/args_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,15 @@
get_origin,
)

from ethereum_types.bytes import Bytes, Bytes0, Bytes8, Bytes20, Bytes32, Bytes256
from ethereum_types.bytes import (
Bytes,
Bytes0,
Bytes1,
Bytes8,
Bytes20,
Bytes32,
Bytes256,
)
from ethereum_types.numeric import U64, U256, Uint
from starkware.cairo.common.dict import DictManager, DictTracker
from starkware.cairo.lang.compiler.ast.cairo_types import (
Expand Down Expand Up @@ -105,6 +113,11 @@
from ethereum.rlp import Extended, Simple
from tests.utils.helpers import flatten


class Memory(bytearray):
pass


_cairo_struct_to_python_type: Dict[Tuple[str, ...], Any] = {
("ethereum_types", "others", "None"): type(None),
("ethereum_types", "numeric", "bool"): bool,
Expand All @@ -114,6 +127,7 @@
("ethereum_types", "numeric", "SetUint"): Set[Uint],
("ethereum_types", "numeric", "UnionUintU256"): Union[Uint, U256],
("ethereum_types", "bytes", "Bytes0"): Bytes0,
("ethereum_types", "bytes", "Bytes1"): Bytes1,
("ethereum_types", "bytes", "Bytes8"): Bytes8,
("ethereum_types", "bytes", "Bytes20"): Bytes20,
("ethereum_types", "bytes", "Bytes32"): Bytes32,
Expand Down Expand Up @@ -184,6 +198,7 @@
Address, Account
],
("ethereum", "exceptions", "EthereumException"): EthereumException,
("ethereum", "cancun", "vm", "memory", "Memory"): Memory,
("ethereum", "cancun", "vm", "stack", "Stack"): List[U256],
(
"ethereum",
Expand Down Expand Up @@ -287,9 +302,10 @@ def _gen_arg(
segments.load_data(struct_ptr, data)
return struct_ptr

if arg_type_origin is list:
# A `list` is represented as a Dict[felt, V] along with a length field.
value_type = get_args(arg_type)[0] # Get the concrete type parameter
if arg_type_origin in (list, Memory):
# Collection types are represented as a Dict[felt, V] along with a length field.
# Get the concrete type parameter. For bytearray, the value type is int.
value_type = next(iter(get_args(arg_type)), int)
data = defaultdict(int, {k: v for k, v in enumerate(arg)})
base = _gen_arg(dict_manager, segments, Dict[Uint, value_type], data)
segments.load_data(base + 2, [len(arg)])
Expand Down
Loading

0 comments on commit adf8deb

Please sign in to comment.