From a9ac919fd5497b4dbb83dd18e5daac7d1ad594fd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Juli=C3=A1n=20Gonz=C3=A1lez=20Calder=C3=B3n?= Date: Tue, 10 Dec 2024 11:54:57 -0300 Subject: [PATCH] Add support for multiple contracts --- src/bin/cairo-native-run.rs | 26 +++++++++++++--------- src/metadata/profiler.rs | 44 ++++++++++++++++++++++++++++--------- 2 files changed, 50 insertions(+), 20 deletions(-) diff --git a/src/bin/cairo-native-run.rs b/src/bin/cairo-native-run.rs index 1d1a38a8c..255dcaa1e 100644 --- a/src/bin/cairo-native-run.rs +++ b/src/bin/cairo-native-run.rs @@ -10,10 +10,10 @@ use cairo_native::{ starknet_stub::StubSyscallHandler, }; use clap::{Parser, ValueEnum}; +use starknet_types_core::felt::Felt; use std::path::PathBuf; use tracing_subscriber::{EnvFilter, FmtSubscriber}; use utils::{find_function, result_to_runresult}; - mod utils; #[derive(Clone, Debug, ValueEnum)] @@ -63,6 +63,8 @@ fn main() -> anyhow::Result<()> { let mut db = RootDatabase::builder().detect_corelib().build()?; let main_crate_ids = setup_project(&mut db, &args.path)?; + cairo_native::metadata::profiler::ProfilerImpl::push_contract(Felt::ZERO); + let sierra_program = compile_prepared_db( &db, main_crate_ids, @@ -149,15 +151,19 @@ fn main() -> anyhow::Result<()> { let mut trace = HashMap::, u64)>::new(); - for (statement_idx, tick_delta) in cairo_native::metadata::profiler::ProfilerImpl::take() { - if let Statement::Invocation(invocation) = &sierra_program.statements[statement_idx.0] { - let (tick_deltas, extra_count) = - trace.entry(invocation.libfunc_id.clone()).or_default(); - - if tick_delta != u64::MAX { - tick_deltas.push(tick_delta); - } else { - *extra_count += 1; + for (_, contract_trace) in cairo_native::metadata::profiler::ProfilerImpl::take() { + for (statement_idx, tick_delta) in contract_trace { + if let Statement::Invocation(invocation) = + &sierra_program.statements[statement_idx.0] + { + let (tick_deltas, extra_count) = + trace.entry(invocation.libfunc_id.clone()).or_default(); + + if tick_delta != u64::MAX { + tick_deltas.push(tick_delta); + } else { + *extra_count += 1; + } } } } diff --git a/src/metadata/profiler.rs b/src/metadata/profiler.rs index 2a0434c3f..7be53d792 100644 --- a/src/metadata/profiler.rs +++ b/src/metadata/profiler.rs @@ -16,7 +16,8 @@ use melior::{ }, Context, }; -use std::{cell::UnsafeCell, mem}; +use starknet_types_core::felt::Felt; +use std::{cell::UnsafeCell, collections::HashMap, mem}; pub struct ProfilerMeta { _private: (), @@ -193,26 +194,47 @@ impl ProfilerMeta { } thread_local! { - static PROFILER_IMPL: UnsafeCell = const { UnsafeCell::new(ProfilerImpl::new()) }; + static PROFILER_IMPL: UnsafeCell = UnsafeCell::new(ProfilerImpl::new()) ; } pub struct ProfilerImpl { - trace: Vec<(StatementIdx, u64)>, + traces: HashMap>, + contracts: Vec, } impl ProfilerImpl { - const fn new() -> Self { - Self { trace: Vec::new() } + fn new() -> Self { + Self { + traces: HashMap::new(), + contracts: Vec::new(), + } } - pub fn take() -> Vec<(StatementIdx, u64)> { + pub fn push_contract(hash: Felt) { PROFILER_IMPL.with(|x| { let x = unsafe { &mut *x.get() }; - let mut trace = Vec::new(); - mem::swap(&mut x.trace, &mut trace); + x.contracts.push(hash.clone()); + x.traces.entry(hash).or_insert(Vec::new()); + }) + } + + pub fn pop_contract(&mut self) { + PROFILER_IMPL.with(|x| { + let x = unsafe { &mut *x.get() }; + + x.contracts.pop(); + }) + } + + pub fn take() -> HashMap> { + PROFILER_IMPL.with(|x| { + let x = unsafe { &mut *x.get() }; + + let mut traces = HashMap::new(); + mem::swap(&mut x.traces, &mut traces); - trace + traces }) } @@ -220,7 +242,9 @@ impl ProfilerImpl { PROFILER_IMPL.with(|x| { let x = unsafe { &mut *x.get() }; - x.trace + x.traces + .get_mut(x.contracts.last().unwrap()) + .unwrap() .push((StatementIdx(statement_idx as usize), tick_delta)); }); }