diff --git a/Cargo.toml b/Cargo.toml index 086bda8..b2555a0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,6 +14,7 @@ name = "repl" path = "src/bin/repl.rs" [features] -default = ["pemdas"] +default = ["pemdas", "compile-time-optimizations"] pemdas = [] pejmdas = [] +compile-time-optimizations = [] diff --git a/src/element/comp_time.rs b/src/element/comp_time.rs new file mode 100644 index 0000000..7e97786 --- /dev/null +++ b/src/element/comp_time.rs @@ -0,0 +1,211 @@ +use super::{BinOp, Element, FunctionCall, UnOp}; +use crate::token::Operator; + +pub trait CompTime<'a> { + fn simplify(self) -> Element<'a>; +} + +impl<'a> CompTime<'a> for Element<'a> { + #[inline] + fn simplify(self) -> Self { + match self { + Self::BinOp(binop) => binop.simplify(), + Self::UnOp(unop) => unop.simplify(), + Self::Function(func) => func.simplify(), + Self::Number(_) | Self::Variable(_) => self, + } + } +} + +impl<'a> CompTime<'a> for BinOp<'a> { + #[inline] + #[allow(clippy::too_many_lines)] + fn simplify(self) -> Element<'a> { + use Element::Number; + use Operator::{Divide, Minus, Modulo, Plus, Power, Times}; + match self { + /////////////////////////// Additions /////////////////////////// + // 0 + .. => .. + BinOp { op: Plus, lhs, rhs } if lhs == Number(0.0_f64) => { + rhs.simplify() + }, + // .. + 0 => .. + BinOp { op: Plus, lhs, rhs } if rhs == Number(0.0_f64) => { + lhs.simplify() + }, + ////////////////////////// Subtractions ///////////////////////// + // 0 - .. => -.. + BinOp { + op: Minus, + lhs, + rhs, + } if lhs == Number(0.0_f64) => { + UnOp::new(Operator::Minus, rhs.simplify()).simplify() + }, + // .. - 0 => .. + BinOp { + op: Minus, + lhs, + rhs, + } if rhs == Number(0.0_f64) => lhs.simplify(), + // .. - .. => 0 + BinOp { + op: Minus, + lhs, + rhs, + } if lhs == rhs => Number(0.0_f64), + //////////////////////// Multiplications //////////////////////// + // 0 * .. => 0 + BinOp { op: Times, lhs, .. } if lhs == Number(0.0_f64) => { + Number(0.0_f64) + }, + // .. * 0 => 0 + BinOp { op: Times, rhs, .. } if rhs == Number(0.0_f64) => { + Number(0.0_f64) + }, + // 1 * .. => .. + BinOp { + op: Times, + lhs, + rhs, + } if lhs == Number(1.0_f64) => rhs.simplify(), + // .. * 1 => .. + BinOp { + op: Times, + lhs, + rhs, + } if rhs == Number(1.0_f64) => lhs.simplify(), + /////////////////////////// Divisions /////////////////////////// + // 0 / .. => 0 + BinOp { + op: Divide, lhs, .. + } if lhs == Number(0.0_f64) => Number(0.0_f64), + // .. / 0 => inf + BinOp { + op: Divide, rhs, .. + } if rhs == Number(0.0_f64) => Number(f64::INFINITY), + // .. / 1 => .. + BinOp { + op: Divide, + lhs, + rhs, + } if rhs == Number(1.0_f64) => lhs.simplify(), + // .. / .. => 1 + BinOp { + op: Divide, + lhs, + rhs, + } if lhs == rhs => Number(1.0_f64), + ///////////////////////////// Powers //////////////////////////// + // 0 ^ .. => 0 + BinOp { op: Power, lhs, .. } if lhs == Number(0.0_f64) => { + Number(0.0_f64) + }, + // .. ^ 0 => 1 + BinOp { + op: Divide, rhs, .. + } if rhs == Number(0.0_f64) => Number(1.0_f64), + // .. ^ 1 => .. + BinOp { + op: Power, + lhs, + rhs, + } if rhs == Number(1.0_f64) => lhs.simplify(), + //////////////////////////// Modulos //////////////////////////// + // 0 % .. => 0 + BinOp { + op: Modulo, lhs, .. + } if lhs == Number(0.0_f64) => Number(0.0_f64), + // .. % 0 => NaN + BinOp { + op: Modulo, rhs, .. + } if rhs == Number(0.0_f64) => Number(f64::NAN), + // .. % 1 => 0 + BinOp { + op: Modulo, rhs, .. + } if rhs == Number(1.0_f64) => Number(0.0_f64), + // .. % .. => 0 + BinOp { + op: Modulo, + lhs, + rhs, + } if lhs == rhs => Number(0.0_f64), + // other + BinOp { + op, + rhs: Number(rhs), + lhs: Number(lhs), + } => { + let result = match op { + Plus => lhs + rhs, + Minus => lhs - rhs, + Times => lhs * rhs, + Divide => lhs / rhs, + Power => lhs.powf(rhs), + #[allow(clippy::modulo_arithmetic)] + Modulo => lhs % rhs, + }; + Number(result) + }, + BinOp { op, rhs, lhs } => Element::BinOp(Box::new(BinOp::new( + op, + lhs.simplify(), + rhs.simplify(), + ))), + } + } +} + +impl<'a> CompTime<'a> for UnOp<'a> { + #[inline] + fn simplify(self) -> Element<'a> { + let operand = self.operand.simplify(); + #[allow(clippy::unreachable)] + match self.op { + Operator::Plus => operand, + Operator::Minus => match operand { + Element::Number(num) => Element::Number(-num), + Element::UnOp(unop) => match unop.op { + Operator::Plus => Element::UnOp(Box::new(UnOp::new( + Operator::Minus, + unop.operand, + ))), + Operator::Minus => unop.operand, + Operator::Times + | Operator::Divide + | Operator::Power + | Operator::Modulo => unreachable!(), + }, + Element::BinOp(_) + | Element::Function(_) + | Element::Variable(_) => Element::UnOp(Box::new(UnOp { + op: self.op, + operand, + })), + }, + Operator::Times + | Operator::Divide + | Operator::Power + | Operator::Modulo => unreachable!(), + } + } +} + +impl<'a> CompTime<'a> for FunctionCall<'a> { + #[inline] + fn simplify(self) -> Element<'a> { + let arg = self.arg.simplify(); + match arg { + Element::Number(num) => Element::Number((self.func)(num)), + Element::BinOp(_) + | Element::UnOp(_) + | Element::Function(_) + | Element::Variable(_) => { + Element::Function(Box::new(FunctionCall { + func: self.func, + arg, + })) + }, + } + } +} diff --git a/src/element/mod.rs b/src/element/mod.rs index 059b81a..e5b0454 100644 --- a/src/element/mod.rs +++ b/src/element/mod.rs @@ -2,10 +2,12 @@ use core::fmt; /* Modules */ mod binop; +mod comp_time; mod function_call; mod unop; /* Exports */ pub use binop::BinOp; +pub use comp_time::CompTime; pub use function_call::FunctionCall; pub use unop::UnOp; diff --git a/src/parser.rs b/src/parser.rs index 985923e..4f98c7f 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -6,7 +6,7 @@ use std::collections::HashSet; /* Crate imports */ use crate::{ context::Context, - element::{BinOp, Element, FunctionCall, UnOp}, + element::{BinOp, CompTime, Element, FunctionCall, UnOp}, macros::{trust_me, yeet}, token::{Identifier, Operator}, utils::precedence, @@ -112,6 +112,11 @@ impl<'a> ParserImpl<'a> { el = Element::BinOp(Box::new(BinOp::new(op, el, rhs))); } + #[cfg(feature = "compile-time-optimizations")] + { + el = el.simplify(); + }; + Ok(el) }