From 4508d82d3a6e306c76828a74da0b2e564ec28c72 Mon Sep 17 00:00:00 2001 From: Lucas Date: Sat, 2 Dec 2023 20:07:41 +0100 Subject: [PATCH] add bench vecotrizer --- .gitignore | 3 + algorithms/linfa-preprocessing/Cargo.toml | 4 + .../benches/vectorizer_bench.rs | 98 +++++++++++-------- 3 files changed, 65 insertions(+), 40 deletions(-) diff --git a/.gitignore b/.gitignore index ba8a9bdd8..e32d1694a 100644 --- a/.gitignore +++ b/.gitignore @@ -36,3 +36,6 @@ poetry.lock # Generated artifacts of website (with Zola) docs/website/public/* docs/website/static/rustdocs/ + +# Downloaded data for the linfa-preprocessing benches +20news/ \ No newline at end of file diff --git a/algorithms/linfa-preprocessing/Cargo.toml b/algorithms/linfa-preprocessing/Cargo.toml index 7f39de194..c2e6aa152 100644 --- a/algorithms/linfa-preprocessing/Cargo.toml +++ b/algorithms/linfa-preprocessing/Cargo.toml @@ -62,3 +62,7 @@ harness = false [[bench]] name = "whitening_bench" harness = false + +[[bench]] +name = "norm_scaler_bench" +harness = false \ No newline at end of file diff --git a/algorithms/linfa-preprocessing/benches/vectorizer_bench.rs b/algorithms/linfa-preprocessing/benches/vectorizer_bench.rs index 8dfe6eba0..0debbbafd 100644 --- a/algorithms/linfa-preprocessing/benches/vectorizer_bench.rs +++ b/algorithms/linfa-preprocessing/benches/vectorizer_bench.rs @@ -5,6 +5,9 @@ use linfa_preprocessing::CountVectorizer; use std::path::Path; use tar::Archive; +use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion}; +use linfa::benchmarks::config; + fn download_20news_bydate() { let mut data = Vec::new(); let mut easy = Easy::new(); @@ -77,95 +80,110 @@ fn load_test_set(desired_targets: &[&str]) -> Result, st load_set("./20news/20news-bydate-test", desired_targets) } -fn iai_fit_vectorizer() { - let file_names = load_20news_bydate(); +fn fit_vectorizer(file_names: &Vec) { CountVectorizer::params() .document_frequency(0.05, 0.75) .n_gram_range(1, 2) .fit_files( - &file_names, + file_names, encoding::all::ISO_8859_1, encoding::DecoderTrap::Strict, ) .unwrap(); } -fn iai_fit_tf_idf() { - let file_names = load_20news_bydate(); +fn fit_tf_idf(file_names: &Vec) { TfIdfVectorizer::default() .document_frequency(0.05, 0.75) .n_gram_range(1, 2) .fit_files( - &file_names, + file_names, encoding::all::ISO_8859_1, encoding::DecoderTrap::Strict, ) .unwrap(); } -fn iai_fit_transform_vectorizer() { - let file_names = load_20news_bydate(); +fn fit_transform_vectorizer(file_names: &Vec) { CountVectorizer::params() .document_frequency(0.05, 0.75) .n_gram_range(1, 2) .fit_files( - &file_names, + file_names, encoding::all::ISO_8859_1, encoding::DecoderTrap::Strict, ) .unwrap() .transform_files( - &file_names, + file_names, encoding::all::ISO_8859_1, encoding::DecoderTrap::Strict, ); } -fn iai_fit_transform_tf_idf() { - let file_names = load_20news_bydate(); +fn fit_transform_tf_idf(file_names: &Vec) { TfIdfVectorizer::default() .document_frequency(0.05, 0.75) .n_gram_range(1, 2) .fit_files( - &file_names, + file_names, encoding::all::ISO_8859_1, encoding::DecoderTrap::Strict, ) .unwrap() .transform_files( - &file_names, + file_names, encoding::all::ISO_8859_1, encoding::DecoderTrap::Strict, ); } -macro_rules! iai_main { - ( $( $func_name:ident ),+ $(,)* ) => { - mod iai_wrappers { - $( - pub fn $func_name() { - let _ = iai::black_box(super::$func_name()); - } - )+ - } +fn bench(c: &mut Criterion) { + let mut benchmark = c.benchmark_group("Linfa_preprocessing_vectorizer"); + config::set_default_benchmark_configs(&mut benchmark); - fn main() { - load_20news_bydate(); - let benchmarks : &[&(&'static str, fn())]= &[ + let file_names = load_20news_bydate(); - $( - &(stringify!($func_name), iai_wrappers::$func_name), - )+ - ]; + benchmark.bench_function( + BenchmarkId::new("Fit-Vectorizer", "20news_bydate"), + |bencher| { + bencher.iter(|| { + fit_vectorizer(black_box(&file_names)); + }); + }, + ); - iai::runner(benchmarks); - std::fs::remove_dir_all("./20news").unwrap_or(()); - } - } + benchmark.bench_function(BenchmarkId::new("Fit-Tf-Idf", "20news_bydate"), |bencher| { + bencher.iter(|| { + fit_tf_idf(black_box(&file_names)); + }); + }); + + benchmark.bench_function( + BenchmarkId::new("Fit-Transfor-Vectorizer", "20news_bydate"), + |bencher| { + bencher.iter(|| { + fit_transform_vectorizer(black_box(&file_names)); + }); + }, + ); + + benchmark.bench_function( + BenchmarkId::new("Fit-Transfor-Tf-Idf", "20news_bydate"), + |bencher| { + bencher.iter(|| { + fit_transform_tf_idf(black_box(&file_names)); + }); + }, + ); +} + +#[cfg(not(target_os = "windows"))] +criterion_group! { + name = benches; + config = config::get_default_profiling_configs(); + targets = bench } +#[cfg(target_os = "windows")] +criterion_group!(benches, bench); -iai_main!( - iai_fit_vectorizer, - iai_fit_transform_vectorizer, - iai_fit_tf_idf, - iai_fit_transform_tf_idf -); +criterion_main!(benches);