From 4c14e706e6fe60ec21d900b7e03cfcc24d83f4e5 Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Tue, 31 Dec 2024 14:03:46 +0100 Subject: [PATCH] perf: Vectorized nested loop join for in-memory engine (#20495) --- crates/polars-core/src/frame/horizontal.rs | 2 + crates/polars-core/src/frame/mod.rs | 50 ++++-- crates/polars-expr/src/expressions/window.rs | 2 +- crates/polars-io/src/avro/write.rs | 2 +- crates/polars-io/src/ipc/write.rs | 2 +- crates/polars-io/src/parquet/write/writer.rs | 2 +- .../polars-lazy/src/physical_plan/exotic.rs | 2 +- crates/polars-lazy/src/tests/io.rs | 2 +- crates/polars-lazy/src/tests/pdsh.rs | 2 +- .../polars-mem-engine/src/executors/join.rs | 6 +- .../polars-mem-engine/src/executors/sort.rs | 2 +- crates/polars-mem-engine/src/planner/lp.rs | 30 ++++ crates/polars-ops/src/frame/join/args.rs | 120 ++++++++----- .../polars-ops/src/frame/join/cross_join.rs | 163 ++++++++++++------ crates/polars-ops/src/frame/join/general.rs | 20 ++- crates/polars-ops/src/frame/join/mod.rs | 55 +++++- crates/polars-ops/src/frame/pivot/unpivot.rs | 2 +- crates/polars-ops/src/series/ops/replace.rs | 2 + .../src/executors/operators/reproject.rs | 2 +- .../sinks/joins/generic_probe_inner_left.rs | 1 + .../sinks/joins/generic_probe_outer.rs | 1 + .../src/executors/sinks/sort/sink_multiple.rs | 2 + .../polars-pipe/src/executors/sources/csv.rs | 9 +- crates/polars-pipe/src/pipeline/convert.rs | 4 +- crates/polars-plan/src/dsl/options.rs | 51 ++++++ crates/polars-plan/src/plans/builder_dsl.rs | 2 +- .../polars-plan/src/plans/conversion/join.rs | 13 +- crates/polars-plan/src/plans/functions/mod.rs | 2 +- .../src/plans/functions/python_udf.rs | 8 +- crates/polars-plan/src/plans/ir/format.rs | 26 ++- crates/polars-plan/src/plans/mod.rs | 4 +- .../src/plans/optimizer/collapse_joins.rs | 19 +- .../optimizer/predicate_pushdown/join.rs | 3 +- .../src/plans/optimizer/slice_pushdown_lp.rs | 2 +- crates/polars-python/src/lazyframe/visit.rs | 2 +- .../src/lazyframe/visitor/nodes.rs | 31 +++- crates/polars-stream/src/physical_plan/fmt.rs | 1 + .../src/physical_plan/lower_ir.rs | 2 + crates/polars-stream/src/physical_plan/mod.rs | 2 + .../src/physical_plan/to_graph.rs | 2 + crates/polars-time/src/upsample.rs | 1 + crates/polars/tests/it/chunks/parquet.rs | 8 +- crates/polars/tests/it/core/joins.rs | 40 ++++- crates/polars/tests/it/io/json.rs | 2 +- .../unit/lazyframe/test_optimizations.py | 8 +- .../tests/unit/operations/test_cross_join.py | 31 ++++ .../unit/operations/test_inequality_join.py | 5 +- 47 files changed, 557 insertions(+), 193 deletions(-) diff --git a/crates/polars-core/src/frame/horizontal.rs b/crates/polars-core/src/frame/horizontal.rs index 0886c6b3f958..3a06cbc2ff55 100644 --- a/crates/polars-core/src/frame/horizontal.rs +++ b/crates/polars-core/src/frame/horizontal.rs @@ -41,6 +41,7 @@ impl DataFrame { } } + self.clear_schema(); self.columns.extend_from_slice(columns); self } @@ -99,6 +100,7 @@ pub fn concat_df_horizontal(dfs: &[DataFrame], check_duplicates: bool) -> Polars unsafe { df.get_columns_mut() }.iter_mut().for_each(|c| { *c = c.extend_constant(AnyValue::Null, diff).unwrap(); }); + df.clear_schema(); unsafe { df.set_height(output_height); } diff --git a/crates/polars-core/src/frame/mod.rs b/crates/polars-core/src/frame/mod.rs index a0cce52e65d1..35df9e3ba96f 100644 --- a/crates/polars-core/src/frame/mod.rs +++ b/crates/polars-core/src/frame/mod.rs @@ -176,6 +176,10 @@ pub struct DataFrame { } impl DataFrame { + pub fn clear_schema(&mut self) { + self.cached_schema = OnceLock::new(); + } + #[inline] pub fn materialized_column_iter(&self) -> impl ExactSizeIterator { self.columns.iter().map(Column::as_materialized_series) @@ -416,6 +420,8 @@ impl DataFrame { /// # Ok::<(), PolarsError>(()) /// ``` pub fn pop(&mut self) -> Option { + self.clear_schema(); + self.columns.pop() } @@ -477,6 +483,7 @@ impl DataFrame { ); ca.set_sorted_flag(IsSorted::Ascending); + self.clear_schema(); self.columns.insert(0, ca.into_series().into()); self } @@ -687,14 +694,22 @@ impl DataFrame { /// let f2: Field = Field::new("Diameter (m)".into(), DataType::Float64); /// let sc: Schema = Schema::from_iter(vec![f1, f2]); /// - /// assert_eq!(df.schema(), sc); + /// assert_eq!(&**df.schema(), &sc); /// # Ok::<(), PolarsError>(()) /// ``` - pub fn schema(&self) -> Schema { - self.columns - .iter() - .map(|x| (x.name().clone(), x.dtype().clone())) - .collect() + pub fn schema(&self) -> &SchemaRef { + let out = self.cached_schema.get_or_init(|| { + Arc::new( + self.columns + .iter() + .map(|x| (x.name().clone(), x.dtype().clone())) + .collect(), + ) + }); + + debug_assert_eq!(out.len(), self.width()); + + out } /// Get a reference to the [`DataFrame`] columns. @@ -723,6 +738,8 @@ impl DataFrame { /// /// The caller must ensure the length of all [`Series`] remains equal to `height` or /// [`DataFrame::set_height`] is called afterwards with the appropriate `height`. + /// The caller must ensure that the cached schema is cleared if it modifies the schema by + /// calling [`DataFrame::clear_schema`]. pub unsafe fn get_columns_mut(&mut self) -> &mut Vec { &mut self.columns } @@ -730,7 +747,8 @@ impl DataFrame { #[inline] /// Remove all the columns in the [`DataFrame`] but keep the `height`. pub fn clear_columns(&mut self) { - unsafe { self.get_columns_mut() }.clear() + unsafe { self.get_columns_mut() }.clear(); + self.clear_schema(); } #[inline] @@ -744,7 +762,8 @@ impl DataFrame { /// `DataFrame`]s with no columns (ZCDFs), it is important that the height is set afterwards /// with [`DataFrame::set_height`]. pub unsafe fn column_extend_unchecked(&mut self, iter: impl IntoIterator) { - unsafe { self.get_columns_mut() }.extend(iter) + unsafe { self.get_columns_mut() }.extend(iter); + self.clear_schema(); } /// Take ownership of the underlying columns vec. @@ -834,6 +853,7 @@ impl DataFrame { s }) .collect(); + self.clear_schema(); Ok(()) } @@ -1194,6 +1214,7 @@ impl DataFrame { Ok(()) })?; self.height += other.height; + self.clear_schema(); Ok(()) } @@ -1215,6 +1236,7 @@ impl DataFrame { /// ``` pub fn drop_in_place(&mut self, name: &str) -> PolarsResult { let idx = self.check_name_to_idx(name)?; + self.clear_schema(); Ok(self.columns.remove(idx)) } @@ -1347,6 +1369,7 @@ impl DataFrame { } self.columns.insert(index, column); + self.clear_schema(); Ok(self) } @@ -1370,6 +1393,7 @@ impl DataFrame { } self.columns.push(column); + self.clear_schema(); } Ok(()) } @@ -1417,6 +1441,7 @@ impl DataFrame { unsafe { self.set_height(column.len()) }; } unsafe { self.get_columns_mut() }.push(column); + self.clear_schema(); self } @@ -1433,6 +1458,7 @@ impl DataFrame { } self.columns.push(c); + self.clear_schema(); } // Schema is incorrect fallback to search else { @@ -1448,6 +1474,7 @@ impl DataFrame { } self.columns.push(c); + self.clear_schema(); } Ok(()) @@ -1637,7 +1664,7 @@ impl DataFrame { /// # Ok::<(), PolarsError>(()) /// ``` pub fn get_column_index(&self, name: &str) -> Option { - let schema = self.cached_schema.get_or_init(|| Arc::new(self.schema())); + let schema = self.schema(); if let Some(idx) = schema.index_of(name) { if self .get_columns() @@ -1775,7 +1802,7 @@ impl DataFrame { cols: &[PlSmallStr], schema: &Schema, ) -> PolarsResult> { - debug_ensure_matching_schema_names(schema, &self.schema())?; + debug_ensure_matching_schema_names(schema, self.schema())?; cols.iter() .map(|name| { @@ -1984,7 +2011,7 @@ impl DataFrame { return Ok(self); } polars_ensure!( - self.columns.iter().all(|c| c.name() != &name), + !self.schema().contains(&name), Duplicate: "column rename attempted with already existing name \"{name}\"" ); @@ -2326,6 +2353,7 @@ impl DataFrame { ); let old_col = &mut self.columns[index]; mem::swap(old_col, &mut new_column); + self.clear_schema(); Ok(self) } diff --git a/crates/polars-expr/src/expressions/window.rs b/crates/polars-expr/src/expressions/window.rs index d833278a12cb..b799d55467e5 100644 --- a/crates/polars-expr/src/expressions/window.rs +++ b/crates/polars-expr/src/expressions/window.rs @@ -399,7 +399,7 @@ impl PhysicalExpr for WindowExpr { // 4. select the final column and return if df.is_empty() { - let field = self.phys_function.to_field(&df.schema())?; + let field = self.phys_function.to_field(df.schema())?; return Ok(Column::full_null(field.name().clone(), 0, field.dtype())); } diff --git a/crates/polars-io/src/avro/write.rs b/crates/polars-io/src/avro/write.rs index 2954de97d964..2681fac9433b 100644 --- a/crates/polars-io/src/avro/write.rs +++ b/crates/polars-io/src/avro/write.rs @@ -64,7 +64,7 @@ where } fn finish(&mut self, df: &mut DataFrame) -> PolarsResult<()> { - let schema = schema_to_arrow_checked(&df.schema(), CompatLevel::oldest(), "avro")?; + let schema = schema_to_arrow_checked(df.schema(), CompatLevel::oldest(), "avro")?; let record = write::to_record(&schema, self.name.clone())?; let mut data = vec![]; diff --git a/crates/polars-io/src/ipc/write.rs b/crates/polars-io/src/ipc/write.rs index 38b5d1d27fde..0f13b9967b07 100644 --- a/crates/polars-io/src/ipc/write.rs +++ b/crates/polars-io/src/ipc/write.rs @@ -116,7 +116,7 @@ where } fn finish(&mut self, df: &mut DataFrame) -> PolarsResult<()> { - let schema = schema_to_arrow_checked(&df.schema(), self.compat_level, "ipc")?; + let schema = schema_to_arrow_checked(df.schema(), self.compat_level, "ipc")?; let mut ipc_writer = write::FileWriter::try_new( &mut self.writer, Arc::new(schema), diff --git a/crates/polars-io/src/parquet/write/writer.rs b/crates/polars-io/src/parquet/write/writer.rs index 6ef167672d96..13885316d9d7 100644 --- a/crates/polars-io/src/parquet/write/writer.rs +++ b/crates/polars-io/src/parquet/write/writer.rs @@ -124,7 +124,7 @@ where /// Write the given DataFrame in the writer `W`. Returns the total size of the file. pub fn finish(self, df: &mut DataFrame) -> PolarsResult { let chunked_df = chunk_df_for_writing(df, self.row_group_size.unwrap_or(512 * 512))?; - let mut batched = self.batched(&chunked_df.schema())?; + let mut batched = self.batched(chunked_df.schema())?; batched.write_batch(&chunked_df)?; batched.finish() } diff --git a/crates/polars-lazy/src/physical_plan/exotic.rs b/crates/polars-lazy/src/physical_plan/exotic.rs index 08673ca1f032..29a155310eb5 100644 --- a/crates/polars-lazy/src/physical_plan/exotic.rs +++ b/crates/polars-lazy/src/physical_plan/exotic.rs @@ -25,7 +25,7 @@ pub(crate) fn prepare_expression_for_context( // type coercion and simplify expression optimizations run. let column = Series::full_null(name, 0, dtype); let df = column.into_frame(); - let input_schema = Arc::new(df.schema()); + let input_schema = df.schema().clone(); let lf = df .lazy() .without_optimizations() diff --git a/crates/polars-lazy/src/tests/io.rs b/crates/polars-lazy/src/tests/io.rs index a1d3f2c050a8..435fa74a373a 100644 --- a/crates/polars-lazy/src/tests/io.rs +++ b/crates/polars-lazy/src/tests/io.rs @@ -672,7 +672,7 @@ fn scan_anonymous_fn_with_options() -> PolarsResult<()> { let function = Arc::new(MyScan {}); let args = ScanArgsAnonymous { - schema: Some(Arc::new(fruits_cars().schema())), + schema: Some(fruits_cars().schema().clone()), ..ScanArgsAnonymous::default() }; diff --git a/crates/polars-lazy/src/tests/pdsh.rs b/crates/polars-lazy/src/tests/pdsh.rs index 426b19506684..f0de0b641446 100644 --- a/crates/polars-lazy/src/tests/pdsh.rs +++ b/crates/polars-lazy/src/tests/pdsh.rs @@ -107,7 +107,7 @@ fn test_q2() -> PolarsResult<()> { Field::new("s_phone".into(), DataType::String), Field::new("s_comment".into(), DataType::String), ]); - assert_eq!(&out.schema(), &schema); + assert_eq!(&**out.schema(), &schema); Ok(()) } diff --git a/crates/polars-mem-engine/src/executors/join.rs b/crates/polars-mem-engine/src/executors/join.rs index 4fed3cb7a3ff..5215b0ca06c2 100644 --- a/crates/polars-mem-engine/src/executors/join.rs +++ b/crates/polars-mem-engine/src/executors/join.rs @@ -9,6 +9,7 @@ pub struct JoinExec { right_on: Vec>, parallel: bool, args: JoinArgs, + options: Option, } impl JoinExec { @@ -20,6 +21,7 @@ impl JoinExec { right_on: Vec>, parallel: bool, args: JoinArgs, + options: Option, ) -> Self { JoinExec { input_left: Some(input_left), @@ -28,6 +30,7 @@ impl JoinExec { right_on, parallel, args, + options, } } } @@ -75,7 +78,7 @@ impl Executor for JoinExec { let by = self .left_on .iter() - .map(|s| Ok(s.to_field(&df_left.schema())?.name)) + .map(|s| Ok(s.to_field(df_left.schema())?.name)) .collect::>>()?; let name = comma_delimited("join".to_string(), &by); Cow::Owned(name) @@ -142,6 +145,7 @@ impl Executor for JoinExec { left_on_series.into_iter().map(|c| c.take_materialized_series()).collect(), right_on_series.into_iter().map(|c| c.take_materialized_series()).collect(), self.args.clone(), + self.options.clone(), true, state.verbose(), ); diff --git a/crates/polars-mem-engine/src/executors/sort.rs b/crates/polars-mem-engine/src/executors/sort.rs index a50e38af2750..ec1e0aad276c 100644 --- a/crates/polars-mem-engine/src/executors/sort.rs +++ b/crates/polars-mem-engine/src/executors/sort.rs @@ -61,7 +61,7 @@ impl Executor for SortExec { let by = self .by_column .iter() - .map(|s| Ok(s.to_field(&df.schema())?.name)) + .map(|s| Ok(s.to_field(df.schema())?.name)) .collect::>>()?; let name = comma_delimited("sort".to_string(), &by); Cow::Owned(name) diff --git a/crates/polars-mem-engine/src/planner/lp.rs b/crates/polars-mem-engine/src/planner/lp.rs index 9a8513052477..8850ab1173d4 100644 --- a/crates/polars-mem-engine/src/planner/lp.rs +++ b/crates/polars-mem-engine/src/planner/lp.rs @@ -1,5 +1,6 @@ use polars_core::prelude::*; use polars_core::POOL; +use polars_expr::state::ExecutionState; use polars_plan::global::_set_n_rows_for_scan; use polars_plan::plans::expr_ir::ExprIR; @@ -484,6 +485,7 @@ fn create_physical_plan_impl( left_on, right_on, options, + schema, .. } => { let parallel = if options.force_parallel { @@ -521,6 +523,33 @@ fn create_physical_plan_impl( &mut ExpressionConversionState::new(true, state.expr_depth), )?; let options = Arc::try_unwrap(options).unwrap_or_else(|options| (*options).clone()); + + // Convert the join options, to the physical join options. This requires the physical + // planner, so we do this last minute. + let join_type_options = options + .options + .map(|o| { + o.compile(|e| { + let phys_expr = create_physical_expr( + e, + Context::Default, + expr_arena, + &schema, + &mut ExpressionConversionState::new(false, state.expr_depth), + )?; + + let execution_state = ExecutionState::default(); + + Ok(Arc::new(move |df: DataFrame| { + let mask = phys_expr.evaluate(&df, &execution_state)?; + let mask = mask.as_materialized_series(); + let mask = mask.bool()?; + df._filter_seq(mask) + })) + }) + }) + .transpose()?; + Ok(Box::new(executors::JoinExec::new( input_left, input_right, @@ -528,6 +557,7 @@ fn create_physical_plan_impl( right_on, parallel, options.args, + join_type_options, ))) }, HStack { diff --git a/crates/polars-ops/src/frame/join/args.rs b/crates/polars-ops/src/frame/join/args.rs index e2df38daf88b..903cd6e2e9f2 100644 --- a/crates/polars-ops/src/frame/join/args.rs +++ b/crates/polars-ops/src/frame/join/args.rs @@ -20,7 +20,7 @@ use polars_core::export::once_cell::sync::Lazy; use serde::{Deserialize, Serialize}; use strum_macros::IntoStaticStr; -#[derive(Clone, PartialEq, Eq, Debug, Hash)] +#[derive(Clone, PartialEq, Eq, Debug, Hash, Default)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct JoinArgs { pub how: JoinType, @@ -38,6 +38,27 @@ impl JoinArgs { } } +#[derive(Clone, PartialEq, Eq, Hash, Default, IntoStaticStr)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum JoinType { + #[default] + Inner, + Left, + Right, + Full, + #[cfg(feature = "asof_join")] + AsOf(AsOfOptions), + #[cfg(feature = "semi_anti_join")] + Semi, + #[cfg(feature = "semi_anti_join")] + Anti, + #[cfg(feature = "iejoin")] + // Options are set by optimizer/planner in Options + IEJoin, + // Options are set by optimizer/planner in Options + Cross, +} + #[derive(Copy, Clone, PartialEq, Eq, Debug, Hash, Default)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub enum JoinCoalesce { @@ -61,7 +82,7 @@ impl JoinCoalesce { #[cfg(feature = "asof_join")] AsOf(_) => matches!(self, JoinSpecific | CoalesceColumns), #[cfg(feature = "iejoin")] - IEJoin(_) => false, + IEJoin => false, Cross => false, #[cfg(feature = "semi_anti_join")] Semi | Anti => false, @@ -93,20 +114,6 @@ impl MaintainOrderJoin { } } -impl Default for JoinArgs { - fn default() -> Self { - Self { - how: JoinType::Inner, - validation: Default::default(), - suffix: None, - slice: None, - join_nulls: false, - coalesce: Default::default(), - maintain_order: Default::default(), - } - } -} - impl JoinArgs { pub fn new(how: JoinType) -> Self { Self { @@ -136,31 +143,66 @@ impl JoinArgs { } } -#[derive(Clone, PartialEq, Eq, Hash, IntoStaticStr)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -#[strum(serialize_all = "snake_case")] -pub enum JoinType { - Inner, - Left, - Right, - Full, - #[cfg(feature = "asof_join")] - AsOf(AsOfOptions), - Cross, - #[cfg(feature = "semi_anti_join")] - Semi, - #[cfg(feature = "semi_anti_join")] - Anti, - #[cfg(feature = "iejoin")] - IEJoin(IEJoinOptions), -} - impl From for JoinArgs { fn from(value: JoinType) -> Self { JoinArgs::new(value) } } +pub trait CrossJoinFilter: Send + Sync { + fn apply(&self, df: DataFrame) -> PolarsResult; +} + +impl CrossJoinFilter for T +where + T: Fn(DataFrame) -> PolarsResult + Send + Sync, +{ + fn apply(&self, df: DataFrame) -> PolarsResult { + self(df) + } +} + +#[derive(Clone)] +pub struct CrossJoinOptions { + pub predicate: Arc, +} + +impl CrossJoinOptions { + fn as_ptr_ref(&self) -> *const dyn CrossJoinFilter { + Arc::as_ptr(&self.predicate) + } +} + +impl Eq for CrossJoinOptions {} + +impl PartialEq for CrossJoinOptions { + fn eq(&self, other: &Self) -> bool { + std::ptr::addr_eq(self.as_ptr_ref(), other.as_ptr_ref()) + } +} + +impl Hash for CrossJoinOptions { + fn hash(&self, state: &mut H) { + self.as_ptr_ref().hash(state); + } +} + +impl Debug for CrossJoinOptions { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "CrossJoinOptions",) + } +} + +#[derive(Clone, PartialEq, Eq, Hash, IntoStaticStr, Debug)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[strum(serialize_all = "snake_case")] +pub enum JoinTypeOptions { + #[cfg(feature = "iejoin")] + IEJoin(IEJoinOptions), + #[cfg_attr(feature = "serde", serde(skip))] + Cross(CrossJoinOptions), +} + impl Display for JoinType { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { use JoinType::*; @@ -168,11 +210,11 @@ impl Display for JoinType { Left => "LEFT", Right => "RIGHT", Inner => "INNER", - Full { .. } => "FULL", + Full => "FULL", #[cfg(feature = "asof_join")] AsOf(_) => "ASOF", #[cfg(feature = "iejoin")] - IEJoin(_) => "IEJOIN", + IEJoin => "IEJOIN", Cross => "CROSS", #[cfg(feature = "semi_anti_join")] Semi => "SEMI", @@ -215,7 +257,7 @@ impl JoinType { pub fn is_ie(&self) -> bool { #[cfg(feature = "iejoin")] { - matches!(self, JoinType::IEJoin(_)) + matches!(self, JoinType::IEJoin) } #[cfg(not(feature = "iejoin"))] { @@ -261,7 +303,7 @@ impl JoinValidation { if !self.needs_checks() { return Ok(()); } - polars_ensure!(matches!(join_type, JoinType::Inner | JoinType::Full{..} | JoinType::Left), + polars_ensure!(matches!(join_type, JoinType::Inner | JoinType::Full | JoinType::Left), ComputeError: "{self} validation on a {join_type} join is not supported"); Ok(()) } diff --git a/crates/polars-ops/src/frame/join/cross_join.rs b/crates/polars-ops/src/frame/join/cross_join.rs index c4290e262627..f54a3fd71393 100644 --- a/crates/polars-ops/src/frame/join/cross_join.rs +++ b/crates/polars-ops/src/frame/join/cross_join.rs @@ -1,4 +1,7 @@ -use polars_core::utils::{concat_df_unchecked, CustomIterTools, NoNull}; +use polars_core::utils::{ + concat_df_unchecked, CustomIterTools, NoNull, _set_partition_size, + accumulate_dataframes_vertical_unchecked, split, +}; use polars_utils::pl_str::PlSmallStr; use super::*; @@ -40,60 +43,6 @@ fn take_right(total_rows: IdxSize, n_rows_right: IdxSize, slice: Option<(i64, us } pub trait CrossJoin: IntoDf { - fn cross_join_dfs( - &self, - other: &DataFrame, - slice: Option<(i64, usize)>, - parallel: bool, - ) -> PolarsResult<(DataFrame, DataFrame)> { - let df_self = self.to_df(); - let n_rows_left = df_self.height() as IdxSize; - let n_rows_right = other.height() as IdxSize; - let Some(total_rows) = n_rows_left.checked_mul(n_rows_right) else { - polars_bail!( - ComputeError: "cross joins would produce more rows than fits into 2^32; \ - consider compiling with polars-big-idx feature, or set 'streaming'" - ); - }; - if n_rows_left == 0 || n_rows_right == 0 { - return Ok((df_self.clear(), other.clear())); - } - - // the left side has the Nth row combined with every row from right. - // So let's say we have the following no. of rows - // left: 3 - // right: 4 - // - // left take idx: 000011112222 - // right take idx: 012301230123 - - let create_left_df = || { - // SAFETY: - // take left is in bounds - unsafe { df_self.take_unchecked(&take_left(total_rows, n_rows_right, slice)) } - }; - - let create_right_df = || { - // concatenation of dataframes is very expensive if we need to make the series mutable - // many times, these are atomic operations - // so we choose a different strategy at > 100 rows (arbitrarily small number) - if n_rows_left > 100 || slice.is_some() { - // SAFETY: - // take right is in bounds - unsafe { other.take_unchecked(&take_right(total_rows, n_rows_right, slice)) } - } else { - let iter = (0..n_rows_left).map(|_| other); - concat_df_unchecked(iter) - } - }; - let (l_df, r_df) = if parallel { - POOL.install(|| rayon::join(create_left_df, create_right_df)) - } else { - (create_left_df(), create_right_df()) - }; - Ok((l_df, r_df)) - } - #[doc(hidden)] /// used by streaming fn _cross_join_with_names( @@ -101,7 +50,8 @@ pub trait CrossJoin: IntoDf { other: &DataFrame, names: &[PlSmallStr], ) -> PolarsResult { - let (mut l_df, r_df) = self.cross_join_dfs(other, None, false)?; + let (mut l_df, r_df) = cross_join_dfs(self.to_df(), other, None, false)?; + l_df.clear_schema(); unsafe { l_df.get_columns_mut().extend_from_slice(r_df.get_columns()); @@ -125,10 +75,109 @@ pub trait CrossJoin: IntoDf { suffix: Option, slice: Option<(i64, usize)>, ) -> PolarsResult { - let (l_df, r_df) = self.cross_join_dfs(other, slice, true)?; + let (l_df, r_df) = cross_join_dfs(self.to_df(), other, slice, true)?; _finish_join(l_df, r_df, suffix) } } impl CrossJoin for DataFrame {} + +fn cross_join_dfs( + df_self: &DataFrame, + other: &DataFrame, + slice: Option<(i64, usize)>, + parallel: bool, +) -> PolarsResult<(DataFrame, DataFrame)> { + let n_rows_left = df_self.height() as IdxSize; + let n_rows_right = other.height() as IdxSize; + let Some(total_rows) = n_rows_left.checked_mul(n_rows_right) else { + polars_bail!( + ComputeError: "cross joins would produce more rows than fits into 2^32; \ + consider compiling with polars-big-idx feature, or set 'streaming'" + ); + }; + if n_rows_left == 0 || n_rows_right == 0 { + return Ok((df_self.clear(), other.clear())); + } + + // the left side has the Nth row combined with every row from right. + // So let's say we have the following no. of rows + // left: 3 + // right: 4 + // + // left take idx: 000011112222 + // right take idx: 012301230123 + + let create_left_df = || { + // SAFETY: + // take left is in bounds + unsafe { + df_self.take_unchecked_impl(&take_left(total_rows, n_rows_right, slice), parallel) + } + }; + + let create_right_df = || { + // concatenation of dataframes is very expensive if we need to make the series mutable + // many times, these are atomic operations + // so we choose a different strategy at > 100 rows (arbitrarily small number) + if n_rows_left > 100 || slice.is_some() { + // SAFETY: + // take right is in bounds + unsafe { + other.take_unchecked_impl(&take_right(total_rows, n_rows_right, slice), parallel) + } + } else { + let iter = (0..n_rows_left).map(|_| other); + concat_df_unchecked(iter) + } + }; + let (l_df, r_df) = if parallel { + POOL.install(|| rayon::join(create_left_df, create_right_df)) + } else { + (create_left_df(), create_right_df()) + }; + Ok((l_df, r_df)) +} + +pub(super) fn fused_cross_filter( + left: &DataFrame, + right: &DataFrame, + suffix: Option, + cross_join_options: &CrossJoinOptions, +) -> PolarsResult { + // Because we do a cartesian product, the number of partitions is squared. + // We take the sqrt, but we don't expect every partition to produce results and work can be + // imbalanced, so we multiply the number of partitions by 2; + let n_partitions = (_set_partition_size() as f32).sqrt() as usize * 2; + let splitted_a = split(left, n_partitions); + let splitted_b = split(right, n_partitions); + + let cartesian_prod = splitted_a + .iter() + .flat_map(|l| splitted_b.iter().map(move |r| (l, r))) + .collect::>(); + + let names = _finish_join(left.clear(), right.clear(), suffix)?; + let rename_names = names.get_column_names(); + let rename_names = &rename_names[left.width()..]; + + let dfs = POOL + .install(|| { + cartesian_prod.par_iter().map(|(left, right)| { + let (mut left, right) = cross_join_dfs(left, right, None, false)?; + let mut right_columns = right.take_columns(); + + for (c, name) in right_columns.iter_mut().zip(rename_names) { + c.rename((*name).clone()); + } + + unsafe { left.hstack_mut_unchecked(&right_columns) }; + + cross_join_options.predicate.apply(left) + }) + }) + .collect::>>()?; + + Ok(accumulate_dataframes_vertical_unchecked(dfs)) +} diff --git a/crates/polars-ops/src/frame/join/general.rs b/crates/polars-ops/src/frame/join/general.rs index 5ea6ef68638d..caeddbdb8c67 100644 --- a/crates/polars-ops/src/frame/join/general.rs +++ b/crates/polars-ops/src/frame/join/general.rs @@ -26,16 +26,19 @@ pub fn _finish_join( }); let mut rename_strs = Vec::with_capacity(df_right.width()); + let right_names = df_right.schema(); - df_right.get_columns().iter().for_each(|series| { - if left_names.contains(series.name()) { - rename_strs.push(series.name().to_owned()) + for name in right_names.iter_names() { + if left_names.contains(name) { + rename_strs.push(name.clone()) } - }); + } + let suffix = get_suffix(suffix); for name in rename_strs { let new_name = _join_suffix_name(name.as_str(), suffix.as_str()); + df_right.rename(&name, new_name.clone()).map_err(|_| { polars_err!(Duplicate: "column with name '{}' already exists\n\n\ You may want to try:\n\ @@ -45,7 +48,7 @@ pub fn _finish_join( } drop(left_names); - df_left.hstack_mut(df_right.get_columns())?; + unsafe { df_left.hstack_mut_unchecked(df_right.get_columns()) }; Ok(df_left) } @@ -60,12 +63,12 @@ pub fn _coalesce_full_join( // know for certain that the column name for left is `name` // and for right is `name + suffix` let schema_left = if keys_left == keys_right { - Schema::default() + Arc::new(Schema::default()) } else { - df_left.schema() + df_left.schema().clone() }; - let schema = df.schema(); + let schema = df.schema().clone(); let mut to_remove = Vec::with_capacity(keys_right.len()); // SAFETY: we maintain invariants. @@ -92,6 +95,7 @@ pub fn _coalesce_full_join( for pos in to_remove { let _ = columns.remove(pos); } + df.clear_schema(); df } diff --git a/crates/polars-ops/src/frame/join/mod.rs b/crates/polars-ops/src/frame/join/mod.rs index 62ea32cb9ac6..2ef528f4232e 100644 --- a/crates/polars-ops/src/frame/join/mod.rs +++ b/crates/polars-ops/src/frame/join/mod.rs @@ -48,6 +48,7 @@ use polars_core::POOL; use polars_utils::hashing::BytesHash; use rayon::prelude::*; +use self::cross_join::fused_cross_filter; use super::IntoDf; pub trait DataFrameJoinOps: IntoDf { @@ -63,7 +64,8 @@ pub trait DataFrameJoinOps: IntoDf { /// let df2: DataFrame = df!("Name" => &["Apple", "Banana", "Pear"], /// "Potassium (mg/100g)" => &[107, 358, 115])?; /// - /// let df3: DataFrame = df1.join(&df2, ["Fruit"], ["Name"], JoinArgs::new(JoinType::Inner))?; + /// let df3: DataFrame = df1.join(&df2, ["Fruit"], ["Name"], JoinArgs::new(JoinType::Inner), + /// None)?; /// assert_eq!(df3.shape(), (3, 3)); /// println!("{}", df3); /// # Ok::<(), PolarsError>(()) @@ -91,6 +93,7 @@ pub trait DataFrameJoinOps: IntoDf { left_on: impl IntoIterator>, right_on: impl IntoIterator>, args: JoinArgs, + options: Option, ) -> PolarsResult { let df_left = self.to_df(); let selected_left = df_left.select_columns(left_on)?; @@ -105,7 +108,15 @@ pub trait DataFrameJoinOps: IntoDf { .map(Column::take_materialized_series) .collect::>(); - self._join_impl(other, selected_left, selected_right, args, true, false) + self._join_impl( + other, + selected_left, + selected_right, + args, + options, + true, + false, + ) } #[doc(hidden)] @@ -117,6 +128,7 @@ pub trait DataFrameJoinOps: IntoDf { mut selected_left: Vec, mut selected_right: Vec, mut args: JoinArgs, + options: Option, _check_rechunk: bool, _verbose: bool, ) -> PolarsResult { @@ -124,6 +136,10 @@ pub trait DataFrameJoinOps: IntoDf { #[cfg(feature = "cross_join")] if let JoinType::Cross = args.how { + if let Some(JoinTypeOptions::Cross(cross_options)) = &options { + assert!(args.slice.is_none()); + return fused_cross_filter(left_df, other, args.suffix.clone(), cross_options); + } return left_df.cross_join(other, args.suffix.clone(), args.slice); } @@ -178,6 +194,7 @@ pub trait DataFrameJoinOps: IntoDf { selected_left, selected_right, args, + options, false, _verbose, ); @@ -212,7 +229,10 @@ pub trait DataFrameJoinOps: IntoDf { } #[cfg(feature = "iejoin")] - if let JoinType::IEJoin(options) = args.how { + if let JoinType::IEJoin = args.how { + let Some(JoinTypeOptions::IEJoin(options)) = options else { + unreachable!() + }; let func = if POOL.current_num_threads() > 1 && !left_df.is_empty() && !other.is_empty() { iejoin::iejoin_par @@ -303,7 +323,7 @@ pub trait DataFrameJoinOps: IntoDf { }, }, #[cfg(feature = "iejoin")] - JoinType::IEJoin(_) => { + JoinType::IEJoin => { unreachable!() }, JoinType::Cross => { @@ -331,7 +351,7 @@ pub trait DataFrameJoinOps: IntoDf { ComputeError: "asof join not supported for join on multiple keys" ), #[cfg(feature = "iejoin")] - JoinType::IEJoin(_) => { + JoinType::IEJoin => { unreachable!() }, JoinType::Cross => { @@ -390,6 +410,7 @@ pub trait DataFrameJoinOps: IntoDf { vec![lhs_keys], vec![rhs_keys], args, + options, _check_rechunk, _verbose, ), @@ -413,7 +434,13 @@ pub trait DataFrameJoinOps: IntoDf { left_on: impl IntoIterator>, right_on: impl IntoIterator>, ) -> PolarsResult { - self.join(other, left_on, right_on, JoinArgs::new(JoinType::Inner)) + self.join( + other, + left_on, + right_on, + JoinArgs::new(JoinType::Inner), + None, + ) } /// Perform a left outer join on two DataFrames @@ -457,7 +484,13 @@ pub trait DataFrameJoinOps: IntoDf { left_on: impl IntoIterator>, right_on: impl IntoIterator>, ) -> PolarsResult { - self.join(other, left_on, right_on, JoinArgs::new(JoinType::Left)) + self.join( + other, + left_on, + right_on, + JoinArgs::new(JoinType::Left), + None, + ) } /// Perform a full outer join on two DataFrames @@ -476,7 +509,13 @@ pub trait DataFrameJoinOps: IntoDf { left_on: impl IntoIterator>, right_on: impl IntoIterator>, ) -> PolarsResult { - self.join(other, left_on, right_on, JoinArgs::new(JoinType::Full)) + self.join( + other, + left_on, + right_on, + JoinArgs::new(JoinType::Full), + None, + ) } } diff --git a/crates/polars-ops/src/frame/pivot/unpivot.rs b/crates/polars-ops/src/frame/pivot/unpivot.rs index 49eeaeba4498..45608b3c8c47 100644 --- a/crates/polars-ops/src/frame/pivot/unpivot.rs +++ b/crates/polars-ops/src/frame/pivot/unpivot.rs @@ -145,7 +145,7 @@ pub trait UnpivotDF: IntoDf { // The column name of the variable that is unpivoted let mut variable_col = MutablePlString::with_capacity(len * on.len() + 1); // prepare ids - let ids_ = self_.select_with_schema_unchecked(index, &schema)?; + let ids_ = self_.select_with_schema_unchecked(index, schema)?; let mut ids = ids_.clone(); if ids.width() > 0 { for _ in 0..on.len() - 1 { diff --git a/crates/polars-ops/src/series/ops/replace.rs b/crates/polars-ops/src/series/ops/replace.rs index 4aa84910239c..538994ce6151 100644 --- a/crates/polars-ops/src/series/ops/replace.rs +++ b/crates/polars-ops/src/series/ops/replace.rs @@ -177,6 +177,7 @@ fn replace_by_multiple( join_nulls: true, ..Default::default() }, + None, )?; let replaced = joined @@ -218,6 +219,7 @@ fn replace_by_multiple_strict(s: &Series, old: Series, new: Series) -> PolarsRes join_nulls: true, ..Default::default() }, + None, )?; let replaced = joined.column("__POLARS_REPLACE_NEW").unwrap(); diff --git a/crates/polars-pipe/src/executors/operators/reproject.rs b/crates/polars-pipe/src/executors/operators/reproject.rs index d037937896d3..b3355b30cc87 100644 --- a/crates/polars-pipe/src/executors/operators/reproject.rs +++ b/crates/polars-pipe/src/executors/operators/reproject.rs @@ -16,7 +16,7 @@ pub(crate) fn reproject_chunk( let out = chunk .data - .select_with_schema_unchecked(schema.iter_names_cloned(), &chunk_schema)?; + .select_with_schema_unchecked(schema.iter_names_cloned(), chunk_schema)?; *positions = out .get_columns() diff --git a/crates/polars-pipe/src/executors/sinks/joins/generic_probe_inner_left.rs b/crates/polars-pipe/src/executors/sinks/joins/generic_probe_inner_left.rs index da3ed24f2fa4..6b966da07d3c 100644 --- a/crates/polars-pipe/src/executors/sinks/joins/generic_probe_inner_left.rs +++ b/crates/polars-pipe/src/executors/sinks/joins/generic_probe_inner_left.rs @@ -130,6 +130,7 @@ impl GenericJoinProbe { .for_each(|(s, name)| { s.rename(name.clone()); }); + left_df.clear_schema(); left_df }, }) diff --git a/crates/polars-pipe/src/executors/sinks/joins/generic_probe_outer.rs b/crates/polars-pipe/src/executors/sinks/joins/generic_probe_outer.rs index 82dea0326b7d..f511ae4d7797 100644 --- a/crates/polars-pipe/src/executors/sinks/joins/generic_probe_outer.rs +++ b/crates/polars-pipe/src/executors/sinks/joins/generic_probe_outer.rs @@ -130,6 +130,7 @@ impl GenericFullOuterJoinProbe { .for_each(|(s, name)| { s.rename(name.clone()); }); + left_df.clear_schema(); left_df }, }) diff --git a/crates/polars-pipe/src/executors/sinks/sort/sink_multiple.rs b/crates/polars-pipe/src/executors/sinks/sort/sink_multiple.rs index 024dc8522503..daa80f18b4f4 100644 --- a/crates/polars-pipe/src/executors/sinks/sort/sink_multiple.rs +++ b/crates/polars-pipe/src/executors/sinks/sort/sink_multiple.rs @@ -95,6 +95,7 @@ fn finalize_dataframe( // SAFETY: col has the same length as the df height because it was popped from df. unsafe { df.get_columns_mut() }.insert(sort_idx, col); + df.clear_schema(); } // SAFETY: We just change the sorted flag. @@ -217,6 +218,7 @@ impl SortSinkMultiple { let _ = cols.remove(sort_idx - i); }); + df.clear_schema(); let name = PlSmallStr::from_static(POLARS_SORT_COLUMN); let column = if chunk.data.height() == 0 && chunk.data.width() > 0 { Column::new_empty(name, &DataType::BinaryOffset) diff --git a/crates/polars-pipe/src/executors/sources/csv.rs b/crates/polars-pipe/src/executors/sources/csv.rs index 900be25256b4..f78674c6c3d1 100644 --- a/crates/polars-pipe/src/executors/sources/csv.rs +++ b/crates/polars-pipe/src/executors/sources/csv.rs @@ -28,7 +28,7 @@ pub(crate) struct CsvSource { // state for multi-file reads current_path_idx: usize, n_rows_read: usize, - first_schema: Schema, + first_schema: SchemaRef, include_file_path: Option, } @@ -189,9 +189,9 @@ impl Source for CsvSource { if first_read_from_file { if self.first_schema.is_empty() { - self.first_schema = batches[0].schema(); + self.first_schema = batches[0].schema().clone(); } - ensure_matching_schema(&self.first_schema, &batches[0].schema())?; + ensure_matching_schema(&self.first_schema, batches[0].schema())?; } let index = get_source_index(0); @@ -220,7 +220,8 @@ impl Source for CsvSource { // SAFETY: Columns are only replaced with columns // 1. of the same name, and // 2. of the same length. - unsafe { data_chunk.data.get_columns_mut() }.push(ca.slice(0, n).into_column()) + unsafe { data_chunk.data.get_columns_mut() }.push(ca.slice(0, n).into_column()); + data_chunk.data.clear_schema(); } } diff --git a/crates/polars-pipe/src/pipeline/convert.rs b/crates/polars-pipe/src/pipeline/convert.rs index dbb9535877b3..d4bd3eb80c9c 100644 --- a/crates/polars-pipe/src/pipeline/convert.rs +++ b/crates/polars-pipe/src/pipeline/convert.rs @@ -55,9 +55,7 @@ where df, output_schema, .. } => { let mut df = (*df).clone(); - let schema = output_schema - .clone() - .unwrap_or_else(|| Arc::new(df.schema())); + let schema = output_schema.clone().unwrap_or_else(|| df.schema().clone()); if push_predicate { // projection is free if let Some(schema) = output_schema { diff --git a/crates/polars-plan/src/dsl/options.rs b/crates/polars-plan/src/dsl/options.rs index 259d66af95ae..e8a72739f6fc 100644 --- a/crates/polars-plan/src/dsl/options.rs +++ b/crates/polars-plan/src/dsl/options.rs @@ -1,3 +1,10 @@ +use std::hash::Hash; +use std::sync::Arc; + +use polars_core::error::PolarsResult; +#[cfg(feature = "iejoin")] +use polars_ops::frame::IEJoinOptions; +use polars_ops::frame::{CrossJoinFilter, CrossJoinOptions, JoinTypeOptions}; use polars_ops::prelude::{JoinArgs, JoinType}; #[cfg(feature = "dynamic_group_by")] use polars_time::RollingGroupOptions; @@ -7,6 +14,7 @@ use polars_utils::IdxSize; use serde::{Deserialize, Serialize}; use strum_macros::IntoStaticStr; +use super::ExprIR; use crate::dsl::Selector; #[derive(Copy, Clone, PartialEq, Debug, Eq, Hash)] @@ -42,12 +50,53 @@ impl Default for StrptimeOptions { } } +#[derive(Clone, PartialEq, Eq, IntoStaticStr, Debug)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[strum(serialize_all = "snake_case")] +pub enum JoinTypeOptionsIR { + #[cfg(feature = "iejoin")] + IEJoin(IEJoinOptions), + #[cfg_attr(feature = "serde", serde(skip))] + // Fused cross join and filter (only in in-memory engine) + Cross { predicate: ExprIR }, +} + +impl Hash for JoinTypeOptionsIR { + fn hash(&self, state: &mut H) { + use JoinTypeOptionsIR::*; + match self { + #[cfg(feature = "iejoin")] + IEJoin(opt) => opt.hash(state), + Cross { predicate } => predicate.node().hash(state), + } + } +} + +impl JoinTypeOptionsIR { + pub fn compile PolarsResult>>( + self, + plan: C, + ) -> PolarsResult { + use JoinTypeOptionsIR::*; + match self { + Cross { predicate } => { + let predicate = plan(&predicate)?; + + Ok(JoinTypeOptions::Cross(CrossJoinOptions { predicate })) + }, + #[cfg(feature = "iejoin")] + IEJoin(opt) => Ok(JoinTypeOptions::IEJoin(opt)), + } + } +} + #[derive(Clone, Debug, PartialEq, Eq, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct JoinOptions { pub allow_parallel: bool, pub force_parallel: bool, pub args: JoinArgs, + pub options: Option, /// Proxy of the number of rows in both sides of the joins /// Holds `(Option, estimated_size)` pub rows_left: (Option, usize), @@ -59,7 +108,9 @@ impl Default for JoinOptions { JoinOptions { allow_parallel: true, force_parallel: false, + // Todo!: make default args: JoinArgs::new(JoinType::Left), + options: Default::default(), rows_left: (None, usize::MAX), rows_right: (None, usize::MAX), } diff --git a/crates/polars-plan/src/plans/builder_dsl.rs b/crates/polars-plan/src/plans/builder_dsl.rs index ec8e7c4ceebe..c8eff343807f 100644 --- a/crates/polars-plan/src/plans/builder_dsl.rs +++ b/crates/polars-plan/src/plans/builder_dsl.rs @@ -335,7 +335,7 @@ impl DslBuilder { } pub fn from_existing_df(df: DataFrame) -> Self { - let schema = Arc::new(df.schema()); + let schema = df.schema().clone(); DslPlan::DataFrameScan { df: Arc::new(df), schema, diff --git a/crates/polars-plan/src/plans/conversion/join.rs b/crates/polars-plan/src/plans/conversion/join.rs index 2d3e2a0f483d..dd94ed2d7784 100644 --- a/crates/polars-plan/src/plans/conversion/join.rs +++ b/crates/polars-plan/src/plans/conversion/join.rs @@ -111,7 +111,7 @@ pub fn resolve_join( let mut joined_on = PlHashSet::new(); #[cfg(feature = "iejoin")] - let check = !matches!(options.args.how, JoinType::IEJoin(_)); + let check = !matches!(options.args.how, JoinType::IEJoin); #[cfg(not(feature = "iejoin"))] let check = true; if check { @@ -641,10 +641,12 @@ fn resolve_join_where( } else if ie_right_on.len() >= 2 { // Do an IEjoin. let opts = Arc::make_mut(&mut options); - opts.args.how = JoinType::IEJoin(IEJoinOptions { + + opts.args.how = JoinType::IEJoin; + opts.options = Some(JoinTypeOptionsIR::IEJoin(IEJoinOptions { operator1: ie_op[0], operator2: Some(ie_op[1]), - }); + })); let (last_node, join_node) = resolve_join( Either::Right(input_left), @@ -671,10 +673,11 @@ fn resolve_join_where( } else if ie_right_on.len() == 1 { // For a single inequality comparison, we use the piecewise merge join algorithm let opts = Arc::make_mut(&mut options); - opts.args.how = JoinType::IEJoin(IEJoinOptions { + opts.args.how = JoinType::IEJoin; + opts.options = Some(JoinTypeOptionsIR::IEJoin(IEJoinOptions { operator1: ie_op[0], operator2: None, - }); + })); resolve_join( Either::Right(input_left), diff --git a/crates/polars-plan/src/plans/functions/mod.rs b/crates/polars-plan/src/plans/functions/mod.rs index a357d5b269ad..63d658e72e38 100644 --- a/crates/polars-plan/src/plans/functions/mod.rs +++ b/crates/polars-plan/src/plans/functions/mod.rs @@ -267,7 +267,7 @@ impl FunctionIR { validate_output, schema, .. - }) => python_udf::call_python_udf(function, df, *validate_output, schema.as_deref()), + }) => python_udf::call_python_udf(function, df, *validate_output, schema.clone()), FastCount { sources, scan_type, diff --git a/crates/polars-plan/src/plans/functions/python_udf.rs b/crates/polars-plan/src/plans/functions/python_udf.rs index 50917ef8f9fb..2dac388cd2f7 100644 --- a/crates/polars-plan/src/plans/functions/python_udf.rs +++ b/crates/polars-plan/src/plans/functions/python_udf.rs @@ -4,14 +4,14 @@ pub(super) fn call_python_udf( function: &PythonFunction, df: DataFrame, validate_output: bool, - opt_schema: Option<&Schema>, + opt_schema: Option, ) -> PolarsResult { let expected_schema = if let Some(schema) = opt_schema { - Some(Cow::Borrowed(schema)) + Some(schema) } // only materialize if we validate the output else if validate_output { - Some(Cow::Owned(df.schema())) + Some(df.schema().clone()) } // do not materialize the schema, we will ignore it. else { @@ -22,7 +22,7 @@ pub(super) fn call_python_udf( if validate_output { let output_schema = out.schema(); let expected = expected_schema.unwrap(); - if expected.as_ref() != &output_schema { + if &expected != output_schema { return Err(PolarsError::ComputeError( format!( "The output schema of 'LazyFrame.map' is incorrect. Expected: {expected:?}\n\ diff --git a/crates/polars-plan/src/plans/ir/format.rs b/crates/polars-plan/src/plans/ir/format.rs index 145c68fcbcfa..bc97a609ac56 100644 --- a/crates/polars-plan/src/plans/ir/format.rs +++ b/crates/polars-plan/src/plans/ir/format.rs @@ -319,13 +319,25 @@ impl<'a> IRDisplay<'a> { let left_on = self.display_expr_slice(left_on); let right_on = self.display_expr_slice(right_on); - let how = &options.args.how; - write!(f, "{:indent$}{how} JOIN:", "")?; - write!(f, "\n{:indent$}LEFT PLAN ON: {left_on}", "")?; - self.with_root(*input_left)._format(f, sub_indent)?; - write!(f, "\n{:indent$}RIGHT PLAN ON: {right_on}", "")?; - self.with_root(*input_right)._format(f, sub_indent)?; - write!(f, "\n{:indent$}END {how} JOIN", "") + // Fused cross + filter (show as nested loop join) + if let Some(JoinTypeOptionsIR::Cross { predicate }) = &options.options { + let predicate = self.display_expr(predicate); + let name = "NESTED LOOP"; + write!(f, "{:indent$}{name} JOIN ON {predicate}:", "")?; + write!(f, "\n{:indent$}LEFT PLAN:", "")?; + self.with_root(*input_left)._format(f, sub_indent)?; + write!(f, "\n{:indent$}RIGHT PLAN:", "")?; + self.with_root(*input_right)._format(f, sub_indent)?; + write!(f, "\n{:indent$}END {name} JOIN", "") + } else { + let how = &options.args.how; + write!(f, "{:indent$}{how} JOIN:", "")?; + write!(f, "\n{:indent$}LEFT PLAN ON: {left_on}", "")?; + self.with_root(*input_left)._format(f, sub_indent)?; + write!(f, "\n{:indent$}RIGHT PLAN ON: {right_on}", "")?; + self.with_root(*input_right)._format(f, sub_indent)?; + write!(f, "\n{:indent$}END {how} JOIN", "") + } }, HStack { input, exprs, .. } => { // @NOTE: Maybe there should be a clear delimiter here? diff --git a/crates/polars-plan/src/plans/mod.rs b/crates/polars-plan/src/plans/mod.rs index efb01919a15a..a8bcdc198fe9 100644 --- a/crates/polars-plan/src/plans/mod.rs +++ b/crates/polars-plan/src/plans/mod.rs @@ -204,10 +204,10 @@ impl Clone for DslPlan { impl Default for DslPlan { fn default() -> Self { let df = DataFrame::empty(); - let schema = df.schema(); + let schema = df.schema().clone(); DslPlan::DataFrameScan { df: Arc::new(df), - schema: Arc::new(schema), + schema, } } } diff --git a/crates/polars-plan/src/plans/optimizer/collapse_joins.rs b/crates/polars-plan/src/plans/optimizer/collapse_joins.rs index 07a4a45948b8..66bc8b8c6ce5 100644 --- a/crates/polars-plan/src/plans/optimizer/collapse_joins.rs +++ b/crates/polars-plan/src/plans/optimizer/collapse_joins.rs @@ -13,7 +13,7 @@ use polars_utils::arena::{Arena, Node}; use polars_utils::pl_str::PlSmallStr; use super::{aexpr_to_leaf_names_iter, AExpr, JoinOptions, IR}; -use crate::dsl::Operator; +use crate::dsl::{JoinTypeOptionsIR, Operator}; use crate::plans::{ExprIR, OutputName}; /// Join origin of an expression @@ -353,6 +353,7 @@ pub fn optimize(root: Node, lp_arena: &mut Arena, expr_arena: &mut Arena, expr_arena: &mut Arena, eq_right_on: Vec, #[cfg(feature = "iejoin")] ie_left_on: Vec, @@ -408,7 +409,7 @@ pub fn insert_fitting_join( debug_assert_eq!(ie_left_on.len(), ie_right_on.len()); debug_assert!(ie_op.len() <= 2); } - debug_assert_eq!(options.args.how, JoinType::Cross); + debug_assert!(matches!(options.args.how, JoinType::Cross)); let remaining_predicates = remaining_predicates .iter() @@ -444,10 +445,11 @@ pub fn insert_fitting_join( let operator2 = ie_op.get(1).copied(); // Do an IEjoin. - options.args.how = JoinType::IEJoin(IEJoinOptions { + options.args.how = JoinType::IEJoin; + options.options = Some(JoinTypeOptionsIR::IEJoin(IEJoinOptions { operator1, operator2, - }); + })); // We need to make sure not to delete any columns options.args.coalesce = JoinCoalesce::KeepColumns; @@ -471,8 +473,13 @@ pub fn insert_fitting_join( Some(acc.map_or(e, |acc| and_expr(acc, e, expr_arena))) }, ); + if let Some(pred) = remaining_predicates { + options.options = Some(JoinTypeOptionsIR::Cross { + predicate: ExprIR::from_node(pred, expr_arena), + }) + } - (Vec::new(), Vec::new(), remaining_predicates) + (Vec::new(), Vec::new(), None) }, }; diff --git a/crates/polars-plan/src/plans/optimizer/predicate_pushdown/join.rs b/crates/polars-plan/src/plans/optimizer/predicate_pushdown/join.rs index 6449e63ad4ed..35cfd6f3f864 100644 --- a/crates/polars-plan/src/plans/optimizer/predicate_pushdown/join.rs +++ b/crates/polars-plan/src/plans/optimizer/predicate_pushdown/join.rs @@ -84,7 +84,7 @@ fn join_produces_null(how: &JoinType) -> LeftRight { #[cfg(feature = "semi_anti_join")] JoinType::Semi | JoinType::Anti => LeftRight(false, false), #[cfg(feature = "iejoin")] - JoinType::IEJoin(..) => LeftRight(false, false), + JoinType::IEJoin => LeftRight(false, false), } } @@ -252,5 +252,6 @@ pub(super) fn process_join( schema, options, }; + Ok(opt.optional_apply_predicate(lp, local_predicates, lp_arena, expr_arena)) } diff --git a/crates/polars-plan/src/plans/optimizer/slice_pushdown_lp.rs b/crates/polars-plan/src/plans/optimizer/slice_pushdown_lp.rs index b10fc1daf9c5..eeb52e4a4d4f 100644 --- a/crates/polars-plan/src/plans/optimizer/slice_pushdown_lp.rs +++ b/crates/polars-plan/src/plans/optimizer/slice_pushdown_lp.rs @@ -354,7 +354,7 @@ impl SlicePushDown { left_on, right_on, mut options - }, Some(state)) if !self.streaming => { + }, Some(state)) if !self.streaming && !matches!(options.options, Some(JoinTypeOptionsIR::Cross { .. })) => { // first restart optimization in both inputs and get the updated LP let lp_left = lp_arena.take(input_left); let lp_left = self.pushdown(lp_left, None, lp_arena, expr_arena)?; diff --git a/crates/polars-python/src/lazyframe/visit.rs b/crates/polars-python/src/lazyframe/visit.rs index 4198b54c7ad7..636f852316b7 100644 --- a/crates/polars-python/src/lazyframe/visit.rs +++ b/crates/polars-python/src/lazyframe/visit.rs @@ -57,7 +57,7 @@ impl NodeTraverser { // Increment major on breaking changes to the IR (e.g. renaming // fields, reordering tuples), minor on backwards compatible // changes (e.g. exposing a new expression node). - const VERSION: Version = (4, 2); + const VERSION: Version = (5, 0); pub fn new(root: Node, lp_arena: Arena, expr_arena: Arena) -> Self { Self { diff --git a/crates/polars-python/src/lazyframe/visitor/nodes.rs b/crates/polars-python/src/lazyframe/visitor/nodes.rs index 773dae0bb74d..eaf5401f53d1 100644 --- a/crates/polars-python/src/lazyframe/visitor/nodes.rs +++ b/crates/polars-python/src/lazyframe/visitor/nodes.rs @@ -1,3 +1,5 @@ +#[cfg(feature = "iejoin")] +use polars::prelude::JoinTypeOptionsIR; use polars_core::prelude::IdxSize; use polars_ops::prelude::JoinType; use polars_plan::plans::IR; @@ -471,15 +473,26 @@ pub(crate) fn into_py(py: Python<'_>, plan: &IR) -> PyResult { return Err(PyNotImplementedError::new_err("asof join")) }, #[cfg(feature = "iejoin")] - JoinType::IEJoin(ie_options) => ( - name, - crate::Wrap(ie_options.operator1).into_py_any(py)?, - ie_options.operator2.as_ref().map_or_else( - || Ok(py.None()), - |op| crate::Wrap(*op).into_py_any(py), - )?, - ) - .into_py_any(py)?, + JoinType::IEJoin => { + let Some(JoinTypeOptionsIR::IEJoin(ie_options)) = &options.options + else { + unreachable!() + }; + ( + name, + crate::Wrap(ie_options.operator1).into_py_any(py)?, + ie_options.operator2.as_ref().map_or_else( + || Ok(py.None()), + |op| crate::Wrap(*op).into_py_any(py), + )?, + ) + .into_py_any(py)? + }, + // This is a cross join fused with a predicate. Shown in the IR::explain as + // NESTED LOOP JOIN + JoinType::Cross if options.options.is_some() => { + return Err(PyNotImplementedError::new_err("nested loop join")) + }, _ => name.into_any().unbind(), }, options.args.join_nulls, diff --git a/crates/polars-stream/src/physical_plan/fmt.rs b/crates/polars-stream/src/physical_plan/fmt.rs index 37f3dcfcdaa4..11c701abdbbf 100644 --- a/crates/polars-stream/src/physical_plan/fmt.rs +++ b/crates/polars-stream/src/physical_plan/fmt.rs @@ -216,6 +216,7 @@ fn visualize_plan_rec( left_on, right_on, args, + .. } | PhysNodeKind::EquiJoin { input_left, diff --git a/crates/polars-stream/src/physical_plan/lower_ir.rs b/crates/polars-stream/src/physical_plan/lower_ir.rs index 70e5c4885c54..7b0140751159 100644 --- a/crates/polars-stream/src/physical_plan/lower_ir.rs +++ b/crates/polars-stream/src/physical_plan/lower_ir.rs @@ -533,6 +533,7 @@ pub fn lower_ir( let left_on = left_on.clone(); let right_on = right_on.clone(); let args = options.args.clone(); + let options = options.options.clone(); let phys_left = lower_ir!(input_left)?; let phys_right = lower_ir!(input_right)?; if args.how.is_equi() && !args.validation.needs_checks() { @@ -578,6 +579,7 @@ pub fn lower_ir( left_on, right_on, args, + options, } } }, diff --git a/crates/polars-stream/src/physical_plan/mod.rs b/crates/polars-stream/src/physical_plan/mod.rs index 217131b18b03..1d33addb4c4a 100644 --- a/crates/polars-stream/src/physical_plan/mod.rs +++ b/crates/polars-stream/src/physical_plan/mod.rs @@ -6,6 +6,7 @@ use polars_core::prelude::{IdxSize, InitHashMaps, PlHashMap, SortMultipleOptions use polars_core::schema::{Schema, SchemaRef}; use polars_error::PolarsResult; use polars_ops::frame::JoinArgs; +use polars_plan::dsl::JoinTypeOptionsIR; use polars_plan::plans::hive::HivePartitions; use polars_plan::plans::{AExpr, DataFrameUdf, FileInfo, FileScan, ScanSources, IR}; use polars_plan::prelude::expr_ir::ExprIR; @@ -171,6 +172,7 @@ pub enum PhysNodeKind { left_on: Vec, right_on: Vec, args: JoinArgs, + options: Option, }, } diff --git a/crates/polars-stream/src/physical_plan/to_graph.rs b/crates/polars-stream/src/physical_plan/to_graph.rs index a0a9a24a4e4d..1f6cbecb6cbf 100644 --- a/crates/polars-stream/src/physical_plan/to_graph.rs +++ b/crates/polars-stream/src/physical_plan/to_graph.rs @@ -457,6 +457,7 @@ fn to_graph_rec<'a>( left_on, right_on, args, + options, } => { let left_input_key = to_graph_rec(*input_left, ctx)?; let right_input_key = to_graph_rec(*input_right, ctx)?; @@ -480,6 +481,7 @@ fn to_graph_rec<'a>( allow_parallel: true, force_parallel: false, args: args.clone(), + options: options.clone(), rows_left: (None, 0), rows_right: (None, 0), }), diff --git a/crates/polars-time/src/upsample.rs b/crates/polars-time/src/upsample.rs index 47fef8180751..c1f9847486a9 100644 --- a/crates/polars-time/src/upsample.rs +++ b/crates/polars-time/src/upsample.rs @@ -216,6 +216,7 @@ fn upsample_single_impl( [index_col_name.clone()], [index_col_name.clone()], JoinArgs::new(JoinType::Left), + None, ) }, _ => polars_bail!( diff --git a/crates/polars/tests/it/chunks/parquet.rs b/crates/polars/tests/it/chunks/parquet.rs index 384382fdd5f9..73770d0d0faa 100644 --- a/crates/polars/tests/it/chunks/parquet.rs +++ b/crates/polars/tests/it/chunks/parquet.rs @@ -25,7 +25,13 @@ fn test_cast_join_14872() { let df2 = ParquetReader::new(buf).finish().unwrap(); let out = df1 - .join(&df2, ["ints"], ["ints"], JoinArgs::new(JoinType::Left)) + .join( + &df2, + ["ints"], + ["ints"], + JoinArgs::new(JoinType::Left), + None, + ) .unwrap(); let expected = df![ diff --git a/crates/polars/tests/it/core/joins.rs b/crates/polars/tests/it/core/joins.rs index 2b838a8a41f1..935d71b0c7d9 100644 --- a/crates/polars/tests/it/core/joins.rs +++ b/crates/polars/tests/it/core/joins.rs @@ -27,6 +27,7 @@ fn test_chunked_left_join() -> PolarsResult<()> { ["name"], ["name"], JoinArgs::new(JoinType::Left), + None, )?; let expected = df![ "name" => ["john", "paul", "keith"], @@ -135,6 +136,7 @@ fn test_full_outer_join() -> PolarsResult<()> { ["days"], ["days"], JoinArgs::new(JoinType::Full).with_coalesce(JoinCoalesce::CoalesceColumns), + None, )?; assert_eq!(joined.height(), 5); assert_eq!( @@ -162,6 +164,7 @@ fn test_full_outer_join() -> PolarsResult<()> { ["a"], ["a"], JoinArgs::new(JoinType::Full).with_coalesce(JoinCoalesce::CoalesceColumns), + None, )?; assert_eq!(out.column("c_right")?.null_count(), 1); @@ -260,13 +263,25 @@ fn test_join_multiple_columns() { // now check the join with multiple columns let joined = df_a - .join(&df_b, ["a", "b"], ["foo", "bar"], JoinType::Left.into()) + .join( + &df_b, + ["a", "b"], + ["foo", "bar"], + JoinType::Left.into(), + None, + ) .unwrap(); let ca = joined.column("ham").unwrap().str().unwrap(); assert_eq!(Vec::from(ca), correct_ham); let joined_inner_hack = df_a.inner_join(&df_b, ["dummy"], ["dummy"]).unwrap(); let joined_inner = df_a - .join(&df_b, ["a", "b"], ["foo", "bar"], JoinType::Inner.into()) + .join( + &df_b, + ["a", "b"], + ["foo", "bar"], + JoinType::Inner.into(), + None, + ) .unwrap(); assert!(joined_inner_hack @@ -281,6 +296,7 @@ fn test_join_multiple_columns() { ["a", "b"], ["foo", "bar"], JoinArgs::new(JoinType::Full).with_coalesce(JoinCoalesce::CoalesceColumns), + None, ) .unwrap(); assert!(joined_full_outer_hack @@ -309,7 +325,7 @@ fn test_join_categorical() { .unwrap(); let out = df_a - .join(&df_b, ["b"], ["bar"], JoinType::Left.into()) + .join(&df_b, ["b"], ["bar"], JoinType::Left.into(), None) .unwrap(); assert_eq!(out.shape(), (6, 5)); let correct_ham = &[ @@ -327,7 +343,7 @@ fn test_join_categorical() { // test dispatch for jt in [JoinType::Left, JoinType::Inner, JoinType::Full] { - let out = df_a.join(&df_b, ["b"], ["bar"], jt.into()).unwrap(); + let out = df_a.join(&df_b, ["b"], ["bar"], jt.into(), None).unwrap(); let out = out.column("b").unwrap(); assert_eq!( out.dtype(), @@ -350,7 +366,7 @@ fn test_join_categorical() { s.cast(&DataType::Categorical(None, Default::default())) }) .unwrap(); - let out = df_a.join(&df_b, ["b"], ["bar"], JoinType::Left.into()); + let out = df_a.join(&df_b, ["b"], ["bar"], JoinType::Left.into(), None); assert!(out.is_err()); } @@ -444,7 +460,13 @@ fn test_join_err() -> PolarsResult<()> { // dtypes don't match, error assert!(df1 - .join(&df2, vec!["a", "b"], vec!["a", "b"], JoinType::Left.into()) + .join( + &df2, + vec!["a", "b"], + vec!["a", "b"], + JoinType::Left.into(), + None + ) .is_err()); Ok(()) } @@ -490,6 +512,7 @@ fn test_joins_with_duplicates() -> PolarsResult<()> { ["col1"], ["join_col1"], JoinArgs::new(JoinType::Full).with_coalesce(JoinCoalesce::CoalesceColumns), + None, ) .unwrap(); @@ -532,6 +555,7 @@ fn test_multi_joins_with_duplicates() -> PolarsResult<()> { ["col1", "join_col2"], ["join_col1", "col2"], JoinType::Inner.into(), + None, ) .unwrap(); @@ -547,6 +571,7 @@ fn test_multi_joins_with_duplicates() -> PolarsResult<()> { ["col1", "join_col2"], ["join_col1", "col2"], JoinType::Left.into(), + None, ) .unwrap(); @@ -562,6 +587,7 @@ fn test_multi_joins_with_duplicates() -> PolarsResult<()> { ["col1", "join_col2"], ["join_col1", "col2"], JoinArgs::new(JoinType::Full).with_coalesce(JoinCoalesce::CoalesceColumns), + None, ) .unwrap(); @@ -594,6 +620,7 @@ fn test_join_floats() -> PolarsResult<()> { vec!["a", "c"], vec!["foo", "bar"], JoinType::Left.into(), + None, )?; assert_eq!( Vec::from(out.column("ham")?.str()?), @@ -605,6 +632,7 @@ fn test_join_floats() -> PolarsResult<()> { vec!["a", "c"], vec!["foo", "bar"], JoinArgs::new(JoinType::Full).with_coalesce(JoinCoalesce::CoalesceColumns), + None, )?; assert_eq!( out.dtypes(), diff --git a/crates/polars/tests/it/io/json.rs b/crates/polars/tests/it/io/json.rs index 9095d4299bdd..bc84cd72d25e 100644 --- a/crates/polars/tests/it/io/json.rs +++ b/crates/polars/tests/it/io/json.rs @@ -158,7 +158,7 @@ fn test_read_ndjson_iss_5875() { ); schema.with_column("float".into(), DataType::Float64); - assert_eq!(schema, df.unwrap().schema()); + assert_eq!(&schema, &(**df.unwrap().schema())); } #[test] diff --git a/py-polars/tests/unit/lazyframe/test_optimizations.py b/py-polars/tests/unit/lazyframe/test_optimizations.py index e2ac38bd0424..2e238cd4cf95 100644 --- a/py-polars/tests/unit/lazyframe/test_optimizations.py +++ b/py-polars/tests/unit/lazyframe/test_optimizations.py @@ -244,8 +244,8 @@ def test_collapse_joins() -> None: dont_mix = cross.filter(pl.col.x + pl.col.a != 0) e = dont_mix.explain() - assert "CROSS JOIN" in e - assert "FILTER" in e + assert "NESTED LOOP JOIN" in e + assert "FILTER" not in e assert_frame_equal( dont_mix.collect(collapse_joins=False), dont_mix.collect(), @@ -254,7 +254,7 @@ def test_collapse_joins() -> None: no_literals = cross.filter(pl.col.x == 2) e = no_literals.explain() - assert "CROSS JOIN" in e + assert "NESTED LOOP JOIN" in e assert_frame_equal( no_literals.collect(collapse_joins=False), no_literals.collect(), @@ -264,6 +264,7 @@ def test_collapse_joins() -> None: iejoin = cross.filter(pl.col.x >= pl.col.a) e = iejoin.explain() assert "IEJOIN" in e + assert "NESTED LOOP JOIN" not in e assert "CROSS JOIN" not in e assert "FILTER" not in e assert_frame_equal( @@ -276,6 +277,7 @@ def test_collapse_joins() -> None: e = iejoin.explain() assert "IEJOIN" in e assert "CROSS JOIN" not in e + assert "NESTED LOOP JOIN" not in e assert "FILTER" not in e assert_frame_equal( iejoin.collect(collapse_joins=False), iejoin.collect(), check_row_order=False diff --git a/py-polars/tests/unit/operations/test_cross_join.py b/py-polars/tests/unit/operations/test_cross_join.py index f424da5ab170..9913ab3f1094 100644 --- a/py-polars/tests/unit/operations/test_cross_join.py +++ b/py-polars/tests/unit/operations/test_cross_join.py @@ -4,6 +4,7 @@ import pytest import polars as pl +from polars.testing import assert_frame_equal def test_cross_join_predicate_pushdown_block_16956() -> None: @@ -44,3 +45,33 @@ def test_cross_join_raise_on_keys() -> None: with pytest.raises(ValueError): df.join(df, how="cross", left_on="a", right_on="b") + + +def test_nested_loop_join() -> None: + left = pl.LazyFrame( + { + "a": [1, 2, 1, 3], + "b": [1, 2, 3, 4], + } + ) + right = pl.LazyFrame( + { + "c": [4, 1, 2], + "d": [1, 2, 3], + } + ) + + actual = left.join_where(right, pl.col("a") != pl.col("c")) + plan = actual.explain() + assert "NESTED LOOP JOIN" in plan + expected = pl.DataFrame( + { + "a": [1, 1, 2, 2, 1, 1, 3, 3, 3], + "b": [1, 1, 2, 2, 3, 3, 4, 4, 4], + "c": [4, 2, 4, 1, 4, 2, 4, 1, 2], + "d": [1, 3, 1, 2, 1, 3, 1, 2, 3], + } + ) + assert_frame_equal( + actual.collect(), expected, check_row_order=False, check_exact=True + ) diff --git a/py-polars/tests/unit/operations/test_inequality_join.py b/py-polars/tests/unit/operations/test_inequality_join.py index 87a76c985b13..848a4b2b7f85 100644 --- a/py-polars/tests/unit/operations/test_inequality_join.py +++ b/py-polars/tests/unit/operations/test_inequality_join.py @@ -297,8 +297,7 @@ def test_join_where_predicates(range_constraint: list[pl.Expr]) -> None: ) explained = q.explain() - assert "CROSS" in explained - assert "FILTER" in explained + assert "NESTED LOOP" in explained actual = q.collect() assert actual.to_dict(as_series=False) == { "group": [0, 0, 0, 0, 0, 0, 1, 1, 1], @@ -603,7 +602,7 @@ def test_join_on_strings() -> None: q = df.join_where(df, pl.col("a").ge(pl.col("a_right"))) - assert "CROSS JOIN" in q.explain() + assert "NESTED LOOP JOIN" in q.explain() assert q.collect().to_dict(as_series=False) == { "a": ["a", "b", "b", "c", "c", "c"], "b": ["b", "b", "b", "b", "b", "b"],