Skip to content

Commit

Permalink
Add deploy_bytecode utils and fix dict_ptr error in memory.finalize (#…
Browse files Browse the repository at this point in the history
…775)

Time spent on this PR: 0.3

## Pull request type

Please check the type of change your PR introduces:

- [ ] Bugfix
- [x] Feature
- [ ] Code style update (formatting, renaming)
- [ ] Refactoring (no functional changes, no api changes)
- [ ] Build related changes
- [ ] Documentation content changes
- [ ] Other (please describe):

## What is the current behavior?

No simple way to create a contract with a given bytecode in Katana.

## What is the new behavior?

Added an util to take the target bytecode and put it as RETURN_DATA of a
simple tx.
Assert that the finally stored bytecode is the expected one.

I've also updated slightly the utils due to max_fee errors arising in
Katana
when restarting tests without restarting Katana.
  • Loading branch information
ClementWalter authored Oct 30, 2023
1 parent a52ba38 commit f6cc59b
Show file tree
Hide file tree
Showing 13 changed files with 388 additions and 186 deletions.
87 changes: 78 additions & 9 deletions scripts/utils/kakarot.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@

try:
FOUNDRY_FILE = toml.loads((Path(__file__).parents[2] / "foundry.toml").read_text())
except NameError:
except (NameError, FileNotFoundError):
FOUNDRY_FILE = toml.loads(Path("foundry.toml").read_text())


Expand All @@ -62,7 +62,12 @@ class EvmTransactionError(Exception):


@functools.lru_cache()
def get_contract(contract_app: str, contract_name: str, address=None) -> Web3Contract:
def get_contract(
contract_app: str,
contract_name: str,
address=None,
caller_eoa: Optional[Account] = None,
) -> Web3Contract:
all_compilation_outputs = [
json.load(open(file))
for file in Path(FOUNDRY_FILE["profile"]["default"]["out"]).glob(
Expand Down Expand Up @@ -108,7 +113,7 @@ def get_contract(contract_app: str, contract_name: str, address=None) -> Web3Con

try:
for fun in contract.functions:
setattr(contract, fun, MethodType(_wrap_kakarot(fun), contract))
setattr(contract, fun, MethodType(_wrap_kakarot(fun, caller_eoa), contract))
except NoABIFunctionsFound:
pass
contract.events.parse_starknet_events = MethodType(_parse_events, contract.events)
Expand All @@ -119,8 +124,8 @@ async def deploy(
contract_app: str, contract_name: str, *args, **kwargs
) -> Web3Contract:
logger.info(f"⏳ Deploying {contract_name}")
contract = get_contract(contract_app, contract_name)
caller_eoa = kwargs.pop("caller_eoa", None)
contract = get_contract(contract_app, contract_name, caller_eoa=caller_eoa)
max_fee = kwargs.pop("max_fee", None)
value = kwargs.pop("value", 0)
receipt, response, success = await eth_send_transaction(
Expand Down Expand Up @@ -192,23 +197,27 @@ def _get_matching_logs_for_event(event_abi, log_receipts) -> List[dict]:
return logs


def _wrap_kakarot(fun: str):
def _wrap_kakarot(fun: str, caller_eoa: Optional[Account] = None):
"""Wrap a contract function call with the Kakarot contract."""

async def _wrapper(self, *args, **kwargs):
abi = self.get_function_by_name(fun).abi
gas_price = kwargs.pop("gas_price", 1_000)
gas_limit = kwargs.pop("gas_limit", 1_000_000_000)
value = kwargs.pop("value", 0)
caller_eoa = kwargs.pop("caller_eoa", None)
caller_eoa_ = kwargs.pop("caller_eoa", caller_eoa)
max_fee = kwargs.pop("max_fee", None)
calldata = self.get_function_by_name(fun)(
*args, **kwargs
)._encode_transaction_data()

if abi["stateMutability"] in ["pure", "view"]:
kakarot_contract = await _get_starknet_contract("kakarot")
origin = int(caller_eoa.address, 16) if caller_eoa else int(EVM_ADDRESS, 16)
origin = (
int(caller_eoa_.signer.public_key.to_address(), 16)
if caller_eoa_
else int(EVM_ADDRESS, 16)
)
result = await kakarot_contract.functions["eth_call"].call(
origin=origin,
to=int(self.address, 16),
Expand All @@ -231,7 +240,7 @@ async def _wrapper(self, *args, **kwargs):
value=value,
gas=gas_limit,
data=calldata,
caller_eoa=caller_eoa.starknet_contract if caller_eoa else None,
caller_eoa=caller_eoa_ if caller_eoa_ else None,
max_fee=max_fee,
)
if success == 0:
Expand All @@ -251,7 +260,7 @@ async def _contract_exists(address: int) -> bool:
return False


async def get_eoa(private_key=None, amount=0.1) -> Account:
async def get_eoa(private_key=None, amount=10) -> Account:
private_key = private_key or keys.PrivateKey(bytes.fromhex(EVM_PRIVATE_KEY[2:]))
starknet_address = await deploy_and_fund_evm_address(
private_key.public_key.to_checksum_address(), amount
Expand Down Expand Up @@ -360,3 +369,63 @@ async def fund_address(address: Union[str, int], amount: float):
f"ℹ️ Funding EVM address {address} at Starknet address {hex(starknet_address)}"
)
await _fund_starknet_address(starknet_address, amount)


async def store_bytecode(bytecode: Union[str, bytes], **kwargs):
"""
Deploy a contract account through Kakarot with given bytecode as finally
stored bytecode.
Note: Deploying directly a contract account and using `write_bytecode` would not
produce an EVM contract registered in Kakarot and thus is not an option. We need
to have Kakarot deploying EVM contrats.
"""
bytecode = (
bytecode
if isinstance(bytecode, bytes)
else bytes.fromhex(bytecode.replace("0x", ""))
)

# Defines variables for used opcodes to make it easier to write the mnemonic
PUSH1 = "60"
PUSH2 = "61"
CODECOPY = "39"
RETURN = "f3"
# The deploy_bytecode is crafted such that:
# - append at the end of the run bytecode the target bytecode
# - load this chunk of code in memory using CODECOPY
# - return this data in RETURN
#
# Bytecode usage
# - CODECOPY(len, offset, destOffset): set memory such that memory[destOffset:destOffset + len] = code[offset:offset + len]
# - RETURN(len, offset): return memory[offset:offset + len]
deploy_bytecode = bytes.fromhex(
f"""
{PUSH2} {len(bytecode):04x}
{PUSH1} 0e
{PUSH1} 00
{CODECOPY}
{PUSH2} {len(bytecode):04x}
{PUSH1} 00
{RETURN}
{bytecode.hex()}"""
)
receipt, response, success = await eth_send_transaction(
to=0, data=deploy_bytecode, **kwargs
)
assert success
starknet_address, evm_address = response
stored_bytecode = await get_bytecode(evm_address)
assert stored_bytecode == bytecode
return evm_address


async def get_bytecode(address: Union[int, str]):
starknet_address = await _compute_starknet_address(address)
return bytes(
(
await _call_starknet(
"contract_account", "bytecode", address=starknet_address
)
).bytecode
)
4 changes: 4 additions & 0 deletions solidity_contracts/src/PlainOpcodes/PlainOpcodes.sol
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ contract PlainOpcodes {
return address(counter).staticcall(data);
}

function opcodeStaticCallToAddress(address target, bytes memory data) public view returns (bool, bytes memory) {
return target.staticcall(data);
}

function opcodeCall() public {
counter.inc();
}
Expand Down
6 changes: 6 additions & 0 deletions solidity_contracts/tests/PlainOpcodes.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -111,4 +111,10 @@ contract PlainOpcodesTest is Test {
(bool success,) = plainOpcodes.opcodeStaticCall2();
assert(!success);
}

function testStaticCallToCallToInc() public view {
bytes memory data = abi.encodeWithSelector(bytes4(keccak256("opcodeCall()")));
(bool success,) = plainOpcodes.opcodeStaticCallToAddress(address(plainOpcodes), data);
assert(!success);
}
}
13 changes: 7 additions & 6 deletions src/kakarot/instructions/system_operations.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ namespace SystemOperations {
);
if (ctx.call_context.read_only * sub_ctx.call_context.value != FALSE) {
let (revert_reason_len, revert_reason) = Errors.stateModificationError();
let ctx = sub_ctx.call_context.calling_context;
let ctx = ExecutionContext.stop(ctx, revert_reason_len, revert_reason, TRUE);
return ctx;
}
Expand Down Expand Up @@ -772,19 +773,19 @@ namespace CreateHelper {
let is_collision = Account.has_code_or_nonce(account);
let account = Account.set_nonce(account, 1);

if (is_collision != 0) {
let (revert_reason_len, revert_reason) = Errors.addressCollision();
tempvar ctx = ExecutionContext.stop(ctx, revert_reason_len, revert_reason, TRUE);
return ctx;
}

// Update calling context before creating sub context
let ctx = ExecutionContext.update_memory(ctx, memory);
let ctx = ExecutionContext.increment_gas_used(
ctx, gas_cost + SystemOperations.GAS_COST_CREATE
);
let ctx = ExecutionContext.update_state(ctx, state);

if (is_collision != 0) {
let (revert_reason_len, revert_reason) = Errors.addressCollision();
tempvar ctx = ExecutionContext.stop(ctx, revert_reason_len, revert_reason, TRUE);
return ctx;
}

// Create sub context with copied state
let state = State.copy(ctx.state);
let state = State.set_account(state, address, account);
Expand Down
42 changes: 17 additions & 25 deletions src/kakarot/memory.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ from starkware.cairo.common.bool import FALSE
from starkware.cairo.common.cairo_builtins import HashBuiltin, BitwiseBuiltin
from starkware.cairo.common.dict import DictAccess, dict_read, dict_write
from starkware.cairo.common.default_dict import default_dict_new, default_dict_finalize
from starkware.cairo.common.math import split_int, unsigned_div_rem, assert_nn
from starkware.cairo.common.math import split_int, unsigned_div_rem
from starkware.cairo.common.math_cmp import is_le
from starkware.cairo.common.memcpy import memcpy
from starkware.cairo.common.uint256 import Uint256
Expand Down Expand Up @@ -196,8 +196,8 @@ namespace Memory {
dict_write{dict_ptr=word_dict}(chunk_index_f, x_f * mask_f + w_f_l);

// Write blocks.
let (word_dict) = store_aligned_words(
word_dict, chunk_index_i + 1, chunk_index_f, element + 16 - offset_in_chunk_i
store_aligned_words{dict_ptr=word_dict}(
chunk_index_i + 1, chunk_index_f, element + 16 - offset_in_chunk_i
);

return (
Expand All @@ -207,11 +207,11 @@ namespace Memory {
);
}

func store_aligned_words{range_check_ptr}(
word_dict: DictAccess*, chunk_index: felt, chunk_index_f: felt, element: felt*
) -> (word_dict: DictAccess*) {
func store_aligned_words{range_check_ptr, dict_ptr: DictAccess*}(
chunk_index: felt, chunk_index_f: felt, element: felt*
) {
if (chunk_index == chunk_index_f) {
return (word_dict=word_dict);
return ();
}
let current = (
element[0] * 256 ** 15 +
Expand All @@ -231,12 +231,9 @@ namespace Memory {
element[14] * 256 ** 1 +
element[15] * 256 ** 0
);
dict_write{dict_ptr=word_dict}(chunk_index, current);
dict_write(chunk_index, current);
return store_aligned_words(
word_dict=word_dict,
chunk_index=chunk_index + 1,
chunk_index_f=chunk_index_f,
element=&element[16],
chunk_index=chunk_index + 1, chunk_index_f=chunk_index_f, element=&element[16]
);
}

Expand Down Expand Up @@ -346,8 +343,8 @@ namespace Memory {
Helpers.split_word(w_f_h, offset_in_chunk_f, element + element_len - offset_in_chunk_f);

// Get blocks.
let (word_dict) = load_aligned_words(
word_dict, chunk_index_i + 1, chunk_index_f, element + 16 - offset_in_chunk_i
load_aligned_words{dict_ptr=word_dict}(
chunk_index_i + 1, chunk_index_f, element + 16 - offset_in_chunk_i
);

return (
Expand All @@ -357,21 +354,16 @@ namespace Memory {
);
}

func load_aligned_words{range_check_ptr}(
word_dict: DictAccess*, chunk_index: felt, chunk_index_f: felt, element: felt*
) -> (word_dict: DictAccess*) {
func load_aligned_words{range_check_ptr, dict_ptr: DictAccess*}(
chunk_index: felt, chunk_index_f: felt, element: felt*
) {
if (chunk_index == chunk_index_f) {
return (word_dict=word_dict);
return ();
}
let original_word_dict = word_dict;
let (value) = dict_read{dict_ptr=word_dict}(chunk_index);
let word_dict = original_word_dict + 3;
let (value) = dict_read(chunk_index);
Helpers.split_word_128(value, element);
return load_aligned_words(
word_dict=word_dict,
chunk_index=chunk_index + 1,
chunk_index_f=chunk_index_f,
element=&element[16],
chunk_index=chunk_index + 1, chunk_index_f=chunk_index_f, element=&element[16]
);
}

Expand Down
38 changes: 19 additions & 19 deletions tests/end_to_end/PlainOpcodes/test_counter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,46 +19,46 @@ async def test_should_return_0_after_deployment(

class TestInc:
async def test_should_increase_count(self, counter, owner):
await counter.reset(caller_eoa=owner)
await counter.inc(caller_eoa=owner)
await counter.reset(caller_eoa=owner.starknet_contract)
await counter.inc(caller_eoa=owner.starknet_contract)
assert await counter.count() == 1

class TestDec:
async def test_should_raise_from_modifier_when_count_is_0(self, counter, owner):
await counter.reset(caller_eoa=owner)
await counter.reset(caller_eoa=owner.starknet_contract)
with evm_error("count should be strictly greater than 0"):
await counter.dec(caller_eoa=owner)
await counter.dec(caller_eoa=owner.starknet_contract)

@pytest.mark.xfail(
reason="https://github.com/kkrt-labs/kakarot/issues/683",
)
async def test_should_raise_from_vm_when_count_is_0(self, counter, owner):
await counter.reset(caller_eoa=owner)
await counter.reset(caller_eoa=owner.starknet_contract)
with evm_error():
await counter.decUnchecked(caller_eoa=owner)
await counter.decUnchecked(caller_eoa=owner.starknet_contract)

async def test_should_decrease_count(self, counter, owner):
await counter.reset(caller_eoa=owner)
await counter.inc(caller_eoa=owner)
await counter.dec(caller_eoa=owner)
await counter.reset(caller_eoa=owner.starknet_contract)
await counter.inc(caller_eoa=owner.starknet_contract)
await counter.dec(caller_eoa=owner.starknet_contract)
assert await counter.count() == 0

async def test_should_decrease_count_unchecked(self, counter, owner):
await counter.reset(caller_eoa=owner)
await counter.inc(caller_eoa=owner)
await counter.decUnchecked(caller_eoa=owner)
await counter.reset(caller_eoa=owner.starknet_contract)
await counter.inc(caller_eoa=owner.starknet_contract)
await counter.decUnchecked(caller_eoa=owner.starknet_contract)
assert await counter.count() == 0

async def test_should_decrease_count_in_place(self, counter, owner):
await counter.reset(caller_eoa=owner)
await counter.inc(caller_eoa=owner)
await counter.decInPlace(caller_eoa=owner)
await counter.reset(caller_eoa=owner.starknet_contract)
await counter.inc(caller_eoa=owner.starknet_contract)
await counter.decInPlace(caller_eoa=owner.starknet_contract)
assert await counter.count() == 0

class TestReset:
async def test_should_set_count_to_0(self, counter, owner):
await counter.inc(caller_eoa=owner)
await counter.reset(caller_eoa=owner)
await counter.inc(caller_eoa=owner.starknet_contract)
await counter.reset(caller_eoa=owner.starknet_contract)
assert await counter.count() == 0

class TestDeploymentWithValue:
Expand All @@ -73,12 +73,12 @@ class TestLoops:
async def test_should_set_counter_to_iterations_with_for_loop(
self, counter, owner, iterations
):
await counter.incForLoop(iterations, caller_eoa=owner)
await counter.incForLoop(iterations, caller_eoa=owner.starknet_contract)
assert await counter.count() == iterations

@pytest.mark.parametrize("iterations", [0, 50, 200])
async def test_should_set_counter_to_iterations_with_while_loop(
self, counter, owner, iterations
):
await counter.incWhileLoop(iterations, caller_eoa=owner)
await counter.incWhileLoop(iterations, caller_eoa=owner.starknet_contract)
assert await counter.count() == iterations
Loading

0 comments on commit f6cc59b

Please sign in to comment.