Skip to content

Commit

Permalink
Merge pull request #68 from lambdaclass/locals_use_allocas
Browse files Browse the repository at this point in the history
use allocas for variables
  • Loading branch information
unbalancedparentheses authored Jan 10, 2024
2 parents 538a911 + f29f7d3 commit a55b079
Showing 1 changed file with 80 additions and 17 deletions.
97 changes: 80 additions & 17 deletions crates/concrete_codegen_mlir/src/codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@ use concrete_session::Session;
use melior::{
dialect::{
arith::{self, CmpiPredicate},
func,
func, memref,
},
ir::{
attribute::{FlatSymbolRefAttribute, IntegerAttribute, StringAttribute, TypeAttribute},
r#type::{FunctionType, IntegerType},
Block, Location, Module as MeliorModule, Region, Type, Value,
r#type::{FunctionType, IntegerType, MemRefType},
Block, Location, Module as MeliorModule, Region, Type, Value, ValueLike,
},
Context as MeliorContext,
};
Expand All @@ -38,6 +38,8 @@ pub fn compile_program(
#[derive(Debug, Clone)]
pub struct LocalVar<'c, 'op> {
pub type_spec: TypeSpec,
// If it's none its on a register, otherwise allocated on the stack.
pub memref_type: Option<MemRefType<'c>>,
pub value: Value<'c, 'op>,
}

Expand Down Expand Up @@ -152,6 +154,7 @@ fn compile_function_def<'c, 'op>(
LocalVar {
type_spec: param.r#type.clone(),
value: fn_block.argument(i)?.into(),
memref_type: None,
},
);
}
Expand Down Expand Up @@ -216,11 +219,37 @@ fn compile_let_stmt<'c, 'op>(
Some(r#type),
)?;

let location = get_location(context, session, &name.span);

let memref_type = MemRefType::new(value.r#type(), &[1], None, None);

let alloca: Value = block
.append_operation(memref::alloca(
context,
memref_type,
&[],
&[],
None,
location,
))
.result(0)?
.into();
let k0 = block
.append_operation(arith::constant(
context,
IntegerAttribute::new(0, Type::index(context)).into(),
location,
))
.result(0)?
.into();
block.append_operation(memref::store(value, alloca, &[k0], location));

compiler_ctx.locals.insert(
name.name.clone(),
LocalVar {
type_spec: r#type.clone(),
value,
memref_type: Some(memref_type),
value: alloca,
},
);

Expand All @@ -244,6 +273,14 @@ fn compile_assign_stmt<'c, 'op>(
.get(&info.target.first.name)
.expect("local should exist")
.clone();

assert!(
local.memref_type.is_some(),
"can only mutate local stack variables"
);

let location = get_location(context, session, &info.target.first.span);

let value = compile_expression(
session,
context,
Expand All @@ -252,13 +289,16 @@ fn compile_assign_stmt<'c, 'op>(
&info.value,
Some(&local.type_spec),
)?;
compiler_ctx.locals.insert(
info.target.first.name.clone(),
LocalVar {
type_spec: local.type_spec,
value,
},
);

let k0 = block
.append_operation(arith::constant(
context,
IntegerAttribute::new(0, Type::index(context)).into(),
location,
))
.result(0)?
.into();
block.append_operation(memref::store(value, local.value, &[k0], location));

Ok(())
}
Expand Down Expand Up @@ -316,7 +356,9 @@ fn compile_expression<'c, 'op>(
}
SimpleExpr::ConstFloat(_) => todo!(),
SimpleExpr::ConstStr(_) => todo!(),
SimpleExpr::Path(value) => compile_path_op(context, compiler_ctx, block, value),
SimpleExpr::Path(value) => {
compile_path_op(session, context, compiler_ctx, block, value)
}
},
Expression::FnCall(value) => {
let mut args = Vec::with_capacity(value.args.len());
Expand Down Expand Up @@ -464,16 +506,37 @@ fn compile_expression<'c, 'op>(
}

fn compile_path_op<'c, 'op>(
_context: &MeliorContext,
session: &Session,
context: &'c MeliorContext,
compiler_ctx: &mut CompilerContext<'c, 'op>,
_block: &Block<'c>,
block: &'op Block<'c>,
path: &PathOp,
) -> Result<Value<'c, 'op>, Box<dyn Error>> {
// For now only simple variables work.
// TODO: implement properly, this requires having structs implemented.
Ok(compiler_ctx

let local = compiler_ctx
.locals
.get(&path.first.name)
.map(|x| x.value)
.expect("local not found"))
.expect("local not found");

let location = get_location(context, session, &path.first.span);

if let Some(_memref_type) = local.memref_type {
let k0 = block
.append_operation(arith::constant(
context,
IntegerAttribute::new(0, Type::index(context)).into(),
location,
))
.result(0)?
.into();
let value = block
.append_operation(memref::load(local.value, &[k0], location))
.result(0)?
.into();
Ok(value)
} else {
Ok(local.value)
}
}

0 comments on commit a55b079

Please sign in to comment.