Skip to content

Commit

Permalink
[JSON] Allow custom retrievers (#106)
Browse files Browse the repository at this point in the history
Allow `referencing::Retrieve` implementations to be passed to JSON compiler
  • Loading branch information
hudson-ai authored Jan 9, 2025
1 parent 9e65e50 commit c81c41d
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 6 deletions.
24 changes: 22 additions & 2 deletions parser/src/json/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@ use anyhow::{anyhow, Context, Result};
use derivre::{JsonQuoteOptions, RegexAst};
use hashbrown::HashMap;
use indexmap::IndexMap;
use referencing::Retrieve;
use serde_json::{json, Value};

use super::numeric::{check_number_bounds, rx_float_range, rx_int_range, Decimal};
use super::schema::{build_schema, Schema};
use super::schema::{build_schema, RetrieveWrapper, Schema};

use crate::{
api::{GrammarWithLexer, RegexSpec, TopLevelGrammar},
Expand All @@ -22,6 +23,7 @@ pub struct JsonCompileOptions {
pub key_separator: String,
pub whitespace_flexible: bool,
pub coerce_one_of: bool,
pub retriever: Option<RetrieveWrapper>,
}

fn json_dumps(target: &serde_json::Value) -> String {
Expand Down Expand Up @@ -67,11 +69,28 @@ impl Default for JsonCompileOptions {
key_separator: ":".to_string(),
whitespace_flexible: true,
coerce_one_of: false,
retriever: None,
}
}
}

impl JsonCompileOptions {
pub fn new(
item_separator: String,
key_separator: String,
whitespace_flexible: bool,
coerce_one_of: bool,
retriever: Option<std::rc::Rc<dyn Retrieve>>,
) -> Self {
Self {
item_separator,
key_separator,
whitespace_flexible,
coerce_one_of,
retriever: retriever.map(RetrieveWrapper::new),
}
}

pub fn json_to_llg(&self, schema: Value) -> Result<TopLevelGrammar> {
let mut compiler = Compiler::new(self.clone());
#[cfg(feature = "jsonschema_validation")]
Expand Down Expand Up @@ -117,7 +136,8 @@ impl Compiler {
..GrammarWithLexer::default()
});

let (compiled_schema, definitions) = build_schema(schema)?;
let (compiled_schema, definitions) =
build_schema(schema, self.options.retriever.as_deref())?;

let root = self.gen_json(&compiled_schema)?;
self.builder.set_start_node(root);
Expand Down
110 changes: 106 additions & 4 deletions parser/src/json/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ use anyhow::{anyhow, bail, Result};
use derivre::RegexAst;
use hashbrown::{HashMap, HashSet};
use indexmap::{IndexMap, IndexSet};
use referencing::{Draft, Registry, Resolver, ResourceRef};
use referencing::{Draft, Registry, Resolver, Resource, ResourceRef, Retrieve};
use serde_json::Value;
use std::{cell::RefCell, mem, rc::Rc};
use std::{any::type_name_of_val, cell::RefCell, mem, rc::Rc};

use super::formats::lookup_format;
use super::numeric::Decimal;
Expand Down Expand Up @@ -550,11 +550,33 @@ impl<'a> Context<'a> {
}
}

#[derive(Clone)]
pub struct RetrieveWrapper(pub Rc<dyn Retrieve>);
impl RetrieveWrapper {
pub fn new(retrieve: Rc<dyn Retrieve>) -> Self {
RetrieveWrapper(retrieve)
}
}
impl std::ops::Deref for RetrieveWrapper {
type Target = dyn Retrieve;
fn deref(&self) -> &Self::Target {
self.0.as_ref()
}
}
impl std::fmt::Debug for RetrieveWrapper {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", type_name_of_val(&self.0))
}
}

fn draft_for(value: &Value) -> Draft {
DEFAULT_DRAFT.detect(value).unwrap_or(DEFAULT_DRAFT)
}

pub fn build_schema(contents: Value) -> Result<(Schema, HashMap<String, Schema>)> {
pub fn build_schema(
contents: Value,
retriever: Option<&dyn Retrieve>,
) -> Result<(Schema, HashMap<String, Schema>)> {
if let Some(b) = contents.as_bool() {
if b {
return Ok((Schema::Any, HashMap::new()));
Expand All @@ -567,7 +589,17 @@ pub fn build_schema(contents: Value) -> Result<(Schema, HashMap<String, Schema>)
let resource = draft.create_resource(contents);
let base_uri = resource.id().unwrap_or(DEFAULT_ROOT_URI).to_string();

let registry = Registry::try_new(&base_uri, resource)?;
let registry = {
// Weirdly no apparent way to instantiate a new registry with a retriever, so we need to
// make an empty one and then add the retriever + resource that may depend on said retriever
let empty_registry =
Registry::try_from_resources(std::iter::empty::<(String, Resource)>())?;
empty_registry.try_with_resource_and_retriever(
&base_uri,
resource,
retriever.unwrap_or(&referencing::DefaultRetriever),
)?
};

let resolver = registry.try_resolver(&base_uri)?;
let ctx = Context {
Expand Down Expand Up @@ -1175,3 +1207,73 @@ fn opt_min<T: PartialOrd>(a: Option<T>, b: Option<T>) -> Option<T> {
(None, None) => None,
}
}

#[cfg(test)]
mod test_retriever {
use super::{build_schema, Schema};
use referencing::{Retrieve, Uri};
use serde_json::{json, Value};
use std::fmt;

#[derive(Debug, Clone)]
struct TestRetrieverError(String);
impl fmt::Display for TestRetrieverError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Could not retrieve URI: {}", self.0)
}
}
impl std::error::Error for TestRetrieverError {}

struct TestRetriever {
schemas: std::collections::HashMap<String, serde_json::Value>,
}
impl Retrieve for TestRetriever {
fn retrieve(
&self,
uri: &Uri<&str>,
) -> Result<Value, Box<dyn std::error::Error + Send + Sync>> {
let key = uri.as_str();
match self.schemas.get(key) {
Some(schema) => Ok(schema.clone()),
None => Err(Box::new(TestRetrieverError(key.to_string()))),
}
}
}

#[test]
fn test_retriever() {
let key: &str = "http://example.com/schema";

let schema = json!({
"$ref": key
});
let retriever = TestRetriever {
schemas: vec![(
key.to_string(),
json!({
"type": "string"
}),
)]
.into_iter()
.collect(),
};
let (schema, defs) = build_schema(schema, Some(&retriever)).unwrap();
match schema {
Schema::Ref { uri } => {
assert_eq!(uri, key);
}
_ => panic!("Unexpected schema: {:?}", schema),
}
assert_eq!(defs.len(), 1);
let val = defs.get(key).unwrap();
// poor-man's partial_eq
match val {
Schema::String {
min_length: 0,
max_length: None,
regex: None,
} => {}
_ => panic!("Unexpected schema: {:?}", val),
}
}
}
1 change: 1 addition & 0 deletions python_ext/src/py.rs
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,7 @@ impl JsonCompiler {
key_separator: self.key_separator.clone(),
whitespace_flexible: self.whitespace_flexible,
coerce_one_of: self.coerce_one_of,
retriever: None,
};
let grammar = compile_options.json_to_llg(schema).map_err(val_error)?;
serde_json::to_string(&grammar).map_err(val_error)
Expand Down

0 comments on commit c81c41d

Please sign in to comment.