Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

experiment in compressing masks #99

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions json_stats/src/json_stats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ struct LlgResult {
max_mask_us: usize,
#[serde(skip_serializing_if = "is_zero")]
slicer_leftover_us: usize,
#[serde(skip_serializing_if = "is_zero")]
compressed_mask_size: usize,

one: usize,

Expand Down Expand Up @@ -382,6 +384,7 @@ impl TestEnv {

let m = parser.parser.metrics_mut();
stats.slicer_leftover_us += m.slicer_leftover_us;
stats.compressed_mask_size += m.compressed_mask_size;

let lx = parser.parser.lexer_stats();
stats.max_lexer_states = std::cmp::max(stats.max_lexer_states, lx.num_states);
Expand Down
1 change: 1 addition & 0 deletions parser/src/earley/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ pub struct ParserMetrics {
pub rand: XorShift,
pub message: String,
pub slicer_leftover_us: usize,
pub compressed_mask_size: usize,
}

impl ParserStats {
Expand Down
55 changes: 53 additions & 2 deletions parser/src/earley/slicer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::{
toktrie::{SimpleVob, TokEnv, TokTrie, TokenId},
};

use super::parser::ITEM_TRACE;
use super::{parser::ITEM_TRACE, ParserMetrics};

struct TokenizerSlice {
idx: usize,
Expand Down Expand Up @@ -187,6 +187,7 @@ impl BiasComputer for SlicedBiasComputer {
fn compute_bias<'b>(&self, rec: &mut ParserRecognizer<'b>, start: &[u8]) -> SimpleVob {
let mut set = self.trie().alloc_token_set();
let lexer_state = rec.lexer_state();

if self.slices.len() > 0
&& start.is_empty()
&& rec.lexer_mut().subsume_possible(lexer_state)
Expand All @@ -208,9 +209,11 @@ impl BiasComputer for SlicedBiasComputer {
if slice_matches.iter().all(|&x| x == false) {
// if nothing matches, just run the full trie
self.wildcard_slice.add_bias(rec, &mut set, start);
apply_metrics(rec.metrics_mut(), &set);
debug!("no slice matches; {} tokens", set.num_set());
} else {
// otherwise, apply the matching slices, and compute the rest
let mut acc = self.trie().alloc_token_set();
for (i, slice) in self.slices.iter().enumerate() {
if slice_matches[i] {
rec.stats_mut().slices_applied += 1;
Expand All @@ -219,7 +222,8 @@ impl BiasComputer for SlicedBiasComputer {
// assert!(slice.regex == "");
let c0 = if DEBUG { set.num_set() } else { 0 };
let t0 = std::time::Instant::now();
slice.trie.add_bias(rec, &mut set, start);
slice.trie.add_bias(rec, &mut acc, start);
set.or(&acc);
let us = t0.elapsed().as_micros() as usize;
rec.metrics_mut().slicer_leftover_us += us;
debug!("slice matches #{}; {} tokens", i, set.num_set() - c0);
Expand All @@ -234,9 +238,11 @@ impl BiasComputer for SlicedBiasComputer {
// }
}
}
apply_metrics(rec.metrics_mut(), &acc);
}
} else {
self.wildcard_slice.add_bias(rec, &mut set, start);
apply_metrics(rec.metrics_mut(), &set);
debug!("slicer disabled; {} tokens", set.num_set());
}

Expand All @@ -249,3 +255,48 @@ impl BiasComputer for SlicedBiasComputer {
self.tok_env.tok_trie()
}
}

fn apply_metrics(parser_metrics: &mut ParserMetrics, mask: &SimpleVob) {
//let size = compress_mask(&mask).len();
let size = std::cmp::min(mask.num_set() * 2, mask.len() / 8);
parser_metrics.compressed_mask_size += size;
}

fn compress_mask(s: &SimpleVob) -> Vec<u8> {
let mut res: Vec<u8> = vec![];
let mut num_zero = 0;
for &d in s.as_slice() {
let num_bits = d.count_ones();
if num_bits == 0 {
if num_zero < 32 {
num_zero += 1;
continue;
}
}
if num_zero > 0 {
res.push(num_zero + 32);
num_zero = 0;
}
if num_bits == 1 {
res.push(d.leading_zeros() as u8);
} else if num_bits == 2 {
res.push(d.leading_zeros() as u8);
res.push(d.leading_zeros() as u8);
} else if num_bits == 3 {
res.push(d.leading_zeros() as u8);
res.push(d.leading_zeros() as u8);
res.push(d.leading_zeros() as u8);
} else if false && num_bits == 31 {
res.push(d.leading_ones() as u8);
} else if num_bits == 32 {
res.push(60);
} else {
res.push(61);
res.push(d as u8);
res.push((d >> 8) as u8);
res.push((d >> 16) as u8);
res.push((d >> 24) as u8);
}
}
res
}
Loading