From 331737f9dc17de31fdf4f1664b0a92eb952b7083 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Sun, 5 Jan 2025 22:18:09 +0100 Subject: [PATCH] use parser factory in test_ll --- sample_parser/src/lib.rs | 51 ++++++++++++++++++++-------------------- 1 file changed, 25 insertions(+), 26 deletions(-) diff --git a/sample_parser/src/lib.rs b/sample_parser/src/lib.rs index 8fd9441..b9fb404 100644 --- a/sample_parser/src/lib.rs +++ b/sample_parser/src/lib.rs @@ -1,8 +1,9 @@ use lazy_static::lazy_static; use llguidance::{ - api::{GrammarWithLexer, ParserLimits, TopLevelGrammar}, + api::{GrammarWithLexer, TopLevelGrammar}, + earley::SlicedBiasComputer, toktrie::{InferenceCapabilities, TokEnv, TokenId}, - Constraint, TokenParser, + Constraint, ParserFactory, }; /// Check that the grammar generates the expected output. @@ -17,29 +18,17 @@ use llguidance::{ /// These tests are "recorded" by passing "test_trace": true in the llguidance /// request and post-processing. fn check_grammar( - tok_env: &TokEnv, + factory: &ParserFactory, prompt_str: &str, grammar: TopLevelGrammar, output: &[&str], temp: f32, ) -> Constraint { - let parser = TokenParser::from_llguidance_json( - tok_env.clone(), - grammar, - llguidance::Logger::new(0, 2), - InferenceCapabilities { - ff_tokens: true, // can the engine append multiple tokens? - backtrack: true, // can the engine remove generated tokens? - - conditional_ff_tokens: false, // not used - fork: false, // not used - }, - ParserLimits::default(), - vec![], - ) - .unwrap(); + let parser = factory.create_parser(grammar).unwrap(); let mut constraint = Constraint::new(parser); + let tok_env = factory.tok_env(); + let prompt = constraint.process_prompt(tok_env.tokenize(prompt_str)); check_eq(tok_env, "prompt", &prompt, output[0]); @@ -192,17 +181,27 @@ fn tokenize_trace(tok_env: &TokEnv, s: &str) -> Vec { } lazy_static! { - static ref TOK_ENV: TokEnv = { - toktrie_hf_tokenizers::ByteTokenizerEnv::from_name("microsoft/Phi-3.5-mini-instruct", None) + static ref PARSER_FACTORY: ParserFactory = { + let env = toktrie_hf_tokenizers::ByteTokenizerEnv::from_name("microsoft/Phi-3.5-mini-instruct", None) .unwrap() - .to_env() + .to_env(); + let mut fact = ParserFactory::new(&env, + InferenceCapabilities { + ff_tokens: true, // can the engine append multiple tokens? + backtrack: true, // can the engine remove generated tokens? + conditional_ff_tokens: false, // not used + fork: false, // not used + }, &SlicedBiasComputer::general_slices()).unwrap(); + fact.set_stderr_log_level(2); + fact.set_buffer_log_level(0); + fact }; } pub fn check_lark_grammar_prompt(lark: &str, prompt_str: &str, output: &[&str]) -> Constraint { let grm = TopLevelGrammar::from_lark(lark.to_string()); println!("\nChecking grammar:\n{}\nagainst: {:?}", lark, output); - check_grammar(&TOK_ENV, prompt_str, grm, output, 0.0) + check_grammar(&PARSER_FACTORY, prompt_str, grm, output, 0.0) } pub fn check_lark_grammar(lark: &str, output: &[&str]) -> Constraint { @@ -231,7 +230,7 @@ pub fn check_lark_grammar_nested(lark: &str, sub_lark: &str, output: &[&str]) -> "\nChecking nested grammars:\n{}\nNested:\n{}\nagainst: {:?}", lark, sub_lark, output ); - check_grammar(&TOK_ENV, "", top_grm, output, temp) + check_grammar(&PARSER_FACTORY, "", top_grm, output, temp) } pub fn check_lark_json(lark: &str, json_schema: serde_json::Value, output: &[&str]) -> Constraint { @@ -244,7 +243,7 @@ pub fn check_lark_json(lark: &str, json_schema: serde_json::Value, output: &[&st "\nChecking lark+json:\n{}\nNested:\n{}\nagainst: {:?}", lark, schema_str, output ); - check_grammar(&TOK_ENV, "", top_grm, output, 0.0) + check_grammar(&PARSER_FACTORY, "", top_grm, output, 0.0) } pub fn check_capture(c: &Constraint, name: &str, expected: &str) { @@ -257,7 +256,7 @@ pub fn check_capture(c: &Constraint, name: &str, expected: &str) { } pub fn print_tokenized(s: &str) { - let trie = TOK_ENV.tok_trie(); - let tokens = TOK_ENV.tokenize(s); + let trie = PARSER_FACTORY.tok_env().tok_trie(); + let tokens = PARSER_FACTORY.tok_env().tokenize(s); println!("{:?}", trie.test_trace_tokens(&tokens)); }