Skip to content

Commit

Permalink
implement casts
Browse files Browse the repository at this point in the history
  • Loading branch information
edg-l committed Mar 12, 2024
1 parent 12d52cb commit cdb02b9
Show file tree
Hide file tree
Showing 10 changed files with 394 additions and 147 deletions.
21 changes: 8 additions & 13 deletions crates/concrete_ast/src/expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,18 @@ pub enum Expression {
UnaryOp(UnaryOp, Box<Self>),
BinaryOp(Box<Self>, BinaryOp, Box<Self>),
StructInit(StructInitExpr),
Deref(Box<Self>),
AsRef(Box<Self>, bool),
Deref(Box<Self>, Span),
AsRef(Box<Self>, bool, Span),
Cast(PathOp, TypeSpec, Span),
}

#[derive(Clone, Debug, Eq, PartialEq)]
pub enum ValueExpr {
ConstBool(bool),
ConstChar(char),
ConstInt(u128),
ConstFloat(String),
ConstStr(String),
ConstBool(bool, Span),
ConstChar(char, Span),
ConstInt(u128, Span),
ConstFloat(String, Span),
ConstStr(String, Span),
Path(PathOp),
}

Expand Down Expand Up @@ -124,12 +125,6 @@ pub struct PathOp {
pub span: Span,
}

#[derive(Clone, Debug, Eq, PartialEq)]
pub struct CastOp {
pub value: Expression,
pub r#type: TypeSpec,
}

#[derive(Clone, Debug, Eq, PartialEq)]
pub struct FnCallOp {
pub target: Ident,
Expand Down
2 changes: 1 addition & 1 deletion crates/concrete_check/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ pub fn lowering_error_to_report(
LoweringError::UnexpectedType { span, found, expected } => {
let mut labels = vec![
Label::new((path.clone(), span.into()))
.with_message(format!("Unexpected type '{}', expected '{}'", found, expected.kind))
.with_message(format!("Unexpected type '{}', expected '{}'", found.kind, expected.kind))
.with_color(colors.next())
];

Expand Down
157 changes: 142 additions & 15 deletions crates/concrete_codegen_mlir/src/codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use melior::{
dialect::{
arith, cf, func,
llvm::{self, r#type::opaque_pointer, AllocaOptions, LoadStoreOptions},
ods,
},
ir::{
attribute::{
Expand Down Expand Up @@ -464,13 +465,13 @@ fn compile_rvalue<'c: 'b, 'b>(
block: &'b Block<'c>,
info: &Rvalue,
locals: &'b HashMap<usize, Value<'c, '_>>,
) -> Result<(Value<'c, 'b>, TyKind), CodegenError> {
) -> Result<(Value<'c, 'b>, Ty), CodegenError> {
Ok(match info {
Rvalue::Use(info) => compile_load_operand(ctx, block, info, locals)?,
Rvalue::LogicOp(_, _) => todo!(),
Rvalue::BinaryOp(op, (lhs, rhs)) => compile_binop(ctx, block, op, lhs, rhs, locals)?,
Rvalue::UnaryOp(_, _) => todo!(),
Rvalue::Ref(mutability, place) => {
Rvalue::Ref(_mutability, place) => {
let mut value = locals[&place.local];
let mut local_ty = ctx.get_fn_body().locals[place.local].ty.clone();

Expand All @@ -480,6 +481,7 @@ fn compile_rvalue<'c: 'b, 'b>(
PlaceElem::Deref => {
local_ty = match local_ty.kind {
TyKind::Ref(inner, _) => *(inner.clone()),
TyKind::Ptr(inner, _) => *(inner.clone()),
_ => unreachable!(),
};

Expand All @@ -499,7 +501,96 @@ fn compile_rvalue<'c: 'b, 'b>(
}
}

(value, TyKind::Ref(Box::new(local_ty), *mutability))
(value, local_ty)
}
Rvalue::Cast(place, target_ty, _span) => {
let location = Location::unknown(ctx.context());
let target_ty = target_ty.clone();
let target_mlir_ty = compile_type(ctx.module_ctx, &target_ty);
let (value, current_ty) = compile_load_place(ctx, block, place, locals)?;
let is_signed = target_ty.kind.is_signed();

if target_ty.kind.is_ptr_like() {
// int to ptr
if current_ty.kind.is_int() {
let value = block
.append_operation(
ods::llvm::inttoptr(ctx.context(), target_mlir_ty, value, location)
.into(),
)
.result(0)?
.into();
(value, target_ty.clone())
} else if current_ty.kind.is_ptr_like() {
// ptr to ptr: noop
(value, target_ty.clone())
} else {
unreachable!("cast from {:?} to ptr", current_ty.kind)
}
} else if target_ty.kind.is_int() {
if current_ty.kind.is_int() {
// int to int casts
match current_ty
.kind
.get_bit_width()
.unwrap()
.cmp(&target_ty.kind.get_bit_width().unwrap())
{
std::cmp::Ordering::Less => {
if is_signed {
let value = block
.append_operation(arith::extsi(value, target_mlir_ty, location))
.result(0)?
.into();
(value, target_ty.clone())
} else {
let value = block
.append_operation(arith::extui(value, target_mlir_ty, location))
.result(0)?
.into();
(value, target_ty.clone())
}
}
std::cmp::Ordering::Equal => (value, target_ty.clone()),
std::cmp::Ordering::Greater => {
let value = block
.append_operation(arith::trunci(value, target_mlir_ty, location))
.result(0)?
.into();
(value, target_ty.clone())
}
}
} else if current_ty.kind.is_float() {
// float to int
if is_signed {
let value = block
.append_operation(arith::fptosi(value, target_mlir_ty, location))
.result(0)?
.into();
(value, target_ty.clone())
} else {
let value = block
.append_operation(arith::fptoui(value, target_mlir_ty, location))
.result(0)?
.into();
(value, target_ty.clone())
}
} else if current_ty.kind.is_ptr_like() {
// ptr to int
let value = block
.append_operation(
ods::llvm::ptrtoint(ctx.context(), target_mlir_ty, value, location)
.into(),
)
.result(0)?
.into();
(value, target_ty.clone())
} else {
todo!("cast from {:?} to {:?}", current_ty, target_ty)
}
} else {
todo!("cast from {:?} to {:?}", current_ty, target_ty)
}
}
})
}
Expand All @@ -512,13 +603,13 @@ fn compile_binop<'c: 'b, 'b>(
lhs: &Operand,
rhs: &Operand,
locals: &HashMap<usize, Value<'c, '_>>,
) -> Result<(Value<'c, 'b>, TyKind), CodegenError> {
) -> Result<(Value<'c, 'b>, Ty), CodegenError> {
let (lhs, lhs_ty) = compile_load_operand(ctx, block, lhs, locals)?;
let (rhs, _rhs_ty) = compile_load_operand(ctx, block, rhs, locals)?;
let location = Location::unknown(ctx.context());

let is_float = matches!(lhs_ty, TyKind::Float(_));
let is_signed = matches!(lhs_ty, TyKind::Int(_));
let is_float = matches!(lhs_ty.kind, TyKind::Float(_));
let is_signed = matches!(lhs_ty.kind, TyKind::Int(_));

Ok(match op {
BinOp::Add => {
Expand Down Expand Up @@ -630,7 +721,13 @@ fn compile_binop<'c: 'b, 'b>(
.result(0)?
.into()
};
(value, TyKind::Bool)
(
value,
Ty {
span: None,
kind: TyKind::Bool,
},
)
}
BinOp::Lt => {
let value = if is_float {
Expand Down Expand Up @@ -667,7 +764,13 @@ fn compile_binop<'c: 'b, 'b>(
.result(0)?
.into()
};
(value, TyKind::Bool)
(
value,
Ty {
span: None,
kind: TyKind::Bool,
},
)
}
BinOp::Le => {
let value = if is_float {
Expand Down Expand Up @@ -704,7 +807,13 @@ fn compile_binop<'c: 'b, 'b>(
.result(0)?
.into()
};
(value, TyKind::Bool)
(
value,
Ty {
span: None,
kind: TyKind::Bool,
},
)
}
BinOp::Ne => {
let value = if is_float {
Expand All @@ -730,7 +839,13 @@ fn compile_binop<'c: 'b, 'b>(
.result(0)?
.into()
};
(value, TyKind::Bool)
(
value,
Ty {
span: None,
kind: TyKind::Bool,
},
)
}
BinOp::Ge => {
let value = if is_float {
Expand Down Expand Up @@ -767,7 +882,13 @@ fn compile_binop<'c: 'b, 'b>(
.result(0)?
.into()
};
(value, TyKind::Bool)
(
value,
Ty {
span: None,
kind: TyKind::Bool,
},
)
}
BinOp::Gt => {
let value = if is_float {
Expand Down Expand Up @@ -804,7 +925,13 @@ fn compile_binop<'c: 'b, 'b>(
.result(0)?
.into()
};
(value, TyKind::Bool)
(
value,
Ty {
span: None,
kind: TyKind::Bool,
},
)
}
})
}
Expand All @@ -814,7 +941,7 @@ fn compile_load_operand<'c: 'b, 'b>(
block: &'b Block<'c>,
info: &Operand,
locals: &HashMap<usize, Value<'c, '_>>,
) -> Result<(Value<'c, 'b>, TyKind), CodegenError> {
) -> Result<(Value<'c, 'b>, Ty), CodegenError> {
Ok(match info {
Operand::Place(info) => compile_load_place(ctx, block, info, locals)?,
Operand::Const(data) => match &data.data {
Expand Down Expand Up @@ -902,7 +1029,7 @@ fn compile_load_place<'c: 'b, 'b>(
block: &'b Block<'c>,
info: &Place,
locals: &HashMap<usize, Value<'c, '_>>,
) -> Result<(Value<'c, 'b>, TyKind), CodegenError> {
) -> Result<(Value<'c, 'b>, Ty), CodegenError> {
let mut ptr = locals[&info.local];
let body = ctx.get_fn_body();

Expand Down Expand Up @@ -966,7 +1093,7 @@ fn compile_load_place<'c: 'b, 'b>(
.result(0)?
.into();

Ok((value, local_ty.kind))
Ok((value, local_ty))
}

/// Used in switch
Expand Down
1 change: 1 addition & 0 deletions crates/concrete_driver/tests/examples.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ mod common;
#[test_case(include_str!("../../../examples/floats.con"), "floats", false, 1 ; "floats.con")]
#[test_case(include_str!("../../../examples/refs.con"), "refs", false, 6 ; "refs.con")]
#[test_case(include_str!("../../../examples/structs.con"), "structs", false, 8 ; "structs.con")]
#[test_case(include_str!("../../../examples/casts.con"), "casts", false, 2 ; "casts.con")]
fn example_tests(source: &str, name: &str, is_library: bool, status_code: i32) {
assert_eq!(
status_code,
Expand Down
57 changes: 56 additions & 1 deletion crates/concrete_ir/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,8 @@ pub enum Rvalue {
UnaryOp(UnOp, Operand),
/// A reference to a place.
Ref(Mutability, Place),
/// A cast.
Cast(Place, Ty, Span),
}

/// A operand is a value, either from a place in memory or constant data.
Expand Down Expand Up @@ -300,6 +302,59 @@ pub enum TyKind {
},
}

impl TyKind {
pub fn is_ptr_like(&self) -> bool {
matches!(self, TyKind::Ptr(_, _) | TyKind::Ref(_, _))
}

pub fn is_int(&self) -> bool {
matches!(self, TyKind::Int(_) | TyKind::Uint(_))
}

pub fn is_signed(&self) -> bool {
matches!(self, TyKind::Int(_))
}

pub fn is_float(&self) -> bool {
matches!(self, TyKind::Float(_))
}

/// Returns the type bit width, None if unsized.
///
/// Meant for use in casts.
pub fn get_bit_width(&self) -> Option<usize> {
match self {
TyKind::Unit => None,
TyKind::Bool => Some(1),
TyKind::Char => Some(8),
TyKind::Int(ty) => match ty {
IntTy::I8 => Some(8),
IntTy::I16 => Some(16),
IntTy::I32 => Some(32),
IntTy::I64 => Some(64),
IntTy::I128 => Some(128),
},
TyKind::Uint(ty) => match ty {
UintTy::U8 => Some(8),
UintTy::U16 => Some(16),
UintTy::U32 => Some(32),
UintTy::U64 => Some(64),
UintTy::U128 => Some(128),
},
TyKind::Float(ty) => match ty {
FloatTy::F32 => Some(32),
FloatTy::F64 => Some(64),
},
TyKind::String => todo!(),
TyKind::Array(_, _) => todo!(),
TyKind::Ref(_, _) => todo!(),
TyKind::Ptr(_, _) => todo!(),
TyKind::Param { .. } => todo!(),
TyKind::Struct { .. } => todo!(),
}
}
}

impl fmt::Display for TyKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Expand Down Expand Up @@ -413,7 +468,7 @@ pub enum FloatTy {

#[derive(Debug, Clone, PartialEq, PartialOrd)]
pub struct ConstData {
pub ty: TyKind,
pub ty: Ty,
pub data: ConstKind,
}

Expand Down
Loading

0 comments on commit cdb02b9

Please sign in to comment.