diff --git a/crates/concrete_ast/src/statements.rs b/crates/concrete_ast/src/statements.rs index 8076a7a..4c0a3cd 100644 --- a/crates/concrete_ast/src/statements.rs +++ b/crates/concrete_ast/src/statements.rs @@ -26,6 +26,24 @@ pub enum LetStmtTarget { pub struct LetStmt { pub is_mutable: bool, pub target: LetStmtTarget, + pub value: LetValue, +} + +#[derive(Clone, Debug, Eq, PartialEq)] +pub enum LetValue { + Expr(Expression), + StructConstruct(StructConstruct), +} + +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct StructConstruct { + pub name: Ident, + pub fields: Vec, +} + +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct FieldConstruct { + pub name: Ident, pub value: Expression, } @@ -38,6 +56,7 @@ pub struct ReturnStmt { pub struct AssignStmt { pub target: PathOp, pub value: Expression, + pub is_deref: bool, } #[derive(Clone, Debug, Eq, Hash, PartialEq)] diff --git a/crates/concrete_ast/src/structs.rs b/crates/concrete_ast/src/structs.rs index bd36ab6..9d134b0 100644 --- a/crates/concrete_ast/src/structs.rs +++ b/crates/concrete_ast/src/structs.rs @@ -5,6 +5,7 @@ use crate::{ #[derive(Clone, Debug, Eq, PartialEq)] pub struct StructDecl { + pub is_pub: bool, pub doc_string: Option, pub name: Ident, pub type_params: Vec, diff --git a/crates/concrete_ast/src/types.rs b/crates/concrete_ast/src/types.rs index 6454c4a..970fb3b 100644 --- a/crates/concrete_ast/src/types.rs +++ b/crates/concrete_ast/src/types.rs @@ -21,7 +21,7 @@ pub enum TypeSpec { }, Array { of_type: Box, - size: Option, + size: Option, is_ref: Option, span: Span, }, @@ -36,6 +36,10 @@ impl TypeSpec { } } + pub fn is_mut_ref(&self) -> bool { + matches!(self.is_ref(), Some(RefType::MutBorrow)) + } + pub fn get_name(&self) -> String { match self { TypeSpec::Simple { name, .. } => name.name.clone(), @@ -43,6 +47,42 @@ impl TypeSpec { TypeSpec::Array { of_type, .. } => format!("[{}]", of_type.get_name()), } } + + pub fn to_nonref_type(&self) -> TypeSpec { + match self { + TypeSpec::Simple { + name, + is_ref: _, + span, + } => TypeSpec::Simple { + name: name.clone(), + is_ref: None, + span: *span, + }, + TypeSpec::Generic { + name, + is_ref: _, + type_params, + span, + } => TypeSpec::Generic { + name: name.clone(), + is_ref: None, + type_params: type_params.clone(), + span: *span, + }, + TypeSpec::Array { + of_type, + size, + is_ref: _, + span, + } => TypeSpec::Array { + of_type: of_type.clone(), + size: *size, + is_ref: None, + span: *span, + }, + } + } } #[derive(Clone, Debug, Eq, Hash, PartialEq)] diff --git a/crates/concrete_check/src/lib.rs b/crates/concrete_check/src/lib.rs index a5f02c7..61ca6a5 100644 --- a/crates/concrete_check/src/lib.rs +++ b/crates/concrete_check/src/lib.rs @@ -155,7 +155,9 @@ impl<'parent> ScopeContext<'parent> { "f64" => name.to_string(), "bool" => name.to_string(), name => { - if let Some(module) = self.imports.get(name) { + if let Some(x) = self.module_info.structs.get(name) { + name.to_string() + } else if let Some(module) = self.imports.get(name) { // a import self.resolve_type_spec(&module.types.get(name)?.value)? } else { diff --git a/crates/concrete_codegen_mlir/src/codegen.rs b/crates/concrete_codegen_mlir/src/codegen.rs index 4bcae6f..ab7cf29 100644 --- a/crates/concrete_codegen_mlir/src/codegen.rs +++ b/crates/concrete_codegen_mlir/src/codegen.rs @@ -3,11 +3,12 @@ use std::{collections::HashMap, error::Error}; use bumpalo::Bump; use concrete_ast::{ expressions::{ - ArithOp, BinaryOp, CmpOp, Expression, FnCallOp, IfExpr, LogicOp, PathOp, ValueExpr, + ArithOp, BinaryOp, CmpOp, Expression, FnCallOp, IfExpr, LogicOp, PathOp, PathSegment, + ValueExpr, }, functions::FunctionDef, modules::{Module, ModuleDefItem}, - statements::{AssignStmt, LetStmt, LetStmtTarget, ReturnStmt, Statement, WhileStmt}, + statements::{AssignStmt, LetStmt, LetStmtTarget, LetValue, ReturnStmt, Statement, WhileStmt}, types::TypeSpec, Program, }; @@ -16,10 +17,15 @@ use concrete_session::Session; use melior::{ dialect::{ arith::{self, CmpiPredicate}, - cf, func, memref, + cf, func, + llvm::{self, r#type::opaque_pointer, LoadStoreOptions}, + memref, }, ir::{ - attribute::{FlatSymbolRefAttribute, IntegerAttribute, StringAttribute, TypeAttribute}, + attribute::{ + DenseI32ArrayAttribute, DenseI64ArrayAttribute, FlatSymbolRefAttribute, + IntegerAttribute, StringAttribute, TypeAttribute, + }, r#type::{FunctionType, IntegerType, MemRefType}, Block, BlockRef, Location, Module as MeliorModule, Operation, Region, Type, Value, ValueLike, @@ -27,6 +33,8 @@ use melior::{ Context as MeliorContext, }; +use crate::scope_context::ScopeContext; + pub fn compile_program( session: &Session, ctx: &MeliorContext, @@ -50,34 +58,29 @@ pub struct LocalVar<'ctx, 'parent: 'ctx> { // If it's none its on a register, otherwise allocated on the stack. pub alloca: bool, pub value: Value<'ctx, 'parent>, + pub is_mut: bool, } impl<'ctx, 'parent: 'ctx> LocalVar<'ctx, 'parent> { - pub fn param(value: Value<'ctx, 'parent>, type_spec: TypeSpec) -> Self { + pub fn param(value: Value<'ctx, 'parent>, type_spec: TypeSpec, is_mut: bool) -> Self { Self { value, type_spec, alloca: false, + is_mut, } } - pub fn alloca(value: Value<'ctx, 'parent>, type_spec: TypeSpec) -> Self { + pub fn alloca(value: Value<'ctx, 'parent>, type_spec: TypeSpec, is_mut: bool) -> Self { Self { value, type_spec, alloca: true, + is_mut, } } } -#[derive(Debug, Clone)] -struct ScopeContext<'ctx, 'parent: 'ctx> { - pub locals: HashMap>, - pub function: Option, - pub imports: HashMap>, - pub module_info: &'parent ModuleInfo<'parent>, -} - struct BlockHelper<'ctx, 'region: 'ctx> { region: &'region Region<'ctx>, blocks_arena: &'region Bump, @@ -93,120 +96,6 @@ impl<'ctx, 'region> BlockHelper<'ctx, 'region> { } } -impl<'ctx, 'parent> ScopeContext<'ctx, 'parent> { - /// Returns the symbol name from a local name. - pub fn get_symbol_name(&self, local_name: &str) -> String { - if local_name == "main" { - return local_name.to_string(); - } - - if let Some(module) = self.imports.get(local_name) { - // a import - module.get_symbol_name(local_name) - } else { - let mut result = self.module_info.name.clone(); - - result.push_str("::"); - result.push_str(local_name); - - result - } - } - - pub fn get_function(&self, local_name: &str) -> Option<&FunctionDef> { - if let Some(module) = self.imports.get(local_name) { - // a import - module.functions.get(local_name).copied() - } else { - self.module_info.functions.get(local_name).copied() - } - } - - fn resolve_type( - &self, - context: &'ctx MeliorContext, - name: &str, - ) -> Result, Box> { - Ok(match name { - "u64" | "i64" => IntegerType::new(context, 64).into(), - "u32" | "i32" => IntegerType::new(context, 32).into(), - "u16" | "i16" => IntegerType::new(context, 16).into(), - "u8" | "i8" => IntegerType::new(context, 8).into(), - "f32" => Type::float32(context), - "f64" => Type::float64(context), - "bool" => IntegerType::new(context, 1).into(), - name => { - if let Some(module) = self.imports.get(name) { - // a import - self.resolve_type_spec( - context, - &module.types.get(name).expect("failed to find type").value, - )? - } else { - self.resolve_type_spec( - context, - &self - .module_info - .types - .get(name) - .expect("failed to find type") - .value, - )? - } - } - }) - } - - fn resolve_type_spec( - &self, - context: &'ctx MeliorContext, - spec: &TypeSpec, - ) -> Result, Box> { - match spec.is_ref() { - Some(_) => { - Ok( - MemRefType::new(self.resolve_type_spec_ref(context, spec)?, &[], None, None) - .into(), - ) - } - None => self.resolve_type_spec_ref(context, spec), - } - } - - /// Resolves the type this ref points to. - fn resolve_type_spec_ref( - &self, - context: &'ctx MeliorContext, - spec: &TypeSpec, - ) -> Result, Box> { - Ok(match spec { - TypeSpec::Simple { name, .. } => self.resolve_type(context, &name.name)?, - TypeSpec::Generic { name, .. } => self.resolve_type(context, &name.name)?, - TypeSpec::Array { .. } => { - todo!("implement arrays") - } - }) - } - - fn is_type_signed(&self, type_info: &TypeSpec) -> bool { - let signed = ["i8", "i16", "i32", "i64", "i128"]; - match type_info { - TypeSpec::Simple { name, .. } => signed.contains(&name.name.as_str()), - TypeSpec::Generic { name, .. } => signed.contains(&name.name.as_str()), - TypeSpec::Array { .. } => unreachable!(), - } - } - - fn is_float(&self, type_info: &TypeSpec) -> bool { - let signed = ["f32", "f64"]; - match type_info { - TypeSpec::Simple { name, .. } => signed.contains(&name.name.as_str()), - TypeSpec::Generic { name, .. } => signed.contains(&name.name.as_str()), - TypeSpec::Array { .. } => unreachable!(), - } - } -} - fn compile_module( session: &Session, context: &MeliorContext, @@ -250,7 +139,7 @@ fn compile_module( let op = compile_function_def(session, context, &scope_ctx, info)?; body.append_operation(op); } - ModuleDefItem::Struct(_) => todo!(), + ModuleDefItem::Struct(_) => {} ModuleDefItem::Type(_) => todo!(), ModuleDefItem::Module(info) => { let module_info = module_info.modules.get(&info.name.name).unwrap_or_else(|| { @@ -331,7 +220,7 @@ fn compile_function_def<'ctx, 'parent: 'ctx>( for (i, param) in info.decl.params.iter().enumerate() { scope_ctx.locals.insert( param.name.name.clone(), - LocalVar::param(fn_block.argument(i)?.into(), param.r#type.clone()), + LocalVar::param(fn_block.argument(i)?.into(), param.r#type.clone(), false), ); } @@ -339,6 +228,10 @@ fn compile_function_def<'ctx, 'parent: 'ctx>( fn_block = compile_statement(session, context, &mut scope_ctx, &helper, fn_block, stmt)?; } + + if fn_block.terminator().is_none() { + fn_block.append_operation(func::r#return(&[], location)); + } } let fn_name = scope_ctx.get_symbol_name(&info.decl.name.name); @@ -547,15 +440,74 @@ fn compile_let_stmt<'ctx, 'parent: 'ctx>( LetStmtTarget::Simple { name, r#type } => { let location = get_location(context, session, name.span.from); - let value = compile_expression( - session, - context, - scope_ctx, - helper, - block, - &info.value, - Some(r#type), - )?; + let value = match &info.value { + LetValue::Expr(value) => compile_expression( + session, + context, + scope_ctx, + helper, + block, + value, + Some(r#type), + )?, + LetValue::StructConstruct(struct_construct) => { + let struct_decl = scope_ctx + .module_info + .structs + .get(&struct_construct.name.name) + .unwrap_or_else(|| { + panic!("failed to find struct {:?}", struct_construct.name.name) + }); + assert_eq!( + struct_construct.fields.len(), + struct_decl.fields.len(), + "struct field len mismatch" + ); + let (ty, field_indexes) = scope_ctx.get_struct_type(context, struct_decl)?; + let mut struct_value = block + .append_operation(llvm::undef(ty, location)) + .result(0)? + .into(); + let field_types: HashMap = struct_decl + .fields + .iter() + .map(|x| (x.name.name.clone(), &x.r#type)) + .collect(); + + for field in &struct_construct.fields { + let field_idx = field_indexes.get(&field.name.name).unwrap_or_else(|| { + panic!( + "failed to find field {:?} for struct {:?}", + field.name.name, struct_construct.name.name + ) + }); + + let field_ty = field_types.get(&field.name.name).expect("field not found"); + let value = compile_expression( + session, + context, + scope_ctx, + helper, + block, + &field.value, + Some(field_ty), + )?; + + struct_value = block + .append_operation(llvm::insert_value( + context, + struct_value, + DenseI64ArrayAttribute::new(context, &[(*field_idx) as i64]), + value, + location, + )) + .result(0)? + .into(); + } + + struct_value + } + }; let memref_type = MemRefType::new(value.r#type(), &[], None, None); let alloca: Value = block @@ -572,9 +524,10 @@ fn compile_let_stmt<'ctx, 'parent: 'ctx>( block.append_operation(memref::store(value, alloca, &[], location)); - scope_ctx - .locals - .insert(name.name.clone(), LocalVar::alloca(alloca, r#type.clone())); + scope_ctx.locals.insert( + name.name.clone(), + LocalVar::alloca(alloca, r#type.clone(), info.is_mutable), + ); Ok(()) } @@ -590,6 +543,7 @@ fn compile_assign_stmt<'ctx, 'parent: 'ctx>( block: &'parent Block<'ctx>, info: &AssignStmt, ) -> Result<(), Box> { + tracing::debug!("compiling assign for {:?}", info.target); // todo: implement properly for structs, right now only really works for simple variables. let local = scope_ctx @@ -598,21 +552,143 @@ fn compile_assign_stmt<'ctx, 'parent: 'ctx>( .expect("local should exist") .clone(); - assert!(local.alloca, "can only mutate local stack variables"); + assert!( + local.is_mut || local.type_spec.is_mut_ref(), + "can only mutate mutable or ref mut variables" + ); + assert!( + local.type_spec.is_mut_ref() || local.alloca, + "can only mutate local stack variables" + ); let location = get_location(context, session, info.target.first.span.from); - let value = compile_expression( - session, - context, - scope_ctx, - helper, - block, - &info.value, - Some(&local.type_spec), - )?; + if info.target.extra.is_empty() { + let mut target_value = local.value; + let mut type_spec = local.type_spec.clone(); + + if info.is_deref { + assert!(local.type_spec.is_mut_ref(), "can only mutate mutable refs"); + if local.alloca { + target_value = block + .append_operation(memref::load(local.value, &[], location)) + .result(0)? + .into() + } + type_spec = type_spec.to_nonref_type(); + } + let value = compile_expression( + session, + context, + scope_ctx, + helper, + block, + &info.value, + Some(&type_spec), + )?; + + block.append_operation(memref::store(value, target_value, &[], location)); + } else { + let mut current_type_spec = &local.type_spec; + + // get a ptr to the field + let target_ptr = block + .append_operation( + melior::dialect::ods::memref::extract_aligned_pointer_as_index( + context, + Type::index(context), + local.value, + location, + ) + .into(), + ) + .result(0)? + .into(); + + let target_ptr = block + .append_operation(arith::index_cast( + target_ptr, + IntegerType::new(context, 64).into(), + location, + )) + .result(0)? + .into(); - block.append_operation(memref::store(value, local.value, &[], location)); + let mut target_ptr = block + .append_operation( + melior::dialect::ods::llvm::inttoptr( + context, + opaque_pointer(context), + target_ptr, + location, + ) + .into(), + ) + .result(0)? + .into(); + + let mut extra_it = info.target.extra.iter().peekable(); + + while let Some(extra) = extra_it.next() { + match extra { + PathSegment::FieldAccess(ident) => { + let (struct_decl, (_, field_indexes)) = match current_type_spec { + TypeSpec::Simple { name, .. } => { + let struct_decl = + scope_ctx.module_info.structs.get(&name.name).unwrap(); + ( + struct_decl, + scope_ctx.get_struct_type(context, struct_decl)?, + ) + } + _ => unreachable!(), + }; + + let field = struct_decl + .fields + .iter() + .find(|x| x.name.name == ident.name) + .unwrap_or_else(|| panic!("failed to find field {:?}", ident.name)); + let field_idx = *field_indexes.get(&ident.name).unwrap(); + let field_ty = scope_ctx.resolve_type_spec(context, &field.r#type)?; + + current_type_spec = &field.r#type; + target_ptr = block + .append_operation(llvm::get_element_ptr( + context, + target_ptr, + DenseI32ArrayAttribute::new(context, &[field_idx as i32]), + field_ty, + opaque_pointer(context), + location, + )) + .result(0)? + .into(); + + if extra_it.peek().is_none() { + let value = compile_expression( + session, + context, + scope_ctx, + helper, + block, + &info.value, + Some(current_type_spec), + )?; + + block.append_operation(llvm::store( + context, + value, + target_ptr, + location, + LoadStoreOptions::default(), + )); + } + } + PathSegment::ArrayIndex(_) => todo!(), + } + } + } Ok(()) } @@ -660,7 +736,7 @@ fn compile_expression<'ctx, 'parent: 'ctx>( session, context, scope_ctx, _helper, block, value, type_info, ), Expression::FnCall(value) => { - compile_fn_call(session, context, scope_ctx, _helper, block, value) + compile_fn_call_with_return(session, context, scope_ctx, _helper, block, value) } Expression::Match(_) => todo!(), Expression::If(_) => todo!(), @@ -849,6 +925,54 @@ fn compile_fn_call<'ctx, 'parent: 'ctx>( _helper: &BlockHelper<'ctx, 'parent>, block: &'parent Block<'ctx>, info: &FnCallOp, +) -> Result<(), Box> { + tracing::debug!("compiling fncall: {:?}", info); + let mut args = Vec::with_capacity(info.args.len()); + let location = get_location(context, session, info.target.span.from); + + let target_fn = scope_ctx + .get_function(&info.target.name) + .expect("function not found") + .clone(); + + assert_eq!( + info.args.len(), + target_fn.decl.params.len(), + "parameter length doesnt match" + ); + + for (arg, arg_info) in info.args.iter().zip(&target_fn.decl.params) { + let value = compile_expression( + session, + context, + scope_ctx, + _helper, + block, + arg, + Some(&arg_info.r#type), + )?; + args.push(value); + } + + let fn_name = scope_ctx.get_symbol_name(&info.target.name); + + block.append_operation(func::call( + context, + FlatSymbolRefAttribute::new(context, &fn_name), + &args, + &[], + location, + )); + Ok(()) +} + +fn compile_fn_call_with_return<'ctx, 'parent: 'ctx>( + session: &Session, + context: &'ctx MeliorContext, + scope_ctx: &mut ScopeContext<'ctx, 'parent>, + _helper: &BlockHelper<'ctx, 'parent>, + block: &'parent Block<'ctx>, + info: &FnCallOp, ) -> Result, Box> { tracing::debug!("compiling fncall: {:?}", info); let mut args = Vec::with_capacity(info.args.len()); @@ -918,7 +1042,7 @@ fn compile_path_op<'ctx, 'parent: 'ctx>( let location = get_location(context, session, path.first.span.from); - let value = if local.alloca { + let mut value = if local.alloca { block .append_operation(memref::load(local.value, &[], location)) .result(0)? @@ -927,7 +1051,52 @@ fn compile_path_op<'ctx, 'parent: 'ctx>( local.value }; - Ok(value) + if path.extra.is_empty() { + Ok(value) + } else { + let mut current_type_spec = &local.type_spec; + + for extra in &path.extra { + match extra { + PathSegment::FieldAccess(ident) => { + let (struct_decl, (_, field_indexes)) = match current_type_spec { + TypeSpec::Simple { name, .. } => { + let struct_decl = + scope_ctx.module_info.structs.get(&name.name).unwrap(); + ( + struct_decl, + scope_ctx.get_struct_type(context, struct_decl)?, + ) + } + _ => unreachable!(), + }; + + let field = struct_decl + .fields + .iter() + .find(|x| x.name.name == ident.name) + .unwrap(); + let field_idx = *field_indexes.get(&ident.name).unwrap(); + let field_ty = scope_ctx.resolve_type_spec(context, &field.r#type)?; + + current_type_spec = &field.r#type; + value = block + .append_operation(llvm::extract_value( + context, + value, + DenseI64ArrayAttribute::new(context, &[field_idx as i64]), + field_ty, + location, + )) + .result(0)? + .into(); + } + PathSegment::ArrayIndex(_) => todo!(), + } + } + + Ok(value) + } } fn compile_deref<'ctx, 'parent: 'ctx>( @@ -959,6 +1128,8 @@ fn compile_deref<'ctx, 'parent: 'ctx>( .into(); } + // todo: handle deref for struct fields + Ok(value) } @@ -981,5 +1152,7 @@ fn compile_asref<'ctx, 'parent: 'ctx>( panic!("can only take refs to non register values"); } + // todo: handle asref for struct fields + Ok(local.value) } diff --git a/crates/concrete_codegen_mlir/src/context.rs b/crates/concrete_codegen_mlir/src/context.rs index 70961fd..db65869 100644 --- a/crates/concrete_codegen_mlir/src/context.rs +++ b/crates/concrete_codegen_mlir/src/context.rs @@ -43,7 +43,7 @@ impl Context { super::codegen::compile_program(session, &self.melior_context, &melior_module, program)?; - let print_flags = OperationPrintingFlags::new().enable_debug_info(true, true); + let print_flags = OperationPrintingFlags::new().enable_debug_info(false, true); tracing::debug!( "MLIR Code before passes:\n{}", melior_module diff --git a/crates/concrete_codegen_mlir/src/lib.rs b/crates/concrete_codegen_mlir/src/lib.rs index 959807b..dfa9b1f 100644 --- a/crates/concrete_codegen_mlir/src/lib.rs +++ b/crates/concrete_codegen_mlir/src/lib.rs @@ -35,6 +35,7 @@ mod error; pub mod linker; mod module; mod pass_manager; +mod scope_context; /// Returns the object file path. pub fn compile(session: &Session, program: &Program) -> Result> { diff --git a/crates/concrete_codegen_mlir/src/scope_context.rs b/crates/concrete_codegen_mlir/src/scope_context.rs new file mode 100644 index 0000000..6d077c9 --- /dev/null +++ b/crates/concrete_codegen_mlir/src/scope_context.rs @@ -0,0 +1,299 @@ +use std::{collections::HashMap, error::Error}; + +use concrete_ast::{ + expressions::{Expression, PathSegment, ValueExpr}, + functions::FunctionDef, + structs::StructDecl, + types::TypeSpec, +}; +use concrete_check::ast_helper::ModuleInfo; +use melior::{ + dialect::llvm, + ir::{ + r#type::{IntegerType, MemRefType}, + Type, + }, + Context as MeliorContext, +}; + +use crate::codegen::LocalVar; + +#[derive(Debug, Clone)] +pub struct ScopeContext<'ctx, 'parent: 'ctx> { + pub locals: HashMap>, + pub function: Option, + pub imports: HashMap>, + pub module_info: &'parent ModuleInfo<'parent>, +} + +impl<'ctx, 'parent> ScopeContext<'ctx, 'parent> { + /// Returns the symbol name from a local name. + pub fn get_symbol_name(&self, local_name: &str) -> String { + if local_name == "main" { + return local_name.to_string(); + } + + if let Some(module) = self.imports.get(local_name) { + // a import + module.get_symbol_name(local_name) + } else { + let mut result = self.module_info.name.clone(); + + result.push_str("::"); + result.push_str(local_name); + + result + } + } + + pub fn get_function(&self, local_name: &str) -> Option<&FunctionDef> { + if let Some(module) = self.imports.get(local_name) { + // a import + module.functions.get(local_name).copied() + } else { + self.module_info.functions.get(local_name).copied() + } + } + + /// Returns the size in bytes for a type. + fn get_type_size(&self, type_info: &TypeSpec) -> Result> { + Ok(match type_info { + TypeSpec::Simple { name, .. } => match name.name.as_str() { + "u128" | "i128" => 16, + "u64" | "i64" => 8, + "u32" | "i32" => 4, + "u16" | "i16" => 2, + "u8" | "i8" => 1, + "f64" => 8, + "f32" => 4, + "bool" => 1, + name => { + if let Some(x) = self.module_info.structs.get(name) { + let mut size = 0u32; + + for field in &x.fields { + let ty_size = self.get_type_size(&field.r#type)?; + let ty_align = self.get_align_for_size(ty_size); + + // Calculate padding needed. + let size_rounded_up = size.wrapping_add(ty_align).wrapping_sub(1) + & !ty_align.wrapping_sub(1u32); + let pad = size_rounded_up.wrapping_sub(size); + + size += pad; + size += ty_size; + } + + let struct_align = self.get_align_for_size(size); + let size_rounded_up = size.wrapping_add(struct_align).wrapping_sub(1) + & !struct_align.wrapping_sub(1u32); + let pad = size_rounded_up.wrapping_sub(size); + + size + pad + } else if let Some(module) = self.imports.get(name) { + // a import + self.get_type_size( + &module.types.get(name).expect("failed to find type").value, + )? + } else { + self.get_type_size( + &self + .module_info + .types + .get(name) + .expect("failed to find type") + .value, + )? + } + } + }, + TypeSpec::Generic { .. } => todo!(), + TypeSpec::Array { of_type, size, .. } => { + self.get_type_size(of_type)? * size.unwrap_or(1u32) + } + }) + } + + fn get_align_for_size(&self, size: u32) -> u32 { + if size <= 1 { + 1 + } else if size <= 2 { + 2 + } else if size <= 4 { + 4 + } else { + 8 + } + } + + /// Returns the struct type along with the field indexes. + pub fn get_struct_type( + &self, + context: &'ctx MeliorContext, + strct: &StructDecl, + ) -> Result<(Type<'ctx>, HashMap), Box> { + let mut fields = Vec::with_capacity(strct.fields.len()); + + let mut field_indexes = HashMap::new(); + let mut size: u32 = 0; + + for field in &strct.fields { + let ty = self.resolve_type_spec(context, &field.r#type)?; + let ty_size = self.get_type_size(&field.r#type)?; + let ty_align = self.get_align_for_size(ty_size); + + // Calculate padding needed. + let size_rounded_up = + size.wrapping_add(ty_align).wrapping_sub(1) & !ty_align.wrapping_sub(1u32); + let pad = size_rounded_up.wrapping_sub(size); + + if pad > 0 { + fields.push(llvm::r#type::array( + IntegerType::new(context, 8).into(), + pad, + )); + } + + size += pad; + size += ty_size; + field_indexes.insert(field.name.name.clone(), fields.len()); + fields.push(ty); + } + + let struct_align = self.get_align_for_size(size); + + // Calculate padding needed for whole struct. + let size_rounded_up = + size.wrapping_add(struct_align).wrapping_sub(1) & !struct_align.wrapping_sub(1u32); + let pad = size_rounded_up.wrapping_sub(size); + + if pad > 0 { + fields.push(llvm::r#type::array( + IntegerType::new(context, 8).into(), + pad, + )); + } + + Ok(( + llvm::r#type::r#struct(context, &fields, false), + field_indexes, + )) + } + + pub fn resolve_type( + &self, + context: &'ctx MeliorContext, + name: &str, + ) -> Result, Box> { + Ok(match name { + "u64" | "i64" => IntegerType::new(context, 64).into(), + "u32" | "i32" => IntegerType::new(context, 32).into(), + "u16" | "i16" => IntegerType::new(context, 16).into(), + "u8" | "i8" => IntegerType::new(context, 8).into(), + "f32" => Type::float32(context), + "f64" => Type::float64(context), + "bool" => IntegerType::new(context, 1).into(), + name => { + if let Some(strct) = self.module_info.structs.get(name) { + self.get_struct_type(context, strct)?.0 + } else if let Some(module) = self.imports.get(name) { + // a import + self.resolve_type_spec( + context, + &module.types.get(name).expect("failed to find type").value, + )? + } else { + self.resolve_type_spec( + context, + &self + .module_info + .types + .get(name) + .expect("failed to find type") + .value, + )? + } + } + }) + } + + pub fn resolve_type_spec( + &self, + context: &'ctx MeliorContext, + spec: &TypeSpec, + ) -> Result, Box> { + match spec.is_ref() { + Some(_) => { + Ok( + MemRefType::new(self.resolve_type_spec_ref(context, spec)?, &[], None, None) + .into(), + ) + } + None => self.resolve_type_spec_ref(context, spec), + } + } + + /// Resolves the type this ref points to. + pub fn resolve_type_spec_ref( + &self, + context: &'ctx MeliorContext, + spec: &TypeSpec, + ) -> Result, Box> { + Ok(match spec { + TypeSpec::Simple { name, .. } => self.resolve_type(context, &name.name)?, + TypeSpec::Generic { name, .. } => self.resolve_type(context, &name.name)?, + TypeSpec::Array { .. } => { + todo!("implement arrays") + } + }) + } + + pub fn is_type_signed(&self, type_info: &TypeSpec) -> bool { + let signed = ["i8", "i16", "i32", "i64", "i128"]; + match type_info { + TypeSpec::Simple { name, .. } => signed.contains(&name.name.as_str()), + TypeSpec::Generic { name, .. } => signed.contains(&name.name.as_str()), + TypeSpec::Array { .. } => unreachable!(), + } + } + + pub fn is_float(&self, type_info: &TypeSpec) -> bool { + let signed = ["f32", "f64"]; + match type_info { + TypeSpec::Simple { name, .. } => signed.contains(&name.name.as_str()), + TypeSpec::Generic { name, .. } => signed.contains(&name.name.as_str()), + TypeSpec::Array { .. } => unreachable!(), + } + } + + pub fn get_expr_type(&self, exp: &Expression) -> Option { + match exp { + Expression::Value(value) => match value { + ValueExpr::Path(path) => { + let first = self.locals.get(&path.first.name)?; + + if path.extra.is_empty() { + Some(first.type_spec.clone()) + } else { + let mut current = &first.type_spec; + for extra in &path.extra { + match extra { + PathSegment::FieldAccess(ident) => { + let st = self.module_info.structs.get(&ident.name)?; + let field = st.fields.get(ident.name); + } + PathSegment::ArrayIndex(_) => todo!(), + } + } + } + } + _ => None, + }, + Expression::FnCall(_) => todo!(), + Expression::Match(_) => todo!(), + Expression::If(_) => todo!(), + Expression::UnaryOp(_, _) => todo!(), + Expression::BinaryOp(_, _, _) => todo!(), + } + } +} diff --git a/crates/concrete_driver/tests/programs.rs b/crates/concrete_driver/tests/programs.rs index fecc2f3..c41b6c2 100644 --- a/crates/concrete_driver/tests/programs.rs +++ b/crates/concrete_driver/tests/programs.rs @@ -158,3 +158,73 @@ fn test_reference() { let code = output.status.code().unwrap(); assert_eq!(code, 2); } + +#[test] +fn test_mut_reference() { + let source = r#" + mod Simple { + fn main(argc: i64) -> i64 { + let mut x: i64 = 2; + change(&mut x); + return x; + } + + fn change(a: &mut i64) { + *a = 4; + } + } + "#; + + let result = compile_program(source, "mut_ref", false).expect("failed to compile"); + + let output = run_program(&result.binary_file).expect("failed to run"); + let code = output.status.code().unwrap(); + assert_eq!(code, 4); +} + +#[test] +fn test_structs() { + let source = r#" + mod Structs { + + struct Leaf { + x: i32, + y: i64, + } + struct Node { + a: Leaf, + b: Leaf, + } + + fn main() -> i32 { + let a: Leaf = Leaf { + x: 1, + y: 2, + }; + let b: Leaf = Leaf { + x: 1, + y: 2, + }; + let mut x: Node = Node { + a: a, + b: b, + }; + x.a.x = 2; + modify(&mut x); + return x.a.x + x.b.x; + } + + fn modify(node: &mut Node) { + node.a.x = 3; + node.b.x = 3; + } + } + + "#; + + let result = compile_program(source, "structs", false).expect("failed to compile"); + + let output = run_program(&result.binary_file).expect("failed to run"); + let code = output.status.code().unwrap(); + assert_eq!(code, 6); +} diff --git a/crates/concrete_parser/src/grammar.lalrpop b/crates/concrete_parser/src/grammar.lalrpop index e9ffbd8..c6f7851 100644 --- a/crates/concrete_parser/src/grammar.lalrpop +++ b/crates/concrete_parser/src/grammar.lalrpop @@ -150,7 +150,7 @@ pub(crate) TypeSpec: ast::types::TypeSpec = { }, "[" )?> "]" => ast::types::TypeSpec::Array { of_type: Box::new(of_type), - size, + size: size.map(|x| x.try_into().expect("size is bigger than u32::MAX")), is_ref, span: Span::new(lo, hi), } @@ -237,6 +237,9 @@ pub(crate) ModuleDefItem: ast::modules::ModuleDefItem = { => { ast::modules::ModuleDefItem::Module(<>) }, + => { + ast::modules::ModuleDefItem::Struct(<>) + }, } // Constants @@ -255,6 +258,27 @@ pub(crate) ConstantDef: ast::constants::ConstantDef = { }, } +// - Structs + +pub(crate) Field: ast::structs::Field = { + ":" => ast::structs::Field { + name, + r#type: type_spec, + doc_string: None, + } +} + +pub(crate) Struct: ast::structs::StructDecl = { + "struct" + "{" > "}" => ast::structs::StructDecl { + doc_string: None, + is_pub: is_pub.is_some(), + name, + fields, + type_params: type_params.unwrap_or_else(Vec::new), + } +} + // -- Functions pub(crate) FunctionRetType: ast::types::TypeSpec = { @@ -449,7 +473,7 @@ pub(crate) Statement: ast::statements::Statement = { ";" => ast::statements::Statement::Let(<>), ";" => ast::statements::Statement::Assign(<>), ";" => ast::statements::Statement::FnCall(<>), - ";"? => ast::statements::Statement::Return(<>), + ";" => ast::statements::Statement::Return(<>), } pub(crate) LetStmt: ast::statements::LetStmt = { @@ -459,14 +483,38 @@ pub(crate) LetStmt: ast::statements::LetStmt = { name, r#type: target_type }, - value + value: ast::statements::LetValue::Expr(value) }, + "let" ":" "=" => ast::statements::LetStmt { + is_mutable: is_mutable.is_some(), + target: ast::statements::LetStmtTarget::Simple { + name, + r#type: target_type + }, + value: ast::statements::LetValue::StructConstruct(value) + }, +} + +pub(crate) StructConstruct: ast::statements::StructConstruct = { + "{" > "}" => ast::statements::StructConstruct { + name, + fields, + } +} + + +pub(crate) FieldConstruct: ast::statements::FieldConstruct = { + ":" => ast::statements::FieldConstruct { + name, + value + } } pub(crate) AssignStmt: ast::statements::AssignStmt = { - "=" => ast::statements::AssignStmt { + "=" => ast::statements::AssignStmt { target, - value + value, + is_deref: is_deref.is_some(), }, } diff --git a/crates/concrete_parser/src/lib.rs b/crates/concrete_parser/src/lib.rs index f6fcb16..2881a4a 100644 --- a/crates/concrete_parser/src/lib.rs +++ b/crates/concrete_parser/src/lib.rs @@ -59,7 +59,9 @@ mod ModuleName { x = x % 2; match x { - 0 -> return 2, + 0 -> { + return 2; + }, 1 -> { let y: u64 = x * 2; return y * 10; @@ -97,8 +99,12 @@ mod ModuleName { let source = r##"mod FactorialModule { pub fn factorial(x: u64) -> u64 { return match x { - 0 -> return 1, - n -> return n * factorial(n-1), + 0 -> { + return 1; + }, + n -> { + return n * factorial(n-1); + }, }; } }"##; diff --git a/examples/aabb.con b/examples/aabb.con new file mode 100644 index 0000000..40762a2 --- /dev/null +++ b/examples/aabb.con @@ -0,0 +1,20 @@ +mod RectCheck { + struct Point2D { + x: i64, + y: i64, + } + + struct Rect2D { + pos: Point2D, + size: Point2D, + } + + pub fn is_point_inbounds(rect: &Rect2D, point: &Point2D) -> bool { + if point.x >= rect.pos.x && point.x <= (rect.pos.x + rect.size.x) + && point.y >= rect.pos.y && point.y <= (rect.pos.y + rect.size.y) { + return true; + } else { + return false; + } + } +} diff --git a/examples/complex.con b/examples/complex.con new file mode 100644 index 0000000..e982712 --- /dev/null +++ b/examples/complex.con @@ -0,0 +1,34 @@ +mod Structs { + + struct Leaf { + x: i32, + y: i64, + } + struct Node { + a: Leaf, + b: Leaf, + } + + fn main() -> i32 { + let a: Leaf = Leaf { + x: 1, + y: 2, + }; + let b: Leaf = Leaf { + x: 1, + y: 2, + }; + let mut x: Node = Node { + a: a, + b: b, + }; + x.a.x = 2; + modify(&mut x); + return x.a.x + x.b.x; + } + + fn modify(node: &mut Node) { + node.a.x = 3; + node.b.x = 3; + } +} diff --git a/examples/mutborrow.con b/examples/mutborrow.con new file mode 100644 index 0000000..75a4fb2 --- /dev/null +++ b/examples/mutborrow.con @@ -0,0 +1,11 @@ +mod Simple { + fn main(argc: i64) -> i64 { + let mut x: i64 = 2; + change(&mut x); + return x; + } + + fn change(a: &mut i64) { + *a = 4; + } +} diff --git a/examples/structs.con b/examples/structs.con new file mode 100644 index 0000000..95fa1d5 --- /dev/null +++ b/examples/structs.con @@ -0,0 +1,25 @@ +mod Structs { + struct Node { + a: i32, + b: i32, + } + + fn main() -> i32 { + let mut x: Node = create_node(2, 4); + x.a = 100; + modify_node(&mut x); + return x.a + x.b; + } + + fn create_node(a: i32, b: i32) -> Node { + let x: Node = Node { + a: a, + b: b, + }; + return x; + } + + fn modify_node(x: &mut Node) { + x.a = 1; + } +}