Skip to content

Commit

Permalink
Refacto CA bytecode read and write (#921)
Browse files Browse the repository at this point in the history
Time spent on this PR: 0.2

## Pull request type

Please check the type of change your PR introduces:

- [ ] Bugfix
- [ ] Feature
- [ ] Code style update (formatting, renaming)
- [x] Refactoring (no functional changes, no api changes)
- [ ] Build related changes
- [ ] Documentation content changes
- [ ] Other (please describe):

## What is the current behavior?

Resolves #906

## What is the new behavior?

Reduced by ~2.5x the number of steps for write, ~x5 for read

<!-- Reviewable:start -->
- - -
This change is [<img src="https://reviewable.io/review_button.svg"
height="34" align="absmiddle"
alt="Reviewable"/>](https://reviewable.io/reviews/kkrt-labs/kakarot/921)
<!-- Reviewable:end -->
  • Loading branch information
ClementWalter authored Feb 1, 2024
1 parent f5fffb6 commit 3efdde1
Show file tree
Hide file tree
Showing 7 changed files with 224 additions and 128 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,4 @@ resources/*

.hypothesis
*.pb.gz
profile*.png
2 changes: 1 addition & 1 deletion blockchain-tests-skip.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2452,7 +2452,6 @@ testname:
- stackOverflowSWAP_d0g0v0_Shanghai
- stacksanitySWAP_d0g0v0_Shanghai
stStaticCall:
- RevertOpcodeCalls_d0g1v0_Shanghai
- StaticcallToPrecompileFromCalledContract_d0g0v0_Shanghai
- StaticcallToPrecompileFromContractInitialization_d0g0v0_Shanghai
- StaticcallToPrecompileFromTransaction_d0g0v0_Shanghai
Expand Down Expand Up @@ -2538,6 +2537,7 @@ testname:
- static_callWithHighValueAndGasOOG_d1g0v0_Shanghai
- static_call_value_inherit_d0g0v0_Shanghai
- static_call_value_inherit_from_call_d0g0v0_Shanghai
- static_RevertOpcodeCalls_d0g1v0_Shanghai
stStaticFlagEnabled:
- CallWithNOTZeroValueToPrecompileFromCalledContract_d0g0v0_Shanghai
- CallWithNOTZeroValueToPrecompileFromCalledContract_d1g0v0_Shanghai
Expand Down
268 changes: 157 additions & 111 deletions src/kakarot/accounts/contract/library.cairo
Original file line number Diff line number Diff line change
@@ -1,26 +1,25 @@
// SPDX-License-Identifier: MIT

%lang starknet

// Starkware dependencies
from openzeppelin.access.ownable.library import Ownable
from starkware.cairo.common.alloc import alloc
from starkware.cairo.common.bool import FALSE
from starkware.cairo.common.cairo_builtins import HashBuiltin, BitwiseBuiltin
from starkware.cairo.common.math import unsigned_div_rem
from starkware.cairo.common.math import unsigned_div_rem, split_int
from starkware.cairo.common.registers import get_label_location
from starkware.cairo.common.uint256 import Uint256, uint256_not
from starkware.starknet.common.syscalls import storage_read, storage_write
from starkware.starknet.common.syscalls import (
StorageRead,
StorageWrite,
STORAGE_READ_SELECTOR,
STORAGE_WRITE_SELECTOR,
storage_read,
storage_write,
StorageReadRequest,
)
from starkware.cairo.common.memset import memset

from kakarot.interfaces.interfaces import IERC20, IKakarot

// Storage

@storage_var
func bytecode_(index: felt) -> (res: felt) {
}

@storage_var
func bytecode_len_() -> (res: felt) {
}
Expand All @@ -41,13 +40,12 @@ func evm_address() -> (evm_address: felt) {
func nonce() -> (nonce: felt) {
}

// Define the number of bytes per felt
const BYTES_PER_FELT = 31;

// @title ContractAccount main library file.
// @notice This file contains the EVM smart contract account representation logic.
namespace ContractAccount {
// Define the number of bytes per felt. Above 16, the following code won't work as it uses unsigned_div_rem
// which is bounded by RC_BOUND = 2 ** 128 ~ uint128 ~ bytes16
const BYTES_PER_FELT = 16;
// @notice This function is used to initialize the smart contract account.
func initialize{
syscall_ptr: felt*,
Expand Down Expand Up @@ -85,17 +83,12 @@ namespace ContractAccount {
range_check_ptr,
bitwise_ptr: BitwiseBuiltin*,
}(bytecode_len: felt, bytecode: felt*) {
alloc_locals;
// Access control check.
Ownable.assert_only_owner();
// Recursively store the bytecode.
bytecode_len_.write(bytecode_len);
internal.write_bytecode(
index=0,
bytecode_len=bytecode_len,
bytecode=bytecode,
current_felt=0,
remaining_shift=BYTES_PER_FELT,
);
internal.write_bytecode(bytecode_len=bytecode_len, bytecode=bytecode);
return ();
}

Expand All @@ -117,17 +110,8 @@ namespace ContractAccount {
bitwise_ptr: BitwiseBuiltin*,
}() -> (bytecode_len: felt, bytecode: felt*) {
alloc_locals;
// Read bytecode length from storage.
let (bytecode_len) = bytecode_len_.read();
// Recursively load bytecode into specified memory location.
let bytecode_: felt* = alloc();
internal.load_bytecode(
index=0,
bytecode_len=bytecode_len,
bytecode=bytecode_,
current_felt=0,
remaining_shift=0,
);
let (bytecode_) = internal.load_bytecode(bytecode_len);
return (bytecode_len, bytecode_);
}

Expand Down Expand Up @@ -224,107 +208,169 @@ namespace ContractAccount {
}

namespace internal {
// Use a precomputed 2 ** n array to save on resources usage.
// Array starts with a 0 to be shifted and have pow[i] = bit shift for byte i with
// i as a counter, ie i \in (0, BYTES_PER_FELT]
pow_:
dw 0;
dw 1;
dw 2 ** 8;
dw 2 ** 16;
dw 2 ** 24;
dw 2 ** 32;
dw 2 ** 40;
dw 2 ** 48;
dw 2 ** 56;
dw 2 ** 64;
dw 2 ** 72;
dw 2 ** 80;
dw 2 ** 88;
dw 2 ** 96;
dw 2 ** 104;
dw 2 ** 112;
dw 2 ** 120;
// @notice Store the bytecode of the contract.
// @param index The current free index in the bytecode_ storage.
// @param bytecode_len The length of the bytecode.
// @param bytecode The bytecode of the contract.
func write_bytecode{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}(
index: felt, bytecode_len: felt, bytecode: felt*, current_felt: felt, remaining_shift: felt
) {
func write_bytecode{syscall_ptr: felt*}(bytecode_len: felt, bytecode: felt*) {
alloc_locals;

if (bytecode_len == 0) {
// end of bytecode case, break loop storing latest "pending" packed felt
bytecode_.write(index, current_felt);
return ();
}

if (remaining_shift == 0) {
// end of packed felt case, store current "pending" felt
// continue loop with a new current_felt and increment index in bytecode_ storage
bytecode_.write(index, current_felt);
return write_bytecode(
index + 1, bytecode_len, bytecode, 0, ContractAccount.BYTES_PER_FELT
);
}

// retrieve the precomputed pow array
let (pow_address) = get_label_location(pow_);
let pow = cast(pow_address, felt*);

// shift the current byte and add it to the current felt
// bytes are stored big endian, ie that 3 bytes ends up being stored as a felt whose representation is 0xabcdef000...000
// for a given remaining_shift:
// current_felt = 0x 12 34 00 00...00 00
// bytecode = 0x 56
// pow[remaining_shift] * bytecode = 0x 00 00 56 00...00 00
// resulting in 0x 12 34 56 00...00 00
let current_felt = pow[remaining_shift] * [bytecode] + current_felt;

return write_bytecode(
index, bytecode_len - 1, bytecode + 1, current_felt, remaining_shift - 1
tempvar value = 0;
tempvar address = 0;
tempvar syscall_ptr = syscall_ptr;
tempvar bytecode_len = bytecode_len;
tempvar count = BYTES_PER_FELT;

body:
let value = [ap - 5];
let address = [ap - 4];
let syscall_ptr = cast([ap - 3], felt*);
let bytecode_len = [ap - 2];
let count = [ap - 1];
let initial_bytecode_len = [fp - 4];
let bytecode = cast([fp - 3], felt*);

tempvar value = value * 256 + bytecode[initial_bytecode_len - bytecode_len];
tempvar address = address;
tempvar syscall_ptr = syscall_ptr;
tempvar bytecode_len = bytecode_len - 1;
tempvar count = count - 1;

jmp cond if bytecode_len != 0;
jmp store;

cond:
jmp body if count != 0;

store:
assert [cast(syscall_ptr, StorageWrite*)] = StorageWrite(
selector=STORAGE_WRITE_SELECTOR, address=address, value=value
);
%{ syscall_handler.storage_write(segments=segments, syscall_ptr=ids.syscall_ptr) %}
tempvar value = 0;
tempvar address = address + 1;
tempvar syscall_ptr = syscall_ptr + StorageWrite.SIZE;
tempvar bytecode_len = bytecode_len;
tempvar count = BYTES_PER_FELT;

jmp body if bytecode_len != 0;

return ();
}

// @notice Load the bytecode of the contract in the specified array.
// @param index The index in the bytecode.
// @param bytecode_len The length of the bytecode.
// @param bytecode The bytecode of the contract.
func load_bytecode{
syscall_ptr: felt*,
pedersen_ptr: HashBuiltin*,
range_check_ptr,
bitwise_ptr: BitwiseBuiltin*,
}(index: felt, bytecode_len: felt, bytecode: felt*, current_felt: felt, remaining_shift: felt) {
func load_bytecode{syscall_ptr: felt*, range_check_ptr}(bytecode_len: felt) -> (
bytecode: felt*
) {
alloc_locals;

if (bytecode_len == 0) {
// end of loop
return ();
}
let (local bytecode: felt*) = alloc();
local bound = 256;
local base = 256;

if (remaining_shift == 0) {
// end of current packed felt, loading next stored felt and increase storage index
let (current_felt) = bytecode_.read(index);
return load_bytecode(
index + 1, bytecode_len, bytecode, current_felt, ContractAccount.BYTES_PER_FELT
);
if (bytecode_len == 0) {
return (bytecode=bytecode);
}

// retrieve the precomputed pow array
let (pow_address) = get_label_location(pow_);
let pow = cast(pow_address, felt*);

// get the leading (big endian) byte of the current_felt
// reassign current_felt to the be remainder
let (current_byte, current_felt) = unsigned_div_rem(current_felt, pow[remaining_shift]);
// add byte to returned array
assert [bytecode] = current_byte;

return load_bytecode(
index, bytecode_len - 1, bytecode + 1, current_felt, remaining_shift - 1
let (local chunk_counts, local remainder) = unsigned_div_rem(bytecode_len, BYTES_PER_FELT);

tempvar remaining_bytes = bytecode_len;
tempvar range_check_ptr = range_check_ptr;
tempvar address = 0;
tempvar syscall_ptr = syscall_ptr;
tempvar value = 0;
tempvar count = 0;

read:
let remaining_bytes = [ap - 6];
let range_check_ptr = [ap - 5];
let address = [ap - 4];
let syscall_ptr = cast([ap - 3], felt*);
let value = [ap - 2];
let count = [ap - 1];

let syscall = [cast(syscall_ptr, StorageRead*)];
assert syscall.request = StorageReadRequest(
selector=STORAGE_READ_SELECTOR, address=address
);
%{ syscall_handler.storage_read(segments=segments, syscall_ptr=ids.syscall_ptr) %}
let response = syscall.response;

let remainder = [fp + 4];
let chunk_counts = [fp + 3];
tempvar remaining_chunk = chunk_counts - address;
jmp full_chunk if remaining_chunk != 0;
tempvar count = remainder;
jmp next;

full_chunk:
tempvar count = BYTES_PER_FELT;

next:
tempvar remaining_bytes = remaining_bytes;
tempvar range_check_ptr = range_check_ptr;
tempvar address = address + 1;
tempvar syscall_ptr = syscall_ptr + StorageRead.SIZE;
tempvar value = response.value;
tempvar count = count;

body:
let remaining_bytes = [ap - 6];
let range_check_ptr = [ap - 5];
let address = [ap - 4];
let syscall_ptr = cast([ap - 3], felt*);
let value = [ap - 2];
let count = [ap - 1];

let base = [fp + 1];
let bound = [fp + 2];
let bytecode = cast([fp], felt*);
tempvar offset = (address - 1) * BYTES_PER_FELT + count - 1;
let output = bytecode + offset;

// Put byte in output and assert that 0 <= byte < bound
// See math.split_int
%{
memory[ids.output] = res = (int(ids.value) % PRIME) % ids.base
assert res < ids.bound, f'split_int(): Limb {res} is out of range.'
%}
tempvar a = [output];
%{
from starkware.cairo.common.math_utils import assert_integer
assert_integer(ids.a)
assert 0 <= ids.a % PRIME < range_check_builtin.bound, f'a = {ids.a} is out of range.'
%}
assert a = [range_check_ptr];
tempvar a = bound - 1 - a;
%{
from starkware.cairo.common.math_utils import assert_integer
assert_integer(ids.a)
assert 0 <= ids.a % PRIME < range_check_builtin.bound, f'a = {ids.a} is out of range.'
%}
assert a = [range_check_ptr + 1];

tempvar value = (value - [output]) / base;
tempvar remaining_bytes = remaining_bytes - 1;
tempvar range_check_ptr = range_check_ptr + 2;
tempvar address = address;
tempvar syscall_ptr = syscall_ptr;
tempvar value = value;
tempvar count = count - 1;

jmp cond if remaining_bytes != 0;

let bytecode = cast([fp], felt*);
return (bytecode=bytecode);

cond:
jmp body if count != 0;
jmp read;
}
}
10 changes: 10 additions & 0 deletions tests/src/kakarot/accounts/test_contract_account.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from starkware.cairo.common.cairo_builtins import HashBuiltin, BitwiseBuiltin
from starkware.cairo.common.alloc import alloc
from starkware.cairo.common.memcpy import memcpy

from kakarot.accounts.contract.library import ContractAccount

Expand Down Expand Up @@ -51,3 +52,12 @@ func test__write_bytecode{

return ();
}

func test__read_bytecode{
syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr, bitwise_ptr: BitwiseBuiltin*
}(output_ptr: felt*) {
alloc_locals;
let (bytecode_len, bytecode) = ContractAccount.bytecode();
memcpy(output_ptr, bytecode, bytecode_len);
return ();
}
Loading

0 comments on commit 3efdde1

Please sign in to comment.