Skip to content

Commit

Permalink
Make some types explicit for columns. (powdr-labs#2014)
Browse files Browse the repository at this point in the history
Add some explicit types or explicit conversions to prepare for the
`ToCol` trait.
  • Loading branch information
chriseth authored Nov 1, 2024
1 parent 047f272 commit 8321a97
Show file tree
Hide file tree
Showing 7 changed files with 71 additions and 47 deletions.
74 changes: 44 additions & 30 deletions executor/src/constant_evaluator/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ mod test {
let src = r#"
let N = 8;
namespace F(N);
col fixed LAST(i) { if i == N - 1 { 1 } else { 0 } };
col fixed LAST(i) { if i == N - 1_int { 1_fe } else { 0_fe } };
"#;
let analyzed = analyze_string(src);
assert_eq!(analyzed.degree(), 8);
Expand All @@ -127,8 +127,10 @@ mod test {
fn counter() {
let src = r#"
let N: int = 8;
namespace std::convert;
let fe = [];
namespace F(N);
pol constant EVEN(i) { 2 * (i - 1) + 4 };
pol constant EVEN(i) { std::convert::fe(2 * (i - 1) + 4_int) };
"#;
let analyzed = analyze_string(src);
assert_eq!(analyzed.degree(), 8);
Expand All @@ -146,8 +148,10 @@ mod test {
fn xor() {
let src = r#"
let N: int = 8;
namespace std::convert;
let fe = [];
namespace F(N);
pol constant X(i) { i ^ (i + 17) | 3 };
pol constant X(i) { std::convert::fe(i ^ (i + 17) | 3_int) };
"#;
let analyzed = analyze_string(src);
assert_eq!(analyzed.degree(), 8);
Expand All @@ -165,13 +169,16 @@ mod test {
fn match_expression() {
let src = r#"
let N: int = 8;
namespace std::convert;
let fe = [];
namespace F(N);
pol constant X(i) { match i {
let x: int -> fe = |i| std::convert::fe(match i {
0 => 7,
3 => 9,
5 => 2,
_ => 4,
} + 1 };
} + 1_int);
pol constant X(i) { x(i) };
"#;
let analyzed = analyze_string(src);
assert_eq!(analyzed.degree(), 8);
Expand All @@ -187,7 +194,7 @@ mod test {
let src = r#"
let N: int = 8;
namespace F(N);
let X: col = |i| if i < 3 { 7 } else { 9 };
let X: col = |i| if i < 3_int { 7_fe } else { 9 };
"#;
let analyzed = analyze_string(src);
assert_eq!(analyzed.degree(), 8);
Expand All @@ -202,9 +209,11 @@ mod test {
fn macro_directive() {
let src = r#"
let N: int = 8;
namespace std::convert;
let fe = [];
namespace F(N);
let minus_one: int -> int = |x| x - 1;
pol constant EVEN(i) { 2 * minus_one(i) + 2 };
pol constant EVEN(i) { std::convert::fe(2 * minus_one(i) + 2) };
"#;
let analyzed = analyze_string(src);
assert_eq!(analyzed.degree(), 8);
Expand All @@ -224,13 +233,14 @@ mod test {
let N = 16;
namespace std::convert(N);
let int = [];
let fe = [];
namespace F(N);
let seq_f = |i| i;
col fixed seq(i) { i };
col fixed double_plus_one(i) { std::convert::int(seq_f((2 * i) % N)) + 1 };
let seq_f: int -> int = |i| i;
col fixed seq(i) { std::convert::fe(seq_f(i)) };
col fixed double_plus_one(i) { std::convert::fe(std::convert::int(seq_f((2 * i) % N)) + 1) };
let half_nibble_f = |i| i & 0x7;
col fixed half_nibble(i) { half_nibble_f(i) };
col fixed doubled_half_nibble(i) { half_nibble_f(i / 2) };
col fixed half_nibble(i) { std::convert::fe(half_nibble_f(i)) };
col fixed doubled_half_nibble(i) { std::convert::fe(half_nibble_f(i / 2)) };
"#;
let analyzed = analyze_string(src);
assert_eq!(analyzed.degree(), 16);
Expand Down Expand Up @@ -333,15 +343,15 @@ mod test {
let inv = |i| N - i;
let a: int -> int = |i| [0, 1, 0, 1, 2, 1, 1, 1][i];
let b: int -> int = |i| [0, 0, 1, 1, 0, 5, 5, 5][i];
col fixed or(i) { if (a(i) != 0) || (b(i) != 0) { 1 } else { 0 } };
col fixed and(i) { if (a(i) != 0) && (b(i) != 0) { 1 } else { 0 } };
col fixed not(i) { if !(a(i) != 0) { 1 } else { 0 } };
col fixed less(i) { if id(i) < inv(i) { 1 } else { 0 } };
col fixed less_eq(i) { if id(i) <= inv(i) { 1 } else { 0 } };
col fixed eq(i) { if id(i) == inv(i) { 1 } else { 0 } };
col fixed not_eq(i) { if id(i) != inv(i) { 1 } else { 0 } };
col fixed greater(i) { if id(i) > inv(i) { 1 } else { 0 } };
col fixed greater_eq(i) { if id(i) >= inv(i) { 1 } else { 0 } };
col fixed or(i) { if (a(i) != 0) || (b(i) != 0) { 1_fe } else { 0_fe } };
col fixed and(i) { if (a(i) != 0) && (b(i) != 0) { 1_fe } else { 0_fe } };
col fixed not(i) { if !(a(i) != 0) { 1_fe } else { 0_fe } };
col fixed less(i) { if id(i) < inv(i) { 1_fe } else { 0_fe } };
col fixed less_eq(i) { if id(i) <= inv(i) { 1_fe } else { 0_fe } };
col fixed eq(i) { if id(i) == inv(i) { 1_fe } else { 0_fe } };
col fixed not_eq(i) { if id(i) != inv(i) { 1_fe } else { 0_fe } };
col fixed greater(i) { if id(i) > inv(i) { 1_fe } else { 0_fe } };
col fixed greater_eq(i) { if id(i) >= inv(i) { 1_fe } else { 0_fe } };
"#;
let analyzed = analyze_string(src);
assert_eq!(analyzed.degree(), 8);
Expand Down Expand Up @@ -444,7 +454,7 @@ mod test {
let src = r#"
let N: int = 10;
namespace F(N);
let x: col = |i| y(i) + 1;
let x: col = |i| { let t = y(i) + 1; 1_fe };
col fixed y = [1, 2, 3]*;
"#;
let analyzed = analyze_string(src);
Expand All @@ -456,9 +466,12 @@ mod test {
fn forward_reference_to_function() {
let src = r#"
let N: int = 4;
namespace std::convert(N);
let int = [];
let fe = [];
namespace F(N);
let x = |i| y(i) + 1;
let y = |i| i + 20;
let x: int -> fe = |i| std::convert::fe(y(i) + 1);
let y: int -> fe = |i| std::convert::fe(i + 20);
let X: col = x;
let Y: col = y;
"#;
Expand All @@ -483,7 +496,7 @@ mod test {
let int = [];
let fe = [];
namespace F(N);
let x: col = |i| (1 << (2000 + i)) >> 2000;
let x: col = |i| std::convert::fe((1 << (2000 + i)) >> 2000);
"#;
let analyzed = analyze_string(src);
assert_eq!(analyzed.degree(), 4);
Expand All @@ -503,7 +516,7 @@ mod test {
let fe = [];
namespace F(N);
let x_arr = [ 3 % 4, (-3) % 4, 3 % (-4), (-3) % (-4)];
let x: col = |i| 100 + x_arr[i];
let x: col = |i| std::convert::fe(100 + x_arr[i]);
"#;
let analyzed = analyze_string(src);
assert_eq!(analyzed.degree(), 4);
Expand Down Expand Up @@ -547,7 +560,7 @@ mod test {
let fe = || fe();
namespace F(4);
let<T: FromLiteral> seven: T = 7;
let a: col = |i| std::convert::fe(i + seven) + seven;
let a: col = |i| std::convert::fe(i + seven + 0_int) + seven;
"#;
let analyzed = analyze_string(src);
assert_eq!(analyzed.degree(), 4);
Expand All @@ -562,19 +575,20 @@ mod test {
fn do_not_add_constraint_for_empty_tuple() {
let input = r#"namespace N(4);
let f: -> () = || ();
let g: col = |i| {
let r: int -> fe = |i| {
// This returns an empty tuple, we check that this does not lead to
// a call to add_proof_items()
f();
i
7_fe
};
let g: col = r;
"#;
let analyzed = analyze_string(input);
assert_eq!(analyzed.degree(), 4);
let constants = generate(&analyzed);
assert_eq!(
constants[0],
("N::g".to_string(), convert([0, 1, 2, 3].to_vec()))
("N::g".to_string(), convert([7, 7, 7, 7].to_vec()))
);
}
}
12 changes: 8 additions & 4 deletions executor/src/witgen/global_constraints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -434,10 +434,12 @@ mod test {
#[test]
fn constraints_propagation() {
let pil_source = r"
namespace std::convert;
let fe = [];
namespace Global(2**20);
col fixed BYTE(i) { i & 0xff };
col fixed BYTE2(i) { i & 0xffff };
col fixed SHIFTED(i) { i & 0xff0 };
col fixed BYTE(i) { std::convert::fe(i & 0xff) };
col fixed BYTE2(i) { std::convert::fe(i & 0xffff) };
col fixed SHIFTED(i) { std::convert::fe(i & 0xff0) };
col witness A;
// A bit more complicated to see that the 'pattern matcher' works properly.
(1 - A + 0) * (A + 1 - 1) = 0;
Expand Down Expand Up @@ -508,8 +510,10 @@ namespace Global(2**20);
// incorrectly determined it to be a pure range constraint, but it would actually not
// be able to derive the full constraint.
let pil_source = r"
namespace std::convert;
let fe = [];
namespace Global(1024);
let bytes: col = |i| i % 256;
let bytes: col = |i| std::convert::fe(i % 256);
let X;
[ X * 4 ] in [ bytes ];
";
Expand Down
4 changes: 1 addition & 3 deletions jit-compiler/src/codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,7 @@ impl<'a, T: FieldElement> CodeGenerator<'a, T> {
}
})
}
_ => Err(format!(
"Definition of this kind not supported: {symbol} - {definition}"
)),
_ => Err(format!("Definition of this kind not supported: {symbol}")),
}
}

Expand Down
17 changes: 11 additions & 6 deletions pil-analyzer/tests/condenser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -238,24 +238,29 @@ fn double_next() {

#[test]
fn new_fixed_column() {
let input = r#"namespace N(16);
let input = r#"
namespace std::convert;
let fe = 9;
namespace N(16);
let f = constr || {
let even: col = |i| i * 2;
let even: col = |i| std::convert::fe(i * 2_int);
even
};
let ev = f();
let x;
x = ev;
"#;
let formatted = analyze_string(input).to_string();
let expected = r#"namespace N(16);
let expected = r#"namespace std::convert;
let fe = 9;
namespace N(16);
let f: -> expr = constr || {
let even: col = |i| i * 2_int;
let even: col = |i| std::convert::fe::<int>(i * 2_int);
even
};
let ev: expr = N::f();
col witness x;
col fixed even(i) { i * 2_int };
col fixed even(i) { std::convert::fe::<int>(i * 2_int) };
N::x = N::even;
"#;
assert_eq!(formatted, expected);
Expand Down Expand Up @@ -1055,7 +1060,7 @@ pub fn at_next_stage_intermediate_and_fixed() {
std::prover::at_next_stage(constr || {
let b: inter = a * a;
let c;
let first: col = |i| if i == 0 { 1 } else { 0 };
let first: col = |i| if i == 0_int { 1_fe } else { 0 };
let d: inter = a + c;
c' = first;
});
Expand Down
4 changes: 2 additions & 2 deletions pil-analyzer/tests/side_effects.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ fn new_wit_in_pure() {
}

#[test]
#[should_panic = "Tried to create a fixed column in a pure context: let x: col = |i| i;"]
#[should_panic = "Tried to create a fixed column in a pure context: let x: col = |i| i + 0_int;"]
fn new_fixed_in_pure() {
let input = r#"namespace N(16);
let new_col = || { let x: col = |i| i; x };
let new_col = || { let x: col = |i| i + 0_int; x };
"#;
analyze_string(input);
}
Expand Down
2 changes: 1 addition & 1 deletion std/machines/small_field/binary.asm
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ machine Binary8(byte_binary: ByteBinary) with

let operation_id;

let latch: col = |i| { 1 };
let latch: col = std::well_known::one;

let A1;
let A2;
Expand Down
5 changes: 4 additions & 1 deletion std/well_known.asm
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
/// Evaluates to 1 on the first row and 0 on all other rows.
/// Useful to define a fixed column of that property.
let is_first: int -> int = |i| if i == 0 { 1 } else { 0 };
let is_first: int -> int = |i| if i == 0 { 1 } else { 0 };

/// The constant one.
let one: int -> int = |i| 1;

0 comments on commit 8321a97

Please sign in to comment.