Skip to content

Commit

Permalink
re-work how tokens are validated
Browse files Browse the repository at this point in the history
  • Loading branch information
mmoskal committed Jan 16, 2025
1 parent 7057c3a commit bfc5e8a
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 142 deletions.
150 changes: 57 additions & 93 deletions parser/src/earley/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use derivre::{AlphabetInfo, RegexAst, StateID};
use hashbrown::HashSet;
use instant::Instant;
use serde::{Deserialize, Serialize};
use toktrie::{parse_numeric_token, Recognizer, SimpleVob, TokEnv, TokTrie, INVALID_TOKEN};
use toktrie::{Recognizer, SimpleVob, TokEnv, TokTrie, TokenId, INVALID_TOKEN};

use crate::{
api::{ParserLimits, StopReason},
Expand Down Expand Up @@ -845,103 +845,67 @@ impl ParserState {
}
}

pub fn validate_bytes(&mut self, tok_bytes: &[u8], check_eos: bool) -> usize {
pub fn validate_tokens(&mut self, tokens: &[TokenId]) -> usize {
self.assert_definitive();
let applied_idx = self.byte_to_token_idx.len();
let mut prefix_len = 0;
let tok_bytes = if applied_idx < self.bytes.len() {
let bytes_left = self.bytes.len() - applied_idx;
prefix_len = std::cmp::min(tok_bytes.len(), bytes_left);
if self.bytes[applied_idx..applied_idx + prefix_len] != tok_bytes[..prefix_len] {
// find common prefix
let mut i = 0;
while i < prefix_len && self.bytes[applied_idx + i] == tok_bytes[i] {
i += 1;
self.run_speculative("validate_tokens", |state| {
let mut applied_idx = state.byte_to_token_idx.len();
let eos = state.tok_env.tok_trie().eos_token();
let mut r = ParserRecognizer { state };
'token: for (tidx, &tok) in tokens.iter().enumerate() {
if tok == eos {
if r.state.is_accepting_inner() {
return tidx + 1;
} else {
return tidx;
}
}
return i;
}
// there are still pending bytes after applying tok_bytes
// do not check for eos
if bytes_left > prefix_len {
return prefix_len;
} else {
// otherwise, process the remaining bytes (could be 0)
// as speculative
&tok_bytes[prefix_len..]
}
} else {
tok_bytes
};

// fast path
if tok_bytes.is_empty() && !check_eos {
return prefix_len;
}
let token_bytes = r.state.tok_env.tok_trie().decode_raw(&[tok]);

self.run_speculative("validate_bytes", |state| {
let mut r = ParserRecognizer { state };
let mut idx = 0;
while idx < tok_bytes.len() {
let b = tok_bytes[idx];
if b == TokTrie::SPECIAL_TOKEN_MARKER {
if !r.state.flush_lexer() {
break;
}
if let Some((n_bytes, token_id)) = parse_numeric_token(&tok_bytes[(idx + 1)..])
{
if r.state
.token_range_lexemes()
.iter()
.any(|r| r.contains_token(token_id))
{
for b in &tok_bytes[idx..(idx + n_bytes + 1)] {
let ok = r.try_push_byte(*b);
assert!(ok);
}
idx += n_bytes + 1;
continue;
'byte: for (bidx, &b) in token_bytes.iter().enumerate() {
if applied_idx < r.state.bytes.len() {
if r.state.bytes[applied_idx] == b {
applied_idx += 1;
} else {
return tidx;
}
}
// if we failed to account for the whole token, stop
break;
}
if !r.try_push_byte(b) {
if r.state.flush_lexer() {
let range_specs = r.state.token_range_lexemes();
if range_specs.len() > 0 {
let toks = r.state.tok_env.tok_trie().all_prefixes(&tok_bytes[idx..]);
let mut found_numeric = false;
for spec in range_specs {
if let Some(&t) = toks.iter().find(|t| spec.contains_token(**t)) {
let n_bytes = r.state.tok_env.tok_trie().token(t).len();
let sidx = spec.idx;
let ok = r
.state
.add_numeric_token(sidx, &tok_bytes[idx..(idx + n_bytes)]);
if ok.is_ok() {
idx += n_bytes;
found_numeric = true;
break;
}
break;
} else {
// never push FF
if b != TokTrie::SPECIAL_TOKEN_MARKER && r.try_push_byte(b) {
// normal path
continue 'byte;
}

if bidx != 0 {
// not at the start of the token, bail
return tidx;
}

if !r.state.flush_lexer() {
// we need to flush lexer before checking for special/numeric tokens
return tidx;
}

for spec in r.state.token_range_lexemes() {
if spec.contains_token(tok) {
let numeric_bytes =
r.state.tok_env.tok_trie().decode_as_special(tok);
let ok = r.state.add_numeric_token(spec.idx, &numeric_bytes);
if ok.is_ok() {
continue 'token;
} else {
unreachable!(); // ???
}
}
if found_numeric {
continue;
}
}

// didn't find numeric token
return tidx;
}
return prefix_len + idx;
}
idx += 1;
}
prefix_len += idx;
if check_eos {
if state.is_accepting_inner() {
prefix_len += 1;
}
}
prefix_len

tokens.len() // all ok!
})
}

Expand Down Expand Up @@ -2445,15 +2409,15 @@ impl Parser {
r
}

/// Returns how many bytes can be applied.
pub fn validate_bytes(&mut self, tok_bytes: &[u8], check_eos: bool) -> usize {
/// Returns how many tokens can be applied.
pub fn validate_tokens(&mut self, tokens: &[TokenId]) -> usize {
self.with_shared(|state| {
let r = state.validate_bytes(tok_bytes, check_eos);
let r = state.validate_tokens(tokens);
debug!(
"validate_bytes: {:?} -> {}/{}",
String::from_utf8_lossy(tok_bytes),
"validate_tokens: {} -> {}/{}",
state.tok_env.tok_trie().tokens_dbg(tokens),
r,
tok_bytes.len()
tokens.len()
);
r
})
Expand Down
52 changes: 3 additions & 49 deletions parser/src/tokenparser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -284,14 +284,7 @@ impl TokenParser {

pub fn validate_token(&mut self, token: TokenId) -> Result<bool> {
self.check_initialized("validate_token")?;
if token == self.eos_token {
Ok(self.parser.validate_bytes(&[], true) > 0)
} else {
let bytes = self.tok_trie().decode_raw(&[token]);
let n_valid = self.parser.validate_bytes(&bytes, false);
assert!(n_valid <= bytes.len());
Ok(n_valid == bytes.len())
}
self.validate_tokens_raw(&[token]).map(|n| n > 0)
}

/// Returns how many of the passed tokens can be accepted by the parser.
Expand All @@ -304,47 +297,8 @@ impl TokenParser {
return Ok(0);
}

if tokens.len() == 1 {
return if self.validate_token(tokens[0])? {
Ok(1)
} else {
Ok(0)
};
}

let mut final_eos = false;
let tokens = if tokens.last() == Some(&self.eos_token) {
final_eos = true;
&tokens[..tokens.len() - 1]
} else {
tokens
};

let bytes = self.tok_trie().decode_raw(tokens);
let n_valid = self.parser.validate_bytes(&bytes, final_eos);

if final_eos && n_valid == bytes.len() + 1 {
return Ok(tokens.len() + 1);
}

assert!(n_valid <= bytes.len());

// fast paths
if n_valid == bytes.len() {
return Ok(tokens.len());
}
if n_valid == 0 {
return Ok(0);
}

let mut byte_ptr = 0;
for (token_ptr, tok) in tokens.iter().enumerate() {
byte_ptr += self.tok_trie().token_len(*tok);
if byte_ptr > n_valid {
return Ok(token_ptr);
}
}
Ok(tokens.len())
let n_valid = self.parser.validate_tokens(tokens);
Ok(n_valid)
}

fn anyhow_error(&self) -> anyhow::Error {
Expand Down
18 changes: 18 additions & 0 deletions sample_parser/tests/test_ll.rs
Original file line number Diff line number Diff line change
Expand Up @@ -691,4 +691,22 @@ fn test_ll_numeric_token_for_text() {
"#,
&["", "✖<|assistant|>✖f‧foo‧✖bar‧long‧✖<|system|>‧cat‧≺EOS≻"],
);

check_lark_grammar(
r#"start: f | foo | bar
f: <[29730]> <[105]>
foo: <[29730]> <[5431]>
bar: <[29842]>
"#,
&["", "zott‧foo"],
);

// check_lark_grammar(
// r#"start: f | foo | bar
// f: <[29730]> <[105]> <[29659]>
// foo: <[29730]> <[5431]> <[29659]>
// bar: <[29842]>
// "#,
// &["", "zott‧foo", "coded"],
// );
}
8 changes: 8 additions & 0 deletions toktrie/src/toktree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,14 @@ impl TokTrie {
bytes
}

pub fn decode_as_special(&self, tok: TokenId) -> Vec<u8> {
let mut res = Vec::new();
res.reserve(9);
res.push(TokTrie::SPECIAL_TOKEN_MARKER);
res.extend_from_slice(format!("[{}]", tok).as_bytes());
res
}

pub fn decode_raw(&self, tokens: &[TokenId]) -> Vec<u8> {
let mut res = Vec::new();
res.reserve(tokens.len() * 6 + 32); // approximately
Expand Down

0 comments on commit bfc5e8a

Please sign in to comment.