diff --git a/Cargo.lock b/Cargo.lock index d81cc42..df647c4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -435,7 +435,7 @@ dependencies = [ [[package]] name = "derivre" version = "0.1.0" -source = "git+https://github.com/microsoft/derivre?rev=5cd57842ed6d2b64156b5d44f4d2094097855e44#5cd57842ed6d2b64156b5d44f4d2094097855e44" +source = "git+https://github.com/microsoft/derivre?rev=a629ecc9dfbd9297bc3a35a5f2028f4bfce91060#a629ecc9dfbd9297bc3a35a5f2028f4bfce91060" dependencies = [ "ahash", "anyhow", diff --git a/json_stats/run.sh b/json_stats/run.sh index 421937c..c5d36db 100755 --- a/json_stats/run.sh +++ b/json_stats/run.sh @@ -7,7 +7,7 @@ else fi if [ -z "$PERF" ]; then - cargo run --release $DEFAULT_ARGS "$@" + cargo run --release -- $DEFAULT_ARGS "$@" else PERF='perf record -F 999 -g' RUSTFLAGS='-C force-frame-pointers=y' cargo build --profile perf diff --git a/json_stats/src/json_stats.rs b/json_stats/src/json_stats.rs index 40925b1..1f5c9fe 100644 --- a/json_stats/src/json_stats.rs +++ b/json_stats/src/json_stats.rs @@ -3,6 +3,7 @@ use clap::Parser; use json_stats::SchemaStats; use jsonschema::Validator; use llguidance::{ + earley::regexvec::LexerStats, toktrie::{InferenceCapabilities, TokEnv}, Constraint, JsonCompileOptions, ParserFactory, TokenParser, }; @@ -12,7 +13,7 @@ use std::{ collections::HashMap, fs::File, io::{Read, Write}, - sync::Arc, + sync::{atomic::AtomicUsize, Arc}, }; use rayon::prelude::*; @@ -42,6 +43,12 @@ pub struct CliOptions { #[arg(long, short = 's')] llg_slicer: bool, + #[arg(long)] + llg_no_forcing: bool, + + #[arg(long)] + csv: bool, + #[arg(long)] num_threads: Option, @@ -54,6 +61,9 @@ pub struct CliOptions { #[arg(long)] additional_features: bool, + #[arg(long, default_value = "meta-llama/Llama-3.1-8B-Instruct")] + tokenizer: String, + // .json files or folders with .json files #[arg(value_name = "FILES")] files: Vec, @@ -68,12 +78,16 @@ struct LlgResult { #[serde(skip_serializing_if = "is_zero")] ttfm_us: usize, #[serde(skip_serializing_if = "is_zero")] + max_ttfm_us: usize, + #[serde(skip_serializing_if = "is_zero")] masks_us: usize, #[serde(skip_serializing_if = "is_zero")] max_mask_us: usize, #[serde(skip_serializing_if = "is_zero")] slicer_leftover_us: usize, + one: usize, + num_tokens: usize, num_tests: usize, num_valid_tests: usize, @@ -85,6 +99,11 @@ struct LlgResult { max_sum_parser_items: usize, max_parser_items: usize, max_lexer_cost: u64, + max_lexer_states: usize, + lexer_cost: u64, + trie_nodes_walked: usize, + + lexer_stats: LexerStats, #[serde(skip)] slow_mask_count: [usize; MASK_STEPS], @@ -212,8 +231,10 @@ impl TestEnv { stats.num_tests += 1; + // println!("tokenized: {}", trie.tokens_dbg(&tokens)); + for (tidx, &token) in tokens.iter().enumerate() { - //println!("WILL TEST {}: {}", tidx, trie.token_dbg(token)); + // eprintln!("WILL TEST {}: {}", tidx, trie.token_dbg(token)); stats.num_tokens += 1; @@ -223,9 +244,36 @@ impl TestEnv { let us = t0.elapsed().as_micros() as usize; let pstats = parser.last_step_stats(); + // && pstats.lexer_cost < 7 * us as u64 + if self.cli.csv && us > 1000 { + static CSV_LINE: AtomicUsize = AtomicUsize::new(0); + let line_no = CSV_LINE.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + if line_no == 0 { + println!("MASK,us,lexer_cost,slices,items,rows,cached_rows,trie_nodes,allowed_tokens,est_time"); + } + println!( + "{},{},{},{},{},{},{},{},{},{}", + if us > 1000 { "SLOW" } else { "OK" }, + us, + pstats.lexer_cost, + pstats.slices_applied, + pstats.all_items, + pstats.rows, + pstats.cached_rows, + pstats.trie_nodes_walked, + m.num_set(), + (pstats.trie_nodes_walked as u64 * 4 + pstats.lexer_cost * 60) / 1000 + ); + // eprintln!("{}", parser.parser.lexer_stats()); + + // eprintln!("{:?}", pstats); + } + stats.sum_parser_items += pstats.all_items; stats.max_parser_items = std::cmp::max(stats.max_parser_items, pstats.all_items); stats.max_lexer_cost = std::cmp::max(stats.max_lexer_cost, pstats.lexer_cost); + stats.lexer_cost += pstats.lexer_cost; + stats.trie_nodes_walked += pstats.trie_nodes_walked; let step = us.next_power_of_two().trailing_zeros() as usize; let step = std::cmp::min(step, MASK_STEPS - 1); @@ -233,7 +281,7 @@ impl TestEnv { stats.slow_mask_count[step] += 1; stats.slow_mask_us[step] += us; - assert!(pstats.slices_applied <= 1); + // assert!(pstats.slices_applied <= 1); let is_big = m.num_set() >= 120_000; let sliced = pstats.slices_applied > 0; @@ -301,6 +349,9 @@ impl TestEnv { let m = parser.parser.metrics_mut(); stats.slicer_leftover_us += m.slicer_leftover_us; + let lx = parser.parser.lexer_stats(); + stats.max_lexer_states = std::cmp::max(stats.max_lexer_states, lx.num_states); + r } @@ -311,7 +362,7 @@ impl TestEnv { let t0 = std::time::Instant::now(); let schema = opts.json_to_llg(test_file.schema.clone()); - let schema = match schema { + let mut schema = match schema { Ok(schema) => schema, Err(e) => { res.compile_error = Some(format!("{e}")); @@ -321,6 +372,10 @@ impl TestEnv { } }; + if self.cli.llg_no_forcing { + schema.grammars[0].no_forcing = true; + } + let parser = self.factory.create_parser(schema); let parser = match parser { @@ -328,6 +383,8 @@ impl TestEnv { let mut constraint = Constraint::new(parser.clone()); constraint.compute_mask().unwrap(); res.ttfm_us = t0.elapsed().as_micros() as usize; + res.max_ttfm_us = res.ttfm_us; + res.one = 1; parser // eprintln!("{} OK", file); } @@ -339,6 +396,8 @@ impl TestEnv { } }; + res.lexer_stats = parser.parser.lexer_stats(); + if self.cli.llg_test { for (idx, t) in test_file.tests.iter().enumerate() { let t0 = std::time::Instant::now(); @@ -523,18 +582,12 @@ fn main() { files.retain(|f| !f.contains("Handwritten") && !f.contains("Synthesized")); } - // "microsoft/Phi-3.5-mini-instruct" - let tok_env: TokEnv = toktrie_hf_tokenizers::ByteTokenizerEnv::from_name( - "meta-llama/Llama-3.1-8B-Instruct", - None, - ) - .unwrap() - .to_env(); - - let mut slices = vec![ - r#"[^"\\\x00-\x1F\x7F]{1,30}"#.to_string(), - // r#"[^"\\\x00-\x1F\x7F]+"#.to_string(), - ]; + let tok_env: TokEnv = + toktrie_hf_tokenizers::ByteTokenizerEnv::from_name(&options.tokenizer, None) + .unwrap() + .to_env(); + + let mut slices = llguidance::earley::SlicedBiasComputer::json_slices(); if !options.llg_slicer { slices.clear(); } @@ -552,6 +605,9 @@ fn main() { factory.quiet(); let factory = Arc::new(factory); + save_text_to_file("tmp/slices.txt", &factory.slicer().stats(false)); + save_text_to_file("tmp/slices_tokens.txt", &factory.slicer().stats(true)); + let t0 = std::time::Instant::now(); let par = num_threads > 1; let do_file = |file: &String| { @@ -692,13 +748,12 @@ fn main() { } } - println!("{}", serde_json::to_string_pretty(&total).unwrap()); - println!( + eprintln!("{}", serde_json::to_string_pretty(&total).unwrap()); + eprintln!( "LLG: {}", serde_json::to_string_pretty(&llg_totals).unwrap() ); - - println!("Total time: {}ms", t0.elapsed().as_millis()); + eprintln!("Total time: {}ms", t0.elapsed().as_millis()); save_text_to_file("tmp/mask_histogram.csv", &histogram_csv); save_json_to_file("tmp/test_total.json", &total); diff --git a/parser/Cargo.toml b/parser/Cargo.toml index 28b980d..7609824 100644 --- a/parser/Cargo.toml +++ b/parser/Cargo.toml @@ -5,7 +5,7 @@ edition = "2021" [dependencies] toktrie = { workspace = true } -derivre = { git = "https://github.com/microsoft/derivre", rev = "5cd57842ed6d2b64156b5d44f4d2094097855e44" } +derivre = { git = "https://github.com/microsoft/derivre", rev = "a629ecc9dfbd9297bc3a35a5f2028f4bfce91060" } serde = { version = "1.0.210", features = ["derive"] } serde_json = { version = "1.0.132", features = ["preserve_order"] } anyhow = "1.0.90" diff --git a/parser/llguidance.h b/parser/llguidance.h index 1b0b07b..0d1fdf2 100644 --- a/parser/llguidance.h +++ b/parser/llguidance.h @@ -38,7 +38,7 @@ typedef struct LlgParserLimits { size_t step_max_items; /** * Maximum number of lexer states. - * Default: 10_000 + * Default: 50_000 */ size_t max_lexer_states; /** diff --git a/parser/src/api.rs b/parser/src/api.rs index d0cc5ea..f68ec26 100644 --- a/parser/src/api.rs +++ b/parser/src/api.rs @@ -382,7 +382,7 @@ pub struct ParserLimits { pub step_max_items: usize, /// Maximum number of lexer states. - /// Default: 10_000 + /// Default: 50_000 pub max_lexer_states: usize, /// Maximum size of the grammar (symbols in productions) @@ -396,7 +396,7 @@ impl Default for ParserLimits { max_items_in_row: 2000, initial_lexer_fuel: 1_000_000, // fhir schema => 500k step_lexer_fuel: 200_000, // - max_lexer_states: 10_000, // ? + max_lexer_states: 50_000, // ? max_grammar_size: 500_000, // fhir schema => 200k step_max_items: 50_000, // } diff --git a/parser/src/earley/parser.rs b/parser/src/earley/parser.rs index 3f2b5aa..aa120b3 100644 --- a/parser/src/earley/parser.rs +++ b/parser/src/earley/parser.rs @@ -12,7 +12,7 @@ use std::{ }; use anyhow::{bail, ensure, Result}; -use derivre::{RegexAst, StateID}; +use derivre::{AlphabetInfo, RegexAst, StateID}; use hashbrown::HashSet; use instant::Instant; use serde::{Deserialize, Serialize}; @@ -28,6 +28,7 @@ use super::{ grammar::{CGrammar, CSymIdx, CSymbol, RhsPtr}, lexer::{LexerResult, PreLexeme}, lexerspec::{Lexeme, LexemeIdx, LexerSpec}, + regexvec::LexerStats, }; const TRACE: bool = false; @@ -72,6 +73,7 @@ pub struct ParserStats { pub all_items: usize, pub lexer_cost: u64, pub slices_applied: usize, + pub trie_nodes_walked: usize, pub definitive_bytes: usize, pub lexer_ops: usize, @@ -142,6 +144,9 @@ impl ParserStats { .compute_time_us .saturating_sub(previous.compute_time_us), slices_applied: self.slices_applied.saturating_sub(previous.slices_applied), + trie_nodes_walked: self + .trie_nodes_walked + .saturating_sub(previous.trie_nodes_walked), } } @@ -157,6 +162,7 @@ impl ParserStats { lexer_cost: self.lexer_cost.max(other.lexer_cost), compute_time_us: self.compute_time_us.max(other.compute_time_us), slices_applied: self.slices_applied.max(other.slices_applied), + trie_nodes_walked: self.trie_nodes_walked.max(other.trie_nodes_walked), } } } @@ -636,8 +642,6 @@ impl ParserState { assert!(toks.len() == 1); set.disallow_token(toks[0]); - computer.trie().apply_duplicates(&mut set); - if set.is_zero() { // nothing allowed // we're going to be stopped outside - we better flush the lexer @@ -2106,6 +2110,10 @@ impl<'a> Recognizer for ParserRecognizer<'a> { r } + + fn save_stats(&mut self, nodes_walked: usize) { + self.state.stats.trie_nodes_walked += nodes_walked; + } } fn item_to_string(g: &CGrammar, item: &Item) -> String { @@ -2170,10 +2178,16 @@ impl Parser { self.state.hidden_start(shared.lexer_mut()) } - pub fn lexer_stats(&self) -> String { + pub fn lexer_stats(&self) -> LexerStats { self.shared.lock().unwrap().lexer().dfa.stats() } + pub fn with_alphabet_info(&self, f: impl FnOnce(&AlphabetInfo) -> T) -> T { + let a = self.shared.lock().unwrap(); + let a = a.lexer().dfa.alpha(); + f(a) + } + pub fn get_error(&self) -> Option { let shared = self.shared.lock().unwrap(); if let Some(e) = shared.lexer().dfa.get_error() { diff --git a/parser/src/earley/regexvec.rs b/parser/src/earley/regexvec.rs index da69271..167941a 100644 --- a/parser/src/earley/regexvec.rs +++ b/parser/src/earley/regexvec.rs @@ -6,13 +6,56 @@ /// https://www.khoury.northeastern.edu/home/turon/re-deriv.pdf (retrieved 15 Nov 2024) use anyhow::{bail, Result}; use derivre::raw::{DerivCache, ExprSet, NextByteCache, RelevanceCache, VecHashCons}; -use std::{fmt::Debug, u64}; +use serde::{Deserialize, Serialize}; +use std::{ + fmt::{Debug, Display}, + u64, +}; use toktrie::SimpleVob; pub use derivre::{AlphabetInfo, ExprRef, NextByte, StateID}; use crate::api::ParserLimits; +#[derive(Clone, Serialize, Deserialize, Default)] +pub struct LexerStats { + pub num_regexps: usize, + pub num_ast_nodes: usize, + pub num_derived: usize, + pub num_derivatives: usize, + pub total_fuel_spent: usize, + pub num_states: usize, + pub num_transitions: usize, + pub num_bytes: usize, + pub alphabet_size: usize, + pub error: bool, +} + +impl Display for LexerStats { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "regexps: {} with {} nodes (+ {} derived via {} derivatives with total fuel {}), states: {}; transitions: {}; bytes: {}; alphabet size: {} {}", + self.num_regexps, + self.num_ast_nodes, + self.num_derived, + self.num_derivatives, + self.total_fuel_spent, + self.num_states, + self.num_transitions, + self.num_bytes, + self.alphabet_size, + if self.error { "ERROR" } else { "" } + ) + } +} + +impl Debug for LexerStats { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + Display::fmt(self, f) + } +} + #[derive(Clone)] pub struct RegexVec { exprs: ExprSet, @@ -158,7 +201,6 @@ impl RegexVec { let t0 = instant::Instant::now(); assert!(self.subsume_possible(state)); let small = self.rx_list[lexeme_idx]; - self.set_fuel(u64::MAX); let mut res = false; for (idx, e) in iter_state(&self.rx_sets, state) { assert!(!self.lazy[idx]); @@ -351,22 +393,19 @@ impl RegexVec { } } - pub fn stats(&self) -> String { - format!( - "regexps: {} with {} nodes (+ {} derived via {} derivatives with total fuel {}), states: {}; transitions: {}; bytes: {}; alphabet size: {} {}", - self.rx_list.len(), - self.num_ast_nodes, - self.exprs.len() - self.num_ast_nodes, - self.deriv.num_deriv, - self.total_fuel_spent(), - self.state_descs.len(), - self.num_transitions, - self.num_bytes(), - self.alpha.len(), - if self.has_error() { - "ERROR" - } else { "" } - ) + pub fn stats(&self) -> LexerStats { + LexerStats { + num_regexps: self.rx_list.len(), + num_ast_nodes: self.num_ast_nodes, + num_derived: self.exprs.len() - self.num_ast_nodes, + num_derivatives: self.deriv.num_deriv, + total_fuel_spent: self.total_fuel_spent() as usize, + num_states: self.state_descs.len(), + num_transitions: self.num_transitions, + num_bytes: self.num_bytes(), + alphabet_size: self.alpha.len(), + error: self.has_error(), + } } pub fn print_state_table(&self) { diff --git a/parser/src/earley/slicer.rs b/parser/src/earley/slicer.rs index 84b1024..4fe9393 100644 --- a/parser/src/earley/slicer.rs +++ b/parser/src/earley/slicer.rs @@ -1,5 +1,7 @@ use std::sync::Arc; +use derivre::AlphabetInfo; + use crate::{ derivre::Regex, earley::{BiasComputer, ParserRecognizer}, @@ -16,8 +18,9 @@ struct TokenizerSlice { } pub struct SlicedBiasComputer { - tok_env: TokEnv, + wildcard_slice: TokTrie, slices: Arc>, + tok_env: TokEnv, } const DEBUG: bool = ITEM_TRACE; @@ -31,6 +34,19 @@ macro_rules! debug { } impl SlicedBiasComputer { + pub fn json_slices() -> Vec { + vec![ + r#"[^"\\\x00-\x1F\x7F]{1,10}"#.to_string(), + r#"[^"\\\x00-\x1F\x7F]{1,30}"#.to_string(), + r#"[^"\\\x00-\x1F\x7F]+"#.to_string(), + ] + } + + pub fn general_slices() -> Vec { + // to be improved in future + Self::json_slices() + } + pub fn new(tok_env: &TokEnv, regexes: &Vec) -> Self { let mut slices = vec![]; @@ -38,7 +54,6 @@ impl SlicedBiasComputer { let n_vocab = trie.vocab_size() as TokenId; let mut covered = trie.alloc_token_set(); let mut idx = 0; - let mut total_nodes = 0; let mut regexes = regexes.clone(); if regexes.len() > 0 { regexes.push("".to_string()); // catch-all @@ -79,38 +94,91 @@ impl SlicedBiasComputer { trie: TokTrie::from(trie.info(), &tokens), mask, }; - debug!( - "slice{}: /{}/ -> {}", - idx, - entry.regex, - entry.trie.trie_stats() - ); - if false && DEBUG && entry.regex == "" { - for (tok_idx, b) in entry.trie.sorted_tokens() { - if b.len() > 0 { - debug!(" tok{}-> {}", tok_idx, entry.trie.token_dbg(tok_idx)); - } - } - } - total_nodes += entry.trie.root().subtree_size(); slices.push(entry); idx += 1; } - if total_nodes > 0 { - debug!("total_nodes: {}", total_nodes); - } - SlicedBiasComputer { - tok_env: tok_env.clone(), + let r = SlicedBiasComputer { slices: Arc::new(slices), + wildcard_slice: trie.clone(), + tok_env: tok_env.clone(), + }; + + debug!("slicer:\n{}", r.stats(false)); + + r + } + + pub fn stats(&self, include_tokens: bool) -> String { + let mut total_nodes = 0; + let mut s = String::new(); + for (i, slice) in self.slices.iter().enumerate() { + total_nodes += slice.trie.root().subtree_size(); + s.push_str(&format!( + "slice{}: /{}/ -> {}\n", + i, + slice.regex, + slice.trie.trie_stats() + )); + if include_tokens { + for (tok_idx, b) in slice.trie.sorted_tokens() { + if b.len() > 0 { + s.push_str(&format!( + " tok{}-> {}\n", + tok_idx, + slice.trie.token_dbg(tok_idx) + )); + } + } + } } + s.push_str(&format!("total_nodes: {}\n", total_nodes)); + s.push_str(&format!("WILDCARD: {}\n", self.wildcard_slice.trie_stats())); + s } pub fn extra_lexemes(&self) -> Vec { self.slices.iter().map(|s| s.regex.clone()).collect() } + + pub fn compress(&self, ai: &AlphabetInfo) -> Self { + let slices = self + .slices + .iter() + .map(|s| TokenizerSlice { + idx: s.idx, + regex: s.regex.clone(), + trie: compress_trie(&s.trie, ai), + mask: s.mask.clone(), + }) + .collect(); + SlicedBiasComputer { + wildcard_slice: compress_trie(&self.wildcard_slice, ai), + slices: Arc::new(slices), + tok_env: self.tok_env.clone(), + } + } +} + +fn compress_trie(trie: &TokTrie, ai: &AlphabetInfo) -> TokTrie { + let mut tokens = trie.all_tokens(); + let mut repr = vec![None; 256]; + let repr2 = (0..=255) + .map(|b| { + if repr[ai.map(b)].is_none() { + repr[ai.map(b)] = Some(b); + } + repr[ai.map(b)].unwrap() + }) + .collect::>(); + for t in tokens.iter_mut() { + for i in 0..t.len() { + t[i] = repr2[t[i] as usize]; + } + } + TokTrie::from(trie.info(), &tokens) } impl BiasComputer for SlicedBiasComputer { @@ -121,9 +189,8 @@ impl BiasComputer for SlicedBiasComputer { && start.is_empty() && rec.lexer_mut().subsume_possible(lexer_state) { - // for JSON string lexer and /[a-zA-Z\u{0080}-\u{10FFFF}]+/ kind of slices - // we use about 200 of the budget and it takes around 20us - let budget = 5500; + // set to at least 500 + let budget = 1000; let slice_matches = self .slices .iter() @@ -138,7 +205,7 @@ impl BiasComputer for SlicedBiasComputer { if slice_matches.iter().all(|&x| x == false) { // if nothing matches, just run the full trie - self.trie().add_bias(rec, &mut set, start); + self.wildcard_slice.add_bias(rec, &mut set, start); debug!("no slice matches; {} tokens", set.num_set()); } else { // otherwise, apply the matching slices, and compute the rest @@ -167,7 +234,7 @@ impl BiasComputer for SlicedBiasComputer { } } } else { - self.trie().add_bias(rec, &mut set, start); + self.wildcard_slice.add_bias(rec, &mut set, start); debug!("slicer disabled; {} tokens", set.num_set()); } diff --git a/parser/src/factory.rs b/parser/src/factory.rs index f827335..7c0112a 100644 --- a/parser/src/factory.rs +++ b/parser/src/factory.rs @@ -51,22 +51,53 @@ impl ParserFactory { self } + pub fn set_buffer_log_level(&mut self, level: u32) -> &mut Self { + self.buffer_log_level = level; + self + } + + pub fn set_stderr_log_level(&mut self, level: u32) -> &mut Self { + self.stderr_log_level = level; + self + } + pub fn extra_lexemes(&self) -> Vec { self.slicer.extra_lexemes() } + pub fn slicer(&self) -> Arc { + self.slicer.clone() + } + pub fn post_process_parser(&self, parser: &mut TokenParser) { - parser.bias_computer = self.slicer.clone(); + if false { + // this only reduces the nodes walked by about 20%, but is quite + // expensive to compute + let slicer = parser + .parser + .with_alphabet_info(|a| self.slicer.compress(a)); + parser.bias_computer = Arc::new(slicer); + } else { + parser.bias_computer = self.slicer.clone(); + } let mut rng = self.seed.lock().unwrap(); rng.next_alt(); parser.parser.metrics_mut().rand = rng.clone(); } pub fn create_parser(&self, grammar: TopLevelGrammar) -> Result { + self.create_parser_ext(grammar, self.buffer_log_level) + } + + pub fn create_parser_ext( + &self, + grammar: TopLevelGrammar, + buffer_log_level: u32, + ) -> Result { let mut parser = TokenParser::from_llguidance_json( self.tok_env.clone(), grammar, - Logger::new(self.buffer_log_level, self.stderr_log_level), + Logger::new(buffer_log_level, self.stderr_log_level), self.inference_caps.clone(), self.limits.clone(), self.extra_lexemes(), diff --git a/parser/src/tokenparser.rs b/parser/src/tokenparser.rs index 8883e6a..2bbd806 100644 --- a/parser/src/tokenparser.rs +++ b/parser/src/tokenparser.rs @@ -366,10 +366,7 @@ impl TokenParser { let mut prefix = self.compute_ff_bytes(); // if ff_tokens is enabled, we assume the user has already called compute_ff_tokens() - if !self.inference_caps.ff_tokens - && !self.parser.grammar().lexer_spec().no_forcing - && self.token_env.tokenize_is_canonical() - { + if !self.inference_caps.ff_tokens && self.can_force_bytes() { let (ff_tokens, token_prefix) = self.ff_bytes_to_tokens(prefix); if ff_tokens.len() > 0 { let t = ff_tokens[0]; @@ -513,9 +510,17 @@ impl TokenParser { self.pending_grm_prefix().len() > 0 || self.parser.currently_forced_bytes().len() > 0 } + fn can_force_bytes(&self) -> bool { + !self.parser.grammar().lexer_spec().no_forcing && self.token_env.tokenize_is_canonical() + } + fn compute_ff_bytes(&mut self) -> Vec { // PERF: in some cases, this may be long - let mut new_forced = self.parser.force_bytes().to_vec(); + if self.can_force_bytes() { + self.parser.force_bytes(); + } + + let mut new_forced = self.parser.currently_forced_bytes().to_vec(); // handle grm_prefix we might have injected if self.llm_bytes.len() < self.grm_prefix.len() { diff --git a/toktrie/src/toktree.rs b/toktrie/src/toktree.rs index 08b0457..ed68823 100644 --- a/toktrie/src/toktree.rs +++ b/toktrie/src/toktree.rs @@ -5,12 +5,8 @@ use std::sync::Arc; use anyhow::Result; use bytemuck_derive::{Pod, Zeroable}; -use hashbrown::HashMap; -use crate::{ - bytes::{to_hex_string, vec_from_bytes}, - SimpleVob, -}; +use crate::{bytes::to_hex_string, SimpleVob}; pub type TokenId = u32; @@ -102,6 +98,7 @@ pub trait Recognizer { fn get_error(&mut self) -> Option { None } + fn save_stats(&mut self, _nodes_walked: usize) {} } pub trait TokenizerEnv: Send { @@ -200,23 +197,6 @@ pub struct TokTrie { token_data: Vec, nodes: Vec, max_token_len: usize, - token_duplicates: HashMap>, -} - -#[derive(Clone, Copy, Zeroable, Pod)] -#[repr(C)] -pub struct TokTrieHeader { - magic: u32, - hd_size: u32, - trie_bytes: u32, - token_offset_bytes: u32, - token_data_bytes: u32, - info: BinTokRxInfo, - align: [u32; 0], -} - -impl TokTrieHeader { - const MAGIC: u32 = 0x558b6fd3; } #[derive(Clone, Copy, Zeroable, Pod)] @@ -275,9 +255,11 @@ impl TokTrie { let mut token_offsets = Vec::new(); let mut token_data = Vec::new(); assert!(info.vocab_size == words.len() as u32); + let mut max_token_len = 0; for (idx, word) in words.iter().enumerate() { if word.len() > 0 { trie.insert(word, idx as u32); + max_token_len = std::cmp::max(max_token_len, word.len()); } assert!(word.len() < (1 << LEN_BITS)); assert!(token_data.len() < (1 << (32 - LEN_BITS))); @@ -287,15 +269,14 @@ impl TokTrie { } let mut nodes = Vec::new(); trie.serialize(&mut nodes, 0); - let mut r = TokTrie { + let r = TokTrie { info: info.clone(), token_offsets, token_data, nodes, - max_token_len: 0, - token_duplicates: HashMap::default(), + max_token_len, }; - r.finalize_ctor(); + r.validate(); r } @@ -316,21 +297,6 @@ impl TokTrie { self.with_eos_token(self.info.tok_end_of_turn.unwrap_or(self.info.tok_eos)) } - fn finalize_ctor(&mut self) { - for tok_id in 0..self.info.vocab_size { - let bytes = self.token(tok_id); - let tok_ids = self.greedy_tokenize(bytes); - self.max_token_len = std::cmp::max(self.max_token_len, bytes.len()); - if tok_ids.len() == 1 && tok_ids[0] != tok_id { - self.token_duplicates - .entry(tok_ids[0]) - .or_insert_with(Vec::new) - .push(tok_id); - } - } - self.validate(); - } - fn node_offset(&self, n: &TrieNode) -> usize { let off = unsafe { (n as *const TrieNode).offset_from(self.root() as *const TrieNode) }; assert!(off >= 0); @@ -627,31 +593,6 @@ impl TokTrie { return last; } - pub fn from_bytes(bytes: &[u8]) -> Self { - let pref = std::mem::size_of::(); - let hd: &TokTrieHeader = bytemuck::from_bytes(&bytes[0..pref]); - - assert!(hd.magic == TokTrieHeader::MAGIC); - assert!(hd.hd_size as usize == pref); - - let trie_end = pref + hd.trie_bytes as usize; - let nodes = vec_from_bytes(&bytes[pref..trie_end]); - let offsets_end = trie_end + hd.token_offset_bytes as usize; - let token_offsets = vec_from_bytes(&bytes[trie_end..offsets_end]); - let token_data = vec_from_bytes(&bytes[offsets_end..]); - - let mut r = TokTrie { - info: TokRxInfo::from_bin(&hd.info), - token_offsets, - token_data, - nodes, - max_token_len: 0, - token_duplicates: HashMap::default(), - }; - r.finalize_ctor(); - r - } - pub fn max_token_len(&self) -> usize { self.max_token_len } @@ -680,28 +621,6 @@ impl TokTrie { } } - pub fn serialize(&self) -> Vec { - let trie_data: &[u8] = bytemuck::cast_slice(&self.nodes); - let token_offsets: &[u8] = bytemuck::cast_slice(&self.token_offsets); - let token_data: &[u8] = bytemuck::cast_slice(&self.token_data); - - let hd = TokTrieHeader { - magic: TokTrieHeader::MAGIC, - hd_size: std::mem::size_of::() as u32, - trie_bytes: trie_data.len() as u32, - token_offset_bytes: token_offsets.len() as u32, - token_data_bytes: trie_data.len() as u32, - info: self.info.to_bin(), - align: [], - }; - - let mut bytes = bytemuck::bytes_of(&hd).to_vec(); - bytes.extend_from_slice(trie_data); - bytes.extend_from_slice(token_offsets); - bytes.extend_from_slice(token_data); - bytes - } - pub fn root(&self) -> &TrieNode { &self.nodes[0] } @@ -720,7 +639,15 @@ impl TokTrie { .token_id() .unwrap(); if tid != tid2 { - assert!(self.token_duplicates[&tid2].contains(&tid)); + let par = self + .child_at_bytes(root, &bytes[0..bytes.len() - 1]) + .unwrap(); + let has_it = self.node_children(par).any(|n| { + n.subtree_size() == 1 + && n.byte() == bytes[bytes.len() - 1] + && n.token_id() == Some(tid) + }); + assert!(has_it); } } } @@ -791,17 +718,6 @@ impl TokTrie { } } self.add_bias(r, logits, start); - self.apply_duplicates(logits); - } - - pub fn apply_duplicates(&self, logits: &mut SimpleVob) { - for (tok, dups) in &self.token_duplicates { - if logits.is_allowed(*tok) { - for &dup in dups { - logits.allow_token(dup); - } - } - } } pub fn append_tokens(&self, r: &mut impl Recognizer, ts: &[TokenId]) -> Result<()> { @@ -906,12 +822,8 @@ impl TokTrie { pub fn add_bias(&self, r: &mut impl Recognizer, toks: &mut SimpleVob, start: &[u8]) { // all prefixes of 'start' are also allowed if start.len() > 0 { - for len in 1..=start.len() { - let bytes = &start[0..len]; - if let Some(tok) = self.token_id(bytes) { - toks.allow_token(tok); - } - } + let mut fixed = FixedRecognizer::new(start); + self.add_bias(&mut fixed, toks, &[]); } let n = self.child_at_bytes(self.root(), start); @@ -920,24 +832,32 @@ impl TokTrie { } let n = n.unwrap(); r.trie_started("add_bias"); - let next_pop = self.add_bias_inner(r, toks, n); + let (next_pop, nodes_walked) = self.add_bias_inner(r, toks, n); if start.len() == 0 { // if start was non-empty, trie_finished() is supposed to clean this up r.pop_bytes(next_pop); } r.trie_finished(); + r.save_stats(nodes_walked); // revert the fake token let defl_tok = self.vocab_size() as u32; toks.disallow_token(defl_tok); } #[inline(never)] - fn add_bias_inner(&self, r: &mut impl Recognizer, toks: &mut SimpleVob, n: &TrieNode) -> usize { + fn add_bias_inner( + &self, + r: &mut impl Recognizer, + toks: &mut SimpleVob, + n: &TrieNode, + ) -> (usize, usize) { let defl_tok = self.vocab_size() as u32; let off = self.node_offset(n); + let total_nodes = n.subtree_size(); let mut p = off + 1; - let endp = off + n.subtree_size(); + let endp = off + total_nodes; let mut next_pop = 0; + let mut num_skip = 0; while p < endp { r.pop_bytes(next_pop); let n = &self.nodes[p]; @@ -951,11 +871,20 @@ impl TokTrie { }; p += 1; } else { - p += n.subtree_size(); + let subtree_size = n.subtree_size(); + p += subtree_size; + // it's slightly faster to count skipped nodes, than walked nodes + num_skip += subtree_size - 1; next_pop = n.num_parents() - 1; } } - next_pop + (next_pop, total_nodes - num_skip) + } + + pub fn all_tokens(&self) -> Vec> { + (0..self.vocab_size()) + .map(|idx| self.token(idx as u32).to_vec()) + .collect() } pub fn sorted_tokens(&self) -> Vec<(u32, Vec)> { @@ -1122,19 +1051,23 @@ impl TrieHash { if word.len() == 0 { // Some tokenizers have duplicate tokens... // we just override - // assert!(self.token_id == NO_TOKEN); + assert!(self.token_id == NO_TOKEN); self.token_id = token_id; } else { - if self.children.len() == 0x100 { - // assert!(self.children[word[0] as usize].byte == word[0]); - self.children[word[0] as usize].insert(&word[1..], token_id); - return; - } + // if self.children.len() == 0x100 { + // // assert!(self.children[word[0] as usize].byte == word[0]); + // self.children[word[0] as usize].insert(&word[1..], token_id); + // return; + // } for ch in &mut self.children { if ch.byte == word[0] { - ch.insert(&word[1..], token_id); - return; + if word.len() == 1 && ch.token_id != NO_TOKEN { + // this is duplicate token, proceed with adding a duplicate node + } else { + ch.insert(&word[1..], token_id); + return; + } } } @@ -1146,20 +1079,21 @@ impl TrieHash { // for cl100k threshold 60->15 nodes, 50->22, 40->45 30->94 // for llama (32k) 50->5, 40->15 // TODO remove this? - if self.children.len() > 250 { - let mut v2 = (0..=255).map(TrieHash::new).collect::>(); - for ch in self.children.drain(..) { - let idx = ch.byte as usize; - v2[idx] = ch; - } - self.children = v2; - } + // if self.children.len() > 250 { + // let mut v2 = (0..=255).map(TrieHash::new).collect::>(); + // for ch in self.children.drain(..) { + // let idx = ch.byte as usize; + // v2[idx] = ch; + // } + // self.children = v2; + // } } } fn serialize(&mut self, data: &mut Vec, num_parents: u8) { let idx = data.len(); let mut num_ch = self.children.len(); data.push(TrieNode::new(self.byte, self.token_id, num_parents)); + //self.children.reverse(); self.children.sort_by_key(|e| e.byte); for entry in &mut self.children { num_ch -= 1; @@ -1168,3 +1102,38 @@ impl TrieHash { data[idx].bits2 |= ((data.len() - idx) as u32) << 8; } } + +struct FixedRecognizer { + bytes: Vec, + bytes_ptr: usize, +} + +impl FixedRecognizer { + fn new(bytes: &[u8]) -> FixedRecognizer { + FixedRecognizer { + bytes: bytes.to_vec(), + bytes_ptr: 0, + } + } +} + +impl Recognizer for FixedRecognizer { + fn collapse(&mut self) {} + fn trie_finished(&mut self) {} + fn special_allowed(&mut self, _: SpecialToken) -> bool { + false + } + + fn pop_bytes(&mut self, num: usize) { + self.bytes_ptr -= num; + } + + fn try_push_byte(&mut self, byte: u8) -> bool { + if self.bytes_ptr < self.bytes.len() && self.bytes[self.bytes_ptr] == byte { + self.bytes_ptr += 1; + true + } else { + false + } + } +}