Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
edg-l committed Jan 15, 2024
1 parent c0f837d commit 8e70167
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 35 deletions.
71 changes: 37 additions & 34 deletions crates/concrete_codegen_mlir/src/codegen.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{cell::Cell, collections::HashMap, error::Error};
use std::{collections::HashMap, error::Error};

use bumpalo::Bump;
use concrete_ast::{
Expand Down Expand Up @@ -38,11 +38,29 @@ pub fn compile_program(
}

#[derive(Debug, Clone)]
pub struct LocalVar<'c, 'op> {
pub struct LocalVar<'ctx, 'parent: 'ctx> {
pub type_spec: TypeSpec,
// If it's none its on a register, otherwise allocated on the stack.
pub memref_type: Option<MemRefType<'c>>,
pub value: Value<'c, 'op>,
pub alloca: bool,
pub value: Value<'ctx, 'parent>,
}

impl<'ctx, 'parent: 'ctx> LocalVar<'ctx, 'parent> {
pub fn param(value: Value<'ctx, 'parent>, type_spec: TypeSpec) -> Self {
Self {
value,
type_spec,
alloca: false,
}
}

pub fn alloca(value: Value<'ctx, 'parent>, type_spec: TypeSpec) -> Self {
Self {
value,
type_spec,
alloca: true,
}
}
}

#[derive(Debug, Clone)]
Expand All @@ -54,17 +72,13 @@ struct CompilerContext<'c, 'this: 'c> {
struct BlockHelper<'ctx, 'this: 'ctx> {
region: &'this Region<'ctx>,
blocks_arena: &'this Bump,
last_block: Cell<&'this BlockRef<'ctx, 'this>>,
}

impl<'ctx, 'this> BlockHelper<'ctx, 'this> {
pub fn append_block(&self, block: Block<'ctx>) -> &'this BlockRef<'ctx, 'this> {
let block = self
.region
.insert_block_after(*self.last_block.get(), block);
let block = self.region.append_block(block);

let block_ref: &'this mut BlockRef<'ctx, 'this> = self.blocks_arena.alloc(block);
self.last_block.set(block_ref);

block_ref
}
Expand All @@ -81,6 +95,8 @@ impl<'c, 'this> CompilerContext<'c, 'this> {
"u32" | "i32" => IntegerType::new(context, 32).into(),
"u16" | "i16" => IntegerType::new(context, 16).into(),
"u8" | "i8" => IntegerType::new(context, 8).into(),
"f32" => Type::float32(context),
"f64" => Type::float64(context),
"bool" => IntegerType::new(context, 1).into(),
_ => todo!("custom type lookup"),
})
Expand Down Expand Up @@ -155,6 +171,7 @@ fn compile_function_def<'c, 'this: 'c>(
compiler_ctx: &mut CompilerContext<'c, 'this>,
info: &FunctionDef,
) -> Result<Operation<'c>, Box<dyn Error>> {
tracing::debug!("compiling function {:?}", info.decl.name.name);
let location = get_location(context, session, &info.decl.name.span);

// Setup function arguments
Expand Down Expand Up @@ -188,18 +205,13 @@ fn compile_function_def<'c, 'this: 'c>(
let helper = BlockHelper {
region: &region,
blocks_arena: &blocks_arena,
last_block: Cell::new(fn_block),
};

// Push arguments into locals
for (i, param) in info.decl.params.iter().enumerate() {
fn_compiler_ctx.locals.insert(
param.name.name.clone(),
LocalVar {
type_spec: param.r#type.clone(),
value: fn_block.argument(i)?.into(),
memref_type: None,
},
LocalVar::param(fn_block.argument(i)?.into(), param.r#type.clone()),
);
}

Expand Down Expand Up @@ -443,14 +455,9 @@ fn compile_let_stmt<'c, 'this: 'c>(
.into();
block.append_operation(memref::store(value, alloca, &[k0], location));

compiler_ctx.locals.insert(
name.name.clone(),
LocalVar {
type_spec: r#type.clone(),
memref_type: Some(memref_type),
value: alloca,
},
);
compiler_ctx
.locals
.insert(name.name.clone(), LocalVar::alloca(alloca, r#type.clone()));

Ok(())
}
Expand All @@ -474,10 +481,7 @@ fn compile_assign_stmt<'c, 'this: 'c>(
.expect("local should exist")
.clone();

assert!(
local.memref_type.is_some(),
"can only mutate local stack variables"
);
assert!(local.alloca, "can only mutate local stack variables");

let location = get_location(context, session, &info.target.first.span);

Expand Down Expand Up @@ -529,7 +533,7 @@ fn compile_expression<'c, 'this: 'c>(
session: &Session,
context: &'c MeliorContext,
compiler_ctx: &mut CompilerContext<'c, 'this>,
helper: &BlockHelper<'c, 'this>,
_helper: &BlockHelper<'c, 'this>,
block: &'this Block<'c>,
info: &Expression,
type_info: Option<&TypeSpec>,
Expand Down Expand Up @@ -568,7 +572,7 @@ fn compile_expression<'c, 'this: 'c>(
SimpleExpr::ConstFloat(_) => todo!(),
SimpleExpr::ConstStr(_) => todo!(),
SimpleExpr::Path(value) => {
compile_path_op(session, context, compiler_ctx, helper, block, value)
compile_path_op(session, context, compiler_ctx, block, value)
}
},
Expression::FnCall(value) => {
Expand All @@ -592,7 +596,7 @@ fn compile_expression<'c, 'this: 'c>(
session,
context,
compiler_ctx,
helper,
_helper,
block,
arg,
Some(&arg_info.r#type),
Expand Down Expand Up @@ -625,7 +629,7 @@ fn compile_expression<'c, 'this: 'c>(
session,
context,
compiler_ctx,
helper,
_helper,
block,
lhs,
type_info,
Expand All @@ -634,7 +638,7 @@ fn compile_expression<'c, 'this: 'c>(
session,
context,
compiler_ctx,
helper,
_helper,
block,
rhs,
type_info,
Expand Down Expand Up @@ -737,7 +741,6 @@ fn compile_path_op<'c, 'this: 'c>(
session: &Session,
context: &'c MeliorContext,
compiler_ctx: &mut CompilerContext<'c, 'this>,
helper: &BlockHelper<'c, 'this>,
block: &'this Block<'c>,
path: &PathOp,
) -> Result<Value<'c, 'this>, Box<dyn Error>> {
Expand All @@ -751,7 +754,7 @@ fn compile_path_op<'c, 'this: 'c>(

let location = get_location(context, session, &path.first.span);

if let Some(_memref_type) = local.memref_type {
if local.alloca {
let k0 = block
.append_operation(arith::constant(
context,
Expand Down
2 changes: 1 addition & 1 deletion crates/concrete_codegen_mlir/src/context.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::error::Error;

use concrete_ast::Program;
use concrete_session::{config::DebugInfo, Session};
use concrete_session::Session;
use melior::{
dialect::DialectRegistry,
ir::{operation::OperationPrintingFlags, Location, Module as MeliorModule},
Expand Down

0 comments on commit 8e70167

Please sign in to comment.