From 10abb4446b2585ffcb53f5aef268eb3740837284 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Mon, 21 Oct 2024 10:48:37 -0700 Subject: [PATCH] allow for not compiling-in json schema validation --- parser/src/json.rs | 22 ++-- sample_parser/Cargo.toml | 4 + sample_parser/run.sh | 8 +- sample_parser/src/minimal.rs | 181 +++++++++++++++++++++++++++++ sample_parser/src/sample_parser.rs | 1 - sample_parser/src/schema_tester.rs | 5 +- 6 files changed, 206 insertions(+), 15 deletions(-) create mode 100644 sample_parser/src/minimal.rs diff --git a/parser/src/json.rs b/parser/src/json.rs index ed6f2641..cd097981 100644 --- a/parser/src/json.rs +++ b/parser/src/json.rs @@ -16,7 +16,6 @@ use crate::{ #[derive(Debug, Default, Clone)] pub struct JsonCompileOptions { pub compact: bool, - pub validate: bool, } fn to_compact_json(target: &serde_json::Value) -> String { @@ -123,7 +122,14 @@ macro_rules! cache { impl JsonCompileOptions { pub fn json_to_llg(&self, schema: &Value) -> Result { let mut compiler = Compiler::new(self.clone()); - compiler.run(schema)?; + compiler.validate(schema)?; + compiler.execute(schema)?; + compiler.builder.finalize() + } + + pub fn json_to_llg_no_validate(&self, schema: &Value) -> Result { + let mut compiler = Compiler::new(self.clone()); + compiler.execute(schema)?; compiler.builder.finalize() } } @@ -236,13 +242,13 @@ impl Compiler { } } - pub fn run(&mut self, schema: &Value) -> Result<()> { - if self.options.validate { - SCHEMA_VALIDATOR - .validate(schema) - .map_err(|mut e| anyhow!("Invalid schema: {}", e.next().unwrap()))?; - } + pub fn validate(&mut self, schema: &Value) -> Result<()> { + SCHEMA_VALIDATOR + .validate(schema) + .map_err(|mut e| anyhow!("Invalid schema: {}", e.next().unwrap())) + } + pub fn execute(&mut self, schema: &Value) -> Result<()> { self.builder.add_grammar(GrammarWithLexer { greedy_skip_rx: if self.options.compact { None diff --git a/sample_parser/Cargo.toml b/sample_parser/Cargo.toml index 186c3e83..12ca97a1 100644 --- a/sample_parser/Cargo.toml +++ b/sample_parser/Cargo.toml @@ -17,3 +17,7 @@ path = "src/sample_parser.rs" [[bin]] name = "schema_tester" path = "src/schema_tester.rs" + +[[bin]] +name = "minimal" +path = "src/minimal.rs" diff --git a/sample_parser/run.sh b/sample_parser/run.sh index b308435b..f10f1fa1 100755 --- a/sample_parser/run.sh +++ b/sample_parser/run.sh @@ -1,4 +1,8 @@ #!/bin/sh -# cargo run data/blog.schema.ll.json data/blog.sample.json -cargo run data/blog.schema.json data/blog.sample.json +cargo run data/blog.schema.ll.json data/blog.sample.json +# cargo run data/blog.schema.json data/blog.sample.json +# cargo run --release --bin minimal data/blog.schema.json data/blog.sample.json +# mkdir -p tmp +# strip -o tmp/minimal ../../target/release/minimal +# ls -l ../../target/release/minimal tmp/minimal diff --git a/sample_parser/src/minimal.rs b/sample_parser/src/minimal.rs new file mode 100644 index 00000000..24eb25e3 --- /dev/null +++ b/sample_parser/src/minimal.rs @@ -0,0 +1,181 @@ +use std::{env, fs::File, hint::black_box, io::Read, sync::Arc, vec}; + +use llguidance_parser::{ + api::{ParserLimits, TopLevelGrammar}, + toktrie::{InferenceCapabilities, TokEnv, TokRxInfo, TokTrie, TokenId, TokenizerEnv}, + Constraint, JsonCompileOptions, TokenParser, +}; + +struct SingleByteTokenizer { + tok_trie: TokTrie, +} + +impl SingleByteTokenizer { + fn new() -> Self { + let mut words = (0..=255).map(|x| vec![x]).collect::>(); + words.push("".as_bytes().to_vec()); + let info = TokRxInfo { + vocab_size: words.len() as u32, + tok_eos: words.len() as u32 - 1, + tok_bos: None, + tok_pad: None, + tok_unk: None, + tok_end_of_turn: None, + }; + let tok_trie = TokTrie::from(&info, &words); + SingleByteTokenizer { tok_trie } + } + + fn to_env(self) -> TokEnv { + Arc::new(self) + } +} + +impl TokenizerEnv for SingleByteTokenizer { + fn stop(&self) -> ! { + panic!("stop called") + } + + fn tok_trie(&self) -> &TokTrie { + &self.tok_trie + } + + fn tokenize_bytes(&self, s: &[u8]) -> Vec { + self.tok_trie.greedy_tokenize(s) + } +} + +fn main() { + let args: Vec = env::args().collect(); + if args.len() != 3 { + eprintln!("Usage: {} ", args[0]); + std::process::exit(1); + } + + let schema_file = read_file_to_string(&args[1]); + let schema: TopLevelGrammar = if args[1].ends_with(".ll.json") { + serde_json::from_str(&schema_file).expect("Invalid JSON in schema") + } else if args[1].ends_with(".schema.json") { + let opts = JsonCompileOptions { + compact: false, + }; + let val = serde_json::from_str(&schema_file).expect("Invalid JSON in schema"); + opts.json_to_llg_no_validate(&val) + .expect("Failed to convert JSON to LLG") + } else { + panic!("Unknown schema file extension") + }; + let obj_str = read_file_to_string(&args[2]); + + let tok_env: TokEnv = SingleByteTokenizer::new().to_env(); + + let tokens = tok_env.tokenize(&obj_str); + + // set to 2 for more output; 1 is warnings only + let stderr_log_level = 1; + + // typically set to 2, to send info-level output to the user + let buffer_log_level = 2; + + let parser = TokenParser::from_llguidance_json( + tok_env.clone(), + schema, + llguidance_parser::Logger::new(buffer_log_level, stderr_log_level), + InferenceCapabilities { + ff_tokens: true, // can the engine append multiple tokens? + backtrack: false, // can the engine remove generated tokens? + + conditional_ff_tokens: false, // not used + fork: false, // not used + }, + ParserLimits::default(), + vec![], + ) + .unwrap(); + let mut constraint = Constraint::new(parser); + + // enable sending parser results back via the logs (constraint.flush_logs()) + constraint.log_json_progress = true; + + let trie = tok_env.tok_trie(); + + eprintln!("Parsing tokens: {}", trie.tokens_dbg(&tokens)); + + let mut idx = 0; + while idx < tokens.len() { + let res = constraint.compute_mask().unwrap(); + + if res.is_stop() { + // stop sequence + break; + } + + let sampled_token = if let Some(mask) = &res.sample_mask { + // Simulate sampling - it should use the mask and temperature + black_box(mask); + black_box(constraint.temperature); + let sampled_token = tokens[idx]; + + println!( + "SAMPLE {}: {} {}", + idx, + sampled_token, + tok_env.tok_trie().token_dbg(sampled_token) + ); + Some(sampled_token) + } else { + // sampling not required + println!("NO SAMPLE"); + None + }; + + let splice = constraint.commit_token(sampled_token).unwrap(); + if splice.stop { + // stop sequence + break; + } + + assert!(splice.backtrack == 0); // we didn't allow backtracking in InferenceCaps + + // The splice contains the tokens (possibly more than one since we enabled ff_tokens + // in InferenceCaps) that the parser wants to append to the output. + + // if this fails, our test data is broken + if tokens[idx..idx + splice.ff_tokens.len()] != splice.ff_tokens { + panic!( + "BAD TEST: ff_tokens mismatch:\n{}\n{}", + trie.tokens_dbg(&tokens[idx..idx + splice.ff_tokens.len()]), + trie.tokens_dbg(&splice.ff_tokens) + ); + } + + if splice.ff_tokens.len() > 1 { + println!("FF: {}", trie.tokens_dbg(&splice.ff_tokens)); + } + + idx += splice.ff_tokens.len(); + + // send output to the user + send_output(&constraint.flush_logs()); + } + + // flush any output + send_output(&constraint.flush_logs()); + // the stop reason should be likely also sent to the user + println!("Stop reason: {:?}", constraint.parser.stop_reason()); +} + +fn read_file_to_string(filename: &str) -> String { + let mut file = File::open(filename).expect("Unable to open file"); + let mut content = String::new(); + file.read_to_string(&mut content) + .expect("Unable to read file"); + content +} + +fn send_output(user_output: &str) { + // enable if you want to see the output + if false { + println!("{}", user_output); + } +} diff --git a/sample_parser/src/sample_parser.rs b/sample_parser/src/sample_parser.rs index 8b1106a9..a6310730 100644 --- a/sample_parser/src/sample_parser.rs +++ b/sample_parser/src/sample_parser.rs @@ -19,7 +19,6 @@ fn main() { } else if args[1].ends_with(".schema.json") { let opts = JsonCompileOptions { compact: false, - validate: true, }; let val = serde_json::from_str(&schema_file).expect("Invalid JSON in schema"); opts.json_to_llg(&val) diff --git a/sample_parser/src/schema_tester.rs b/sample_parser/src/schema_tester.rs index 1d1506a7..da7b8e93 100644 --- a/sample_parser/src/schema_tester.rs +++ b/sample_parser/src/schema_tester.rs @@ -9,10 +9,7 @@ use serde_json::Value; fn test_file(tok_env: TokEnv, file: &str) { let schema_file = read_file_to_string(file); - let opts = JsonCompileOptions { - compact: false, - validate: true, - }; + let opts = JsonCompileOptions { compact: false }; let val: Value = serde_json::from_str(&schema_file).expect("Invalid JSON in schema"); if schema_file.len() < 512 && val["$ref"].is_string() {