From 2747ab1f6d4a810954e236fad9861e82b690de5b Mon Sep 17 00:00:00 2001 From: Sam Zhou Date: Sun, 22 Dec 2024 08:23:47 -0800 Subject: [PATCH] [refactor] Include symbol_table in WASM AST --- crates/samlang-ast/src/wasm.rs | 23 +++++++++++--------- crates/samlang-compiler/src/lib.rs | 15 +++++++------ crates/samlang-compiler/src/wasm_lowering.rs | 6 ++--- 3 files changed, 24 insertions(+), 20 deletions(-) diff --git a/crates/samlang-ast/src/wasm.rs b/crates/samlang-ast/src/wasm.rs index 70e4e5bd..2a2665e6 100644 --- a/crates/samlang-ast/src/wasm.rs +++ b/crates/samlang-ast/src/wasm.rs @@ -323,6 +323,7 @@ impl GlobalData { } pub struct Module { + pub symbol_table: mir::SymbolTable, pub function_type_parameter_counts: Vec, pub type_definition: Vec, pub global_variables: Vec, @@ -331,7 +332,7 @@ pub struct Module { } impl Module { - pub fn pretty_print(&self, heap: &Heap, table: &mir::SymbolTable) -> String { + pub fn pretty_print(&self, heap: &Heap) -> String { let mut collector = String::new(); for count in &self.function_type_parameter_counts { collector.push_str("(type $"); @@ -348,11 +349,11 @@ impl Module { } for type_def in &self.type_definition { collector.push_str("(type $"); - type_def.name.write_encoded(&mut collector, heap, table); + type_def.name.write_encoded(&mut collector, heap, &self.symbol_table); collector.push_str(" (struct"); for field in &type_def.mappings { collector.push_str(" (field "); - field.pretty_print(&mut collector, heap, table); + field.pretty_print(&mut collector, heap, &self.symbol_table); collector.push(')'); } collector.push_str("))\n"); @@ -365,12 +366,12 @@ impl Module { collector.push_str(" funcref)\n(elem $0 (i32.const 0)"); for f in &self.functions { collector.push_str(" $"); - f.name.write_encoded(&mut collector, heap, table); + f.name.write_encoded(&mut collector, heap, &self.symbol_table); } collector.push_str(")\n"); for Function { name, parameters, local_variables, instructions } in &self.functions { collector.push_str("(func $"); - name.write_encoded(&mut collector, heap, table); + name.write_encoded(&mut collector, heap, &self.symbol_table); for param in parameters { collector.push_str(" (param $"); collector.push_str(param.as_str(heap)); @@ -383,15 +384,15 @@ impl Module { collector.push_str(" i32)\n"); } for i in instructions { - i.print_to_collector(heap, table, &mut collector, 1); + i.print_to_collector(heap, &self.symbol_table, &mut collector, 1); } collector.push_str(")\n"); } for f in &self.exported_functions { collector.push_str("(export \""); - f.write_encoded(&mut collector, heap, table); + f.write_encoded(&mut collector, heap, &self.symbol_table); collector.push_str("\" (func $"); - f.write_encoded(&mut collector, heap, table); + f.write_encoded(&mut collector, heap, &self.symbol_table); collector.push_str("))\n"); } collector @@ -419,7 +420,8 @@ mod tests { let heap = &mut Heap::new(); let mut table = mir::SymbolTable::new(); - let module = Module { + let mut module = Module { + symbol_table: mir::SymbolTable::new(), function_type_parameter_counts: vec![0, 1, 2, 3], type_definition: vec![lir::TypeDefinition { name: table.create_type_name_for_test(PStr::UPPER_F), @@ -597,6 +599,7 @@ mod tests { ], }], }; + module.symbol_table = table; let expected = r#"(type $none_=>_i32 (func (result i32))) (type $i32_=>_i32 (func (param i32) (result i32))) (type $i32_i32_=>_i32 (func (param i32 i32) (result i32))) @@ -654,6 +657,6 @@ mod tests { ) (export "__$main" (func $__$main)) "#; - assert_eq!(expected, module.pretty_print(heap, &table)); + assert_eq!(expected, module.pretty_print(heap)); } } diff --git a/crates/samlang-compiler/src/lib.rs b/crates/samlang-compiler/src/lib.rs index d3d95b22..f8ad7aaf 100644 --- a/crates/samlang-compiler/src/lib.rs +++ b/crates/samlang-compiler/src/lib.rs @@ -14,12 +14,12 @@ pub use lir_lowering::compile_mir_to_lir; pub fn compile_lir_to_wasm( heap: &mut samlang_heap::Heap, - sources: &samlang_ast::lir::Sources, + sources: samlang_ast::lir::Sources, ) -> (String, Vec) { let whole_module_string = format!( "(module\n{}\n{}\n)\n", include_str!("libsam.wat"), - wasm_lowering::compile_lir_to_wasm(heap, sources).pretty_print(heap, &sources.symbol_table) + wasm_lowering::compile_lir_to_wasm(heap, sources).pretty_print(heap) ); let wat = wat::parse_str(&whole_module_string).unwrap(); (whole_module_string, wat) @@ -84,10 +84,6 @@ pub fn compile_sources( compile_mir_to_lir(heap, optimized_mir_sources) }); let common_ts_code = lir_sources.pretty_print(heap); - let (wat_text, wasm_file) = - samlang_profiling::measure_time(enable_profiling, "Compile to WASM", || { - compile_lir_to_wasm(heap, &lir_sources) - }); let mut text_code_results = std::collections::BTreeMap::new(); for module_reference in &entry_module_references { @@ -111,6 +107,11 @@ require('./__samlang_loader__.js')(binary).{}(); text_code_results .insert(format!("{}.wasm.js", module_reference.pretty_print(heap)), wasm_js_code); } + + let (wat_text, wasm_file) = + samlang_profiling::measure_time(enable_profiling, "Compile to WASM", || { + compile_lir_to_wasm(heap, lir_sources) + }); text_code_results.insert(EMITTED_WAT_FILE.to_string(), wat_text); Ok(SourcesCompilationResult { text_code_results, wasm_file }) @@ -153,7 +154,7 @@ class HelloWorld { assert_eq!("", error_set.pretty_print_error_messages_no_frame_for_test(&heap)); let mir_sources = super::compile_sources_to_mir(&mut heap, &checked_sources); let lir_sources = super::compile_mir_to_lir(&mut heap, mir_sources); - super::compile_lir_to_wasm(&mut heap, &lir_sources); + super::compile_lir_to_wasm(&mut heap, lir_sources); } #[test] diff --git a/crates/samlang-compiler/src/wasm_lowering.rs b/crates/samlang-compiler/src/wasm_lowering.rs index a76639aa..6ab286b2 100644 --- a/crates/samlang-compiler/src/wasm_lowering.rs +++ b/crates/samlang-compiler/src/wasm_lowering.rs @@ -274,7 +274,7 @@ impl<'a> LoweringManager<'a> { } } -pub(super) fn compile_lir_to_wasm(heap: &Heap, sources: &lir::Sources) -> wasm::Module { +pub(super) fn compile_lir_to_wasm(heap: &Heap, sources: lir::Sources) -> wasm::Module { let mut data_start: usize = 4096; let mut global_variables_to_pointer_mapping = HashMap::new(); let mut function_index_mapping = HashMap::new(); @@ -297,6 +297,7 @@ pub(super) fn compile_lir_to_wasm(heap: &Heap, sources: &lir::Sources) -> wasm:: function_index_mapping.insert(f.name, i); } wasm::Module { + symbol_table: sources.symbol_table, function_type_parameter_counts: sources .functions .iter() @@ -464,8 +465,7 @@ mod tests { return_value: ZERO, }], }; - let actual = - super::compile_lir_to_wasm(heap, &sources).pretty_print(heap, &sources.symbol_table); + let actual = super::compile_lir_to_wasm(heap, sources).pretty_print(heap); let expected = r#"(type $i32_=>_i32 (func (param i32) (result i32))) (data (i32.const 4096) "\00\00\00\00\03\00\00\00FOO") (data (i32.const 4112) "\00\00\00\00\03\00\00\00BAR")