Skip to content

Commit

Permalink
Add pointer addition (arithmetic) and intrinsic parsing
Browse files Browse the repository at this point in the history
  • Loading branch information
edg-l committed May 8, 2024
1 parent 27904dc commit aff819d
Show file tree
Hide file tree
Showing 12 changed files with 198 additions and 13 deletions.
7 changes: 7 additions & 0 deletions crates/concrete_ast/src/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,10 @@ pub struct GenericParam {
pub params: Vec<TypeSpec>,
pub span: Span,
}

#[derive(Clone, Debug, Eq, Hash, PartialEq, PartialOrd, Ord)]
pub struct Attribute {
pub name: String,
pub value: Option<String>,
pub span: Span,
}
3 changes: 2 additions & 1 deletion crates/concrete_ast/src/functions.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::{
common::{DocString, GenericParam, Ident, Span},
common::{Attribute, DocString, GenericParam, Ident, Span},
statements::Statement,
types::TypeSpec,
};
Expand All @@ -13,6 +13,7 @@ pub struct FunctionDecl {
pub ret_type: Option<TypeSpec>,
pub is_extern: bool,
pub is_pub: bool,
pub attributes: Vec<Attribute>,
pub span: Span,
}

Expand Down
39 changes: 38 additions & 1 deletion crates/concrete_codegen_mlir/src/codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -613,10 +613,32 @@ fn compile_binop<'c: 'b, 'b>(

let is_float = matches!(lhs_ty.kind, TyKind::Float(_));
let is_signed = matches!(lhs_ty.kind, TyKind::Int(_));
let is_ptr = if let TyKind::Ptr(inner, _) = &lhs_ty.kind {
Some((*inner).clone())
} else {
None
};

Ok(match op {
BinOp::Add => {
let value = if is_float {
let value = if let Some(inner) = is_ptr {
let inner_ty = compile_type(ctx.module_ctx, &inner);
block
.append_operation(
ods::llvm::getelementptr(
ctx.context(),
pointer(ctx.context(), 0),
lhs,
&[rhs],
DenseI32ArrayAttribute::new(ctx.context(), &[i32::MIN]),
TypeAttribute::new(inner_ty),
location,
)
.into(),
)
.result(0)?
.into()
} else if is_float {
block
.append_operation(arith::addf(lhs, rhs, location))
.result(0)?
Expand All @@ -630,6 +652,11 @@ fn compile_binop<'c: 'b, 'b>(
(value, lhs_ty)
}
BinOp::Sub => {
if is_ptr.is_some() {
return Err(CodegenError::NotImplemented(
"substracting from a pointer is not yet implemented".to_string(),
));
}
let value = if is_float {
block
.append_operation(arith::subf(lhs, rhs, location))
Expand All @@ -644,6 +671,11 @@ fn compile_binop<'c: 'b, 'b>(
(value, lhs_ty)
}
BinOp::Mul => {
if is_ptr.is_some() {
return Err(CodegenError::NotImplemented(
"multiplying a pointer is not yet implemented".to_string(),
));
}
let value = if is_float {
block
.append_operation(arith::mulf(lhs, rhs, location))
Expand All @@ -658,6 +690,11 @@ fn compile_binop<'c: 'b, 'b>(
(value, lhs_ty)
}
BinOp::Div => {
if is_ptr.is_some() {
return Err(CodegenError::NotImplemented(
"dividing a pointer is not yet implemented".to_string(),
));
}
let value = if is_float {
block
.append_operation(arith::divf(lhs, rhs, location))
Expand Down
2 changes: 2 additions & 0 deletions crates/concrete_codegen_mlir/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,6 @@ pub enum CodegenError {
LLVMCompileError(String),
#[error("melior error: {0}")]
MeliorError(#[from] melior::Error),
#[error("not yet implemented: {0}")]
NotImplemented(String),
}
18 changes: 17 additions & 1 deletion crates/concrete_driver/tests/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::{
borrow::Cow,
fmt,
path::{Path, PathBuf},
process::Output,
process::{Output, Stdio},
};

use ariadne::Source;
Expand Down Expand Up @@ -110,6 +110,7 @@ pub fn compile_program(

pub fn run_program(program: &Path) -> Result<Output, std::io::Error> {
std::process::Command::new(program)
.stdout(Stdio::piped())
.spawn()?
.wait_with_output()
}
Expand All @@ -122,3 +123,18 @@ pub fn compile_and_run(source: &str, name: &str, library: bool, optlevel: OptLev

output.status.code().unwrap()
}

#[allow(unused)] // false positive
#[track_caller]
pub fn compile_and_run_output(
source: &str,
name: &str,
library: bool,
optlevel: OptLevel,
) -> String {
let result = compile_program(source, name, library, optlevel).expect("failed to compile");

let output = run_program(&result.binary_file).expect("failed to run");

std::str::from_utf8(&output.stdout).unwrap().to_string()
}
22 changes: 21 additions & 1 deletion crates/concrete_driver/tests/examples.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::common::compile_and_run;
use crate::common::{compile_and_run, compile_and_run_output};
use concrete_session::config::OptLevel;
use test_case::test_case;

Expand Down Expand Up @@ -39,3 +39,23 @@ fn example_tests(source: &str, name: &str, is_library: bool, status_code: i32) {
compile_and_run(source, name, is_library, OptLevel::Aggressive)
);
}

#[test_case(include_str!("../../../examples/hello_world_hacky.con"), "hello_world_hacky", false, "Hello World\n" ; "hello_world_hacky.con")]
fn example_tests_with_output(source: &str, name: &str, is_library: bool, result: &str) {
assert_eq!(
result,
compile_and_run_output(source, name, is_library, OptLevel::None)
);
assert_eq!(
result,
compile_and_run_output(source, name, is_library, OptLevel::Less)
);
assert_eq!(
result,
compile_and_run_output(source, name, is_library, OptLevel::Default)
);
assert_eq!(
result,
compile_and_run_output(source, name, is_library, OptLevel::Aggressive)
);
}
16 changes: 15 additions & 1 deletion crates/concrete_ir/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ pub struct FnBody {
pub id: DefId,
pub name: String,
pub is_extern: bool,
pub is_intrinsic: Option<ConcreteIntrinsic>,
pub basic_blocks: Vec<BasicBlock>,
pub locals: Vec<Local>,
}
Expand Down Expand Up @@ -397,7 +398,15 @@ impl fmt::Display for TyKind {
FloatTy::F64 => write!(f, "f32"),
},
TyKind::String => write!(f, "string"),
TyKind::Array(_, _) => todo!(),
TyKind::Array(inner, size) => {
let value =
if let ConstKind::Value(ValueTree::Leaf(ConstValue::U64(x))) = &size.data {
*x
} else {
unreachable!("const data for array sizes should always be u64")
};
write!(f, "[{}; {:?}]", inner.kind, value)
}
TyKind::Ref(inner, is_mut) => {
let word = if let Mutability::Mut = is_mut {
"mut"
Expand Down Expand Up @@ -571,3 +580,8 @@ pub enum ConstValue {
F32(f32),
F64(f64),
}

#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)]
pub enum ConcreteIntrinsic {
// Todo: Add intrinsics here
}
35 changes: 28 additions & 7 deletions crates/concrete_ir/src/lowering.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ use concrete_ast::{
};

use crate::{
AdtBody, BasicBlock, BinOp, ConstData, ConstKind, ConstValue, DefId, FloatTy, FnBody, IntTy,
Local, LocalKind, LogOp, Mutability, Operand, Place, PlaceElem, ProgramBody, Rvalue, Statement,
StatementKind, SwitchTargets, Terminator, TerminatorKind, Ty, TyKind, UintTy, ValueTree,
VariantDef,
AdtBody, BasicBlock, BinOp, ConcreteIntrinsic, ConstData, ConstKind, ConstValue, DefId,
FloatTy, FnBody, IntTy, Local, LocalKind, LogOp, Mutability, Operand, Place, PlaceElem,
ProgramBody, Rvalue, Statement, StatementKind, SwitchTargets, Terminator, TerminatorKind, Ty,
TyKind, UintTy, ValueTree, VariantDef,
};

use self::errors::LoweringError;
Expand Down Expand Up @@ -217,11 +217,16 @@ fn lower_func(
func: &FunctionDef,
module_id: DefId,
) -> Result<BuildCtx, LoweringError> {
let is_intrinsic: Option<ConcreteIntrinsic> = None;

// TODO: parse insintrics here.

let mut builder = FnBodyBuilder {
body: FnBody {
basic_blocks: Vec::new(),
locals: Vec::new(),
is_extern: func.decl.is_extern,
is_intrinsic,
name: func.decl.name.name.clone(),
id: {
let body = ctx.body.modules.get(&module_id).unwrap();
Expand Down Expand Up @@ -350,11 +355,16 @@ fn lower_func_decl(
func: &FunctionDecl,
module_id: DefId,
) -> Result<BuildCtx, LoweringError> {
let is_intrinsic: Option<ConcreteIntrinsic> = None;

// TODO: parse insintrics here.

let builder = FnBodyBuilder {
body: FnBody {
basic_blocks: Vec::new(),
locals: Vec::new(),
is_extern: func.is_extern,
is_intrinsic,
name: func.name.name.clone(),
id: {
let body = ctx.body.modules.get(&module_id).unwrap();
Expand Down Expand Up @@ -1236,14 +1246,22 @@ fn lower_binary_op(
} else {
lower_expression(builder, lhs, type_hint.clone())?
};

// We must handle the special case where you can do ptr + offset.
let is_lhs_ptr = matches!(lhs_ty.kind, TyKind::Ptr(_, _));

let (rhs, rhs_ty, rhs_span) = if type_hint.is_none() {
let ty = find_expression_type(builder, rhs).unwrap_or(lhs_ty.clone());
lower_expression(builder, rhs, Some(ty))?
lower_expression(builder, rhs, if is_lhs_ptr { None } else { Some(ty) })?
} else {
lower_expression(builder, rhs, type_hint.clone())?
lower_expression(
builder,
rhs,
if is_lhs_ptr { None } else { type_hint.clone() },
)?
};

if lhs_ty != rhs_ty {
if !is_lhs_ptr && lhs_ty != rhs_ty {
return Err(LoweringError::UnexpectedType {
span: rhs_span,
found: rhs_ty,
Expand Down Expand Up @@ -1409,6 +1427,9 @@ fn lower_value_expr(
UintTy::U128 => ConstValue::U128(*value),
},
TyKind::Bool => ConstValue::Bool(*value != 0),
TyKind::Ptr(ref _inner, _mutable) => {
ConstValue::I64((*value).try_into().expect("value out of range"))
}
x => unreachable!("{:?}", x),
})),
},
Expand Down
20 changes: 19 additions & 1 deletion crates/concrete_parser/src/grammar.lalrpop
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ extern {
":" => Token::Colon,
"->" => Token::Arrow,
"," => Token::Coma,
"#" => Token::Hashtag,
"<" => Token::LessThanSign,
">" => Token::MoreThanSign,
">=" => Token::MoreThanEqSign,
Expand Down Expand Up @@ -119,6 +120,14 @@ PlusSeparated<T>: Vec<T> = {
}
};

List<T>: Vec<T> = {
<T> => vec![<>],
<mut s:List<T>> <n:T> => {
s.push(n);
s
},
}

// Requires the semicolon at end
SemiColonSeparated<T>: Vec<T> = {
<T> ";" => vec![<>],
Expand Down Expand Up @@ -291,13 +300,22 @@ pub(crate) Param: ast::functions::Param = {
}
}

pub(crate) Attribute: ast::common::Attribute = {
<lo:@L> "#" "[" <name:"identifier"> <value:("=" <"string">)?> "]" <hi:@R> => ast::common::Attribute {
name,
value,
span: ast::common::Span::new(lo, hi),
}
}

pub(crate) FunctionDecl: ast::functions::FunctionDecl = {
<lo:@L> <is_pub:"pub"?> <is_extern:"extern"?>
<lo:@L> <attributes:List<Attribute>?> <is_pub:"pub"?> <is_extern:"extern"?>
"fn" <name:Ident> <generic_params:GenericParams?> "(" <params:Comma<Param>> ")"
<ret_type:FunctionRetType?> <hi:@R> =>
ast::functions::FunctionDecl {
doc_string: None,
generic_params: generic_params.unwrap_or(vec![]),
attributes: attributes.unwrap_or(vec![]),
name,
params,
ret_type,
Expand Down
11 changes: 11 additions & 0 deletions crates/concrete_parser/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,17 @@ mod ModuleName {
return arr[1][0];
}
}"##;
let lexer = Lexer::new(source);
let parser = grammar::ProgramParser::new();
parser.parse(lexer).unwrap();
}

#[test]
fn parse_intrinsic() {
let source = r##"mod MyMod {
#[intrinsic = "simdsomething"]
pub extern fn myintrinsic();
}"##;
let lexer = Lexer::new(source);
let parser = grammar::ProgramParser::new();
Expand Down
2 changes: 2 additions & 0 deletions crates/concrete_parser/src/tokens.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ pub enum Token {
Coma,
#[token(".")]
Dot,
#[token("#")]
Hashtag,
#[token("<")]
LessThanSign,
#[token(">")]
Expand Down
36 changes: 36 additions & 0 deletions examples/hello_world_hacky.con
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
mod HelloWorld {
pub extern fn malloc(size: u64) -> *mut u8;
pub extern fn puts(data: *mut u8) -> i32;

fn main() -> i32 {
let origin: *mut u8 = malloc(12);
let mut p: *mut u8 = origin;

*p = 'H';
p = p + 1;
*p = 'e';
p = p + 1;
*p = 'l';
p = p + 1;
*p = 'l';
p = p + 1;
*p = 'o';
p = p + 1;
*p = ' ';
p = p + 1;
*p = 'W';
p = p + 1;
*p = 'o';
p = p + 1;
*p = 'r';
p = p + 1;
*p = 'l';
p = p + 1;
*p = 'd';
p = p + 1;
*p = '\0';
puts(origin);

return 0;
}
}

0 comments on commit aff819d

Please sign in to comment.