From de232ea59b8417d4de1fe50cc63694b5858d1ccc Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Wed, 11 Sep 2024 14:09:19 -0700 Subject: [PATCH] more sensible commit_token() result --- parser/src/constraint.rs | 69 +++++++++++++++++++++++++++++++++++++--- parser/src/ffi.rs | 35 ++++++++++++-------- parser/src/lib.rs | 2 +- rust/src/py.rs | 5 ++- 4 files changed, 88 insertions(+), 23 deletions(-) diff --git a/parser/src/constraint.rs b/parser/src/constraint.rs index 02c918ed..f7d38462 100644 --- a/parser/src/constraint.rs +++ b/parser/src/constraint.rs @@ -18,6 +18,36 @@ pub struct Constraint { started: bool, } +#[derive(Debug, Clone, Default)] +pub struct CommitResult { + pub stop: bool, + pub backtrack: u32, + pub ff_tokens: Vec, +} + +impl CommitResult { + pub fn stop() -> Self { + Self { + stop: true, + backtrack: 0, + ff_tokens: vec![], + } + } + + pub fn from_step_result(res: &StepResult) -> Self { + let mut r = CommitResult { + stop: res.is_stop(), + backtrack: 0, + ff_tokens: vec![], + }; + if let Some(s) = res.unconditional_splice() { + r.backtrack = s.backtrack; + r.ff_tokens = s.ff_tokens.clone(); + } + r + } +} + impl Constraint { /// Construct a state machine for a sequence constraint. pub fn new(parser: TokenParser) -> Self { @@ -60,6 +90,21 @@ impl Constraint { self.parser.process_prompt(prompt) } + /// This can be called before the first get_mask() to walk forward the + /// parser with tokens generated in some previous run. + pub fn force_tokens(&mut self, tokens: &[TokenId]) -> Result<()> { + ensure!( + self.step_arg.is_none() || self.step_arg.as_ref().unwrap().tokens.is_empty(), + "force_tokens() called twice" + ); + self.step_arg = Some(StepArg { + backtrack: 0, + tokens: tokens.to_vec(), + sampled: None, + }); + Ok(()) + } + /// This computes token sampling mask. /// It typically takes up to a millisecond for a 100k tokenizer. /// It will return an error when the order of calls is violated. @@ -93,9 +138,17 @@ impl Constraint { Ok(&self.last_res) } + pub fn step_result(&self) -> &StepResult { + &self.last_res + } + + fn res_commit_result(&mut self) -> Result { + Ok(CommitResult::from_step_result(&self.last_res)) + } + /// This commits the sampled token (if any), and sees if this forces any more tokens /// on the output (if ff_tokens are enabled in InferenceCapabilities). - pub fn commit_token(&mut self, sampled_token: Option) -> Result<&StepResult> { + pub fn commit_token(&mut self, sampled_token: Option) -> Result { ensure!( self.step_arg.is_none(), "commit_token() called twice or without compute_bias()" @@ -103,7 +156,7 @@ impl Constraint { // if last result was to stop or to unconditionally splice, we're done already if self.last_res.is_stop() { - return Ok(&self.last_res); + return self.res_commit_result(); } if let Some(splice) = self.last_res.unconditional_splice() { @@ -113,7 +166,7 @@ impl Constraint { // prepare argument for the next step self.step_arg = Some(StepArg::from_splice(splice, sampled_token)); - return Ok(&self.last_res); + return self.res_commit_result(); } // otherwise, append the sampled token and see if more tokens can be forced @@ -139,7 +192,7 @@ impl Constraint { if !self.parser.inference_caps.ff_tokens { self.step_arg = Some(StepArg::from_sampled_token(sampled_token)); self.last_res = StepResult::splice(0, vec![sampled_token]); - return Ok(&self.last_res); + return self.res_commit_result(); } // now, advance the parser with the sampled token - this should be very quick @@ -175,7 +228,7 @@ impl Constraint { } self.last_res = StepResult::splice(splice.backtrack, splice.ff_tokens.clone()); - Ok(&self.last_res) + return self.res_commit_result(); } /// This returns parser outputs to be passed back to the user. @@ -190,4 +243,10 @@ impl Constraint { pub fn flush_logs(&mut self) -> String { self.parser.logger.get_and_clear_logs() } + + // Utility functions + + pub fn tok_trie(&self) -> &toktrie::TokTrie { + self.parser.token_env.tok_trie() + } } diff --git a/parser/src/ffi.rs b/parser/src/ffi.rs index 67943a8e..061c2fbb 100644 --- a/parser/src/ffi.rs +++ b/parser/src/ffi.rs @@ -8,7 +8,7 @@ use toktrie::{InferenceCapabilities, TokEnv, TokRxInfo, TokTrie, TokenizerEnv}; use crate::{ api::{ParserLimits, TopLevelGrammar}, - Constraint, Logger, TokenParser, + CommitResult, Constraint, Logger, TokenParser, }; struct CTokenizerInner { @@ -151,6 +151,7 @@ pub struct LlgConstraint { local_error: Option, last_logs: String, constraint: Option, + last_commit_result: CommitResult, } #[repr(C)] @@ -176,6 +177,21 @@ pub struct LlgCommitResult { pub is_stop: bool, } +impl LlgCommitResult { + pub fn from_commit_result(r: &CommitResult) -> Self { + let len = r.ff_tokens.len() as u32; + LlgCommitResult { + tokens: if len == 0 { + std::ptr::null() + } else { + r.ff_tokens.as_ptr() + }, + n_tokens: len, + is_stop: r.stop, + } + } +} + fn new_constraint(init: &LlgConstraintInit, grammar_json: *const c_char) -> Result { let grammar_json = unsafe { CStr::from_ptr(grammar_json) } .to_str() @@ -250,6 +266,7 @@ pub extern "C" fn llg_new_constraint( local_error: None, constraint: None, last_logs: "\x00".to_string(), + last_commit_result: CommitResult::default(), }; match new_constraint(init, grammar_json) { @@ -312,19 +329,9 @@ pub extern "C" fn llg_commit_token( }; match constraint.commit_token(token) { Ok(r) => { - let res = if let Some(s) = r.unconditional_splice() { - LlgCommitResult { - tokens: s.ff_tokens.as_ptr(), - n_tokens: s.ff_tokens.len() as u32, - is_stop: r.is_stop(), - } - } else { - LlgCommitResult { - tokens: std::ptr::null(), - n_tokens: 0, - is_stop: r.is_stop(), - } - }; + // store it, so it survives until the next call to llg_*() + cc.last_commit_result = r; + let res = LlgCommitResult::from_commit_result(&cc.last_commit_result); unsafe { *res_p = res }; } Err(e) => cc.set_error(&e.to_string()), diff --git a/parser/src/lib.rs b/parser/src/lib.rs index 2b505152..79b00470 100644 --- a/parser/src/lib.rs +++ b/parser/src/lib.rs @@ -8,7 +8,7 @@ pub mod output; pub use toktrie; mod constraint; -pub use constraint::Constraint; +pub use constraint::{CommitResult, Constraint}; mod logging; pub use logging::Logger; diff --git a/rust/src/py.rs b/rust/src/py.rs index fbc02aa7..7dfb05a8 100644 --- a/rust/src/py.rs +++ b/rust/src/py.rs @@ -107,13 +107,12 @@ impl LLInterpreter { fn advance_parser(&mut self, sampled_token: Option) -> PyResult<(u32, Vec)> { let pres = self.inner.commit_token(sampled_token).map_err(val_error)?; - if pres.is_stop() { + if pres.stop { // let the next mid_process() call handle it return Ok((0, vec![])); } - let splice = pres.unconditional_splice().unwrap(); - Ok((splice.backtrack, splice.ff_tokens.clone())) + Ok((pres.backtrack, pres.ff_tokens)) } fn post_process(&mut self, sampled_token: Option) -> PyResult<(u32, Vec)> {