-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: Fix performance regression for DataFrame serialization/pickling #20641
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,6 +20,7 @@ use crate::{HEAD_DEFAULT_LENGTH, TAIL_DEFAULT_LENGTH}; | |
#[cfg(feature = "dataframe_arithmetic")] | ||
mod arithmetic; | ||
mod chunks; | ||
pub use chunks::chunk_df_for_writing; | ||
pub mod column; | ||
pub mod explode; | ||
mod from; | ||
|
@@ -3578,41 +3579,4 @@ mod test { | |
assert_eq!(df.get_column_names(), &["a", "b", "c"]); | ||
Ok(()) | ||
} | ||
|
||
#[cfg(feature = "serde")] | ||
#[test] | ||
fn test_deserialize_height_validation_8751() { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Test no longer works as serialization directly errors now on mismatching height |
||
// Construct an invalid directly from the inner fields as the `new_unchecked_*` functions | ||
// have debug assertions | ||
|
||
use polars_utils::pl_serialize; | ||
|
||
let df = DataFrame { | ||
height: 2, | ||
columns: vec![ | ||
Int64Chunked::full("a".into(), 1, 2).into_column(), | ||
Int64Chunked::full("b".into(), 1, 1).into_column(), | ||
], | ||
cached_schema: OnceLock::new(), | ||
}; | ||
|
||
// We rely on the fact that the serialization doesn't check the heights of all columns | ||
let serialized = serde_json::to_string(&df).unwrap(); | ||
let err = serde_json::from_str::<DataFrame>(&serialized).unwrap_err(); | ||
|
||
assert!(err.to_string().contains( | ||
"successful parse invalid data: lengths don't match: could not create a new DataFrame:", | ||
)); | ||
|
||
let serialized = pl_serialize::SerializeOptions::default() | ||
.serialize_to_bytes(&df) | ||
.unwrap(); | ||
let err = pl_serialize::SerializeOptions::default() | ||
.deserialize_from_reader::<DataFrame, _>(serialized.as_slice()) | ||
.unwrap_err(); | ||
|
||
assert!(err.to_string().contains( | ||
"successful parse invalid data: lengths don't match: could not create a new DataFrame:", | ||
)); | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,35 +1,144 @@ | ||
use polars_error::PolarsError; | ||
use std::sync::Arc; | ||
|
||
use arrow::datatypes::Metadata; | ||
use arrow::io::ipc::read::{read_stream_metadata, StreamReader, StreamState}; | ||
use arrow::io::ipc::write::WriteOptions; | ||
use polars_error::{polars_err, to_compute_err, PolarsResult}; | ||
use polars_utils::format_pl_smallstr; | ||
use polars_utils::pl_serialize::deserialize_map_bytes; | ||
use polars_utils::pl_str::PlSmallStr; | ||
use serde::de::Error; | ||
use serde::*; | ||
|
||
use crate::prelude::{Column, DataFrame}; | ||
|
||
// utility to ensure we serde to a struct | ||
// { | ||
// columns: Vec<Series> | ||
// } | ||
// that ensures it differentiates between Vec<Series> | ||
// and is backwards compatible | ||
#[derive(Deserialize)] | ||
struct Util { | ||
columns: Vec<Column>, | ||
} | ||
use crate::chunked_array::flags::StatisticsFlags; | ||
use crate::config; | ||
use crate::frame::chunk_df_for_writing; | ||
use crate::prelude::{CompatLevel, DataFrame, SchemaExt}; | ||
use crate::utils::accumulate_dataframes_vertical_unchecked; | ||
|
||
#[derive(Serialize)] | ||
struct UtilBorrowed<'a> { | ||
columns: &'a [Column], | ||
} | ||
const FLAGS_KEY: PlSmallStr = PlSmallStr::from_static("_PL_FLAGS"); | ||
|
||
impl<'de> Deserialize<'de> for DataFrame { | ||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> | ||
where | ||
D: Deserializer<'de>, | ||
{ | ||
let parsed = <Util>::deserialize(deserializer)?; | ||
DataFrame::new(parsed.columns).map_err(|e| { | ||
let e = PolarsError::ComputeError(format!("successful parse invalid data: {e}").into()); | ||
D::Error::custom::<PolarsError>(e) | ||
}) | ||
impl DataFrame { | ||
pub fn serialize_into_writer(&mut self, writer: &mut dyn std::io::Write) -> PolarsResult<()> { | ||
let schema = self.schema(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Serialization logic moved to |
||
|
||
if schema.iter_values().any(|x| x.is_object()) { | ||
return Err(polars_err!( | ||
ComputeError: | ||
"serializing data of type Object is not supported", | ||
)); | ||
} | ||
|
||
let mut ipc_writer = | ||
arrow::io::ipc::write::StreamWriter::new(writer, WriteOptions { compression: None }); | ||
|
||
ipc_writer.set_custom_schema_metadata(Arc::new(Metadata::from_iter( | ||
self.get_columns().iter().map(|c| { | ||
( | ||
format_pl_smallstr!("{}{}", FLAGS_KEY, c.name()), | ||
PlSmallStr::from(c.get_flags().bits().to_string()), | ||
) | ||
}), | ||
))); | ||
|
||
ipc_writer.set_custom_schema_metadata(Arc::new(Metadata::from([( | ||
FLAGS_KEY, | ||
serde_json::to_string( | ||
&self | ||
.iter() | ||
.map(|s| s.get_flags().bits()) | ||
.collect::<Vec<u32>>(), | ||
) | ||
.map_err(to_compute_err)? | ||
.into(), | ||
)]))); | ||
|
||
ipc_writer.start(&schema.to_arrow(CompatLevel::newest()), None)?; | ||
|
||
for batch in chunk_df_for_writing(self, 512 * 512)?.iter_chunks(CompatLevel::newest(), true) | ||
{ | ||
ipc_writer.write(&batch, None)?; | ||
} | ||
|
||
ipc_writer.finish()?; | ||
|
||
Ok(()) | ||
} | ||
|
||
pub fn serialize_to_bytes(&mut self) -> PolarsResult<Vec<u8>> { | ||
let mut buf = vec![]; | ||
self.serialize_into_writer(&mut buf)?; | ||
|
||
Ok(buf) | ||
} | ||
|
||
pub fn deserialize_from_reader(reader: &mut dyn std::io::Read) -> PolarsResult<Self> { | ||
let mut md = read_stream_metadata(reader)?; | ||
let arrow_schema = md.schema.clone(); | ||
|
||
let custom_metadata = md.custom_schema_metadata.take(); | ||
|
||
let reader = StreamReader::new(reader, md, None); | ||
let dfs = reader | ||
.into_iter() | ||
.map_while(|batch| match batch { | ||
Ok(StreamState::Some(batch)) => Some(DataFrame::try_from((batch, &arrow_schema))), | ||
Ok(StreamState::Waiting) => None, | ||
Err(e) => Some(Err(e)), | ||
}) | ||
.collect::<PolarsResult<Vec<DataFrame>>>()?; | ||
|
||
let mut df = accumulate_dataframes_vertical_unchecked(dfs); | ||
|
||
// Set custom metadata (fallible) | ||
(|| { | ||
let custom_metadata = custom_metadata?; | ||
let flags = custom_metadata.get(&FLAGS_KEY)?; | ||
|
||
let flags: PolarsResult<Vec<u32>> = serde_json::from_str(flags).map_err(to_compute_err); | ||
|
||
let verbose = config::verbose(); | ||
|
||
if let Err(e) = &flags { | ||
if verbose { | ||
eprintln!("DataFrame::read_ipc: Error parsing metadata flags: {}", e); | ||
} | ||
} | ||
|
||
let flags = flags.ok()?; | ||
|
||
if flags.len() != df.width() { | ||
if verbose { | ||
eprintln!( | ||
"DataFrame::read_ipc: Metadata flags width mismatch: {} != {}", | ||
flags.len(), | ||
df.width() | ||
); | ||
} | ||
|
||
return None; | ||
} | ||
|
||
let mut n_set = 0; | ||
|
||
for (c, v) in unsafe { df.get_columns_mut() }.iter_mut().zip(flags) { | ||
if let Some(flags) = StatisticsFlags::from_bits(v) { | ||
n_set += c.set_flags(flags) as usize; | ||
} | ||
} | ||
|
||
if verbose { | ||
eprintln!( | ||
"DataFrame::read_ipc: Loaded metadata for {} / {} columns", | ||
n_set, | ||
df.width() | ||
); | ||
} | ||
|
||
Some(()) | ||
})(); | ||
|
||
Ok(df) | ||
} | ||
} | ||
|
||
|
@@ -38,9 +147,26 @@ impl Serialize for DataFrame { | |
where | ||
S: Serializer, | ||
{ | ||
UtilBorrowed { | ||
columns: &self.columns, | ||
} | ||
.serialize(serializer) | ||
use serde::ser::Error; | ||
|
||
let mut bytes = vec![]; | ||
self.clone() | ||
.serialize_into_writer(&mut bytes) | ||
.map_err(S::Error::custom)?; | ||
|
||
serializer.serialize_bytes(bytes.as_slice()) | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is where the extra memcopy happens when going through serde - we are calling |
||
} | ||
|
||
impl<'de> Deserialize<'de> for DataFrame { | ||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> | ||
where | ||
D: Deserializer<'de>, | ||
{ | ||
deserialize_map_bytes(deserializer, &mut |b| { | ||
let v = &mut b.as_ref(); | ||
Self::deserialize_from_reader(v) | ||
})? | ||
.map_err(D::Error::custom) | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code is moved from
polars-io