diff --git a/biscuit-auth/examples/testcases.rs b/biscuit-auth/examples/testcases.rs index 05f44e44..648e2656 100644 --- a/biscuit-auth/examples/testcases.rs +++ b/biscuit-auth/examples/testcases.rs @@ -10,9 +10,14 @@ use biscuit::macros::*; use biscuit::Authorizer; use biscuit::{builder::*, builder_ext::*, Biscuit}; use biscuit::{KeyPair, PrivateKey, PublicKey}; +use biscuit_auth::builder; +use biscuit_auth::datalog::ExternFunc; +use biscuit_auth::datalog::RunLimits; use prost::Message; use rand::prelude::*; use serde::Serialize; +use std::collections::HashMap; +use std::sync::Arc; use std::{ collections::{BTreeMap, BTreeSet}, fs::File, @@ -157,6 +162,9 @@ fn run(target: String, root_key: Option, test: bool, json: bool) { add_test_result(&mut results, type_of(&target, &root, test)); add_test_result(&mut results, array_map(&target, &root, test)); + + add_test_result(&mut results, ffi(&target, &root, test)); + if json { let s = serde_json::to_string_pretty(&TestCases { root_private_key: hex::encode(root.private().to_bytes()), @@ -297,6 +305,15 @@ enum AuthorizerResult { } fn validate_token(root: &KeyPair, data: &[u8], authorizer_code: &str) -> Validation { + validate_token_with_limits(root, data, authorizer_code, RunLimits::default()) +} + +fn validate_token_with_limits( + root: &KeyPair, + data: &[u8], + authorizer_code: &str, + run_limits: RunLimits, +) -> Validation { let token = match Biscuit::from(&data[..], &root.public()) { Ok(t) => t, Err(e) => { @@ -331,7 +348,7 @@ fn validate_token(root: &KeyPair, data: &[u8], authorizer_code: &str) -> Validat } }; - let res = authorizer.authorize(); + let res = authorizer.authorize_with_limits(run_limits); //println!("authorizer world:\n{}", authorizer.print_world()); let (_, _, _, policies) = authorizer.dump(); let snapshot = authorizer.snapshot().unwrap(); @@ -2269,6 +2286,56 @@ fn array_map(target: &str, root: &KeyPair, test: bool) -> TestResult { } } +fn ffi(target: &str, root: &KeyPair, test: bool) -> TestResult { + let mut rng: StdRng = SeedableRng::seed_from_u64(1234); + let title = "test ffi calls (v6 blocks)".to_string(); + let filename = "test035_ffi".to_string(); + let token; + + let biscuit = + biscuit!(r#"check if true.extern::test(), "a".extern::test("a") == "equal strings""#) + .build_with_rng(&root, SymbolTable::default(), &mut rng) + .unwrap(); + token = print_blocks(&biscuit); + + let data = write_or_load_testcase(target, &filename, root, &biscuit, test); + + let mut validations = BTreeMap::new(); + validations.insert( + "".to_string(), + validate_token_with_limits( + root, + &data[..], + "allow if true", + RunLimits { + extern_funcs: HashMap::from([( + "test".to_string(), + ExternFunc::new(Arc::new(|left, right| match (left, right) { + (t, None) => Ok(t), + (builder::Term::Str(left), Some(builder::Term::Str(right))) + if left == right => + { + Ok(builder::Term::Str("equal strings".to_string())) + } + (builder::Term::Str(_), Some(builder::Term::Str(_))) => { + Ok(builder::Term::Str("different strings".to_string())) + } + _ => Err("unsupported operands".to_string()), + })), + )]), + ..Default::default() + }, + ), + ); + + TestResult { + title, + filename, + token, + validations, + } +} + fn print_blocks(token: &Biscuit) -> Vec { let mut v = Vec::new(); diff --git a/biscuit-auth/samples/README.md b/biscuit-auth/samples/README.md index 0a69b184..64cede9c 100644 --- a/biscuit-auth/samples/README.md +++ b/biscuit-auth/samples/README.md @@ -3139,3 +3139,51 @@ World { result: `Ok(0)` + +------------------------------ + +## test ffi calls (v6 blocks): test035_ffi.bc +### token + +authority: +symbols: ["test", "a", "equal strings"] + +public keys: [] + +``` +check if true.extern::test(), "a".extern::test("a") == "equal strings"; +``` + +### validation + +authorizer code: +``` +allow if true; +``` + +revocation ids: +- `faf26fe6f5dfa08c114a0a29321405b6fb7be79b0d80694d27925f7deb01effe5707600e42fd74f9a1d2920466446d51949155f4548f0fd68f3e9326c7e12404` + +authorizer world: +``` +World { + facts: [] + rules: [] + checks: [ + Checks { + origin: Some( + 0, + ), + checks: [ + "check if true.extern::test(), \"a\".extern::test(\"a\") == \"equal strings\"", + ], + }, +] + policies: [ + "allow if true", +] +} +``` + +result: `Ok(0)` + diff --git a/biscuit-auth/samples/samples.json b/biscuit-auth/samples/samples.json index 1f2e3c3d..4877aac5 100644 --- a/biscuit-auth/samples/samples.json +++ b/biscuit-auth/samples/samples.json @@ -2913,6 +2913,48 @@ ] } } + }, + { + "title": "test ffi calls (v6 blocks)", + "filename": "test035_ffi.bc", + "token": [ + { + "symbols": [ + "test", + "a", + "equal strings" + ], + "public_keys": [], + "external_key": null, + "code": "check if true.extern::test(), \"a\".extern::test(\"a\") == \"equal strings\";\n" + } + ], + "validations": { + "": { + "world": { + "facts": [], + "rules": [], + "checks": [ + { + "origin": 0, + "checks": [ + "check if true.extern::test(), \"a\".extern::test(\"a\") == \"equal strings\"" + ] + } + ], + "policies": [ + "allow if true" + ] + }, + "result": { + "Ok": 0 + }, + "authorizer_code": "allow if true;\n", + "revocation_ids": [ + "faf26fe6f5dfa08c114a0a29321405b6fb7be79b0d80694d27925f7deb01effe5707600e42fd74f9a1d2920466446d51949155f4548f0fd68f3e9326c7e12404" + ] + } + } } ] } diff --git a/biscuit-auth/samples/test035_ffi.bc b/biscuit-auth/samples/test035_ffi.bc new file mode 100644 index 00000000..d5bb3a8d Binary files /dev/null and b/biscuit-auth/samples/test035_ffi.bc differ diff --git a/biscuit-auth/src/datalog/expression.rs b/biscuit-auth/src/datalog/expression.rs index 67d9e6f8..ecb47a28 100644 --- a/biscuit-auth/src/datalog/expression.rs +++ b/biscuit-auth/src/datalog/expression.rs @@ -1,10 +1,55 @@ -use crate::error; +use crate::{builder, error}; -use super::{MapKey, Term}; +use super::{MapKey, SymbolIndex, Term}; use super::{SymbolTable, TemporarySymbolTable}; use regex::Regex; -use std::collections::{HashMap, HashSet}; -use std::convert::TryFrom; +use std::sync::Arc; +use std::{ + collections::{HashMap, HashSet}, + convert::TryFrom, +}; + +#[derive(Clone)] +pub struct ExternFunc( + pub Arc< + dyn Fn(builder::Term, Option) -> Result + Send + Sync, + >, +); + +impl std::fmt::Debug for ExternFunc { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "") + } +} + +impl ExternFunc { + pub fn new( + f: Arc< + dyn Fn(builder::Term, Option) -> Result + + Send + + Sync, + >, + ) -> Self { + Self(f) + } + + pub fn call( + &self, + symbols: &mut TemporarySymbolTable, + name: &str, + left: Term, + right: Option, + ) -> Result { + let left = builder::Term::from_datalog(left, symbols)?; + let right = right + .map(|right| builder::Term::from_datalog(right, symbols)) + .transpose()?; + match self.0(left, right) { + Ok(t) => Ok(t.to_datalog(symbols)), + Err(e) => Err(error::Expression::ExternEvalError(name.to_string(), e)), + } + } +} #[derive(Debug, Clone, PartialEq, Hash, Eq)] pub struct Expression { @@ -26,6 +71,7 @@ pub enum Unary { Parens, Length, TypeOf, + Ffi(SymbolIndex), } impl Unary { @@ -33,6 +79,7 @@ impl Unary { &self, value: Term, symbols: &mut TemporarySymbolTable, + extern_funcs: &HashMap, ) -> Result { match (self, value) { (Unary::Negate, Term::Bool(b)) => Ok(Term::Bool(!b)), @@ -61,6 +108,16 @@ impl Unary { let sym = symbols.insert(type_string); Ok(Term::Str(sym)) } + (Unary::Ffi(name), i) => { + let name = symbols + .get_symbol(*name) + .ok_or(error::Expression::UnknownSymbol(*name))? + .to_owned(); + let fun = extern_funcs + .get(&name) + .ok_or(error::Expression::UndefinedExtern(name.to_owned()))?; + fun.call(symbols, &name, i, None) + } _ => { //println!("unexpected value type on the stack"); Err(error::Expression::InvalidType) @@ -68,12 +125,15 @@ impl Unary { } } - pub fn print(&self, value: String, _symbols: &SymbolTable) -> String { + pub fn print(&self, value: String, symbols: &SymbolTable) -> String { match self { Unary::Negate => format!("!{}", value), Unary::Parens => format!("({})", value), Unary::Length => format!("{}.length()", value), Unary::TypeOf => format!("{}.type()", value), + Unary::Ffi(name) => { + format!("{value}.extern::{}()", symbols.print_symbol_default(*name)) + } } } } @@ -109,6 +169,7 @@ pub enum Binary { All, Any, Get, + Ffi(SymbolIndex), } impl Binary { @@ -119,18 +180,19 @@ impl Binary { params: &[u32], values: &mut HashMap, symbols: &mut TemporarySymbolTable, + extern_func: &HashMap, ) -> Result { match (self, left, params) { // boolean (Binary::LazyOr, Term::Bool(true), []) => Ok(Term::Bool(true)), (Binary::LazyOr, Term::Bool(false), []) => { let e = Expression { ops: right.clone() }; - e.evaluate(values, symbols) + e.evaluate(values, symbols, extern_func) } (Binary::LazyAnd, Term::Bool(false), []) => Ok(Term::Bool(false)), (Binary::LazyAnd, Term::Bool(true), []) => { let e = Expression { ops: right.clone() }; - e.evaluate(values, symbols) + e.evaluate(values, symbols, extern_func) } // set @@ -138,7 +200,7 @@ impl Binary { for value in set_values.iter() { values.insert(*param, value.clone()); let e = Expression { ops: right.clone() }; - let result = e.evaluate(values, symbols); + let result = e.evaluate(values, symbols, extern_func); values.remove(param); match result? { Term::Bool(true) => {} @@ -152,7 +214,7 @@ impl Binary { for value in set_values.iter() { values.insert(*param, value.clone()); let e = Expression { ops: right.clone() }; - let result = e.evaluate(values, symbols); + let result = e.evaluate(values, symbols, extern_func); values.remove(param); match result? { Term::Bool(false) => {} @@ -168,7 +230,7 @@ impl Binary { for value in array.iter() { values.insert(*param, value.clone()); let e = Expression { ops: right.clone() }; - let result = e.evaluate(values, symbols); + let result = e.evaluate(values, symbols, extern_func); values.remove(param); match result? { Term::Bool(true) => {} @@ -182,7 +244,7 @@ impl Binary { for value in array.iter() { values.insert(*param, value.clone()); let e = Expression { ops: right.clone() }; - let result = e.evaluate(values, symbols); + let result = e.evaluate(values, symbols, extern_func); values.remove(param); match result? { Term::Bool(false) => {} @@ -203,7 +265,7 @@ impl Binary { values.insert(*param, Term::Array(vec![key, value.clone()])); let e = Expression { ops: right.clone() }; - let result = e.evaluate(values, symbols); + let result = e.evaluate(values, symbols, extern_func); values.remove(param); match result? { Term::Bool(true) => {} @@ -222,7 +284,7 @@ impl Binary { values.insert(*param, Term::Array(vec![key, value.clone()])); let e = Expression { ops: right.clone() }; - let result = e.evaluate(values, symbols); + let result = e.evaluate(values, symbols, extern_func); values.remove(param); match result? { Term::Bool(false) => {} @@ -240,6 +302,7 @@ impl Binary { left: Term, right: Term, symbols: &mut TemporarySymbolTable, + extern_funcs: &HashMap, ) -> Result { match (self, left, right) { // integer @@ -438,9 +501,22 @@ impl Binary { None => Ok(Term::Null), }, + // heterogeneous equals catch all (Binary::HeterogeneousEqual, _, _) => Ok(Term::Bool(false)), (Binary::HeterogeneousNotEqual, _, _) => Ok(Term::Bool(true)), + // FFI + (Binary::Ffi(name), left, right) => { + let name = symbols + .get_symbol(*name) + .ok_or(error::Expression::UnknownSymbol(*name))? + .to_owned(); + let fun = extern_funcs + .get(&name) + .ok_or(error::Expression::UndefinedExtern(name.to_owned()))?; + fun.call(symbols, &name, left, Some(right)) + } + _ => { //println!("unexpected value type on the stack"); Err(error::Expression::InvalidType) @@ -448,7 +524,7 @@ impl Binary { } } - pub fn print(&self, left: String, right: String, _symbols: &SymbolTable) -> String { + pub fn print(&self, left: String, right: String, symbols: &SymbolTable) -> String { match self { Binary::LessThan => format!("{} < {}", left, right), Binary::GreaterThan => format!("{} > {}", left, right), @@ -478,6 +554,10 @@ impl Binary { Binary::All => format!("{left}.all({right})"), Binary::Any => format!("{left}.any({right})"), Binary::Get => format!("{left}.get({right})"), + Binary::Ffi(name) => format!( + "{left}.extern::{}({right})", + symbols.print_symbol_default(*name) + ), } } } @@ -493,6 +573,7 @@ impl Expression { &self, values: &HashMap, symbols: &mut TemporarySymbolTable, + extern_funcs: &HashMap, ) -> Result { let mut stack: Vec = Vec::new(); @@ -508,19 +589,24 @@ impl Expression { } }, Op::Value(term) => stack.push(StackElem::Term(term.clone())), - Op::Unary(unary) => match stack.pop() { - Some(StackElem::Term(term)) => { - stack.push(StackElem::Term(unary.evaluate(term, symbols)?)) - } - _ => { - return Err(error::Expression::InvalidStack); + Op::Unary(unary) => { + match stack.pop() { + Some(StackElem::Term(term)) => stack.push(StackElem::Term( + unary.evaluate(term, symbols, extern_funcs)?, + )), + _ => { + return Err(error::Expression::InvalidStack); + } } - }, + } Op::Binary(binary) => match (stack.pop(), stack.pop()) { (Some(StackElem::Term(right_term)), Some(StackElem::Term(left_term))) => stack - .push(StackElem::Term( - binary.evaluate(left_term, right_term, symbols)?, - )), + .push(StackElem::Term(binary.evaluate( + left_term, + right_term, + symbols, + extern_funcs, + )?)), ( Some(StackElem::Closure(params, right_ops)), Some(StackElem::Term(left_term)), @@ -541,6 +627,7 @@ impl Expression { ¶ms, &mut values, symbols, + extern_funcs, )?)) } @@ -610,7 +697,7 @@ impl Expression { #[cfg(test)] mod tests { - use std::collections::BTreeSet; + use std::collections::{BTreeMap, BTreeSet}; use super::*; use crate::datalog::{MapKey, SymbolTable, TemporarySymbolTable}; @@ -638,7 +725,7 @@ mod tests { let e = Expression { ops }; println!("print: {}", e.print(&symbols).unwrap()); - let res = e.evaluate(&values, &mut tmp_symbols); + let res = e.evaluate(&values, &mut tmp_symbols, &Default::default()); assert_eq!(res, Ok(Term::Bool(true))); } @@ -668,7 +755,7 @@ mod tests { let e = Expression { ops }; println!("print: {}", e.print(&symbols).unwrap()); - let res = e.evaluate(&HashMap::new(), &mut tmp_symbols); + let res = e.evaluate(&HashMap::new(), &mut tmp_symbols, &Default::default()); assert_eq!(res, Ok(Term::Integer(expected))); } } @@ -685,7 +772,7 @@ mod tests { let values = HashMap::new(); let e = Expression { ops }; - let res = e.evaluate(&values, &mut tmp_symbols); + let res = e.evaluate(&values, &mut tmp_symbols, &Default::default()); assert_eq!(res, Err(error::Expression::DivideByZero)); let ops = vec![ @@ -696,7 +783,7 @@ mod tests { let values = HashMap::new(); let e = Expression { ops }; - let res = e.evaluate(&values, &mut tmp_symbols); + let res = e.evaluate(&values, &mut tmp_symbols, &Default::default()); assert_eq!(res, Err(error::Expression::Overflow)); let ops = vec![ @@ -707,7 +794,7 @@ mod tests { let values = HashMap::new(); let e = Expression { ops }; - let res = e.evaluate(&values, &mut tmp_symbols); + let res = e.evaluate(&values, &mut tmp_symbols, &Default::default()); assert_eq!(res, Err(error::Expression::Overflow)); let ops = vec![ @@ -718,7 +805,7 @@ mod tests { let values = HashMap::new(); let e = Expression { ops }; - let res = e.evaluate(&values, &mut tmp_symbols); + let res = e.evaluate(&values, &mut tmp_symbols, &Default::default()); assert_eq!(res, Err(error::Expression::Overflow)); } @@ -785,7 +872,7 @@ mod tests { let e = Expression { ops }; println!("print: {}", e.print(&symbols).unwrap()); - let res = e.evaluate(&values, &mut tmp_symbols); + let res = e.evaluate(&values, &mut tmp_symbols, &Default::default()); assert_eq!(res, Ok(Term::Bool(true))); } } @@ -809,7 +896,7 @@ mod tests { let e = Expression { ops }; println!("print: {}", e.print(&symbols).unwrap()); - let res = e.evaluate(&values, &mut tmp_symbols); + let res = e.evaluate(&values, &mut tmp_symbols, &Default::default()); assert_eq!(res, Ok(Term::Bool(false))); } } @@ -833,7 +920,7 @@ mod tests { let e = Expression { ops }; println!("print: {}", e.print(&symbols).unwrap()); - let res = e.evaluate(&values, &mut tmp_symbols); + let res = e.evaluate(&values, &mut tmp_symbols, &Default::default()); assert_eq!(res, Ok(Term::Bool(result))); } } @@ -877,7 +964,7 @@ mod tests { let e = Expression { ops }; println!("print: {}", e.print(&symbols).unwrap()); - let res = e.evaluate(&values, &mut tmp_symbols); + let res = e.evaluate(&values, &mut tmp_symbols, &Default::default()); assert_eq!(res, Ok(Term::Bool(*result))); } } @@ -916,7 +1003,8 @@ mod tests { let e = Expression { ops }; println!("print: {}", e.print(&symbols).unwrap()); - e.evaluate(&values, &mut tmp_symbols).unwrap_err(); + e.evaluate(&values, &mut tmp_symbols, &Default::default()) + .unwrap_err(); } } } @@ -941,7 +1029,9 @@ mod tests { ]; let e2 = Expression { ops: ops1 }; - let res2 = e2.evaluate(&HashMap::new(), &mut symbols).unwrap(); + let res2 = e2 + .evaluate(&HashMap::new(), &mut symbols, &Default::default()) + .unwrap(); assert_eq!(res2, Term::Bool(true)); } @@ -959,7 +1049,9 @@ mod tests { let e1 = Expression { ops: ops1 }; println!("{:?}", e1.print(&symbols)); - let res1 = e1.evaluate(&HashMap::new(), &mut tmp_symbols).unwrap(); + let res1 = e1 + .evaluate(&HashMap::new(), &mut tmp_symbols, &Default::default()) + .unwrap(); assert_eq!(res1, Term::Bool(true)); let ops2 = vec![ @@ -977,7 +1069,9 @@ mod tests { let e2 = Expression { ops: ops2 }; println!("{:?}", e2.print(&symbols)); - let res2 = e2.evaluate(&HashMap::new(), &mut tmp_symbols).unwrap(); + let res2 = e2 + .evaluate(&HashMap::new(), &mut tmp_symbols, &Default::default()) + .unwrap(); assert_eq!(res2, Term::Bool(false)); let ops3 = vec![ @@ -988,7 +1082,9 @@ mod tests { let e3 = Expression { ops: ops3 }; println!("{:?}", e3.print(&symbols)); - let err3 = e3.evaluate(&HashMap::new(), &mut tmp_symbols).unwrap_err(); + let err3 = e3 + .evaluate(&HashMap::new(), &mut tmp_symbols, &Default::default()) + .unwrap_err(); assert_eq!(err3, error::Expression::InvalidType); } @@ -1013,7 +1109,9 @@ mod tests { let e1 = Expression { ops: ops1 }; println!("{:?}", e1.print(&symbols)); - let res1 = e1.evaluate(&HashMap::new(), &mut tmp_symbols).unwrap(); + let res1 = e1 + .evaluate(&HashMap::new(), &mut tmp_symbols, &Default::default()) + .unwrap(); assert_eq!(res1, Term::Bool(true)); let ops2 = vec![ @@ -1031,7 +1129,9 @@ mod tests { let e2 = Expression { ops: ops2 }; println!("{:?}", e2.print(&symbols)); - let res2 = e2.evaluate(&HashMap::new(), &mut tmp_symbols).unwrap(); + let res2 = e2 + .evaluate(&HashMap::new(), &mut tmp_symbols, &Default::default()) + .unwrap(); assert_eq!(res2, Term::Bool(false)); let ops3 = vec![ @@ -1042,7 +1142,9 @@ mod tests { let e3 = Expression { ops: ops3 }; println!("{:?}", e3.print(&symbols)); - let err3 = e3.evaluate(&HashMap::new(), &mut tmp_symbols).unwrap_err(); + let err3 = e3 + .evaluate(&HashMap::new(), &mut tmp_symbols, &Default::default()) + .unwrap_err(); assert_eq!(err3, error::Expression::InvalidType); } @@ -1088,7 +1190,9 @@ mod tests { let e1 = Expression { ops: ops1 }; println!("{}", e1.print(&symbols).unwrap()); - let res1 = e1.evaluate(&HashMap::new(), &mut tmp_symbols).unwrap(); + let res1 = e1 + .evaluate(&HashMap::new(), &mut tmp_symbols, &Default::default()) + .unwrap(); assert_eq!(res1, Term::Bool(true)); } @@ -1115,7 +1219,7 @@ mod tests { let mut values = HashMap::new(); values.insert(p, Term::Null); - let res1 = e1.evaluate(&values, &mut tmp_symbols); + let res1 = e1.evaluate(&values, &mut tmp_symbols, &Default::default()); assert_eq!(res1, Err(error::Expression::ShadowedVariable)); let mut symbols = SymbolTable::new(); @@ -1157,7 +1261,7 @@ mod tests { let e2 = Expression { ops: ops2 }; println!("{}", e2.print(&symbols).unwrap()); - let res2 = e2.evaluate(&HashMap::new(), &mut tmp_symbols); + let res2 = e2.evaluate(&HashMap::new(), &mut tmp_symbols, &Default::default()); assert_eq!(res2, Err(error::Expression::ShadowedVariable)); } @@ -1173,7 +1277,7 @@ mod tests { let values = HashMap::new(); let e = Expression { ops }; - let res = e.evaluate(&values, &mut tmp_symbols); + let res = e.evaluate(&values, &mut tmp_symbols, &Default::default()); assert_eq!(res, Ok(Term::Bool(true))); let ops = vec![ @@ -1184,7 +1288,7 @@ mod tests { let values = HashMap::new(); let e = Expression { ops }; - let res = e.evaluate(&values, &mut tmp_symbols); + let res = e.evaluate(&values, &mut tmp_symbols, &Default::default()); assert_eq!(res, Ok(Term::Bool(false))); let ops = vec![ @@ -1195,7 +1299,7 @@ mod tests { let values = HashMap::new(); let e = Expression { ops }; - let res = e.evaluate(&values, &mut tmp_symbols); + let res = e.evaluate(&values, &mut tmp_symbols, &Default::default()); assert_eq!(res, Ok(Term::Bool(true))); let ops = vec![ @@ -1206,7 +1310,7 @@ mod tests { let values = HashMap::new(); let e = Expression { ops }; - let res = e.evaluate(&values, &mut tmp_symbols); + let res = e.evaluate(&values, &mut tmp_symbols, &Default::default()); assert_eq!(res, Ok(Term::Bool(false))); let ops = vec![ @@ -1221,7 +1325,7 @@ mod tests { let values = HashMap::new(); let e = Expression { ops }; - let res = e.evaluate(&values, &mut tmp_symbols); + let res = e.evaluate(&values, &mut tmp_symbols, &Default::default()); assert_eq!(res, Ok(Term::Bool(true))); let ops = vec![ @@ -1236,7 +1340,7 @@ mod tests { let values = HashMap::new(); let e = Expression { ops }; - let res = e.evaluate(&values, &mut tmp_symbols); + let res = e.evaluate(&values, &mut tmp_symbols, &Default::default()); assert_eq!(res, Ok(Term::Bool(false))); let ops = vec![ @@ -1251,7 +1355,7 @@ mod tests { let values = HashMap::new(); let e = Expression { ops }; - let res = e.evaluate(&values, &mut tmp_symbols); + let res = e.evaluate(&values, &mut tmp_symbols, &Default::default()); assert_eq!(res, Ok(Term::Bool(true))); let ops = vec![ @@ -1266,7 +1370,7 @@ mod tests { let values = HashMap::new(); let e = Expression { ops }; - let res = e.evaluate(&values, &mut tmp_symbols); + let res = e.evaluate(&values, &mut tmp_symbols, &Default::default()); assert_eq!(res, Ok(Term::Bool(false))); // get @@ -1282,7 +1386,7 @@ mod tests { let values = HashMap::new(); let e = Expression { ops }; - let res = e.evaluate(&values, &mut tmp_symbols); + let res = e.evaluate(&values, &mut tmp_symbols, &Default::default()); assert_eq!(res, Ok(Term::Integer(1))); // get out of bounds @@ -1298,7 +1402,7 @@ mod tests { let values = HashMap::new(); let e = Expression { ops }; - let res = e.evaluate(&values, &mut tmp_symbols); + let res = e.evaluate(&values, &mut tmp_symbols, &Default::default()); assert_eq!(res, Ok(Term::Null)); // all @@ -1318,7 +1422,9 @@ mod tests { let e1 = Expression { ops: ops1 }; println!("{:?}", e1.print(&symbols)); - let res1 = e1.evaluate(&HashMap::new(), &mut tmp_symbols).unwrap(); + let res1 = e1 + .evaluate(&HashMap::new(), &mut tmp_symbols, &Default::default()) + .unwrap(); assert_eq!(res1, Term::Bool(true)); // any @@ -1337,7 +1443,9 @@ mod tests { let e1 = Expression { ops: ops1 }; println!("{:?}", e1.print(&symbols)); - let res1 = e1.evaluate(&HashMap::new(), &mut tmp_symbols).unwrap(); + let res1 = e1 + .evaluate(&HashMap::new(), &mut tmp_symbols, &Default::default()) + .unwrap(); assert_eq!(res1, Term::Bool(false)); } @@ -1371,7 +1479,7 @@ mod tests { let values = HashMap::new(); let e = Expression { ops }; - let res = e.evaluate(&values, &mut tmp_symbols); + let res = e.evaluate(&values, &mut tmp_symbols, &Default::default()); assert_eq!(res, Ok(Term::Bool(true))); let ops = vec![ @@ -1395,7 +1503,7 @@ mod tests { let values = HashMap::new(); let e = Expression { ops }; - let res = e.evaluate(&values, &mut tmp_symbols); + let res = e.evaluate(&values, &mut tmp_symbols, &Default::default()); assert_eq!(res, Ok(Term::Bool(false))); let ops = vec![ @@ -1414,7 +1522,7 @@ mod tests { let values = HashMap::new(); let e = Expression { ops }; - let res = e.evaluate(&values, &mut tmp_symbols); + let res = e.evaluate(&values, &mut tmp_symbols, &Default::default()); assert_eq!(res, Ok(Term::Bool(true))); let ops = vec![ @@ -1433,7 +1541,7 @@ mod tests { let values = HashMap::new(); let e = Expression { ops }; - let res = e.evaluate(&values, &mut tmp_symbols); + let res = e.evaluate(&values, &mut tmp_symbols, &Default::default()); assert_eq!(res, Ok(Term::Bool(false))); // get @@ -1453,7 +1561,7 @@ mod tests { let values = HashMap::new(); let e = Expression { ops }; - let res = e.evaluate(&values, &mut tmp_symbols); + let res = e.evaluate(&values, &mut tmp_symbols, &Default::default()); assert_eq!(res, Ok(Term::Integer(0))); let ops = vec![ @@ -1472,7 +1580,7 @@ mod tests { let values = HashMap::new(); let e = Expression { ops }; - let res = e.evaluate(&values, &mut tmp_symbols); + let res = e.evaluate(&values, &mut tmp_symbols, &Default::default()); assert_eq!(res, Ok(Term::Integer(1))); // get non existing key @@ -1492,7 +1600,7 @@ mod tests { let values = HashMap::new(); let e = Expression { ops }; - let res = e.evaluate(&values, &mut tmp_symbols); + let res = e.evaluate(&values, &mut tmp_symbols, &Default::default()); assert_eq!(res, Ok(Term::Null)); let ops = vec![ @@ -1511,7 +1619,7 @@ mod tests { let values = HashMap::new(); let e = Expression { ops }; - let res = e.evaluate(&values, &mut tmp_symbols); + let res = e.evaluate(&values, &mut tmp_symbols, &Default::default()); assert_eq!(res, Ok(Term::Null)); // all @@ -1540,7 +1648,9 @@ mod tests { let e1 = Expression { ops: ops1 }; println!("{:?}", e1.print(&symbols)); - let res1 = e1.evaluate(&HashMap::new(), &mut tmp_symbols).unwrap(); + let res1 = e1 + .evaluate(&HashMap::new(), &mut tmp_symbols, &Default::default()) + .unwrap(); assert_eq!(res1, Term::Bool(true)); // any @@ -1569,7 +1679,154 @@ mod tests { let e1 = Expression { ops: ops1 }; println!("{:?}", e1.print(&symbols)); - let res1 = e1.evaluate(&HashMap::new(), &mut tmp_symbols).unwrap(); + let res1 = e1 + .evaluate(&HashMap::new(), &mut tmp_symbols, &Default::default()) + .unwrap(); assert_eq!(res1, Term::Bool(true)); } + #[test] + fn ffi() { + let mut symbols = SymbolTable::new(); + let i = symbols.insert("test"); + let j = symbols.insert("TeSt"); + let test_bin = symbols.insert("test_bin"); + let test_un = symbols.insert("test_un"); + let test_closure = symbols.insert("test_closure"); + let test_fn = symbols.insert("test_fn"); + let id_fn = symbols.insert("id"); + let mut tmp_symbols = TemporarySymbolTable::new(&symbols); + let ops = vec![ + Op::Value(Term::Integer(60)), + Op::Value(Term::Integer(0)), + Op::Binary(Binary::Ffi(test_bin)), + Op::Value(Term::Str(i)), + Op::Value(Term::Str(j)), + Op::Binary(Binary::Ffi(test_bin)), + Op::Binary(Binary::And), + Op::Value(Term::Integer(42)), + Op::Unary(Unary::Ffi(test_un)), + Op::Binary(Binary::And), + Op::Value(Term::Integer(42)), + Op::Unary(Unary::Ffi(test_closure)), + Op::Binary(Binary::And), + Op::Value(Term::Str(i)), + Op::Unary(Unary::Ffi(test_closure)), + Op::Binary(Binary::And), + Op::Value(Term::Integer(42)), + Op::Unary(Unary::Ffi(test_fn)), + Op::Binary(Binary::And), + Op::Value(Term::Integer(42)), + Op::Unary(Unary::Ffi(id_fn)), + Op::Value(Term::Integer(42)), + Op::Binary(Binary::HeterogeneousEqual), + Op::Binary(Binary::And), + Op::Value(Term::Str(i)), + Op::Unary(Unary::Ffi(id_fn)), + Op::Value(Term::Str(i)), + Op::Binary(Binary::HeterogeneousEqual), + Op::Binary(Binary::And), + Op::Value(Term::Bool(true)), + Op::Unary(Unary::Ffi(id_fn)), + Op::Value(Term::Bool(true)), + Op::Binary(Binary::HeterogeneousEqual), + Op::Binary(Binary::And), + Op::Value(Term::Date(0)), + Op::Unary(Unary::Ffi(id_fn)), + Op::Value(Term::Date(0)), + Op::Binary(Binary::HeterogeneousEqual), + Op::Binary(Binary::And), + Op::Value(Term::Bytes(vec![42])), + Op::Unary(Unary::Ffi(id_fn)), + Op::Value(Term::Bytes(vec![42])), + Op::Binary(Binary::HeterogeneousEqual), + Op::Binary(Binary::And), + Op::Value(Term::Null), + Op::Unary(Unary::Ffi(id_fn)), + Op::Value(Term::Null), + Op::Binary(Binary::HeterogeneousEqual), + Op::Binary(Binary::And), + Op::Value(Term::Array(vec![Term::Null])), + Op::Unary(Unary::Ffi(id_fn)), + Op::Value(Term::Array(vec![Term::Null])), + Op::Binary(Binary::HeterogeneousEqual), + Op::Binary(Binary::And), + Op::Value(Term::Set(BTreeSet::from([Term::Null]))), + Op::Unary(Unary::Ffi(id_fn)), + Op::Value(Term::Set(BTreeSet::from([Term::Null]))), + Op::Binary(Binary::HeterogeneousEqual), + Op::Binary(Binary::And), + Op::Value(Term::Map(BTreeMap::from([ + (MapKey::Integer(42), Term::Null), + (MapKey::Str(i), Term::Null), + ]))), + Op::Unary(Unary::Ffi(id_fn)), + Op::Value(Term::Map(BTreeMap::from([ + (MapKey::Integer(42), Term::Null), + (MapKey::Str(i), Term::Null), + ]))), + Op::Binary(Binary::HeterogeneousEqual), + Op::Binary(Binary::And), + ]; + + let values = HashMap::new(); + let e = Expression { ops }; + let mut extern_funcs: HashMap = Default::default(); + extern_funcs.insert( + "test_bin".to_owned(), + ExternFunc::new(Arc::new(|left, right| match (left, right) { + (builder::Term::Integer(left), Some(builder::Term::Integer(right))) => { + println!("{left} {right}"); + Ok(builder::Term::Bool((left % 60) == (right % 60))) + } + (builder::Term::Str(left), Some(builder::Term::Str(right))) => { + println!("{left} {right}"); + Ok(builder::Term::Bool( + left.to_lowercase() == right.to_lowercase(), + )) + } + _ => Err("Expected two strings or two integers".to_string()), + })), + ); + extern_funcs.insert( + "test_un".to_owned(), + ExternFunc::new(Arc::new(|left, right| match (&left, &right) { + (builder::Term::Integer(left), None) => Ok(builder::boolean(*left == 42)), + _ => { + println!("{left:?}, {right:?}"); + Err("expecting a single integer".to_string()) + } + })), + ); + extern_funcs.insert( + "id".to_string(), + ExternFunc::new(Arc::new(|left, right| match (left, right) { + (a, None) => Ok(a), + _ => Err("expecting a single value".to_string()), + })), + ); + let closed_over_int = 42; + let closed_over_string = "test".to_string(); + extern_funcs.insert( + "test_closure".to_owned(), + ExternFunc::new(Arc::new(move |left, right| match (&left, &right) { + (builder::Term::Integer(left), None) => { + Ok(builder::boolean(*left == closed_over_int)) + } + (builder::Term::Str(left), None) => { + Ok(builder::boolean(left == &closed_over_string)) + } + _ => { + println!("{left:?}, {right:?}"); + Err("expecting a single integer".to_string()) + } + })), + ); + extern_funcs.insert("test_fn".to_owned(), ExternFunc::new(Arc::new(toto))); + let res = e.evaluate(&values, &mut tmp_symbols, &extern_funcs); + assert_eq!(res, Ok(Term::Bool(true))); + } + + fn toto(_left: builder::Term, _right: Option) -> Result { + Ok(builder::Term::Bool(true)) + } } diff --git a/biscuit-auth/src/datalog/mod.rs b/biscuit-auth/src/datalog/mod.rs index f722f00c..4ddf5bcb 100644 --- a/biscuit-auth/src/datalog/mod.rs +++ b/biscuit-auth/src/datalog/mod.rs @@ -138,6 +138,7 @@ impl Rule { facts: IT, rule_origin: usize, symbols: &'a SymbolTable, + extern_funcs: &'a HashMap, ) -> impl Iterator> + 'a where IT: Iterator + Clone + 'a, @@ -149,7 +150,7 @@ impl Rule { .map(move |(origin, variables)| { let mut temporary_symbols = TemporarySymbolTable::new(symbols); for e in self.expressions.iter() { - match e.evaluate(&variables, &mut temporary_symbols) { + match e.evaluate(&variables, &mut temporary_symbols, extern_funcs) { Ok(Term::Bool(true)) => {} Ok(Term::Bool(false)) => return Ok((origin, variables, false)), Ok(_) => return Err(error::Expression::InvalidType), @@ -194,9 +195,10 @@ impl Rule { origin: usize, scope: &TrustedOrigins, symbols: &SymbolTable, + extern_funcs: &HashMap, ) -> Result { let fact_it = facts.iterator(scope); - let mut it = self.apply(fact_it, origin, symbols); + let mut it = self.apply(fact_it, origin, symbols, extern_funcs); let next = it.next(); match next { @@ -211,6 +213,7 @@ impl Rule { facts: &FactSet, scope: &TrustedOrigins, symbols: &SymbolTable, + extern_funcs: &HashMap, ) -> Result { let fact_it = facts.iterator(scope); let variables = MatchedVariables::new(self.variables_set()); @@ -221,7 +224,7 @@ impl Rule { let mut temporary_symbols = TemporarySymbolTable::new(symbols); for e in self.expressions.iter() { - match e.evaluate(&variables, &mut temporary_symbols) { + match e.evaluate(&variables, &mut temporary_symbols, extern_funcs) { Ok(Term::Bool(true)) => {} Ok(Term::Bool(false)) => { //println!("expr returned {:?}", res); @@ -619,7 +622,7 @@ impl World { for (scope, rules) in self.rules.inner.iter() { let it = self.facts.iterator(scope); for (origin, rule) in rules { - for res in rule.apply(it.clone(), *origin, symbols) { + for res in rule.apply(it.clone(), *origin, symbols, &limits.extern_funcs) { match res { Ok((origin, fact)) => { new_facts.insert(&origin, fact); @@ -690,11 +693,12 @@ impl World { origin: usize, scope: &TrustedOrigins, symbols: &SymbolTable, + extern_funcs: &HashMap, ) -> Result { let mut new_facts = FactSet::default(); let it = self.facts.iterator(scope); //new_facts.extend(rule.apply(it, origin, symbols)); - for res in rule.apply(it.clone(), origin, symbols) { + for res in rule.apply(it.clone(), origin, symbols, extern_funcs) { match res { Ok((origin, fact)) => { new_facts.insert(&origin, fact); @@ -714,8 +718,9 @@ impl World { origin: usize, scope: &TrustedOrigins, symbols: &SymbolTable, + extern_funcs: &HashMap, ) -> Result { - rule.find_match(&self.facts, origin, scope, symbols) + rule.find_match(&self.facts, origin, scope, symbols, extern_funcs) } pub fn query_match_all( @@ -723,8 +728,9 @@ impl World { rule: Rule, scope: &TrustedOrigins, symbols: &SymbolTable, + extern_funcs: &HashMap, ) -> Result { - rule.check_match_all(&self.facts, scope, symbols) + rule.check_match_all(&self.facts, scope, symbols, extern_funcs) } } @@ -737,6 +743,8 @@ pub struct RunLimits { pub max_iterations: u64, /// maximum execution time pub max_time: Duration, + + pub extern_funcs: HashMap, } impl std::default::Default for RunLimits { @@ -745,6 +753,7 @@ impl std::default::Default for RunLimits { max_facts: 1000, max_iterations: 100, max_time: Duration::from_millis(1), + extern_funcs: Default::default(), } } } @@ -981,7 +990,7 @@ fn contains_v3_3_op(expressions: &[Expression]) -> bool { expression.ops.iter().any(|op| match op { Op::Value(term) => contains_v3_3_term(term), Op::Closure(_, _) => true, - Op::Unary(Unary::TypeOf) => true, + Op::Unary(unary) => matches!(unary, Unary::TypeOf | Unary::Ffi(_)), Op::Binary(binary) => matches!( binary, Binary::HeterogeneousEqual @@ -990,8 +999,8 @@ fn contains_v3_3_op(expressions: &[Expression]) -> bool { | Binary::LazyOr | Binary::All | Binary::Any + | Binary::Ffi(_) ), - _ => false, }) }) } @@ -1047,7 +1056,8 @@ mod tests { println!("symbols: {:?}", syms); println!("testing r1: {}", syms.print_rule(&r1)); - let query_rule_result = w.query_rule(r1, 0, &[0].iter().collect(), &syms); + let query_rule_result = + w.query_rule(r1, 0, &[0].iter().collect(), &syms, &Default::default()); println!("grandparents query_rules: {:?}", query_rule_result); println!("current facts: {:?}", w.facts); @@ -1092,6 +1102,7 @@ mod tests { 0, &[0].iter().collect(), &syms, + &Default::default(), ) .unwrap(); @@ -1109,7 +1120,8 @@ mod tests { ), 0, &[0].iter().collect(), - &syms + &syms, + &Default::default() ) ); println!( @@ -1125,7 +1137,8 @@ mod tests { ), 0, &[0].iter().collect(), - &syms + &syms, + &Default::default() ) ); w.add_fact(&[0].iter().collect(), fact(parent, &[&c, &e])); @@ -1143,6 +1156,7 @@ mod tests { 0, &[0].iter().collect(), &syms, + &Default::default(), ) .unwrap(); println!("grandparents after inserting parent(C, E): {:?}", res); @@ -1218,6 +1232,7 @@ mod tests { 0, &[0].iter().collect(), &syms, + &Default::default(), ) .unwrap(); @@ -1267,6 +1282,7 @@ mod tests { 0, &[0].iter().collect(), &syms, + &Default::default(), ) .unwrap(); @@ -1353,6 +1369,7 @@ mod tests { 0, &[0].iter().collect(), &syms, + &Default::default(), ) .unwrap() .iter_all() @@ -1433,7 +1450,9 @@ mod tests { ); println!("testing r1: {}", syms.print_rule(&r1)); - let res = w.query_rule(r1, 0, &[0].iter().collect(), &syms).unwrap(); + let res = w + .query_rule(r1, 0, &[0].iter().collect(), &syms, &Default::default()) + .unwrap(); for (_, fact) in res.iter_all() { println!("\t{}", syms.print_fact(fact)); } @@ -1471,7 +1490,9 @@ mod tests { ); println!("testing r2: {}", syms.print_rule(&r2)); - let res = w.query_rule(r2, 0, &[0].iter().collect(), &syms).unwrap(); + let res = w + .query_rule(r2, 0, &[0].iter().collect(), &syms, &Default::default()) + .unwrap(); for (_, fact) in res.iter_all() { println!("\t{}", syms.print_fact(fact)); } @@ -1534,6 +1555,7 @@ mod tests { 0, &[0].iter().collect(), &syms, + &Default::default(), ) .unwrap(); @@ -1585,6 +1607,7 @@ mod tests { 0, &[0].iter().collect(), &syms, + &Default::default(), ) .unwrap(); @@ -1630,6 +1653,7 @@ mod tests { 0, &[0].iter().collect(), &syms, + &Default::default(), ) .unwrap(); @@ -1675,6 +1699,7 @@ mod tests { 0, &[0].iter().collect(), &syms, + &Default::default(), ) .unwrap(); @@ -1698,6 +1723,7 @@ mod tests { 0, &[0].iter().collect(), &syms, + &Default::default(), ) .unwrap(); @@ -1740,7 +1766,9 @@ mod tests { println!("world:\n{}\n", syms.print_world(&w)); println!("\ntesting r1: {}\n", syms.print_rule(&r1)); - let res = w.query_rule(r1, 0, &[0].iter().collect(), &syms).unwrap(); + let res = w + .query_rule(r1, 0, &[0].iter().collect(), &syms, &Default::default()) + .unwrap(); for (_, fact) in res.iter_all() { println!("\t{}", syms.print_fact(fact)); } @@ -1779,7 +1807,9 @@ mod tests { ); println!("world:\n{}\n", syms.print_world(&w)); println!("\ntesting r1: {}\n", syms.print_rule(&r1)); - let res = w.query_rule(r1, 0, &[0].iter().collect(), &syms).unwrap(); + let res = w + .query_rule(r1, 0, &[0].iter().collect(), &syms, &Default::default()) + .unwrap(); println!("generated facts:"); for (_, fact) in res.iter_all() { @@ -1795,7 +1825,9 @@ mod tests { let r2 = rule(check, &[&read], &[pred(operation, &[&read])]); println!("world:\n{}\n", syms.print_world(&w)); println!("\ntesting r2: {}\n", syms.print_rule(&r2)); - let res = w.query_rule(r2, 0, &[0].iter().collect(), &syms).unwrap(); + let res = w + .query_rule(r2, 0, &[0].iter().collect(), &syms, &Default::default()) + .unwrap(); println!("generated facts:"); for (_, fact) in res.iter_all() { diff --git a/biscuit-auth/src/error.rs b/biscuit-auth/src/error.rs index 984369c5..5690e3eb 100644 --- a/biscuit-auth/src/error.rs +++ b/biscuit-auth/src/error.rs @@ -250,6 +250,10 @@ pub enum Expression { InvalidStack, #[error("Shadowed variable")] ShadowedVariable, + #[error("Undefined extern func: {0}")] + UndefinedExtern(String), + #[error("Error while evaluating extern func {0}: {1}")] + ExternEvalError(String, String), } /// runtime limits errors diff --git a/biscuit-auth/src/format/convert.rs b/biscuit-auth/src/format/convert.rs index c829c7b5..df8fa290 100644 --- a/biscuit-auth/src/format/convert.rs +++ b/biscuit-auth/src/format/convert.rs @@ -668,7 +668,12 @@ pub mod v2 { Unary::Parens => Kind::Parens, Unary::Length => Kind::Length, Unary::TypeOf => Kind::TypeOf, + Unary::Ffi(_) => Kind::Ffi, } as i32, + ffi_name: match u { + Unary::Ffi(name) => Some(name.to_owned()), + _ => None, + }, }) } Op::Binary(b) => { @@ -704,7 +709,12 @@ pub mod v2 { Binary::All => Kind::All, Binary::Any => Kind::Any, Binary::Get => Kind::Get, + Binary::Ffi(_) => Kind::Ffi, } as i32, + ffi_name: match b { + Binary::Ffi(name) => Some(name.to_owned()), + _ => None, + }, }) } Op::Closure(params, ops) => schema::op::Content::Closure(schema::OpClosure { @@ -728,54 +738,86 @@ pub mod v2 { use schema::{op, op_binary, op_unary}; Ok(match op.content.as_ref() { Some(op::Content::Value(id)) => Op::Value(proto_id_to_token_term(id)?), - Some(op::Content::Unary(u)) => match op_unary::Kind::from_i32(u.kind) { - Some(op_unary::Kind::Negate) => Op::Unary(Unary::Negate), - Some(op_unary::Kind::Parens) => Op::Unary(Unary::Parens), - Some(op_unary::Kind::Length) => Op::Unary(Unary::Length), - Some(op_unary::Kind::TypeOf) => Op::Unary(Unary::TypeOf), - None => { - return Err(error::Format::DeserializationError( - "deserialization error: unary operation is empty".to_string(), - )) - } - }, - Some(op::Content::Binary(b)) => match op_binary::Kind::from_i32(b.kind) { - Some(op_binary::Kind::LessThan) => Op::Binary(Binary::LessThan), - Some(op_binary::Kind::GreaterThan) => Op::Binary(Binary::GreaterThan), - Some(op_binary::Kind::LessOrEqual) => Op::Binary(Binary::LessOrEqual), - Some(op_binary::Kind::GreaterOrEqual) => Op::Binary(Binary::GreaterOrEqual), - Some(op_binary::Kind::Equal) => Op::Binary(Binary::Equal), - Some(op_binary::Kind::Contains) => Op::Binary(Binary::Contains), - Some(op_binary::Kind::Prefix) => Op::Binary(Binary::Prefix), - Some(op_binary::Kind::Suffix) => Op::Binary(Binary::Suffix), - Some(op_binary::Kind::Regex) => Op::Binary(Binary::Regex), - Some(op_binary::Kind::Add) => Op::Binary(Binary::Add), - Some(op_binary::Kind::Sub) => Op::Binary(Binary::Sub), - Some(op_binary::Kind::Mul) => Op::Binary(Binary::Mul), - Some(op_binary::Kind::Div) => Op::Binary(Binary::Div), - Some(op_binary::Kind::And) => Op::Binary(Binary::And), - Some(op_binary::Kind::Or) => Op::Binary(Binary::Or), - Some(op_binary::Kind::Intersection) => Op::Binary(Binary::Intersection), - Some(op_binary::Kind::Union) => Op::Binary(Binary::Union), - Some(op_binary::Kind::BitwiseAnd) => Op::Binary(Binary::BitwiseAnd), - Some(op_binary::Kind::BitwiseOr) => Op::Binary(Binary::BitwiseOr), - Some(op_binary::Kind::BitwiseXor) => Op::Binary(Binary::BitwiseXor), - Some(op_binary::Kind::NotEqual) => Op::Binary(Binary::NotEqual), - Some(op_binary::Kind::HeterogeneousEqual) => Op::Binary(Binary::HeterogeneousEqual), - Some(op_binary::Kind::HeterogeneousNotEqual) => { - Op::Binary(Binary::HeterogeneousNotEqual) + Some(op::Content::Unary(u)) => { + match (op_unary::Kind::from_i32(u.kind), u.ffi_name.as_ref()) { + (Some(op_unary::Kind::Negate), None) => Op::Unary(Unary::Negate), + (Some(op_unary::Kind::Parens), None) => Op::Unary(Unary::Parens), + (Some(op_unary::Kind::Length), None) => Op::Unary(Unary::Length), + (Some(op_unary::Kind::TypeOf), None) => Op::Unary(Unary::TypeOf), + (Some(op_unary::Kind::Ffi), Some(n)) => Op::Unary(Unary::Ffi(*n)), + (Some(op_unary::Kind::Ffi), None) => { + return Err(error::Format::DeserializationError( + "deserialization error: missing ffi name".to_string(), + )) + } + (Some(_), Some(_)) => { + return Err(error::Format::DeserializationError( + "deserialization error: ffi name set on a regular unary operation" + .to_string(), + )) + } + (None, _) => { + return Err(error::Format::DeserializationError( + "deserialization error: unary operation is empty".to_string(), + )) + } } - Some(op_binary::Kind::LazyAnd) => Op::Binary(Binary::LazyAnd), - Some(op_binary::Kind::LazyOr) => Op::Binary(Binary::LazyOr), - Some(op_binary::Kind::All) => Op::Binary(Binary::All), - Some(op_binary::Kind::Any) => Op::Binary(Binary::Any), - Some(op_binary::Kind::Get) => Op::Binary(Binary::Get), - None => { - return Err(error::Format::DeserializationError( - "deserialization error: binary operation is empty".to_string(), - )) + } + Some(op::Content::Binary(b)) => { + match (op_binary::Kind::from_i32(b.kind), b.ffi_name.as_ref()) { + (Some(op_binary::Kind::LessThan), None) => Op::Binary(Binary::LessThan), + (Some(op_binary::Kind::GreaterThan), None) => Op::Binary(Binary::GreaterThan), + (Some(op_binary::Kind::LessOrEqual), None) => Op::Binary(Binary::LessOrEqual), + (Some(op_binary::Kind::GreaterOrEqual), None) => { + Op::Binary(Binary::GreaterOrEqual) + } + (Some(op_binary::Kind::Equal), None) => Op::Binary(Binary::Equal), + (Some(op_binary::Kind::Contains), None) => Op::Binary(Binary::Contains), + (Some(op_binary::Kind::Prefix), None) => Op::Binary(Binary::Prefix), + (Some(op_binary::Kind::Suffix), None) => Op::Binary(Binary::Suffix), + (Some(op_binary::Kind::Regex), None) => Op::Binary(Binary::Regex), + (Some(op_binary::Kind::Add), None) => Op::Binary(Binary::Add), + (Some(op_binary::Kind::Sub), None) => Op::Binary(Binary::Sub), + (Some(op_binary::Kind::Mul), None) => Op::Binary(Binary::Mul), + (Some(op_binary::Kind::Div), None) => Op::Binary(Binary::Div), + (Some(op_binary::Kind::And), None) => Op::Binary(Binary::And), + (Some(op_binary::Kind::Or), None) => Op::Binary(Binary::Or), + (Some(op_binary::Kind::Intersection), None) => Op::Binary(Binary::Intersection), + (Some(op_binary::Kind::Union), None) => Op::Binary(Binary::Union), + (Some(op_binary::Kind::BitwiseAnd), None) => Op::Binary(Binary::BitwiseAnd), + (Some(op_binary::Kind::BitwiseOr), None) => Op::Binary(Binary::BitwiseOr), + (Some(op_binary::Kind::BitwiseXor), None) => Op::Binary(Binary::BitwiseXor), + (Some(op_binary::Kind::NotEqual), None) => Op::Binary(Binary::NotEqual), + (Some(op_binary::Kind::HeterogeneousEqual), None) => { + Op::Binary(Binary::HeterogeneousEqual) + } + (Some(op_binary::Kind::HeterogeneousNotEqual), None) => { + Op::Binary(Binary::HeterogeneousNotEqual) + } + (Some(op_binary::Kind::LazyAnd), None) => Op::Binary(Binary::LazyAnd), + (Some(op_binary::Kind::LazyOr), None) => Op::Binary(Binary::LazyOr), + (Some(op_binary::Kind::All), None) => Op::Binary(Binary::All), + (Some(op_binary::Kind::Any), None) => Op::Binary(Binary::Any), + (Some(op_binary::Kind::Get), None) => Op::Binary(Binary::Get), + (Some(op_binary::Kind::Ffi), Some(n)) => Op::Binary(Binary::Ffi(*n)), + (Some(op_binary::Kind::Ffi), None) => { + return Err(error::Format::DeserializationError( + "deserialization error: missing ffi name".to_string(), + )) + } + (Some(_), Some(_)) => { + return Err(error::Format::DeserializationError( + "deserialization error: ffi name set on a regular binary operation" + .to_string(), + )) + } + (None, _) => { + return Err(error::Format::DeserializationError( + "deserialization error: binary operation is empty".to_string(), + )) + } } - }, + } Some(op::Content::Closure(op_closure)) => Op::Closure( op_closure.params.clone(), op_closure diff --git a/biscuit-auth/src/format/schema.proto b/biscuit-auth/src/format/schema.proto index 8e716f6e..c71108d8 100644 --- a/biscuit-auth/src/format/schema.proto +++ b/biscuit-auth/src/format/schema.proto @@ -147,9 +147,11 @@ message OpUnary { Parens = 1; Length = 2; TypeOf = 3; + Ffi = 4; } required Kind kind = 1; + optional uint64 ffiName = 2; } message OpBinary { @@ -182,9 +184,11 @@ message OpBinary { All = 25; Any = 26; Get = 27; + Ffi = 28; } required Kind kind = 1; + optional uint64 ffiName = 2; } message OpClosure { diff --git a/biscuit-auth/src/format/schema.rs b/biscuit-auth/src/format/schema.rs index 98a00bd7..ede2c0b2 100644 --- a/biscuit-auth/src/format/schema.rs +++ b/biscuit-auth/src/format/schema.rs @@ -233,6 +233,8 @@ pub mod op { pub struct OpUnary { #[prost(enumeration="op_unary::Kind", required, tag="1")] pub kind: i32, + #[prost(uint64, optional, tag="2")] + pub ffi_name: ::core::option::Option, } /// Nested message and enum types in `OpUnary`. pub mod op_unary { @@ -243,12 +245,15 @@ pub mod op_unary { Parens = 1, Length = 2, TypeOf = 3, + Ffi = 4, } } #[derive(Clone, PartialEq, ::prost::Message)] pub struct OpBinary { #[prost(enumeration="op_binary::Kind", required, tag="1")] pub kind: i32, + #[prost(uint64, optional, tag="2")] + pub ffi_name: ::core::option::Option, } /// Nested message and enum types in `OpBinary`. pub mod op_binary { @@ -283,6 +288,7 @@ pub mod op_binary { All = 25, Any = 26, Get = 27, + Ffi = 28, } } #[derive(Clone, PartialEq, ::prost::Message)] diff --git a/biscuit-auth/src/parser.rs b/biscuit-auth/src/parser.rs index 59510afb..77f0c3ef 100644 --- a/biscuit-auth/src/parser.rs +++ b/biscuit-auth/src/parser.rs @@ -383,7 +383,11 @@ mod tests { println!("print: {}", e.print(&syms).unwrap()); let h = HashMap::new(); let result = e - .evaluate(&h, &mut TemporarySymbolTable::new(&syms)) + .evaluate( + &h, + &mut TemporarySymbolTable::new(&syms), + &Default::default(), + ) .unwrap(); println!("evaluates to: {:?}", result); @@ -414,7 +418,11 @@ mod tests { println!("print: {}", e.print(&syms).unwrap()); let h = HashMap::new(); let result = e - .evaluate(&h, &mut TemporarySymbolTable::new(&syms)) + .evaluate( + &h, + &mut TemporarySymbolTable::new(&syms), + &Default::default(), + ) .unwrap(); println!("evaluates to: {:?}", result); diff --git a/biscuit-auth/src/token/authorizer.rs b/biscuit-auth/src/token/authorizer.rs index 5d9a2e99..00080b16 100644 --- a/biscuit-auth/src/token/authorizer.rs +++ b/biscuit-auth/src/token/authorizer.rs @@ -469,10 +469,15 @@ impl Authorizer { &self.public_key_to_block_id, ); + let extern_binary = limits.extern_funcs.clone(); self.world.run_with_limits(&self.symbols, limits)?; - let res = self - .world - .query_rule(rule, usize::MAX, &rule_trusted_origins, &self.symbols)?; + let res = self.world.query_rule( + rule, + usize::MAX, + &rule_trusted_origins, + &self.symbols, + &extern_binary, + )?; res.inner .into_iter() @@ -552,6 +557,7 @@ impl Authorizer { rule: datalog::Rule, limits: AuthorizerLimits, ) -> Result, error::Token> { + let extern_binary = limits.extern_funcs.clone(); self.world.run_with_limits(&self.symbols, limits)?; let rule_trusted_origins = if rule.scopes.is_empty() { @@ -568,9 +574,13 @@ impl Authorizer { ) }; - let res = self - .world - .query_rule(rule, 0, &rule_trusted_origins, &self.symbols)?; + let res = self.world.query_rule( + rule, + 0, + &rule_trusted_origins, + &self.symbols, + &extern_binary, + )?; let r: HashSet<_> = res.into_iter().map(|(_, fact)| fact).collect(); @@ -741,16 +751,20 @@ impl Authorizer { usize::MAX, &rule_trusted_origins, &self.symbols, + &limits.extern_funcs, + )?, + CheckKind::All => self.world.query_match_all( + query, + &rule_trusted_origins, + &self.symbols, + &limits.extern_funcs, )?, - CheckKind::All => { - self.world - .query_match_all(query, &rule_trusted_origins, &self.symbols)? - } CheckKind::Reject => !self.world.query_match( query, usize::MAX, &rule_trusted_origins, &self.symbols, + &limits.extern_funcs, )?, }; @@ -799,17 +813,20 @@ impl Authorizer { 0, &rule_trusted_origins, &self.symbols, + &limits.extern_funcs, )?, CheckKind::All => self.world.query_match_all( query.clone(), &rule_trusted_origins, &self.symbols, + &limits.extern_funcs, )?, CheckKind::Reject => !self.world.query_match( query.clone(), 0, &rule_trusted_origins, &self.symbols, + &limits.extern_funcs, )?, }; @@ -849,6 +866,7 @@ impl Authorizer { usize::MAX, &rule_trusted_origins, &self.symbols, + &limits.extern_funcs, )?; let now = Instant::now(); @@ -898,17 +916,20 @@ impl Authorizer { i + 1, &rule_trusted_origins, &self.symbols, + &limits.extern_funcs, )?, CheckKind::All => self.world.query_match_all( query.clone(), &rule_trusted_origins, &self.symbols, + &limits.extern_funcs, )?, CheckKind::Reject => !self.world.query_match( query.clone(), i + 1, &rule_trusted_origins, &self.symbols, + &limits.extern_funcs, )?, }; diff --git a/biscuit-auth/src/token/authorizer/snapshot.rs b/biscuit-auth/src/token/authorizer/snapshot.rs index 373aff9f..247dd890 100644 --- a/biscuit-auth/src/token/authorizer/snapshot.rs +++ b/biscuit-auth/src/token/authorizer/snapshot.rs @@ -31,6 +31,7 @@ impl super::Authorizer { max_facts: limits.max_facts, max_iterations: limits.max_iterations, max_time: Duration::from_nanos(limits.max_time), + extern_funcs: Default::default(), }; let execution_time = Duration::from_nanos(execution_time); diff --git a/biscuit-auth/src/token/builder.rs b/biscuit-auth/src/token/builder.rs index 97922451..248afe6e 100644 --- a/biscuit-auth/src/token/builder.rs +++ b/biscuit-auth/src/token/builder.rs @@ -1,7 +1,7 @@ //! helper functions and structure to create tokens and blocks use super::{default_symbol_table, Biscuit, Block}; use crate::crypto::{KeyPair, PublicKey}; -use crate::datalog::{self, get_schema_version, SymbolTable}; +use crate::datalog::{self, get_schema_version, SymbolTable, TemporarySymbolTable}; use crate::error; use crate::token::builder_ext::BuilderExt; use biscuit_parser::parser::parse_block_source; @@ -17,7 +17,10 @@ use std::{ }; // reexport those because the builder uses the same definitions -pub use crate::datalog::{Binary, Expression as DatalogExpression, Op as DatalogOp, Unary}; +pub use crate::datalog::{ + Binary as DatalogBinary, Expression as DatalogExpression, Op as DatalogOp, + Unary as DatalogUnary, +}; /// creates a Block content to append to an existing token #[derive(Clone, Debug, Default)] @@ -419,6 +422,50 @@ pub trait Convert: Sized { } } +/// Builder for a unary operation +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum Unary { + Negate, + Parens, + Length, + TypeOf, + Ffi(String), +} + +/// Builder for a binary operation +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum Binary { + LessThan, + GreaterThan, + LessOrEqual, + GreaterOrEqual, + Equal, + Contains, + Prefix, + Suffix, + Regex, + Add, + Sub, + Mul, + Div, + And, + Or, + Intersection, + Union, + BitwiseAnd, + BitwiseOr, + BitwiseXor, + NotEqual, + HeterogeneousEqual, + HeterogeneousNotEqual, + LazyAnd, + LazyOr, + All, + Any, + Get, + Ffi(String), +} + /// Builder for a Datalog value #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum Term { @@ -519,6 +566,98 @@ pub enum MapKey { Parameter(String), } +impl Term { + pub fn to_datalog(self, symbols: &mut TemporarySymbolTable) -> datalog::Term { + match self { + Term::Variable(s) => datalog::Term::Variable(symbols.insert(&s) as u32), + Term::Integer(i) => datalog::Term::Integer(i), + Term::Str(s) => datalog::Term::Str(symbols.insert(&s)), + Term::Date(d) => datalog::Term::Date(d), + Term::Bytes(s) => datalog::Term::Bytes(s), + Term::Bool(b) => datalog::Term::Bool(b), + Term::Set(s) => { + datalog::Term::Set(s.into_iter().map(|i| i.to_datalog(symbols)).collect()) + } + Term::Null => datalog::Term::Null, + Term::Array(a) => { + datalog::Term::Array(a.into_iter().map(|i| i.to_datalog(symbols)).collect()) + } + Term::Map(m) => datalog::Term::Map( + m.into_iter() + .map(|(k, i)| { + ( + match k { + MapKey::Integer(i) => datalog::MapKey::Integer(i), + MapKey::Str(s) => datalog::MapKey::Str(symbols.insert(&s)), + // The error is caught in the `add_xxx` functions, so this should + // not happen™ + MapKey::Parameter(s) => panic!("Remaining parameter {}", &s), + }, + i.to_datalog(symbols), + ) + }) + .collect(), + ), + // The error is caught in the `add_xxx` functions, so this should + // not happen™ + Term::Parameter(s) => panic!("Remaining parameter {}", &s), + } + } + + pub fn from_datalog( + term: datalog::Term, + symbols: &TemporarySymbolTable, + ) -> Result { + Ok(match term { + datalog::Term::Variable(s) => Term::Variable( + symbols + .get_symbol(s as u64) + .ok_or(error::Expression::UnknownVariable(s))? + .to_string(), + ), + datalog::Term::Integer(i) => Term::Integer(i), + datalog::Term::Str(s) => Term::Str( + symbols + .get_symbol(s) + .ok_or(error::Expression::UnknownSymbol(s))? + .to_string(), + ), + datalog::Term::Date(d) => Term::Date(d), + datalog::Term::Bytes(s) => Term::Bytes(s), + datalog::Term::Bool(b) => Term::Bool(b), + datalog::Term::Set(s) => Term::Set( + s.into_iter() + .map(|i| Self::from_datalog(i, symbols)) + .collect::>()?, + ), + datalog::Term::Null => Term::Null, + datalog::Term::Array(a) => Term::Array( + a.into_iter() + .map(|i| Self::from_datalog(i, symbols)) + .collect::>()?, + ), + datalog::Term::Map(m) => Term::Map( + m.into_iter() + .map(|(k, i)| { + Ok(( + match k { + datalog::MapKey::Integer(i) => MapKey::Integer(i), + datalog::MapKey::Str(s) => MapKey::Str( + symbols + .get_symbol(s) + .ok_or(error::Expression::UnknownSymbol(s))? + .to_string(), + ), + }, + Self::from_datalog(i, symbols)?, + )) + }) + .collect::>()?, + ), + }) + } +} + impl Convert for Term { fn convert(&self, symbols: &mut SymbolTable) -> datalog::Term { match self { @@ -1099,8 +1238,8 @@ impl Convert for Op { fn convert(&self, symbols: &mut SymbolTable) -> datalog::Op { match self { Op::Value(t) => datalog::Op::Value(t.convert(symbols)), - Op::Unary(u) => datalog::Op::Unary(u.clone()), - Op::Binary(b) => datalog::Op::Binary(b.clone()), + Op::Unary(u) => datalog::Op::Unary(u.convert(symbols)), + Op::Binary(b) => datalog::Op::Binary(b.convert(symbols)), Op::Closure(ps, os) => datalog::Op::Closure( ps.iter().map(|p| symbols.insert(p) as u32).collect(), os.iter().map(|o| o.convert(symbols)).collect(), @@ -1111,8 +1250,8 @@ impl Convert for Op { fn convert_from(op: &datalog::Op, symbols: &SymbolTable) -> Result { Ok(match op { datalog::Op::Value(t) => Op::Value(Term::convert_from(t, symbols)?), - datalog::Op::Unary(u) => Op::Unary(u.clone()), - datalog::Op::Binary(b) => Op::Binary(b.clone()), + datalog::Op::Unary(u) => Op::Unary(Unary::convert_from(u, symbols)?), + datalog::Op::Binary(b) => Op::Binary(Binary::convert_from(b, symbols)?), datalog::Op::Closure(ps, os) => Op::Closure( ps.iter() .map(|p| symbols.print_symbol(*p as u64)) @@ -1138,6 +1277,28 @@ impl From for Op { } } +impl Convert for Unary { + fn convert(&self, symbols: &mut SymbolTable) -> datalog::Unary { + match self { + Unary::Negate => datalog::Unary::Negate, + Unary::Parens => datalog::Unary::Parens, + Unary::Length => datalog::Unary::Length, + Unary::TypeOf => datalog::Unary::TypeOf, + Unary::Ffi(n) => datalog::Unary::Ffi(symbols.insert(n)), + } + } + + fn convert_from(f: &datalog::Unary, symbols: &SymbolTable) -> Result { + match f { + datalog::Unary::Negate => Ok(Unary::Negate), + datalog::Unary::Parens => Ok(Unary::Parens), + datalog::Unary::Length => Ok(Unary::Length), + datalog::Unary::TypeOf => Ok(Unary::TypeOf), + datalog::Unary::Ffi(i) => Ok(Unary::Ffi(symbols.print_symbol(*i)?)), + } + } +} + impl From for Unary { fn from(unary: biscuit_parser::builder::Unary) -> Self { match unary { @@ -1145,6 +1306,77 @@ impl From for Unary { biscuit_parser::builder::Unary::Parens => Unary::Parens, biscuit_parser::builder::Unary::Length => Unary::Length, biscuit_parser::builder::Unary::TypeOf => Unary::TypeOf, + biscuit_parser::builder::Unary::Ffi(name) => Unary::Ffi(name), + } + } +} + +impl Convert for Binary { + fn convert(&self, symbols: &mut SymbolTable) -> datalog::Binary { + match self { + Binary::LessThan => datalog::Binary::LessThan, + Binary::GreaterThan => datalog::Binary::GreaterThan, + Binary::LessOrEqual => datalog::Binary::LessOrEqual, + Binary::GreaterOrEqual => datalog::Binary::GreaterOrEqual, + Binary::Equal => datalog::Binary::Equal, + Binary::Contains => datalog::Binary::Contains, + Binary::Prefix => datalog::Binary::Prefix, + Binary::Suffix => datalog::Binary::Suffix, + Binary::Regex => datalog::Binary::Regex, + Binary::Add => datalog::Binary::Add, + Binary::Sub => datalog::Binary::Sub, + Binary::Mul => datalog::Binary::Mul, + Binary::Div => datalog::Binary::Div, + Binary::And => datalog::Binary::And, + Binary::Or => datalog::Binary::Or, + Binary::Intersection => datalog::Binary::Intersection, + Binary::Union => datalog::Binary::Union, + Binary::BitwiseAnd => datalog::Binary::BitwiseAnd, + Binary::BitwiseOr => datalog::Binary::BitwiseOr, + Binary::BitwiseXor => datalog::Binary::BitwiseXor, + Binary::NotEqual => datalog::Binary::NotEqual, + Binary::HeterogeneousEqual => datalog::Binary::HeterogeneousEqual, + Binary::HeterogeneousNotEqual => datalog::Binary::HeterogeneousNotEqual, + Binary::LazyAnd => datalog::Binary::LazyAnd, + Binary::LazyOr => datalog::Binary::LazyOr, + Binary::All => datalog::Binary::All, + Binary::Any => datalog::Binary::Any, + Binary::Get => datalog::Binary::Get, + Binary::Ffi(n) => datalog::Binary::Ffi(symbols.insert(n)), + } + } + + fn convert_from(f: &datalog::Binary, symbols: &SymbolTable) -> Result { + match f { + datalog::Binary::LessThan => Ok(Binary::LessThan), + datalog::Binary::GreaterThan => Ok(Binary::GreaterThan), + datalog::Binary::LessOrEqual => Ok(Binary::LessOrEqual), + datalog::Binary::GreaterOrEqual => Ok(Binary::GreaterOrEqual), + datalog::Binary::Equal => Ok(Binary::Equal), + datalog::Binary::Contains => Ok(Binary::Contains), + datalog::Binary::Prefix => Ok(Binary::Prefix), + datalog::Binary::Suffix => Ok(Binary::Suffix), + datalog::Binary::Regex => Ok(Binary::Regex), + datalog::Binary::Add => Ok(Binary::Add), + datalog::Binary::Sub => Ok(Binary::Sub), + datalog::Binary::Mul => Ok(Binary::Mul), + datalog::Binary::Div => Ok(Binary::Div), + datalog::Binary::And => Ok(Binary::And), + datalog::Binary::Or => Ok(Binary::Or), + datalog::Binary::Intersection => Ok(Binary::Intersection), + datalog::Binary::Union => Ok(Binary::Union), + datalog::Binary::BitwiseAnd => Ok(Binary::BitwiseAnd), + datalog::Binary::BitwiseOr => Ok(Binary::BitwiseOr), + datalog::Binary::BitwiseXor => Ok(Binary::BitwiseXor), + datalog::Binary::NotEqual => Ok(Binary::NotEqual), + datalog::Binary::HeterogeneousEqual => Ok(Binary::HeterogeneousEqual), + datalog::Binary::HeterogeneousNotEqual => Ok(Binary::HeterogeneousNotEqual), + datalog::Binary::LazyAnd => Ok(Binary::LazyAnd), + datalog::Binary::LazyOr => Ok(Binary::LazyOr), + datalog::Binary::All => Ok(Binary::All), + datalog::Binary::Any => Ok(Binary::Any), + datalog::Binary::Get => Ok(Binary::Get), + datalog::Binary::Ffi(i) => Ok(Binary::Ffi(symbols.print_symbol(*i)?)), } } } @@ -1180,6 +1412,7 @@ impl From for Binary { biscuit_parser::builder::Binary::All => Binary::All, biscuit_parser::builder::Binary::Any => Binary::Any, biscuit_parser::builder::Binary::Get => Binary::Get, + biscuit_parser::builder::Binary::Ffi(name) => Binary::Ffi(name), } } } diff --git a/biscuit-auth/src/token/mod.rs b/biscuit-auth/src/token/mod.rs index a6498699..66d43acf 100644 --- a/biscuit-auth/src/token/mod.rs +++ b/biscuit-auth/src/token/mod.rs @@ -36,7 +36,7 @@ pub const MAX_SCHEMA_VERSION: u32 = 6; pub const DATALOG_3_1: u32 = 4; /// starting version for 3rd party blocks (datalog 3.2) pub const DATALOG_3_2: u32 = 5; -/// starting version for datalog 3.3 features (reject if, closures, array/map, null, …) +/// starting version for datalog 3.3 features (reject if, closures, array/map, null, external functions, …) pub const DATALOG_3_3: u32 = 6; /// some symbols are predefined and available in every implementation, to avoid diff --git a/biscuit-auth/tests/macros.rs b/biscuit-auth/tests/macros.rs index 86f0016a..96d6b5b0 100644 --- a/biscuit-auth/tests/macros.rs +++ b/biscuit-auth/tests/macros.rs @@ -34,6 +34,14 @@ check if "my_value".starts_with("my"); check if {false, true}.any($p -> true); "#, ); + + let b = block!(r#"check if "test".extern::toto() && "test".extern::test("test");"#); + + assert_eq!( + b.to_string(), + r#"check if "test".extern::toto() && "test".extern::test("test"); +"# + ); } #[test] diff --git a/biscuit-parser/src/builder.rs b/biscuit-parser/src/builder.rs index 5bf0634e..4a993157 100644 --- a/biscuit-parser/src/builder.rs +++ b/biscuit-parser/src/builder.rs @@ -284,6 +284,7 @@ pub enum Unary { Parens, Length, TypeOf, + Ffi(String), } #[derive(Debug, Clone, PartialEq, Eq)] @@ -316,6 +317,7 @@ pub enum Binary { All, Any, Get, + Ffi(String), } #[cfg(feature = "datalog-macro")] @@ -339,10 +341,11 @@ impl ToTokens for Op { impl ToTokens for Unary { fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { tokens.extend(match self { - Unary::Negate => quote! {::biscuit_auth::datalog::Unary::Negate }, - Unary::Parens => quote! {::biscuit_auth::datalog::Unary::Parens }, - Unary::Length => quote! {::biscuit_auth::datalog::Unary::Length }, - Unary::TypeOf => quote! {::biscuit_auth::datalog::Unary::TypeOf }, + Unary::Negate => quote! {::biscuit_auth::builder::Unary::Negate }, + Unary::Parens => quote! {::biscuit_auth::builder::Unary::Parens }, + Unary::Length => quote! {::biscuit_auth::builder::Unary::Length }, + Unary::TypeOf => quote! {::biscuit_auth::builder::Unary::TypeOf }, + Unary::Ffi(name) => quote! {::biscuit_auth::builder::Unary::Ffi(#name.to_string()) }, }); } } @@ -351,38 +354,39 @@ impl ToTokens for Unary { impl ToTokens for Binary { fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { tokens.extend(match self { - Binary::LessThan => quote! { ::biscuit_auth::datalog::Binary::LessThan }, - Binary::GreaterThan => quote! { ::biscuit_auth::datalog::Binary::GreaterThan }, - Binary::LessOrEqual => quote! { ::biscuit_auth::datalog::Binary::LessOrEqual }, - Binary::GreaterOrEqual => quote! { ::biscuit_auth::datalog::Binary::GreaterOrEqual }, - Binary::Equal => quote! { ::biscuit_auth::datalog::Binary::Equal }, - Binary::Contains => quote! { ::biscuit_auth::datalog::Binary::Contains }, - Binary::Prefix => quote! { ::biscuit_auth::datalog::Binary::Prefix }, - Binary::Suffix => quote! { ::biscuit_auth::datalog::Binary::Suffix }, - Binary::Regex => quote! { ::biscuit_auth::datalog::Binary::Regex }, - Binary::Add => quote! { ::biscuit_auth::datalog::Binary::Add }, - Binary::Sub => quote! { ::biscuit_auth::datalog::Binary::Sub }, - Binary::Mul => quote! { ::biscuit_auth::datalog::Binary::Mul }, - Binary::Div => quote! { ::biscuit_auth::datalog::Binary::Div }, - Binary::And => quote! { ::biscuit_auth::datalog::Binary::And }, - Binary::Or => quote! { ::biscuit_auth::datalog::Binary::Or }, - Binary::Intersection => quote! { ::biscuit_auth::datalog::Binary::Intersection }, - Binary::Union => quote! { ::biscuit_auth::datalog::Binary::Union }, - Binary::BitwiseAnd => quote! { ::biscuit_auth::datalog::Binary::BitwiseAnd }, - Binary::BitwiseOr => quote! { ::biscuit_auth::datalog::Binary::BitwiseOr }, - Binary::BitwiseXor => quote! { ::biscuit_auth::datalog::Binary::BitwiseXor }, - Binary::NotEqual => quote! { ::biscuit_auth::datalog::Binary::NotEqual }, + Binary::LessThan => quote! { ::biscuit_auth::builder::Binary::LessThan }, + Binary::GreaterThan => quote! { ::biscuit_auth::builder::Binary::GreaterThan }, + Binary::LessOrEqual => quote! { ::biscuit_auth::builder::Binary::LessOrEqual }, + Binary::GreaterOrEqual => quote! { ::biscuit_auth::builder::Binary::GreaterOrEqual }, + Binary::Equal => quote! { ::biscuit_auth::builder::Binary::Equal }, + Binary::Contains => quote! { ::biscuit_auth::builder::Binary::Contains }, + Binary::Prefix => quote! { ::biscuit_auth::builder::Binary::Prefix }, + Binary::Suffix => quote! { ::biscuit_auth::builder::Binary::Suffix }, + Binary::Regex => quote! { ::biscuit_auth::builder::Binary::Regex }, + Binary::Add => quote! { ::biscuit_auth::builder::Binary::Add }, + Binary::Sub => quote! { ::biscuit_auth::builder::Binary::Sub }, + Binary::Mul => quote! { ::biscuit_auth::builder::Binary::Mul }, + Binary::Div => quote! { ::biscuit_auth::builder::Binary::Div }, + Binary::And => quote! { ::biscuit_auth::builder::Binary::And }, + Binary::Or => quote! { ::biscuit_auth::builder::Binary::Or }, + Binary::Intersection => quote! { ::biscuit_auth::builder::Binary::Intersection }, + Binary::Union => quote! { ::biscuit_auth::builder::Binary::Union }, + Binary::BitwiseAnd => quote! { ::biscuit_auth::builder::Binary::BitwiseAnd }, + Binary::BitwiseOr => quote! { ::biscuit_auth::builder::Binary::BitwiseOr }, + Binary::BitwiseXor => quote! { ::biscuit_auth::builder::Binary::BitwiseXor }, + Binary::NotEqual => quote! { ::biscuit_auth::builder::Binary::NotEqual }, Binary::HeterogeneousEqual => { - quote! { ::biscuit_auth::datalog::Binary::HeterogeneousEqual} + quote! { ::biscuit_auth::builder::Binary::HeterogeneousEqual} } Binary::HeterogeneousNotEqual => { - quote! { ::biscuit_auth::datalog::Binary::HeterogeneousNotEqual} + quote! { ::biscuit_auth::builder::Binary::HeterogeneousNotEqual} } - Binary::LazyAnd => quote! { ::biscuit_auth::datalog::Binary::LazyAnd }, - Binary::LazyOr => quote! { ::biscuit_auth::datalog::Binary::LazyOr }, - Binary::All => quote! { ::biscuit_auth::datalog::Binary::All }, - Binary::Any => quote! { ::biscuit_auth::datalog::Binary::Any }, - Binary::Get => quote! { ::biscuit_auth::datalog::Binary::Get }, + Binary::LazyAnd => quote! { ::biscuit_auth::builder::Binary::LazyAnd }, + Binary::LazyOr => quote! { ::biscuit_auth::builder::Binary::LazyOr }, + Binary::All => quote! { ::biscuit_auth::builder::Binary::All }, + Binary::Any => quote! { ::biscuit_auth::builder::Binary::Any }, + Binary::Get => quote! { ::biscuit_auth::builder::Binary::Get }, + Binary::Ffi(name) => quote! {::biscuit_auth::builder::Binary::Ffi(#name.to_string()) }, }); } } diff --git a/biscuit-parser/src/parser.rs b/biscuit-parser/src/parser.rs index 0b44f108..a810b085 100644 --- a/biscuit-parser/src/parser.rs +++ b/biscuit-parser/src/parser.rs @@ -497,6 +497,16 @@ fn binary_op_7(i: &str) -> IResult<&str, builder::Binary, Error> { alt((value(Binary::Mul, tag("*")), value(Binary::Div, tag("/"))))(i) } +fn extern_un(i: &str) -> IResult<&str, builder::Unary, Error> { + let (i, func) = preceded(tag("extern::"), name)(i)?; + Ok((i, builder::Unary::Ffi(func.to_string()))) +} + +fn extern_bin(i: &str) -> IResult<&str, builder::Binary, Error> { + let (i, func) = preceded(tag("extern::"), name)(i)?; + Ok((i, builder::Binary::Ffi(func.to_string()))) +} + fn binary_op_8(i: &str) -> IResult<&str, builder::Binary, Error> { use builder::Binary; @@ -510,6 +520,7 @@ fn binary_op_8(i: &str) -> IResult<&str, builder::Binary, Error> { value(Binary::All, tag("all")), value(Binary::Any, tag("any")), value(Binary::Get, tag("get")), + extern_bin, ))(i) } @@ -720,6 +731,7 @@ fn unary_method(i: &str) -> IResult<&str, builder::Unary, Error> { let (i, op) = alt(( value(Unary::Length, tag("length")), value(Unary::TypeOf, tag("type")), + extern_un, ))(i)?; let (i, _) = char('(')(i)?; @@ -2609,6 +2621,31 @@ mod tests { Op::Value(array(h.clone())), Op::Value(var("0")), Op::Binary(Binary::Contains), + ] + )) + ) + } + + #[test] + fn extern_funcs() { + use builder::{int, Binary, Op}; + + assert_eq!( + super::expr("2.extern::toto()").map(|(i, o)| (i, o.opcodes())), + Ok(( + "", + vec![Op::Value(int(2)), Op::Unary(Unary::Ffi("toto".to_string()))], + )) + ); + + assert_eq!( + super::expr("2.extern::toto(3)").map(|(i, o)| (i, o.opcodes())), + Ok(( + "", + vec![ + Op::Value(int(2)), + Op::Value(int(3)), + Op::Binary(Binary::Ffi("toto".to_string())), ], )) );