Skip to content

Commit

Permalink
Make the contract caches thread-compatible. (#729)
Browse files Browse the repository at this point in the history
  • Loading branch information
azteca1998 authored Jul 15, 2024
1 parent 6626938 commit dc6f129
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 27 deletions.
24 changes: 12 additions & 12 deletions src/bin/cairo-native-stress/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,6 @@
//!
//! For documentation on the specific cache used, see `NaiveAotCache`.
use std::alloc::System;
use std::fmt::{Debug, Display};
use std::fs::{create_dir_all, read_dir, OpenOptions};
use std::hash::Hash;
use std::io;
use std::path::{Path, PathBuf};
use std::{collections::HashMap, fs, rc::Rc, time::Instant};

use cairo_lang_sierra::ids::FunctionId;
use cairo_lang_sierra::program::{GenericArg, Program};
use cairo_lang_sierra::program_registry::ProgramRegistry;
Expand All @@ -36,6 +28,14 @@ use clap::Parser;
use libloading::Library;
use num_bigint::BigInt;
use stats_alloc::{Region, StatsAlloc, INSTRUMENTED_SYSTEM};
use std::alloc::System;
use std::fmt::{Debug, Display};
use std::fs::{create_dir_all, read_dir, OpenOptions};
use std::hash::Hash;
use std::io;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::{collections::HashMap, fs, time::Instant};
use tracing::{debug, info, info_span, warn};
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
Expand Down Expand Up @@ -246,7 +246,7 @@ where
K: PartialEq + Eq + Hash + Display,
{
context: &'a NativeContext,
cache: HashMap<K, Rc<AotNativeExecutor>>,
cache: HashMap<K, Arc<AotNativeExecutor>>,
}

impl<'a, K> NaiveAotCache<'a, K>
Expand All @@ -260,7 +260,7 @@ where
}
}

pub fn get(&self, key: &K) -> Option<Rc<AotNativeExecutor>> {
pub fn get(&self, key: &K) -> Option<Arc<AotNativeExecutor>> {
self.cache.get(key).cloned()
}

Expand All @@ -272,7 +272,7 @@ where
key: K,
program: &Program,
opt_level: OptLevel,
) -> Rc<AotNativeExecutor> {
) -> Arc<AotNativeExecutor> {
let native_module = self
.context
.compile(program, None)
Expand Down Expand Up @@ -303,7 +303,7 @@ where
};

let executor = AotNativeExecutor::new(shared_library, registry, metadata);
let executor = Rc::new(executor);
let executor = Arc::new(executor);

self.cache.insert(key, executor.clone());

Expand Down
10 changes: 5 additions & 5 deletions src/cache/aot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@ use std::{
collections::HashMap,
fmt::{self, Debug},
hash::Hash,
rc::Rc,
sync::Arc,
};

pub struct AotProgramCache<'a, K>
where
K: PartialEq + Eq + Hash,
{
context: &'a NativeContext,
cache: HashMap<K, Rc<AotNativeExecutor>>,
cache: HashMap<K, Arc<AotNativeExecutor>>,
}

impl<'a, K> AotProgramCache<'a, K>
Expand All @@ -30,7 +30,7 @@ where
}
}

pub fn get(&self, key: &K) -> Option<Rc<AotNativeExecutor>> {
pub fn get(&self, key: &K) -> Option<Arc<AotNativeExecutor>> {
self.cache.get(key).cloned()
}

Expand All @@ -39,7 +39,7 @@ where
key: K,
program: &Program,
opt_level: OptLevel,
) -> Rc<AotNativeExecutor> {
) -> Arc<AotNativeExecutor> {
let NativeModule {
module,
registry,
Expand All @@ -65,7 +65,7 @@ where
metadata.get::<GasMetadata>().cloned().unwrap(),
);

let executor = Rc::new(executor);
let executor = Arc::new(executor);
self.cache.insert(key, executor.clone());

executor
Expand Down
10 changes: 5 additions & 5 deletions src/cache/jit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::{
collections::HashMap,
fmt::{self, Debug},
hash::Hash,
rc::Rc,
sync::Arc,
};

/// A Cache for programs with the same context.
Expand All @@ -16,7 +16,7 @@ where
// Since we already hold a reference to the Context, it doesn't make sense to use thread-safe
// reference counting. Using a Arc<RwLock<T>> here is useless because NativeExecutor is neither
// Send nor Sync.
cache: HashMap<K, Rc<JitNativeExecutor<'a>>>,
cache: HashMap<K, Arc<JitNativeExecutor<'a>>>,
}

impl<'a, K> JitProgramCache<'a, K>
Expand All @@ -35,7 +35,7 @@ where
self.context
}

pub fn get(&self, key: &K) -> Option<Rc<JitNativeExecutor<'a>>> {
pub fn get(&self, key: &K) -> Option<Arc<JitNativeExecutor<'a>>> {
self.cache.get(key).cloned()
}

Expand All @@ -44,11 +44,11 @@ where
key: K,
program: &Program,
opt_level: OptLevel,
) -> Rc<JitNativeExecutor<'a>> {
) -> Arc<JitNativeExecutor<'a>> {
let module = self.context.compile(program, None).expect("should compile");
let executor = JitNativeExecutor::from_native_module(module, opt_level);

let executor = Rc::new(executor);
let executor = Arc::new(executor);
self.cache.insert(key, executor.clone());

executor
Expand Down
10 changes: 5 additions & 5 deletions src/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use std::{
alloc::Layout,
arch::global_asm,
ptr::{addr_of_mut, null_mut, NonNull},
rc::Rc,
sync::Arc,
};

mod aot;
Expand Down Expand Up @@ -56,8 +56,8 @@ extern "C" {
/// The cairo native executor, either AOT or JIT based.
#[derive(Debug, Clone)]
pub enum NativeExecutor<'m> {
Aot(Rc<AotNativeExecutor>),
Jit(Rc<JitNativeExecutor<'m>>),
Aot(Arc<AotNativeExecutor>),
Jit(Arc<JitNativeExecutor<'m>>),
}

impl<'a> NativeExecutor<'a> {
Expand Down Expand Up @@ -123,13 +123,13 @@ impl<'a> NativeExecutor<'a> {

impl<'m> From<AotNativeExecutor> for NativeExecutor<'m> {
fn from(value: AotNativeExecutor) -> Self {
Self::Aot(Rc::new(value))
Self::Aot(Arc::new(value))
}
}

impl<'m> From<JitNativeExecutor<'m>> for NativeExecutor<'m> {
fn from(value: JitNativeExecutor<'m>) -> Self {
Self::Jit(Rc::new(value))
Self::Jit(Arc::new(value))
}
}

Expand Down
3 changes: 3 additions & 0 deletions src/executor/aot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ pub struct AotNativeExecutor {
gas_metadata: GasMetadata,
}

unsafe impl Send for AotNativeExecutor {}
unsafe impl Sync for AotNativeExecutor {}

impl AotNativeExecutor {
pub fn new(
library: Library,
Expand Down
3 changes: 3 additions & 0 deletions src/executor/jit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ pub struct JitNativeExecutor<'m> {
gas_metadata: GasMetadata,
}

unsafe impl<'a> Send for JitNativeExecutor<'a> {}
unsafe impl<'a> Sync for JitNativeExecutor<'a> {}

impl std::fmt::Debug for JitNativeExecutor<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("JitNativeExecutor")
Expand Down

0 comments on commit dc6f129

Please sign in to comment.