Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
edg-l committed Jan 17, 2024
1 parent ace0fcc commit a49013c
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 167 deletions.
263 changes: 105 additions & 158 deletions crates/concrete_codegen_mlir/src/codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@ use melior::{
DenseI64ArrayAttribute, FlatSymbolRefAttribute, IntegerAttribute, StringAttribute,
TypeAttribute,
},
operation::OperationBuilder,
r#type::{FunctionType, IntegerType, MemRefType},
Block, BlockRef, Location, Module as MeliorModule, Operation, Region, Type, Value,
ValueLike,
Block, BlockRef, Location, Module as MeliorModule, Operation, Region, Type, TypeLike,
Value, ValueLike,
},
Context as MeliorContext,
};
Expand Down Expand Up @@ -152,7 +153,12 @@ impl<'ctx, 'parent> ScopeContext<'ctx, 'parent> {
spec: &TypeSpec,
) -> Result<Type<'ctx>, Box<dyn Error>> {
match spec.is_ref() {
Some(_) => Ok(llvm::r#type::opaque_pointer(context)),
Some(_) => {
Ok(
MemRefType::new(self.resolve_type_spec_ref(context, spec)?, &[], None, None)
.into(),
)
}
None => self.resolve_type_spec_ref(context, spec),
}
}
Expand All @@ -168,8 +174,8 @@ impl<'ctx, 'parent> ScopeContext<'ctx, 'parent> {
TypeSpec::Generic { name, .. } => self.resolve_type(context, &name.name)?,
TypeSpec::Array {
of_type,
span,
is_ref,
span: _,
is_ref: _,
size,
} => match size {
Some(size) => {
Expand All @@ -185,19 +191,16 @@ impl<'ctx, 'parent> ScopeContext<'ctx, 'parent> {
None => {
//
let inner_type = self.resolve_type_spec(context, of_type)?;
Type::parse(context, &format!("memref<?x{inner_type}>")).unwrap()
// MemRefType::new(inner_type, &[(u32::MAX) as u64], None, None).into()
/*
// Type::parse(context, &format!("memref<?x{inner_type}>")).unwrap()

llvm::r#type::r#struct(
context,
&[
llvm::r#type::opaque_pointer(context),
IntegerType::new(context, 64).into(),
Type::parse(context, &format!("memref<?x{inner_type}>")).unwrap(),
IntegerType::new(context, 64).into(),
],
false,
)
*/
}
},
})
Expand Down Expand Up @@ -555,6 +558,17 @@ fn compile_while<'c, 'this: 'c>(
Ok(merge_block)
}

fn is_local_copy(a: &Expression) -> Option<(&PathOp, bool)> {
match a {
Expression::Value(value) => match value {
ValueExpr::Path(path) => Some((path, false)),
ValueExpr::Deref(path) => Some((path, true)),
_ => None,
},
_ => None,
}
}

fn compile_let_stmt<'ctx, 'parent: 'ctx>(
session: &Session,
context: &'ctx MeliorContext,
Expand All @@ -565,6 +579,8 @@ fn compile_let_stmt<'ctx, 'parent: 'ctx>(
) -> Result<(), Box<dyn Error>> {
match &info.target {
LetStmtTarget::Simple { name, r#type } => {
let location = get_location(context, session, name.span.from);

let value = compile_expression(
session,
context,
Expand All @@ -574,10 +590,7 @@ fn compile_let_stmt<'ctx, 'parent: 'ctx>(
&info.value,
Some(r#type),
)?;

let location = get_location(context, session, name.span.from);

let memref_type = MemRefType::new(value.r#type(), &[1], None, None);
let memref_type = MemRefType::new(value.r#type(), &[], None, None);

let alloca: Value = block
.append_operation(memref::alloca(
Expand All @@ -598,6 +611,7 @@ fn compile_let_stmt<'ctx, 'parent: 'ctx>(
))
.result(0)?
.into();

block.append_operation(memref::store(value, alloca, &[k0], location));

scope_ctx
Expand Down Expand Up @@ -640,15 +654,67 @@ fn compile_assign_stmt<'ctx, 'parent: 'ctx>(
Some(&local.type_spec),
)?;

let k0 = block
.append_operation(arith::constant(
context,
IntegerAttribute::new(0, Type::index(context)).into(),
location,
))
.result(0)?
.into();
block.append_operation(memref::store(value, local.value, &[k0], location));
if info.target.extra.is_empty() {
block.append_operation(memref::store(value, local.value, &[], location));
} else {
let mut store_target = block
.append_operation(memref::load(local.value, &[], location))
.result(0)?
.into();

let mut segment_iter = info.target.extra.iter().peekable();

while let Some(segment) = segment_iter.next() {
match segment {
PathSegment::FieldAccess(_) => todo!(),
PathSegment::ArrayIndex(index) => {
let index = compile_value_expr(
session, context, scope_ctx, helper, block, index, None,
)?;
let index_ty = Type::index(context);
let index = block
.append_operation(melior::dialect::index::castu(index, index_ty, location))
.result(0)?
.into();

if let TypeSpec::Array {
of_type: _,
size,
is_ref: _,
span,
} = &local.type_spec
{
let location = get_location(context, session, span.from);

#[allow(clippy::if_same_then_else)]
if size.is_some() {
// todo: check inbounds?
store_target = block
.append_operation(memref::load(store_target, &[index], location))
.result(0)?
.into();
} else {
store_target = block
.append_operation(memref::load(store_target, &[index], location))
.result(0)?
.into();
}

if segment_iter.peek().is_none() {
block.append_operation(memref::store(
value,
store_target,
&[index],
location,
));
}
} else {
panic!("type should be a array when indexing a value");
}
}
}
}
}

Ok(())
}
Expand Down Expand Up @@ -949,16 +1015,8 @@ fn compile_path_op<'ctx, 'parent: 'ctx>(
let location = get_location(context, session, path.first.span.from);

let mut value = if local.alloca {
let k0 = block
.append_operation(arith::constant(
context,
IntegerAttribute::new(0, Type::index(context)).into(),
location,
))
.result(0)?
.into();
block
.append_operation(memref::load(local.value, &[k0], location))
.append_operation(memref::load(local.value, &[], location))
.result(0)?
.into()
} else {
Expand All @@ -978,86 +1036,25 @@ fn compile_path_op<'ctx, 'parent: 'ctx>(
.into();

if let TypeSpec::Array {
of_type,
of_type: _,
size,
is_ref: _,
span,
} = &local.type_spec
{
let location = get_location(context, session, span.from);
let inner_type = scope_ctx.resolve_type_spec(context, of_type)?;
#[allow(clippy::if_same_then_else)]
if size.is_some() {
// todo: check inbounds?
value = block
.append_operation(memref::load(value, &[index], location))
.result(0)?
.into();
/*
// its a llvm.array
let ptr = block
.append_operation(llvm::get_element_ptr_dynamic(
context,
value,
&[index],
inner_type,
opaque_pointer(context),
location,
))
.result(0)?
.into();
value = block
.append_operation(llvm::load(
context,
ptr,
inner_type,
location,
LoadStoreOptions::new(),
))
.result(0)?
.into();
*/
} else {
value = block
.append_operation(memref::load(value, &[index], location))
.result(0)?
.into();
// its a struct {ptr,u64,u64}
/*
let ptr = block
.append_operation(llvm::extract_value(
context,
value,
DenseI64ArrayAttribute::new(context, &[0]),
opaque_pointer(context),
location,
))
.result(0)?
.into();
let elem_ptr = block
.append_operation(llvm::get_element_ptr_dynamic(
context,
ptr,
&[index],
inner_type,
opaque_pointer(context),
location,
))
.result(0)?
.into();
value = block
.append_operation(llvm::load(
context,
elem_ptr,
inner_type,
location,
LoadStoreOptions::new(),
))
.result(0)?
.into();
*/
}
} else {
panic!("type should be a array when indexing a value");
Expand Down Expand Up @@ -1087,89 +1084,39 @@ fn compile_deref<'ctx, 'parent: 'ctx>(
let inner_type = scope_ctx.resolve_type_spec_ref(context, &local.type_spec)?;

let mut value = block
.append_operation(llvm::load(
context,
local.value,
inner_type,
location,
LoadStoreOptions::new(),
))
.result(0)?
.into();
.append_operation(memref::load(local.value, &[], location)).result(0)?.into();


for segment in &path.extra {
match segment {
PathSegment::FieldAccess(_) => todo!(),
PathSegment::ArrayIndex(index) => {
let index =
compile_value_expr(session, context, scope_ctx, helper, block, index, None)?;
let index_ty = Type::index(context);
let index = block
.append_operation(melior::dialect::index::castu(index, index_ty, location))
.result(0)?
.into();

if let TypeSpec::Array {
of_type,
of_type: _,
size,
is_ref: _,
span,
} = &local.type_spec
{
let location = get_location(context, session, span.from);
let inner_type = scope_ctx.resolve_type_spec(context, of_type)?;
#[allow(clippy::if_same_then_else)]
if size.is_some() {
// its a llvm.array
let ptr = block
.append_operation(llvm::get_element_ptr_dynamic(
context,
value,
&[index],
inner_type,
opaque_pointer(context),
location,
))
.result(0)?
.into();

// todo: check inbounds?
value = block
.append_operation(llvm::load(
context,
ptr,
inner_type,
location,
LoadStoreOptions::new(),
))
.append_operation(memref::load(value, &[index], location))
.result(0)?
.into();
} else {
// its a struct {ptr,u64,u64}
let ptr = block
.append_operation(llvm::extract_value(
context,
value,
DenseI64ArrayAttribute::new(context, &[0]),
opaque_pointer(context),
location,
))
.result(0)?
.into();

let elem_ptr = block
.append_operation(llvm::get_element_ptr_dynamic(
context,
ptr,
&[index],
inner_type,
opaque_pointer(context),
location,
))
.result(0)?
.into();

value = block
.append_operation(llvm::load(
context,
elem_ptr,
inner_type,
location,
LoadStoreOptions::new(),
))
.append_operation(memref::load(value, &[index], location))
.result(0)?
.into();
}
Expand Down
Loading

0 comments on commit a49013c

Please sign in to comment.