diff --git a/src/executor/aot.rs b/src/executor/aot.rs index 242d38e6b..de9b663df 100644 --- a/src/executor/aot.rs +++ b/src/executor/aot.rs @@ -209,12 +209,15 @@ mod tests { } #[rstest] - fn test_invoke_dynamic(program: Program) { + #[case(OptLevel::None)] + #[case(OptLevel::Default)] + #[case(OptLevel::Aggressive)] + fn test_invoke_dynamic(program: Program, #[case] optlevel: OptLevel) { let native_context = NativeContext::new(); let module = native_context .compile(&program) .expect("failed to compile context"); - let executor = AotNativeExecutor::from_native_module(module, OptLevel::default()); + let executor = AotNativeExecutor::from_native_module(module, optlevel); // The first function in the program is `run_test`. let entrypoint_function_id = &program.funcs.first().expect("should have a function").id; @@ -227,12 +230,15 @@ mod tests { } #[rstest] - fn test_invoke_dynamic_with_syscall_handler(program: Program) { + #[case(OptLevel::None)] + #[case(OptLevel::Default)] + #[case(OptLevel::Aggressive)] + fn test_invoke_dynamic_with_syscall_handler(program: Program, #[case] optlevel: OptLevel) { let native_context = NativeContext::new(); let module = native_context .compile(&program) .expect("failed to compile context"); - let executor = AotNativeExecutor::from_native_module(module, OptLevel::default()); + let executor = AotNativeExecutor::from_native_module(module, optlevel); // The second function in the program is `get_block_hash`. let entrypoint_function_id = &program.funcs.get(1).expect("should have a function").id; @@ -263,12 +269,15 @@ mod tests { } #[rstest] - fn test_invoke_contract_dynamic(starknet_program: Program) { + #[case(OptLevel::None)] + #[case(OptLevel::Default)] + #[case(OptLevel::Aggressive)] + fn test_invoke_contract_dynamic(starknet_program: Program, #[case] optlevel: OptLevel) { let native_context = NativeContext::new(); let module = native_context .compile(&starknet_program) .expect("failed to compile context"); - let executor = AotNativeExecutor::from_native_module(module, OptLevel::default()); + let executor = AotNativeExecutor::from_native_module(module, optlevel); // The last function in the program is the `get` wrapper function. let entrypoint_function_id = &starknet_program diff --git a/src/ffi.rs b/src/ffi.rs index 20657ab77..48c2204b8 100644 --- a/src/ffi.rs +++ b/src/ffi.rs @@ -9,7 +9,8 @@ use llvm_sys::{ LLVMContextCreate, LLVMContextDispose, LLVMDisposeMemoryBuffer, LLVMDisposeMessage, LLVMDisposeModule, LLVMGetBufferSize, LLVMGetBufferStart, }, - prelude::{LLVMContextRef, LLVMMemoryBufferRef, LLVMModuleRef}, + error::LLVMGetErrorMessage, + prelude::LLVMMemoryBufferRef, target::{ LLVM_InitializeAllAsmParsers, LLVM_InitializeAllAsmPrinters, LLVM_InitializeAllTargetInfos, LLVM_InitializeAllTargetMCs, LLVM_InitializeAllTargets, @@ -20,13 +21,16 @@ use llvm_sys::{ LLVMGetHostCPUName, LLVMGetTargetFromTriple, LLVMRelocMode, LLVMTargetMachineEmitToMemoryBuffer, LLVMTargetRef, }, + transforms::pass_builder::{ + LLVMCreatePassBuilderOptions, LLVMDisposePassBuilderOptions, LLVMRunPasses, + }, }; use melior::ir::{Module, Type, TypeLike}; -use mlir_sys::{MlirAttribute, MlirContext, MlirModule, MlirOperation}; +use mlir_sys::{mlirTranslateModuleToLLVMIR, MlirAttribute, MlirContext, MlirModule}; use std::{ borrow::Cow, error::Error, - ffi::{c_void, CStr}, + ffi::{c_void, CStr, CString}, fmt::Display, io::Write, mem::MaybeUninit, @@ -152,13 +156,6 @@ pub enum MlirLLVMDWTag { extern "C" { fn LLVMStructType_getFieldTypeAt(ty_ptr: *const c_void, index: u32) -> *const c_void; - /// Translate operation that satisfies LLVM dialect module requirements into an LLVM IR module living in the given context. - /// This translates operations from any dilalect that has a registered implementation of LLVMTranslationDialectInterface. - fn mlirTranslateModuleToLLVMIR( - module_operation_ptr: MlirOperation, - llvm_context: LLVMContextRef, - ) -> LLVMModuleRef; - pub fn mlirLLVMDistinctAttrCreate(attr: MlirAttribute) -> MlirAttribute; pub fn mlirLLVMDICompileUnitAttrGet( @@ -299,7 +296,7 @@ pub fn module_to_object( let op = module.as_operation().to_raw(); - let llvm_module = mlirTranslateModuleToLLVMIR(op, llvm_context); + let llvm_module = mlirTranslateModuleToLLVMIR(op, llvm_context as *mut _) as *mut _; let mut null = null_mut(); let mut error_buffer = addr_of_mut!(null); @@ -337,6 +334,23 @@ pub fn module_to_object( LLVMCodeModel::LLVMCodeModelDefault, ); + let opts = LLVMCreatePassBuilderOptions(); + let opt = match opt_level { + OptLevel::None => 0, + OptLevel::Less => 1, + OptLevel::Default => 2, + OptLevel::Aggressive => 3, + }; + let passes = CString::new(format!("default")).unwrap(); + let error = LLVMRunPasses(llvm_module, passes.as_ptr(), machine, opts); + if !error.is_null() { + let msg = LLVMGetErrorMessage(error); + let msg = CStr::from_ptr(msg); + Err(LLVMCompileError(msg.to_string_lossy().into_owned()))?; + } + + LLVMDisposePassBuilderOptions(opts); + let mut out_buf: MaybeUninit = MaybeUninit::uninit(); let ok = LLVMTargetMachineEmitToMemoryBuffer(