diff --git a/crates/concrete_codegen_mlir/src/codegen.rs b/crates/concrete_codegen_mlir/src/codegen.rs index 027d549..aa2c83f 100644 --- a/crates/concrete_codegen_mlir/src/codegen.rs +++ b/crates/concrete_codegen_mlir/src/codegen.rs @@ -25,9 +25,10 @@ use melior::{ DenseI64ArrayAttribute, FlatSymbolRefAttribute, IntegerAttribute, StringAttribute, TypeAttribute, }, + operation::OperationBuilder, r#type::{FunctionType, IntegerType, MemRefType}, - Block, BlockRef, Location, Module as MeliorModule, Operation, Region, Type, Value, - ValueLike, + Block, BlockRef, Location, Module as MeliorModule, Operation, Region, Type, TypeLike, + Value, ValueLike, }, Context as MeliorContext, }; @@ -152,7 +153,12 @@ impl<'ctx, 'parent> ScopeContext<'ctx, 'parent> { spec: &TypeSpec, ) -> Result, Box> { match spec.is_ref() { - Some(_) => Ok(llvm::r#type::opaque_pointer(context)), + Some(_) => { + Ok( + MemRefType::new(self.resolve_type_spec_ref(context, spec)?, &[], None, None) + .into(), + ) + } None => self.resolve_type_spec_ref(context, spec), } } @@ -168,8 +174,8 @@ impl<'ctx, 'parent> ScopeContext<'ctx, 'parent> { TypeSpec::Generic { name, .. } => self.resolve_type(context, &name.name)?, TypeSpec::Array { of_type, - span, - is_ref, + span: _, + is_ref: _, size, } => match size { Some(size) => { @@ -185,19 +191,16 @@ impl<'ctx, 'parent> ScopeContext<'ctx, 'parent> { None => { // let inner_type = self.resolve_type_spec(context, of_type)?; - Type::parse(context, &format!("memref")).unwrap() - // MemRefType::new(inner_type, &[(u32::MAX) as u64], None, None).into() - /* + // Type::parse(context, &format!("memref")).unwrap() + llvm::r#type::r#struct( context, &[ - llvm::r#type::opaque_pointer(context), - IntegerType::new(context, 64).into(), + Type::parse(context, &format!("memref")).unwrap(), IntegerType::new(context, 64).into(), ], false, ) - */ } }, }) @@ -555,6 +558,17 @@ fn compile_while<'c, 'this: 'c>( Ok(merge_block) } +fn is_local_copy(a: &Expression) -> Option<(&PathOp, bool)> { + match a { + Expression::Value(value) => match value { + ValueExpr::Path(path) => Some((path, false)), + ValueExpr::Deref(path) => Some((path, true)), + _ => None, + }, + _ => None, + } +} + fn compile_let_stmt<'ctx, 'parent: 'ctx>( session: &Session, context: &'ctx MeliorContext, @@ -565,6 +579,8 @@ fn compile_let_stmt<'ctx, 'parent: 'ctx>( ) -> Result<(), Box> { match &info.target { LetStmtTarget::Simple { name, r#type } => { + let location = get_location(context, session, name.span.from); + let value = compile_expression( session, context, @@ -574,10 +590,7 @@ fn compile_let_stmt<'ctx, 'parent: 'ctx>( &info.value, Some(r#type), )?; - - let location = get_location(context, session, name.span.from); - - let memref_type = MemRefType::new(value.r#type(), &[1], None, None); + let memref_type = MemRefType::new(value.r#type(), &[], None, None); let alloca: Value = block .append_operation(memref::alloca( @@ -598,6 +611,7 @@ fn compile_let_stmt<'ctx, 'parent: 'ctx>( )) .result(0)? .into(); + block.append_operation(memref::store(value, alloca, &[k0], location)); scope_ctx @@ -640,15 +654,67 @@ fn compile_assign_stmt<'ctx, 'parent: 'ctx>( Some(&local.type_spec), )?; - let k0 = block - .append_operation(arith::constant( - context, - IntegerAttribute::new(0, Type::index(context)).into(), - location, - )) - .result(0)? - .into(); - block.append_operation(memref::store(value, local.value, &[k0], location)); + if info.target.extra.is_empty() { + block.append_operation(memref::store(value, local.value, &[], location)); + } else { + let mut store_target = block + .append_operation(memref::load(local.value, &[], location)) + .result(0)? + .into(); + + let mut segment_iter = info.target.extra.iter().peekable(); + + while let Some(segment) = segment_iter.next() { + match segment { + PathSegment::FieldAccess(_) => todo!(), + PathSegment::ArrayIndex(index) => { + let index = compile_value_expr( + session, context, scope_ctx, helper, block, index, None, + )?; + let index_ty = Type::index(context); + let index = block + .append_operation(melior::dialect::index::castu(index, index_ty, location)) + .result(0)? + .into(); + + if let TypeSpec::Array { + of_type: _, + size, + is_ref: _, + span, + } = &local.type_spec + { + let location = get_location(context, session, span.from); + + #[allow(clippy::if_same_then_else)] + if size.is_some() { + // todo: check inbounds? + store_target = block + .append_operation(memref::load(store_target, &[index], location)) + .result(0)? + .into(); + } else { + store_target = block + .append_operation(memref::load(store_target, &[index], location)) + .result(0)? + .into(); + } + + if segment_iter.peek().is_none() { + block.append_operation(memref::store( + value, + store_target, + &[index], + location, + )); + } + } else { + panic!("type should be a array when indexing a value"); + } + } + } + } + } Ok(()) } @@ -949,16 +1015,8 @@ fn compile_path_op<'ctx, 'parent: 'ctx>( let location = get_location(context, session, path.first.span.from); let mut value = if local.alloca { - let k0 = block - .append_operation(arith::constant( - context, - IntegerAttribute::new(0, Type::index(context)).into(), - location, - )) - .result(0)? - .into(); block - .append_operation(memref::load(local.value, &[k0], location)) + .append_operation(memref::load(local.value, &[], location)) .result(0)? .into() } else { @@ -978,86 +1036,25 @@ fn compile_path_op<'ctx, 'parent: 'ctx>( .into(); if let TypeSpec::Array { - of_type, + of_type: _, size, is_ref: _, span, } = &local.type_spec { let location = get_location(context, session, span.from); - let inner_type = scope_ctx.resolve_type_spec(context, of_type)?; #[allow(clippy::if_same_then_else)] if size.is_some() { + // todo: check inbounds? value = block .append_operation(memref::load(value, &[index], location)) .result(0)? .into(); - /* - // its a llvm.array - let ptr = block - .append_operation(llvm::get_element_ptr_dynamic( - context, - value, - &[index], - inner_type, - opaque_pointer(context), - location, - )) - .result(0)? - .into(); - - value = block - .append_operation(llvm::load( - context, - ptr, - inner_type, - location, - LoadStoreOptions::new(), - )) - .result(0)? - .into(); - */ } else { value = block .append_operation(memref::load(value, &[index], location)) .result(0)? .into(); - // its a struct {ptr,u64,u64} - /* - let ptr = block - .append_operation(llvm::extract_value( - context, - value, - DenseI64ArrayAttribute::new(context, &[0]), - opaque_pointer(context), - location, - )) - .result(0)? - .into(); - - let elem_ptr = block - .append_operation(llvm::get_element_ptr_dynamic( - context, - ptr, - &[index], - inner_type, - opaque_pointer(context), - location, - )) - .result(0)? - .into(); - - value = block - .append_operation(llvm::load( - context, - elem_ptr, - inner_type, - location, - LoadStoreOptions::new(), - )) - .result(0)? - .into(); - */ } } else { panic!("type should be a array when indexing a value"); @@ -1087,15 +1084,8 @@ fn compile_deref<'ctx, 'parent: 'ctx>( let inner_type = scope_ctx.resolve_type_spec_ref(context, &local.type_spec)?; let mut value = block - .append_operation(llvm::load( - context, - local.value, - inner_type, - location, - LoadStoreOptions::new(), - )) - .result(0)? - .into(); + .append_operation(memref::load(local.value, &[], location)).result(0)?.into(); + for segment in &path.extra { match segment { @@ -1103,73 +1093,30 @@ fn compile_deref<'ctx, 'parent: 'ctx>( PathSegment::ArrayIndex(index) => { let index = compile_value_expr(session, context, scope_ctx, helper, block, index, None)?; + let index_ty = Type::index(context); + let index = block + .append_operation(melior::dialect::index::castu(index, index_ty, location)) + .result(0)? + .into(); if let TypeSpec::Array { - of_type, + of_type: _, size, is_ref: _, span, } = &local.type_spec { let location = get_location(context, session, span.from); - let inner_type = scope_ctx.resolve_type_spec(context, of_type)?; + #[allow(clippy::if_same_then_else)] if size.is_some() { - // its a llvm.array - let ptr = block - .append_operation(llvm::get_element_ptr_dynamic( - context, - value, - &[index], - inner_type, - opaque_pointer(context), - location, - )) - .result(0)? - .into(); - + // todo: check inbounds? value = block - .append_operation(llvm::load( - context, - ptr, - inner_type, - location, - LoadStoreOptions::new(), - )) + .append_operation(memref::load(value, &[index], location)) .result(0)? .into(); } else { - // its a struct {ptr,u64,u64} - let ptr = block - .append_operation(llvm::extract_value( - context, - value, - DenseI64ArrayAttribute::new(context, &[0]), - opaque_pointer(context), - location, - )) - .result(0)? - .into(); - - let elem_ptr = block - .append_operation(llvm::get_element_ptr_dynamic( - context, - ptr, - &[index], - inner_type, - opaque_pointer(context), - location, - )) - .result(0)? - .into(); - value = block - .append_operation(llvm::load( - context, - elem_ptr, - inner_type, - location, - LoadStoreOptions::new(), - )) + .append_operation(memref::load(value, &[index], location)) .result(0)? .into(); } diff --git a/examples/borrow.con b/examples/borrow.con index dd56175..b6dfca8 100644 --- a/examples/borrow.con +++ b/examples/borrow.con @@ -3,15 +3,7 @@ mod Simple { return argc; } - fn hello(a: &i64) -> i64 { + fn dereference(a: &i64) -> i64 { return *a; } - - fn hello2(a: [i64]) -> i64 { - return a[0]; - } - - fn hello3(a: [i64; 4]) -> i64 { - return a[0]; - } }