Skip to content

Commit

Permalink
Revert "Make builtin window function output datatype to be derived fr…
Browse files Browse the repository at this point in the history
…om schema (apache#9686)"

This reverts commit 1d0171a.
  • Loading branch information
mwylde committed May 8, 2024
1 parent b97a5b6 commit ad26df9
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 74 deletions.
33 changes: 21 additions & 12 deletions datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -743,13 +743,13 @@ impl DefaultPhysicalPlanner {
);
}

let logical_schema = logical_plan.schema();
let logical_input_schema = input.schema();
let window_expr = window_expr
.iter()
.map(|e| {
create_window_expr(
e,
logical_schema,
logical_input_schema,
session_state.execution_props(),
)
})
Expand Down Expand Up @@ -1572,11 +1572,11 @@ pub fn is_window_frame_bound_valid(window_frame: &WindowFrame) -> bool {
pub fn create_window_expr_with_name(
e: &Expr,
name: impl Into<String>,
logical_schema: &DFSchema,
logical_input_schema: &DFSchema,
execution_props: &ExecutionProps,
) -> Result<Arc<dyn WindowExpr>> {
let name = name.into();
let physical_schema: &Schema = &logical_schema.into();
let physical_input_schema: &Schema = &logical_input_schema.into();
match e {
Expr::WindowFunction(WindowFunction {
fun,
Expand All @@ -1586,11 +1586,20 @@ pub fn create_window_expr_with_name(
window_frame,
null_treatment,
}) => {
let args = create_physical_exprs(args, logical_schema, execution_props)?;
let partition_by =
create_physical_exprs(partition_by, logical_schema, execution_props)?;
let order_by =
create_physical_sort_exprs(order_by, logical_schema, execution_props)?;
let args = args
.iter()
.map(|e| create_physical_expr(e, logical_input_schema, execution_props))
.collect::<Result<Vec<_>>>()?;
let partition_by = partition_by
.iter()
.map(|e| create_physical_expr(e, logical_input_schema, execution_props))
.collect::<Result<Vec<_>>>()?;
let order_by = order_by
.iter()
.map(|e| {
create_physical_sort_expr(e, logical_input_schema, execution_props)
})
.collect::<Result<Vec<_>>>()?;

if !is_window_frame_bound_valid(window_frame) {
return plan_err!(
Expand All @@ -1610,7 +1619,7 @@ pub fn create_window_expr_with_name(
&partition_by,
&order_by,
window_frame,
physical_schema,
physical_input_schema,
ignore_nulls,
)
}
Expand All @@ -1621,15 +1630,15 @@ pub fn create_window_expr_with_name(
/// Create a window expression from a logical expression or an alias
pub fn create_window_expr(
e: &Expr,
logical_schema: &DFSchema,
logical_input_schema: &DFSchema,
execution_props: &ExecutionProps,
) -> Result<Arc<dyn WindowExpr>> {
// unpack aliased logical expressions, e.g. "sum(col) over () as total"
let (name, e) = match e {
Expr::Alias(Alias { expr, name, .. }) => (name.clone(), expr.as_ref()),
_ => (e.display_name()?, e),
};
create_window_expr_with_name(e, name, logical_schema, execution_props)
create_window_expr_with_name(e, name, logical_input_schema, execution_props)
}

type AggregateExprWithOptionalArgs = (
Expand Down
39 changes: 3 additions & 36 deletions datafusion/core/tests/fuzz_cases/window_fuzz.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ use arrow::compute::{concat_batches, SortOptions};
use arrow::datatypes::SchemaRef;
use arrow::record_batch::RecordBatch;
use arrow::util::pretty::pretty_format_batches;
use arrow_schema::{Field, Schema};
use datafusion::physical_plan::memory::MemoryExec;
use datafusion::physical_plan::sorts::sort::SortExec;
use datafusion::physical_plan::windows::{
Expand All @@ -40,7 +39,6 @@ use datafusion_expr::{
};
use datafusion_physical_expr::expressions::{cast, col, lit};
use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr};
use itertools::Itertools;
use test_utils::add_empty_batches;

use hashbrown::HashMap;
Expand Down Expand Up @@ -275,17 +273,14 @@ async fn bounded_window_causal_non_causal() -> Result<()> {
window_frame.is_causal()
};

let extended_schema =
schema_add_window_fields(&args, &schema, &window_fn, fn_name)?;

let window_expr = create_window_expr(
&window_fn,
fn_name.to_string(),
&args,
&partitionby_exprs,
&orderby_exprs,
Arc::new(window_frame),
&extended_schema,
schema.as_ref(),
false,
)?;
let running_window_exec = Arc::new(BoundedWindowAggExec::try_new(
Expand Down Expand Up @@ -683,8 +678,6 @@ async fn run_window_test(
exec1 = Arc::new(SortExec::new(sort_keys, exec1)) as _;
}

let extended_schema = schema_add_window_fields(&args, &schema, &window_fn, &fn_name)?;

let usual_window_exec = Arc::new(WindowAggExec::try_new(
vec![create_window_expr(
&window_fn,
Expand All @@ -693,7 +686,7 @@ async fn run_window_test(
&partitionby_exprs,
&orderby_exprs,
Arc::new(window_frame.clone()),
&extended_schema,
schema.as_ref(),
false,
)?],
exec1,
Expand All @@ -711,7 +704,7 @@ async fn run_window_test(
&partitionby_exprs,
&orderby_exprs,
Arc::new(window_frame.clone()),
&extended_schema,
schema.as_ref(),
false,
)?],
exec2,
Expand Down Expand Up @@ -754,32 +747,6 @@ async fn run_window_test(
Ok(())
}

// The planner has fully updated schema before calling the `create_window_expr`
// Replicate the same for this test
fn schema_add_window_fields(
args: &[Arc<dyn PhysicalExpr>],
schema: &Arc<Schema>,
window_fn: &WindowFunctionDefinition,
fn_name: &str,
) -> Result<Arc<Schema>> {
let data_types = args
.iter()
.map(|e| e.clone().as_ref().data_type(schema))
.collect::<Result<Vec<_>>>()?;
let window_expr_return_type = window_fn.return_type(&data_types)?;
let mut window_fields = schema
.fields()
.iter()
.map(|f| f.as_ref().clone())
.collect_vec();
window_fields.extend_from_slice(&[Field::new(
fn_name,
window_expr_return_type,
true,
)]);
Ok(Arc::new(Schema::new(window_fields)))
}

/// Return randomly sized record batches with:
/// three sorted int32 columns 'a', 'b', 'c' ranged from 0..DISTINCT as columns
/// one random int32 column x
Expand Down
47 changes: 21 additions & 26 deletions datafusion/physical-plan/src/windows/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,15 +174,20 @@ fn create_built_in_window_expr(
name: String,
ignore_nulls: bool,
) -> Result<Arc<dyn BuiltInWindowFunctionExpr>> {
// derive the output datatype from incoming schema
let out_data_type: &DataType = input_schema.field_with_name(&name)?.data_type();
// need to get the types into an owned vec for some reason
let input_types: Vec<_> = args
.iter()
.map(|arg| arg.data_type(input_schema))
.collect::<Result<_>>()?;

// figure out the output type
let data_type = &fun.return_type(&input_types)?;
Ok(match fun {
BuiltInWindowFunction::RowNumber => Arc::new(RowNumber::new(name, out_data_type)),
BuiltInWindowFunction::Rank => Arc::new(rank(name, out_data_type)),
BuiltInWindowFunction::DenseRank => Arc::new(dense_rank(name, out_data_type)),
BuiltInWindowFunction::PercentRank => Arc::new(percent_rank(name, out_data_type)),
BuiltInWindowFunction::CumeDist => Arc::new(cume_dist(name, out_data_type)),
BuiltInWindowFunction::RowNumber => Arc::new(RowNumber::new(name, data_type)),
BuiltInWindowFunction::Rank => Arc::new(rank(name, data_type)),
BuiltInWindowFunction::DenseRank => Arc::new(dense_rank(name, data_type)),
BuiltInWindowFunction::PercentRank => Arc::new(percent_rank(name, data_type)),
BuiltInWindowFunction::CumeDist => Arc::new(cume_dist(name, data_type)),
BuiltInWindowFunction::Ntile => {
let n = get_scalar_value_from_args(args, 0)?.ok_or_else(|| {
DataFusionError::Execution(
Expand All @@ -196,13 +201,13 @@ fn create_built_in_window_expr(

if n.is_unsigned() {
let n: u64 = n.try_into()?;
Arc::new(Ntile::new(name, n, out_data_type))
Arc::new(Ntile::new(name, n, data_type))
} else {
let n: i64 = n.try_into()?;
if n <= 0 {
return exec_err!("NTILE requires a positive integer");
}
Arc::new(Ntile::new(name, n as u64, out_data_type))
Arc::new(Ntile::new(name, n as u64, data_type))
}
}
BuiltInWindowFunction::Lag => {
Expand All @@ -211,10 +216,10 @@ fn create_built_in_window_expr(
.map(|v| v.try_into())
.and_then(|v| v.ok());
let default_value =
get_casted_value(get_scalar_value_from_args(args, 2)?, out_data_type)?;
get_casted_value(get_scalar_value_from_args(args, 2)?, data_type)?;
Arc::new(lag(
name,
out_data_type.clone(),
data_type.clone(),
arg,
shift_offset,
default_value,
Expand All @@ -227,10 +232,10 @@ fn create_built_in_window_expr(
.map(|v| v.try_into())
.and_then(|v| v.ok());
let default_value =
get_casted_value(get_scalar_value_from_args(args, 2)?, out_data_type)?;
get_casted_value(get_scalar_value_from_args(args, 2)?, data_type)?;
Arc::new(lead(
name,
out_data_type.clone(),
data_type.clone(),
arg,
shift_offset,
default_value,
Expand All @@ -247,28 +252,18 @@ fn create_built_in_window_expr(
Arc::new(NthValue::nth(
name,
arg,
out_data_type.clone(),
data_type.clone(),
n,
ignore_nulls,
)?)
}
BuiltInWindowFunction::FirstValue => {
let arg = args[0].clone();
Arc::new(NthValue::first(
name,
arg,
out_data_type.clone(),
ignore_nulls,
))
Arc::new(NthValue::first(name, arg, data_type.clone(), ignore_nulls))
}
BuiltInWindowFunction::LastValue => {
let arg = args[0].clone();
Arc::new(NthValue::last(
name,
arg,
out_data_type.clone(),
ignore_nulls,
))
Arc::new(NthValue::last(name, arg, data_type.clone(), ignore_nulls))
}
})
}
Expand Down

0 comments on commit ad26df9

Please sign in to comment.