diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index a277526c6e31..50d3dcb1df6b 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -83,6 +83,8 @@ pub enum AggregateMode { /// two operators. /// This mode requires tha the input is partitioned by group key (like FinalPartitioned) SinglePartitioned, + /// Combine Partials + CombinePartial, } impl AggregateMode { @@ -94,7 +96,7 @@ impl AggregateMode { AggregateMode::Partial | AggregateMode::Single | AggregateMode::SinglePartitioned => true, - AggregateMode::Final | AggregateMode::FinalPartitioned => false, + AggregateMode::Final | AggregateMode::FinalPartitioned | AggregateMode::CombinePartial => false, } } } @@ -651,7 +653,7 @@ impl ExecutionPlan for AggregateExec { fn required_input_distribution(&self) -> Vec { match &self.mode { - AggregateMode::Partial => { + AggregateMode::Partial | AggregateMode::CombinePartial => { vec![Distribution::UnspecifiedDistribution] } AggregateMode::FinalPartitioned | AggregateMode::SinglePartitioned => { @@ -781,7 +783,7 @@ fn create_schema( } match mode { - AggregateMode::Partial => { + AggregateMode::Partial | AggregateMode::CombinePartial => { // in partial mode, the fields of the accumulator's state for expr in aggr_expr { fields.extend(expr.state_fields()?.iter().cloned()) @@ -1050,7 +1052,7 @@ fn aggregate_expressions( }) .collect()), // In this mode, we build the merge expressions of the aggregation. - AggregateMode::Final | AggregateMode::FinalPartitioned => { + AggregateMode::Final | AggregateMode::FinalPartitioned | AggregateMode::CombinePartial => { let mut col_idx_base = col_idx_base; aggr_expr .iter() @@ -1099,7 +1101,7 @@ fn finalize_aggregation( mode: &AggregateMode, ) -> Result> { match mode { - AggregateMode::Partial => { + AggregateMode::Partial | AggregateMode::CombinePartial => { // Build the vector of states accumulators .iter_mut() diff --git a/datafusion/physical-plan/src/aggregates/no_grouping.rs b/datafusion/physical-plan/src/aggregates/no_grouping.rs index 5ec95bd79942..7062e3be70a2 100644 --- a/datafusion/physical-plan/src/aggregates/no_grouping.rs +++ b/datafusion/physical-plan/src/aggregates/no_grouping.rs @@ -81,7 +81,8 @@ impl AggregateStream { let filter_expressions = match agg.mode { AggregateMode::Partial | AggregateMode::Single - | AggregateMode::SinglePartitioned => agg_filter_expr, + | AggregateMode::SinglePartitioned + | AggregateMode::CombinePartial => agg_filter_expr, AggregateMode::Final | AggregateMode::FinalPartitioned => { vec![None; agg.aggr_expr.len()] } @@ -230,7 +231,7 @@ fn aggregate_batch( AggregateMode::Partial | AggregateMode::Single | AggregateMode::SinglePartitioned => accum.update_batch(values), - AggregateMode::Final | AggregateMode::FinalPartitioned => { + AggregateMode::Final | AggregateMode::FinalPartitioned | AggregateMode::CombinePartial => { accum.merge_batch(values) } }; diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index f9db0a050cfc..1e72352c2989 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -313,7 +313,8 @@ impl GroupedHashAggregateStream { let filter_expressions = match agg.mode { AggregateMode::Partial | AggregateMode::Single - | AggregateMode::SinglePartitioned => agg_filter_expr, + | AggregateMode::SinglePartitioned + | AggregateMode::CombinePartial => agg_filter_expr, AggregateMode::Final | AggregateMode::FinalPartitioned => { vec![None; agg.aggr_expr.len()] } @@ -640,7 +641,8 @@ impl GroupedHashAggregateStream { // Next output each aggregate value for acc in self.accumulators.iter_mut() { match self.mode { - AggregateMode::Partial => output.extend(acc.state(emit_to)?), + AggregateMode::Partial + | AggregateMode::CombinePartial => output.extend(acc.state(emit_to)?), _ if spilling => { // If spilling, output partial state because the spilled data will be // merged and re-evaluated later. diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index d2961875d89a..ca34a35c9d8d 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -1389,8 +1389,12 @@ impl AsExecutionPlan for PhysicalPlanNode { AggregateMode::Single => protobuf::AggregateMode::Single, AggregateMode::SinglePartitioned => { protobuf::AggregateMode::SinglePartitioned + }, + AggregateMode::CombinePartial => { + unimplemented!() } }; + let input_schema = exec.input_schema(); let input = protobuf::PhysicalPlanNode::try_from_physical_plan( exec.input().to_owned(),