Skip to content

Commit

Permalink
program cache for SiR (#321)
Browse files Browse the repository at this point in the history
* implement debug

* initial cache idea

* impl debug

* comment

* native executor take module by ref

* add ability to remove metadata

* format

* cache engine

* getters

* better

* fix

* fix
  • Loading branch information
edg-l authored Oct 24, 2023
1 parent d24a561 commit 120f972
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 2 deletions.
91 changes: 91 additions & 0 deletions src/cache.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
use std::{cell::RefCell, collections::HashMap, fmt::Debug, hash::Hash, rc::Rc};

use cairo_lang_sierra::program::Program;

use crate::{context::NativeContext, executor::NativeExecutor};

/// A Cache for programs with the same context.
pub struct ProgramCache<'a, K: PartialEq + Eq + Hash> {
context: &'a NativeContext,
// Since we already hold a reference to the Context, it doesn't make sense to use thread-safe refcounting.
// Using a Arc<RwLock<T>> here is useless because NativeExecutor is not Send and Sync.
cache: HashMap<K, Rc<RefCell<NativeExecutor<'a>>>>,
}

impl<'a, K: PartialEq + Eq + Hash> Debug for ProgramCache<'a, K> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("ProgramCache")
}
}

impl<'a, K: Clone + PartialEq + Eq + Hash> ProgramCache<'a, K> {
pub fn new(context: &'a NativeContext) -> Self {
Self {
context,
cache: Default::default(),
}
}

pub fn get(&self, key: K) -> Option<Rc<RefCell<NativeExecutor<'a>>>> {
self.cache.get(&key).cloned()
}

pub fn compile_and_insert(
&mut self,
key: K,
program: &Program,
) -> Rc<RefCell<NativeExecutor<'a>>> {
let module = self.context.compile(program).expect("should compile");
let executor = NativeExecutor::new(module);
self.cache
.insert(key.clone(), Rc::new(RefCell::new(executor)));
self.cache.get_mut(&key).cloned().unwrap()
}
}

#[cfg(test)]
mod test {
use std::time::Instant;

use crate::utils::test::load_cairo;

use super::*;

#[test]
fn test_cache() {
let (_, program1) = load_cairo!(
fn main(lhs: felt252, rhs: felt252) -> felt252 {
lhs + rhs
}
);

let (_, program2) = load_cairo!(
fn main(lhs: felt252, rhs: felt252) -> felt252 {
lhs - rhs
}
);

let context = NativeContext::new();
let mut cache: ProgramCache<&'static str> = ProgramCache::new(&context);

let start = Instant::now();
cache.compile_and_insert("program1", &program1);
let diff_1 = Instant::now().duration_since(start);

let start = Instant::now();
cache.get("program1").expect("exists");
let diff_2 = Instant::now().duration_since(start);

assert!(diff_2 < diff_1);

let start = Instant::now();
cache.compile_and_insert("program2", &program2);
let diff_1 = Instant::now().duration_since(start);

let start = Instant::now();
cache.get("program2").expect("exists");
let diff_2 = Instant::now().duration_since(start);

assert!(diff_2 < diff_1);
}
}
9 changes: 9 additions & 0 deletions src/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use serde::{Deserializer, Serializer};
/// A MLIR JIT execution engine in the context of Cairo Native.
pub struct NativeExecutor<'m> {
engine: ExecutionEngine,
// NativeModule needs to be kept alive with the executor or it will segfault when trying to execute.
native_module: NativeModule<'m>,
}

Expand All @@ -27,6 +28,14 @@ impl<'m> NativeExecutor<'m> {
self.native_module.get_program_registry()
}

pub fn get_module(&self) -> &NativeModule<'m> {
&self.native_module
}

pub fn get_module_mut(&mut self) -> &mut NativeModule<'m> {
&mut self.native_module
}

pub fn execute<'de, D, S>(
&self,
fn_id: &FunctionId,
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@

pub use self::{compiler::compile, jit_runner::execute};

pub mod cache;
mod compiler;
pub mod context;
pub mod debug_info;
Expand Down
2 changes: 1 addition & 1 deletion src/metadata.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ pub mod syscall_handler;
pub mod tail_recursion;

/// Metadata container.
#[derive(Default)]
#[derive(Default, Debug)]
pub struct MetadataStorage {
entries: HashMap<TypeId, Box<dyn Any>>,
}
Expand Down
16 changes: 15 additions & 1 deletion src/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use cairo_lang_sierra::{
program_registry::ProgramRegistry,
};
use melior::ir::Module;
use std::any::Any;
use std::{any::Any, fmt::Debug};

/// A MLIR module in the context of Cairo Native.
/// It is conformed by the MLIR module, the Sierra program registry
Expand Down Expand Up @@ -51,6 +51,14 @@ impl<'m> NativeModule<'m> {
self.metadata.insert(meta)
}

/// Removes metadata
pub fn remove_metadata<T>(&mut self) -> Option<T>
where
T: Any,
{
self.metadata.remove()
}

/// Retrieve a reference to some stored metadata.
///
/// The retrieval will fail if there is no metadata with the requested type, in which case it'll
Expand All @@ -70,3 +78,9 @@ impl<'m> NativeModule<'m> {
&self.registry
}
}

impl Debug for NativeModule<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(&self.module.as_operation().to_string())
}
}

0 comments on commit 120f972

Please sign in to comment.