Skip to content

Commit

Permalink
works
Browse files Browse the repository at this point in the history
  • Loading branch information
edg-l committed Jan 22, 2024
1 parent b7f7014 commit 856ffc7
Show file tree
Hide file tree
Showing 7 changed files with 122 additions and 36 deletions.
3 changes: 0 additions & 3 deletions crates/concrete_ast/src/expressions.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
use std::collections::HashMap;

use crate::{
common::Ident,
statements::Statement,
structs::Field,
types::{RefType, TypeSpec},
};

Expand Down
1 change: 0 additions & 1 deletion crates/concrete_ast/src/statements.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use crate::{
common::Ident,
expressions::{Expression, FnCallOp, IfExpr, MatchExpr, PathOp},
structs::Field,
types::TypeSpec,
};

Expand Down
119 changes: 111 additions & 8 deletions crates/concrete_codegen_mlir/src/codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ use std::{collections::HashMap, error::Error};
use bumpalo::Bump;
use concrete_ast::{
expressions::{
ArithOp, BinaryOp, CmpOp, Expression, FnCallOp, IfExpr, LogicOp, PathOp, ValueExpr,
ArithOp, BinaryOp, CmpOp, Expression, FnCallOp, IfExpr, LogicOp, PathOp, PathSegment,
ValueExpr,
},
functions::FunctionDef,
modules::{Module, ModuleDefItem},
Expand All @@ -15,11 +16,14 @@ use concrete_session::Session;
use melior::{
dialect::{
arith::{self, CmpiPredicate},
cf, func, memref,
cf, func, llvm, memref,
},
ir::{
attribute::{FlatSymbolRefAttribute, IntegerAttribute, StringAttribute, TypeAttribute},
r#type::{FunctionType, IntegerType, MemRefType},
attribute::{
DenseI64ArrayAttribute, FlatSymbolRefAttribute, IntegerAttribute, StringAttribute,
TypeAttribute,
},
r#type::{id, FunctionType, IntegerType, MemRefType},
Block, BlockRef, Location, Module as MeliorModule, Operation, Region, Value, ValueLike,
},
Context as MeliorContext,
Expand Down Expand Up @@ -440,8 +444,62 @@ fn compile_let_stmt<'ctx, 'parent: 'ctx>(
value,
Some(r#type),
)?,
LetValue::StructConstruct(_) => {
todo!()
LetValue::StructConstruct(struct_construct) => {
let struct_decl = scope_ctx
.module_info
.structs
.get(&struct_construct.name.name)
.unwrap_or_else(|| {
panic!("failed to find struct {:?}", struct_construct.name.name)
});
assert_eq!(
struct_construct.fields.len(),
struct_decl.fields.len(),
"struct field len mismatch"
);
let (ty, field_indexes) = scope_ctx.get_struct_type(context, struct_decl)?;
let mut struct_value = block
.append_operation(llvm::undef(ty, location))
.result(0)?
.into();
let field_types: HashMap<String, &TypeSpec> = struct_decl
.fields
.iter()
.map(|x| (x.name.name.clone(), &x.r#type))
.collect();

for field in &struct_construct.fields {
let field_idx = field_indexes.get(&field.name.name).unwrap_or_else(|| {
panic!(
"failed to find field {:?} for struct {:?}",
field.name.name, struct_construct.name.name
)
});

let field_ty = field_types.get(&field.name.name).expect("field not found");
let value = compile_expression(
session,
context,
scope_ctx,
helper,
block,
&field.value,
Some(field_ty),
)?;

struct_value = block
.append_operation(llvm::insert_value(
context,
struct_value,
DenseI64ArrayAttribute::new(context, &[(*field_idx) as i64]),
value,
location,
))
.result(0)?
.into();
}

struct_value
}
};
let memref_type = MemRefType::new(value.r#type(), &[], None, None);
Expand Down Expand Up @@ -806,7 +864,7 @@ fn compile_path_op<'ctx, 'parent: 'ctx>(

let location = get_location(context, session, path.first.span.from);

let value = if local.alloca {
let mut value = if local.alloca {
block
.append_operation(memref::load(local.value, &[], location))
.result(0)?
Expand All @@ -815,7 +873,52 @@ fn compile_path_op<'ctx, 'parent: 'ctx>(
local.value
};

Ok(value)
if path.extra.is_empty() {
Ok(value)
} else {
let mut current_type_spec = &local.type_spec;

for extra in &path.extra {
match extra {
PathSegment::FieldAccess(ident) => {
let (struct_decl, (_, field_indexes)) = match current_type_spec {
TypeSpec::Simple { name, .. } => {
let struct_decl =
scope_ctx.module_info.structs.get(&name.name).unwrap();
(
struct_decl,
scope_ctx.get_struct_type(context, struct_decl)?,
)
}
_ => unreachable!(),
};

let field = struct_decl
.fields
.iter()
.find(|x| x.name.name == ident.name)
.unwrap();
let field_idx = *field_indexes.get(&ident.name).unwrap();
let field_ty = scope_ctx.resolve_type_spec(context, &field.r#type)?;

current_type_spec = &field.r#type;
value = block
.append_operation(llvm::extract_value(
context,
value,
DenseI64ArrayAttribute::new(context, &[field_idx as i64]),
field_ty,
location,
))
.result(0)?
.into();
}
PathSegment::ArrayIndex(_) => todo!(),
}
}

Ok(value)
}
}

fn compile_deref<'ctx, 'parent: 'ctx>(
Expand Down
2 changes: 1 addition & 1 deletion crates/concrete_codegen_mlir/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ impl Context {

super::codegen::compile_program(session, &self.melior_context, &melior_module, program)?;

let print_flags = OperationPrintingFlags::new().enable_debug_info(true, true);
let print_flags = OperationPrintingFlags::new().enable_debug_info(false, true);
tracing::debug!(
"MLIR Code before passes:\n{}",
melior_module
Expand Down
2 changes: 1 addition & 1 deletion crates/concrete_codegen_mlir/src/scope_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ impl<'ctx, 'parent> ScopeContext<'ctx, 'parent> {
}

/// Returns the struct type along with the field indexes.
fn get_struct_type(
pub fn get_struct_type(
&self,
context: &'ctx MeliorContext,
strct: &StructDecl,
Expand Down
2 changes: 1 addition & 1 deletion crates/concrete_parser/src/grammar.lalrpop
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ pub(crate) Struct: ast::structs::StructDecl = {
doc_string: None,
is_pub: is_pub.is_some(),
name,
fields: fields,
fields,
type_params: type_params.unwrap_or_else(Vec::new),
}
}
Expand Down
29 changes: 8 additions & 21 deletions examples/structs.con
Original file line number Diff line number Diff line change
@@ -1,32 +1,19 @@
mod Fibonacci {

mod Structs {
struct Node {
a: i32,
b: i32,
}

struct Node2 {
a: i32,
b: i64,
}

struct Nod3 {
a: i64,
b: i32,
}

struct Node5 {
a: i8,
b: i32,
fn main() -> i32 {
let x: Node = create_node(2, 4);
return x.a + x.b;
}

fn main() -> i64 {

fn create_node(a: i32, b: i32) -> Node {
let x: Node = Node {
a: 2,
b: 2,
a: a,
b: b,
};

return 1;
return x;
}
}

0 comments on commit 856ffc7

Please sign in to comment.