From 48bcba44c1c91c1f8df0e735c127f868aa5d59d9 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Fri, 4 Oct 2024 17:19:40 +0000 Subject: [PATCH] allow intersection/negation in lexemes (proper relevance checks) --- parser/Cargo.toml | 4 +-- parser/src/earley/lexer.rs | 6 ++-- parser/src/earley/lexerspec.rs | 5 ++- parser/src/earley/parser.rs | 4 +-- parser/src/earley/regexvec.rs | 58 ++++++++++++++++++++++++++++++---- rust/Cargo.lock | 4 +-- sample_parser/Cargo.toml | 2 +- 7 files changed, 66 insertions(+), 17 deletions(-) diff --git a/parser/Cargo.toml b/parser/Cargo.toml index cd8ca503..cd952673 100644 --- a/parser/Cargo.toml +++ b/parser/Cargo.toml @@ -4,8 +4,8 @@ version = "0.2.0" edition = "2021" [dependencies] -toktrie = { git = "https://github.com/microsoft/toktrie", rev = "8828701d3b1c743472fe61bdf6dab12cdd726ab4" } -derivre = { git = "https://github.com/microsoft/derivre", rev = "424ec3bd1f711add6aeab1711108b63abe856d0c" } +toktrie = { git = "https://github.com/microsoft/toktrie", rev = "5e7013ad05081e918809d4ecebb33db7c4aabc69" } +derivre = { git = "https://github.com/microsoft/derivre", rev = "02ee497e6e404a0b402b4f68a9abf599d22ed2ed" } serde = { version = "1.0.192", features = ["derive"] } serde_json = { version = "1.0.108", features = ["preserve_order"] } anyhow = "1.0.75" diff --git a/parser/src/earley/lexer.rs b/parser/src/earley/lexer.rs index cd867fa5..00f8e1ba 100644 --- a/parser/src/earley/lexer.rs +++ b/parser/src/earley/lexer.rs @@ -2,6 +2,8 @@ use anyhow::Result; use std::fmt::Debug; use toktrie::SimpleVob; +use crate::api::ParserLimits; + use super::{ lexerspec::{LexemeIdx, LexerSpec}, regexvec::{NextByte, RegexVec, StateDesc}, @@ -41,8 +43,8 @@ pub enum LexerResult { } impl Lexer { - pub fn from(spec: &LexerSpec) -> Result { - let dfa = spec.to_regex_vec(); + pub fn from(spec: &LexerSpec, limits: &mut ParserLimits) -> Result { + let dfa = spec.to_regex_vec(limits)?; debug!("lexer: {:?}\n ==> dfa: {:?}", spec, dfa); diff --git a/parser/src/earley/lexerspec.rs b/parser/src/earley/lexerspec.rs index 43c2006d..c40a94fe 100644 --- a/parser/src/earley/lexerspec.rs +++ b/parser/src/earley/lexerspec.rs @@ -3,6 +3,8 @@ use derivre::{ExprRef, JsonQuoteOptions, RegexAst, RegexBuilder}; use std::{fmt::Debug, hash::Hash}; use toktrie::{bytes::limit_str, SimpleVob}; +use crate::api::ParserLimits; + use super::regexvec::RegexVec; #[derive(Clone)] @@ -115,7 +117,7 @@ impl LexerSpec { .is_nullable(self.lexemes[idx.0].compiled_rx) } - pub fn to_regex_vec(&self) -> RegexVec { + pub fn to_regex_vec(&self, limits: &mut ParserLimits) -> Result { // TODO // Find all non-contextual lexemes that are literals (we call them 'keywords') // This assumes that this is the only possible conflict in the lexer that we want to catch. @@ -127,6 +129,7 @@ impl LexerSpec { self.regex_builder.exprset(), &rx_list, Some(self.lazy_lexemes()), + limits, ) } diff --git a/parser/src/earley/parser.rs b/parser/src/earley/parser.rs index 514167b1..c6f11ab6 100644 --- a/parser/src/earley/parser.rs +++ b/parser/src/earley/parser.rs @@ -378,10 +378,10 @@ impl ParserState { fn new( grammar: Arc, options: GenGrammarOptions, - limits: ParserLimits, + mut limits: ParserLimits, ) -> Result<(Self, Lexer)> { let start = grammar.start(); - let mut lexer = Lexer::from(grammar.lexer_spec())?; + let mut lexer = Lexer::from(grammar.lexer_spec(), &mut limits)?; let scratch = Scratch::new(Arc::clone(&grammar)); let lexer_state = lexer.a_dead_state(); // placeholder let mut r = ParserState { diff --git a/parser/src/earley/regexvec.rs b/parser/src/earley/regexvec.rs index b91ae567..e8f6533c 100644 --- a/parser/src/earley/regexvec.rs +++ b/parser/src/earley/regexvec.rs @@ -1,10 +1,12 @@ -use anyhow::Result; +use anyhow::{bail, Result}; use derivre::raw::{DerivCache, ExprSet, NextByteCache, RelevanceCache, VecHashCons}; use std::{fmt::Debug, u64}; use toktrie::SimpleVob; pub use derivre::{AlphabetInfo, ExprRef, NextByte, StateID}; +use crate::api::ParserLimits; + #[derive(Clone)] pub struct RegexVec { exprs: ExprSet, @@ -59,7 +61,10 @@ impl RegexVec { pub fn initial_state(&mut self, selected: &SimpleVob) -> StateID { let mut vec_desc = vec![]; for idx in selected.iter() { - Self::push_rx(&mut vec_desc, idx as usize, self.rx_list[idx as usize]); + let rx = self.rx_list[idx as usize]; + if rx != ExprRef::NO_MATCH { + Self::push_rx(&mut vec_desc, idx as usize, rx); + } } self.insert_state(vec_desc) } @@ -333,16 +338,41 @@ impl RegexVec { exprset: &ExprSet, rx_list: &[ExprRef], lazy: Option, - ) -> Self { - let (alpha, exprset, rx_list) = AlphabetInfo::from_exprset(exprset, rx_list); + limits: &mut ParserLimits, + ) -> Result { + let (alpha, mut exprset, mut rx_list) = AlphabetInfo::from_exprset(exprset, rx_list); let num_ast_nodes = exprset.len(); - let rx_sets = StateID::new_hash_cons(); + let fuel0 = limits.initial_lexer_fuel; + let mut relevance = RelevanceCache::new(); + for idx in 0..rx_list.len() { + let c0 = exprset.cost(); + match relevance.is_non_empty_limited( + &mut exprset, + rx_list[idx], + limits.initial_lexer_fuel, + ) { + Ok(true) => {} + Ok(false) => { + rx_list[idx] = ExprRef::NO_MATCH; + } + Err(_) => { + bail!( + "fuel exhausted when checking relevance of lexemes ({})", + fuel0 + ); + } + } + limits.initial_lexer_fuel = limits + .initial_lexer_fuel + .saturating_sub(exprset.cost() - c0); + } + let rx_sets = StateID::new_hash_cons(); let mut r = RegexVec { deriv: DerivCache::new(), next_byte: NextByteCache::new(), - relevance: RelevanceCache::new(), + relevance, lazy: lazy.unwrap_or_else(|| SimpleVob::alloc(rx_list.len())), exprs: exprset, alpha, @@ -364,7 +394,7 @@ impl RegexVec { // in fact, transition from MISSING and DEAD should both lead to DEAD r.state_table.fill(StateID::DEAD); assert!(r.alpha.len() > 0); - r + Ok(r) } fn append_state(&mut self, state_desc: StateDesc) { @@ -439,6 +469,20 @@ impl RegexVec { for (idx, e) in iter_state(&self.rx_sets, state) { let d = self.deriv.derivative(&mut self.exprs, e, b); + + let fuel = self.fuel.saturating_sub(self.exprs.cost() - c0); + let d = match self + .relevance + .is_non_empty_limited(&mut self.exprs, d, fuel) + { + Ok(true) => d, + Ok(false) => ExprRef::NO_MATCH, + Err(_) => { + self.fuel = 0; // just in case + break; + } + }; + state_size += 1; if d != ExprRef::NO_MATCH { Self::push_rx(&mut vec_desc, idx, d); diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 3a5006d5..3ff22b68 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -218,7 +218,7 @@ dependencies = [ [[package]] name = "derivre" version = "0.1.0" -source = "git+https://github.com/microsoft/derivre?rev=424ec3bd1f711add6aeab1711108b63abe856d0c#424ec3bd1f711add6aeab1711108b63abe856d0c" +source = "git+https://github.com/microsoft/derivre?rev=02ee497e6e404a0b402b4f68a9abf599d22ed2ed#02ee497e6e404a0b402b4f68a9abf599d22ed2ed" dependencies = [ "ahash", "anyhow", @@ -908,7 +908,7 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "toktrie" version = "0.1.0" -source = "git+https://github.com/microsoft/toktrie?rev=8828701d3b1c743472fe61bdf6dab12cdd726ab4#8828701d3b1c743472fe61bdf6dab12cdd726ab4" +source = "git+https://github.com/microsoft/toktrie?rev=5e7013ad05081e918809d4ecebb33db7c4aabc69#5e7013ad05081e918809d4ecebb33db7c4aabc69" dependencies = [ "anyhow", "bytemuck", diff --git a/sample_parser/Cargo.toml b/sample_parser/Cargo.toml index faf5fae3..186c3e83 100644 --- a/sample_parser/Cargo.toml +++ b/sample_parser/Cargo.toml @@ -6,7 +6,7 @@ default-run = "sample_parser" [dependencies] llguidance_parser = { path = "../parser" } -toktrie_hf_tokenizers = { git = "https://github.com/microsoft/toktrie", rev = "8828701d3b1c743472fe61bdf6dab12cdd726ab4" } +toktrie_hf_tokenizers = { git = "https://github.com/microsoft/toktrie", rev = "5e7013ad05081e918809d4ecebb33db7c4aabc69" } serde_json = "1.0.128" anyhow = "1.0.87"