diff --git a/kakarot_scripts/constants.py b/kakarot_scripts/constants.py index ddcaff59f..b04d9c138 100644 --- a/kakarot_scripts/constants.py +++ b/kakarot_scripts/constants.py @@ -323,11 +323,11 @@ def __init__(self, relayers: List[Dict[str, int]]): ) for relayer in relayers ] - self._index = 0 + self.index = 0 def __next__(self) -> Account: - relayer = self.relayer_accounts[self._index] - self._index = (self._index + 1) % len(self.relayer_accounts) + relayer = self.relayer_accounts[self.index] + self.index = (self.index + 1) % len(self.relayer_accounts) return relayer diff --git a/kakarot_scripts/utils/kakarot.py b/kakarot_scripts/utils/kakarot.py index 6ec971750..af7c0da7e 100644 --- a/kakarot_scripts/utils/kakarot.py +++ b/kakarot_scripts/utils/kakarot.py @@ -709,7 +709,10 @@ async def deploy_and_fund_evm_address(evm_address: str, amount: float): await fund_address(evm_address, amount - account_balance) if not await _contract_exists(starknet_address): await _invoke_starknet( - "kakarot", "deploy_externally_owned_account", int(evm_address, 16) + "kakarot", + "deploy_externally_owned_account", + int(evm_address, 16), + account=next(NETWORK["relayers"]), ) return starknet_address diff --git a/kakarot_scripts/utils/starknet.py b/kakarot_scripts/utils/starknet.py index 35da2abe0..587d71be8 100644 --- a/kakarot_scripts/utils/starknet.py +++ b/kakarot_scripts/utils/starknet.py @@ -164,8 +164,8 @@ async def fund_address( else: logger.info(f"{amount / 1e18} ETH minted to {hex(address)}") else: - account = funding_account or await get_starknet_account() - eth_contract = token_contract or await get_eth_contract() + account = funding_account or next(NETWORK["relayers"]) + eth_contract = token_contract or await get_eth_contract(account) balance = await get_balance(account.address, eth_contract) if balance < amount: raise ValueError( diff --git a/tests/end_to_end/conftest.py b/tests/end_to_end/conftest.py index b3fcc675e..b322abb0d 100644 --- a/tests/end_to_end/conftest.py +++ b/tests/end_to_end/conftest.py @@ -8,7 +8,7 @@ from starknet_py.contract import Contract from starknet_py.net.account.account import Account -from kakarot_scripts.constants import RPC_CLIENT, NetworkType +from kakarot_scripts.constants import NETWORK, RPC_CLIENT, NetworkType from kakarot_scripts.utils.kakarot import eth_balance_of from kakarot_scripts.utils.kakarot import get_contract as get_solidity_contract from kakarot_scripts.utils.kakarot import get_eoa @@ -182,3 +182,16 @@ async def _factory(block_number: Optional[Union[int, str]] = "latest"): ).block_hash return _factory + + +@pytest.fixture(autouse=True, scope="session") +def relayers(worker_id): + """ + Override NETWORK["relayers"] to use the worker_id as the index and avoid nonce issues. + """ + try: + logger.info(f"Setting relayer index to {int(worker_id[2:])}") + NETWORK["relayers"].index = int(worker_id[2:]) + except ValueError: + logger.info(f"Error while setting relayer index to {worker_id}") + return diff --git a/tests/fixtures/starknet.py b/tests/fixtures/starknet.py index 9aecd9dde..a899fb347 100644 --- a/tests/fixtures/starknet.py +++ b/tests/fixtures/starknet.py @@ -1,7 +1,6 @@ import json import logging import math -import shutil from hashlib import md5 from pathlib import Path from time import perf_counter, time_ns @@ -24,9 +23,9 @@ from starkware.starknet.compiler.starknet_pass_manager import starknet_pass_manager from tests.utils.constants import Opcodes -from tests.utils.coverage import VmWithCoverage, report_runs +from tests.utils.coverage import VmWithCoverage from tests.utils.hints import debug_info -from tests.utils.reporting import dump_coverage, profile_from_tracer_data +from tests.utils.reporting import profile_from_tracer_data from tests.utils.serde import Serde from tests.utils.syscall_handler import SyscallHandler @@ -38,23 +37,6 @@ logger = logging.getLogger() -@pytest.fixture(scope="session", autouse=True) -async def coverage(worker_id, request): - - output_dir = Path("coverage") - shutil.rmtree(output_dir, ignore_errors=True) - - yield - - output_dir.mkdir(exist_ok=True, parents=True) - files = report_runs(excluded_file={"site-packages", "tests"}) - - if worker_id == "master": - dump_coverage(output_dir, files) - else: - dump_coverage(output_dir / worker_id, files) - - def cairo_compile(path): module_reader = get_module_reader(cairo_path=["src"]) diff --git a/tests/src/conftest.py b/tests/src/conftest.py new file mode 100644 index 000000000..f8ee3948e --- /dev/null +++ b/tests/src/conftest.py @@ -0,0 +1,24 @@ +import shutil +from pathlib import Path + +import pytest + +from tests.utils.coverage import report_runs +from tests.utils.reporting import dump_coverage + + +@pytest.fixture(scope="session", autouse=True) +async def coverage(worker_id): + + output_dir = Path("coverage") + shutil.rmtree(output_dir, ignore_errors=True) + + yield + + output_dir.mkdir(exist_ok=True, parents=True) + files = report_runs(excluded_file={"site-packages", "tests"}) + + if worker_id == "master": + dump_coverage(output_dir, files) + else: + dump_coverage(output_dir / worker_id, files)