Skip to content

Commit

Permalink
JIT support for match expressions. (powdr-labs#1838)
Browse files Browse the repository at this point in the history
This PR supports match expression, while most of the work regards
patterns.

Unfortunately, we cannot use rust patterns directly: The types are too
incompatible. For example a `match ("abc", 7) { ("abc", 7) => ...}`
cannot be directly translated, since the rust types here are `String`
and `ibig::IBig` and they cannot be used with these literals like that,
at least not directly and in all circumstances.

The way it is implemented here is that each pattern is compiled to code
that is supposed to evaluate to an Option. If the pattern matches, the
Option is a Some-value that contains the values that are assigned to the
variables in the pattern. The function also returns a string containing
the variable names. The function calls itself recursively on recursive
data structures.

---------

Co-authored-by: Gastón Zanitti <[email protected]>
Co-authored-by: Georg Wiese <[email protected]>
  • Loading branch information
3 people authored Oct 2, 2024
1 parent 990d357 commit eebbe5f
Show file tree
Hide file tree
Showing 2 changed files with 219 additions and 4 deletions.
127 changes: 123 additions & 4 deletions jit-compiler/src/codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@ use powdr_ast::{
display::quote,
types::{ArrayType, FunctionType, Type, TypeScheme},
ArrayLiteral, BinaryOperation, BinaryOperator, BlockExpression, FunctionCall, IfExpression,
IndexAccess, LambdaExpression, Number, StatementInsideBlock, UnaryOperation,
IndexAccess, LambdaExpression, MatchArm, MatchExpression, Number, Pattern,
StatementInsideBlock, UnaryOperation,
},
};
use powdr_number::{BigUint, FieldElement, LargeInt};
use powdr_number::{BigInt, BigUint, FieldElement, LargeInt};

pub struct CodeGenerator<'a, T> {
analyzed: &'a Analyzed<T>,
Expand Down Expand Up @@ -265,8 +266,8 @@ impl<'a, T: FieldElement> CodeGenerator<'a, T> {
"({})",
items
.iter()
.map(|i| self.format_expr(i))
.collect::<Result<Vec<_>, _>>()?
.map(|i| Ok(format!("({}.clone())", self.format_expr(i)?)))
.collect::<Result<Vec<_>, String>>()?
.join(", ")
),
Expression::BlockExpression(_, BlockExpression { statements, expr }) => {
Expand All @@ -283,6 +284,29 @@ impl<'a, T: FieldElement> CodeGenerator<'a, T> {
.unwrap_or_default()
)
}
Expression::MatchExpression(_, MatchExpression { scrutinee, arms }) => {
// We cannot use rust match expressions directly.
// Instead, we compile to a sequence of `if let Some(...)` statements.

// TODO try to find a solution where we do not introduce a variable
// or at least make it unique.
let var_name = "scrutinee__";
format!(
"{{\nlet {var_name} = ({}).clone();\n{}\n}}\n",
self.format_expr(scrutinee)?,
arms.iter()
.map(|MatchArm { pattern, value }| {
let (bound_vars, arm_test) = check_pattern(var_name, pattern)?;
Ok(format!(
"if let Some({bound_vars}) = ({arm_test}) {{\n{}\n}}",
self.format_expr(value)?,
))
})
.chain(std::iter::once(Ok("{ panic!(\"No match\"); }".to_string())))
.collect::<Result<Vec<_>, String>>()?
.join(" else ")
)
}
_ => return Err(format!("Implement {e}")),
})
}
Expand Down Expand Up @@ -315,6 +339,90 @@ impl<'a, T: FieldElement> CodeGenerator<'a, T> {
}
}

/// Used for patterns in match and let statements:
/// `value_name` is an expression string that is to be matched against `pattern`.
/// Returns a rust pattern string (tuple of bound variables, might be nested) and a code string
/// that, when executed, returns an Option with the values for the bound variables if the pattern
/// matched `value_name` and `None` otherwise.
///
/// So if `let (vars, code) = check_pattern("x", pattern)?;`, then the return value
/// can be used like this: `if let Some({vars}) = ({code}) {{ .. }}`
fn check_pattern(value_name: &str, pattern: &Pattern) -> Result<(String, String), String> {
Ok(match pattern {
Pattern::CatchAll(_) => ("()".to_string(), "Some(())".to_string()),
Pattern::Number(_, n) => (
"_".to_string(),
format!(
"({value_name}.clone() == {}).then_some(())",
format_signed_integer(n)
),
),
Pattern::String(_, s) => (
"_".to_string(),
format!("({value_name}.clone() == {}).then_some(())", quote(s)),
),
Pattern::Tuple(_, items) => {
let mut vars = vec![];
let inner_code = items
.iter()
.enumerate()
.map(|(i, item)| {
let (v, code) = check_pattern(&format!("{value_name}.{i}"), item)?;
vars.push(v);
Ok(format!("({code})?"))
})
.collect::<Result<Vec<_>, String>>()?
.join(", ");
(
format!("({})", vars.join(", ")),
format!("(|| Some(({inner_code})))()"),
)
}
Pattern::Array(_, items) => {
let mut vars = vec![];
let mut ellipsis_seen = false;
// This will be code to check the individual items in the array pattern.
let inner_code = items
.iter()
.enumerate()
.filter_map(|(i, item)| {
if matches!(item, Pattern::Ellipsis(_)) {
ellipsis_seen = true;
return None;
}
// Compute an expression to access the item.
Some(if ellipsis_seen {
let i_rev = items.len() - i;
(format!("{value_name}[{value_name}.len() - {i_rev}]"), item)
} else {
(format!("{value_name}[{i}]"), item)
})
})
.map(|(access_name, item)| {
let (v, code) = check_pattern(&access_name, item)?;
vars.push(v);
Ok(format!("({code})?"))
})
.collect::<Result<Vec<_>, String>>()?
.join(", ");
let length_check = if ellipsis_seen {
format!("{value_name}.len() >= {}", items.len() - 1)
} else {
format!("{value_name}.len() == {}", items.len())
};
(
format!("({})", vars.join(", ")),
format!("if {length_check} {{ (|| Some(({inner_code})))() }} else {{ None }}"),
)
}
Pattern::Variable(_, var) => (var.to_string(), format!("Some({value_name}.clone())")),
Pattern::Enum(..) => {
return Err(format!("Enums as patterns not yet implemented: {pattern}"));
}
Pattern::Ellipsis(_) => unreachable!(),
})
}

fn format_unsigned_integer(n: &BigUint) -> String {
if let Ok(n) = u64::try_from(n) {
format!("ibig::IBig::from({n}_u64)")
Expand All @@ -329,6 +437,17 @@ fn format_unsigned_integer(n: &BigUint) -> String {
}
}

fn format_signed_integer(n: &BigInt) -> String {
if let Ok(n) = BigUint::try_from(n) {
format_unsigned_integer(&n)
} else {
format!(
"-{}",
format_unsigned_integer(&BigUint::try_from(-n).unwrap())
)
}
}

fn map_type(ty: &Type) -> String {
match ty {
Type::Bottom | Type::Bool => format!("{ty}"),
Expand Down
96 changes: 96 additions & 0 deletions jit-compiler/tests/execution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,99 @@ fn simple_field() {
assert_eq!(r.call(2), 2);
assert_eq!(r.call(3), 3);
}

#[test]
fn match_number() {
let f = compile(
r#"let f: int -> int = |x| match x {
0 => 1,
1 => 2,
2 => 3,
_ => 0,
};"#,
"f",
);

assert_eq!(f.call(0), 1);
assert_eq!(f.call(1), 2);
assert_eq!(f.call(2), 3);
assert_eq!(f.call(3), 0);
}

#[test]
fn match_negative() {
let f = compile(
r#"let f: int -> int = |x| match -x {
-0 => 1,
-1 => 2,
-2 => 3,
_ => 9,
};"#,
"f",
);

assert_eq!(f.call(0), 1);
assert_eq!(f.call(1), 2);
assert_eq!(f.call(2), 3);
assert_eq!(f.call(3), 9);
}

#[test]
fn match_string() {
let f = compile(
r#"let f: int -> int = |x| match "abc" {
"ab" => 1,
"abc" => 2,
_ => 0,
};"#,
"f",
);

assert_eq!(f.call(0), 2);
assert_eq!(f.call(1), 2);
}

#[test]
fn match_tuples() {
let f = compile(
r#"let f: int -> int = |x| match (x, ("abc", x + 3)) {
(0, _) => 1,
(1, ("ab", _)) => 2,
(1, ("abc", t)) => t,
(a, (_, b)) => a + b,
};"#,
"f",
);

assert_eq!(f.call(0), 1);
assert_eq!(f.call(1), 4);
assert_eq!(f.call(2), 7);
assert_eq!(f.call(3), 9);
}

#[test]
fn match_array() {
let f = compile(
r#"let f: int -> int = |y| match (y, [1, 3, 3, 4]) {
(0, _) => 1,
(1, [1, 3]) => 20,
(1, [.., 2, 4]) => 20,
(1, [.., x, 4]) => x - 1,
(2, [x, .., 0]) => 22,
(2, [x, .., 4]) => x + 2,
(3, [1, 3, 3, 4, ..]) => 4,
(4, [1, 3, 3, 4]) => 5,
(5, [..]) => 6,
_ => 7
};"#,
"f",
);

assert_eq!(f.call(0), 1);
assert_eq!(f.call(1), 2);
assert_eq!(f.call(2), 3);
assert_eq!(f.call(3), 4);
assert_eq!(f.call(4), 5);
assert_eq!(f.call(5), 6);
assert_eq!(f.call(6), 7);
}

0 comments on commit eebbe5f

Please sign in to comment.