Skip to content

Commit

Permalink
Merge pull request #229 from biscuit-auth/fix-param-substitution-in-c…
Browse files Browse the repository at this point in the history
…losures

fix: recursively collect and apply parameters in closures
  • Loading branch information
divarvel authored Oct 22, 2024
2 parents 911ebe4 + 35edbd8 commit c2dd7a4
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 14 deletions.
64 changes: 53 additions & 11 deletions biscuit-auth/src/token/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -913,6 +913,43 @@ pub enum Op {
Closure(Vec<String>, Vec<Op>),
}

impl Op {
fn collect_parameters(&self, parameters: &mut HashMap<String, Option<Term>>) {
match self {
Op::Value(Term::Parameter(ref name)) => {
parameters.insert(name.to_owned(), None);
}
Op::Closure(_, ops) => {
for op in ops {
op.collect_parameters(parameters);
}
}
_ => {}
}
}

fn apply_parameters(self, parameters: &HashMap<String, Option<Term>>) -> Self {
match self {
Op::Value(Term::Parameter(ref name)) => {
if let Some(Some(t)) = parameters.get(name) {
Op::Value(t.clone())
} else {
self

Check warning on line 937 in biscuit-auth/src/token/builder.rs

View check run for this annotation

Codecov / codecov/patch

biscuit-auth/src/token/builder.rs#L937

Added line #L937 was not covered by tests
}
}
Op::Value(_) => self,
Op::Unary(_) => self,
Op::Binary(_) => self,
Op::Closure(args, mut ops) => Op::Closure(
args,
ops.drain(..)
.map(|op| op.apply_parameters(parameters))
.collect(),
),
}
}
}

impl Convert<datalog::Op> for Op {
fn convert(&self, symbols: &mut SymbolTable) -> datalog::Op {
match self {
Expand Down Expand Up @@ -1036,9 +1073,7 @@ impl Rule {

for expression in &expressions {
for op in &expression.ops {
if let Op::Value(Term::Parameter(name)) = &op {
parameters.insert(name.to_string(), None);
}
op.collect_parameters(&mut parameters);
}
}

Expand Down Expand Up @@ -1282,14 +1317,7 @@ impl Rule {
expression.ops = expression
.ops
.drain(..)
.map(|op| {
if let Op::Value(Term::Parameter(name)) = &op {
if let Some(Some(term)) = parameters.get(name) {
return Op::Value(term.clone());
}
}
op
})
.map(|op| op.apply_parameters(&parameters))
.collect();
}
}
Expand Down Expand Up @@ -2382,6 +2410,20 @@ mod tests {
assert_eq!(s, "fact($var1, \"hello\", [0]) <- f1($var1, $var3), f2(\"hello\", $var3, 1), $var3.starts_with(\"hello\")");
}

#[test]
fn set_closure_parameters() {
let mut rule = Rule::try_from("fact(true) <- false || {p1}").unwrap();
rule.set_lenient("p1", true).unwrap();
println!("{rule:?}");
let s = rule.to_string();
assert_eq!(s, "fact(true) <- false || true");

let mut rule = Rule::try_from("fact(true) <- false || {p1}").unwrap();
rule.set("p1", true).unwrap();
let s = rule.to_string();
assert_eq!(s, "fact(true) <- false || true");
}

#[test]
fn set_rule_scope_parameters() {
let pubkey = PublicKey::from_bytes(
Expand Down
20 changes: 17 additions & 3 deletions biscuit-parser/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,22 @@ pub enum Op {
Closure(Vec<String>, Vec<Op>),
}

impl Op {
fn collect_parameters(&self, parameters: &mut HashMap<String, Option<Term>>) {
match self {
Op::Value(Term::Parameter(ref name)) => {
parameters.insert(name.to_owned(), None);
}
Op::Closure(_, ops) => {
for op in ops {
op.collect_parameters(parameters);
}
}
_ => {}
}
}
}

#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Unary {
Negate,
Expand Down Expand Up @@ -332,9 +348,7 @@ impl Rule {

for expression in &expressions {
for op in &expression.ops {
if let Op::Value(Term::Parameter(name)) = &op {
parameters.insert(name.to_string(), None);
}
op.collect_parameters(&mut parameters);
}
}

Expand Down

0 comments on commit c2dd7a4

Please sign in to comment.