Skip to content

Commit

Permalink
Remove panics in executors (#920)
Browse files Browse the repository at this point in the history
* Remove panics in aot executor

* Fix som leftover unwraps

* Remove panics in jit executor

* Remove panic from contract executor

* Remove metadata unwrap

* Fix clippy
  • Loading branch information
JulianGCalderon authored Nov 19, 2024
1 parent ae234ba commit 4a6e239
Show file tree
Hide file tree
Showing 15 changed files with 92 additions and 71 deletions.
12 changes: 8 additions & 4 deletions benches/libfuncs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ pub fn bench_libfuncs(c: &mut Criterion) {
let module = native_context.compile(program, false).unwrap();
// pass manager internally verifies the MLIR output is correct.
let native_executor =
JitNativeExecutor::from_native_module(module, OptLevel::Aggressive);
JitNativeExecutor::from_native_module(module, OptLevel::Aggressive)
.unwrap();

// Execute the program.
let result = native_executor
Expand All @@ -76,7 +77,8 @@ pub fn bench_libfuncs(c: &mut Criterion) {
let module = native_context.compile(program, false).unwrap();
// pass manager internally verifies the MLIR output is correct.
let native_executor =
JitNativeExecutor::from_native_module(module, OptLevel::Aggressive);
JitNativeExecutor::from_native_module(module, OptLevel::Aggressive)
.unwrap();

// warmup
for _ in 0..5 {
Expand Down Expand Up @@ -104,7 +106,8 @@ pub fn bench_libfuncs(c: &mut Criterion) {
let module = native_context.compile(program, false).unwrap();
// pass manager internally verifies the MLIR output is correct.
let native_executor =
AotNativeExecutor::from_native_module(module, OptLevel::Aggressive);
AotNativeExecutor::from_native_module(module, OptLevel::Aggressive)
.unwrap();

// Execute the program.
let result = native_executor
Expand All @@ -123,7 +126,8 @@ pub fn bench_libfuncs(c: &mut Criterion) {
let module = native_context.compile(program, false).unwrap();
// pass manager internally verifies the MLIR output is correct.
let native_executor =
AotNativeExecutor::from_native_module(module, OptLevel::Aggressive);
AotNativeExecutor::from_native_module(module, OptLevel::Aggressive)
.unwrap();

// warmup
for _ in 0..5 {
Expand Down
3 changes: 2 additions & 1 deletion examples/easy_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ fn main() {
.expect("entry point not found");

// Instantiate the executor.
let native_executor = JitNativeExecutor::from_native_module(native_program, Default::default());
let native_executor =
JitNativeExecutor::from_native_module(native_program, Default::default()).unwrap();

// Execute the program.
let result = native_executor
Expand Down
3 changes: 2 additions & 1 deletion examples/erc20.rs
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,8 @@ fn main() {
find_entry_point_by_idx(&sierra_program, entry_point.function_idx).unwrap();
let fn_id = &entry_point_fn.id;

let native_executor = JitNativeExecutor::from_native_module(native_program, Default::default());
let native_executor =
JitNativeExecutor::from_native_module(native_program, Default::default()).unwrap();

let result = native_executor
.invoke_contract_dynamic(
Expand Down
3 changes: 2 additions & 1 deletion examples/invoke.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ fn main() {

let fn_id = &entry_point_fn.id;

let native_executor = JitNativeExecutor::from_native_module(native_program, Default::default());
let native_executor =
JitNativeExecutor::from_native_module(native_program, Default::default()).unwrap();

let output = native_executor.invoke_dynamic(fn_id, &[Value::Felt252(1.into())], None);

Expand Down
3 changes: 2 additions & 1 deletion examples/starknet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,8 @@ fn main() {

let fn_id = &entry_point_fn.id;

let native_executor = JitNativeExecutor::from_native_module(native_program, Default::default());
let native_executor =
JitNativeExecutor::from_native_module(native_program, Default::default()).unwrap();

let result = native_executor
.invoke_contract_dynamic(fn_id, &[Felt::ONE], Some(u64::MAX), SyscallHandler::new())
Expand Down
4 changes: 2 additions & 2 deletions src/bin/cairo-native-run.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ fn main() -> anyhow::Result<()> {
let native_executor: Box<dyn Fn(_, _, _, &mut StubSyscallHandler) -> _> = match args.run_mode {
RunMode::Aot => {
let executor =
AotNativeExecutor::from_native_module(native_module, args.opt_level.into());
AotNativeExecutor::from_native_module(native_module, args.opt_level.into())?;
Box::new(move |function_id, args, gas, syscall_handler| {
executor.invoke_dynamic_with_syscall_handler(
function_id,
Expand All @@ -89,7 +89,7 @@ fn main() -> anyhow::Result<()> {
}
RunMode::Jit => {
let executor =
JitNativeExecutor::from_native_module(native_module, args.opt_level.into());
JitNativeExecutor::from_native_module(native_module, args.opt_level.into())?;
Box::new(move |function_id, args, gas, syscall_handler| {
executor.invoke_dynamic_with_syscall_handler(
function_id,
Expand Down
4 changes: 2 additions & 2 deletions src/bin/utils/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ pub fn run_tests(
let native_executor: Box<dyn Fn(_, _, _, &mut StubSyscallHandler) -> _> = match args.run_mode {
RunMode::Aot => {
let executor =
AotNativeExecutor::from_native_module(native_module, args.opt_level.into());
AotNativeExecutor::from_native_module(native_module, args.opt_level.into())?;
Box::new(move |function_id, args, gas, syscall_handler| {
executor.invoke_dynamic_with_syscall_handler(
function_id,
Expand All @@ -153,7 +153,7 @@ pub fn run_tests(
}
RunMode::Jit => {
let executor =
JitNativeExecutor::from_native_module(native_module, args.opt_level.into());
JitNativeExecutor::from_native_module(native_module, args.opt_level.into())?;
Box::new(move |function_id, args, gas, syscall_handler| {
executor.invoke_dynamic_with_syscall_handler(
function_id,
Expand Down
2 changes: 1 addition & 1 deletion src/cache/jit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ where
opt_level: OptLevel,
) -> Result<Arc<JitNativeExecutor<'a>>> {
let module = self.context.compile(program, false)?;
let executor = JitNativeExecutor::from_native_module(module, opt_level);
let executor = JitNativeExecutor::from_native_module(module, opt_level)?;

let executor = Arc::new(executor);
self.cache.insert(key, executor.clone());
Expand Down
8 changes: 4 additions & 4 deletions src/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -684,7 +684,7 @@ mod tests {
let module = native_context
.compile(&program, false)
.expect("failed to compile context");
let executor = AotNativeExecutor::from_native_module(module, OptLevel::default());
let executor = AotNativeExecutor::from_native_module(module, OptLevel::default()).unwrap();

// The first function in the program is `run_test`.
let entrypoint_function_id = &program.funcs.first().expect("should have a function").id;
Expand All @@ -702,7 +702,7 @@ mod tests {
let module = native_context
.compile(&program, false)
.expect("failed to compile context");
let executor = JitNativeExecutor::from_native_module(module, OptLevel::default());
let executor = JitNativeExecutor::from_native_module(module, OptLevel::default()).unwrap();

// The first function in the program is `run_test`.
let entrypoint_function_id = &program.funcs.first().expect("should have a function").id;
Expand All @@ -720,7 +720,7 @@ mod tests {
let module = native_context
.compile(&starknet_program, false)
.expect("failed to compile context");
let executor = AotNativeExecutor::from_native_module(module, OptLevel::default());
let executor = AotNativeExecutor::from_native_module(module, OptLevel::default()).unwrap();

// The last function in the program is the `get` wrapper function.
let entrypoint_function_id = &starknet_program
Expand All @@ -747,7 +747,7 @@ mod tests {
let module = native_context
.compile(&starknet_program, false)
.expect("failed to compile context");
let executor = JitNativeExecutor::from_native_module(module, OptLevel::default());
let executor = JitNativeExecutor::from_native_module(module, OptLevel::default()).unwrap();

// The last function in the program is the `get` wrapper function.
let entrypoint_function_id = &starknet_program
Expand Down
53 changes: 27 additions & 26 deletions src/executor/aot.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::io;

use crate::{
error::Error,
execution_result::{ContractExecutionResult, ExecutionResult},
Expand Down Expand Up @@ -48,27 +50,26 @@ impl AotNativeExecutor {
}

/// Utility to convert a [`NativeModule`] into an [`AotNativeExecutor`].
pub fn from_native_module(module: NativeModule, opt_level: OptLevel) -> Self {
pub fn from_native_module(module: NativeModule, opt_level: OptLevel) -> Result<Self, Error> {
let NativeModule {
module,
registry,
mut metadata,
} = module;

let library_path = NamedTempFile::new()
.unwrap()
let library_path = NamedTempFile::new()?
.into_temp_path()
.keep()
.unwrap();
.map_err(io::Error::from)?;

let object_data = crate::module_to_object(&module, opt_level).unwrap();
crate::object_to_shared_lib(&object_data, &library_path).unwrap();
let object_data = crate::module_to_object(&module, opt_level)?;
crate::object_to_shared_lib(&object_data, &library_path)?;

Self {
library: unsafe { Library::new(&library_path).unwrap() },
Ok(Self {
library: unsafe { Library::new(&library_path)? },
registry,
gas_metadata: metadata.remove().unwrap(),
}
gas_metadata: metadata.remove().ok_or(Error::MissingMetadata)?,
})
}

pub fn invoke_dynamic(
Expand All @@ -95,9 +96,9 @@ impl AotNativeExecutor {

super::invoke_dynamic(
&self.registry,
self.find_function_ptr(function_id),
self.find_function_ptr(function_id)?,
set_costs_builtin,
self.extract_signature(function_id),
self.extract_signature(function_id)?,
args,
available_gas,
Option::<DummySyscallHandler>::None,
Expand Down Expand Up @@ -129,9 +130,9 @@ impl AotNativeExecutor {

super::invoke_dynamic(
&self.registry,
self.find_function_ptr(function_id),
self.find_function_ptr(function_id)?,
set_costs_builtin,
self.extract_signature(function_id),
self.extract_signature(function_id)?,
args,
available_gas,
Some(syscall_handler),
Expand Down Expand Up @@ -163,9 +164,9 @@ impl AotNativeExecutor {

ContractExecutionResult::from_execution_result(super::invoke_dynamic(
&self.registry,
self.find_function_ptr(function_id),
self.find_function_ptr(function_id)?,
set_costs_builtin,
self.extract_signature(function_id),
self.extract_signature(function_id)?,
&[Value::Struct {
fields: vec![Value::Array(
args.iter().cloned().map(Value::Felt252).collect(),
Expand All @@ -177,17 +178,17 @@ impl AotNativeExecutor {
)?)
}

pub fn find_function_ptr(&self, function_id: &FunctionId) -> *mut c_void {
pub fn find_function_ptr(&self, function_id: &FunctionId) -> Result<*mut c_void, Error> {
let function_name = generate_function_name(function_id, false);
let function_name = format!("_mlir_ciface_{function_name}");

// Arguments and return values are hardcoded since they'll be handled by the trampoline.
unsafe {
self.library
.get::<extern "C" fn()>(function_name.as_bytes())
.unwrap()
.into_raw()
Ok(self
.library
.get::<extern "C" fn()>(function_name.as_bytes())?
.into_raw()
.into_raw())
}
}

Expand All @@ -200,8 +201,8 @@ impl AotNativeExecutor {
}
}

fn extract_signature(&self, function_id: &FunctionId) -> &FunctionSignature {
&self.registry.get_function(function_id).unwrap().signature
fn extract_signature(&self, function_id: &FunctionId) -> Result<&FunctionSignature, Error> {
Ok(&self.registry.get_function(function_id)?.signature)
}
}

Expand Down Expand Up @@ -265,7 +266,7 @@ mod tests {
let module = native_context
.compile(&program, false)
.expect("failed to compile context");
let executor = AotNativeExecutor::from_native_module(module, optlevel);
let executor = AotNativeExecutor::from_native_module(module, optlevel).unwrap();

// The first function in the program is `run_test`.
let entrypoint_function_id = &program.funcs.first().expect("should have a function").id;
Expand All @@ -286,7 +287,7 @@ mod tests {
let module = native_context
.compile(&program, false)
.expect("failed to compile context");
let executor = AotNativeExecutor::from_native_module(module, optlevel);
let executor = AotNativeExecutor::from_native_module(module, optlevel).unwrap();

// The second function in the program is `get_block_hash`.
let entrypoint_function_id = &program.funcs.get(1).expect("should have a function").id;
Expand Down Expand Up @@ -325,7 +326,7 @@ mod tests {
let module = native_context
.compile(&starknet_program, false)
.expect("failed to compile context");
let executor = AotNativeExecutor::from_native_module(module, optlevel);
let executor = AotNativeExecutor::from_native_module(module, optlevel).unwrap();

// The last function in the program is the `get` wrapper function.
let entrypoint_function_id = &starknet_program
Expand Down
Loading

0 comments on commit 4a6e239

Please sign in to comment.