Skip to content

Commit

Permalink
allow for not compiling-in json schema validation
Browse files Browse the repository at this point in the history
  • Loading branch information
mmoskal committed Oct 21, 2024
1 parent 255f7c3 commit 10abb44
Show file tree
Hide file tree
Showing 6 changed files with 206 additions and 15 deletions.
22 changes: 14 additions & 8 deletions parser/src/json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ use crate::{
#[derive(Debug, Default, Clone)]
pub struct JsonCompileOptions {
pub compact: bool,
pub validate: bool,
}

fn to_compact_json(target: &serde_json::Value) -> String {
Expand Down Expand Up @@ -123,7 +122,14 @@ macro_rules! cache {
impl JsonCompileOptions {
pub fn json_to_llg(&self, schema: &Value) -> Result<TopLevelGrammar> {
let mut compiler = Compiler::new(self.clone());
compiler.run(schema)?;
compiler.validate(schema)?;
compiler.execute(schema)?;
compiler.builder.finalize()
}

pub fn json_to_llg_no_validate(&self, schema: &Value) -> Result<TopLevelGrammar> {
let mut compiler = Compiler::new(self.clone());
compiler.execute(schema)?;
compiler.builder.finalize()
}
}
Expand Down Expand Up @@ -236,13 +242,13 @@ impl Compiler {
}
}

pub fn run(&mut self, schema: &Value) -> Result<()> {
if self.options.validate {
SCHEMA_VALIDATOR
.validate(schema)
.map_err(|mut e| anyhow!("Invalid schema: {}", e.next().unwrap()))?;
}
pub fn validate(&mut self, schema: &Value) -> Result<()> {
SCHEMA_VALIDATOR
.validate(schema)
.map_err(|mut e| anyhow!("Invalid schema: {}", e.next().unwrap()))
}

pub fn execute(&mut self, schema: &Value) -> Result<()> {
self.builder.add_grammar(GrammarWithLexer {
greedy_skip_rx: if self.options.compact {
None
Expand Down
4 changes: 4 additions & 0 deletions sample_parser/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,7 @@ path = "src/sample_parser.rs"
[[bin]]
name = "schema_tester"
path = "src/schema_tester.rs"

[[bin]]
name = "minimal"
path = "src/minimal.rs"
8 changes: 6 additions & 2 deletions sample_parser/run.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
#!/bin/sh

# cargo run data/blog.schema.ll.json data/blog.sample.json
cargo run data/blog.schema.json data/blog.sample.json
cargo run data/blog.schema.ll.json data/blog.sample.json
# cargo run data/blog.schema.json data/blog.sample.json
# cargo run --release --bin minimal data/blog.schema.json data/blog.sample.json
# mkdir -p tmp
# strip -o tmp/minimal ../../target/release/minimal
# ls -l ../../target/release/minimal tmp/minimal
181 changes: 181 additions & 0 deletions sample_parser/src/minimal.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
use std::{env, fs::File, hint::black_box, io::Read, sync::Arc, vec};

use llguidance_parser::{
api::{ParserLimits, TopLevelGrammar},
toktrie::{InferenceCapabilities, TokEnv, TokRxInfo, TokTrie, TokenId, TokenizerEnv},
Constraint, JsonCompileOptions, TokenParser,
};

struct SingleByteTokenizer {
tok_trie: TokTrie,
}

impl SingleByteTokenizer {
fn new() -> Self {
let mut words = (0..=255).map(|x| vec![x]).collect::<Vec<_>>();
words.push("<eos>".as_bytes().to_vec());
let info = TokRxInfo {
vocab_size: words.len() as u32,
tok_eos: words.len() as u32 - 1,
tok_bos: None,
tok_pad: None,
tok_unk: None,
tok_end_of_turn: None,
};
let tok_trie = TokTrie::from(&info, &words);
SingleByteTokenizer { tok_trie }
}

fn to_env(self) -> TokEnv {
Arc::new(self)
}
}

impl TokenizerEnv for SingleByteTokenizer {
fn stop(&self) -> ! {
panic!("stop called")
}

fn tok_trie(&self) -> &TokTrie {
&self.tok_trie
}

fn tokenize_bytes(&self, s: &[u8]) -> Vec<TokenId> {
self.tok_trie.greedy_tokenize(s)
}
}

fn main() {
let args: Vec<String> = env::args().collect();
if args.len() != 3 {
eprintln!("Usage: {} <schema.ll.json> <sample.json>", args[0]);
std::process::exit(1);
}

let schema_file = read_file_to_string(&args[1]);
let schema: TopLevelGrammar = if args[1].ends_with(".ll.json") {
serde_json::from_str(&schema_file).expect("Invalid JSON in schema")
} else if args[1].ends_with(".schema.json") {
let opts = JsonCompileOptions {
compact: false,
};
let val = serde_json::from_str(&schema_file).expect("Invalid JSON in schema");
opts.json_to_llg_no_validate(&val)
.expect("Failed to convert JSON to LLG")
} else {
panic!("Unknown schema file extension")
};
let obj_str = read_file_to_string(&args[2]);

let tok_env: TokEnv = SingleByteTokenizer::new().to_env();

let tokens = tok_env.tokenize(&obj_str);

// set to 2 for more output; 1 is warnings only
let stderr_log_level = 1;

// typically set to 2, to send info-level output to the user
let buffer_log_level = 2;

let parser = TokenParser::from_llguidance_json(
tok_env.clone(),
schema,
llguidance_parser::Logger::new(buffer_log_level, stderr_log_level),
InferenceCapabilities {
ff_tokens: true, // can the engine append multiple tokens?
backtrack: false, // can the engine remove generated tokens?

conditional_ff_tokens: false, // not used
fork: false, // not used
},
ParserLimits::default(),
vec![],
)
.unwrap();
let mut constraint = Constraint::new(parser);

// enable sending parser results back via the logs (constraint.flush_logs())
constraint.log_json_progress = true;

let trie = tok_env.tok_trie();

eprintln!("Parsing tokens: {}", trie.tokens_dbg(&tokens));

let mut idx = 0;
while idx < tokens.len() {
let res = constraint.compute_mask().unwrap();

if res.is_stop() {
// stop sequence
break;
}

let sampled_token = if let Some(mask) = &res.sample_mask {
// Simulate sampling - it should use the mask and temperature
black_box(mask);
black_box(constraint.temperature);
let sampled_token = tokens[idx];

println!(
"SAMPLE {}: {} {}",
idx,
sampled_token,
tok_env.tok_trie().token_dbg(sampled_token)
);
Some(sampled_token)
} else {
// sampling not required
println!("NO SAMPLE");
None
};

let splice = constraint.commit_token(sampled_token).unwrap();
if splice.stop {
// stop sequence
break;
}

assert!(splice.backtrack == 0); // we didn't allow backtracking in InferenceCaps

// The splice contains the tokens (possibly more than one since we enabled ff_tokens
// in InferenceCaps) that the parser wants to append to the output.

// if this fails, our test data is broken
if tokens[idx..idx + splice.ff_tokens.len()] != splice.ff_tokens {
panic!(
"BAD TEST: ff_tokens mismatch:\n{}\n{}",
trie.tokens_dbg(&tokens[idx..idx + splice.ff_tokens.len()]),
trie.tokens_dbg(&splice.ff_tokens)
);
}

if splice.ff_tokens.len() > 1 {
println!("FF: {}", trie.tokens_dbg(&splice.ff_tokens));
}

idx += splice.ff_tokens.len();

// send output to the user
send_output(&constraint.flush_logs());
}

// flush any output
send_output(&constraint.flush_logs());
// the stop reason should be likely also sent to the user
println!("Stop reason: {:?}", constraint.parser.stop_reason());
}

fn read_file_to_string(filename: &str) -> String {
let mut file = File::open(filename).expect("Unable to open file");
let mut content = String::new();
file.read_to_string(&mut content)
.expect("Unable to read file");
content
}

fn send_output(user_output: &str) {
// enable if you want to see the output
if false {
println!("{}", user_output);
}
}
1 change: 0 additions & 1 deletion sample_parser/src/sample_parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ fn main() {
} else if args[1].ends_with(".schema.json") {
let opts = JsonCompileOptions {
compact: false,
validate: true,
};
let val = serde_json::from_str(&schema_file).expect("Invalid JSON in schema");
opts.json_to_llg(&val)
Expand Down
5 changes: 1 addition & 4 deletions sample_parser/src/schema_tester.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,7 @@ use serde_json::Value;

fn test_file(tok_env: TokEnv, file: &str) {
let schema_file = read_file_to_string(file);
let opts = JsonCompileOptions {
compact: false,
validate: true,
};
let opts = JsonCompileOptions { compact: false };
let val: Value = serde_json::from_str(&schema_file).expect("Invalid JSON in schema");

if schema_file.len() < 512 && val["$ref"].is_string() {
Expand Down

0 comments on commit 10abb44

Please sign in to comment.