Skip to content

Commit

Permalink
optimize memory
Browse files Browse the repository at this point in the history
  • Loading branch information
enitrat committed Dec 30, 2024
1 parent 283a532 commit 7a4a06e
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 63 deletions.
76 changes: 31 additions & 45 deletions cairo/ethereum/cancun/vm/memory.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -12,62 +12,50 @@ from starkware.cairo.common.math_cmp import is_le, is_not_zero

from ethereum_types.bytes import Bytes, BytesStruct, Bytearray, BytearrayStruct, Bytes1DictAccess
from ethereum_types.numeric import U256
from ethereum.utils.numeric import max

// @notice Write bytes to memory at a given position.
// @dev assumption: start_position < 2**128
// @dev assumption: value.len < 2**128
// @param memory The pointer to the bytearray.
// @param start_position Starting position to write at.
// @param value Bytes to write.
func memory_write{range_check_ptr, memory: Bytearray}(start_position: U256, value: Bytes) {
alloc_locals;
let bytes_len = value.value.len;
let bytes_data = value.value.data;
let start_position_felt = start_position.value.low;
with_attr error_message("memory_write: start_position > 2**128") {
with_attr error_message("memory_write: start_position > 2**128 || value.len > 2**128") {
assert start_position.value.high = 0;
}
let new_len = start_position_felt + bytes_len;

let bytes_data = value.value.data;
let dict_ptr = cast(memory.value.dict_ptr, DictAccess*);
with dict_ptr {
Internals._write_bytes(start_position_felt, bytes_data, bytes_len);
}
let new_dict_ptr = cast(dict_ptr, Bytes1DictAccess*);

// Update length if we wrote beyond current length
tempvar current_len = memory.value.len;
let is_new_le_current = is_le(new_len, current_len);
if (is_new_le_current != TRUE) {
tempvar final_len = new_len;
} else {
tempvar final_len = current_len;
}

tempvar memory = Bytearray(
new BytearrayStruct(memory.value.dict_ptr_start, new_dict_ptr, final_len)
);
let len = max(memory.value.len, start_position.value.low + value.value.len);
tempvar memory = Bytearray(new BytearrayStruct(memory.value.dict_ptr_start, new_dict_ptr, len));
return ();
}

// @notice Read bytes from memory.
// @dev assumption: start_position < 2**128
// @dev assumption: size < 2**128
// @param memory The pointer to the bytearray.
// @param start_position Starting position to read from.
// @param size Number of bytes to read.
// @return The bytes read from memory.
func memory_read_bytes{memory: Bytearray}(start_position: U256, size: U256) -> Bytes {
alloc_locals;
let (local output: felt*) = alloc();

let start_position_felt = start_position.value.low;
let size_felt = size.value.low;
with_attr error_message("memory_read_bytes: start_position > 2**128 || size > 2**128") {
assert start_position.value.high = 0;
assert size.value.high = 0;
}

let (local output: felt*) = alloc();
let dict_ptr = cast(memory.value.dict_ptr, DictAccess*);
let start_position_felt = start_position.value.low;
let size_felt = size.value.low;

with dict_ptr {
Internals._read_bytes(start_position_felt, size_felt, output);
}
Expand Down Expand Up @@ -114,23 +102,22 @@ namespace Internals {
return ();
}

tempvar start_position = start_position;
tempvar data = data;
tempvar len = len;
tempvar index = len;
tempvar dict_ptr = dict_ptr;

body:
let start_position = [ap - 4];
let data = cast([ap - 3], felt*);
let len = [ap - 2];
let index = [ap - 2] - 1;
let dict_ptr = cast([ap - 1], DictAccess*);
dict_write(start_position, [data]);
let start_position = [fp - 5];
let data = cast([fp - 4], felt*);

dict_write(start_position + index, data[index]);

tempvar start_position = start_position + 1;
tempvar data = data + 1;
tempvar len = len - 1;
tempvar index = index;
tempvar dict_ptr = dict_ptr;
jmp body if len != 0;
jmp body if index != 0;

end:
return ();
}

Expand All @@ -139,29 +126,28 @@ namespace Internals {
// @param size Number of bytes to read.
// @param output Pointer to write output bytes to.
func _read_bytes{dict_ptr: DictAccess*}(start_position: felt, size: felt, output: felt*) {
alloc_locals;
if (size == 0) {
return ();
}

tempvar start_position = start_position;
tempvar size = size;
tempvar output = output;
tempvar dict_index = start_position + size;
tempvar dict_ptr = dict_ptr;

body:
let start_position = [ap - 4];
let size = [ap - 3];
let output = cast([ap - 2], felt*);
let dict_index = [ap - 2] - 1;
let dict_ptr = cast([ap - 1], DictAccess*);
let output = cast([fp - 3], felt*);
let start_position = [fp - 5];
tempvar output_index = dict_index - start_position;

let (value) = dict_read(start_position);
assert [output] = value;
let (value) = dict_read(dict_index);
assert output[output_index] = value;

tempvar start_position = start_position + 1;
tempvar size = size - 1;
tempvar output = output + 1;
tempvar dict_index = dict_index;
tempvar dict_ptr = dict_ptr;
jmp body if size != 0;
jmp body if output_index != 0;

return ();
}

Expand Down
1 change: 0 additions & 1 deletion cairo/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ dependencies = [
"cairo-lang>=0.13.2",
"ethereum",
"marshmallow-dataclass>=8.6.1",
"polars>=1.18.0",
"python-dotenv>=1.0.1",
"toml>=0.10.2",
"web3>=7.2.0",
Expand Down
2 changes: 1 addition & 1 deletion cairo/tests/utils/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def serialize_type(self, path: Tuple[str, ...], ptr) -> Any:
if "__main__" in full_path:
full_path = self.main_part + full_path[full_path.index("__main__") + 1 :]
python_cls = to_python_type(full_path)
origin_cls = get_origin(python_cls)
origin_cls = get_origin(python_cls) or python_cls
annotations = []

if get_origin(python_cls) is Annotated:
Expand Down
16 changes: 0 additions & 16 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 7a4a06e

Please sign in to comment.