Skip to content

Commit

Permalink
refactor(rust): Remove implicit reverse from AExpr::replace_inputs() (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
orlp authored Jan 10, 2025
1 parent 04ad50d commit 17556e4
Show file tree
Hide file tree
Showing 9 changed files with 37 additions and 47 deletions.
2 changes: 1 addition & 1 deletion crates/polars-plan/src/dsl/meta.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ impl MetaNameSpace {
let node = to_aexpr(self.0, &mut arena)?;
let ae = arena.get(node);
let mut inputs = Vec::with_capacity(2);
ae.nodes(&mut inputs);
ae.inputs_rev(&mut inputs);
Ok(inputs
.iter()
.map(|node| node_to_expr(*node, &arena))
Expand Down
4 changes: 2 additions & 2 deletions crates/polars-plan/src/plans/aexpr/properties.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,9 @@ pub fn is_elementwise(stack: &mut UnitVec<Node>, ae: &AExpr, expr_arena: &Arena<
}
};

ae.nodes(stack);
ae.inputs_rev(stack);
})(),
_ => ae.nodes(stack),
_ => ae.inputs_rev(stack),
}

true
Expand Down
57 changes: 23 additions & 34 deletions crates/polars-plan/src/plans/aexpr/traverse.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use super::*;

impl AExpr {
/// Push nodes at this level to a pre-allocated stack.
pub(crate) fn nodes<E>(&self, container: &mut E)
/// Push the inputs of this node to the given container, in reverse order.
/// This ensures the primary node responsible for the name is pushed last.
pub(crate) fn inputs_rev<E>(&self, container: &mut E)
where
E: Extend<Node>,
{
Expand All @@ -12,7 +13,6 @@ impl AExpr {
Column(_) | Literal(_) | Len => {},
Alias(e, _) => container.extend([*e]),
BinaryExpr { left, op: _, right } => {
// reverse order so that left is popped first
container.extend([*right, *left]);
},
Cast { expr, .. } => container.extend([*expr]),
Expand All @@ -21,16 +21,15 @@ impl AExpr {
container.extend([*idx, *expr]);
},
SortBy { expr, by, .. } => {
container.extend(by.iter().cloned());
// latest, so that it is popped first
container.extend(by.iter().cloned().rev());
container.extend([*expr]);
},
Filter { input, by } => {
container.extend([*by, *input]);
},
Agg(agg_e) => match agg_e.get_input() {
NodeInputs::Single(node) => container.extend([node]),
NodeInputs::Many(nodes) => container.extend(nodes),
NodeInputs::Many(nodes) => container.extend(nodes.into_iter().rev()),
NodeInputs::Leaf => {},
},
Ternary {
Expand All @@ -40,10 +39,7 @@ impl AExpr {
} => {
container.extend([*predicate, *falsy, *truthy]);
},
AnonymousFunction { input, .. } | Function { input, .. } =>
// we iterate in reverse order, so that the lhs is popped first and will be found
// as the root columns/ input columns by `_suffix` and `_keep_name` etc.
{
AnonymousFunction { input, .. } | Function { input, .. } => {
container.extend(input.iter().rev().map(|e| e.node()))
},
Explode(e) => container.extend([*e]),
Expand All @@ -56,10 +52,7 @@ impl AExpr {
if let Some((n, _)) = order_by {
container.extend([*n]);
}

container.extend(partition_by.iter().rev().cloned());

// latest so that it is popped first
container.extend([*function]);
},
Slice {
Expand All @@ -80,25 +73,25 @@ impl AExpr {
Cast { expr, .. } => expr,
Explode(input) => input,
BinaryExpr { left, right, .. } => {
*right = inputs[0];
*left = inputs[1];
*left = inputs[0];
*right = inputs[1];
return self;
},
Gather { expr, idx, .. } => {
*idx = inputs[0];
*expr = inputs[1];
*expr = inputs[0];
*idx = inputs[1];
return self;
},
Sort { expr, .. } => expr,
SortBy { expr, by, .. } => {
*expr = *inputs.last().unwrap();
*expr = inputs[0];
by.clear();
by.extend_from_slice(&inputs[..inputs.len() - 1]);
by.extend_from_slice(&inputs[1..]);
return self;
},
Filter { input, by, .. } => {
*by = inputs[0];
*input = inputs[1];
*input = inputs[0];
*by = inputs[1];
return self;
},
Agg(a) => {
Expand All @@ -118,16 +111,14 @@ impl AExpr {
falsy,
predicate,
} => {
*predicate = inputs[0];
*truthy = inputs[0];
*falsy = inputs[1];
*truthy = inputs[2];
*predicate = inputs[2];
return self;
},
AnonymousFunction { input, .. } | Function { input, .. } => {
debug_assert_eq!(input.len(), inputs.len());

// Assign in reverse order as that was the order in which nodes were extracted.
for (e, node) in input.iter_mut().zip(inputs.iter().rev()) {
assert_eq!(input.len(), inputs.len());
for (e, node) in input.iter_mut().zip(inputs.iter()) {
e.set_node(*node);
}
return self;
Expand All @@ -137,9 +128,9 @@ impl AExpr {
offset,
length,
} => {
*length = inputs[0];
*input = inputs[0];
*offset = inputs[1];
*input = inputs[2];
*length = inputs[2];
return self;
},
Window {
Expand All @@ -149,14 +140,12 @@ impl AExpr {
..
} => {
let offset = order_by.is_some() as usize;
*function = *inputs.last().unwrap();
*function = inputs[0];
partition_by.clear();
partition_by.extend_from_slice(&inputs[offset..inputs.len() - 1]);

partition_by.extend_from_slice(&inputs[1..inputs.len() - offset]);
if let Some((_, options)) = order_by {
*order_by = Some((inputs[0], *options));
*order_by = Some((*inputs.last().unwrap(), *options));
}

return self;
},
};
Expand Down
4 changes: 2 additions & 2 deletions crates/polars-plan/src/plans/conversion/stack_opt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ impl ConversionOptimizer {
self.scratch.push(expr);
// traverse all subexpressions and add to the stack
let expr = unsafe { expr_arena.get_unchecked(expr) };
expr.nodes(&mut self.scratch);
expr.inputs_rev(&mut self.scratch);
}

pub(super) fn fill_scratch<N: Borrow<Node>>(&mut self, exprs: &[N], expr_arena: &Arena<AExpr>) {
Expand Down Expand Up @@ -100,7 +100,7 @@ impl ConversionOptimizer {

let expr = unsafe { expr_arena.get_unchecked(current_expr_node) };
// traverse subexpressions and add to the stack
expr.nodes(&mut self.scratch)
expr.inputs_rev(&mut self.scratch)
}

Ok(())
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-plan/src/plans/iterator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ impl<'a> Iterator for AExprIter<'a> {
// take the arena because the bchk doesn't allow a mutable borrow to the field.
let arena = self.arena.unwrap();
let current_expr = arena.get(node);
current_expr.nodes(&mut self.stack);
current_expr.inputs_rev(&mut self.stack);

self.arena = Some(arena);
(node, current_expr)
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-plan/src/plans/optimizer/collapse_joins.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ fn remove_suffix(
stack.push(expr.node());
while let Some(node) = stack.pop() {
let expr = expr_arena.get_mut(node);
expr.nodes(&mut stack);
expr.inputs_rev(&mut stack);

let AExpr::Column(colname) = expr else {
continue;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ impl OptimizationRule for SlicePushDown {
ae @ Alias(..) | ae @ Cast { .. } => {
let ae = ae.clone();
let scratch = self.empty_nodes_scratch_mut();
ae.nodes(scratch);
ae.inputs_rev(scratch);
let input = scratch[0];
let new_input = pushdown(input, offset, length, expr_arena);
Some(ae.replace_inputs(&[new_input]))
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-plan/src/plans/optimizer/stack_opt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ impl StackOptimizer {

let expr = unsafe { expr_arena.get_unchecked(current_expr_node) };
// traverse subexpressions and add to the stack
expr.nodes(&mut exprs)
expr.inputs_rev(&mut exprs)
}
}
}
Expand Down
9 changes: 5 additions & 4 deletions crates/polars-plan/src/plans/visitor/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -261,8 +261,8 @@ impl PartialEq for AExprArena<'_> {
return false;
}

l.to_aexpr().nodes(&mut scratch1);
r.to_aexpr().nodes(&mut scratch2);
l.to_aexpr().inputs_rev(&mut scratch1);
r.to_aexpr().inputs_rev(&mut scratch2);
},
(None, None) => return true,
_ => return false,
Expand All @@ -280,7 +280,7 @@ impl TreeWalker for AexprNode {
) -> PolarsResult<VisitRecursion> {
let mut scratch = unitvec![];

self.to_aexpr(arena).nodes(&mut scratch);
self.to_aexpr(arena).inputs_rev(&mut scratch);
for node in scratch.as_slice() {
let aenode = AexprNode::new(*node);
match op(&aenode, arena)? {
Expand All @@ -301,14 +301,15 @@ impl TreeWalker for AexprNode {
let mut scratch = unitvec![];

let ae = arena.get(self.node).clone();
ae.nodes(&mut scratch);
ae.inputs_rev(&mut scratch);

// rewrite the nodes
for node in scratch.as_mut_slice() {
let aenode = AexprNode::new(*node);
*node = op(aenode, arena)?.node;
}

scratch.as_mut_slice().reverse();
let ae = ae.replace_inputs(&scratch);
self.node = arena.add(ae);
Ok(self)
Expand Down

0 comments on commit 17556e4

Please sign in to comment.