Skip to content

Commit

Permalink
Merge branch 'trie_compression'
Browse files Browse the repository at this point in the history
  • Loading branch information
mmoskal committed Dec 19, 2024
2 parents 0c8edea + 136013f commit a328a15
Show file tree
Hide file tree
Showing 12 changed files with 388 additions and 208 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion json_stats/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ else
fi

if [ -z "$PERF" ]; then
cargo run --release $DEFAULT_ARGS "$@"
cargo run --release -- $DEFAULT_ARGS "$@"
else
PERF='perf record -F 999 -g'
RUSTFLAGS='-C force-frame-pointers=y' cargo build --profile perf
Expand Down
95 changes: 75 additions & 20 deletions json_stats/src/json_stats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use clap::Parser;
use json_stats::SchemaStats;
use jsonschema::Validator;
use llguidance::{
earley::regexvec::LexerStats,
toktrie::{InferenceCapabilities, TokEnv},
Constraint, JsonCompileOptions, ParserFactory, TokenParser,
};
Expand All @@ -12,7 +13,7 @@ use std::{
collections::HashMap,
fs::File,
io::{Read, Write},
sync::Arc,
sync::{atomic::AtomicUsize, Arc},
};

use rayon::prelude::*;
Expand Down Expand Up @@ -42,6 +43,12 @@ pub struct CliOptions {
#[arg(long, short = 's')]
llg_slicer: bool,

#[arg(long)]
llg_no_forcing: bool,

#[arg(long)]
csv: bool,

#[arg(long)]
num_threads: Option<usize>,

Expand All @@ -54,6 +61,9 @@ pub struct CliOptions {
#[arg(long)]
additional_features: bool,

#[arg(long, default_value = "meta-llama/Llama-3.1-8B-Instruct")]
tokenizer: String,

// .json files or folders with .json files
#[arg(value_name = "FILES")]
files: Vec<String>,
Expand All @@ -68,12 +78,16 @@ struct LlgResult {
#[serde(skip_serializing_if = "is_zero")]
ttfm_us: usize,
#[serde(skip_serializing_if = "is_zero")]
max_ttfm_us: usize,
#[serde(skip_serializing_if = "is_zero")]
masks_us: usize,
#[serde(skip_serializing_if = "is_zero")]
max_mask_us: usize,
#[serde(skip_serializing_if = "is_zero")]
slicer_leftover_us: usize,

one: usize,

num_tokens: usize,
num_tests: usize,
num_valid_tests: usize,
Expand All @@ -85,6 +99,11 @@ struct LlgResult {
max_sum_parser_items: usize,
max_parser_items: usize,
max_lexer_cost: u64,
max_lexer_states: usize,
lexer_cost: u64,
trie_nodes_walked: usize,

lexer_stats: LexerStats,

#[serde(skip)]
slow_mask_count: [usize; MASK_STEPS],
Expand Down Expand Up @@ -212,8 +231,10 @@ impl TestEnv {

stats.num_tests += 1;

// println!("tokenized: {}", trie.tokens_dbg(&tokens));

for (tidx, &token) in tokens.iter().enumerate() {
//println!("WILL TEST {}: {}", tidx, trie.token_dbg(token));
// eprintln!("WILL TEST {}: {}", tidx, trie.token_dbg(token));

stats.num_tokens += 1;

Expand All @@ -223,17 +244,44 @@ impl TestEnv {
let us = t0.elapsed().as_micros() as usize;
let pstats = parser.last_step_stats();

// && pstats.lexer_cost < 7 * us as u64
if self.cli.csv && us > 1000 {
static CSV_LINE: AtomicUsize = AtomicUsize::new(0);
let line_no = CSV_LINE.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
if line_no == 0 {
println!("MASK,us,lexer_cost,slices,items,rows,cached_rows,trie_nodes,allowed_tokens,est_time");
}
println!(
"{},{},{},{},{},{},{},{},{},{}",
if us > 1000 { "SLOW" } else { "OK" },
us,
pstats.lexer_cost,
pstats.slices_applied,
pstats.all_items,
pstats.rows,
pstats.cached_rows,
pstats.trie_nodes_walked,
m.num_set(),
(pstats.trie_nodes_walked as u64 * 4 + pstats.lexer_cost * 60) / 1000
);
// eprintln!("{}", parser.parser.lexer_stats());

// eprintln!("{:?}", pstats);
}

stats.sum_parser_items += pstats.all_items;
stats.max_parser_items = std::cmp::max(stats.max_parser_items, pstats.all_items);
stats.max_lexer_cost = std::cmp::max(stats.max_lexer_cost, pstats.lexer_cost);
stats.lexer_cost += pstats.lexer_cost;
stats.trie_nodes_walked += pstats.trie_nodes_walked;

let step = us.next_power_of_two().trailing_zeros() as usize;
let step = std::cmp::min(step, MASK_STEPS - 1);

stats.slow_mask_count[step] += 1;
stats.slow_mask_us[step] += us;

assert!(pstats.slices_applied <= 1);
// assert!(pstats.slices_applied <= 1);

let is_big = m.num_set() >= 120_000;
let sliced = pstats.slices_applied > 0;
Expand Down Expand Up @@ -301,6 +349,9 @@ impl TestEnv {
let m = parser.parser.metrics_mut();
stats.slicer_leftover_us += m.slicer_leftover_us;

let lx = parser.parser.lexer_stats();
stats.max_lexer_states = std::cmp::max(stats.max_lexer_states, lx.num_states);

r
}

Expand All @@ -311,7 +362,7 @@ impl TestEnv {
let t0 = std::time::Instant::now();
let schema = opts.json_to_llg(test_file.schema.clone());

let schema = match schema {
let mut schema = match schema {
Ok(schema) => schema,
Err(e) => {
res.compile_error = Some(format!("{e}"));
Expand All @@ -321,13 +372,19 @@ impl TestEnv {
}
};

if self.cli.llg_no_forcing {
schema.grammars[0].no_forcing = true;
}

let parser = self.factory.create_parser(schema);

let parser = match parser {
Ok(parser) => {
let mut constraint = Constraint::new(parser.clone());
constraint.compute_mask().unwrap();
res.ttfm_us = t0.elapsed().as_micros() as usize;
res.max_ttfm_us = res.ttfm_us;
res.one = 1;
parser
// eprintln!("{} OK", file);
}
Expand All @@ -339,6 +396,8 @@ impl TestEnv {
}
};

res.lexer_stats = parser.parser.lexer_stats();

if self.cli.llg_test {
for (idx, t) in test_file.tests.iter().enumerate() {
let t0 = std::time::Instant::now();
Expand Down Expand Up @@ -523,18 +582,12 @@ fn main() {
files.retain(|f| !f.contains("Handwritten") && !f.contains("Synthesized"));
}

// "microsoft/Phi-3.5-mini-instruct"
let tok_env: TokEnv = toktrie_hf_tokenizers::ByteTokenizerEnv::from_name(
"meta-llama/Llama-3.1-8B-Instruct",
None,
)
.unwrap()
.to_env();

let mut slices = vec![
r#"[^"\\\x00-\x1F\x7F]{1,30}"#.to_string(),
// r#"[^"\\\x00-\x1F\x7F]+"#.to_string(),
];
let tok_env: TokEnv =
toktrie_hf_tokenizers::ByteTokenizerEnv::from_name(&options.tokenizer, None)
.unwrap()
.to_env();

let mut slices = llguidance::earley::SlicedBiasComputer::json_slices();
if !options.llg_slicer {
slices.clear();
}
Expand All @@ -552,6 +605,9 @@ fn main() {
factory.quiet();
let factory = Arc::new(factory);

save_text_to_file("tmp/slices.txt", &factory.slicer().stats(false));
save_text_to_file("tmp/slices_tokens.txt", &factory.slicer().stats(true));

let t0 = std::time::Instant::now();
let par = num_threads > 1;
let do_file = |file: &String| {
Expand Down Expand Up @@ -692,13 +748,12 @@ fn main() {
}
}

println!("{}", serde_json::to_string_pretty(&total).unwrap());
println!(
eprintln!("{}", serde_json::to_string_pretty(&total).unwrap());
eprintln!(
"LLG: {}",
serde_json::to_string_pretty(&llg_totals).unwrap()
);

println!("Total time: {}ms", t0.elapsed().as_millis());
eprintln!("Total time: {}ms", t0.elapsed().as_millis());

save_text_to_file("tmp/mask_histogram.csv", &histogram_csv);
save_json_to_file("tmp/test_total.json", &total);
Expand Down
2 changes: 1 addition & 1 deletion parser/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ edition = "2021"

[dependencies]
toktrie = { workspace = true }
derivre = { git = "https://github.com/microsoft/derivre", rev = "5cd57842ed6d2b64156b5d44f4d2094097855e44" }
derivre = { git = "https://github.com/microsoft/derivre", rev = "a629ecc9dfbd9297bc3a35a5f2028f4bfce91060" }
serde = { version = "1.0.210", features = ["derive"] }
serde_json = { version = "1.0.132", features = ["preserve_order"] }
anyhow = "1.0.90"
Expand Down
2 changes: 1 addition & 1 deletion parser/llguidance.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ typedef struct LlgParserLimits {
size_t step_max_items;
/**
* Maximum number of lexer states.
* Default: 10_000
* Default: 50_000
*/
size_t max_lexer_states;
/**
Expand Down
4 changes: 2 additions & 2 deletions parser/src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ pub struct ParserLimits {
pub step_max_items: usize,

/// Maximum number of lexer states.
/// Default: 10_000
/// Default: 50_000
pub max_lexer_states: usize,

/// Maximum size of the grammar (symbols in productions)
Expand All @@ -396,7 +396,7 @@ impl Default for ParserLimits {
max_items_in_row: 2000,
initial_lexer_fuel: 1_000_000, // fhir schema => 500k
step_lexer_fuel: 200_000, //
max_lexer_states: 10_000, // ?
max_lexer_states: 50_000, // ?
max_grammar_size: 500_000, // fhir schema => 200k
step_max_items: 50_000, //
}
Expand Down
22 changes: 18 additions & 4 deletions parser/src/earley/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use std::{
};

use anyhow::{bail, ensure, Result};
use derivre::{RegexAst, StateID};
use derivre::{AlphabetInfo, RegexAst, StateID};
use hashbrown::HashSet;
use instant::Instant;
use serde::{Deserialize, Serialize};
Expand All @@ -28,6 +28,7 @@ use super::{
grammar::{CGrammar, CSymIdx, CSymbol, RhsPtr},
lexer::{LexerResult, PreLexeme},
lexerspec::{Lexeme, LexemeIdx, LexerSpec},
regexvec::LexerStats,
};

const TRACE: bool = false;
Expand Down Expand Up @@ -72,6 +73,7 @@ pub struct ParserStats {
pub all_items: usize,
pub lexer_cost: u64,
pub slices_applied: usize,
pub trie_nodes_walked: usize,

pub definitive_bytes: usize,
pub lexer_ops: usize,
Expand Down Expand Up @@ -142,6 +144,9 @@ impl ParserStats {
.compute_time_us
.saturating_sub(previous.compute_time_us),
slices_applied: self.slices_applied.saturating_sub(previous.slices_applied),
trie_nodes_walked: self
.trie_nodes_walked
.saturating_sub(previous.trie_nodes_walked),
}
}

Expand All @@ -157,6 +162,7 @@ impl ParserStats {
lexer_cost: self.lexer_cost.max(other.lexer_cost),
compute_time_us: self.compute_time_us.max(other.compute_time_us),
slices_applied: self.slices_applied.max(other.slices_applied),
trie_nodes_walked: self.trie_nodes_walked.max(other.trie_nodes_walked),
}
}
}
Expand Down Expand Up @@ -636,8 +642,6 @@ impl ParserState {
assert!(toks.len() == 1);
set.disallow_token(toks[0]);

computer.trie().apply_duplicates(&mut set);

if set.is_zero() {
// nothing allowed
// we're going to be stopped outside - we better flush the lexer
Expand Down Expand Up @@ -2106,6 +2110,10 @@ impl<'a> Recognizer for ParserRecognizer<'a> {

r
}

fn save_stats(&mut self, nodes_walked: usize) {
self.state.stats.trie_nodes_walked += nodes_walked;
}
}

fn item_to_string(g: &CGrammar, item: &Item) -> String {
Expand Down Expand Up @@ -2170,10 +2178,16 @@ impl Parser {
self.state.hidden_start(shared.lexer_mut())
}

pub fn lexer_stats(&self) -> String {
pub fn lexer_stats(&self) -> LexerStats {
self.shared.lock().unwrap().lexer().dfa.stats()
}

pub fn with_alphabet_info<T>(&self, f: impl FnOnce(&AlphabetInfo) -> T) -> T {
let a = self.shared.lock().unwrap();
let a = a.lexer().dfa.alpha();
f(a)
}

pub fn get_error(&self) -> Option<ParserError> {
let shared = self.shared.lock().unwrap();
if let Some(e) = shared.lexer().dfa.get_error() {
Expand Down
Loading

0 comments on commit a328a15

Please sign in to comment.