Skip to content

Commit

Permalink
rm all _f64 annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
vic1707 committed Nov 11, 2023
1 parent b1a51b8 commit 36cf61e
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 78 deletions.
64 changes: 25 additions & 39 deletions src/element/simplify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,45 +46,43 @@ impl<'a> Simplify<'a> for BinOp<'a> {
match self {
/////////////////////////// Additions ///////////////////////////
// 0 + .. => ..
BinOp { op: Plus, lhs, rhs } if lhs == Number(0.0_f64) => rhs,
BinOp { op: Plus, lhs, rhs } if lhs == Number(0.0) => rhs,
// .. + 0 => ..
BinOp { op: Plus, lhs, rhs } if rhs == Number(0.0_f64) => lhs,
BinOp { op: Plus, lhs, rhs } if rhs == Number(0.0) => lhs,
////// NIGHTLY FEATURES //////
#[cfg(NIGHTLY)]
// (-..) + .. => 0
BinOp {
op: Plus,
lhs: Element::UnOp(box UnOp { op: Minus, operand }),
rhs,
} if operand == rhs => Number(0.0_f64),
} if operand == rhs => Number(0.0),
#[cfg(NIGHTLY)]
// .. + (-..) => 0
BinOp {
op: Plus,
lhs,
rhs: Element::UnOp(box UnOp { op: Minus, operand }),
} if lhs == operand => Number(0.0_f64),
} if lhs == operand => Number(0.0),
////////////////////////// Subtractions /////////////////////////
// 0 - .. => -..
BinOp {
op: Minus,
lhs,
rhs,
} if lhs == Number(0.0_f64) => {
UnOp::new_element(Operator::Minus, rhs)
},
} if lhs == Number(0.0) => UnOp::new_element(Operator::Minus, rhs),
// .. - 0 => ..
BinOp {
op: Minus,
lhs,
rhs,
} if rhs == Number(0.0_f64) => lhs,
} if rhs == Number(0.0) => lhs,
// .. - .. => 0
BinOp {
op: Minus,
lhs,
rhs,
} if lhs == rhs => Number(0.0_f64),
} if lhs == rhs => Number(0.0),
////// NIGHTLY FEATURES //////
#[cfg(NIGHTLY)]
// .. - (-..) => .. + ..
Expand All @@ -95,104 +93,92 @@ impl<'a> Simplify<'a> for BinOp<'a> {
} => BinOp::new_element(Operator::Plus, lhs, operand),
//////////////////////// Multiplications ////////////////////////
// 0 * .. => 0
BinOp { op: Times, lhs, .. } if lhs == Number(0.0_f64) => {
Number(0.0_f64)
},
BinOp { op: Times, lhs, .. } if lhs == Number(0.0) => Number(0.0),
// .. * 0 => 0
BinOp { op: Times, rhs, .. } if rhs == Number(0.0_f64) => {
Number(0.0_f64)
},
BinOp { op: Times, rhs, .. } if rhs == Number(0.0) => Number(0.0),
// 1 * .. => ..
BinOp {
op: Times,
lhs,
rhs,
} if lhs == Number(1.0_f64) => rhs,
} if lhs == Number(1.0) => rhs,
// .. * 1 => ..
BinOp {
op: Times,
lhs,
rhs,
} if rhs == Number(1.0_f64) => lhs,
} if rhs == Number(1.0) => lhs,
/////////////////////////// Divisions ///////////////////////////
// 0/0 => NaN // special case
BinOp {
op: Divide,
lhs,
rhs,
} if lhs == Number(0.0_f64) && rhs == Number(0.0_f64) => {
Number(f64::NAN)
},
} if lhs == Number(0.0) && rhs == Number(0.0) => Number(f64::NAN),
// 0 / .. => 0
BinOp {
op: Divide, lhs, ..
} if lhs == Number(0.0_f64) => Number(0.0_f64),
} if lhs == Number(0.0) => Number(0.0),
// .. / 0 => inf
BinOp {
op: Divide, rhs, ..
} if rhs == Number(0.0_f64) => Number(f64::INFINITY),
} if rhs == Number(0.0) => Number(f64::INFINITY),
// .. / 1 => ..
BinOp {
op: Divide,
lhs,
rhs,
} if rhs == Number(1.0_f64) => lhs,
} if rhs == Number(1.0) => lhs,
// .. / .. => 1
BinOp {
op: Divide,
lhs,
rhs,
} if lhs == rhs => Number(1.0_f64),
} if lhs == rhs => Number(1.0),
///////////////////////////// Powers ////////////////////////////
// 0 ^ 0 => 1 // special case
BinOp {
op: Power,
lhs,
rhs,
} if lhs == Number(0.0_f64) && rhs == Number(0.0_f64) => {
Number(1.0_f64)
},
} if lhs == Number(0.0) && rhs == Number(0.0) => Number(1.0),
// 0 ^ .. => 0
BinOp { op: Power, lhs, .. } if lhs == Number(0.0_f64) => {
Number(0.0_f64)
},
BinOp { op: Power, lhs, .. } if lhs == Number(0.0) => Number(0.0),
// .. ^ 0 => 1
BinOp {
op: Divide, rhs, ..
} if rhs == Number(0.0_f64) => Number(1.0_f64),
} if rhs == Number(0.0) => Number(1.0),
// .. ^ 1 => ..
BinOp {
op: Power,
lhs,
rhs,
} if rhs == Number(1.0_f64) => lhs,
} if rhs == Number(1.0) => lhs,
//////////////////////////// Modulos ////////////////////////////
// 0 % 0 => NaN // special case
BinOp {
op: Modulo,
lhs,
rhs,
} if lhs == Number(0.0_f64) && rhs == Number(0.0_f64) => {
Number(f64::NAN)
},
} if lhs == Number(0.0) && rhs == Number(0.0) => Number(f64::NAN),
// 0 % .. => 0
BinOp {
op: Modulo, lhs, ..
} if lhs == Number(0.0_f64) => Number(0.0_f64),
} if lhs == Number(0.0) => Number(0.0),
// .. % 0 => NaN
BinOp {
op: Modulo, rhs, ..
} if rhs == Number(0.0_f64) => Number(f64::NAN),
} if rhs == Number(0.0) => Number(f64::NAN),
// .. % 1 => 0
BinOp {
op: Modulo, rhs, ..
} if rhs == Number(1.0_f64) => Number(0.0_f64),
} if rhs == Number(1.0) => Number(0.0),
// .. % .. => 0
BinOp {
op: Modulo,
lhs,
rhs,
} if lhs == rhs => Number(0.0_f64),
} if lhs == rhs => Number(0.0),
/////////////////////////// 2 Numbers ///////////////////////////
BinOp {
op,
Expand Down
8 changes: 4 additions & 4 deletions src/tests/parser/ctx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::{
};

fn double(x: f64) -> f64 {
x * 2.0_f64
x * 2.0
}
const DOUBLE: Function = xprs_fn!("DOUBLE", double, 1);
fn add(x: f64, y: f64) -> f64 {
Expand All @@ -23,15 +23,15 @@ const MEAN: Function = xprs_fn!("MEAN", mean);
fn get_parser_with_ctx() -> Parser<'static> {
let mut ctx = Context::default();

ctx.add_var("x", 2.0_f64);
ctx.add_var("phi", 1.618_033_988_749_895_f64);
ctx.add_var("x", 2.0);
ctx.add_var("phi", 1.618_033_988_749_895);

ctx.add_func("double", DOUBLE);
ctx.add_func("add", ADD);

let mut parser = Parser::new_with_ctx(ctx);

parser.ctx_mut().add_var("y", 1.0_f64);
parser.ctx_mut().add_var("y", 1.0);
parser.ctx_mut().add_func("mean", MEAN);

parser
Expand Down
60 changes: 28 additions & 32 deletions src/tests/xprs/eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ use crate::Parser;

const ERROR_MARGIN: f64 = f64::EPSILON;

// shitty type because of clippy and default numeric fallback
// https://github.com/rust-lang/rust-clippy/issues/11535
type InputVarsResult = (&'static str, &'static [(&'static str, f64)], f64);
/// 2(3)
/// 2 + 3 ^ 2 * 3 + 4
/// 2^2^(2^2 + 1)
Expand All @@ -16,43 +19,33 @@ const ERROR_MARGIN: f64 = f64::EPSILON;
/// PEMDAS vs PEJMDAS
/// 6/2(2+1)
/// 1/2x
fn get_valid_test_cases(
) -> [(&'static str, HashMap<&'static str, f64>, f64); 10] {
[
("2(3)", [].into(), 6.0_f64),
("2 + 3 ^ 2 * 3 + 4", [].into(), 33.0_f64),
("2^2^(2^2 + 1)", [].into(), 1024.0_f64),
(
"2 * x + 3y + 4x + 5",
[("x", 2.0_f64), ("y", 3.0_f64)].into(),
26.0_f64,
),
("sin(-cos(2))", [].into(), 0.404_239_153_852_265_8_f64),
("sin(2)^2", [].into(), 0.826_821_810_431_806_f64),
(
"12 + 3-1x+3y x",
[("x", 2.0_f64), ("y", 3.0_f64)].into(),
31.0_f64,
),
("2 (3+4) 5", [].into(), 70.0_f64),
#[cfg(feature = "pemdas")]
("6/2(2+1)", [].into(), 9.0_f64), // is "6/2*(2+1)"
#[cfg(feature = "pejmdas")]
("6/2(2+1)", [].into(), 1.0_f64), // is "6/(2*(2+1))"
#[cfg(feature = "pemdas")]
("1/2x", [("x", 2.0_f64)].into(), 1_f64), // is "(1/2)*x"
#[cfg(feature = "pejmdas")]
("1/2x", [("x", 2.0_f64)].into(), 0.25_f64), // is "1/(2*x)"
]
}
const VALID: [InputVarsResult; 10] = [
("2(3)", &[], 6.0),
("2 + 3 ^ 2 * 3 + 4", &[], 33.0),
("2^2^(2^2 + 1)", &[], 1024.0),
("2 * x + 3y + 4x + 5", &[("x", 2.0), ("y", 3.0)], 26.0),
("sin(-cos(2))", &[], 0.404_239_153_852_265_8),
("sin(2)^2", &[], 0.826_821_810_431_806),
("12 + 3-1x+3y x", &[("x", 2.0), ("y", 3.0)], 31.0),
("2 (3+4) 5", &[], 70.0),
#[cfg(feature = "pemdas")]
("6/2(2+1)", &[], 9.0), // is "6/2*(2+1)"
#[cfg(feature = "pejmdas")]
("6/2(2+1)", &[], 1.0), // is "6/(2*(2+1))"
#[cfg(feature = "pemdas")]
("1/2x", &[("x", 2.0)], 1.0), // is "(1/2)*x"
#[cfg(feature = "pejmdas")]
("1/2x", &[("x", 2.0)], 0.25), // is "1/(2*x)"
];

#[test]
fn test_valid_eval() {
let parser = Parser::default();

for (input, vars, expected) in get_valid_test_cases() {
for (input, vars, expected) in VALID {
let var_map: HashMap<&str, f64> = vars.iter().copied().collect();
let xprs = parser.parse(input).unwrap();
let result = xprs.eval(&vars).unwrap();
let result = xprs.eval(&var_map).unwrap();
assert!(
(result - expected).abs() < ERROR_MARGIN,
"{input}\nExpected: {expected}, got: {result}"
Expand All @@ -62,10 +55,13 @@ fn test_valid_eval() {

#[test]
fn test_invalid_eval() {
// this var needs to be declared here because of clippy and default numeric fallback
// https://github.com/rust-lang/rust-clippy/issues/11535
const VARS: [(&str, f64); 1] = [("x", 2.0)];
let parser = Parser::default();

let xprs = parser.parse("2 * x + 3y + 4x + 5").unwrap();
let result = xprs.eval(&[("x", 2.0_f64)].into());
let result = xprs.eval(&VARS.into());
assert!(
result.is_err(),
"Should have failed because `y` is not provided"
Expand Down
11 changes: 8 additions & 3 deletions src/tests/xprs/simplify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,20 @@ use crate::Parser;

#[test]
fn test_simplify() {
// these vars need to be declared here because of clippy and default numeric fallback
// https://github.com/rust-lang/rust-clippy/issues/11535
const X_VAR: (&str, f64) = ("x", 2.0);
const Y_VAR: (&str, f64) = ("y", 3.0);
const UNKNOWN_VAR: (&str, f64) = ("unknown", 4.0);
let parser = Parser::default();
let mut xprs = parser.parse("2x + 3y + 4x + 5z").unwrap();
// simplify for x
xprs = xprs.simplify_for(("x", 2.0_f64));
xprs = xprs.simplify_for(X_VAR);
assert_eq!(xprs, parser.parse("4 + 3y + 8 + 5z").unwrap());
// simplify for y
xprs.simplify_for_inplace(("y", 3.0_f64));
xprs.simplify_for_inplace(Y_VAR);
assert_eq!(xprs, parser.parse("21 + 5z").unwrap());
// try simplifying for an unknown variable
xprs.simplify_for_inplace(("unknown", 4.0_f64));
xprs.simplify_for_inplace(UNKNOWN_VAR);
assert_eq!(xprs, parser.parse("21 + 5z").unwrap());
}

0 comments on commit 36cf61e

Please sign in to comment.