diff --git a/crates/concrete_codegen_mlir/src/codegen.rs b/crates/concrete_codegen_mlir/src/codegen.rs index c0991ec..2d20d5d 100644 --- a/crates/concrete_codegen_mlir/src/codegen.rs +++ b/crates/concrete_codegen_mlir/src/codegen.rs @@ -16,15 +16,18 @@ use concrete_session::Session; use melior::{ dialect::{ arith::{self, CmpiPredicate}, - cf, func, llvm, memref, + cf, func, + llvm::{self, r#type::opaque_pointer, LoadStoreOptions}, + memref, }, ir::{ attribute::{ - DenseI64ArrayAttribute, FlatSymbolRefAttribute, IntegerAttribute, StringAttribute, - TypeAttribute, + DenseI32ArrayAttribute, DenseI64ArrayAttribute, FlatSymbolRefAttribute, + IntegerAttribute, StringAttribute, TypeAttribute, }, r#type::{FunctionType, IntegerType, MemRefType}, - Block, BlockRef, Location, Module as MeliorModule, Operation, Region, Value, ValueLike, + Block, BlockRef, Location, Module as MeliorModule, Operation, Region, Type, Value, + ValueLike, }, Context as MeliorContext, }; @@ -548,17 +551,119 @@ fn compile_assign_stmt<'ctx, 'parent: 'ctx>( let location = get_location(context, session, info.target.first.span.from); - let value = compile_expression( - session, - context, - scope_ctx, - helper, - block, - &info.value, - Some(&local.type_spec), - )?; + if info.target.extra.is_empty() { + let value = compile_expression( + session, + context, + scope_ctx, + helper, + block, + &info.value, + Some(&local.type_spec), + )?; + + block.append_operation(memref::store(value, local.value, &[], location)); + } else { + let mut current_type_spec = &local.type_spec; + + // todo: instead of loading, use memref.extract_aligned_pointer_as_index + + let target_ptr = block + .append_operation( + melior::dialect::ods::memref::extract_aligned_pointer_as_index( + context, + Type::index(context), + local.value, + location, + ) + .into(), + ) + .result(0)? + .into(); + + let target_ptr = block + .append_operation(arith::index_cast( + target_ptr, + IntegerType::new(context, 64).into(), + location, + )) + .result(0)? + .into(); + + let mut target_ptr = block + .append_operation( + melior::dialect::ods::llvm::inttoptr( + context, + opaque_pointer(context), + target_ptr, + location, + ) + .into(), + ) + .result(0)? + .into(); + + let mut extra_it = info.target.extra.iter().peekable(); + + while let Some(extra) = extra_it.next() { + match extra { + PathSegment::FieldAccess(ident) => { + let (struct_decl, (struct_ty, field_indexes)) = match current_type_spec { + TypeSpec::Simple { name, .. } => { + let struct_decl = + scope_ctx.module_info.structs.get(&name.name).unwrap(); + ( + struct_decl, + scope_ctx.get_struct_type(context, struct_decl)?, + ) + } + _ => unreachable!(), + }; + + let field = struct_decl + .fields + .iter() + .find(|x| x.name.name == ident.name) + .unwrap(); + let field_idx = *field_indexes.get(&ident.name).unwrap(); - block.append_operation(memref::store(value, local.value, &[], location)); + current_type_spec = &field.r#type; + target_ptr = block + .append_operation(llvm::get_element_ptr( + context, + target_ptr, + DenseI32ArrayAttribute::new(context, &[field_idx as i32]), + struct_ty, + opaque_pointer(context), + location, + )) + .result(0)? + .into(); + + if extra_it.peek().is_none() { + let value = compile_expression( + session, + context, + scope_ctx, + helper, + block, + &info.value, + Some(current_type_spec), + )?; + + block.append_operation(llvm::store( + context, + value, + target_ptr, + location, + LoadStoreOptions::default(), + )); + } + } + PathSegment::ArrayIndex(_) => todo!(), + } + } + } Ok(()) } diff --git a/examples/structs.con b/examples/structs.con index 254509a..a31c069 100644 --- a/examples/structs.con +++ b/examples/structs.con @@ -5,7 +5,8 @@ mod Structs { } fn main() -> i32 { - let x: Node = create_node(2, 4); + let mut x: Node = create_node(2, 4); + x.a = 5; return x.a + x.b; }