From 94b408759577fc150bf2681836d94ef7181854f0 Mon Sep 17 00:00:00 2001 From: Gijs Burghoorn Date: Fri, 10 Jan 2025 18:06:25 +0100 Subject: [PATCH] perf: Use downcast_ref instead of dtype equality in `>` (#20664) --- crates/polars-core/src/datatypes/mod.rs | 34 +++++++---- .../src/series/implementations/array.rs | 3 +- .../src/series/implementations/binary.rs | 4 ++ .../series/implementations/binary_offset.rs | 4 ++ .../src/series/implementations/boolean.rs | 4 ++ .../src/series/implementations/categorical.rs | 5 ++ .../src/series/implementations/date.rs | 4 ++ .../src/series/implementations/datetime.rs | 5 ++ .../src/series/implementations/decimal.rs | 4 ++ .../src/series/implementations/duration.rs | 4 ++ .../src/series/implementations/floats.rs | 5 ++ .../src/series/implementations/list.rs | 3 +- .../src/series/implementations/mod.rs | 4 ++ .../src/series/implementations/null.rs | 5 ++ .../src/series/implementations/object.rs | 4 ++ .../src/series/implementations/string.rs | 4 ++ .../src/series/implementations/struct_.rs | 4 ++ .../src/series/implementations/time.rs | 4 ++ crates/polars-core/src/series/mod.rs | 60 +++++++------------ crates/polars-core/src/series/series_trait.rs | 13 ++-- crates/polars-expr/src/reduce/mod.rs | 2 +- 21 files changed, 118 insertions(+), 61 deletions(-) diff --git a/crates/polars-core/src/datatypes/mod.rs b/crates/polars-core/src/datatypes/mod.rs index f9e4b71e5602..4c4c726037eb 100644 --- a/crates/polars-core/src/datatypes/mod.rs +++ b/crates/polars-core/src/datatypes/mod.rs @@ -78,6 +78,7 @@ pub unsafe trait PolarsDataType: Send + Sync + Sized + 'static { type HasViews; type IsStruct; type IsObject; + type IsLogical; fn get_dtype() -> DataType where @@ -95,6 +96,7 @@ where HasViews = FalseT, IsStruct = FalseT, IsObject = FalseT, + IsLogical = FalseT, >, { type Native: NumericNative; @@ -117,6 +119,7 @@ macro_rules! impl_polars_num_datatype { type HasViews = FalseT; type IsStruct = FalseT; type IsObject = FalseT; + type IsLogical = FalseT; #[inline] fn get_dtype() -> DataType { @@ -133,7 +136,7 @@ macro_rules! impl_polars_num_datatype { } macro_rules! impl_polars_datatype_pass_dtype { - ($ca:ident, $dtype:expr, $arr:ty, $lt:lifetime, $phys:ty, $zerophys:ty, $owned_phys:ty, $has_views:ident) => { + ($ca:ident, $dtype:expr, $arr:ty, $lt:lifetime, $phys:ty, $zerophys:ty, $owned_phys:ty, $has_views:ident, $is_logical:ident) => { #[derive(Clone, Copy)] pub struct $ca {} @@ -146,6 +149,7 @@ macro_rules! impl_polars_datatype_pass_dtype { type HasViews = $has_views; type IsStruct = FalseT; type IsObject = FalseT; + type IsLogical = $is_logical; #[inline] fn get_dtype() -> DataType { @@ -164,13 +168,14 @@ macro_rules! impl_polars_binview_datatype { $phys, $zerophys, $owned_phys, - TrueT + TrueT, + FalseT ); }; } macro_rules! impl_polars_datatype { - ($ca:ident, $variant:ident, $arr:ty, $lt:lifetime, $phys:ty, $zerophys:ty, $owned_phys:ty) => { + ($ca:ident, $variant:ident, $arr:ty, $lt:lifetime, $phys:ty, $zerophys:ty, $owned_phys:ty, $is_logical:ident) => { impl_polars_datatype_pass_dtype!( $ca, DataType::$variant, @@ -179,7 +184,8 @@ macro_rules! impl_polars_datatype { $phys, $zerophys, $owned_phys, - FalseT + FalseT, + $is_logical ); }; } @@ -197,18 +203,18 @@ impl_polars_num_datatype!(PolarsIntegerType, Int64Type, Int64, i64, i64); impl_polars_num_datatype!(PolarsIntegerType, Int128Type, Int128, i128, i128); impl_polars_num_datatype!(PolarsFloatType, Float32Type, Float32, f32, f32); impl_polars_num_datatype!(PolarsFloatType, Float64Type, Float64, f64, f64); -impl_polars_datatype!(DateType, Date, PrimitiveArray, 'a, i32, i32, i32); -impl_polars_datatype!(TimeType, Time, PrimitiveArray, 'a, i64, i64, i64); +impl_polars_datatype!(DateType, Date, PrimitiveArray, 'a, i32, i32, i32, TrueT); +impl_polars_datatype!(TimeType, Time, PrimitiveArray, 'a, i64, i64, i64, TrueT); impl_polars_binview_datatype!(StringType, String, Utf8ViewArray, 'a, &'a str, Option<&'a str>, String); impl_polars_binview_datatype!(BinaryType, Binary, BinaryViewArray, 'a, &'a [u8], Option<&'a [u8]>, Box<[u8]>); -impl_polars_datatype!(BinaryOffsetType, BinaryOffset, BinaryArray, 'a, &'a [u8], Option<&'a [u8]>, Box<[u8]>); -impl_polars_datatype!(BooleanType, Boolean, BooleanArray, 'a, bool, bool, bool); +impl_polars_datatype!(BinaryOffsetType, BinaryOffset, BinaryArray, 'a, &'a [u8], Option<&'a [u8]>, Box<[u8]>, FalseT); +impl_polars_datatype!(BooleanType, Boolean, BooleanArray, 'a, bool, bool, bool, FalseT); #[cfg(feature = "dtype-decimal")] -impl_polars_datatype_pass_dtype!(DecimalType, DataType::Unknown(UnknownKind::Any), PrimitiveArray, 'a, i128, i128, i128, FalseT); -impl_polars_datatype_pass_dtype!(DatetimeType, DataType::Unknown(UnknownKind::Any), PrimitiveArray, 'a, i64, i64, i64, FalseT); -impl_polars_datatype_pass_dtype!(DurationType, DataType::Unknown(UnknownKind::Any), PrimitiveArray, 'a, i64, i64, i64, FalseT); -impl_polars_datatype_pass_dtype!(CategoricalType, DataType::Unknown(UnknownKind::Any), PrimitiveArray, 'a, u32, u32, u32, FalseT); +impl_polars_datatype_pass_dtype!(DecimalType, DataType::Unknown(UnknownKind::Any), PrimitiveArray, 'a, i128, i128, i128, FalseT, TrueT); +impl_polars_datatype_pass_dtype!(DatetimeType, DataType::Unknown(UnknownKind::Any), PrimitiveArray, 'a, i64, i64, i64, FalseT, TrueT); +impl_polars_datatype_pass_dtype!(DurationType, DataType::Unknown(UnknownKind::Any), PrimitiveArray, 'a, i64, i64, i64, FalseT, TrueT); +impl_polars_datatype_pass_dtype!(CategoricalType, DataType::Unknown(UnknownKind::Any), PrimitiveArray, 'a, u32, u32, u32, FalseT, TrueT); #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct ListType {} @@ -221,6 +227,7 @@ unsafe impl PolarsDataType for ListType { type HasViews = FalseT; type IsStruct = FalseT; type IsObject = FalseT; + type IsLogical = FalseT; fn get_dtype() -> DataType { // Null as we cannot know anything without self. @@ -245,6 +252,7 @@ unsafe impl PolarsDataType for StructType { type HasViews = FalseT; type IsStruct = TrueT; type IsObject = FalseT; + type IsLogical = FalseT; fn get_dtype() -> DataType where @@ -266,6 +274,7 @@ unsafe impl PolarsDataType for FixedSizeListType { type HasViews = FalseT; type IsStruct = FalseT; type IsObject = FalseT; + type IsLogical = FalseT; fn get_dtype() -> DataType { // Null as we cannot know anything without self. @@ -285,6 +294,7 @@ unsafe impl PolarsDataType for ObjectType { type HasViews = FalseT; type IsStruct = FalseT; type IsObject = TrueT; + type IsLogical = FalseT; fn get_dtype() -> DataType { DataType::Object(T::type_name(), None) diff --git a/crates/polars-core/src/series/implementations/array.rs b/crates/polars-core/src/series/implementations/array.rs index 893ee2b6b0c8..156076a4c295 100644 --- a/crates/polars-core/src/series/implementations/array.rs +++ b/crates/polars-core/src/series/implementations/array.rs @@ -207,12 +207,11 @@ impl SeriesTrait for SeriesWrap { fn clone_inner(&self) -> Arc { Arc::new(SeriesWrap(Clone::clone(&self.0))) } + fn as_any(&self) -> &dyn Any { &self.0 } - /// Get a hold to self as `Any` trait reference. - /// Only implemented for ObjectType fn as_any_mut(&mut self) -> &mut dyn Any { &mut self.0 } diff --git a/crates/polars-core/src/series/implementations/binary.rs b/crates/polars-core/src/series/implementations/binary.rs index 8c011f5b8104..52bdd857edb1 100644 --- a/crates/polars-core/src/series/implementations/binary.rs +++ b/crates/polars-core/src/series/implementations/binary.rs @@ -247,4 +247,8 @@ impl SeriesTrait for SeriesWrap { fn as_any(&self) -> &dyn Any { &self.0 } + + fn as_any_mut(&mut self) -> &mut dyn Any { + &mut self.0 + } } diff --git a/crates/polars-core/src/series/implementations/binary_offset.rs b/crates/polars-core/src/series/implementations/binary_offset.rs index 574f6617252e..6cb2a4b3e86c 100644 --- a/crates/polars-core/src/series/implementations/binary_offset.rs +++ b/crates/polars-core/src/series/implementations/binary_offset.rs @@ -189,4 +189,8 @@ impl SeriesTrait for SeriesWrap { fn as_any(&self) -> &dyn Any { &self.0 } + + fn as_any_mut(&mut self) -> &mut dyn Any { + &mut self.0 + } } diff --git a/crates/polars-core/src/series/implementations/boolean.rs b/crates/polars-core/src/series/implementations/boolean.rs index eaa9bd9a641a..ab92d66bbf40 100644 --- a/crates/polars-core/src/series/implementations/boolean.rs +++ b/crates/polars-core/src/series/implementations/boolean.rs @@ -354,4 +354,8 @@ impl SeriesTrait for SeriesWrap { fn as_any(&self) -> &dyn Any { &self.0 } + + fn as_any_mut(&mut self) -> &mut dyn Any { + &mut self.0 + } } diff --git a/crates/polars-core/src/series/implementations/categorical.rs b/crates/polars-core/src/series/implementations/categorical.rs index 8bad6f18e1db..fc9dc0eb5760 100644 --- a/crates/polars-core/src/series/implementations/categorical.rs +++ b/crates/polars-core/src/series/implementations/categorical.rs @@ -293,9 +293,14 @@ impl SeriesTrait for SeriesWrap { fn max_reduce(&self) -> PolarsResult { Ok(ChunkAggSeries::max_reduce(&self.0)) } + fn as_any(&self) -> &dyn Any { &self.0 } + + fn as_any_mut(&mut self) -> &mut dyn Any { + &mut self.0 + } } impl private::PrivateSeriesNumeric for SeriesWrap { diff --git a/crates/polars-core/src/series/implementations/date.rs b/crates/polars-core/src/series/implementations/date.rs index da11edf51e70..e5e319c3fc16 100644 --- a/crates/polars-core/src/series/implementations/date.rs +++ b/crates/polars-core/src/series/implementations/date.rs @@ -351,6 +351,10 @@ impl SeriesTrait for SeriesWrap { fn as_any(&self) -> &dyn Any { &self.0 } + + fn as_any_mut(&mut self) -> &mut dyn Any { + &mut self.0 + } } impl private::PrivateSeriesNumeric for SeriesWrap { diff --git a/crates/polars-core/src/series/implementations/datetime.rs b/crates/polars-core/src/series/implementations/datetime.rs index 85772c82e545..c547be64b504 100644 --- a/crates/polars-core/src/series/implementations/datetime.rs +++ b/crates/polars-core/src/series/implementations/datetime.rs @@ -360,7 +360,12 @@ impl SeriesTrait for SeriesWrap { fn clone_inner(&self) -> Arc { Arc::new(SeriesWrap(Clone::clone(&self.0))) } + fn as_any(&self) -> &dyn Any { &self.0 } + + fn as_any_mut(&mut self) -> &mut dyn Any { + &mut self.0 + } } diff --git a/crates/polars-core/src/series/implementations/decimal.rs b/crates/polars-core/src/series/implementations/decimal.rs index b9cbaf6514fd..0494da7bffd2 100644 --- a/crates/polars-core/src/series/implementations/decimal.rs +++ b/crates/polars-core/src/series/implementations/decimal.rs @@ -419,4 +419,8 @@ impl SeriesTrait for SeriesWrap { fn as_any(&self) -> &dyn Any { &self.0 } + + fn as_any_mut(&mut self) -> &mut dyn Any { + &mut self.0 + } } diff --git a/crates/polars-core/src/series/implementations/duration.rs b/crates/polars-core/src/series/implementations/duration.rs index 1ededbed7d16..c7563221cfa3 100644 --- a/crates/polars-core/src/series/implementations/duration.rs +++ b/crates/polars-core/src/series/implementations/duration.rs @@ -520,4 +520,8 @@ impl SeriesTrait for SeriesWrap { fn as_any(&self) -> &dyn Any { &self.0 } + + fn as_any_mut(&mut self) -> &mut dyn Any { + &mut self.0 + } } diff --git a/crates/polars-core/src/series/implementations/floats.rs b/crates/polars-core/src/series/implementations/floats.rs index 0a80f1a0474e..e787c158d5e2 100644 --- a/crates/polars-core/src/series/implementations/floats.rs +++ b/crates/polars-core/src/series/implementations/floats.rs @@ -362,9 +362,14 @@ macro_rules! impl_dyn_series { fn checked_div(&self, rhs: &Series) -> PolarsResult { self.0.checked_div(rhs) } + fn as_any(&self) -> &dyn Any { &self.0 } + + fn as_any_mut(&mut self) -> &mut dyn Any { + &mut self.0 + } } }; } diff --git a/crates/polars-core/src/series/implementations/list.rs b/crates/polars-core/src/series/implementations/list.rs index 4217b0b2c9f2..bfee61814fcc 100644 --- a/crates/polars-core/src/series/implementations/list.rs +++ b/crates/polars-core/src/series/implementations/list.rs @@ -254,12 +254,11 @@ impl SeriesTrait for SeriesWrap { fn clone_inner(&self) -> Arc { Arc::new(SeriesWrap(Clone::clone(&self.0))) } + fn as_any(&self) -> &dyn Any { &self.0 } - /// Get a hold to self as `Any` trait reference. - /// Only implemented for ObjectType fn as_any_mut(&mut self) -> &mut dyn Any { &mut self.0 } diff --git a/crates/polars-core/src/series/implementations/mod.rs b/crates/polars-core/src/series/implementations/mod.rs index 9e9a5bb96b14..4214f645f381 100644 --- a/crates/polars-core/src/series/implementations/mod.rs +++ b/crates/polars-core/src/series/implementations/mod.rs @@ -439,6 +439,10 @@ macro_rules! impl_dyn_series { fn as_any(&self) -> &dyn Any { &self.0 } + + fn as_any_mut(&mut self) -> &mut dyn Any { + &mut self.0 + } } }; } diff --git a/crates/polars-core/src/series/implementations/null.rs b/crates/polars-core/src/series/implementations/null.rs index 912a191ead43..d844e24aa8c1 100644 --- a/crates/polars-core/src/series/implementations/null.rs +++ b/crates/polars-core/src/series/implementations/null.rs @@ -316,9 +316,14 @@ impl SeriesTrait for NullChunked { fn clone_inner(&self) -> Arc { Arc::new(self.clone()) } + fn as_any(&self) -> &dyn Any { self } + + fn as_any_mut(&mut self) -> &mut dyn Any { + self + } } unsafe impl IntoSeries for NullChunked { diff --git a/crates/polars-core/src/series/implementations/object.rs b/crates/polars-core/src/series/implementations/object.rs index 22afb23fec9e..b70ef3f074b6 100644 --- a/crates/polars-core/src/series/implementations/object.rs +++ b/crates/polars-core/src/series/implementations/object.rs @@ -237,6 +237,10 @@ where fn as_any(&self) -> &dyn Any { &self.0 } + + fn as_any_mut(&mut self) -> &mut dyn Any { + &mut self.0 + } } #[cfg(test)] diff --git a/crates/polars-core/src/series/implementations/string.rs b/crates/polars-core/src/series/implementations/string.rs index 44c8d5522491..c98337af075d 100644 --- a/crates/polars-core/src/series/implementations/string.rs +++ b/crates/polars-core/src/series/implementations/string.rs @@ -261,4 +261,8 @@ impl SeriesTrait for SeriesWrap { fn as_any(&self) -> &dyn Any { &self.0 } + + fn as_any_mut(&mut self) -> &mut dyn Any { + &mut self.0 + } } diff --git a/crates/polars-core/src/series/implementations/struct_.rs b/crates/polars-core/src/series/implementations/struct_.rs index 7cd943771351..969601c338c6 100644 --- a/crates/polars-core/src/series/implementations/struct_.rs +++ b/crates/polars-core/src/series/implementations/struct_.rs @@ -257,6 +257,10 @@ impl SeriesTrait for SeriesWrap { &self.0 } + fn as_any_mut(&mut self) -> &mut dyn Any { + &mut self.0 + } + fn sort_with(&self, options: SortOptions) -> PolarsResult { Ok(self.0.sort_with(options).into_series()) } diff --git a/crates/polars-core/src/series/implementations/time.rs b/crates/polars-core/src/series/implementations/time.rs index 75f64e651e1a..247f48091f42 100644 --- a/crates/polars-core/src/series/implementations/time.rs +++ b/crates/polars-core/src/series/implementations/time.rs @@ -324,6 +324,10 @@ impl SeriesTrait for SeriesWrap { fn as_any(&self) -> &dyn Any { &self.0 } + + fn as_any_mut(&mut self) -> &mut dyn Any { + &mut self.0 + } } impl private::PrivateSeriesNumeric for SeriesWrap { diff --git a/crates/polars-core/src/series/mod.rs b/crates/polars-core/src/series/mod.rs index 7ee9991c1f6b..f84183f7a4af 100644 --- a/crates/polars-core/src/series/mod.rs +++ b/crates/polars-core/src/series/mod.rs @@ -1006,55 +1006,41 @@ impl Default for Series { } } -fn equal_outer_type(dtype: &DataType) -> bool { - match (T::get_dtype(), dtype) { - (DataType::List(_), DataType::List(_)) => true, - #[cfg(feature = "dtype-array")] - (DataType::Array(_, _), DataType::Array(_, _)) => true, - #[cfg(feature = "dtype-struct")] - (DataType::Struct(_), DataType::Struct(_)) => true, - (a, b) => &a == b, - } -} - impl AsRef> for dyn SeriesTrait + '_ where - T: 'static + PolarsDataType, + T: 'static + PolarsDataType, { fn as_ref(&self) -> &ChunkedArray { - let dtype = self.dtype(); + // @NOTE: SeriesTrait `as_any` returns a std::any::Any for the underlying ChunkedArray / + // Logical (so not the SeriesWrap). + let Some(ca) = self.as_any().downcast_ref::>() else { + panic!( + "implementation error, cannot get ref {:?} from {:?}", + T::get_dtype(), + self.dtype() + ); + }; - #[cfg(feature = "dtype-decimal")] - if dtype.is_decimal() { - let logical = self.as_any().downcast_ref::().unwrap(); - let ca = logical.physical(); - return ca.as_any().downcast_ref::>().unwrap(); - } - let eq = equal_outer_type::(dtype); - assert!( - eq, - "implementation error, cannot get ref {:?} from {:?}", - T::get_dtype(), - self.dtype() - ); - // SAFETY: we just checked the type. - unsafe { &*(self as *const dyn SeriesTrait as *const ChunkedArray) } + ca } } impl AsMut> for dyn SeriesTrait + '_ where - T: 'static + PolarsDataType, + T: 'static + PolarsDataType, { fn as_mut(&mut self) -> &mut ChunkedArray { - let eq = equal_outer_type::(self.dtype()); - assert!( - eq, - "implementation error, cannot get ref {:?} from {:?}", - T::get_dtype(), - self.dtype() - ); - unsafe { &mut *(self as *mut dyn SeriesTrait as *mut ChunkedArray) } + if !self.as_any_mut().is::>() { + panic!( + "implementation error, cannot get ref {:?} from {:?}", + T::get_dtype(), + self.dtype() + ); + } + + // @NOTE: SeriesTrait `as_any` returns a std::any::Any for the underlying ChunkedArray / + // Logical (so not the SeriesWrap). + self.as_any_mut().downcast_mut::>().unwrap() } } diff --git a/crates/polars-core/src/series/series_trait.rs b/crates/polars-core/src/series/series_trait.rs index 64f403e0393e..af89119feb18 100644 --- a/crates/polars-core/src/series/series_trait.rs +++ b/crates/polars-core/src/series/series_trait.rs @@ -563,14 +563,13 @@ pub trait SeriesTrait: invalid_operation_panic!(get_object_chunked_unchecked, self) } - /// Get a hold to self as `Any` trait reference. + /// Get a hold of the [`ChunkedArray`], [`Logical`] or `NullChunked` as an `Any` trait + /// reference. fn as_any(&self) -> &dyn Any; - /// Get a hold to self as `Any` trait reference. - /// Only implemented for ObjectType - fn as_any_mut(&mut self) -> &mut dyn Any { - invalid_operation_panic!(as_any_mut, self) - } + /// Get a hold of the [`ChunkedArray`], [`Logical`] or `NullChunked` as an `Any` trait mutable + /// reference. + fn as_any_mut(&mut self) -> &mut dyn Any; #[cfg(feature = "checked_arithmetic")] fn checked_div(&self, _rhs: &Series) -> PolarsResult { @@ -592,7 +591,7 @@ pub trait SeriesTrait: impl (dyn SeriesTrait + '_) { pub fn unpack(&self) -> PolarsResult<&ChunkedArray> where - N: 'static + PolarsDataType, + N: 'static + PolarsDataType, { polars_ensure!(&N::get_dtype() == self.dtype(), unpack); Ok(self.as_ref()) diff --git a/crates/polars-expr/src/reduce/mod.rs b/crates/polars-expr/src/reduce/mod.rs index 55c4db56ddbe..bf38d62a988b 100644 --- a/crates/polars-expr/src/reduce/mod.rs +++ b/crates/polars-expr/src/reduce/mod.rs @@ -94,7 +94,7 @@ pub trait GroupedReduction: Any + Send + Sync { // Helper traits used in the VecGroupedReduction and VecMaskGroupedReduction to // reduce code duplication. pub trait Reducer: Send + Sync + Clone + 'static { - type Dtype: PolarsDataType; + type Dtype: PolarsDataType; type Value: Clone + Send + Sync + 'static; fn init(&self) -> Self::Value; #[inline(always)]