Skip to content

Commit

Permalink
json fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
mmoskal committed Sep 8, 2024
1 parent ccceace commit f25944b
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 20 deletions.
53 changes: 33 additions & 20 deletions parser/src/json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ fn limited_str(node: &Value) -> String {
fn validate_json_node_keys(node: &Value) -> Result<()> {
let node = node
.as_object()
.ok_or_else(|| anyhow!("Expected object as json schema, got: {}", limited_str(node)))?;
.ok_or_else(|| anyhow!("Expected object as json schema, got: {}", limited_str(node)))
.unwrap();

let typ = node.get("type").and_then(|v| v.as_str()).unwrap_or("");

Expand Down Expand Up @@ -268,6 +269,11 @@ impl Compiler {
}

fn gen_json(&mut self, json_schema: &Value) -> Result<NodeRef> {
if json_schema.as_bool() == Some(true) {
return Ok(self.gen_json_any());
}

// eprintln!("gen_json: {}", limited_str(json_schema));
validate_json_node_keys(json_schema)?;

// Process anyOf
Expand Down Expand Up @@ -379,6 +385,7 @@ impl Compiler {
fn gen_json_any(&mut self) -> NodeRef {
cache!(self.any_cache, {
let json_any = self.builder.placeholder();
self.any_cache = Some(json_any); // avoid infinite recursion
let all_jsons = json!([
{"type": "null"},
{"type": "boolean"},
Expand Down Expand Up @@ -528,31 +535,37 @@ impl Compiler {
max as usize
});

let item_schema_compiled = if item_schema_is_false {
None
if let Some(item_arr) = item_schema.as_array() {
for item in item_arr {
required_items.push(self.gen_json(item)?);
}
} else {
Some(self.gen_json(item_schema)?)
};

for i in 0..n_to_add {
let item = if i < prefix_items.len() {
self.gen_json(&prefix_items[i])?
} else if let Some(compiled) = &item_schema_compiled {
compiled.clone()
let item_schema_compiled = if item_schema_is_false {
None
} else {
break;
Some(self.gen_json(item_schema)?)
};

if i < min_items as usize {
required_items.push(item);
} else {
optional_items.push(item);
for i in 0..n_to_add {
let item = if i < prefix_items.len() {
self.gen_json(&prefix_items[i])?
} else if let Some(compiled) = &item_schema_compiled {
compiled.clone()
} else {
break;
};

if i < min_items as usize {
required_items.push(item);
} else {
optional_items.push(item);
}
}
}

if max_items.is_none() && !item_schema_is_false {
// Add an infinite tail of items
optional_items.push(self.sequence(item_schema_compiled.unwrap()));
if max_items.is_none() && !item_schema_is_false {
// Add an infinite tail of items
optional_items.push(self.sequence(item_schema_compiled.unwrap()));
}
}

let mut grammars: Vec<NodeRef> = vec![self.builder.string("[")];
Expand Down
5 changes: 5 additions & 0 deletions sample_parser/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
name = "sample_parser"
version = "0.1.0"
edition = "2021"
default-run = "sample_parser"

[dependencies]
llguidance_parser = { path = "../parser" }
Expand All @@ -12,3 +13,7 @@ anyhow = "1.0.87"
[[bin]]
name = "sample_parser"
path = "src/sample_parser.rs"

[[bin]]
name = "schema_tester"
path = "src/schema_tester.rs"
98 changes: 98 additions & 0 deletions sample_parser/src/schema_tester.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
use std::{env, fs::File, io::Read, vec};

use llguidance_parser::{
api::ParserLimits,
toktrie::{InferenceCapabilities, TokEnv},
Constraint, JsonCompileOptions, TokenParser,
};
use serde_json::Value;

fn test_file(tok_env: TokEnv, file: &str) {
let schema_file = read_file_to_string(file);
let opts = JsonCompileOptions {
compact: false,
validate: true,
};
let val: Value = serde_json::from_str(&schema_file).expect("Invalid JSON in schema");

if schema_file.len() < 512 && val["$ref"].is_string() {
eprintln!("{} ref-only", file);
return;
}

let schema = opts.json_to_llg(&val);

let schema = match schema {
Ok(schema) => schema,
Err(e) => {
eprintln!("{} Error: {}", file, e);
return;
}
};

let parser = TokenParser::from_llguidance_json(
tok_env,
schema,
llguidance_parser::Logger::new(0, 1),
InferenceCapabilities {
ff_tokens: true,
backtrack: false,
conditional_ff_tokens: false,
fork: false,
},
ParserLimits::default(),
vec![],
);

match parser {
Ok(parser) => {
let mut constraint = Constraint::new(parser);
constraint.compute_mask().unwrap();
eprintln!("{} OK", file);
}
Err(e) => {
eprintln!("{} Error: {}", file, e);
}
}
}

fn main() {
let args: Vec<String> = env::args().collect();
if args.len() < 2 {
eprintln!("Usage: {} <json-schema.json|folder>...", args[0]);
std::process::exit(1);
}

let mut files = vec![];
for arg in &args[1..] {
if arg.ends_with(".json") {
files.push(arg.to_string());
} else {
let dir = std::fs::read_dir(arg).expect("Unable to read directory");
for entry in dir {
let entry = entry.expect("Unable to read entry");
let path = entry.path();
if path.is_file() && path.to_str().unwrap().ends_with(".json") {
files.push(path.to_str().unwrap().to_string());
}
}
}
}

let tok_env: TokEnv =
toktrie_hf_tokenizers::ByteTokenizerEnv::from_name("microsoft/Phi-3.5-mini-instruct", None)
.unwrap()
.to_env();

for file in files {
test_file(tok_env.clone(), &file);
}
}

fn read_file_to_string(filename: &str) -> String {
let mut file = File::open(filename).expect("Unable to open file");
let mut content = String::new();
file.read_to_string(&mut content)
.expect("Unable to read file");
content
}

0 comments on commit f25944b

Please sign in to comment.