From 525c27805c26463cad2889e9fe10fd9535a2454c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Delafargue?= Date: Sat, 8 Jun 2024 15:21:03 +0200 Subject: [PATCH] Support for closures (#202) This introduces the closure operations to the Biscuit language, first with the `.all()` and `.any()` operations to add conditions on the elements of a set. It is now possible to use expressions with the following format: ``` check if [1,2,3].all($p -> $p > 0); check if [1,2,3].any($p -> $p > 2); ``` Co-authored-by: Geoffroy Couprie --- biscuit-auth/examples/testcases.rs | 78 +++- biscuit-auth/samples/README.md | 133 +++++- biscuit-auth/samples/samples.json | 105 ++++- biscuit-auth/samples/test017_expressions.bc | Bin 1799 -> 1687 bytes .../samples/test027_integer_wraparound.bc | Bin 329 -> 293 bytes .../samples/test032_laziness_closures.bc | Bin 0 -> 758 bytes biscuit-auth/src/datalog/expression.rs | 388 +++++++++++++++++- biscuit-auth/src/datalog/mod.rs | 30 +- biscuit-auth/src/error.rs | 2 + biscuit-auth/src/format/convert.rs | 237 ++++++----- biscuit-auth/src/format/schema.proto | 12 +- biscuit-auth/src/format/schema.rs | 15 +- biscuit-auth/src/parser.rs | 33 +- biscuit-auth/src/token/builder.rs | 20 + biscuit-auth/tests/macros.rs | 6 +- biscuit-parser/src/builder.rs | 15 + biscuit-parser/src/parser.rs | 82 +++- 17 files changed, 947 insertions(+), 209 deletions(-) create mode 100644 biscuit-auth/samples/test032_laziness_closures.bc diff --git a/biscuit-auth/examples/testcases.rs b/biscuit-auth/examples/testcases.rs index 822c9c28..e0558a1b 100644 --- a/biscuit-auth/examples/testcases.rs +++ b/biscuit-auth/examples/testcases.rs @@ -152,6 +152,8 @@ fn run(target: String, root_key: Option, test: bool, json: bool) { add_test_result(&mut results, heterogeneous_equal(&target, &root, test)); + add_test_result(&mut results, closures(&target, &root, test)); + if json { let s = serde_json::to_string_pretty(&TestCases { root_private_key: hex::encode(root.private().to_bytes()), @@ -1277,12 +1279,6 @@ fn expressions(target: &str, root: &KeyPair, test: bool) -> TestResult { check if true; //boolean false and negation check if !false; - //boolean and - check if !false && true; - //boolean or - check if false || true; - //boolean parens - check if (true || false) && true; // boolean strict equality check if true === true; check if false === false; @@ -1303,7 +1299,7 @@ fn expressions(target: &str, root: &KeyPair, test: bool) -> TestResult { check if 1 + 2 * 3 - 4 /2 === 5; // string prefix and suffix - check if "hello world".starts_with("hello") && "hello world".ends_with("world"); + check if "hello world".starts_with("hello"), "hello world".ends_with("world"); // string regex check if "aaabde".matches("a*c?.e"); // string contains @@ -1893,12 +1889,9 @@ fn integer_wraparound(target: &str, root: &KeyPair, test: bool) -> TestResult { let biscuit = biscuit!( r#" - // integer overflows must abort evaluating the whole expression - // todo update this test when integer overflows abort - // the whole datalog evaluation - check if true || 10000000000 * 10000000000 != 0; - check if true || 9223372036854775807 + 1 != 0; - check if true || -9223372036854775808 - 1 != 0; + check if 10000000000 * 10000000000 != 0; + check if 9223372036854775807 + 1 != 0; + check if -9223372036854775808 - 1 != 0; "# ) .build_with_rng(&root, SymbolTable::default(), &mut rng) @@ -2079,6 +2072,65 @@ fn heterogeneous_equal(target: &str, root: &KeyPair, test: bool) -> TestResult { } } +fn closures(target: &str, root: &KeyPair, test: bool) -> TestResult { + let mut rng: StdRng = SeedableRng::seed_from_u64(1234); + let title = "test laziness and closures".to_string(); + let filename = "test032_laziness_closures".to_string(); + let token; + + let biscuit = biscuit!( + r#" + // boolean and + check if !false && true; + // boolean or + check if false || true; + // boolean parens + check if (true || false) && true; + // boolean and laziness + check if !(false && "x".intersection("x")); + // boolean or laziness + check if true || "x".intersection("x"); + // all + check if [1,2,3].all($p -> $p > 0); + // all + check if ![1,2,3].all($p -> $p == 2); + // any + check if [1,2,3].any($p -> $p > 2); + // any + check if ![1,2,3].any($p -> $p > 3); + // nested closures + check if [1,2,3].any($p -> $p > 1 && [3,4,5].any($q -> $p == $q)); + "# + ) + .build_with_rng(&root, SymbolTable::default(), &mut rng) + .unwrap(); + + token = print_blocks(&biscuit); + + let data = write_or_load_testcase(target, &filename, root, &biscuit, test); + + let mut validations = BTreeMap::new(); + validations.insert( + "".to_string(), + validate_token(root, &data[..], "allow if true"), + ); + validations.insert( + "shadowing".to_string(), + validate_token( + root, + &data[..], + "allow if [true].any($p -> [true].all($p -> $p))", + ), + ); + + TestResult { + title, + filename, + token, + validations, + } +} + fn print_blocks(token: &Biscuit) -> Vec { let mut v = Vec::new(); diff --git a/biscuit-auth/samples/README.md b/biscuit-auth/samples/README.md index 141a045f..667cbedb 100644 --- a/biscuit-auth/samples/README.md +++ b/biscuit-auth/samples/README.md @@ -1221,9 +1221,6 @@ public keys: [] ``` check if true; check if !false; -check if !false && true; -check if false || true; -check if (true || false) && true; check if true === true; check if false === false; check if 1 < 2; @@ -1234,7 +1231,7 @@ check if 2 >= 1; check if 2 >= 2; check if 3 === 3; check if 1 + 2 * 3 - 4 / 2 === 5; -check if "hello world".starts_with("hello") && "hello world".ends_with("world"); +check if "hello world".starts_with("hello"), "hello world".ends_with("world"); check if "aaabde".matches("a*c?.e"); check if "aaabde".contains("abd"); check if "aaabde" === "aaa" + "b" + "de"; @@ -1270,7 +1267,7 @@ allow if true; ``` revocation ids: -- `3d5b23b502b3dd920bfb68b9039164d1563bb8927210166fa5c17f41b76b31bb957bc2ed3318452958f658baa2d398fe4cf25c58a27e6c8bc42c9702c8aa1b0c` +- `d0420227266e3583a42dfaa0e38550d99f681d150dd18856f3af9a697bc9c5c8bf06b4b0fe5b9df0377d1b963574e2fd210a0a76a8b0756a65f640c602bebd07` authorizer world: ``` @@ -1284,15 +1281,13 @@ World { ), checks: [ "check if !false", - "check if !false && true", "check if \"aaabde\" === \"aaa\" + \"b\" + \"de\"", "check if \"aaabde\".contains(\"abd\")", "check if \"aaabde\".matches(\"a*c?.e\")", "check if \"abcD12\" === \"abcD12\"", "check if \"abcD12\".length() === 6", - "check if \"hello world\".starts_with(\"hello\") && \"hello world\".ends_with(\"world\")", + "check if \"hello world\".starts_with(\"hello\"), \"hello world\".ends_with(\"world\")", "check if \"é\".length() === 2", - "check if (true || false) && true", "check if 1 + 2 * 3 - 4 / 2 === 5", "check if 1 < 2", "check if 1 <= 1", @@ -1320,7 +1315,6 @@ World { "check if [false, true].contains(true)", "check if [hex:12ab, hex:34de].contains(hex:34de)", "check if false === false", - "check if false || true", "check if hex:12ab === hex:12ab", "check if true", "check if true === true", @@ -2263,9 +2257,9 @@ symbols: [] public keys: [] ``` -check if true || 10000000000 * 10000000000 != 0; -check if true || 9223372036854775807 + 1 != 0; -check if true || -9223372036854775808 - 1 != 0; +check if 10000000000 * 10000000000 != 0; +check if 9223372036854775807 + 1 != 0; +check if -9223372036854775808 - 1 != 0; ``` ### validation @@ -2276,7 +2270,7 @@ allow if true; ``` revocation ids: -- `a57be539aae237040fe6c2c28c4263516147c9f0d1d7ba88a385f1574f504c544164a2c747efd8b30eaab9d351c383cc1875642f173546d5f4b53b2220c87a0a` +- `365092619226161cf3973343f02c829fe05ab2b0d01f09555272348c9fcce041846be6159badd643aee108c9ce735ca8d12a009979c46b6e2c46e7999824c008` authorizer world: ``` @@ -2289,9 +2283,9 @@ World { 0, ), checks: [ - "check if true || -9223372036854775808 - 1 != 0", - "check if true || 10000000000 * 10000000000 != 0", - "check if true || 9223372036854775807 + 1 != 0", + "check if -9223372036854775808 - 1 != 0", + "check if 10000000000 * 10000000000 != 0", + "check if 9223372036854775807 + 1 != 0", ], }, ] @@ -2765,3 +2759,110 @@ World { result: `Err(FailedLogic(Unauthorized { policy: Allow(0), checks: [Block(FailedBlockCheck { block_id: 0, check_id: 0, rule: "check if fact(1, $value), 1 == $value" })] }))` + +------------------------------ + +## test laziness and closures: test032_laziness_closures.bc +### token + +authority: +symbols: ["x", "p", "q"] + +public keys: [] + +``` +check if !false && true; +check if false || true; +check if (true || false) && true; +check if !(false && "x".intersection("x")); +check if true || "x".intersection("x"); +check if [1, 2, 3].all($p -> $p > 0); +check if ![1, 2, 3].all($p -> $p == 2); +check if [1, 2, 3].any($p -> $p > 2); +check if ![1, 2, 3].any($p -> $p > 3); +check if [1, 2, 3].any($p -> $p > 1 && [3, 4, 5].any($q -> $p == $q)); +``` + +### validation + +authorizer code: +``` +allow if true; +``` + +revocation ids: +- `65e4da4fa213559d3b1097424504d2c9daeb28b4db51c49254852b6f57dc55e200f2f977b459f0c35e17c3c06394bfcaf5db7106e23bb2a623f48c4b84649a0b` + +authorizer world: +``` +World { + facts: [] + rules: [] + checks: [ + Checks { + origin: Some( + 0, + ), + checks: [ + "check if !(false && \"x\".intersection(\"x\"))", + "check if ![1, 2, 3].all($p -> $p == 2)", + "check if ![1, 2, 3].any($p -> $p > 3)", + "check if !false && true", + "check if (true || false) && true", + "check if [1, 2, 3].all($p -> $p > 0)", + "check if [1, 2, 3].any($p -> $p > 1 && [3, 4, 5].any($q -> $p == $q))", + "check if [1, 2, 3].any($p -> $p > 2)", + "check if false || true", + "check if true || \"x\".intersection(\"x\")", + ], + }, +] + policies: [ + "allow if true", +] +} +``` + +result: `Ok(0)` +### validation for "shadowing" + +authorizer code: +``` +allow if [true].any($p -> [true].all($p -> $p)); +``` + +revocation ids: +- `65e4da4fa213559d3b1097424504d2c9daeb28b4db51c49254852b6f57dc55e200f2f977b459f0c35e17c3c06394bfcaf5db7106e23bb2a623f48c4b84649a0b` + +authorizer world: +``` +World { + facts: [] + rules: [] + checks: [ + Checks { + origin: Some( + 0, + ), + checks: [ + "check if !(false && \"x\".intersection(\"x\"))", + "check if ![1, 2, 3].all($p -> $p == 2)", + "check if ![1, 2, 3].any($p -> $p > 3)", + "check if !false && true", + "check if (true || false) && true", + "check if [1, 2, 3].all($p -> $p > 0)", + "check if [1, 2, 3].any($p -> $p > 1 && [3, 4, 5].any($q -> $p == $q))", + "check if [1, 2, 3].any($p -> $p > 2)", + "check if false || true", + "check if true || \"x\".intersection(\"x\")", + ], + }, +] + policies: [ + "allow if [true].any($p -> [true].all($p -> $p))", +] +} +``` + +result: `Err(Execution(ShadowedVariable))` + diff --git a/biscuit-auth/samples/samples.json b/biscuit-auth/samples/samples.json index b2fbf2bd..82856d84 100644 --- a/biscuit-auth/samples/samples.json +++ b/biscuit-auth/samples/samples.json @@ -1245,7 +1245,7 @@ ], "public_keys": [], "external_key": null, - "code": "check if true;\ncheck if !false;\ncheck if !false && true;\ncheck if false || true;\ncheck if (true || false) && true;\ncheck if true === true;\ncheck if false === false;\ncheck if 1 < 2;\ncheck if 2 > 1;\ncheck if 1 <= 2;\ncheck if 1 <= 1;\ncheck if 2 >= 1;\ncheck if 2 >= 2;\ncheck if 3 === 3;\ncheck if 1 + 2 * 3 - 4 / 2 === 5;\ncheck if \"hello world\".starts_with(\"hello\") && \"hello world\".ends_with(\"world\");\ncheck if \"aaabde\".matches(\"a*c?.e\");\ncheck if \"aaabde\".contains(\"abd\");\ncheck if \"aaabde\" === \"aaa\" + \"b\" + \"de\";\ncheck if \"abcD12\" === \"abcD12\";\ncheck if \"abcD12\".length() === 6;\ncheck if \"é\".length() === 2;\ncheck if 2019-12-04T09:46:41Z < 2020-12-04T09:46:41Z;\ncheck if 2020-12-04T09:46:41Z > 2019-12-04T09:46:41Z;\ncheck if 2019-12-04T09:46:41Z <= 2020-12-04T09:46:41Z;\ncheck if 2020-12-04T09:46:41Z >= 2020-12-04T09:46:41Z;\ncheck if 2020-12-04T09:46:41Z >= 2019-12-04T09:46:41Z;\ncheck if 2020-12-04T09:46:41Z >= 2020-12-04T09:46:41Z;\ncheck if 2020-12-04T09:46:41Z === 2020-12-04T09:46:41Z;\ncheck if hex:12ab === hex:12ab;\ncheck if [1, 2].contains(2);\ncheck if [2019-12-04T09:46:41Z, 2020-12-04T09:46:41Z].contains(2020-12-04T09:46:41Z);\ncheck if [false, true].contains(true);\ncheck if [\"abc\", \"def\"].contains(\"abc\");\ncheck if [hex:12ab, hex:34de].contains(hex:34de);\ncheck if [1, 2].contains([2]);\ncheck if [1, 2] === [1, 2];\ncheck if [1, 2].intersection([2, 3]) === [2];\ncheck if [1, 2].union([2, 3]) === [1, 2, 3];\ncheck if [1, 2, 3].intersection([1, 2]).contains(1);\ncheck if [1, 2, 3].intersection([1, 2]).length() === 2;\n" + "code": "check if true;\ncheck if !false;\ncheck if true === true;\ncheck if false === false;\ncheck if 1 < 2;\ncheck if 2 > 1;\ncheck if 1 <= 2;\ncheck if 1 <= 1;\ncheck if 2 >= 1;\ncheck if 2 >= 2;\ncheck if 3 === 3;\ncheck if 1 + 2 * 3 - 4 / 2 === 5;\ncheck if \"hello world\".starts_with(\"hello\"), \"hello world\".ends_with(\"world\");\ncheck if \"aaabde\".matches(\"a*c?.e\");\ncheck if \"aaabde\".contains(\"abd\");\ncheck if \"aaabde\" === \"aaa\" + \"b\" + \"de\";\ncheck if \"abcD12\" === \"abcD12\";\ncheck if \"abcD12\".length() === 6;\ncheck if \"é\".length() === 2;\ncheck if 2019-12-04T09:46:41Z < 2020-12-04T09:46:41Z;\ncheck if 2020-12-04T09:46:41Z > 2019-12-04T09:46:41Z;\ncheck if 2019-12-04T09:46:41Z <= 2020-12-04T09:46:41Z;\ncheck if 2020-12-04T09:46:41Z >= 2020-12-04T09:46:41Z;\ncheck if 2020-12-04T09:46:41Z >= 2019-12-04T09:46:41Z;\ncheck if 2020-12-04T09:46:41Z >= 2020-12-04T09:46:41Z;\ncheck if 2020-12-04T09:46:41Z === 2020-12-04T09:46:41Z;\ncheck if hex:12ab === hex:12ab;\ncheck if [1, 2].contains(2);\ncheck if [2019-12-04T09:46:41Z, 2020-12-04T09:46:41Z].contains(2020-12-04T09:46:41Z);\ncheck if [false, true].contains(true);\ncheck if [\"abc\", \"def\"].contains(\"abc\");\ncheck if [hex:12ab, hex:34de].contains(hex:34de);\ncheck if [1, 2].contains([2]);\ncheck if [1, 2] === [1, 2];\ncheck if [1, 2].intersection([2, 3]) === [2];\ncheck if [1, 2].union([2, 3]) === [1, 2, 3];\ncheck if [1, 2, 3].intersection([1, 2]).contains(1);\ncheck if [1, 2, 3].intersection([1, 2]).length() === 2;\n" } ], "validations": { @@ -1258,15 +1258,13 @@ "origin": 0, "checks": [ "check if !false", - "check if !false && true", "check if \"aaabde\" === \"aaa\" + \"b\" + \"de\"", "check if \"aaabde\".contains(\"abd\")", "check if \"aaabde\".matches(\"a*c?.e\")", "check if \"abcD12\" === \"abcD12\"", "check if \"abcD12\".length() === 6", - "check if \"hello world\".starts_with(\"hello\") && \"hello world\".ends_with(\"world\")", + "check if \"hello world\".starts_with(\"hello\"), \"hello world\".ends_with(\"world\")", "check if \"é\".length() === 2", - "check if (true || false) && true", "check if 1 + 2 * 3 - 4 / 2 === 5", "check if 1 < 2", "check if 1 <= 1", @@ -1294,7 +1292,6 @@ "check if [false, true].contains(true)", "check if [hex:12ab, hex:34de].contains(hex:34de)", "check if false === false", - "check if false || true", "check if hex:12ab === hex:12ab", "check if true", "check if true === true" @@ -1310,7 +1307,7 @@ }, "authorizer_code": "allow if true;\n", "revocation_ids": [ - "3d5b23b502b3dd920bfb68b9039164d1563bb8927210166fa5c17f41b76b31bb957bc2ed3318452958f658baa2d398fe4cf25c58a27e6c8bc42c9702c8aa1b0c" + "d0420227266e3583a42dfaa0e38550d99f681d150dd18856f3af9a697bc9c5c8bf06b4b0fe5b9df0377d1b963574e2fd210a0a76a8b0756a65f640c602bebd07" ] } } @@ -2098,7 +2095,7 @@ "symbols": [], "public_keys": [], "external_key": null, - "code": "check if true || 10000000000 * 10000000000 != 0;\ncheck if true || 9223372036854775807 + 1 != 0;\ncheck if true || -9223372036854775808 - 1 != 0;\n" + "code": "check if 10000000000 * 10000000000 != 0;\ncheck if 9223372036854775807 + 1 != 0;\ncheck if -9223372036854775808 - 1 != 0;\n" } ], "validations": { @@ -2110,9 +2107,9 @@ { "origin": 0, "checks": [ - "check if true || -9223372036854775808 - 1 != 0", - "check if true || 10000000000 * 10000000000 != 0", - "check if true || 9223372036854775807 + 1 != 0" + "check if -9223372036854775808 - 1 != 0", + "check if 10000000000 * 10000000000 != 0", + "check if 9223372036854775807 + 1 != 0" ] } ], @@ -2127,7 +2124,7 @@ }, "authorizer_code": "allow if true;\n", "revocation_ids": [ - "a57be539aae237040fe6c2c28c4263516147c9f0d1d7ba88a385f1574f504c544164a2c747efd8b30eaab9d351c383cc1875642f173546d5f4b53b2220c87a0a" + "365092619226161cf3973343f02c829fe05ab2b0d01f09555272348c9fcce041846be6159badd643aee108c9ce735ca8d12a009979c46b6e2c46e7999824c008" ] } } @@ -2606,6 +2603,92 @@ ] } } + }, + { + "title": "test laziness and closures", + "filename": "test032_laziness_closures.bc", + "token": [ + { + "symbols": [ + "x", + "p", + "q" + ], + "public_keys": [], + "external_key": null, + "code": "check if !false && true;\ncheck if false || true;\ncheck if (true || false) && true;\ncheck if !(false && \"x\".intersection(\"x\"));\ncheck if true || \"x\".intersection(\"x\");\ncheck if [1, 2, 3].all($p -> $p > 0);\ncheck if ![1, 2, 3].all($p -> $p == 2);\ncheck if [1, 2, 3].any($p -> $p > 2);\ncheck if ![1, 2, 3].any($p -> $p > 3);\ncheck if [1, 2, 3].any($p -> $p > 1 && [3, 4, 5].any($q -> $p == $q));\n" + } + ], + "validations": { + "": { + "world": { + "facts": [], + "rules": [], + "checks": [ + { + "origin": 0, + "checks": [ + "check if !(false && \"x\".intersection(\"x\"))", + "check if ![1, 2, 3].all($p -> $p == 2)", + "check if ![1, 2, 3].any($p -> $p > 3)", + "check if !false && true", + "check if (true || false) && true", + "check if [1, 2, 3].all($p -> $p > 0)", + "check if [1, 2, 3].any($p -> $p > 1 && [3, 4, 5].any($q -> $p == $q))", + "check if [1, 2, 3].any($p -> $p > 2)", + "check if false || true", + "check if true || \"x\".intersection(\"x\")" + ] + } + ], + "policies": [ + "allow if true" + ] + }, + "result": { + "Ok": 0 + }, + "authorizer_code": "allow if true;\n", + "revocation_ids": [ + "65e4da4fa213559d3b1097424504d2c9daeb28b4db51c49254852b6f57dc55e200f2f977b459f0c35e17c3c06394bfcaf5db7106e23bb2a623f48c4b84649a0b" + ] + }, + "shadowing": { + "world": { + "facts": [], + "rules": [], + "checks": [ + { + "origin": 0, + "checks": [ + "check if !(false && \"x\".intersection(\"x\"))", + "check if ![1, 2, 3].all($p -> $p == 2)", + "check if ![1, 2, 3].any($p -> $p > 3)", + "check if !false && true", + "check if (true || false) && true", + "check if [1, 2, 3].all($p -> $p > 0)", + "check if [1, 2, 3].any($p -> $p > 1 && [3, 4, 5].any($q -> $p == $q))", + "check if [1, 2, 3].any($p -> $p > 2)", + "check if false || true", + "check if true || \"x\".intersection(\"x\")" + ] + } + ], + "policies": [ + "allow if [true].any($p -> [true].all($p -> $p))" + ] + }, + "result": { + "Err": { + "Execution": "ShadowedVariable" + } + }, + "authorizer_code": "allow if [true].any($p -> [true].all($p -> $p));\n", + "revocation_ids": [ + "65e4da4fa213559d3b1097424504d2c9daeb28b4db51c49254852b6f57dc55e200f2f977b459f0c35e17c3c06394bfcaf5db7106e23bb2a623f48c4b84649a0b" + ] + } + } } ] } diff --git a/biscuit-auth/samples/test017_expressions.bc b/biscuit-auth/samples/test017_expressions.bc index 1f3234c07694321afef884cd0b547baa27d64096..7be7a4fda92f3feecff08e8c3a68d1c53bea2100 100644 GIT binary patch delta 143 zcmZqYo6gHB^nr(~bs}rc#qOR?Yy~bkE+!6XDG4qXE+zv8E*2ps4j>6+Gjg#=F>&x3NpVSl zRS3aUz~%Uibhxy@aw;$#5Ho=K_`oIs&48=s-PqvCXl}-30@k6&#mdDj(ZB(w8acpr zu_22!LB!Z0uGsvC(VWG>Hd=Wr)8@OAxPNEtWS*FEG0b|$q#^;a{G|u$9k*v2?w(qG o=&iAYt7gQvh+T^=&-mx_DJEi3T~7BAo#{*`R!Q?rE@o2!0FwDK%K!iX diff --git a/biscuit-auth/samples/test027_integer_wraparound.bc b/biscuit-auth/samples/test027_integer_wraparound.bc index 5f0ff2e73c827e18368706e6d0ac09c27e1b73ab..0d01dd5682cf1555fe07ceffc6f5fa4e0bbc5a45 100644 GIT binary patch delta 126 zcmX@fw3JCm=pQ53WJU>ABRwu1E+!6XDYc0bYFtRX5Fk&VOBXDsK5>Hj#Jw60W&x8D zC#i|ae4cLX{6VK_{)4DZ8!pImh6WXx^vpl=z_BI!ndt1b*PPcqlv@vMQqr0Fcxw)c^nh delta 163 zcmZ3=bdpIzXb}_Fc18(SBP%WoE+!6XDI+cxE+zxUi9%{(Tr5&d9DEouPH-7(E=#Z; zW2l}YbvB5ei8D1EmR3KtT=mGDh5y;1Lp@H(fr;)XKU}=Nt7CEN$8i4upAg5CMaSLW k-`LEzYUky^!_8+TN>lX3P2H}3*=nt%aH5K9G9#l30LXDJ5dZ)H diff --git a/biscuit-auth/samples/test032_laziness_closures.bc b/biscuit-auth/samples/test032_laziness_closures.bc new file mode 100644 index 0000000000000000000000000000000000000000..b95a36efbf740a585241ac3209c0028f44883484 GIT binary patch literal 758 zcmWeS&&u_Lg^RI*i?M)=GT|};%jiqK72YQjI6jUz^aY71i1LDc(|AZ7=f4xh?%)0 zl*Bk1!S;bDpq~X8gg~}J+$d?}z-0&4YeAG=Ca_*nu!|)jVM(Ix5ECG_OOa?hGm^`t zpthHDm4e-$kH!6=O2KG;XM~s_WU6E=1Pd!>V3>fyg;hvWNt~kz9#+7Z0`WkBB?=~` zz@#|X^HM@891KDV0-?tPHZBZ^G+!@jeoSqbQS!_f<7Msl>mD<4`EP%YduS6W_cOT>NML;$je BY`*{i literal 0 HcmV?d00001 diff --git a/biscuit-auth/src/datalog/expression.rs b/biscuit-auth/src/datalog/expression.rs index 5dce998a..c07e42e4 100644 --- a/biscuit-auth/src/datalog/expression.rs +++ b/biscuit-auth/src/datalog/expression.rs @@ -3,7 +3,7 @@ use crate::error; use super::Term; use super::{SymbolTable, TemporarySymbolTable}; use regex::Regex; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; #[derive(Debug, Clone, PartialEq, Hash, Eq)] pub struct Expression { @@ -15,6 +15,7 @@ pub enum Op { Value(Term), Unary(Unary), Binary(Binary), + Closure(Vec, Vec), } /// Unary operation code @@ -82,9 +83,63 @@ pub enum Binary { NotEqual, HeterogeneousEqual, HeterogeneousNotEqual, + LazyAnd, + LazyOr, + All, + Any, } impl Binary { + fn evaluate_with_closure( + &self, + left: Term, + right: Vec, + params: &[u32], + values: &mut HashMap, + symbols: &mut TemporarySymbolTable, + ) -> Result { + match (self, left, params) { + (Binary::LazyOr, Term::Bool(true), []) => Ok(Term::Bool(true)), + (Binary::LazyOr, Term::Bool(false), []) => { + let e = Expression { ops: right.clone() }; + e.evaluate(values, symbols) + } + (Binary::LazyAnd, Term::Bool(false), []) => Ok(Term::Bool(false)), + (Binary::LazyAnd, Term::Bool(true), []) => { + let e = Expression { ops: right.clone() }; + e.evaluate(values, symbols) + } + (Binary::All, Term::Set(set_values), [param]) => { + for value in set_values.iter() { + values.insert(*param, value.clone()); + let e = Expression { ops: right.clone() }; + let result = e.evaluate(values, symbols); + values.remove(param); + match result? { + Term::Bool(true) => {} + Term::Bool(false) => return Ok(Term::Bool(false)), + _ => return Err(error::Expression::InvalidType), + }; + } + Ok(Term::Bool(true)) + } + (Binary::Any, Term::Set(set_values), [param]) => { + for value in set_values.iter() { + values.insert(*param, value.clone()); + let e = Expression { ops: right.clone() }; + let result = e.evaluate(values, symbols); + values.remove(param); + match result? { + Term::Bool(false) => {} + Term::Bool(true) => return Ok(Term::Bool(true)), + _ => return Err(error::Expression::InvalidType), + }; + } + Ok(Term::Bool(false)) + } + (_, _, _) => Err(error::Expression::InvalidType), + } + } fn evaluate( &self, left: Term, @@ -245,12 +300,8 @@ impl Binary { (Binary::NotEqual | Binary::HeterogeneousNotEqual, Term::Null, Term::Null) => { Ok(Term::Bool(false)) } - (Binary::HeterogeneousNotEqual, Term::Null, _) => { - Ok(Term::Bool(true)) - } - (Binary::HeterogeneousNotEqual, _, Term::Null) => { - Ok(Term::Bool(true)) - } + (Binary::HeterogeneousNotEqual, Term::Null, _) => Ok(Term::Bool(true)), + (Binary::HeterogeneousNotEqual, _, Term::Null) => Ok(Term::Bool(true)), (Binary::HeterogeneousEqual, _, _) => Ok(Term::Bool(false)), (Binary::HeterogeneousNotEqual, _, _) => Ok(Term::Bool(true)), @@ -280,58 +331,98 @@ impl Binary { Binary::Sub => format!("{} - {}", left, right), Binary::Mul => format!("{} * {}", left, right), Binary::Div => format!("{} / {}", left, right), - Binary::And => format!("{} && {}", left, right), - Binary::Or => format!("{} || {}", left, right), + Binary::And => format!("{} &&! {}", left, right), + Binary::Or => format!("{} ||! {}", left, right), Binary::Intersection => format!("{}.intersection({})", left, right), Binary::Union => format!("{}.union({})", left, right), Binary::BitwiseAnd => format!("{} & {}", left, right), Binary::BitwiseOr => format!("{} | {}", left, right), Binary::BitwiseXor => format!("{} ^ {}", left, right), + Binary::LazyAnd => format!("{left} && {right}"), + Binary::LazyOr => format!("{left} || {right}"), + Binary::All => format!("{left}.all({right})"), + Binary::Any => format!("{left}.any({right})"), } } } +#[derive(Clone, Debug)] +enum StackElem { + Closure(Vec, Vec), + Term(Term), +} + impl Expression { pub fn evaluate( &self, values: &HashMap, symbols: &mut TemporarySymbolTable, ) -> Result { - let mut stack: Vec = Vec::new(); + let mut stack: Vec = Vec::new(); for op in self.ops.iter() { - //println!("op: {:?}\t| stack: {:?}", op, stack); + // println!("op: {:?}\t| stack: {:?}", op, stack); + match op { Op::Value(Term::Variable(i)) => match values.get(i) { - Some(term) => stack.push(term.clone()), + Some(term) => stack.push(StackElem::Term(term.clone())), None => { //println!("unknown variable {}", i); return Err(error::Expression::UnknownVariable(*i)); } }, - Op::Value(term) => stack.push(term.clone()), + Op::Value(term) => stack.push(StackElem::Term(term.clone())), Op::Unary(unary) => match stack.pop() { - None => { - //println!("expected a value on the stack"); + Some(StackElem::Term(term)) => { + stack.push(StackElem::Term(unary.evaluate(term, symbols)?)) + } + _ => { return Err(error::Expression::InvalidStack); } - Some(term) => stack.push(unary.evaluate(term, symbols)?), }, Op::Binary(binary) => match (stack.pop(), stack.pop()) { - (Some(right_term), Some(left_term)) => { - stack.push(binary.evaluate(left_term, right_term, symbols)?) + (Some(StackElem::Term(right_term)), Some(StackElem::Term(left_term))) => stack + .push(StackElem::Term( + binary.evaluate(left_term, right_term, symbols)?, + )), + ( + Some(StackElem::Closure(params, right_ops)), + Some(StackElem::Term(left_term)), + ) => { + if values + .keys() + .collect::>() + .intersection(¶ms.iter().collect()) + .next() + .is_some() + { + return Err(error::Expression::ShadowedVariable); + } + let mut values = values.clone(); + stack.push(StackElem::Term(binary.evaluate_with_closure( + left_term, + right_ops, + ¶ms, + &mut values, + symbols, + )?)) } _ => { - //println!("expected two values on the stack"); return Err(error::Expression::InvalidStack); } }, + Op::Closure(params, ops) => { + stack.push(StackElem::Closure(params.clone(), ops.clone())); + } } } if stack.len() == 1 { - Ok(stack.remove(0)) + match stack.remove(0) { + StackElem::Term(t) => Ok(t), + _ => Err(error::Expression::InvalidStack), + } } else { Err(error::Expression::InvalidStack) } @@ -352,6 +443,24 @@ impl Expression { (Some(right), Some(left)) => stack.push(binary.print(left, right, symbols)), _ => return None, }, + Op::Closure(params, ops) => { + let exp_body = Expression { ops: ops.clone() }; + let body = match exp_body.print(symbols) { + Some(c) => c, + _ => return None, + }; + + if params.is_empty() { + stack.push(body); + } else { + let param_group = params + .iter() + .map(|s| symbols.print_term(&Term::Variable(*s))) + .collect::>() + .join(", "); + stack.push(format!("{param_group} -> {body}")); + } + } } } @@ -676,4 +785,243 @@ mod tests { } } } + + #[test] + fn laziness() { + let symbols = SymbolTable::new(); + let mut symbols = TemporarySymbolTable::new(&symbols); + + let ops1 = vec![ + Op::Value(Term::Bool(false)), + Op::Closure( + vec![], + vec![ + Op::Value(Term::Bool(true)), + Op::Closure(vec![], vec![Op::Value(Term::Bool(true))]), + Op::Binary(Binary::LazyAnd), + ], + ), + Op::Binary(Binary::LazyOr), + ]; + let e2 = Expression { ops: ops1 }; + + let res2 = e2.evaluate(&HashMap::new(), &mut symbols).unwrap(); + assert_eq!(res2, Term::Bool(true)); + } + + #[test] + fn any() { + let mut symbols = SymbolTable::new(); + let p = symbols.insert("param") as u32; + let mut tmp_symbols = TemporarySymbolTable::new(&symbols); + + let ops1 = vec![ + Op::Value(Term::Set([Term::Bool(false), Term::Bool(true)].into())), + Op::Closure(vec![p], vec![Op::Value(Term::Variable(p))]), + Op::Binary(Binary::Any), + ]; + let e1 = Expression { ops: ops1 }; + println!("{:?}", e1.print(&symbols)); + + let res1 = e1.evaluate(&HashMap::new(), &mut tmp_symbols).unwrap(); + assert_eq!(res1, Term::Bool(true)); + + let ops2 = vec![ + Op::Value(Term::Set([Term::Integer(1), Term::Integer(2)].into())), + Op::Closure( + vec![p], + vec![ + Op::Value(Term::Variable(p)), + Op::Value(Term::Integer(0)), + Op::Binary(Binary::LessThan), + ], + ), + Op::Binary(Binary::Any), + ]; + let e2 = Expression { ops: ops2 }; + println!("{:?}", e2.print(&symbols)); + + let res2 = e2.evaluate(&HashMap::new(), &mut tmp_symbols).unwrap(); + assert_eq!(res2, Term::Bool(false)); + + let ops3 = vec![ + Op::Value(Term::Set([Term::Integer(1), Term::Integer(2)].into())), + Op::Closure(vec![p], vec![Op::Value(Term::Integer(0))]), + Op::Binary(Binary::Any), + ]; + let e3 = Expression { ops: ops3 }; + println!("{:?}", e3.print(&symbols)); + + let err3 = e3.evaluate(&HashMap::new(), &mut tmp_symbols).unwrap_err(); + assert_eq!(err3, error::Expression::InvalidType); + } + + #[test] + fn all() { + let mut symbols = SymbolTable::new(); + let p = symbols.insert("param") as u32; + let mut tmp_symbols = TemporarySymbolTable::new(&symbols); + + let ops1 = vec![ + Op::Value(Term::Set([Term::Integer(1), Term::Integer(2)].into())), + Op::Closure( + vec![p], + vec![ + Op::Value(Term::Variable(p)), + Op::Value(Term::Integer(0)), + Op::Binary(Binary::GreaterThan), + ], + ), + Op::Binary(Binary::All), + ]; + let e1 = Expression { ops: ops1 }; + println!("{:?}", e1.print(&symbols)); + + let res1 = e1.evaluate(&HashMap::new(), &mut tmp_symbols).unwrap(); + assert_eq!(res1, Term::Bool(true)); + + let ops2 = vec![ + Op::Value(Term::Set([Term::Integer(1), Term::Integer(2)].into())), + Op::Closure( + vec![p], + vec![ + Op::Value(Term::Variable(p)), + Op::Value(Term::Integer(0)), + Op::Binary(Binary::LessThan), + ], + ), + Op::Binary(Binary::All), + ]; + let e2 = Expression { ops: ops2 }; + println!("{:?}", e2.print(&symbols)); + + let res2 = e2.evaluate(&HashMap::new(), &mut tmp_symbols).unwrap(); + assert_eq!(res2, Term::Bool(false)); + + let ops3 = vec![ + Op::Value(Term::Set([Term::Integer(1), Term::Integer(2)].into())), + Op::Closure(vec![p], vec![Op::Value(Term::Integer(0))]), + Op::Binary(Binary::All), + ]; + let e3 = Expression { ops: ops3 }; + println!("{:?}", e3.print(&symbols)); + + let err3 = e3.evaluate(&HashMap::new(), &mut tmp_symbols).unwrap_err(); + assert_eq!(err3, error::Expression::InvalidType); + } + + #[test] + fn nested_closures() { + let mut symbols = SymbolTable::new(); + let p = symbols.insert("p") as u32; + let q = symbols.insert("q") as u32; + let mut tmp_symbols = TemporarySymbolTable::new(&symbols); + + let ops1 = vec![ + Op::Value(Term::Set( + [Term::Integer(1), Term::Integer(2), Term::Integer(3)].into(), + )), + Op::Closure( + vec![p], + vec![ + Op::Value(Term::Variable(p)), + Op::Value(Term::Integer(1)), + Op::Binary(Binary::GreaterThan), + Op::Closure( + vec![], + vec![ + Op::Value(Term::Set( + [Term::Integer(3), Term::Integer(4), Term::Integer(5)].into(), + )), + Op::Closure( + vec![q], + vec![ + Op::Value(Term::Variable(p)), + Op::Value(Term::Variable(q)), + Op::Binary(Binary::Equal), + ], + ), + Op::Binary(Binary::Any), + ], + ), + Op::Binary(Binary::LazyAnd), + ], + ), + Op::Binary(Binary::Any), + ]; + let e1 = Expression { ops: ops1 }; + println!("{}", e1.print(&symbols).unwrap()); + + let res1 = e1.evaluate(&HashMap::new(), &mut tmp_symbols).unwrap(); + assert_eq!(res1, Term::Bool(true)); + } + + #[test] + fn variable_shadowing() { + let mut symbols = SymbolTable::new(); + let p = symbols.insert("param") as u32; + let mut tmp_symbols = TemporarySymbolTable::new(&symbols); + + let ops1 = vec![ + Op::Value(Term::Set([Term::Integer(1), Term::Integer(2)].into())), + Op::Closure( + vec![p], + vec![ + Op::Value(Term::Variable(p)), + Op::Value(Term::Integer(0)), + Op::Binary(Binary::GreaterThan), + ], + ), + Op::Binary(Binary::All), + ]; + let e1 = Expression { ops: ops1 }; + println!("{:?}", e1.print(&symbols)); + + let mut values = HashMap::new(); + values.insert(p, Term::Null); + let res1 = e1.evaluate(&values, &mut tmp_symbols); + assert_eq!(res1, Err(error::Expression::ShadowedVariable)); + + let mut symbols = SymbolTable::new(); + let p = symbols.insert("p") as u32; + let mut tmp_symbols = TemporarySymbolTable::new(&symbols); + + let ops2 = vec![ + Op::Value(Term::Set( + [Term::Integer(1), Term::Integer(2), Term::Integer(3)].into(), + )), + Op::Closure( + vec![p], + vec![ + Op::Value(Term::Variable(p)), + Op::Value(Term::Integer(1)), + Op::Binary(Binary::GreaterThan), + Op::Closure( + vec![], + vec![ + Op::Value(Term::Set( + [Term::Integer(3), Term::Integer(4), Term::Integer(5)].into(), + )), + Op::Closure( + vec![p], + vec![ + Op::Value(Term::Variable(p)), + Op::Value(Term::Variable(p)), + Op::Binary(Binary::Equal), + ], + ), + Op::Binary(Binary::Any), + ], + ), + Op::Binary(Binary::LazyAnd), + ], + ), + Op::Binary(Binary::Any), + ]; + let e2 = Expression { ops: ops2 }; + println!("{}", e2.print(&symbols).unwrap()); + + let res2 = e2.evaluate(&HashMap::new(), &mut tmp_symbols); + assert_eq!(res2, Err(error::Expression::ShadowedVariable)); + } } diff --git a/biscuit-auth/src/datalog/mod.rs b/biscuit-auth/src/datalog/mod.rs index 71cf84cd..8b954e6b 100644 --- a/biscuit-auth/src/datalog/mod.rs +++ b/biscuit-auth/src/datalog/mod.rs @@ -878,7 +878,7 @@ impl SchemaVersion { } } else if version < 5 && self.contains_v5 { Err(error::Format::DeserializationError( - "v5 blocks must not have reject if".to_string(), + "v3 or v4 blocks must not have v5 features".to_string(), )) } else { Ok(()) @@ -918,7 +918,7 @@ pub fn get_schema_version( .any(|query| contains_v4_op(&query.expressions)) }); - // null + // null, heterogeneous equals, closures if !contains_v5 { contains_v5 = rules.iter().any(|rule| { contains_v5_predicate(&rule.head) @@ -967,10 +967,16 @@ fn contains_v5_op(expressions: &[Expression]) -> bool { expressions.iter().any(|expression| { expression.ops.iter().any(|op| match op { Op::Value(term) => contains_v5_term(term), - Op::Binary(binary) => match binary { - Binary::HeterogeneousEqual | Binary::HeterogeneousNotEqual => true, - _ => false, - }, + Op::Closure(_, _) => true, + Op::Binary(binary) => matches!( + binary, + Binary::HeterogeneousEqual + | Binary::HeterogeneousNotEqual + | Binary::LazyAnd + | Binary::LazyOr + | Binary::All + | Binary::Any + ), _ => false, }) }) @@ -1049,10 +1055,14 @@ mod tests { println!("adding r2: {}", syms.print_rule(&r2)); w.add_rule(0, &[0].iter().collect(), r2); - w.run_with_limits(&syms, RunLimits { - max_time: Duration::from_secs(10), - ..Default::default() - }).unwrap(); + w.run_with_limits( + &syms, + RunLimits { + max_time: Duration::from_secs(10), + ..Default::default() + }, + ) + .unwrap(); println!("parents:"); let res = w diff --git a/biscuit-auth/src/error.rs b/biscuit-auth/src/error.rs index 35856121..984369c5 100644 --- a/biscuit-auth/src/error.rs +++ b/biscuit-auth/src/error.rs @@ -248,6 +248,8 @@ pub enum Expression { DivideByZero, #[error("Wrong number of elements on stack")] InvalidStack, + #[error("Shadowed variable")] + ShadowedVariable, } /// runtime limits errors diff --git a/biscuit-auth/src/format/convert.rs b/biscuit-auth/src/format/convert.rs index 57402e52..590e0f7c 100644 --- a/biscuit-auth/src/format/convert.rs +++ b/biscuit-auth/src/format/convert.rs @@ -588,126 +588,145 @@ pub mod v2 { } } + fn token_op_to_proto_op(op: &Op) -> schema::Op { + let content = match op { + Op::Value(i) => schema::op::Content::Value(token_term_to_proto_id(i)), + Op::Unary(u) => { + use schema::op_unary::Kind; + + schema::op::Content::Unary(schema::OpUnary { + kind: match u { + Unary::Negate => Kind::Negate, + Unary::Parens => Kind::Parens, + Unary::Length => Kind::Length, + } as i32, + }) + } + Op::Binary(b) => { + use schema::op_binary::Kind; + + schema::op::Content::Binary(schema::OpBinary { + kind: match b { + Binary::LessThan => Kind::LessThan, + Binary::GreaterThan => Kind::GreaterThan, + Binary::LessOrEqual => Kind::LessOrEqual, + Binary::GreaterOrEqual => Kind::GreaterOrEqual, + Binary::Equal => Kind::Equal, + Binary::Contains => Kind::Contains, + Binary::Prefix => Kind::Prefix, + Binary::Suffix => Kind::Suffix, + Binary::Regex => Kind::Regex, + Binary::Add => Kind::Add, + Binary::Sub => Kind::Sub, + Binary::Mul => Kind::Mul, + Binary::Div => Kind::Div, + Binary::And => Kind::And, + Binary::Or => Kind::Or, + Binary::Intersection => Kind::Intersection, + Binary::Union => Kind::Union, + Binary::BitwiseAnd => Kind::BitwiseAnd, + Binary::BitwiseOr => Kind::BitwiseOr, + Binary::BitwiseXor => Kind::BitwiseXor, + Binary::NotEqual => Kind::NotEqual, + Binary::HeterogeneousEqual => Kind::HeterogeneousEqual, + Binary::HeterogeneousNotEqual => Kind::HeterogeneousNotEqual, + Binary::LazyAnd => Kind::LazyAnd, + Binary::LazyOr => Kind::LazyOr, + Binary::All => Kind::All, + Binary::Any => Kind::Any, + } as i32, + }) + } + Op::Closure(params, ops) => schema::op::Content::Closure(schema::OpClosure { + params: params.clone(), + ops: ops.iter().map(token_op_to_proto_op).collect(), + }), + }; + + schema::Op { + content: Some(content), + } + } + pub fn token_expression_to_proto_expression(input: &Expression) -> schema::ExpressionV2 { schema::ExpressionV2 { - ops: input - .ops - .iter() - .map(|op| { - let content = match op { - Op::Value(i) => schema::op::Content::Value(token_term_to_proto_id(i)), - Op::Unary(u) => { - use schema::op_unary::Kind; - - schema::op::Content::Unary(schema::OpUnary { - kind: match u { - Unary::Negate => Kind::Negate, - Unary::Parens => Kind::Parens, - Unary::Length => Kind::Length, - } as i32, - }) - } - Op::Binary(b) => { - use schema::op_binary::Kind; - - schema::op::Content::Binary(schema::OpBinary { - kind: match b { - Binary::LessThan => Kind::LessThan, - Binary::GreaterThan => Kind::GreaterThan, - Binary::LessOrEqual => Kind::LessOrEqual, - Binary::GreaterOrEqual => Kind::GreaterOrEqual, - Binary::Equal => Kind::Equal, - Binary::Contains => Kind::Contains, - Binary::Prefix => Kind::Prefix, - Binary::Suffix => Kind::Suffix, - Binary::Regex => Kind::Regex, - Binary::Add => Kind::Add, - Binary::Sub => Kind::Sub, - Binary::Mul => Kind::Mul, - Binary::Div => Kind::Div, - Binary::And => Kind::And, - Binary::Or => Kind::Or, - Binary::Intersection => Kind::Intersection, - Binary::Union => Kind::Union, - Binary::BitwiseAnd => Kind::BitwiseAnd, - Binary::BitwiseOr => Kind::BitwiseOr, - Binary::BitwiseXor => Kind::BitwiseXor, - Binary::NotEqual => Kind::NotEqual, - Binary::HeterogeneousEqual => Kind::HeterogeneousEqual, - Binary::HeterogeneousNotEqual => Kind::HeterogeneousNotEqual, - } as i32, - }) - } - }; - - schema::Op { - content: Some(content), - } - }) - .collect(), + ops: input.ops.iter().map(token_op_to_proto_op).collect(), } } + fn proto_op_to_token_op(op: &schema::Op) -> Result { + use schema::{op, op_binary, op_unary}; + Ok(match op.content.as_ref() { + Some(op::Content::Value(id)) => Op::Value(proto_id_to_token_term(id)?), + Some(op::Content::Unary(u)) => match op_unary::Kind::from_i32(u.kind) { + Some(op_unary::Kind::Negate) => Op::Unary(Unary::Negate), + Some(op_unary::Kind::Parens) => Op::Unary(Unary::Parens), + Some(op_unary::Kind::Length) => Op::Unary(Unary::Length), + None => { + return Err(error::Format::DeserializationError( + "deserialization error: unary operation is empty".to_string(), + )) + } + }, + Some(op::Content::Binary(b)) => match op_binary::Kind::from_i32(b.kind) { + Some(op_binary::Kind::LessThan) => Op::Binary(Binary::LessThan), + Some(op_binary::Kind::GreaterThan) => Op::Binary(Binary::GreaterThan), + Some(op_binary::Kind::LessOrEqual) => Op::Binary(Binary::LessOrEqual), + Some(op_binary::Kind::GreaterOrEqual) => Op::Binary(Binary::GreaterOrEqual), + Some(op_binary::Kind::Equal) => Op::Binary(Binary::Equal), + Some(op_binary::Kind::Contains) => Op::Binary(Binary::Contains), + Some(op_binary::Kind::Prefix) => Op::Binary(Binary::Prefix), + Some(op_binary::Kind::Suffix) => Op::Binary(Binary::Suffix), + Some(op_binary::Kind::Regex) => Op::Binary(Binary::Regex), + Some(op_binary::Kind::Add) => Op::Binary(Binary::Add), + Some(op_binary::Kind::Sub) => Op::Binary(Binary::Sub), + Some(op_binary::Kind::Mul) => Op::Binary(Binary::Mul), + Some(op_binary::Kind::Div) => Op::Binary(Binary::Div), + Some(op_binary::Kind::And) => Op::Binary(Binary::And), + Some(op_binary::Kind::Or) => Op::Binary(Binary::Or), + Some(op_binary::Kind::Intersection) => Op::Binary(Binary::Intersection), + Some(op_binary::Kind::Union) => Op::Binary(Binary::Union), + Some(op_binary::Kind::BitwiseAnd) => Op::Binary(Binary::BitwiseAnd), + Some(op_binary::Kind::BitwiseOr) => Op::Binary(Binary::BitwiseOr), + Some(op_binary::Kind::BitwiseXor) => Op::Binary(Binary::BitwiseXor), + Some(op_binary::Kind::NotEqual) => Op::Binary(Binary::NotEqual), + Some(op_binary::Kind::HeterogeneousEqual) => Op::Binary(Binary::HeterogeneousEqual), + Some(op_binary::Kind::HeterogeneousNotEqual) => { + Op::Binary(Binary::HeterogeneousNotEqual) + } + Some(op_binary::Kind::LazyAnd) => Op::Binary(Binary::LazyAnd), + Some(op_binary::Kind::LazyOr) => Op::Binary(Binary::LazyOr), + Some(op_binary::Kind::All) => Op::Binary(Binary::All), + Some(op_binary::Kind::Any) => Op::Binary(Binary::Any), + None => { + return Err(error::Format::DeserializationError( + "deserialization error: binary operation is empty".to_string(), + )) + } + }, + Some(op::Content::Closure(op_closure)) => Op::Closure( + op_closure.params.clone(), + op_closure + .ops + .iter() + .map(proto_op_to_token_op) + .collect::>()?, + ), + None => { + return Err(error::Format::DeserializationError( + "deserialization error: operation is empty".to_string(), + )) + } + }) + } + pub fn proto_expression_to_token_expression( input: &schema::ExpressionV2, ) -> Result { - use schema::{op, op_binary, op_unary}; let mut ops = Vec::new(); for op in input.ops.iter() { - let translated = match op.content.as_ref() { - Some(op::Content::Value(id)) => Op::Value(proto_id_to_token_term(id)?), - Some(op::Content::Unary(u)) => match op_unary::Kind::from_i32(u.kind) { - Some(op_unary::Kind::Negate) => Op::Unary(Unary::Negate), - Some(op_unary::Kind::Parens) => Op::Unary(Unary::Parens), - Some(op_unary::Kind::Length) => Op::Unary(Unary::Length), - None => { - return Err(error::Format::DeserializationError( - "deserialization error: unary operation is empty".to_string(), - )) - } - }, - Some(op::Content::Binary(b)) => match op_binary::Kind::from_i32(b.kind) { - Some(op_binary::Kind::LessThan) => Op::Binary(Binary::LessThan), - Some(op_binary::Kind::GreaterThan) => Op::Binary(Binary::GreaterThan), - Some(op_binary::Kind::LessOrEqual) => Op::Binary(Binary::LessOrEqual), - Some(op_binary::Kind::GreaterOrEqual) => Op::Binary(Binary::GreaterOrEqual), - Some(op_binary::Kind::Equal) => Op::Binary(Binary::Equal), - Some(op_binary::Kind::Contains) => Op::Binary(Binary::Contains), - Some(op_binary::Kind::Prefix) => Op::Binary(Binary::Prefix), - Some(op_binary::Kind::Suffix) => Op::Binary(Binary::Suffix), - Some(op_binary::Kind::Regex) => Op::Binary(Binary::Regex), - Some(op_binary::Kind::Add) => Op::Binary(Binary::Add), - Some(op_binary::Kind::Sub) => Op::Binary(Binary::Sub), - Some(op_binary::Kind::Mul) => Op::Binary(Binary::Mul), - Some(op_binary::Kind::Div) => Op::Binary(Binary::Div), - Some(op_binary::Kind::And) => Op::Binary(Binary::And), - Some(op_binary::Kind::Or) => Op::Binary(Binary::Or), - Some(op_binary::Kind::Intersection) => Op::Binary(Binary::Intersection), - Some(op_binary::Kind::Union) => Op::Binary(Binary::Union), - Some(op_binary::Kind::BitwiseAnd) => Op::Binary(Binary::BitwiseAnd), - Some(op_binary::Kind::BitwiseOr) => Op::Binary(Binary::BitwiseOr), - Some(op_binary::Kind::BitwiseXor) => Op::Binary(Binary::BitwiseXor), - Some(op_binary::Kind::NotEqual) => Op::Binary(Binary::NotEqual), - Some(op_binary::Kind::HeterogeneousEqual) => { - Op::Binary(Binary::HeterogeneousEqual) - } - Some(op_binary::Kind::HeterogeneousNotEqual) => { - Op::Binary(Binary::HeterogeneousNotEqual) - } - None => { - return Err(error::Format::DeserializationError( - "deserialization error: binary operation is empty".to_string(), - )) - } - }, - None => { - return Err(error::Format::DeserializationError( - "deserialization error: operation is empty".to_string(), - )) - } - }; - ops.push(translated); + ops.push(proto_op_to_token_op(op)?); } Ok(Expression { ops }) diff --git a/biscuit-auth/src/format/schema.proto b/biscuit-auth/src/format/schema.proto index 88d0d3b0..349bfb41 100644 --- a/biscuit-auth/src/format/schema.proto +++ b/biscuit-auth/src/format/schema.proto @@ -115,6 +115,7 @@ message Op { TermV2 value = 1; OpUnary unary = 2; OpBinary Binary = 3; + OpClosure closure = 4; } } @@ -153,11 +154,20 @@ message OpBinary { NotEqual = 20; HeterogeneousEqual = 21; HeterogeneousNotEqual = 22; + LazyAnd = 23; + LazyOr = 24; + All = 25; + Any = 26; } required Kind kind = 1; } +message OpClosure { + repeated uint32 params = 1; + repeated Op ops = 2; +} + message Policy { enum Kind { Allow = 0; @@ -232,4 +242,4 @@ message SnapshotBlock { repeated CheckV2 checks_v2 = 5; repeated Scope scope = 6; optional PublicKey externalKey = 7; -} \ No newline at end of file +} diff --git a/biscuit-auth/src/format/schema.rs b/biscuit-auth/src/format/schema.rs index bc59730f..58e7769a 100644 --- a/biscuit-auth/src/format/schema.rs +++ b/biscuit-auth/src/format/schema.rs @@ -176,7 +176,7 @@ pub struct ExpressionV2 { } #[derive(Clone, PartialEq, ::prost::Message)] pub struct Op { - #[prost(oneof="op::Content", tags="1, 2, 3")] + #[prost(oneof="op::Content", tags="1, 2, 3, 4")] pub content: ::core::option::Option, } /// Nested message and enum types in `Op`. @@ -189,6 +189,8 @@ pub mod op { Unary(super::OpUnary), #[prost(message, tag="3")] Binary(super::OpBinary), + #[prost(message, tag="4")] + Closure(super::OpClosure), } } #[derive(Clone, PartialEq, ::prost::Message)] @@ -239,9 +241,20 @@ pub mod op_binary { NotEqual = 20, HeterogeneousEqual = 21, HeterogeneousNotEqual = 22, + LazyAnd = 23, + LazyOr = 24, + All = 25, + Any = 26, } } #[derive(Clone, PartialEq, ::prost::Message)] +pub struct OpClosure { + #[prost(uint32, repeated, packed="false", tag="1")] + pub params: ::prost::alloc::vec::Vec, + #[prost(message, repeated, tag="2")] + pub ops: ::prost::alloc::vec::Vec, +} +#[derive(Clone, PartialEq, ::prost::Message)] pub struct Policy { #[prost(message, repeated, tag="1")] pub queries: ::prost::alloc::vec::Vec, diff --git a/biscuit-auth/src/parser.rs b/biscuit-auth/src/parser.rs index bf0efd5c..59510afb 100644 --- a/biscuit-auth/src/parser.rs +++ b/biscuit-auth/src/parser.rs @@ -31,6 +31,7 @@ mod tests { Value(builder::Term), Unary(builder::Op, Box), Binary(builder::Op, Box, Box), + Closure(Vec, Box), } impl Expr { @@ -52,6 +53,11 @@ mod tests { right.into_opcodes(v); v.push(op); } + Expr::Closure(params, body) => { + let mut ops = vec![]; + body.into_opcodes(&mut ops); + v.push(builder::Op::Closure(params, ops)); + } } } } @@ -68,6 +74,9 @@ mod tests { Box::new((*expr1).into()), Box::new((*expr2).into()), ), + biscuit_parser::parser::Expr::Closure(params, body) => { + Expr::Closure(params, Box::new((*body).into())) + } } } } @@ -319,22 +328,28 @@ mod tests { Ok(( " ", Expr::Binary( - Op::Binary(Binary::And), + Op::Binary(Binary::LazyAnd), Box::new(Expr::Binary( - Op::Binary(Binary::And), + Op::Binary(Binary::LazyAnd), Box::new(Expr::Binary( Op::Binary(Binary::LessThan), Box::new(Expr::Value(int(2))), Box::new(Expr::Value(var("test"))), )), - Box::new(Expr::Binary( - Op::Binary(Binary::Prefix), - Box::new(Expr::Value(var("var2"))), - Box::new(Expr::Value(string("test"))), - )), + Box::new(Expr::Closure( + vec![], + Box::new(Expr::Binary( + Op::Binary(Binary::Prefix), + Box::new(Expr::Value(var("var2"))), + Box::new(Expr::Value(string("test"))) + ),) + )) )), - Box::new(Expr::Value(Term::Bool(true))), - ) + Box::new(Expr::Closure( + vec![], + Box::new(Expr::Value(Term::Bool(true))) + )), + ), )) ); diff --git a/biscuit-auth/src/token/builder.rs b/biscuit-auth/src/token/builder.rs index 2ed0fb49..e63fc72e 100644 --- a/biscuit-auth/src/token/builder.rs +++ b/biscuit-auth/src/token/builder.rs @@ -910,6 +910,7 @@ pub enum Op { Value(Term), Unary(Unary), Binary(Binary), + Closure(Vec, Vec), } impl Convert for Op { @@ -918,6 +919,10 @@ impl Convert for Op { Op::Value(t) => datalog::Op::Value(t.convert(symbols)), Op::Unary(u) => datalog::Op::Unary(u.clone()), Op::Binary(b) => datalog::Op::Binary(b.clone()), + Op::Closure(ps, os) => datalog::Op::Closure( + ps.iter().map(|p| symbols.insert(p) as u32).collect(), + os.iter().map(|o| o.convert(symbols)).collect(), + ), } } @@ -926,6 +931,14 @@ impl Convert for Op { datalog::Op::Value(t) => Op::Value(Term::convert_from(t, symbols)?), datalog::Op::Unary(u) => Op::Unary(u.clone()), datalog::Op::Binary(b) => Op::Binary(b.clone()), + datalog::Op::Closure(ps, os) => Op::Closure( + ps.iter() + .map(|p| symbols.print_symbol(*p as u64)) + .collect::>()?, + os.iter() + .map(|o| Op::convert_from(o, symbols)) + .collect::>()?, + ), }) } } @@ -936,6 +949,9 @@ impl From for Op { biscuit_parser::builder::Op::Value(t) => Op::Value(t.into()), biscuit_parser::builder::Op::Unary(u) => Op::Unary(u.into()), biscuit_parser::builder::Op::Binary(b) => Op::Binary(b.into()), + biscuit_parser::builder::Op::Closure(ps, os) => { + Op::Closure(ps, os.into_iter().map(|o| o.into()).collect()) + } } } } @@ -976,6 +992,10 @@ impl From for Binary { biscuit_parser::builder::Binary::NotEqual => Binary::NotEqual, biscuit_parser::builder::Binary::HeterogeneousEqual => Binary::HeterogeneousEqual, biscuit_parser::builder::Binary::HeterogeneousNotEqual => Binary::HeterogeneousNotEqual, + biscuit_parser::builder::Binary::LazyAnd => Binary::LazyAnd, + biscuit_parser::builder::Binary::LazyOr => Binary::LazyOr, + biscuit_parser::builder::Binary::All => Binary::All, + biscuit_parser::builder::Binary::Any => Binary::Any, } } } diff --git a/biscuit-auth/tests/macros.rs b/biscuit-auth/tests/macros.rs index 49244655..f38a14ea 100644 --- a/biscuit-auth/tests/macros.rs +++ b/biscuit-auth/tests/macros.rs @@ -12,8 +12,9 @@ fn block_macro() { let my_key = "my_value"; let mut b = block!( r#"fact("test", hex:aabbcc, [true], {my_key}, {term_set}); - rule($0, true) <- fact($0, $1, $2, {my_key}); + rule($0, true) <- fact($0, $1, $2, {my_key}), true || false; check if {my_key}.starts_with("my"); + check if [true,false].any($p -> true); "#, ); @@ -24,8 +25,9 @@ fn block_macro() { b.to_string(), r#"fact("test", hex:aabbcc, [true], "my_value", [0]); appended(true); -rule($0, true) <- fact($0, $1, $2, "my_value"); +rule($0, true) <- fact($0, $1, $2, "my_value"), true || false; check if "my_value".starts_with("my"); +check if [false, true].any($p -> true); "#, ); } diff --git a/biscuit-parser/src/builder.rs b/biscuit-parser/src/builder.rs index 9fd85673..bf632b31 100644 --- a/biscuit-parser/src/builder.rs +++ b/biscuit-parser/src/builder.rs @@ -185,6 +185,7 @@ pub enum Op { Value(Term), Unary(Unary), Binary(Binary), + Closure(Vec, Vec), } #[derive(Debug, Clone, PartialEq, Eq)] @@ -219,6 +220,10 @@ pub enum Binary { NotEqual, HeterogeneousEqual, HeterogeneousNotEqual, + LazyAnd, + LazyOr, + All, + Any, } #[cfg(feature = "datalog-macro")] @@ -228,6 +233,12 @@ impl ToTokens for Op { Op::Value(t) => quote! { ::biscuit_auth::builder::Op::Value(#t) }, Op::Unary(u) => quote! { ::biscuit_auth::builder::Op::Unary(#u) }, Op::Binary(b) => quote! { ::biscuit_auth::builder::Op::Binary(#b) }, + Op::Closure(params, os) => quote! { + ::biscuit_auth::builder::Op::Closure( + <[String]>::into_vec(Box::new([#(#params.to_string()),*])), + <[::biscuit_auth::builder::Op]>::into_vec(Box::new([#(#os),*])) + ) + }, }); } } @@ -274,6 +285,10 @@ impl ToTokens for Binary { Binary::HeterogeneousNotEqual => { quote! { ::biscuit_auth::datalog::Binary::HeterogeneousNotEqual} } + Binary::LazyAnd => quote! { ::biscuit_auth::datalog::Binary::LazyAnd }, + Binary::LazyOr => quote! { ::biscuit_auth::datalog::Binary::LazyOr }, + Binary::All => quote! { ::biscuit_auth::datalog::Binary::All }, + Binary::Any => quote! { ::biscuit_auth::datalog::Binary::Any }, }); } } diff --git a/biscuit-parser/src/parser.rs b/biscuit-parser/src/parser.rs index 6176ed39..43c37987 100644 --- a/biscuit-parser/src/parser.rs +++ b/biscuit-parser/src/parser.rs @@ -388,6 +388,7 @@ pub enum Expr { Value(builder::Term), Unary(builder::Op, Box), Binary(builder::Op, Box, Box), + Closure(Vec, Box), } impl Expr { @@ -409,6 +410,11 @@ impl Expr { right.into_opcodes(v); v.push(op); } + Expr::Closure(params, expr) => { + let mut ops = vec![]; + expr.into_opcodes(&mut ops); + v.push(builder::Op::Closure(params, ops)) + } } } } @@ -441,12 +447,12 @@ fn unary_parens(i: &str) -> IResult<&str, Expr, Error> { fn binary_op_0(i: &str) -> IResult<&str, builder::Binary, Error> { use builder::Binary; - value(Binary::Or, tag("||"))(i) + value(Binary::LazyOr, tag("||"))(i) } fn binary_op_1(i: &str) -> IResult<&str, builder::Binary, Error> { use builder::Binary; - value(Binary::And, tag("&&"))(i) + value(Binary::LazyAnd, tag("&&"))(i) } fn binary_op_2(i: &str) -> IResult<&str, builder::Binary, Error> { @@ -498,6 +504,8 @@ fn binary_op_8(i: &str) -> IResult<&str, builder::Binary, Error> { value(Binary::Regex, tag("matches")), value(Binary::Intersection, tag("intersection")), value(Binary::Union, tag("union")), + value(Binary::All, tag("all")), + value(Binary::Any, tag("any")), ))(i) } @@ -510,7 +518,14 @@ fn expr_term(i: &str) -> IResult<&str, Expr, Error> { fn fold_exprs(initial: Expr, remainder: Vec<(builder::Binary, Expr)>) -> Expr { remainder.into_iter().fold(initial, |acc, pair| { let (op, expr) = pair; - Expr::Binary(builder::Op::Binary(op), Box::new(acc), Box::new(expr)) + match op { + builder::Binary::LazyAnd | builder::Binary::LazyOr => Expr::Binary( + builder::Op::Binary(op), + Box::new(acc), + Box::new(Expr::Closure(vec![], Box::new(expr))), + ), + _ => Expr::Binary(builder::Op::Binary(op), Box::new(acc), Box::new(expr)), + } }) } @@ -637,10 +652,24 @@ fn expr9(i: &str) -> IResult<&str, Expr, Error> { let bin_result = binary_method(i); let un_result = unary_method(i); match (bin_result, un_result) { - (Ok((i, (op, arg))), _) => { + (Ok((i, (op, params, arg))), _) => { input = i; - initial = - Expr::Binary(builder::Op::Binary(op), Box::new(initial), Box::new(arg)); + match params { + Some(params) => { + initial = Expr::Binary( + builder::Op::Binary(op), + Box::new(initial), + Box::new(Expr::Closure(params, Box::new(arg))), + ); + } + None => { + initial = Expr::Binary( + builder::Op::Binary(op), + Box::new(initial), + Box::new(arg), + ); + } + } } (_, Ok((i, op))) => { input = i; @@ -655,17 +684,31 @@ fn expr9(i: &str) -> IResult<&str, Expr, Error> { } } -fn binary_method(i: &str) -> IResult<&str, (builder::Binary, Expr), Error> { +fn binary_method(i: &str) -> IResult<&str, (builder::Binary, Option>, Expr), Error> { let (i, op) = binary_op_8(i)?; let (i, _) = char('(')(i)?; let (i, _) = space0(i)?; // we only support a single argument for now - let (i, arg) = expr(i)?; - let (i, _) = space0(i)?; - let (i, _) = char(')')(i)?; + match op { + builder::Binary::All | builder::Binary::Any => { + let (i, param) = preceded(char('$'), name)(i)?; + let (i, _) = space0(i)?; + let (i, _) = tag("->")(i)?; + let (i, _) = space0(i)?; + let (i, arg) = expr(i)?; + let (i, _) = space0(i)?; + let (i, _) = char(')')(i)?; + Ok((i, (op, Some(vec![param.to_owned()]), arg))) + } + _ => { + let (i, arg) = expr(i)?; + let (i, _) = space0(i)?; + let (i, _) = char(')')(i)?; - Ok((i, (op, arg))) + Ok((i, (op, None, arg))) + } + } } fn unary_method(i: &str) -> IResult<&str, builder::Unary, Error> { @@ -1483,8 +1526,8 @@ mod tests { vec![ Op::Value(boolean(false)), Op::Unary(Unary::Negate), - Op::Value(boolean(true)), - Op::Binary(Binary::And), + Op::Closure(vec![], vec![Op::Value(boolean(true)),]), + Op::Binary(Binary::LazyAnd), ], )) ); @@ -1495,10 +1538,15 @@ mod tests { "", vec![ Op::Value(boolean(true)), - Op::Value(boolean(true)), - Op::Value(boolean(true)), - Op::Binary(Binary::And), - Op::Binary(Binary::Or), + Op::Closure( + vec![], + vec![ + Op::Value(boolean(true)), + Op::Closure(vec![], vec![Op::Value(boolean(true)),]), + Op::Binary(Binary::LazyAnd), + ] + ), + Op::Binary(Binary::LazyOr), ], )) );