diff --git a/sample_parser/src/json_schema_testsuite.rs b/sample_parser/src/json_schema_testsuite.rs index 6ac00b5..4bc5921 100644 --- a/sample_parser/src/json_schema_testsuite.rs +++ b/sample_parser/src/json_schema_testsuite.rs @@ -6,7 +6,7 @@ use llguidance::{ Constraint, JsonCompileOptions, TokenParser, }; use serde::{Deserialize, Serialize}; -use serde_json::Value; +use serde_json::{json, Value}; use std::{env, fs::File, io::Read, time::Duration, vec}; #[derive(Debug, Serialize, Deserialize)] @@ -23,6 +23,29 @@ struct JsonTestSequence { valid: bool, } +fn round_float_to_int(v: &Value) -> Value { + match v { + Value::Number(q) => { + if let Some(n) = q.as_f64() { + if n.floor() == n { + json!(n as i64) + } else { + v.clone() + } + } else { + v.clone() + } + } + Value::Array(a) => Value::Array(a.iter().map(round_float_to_int).collect()), + Value::Object(o) => Value::Object( + o.iter() + .map(|(k, v)| (k.clone(), round_float_to_int(v))) + .collect(), + ), + _ => v.clone(), + } +} + impl JsonTestSequence { fn run_for( &self, @@ -45,7 +68,7 @@ impl JsonTestSequence { if self.valid { bail!("premature stop in valid test"); } else { - bail!("premature stop in invalid test"); // ?? + return Ok(()); } } @@ -103,10 +126,12 @@ impl JsonTestSequence { bail!("unexpected end of test"); } } else { - bail!( - "unexpected end of test for invalid test (accept={})", - accept - ); + if accept { + bail!("unexpected end of test for invalid test (accept)"); + } else { + // this is in fact correct - we forced EOS + Ok(()) + } } } @@ -129,7 +154,7 @@ impl JsonTestSequence { )?; let constraint = Constraint::new(parser); - let obj_str = serde_json::to_string_pretty(&self.data).unwrap(); + let obj_str = serde_json::to_string_pretty(&round_float_to_int(&self.data)).unwrap(); match self.run_for(stats, &obj_str, tok_env, constraint) { Ok(_) => Ok(()), Err(e) => {