diff --git a/crates/polars-stream/src/physical_plan/lower_expr.rs b/crates/polars-stream/src/physical_plan/lower_expr.rs index bb1696fcd9a..bb954021149 100644 --- a/crates/polars-stream/src/physical_plan/lower_expr.rs +++ b/crates/polars-stream/src/physical_plan/lower_expr.rs @@ -207,12 +207,12 @@ pub fn is_input_independent_rec( ret } -pub fn is_input_independent(expr_key: ExprNodeKey, expr_arena: &Arena, cache: &mut ExprCache) -> bool { - is_input_independent_rec( - expr_key, - expr_arena, - &mut cache.is_input_independent, - ) +pub fn is_input_independent( + expr_key: ExprNodeKey, + expr_arena: &Arena, + cache: &mut ExprCache, +) -> bool { + is_input_independent_rec(expr_key, expr_arena, &mut cache.is_input_independent) } fn is_input_independent_ctx(expr_key: ExprNodeKey, ctx: &mut LowerExprContext) -> bool { @@ -687,7 +687,10 @@ fn build_select_stream_with_ctx( exprs: &[ExprIR], ctx: &mut LowerExprContext, ) -> PolarsResult { - if exprs.iter().all(|e| is_input_independent_ctx(e.node(), ctx)) { + if exprs + .iter() + .all(|e| is_input_independent_ctx(e.node(), ctx)) + { return Ok(PhysStream::first(build_input_independent_node_with_ctx( exprs, ctx, )?)); diff --git a/crates/polars-stream/src/physical_plan/lower_group_by.rs b/crates/polars-stream/src/physical_plan/lower_group_by.rs index 9e5f96e3f64..d4f8b1d5456 100644 --- a/crates/polars-stream/src/physical_plan/lower_group_by.rs +++ b/crates/polars-stream/src/physical_plan/lower_group_by.rs @@ -16,7 +16,10 @@ use slotmap::SlotMap; use super::lower_expr::{is_elementwise_rec_cached, lower_exprs}; use super::{ExprCache, PhysNode, PhysNodeKey, PhysNodeKind, PhysStream}; -use crate::physical_plan::lower_expr::{build_select_stream, compute_output_schema, is_input_independent, is_input_independent_rec, unique_column_name}; +use crate::physical_plan::lower_expr::{ + build_select_stream, compute_output_schema, is_input_independent, is_input_independent_rec, + unique_column_name, +}; use crate::utils::late_materialized_df::LateMaterializedDataFrame; fn build_group_by_fallback( @@ -117,13 +120,12 @@ fn try_lower_elementwise_scalar_agg_expr( | AExpr::Sort { .. } | AExpr::SortBy { .. } | AExpr::Gather { .. } => None, - + // Explode and filter are row-separable and should thus in theory work // in a streaming fashion but they change the length of the input which // means the same filter/explode should also be applied to the key // column, which is not (yet) supported. - AExpr::Explode(_) - | AExpr::Filter { .. } => None, + AExpr::Explode(_) | AExpr::Filter { .. } => None, AExpr::BinaryExpr { left, op, right } => { let (left, op, right) = (*left, *op, *right); @@ -197,7 +199,7 @@ fn try_lower_elementwise_scalar_agg_expr( let trans_agg_node = expr_arena.add(AExpr::Agg(trans_agg)); // Add to aggregation expressions and replace with a reference to its output. - + let agg_expr = if let Some(name) = outer_name { ExprIR::new(trans_agg_node, OutputName::Alias(name)) } else { @@ -251,13 +253,15 @@ fn try_build_streaming_group_by( } if keys.len() == 0 { - return Some(Err(polars_err!(ComputeError: "at least one key is required in a group_by operation"))); + return Some(Err( + polars_err!(ComputeError: "at least one key is required in a group_by operation"), + )); } - - let all_independent = keys.iter().chain(aggs.iter()).all(|expr| - is_input_independent(expr.node(), expr_arena, expr_cache) - ); + let all_independent = keys + .iter() + .chain(aggs.iter()) + .all(|expr| is_input_independent(expr.node(), expr_arena, expr_cache)); if all_independent { return None; } @@ -295,10 +299,13 @@ fn try_build_streaming_group_by( // substituting the translated input columns and extracting the aggregate // expressions. let mut trans_agg_exprs = Vec::new(); - let mut trans_output_exprs = keys.iter().map(|key| { - let key_node = expr_arena.add(AExpr::Column(key.output_name().clone())); - ExprIR::from_node(key_node, expr_arena) - }).collect_vec(); + let mut trans_output_exprs = keys + .iter() + .map(|key| { + let key_node = expr_arena.add(AExpr::Column(key.output_name().clone())); + ExprIR::from_node(key_node, expr_arena) + }) + .collect_vec(); for agg in aggs { let trans_node = try_lower_elementwise_scalar_agg_expr( agg.node(), @@ -311,9 +318,14 @@ fn try_build_streaming_group_by( )?; trans_output_exprs.push(ExprIR::new(trans_node, agg.output_name_inner().clone())); } - + let input_schema = &phys_sm[trans_input.node].output_schema; - let group_by_output_schema = compute_output_schema(input_schema, &[trans_keys.clone(), trans_agg_exprs.clone()].concat(), expr_arena).unwrap(); + let group_by_output_schema = compute_output_schema( + input_schema, + &[trans_keys.clone(), trans_agg_exprs.clone()].concat(), + expr_arena, + ) + .unwrap(); let agg_node = phys_sm.insert(PhysNode::new( group_by_output_schema, PhysNodeKind::GroupBy { diff --git a/crates/polars-stream/src/physical_plan/lower_ir.rs b/crates/polars-stream/src/physical_plan/lower_ir.rs index f8a2289368d..f0178afcb80 100644 --- a/crates/polars-stream/src/physical_plan/lower_ir.rs +++ b/crates/polars-stream/src/physical_plan/lower_ir.rs @@ -16,7 +16,7 @@ use super::{PhysNode, PhysNodeKey, PhysNodeKind, PhysStream}; use crate::physical_plan::lower_expr::{ build_select_stream, is_elementwise_rec_cached, lower_exprs, ExprCache, }; -use crate::physical_plan::lower_group_by::{build_group_by_stream}; +use crate::physical_plan::lower_group_by::build_group_by_stream; /// Creates a new PhysStream which outputs a slice of the input stream. fn build_slice_stream( diff --git a/crates/polars-stream/src/physical_plan/mod.rs b/crates/polars-stream/src/physical_plan/mod.rs index ab5dca031ca..87acf2c3a72 100644 --- a/crates/polars-stream/src/physical_plan/mod.rs +++ b/crates/polars-stream/src/physical_plan/mod.rs @@ -13,8 +13,8 @@ use polars_plan::prelude::expr_ir::ExprIR; mod fmt; mod lower_expr; -mod lower_ir; mod lower_group_by; +mod lower_ir; mod to_graph; pub use fmt::visualize_plan;