Skip to content

Commit

Permalink
Merge pull request #54 from Gateway-DAO/10d9e/feat/add-range-expressions
Browse files Browse the repository at this point in the history
10d9e/feat/add range expressions
  • Loading branch information
10d9e authored Oct 30, 2024
2 parents 95dd59b + 84f9e10 commit 039e3bf
Show file tree
Hide file tree
Showing 10 changed files with 904 additions and 279 deletions.
289 changes: 223 additions & 66 deletions circuit_macro/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,11 @@ extern crate proc_macro;
use core::panic;

use proc_macro::TokenStream;
use quote::format_ident;
use quote::quote;
use syn::ExprAssign;
use syn::ExprBlock;
use syn::ExprMatch;
use quote::{format_ident, quote};
use std::collections::HashSet;
use syn::{
parse_macro_input, BinOp, Expr, ExprBinary, ExprIf, ExprUnary, FnArg, ItemFn, Lit, Pat, PatType,
parse_macro_input, BinOp, Expr, ExprAssign, ExprBinary, ExprBlock, ExprIf, ExprLet, ExprMatch,
ExprReference, ExprUnary, FnArg, ItemFn, Lit, Pat, PatType,
};

#[proc_macro_attribute]
Expand Down Expand Up @@ -57,6 +55,13 @@ fn generate_macro(item: TokenStream, mode: &str) -> TokenStream {
let mut constants = vec![];
let transformed_block = modify_body(*input_fn.block, &mut constants);

// remove duplicates
let mut seen = HashSet::new();
let constants: Vec<proc_macro2::TokenStream> = constants
.into_iter()
.filter(|item| seen.insert(item.to_string()))
.collect();

// Collect parameter names dynamically
let param_names: Vec<_> = inputs
.iter()
Expand Down Expand Up @@ -201,8 +206,18 @@ fn replace_expressions(expr: Expr, constants: &mut Vec<proc_macro2::TokenStream>
Expr::Assign(ExprAssign { left, right, .. }) => {
let left_expr = replace_expressions(*left, constants);
let right_expr = replace_expressions(*right, constants);
syn::parse_quote! {
#left_expr = #right_expr.clone()

match right_expr {
Expr::Reference(ExprReference { .. }) => {
syn::parse_quote! {
#left_expr = &#right_expr.clone()
}
}
_ => {
syn::parse_quote! {
#left_expr = #right_expr.clone()
}
}
}
}
// return statement
Expand All @@ -221,9 +236,16 @@ fn replace_expressions(expr: Expr, constants: &mut Vec<proc_macro2::TokenStream>
}) => {
let value = lit_bool.value;
let const_var = format_ident!("const_{}", value as u128);
constants.push(quote! {
let #const_var = &context.input::<N>(&#value.into()).clone();
});

if value {
constants.push(quote! {
let #const_var = &context.input::<N>(&1_u128.into()).clone();
});
} else {
constants.push(quote! {
let #const_var = &context.input::<N>(&0_u128.into()).clone();
});
}
syn::parse_quote! {#const_var}
}
// integer literal - handle as a constant in the circuit context
Expand Down Expand Up @@ -586,6 +608,7 @@ fn replace_expressions(expr: Expr, constants: &mut Vec<proc_macro2::TokenStream>
}}
}

/*
Expr::If(ExprIf {
cond,
then_branch,
Expand All @@ -595,76 +618,210 @@ fn replace_expressions(expr: Expr, constants: &mut Vec<proc_macro2::TokenStream>
let cond_expr = replace_expressions(*cond, constants);
let then_block = modify_body(then_branch, constants);
if let Some((_, else_expr)) = else_branch {
match *else_expr {
Expr::If(else_if) => {
let else_if_expr = replace_expressions(Expr::If(else_if), constants);
syn::parse_quote! {{
let if_true = #then_block;
let if_false = #else_if_expr;
let cond = #cond_expr;
context.mux(&cond.into(), &if_true, &if_false)
}}
}
_ => {
let else_block = modify_body(syn::parse_quote! { #else_expr }, constants);
syn::parse_quote! {{
let if_true = #then_block;
let if_false = #else_block;
let cond = #cond_expr;
context.mux(&cond.into(), &if_true, &if_false)
}}
}
}
// If there's an explicit else block, use it; otherwise, continue with remaining expressions
let else_expr = if let Some((_, else_expr)) = else_branch {
replace_expressions(*else_expr, constants)
} else {
// Placeholder for remaining function body as the fall-through `else` case
//syn::parse_quote! { context.input::<N>(&0u128.into()) }
panic!("else branch is required");
/*
syn::parse_quote! {
{
let if_true = #then_block;
let cond = #cond_expr;
let if_false = context.len();
&context.mux(&cond.into(), &if_true.into(), &if_false.into());
};
// Generate code for conditional execution and fall-through
syn::parse_quote! {{
let cond = #cond_expr;
let if_true = #then_block;
let if_false = #else_expr;
context.mux(&cond.into(), &if_true, &if_false)
}}
}
*/
Expr::If(ExprIf {
cond,
then_branch,
else_branch,
..
}) => {
// Check if `cond` is an `if let` with a range pattern
let cond_expr = match *cond {
Expr::Let(ExprLet { pat, expr, .. }) => {
match &*pat {
// Handle inclusive range pattern (e.g., 1..=5)
syn::Pat::Range(syn::PatRange {
start: Some(start),
end: Some(end),
limits: syn::RangeLimits::Closed(_),
..
}) => {
let start_expr = replace_expressions(*start.clone(), constants);
let end_expr = replace_expressions(*end.clone(), constants);
let input_expr = replace_expressions(*expr, constants);

// Inclusive range with embedded `let` statements for `lhs` and `rhs`
syn::parse_quote! {{
let lhs = &context.ge(&#input_expr.into(), &#start_expr.into()).into();
let rhs = &context.le(&#input_expr.into(), &#end_expr.into()).into();
context.and(lhs, rhs)
}}
}
// Handle exclusive range pattern (e.g., 1..10)
syn::Pat::Range(syn::PatRange {
start: Some(start),
end: Some(end),
limits: syn::RangeLimits::HalfOpen(_),
..
}) => {
let start_expr = replace_expressions(*start.clone(), constants);
let end_expr = replace_expressions(*end.clone(), constants);
let input_expr = replace_expressions(*expr, constants);

// Exclusive range with embedded `let` statements for `lhs` and `rhs`
syn::parse_quote! {{
let lhs = &context.ge(&#input_expr.into(), &#start_expr.into()).into();
let rhs = &context.lt(&#input_expr.into(), &#end_expr.into()).into();
context.and(lhs, rhs)
}}
}
// Handle single literal pattern, e.g., `if let 5 = n`
syn::Pat::Lit(lit) => {
let lit_expr = replace_expressions(Expr::Lit(lit.clone()), constants);
let input_expr = replace_expressions(*expr, constants);

syn::parse_quote! {
context.eq(&#input_expr.into(), &#lit_expr.into())
}
}
_ => panic!(
"Unsupported pattern in if let: expected a range or literal pattern."
),
}
}
*/
}
ref _other => {
replace_expressions(*cond, constants) // Fallback for non-let conditions
}
};

let then_block = modify_body(then_branch, constants);

// Check if an `else` branch exists, as it's required.
let else_expr = if let Some((_, else_expr)) = else_branch {
replace_expressions(*else_expr, constants)
} else {
panic!("else branch is required for range if let");
};

// Generate code for conditional execution and chaining
syn::parse_quote! {{
let cond = #cond_expr;
let if_true = #then_block;
let if_false = #else_expr;
context.mux(&cond.into(), &if_true, &if_false)
}}
}

// support match arms with mux and other operations
// Support match arms with mux and other operations
Expr::Match(ExprMatch { expr, arms, .. }) => {
let match_expr = replace_expressions(*expr, constants);

// Define an input variable to use in range proof processing
let input = syn::Ident::new("input", proc_macro2::Span::call_site());
let input_binding = quote! { let #input = #match_expr; };

// Process each arm, building up the conditional chain
let arm_exprs = arms.into_iter().rev().fold(None, |acc, arm| {
let pat = arm.pat;
let body_expr = replace_expressions(*arm.body, constants);

// Create conditional expression for each arm
let cond_expr =
replace_expressions(syn::parse_quote! { #match_expr == #pat }, constants);

Some(if let Some(else_expr) = acc {
let else_expr = replace_expressions(else_expr, constants);

syn::parse_quote! {{
let if_true = { #body_expr };
let if_false = { #else_expr };
let cond = { #cond_expr };
context.mux(&cond.into(), &if_true, &if_false)
}}
} else {
syn::parse_quote! {{
{ #body_expr }
}}
})
});
let arm_exprs = arms
.into_iter()
.rev()
.fold(None as Option<Expr>, |acc, arm| {
let pat = arm.pat;
let body_expr = replace_expressions(*arm.body, constants);

// Create conditional expression for each arm, handling ranges
let cond_expr = match &pat {
// Handle inclusive range pattern (start..=end)
syn::Pat::Range(syn::PatRange {
start: Some(start),
end: Some(end),
limits: syn::RangeLimits::Closed(_),
..
}) => {
let start = replace_expressions(*start.clone(), constants);
let end = replace_expressions(*end.clone(), constants);
quote! {
let lhs = &context.ge(&#input.into(), &#start.into()).into();
let rhs = &context.le(&#input.into(), &#end.into()).into();
context.and(
lhs,
rhs
)
}
}
// Handle exclusive range pattern (start..end)
syn::Pat::Range(syn::PatRange {
start: Some(start),
end: Some(end),
limits: syn::RangeLimits::HalfOpen(_),
..
}) => {
let start = replace_expressions(*start.clone(), constants);
let end = replace_expressions(*end.clone(), constants);
quote! {
let lhs = &context.ge(&#input.into(), &#start.into()).into();
let rhs = &context.lt(&#input.into(), &#end.into()).into();
context.and(
lhs,
rhs
)
}
}
// Handle single value pattern (e.g., `5`)
syn::Pat::Lit(lit) => {
let lit_expr =
replace_expressions(syn::Expr::Lit(lit.clone()), constants);
quote! {
context.eq(&#input.into(), &#lit_expr.into())
}
}

syn::Pat::Ident(pat) => {
// Create conditional expression for each arm
let cond_expr = replace_expressions(
syn::parse_quote! { #match_expr == #pat },
constants,
);

syn::parse_quote! {{
{ #cond_expr }
}}
}
// Handle the wildcard pattern `_` as default/fallback case
syn::Pat::Wild(_) => quote! { true },
other => panic!("{:?}: Unsupported pattern in match arm", other),
};

// Chain the condition with the body, selecting based on condition
Some(if let Some(else_expr) = acc {
syn::parse_quote! {{
let if_true = { #body_expr };
let if_false = { #else_expr };
let cond = { #cond_expr };
context.mux(&cond.into(), &if_true, &if_false)
}}
} else {
syn::parse_quote! {{
{ #body_expr }
}}
})
});

match arm_exprs {
Some(result) => result,
Some(result) => syn::parse_quote! {{
#input_binding // Bind `input` at the beginning
#result // Process the chained expressions
}},
None => panic!("Match expression requires at least one arm"),
}
}

other => other,
}
}
17 changes: 12 additions & 5 deletions compute/examples/access_control.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,27 @@ fn access_control(role: u8) -> u8 {
let SENSITIVE_DATA = 1; // 1 represents sensitive data, accessible only to certain roles.
let PATIENT_NOTES = 2; // 2 represents patient notes, which may have broader access.

let ADMIN_ROLE = 1; // Role identifier for Admin
let DOCTOR_ROLE = 2; // Role identifier for Doctor
let NURSE_ROLE = 3; // Role identifier for Nurse

// Use a match expression to determine access based on the provided role.
match role {
let determined_role = match role {
// Case for Admin role, which we assume has the highest level of access.
1 => {
ADMIN_ROLE => {
// Admin role (encoded as 1) gets full access to the most sensitive data.
// The return value is `SENSITIVE_DATA`, indicating unrestricted access to it.
SENSITIVE_DATA
}
// Case for Doctor role, which has limited access.
2 => {
DOCTOR_ROLE => {
// Doctor role (encoded as 2) has partial access to both SENSITIVE_DATA and PATIENT_NOTES.
// Using the `+` operator, we perform a bitwise AND on `PATIENT_NOTES` and `SENSITIVE_DATA`.
// This allows the doctor role to have limited, controlled access to data while preserving privacy.
PATIENT_NOTES + SENSITIVE_DATA
}
// Case for Nurse role, which has only patient notes access.
3 => {
NURSE_ROLE => {
// Nurse role (encoded as 3) can view only patient notes.
// The function returns `PATIENT_NOTES`, granting access exclusively to this data type.
PATIENT_NOTES
Expand All @@ -38,7 +42,10 @@ fn access_control(role: u8) -> u8 {
// Returning `0` signifies no access to any sensitive or restricted data.
0
}
}
};

// Return the determined role access level.
determined_role
}

/// Main function to simulate a real-world scenario and demonstrate the access control function.
Expand Down
Loading

0 comments on commit 039e3bf

Please sign in to comment.