Skip to content

Commit

Permalink
add bench vecotrizer
Browse files Browse the repository at this point in the history
  • Loading branch information
lucas-montes committed Dec 2, 2023
1 parent a6721c8 commit 4508d82
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 40 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,6 @@ poetry.lock
# Generated artifacts of website (with Zola)
docs/website/public/*
docs/website/static/rustdocs/

# Downloaded data for the linfa-preprocessing benches
20news/
4 changes: 4 additions & 0 deletions algorithms/linfa-preprocessing/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,7 @@ harness = false
[[bench]]
name = "whitening_bench"
harness = false

[[bench]]
name = "norm_scaler_bench"
harness = false
98 changes: 58 additions & 40 deletions algorithms/linfa-preprocessing/benches/vectorizer_bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ use linfa_preprocessing::CountVectorizer;
use std::path::Path;
use tar::Archive;

use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion};
use linfa::benchmarks::config;

fn download_20news_bydate() {
let mut data = Vec::new();
let mut easy = Easy::new();
Expand Down Expand Up @@ -77,95 +80,110 @@ fn load_test_set(desired_targets: &[&str]) -> Result<Vec<std::path::PathBuf>, st
load_set("./20news/20news-bydate-test", desired_targets)
}

fn iai_fit_vectorizer() {
let file_names = load_20news_bydate();
fn fit_vectorizer(file_names: &Vec<std::path::PathBuf>) {
CountVectorizer::params()
.document_frequency(0.05, 0.75)
.n_gram_range(1, 2)
.fit_files(
&file_names,
file_names,
encoding::all::ISO_8859_1,
encoding::DecoderTrap::Strict,
)
.unwrap();
}

fn iai_fit_tf_idf() {
let file_names = load_20news_bydate();
fn fit_tf_idf(file_names: &Vec<std::path::PathBuf>) {
TfIdfVectorizer::default()
.document_frequency(0.05, 0.75)
.n_gram_range(1, 2)
.fit_files(
&file_names,
file_names,
encoding::all::ISO_8859_1,
encoding::DecoderTrap::Strict,
)
.unwrap();
}

fn iai_fit_transform_vectorizer() {
let file_names = load_20news_bydate();
fn fit_transform_vectorizer(file_names: &Vec<std::path::PathBuf>) {
CountVectorizer::params()
.document_frequency(0.05, 0.75)
.n_gram_range(1, 2)
.fit_files(
&file_names,
file_names,
encoding::all::ISO_8859_1,
encoding::DecoderTrap::Strict,
)
.unwrap()
.transform_files(
&file_names,
file_names,
encoding::all::ISO_8859_1,
encoding::DecoderTrap::Strict,
);
}
fn iai_fit_transform_tf_idf() {
let file_names = load_20news_bydate();
fn fit_transform_tf_idf(file_names: &Vec<std::path::PathBuf>) {
TfIdfVectorizer::default()
.document_frequency(0.05, 0.75)
.n_gram_range(1, 2)
.fit_files(
&file_names,
file_names,
encoding::all::ISO_8859_1,
encoding::DecoderTrap::Strict,
)
.unwrap()
.transform_files(
&file_names,
file_names,
encoding::all::ISO_8859_1,
encoding::DecoderTrap::Strict,
);
}

macro_rules! iai_main {
( $( $func_name:ident ),+ $(,)* ) => {
mod iai_wrappers {
$(
pub fn $func_name() {
let _ = iai::black_box(super::$func_name());
}
)+
}
fn bench(c: &mut Criterion) {
let mut benchmark = c.benchmark_group("Linfa_preprocessing_vectorizer");
config::set_default_benchmark_configs(&mut benchmark);

fn main() {
load_20news_bydate();
let benchmarks : &[&(&'static str, fn())]= &[
let file_names = load_20news_bydate();

$(
&(stringify!($func_name), iai_wrappers::$func_name),
)+
];
benchmark.bench_function(
BenchmarkId::new("Fit-Vectorizer", "20news_bydate"),
|bencher| {
bencher.iter(|| {
fit_vectorizer(black_box(&file_names));
});
},
);

iai::runner(benchmarks);
std::fs::remove_dir_all("./20news").unwrap_or(());
}
}
benchmark.bench_function(BenchmarkId::new("Fit-Tf-Idf", "20news_bydate"), |bencher| {
bencher.iter(|| {
fit_tf_idf(black_box(&file_names));
});
});

benchmark.bench_function(
BenchmarkId::new("Fit-Transfor-Vectorizer", "20news_bydate"),
|bencher| {
bencher.iter(|| {
fit_transform_vectorizer(black_box(&file_names));
});
},
);

benchmark.bench_function(
BenchmarkId::new("Fit-Transfor-Tf-Idf", "20news_bydate"),
|bencher| {
bencher.iter(|| {
fit_transform_tf_idf(black_box(&file_names));
});
},
);
}

#[cfg(not(target_os = "windows"))]
criterion_group! {
name = benches;
config = config::get_default_profiling_configs();
targets = bench
}
#[cfg(target_os = "windows")]
criterion_group!(benches, bench);

iai_main!(
iai_fit_vectorizer,
iai_fit_transform_vectorizer,
iai_fit_tf_idf,
iai_fit_transform_tf_idf
);
criterion_main!(benches);

0 comments on commit 4508d82

Please sign in to comment.