Skip to content

Commit

Permalink
bugfixes
Browse files Browse the repository at this point in the history
  • Loading branch information
mmoskal committed Nov 26, 2024
1 parent 498ac9c commit 6eaa15f
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 15 deletions.
10 changes: 9 additions & 1 deletion parser/src/constraint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,13 @@ impl Constraint {
/// It only returns 'STOP' if previous compute_mask() already returned 'STOP'
/// (in which case there's little point calling commit_token()).
pub fn commit_token(&mut self, sampled_token: Option<TokenId>) -> Result<CommitResult> {
loginfo!(self.parser.logger, "\ncommit_token({:?})", sampled_token);
loginfo!(
self.parser.logger,
"\ncommit_token({})",
sampled_token
.map(|t| self.parser.token_env.tok_trie().token_dbg(t))
.unwrap_or("None".to_string())
);

// ensure!(
// self.step_arg.is_none(),
Expand Down Expand Up @@ -199,6 +205,7 @@ impl Constraint {
let mut bt = self.parser.consume_token(t)?;
let mut tokens = vec![t];
if bt > 0 {
loginfo!(self.parser.logger, "backtrack sampled");
tokens.clear();
bt -= 1;
}
Expand All @@ -207,6 +214,7 @@ impl Constraint {
}

if self.parser.check_stop()? {
loginfo!(self.parser.logger, "set pending stop");
self.pending_stop = true;
}

Expand Down
21 changes: 13 additions & 8 deletions parser/src/tokenparser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -551,16 +551,21 @@ impl TokenParser {
if self.parser.scan_eos() {
// it got scanned correctly, so we remove it
infoln!(self, "scanned eos_token");
if self.inference_caps.backtrack {
return Ok(1);
} else {
warn!(self, "can't backtrack over eos_token");
// if self.inference_caps.backtrack {
// return Ok(1);
// } else {
// warn!(self, "can't backtrack over eos_token");
// return Ok(0);
// }
// don't backtrack it for now, fails tests
return Ok(0);
} else {
let accepting = self.is_accepting();
infoln!(self, "didn't scan eos_token; accept={}", accepting);
if accepting {
self.llm_tokens.push(token);
return Ok(0);
}
} else {
infoln!(self, "didn't scan eos_token; saving");
// TODO this will probably fail to apply, which will cause
// the parser to stop, which is approx. correct but ugly
}
}

Expand Down
4 changes: 2 additions & 2 deletions sample_parser/data/rfc.lark
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

start: root
nl: "\n"?
text: ( /[^>"<&]/ | "&amp;"| "&lt;"| "&gt;"| "&quot;"| "&apos;"| "&#x" /[0-9a-fA-F]+/ ";"
TEXT: ( /[^>"<&]/ | "&amp;"| "&lt;"| "&gt;"| "&quot;"| "&apos;"| "&#x" /[0-9a-fA-F]+/ ";"
| "&#" /[0-9]+/ ";")*
// text: TEXT
text: TEXT
root: element1
element1: "<rfc" ( (" number=\"" text "\"") | (" obsoletes=\"" text "\"") | (" updates=\"" text "\"") | (" category=\"" text "\"") | (" mode=\"" text "\"") | (" consensus=\"" "\"") | (" seriesNo=\"" text "\"") | (" ipr=\"" text "\"") | (" iprExtract=\"" text "\"") | (" submissionType=\"" "\"") | (" docName=\"" text "\"") | (" sortRefs=\"" "\"") | (" symRefs=\"" "\"") | (" tocInclude=\"" "\"") | (" tocDepth=\"" text "\"") | (" prepTime=\"" text "\"") | (" indexInclude=\"" "\"") | (" version=\"" text "\"") | (" scripts=\"" text "\"") | (" expiresDate=\"" text "\""))* ">" nl ( ( ( ( ( "" | ( element2 )+ ) ) element3 ) element4 ) ( "" | element5 ) ) "</rfc>" nl
element2: "<link" ( (" href=\"" text "\"") | (" rel=\"" text "\""))* ">" nl "</link>" nl
Expand Down
15 changes: 11 additions & 4 deletions sample_parser/src/grammar_tester.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@ fn check_grammar(
output: &[&str],
temp: f32,
) {
println!("\nChecking grammar");

let parser = TokenParser::from_llguidance_json(
tok_env.clone(),
grammar,
Expand Down Expand Up @@ -84,13 +82,17 @@ fn check_grammar(
bt = res.backtrack;
toks = res.ff_tokens.clone();
if toks.is_empty() || toks[0] != tok {
if output[idx + 1].starts_with("1↶") {
if idx + 1 < output.len() && output[idx + 1].starts_with("1↶") {
// fast-forward with fake backtrack
assert!(bt == 0 || res.ff_tokens.is_empty());
bt = 1;
// go to forced byte checking
} else {
panic!("Expected token {} got {}", tok, toks[0]);
if toks.is_empty() {
panic!("Expected {}; got nothing", tok);
} else {
panic!("Expected token {} got {}", tok, toks[0]);
}
}
} else if toks.len() > 1 {
// we got fast-forwarded to the next entry,
Expand Down Expand Up @@ -180,6 +182,7 @@ lazy_static! {

fn check_lark_grammar_prompt(lark: &str, prompt_str: &str, output: &[&str]) {
let grm = TopLevelGrammar::from_lark(lark.to_string());
println!("\nChecking grammar:\n{}\nagainst: {:?}", lark, output);
check_grammar(&TOK_ENV, prompt_str, grm, output, 0.0);
}

Expand All @@ -205,6 +208,10 @@ fn check_lark_grammar_nested(lark: &str, sub_lark: &str, output: &[&str]) {
let mut sub_grm = GrammarWithLexer::from_lark(sub_lark.to_string());
sub_grm.name = Some("sub".to_string());
top_grm.grammars.push(sub_grm);
println!(
"\nChecking nested grammars:\n{}\nNested:\n{}\nagainst: {:?}",
lark, sub_lark, output
);
check_grammar(&TOK_ENV, "", top_grm, output, temp);
}

Expand Down

0 comments on commit 6eaa15f

Please sign in to comment.