Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
bpowers committed Jul 28, 2024
1 parent 3672642 commit 7fdab48
Show file tree
Hide file tree
Showing 7 changed files with 312 additions and 104 deletions.
2 changes: 1 addition & 1 deletion src/engine/error_codes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ export function errorCodeDescription(code: ErrorCode): string {
case ErrorCode.UnknownBuiltin:
return 'Reference to unknown or unimplemented builtin';
case ErrorCode.BadBuiltinArgs:
return 'Builtin function arguments';
return 'Incorrect arguments to a builtin function (e.g. too many, too few)';
case ErrorCode.EmptyEquation:
return 'Variable has empty equation';
case ErrorCode.BadModuleInputDst:
Expand Down
58 changes: 54 additions & 4 deletions src/simlin-engine/src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,18 @@ impl Expr {
let a = args.remove(0);
BuiltinFn::$builtin_fn(Box::new(a), Box::new(b))
}};
($builtin_fn:tt, 1, 2) => {{
if args.len() == 1 {
let a = args.remove(0);
BuiltinFn::$builtin_fn(Box::new(a), None)
} else if args.len() == 2 {
let b = args.remove(1);
let a = args.remove(0);
BuiltinFn::$builtin_fn(Box::new(a), Some(Box::new(b)))
} else {
return eqn_err!(BadBuiltinArgs, loc.start, loc.end);
}
}};
($builtin_fn:tt, 3) => {{
if args.len() != 3 {
return eqn_err!(BadBuiltinArgs, loc.start, loc.end);
Expand All @@ -339,6 +351,26 @@ impl Expr {
let a = args.remove(0);
BuiltinFn::$builtin_fn(Box::new(a), Box::new(b), Box::new(c))
}};
($builtin_fn:tt, 1, 3) => {{
if args.len() == 1 {
let a = args.remove(0);
BuiltinFn::$builtin_fn(Box::new(a), None)
} else if args.len() == 2 {
let b = args.remove(1);
let a = args.remove(0);
BuiltinFn::$builtin_fn(Box::new(a), Some((Box::new(b), None)))
} else if args.len() == 3 {
let c = args.remove(2);
let b = args.remove(1);
let a = args.remove(0);
BuiltinFn::$builtin_fn(
Box::new(a),
Some((Box::new(b), Some(Box::new(c)))),
)
} else {
return eqn_err!(BadBuiltinArgs, loc.start, loc.end);
}
}};
($builtin_fn:tt, 2, 3) => {{
if args.len() == 2 {
let b = args.remove(1);
Expand Down Expand Up @@ -381,8 +413,8 @@ impl Expr {
}
"ln" => check_arity!(Ln, 1),
"log10" => check_arity!(Log10, 1),
"max" => check_arity!(Max, 2),
"min" => check_arity!(Min, 2),
"max" => check_arity!(Max, 1, 2),
"min" => check_arity!(Min, 1, 2),
"pi" => check_arity!(Pi, 0),
"pulse" => check_arity!(Pulse, 2, 3),
"ramp" => check_arity!(Ramp, 2, 3),
Expand All @@ -395,6 +427,10 @@ impl Expr {
"time_step" | "dt" => check_arity!(TimeStep, 0),
"initial_time" => check_arity!(StartTime, 0),
"final_time" => check_arity!(FinalTime, 0),
"rank" => check_arity!(Rank, 1, 3),
"size" => check_arity!(Size, 1),
"stddev" => check_arity!(Stddev, 1),
"sum" => check_arity!(Sum, 1),
_ => {
// TODO: this could be a table reference, array reference,
// or module instantiation according to 3.3.2 of the spec
Expand Down Expand Up @@ -468,11 +504,11 @@ impl Expr {
),
BuiltinFn::Max(a, b) => BuiltinFn::Max(
Box::new(a.constify_dimensions(scope)),
Box::new(b.constify_dimensions(scope)),
b.map(|expr| Box::new(expr.constify_dimensions(scope))),
),
BuiltinFn::Min(a, b) => BuiltinFn::Min(
Box::new(a.constify_dimensions(scope)),
Box::new(b.constify_dimensions(scope)),
b.map(|expr| Box::new(expr.constify_dimensions(scope))),
),
BuiltinFn::Step(a, b) => BuiltinFn::Step(
Box::new(a.constify_dimensions(scope)),
Expand All @@ -497,6 +533,20 @@ impl Expr {
Box::new(b.constify_dimensions(scope)),
c.map(|arg| Box::new(arg.constify_dimensions(scope))),
),
BuiltinFn::Rank(a, rest) => BuiltinFn::Rank(
Box::new(a.constify_dimensions(scope)),
rest.map(|(b, c)| {
(
Box::new(b.constify_dimensions(scope)),
c.map(|c| Box::new(c.constify_dimensions(scope))),
)
}),
),
BuiltinFn::Size(a) => BuiltinFn::Size(Box::new(a.constify_dimensions(scope))),
BuiltinFn::Stddev(a) => {
BuiltinFn::Stddev(Box::new(a.constify_dimensions(scope)))
}
BuiltinFn::Sum(a) => BuiltinFn::Sum(Box::new(a.constify_dimensions(scope))),
};
Expr::App(func, loc)
}
Expand Down
143 changes: 92 additions & 51 deletions src/simlin-engine/src/builtins.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,11 @@ pub enum BuiltinFn<Expr> {
IsModuleInput(String, Loc),
Ln(Box<Expr>),
Log10(Box<Expr>),
Max(Box<Expr>, Box<Expr>),
// max takes 2 scalar args OR 1-2 args for an array
Max(Box<Expr>, Option<Box<Expr>>),
Mean(Vec<Expr>),
Min(Box<Expr>, Box<Expr>),
// max takes 2 scalar args OR 1-2 args for an array
Min(Box<Expr>, Option<Box<Expr>>),
Pi,
Pulse(Box<Expr>, Box<Expr>, Option<Box<Expr>>),
Ramp(Box<Expr>, Box<Expr>, Option<Box<Expr>>),
Expand All @@ -80,38 +82,49 @@ pub enum BuiltinFn<Expr> {
TimeStep,
StartTime,
FinalTime,
// array-only builtins
Rank(Box<Expr>, Option<(Box<Expr>, Option<Box<Expr>>)>),
Size(Box<Expr>),
Stddev(Box<Expr>),
Sum(Box<Expr>),
}

impl<Expr> BuiltinFn<Expr> {
pub fn name(&self) -> &'static str {
use BuiltinFn::*;
match self {
BuiltinFn::Lookup(_, _, _) => "lookup",
BuiltinFn::Abs(_) => "abs",
BuiltinFn::Arccos(_) => "arccos",
BuiltinFn::Arcsin(_) => "arcsin",
BuiltinFn::Arctan(_) => "arctan",
BuiltinFn::Cos(_) => "cos",
BuiltinFn::Exp(_) => "exp",
BuiltinFn::Inf => "inf",
BuiltinFn::Int(_) => "int",
BuiltinFn::IsModuleInput(_, _) => "ismoduleinput",
BuiltinFn::Ln(_) => "ln",
BuiltinFn::Log10(_) => "log10",
BuiltinFn::Max(_, _) => "max",
BuiltinFn::Mean(_) => "mean",
BuiltinFn::Min(_, _) => "min",
BuiltinFn::Pi => "pi",
BuiltinFn::Pulse(_, _, _) => "pulse",
BuiltinFn::Ramp(_, _, _) => "ramp",
BuiltinFn::SafeDiv(_, _, _) => "safediv",
BuiltinFn::Sin(_) => "sin",
BuiltinFn::Sqrt(_) => "sqrt",
BuiltinFn::Step(_, _) => "step",
BuiltinFn::Tan(_) => "tan",
BuiltinFn::Time => "time",
BuiltinFn::TimeStep => "time_step",
BuiltinFn::StartTime => "initial_time",
BuiltinFn::FinalTime => "final_time",
Lookup(_, _, _) => "lookup",
Abs(_) => "abs",
Arccos(_) => "arccos",
Arcsin(_) => "arcsin",
Arctan(_) => "arctan",
Cos(_) => "cos",
Exp(_) => "exp",
Inf => "inf",
Int(_) => "int",
IsModuleInput(_, _) => "ismoduleinput",
Ln(_) => "ln",
Log10(_) => "log10",
Max(_, _) => "max",
Mean(_) => "mean",
Min(_, _) => "min",
Pi => "pi",
Pulse(_, _, _) => "pulse",
Ramp(_, _, _) => "ramp",
SafeDiv(_, _, _) => "safediv",
Sin(_) => "sin",
Sqrt(_) => "sqrt",
Step(_, _) => "step",
Tan(_) => "tan",
Time => "time",
TimeStep => "time_step",
StartTime => "initial_time",
FinalTime => "final_time",
// array only builtins
Rank(_, _) => "rank",
Size(_) => "size",
Stddev(_) => "stddev",
Sum(_) => "sum",
}
}
}
Expand All @@ -127,27 +140,33 @@ pub fn is_builtin_fn(name: &str) -> bool {
is_0_arity_builtin_fn(name)
|| matches!(
name,
// scalar builtins
"lookup"
| "abs"
| "arccos"
| "arcsin"
| "arctan"
| "cos"
| "exp"
| "int"
| "ismoduleinput"
| "ln"
| "log10"
| "max"
| "mean"
| "min"
| "pulse"
| "ramp"
| "safediv"
| "sin"
| "sqrt"
| "step"
| "tan"
| "abs"
| "arccos"
| "arcsin"
| "arctan"
| "cos"
| "exp"
| "int"
| "ismoduleinput"
| "ln"
| "log10"
| "max"
| "mean"
| "min"
| "pulse"
| "ramp"
| "safediv"
| "sin"
| "sqrt"
| "step"
| "tan"
// array-only builtins
| "rank"
| "size"
| "stddev"
| "sum"
)
}

Expand Down Expand Up @@ -183,14 +202,23 @@ where
| BuiltinFn::Log10(a)
| BuiltinFn::Sin(a)
| BuiltinFn::Sqrt(a)
| BuiltinFn::Tan(a) => cb(BuiltinContents::Expr(a)),
| BuiltinFn::Tan(a)
| BuiltinFn::Size(a)
| BuiltinFn::Stddev(a)
| BuiltinFn::Sum(a) => cb(BuiltinContents::Expr(a)),
BuiltinFn::Mean(args) => {
args.iter().for_each(|a| cb(BuiltinContents::Expr(a)));
}
BuiltinFn::Max(a, b) | BuiltinFn::Min(a, b) | BuiltinFn::Step(a, b) => {
BuiltinFn::Step(a, b) => {
cb(BuiltinContents::Expr(a));
cb(BuiltinContents::Expr(b));
}
BuiltinFn::Max(a, b) | BuiltinFn::Min(a, b) => {
cb(BuiltinContents::Expr(a));
if let Some(b) = b {
cb(BuiltinContents::Expr(b));
}
}
BuiltinFn::Pulse(a, b, c) | BuiltinFn::Ramp(a, b, c) | BuiltinFn::SafeDiv(a, b, c) => {
cb(BuiltinContents::Expr(a));
cb(BuiltinContents::Expr(b));
Expand All @@ -199,6 +227,15 @@ where
None => {}
}
}
BuiltinFn::Rank(a, rest) => {
cb(BuiltinContents::Expr(a));
if let Some((b, c)) = rest {
cb(BuiltinContents::Expr(b));
if let Some(c) = c {
cb(BuiltinContents::Expr(c));
}
}
}
}
}

Expand All @@ -207,6 +244,10 @@ fn test_is_builtin_fn() {
assert!(is_builtin_fn("lookup"));
assert!(!is_builtin_fn("lookupz"));
assert!(is_builtin_fn("log10"));
assert!(is_builtin_fn("sum"));
assert!(is_builtin_fn("rank"));
assert!(is_builtin_fn("size"));
assert!(is_builtin_fn("stddev"));
}

#[test]
Expand Down
2 changes: 2 additions & 0 deletions src/simlin-engine/src/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ pub enum ErrorCode {
TodoWildcard,
TodoStarRange,
TodoRange,
TodoArrayBuiltin,
}

impl fmt::Display for ErrorCode {
Expand Down Expand Up @@ -125,6 +126,7 @@ impl fmt::Display for ErrorCode {
TodoWildcard => "todo_wildcard",
TodoStarRange => "todo_star_range",
TodoRange => "todo_range",
TodoArrayBuiltin => "todo_array_builtin",
};

write!(f, "{}", name)
Expand Down
Loading

0 comments on commit 7fdab48

Please sign in to comment.