From 148dc9bbdf9e8553a238a0faa0ee9a98eb1574ed Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Fri, 10 Jan 2025 14:12:52 +0100 Subject: [PATCH] refactor(rust): Remove implicit reverse from AExpr::replace_inputs() --- crates/polars-plan/src/dsl/meta.rs | 2 +- .../polars-plan/src/plans/aexpr/properties.rs | 4 +- .../polars-plan/src/plans/aexpr/traverse.rs | 57 ++++++++----------- .../src/plans/conversion/stack_opt.rs | 4 +- crates/polars-plan/src/plans/iterator.rs | 2 +- .../src/plans/optimizer/collapse_joins.rs | 2 +- .../plans/optimizer/slice_pushdown_expr.rs | 2 +- .../src/plans/optimizer/stack_opt.rs | 2 +- crates/polars-plan/src/plans/visitor/expr.rs | 9 +-- 9 files changed, 37 insertions(+), 47 deletions(-) diff --git a/crates/polars-plan/src/dsl/meta.rs b/crates/polars-plan/src/dsl/meta.rs index 76a881f08ed1..7d4d421f688e 100644 --- a/crates/polars-plan/src/dsl/meta.rs +++ b/crates/polars-plan/src/dsl/meta.rs @@ -17,7 +17,7 @@ impl MetaNameSpace { let node = to_aexpr(self.0, &mut arena)?; let ae = arena.get(node); let mut inputs = Vec::with_capacity(2); - ae.nodes(&mut inputs); + ae.inputs_rev(&mut inputs); Ok(inputs .iter() .map(|node| node_to_expr(*node, &arena)) diff --git a/crates/polars-plan/src/plans/aexpr/properties.rs b/crates/polars-plan/src/plans/aexpr/properties.rs index a8870839a44c..79cdc507c2ce 100644 --- a/crates/polars-plan/src/plans/aexpr/properties.rs +++ b/crates/polars-plan/src/plans/aexpr/properties.rs @@ -78,9 +78,9 @@ pub fn is_elementwise(stack: &mut UnitVec, ae: &AExpr, expr_arena: &Arena< } }; - ae.nodes(stack); + ae.inputs_rev(stack); })(), - _ => ae.nodes(stack), + _ => ae.inputs_rev(stack), } true diff --git a/crates/polars-plan/src/plans/aexpr/traverse.rs b/crates/polars-plan/src/plans/aexpr/traverse.rs index 80277619a9a3..a5ff7c5af217 100644 --- a/crates/polars-plan/src/plans/aexpr/traverse.rs +++ b/crates/polars-plan/src/plans/aexpr/traverse.rs @@ -1,8 +1,9 @@ use super::*; impl AExpr { - /// Push nodes at this level to a pre-allocated stack. - pub(crate) fn nodes(&self, container: &mut E) + /// Push the inputs of this node to the given container, in reverse order. + /// This ensures the primary node responsible for the name is pushed last. + pub(crate) fn inputs_rev(&self, container: &mut E) where E: Extend, { @@ -12,7 +13,6 @@ impl AExpr { Column(_) | Literal(_) | Len => {}, Alias(e, _) => container.extend([*e]), BinaryExpr { left, op: _, right } => { - // reverse order so that left is popped first container.extend([*right, *left]); }, Cast { expr, .. } => container.extend([*expr]), @@ -21,8 +21,7 @@ impl AExpr { container.extend([*idx, *expr]); }, SortBy { expr, by, .. } => { - container.extend(by.iter().cloned()); - // latest, so that it is popped first + container.extend(by.iter().cloned().rev()); container.extend([*expr]); }, Filter { input, by } => { @@ -30,7 +29,7 @@ impl AExpr { }, Agg(agg_e) => match agg_e.get_input() { NodeInputs::Single(node) => container.extend([node]), - NodeInputs::Many(nodes) => container.extend(nodes), + NodeInputs::Many(nodes) => container.extend(nodes.into_iter().rev()), NodeInputs::Leaf => {}, }, Ternary { @@ -40,10 +39,7 @@ impl AExpr { } => { container.extend([*predicate, *falsy, *truthy]); }, - AnonymousFunction { input, .. } | Function { input, .. } => - // we iterate in reverse order, so that the lhs is popped first and will be found - // as the root columns/ input columns by `_suffix` and `_keep_name` etc. - { + AnonymousFunction { input, .. } | Function { input, .. } => { container.extend(input.iter().rev().map(|e| e.node())) }, Explode(e) => container.extend([*e]), @@ -56,10 +52,7 @@ impl AExpr { if let Some((n, _)) = order_by { container.extend([*n]); } - container.extend(partition_by.iter().rev().cloned()); - - // latest so that it is popped first container.extend([*function]); }, Slice { @@ -80,25 +73,25 @@ impl AExpr { Cast { expr, .. } => expr, Explode(input) => input, BinaryExpr { left, right, .. } => { - *right = inputs[0]; - *left = inputs[1]; + *left = inputs[0]; + *right = inputs[1]; return self; }, Gather { expr, idx, .. } => { - *idx = inputs[0]; - *expr = inputs[1]; + *expr = inputs[0]; + *idx = inputs[1]; return self; }, Sort { expr, .. } => expr, SortBy { expr, by, .. } => { - *expr = *inputs.last().unwrap(); + *expr = inputs[0]; by.clear(); - by.extend_from_slice(&inputs[..inputs.len() - 1]); + by.extend_from_slice(&inputs[1..]); return self; }, Filter { input, by, .. } => { - *by = inputs[0]; - *input = inputs[1]; + *input = inputs[0]; + *by = inputs[1]; return self; }, Agg(a) => { @@ -118,16 +111,14 @@ impl AExpr { falsy, predicate, } => { - *predicate = inputs[0]; + *truthy = inputs[0]; *falsy = inputs[1]; - *truthy = inputs[2]; + *predicate = inputs[2]; return self; }, AnonymousFunction { input, .. } | Function { input, .. } => { - debug_assert_eq!(input.len(), inputs.len()); - - // Assign in reverse order as that was the order in which nodes were extracted. - for (e, node) in input.iter_mut().zip(inputs.iter().rev()) { + assert_eq!(input.len(), inputs.len()); + for (e, node) in input.iter_mut().zip(inputs.iter()) { e.set_node(*node); } return self; @@ -137,9 +128,9 @@ impl AExpr { offset, length, } => { - *length = inputs[0]; + *input = inputs[0]; *offset = inputs[1]; - *input = inputs[2]; + *length = inputs[2]; return self; }, Window { @@ -149,14 +140,12 @@ impl AExpr { .. } => { let offset = order_by.is_some() as usize; - *function = *inputs.last().unwrap(); + *function = inputs[0]; partition_by.clear(); - partition_by.extend_from_slice(&inputs[offset..inputs.len() - 1]); - + partition_by.extend_from_slice(&inputs[1..inputs.len() - offset]); if let Some((_, options)) = order_by { - *order_by = Some((inputs[0], *options)); + *order_by = Some((*inputs.last().unwrap(), *options)); } - return self; }, }; diff --git a/crates/polars-plan/src/plans/conversion/stack_opt.rs b/crates/polars-plan/src/plans/conversion/stack_opt.rs index 3401a892ced3..6e033ac90f86 100644 --- a/crates/polars-plan/src/plans/conversion/stack_opt.rs +++ b/crates/polars-plan/src/plans/conversion/stack_opt.rs @@ -51,7 +51,7 @@ impl ConversionOptimizer { self.scratch.push(expr); // traverse all subexpressions and add to the stack let expr = unsafe { expr_arena.get_unchecked(expr) }; - expr.nodes(&mut self.scratch); + expr.inputs_rev(&mut self.scratch); } pub(super) fn fill_scratch>(&mut self, exprs: &[N], expr_arena: &Arena) { @@ -100,7 +100,7 @@ impl ConversionOptimizer { let expr = unsafe { expr_arena.get_unchecked(current_expr_node) }; // traverse subexpressions and add to the stack - expr.nodes(&mut self.scratch) + expr.inputs_rev(&mut self.scratch) } Ok(()) diff --git a/crates/polars-plan/src/plans/iterator.rs b/crates/polars-plan/src/plans/iterator.rs index 2dc13870b553..257962f815f9 100644 --- a/crates/polars-plan/src/plans/iterator.rs +++ b/crates/polars-plan/src/plans/iterator.rs @@ -176,7 +176,7 @@ impl<'a> Iterator for AExprIter<'a> { // take the arena because the bchk doesn't allow a mutable borrow to the field. let arena = self.arena.unwrap(); let current_expr = arena.get(node); - current_expr.nodes(&mut self.stack); + current_expr.inputs_rev(&mut self.stack); self.arena = Some(arena); (node, current_expr) diff --git a/crates/polars-plan/src/plans/optimizer/collapse_joins.rs b/crates/polars-plan/src/plans/optimizer/collapse_joins.rs index 2716340b4e52..ef954fb57e52 100644 --- a/crates/polars-plan/src/plans/optimizer/collapse_joins.rs +++ b/crates/polars-plan/src/plans/optimizer/collapse_joins.rs @@ -85,7 +85,7 @@ fn remove_suffix( stack.push(expr.node()); while let Some(node) = stack.pop() { let expr = expr_arena.get_mut(node); - expr.nodes(&mut stack); + expr.inputs_rev(&mut stack); let AExpr::Column(colname) = expr else { continue; diff --git a/crates/polars-plan/src/plans/optimizer/slice_pushdown_expr.rs b/crates/polars-plan/src/plans/optimizer/slice_pushdown_expr.rs index 8d9bff3ea868..fb3508170bab 100644 --- a/crates/polars-plan/src/plans/optimizer/slice_pushdown_expr.rs +++ b/crates/polars-plan/src/plans/optimizer/slice_pushdown_expr.rs @@ -30,7 +30,7 @@ impl OptimizationRule for SlicePushDown { ae @ Alias(..) | ae @ Cast { .. } => { let ae = ae.clone(); let scratch = self.empty_nodes_scratch_mut(); - ae.nodes(scratch); + ae.inputs_rev(scratch); let input = scratch[0]; let new_input = pushdown(input, offset, length, expr_arena); Some(ae.replace_inputs(&[new_input])) diff --git a/crates/polars-plan/src/plans/optimizer/stack_opt.rs b/crates/polars-plan/src/plans/optimizer/stack_opt.rs index 5468960661cf..8c1658a0ffda 100644 --- a/crates/polars-plan/src/plans/optimizer/stack_opt.rs +++ b/crates/polars-plan/src/plans/optimizer/stack_opt.rs @@ -74,7 +74,7 @@ impl StackOptimizer { let expr = unsafe { expr_arena.get_unchecked(current_expr_node) }; // traverse subexpressions and add to the stack - expr.nodes(&mut exprs) + expr.inputs_rev(&mut exprs) } } } diff --git a/crates/polars-plan/src/plans/visitor/expr.rs b/crates/polars-plan/src/plans/visitor/expr.rs index 175396235613..2db617bb453d 100644 --- a/crates/polars-plan/src/plans/visitor/expr.rs +++ b/crates/polars-plan/src/plans/visitor/expr.rs @@ -261,8 +261,8 @@ impl PartialEq for AExprArena<'_> { return false; } - l.to_aexpr().nodes(&mut scratch1); - r.to_aexpr().nodes(&mut scratch2); + l.to_aexpr().inputs_rev(&mut scratch1); + r.to_aexpr().inputs_rev(&mut scratch2); }, (None, None) => return true, _ => return false, @@ -280,7 +280,7 @@ impl TreeWalker for AexprNode { ) -> PolarsResult { let mut scratch = unitvec![]; - self.to_aexpr(arena).nodes(&mut scratch); + self.to_aexpr(arena).inputs_rev(&mut scratch); for node in scratch.as_slice() { let aenode = AexprNode::new(*node); match op(&aenode, arena)? { @@ -301,7 +301,7 @@ impl TreeWalker for AexprNode { let mut scratch = unitvec![]; let ae = arena.get(self.node).clone(); - ae.nodes(&mut scratch); + ae.inputs_rev(&mut scratch); // rewrite the nodes for node in scratch.as_mut_slice() { @@ -309,6 +309,7 @@ impl TreeWalker for AexprNode { *node = op(aenode, arena)?.node; } + scratch.as_mut_slice().reverse(); let ae = ae.replace_inputs(&scratch); self.node = arena.add(ae); Ok(self)