Skip to content

Commit

Permalink
Improvements to Arith16 machine (powdr-labs#1873)
Browse files Browse the repository at this point in the history
Applied the feedback that I gave in my review for powdr-labs#1790. See in-line
comments.
  • Loading branch information
georgwiese authored Oct 8, 2024
1 parent e9fc9d5 commit 01f90f1
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 62 deletions.
95 changes: 35 additions & 60 deletions std/machines/arith16.asm
Original file line number Diff line number Diff line change
Expand Up @@ -10,45 +10,43 @@ use std::convert::expr;
use std::prover::eval;
use std::prelude::Query;
use std::machines::range::Byte;
use std::machines::range::Byte2;

// Arithmetic machine, ported mainly from Polygon: https://github.com/0xPolygonHermez/zkevm-proverjs/blob/main/pil/arith.pil
// This machine supports eq0, which is the affine equation. Currently we only expose operations for mul and div.
machine Arith16(byte: Byte) with
// Arithmetic machine, inspired by Polygon's 256-Bit Arith machine: https://github.com/0xPolygonHermez/zkevm-proverjs/blob/main/pil/arith.pil
machine Arith16(byte: Byte, byte2: Byte2) with
latch: CLK8_7,
operation_id: operation_id,
operation_id: is_division,
// Allow this machine to be connected via a permutation
call_selectors: sel,
{
col witness operation_id;

// operation_id has to be either mul or div.
force_bool(operation_id);
col witness is_division;

// Computes x1 * y1 + x2, where all inputs / outputs are 32-bit words (represented as 16-bit limbs in big-endian order).
// More precisely, affine_256(x1, y1, x2) = (y2, y3), where x1 * y1 + x2 = 2**16 * y2 + y3

// x1 * y1 = y2 * 2**16 + y3
operation mul<0> x1c[1], x1c[0], y1c[1], y1c[0] -> y2c[1], y2c[0], y3c[1], y3c[0];

// Constrain that x2 = 0 when operation is mul.
array::new(4, |i| (1 - operation_id) * x2[i] = 0);
// More precisely, affine(x1, y1, x2) = (y2, y3), where x1 * y1 + x2 = 2**16 * y2 + y3
operation mul<0> x1c[1], x1c[0], x2c[1], x2c[0], y1c[1], y1c[0] -> y2c[1], y2c[0], y3c[1], y3c[0];

// y3 / x1 = y1 (remainder x2)
// WARNING: it's not constrained that remainder is less than the divisor.
// This is done in the main machine, e.g. our RISCV BabyBear machine, that uses this operation.
// WARNING: For division by zero, the quotient is unconstrained.
// Both need to be handled by any machine calling into this one.
operation div<1> y3c[1], y3c[0], x1c[1], x1c[0] -> y1c[1], y1c[0], x2c[1], x2c[0];

// Constrain that y2 = 0 when operation is div.
array::new(4, |i| operation_id * y2[i] = 0);
array::new(4, |i| is_division * y2[i] = 0);

// We need to provide hints for the quotient and remainder, because they are not unique under our current constraints.
// They are unique given additional main machine constraints, but it's still good to provide hints for the solver.
let quotient_hint = query |limb| match(eval(operation_id)) {
let quotient_hint = query |limb| match(eval(is_division)) {
1 => {
let y3 = y3_int();
let x1 = x1_int();
let quotient = y3 / x1;
Query::Hint(fe(select_limb(quotient, limb)))
if x1_int() == 0 {
// Quotient is unconstrained, use zero.
Query::Hint(0)
} else {
let y3 = y3_int();
let x1 = x1_int();
let quotient = y3 / x1;
Query::Hint(fe(select_limb(quotient, limb)))
}
},
_ => Query::None
};
Expand All @@ -60,12 +58,17 @@ machine Arith16(byte: Byte) with

let y1: expr[] = [y1_0, y1_1, y1_2, y1_3];

let remainder_hint = query |limb| match(eval(operation_id)) {
let remainder_hint = query |limb| match(eval(is_division)) {
1 => {
let y3 = y3_int();
let x1 = x1_int();
let remainder = y3 % x1;
Query::Hint(fe(select_limb(remainder, limb)))
if x1 == 0 {
// To satisfy x1 * y1 + x2 = y3, we need to set x2 = y3.
Query::Hint(fe(select_limb(y3, limb)))
} else {
let remainder = y3 % x1;
Query::Hint(fe(select_limb(remainder, limb)))
}
},
_ => Query::None
};
Expand Down Expand Up @@ -106,11 +109,7 @@ machine Arith16(byte: Byte) with
let CLK8: col[8] = array::new(8, |i| |row| if row % 8 == i { 1 } else { 0 });
let CLK8_7: expr = CLK8[7];

/****
*
* LATCH POLS: x1,y1,x2,y2,y3
*
*****/
// All inputs & outputs are kept constant within a block.

let fixed_inside_8_block = |e| unchanged_until(e, CLK8[7]);

Expand All @@ -120,22 +119,13 @@ machine Arith16(byte: Byte) with
array::map(y2, fixed_inside_8_block);
array::map(y3, fixed_inside_8_block);

/****
*
* RANGE CHECK x1,y1,x2,y2,y3
*
*****/
// All input & output limbs are range-constrained to be bytes.

link => byte.check(sum(4, |i| x1[i] * CLK8[i]) + sum(4, |i| y1[i] * CLK8[4 + i]));
link => byte.check(sum(4, |i| x2[i] * CLK8[i]) + sum(4, |i| y2[i] * CLK8[4 + i]));
link => byte.check(sum(4, |i| y3[i] * CLK8[i]));

/*******
*
* EQ0: A(x1) * B(y1) + C(x2) = D (y2) * 2 ** 16 + op (y3)
* x1 * y1 + x2 - y2 * 2**256 - y3 = 0
*
*******/
// Constrain x1 * y1 + x2 - y2 * 2**16 - y3 = 0

/// returns a(0) * b(0) + ... + a(n - 1) * b(n - 1)
let dot_prod = |n, a, b| sum(n, |i| a(i) * b(i));
Expand Down Expand Up @@ -163,27 +153,12 @@ machine Arith16(byte: Byte) with
- shift_right(y2f, 4)(nr)
- y3f(nr);

/*******
*
* Carry
*
*******/

pol witness carry_low, carry_high;
link => byte.check(carry_low);
link => byte.check(carry_high);

let carry = carry_high * 2**8 + carry_low;

// Carry: Constrained to be 16-Bit and zero in the first row of the block.
col witness carry;
link => byte2.check(carry);
carry * CLK8[0] = 0;

/*******
*
* Putting everything together
*
*******/

// Putting everything together
col eq0_sum = sum(8, |i| eq0(i) * CLK8[i]);

eq0_sum + carry = carry' * 2**8;
}
12 changes: 10 additions & 2 deletions test_data/std/arith16_test.asm
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::machines::arith16::Arith16;
use std::machines::range::Byte;
use std::machines::range::Byte2;

machine Main with degree: 65536 {
reg pc[@pc];
Expand All @@ -18,11 +19,12 @@ machine Main with degree: 65536 {
reg t_1_1;

Byte byte;
Byte2 byte2;

Arith16 arith(byte);
Arith16 arith(byte, byte2);

instr mul A0, A1, B0, B1 -> C0, C1, D0, D1
link ~> (C0, C1, D0, D1) = arith.mul(A0, A1, B0, B1);
link ~> (C0, C1, D0, D1) = arith.mul(A0, A1, 0, 0, B0, B1);

instr div A0, A1, B0, B1 -> C0, C1, D0, D1
link ~> (C0, C1, D0, D1) = arith.div(A0, A1, B0, B1);
Expand Down Expand Up @@ -58,5 +60,11 @@ machine Main with degree: 65536 {
// 0xffffeff / 0xfffff = 0xff (remainder 0xffffe)
t_0_0, t_0_1, t_1_0, t_1_1 <== div(0xfff, 0xfeff, 0xf, 0xffff);
assert_eq t_0_0, t_0_1, t_1_0, t_1_1, 0, 0xff, 0xf, 0xfffe;

// 0xabcdef01 / 0 = 0 (remainder 0xabcdef01)
// (note that the quotient is unconstrained though)
t_0_0, t_0_1, t_1_0, t_1_1 <== div(0xabcd, 0xef01, 0, 0);
assert_eq t_0_0, t_0_1, t_1_0, t_1_1, 0, 0, 0xabcd, 0xef01;

}
}

0 comments on commit 01f90f1

Please sign in to comment.