Skip to content

Commit

Permalink
perf: Improve unique pred-pd (#20569)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Jan 6, 2025
1 parent e0ef7f9 commit 11dd4b3
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 80 deletions.
50 changes: 2 additions & 48 deletions crates/polars-plan/src/plans/aexpr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ mod hash;
mod scalar;
mod schema;
mod traverse;
mod utils;

use std::hash::{Hash, Hasher};

Expand All @@ -18,8 +17,8 @@ pub use scalar::is_scalar_ae;
use serde::{Deserialize, Serialize};
use strum_macros::IntoStaticStr;
pub use traverse::*;
pub(crate) use utils::permits_filter_pushdown;
pub use utils::*;
mod properties;
pub use properties::*;

use crate::constants::LEN;
use crate::plans::Context;
Expand Down Expand Up @@ -212,43 +211,6 @@ impl AExpr {
AExpr::Column(name)
}

/// Checks whether this expression is elementwise. This only checks the top level expression.
pub(crate) fn is_elementwise_top_level(&self) -> bool {
use AExpr::*;

match self {
AnonymousFunction { options, .. } => options.is_elementwise(),

// Non-strict strptime must be done in-memory to ensure the format
// is consistent across the entire dataframe.
#[cfg(all(feature = "strings", feature = "temporal"))]
Function {
options,
function: FunctionExpr::StringExpr(StringFunction::Strptime(_, opts)),
..
} => {
assert!(options.is_elementwise());
opts.strict
},

Function { options, .. } => options.is_elementwise(),

Literal(v) => v.projects_as_scalar(),

Alias(_, _) | BinaryExpr { .. } | Column(_) | Ternary { .. } | Cast { .. } => true,

Agg { .. }
| Explode(_)
| Filter { .. }
| Gather { .. }
| Len
| Slice { .. }
| Sort { .. }
| SortBy { .. }
| Window { .. } => false,
}
}

/// This should be a 1 on 1 copy of the get_type method of Expr until Expr is completely phased out.
pub fn get_type(
&self,
Expand All @@ -259,12 +221,4 @@ impl AExpr {
self.to_field(schema, ctxt, arena)
.map(|f| f.dtype().clone())
}

pub(crate) fn is_leaf(&self) -> bool {
matches!(self, AExpr::Column(_) | AExpr::Literal(_) | AExpr::Len)
}

pub(crate) fn is_col(&self) -> bool {
matches!(self, AExpr::Column(_))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,53 @@ use polars_utils::unitvec;

use super::*;

impl AExpr {
pub(crate) fn is_leaf(&self) -> bool {
matches!(self, AExpr::Column(_) | AExpr::Literal(_) | AExpr::Len)
}

pub(crate) fn is_col(&self) -> bool {
matches!(self, AExpr::Column(_))
}

/// Checks whether this expression is elementwise. This only checks the top level expression.
pub(crate) fn is_elementwise_top_level(&self) -> bool {
use AExpr::*;

match self {
AnonymousFunction { options, .. } => options.is_elementwise(),

// Non-strict strptime must be done in-memory to ensure the format
// is consistent across the entire dataframe.
#[cfg(all(feature = "strings", feature = "temporal"))]
Function {
options,
function: FunctionExpr::StringExpr(StringFunction::Strptime(_, opts)),
..
} => {
assert!(options.is_elementwise());
opts.strict
},

Function { options, .. } => options.is_elementwise(),

Literal(v) => v.projects_as_scalar(),

Alias(_, _) | BinaryExpr { .. } | Column(_) | Ternary { .. } | Cast { .. } => true,

Agg { .. }
| Explode(_)
| Filter { .. }
| Gather { .. }
| Len
| Slice { .. }
| Sort { .. }
| SortBy { .. }
| Window { .. } => false,
}
}
}

/// Checks if the top-level expression node is elementwise. If this is the case, then `stack` will
/// be extended further with any nested expression nodes.
pub fn is_elementwise(stack: &mut UnitVec<Node>, ae: &AExpr, expr_arena: &Arena<AExpr>) -> bool {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,7 @@ pub(super) fn process_group_by(
for (pred_name, predicate) in acc_predicates {
// Counts change due to groupby's
// TODO! handle aliases, so that the predicate that is pushed down refers to the column before alias.
let mut push_down = !has_aexpr(predicate.node(), expr_arena, |ae| {
matches!(ae, AExpr::Len | AExpr::Alias(_, _))
});
let mut push_down = !has_aexpr(predicate.node(), expr_arena, |ae| matches!(ae, AExpr::Len));

for name in aexpr_to_leaf_names_iter(predicate.node(), expr_arena) {
push_down &= key_schema.contains(name.as_ref());
Expand Down
47 changes: 30 additions & 17 deletions crates/polars-plan/src/plans/optimizer/predicate_pushdown/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -488,25 +488,38 @@ impl PredicatePushDown<'_> {
Ok(self.optional_apply_predicate(lp, local_predicates, lp_arena, expr_arena))
},
Distinct { input, options } => {
if let Some(ref subset) = options.subset {
// Predicates on the subset can pass.
let subset = subset.clone();
let mut names_set = PlHashSet::<PlSmallStr>::with_capacity(subset.len());
for name in subset.iter() {
names_set.insert(name.clone());
}

let condition = |name: &PlSmallStr| !names_set.contains(name.as_str());
let local_predicates =
transfer_to_local_by_name(expr_arena, &mut acc_predicates, condition);

self.pushdown_and_assign(input, acc_predicates, lp_arena, expr_arena)?;
let lp = Distinct { input, options };
Ok(self.optional_apply_predicate(lp, local_predicates, lp_arena, expr_arena))
let subset = if let Some(ref subset) = options.subset {
subset.as_ref()
} else {
let lp = Distinct { input, options };
self.no_pushdown_restart_opt(lp, acc_predicates, lp_arena, expr_arena)
&[]
};
let mut names_set = PlHashSet::<PlSmallStr>::with_capacity(subset.len());
for name in subset.iter() {
names_set.insert(name.clone());
}

let local_predicates = match options.keep_strategy {
UniqueKeepStrategy::Any => {
let condition = |e: &ExprIR| {
let ae = expr_arena.get(e.node());
// if not elementwise -> to local
!is_elementwise_rec(ae, expr_arena)
};
transfer_to_local_by_expr_ir(expr_arena, &mut acc_predicates, condition)
},
UniqueKeepStrategy::First
| UniqueKeepStrategy::Last
| UniqueKeepStrategy::None => {
let condition = |name: &PlSmallStr| {
!subset.is_empty() && !names_set.contains(name.as_str())
};
transfer_to_local_by_name(expr_arena, &mut acc_predicates, condition)
},
};

self.pushdown_and_assign(input, acc_predicates, lp_arena, expr_arena)?;
let lp = Distinct { input, options };
Ok(self.optional_apply_predicate(lp, local_predicates, lp_arena, expr_arena))
},
Join {
input_left,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,35 @@ pub(super) fn predicate_at_scan(
}
}

/// Evaluates a condition on the column name inputs of every predicate, where if
/// the condition evaluates to true on any column name the predicate is
/// transferred to local.
pub(super) fn transfer_to_local_by_expr_ir<F>(
expr_arena: &Arena<AExpr>,
acc_predicates: &mut PlHashMap<PlSmallStr, ExprIR>,
mut condition: F,
) -> Vec<ExprIR>
where
F: FnMut(&ExprIR) -> bool,
{
let mut remove_keys = Vec::with_capacity(acc_predicates.len());

for predicate in acc_predicates.values() {
if condition(predicate) {
if let Some(name) = aexpr_to_leaf_names_iter(predicate.node(), expr_arena).next() {
remove_keys.push(name);
}
}
}
let mut local_predicates = Vec::with_capacity(remove_keys.len());
for key in remove_keys {
if let Some(pred) = acc_predicates.remove(&*key) {
local_predicates.push(pred)
}
}
local_predicates
}

/// Evaluates a condition on the column name inputs of every predicate, where if
/// the condition evaluates to true on any column name the predicate is
/// transferred to local.
Expand All @@ -94,7 +123,7 @@ where
let mut remove_keys = Vec::with_capacity(acc_predicates.len());

for (key, predicate) in &*acc_predicates {
let root_names = aexpr_to_leaf_names(predicate.node(), expr_arena);
let root_names = aexpr_to_leaf_names_iter(predicate.node(), expr_arena);
for name in root_names {
if condition(&name) {
remove_keys.push(key.clone());
Expand Down
23 changes: 12 additions & 11 deletions py-polars/tests/unit/operations/unique/test_unique.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import re
from datetime import date
from typing import TYPE_CHECKING, Any

Expand Down Expand Up @@ -47,16 +46,6 @@ def test_unique_predicate_pd() -> None:
.filter(pl.col("x") == "abc")
.filter(pl.col("z"))
)
plan = q.explain()
assert r'FILTER col("z")' in plan
# We can push filters if they only depend on the subset columns of unique()
assert (
re.search(
r"FILTER \[\(col\(\"x\"\)\) == \(String\(abc\)\)\] FROM\n\s*DF",
plan,
)
is not None
)
assert_frame_equal(q.collect(predicate_pushdown=False), q.collect())


Expand Down Expand Up @@ -256,3 +245,15 @@ def test_unique_check_order_20480() -> None:
.item()
== 1
)


def test_predicate_pushdown_unique() -> None:
q = (
pl.LazyFrame({"id": [1, 2, 3]})
.with_columns(pl.date(2024, 1, 1) + pl.duration(days=[1, 2, 3])) # type: ignore[arg-type]
.unique()
)

print(q.filter(pl.col("id").is_in([1, 2, 3])).explain())
assert not q.filter(pl.col("id").is_in([1, 2, 3])).explain().startswith("FILTER")
assert q.filter(pl.col("id").sum() == pl.col("id")).explain().startswith("FILTER")

0 comments on commit 11dd4b3

Please sign in to comment.