Skip to content

Commit

Permalink
compiles
Browse files Browse the repository at this point in the history
  • Loading branch information
edg-l committed Feb 13, 2024
1 parent 3dc5af8 commit e9a2181
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 67 deletions.
120 changes: 55 additions & 65 deletions crates/concrete_codegen_mlir/src/codegen.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
use std::{collections::HashMap, error::Error};

use concrete_ir::{
BinOp, DefId, FnBody, LocalKind, ModuleBody, Operand, Place, ProgramBody, Rvalue, Ty, TyKind,
ValueTree,
BinOp, DefId, FnBody, LocalKind, ModuleBody, Operand, Place, ProgramBody, Rvalue, Span, Ty,
TyKind, ValueTree,
};
use concrete_session::Session;
use melior::{
dialect::{arith, cf, func, llvm, memref},
ir::{
attribute::{FlatSymbolRefAttribute, IntegerAttribute, StringAttribute, TypeAttribute},
attribute::{FlatSymbolRefAttribute, StringAttribute, TypeAttribute},
r#type::{FunctionType, IntegerType, MemRefType},
Block, Location, Module as MeliorModule, Region, Type, Value, ValueLike,
Attribute, Block, Location, Module as MeliorModule, Region, Type, Value,
},
Context as MeliorContext,
};
Expand All @@ -37,6 +37,26 @@ impl<'a> ModuleCodegenCtx<'a> {
.get(&self.module_id)
.expect("module should exist")
}

pub fn get_location(&self, span: Option<Span>) -> Location {
if let Some(span) = span {
let (_, line, col) = self.ctx.session.source.get_offset_line(span.from).unwrap();
Location::new(
self.ctx.mlir_context,
self.ctx
.session
.file_path
.file_name()
.unwrap()
.to_str()
.unwrap(),
line + 1,
col + 1,
)
} else {
Location::unknown(self.ctx.mlir_context)
}
}
}

pub fn compile_program(ctx: CodegenCtx) -> Result<(), Box<dyn Error>> {
Expand Down Expand Up @@ -95,10 +115,8 @@ impl<'a> FunctionCodegenCtx<'a> {
}

fn compile_function(ctx: FunctionCodegenCtx) -> Result<(), Box<dyn std::error::Error>> {
let module = ctx.module_ctx.get_module_body();
let body = ctx.get_fn_body();
let body_sig = ctx.get_fn_sig();
let (param_types, ret_type) = ctx.get_fn_sig();

let region = Region::new();

Expand All @@ -109,7 +127,7 @@ fn compile_function(ctx: FunctionCodegenCtx) -> Result<(), Box<dyn std::error::E
.map(|x| {
(
compile_type(ctx.module_ctx, &x.ty),
Location::unknown(ctx.context()),
ctx.module_ctx.get_location(x.span),
)
})
.collect();
Expand Down Expand Up @@ -289,10 +307,7 @@ fn compile_function(ctx: FunctionCodegenCtx) -> Result<(), Box<dyn std::error::E
));
}
}
concrete_ir::TerminatorKind::SwitchInt {
discriminator,
targets,
} => todo!(),
concrete_ir::TerminatorKind::SwitchInt { .. } => todo!(),
}
}
}
Expand Down Expand Up @@ -518,11 +533,7 @@ fn compile_value_tree<'c: 'b, 'b>(
concrete_ir::ConstValue::Bool(value) => block
.append_operation(arith::constant(
ctx.context(),
IntegerAttribute::new(
(*value) as i64,
IntegerType::new(ctx.context(), 1).into(),
)
.into(),
Attribute::parse(ctx.context(), &format!("{} : i1", (*value) as u8)).unwrap(),
Location::unknown(ctx.context()),
))
.result(0)
Expand All @@ -531,11 +542,7 @@ fn compile_value_tree<'c: 'b, 'b>(
concrete_ir::ConstValue::I8(value) => block
.append_operation(arith::constant(
ctx.context(),
IntegerAttribute::new(
(*value) as i64,
IntegerType::new(ctx.context(), 8).into(),
)
.into(),
Attribute::parse(ctx.context(), &format!("{} : i8", value)).unwrap(),
Location::unknown(ctx.context()),
))
.result(0)
Expand All @@ -544,11 +551,7 @@ fn compile_value_tree<'c: 'b, 'b>(
concrete_ir::ConstValue::I16(value) => block
.append_operation(arith::constant(
ctx.context(),
IntegerAttribute::new(
(*value) as i64,
IntegerType::new(ctx.context(), 16).into(),
)
.into(),
Attribute::parse(ctx.context(), &format!("{} : i16", value)).unwrap(),
Location::unknown(ctx.context()),
))
.result(0)
Expand All @@ -557,11 +560,7 @@ fn compile_value_tree<'c: 'b, 'b>(
concrete_ir::ConstValue::I32(value) => block
.append_operation(arith::constant(
ctx.context(),
IntegerAttribute::new(
(*value) as i64,
IntegerType::new(ctx.context(), 32).into(),
)
.into(),
Attribute::parse(ctx.context(), &format!("{} : i32", value)).unwrap(),
Location::unknown(ctx.context()),
))
.result(0)
Expand All @@ -570,8 +569,7 @@ fn compile_value_tree<'c: 'b, 'b>(
concrete_ir::ConstValue::I64(value) => block
.append_operation(arith::constant(
ctx.context(),
IntegerAttribute::new((*value), IntegerType::new(ctx.context(), 64).into())
.into(),
Attribute::parse(ctx.context(), &format!("{} : i64", value)).unwrap(),
Location::unknown(ctx.context()),
))
.result(0)
Expand All @@ -580,11 +578,7 @@ fn compile_value_tree<'c: 'b, 'b>(
concrete_ir::ConstValue::I128(value) => block
.append_operation(arith::constant(
ctx.context(),
IntegerAttribute::new(
(*value) as i64,
IntegerType::new(ctx.context(), 128).into(),
)
.into(),
Attribute::parse(ctx.context(), &format!("{} : i128", value)).unwrap(),
Location::unknown(ctx.context()),
))
.result(0)
Expand All @@ -593,11 +587,7 @@ fn compile_value_tree<'c: 'b, 'b>(
concrete_ir::ConstValue::U8(value) => block
.append_operation(arith::constant(
ctx.context(),
IntegerAttribute::new(
(*value) as i64,
IntegerType::new(ctx.context(), 8).into(),
)
.into(),
Attribute::parse(ctx.context(), &format!("{} : i8", value)).unwrap(),
Location::unknown(ctx.context()),
))
.result(0)
Expand All @@ -606,11 +596,7 @@ fn compile_value_tree<'c: 'b, 'b>(
concrete_ir::ConstValue::U16(value) => block
.append_operation(arith::constant(
ctx.context(),
IntegerAttribute::new(
(*value) as i64,
IntegerType::new(ctx.context(), 16).into(),
)
.into(),
Attribute::parse(ctx.context(), &format!("{} : i16", value)).unwrap(),
Location::unknown(ctx.context()),
))
.result(0)
Expand All @@ -619,11 +605,7 @@ fn compile_value_tree<'c: 'b, 'b>(
concrete_ir::ConstValue::U32(value) => block
.append_operation(arith::constant(
ctx.context(),
IntegerAttribute::new(
(*value) as i64,
IntegerType::new(ctx.context(), 32).into(),
)
.into(),
Attribute::parse(ctx.context(), &format!("{} : i32", value)).unwrap(),
Location::unknown(ctx.context()),
))
.result(0)
Expand All @@ -632,11 +614,7 @@ fn compile_value_tree<'c: 'b, 'b>(
concrete_ir::ConstValue::U64(value) => block
.append_operation(arith::constant(
ctx.context(),
IntegerAttribute::new(
(*value) as i64,
IntegerType::new(ctx.context(), 64).into(),
)
.into(),
Attribute::parse(ctx.context(), &format!("{} : i64", value)).unwrap(),
Location::unknown(ctx.context()),
))
.result(0)
Expand All @@ -645,18 +623,30 @@ fn compile_value_tree<'c: 'b, 'b>(
concrete_ir::ConstValue::U128(value) => block
.append_operation(arith::constant(
ctx.context(),
IntegerAttribute::new(
(*value) as i64,
IntegerType::new(ctx.context(), 128).into(),
)
.into(),
Attribute::parse(ctx.context(), &format!("{} : i128", value)).unwrap(),
Location::unknown(ctx.context()),
))
.result(0)
.unwrap()
.into(),
concrete_ir::ConstValue::F32(value) => block
.append_operation(arith::constant(
ctx.context(),
Attribute::parse(ctx.context(), &format!("{} : f32", value)).unwrap(),
Location::unknown(ctx.context()),
))
.result(0)
.unwrap()
.into(),
concrete_ir::ConstValue::F64(value) => block
.append_operation(arith::constant(
ctx.context(),
Attribute::parse(ctx.context(), &format!("{} : f64", value)).unwrap(),
Location::unknown(ctx.context()),
))
.result(0)
.unwrap()
.into(),
concrete_ir::ConstValue::F32(_) => todo!(),
concrete_ir::ConstValue::F64(_) => todo!(),
},
ValueTree::Branch(_) => todo!(),
}
Expand Down
2 changes: 1 addition & 1 deletion crates/concrete_codegen_mlir/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use concrete_ir::ProgramBody;
use concrete_session::Session;
use melior::{
dialect::DialectRegistry,
ir::{operation::OperationPrintingFlags, Location, Module as MeliorModule},
ir::{Location, Module as MeliorModule},
utility::{register_all_dialects, register_all_llvm_translations, register_all_passes},
Context as MeliorContext,
};
Expand Down
4 changes: 3 additions & 1 deletion crates/concrete_ir/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::collections::{BTreeMap, HashMap, HashSet};

use concrete_ast::common::{Ident, Span};
use concrete_ast::common::Ident;

pub mod lowering;

Expand All @@ -9,6 +9,8 @@ pub type BlockIndex = usize;
pub type TypeIndex = usize;
pub type FieldIndex = usize;

pub use concrete_ast::common::Span;

#[derive(Debug, Clone, Default)]
pub struct SymbolTable {
pub symbols: HashMap<DefId, String>,
Expand Down

0 comments on commit e9a2181

Please sign in to comment.