From 856ffc7d0567928b0cb8968170c6942b734d8a04 Mon Sep 17 00:00:00 2001 From: Edgar Date: Mon, 22 Jan 2024 17:50:21 -0300 Subject: [PATCH] works --- crates/concrete_ast/src/expressions.rs | 3 - crates/concrete_ast/src/statements.rs | 1 - crates/concrete_codegen_mlir/src/codegen.rs | 119 ++++++++++++++++-- crates/concrete_codegen_mlir/src/context.rs | 2 +- .../src/scope_context.rs | 2 +- crates/concrete_parser/src/grammar.lalrpop | 2 +- examples/structs.con | 29 ++--- 7 files changed, 122 insertions(+), 36 deletions(-) diff --git a/crates/concrete_ast/src/expressions.rs b/crates/concrete_ast/src/expressions.rs index e016274..12f8752 100644 --- a/crates/concrete_ast/src/expressions.rs +++ b/crates/concrete_ast/src/expressions.rs @@ -1,9 +1,6 @@ -use std::collections::HashMap; - use crate::{ common::Ident, statements::Statement, - structs::Field, types::{RefType, TypeSpec}, }; diff --git a/crates/concrete_ast/src/statements.rs b/crates/concrete_ast/src/statements.rs index 4ef2fe6..f316fc5 100644 --- a/crates/concrete_ast/src/statements.rs +++ b/crates/concrete_ast/src/statements.rs @@ -1,7 +1,6 @@ use crate::{ common::Ident, expressions::{Expression, FnCallOp, IfExpr, MatchExpr, PathOp}, - structs::Field, types::TypeSpec, }; diff --git a/crates/concrete_codegen_mlir/src/codegen.rs b/crates/concrete_codegen_mlir/src/codegen.rs index 6e7a817..40db12f 100644 --- a/crates/concrete_codegen_mlir/src/codegen.rs +++ b/crates/concrete_codegen_mlir/src/codegen.rs @@ -3,7 +3,8 @@ use std::{collections::HashMap, error::Error}; use bumpalo::Bump; use concrete_ast::{ expressions::{ - ArithOp, BinaryOp, CmpOp, Expression, FnCallOp, IfExpr, LogicOp, PathOp, ValueExpr, + ArithOp, BinaryOp, CmpOp, Expression, FnCallOp, IfExpr, LogicOp, PathOp, PathSegment, + ValueExpr, }, functions::FunctionDef, modules::{Module, ModuleDefItem}, @@ -15,11 +16,14 @@ use concrete_session::Session; use melior::{ dialect::{ arith::{self, CmpiPredicate}, - cf, func, memref, + cf, func, llvm, memref, }, ir::{ - attribute::{FlatSymbolRefAttribute, IntegerAttribute, StringAttribute, TypeAttribute}, - r#type::{FunctionType, IntegerType, MemRefType}, + attribute::{ + DenseI64ArrayAttribute, FlatSymbolRefAttribute, IntegerAttribute, StringAttribute, + TypeAttribute, + }, + r#type::{id, FunctionType, IntegerType, MemRefType}, Block, BlockRef, Location, Module as MeliorModule, Operation, Region, Value, ValueLike, }, Context as MeliorContext, @@ -440,8 +444,62 @@ fn compile_let_stmt<'ctx, 'parent: 'ctx>( value, Some(r#type), )?, - LetValue::StructConstruct(_) => { - todo!() + LetValue::StructConstruct(struct_construct) => { + let struct_decl = scope_ctx + .module_info + .structs + .get(&struct_construct.name.name) + .unwrap_or_else(|| { + panic!("failed to find struct {:?}", struct_construct.name.name) + }); + assert_eq!( + struct_construct.fields.len(), + struct_decl.fields.len(), + "struct field len mismatch" + ); + let (ty, field_indexes) = scope_ctx.get_struct_type(context, struct_decl)?; + let mut struct_value = block + .append_operation(llvm::undef(ty, location)) + .result(0)? + .into(); + let field_types: HashMap = struct_decl + .fields + .iter() + .map(|x| (x.name.name.clone(), &x.r#type)) + .collect(); + + for field in &struct_construct.fields { + let field_idx = field_indexes.get(&field.name.name).unwrap_or_else(|| { + panic!( + "failed to find field {:?} for struct {:?}", + field.name.name, struct_construct.name.name + ) + }); + + let field_ty = field_types.get(&field.name.name).expect("field not found"); + let value = compile_expression( + session, + context, + scope_ctx, + helper, + block, + &field.value, + Some(field_ty), + )?; + + struct_value = block + .append_operation(llvm::insert_value( + context, + struct_value, + DenseI64ArrayAttribute::new(context, &[(*field_idx) as i64]), + value, + location, + )) + .result(0)? + .into(); + } + + struct_value } }; let memref_type = MemRefType::new(value.r#type(), &[], None, None); @@ -806,7 +864,7 @@ fn compile_path_op<'ctx, 'parent: 'ctx>( let location = get_location(context, session, path.first.span.from); - let value = if local.alloca { + let mut value = if local.alloca { block .append_operation(memref::load(local.value, &[], location)) .result(0)? @@ -815,7 +873,52 @@ fn compile_path_op<'ctx, 'parent: 'ctx>( local.value }; - Ok(value) + if path.extra.is_empty() { + Ok(value) + } else { + let mut current_type_spec = &local.type_spec; + + for extra in &path.extra { + match extra { + PathSegment::FieldAccess(ident) => { + let (struct_decl, (_, field_indexes)) = match current_type_spec { + TypeSpec::Simple { name, .. } => { + let struct_decl = + scope_ctx.module_info.structs.get(&name.name).unwrap(); + ( + struct_decl, + scope_ctx.get_struct_type(context, struct_decl)?, + ) + } + _ => unreachable!(), + }; + + let field = struct_decl + .fields + .iter() + .find(|x| x.name.name == ident.name) + .unwrap(); + let field_idx = *field_indexes.get(&ident.name).unwrap(); + let field_ty = scope_ctx.resolve_type_spec(context, &field.r#type)?; + + current_type_spec = &field.r#type; + value = block + .append_operation(llvm::extract_value( + context, + value, + DenseI64ArrayAttribute::new(context, &[field_idx as i64]), + field_ty, + location, + )) + .result(0)? + .into(); + } + PathSegment::ArrayIndex(_) => todo!(), + } + } + + Ok(value) + } } fn compile_deref<'ctx, 'parent: 'ctx>( diff --git a/crates/concrete_codegen_mlir/src/context.rs b/crates/concrete_codegen_mlir/src/context.rs index 70961fd..db65869 100644 --- a/crates/concrete_codegen_mlir/src/context.rs +++ b/crates/concrete_codegen_mlir/src/context.rs @@ -43,7 +43,7 @@ impl Context { super::codegen::compile_program(session, &self.melior_context, &melior_module, program)?; - let print_flags = OperationPrintingFlags::new().enable_debug_info(true, true); + let print_flags = OperationPrintingFlags::new().enable_debug_info(false, true); tracing::debug!( "MLIR Code before passes:\n{}", melior_module diff --git a/crates/concrete_codegen_mlir/src/scope_context.rs b/crates/concrete_codegen_mlir/src/scope_context.rs index 876d35e..5f97eb1 100644 --- a/crates/concrete_codegen_mlir/src/scope_context.rs +++ b/crates/concrete_codegen_mlir/src/scope_context.rs @@ -121,7 +121,7 @@ impl<'ctx, 'parent> ScopeContext<'ctx, 'parent> { } /// Returns the struct type along with the field indexes. - fn get_struct_type( + pub fn get_struct_type( &self, context: &'ctx MeliorContext, strct: &StructDecl, diff --git a/crates/concrete_parser/src/grammar.lalrpop b/crates/concrete_parser/src/grammar.lalrpop index a6142ee..ae623e4 100644 --- a/crates/concrete_parser/src/grammar.lalrpop +++ b/crates/concrete_parser/src/grammar.lalrpop @@ -272,7 +272,7 @@ pub(crate) Struct: ast::structs::StructDecl = { doc_string: None, is_pub: is_pub.is_some(), name, - fields: fields, + fields, type_params: type_params.unwrap_or_else(Vec::new), } } diff --git a/examples/structs.con b/examples/structs.con index 218632f..254509a 100644 --- a/examples/structs.con +++ b/examples/structs.con @@ -1,32 +1,19 @@ -mod Fibonacci { - +mod Structs { struct Node { a: i32, b: i32, } - struct Node2 { - a: i32, - b: i64, - } - - struct Nod3 { - a: i64, - b: i32, - } - - struct Node5 { - a: i8, - b: i32, + fn main() -> i32 { + let x: Node = create_node(2, 4); + return x.a + x.b; } - fn main() -> i64 { - + fn create_node(a: i32, b: i32) -> Node { let x: Node = Node { - a: 2, - b: 2, + a: a, + b: b, }; - - return 1; + return x; } }