Skip to content

Commit

Permalink
row-change-check and more unary operators
Browse files Browse the repository at this point in the history
  • Loading branch information
bertiqwerty committed Jul 26, 2024
1 parent ca0455d commit a26e0e8
Show file tree
Hide file tree
Showing 8 changed files with 127 additions and 32 deletions.
43 changes: 20 additions & 23 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion rormula-rs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ version = "0.1.6"
edition = "2021"

[dependencies]
exmex = "0.20.1"
exmex = "0.20.2"
numpy = "0.21.0"

[features]
Expand Down
48 changes: 43 additions & 5 deletions rormula-rs/src/expression/expr_arithmetic.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use exmex::BinOp;
use exmex::Express;
use exmex::FlatEx;
use exmex::MakeOperators;
use exmex::Operator;
Expand Down Expand Up @@ -285,15 +286,54 @@ where
is_commutative: false,
},
),
Operator::make_unary("abs", |a| op_unary(a, &|x| x.abs())),
Operator::make_unary("sqrt", |a| op_unary(a, &|x| x.sqrt())),
Operator::make_unary("round", |a| op_unary(a, &|x| x.round())),
Operator::make_unary("floor", |a| op_unary(a, &|x| x.floor())),
Operator::make_unary("ceil", |a| op_unary(a, &|x| x.ceil())),
Operator::make_unary("trunc", |a| op_unary(a, &|x| x.trunc())),
Operator::make_unary("fract", |a| op_unary(a, &|x| x.fract())),
Operator::make_unary("sign", |a| op_unary(a, &|x| x.signum())),
Operator::make_unary("sin", |a| op_unary(a, &|x| x.sin())),
Operator::make_unary("cos", |a| op_unary(a, &|x| x.cos())),
Operator::make_unary("tan", |a| op_unary(a, &|x| x.tan())),
Operator::make_unary("asin", |a| op_unary(a, &|x| x.asin())),
Operator::make_unary("acos", |a| op_unary(a, &|x| x.acos())),
Operator::make_unary("atan", |a| op_unary(a, &|x| x.atan())),
Operator::make_unary("exp", |a| op_unary(a, &|x| x.exp())),
Operator::make_unary("ln", |a| op_unary(a, &|x| x.ln())),
Operator::make_unary("log", |a| op_unary(a, &|x| x.ln())),
Operator::make_unary("log2", |a| op_unary(a, &|x| x.log2())),
Operator::make_unary("log10", |a| op_unary(a, &|x| x.log10())),
]
}
}

pub type ExprArithmetic<M=DefaultOrder> = FlatEx<Value<M>, ArithmeticOpsFactory>;
const ROW_CHANGE_OPS: [&str; 1] = ["|"];


pub fn has_row_change_op(expr: &ExprArithmetic) -> bool {
expr.operator_reprs()
.iter()
.any(|o| ROW_CHANGE_OPS.contains(&o.as_str()))
}

pub type ExprArithmetic<M = DefaultOrder> = FlatEx<Value<M>, ArithmeticOpsFactory>;

#[cfg(test)]
use crate::array::ColMajor;
#[test]
fn keep_or_change_ops() {
let x = ExprArithmetic::parse("x + 1").unwrap();
assert!(!has_row_change_op(&x));
let x = ExprArithmetic::parse("x + y - 1 == 4").unwrap();
assert!(!has_row_change_op(&x));
let x = ExprArithmetic::parse("sin(x) + y - 2").unwrap();
assert!(!has_row_change_op(&x));
let x = ExprArithmetic::parse("sin(x)|y==2").unwrap();
assert!(has_row_change_op(&x));
}
#[test]
fn test() {
let a = Array2d::<ColMajor>::from_iter([0.0, 1.0, 2.0, 3.0, 4.0, 5.0].iter(), 3, 2).unwrap();
let a_ref = Array2d::from_iter([1.0, 2.0, 3.0, 4.0, 5.0, 6.0].iter(), 3, 2).unwrap();
Expand Down Expand Up @@ -353,10 +393,8 @@ fn test() {
);
let a_ref = Value::RowInds(vec![0]);
assert_eq!(res, a_ref);
let res: Value<ColMajor> = op_compare_equals(
Value::RowInds(vec![4, 3, 2]),
Value::RowInds(vec![1, 3, 7]),
);
let res: Value<ColMajor> =
op_compare_equals(Value::RowInds(vec![4, 3, 2]), Value::RowInds(vec![1, 3, 7]));
let a_ref = Value::RowInds(vec![1]);
assert_eq!(res, a_ref);

Expand Down
2 changes: 1 addition & 1 deletion rormula-rs/src/expression/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@ mod expr_wilkinson;
mod ops_common;
mod value;

pub use expr_arithmetic::ExprArithmetic;
pub use expr_arithmetic::{has_row_change_op, ExprArithmetic};
pub use expr_wilkinson::{ExprColCount, ExprNames, ExprWilkinson};
pub use value::{NameValue, Value};
3 changes: 3 additions & 0 deletions rormula/rormula/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,6 @@ def eval_asdf(self, data: pd.DataFrame):
else:
data = pd.DataFrame(data=resulting_data, columns=[self.name])
return data

def has_row_change_op(self) -> bool:
return self.ror.has_row_change_op()
2 changes: 1 addition & 1 deletion rormula/rormula/rormula.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def eval_wilkinson(
) -> Tuple[Optional[List[str]], np.ndarray]: ...

class Arithmetic:
pass
def has_row_change_op(self) -> bool: ...

def parse_arithmetic(s: str) -> Arithmetic: ...
def eval_arithmetic(
Expand Down
11 changes: 10 additions & 1 deletion rormula/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@ use pyo3::{
};
pub use rormula_rs::exmex::prelude::*;
pub use rormula_rs::exmex::ExError;
use rormula_rs::{array::Array2d, expression::ExprArithmetic};
use rormula_rs::{
array::Array2d,
expression::{has_row_change_op, ExprArithmetic},
};
use rormula_rs::{array::DefaultOrder, result::RoErr};
use rormula_rs::{
expression::{ExprColCount, ExprNames, ExprWilkinson, NameValue, Value},
Expand Down Expand Up @@ -234,6 +237,12 @@ fn parse_arithmetic(s: &str) -> PyResult<Arithmetic> {
struct Arithmetic {
expr: ExprArithmetic,
}
#[pymethods]
impl Arithmetic {
pub fn has_row_change_op(&self) -> PyResult<bool> {
Ok(has_row_change_op(&self.expr))
}
}

#[derive(Debug)]
#[pyclass]
Expand Down
48 changes: 48 additions & 0 deletions rormula/test/test_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,53 @@ def eval_asdf():
rormula = Arithmetic(s, "reduced")
res = rormula.eval_asdf(df)
assert np.allclose(res.to_numpy().item(), (5.0 - 2.5) / 4.0)
assert rormula.has_row_change_op()

def test_unary(repr: str, np_func):
s = f"{repr}( ((first_var|{{second.var}}==5.0) - (first_var|{{second.var}}==2.5)) / 4.0)"
data[7, :] = 5.0
df = pd.DataFrame(data=data[:10, :2], columns=["first_var", "second.var"])
rormula = Arithmetic(s, "reduced")
res = rormula.eval_asdf(df)
assert np.allclose(res.to_numpy().item(), np_func((5.0 - 2.5) / 4.0))
assert rormula.has_row_change_op()
s = f"{repr}(first_var) * {repr}({{second.var}})"
rormula = Arithmetic(s, "multiplied")
res = rormula.eval_asdf(df)
assert np.allclose(
res["multiplied"], np_func(df["first_var"]) * np_func(df["second.var"])
)

test_unary("abs", np.abs)
test_unary("floor", np.floor)
test_unary("ceil", np.ceil)
test_unary("sign", np.sign)
test_unary("sqrt", np.sqrt)
test_unary("exp", np.exp)
test_unary("log", np.log)
test_unary("log2", np.log2)
test_unary("log10", np.log10)
test_unary("sin", np.sin)
test_unary("cos", np.cos)
test_unary("tan", np.tan)

data = np.random.random((100, 1))
df = pd.DataFrame(data=data, columns=["alpha"])
df[df == 0.5] = 0.5001
s = "round(alpha)"
rormula = Arithmetic(s, "rounded")
res = rormula.eval_asdf(df)
assert np.allclose(res["rounded"], np.round(df["alpha"]))

def atest(s, np_func):
rormula = Arithmetic(f"{s}(alpha)", "atri")
res = rormula.eval_asdf(df)
assert np.allclose(res["atri"], np_func(df["alpha"]))

atest("asin", np.arcsin)
atest("acos", np.arccos)
atest("atan", np.arctan)
atest("sqrt", np.sqrt)


def test_scalar_scalar():
Expand All @@ -78,6 +125,7 @@ def test_scalar_scalar():
res = rormula.eval_asdf(df)
ref = df.eval(s)
np.allclose(res[name].to_numpy(), ref)
assert not rormula.has_row_change_op()


if __name__ == "__main__":
Expand Down

0 comments on commit a26e0e8

Please sign in to comment.