Skip to content

Commit

Permalink
more sensible commit_token() result
Browse files Browse the repository at this point in the history
  • Loading branch information
mmoskal committed Sep 11, 2024
1 parent 92a77eb commit de232ea
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 23 deletions.
69 changes: 64 additions & 5 deletions parser/src/constraint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,36 @@ pub struct Constraint {
started: bool,
}

#[derive(Debug, Clone, Default)]
pub struct CommitResult {
pub stop: bool,
pub backtrack: u32,
pub ff_tokens: Vec<TokenId>,
}

impl CommitResult {
pub fn stop() -> Self {
Self {
stop: true,
backtrack: 0,
ff_tokens: vec![],
}
}

pub fn from_step_result(res: &StepResult) -> Self {
let mut r = CommitResult {
stop: res.is_stop(),
backtrack: 0,
ff_tokens: vec![],
};
if let Some(s) = res.unconditional_splice() {
r.backtrack = s.backtrack;
r.ff_tokens = s.ff_tokens.clone();
}
r
}
}

impl Constraint {
/// Construct a state machine for a sequence constraint.
pub fn new(parser: TokenParser) -> Self {
Expand Down Expand Up @@ -60,6 +90,21 @@ impl Constraint {
self.parser.process_prompt(prompt)
}

/// This can be called before the first get_mask() to walk forward the
/// parser with tokens generated in some previous run.
pub fn force_tokens(&mut self, tokens: &[TokenId]) -> Result<()> {
ensure!(
self.step_arg.is_none() || self.step_arg.as_ref().unwrap().tokens.is_empty(),
"force_tokens() called twice"
);
self.step_arg = Some(StepArg {
backtrack: 0,
tokens: tokens.to_vec(),
sampled: None,
});
Ok(())
}

/// This computes token sampling mask.
/// It typically takes up to a millisecond for a 100k tokenizer.
/// It will return an error when the order of calls is violated.
Expand Down Expand Up @@ -93,17 +138,25 @@ impl Constraint {
Ok(&self.last_res)
}

pub fn step_result(&self) -> &StepResult {
&self.last_res
}

fn res_commit_result(&mut self) -> Result<CommitResult> {
Ok(CommitResult::from_step_result(&self.last_res))
}

/// This commits the sampled token (if any), and sees if this forces any more tokens
/// on the output (if ff_tokens are enabled in InferenceCapabilities).
pub fn commit_token(&mut self, sampled_token: Option<TokenId>) -> Result<&StepResult> {
pub fn commit_token(&mut self, sampled_token: Option<TokenId>) -> Result<CommitResult> {
ensure!(
self.step_arg.is_none(),
"commit_token() called twice or without compute_bias()"
);

// if last result was to stop or to unconditionally splice, we're done already
if self.last_res.is_stop() {
return Ok(&self.last_res);
return self.res_commit_result();
}

if let Some(splice) = self.last_res.unconditional_splice() {
Expand All @@ -113,7 +166,7 @@ impl Constraint {

// prepare argument for the next step
self.step_arg = Some(StepArg::from_splice(splice, sampled_token));
return Ok(&self.last_res);
return self.res_commit_result();
}

// otherwise, append the sampled token and see if more tokens can be forced
Expand All @@ -139,7 +192,7 @@ impl Constraint {
if !self.parser.inference_caps.ff_tokens {
self.step_arg = Some(StepArg::from_sampled_token(sampled_token));
self.last_res = StepResult::splice(0, vec![sampled_token]);
return Ok(&self.last_res);
return self.res_commit_result();
}

// now, advance the parser with the sampled token - this should be very quick
Expand Down Expand Up @@ -175,7 +228,7 @@ impl Constraint {
}

self.last_res = StepResult::splice(splice.backtrack, splice.ff_tokens.clone());
Ok(&self.last_res)
return self.res_commit_result();
}

/// This returns parser outputs to be passed back to the user.
Expand All @@ -190,4 +243,10 @@ impl Constraint {
pub fn flush_logs(&mut self) -> String {
self.parser.logger.get_and_clear_logs()
}

// Utility functions

pub fn tok_trie(&self) -> &toktrie::TokTrie {
self.parser.token_env.tok_trie()
}
}
35 changes: 21 additions & 14 deletions parser/src/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use toktrie::{InferenceCapabilities, TokEnv, TokRxInfo, TokTrie, TokenizerEnv};

use crate::{
api::{ParserLimits, TopLevelGrammar},
Constraint, Logger, TokenParser,
CommitResult, Constraint, Logger, TokenParser,
};

struct CTokenizerInner {
Expand Down Expand Up @@ -151,6 +151,7 @@ pub struct LlgConstraint {
local_error: Option<String>,
last_logs: String,
constraint: Option<Constraint>,
last_commit_result: CommitResult,
}

#[repr(C)]
Expand All @@ -176,6 +177,21 @@ pub struct LlgCommitResult {
pub is_stop: bool,
}

impl LlgCommitResult {
pub fn from_commit_result(r: &CommitResult) -> Self {
let len = r.ff_tokens.len() as u32;
LlgCommitResult {
tokens: if len == 0 {
std::ptr::null()
} else {
r.ff_tokens.as_ptr()
},
n_tokens: len,
is_stop: r.stop,
}
}
}

fn new_constraint(init: &LlgConstraintInit, grammar_json: *const c_char) -> Result<Constraint> {
let grammar_json = unsafe { CStr::from_ptr(grammar_json) }
.to_str()
Expand Down Expand Up @@ -250,6 +266,7 @@ pub extern "C" fn llg_new_constraint(
local_error: None,
constraint: None,
last_logs: "\x00".to_string(),
last_commit_result: CommitResult::default(),
};

match new_constraint(init, grammar_json) {
Expand Down Expand Up @@ -312,19 +329,9 @@ pub extern "C" fn llg_commit_token(
};
match constraint.commit_token(token) {
Ok(r) => {
let res = if let Some(s) = r.unconditional_splice() {
LlgCommitResult {
tokens: s.ff_tokens.as_ptr(),
n_tokens: s.ff_tokens.len() as u32,
is_stop: r.is_stop(),
}
} else {
LlgCommitResult {
tokens: std::ptr::null(),
n_tokens: 0,
is_stop: r.is_stop(),
}
};
// store it, so it survives until the next call to llg_*()
cc.last_commit_result = r;
let res = LlgCommitResult::from_commit_result(&cc.last_commit_result);
unsafe { *res_p = res };
}
Err(e) => cc.set_error(&e.to_string()),
Expand Down
2 changes: 1 addition & 1 deletion parser/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ pub mod output;
pub use toktrie;

mod constraint;
pub use constraint::Constraint;
pub use constraint::{CommitResult, Constraint};

mod logging;
pub use logging::Logger;
Expand Down
5 changes: 2 additions & 3 deletions rust/src/py.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,12 @@ impl LLInterpreter {
fn advance_parser(&mut self, sampled_token: Option<TokenId>) -> PyResult<(u32, Vec<TokenId>)> {
let pres = self.inner.commit_token(sampled_token).map_err(val_error)?;

if pres.is_stop() {
if pres.stop {
// let the next mid_process() call handle it
return Ok((0, vec![]));
}

let splice = pres.unconditional_splice().unwrap();
Ok((splice.backtrack, splice.ff_tokens.clone()))
Ok((pres.backtrack, pres.ff_tokens))
}

fn post_process(&mut self, sampled_token: Option<TokenId>) -> PyResult<(u32, Vec<TokenId>)> {
Expand Down

0 comments on commit de232ea

Please sign in to comment.