Skip to content

Commit

Permalink
references
Browse files Browse the repository at this point in the history
  • Loading branch information
edg-l committed Jan 18, 2024
1 parent a490339 commit 37a040c
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 165 deletions.
7 changes: 6 additions & 1 deletion crates/concrete_ast/src/expressions.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
use crate::{common::Ident, statements::Statement, types::TypeSpec};
use crate::{
common::Ident,
statements::Statement,
types::{RefType, TypeSpec},
};

#[derive(Clone, Debug, Eq, PartialEq)]
pub enum Expression {
Expand All @@ -19,6 +23,7 @@ pub enum ValueExpr {
ConstStr(String),
Path(PathOp),
Deref(PathOp),
AsRef { path: PathOp, ref_type: RefType },
}

#[derive(Clone, Copy, Debug, Eq, PartialEq)]
Expand Down
219 changes: 56 additions & 163 deletions crates/concrete_codegen_mlir/src/codegen.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{char::MAX, collections::HashMap, error::Error};
use std::{collections::HashMap, error::Error};

use bumpalo::Bump;
use concrete_ast::{
Expand All @@ -16,19 +16,13 @@ use concrete_session::Session;
use melior::{
dialect::{
arith::{self, CmpiPredicate},
cf, func,
llvm::{self, r#type::opaque_pointer, LoadStoreOptions},
memref,
cf, func, memref,
},
ir::{
attribute::{
DenseI64ArrayAttribute, FlatSymbolRefAttribute, IntegerAttribute, StringAttribute,
TypeAttribute,
},
operation::OperationBuilder,
attribute::{FlatSymbolRefAttribute, IntegerAttribute, StringAttribute, TypeAttribute},
r#type::{FunctionType, IntegerType, MemRefType},
Block, BlockRef, Location, Module as MeliorModule, Operation, Region, Type, TypeLike,
Value, ValueLike,
Block, BlockRef, Location, Module as MeliorModule, Operation, Region, Type, Value,
ValueLike,
},
Context as MeliorContext,
};
Expand Down Expand Up @@ -172,65 +166,27 @@ impl<'ctx, 'parent> ScopeContext<'ctx, 'parent> {
Ok(match spec {
TypeSpec::Simple { name, .. } => self.resolve_type(context, &name.name)?,
TypeSpec::Generic { name, .. } => self.resolve_type(context, &name.name)?,
TypeSpec::Array {
of_type,
span: _,
is_ref: _,
size,
} => match size {
Some(size) => {
let inner_type = self.resolve_type_spec(context, of_type)?;
MemRefType::new(inner_type, &[*size], None, None).into()
/*
llvm::r#type::array(
self.resolve_type_spec(context, &of_type)?,
(*size).try_into().expect("size was above u32"),
)
*/
}
None => {
//
let inner_type = self.resolve_type_spec(context, of_type)?;
// Type::parse(context, &format!("memref<?x{inner_type}>")).unwrap()

llvm::r#type::r#struct(
context,
&[
Type::parse(context, &format!("memref<?x{inner_type}>")).unwrap(),
IntegerType::new(context, 64).into(),
],
false,
)
}
},
TypeSpec::Array { .. } => {
todo!("implement arrays")
}
})
}

fn is_type_signed(&self, type_info: &TypeSpec) -> bool {
let signed = ["i8", "i16", "i32", "i64", "i128"];
match type_info {
TypeSpec::Simple { name, is_ref, .. } => signed.contains(&name.name.as_str()),
TypeSpec::Generic { name, is_ref, .. } => signed.contains(&name.name.as_str()),
TypeSpec::Array {
of_type,
span,
is_ref,
size,
} => unreachable!(),
TypeSpec::Simple { name, .. } => signed.contains(&name.name.as_str()),
TypeSpec::Generic { name, .. } => signed.contains(&name.name.as_str()),
TypeSpec::Array { .. } => unreachable!(),
}
}

fn is_float(&self, type_info: &TypeSpec) -> bool {
let signed = ["f32", "f64"];
match type_info {
TypeSpec::Simple { name, is_ref, .. } => signed.contains(&name.name.as_str()),
TypeSpec::Generic { name, is_ref, .. } => signed.contains(&name.name.as_str()),
TypeSpec::Array {
of_type,
span,
is_ref,
size,
} => unreachable!(),
TypeSpec::Simple { name, .. } => signed.contains(&name.name.as_str()),
TypeSpec::Generic { name, .. } => signed.contains(&name.name.as_str()),
TypeSpec::Array { .. } => unreachable!(),
}
}
}
Expand Down Expand Up @@ -565,17 +521,6 @@ fn compile_while<'c, 'this: 'c>(
Ok(merge_block)
}

fn is_local_copy(a: &Expression) -> Option<(&PathOp, bool)> {
match a {
Expression::Value(value) => match value {
ValueExpr::Path(path) => Some((path, false)),
ValueExpr::Deref(path) => Some((path, true)),
_ => None,
},
_ => None,
}
}

fn compile_let_stmt<'ctx, 'parent: 'ctx>(
session: &Session,
context: &'ctx MeliorContext,
Expand Down Expand Up @@ -610,16 +555,8 @@ fn compile_let_stmt<'ctx, 'parent: 'ctx>(
))
.result(0)?
.into();
let k0 = block
.append_operation(arith::constant(
context,
IntegerAttribute::new(0, Type::index(context)).into(),
location,
))
.result(0)?
.into();

block.append_operation(memref::store(value, alloca, &[k0], location));
block.append_operation(memref::store(value, alloca, &[], location));

scope_ctx
.locals
Expand Down Expand Up @@ -907,6 +844,7 @@ fn compile_value_expr<'ctx, 'parent: 'ctx>(
value: &ValueExpr,
type_info: Option<&TypeSpec>,
) -> Result<Value<'ctx, 'parent>, Box<dyn Error>> {
tracing::debug!("compiling value_expr for {:?}", value);
let location = Location::unknown(context);
match value {
ValueExpr::ConstBool(value) => {
Expand Down Expand Up @@ -944,6 +882,9 @@ fn compile_value_expr<'ctx, 'parent: 'ctx>(
ValueExpr::Deref(value) => {
compile_deref(session, context, scope_ctx, _helper, block, value)
}
ValueExpr::AsRef { path, ref_type: _ } => {
compile_asref(session, context, scope_ctx, _helper, block, path)
}
}
}

Expand All @@ -955,6 +896,7 @@ fn compile_fn_call<'ctx, 'parent: 'ctx>(
block: &'parent Block<'ctx>,
info: &FnCallOp,
) -> Result<Value<'ctx, 'parent>, Box<dyn Error>> {
tracing::debug!("compiling fncall: {:?}", info);
let mut args = Vec::with_capacity(info.args.len());
let location = get_location(context, session, info.target.span.from);

Expand Down Expand Up @@ -1006,10 +948,11 @@ fn compile_path_op<'ctx, 'parent: 'ctx>(
session: &Session,
context: &'ctx MeliorContext,
scope_ctx: &mut ScopeContext<'ctx, 'parent>,
helper: &BlockHelper<'ctx, 'parent>,
_helper: &BlockHelper<'ctx, 'parent>,
block: &'parent Block<'ctx>,
path: &PathOp,
) -> Result<Value<'ctx, 'parent>, Box<dyn Error>> {
tracing::debug!("compiling pathop {:?}", path);
// For now only simple and array variables work.
// TODO: implement properly, this requires having structs implemented.

Expand All @@ -1021,7 +964,7 @@ fn compile_path_op<'ctx, 'parent: 'ctx>(

let location = get_location(context, session, path.first.span.from);

let mut value = if local.alloca {
let value = if local.alloca {
block
.append_operation(memref::load(local.value, &[], location))
.result(0)?
Expand All @@ -1030,109 +973,59 @@ fn compile_path_op<'ctx, 'parent: 'ctx>(
local.value
};

for segment in &path.extra {
match segment {
PathSegment::FieldAccess(_) => todo!(),
PathSegment::ArrayIndex(index) => {
let index =
compile_value_expr(session, context, scope_ctx, helper, block, index, None)?;
let index_ty = Type::index(context);
let index = block
.append_operation(melior::dialect::index::castu(index, index_ty, location))
.result(0)?
.into();

if let TypeSpec::Array {
of_type: _,
size,
is_ref: _,
span,
} = &local.type_spec
{
let location = get_location(context, session, span.from);
#[allow(clippy::if_same_then_else)]
if size.is_some() {
// todo: check inbounds?
value = block
.append_operation(memref::load(value, &[index], location))
.result(0)?
.into();
} else {
value = block
.append_operation(memref::load(value, &[index], location))
.result(0)?
.into();
}
} else {
panic!("type should be a array when indexing a value");
}
}
}
}

Ok(value)
}

fn compile_deref<'ctx, 'parent: 'ctx>(
session: &Session,
context: &'ctx MeliorContext,
scope_ctx: &mut ScopeContext<'ctx, 'parent>,
helper: &BlockHelper<'ctx, 'parent>,
_helper: &BlockHelper<'ctx, 'parent>,
block: &'parent Block<'ctx>,
path: &PathOp,
) -> Result<Value<'ctx, 'parent>, Box<dyn Error>> {
tracing::debug!("compiling deref for {:?}", path);
let local = scope_ctx
.locals
.get(&path.first.name)
.expect("local not found")
.clone();

let location = get_location(context, session, path.first.span.from);
let inner_type = scope_ctx.resolve_type_spec_ref(context, &local.type_spec)?;

let mut value = block
.append_operation(memref::load(local.value, &[], location)).result(0)?.into();


for segment in &path.extra {
match segment {
PathSegment::FieldAccess(_) => todo!(),
PathSegment::ArrayIndex(index) => {
let index =
compile_value_expr(session, context, scope_ctx, helper, block, index, None)?;
let index_ty = Type::index(context);
let index = block
.append_operation(melior::dialect::index::castu(index, index_ty, location))
.result(0)?
.into();

if let TypeSpec::Array {
of_type: _,
size,
is_ref: _,
span,
} = &local.type_spec
{
let location = get_location(context, session, span.from);
#[allow(clippy::if_same_then_else)]
if size.is_some() {
// todo: check inbounds?
value = block
.append_operation(memref::load(value, &[index], location))
.result(0)?
.into();
} else {
value = block
.append_operation(memref::load(value, &[index], location))
.result(0)?
.into();
}
} else {
panic!("type should be a array when indexing a value");
}
}
}
.append_operation(memref::load(local.value, &[], location))
.result(0)?
.into();

if local.alloca {
value = block
.append_operation(memref::load(value, &[], location))
.result(0)?
.into();
}

Ok(value)
}

fn compile_asref<'ctx, 'parent: 'ctx>(
_session: &Session,
_context: &'ctx MeliorContext,
scope_ctx: &mut ScopeContext<'ctx, 'parent>,
_helper: &BlockHelper<'ctx, 'parent>,
_block: &'parent Block<'ctx>,
path: &PathOp,
) -> Result<Value<'ctx, 'parent>, Box<dyn Error>> {
tracing::debug!("compiling asref for {:?}", path);
let local = scope_ctx
.locals
.get(&path.first.name)
.expect("local not found")
.clone();

if !local.alloca {
panic!("can only take refs to non register values");
}

Ok(local.value)
}
28 changes: 28 additions & 0 deletions crates/concrete_driver/tests/programs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,31 @@ fn test_import() {
let code = output.status.code().unwrap();
assert_eq!(code, 8);
}

#[test]
fn test_reference() {
let source = r#"
mod Simple {
fn main(argc: i64) -> i64 {
let x: i64 = argc;
return references(x) + dereference(&x);
}
fn dereference(a: &i64) -> i64 {
return *a;
}
fn references(a: i64) -> i64 {
let x: i64 = a;
let y: &i64 = &x;
return *y;
}
}
"#;

let result = compile_program(source, "references", false).expect("failed to compile");

let output = run_program(&result.binary_file).expect("failed to run");
let code = output.status.code().unwrap();
assert_eq!(code, 2);
}
4 changes: 4 additions & 0 deletions crates/concrete_parser/src/grammar.lalrpop
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,10 @@ pub(crate) ValueExpr: ast::expressions::ValueExpr = {
<"string"> => ast::expressions::ValueExpr::ConstStr(<>),
<PathOp> => ast::expressions::ValueExpr::Path(<>),
"*" <PathOp> => ast::expressions::ValueExpr::Deref(<>),
<ref_type:RefType> <path:PathOp> => ast::expressions::ValueExpr::AsRef {
path,
ref_type
},
}

pub(crate) IfExpr: ast::expressions::IfExpr = {
Expand Down
Loading

0 comments on commit 37a040c

Please sign in to comment.