From b23f5387c5d9f2a2ec212a6c07b950b61417770a Mon Sep 17 00:00:00 2001 From: zkronos73 Date: Fri, 13 Sep 2024 16:51:41 +0200 Subject: [PATCH 01/17] WIP arith --- pil/fork_0/pil/operations.pil | 56 ++++ pil/fork_0/pil/zisk.pil | 14 +- state-machines/arith/pil/arith.pil | 259 ++++++++++++++++++ state-machines/arith/pil/arith_32.pil | 160 +++++++++++ state-machines/arith/pil/arith_3264.pil | 0 state-machines/arith/pil/arith_64.pil | 0 state-machines/arith/pil/arith_mul_32.pil | 90 ++++++ state-machines/arith/pil/arith_mul_64.pil | 177 ++++++++++++ .../arith/pil/arith_range_table.pil | 12 + state-machines/arith/pil/arith_table.pil | 100 +++++++ 10 files changed, 865 insertions(+), 3 deletions(-) create mode 100644 pil/fork_0/pil/operations.pil delete mode 100644 state-machines/arith/pil/arith_3264.pil delete mode 100644 state-machines/arith/pil/arith_64.pil create mode 100644 state-machines/arith/pil/arith_mul_32.pil create mode 100644 state-machines/arith/pil/arith_mul_64.pil create mode 100644 state-machines/arith/pil/arith_range_table.pil create mode 100644 state-machines/arith/pil/arith_table.pil diff --git a/pil/fork_0/pil/operations.pil b/pil/fork_0/pil/operations.pil new file mode 100644 index 00000000..aa6f5d09 --- /dev/null +++ b/pil/fork_0/pil/operations.pil @@ -0,0 +1,56 @@ +const int OPERATION_BUS_ID = 90; + +const int OP_FLAG = 0x00; +const int OP_COPYB = 0x01; +const int OP_SIGNEXTEND_B = 0x02; +const int OP_SIGNEXTEND_H = 0x03; +const int OP_SIGNEXTEND_W = 0x04; +const int OP_ADD = 0x10; +const int OP_ADD_W = 0x14; +const int OP_SUB = 0x20; +const int OP_SUB_W = 0x24; +const int OP_SLL = 0x30; +const int OP_SLL_W = 0x34; +const int OP_SRA = 0x40; +const int OP_SRL = 0x41; +const int OP_SRA_W = 0x44; +const int OP_SRL_W = 0x45; +const int OP_EQ = 0x50; +const int OP_EQ_W = 0x54; +const int OP_LTU = 0x60; +const int OP_LT = 0x61; +const int OP_LTU_W = 0x64; +const int OP_LT_W = 0x65; +const int OP_LEU = 0x70; +const int OP_LE = 0x71; +const int OP_LEU_W = 0x74; +const int OP_LE_W = 0x75; +const int OP_AND = 0x80; +const int OP_OR = 0x90; +const int OP_XOR = 0xA0; +const int OP_MULU = 0xB0; +const int OP_MUL = 0xB1; +const int OP_MUL_W = 0xB5; +const int OP_MULUH = 0xB8; +const int OP_MULH = 0xB9; +const int OP_MULSUH = 0xBB; +const int OP_DIVU = 0xC0; +const int OP_DIV = 0xC1; +const int OP_DIVU_W = 0xC4; +const int OP_DIV_W = 0xC5; +const int OP_REMU = 0xC8; +const int OP_REM = 0xC9; +const int OP_REMU_W = 0xCC; +const int OP_REM_W = 0xCD; +const int OP_MINU = 0xD0; +const int OP_MIN = 0xD1; +const int OP_MINU_W = 0xD4; +const int OP_MIN_W = 0xD5; +const int OP_MAXU = 0xE0; +const int OP_MAX = 0xE1; +const int OP_MAXU_W = 0xE4; +const int OP_MAX_W = 0xE5; + +function verify_operation(expr opid, expr sel, expr a[], expr b[], expr c[], expr flag) { + lookup_assume(OPERATION_BUS_ID, cols:[operation, ...a, ...b, ...c, flag], sel:sel); +} diff --git a/pil/fork_0/pil/zisk.pil b/pil/fork_0/pil/zisk.pil index 70a48515..b5ce6cde 100644 --- a/pil/fork_0/pil/zisk.pil +++ b/pil/fork_0/pil/zisk.pil @@ -5,14 +5,17 @@ require "binary/pil/binary.pil" require "binary/pil/binary_table.pil" require "binary/pil/binary_extension.pil" require "binary/pil/binary_extension_table.pil" +require "arith/pil/arith.pil" const int OPERATION_BUS_ID = 5000; +const int DEFAULT_N = 2**21; + airgroup Main { - Main(N: 2**21, RC: 2, operation_bus_id: OPERATION_BUS_ID); + Main(N: DEFAULT_N, RC: 2, operation_bus_id: OPERATION_BUS_ID); } airgroup Binary { - Binary(N: 2**21, operation_bus_id: OPERATION_BUS_ID); + Binary(N: DEFAULT_N, operation_bus_id: OPERATION_BUS_ID); } @@ -22,9 +25,14 @@ airgroup BinaryTable { airgroup BinaryExtension { - BinaryExtension(N: 2**21, operation_bus_id: OPERATION_BUS_ID); + BinaryExtension(N: DEFAULT_N, operation_bus_id: OPERATION_BUS_ID); } airgroup BinaryExtensionTable { BinaryExtensionTable(disable_fixed: 1); } + +airgroup Arith { + Arith(N: DEFAULT_N, operation_bus_id: OPERATION_BUS_ID); +} + diff --git a/state-machines/arith/pil/arith.pil b/state-machines/arith/pil/arith.pil index e69de29b..60191088 100644 --- a/state-machines/arith/pil/arith.pil +++ b/state-machines/arith/pil/arith.pil @@ -0,0 +1,259 @@ +require "std_lookup.pil" +require "std_range_check.pil" +require "operations.pil" +// require "arith_table.pil" + +// generic 64 u64 mul_u64 32 *u32 +// witness 45 41 30 26 27 13 +// lookups 3 3 3 2 3 3 +// range_checks 16+7 16+7 16+7 16+7 8+3 7+2 +// ---------------------------------------------------------- +// TOTAL 123 119 108 101 69 61 +// +// (*) unsigned 32 bit operations only divu_w, remu_w + +airtemplate Arith(int N = 2**18, const int operation_bus_id, const int dual_result = 0) { + + // NOTE: + // Divisions and remainders by 0 are done by QuickOps + + const int CHUNK_SIZE = 2**16; + const int CHUNKS_INPUT = 4; + const int CHUNKS_OP = CHUNKS_INPUT * 2; + + col witness carry[CHUNKS_OP - 1]; + col witness a[CHUNKS_INPUT]; + col witness b[CHUNKS_INPUT]; + col witness c[CHUNKS_INPUT]; + col witness d[CHUNKS_INPUT]; + + col witness na; // a is negative + col witness nb; // b is negative + col witness nr; // rem is negative + col witness np; // prod is negative + col witness na32; // a is 32-bit negative, 31th bit is 1. + col witness nd32; // d is 32-bit negative, 31th bit is 1. + + col witness m32; // 32 bits operation + col witness div; // division operation (div,rem) + + col witness fab; // fab, to decrease degree of intermediate products a * b + // fab = 1 if sign of a,b are the same + // fab = -1 if sign of a,b are different + + if (!dual_result) { + col witness air.secondary_res; // op_index: 0 => first result, 1 => second result; + secondary_res * (secondary_res - 1) === 0; + } else { + const expr air.secondary_res = 0; + } + + // factor ab € {-1, 1} + fab === 1 - 2 * na - 2 * nb + 4 * na * nb; + + const expr eq[CHUNKS_OP]; + + eq[0] = fab * a[0] * b[0] + - c[0] + + 2 * np * c[0] + + div * d[0] + - 2 * nr * d[0]; + + eq[1] = fab * a[1] * b[0] + + fab * a[0] * b[1] + - c[1] + + 2 * np * c[1] + + div * d[1] + - 2 * nr * d[1]; + + eq[2] = fab * a[2] * b[0] + + fab * a[1] * b[1] + + fab * a[0] * b[2] + - c[2] + + 2 * np * c[2] + + div * d[2] + - 2 * nr * d[2] + - np * div * m32 + + nr * m32; + + eq[3] = fab * a[3] * b[0] + + fab * a[2] * b[1] + + fab * a[1] * b[2] + + fab * a[0] * b[3] + - c[3] + + 2 * np * c[3] + + div * d[3] + - 2 * nr * d[3]; + + eq[4] = fab * a[3] * b[1] + + fab * a[2] * b[2] + + fab * a[1] * b[3] + + na * b[0] * (1 - 2 * nb) + + nb * a[0] * (1 - 2 * na) + - np * div // \ + + np * m32 // np * (div ^ m32) + - 2 * div * m32 * np // / + + nr * (1 - m32) + - d[0] * (1 - div) + + 2 * np * d[0] * (1 - div); + + eq[5] = fab * a[3] * b[2] + + fab * a[2] * b[3] + + nb * a[1] * (1 - 2 * na) + + na * b[1] * (1 - 2 * nb) + - d[1] * (1 - div) + + 2 * np * d[1] * (1 - div); + + eq[6] = fab * a[3] * b[3] + + nb * a[2] * (1 - 2 * na) + + na * b[2] * (1 - 2 * nb) + - d[2] * (1 - div) + + 2 * np * d[2] * (1 - div); + + eq[7] = CHUNK_SIZE * na * nb + + na * b[3] * (1 - 2 * nb) + + nb * a[3] * (1 - 2 * na) + - CHUNK_SIZE * np * (1 - div) * (1 - m32) + - d[3] * (1 - div) + + 2 * np * d[3] * (1 - div); + + eq[0] - carry[0] * CHUNK_SIZE === 0; + for (int index = 1; index < (CHUNKS_OP - 1); ++index) { + eq[index] + carry[index-1] - carry[index] * CHUNK_SIZE === 0; + } + eq[CHUNKS_OP-1] + carry[CHUNKS_OP-2] === 0; + + // binary contraint + div * (1 - div) === 0; + m32 * (1 - m32) === 0; + na * (1 - na) === 0; + nb * (1 - nb) === 0; + nr * (1 - nr) === 0; + np * (1 - np) === 0; + na32 * (1 - na32) === 0; + nd32 * (1 - nd32) === 0; + + col witness op; + + // div m32 sa sb comm primary secondary opcodes na nb nr np na32 nd32 + // ---------------------------------------------------------------------------------- + // 0 0 0 0 x mulu muluh (0xb0,0xb1) =0 =0 =0 =0 =0 =0 + // 0 0 1 0 *n/a* mulsuh (0xb2,0xb3) a3 =0 =0 d3 =0 =0 a3, d3 + // 0 0 1 1 x mul mulh (0xb4,0xb5) a3 b3 =0 d3 =0 =0 a3,b3, d3 + // 0 1 1 1 x mul_w *n/a* (0xb6,0xb7) a1 b1 =0 d3 c1 =0 d3, a1,b1,c1 + // 1 0 0 0 divu remu (0xb8,0xb9) =0 =0 =0 =0 =0 =0 + // 1 0 1 1 div rem (0xba,0xbb) a3 b3 d3 c3 =0 =0 a3,b3,c3,d3 + // 1 1 0 0 divu_w remu_w (0xbc,0xbd) =0 =0 =0 =0 c1 d1 c1,d1 + // 1 1 1 1 div_w rem_w (0xbe,0xbf) a1 b1 d1 c1 c1 d1 a1,b1,c1,d1 + + // (*) removed combinations of flags div,m32,sa,sb did allow combinations div, m32, sa, sb + // see 5 previous constraints. + // =0 means forced to zero by previous constraints + // comm = commutative (trivial: commutative operations) + + col witness bus_a_low; + bus_a_low === div * (c[0] - a[0]) + + a[0] + + CHUNK_SIZE * div * (c[1] - a[1]) + + CHUNK_SIZE * a[1]; + + col witness bus_a_high; + bus_a_high === (1 - m32) * (div * (c[2] - a[2]) + + a[2] + + CHUNK_SIZE * div * (c[3] - a[3]) + + CHUNK_SIZE * a[3]); + + + const expr bus_b_low = b[0] + CHUNK_SIZE * b[1]; + + // TODO: na32 and nd32 only valid on 32 bit operations + // TODO: m32 === 0 ==> b[2],a[2],b[3],a[3] === 0 avoid two witness + col witness bus_b_high; + bus_b_high === (1 - m32) * b[2] + (1 - m32) * CHUNK_SIZE * b[3]; + + const expr res2_low = d[0] + CHUNK_SIZE * d[1]; + const expr res2_high = d[2] + CHUNK_SIZE * d[3] + nd32 * 0xFFFFFFFF; + + if (dual_result) { + // theorical cost: 4 columns + col witness multiplicity_2; + lookup_proves(operation_bus_id, [op+1, bus_a_low, bus_a_high, bus_b_low, bus_b_high, res2_low, res2_high, 0], mul: multiplicity_2); + } + + if (dual_result) { + const expr air.res1_low = a[0] + c[0] + CHUNK_SIZE * a[1] + CHUNK_SIZE * c[1] - bus_a_low; + col witness air.res1_high; + res1_high === (1 - m32) * (div * (a[2] - c[2]) + c[2] + CHUNK_SIZE * div * (a[3] - c[3]) + CHUNK_SIZE * c[3]) + div * na32 * 0xFFFFFFFF + (1 - div) * nd32 * 0xFFFFFFFF; + } else { + col witness air.res1_low; + res1_low === secondary_res * res2_low - (1 - secondary_res) * (a[0] + c[0] + CHUNK_SIZE * a[1] + CHUNK_SIZE * c[1] - bus_a_low); + + col witness air.div64; + div64 === (1 - m32) * div; + + col witness air.res1_high; + // res1_high === secondary_res * res2_high + (1 - secondary_res) * ((1 - m32) * (div * (a[2] - c[2]) + c[2] + 2**16 * div * (a[3] - c[3]) + 2**16 * c[3]) + div * na32 * 0xFFFFFFFF + (1 - div) * nd32 * 0xFFFFFFFF); + res1_high === secondary_res * res2_high + (1 - secondary_res) * (div64 * (a[2] - c[2]) + (1 - m32) * c[2] + CHUNK_SIZE * div64 * (a[3] - c[3]) + (1 - m32) * 2**16 * c[3] + div * na32 * 0xFFFFFFFF + (1 - div) * nd32 * 0xFFFFFFFF); + } + + + col witness multiplicity; + + lookup_proves(operation_bus_id, [op + secondary_res, + bus_a_low, bus_a_high, + bus_b_low, bus_b_high, + res1_low, res1_high, +// secondary_res * (res2_low - res1_low) + res1_low, +// secondary_res * (res2_high - res1_high) + res1_high, + 0], mul: multiplicity); + + + // TODO: review + lookup_assumes(operation_bus_id, [OP_LT, res2_low, res2_high, bus_b_low, bus_b_high, 0, 1, 1], sel: div); + + for (int index = 0; index < length(carry); ++index) { + range_check(colu: carry[index], min:-2**20, max: 2**20-1); // TODO: review carry range + } + + // loop for range checks index 0, 2 + for (int index = 0; index < 2; ++index) { + range_check(colu: a[2 * index], min:0, max: CHUNK_SIZE - 1); + range_check(colu: b[2 * index], min:0, max: CHUNK_SIZE - 1); + range_check(colu: c[2 * index], min:0, max: CHUNK_SIZE - 1); + range_check(colu: d[2 * index], min:0, max: CHUNK_SIZE - 1); + } + + col witness range_a1; + col witness range_b1; + col witness range_c1; + col witness range_d1; + + col witness range_a3; + col witness range_b3; + col witness range_c3; + col witness range_d3; + + // verify values of range_xy € {0,1,2} => these constraints not generate + // intermediate columns + range_a1 * (1 - range_a1) * (2 - range_a1) === 0; + range_b1 * (1 - range_b1) * (2 - range_b1) === 0; + range_c1 * (1 - range_c1) * (2 - range_c1) === 0; + range_d1 * (1 - range_d1) * (2 - range_d1) === 0; + range_a3 * (1 - range_a3) * (2 - range_a3) === 0; + range_b3 * (1 - range_b3) * (2 - range_b3) === 0; + range_c3 * (1 - range_c3) * (2 - range_c3) === 0; + range_d3 * (1 - range_d3) * (2 - range_d3) === 0; + + lookup_assumes(ARITH_TABLE_ID, cols: [ op, m32 + 2 * div + 4 * na + 8 * nb + 16 * nr + 32 * np + 64 * na32 + 128 * nd32 + + 2**8 * range_a1 + 2**10 * range_b1 + 2**12 * range_c1 + 2**14 * range_d1 + + 2**16 * range_a3 + 2**18 * range_b3 + 2**20 * range_c3 + 2**22 * range_d3]); + + lookup_assumes(AIRTH_RANGE_TABLE_ID, [range_a1, a[1]]); + lookup_assumes(AIRTH_RANGE_TABLE_ID, [range_b1, b[1]]); + lookup_assumes(AIRTH_RANGE_TABLE_ID, [range_c1, c[1]]); + lookup_assumes(AIRTH_RANGE_TABLE_ID, [range_d1, d[1]]); + lookup_assumes(AIRTH_RANGE_TABLE_ID, [range_a3, a[3]]); + lookup_assumes(AIRTH_RANGE_TABLE_ID, [range_b3, b[3]]); + lookup_assumes(AIRTH_RANGE_TABLE_ID, [range_c3, c[3]]); + lookup_assumes(AIRTH_RANGE_TABLE_ID, [range_d3, d[3]]); +} \ No newline at end of file diff --git a/state-machines/arith/pil/arith_32.pil b/state-machines/arith/pil/arith_32.pil index e69de29b..d012e38d 100644 --- a/state-machines/arith/pil/arith_32.pil +++ b/state-machines/arith/pil/arith_32.pil @@ -0,0 +1,160 @@ +require "std_lookup.pil" +require "std_range_check.pil" +require "operations.pil" +require "arith_table.pil" + +airtemplate Arith32(int N = 2**10, const int dual_result = 0) { + + // NOTE: + // Divisions and remainders by 0 are done by QuickOps + + col witness carry[3]; + col witness a[2]; + col witness b[2]; + col witness c[2]; + col witness d[2]; + + col witness na; // a is negative + col witness nb; // b is negative + col witness nr; // rem is negative + col witness np; // prod is negative + col witness na32; // a is 32-bit negative, 31th bit is 1. + col witness nd32; // d is 32-bit negative, 31th bit is 1. + + col witness div; // division operation (div,rem) + + col witness fab; // fab, to decrease degree of intermediate products a * b + // fab = 1 if sign of a,b are the same + // fab = -1 if sign of a,b are different + + if (!dual_result) { + col witness air.secondary_res; // op_index: 0 => first result, 1 => second result; + secondary_res * (secondary_res - 1) === 0; + } else { + const expr air.secondary_res = 0; + } + + fab === 1 - 2 * na - 2 * nb + 4 * na * nb; + + const expr eq[8]; + + eq[0] = fab * a[0] * b[0] + - c[0] + + 2 * np * c[0] + + div * d[0] + - 2 * nr * d[0]; + + eq[1] = fab * a[1] * b[0] + + fab * a[0] * b[1] + - c[1] + + 2 * np * c[1] + + div * d[1] + - 2 * nr * d[1]; + + eq[2] = fab * a[1] * b[1] + - np * div + + nr; + + // TODO: review !!!!! + eq[3] = 2**16 * na * nb; + + eq[0] - carry[0] * 2**16 === 0; + eq[1] + carry[0] - carry[1] * 2**16 === 0; + eq[2] + carry[1] - carry[2] * 2**16 === 0; + eq[3] + carry[2] === 0; + + // binary contraint + div * (1 - div) === 0; + na * (1 - na) === 0; + nb * (1 - nb) === 0; + nr * (1 - nr) === 0; + np * (1 - np) === 0; + na32 * (1 - na32) === 0; + nd32 * (1 - nd32) === 0; + + col witness op; + + // div sa sb comm primary secondary opcodes na nb nr np na32 nd32 + // ------------------------------------------------------------------------------ + // 0 1 1 x mul_w *n/a* (0xb6,0xb7) a1 b1 0 d3 c1 0 d3, a1,b1,c1 + // 1 1 1 div_w rem_w (0xbe,0xbf) a1 b1 d1 c1 c1 d1 a1,b1,c1,d1 + + // (*) removed combinations of flags div,sa,sb did allow combinations div, sa, sb + // comm = commutative (trivial: commutative operations) + + col witness bus_a_low; + bus_a_low === div * (c[0] - a[0]) + + a[0] + + 2**16 * div * (c[1] - a[1]) + + 2**16 * a[1]; + + const expr bus_a_high = 0; + + + const expr bus_b_low = b[0] + 2**16 * b[1]; + + // TODO: na32 and nd32 only valid on 32 bit operations + // TODO: m32 === 0 ==> b[2],a[2],b[3],a[3] === 0 avoid two witness + const expr bus_b_high = 0; + + const expr res2_low = d[0] + 2**16 * d[1]; + const expr res2_high = nd32 * 0xFFFFFFFF; + + if (dual_result) { + // theorical cost: 4 columns + col witness multiplicity_2; + lookup_proves(OPERATION_BUS_ID, [op+1, bus_a_low, bus_a_high, bus_b_low, bus_b_high, res2_low, res2_high, 0], mul: multiplicity_2); + } + + if (dual_result) { + const expr air.res1_low = a[0] + c[0] + 2**16 * a[1] + 2**16 * c[1] - bus_a_low; + col witness air.res1_high; + res1_high === div * na32 * 0xFFFFFFFF + (1 - div) * nd32 * 0xFFFFFFFF; + } else { + col witness air.res1_low; + res1_low === secondary_res * res2_low - (1 - secondary_res) * (a[0] + c[0] + 2**16 * a[1] + 2**16 * c[1] - bus_a_low); + + col witness air.res1_high; + res1_high === secondary_res * res2_high + (1 - secondary_res) * (div * na32 * 0xFFFFFFFF + (1 - div) * nd32 * 0xFFFFFFFF); + } + + + col witness multiplicity; + + lookup_proves(OPERATION_BUS_ID, [op + secondary_res, + bus_a_low, bus_a_high, + bus_b_low, bus_b_high, + res1_low, res1_high, + 0], mul: multiplicity); + + + // TODO: review + lookup_assumes(OPERATION_BUS_ID, [OP_LT, res2_low, res2_high, bus_b_low, bus_b_high, 0, 1, 1], sel: div); + + for (int index = 0; index < length(carry); ++index) { + range_check(colu: carry[index], min:-2**20, max: 2**20-1); // TODO: review range + } + + range_check(colu: a[0], min:0, max: 2**16-1); + range_check(colu: b[0], min:0, max: 2**16-1); + range_check(colu: c[0], min:0, max: 2**16-1); + range_check(colu: d[0], min:0, max: 2**16-1); + + col witness range_a1; + col witness range_b1; + col witness range_c1; + col witness range_d1; + + lookup_assumes(ARITH_TABLE_ID, cols: [ op, 1 + 2 * div + 4 * na + 8 * nb + 16 * nr + 32 * np + 64 * na32 + 128 * nd32 + + 2**8 * range_a1 + 2**10 * range_b1 + 2**12 * range_c1 + 2**14 * range_d1]); + + range_a1 * (1 - range_a1) * (2 - range_a1) === 0; + range_b1 * (1 - range_b1) * (2 - range_b1) === 0; + range_c1 * (1 - range_c1) * (2 - range_c1) === 0; + range_d1 * (1 - range_d1) * (2 - range_d1) === 0; + + lookup_assumes(QUICK_RANGE_TABLE_ID, [range_a1, a[1]]); + lookup_assumes(QUICK_RANGE_TABLE_ID, [range_b1, b[1]]); + lookup_assumes(QUICK_RANGE_TABLE_ID, [range_c1, c[1]]); + lookup_assumes(QUICK_RANGE_TABLE_ID, [range_d1, d[1]]); +} \ No newline at end of file diff --git a/state-machines/arith/pil/arith_3264.pil b/state-machines/arith/pil/arith_3264.pil deleted file mode 100644 index e69de29b..00000000 diff --git a/state-machines/arith/pil/arith_64.pil b/state-machines/arith/pil/arith_64.pil deleted file mode 100644 index e69de29b..00000000 diff --git a/state-machines/arith/pil/arith_mul_32.pil b/state-machines/arith/pil/arith_mul_32.pil new file mode 100644 index 00000000..91c2ec0c --- /dev/null +++ b/state-machines/arith/pil/arith_mul_32.pil @@ -0,0 +1,90 @@ +require "std_lookup.pil" +require "std_range_check.pil" +require "operations.pil" +require "arith_table.pil" + +airtemplate ArithMul32(int N = 2**10, const int operation_bus_id) { + + const int CHUNK_SIZE = 2**16; + const int CHUNKS_INPUT = 2; + const int CHUNKS_OP = CHUNKS_INPUT * 2; + + col witness carry[CHUNKS_OP - 1]; + col witness a[CHUNKS_INPUT]; + col witness b[CHUNKS_INPUT]; + col witness c[CHUNKS_INPUT]; + col witness d[CHUNKS_INPUT]; + + col witness na; // a is negative + col witness nb; // b is negative + col witness np; // prod is negative + col witness nd32; // d is 32-bit negative, 31th bit is 1. + + col witness fab; // fab, to decrease degree of intermediate products a * b + // fab = 1 if sign of a,b are the same + // fab = -1 if sign of a,b are different + // factor ab € {-1, 1} + fab === 1 - 2 * na - 2 * nb + 4 * na * nb; + + const expr eq[CHUNKS_OP]; + + eq[0] = fab * a[0] * b[0] + - c[0] + + 2 * np * c[0]; + + eq[1] = fab * a[1] * b[0] + + fab * a[0] * b[1] + - c[1] + + 2 * np * c[1]; + + eq[2] = fab * a[1] * b[1]; + + // TODO: review !!!!! + eq[3] = 2**16 * na * nb; + + eq[0] - carry[0] * CHUNK_SIZE === 0; + for (int index = 1; index < (CHUNKS_OP - 1); ++index) { + eq[index] + carry[index-1] - carry[index] * CHUNK_SIZE === 0; + } + eq[CHUNKS_OP-1] + carry[CHUNKS_OP-2] === 0; + + // binary contraint + na * (1 - na) === 0; + nb * (1 - nb) === 0; + np * (1 - np) === 0; + nd32 * (1 - nd32) === 0; + + np === na + nb - 2 * na * nb; + + const expr bus_a_low = a[0] + 2**16 * a[1]; + const expr bus_a_high = 0; + + const expr bus_b_low = b[0] + CHUNK_SIZE * b[1]; + const expr bus_b_high = 0; + + const expr res1_low = c[0] + CHUNK_SIZE * + CHUNK_SIZE * c[1]; + const expr res1_high = nd32 * 0xFFFFFFFF; + + col witness multiplicity; + + lookup_proves(operation_bus_id, [OP_MUL_W, + bus_a_low, bus_a_high, + bus_b_low, bus_b_high, + res1_low, res1_high, + 0], mul: multiplicity); + + + for (int index = 0; index < length(carry); ++index) { + range_check(colu: carry[index], min:-2**20, max: 2**20-1); // TODO: review range + } + + range_check(colu: a[0], min:0, max: CHUNK_SIZE-1); + range_check(colu: b[0], min:0, max: CHUNK_SIZE-1); + range_check(colu: c[0], min:0, max: CHUNK_SIZE-1); + range_check(colu: d[0], min:0, max: CHUNK_SIZE-1); + range_check(colu: c[1], min:0, max: CHUNK_SIZE-1); + + lookup_assumes(QUICK_RANGE_TABLE_ID, [1 + na, a[1]]); + lookup_assumes(QUICK_RANGE_TABLE_ID, [1 + nb, b[1]]); + lookup_assumes(QUICK_RANGE_TABLE_ID, [1 + np, d[1]]); +} \ No newline at end of file diff --git a/state-machines/arith/pil/arith_mul_64.pil b/state-machines/arith/pil/arith_mul_64.pil new file mode 100644 index 00000000..03ca6ec4 --- /dev/null +++ b/state-machines/arith/pil/arith_mul_64.pil @@ -0,0 +1,177 @@ +require "std_lookup.pil" +require "std_range_check.pil" +require "operations.pil" +require "arith_table.pil" + +airtemplate ArithMul64(int N = 2**18, const int operation_bus_id, const int dual_result = 0) { + + // NOTE: + // Divisions and remainders by 0 are done by QuickOps + + const int CHUNK_SIZE = 2**16; + const int CHUNKS = 8; + + col witness carry[CHUNKS - 1]; + col witness a[4]; + col witness b[4]; + col witness c[4]; + col witness d[4]; + + col witness na; // a is negative + col witness nb; // b is negative + col witness np; // prod is negative + + col witness fab; // fab, to decrease degree of intermediate products a * b + // fab = 1 if sign of a,b are the same + // fab = -1 if sign of a,b are different + + if (!dual_result) { + col witness air.secondary_res; // op_index: 0 => first result, 1 => second result; + secondary_res * (secondary_res - 1) === 0; + } else { + const expr air.secondary_res = 0; + } + + // factor ab € {-1, 1} + fab === 1 - 2 * na - 2 * nb + 4 * na * nb; + + const expr eq[CHUNKS]; + + eq[0] = fab * a[0] * b[0] + - c[0] + + 2 * np * c[0]; + + eq[1] = fab * a[1] * b[0] + + fab * a[0] * b[1] + - c[1] + + 2 * np * c[1]; + + eq[2] = fab * a[2] * b[0] + + fab * a[1] * b[1] + + fab * a[0] * b[2] + - c[2] + + 2 * np * c[2]; + + eq[3] = fab * a[3] * b[0] + + fab * a[2] * b[1] + + fab * a[1] * b[2] + + fab * a[0] * b[3] + - c[3] + + 2 * np * c[3]; + + eq[4] = fab * a[3] * b[1] + + fab * a[2] * b[2] + + fab * a[1] * b[3] + + na * b[0] * (1 - 2 * nb) + + nb * a[0] * (1 - 2 * na) + - d[0] + + 2 * np * d[0]; + + eq[5] = fab * a[3] * b[2] + + fab * a[2] * b[3] + + nb * a[1] * (1 - 2 * na) + + na * b[1] * (1 - 2 * nb) + - d[1] + + 2 * np * d[1]; + + eq[6] = fab * a[3] * b[3] + + nb * a[2] * (1 - 2 * na) + + na * b[2] * (1 - 2 * nb) + - d[2] + + 2 * np * d[2]; + + eq[7] = CHUNK_SIZE * na * nb + + na * b[3] * (1 - 2 * nb) + + nb * a[3] * (1 - 2 * na) + - CHUNK_SIZE * np + - d[3] + + 2 * np * d[3]; + + eq[0] - carry[0] * CHUNK_SIZE === 0; + for (int index = 1; index < (CHUNKS - 1); ++index) { + eq[index] + carry[index-1] - carry[index] * CHUNK_SIZE === 0; + } + + // binary contraint + na * (1 - na) === 0; + nb * (1 - nb) === 0; + np * (1 - np) === 0; + + col witness op; + + // div m32 sa sb comm primary secondary opcodes na nb nr np na32 nd32 + // ---------------------------------------------------------------------------------- + // 0 0 0 0 x mulu muluh (0xb0,0xb1) =0 =0 =0 =0 =0 =0 + // 0 0 1 0 *n/a* mulsuh (0xb2,0xb3) a3 =0 =0 d3 =0 =0 a3, d3 + // 0 0 1 1 x mul mulh (0xb4,0xb5) a3 b3 =0 d3 =0 =0 a3,b3, d3 + + // (*) removed combinations of flags div,m32,sa,sb did allow combinations div, m32, sa, sb + // see 5 previous constraints. + // =0 means forced to zero by previous constraints + // comm = commutative (trivial: commutative operations) + + const expr bus_a_low = a[0] + CHUNK_SIZE * a[1]; + const expr bus_a_high = a[2] + CHUNK_SIZE * a[3]; + + + const expr bus_b_low = b[0] + CHUNK_SIZE * b[1]; + const expr bus_b_high = b[2] + CHUNK_SIZE * b[3]; + + const expr res2_low = d[0] + CHUNK_SIZE * d[1]; + const expr res2_high = d[2] + CHUNK_SIZE * d[3]; + + if (dual_result) { + // theorical cost: 4 columns + col witness multiplicity_2; + lookup_proves(operation_bus_id, [op+1, bus_a_low, bus_a_high, bus_b_low, bus_b_high, res2_low, res2_high, 0], mul: multiplicity_2); + + const expr air.res1_low = a[0] + c[0] + CHUNK_SIZE * a[1] + CHUNK_SIZE * c[1] - bus_a_low; + const expr air.res1_high = c[2] + CHUNK_SIZE * c[3]; + } else { + col witness air.res1_low; + res1_low === secondary_res * res2_low - (1 - secondary_res) * (a[0] + c[0] + CHUNK_SIZE * a[1] + CHUNK_SIZE * c[1] - bus_a_low); + + col witness air.res1_high; + // res1_high === secondary_res * res2_high + (1 - secondary_res) * ((1 - m32) * (div * (a[2] - c[2]) + c[2] + 2**16 * div * (a[3] - c[3]) + 2**16 * c[3]) + div * na32 * 0xFFFFFFFF + (1 - div) * nd32 * 0xFFFFFFFF); + res1_high === secondary_res * res2_high + (1 - secondary_res) * (c[2] + CHUNK_SIZE * c[3]); + } + + + col witness multiplicity; + + lookup_proves(operation_bus_id, [op + secondary_res, + bus_a_low, bus_a_high, + bus_b_low, bus_b_high, + res1_low, res1_high, +// secondary_res * (res2_low - res1_low) + res1_low, +// secondary_res * (res2_high - res1_high) + res1_high, + 0], mul: multiplicity); + + for (int index = 0; index < length(carry); ++index) { + range_check(colu: carry[index], min:-2**20, max: 2**20-1); // TODO: review range + } + + // loop for range checks index 0, 2 + for (int index = 0; index < 3; ++index) { + range_check(colu: a[index], min:0, max: CHUNK_SIZE-1); + range_check(colu: b[index], min:0, max: CHUNK_SIZE-1); + range_check(colu: c[index], min:0, max: CHUNK_SIZE-1); + range_check(colu: d[index], min:0, max: CHUNK_SIZE-1); + } + + range_check(colu: c[3], min:0, max: 2**16-1); + + col witness range_a3; + col witness range_b3; + col witness range_d3; + + lookup_assumes(ARITH_TABLE_ID, cols: [ op, 4 * na + 8 * nb + 32 * np + 2**16 * range_a3 + 2**18 * range_b3 + 2**22 * range_d3]); + + range_a3 * (1 - range_a3) * (2 - range_a3) === 0; + range_b3 * (1 - range_b3) * (2 - range_b3) === 0; + range_d3 * (1 - range_d3) * (2 - range_d3) === 0; + + lookup_assumes(QUICK_RANGE_TABLE_ID, [range_a3, a[3]]); + lookup_assumes(QUICK_RANGE_TABLE_ID, [range_b3, b[3]]); + lookup_assumes(QUICK_RANGE_TABLE_ID, [range_d3, d[3]]); +} \ No newline at end of file diff --git a/state-machines/arith/pil/arith_range_table.pil b/state-machines/arith/pil/arith_range_table.pil new file mode 100644 index 00000000..48cdb666 --- /dev/null +++ b/state-machines/arith/pil/arith_range_table.pil @@ -0,0 +1,12 @@ +require "std_lookup.pil" +require "operations.pil" + +const int ARITH_RANGE_TABLE_ID = 330; + +airtemplate ArithRangeTable(int N = 2**17) { + + col fixed RANGES = [0:2**16,1:2**15,2:2**15]; + col fixed VALUES = [0..2**16-1]..; + + lookup_proves(ARITH_TABLE_ID, [RANGES, VALUES], multiplicity); +} \ No newline at end of file diff --git a/state-machines/arith/pil/arith_table.pil b/state-machines/arith/pil/arith_table.pil new file mode 100644 index 00000000..61e86400 --- /dev/null +++ b/state-machines/arith/pil/arith_table.pil @@ -0,0 +1,100 @@ +require "std_lookup.pil" +require "operations.pil" + +const int ARITH_TABLE_ID = 330; + +airtemplate ArithTable(int N = 2**8) { + + // NOTE: + // Divisions and remainders by 0 are done by QuickOps + + int na; // a is negative + int nb; // b is negative + int nr; // rem is negative + int np; // prod is negative + int na32; // a is 32-bit negative, 31th bit is 1. + int nd32; // d is 32-bit negative, 31th bit is 1. + + int m32; // 32 bits operation + int div; // division operation (div,rem) + int sa; + int sb; + + // negative a,c,d,na32,nd32 must be 0 if no signed_a + // na * (1 - sa) === 0; + // nr * (1 - sa) === 0; + // nr * (1 - div) === 0; + // np * (1 - sa) === 0; + // na32 * (1 - sa) === 0; + // nd32 * (1 - sa) === 0; + + // negative b must be 0 if no signed_b + // nb * (1 - sb) === 0; + + // na32, nd32 only available when 32 bits operation + // na32 * (1 - m32) === 0; + // nd32 * (1 - m32) === 0; + + // nr, nd32 only could be one 1 in divisions + // nr * (1 - div) === 0; + // nd32 * (1 - div) === 0; + + // if sb === 1 then sa must be 1, not allowed sa = 0, sb = 1 + // sb * (1 - sa) === 0; + // m32 * (sa - sb) === 0; + // div * (sa - sb) === 0; + // (1 - div) * m32 * (1 - sa) === 0; + // (1 - div) * m32 * (1 - sb) === 0; + + int op; + op = 0xb0 + 2 * (sa + sb + m32) + 2 * div * (m32 - sa + 4); + + // div m32 sa sb comm primary secondary opcodes na nb nr np na32 nd32 + // ---------------------------------------------------------------------------------- + // 0 0 0 0 x mulu muluh (0xb0,0xb1) =0 =0 =0 =0 =0 =0 + // 0 0 1 0 *n/a* mulsuh (0xb2,0xb3) a3 =0 =0 d3 =0 =0 a3, d3 + // 0 0 1 1 x mul mulh (0xb4,0xb5) a3 b3 =0 d3 =0 =0 a3,b3, d3 + // 0 1 1 1 x mul_w *n/a* (0xb6,0xb7) a1 b1 =0 d3 c1 =0 d3, a1,b1,c1 + // 1 0 0 0 divu remu (0xb8,0xb9) =0 =0 =0 =0 =0 =0 + // 1 0 1 1 div rem (0xba,0xbb) a3 b3 d3 c3 =0 =0 a3,b3,c3,d3 + // 1 1 0 0 divu_w remu_w (0xbc,0xbd) =0 =0 =0 =0 c1 d1 c1,d1 + // 1 1 1 1 div_w rem_w (0xbe,0xbf) a1 b1 d1 c1 c1 d1 a1,b1,c1,d1 + + // (*) removed combinations of flags div,m32,sa,sb did allow combinations div, m32, sa, sb + // see 5 previous constraints. + // =0 means forced to zero by previous constraints + // comm = commutative (trivial: commutative operations) + + // positive_a1 = m32 * sa * (1 - na); + // negative_a1 = m32 * sa * na; + + // positive_a3 = (1-m32) * sa * (1 - na); + // negative_a3 = (1-m32) * sa * na; + + // positive_b1 = m32 * sa * (1 - nb); + // negative_b1 = m32 * sa * nb; + + // positive_b3 = (1-m32) * sb * (1 - nb); + // negative_b3 = (1-m32) * sb * nb; + + // positive_c1 = div * m32 * sa * (1 - np) + div * m32 * (1 - na32) + (1 - div) * m32 * (1 - na32); + // negative_c1 = div * m32 * sa * np + div * m32 * na32 + (1 - div) * m32 * na32; + + // positive_c3 = div * (1-m32) * sa * (1 - np); + // negative_c3 = div * (1-m32) * sa * np; + + // positive_d1 = div * m32 * (1 - nd32) + div * m32 * sa * (1 - nr); + // negative_d1 = div * m32 * nd32 + div * m32 * sa * nr; + + // positive_d3 = (1-div) * sa * (1 - np) + div * (1-m32) * sa * (1 - nr); + // negative_d3 = (1-div) * sa * np + div * (1-m32) * sa * nr; + + // TODO: correct values + + + col fixed OP = [0..10]...; + col fixed FLAGS_AND_RANGES = [1,0...]; + col witness multiplicity; + + lookup_proves(ARITH_TABLE_ID, mul: multiplicity, cols: [OP, FLAGS_AND_RANGES]); +} \ No newline at end of file From 7b1c281fed7d6c4a8dda9588490b8aeb56b1e10a Mon Sep 17 00:00:00 2001 From: zkronos73 Date: Mon, 16 Sep 2024 17:19:35 +0200 Subject: [PATCH 02/17] WIP updating pils and complete arith table --- pil/fork_0/pil/operations.pil | 6 - pil/fork_0/pil/zisk.pil | 7 + pil/fork_0/pil/zisk.pilout | Bin 24060707 -> 24926215 bytes pil/fork_0/src/pil_helpers/pilout.rs | 38 ++- pil/fork_0/src/pil_helpers/traces.rs | 18 +- state-machines/arith/pil/arith.pil | 19 +- .../arith/pil/arith_range_table.pil | 7 +- state-machines/arith/pil/arith_table.pil | 223 ++++++++++++------ 8 files changed, 210 insertions(+), 108 deletions(-) diff --git a/pil/fork_0/pil/operations.pil b/pil/fork_0/pil/operations.pil index aa6f5d09..d857a854 100644 --- a/pil/fork_0/pil/operations.pil +++ b/pil/fork_0/pil/operations.pil @@ -1,5 +1,3 @@ -const int OPERATION_BUS_ID = 90; - const int OP_FLAG = 0x00; const int OP_COPYB = 0x01; const int OP_SIGNEXTEND_B = 0x02; @@ -50,7 +48,3 @@ const int OP_MAXU = 0xE0; const int OP_MAX = 0xE1; const int OP_MAXU_W = 0xE4; const int OP_MAX_W = 0xE5; - -function verify_operation(expr opid, expr sel, expr a[], expr b[], expr c[], expr flag) { - lookup_assume(OPERATION_BUS_ID, cols:[operation, ...a, ...b, ...c, flag], sel:sel); -} diff --git a/pil/fork_0/pil/zisk.pil b/pil/fork_0/pil/zisk.pil index b5ce6cde..793a9227 100644 --- a/pil/fork_0/pil/zisk.pil +++ b/pil/fork_0/pil/zisk.pil @@ -36,3 +36,10 @@ airgroup Arith { Arith(N: DEFAULT_N, operation_bus_id: OPERATION_BUS_ID); } +airgroup ArithTable { + ArithTable(); +} + +airgroup ArithRangeTable { + ArithRangeTable(); +} diff --git a/pil/fork_0/pil/zisk.pilout b/pil/fork_0/pil/zisk.pilout index a19e26a506c13f6cbbb9b54408075bbdd5004f33..918759047d13853217025cf96b362537870eb447 100644 GIT binary patch delta 213087 zcmeI3d3>Bz)yH!uU52NnO{dV7Da%l{B;7JIN!pCEq(Fg+1=$oT#7R=xU`cCQz`8xH zoC@kAf(k_uK^7Mf1zBWY6h(30cj~_HyQ1)(yFJe{nJ2~1`*}Z~^ZxO}pqYEmJ?GwY z&+mTk+}0aE_29HiZvW%7ZKNnitu%=y(-fLYdD?>lnnu%UPntn9sf}h)fo9WQv^UM6 zxipXVp(4$veQ7`1pAMiB9Y_o4AUc>1p+o5~YNrnBq=mGI7Sj@1N?lZ@3RS6_dZ?F{ z(Q@je8aT1m&z@pJ;Mq8HGK^g>!qC(+6DB07amrPJti zI)h$JFQGH(ELuY^rL*ZA8laa^od#)$hG{LGOY3MojnH}Ya@s)W(0%<jqqoxKbOpVQ z-cIkJE9olQNmtW5>0NXUy_>G3_t158J-wH1pd0BXx|!Zb@23yY2k92Nl|Dqb(TC|H zbUWQaAEi6#F1nlUp^wqMbRXSM572}35Iszf(4+JieVjf)kJBgVQ}hIVnm$8M(r4*& z^m+OMeUZLIU#73nSLti?b@~Qq>3j5jdWv?@59o*VBl`ZfKAeoMcj-_sxHkMt+{GyR4BN`Irj(?95+^e=jv{!Ra(XXx2WRispuQ&Fpm zCaGw$il(S&s*3U|+CxQwil(V(x{CHx(F_&MR8gCXW~r#4qS-3iOGSIDXpV~Js%V~y z_EAw$Me|j(uZs3l(f%qrKt&}L9jKxODmq9-2dn516&;Q&CMt&r{L!Rdl$Dj!@B&DmqF<{VH0a zqN7!GjEas`(MlB^r=sIkbb^XjspthNI#ES0RMBb`ous0ZRrDeiouZ;sRdkw)PFK+x zDtfVsUZSEiRdkk$)~M*EDmq(5=cs5vMK4oPT}6W`8dA})iq@*=TotWT(Rvk)sOUTu zy<9~bRCKxqobA%lbC8;HvodZQHh~PF1eVF3bw@ zE2iXCTWMZPVRFzKC{-yG+ft2_R^{?7mD$Bvf$7c_to_`1EtQ#tJ#C{ZXscYaRXC5$ zW)!9ey1&&}O-igLrF`k2+GvVvGXn!VaIzUVwNfbfi5qGp*0vnzWSXS=nvaD=Hk*QzDS=U&Ea@6<$+6y$A)&@6XuDh0dGS1!Mg zvz?w=$vhLN7^iQ(vuR_S-KKB9FlFiUvaj*Be`?xE{`N{mJ;2zNnj@ssx4;E316}$K zGy@N+6oNLp{)3$nXEBvim6g*&%-BOylS^{C2M_mU+U{KHs>(`fL;gFRaXP0RCeXsn z!Abq*@{8TrJyV%kWX3Lu_hYHEnZ-7Z`w_CT)Ad|VS^E9xGTthwnI!wMcPeVt*maMK zzyZ?x(dz=3fraTd|2<~pvQz|~lQZO4aZ>wS+kQH#+sx%=Y%MjG!=>$r!cQpF!($WAAH?w{@}iVf{JGuTQye z)T`I#h#7cZrk&BS)8@;K-G=6ArNuw$=C23x9-e80ve7l_Gt5Ul-+0?(?k@Jm(m2u~ zQy&wW^<(k6USU3O-uautu2RUnH6JgC{cbUzFKnJ=`T%{Ui=}a`FHa2<(97{#w$*rh zRUG^x^ZC`ju`s4cnP<-5Tsi; z{-x;{i9EDkn__5rjEnK+B*vR;j3|lmM%(fhmAy%`i=n^w^2NbS`K>%yFM=0;nH_vN z4<3K()umt)*qVQv&!4_~WY&Vu!P~9FcO>(>BD76ca*Tkhh@)QR-Ej$|3Sx4)Ap8t> z+R<01Gnicx@A6*tCMP8^FJ!m&vaFfbv`&3qjLY}6Nxt80!@nn4&o#E?y7YvSE!E{X zE<~u>GWPU_B*uHK#~YIv*V~qxCfrlMrF{1=Tlzj8>$db}JNEtYmWF}y-(=Ez`az#A zeSg{Qp59_T-kQws1GeQu$(G*cz40cdwsZ>5FN{U*r}$wz{3FU(`fj&LdOv*K-Qm6J z1SVUWq+4(GxO6vVJJcn2SCZ~KZS=d7<$Tn(+>_qY`T@W>t*$j9sO`dL1>70vq!Aiqa4>&!-J||*8f<>NPULug;{a%rNxxe zCz8P*x9%Q~ODVgOKIL~w-+P)XiS-vx@Cobj(zNzQbUXv?o;Sn7+RvFtOC4w}*0 z^lKTWUY0gIe=?TeP&;4Ix@%6Dgq!brzLnu)Umwmy<4?xNcetbr{?&k=0G!7B#mNTr zN&FtWbtm-q{T@u67g{}!Q>m_<#4~?8cAhmheKqc~Tl<4#v!AjpKg<+?8S23KBWw6$ zo>3UT+bnR4`boy}^=2pY;ctiIj{L@z{ybUfFZ@b19Q1{J{6cm4CZ73QhV!hC1;5f? z+3Ef|S?Mor%WpDQYKFR%{?;1)j)%Hy#LsNY@9iq{q1n)b+}i&z!P<3^=<@oLUycFM z_#;JUNq=e3mVZ$Nzf@DFy6vxibn2cI=KpLb^fxylGoORqeEy!fKyyvtQnFds_fPHp zpX?w$ZpE+#1Qbvlm3$=MN$s;|)HBIeg>NKGY4r zB-gY){q4{ktZ_a)^xv3EF;Oc%A4~}*JZi+oegWd6ytOV6Y?# zuKxUci*@0r`_>E>>D6B0SF4Bn#9wab^tR@u@=DorN^LBZf;+zT*L$y!cUTWs@|HLi zz1_B4HNkv*&0n2suK9QRIhqL0-%jiAU7X=Ces$I7q+LHB#Ia`H&EuQ$evKV}ZN_}e z#yhoM=VSZRTbIc7jSfS5Q-6>3_+DPKJ0~u)EjJ`tc#Cbh(MGjXJ&Y%4_omFe!zF2lVHmo8IK=NS$dD!(tG@bo#NxU+N$F5K`yTfD^6Iw@}yE% zt<+U2m+JaYZ{xTEe~F zj(ciPsN&;_L(VBYq)g08tyC#hDw#pcm2!z=@jr2F=c~S^>um3LQ2)NzS`6vmIlP`* zXu@yuW~Ph#tl{v;h0fKm9>NPq)-&dZ_%7$WJ~pQZ>$B*svSrO4Y3;Yc=J@)$j^-zhbsemdqdAk{ebWE*V;*?Xs3^#%C|t3~5|3 zXD8iIj#KbduK3*?r{Ji<*K%ezyKANHQg64}&2W4{rc;o)r_AuR2jq~}9*5X%NIiF$SbFHRkduyfM zQn~Wn37z01;^^sg#o7%g%4v4GI9utQ+BCSFj@IJ{ylI^cr!SwXI$_>2*~YNRln-is z(%HN;)N7paxYg9CKI#22&nKo|r|xqT8r_ua7S}h=%z9^&FcnOnq)DfCW#$|^XE)BF zIdhv#8@DoE!;m?WyW_omELVIu7qsTzsS5YxLbJQeYNchRGUKbBfynfUWJ9F)EWNS3 zXK{|xHryU>P2-jghOWczYr~$IjfnlFc2`$GgQeMBzh%Q*Ax)`CZBaUl<+DTM$}Y=g zJe|gFt&wy?Pn^bUa>c848gEyH%k7!7yjEIXDlhAOt}LdLX!u)Slg=Xlo~$~lRGq|z zRZp0oCd+a8;bSu{!EpCw=3GsfpPupg>5cPqTdsJE&d+C6;Rc(Zz8c?~DqRg#X8EIU zV&zJYXTP~SczmhGqF{iIIb(J|*I%zra> z#CL{r4?CJ5-Bo|P=i@(F*`^NjWurC!ZB^*E8DdtJd)ZG@vTd%b-F`xKKL1m$_@j6B z-Pjo)&FrXBo+#Zt-q~<^=SKNcN9`U}yi@V!{ZtihwYTpwpHb!I{8hrc3md-W-C~KQE(J97nqKPw%vSSnWhD$h0*CxQ9hmHaeR=o{uVYE z8xZ)fK&|;ds=|YA>U?}~7AoBn)Jw0=gq`{Nk#*~N>M7Ln9GX`Wvi$fi7f*=~yksf- zVYf@ECgWJ%%_}jDwe6}H@+SO46<%rTqr$STu&le6bHVEKf81Bn@-km&dL@GcZY7ON zvDr5FcAjEmpZ)69K-}8bRq-pj6n5kaPuhG|_)Msj!&0Spf`s;UmH7K`+0{V9w_?AW z`X17_eVV@JO2boX7?RGuE)rKSq5CqLt`}PZjehO7G`CH0ZE~dA$;H*Yu1j-;<7{^I zcV#}G-1n=d;c>-Ic!XzTRb(z8-u1%gkeYyIF+a2mN)# z<;H(UYb*&nfw<3RTe}FmV`=JJNbJPb^X@8ZNZ(qLs|tIIk*)KYUPrn_b%wp+a5Wrf z$s4+^xR!JO!Cc`SyPW~9Egihbid$RUW&-j1 z(`|7i1^{0|s$JQFsC0KVMjC92G+1tq#BFX#cf^615S3mI)KdyG*TiRB!@=|pI3|G4OJGZSuw*ghN zP4@oho0_v%ulh(2#gP~U%&+PMe64hMH%1z?k<3Gz4K!Gd1L^x$-00Tyn2%KDlbY`{ zReoX1-X^}Wrq?vo6lo}l6bIrq7wD6@LTHP;%7-yuc&c1}*^zo1BMmo28qSEsZNy&>ED{# za*+20U$?aM2CX-@bn4%Z$ZX49&d(zH-z)jwbN?RXMivIS^PX+` z{NGN7Zw7jXZZ`k7lj-FD3xfXBNBajbE&37N{K295Lt59U|FV<*%Y7}iD^&3^{Zgw6 zm)Hw}c{Sg8#q?c^#vf7xTXZei+iY*Iy?~nXT5|pc8@7yW+AuOSvSq7-Q*#HusTa3D zQw?UCx;#an+P*k{^`DI9ZBw@87i`U-AIJ{YUl!i2Nh}k|!YU|B|OK?*F*|NU9(258yBOE2%*67yOk} zAowe(e()Fkl~lh9kbmS~@&v^F2lpRI1>*iAseaslz+XxAQ=0>NMKS5krCucZ3H zU+`B_{dj){f5Bf#1%kieucQLOUrF_Yzu>Q=`tkk@{(`@f3Iu<_Ur7amzmn<)f5Bf# z_2c~+`~`m{6$t)c{&t_zV6@DiHhyewc{AN&P>CDo7jXYd#Nl~f@33;s$f5d4)? zKlls&N~$04&)_fkE2%*67yOk}Aowe(e()Fkl~h08pTS@7S5krCFZe5|K=4;m{opV7 zE2)0GKZC#EucQLOU+`B_f#9#C`oUlDS5p0We+GZSUr7amzu>Q=0>NKN^@G3QucZ3% z{tW(tzmf_Bf5Bf#1%khl>IZ+pUrF`j{TciPeybQ=0>NMKS5krCucZ3HU+`B_{r`{N zpL2nd_z6X#CP71qMJM5NeEy5h0QpD$B?}PsFM9wW|H!}O2?+j5p1yqllnsyj5ArX1 z0Hgk83lRB7{v{6peE*fa{pAAWANiL&0rCA;^7K`>|4SA>Z$0WC^)E>v-2Y_{0OTL} zmplP+|Cc;{@%;<;A4&D&{t5nqzmf_Bf5Bf#1%khl>IZ+pUrF`j@1KCb;IE_t!C&xK zQi0&Fr24^M@K;j(cz*_e!Cy%Qg1_LeqyoWTN%e!j;IE|m@%{|{g1?dq1b@L_NdybQ=0>NMKS5krCucZ3HU+`B_{dj){ zf5Bf#1%kieucQLOUrF_Yzu>Q=`tkk@{(`@f3Iu<_Ur7amzmn<)f5Bf#_2c~+`~`m{ z6$t)c{&t_zV6@DiHhyewc{AN&P>CDo7jXYd#Nl~f@33;s$f5d4)?Klls&N~$04 z&)_fkE2%*67yOk}Aowe(e()Fkl~h08pTS@7S5krCFZe5|K=4;m{opV7E2)0GKZC#E zucQLOU+`B_f#9#C`oUlDS5p0We+GZSUr7amzu>Q=0>NKN^@G3QucZ3%{tW(tzmf_B zf5Bf#1%khl>IZ+pUrF`j{TciPeQkTB)nF^NQ)a7Ok(3jci%JbkoR&rRKMRvHIw`wF6t~ zgB#W^;nrHYx74pqPSNeBndb8^*sx_}(}t0uku6(0+jrNyxZKg+($#X-VZq#;Pfgx6 z_mPX|Z&cS$4bR#?m}bT;8LE$sZEc%8ucg$=mTI|Lp{A~CIja~n&8q&8J={q=~i0mrm_|W^Hbd`PwYSL z)Perh#|)g(zxudiPi>nLvNP4Lob{q$Ut>MN3DY*UA@rfaoblGqt+q?%kYRU5|B0t_ L_^I=zrNaLYm)mc) delta 878 zcmV~$1#}Pu07cRBM+_JZqecuEF}fK&I!2Gq(ai`6VTAAq^-G5VPD<$p1t%j%DJYGC z(jg7f3f? za+8O=72uC@_aen3(PH>W6`HfSY<_u>!$9XPr zk>B})OI+qp{^AN(xyE&FaFbiy<_>rHn|s{n0S|e^W1jGoXZ*u+UhtAvyygvW`IrBA z$NQPGxu4u$9v}~t2g!rwA@Wcq c#40=_AR;tACOS^R@sT013QJ6iic@0X|1V! { - main_first_segment: F, main_last_segment: F, main_segment: F, a: [F; 2], b: [F; 2], c: [F; 2], last_c: [F; 2], flag: F, pc: F, a_src_imm: F, a_src_mem: F, a_offset_imm0: F, sp: F, a_src_sp: F, a_use_sp_imm1: F, a_src_step: F, b_src_imm: F, b_src_mem: F, b_offset_imm0: F, b_use_sp_imm1: F, b_src_ind: F, ind_width: F, is_external_op: F, op: F, store_ra: F, store_mem: F, store_ind: F, store_offset: F, set_pc: F, store_use_sp: F, set_sp: F, inc_sp: F, jmp_offset1: F, jmp_offset2: F, end: F, m32: F, + main_first_segment: F, main_last_segment: F, main_segment: F, a: [F; 2], b: [F; 2], c: [F; 2], last_c: [F; 2], flag: F, pc: F, a_src_imm: F, a_src_mem: F, a_offset_imm0: F, sp: F, a_src_sp: F, a_use_sp_imm1: F, a_src_step: F, b_src_imm: F, b_src_mem: F, b_offset_imm0: F, b_use_sp_imm1: F, b_src_ind: F, ind_width: F, is_external_op: F, op: F, store_ra: F, store_mem: F, store_ind: F, store_offset: F, set_pc: F, store_use_sp: F, set_sp: F, inc_sp: F, jmp_offset1: F, jmp_offset2: F, end: F, m32: F, }); trace!(Binary0Row, Binary0Trace { - m_op: F, mode32: F, free_in_a: [F; 8], free_in_b: [F; 8], free_in_c: [F; 8], carry: [F; 9], use_last_carry: F, multiplicity: F, + m_op: F, mode32: F, free_in_a: [F; 8], free_in_b: [F; 8], free_in_c: [F; 8], carry: [F; 9], use_last_carry: F, multiplicity: F, }); trace!(BinaryTable0Row, BinaryTable0Trace { @@ -16,9 +16,21 @@ trace!(BinaryTable0Row, BinaryTable0Trace { }); trace!(BinaryExtension0Row, BinaryExtension0Trace { - m_op: F, mode8: F, mode16: F, mode32: F, in1: [F; 8], in2_low: F, out: [F; 8], free_in2: [F; 4], multiplicity: F, + m_op: F, mode8: F, mode16: F, mode32: F, in1: [F; 8], in2_low: F, out: [F; 8], free_in2: [F; 4], multiplicity: F, }); trace!(BinaryExtensionTable0Row, BinaryExtensionTable0Trace { multiplicity: F, }); + +trace!(Arith0Row, Arith0Trace { + carry: [F; 7], a: [F; 4], b: [F; 4], c: [F; 4], d: [F; 4], na: F, nb: F, nr: F, np: F, na32: F, nd32: F, m32: F, div: F, fab: F, secondary_res: F, op: F, bus_a_low: F, bus_a_high: F, bus_b_high: F, res1_low: F, div64: F, res1_high: F, multiplicity: F, range_a1: F, range_b1: F, range_c1: F, range_d1: F, range_a3: F, range_b3: F, range_c3: F, range_d3: F, +}); + +trace!(ArithTable0Row, ArithTable0Trace { + multiplicity: F, +}); + +trace!(ArithRangeTable0Row, ArithRangeTable0Trace { + multiplicity: F, +}); diff --git a/state-machines/arith/pil/arith.pil b/state-machines/arith/pil/arith.pil index 60191088..5bcfac7f 100644 --- a/state-machines/arith/pil/arith.pil +++ b/state-machines/arith/pil/arith.pil @@ -1,7 +1,8 @@ require "std_lookup.pil" require "std_range_check.pil" require "operations.pil" -// require "arith_table.pil" +require "arith_table.pil" +require "arith_range_table.pil" // generic 64 u64 mul_u64 32 *u32 // witness 45 41 30 26 27 13 @@ -248,12 +249,12 @@ airtemplate Arith(int N = 2**18, const int operation_bus_id, const int dual_resu 2**8 * range_a1 + 2**10 * range_b1 + 2**12 * range_c1 + 2**14 * range_d1 + 2**16 * range_a3 + 2**18 * range_b3 + 2**20 * range_c3 + 2**22 * range_d3]); - lookup_assumes(AIRTH_RANGE_TABLE_ID, [range_a1, a[1]]); - lookup_assumes(AIRTH_RANGE_TABLE_ID, [range_b1, b[1]]); - lookup_assumes(AIRTH_RANGE_TABLE_ID, [range_c1, c[1]]); - lookup_assumes(AIRTH_RANGE_TABLE_ID, [range_d1, d[1]]); - lookup_assumes(AIRTH_RANGE_TABLE_ID, [range_a3, a[3]]); - lookup_assumes(AIRTH_RANGE_TABLE_ID, [range_b3, b[3]]); - lookup_assumes(AIRTH_RANGE_TABLE_ID, [range_c3, c[3]]); - lookup_assumes(AIRTH_RANGE_TABLE_ID, [range_d3, d[3]]); + lookup_assumes(ARITH_RANGE_TABLE_ID, [range_a1, a[1]]); + lookup_assumes(ARITH_RANGE_TABLE_ID, [range_b1, b[1]]); + lookup_assumes(ARITH_RANGE_TABLE_ID, [range_c1, c[1]]); + lookup_assumes(ARITH_RANGE_TABLE_ID, [range_d1, d[1]]); + lookup_assumes(ARITH_RANGE_TABLE_ID, [range_a3, a[3]]); + lookup_assumes(ARITH_RANGE_TABLE_ID, [range_b3, b[3]]); + lookup_assumes(ARITH_RANGE_TABLE_ID, [range_c3, c[3]]); + lookup_assumes(ARITH_RANGE_TABLE_ID, [range_d3, d[3]]); } \ No newline at end of file diff --git a/state-machines/arith/pil/arith_range_table.pil b/state-machines/arith/pil/arith_range_table.pil index 48cdb666..331c7465 100644 --- a/state-machines/arith/pil/arith_range_table.pil +++ b/state-machines/arith/pil/arith_range_table.pil @@ -6,7 +6,12 @@ const int ARITH_RANGE_TABLE_ID = 330; airtemplate ArithRangeTable(int N = 2**17) { col fixed RANGES = [0:2**16,1:2**15,2:2**15]; - col fixed VALUES = [0..2**16-1]..; + col fixed VALUES = [0..2**16-1]...; + + col witness multiplicity; lookup_proves(ARITH_TABLE_ID, [RANGES, VALUES], multiplicity); + + // REMOVE + multiplicity * (multiplicity - 1) === 0; } \ No newline at end of file diff --git a/state-machines/arith/pil/arith_table.pil b/state-machines/arith/pil/arith_table.pil index 61e86400..42b30f8e 100644 --- a/state-machines/arith/pil/arith_table.pil +++ b/state-machines/arith/pil/arith_table.pil @@ -1,53 +1,13 @@ require "std_lookup.pil" -require "operations.pil" const int ARITH_TABLE_ID = 330; -airtemplate ArithTable(int N = 2**8) { +airtemplate ArithTable(int N = 2**6) { - // NOTE: - // Divisions and remainders by 0 are done by QuickOps - - int na; // a is negative - int nb; // b is negative - int nr; // rem is negative - int np; // prod is negative - int na32; // a is 32-bit negative, 31th bit is 1. - int nd32; // d is 32-bit negative, 31th bit is 1. - - int m32; // 32 bits operation - int div; // division operation (div,rem) - int sa; - int sb; - - // negative a,c,d,na32,nd32 must be 0 if no signed_a - // na * (1 - sa) === 0; - // nr * (1 - sa) === 0; - // nr * (1 - div) === 0; - // np * (1 - sa) === 0; - // na32 * (1 - sa) === 0; - // nd32 * (1 - sa) === 0; - - // negative b must be 0 if no signed_b - // nb * (1 - sb) === 0; - - // na32, nd32 only available when 32 bits operation - // na32 * (1 - m32) === 0; - // nd32 * (1 - m32) === 0; - - // nr, nd32 only could be one 1 in divisions - // nr * (1 - div) === 0; - // nd32 * (1 - div) === 0; - - // if sb === 1 then sa must be 1, not allowed sa = 0, sb = 1 - // sb * (1 - sa) === 0; - // m32 * (sa - sb) === 0; - // div * (sa - sb) === 0; - // (1 - div) * m32 * (1 - sa) === 0; - // (1 - div) * m32 * (1 - sb) === 0; - - int op; - op = 0xb0 + 2 * (sa + sb + m32) + 2 * div * (m32 - sa + 4); + + // TABLE + // op + // m32|div|na|nb|nr|np|na32|nd32|range_a1(*)|range_b1(*)|range_c1(*)|range_d1(*)|range_a3(*)|range_b3(*)|range_c3(*)|range_d3(*) // div m32 sa sb comm primary secondary opcodes na nb nr np na32 nd32 // ---------------------------------------------------------------------------------- @@ -59,42 +19,149 @@ airtemplate ArithTable(int N = 2**8) { // 1 0 1 1 div rem (0xba,0xbb) a3 b3 d3 c3 =0 =0 a3,b3,c3,d3 // 1 1 0 0 divu_w remu_w (0xbc,0xbd) =0 =0 =0 =0 c1 d1 c1,d1 // 1 1 1 1 div_w rem_w (0xbe,0xbf) a1 b1 d1 c1 c1 d1 a1,b1,c1,d1 - - // (*) removed combinations of flags div,m32,sa,sb did allow combinations div, m32, sa, sb - // see 5 previous constraints. - // =0 means forced to zero by previous constraints - // comm = commutative (trivial: commutative operations) - - // positive_a1 = m32 * sa * (1 - na); - // negative_a1 = m32 * sa * na; - - // positive_a3 = (1-m32) * sa * (1 - na); - // negative_a3 = (1-m32) * sa * na; - - // positive_b1 = m32 * sa * (1 - nb); - // negative_b1 = m32 * sa * nb; - - // positive_b3 = (1-m32) * sb * (1 - nb); - // negative_b3 = (1-m32) * sb * nb; - - // positive_c1 = div * m32 * sa * (1 - np) + div * m32 * (1 - na32) + (1 - div) * m32 * (1 - na32); - // negative_c1 = div * m32 * sa * np + div * m32 * na32 + (1 - div) * m32 * na32; - - // positive_c3 = div * (1-m32) * sa * (1 - np); - // negative_c3 = div * (1-m32) * sa * np; - - // positive_d1 = div * m32 * (1 - nd32) + div * m32 * sa * (1 - nr); - // negative_d1 = div * m32 * nd32 + div * m32 * sa * nr; - - // positive_d3 = (1-div) * sa * (1 - np) + div * (1-m32) * sa * (1 - nr); - // negative_d3 = (1-div) * sa * np + div * (1-m32) * sa * nr; + + const int OPS[14] = [0xb0, 0xb1, 0xb3, 0xb4, 0xb5, 0xb6, 0xb8, 0xb9, 0xba, 0xbb, 0xbc, 0xbd, 0xbe, 0xbf]; + + col fixed OP; + col fixed FLAGS_AND_RANGES; + + int index = 0; + int size = 0; + while (index < N) { + for (int iop = 0; iop < length(OPS); ++iop) { + int opcode = OPS[iop]; + int m32 = 0; // 32 bits operation + int div = 0; // division operation (div,rem) + int sa = 0; + int sb = 0; + + switch (opcode & 0xFE) { + case 0xb3: // mulsuh + sa = 1; + case 0xb4: // mul, mulh + sa = 1; + sb = 1; + case 0xb6: // mul_w + m32 = 1; + sa = 1; + case 0xb8: // divu, remu + div = 1; + case 0xba: // div, rem + sa = 1; + sb = 1; + div = 1; + case 0xbc: // divu_w, remu_w + div = 1; + m32 = 1; + case 0xbe: // div_w, rem_w + sa = 1; + sb = 1; + div = 1; + m32 = 1; + } + + // CASES: + // sa = 0 sb = 0 => [a >= 0, b >= 0] + // sa = 1 sb = 0 => [a >= 0, b >= 0], [a < 0, b >= 0] + // sa = 1 sb = 1 => [a >= 0, b >= 0], [a < 0, b >= 0], [a >= 0, b < 0], [a < 0, b < 0] - // TODO: correct values + int cases = 1 + sa + sb + sa * sb; + + for (int icase = 0; icase < cases; ++icase) { + int na = 0; // a is negative + int nb = 0; // b is negative + int nr = 0; // rem is negative + int np = 0; // prod is negative + int na32 = 0; // a is 32-bit negative, 31th bit is 1. + int nd32 = 0; // d is 32-bit negative, 31th bit is 1. + switch (icase) { + case 1: + na = 1; + case 2: + nb = 1; + case 3: + na = 1; + nb = 1; + } + np = na + nb - na * nb; + nr = div ? na : 0; + na32 = m32 ? na : 0; + nd32 = m32 ? nr : 0; + + // negative a,c,d,na32,nd32 must be 0 if no signed_a + // na * (1 - sa) === 0; + // nr * (1 - sa) === 0; + // nr * (1 - div) === 0; + // np * (1 - sa) === 0; + // na32 * (1 - sa) === 0; + // nd32 * (1 - sa) === 0; + + // negative b must be 0 if no signed_b + // nb * (1 - sb) === 0; + + // na32, nd32 only available when 32 bits operation + // na32 * (1 - m32) === 0; + // nd32 * (1 - m32) === 0; + + // nr, nd32 only could be one 1 in divisions + // nr * (1 - div) === 0; + // nd32 * (1 - div) === 0; + + // if sb === 1 then sa must be 1, not allowed sa = 0, sb = 1 + // sb * (1 - sa) === 0; + // m32 * (sa - sb) === 0; + // div * (sa - sb) === 0; + // (1 - div) * m32 * (1 - sa) === 0; + // (1 - div) * m32 * (1 - sb) === 0; + + // div m32 sa sb comm primary secondary opcodes na nb nr np na32 nd32 + // ---------------------------------------------------------------------------------- + // 0 0 0 0 x mulu muluh (0xb0,0xb1) =0 =0 =0 =0 =0 =0 + // 0 0 1 0 *n/a* mulsuh (0xb2,0xb3) a3 =0 =0 d3 =0 =0 a3, d3 + // 0 0 1 1 x mul mulh (0xb4,0xb5) a3 b3 =0 d3 =0 =0 a3,b3, d3 + // 0 1 1 1 x mul_w *n/a* (0xb6,0xb7) a1 b1 =0 d3 c1 =0 d3, a1,b1,c1 + // 1 0 0 0 divu remu (0xb8,0xb9) =0 =0 =0 =0 =0 =0 + // 1 0 1 1 div rem (0xba,0xbb) a3 b3 d3 c3 =0 =0 a3,b3,c3,d3 + // 1 1 0 0 divu_w remu_w (0xbc,0xbd) =0 =0 =0 =0 c1 d1 c1,d1 + // 1 1 1 1 div_w rem_w (0xbe,0xbf) a1 b1 d1 c1 c1 d1 a1,b1,c1,d1 + + int range_a1 = m32 * sa ? 1 + na : 0; + int range_b1 = m32 * sb ? 1 + nb : 0; + + int range_c1 = 0; + if (m32) { + if (div) { + range_c1 = np || na32 ? 2 : 1; + } else { + range_c1 = 1 + na32; + } + } + int range_d1 = m32 * div ? (((np * sa) || nd32) ? 1:2) : 0; + + int range_a3 = (1 - m32) * sa ? 1 + na : 0; + int range_b3 = (1 - m32) * sb ? 1 + na : 0; + int range_c3 = div * (1 - m32) * sa ? 1 + np : 0; + int range_d3 = div * (1 - m32) * sa ? 1 + np : 0; + + OP[index] = opcode; + FLAGS_AND_RANGES[index] = m32 + 2 * div + 4 * na + 8 * nb + 16 * nr + 32 * np + 64 * na32 + 128 * nd32 + + 2**8 * range_a1 + 2**10 * range_b1 + 2**12 * range_c1 + 2**14 * range_d1 + + 2**16 * range_a3 + 2**18 * range_b3 + 2**20 * range_c3 + 2**22 * range_d3; + + index = index + 1; + if (index == N) break; + } + if (index == N) break; + } + if (size == 0) size = index; + } + + println("ARITH_TABLE SIZE: ", size); - - col fixed OP = [0..10]...; - col fixed FLAGS_AND_RANGES = [1,0...]; col witness multiplicity; lookup_proves(ARITH_TABLE_ID, mul: multiplicity, cols: [OP, FLAGS_AND_RANGES]); + + // REMOVE + multiplicity * (multiplicity - 1) === 0; } \ No newline at end of file From 1e2556b9aa11d9b02cb20212c106bd264ba71dd4 Mon Sep 17 00:00:00 2001 From: zkronos73 Date: Wed, 18 Sep 2024 14:33:42 +0200 Subject: [PATCH 03/17] wip aritmethic state machines --- Cargo.lock | 2 + pil/v0.1/pil/zisk.pil | 20 +-- pil/v0.1/pil/zisk.pilout | Bin 24926215 -> 24926182 bytes pil/v0.1/src/pil_helpers/pilout.rs | 16 +-- pil/v0.1/src/pil_helpers/traces.rs | 8 +- state-machines/arith/Cargo.toml | 2 + state-machines/arith/pil/arith.pil | 6 + .../arith/pil/arith_range_table.pil | 6 +- state-machines/arith/pil/arith_table.pil | 7 +- state-machines/arith/src/arith.rs | 119 +++++++++++------ state-machines/arith/src/arith_32.rs | 35 +++-- state-machines/arith/src/arith_full.rs | 122 ++++++++++++++++++ state-machines/arith/src/arith_mul_32.rs | 106 +++++++++++++++ .../src/{arith_64.rs => arith_mul_64.rs} | 34 +++-- state-machines/arith/src/arith_range_table.rs | 106 +++++++++++++++ .../src/{arith_3264.rs => arith_table.rs} | 37 +++--- state-machines/arith/src/lib.rs | 14 +- state-machines/main/src/main_sm.rs | 4 +- witness-computation/src/zisk_lib.rs | 35 ++--- 19 files changed, 545 insertions(+), 134 deletions(-) create mode 100644 state-machines/arith/src/arith_full.rs create mode 100644 state-machines/arith/src/arith_mul_32.rs rename state-machines/arith/src/{arith_64.rs => arith_mul_64.rs} (69%) create mode 100644 state-machines/arith/src/arith_range_table.rs rename state-machines/arith/src/{arith_3264.rs => arith_table.rs} (70%) diff --git a/Cargo.lock b/Cargo.lock index a9f00cc4..9ac0e14b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1994,12 +1994,14 @@ name = "sm-arith" version = "0.1.0" dependencies = [ "log", + "p3-field", "proofman", "proofman-common", "proofman-macros", "rayon", "sm-common", "zisk-core", + "zisk-pil", ] [[package]] diff --git a/pil/v0.1/pil/zisk.pil b/pil/v0.1/pil/zisk.pil index 793a9227..f54c8866 100644 --- a/pil/v0.1/pil/zisk.pil +++ b/pil/v0.1/pil/zisk.pil @@ -8,14 +8,14 @@ require "binary/pil/binary_extension_table.pil" require "arith/pil/arith.pil" const int OPERATION_BUS_ID = 5000; -const int DEFAULT_N = 2**21; +const int DEFAULT_ROWS = 2**21; airgroup Main { - Main(N: DEFAULT_N, RC: 2, operation_bus_id: OPERATION_BUS_ID); + Main(DEFAULT_ROWS, RC: 2, operation_bus_id: OPERATION_BUS_ID); } airgroup Binary { - Binary(N: DEFAULT_N, operation_bus_id: OPERATION_BUS_ID); + Binary(DEFAULT_ROWS, operation_bus_id: OPERATION_BUS_ID); } @@ -25,7 +25,7 @@ airgroup BinaryTable { airgroup BinaryExtension { - BinaryExtension(N: DEFAULT_N, operation_bus_id: OPERATION_BUS_ID); + BinaryExtension(DEFAULT_ROWS, operation_bus_id: OPERATION_BUS_ID); } airgroup BinaryExtensionTable { @@ -33,13 +33,5 @@ airgroup BinaryExtensionTable { } airgroup Arith { - Arith(N: DEFAULT_N, operation_bus_id: OPERATION_BUS_ID); -} - -airgroup ArithTable { - ArithTable(); -} - -airgroup ArithRangeTable { - ArithRangeTable(); -} + instance_arith(DEFAULT_ROWS, bus_id: OPERATION_BUS_ID); +} \ No newline at end of file diff --git a/pil/v0.1/pil/zisk.pilout b/pil/v0.1/pil/zisk.pilout index 918759047d13853217025cf96b362537870eb447..863780f7ed87229a65df96486cdc57f3fc2ea209 100644 GIT binary patch delta 1143 zcmYL@Wqj6E7>0N6WyD@aj2a6VqsCxkY;=!ij2hiJMvok+PQ4V+n@|kQS5Yw$>;RpL z-GzaLg@u8g==+Bcez~6e{GRKa`)t~OB7AJ$iEsrXmJn5lqbk*?P7UG-QY(34*DrVo87pdbAiz(58um>~?MkYN-toDqy<6r(9-3}YF`cqTBBNla!6Q<=te zW-yak%w`UADPbO^%x3`$S;S(Nu#_^Ev78mGWEHC^XANsv$9gufkux}xv#8)~&f#3( zJkDnmo7utzT!`7qHny{ai@2DbT*9SX#^vnd3a;cTuI3u9CBk)F&u(tuMs8vcH**WO zavQgE2Yb1beeCBh?&couqGCJYJq4Pn0LgljSM$RC$^_U7jJ&lxNAa7SHW|maABi5KSQcP;fE58&23Qqf zb%62!YXYndur9!QRURqYpu&^k-Aay}?%Z{%vpkd^ZP_|uqf(V$c_{je;&3Q4KIwFE zi@Z8$!eM4<3(fKXg2@6%ity1S(OPDpaK!VIm|_of_06iDXi! zMQ!R(msHZIM>-i~QlADiq!Ep2LQ|U2oEEesi&kXQnl`kh9qq}X10Cr^XS&dpZgi&y zJ?TYn`p}nt^k)DA8N^_QFqB~oX9T&7B#%*yW(;E)M?T}3z(gi7nJG+V8q=A;xRlGdoGZAJyU{ogAi&ySSTsxR?95pCdfLgOu|S5Az6*@)(cv1W)o5PxB1V@*L0e z0x$9sFY^ko@*1!625<5fZ}SfC@*eN=0Uz=aAM**H@)@6VlrQ*_uQ?av!;`+)wT=50D4SgXF>T5P7IPOdc+e zkaOjca-KX&9xacN$I9d6e0jV)L7pg2k|)bk { carry: [F; 7], a: [F; 4], b: [F; 4], c: [F; 4], d: [F; 4], na: F, nb: F, nr: F, np: F, na32: F, nd32: F, m32: F, div: F, fab: F, secondary_res: F, op: F, bus_a_low: F, bus_a_high: F, bus_b_high: F, res1_low: F, div64: F, res1_high: F, multiplicity: F, range_a1: F, range_b1: F, range_c1: F, range_d1: F, range_a3: F, range_b3: F, range_c3: F, range_d3: F, }); -trace!(ArithTable0Row, ArithTable0Trace { - multiplicity: F, +trace!(ArithTable1Row, ArithTable1Trace { + multiplicity2: F, }); -trace!(ArithRangeTable0Row, ArithRangeTable0Trace { - multiplicity: F, +trace!(ArithRangeTable2Row, ArithRangeTable2Trace { + multiplicity3: F, }); diff --git a/state-machines/arith/Cargo.toml b/state-machines/arith/Cargo.toml index 20379c5c..faa984da 100644 --- a/state-machines/arith/Cargo.toml +++ b/state-machines/arith/Cargo.toml @@ -5,8 +5,10 @@ edition = "2021" [dependencies] zisk-core = { path = "../../core" } +zisk-pil = { path="../../pil/v0.1" } sm-common = { path = "../common" } +p3-field = { workspace=true } proofman-common = { workspace = true } proofman-macros = { workspace = true } proofman = { workspace = true } diff --git a/state-machines/arith/pil/arith.pil b/state-machines/arith/pil/arith.pil index 5bcfac7f..12666864 100644 --- a/state-machines/arith/pil/arith.pil +++ b/state-machines/arith/pil/arith.pil @@ -257,4 +257,10 @@ airtemplate Arith(int N = 2**18, const int operation_bus_id, const int dual_resu lookup_assumes(ARITH_RANGE_TABLE_ID, [range_b3, b[3]]); lookup_assumes(ARITH_RANGE_TABLE_ID, [range_c3, c[3]]); lookup_assumes(ARITH_RANGE_TABLE_ID, [range_d3, d[3]]); +} + +function instance_arith(const int rows = 2**21, const int bus_id) { + Arith(rows, operation_bus_id: bus_id); + ArithTable(); + ArithRangeTable(); } \ No newline at end of file diff --git a/state-machines/arith/pil/arith_range_table.pil b/state-machines/arith/pil/arith_range_table.pil index 331c7465..be88ff3b 100644 --- a/state-machines/arith/pil/arith_range_table.pil +++ b/state-machines/arith/pil/arith_range_table.pil @@ -8,10 +8,10 @@ airtemplate ArithRangeTable(int N = 2**17) { col fixed RANGES = [0:2**16,1:2**15,2:2**15]; col fixed VALUES = [0..2**16-1]...; - col witness multiplicity; + col witness multiplicity3; - lookup_proves(ARITH_TABLE_ID, [RANGES, VALUES], multiplicity); + lookup_proves(ARITH_TABLE_ID, [RANGES, VALUES], multiplicity3); // REMOVE - multiplicity * (multiplicity - 1) === 0; + multiplicity3 * (multiplicity3 - 1) === 0; } \ No newline at end of file diff --git a/state-machines/arith/pil/arith_table.pil b/state-machines/arith/pil/arith_table.pil index 42b30f8e..0753feff 100644 --- a/state-machines/arith/pil/arith_table.pil +++ b/state-machines/arith/pil/arith_table.pil @@ -3,7 +3,6 @@ require "std_lookup.pil" const int ARITH_TABLE_ID = 330; airtemplate ArithTable(int N = 2**6) { - // TABLE // op @@ -158,10 +157,10 @@ airtemplate ArithTable(int N = 2**6) { println("ARITH_TABLE SIZE: ", size); - col witness multiplicity; + col witness multiplicity2; - lookup_proves(ARITH_TABLE_ID, mul: multiplicity, cols: [OP, FLAGS_AND_RANGES]); + lookup_proves(ARITH_TABLE_ID, mul: multiplicity2, cols: [OP, FLAGS_AND_RANGES]); // REMOVE - multiplicity * (multiplicity - 1) === 0; + multiplicity2 * (multiplicity2 - 1) === 0; } \ No newline at end of file diff --git a/state-machines/arith/src/arith.rs b/state-machines/arith/src/arith.rs index f3fed723..8228e1a6 100644 --- a/state-machines/arith/src/arith.rs +++ b/state-machines/arith/src/arith.rs @@ -3,17 +3,19 @@ use std::sync::{ Arc, Mutex, }; -use crate::{Arith3264SM, Arith32SM, Arith64SM}; +use p3_field::AbstractField; use proofman::{WitnessComponent, WitnessManager}; use proofman_common::{ExecutionCtx, ProofCtx, SetupCtx}; use rayon::Scope; use sm_common::{OpResult, Provable, ThreadController}; use zisk_core::{opcode_execute, ZiskRequiredOperation}; +use crate::{Arith32SM, ArithFullSM, ArithMul32SM, ArithMul64SM, ArithRangeTableSM, ArithTableSM}; + const PROVE_CHUNK_SIZE: usize = 1 << 12; #[allow(dead_code)] -pub struct ArithSM { +pub struct ArithSM { // Count of registered predecessors registered_predecessors: AtomicU32, @@ -21,60 +23,103 @@ pub struct ArithSM { threads_controller: Arc, // Inputs - inputs32: Mutex>, - inputs64: Mutex>, + inputs: Mutex>, + inputs_32: Mutex>, + inputs_mul_32: Mutex>, + inputs_mul_64: Mutex>, // Secondary State machines - arith32_sm: Arc, - arith64_sm: Arc, - arith3264_sm: Arc, + arith_32_sm: Arc>, + arith_mul_32_sm: Arc>, + arith_mul_64_sm: Arc>, + arith_full_sm: Arc>, + arith_range_table_sm: Arc>, + arith_table_sm: Arc>, } -impl ArithSM { - pub fn new( - wcm: &mut WitnessManager, - arith32_sm: Arc, - arith64_sm: Arc, - arith3264_sm: Arc, - ) -> Arc { +impl ArithSM { + pub fn new(wcm: &mut WitnessManager) -> Arc { + // TODO: change this call, for calls to WitnessManager to obtain from airGroupId and airIds + // ON each SM, not need pass to the constructor + let arith_full_ids = ArithSM::::get_ids_by_name("Arith"); + let arith_32_ids = ArithSM::::get_ids_by_name("Arith32"); + let arith_mul_32_ids = ArithSM::::get_ids_by_name("ArithMul32"); + let arith_mul_64_ids = ArithSM::::get_ids_by_name("ArithMul64"); + let arith_range_table_ids = ArithSM::::get_ids_by_name("ArithRangeTable"); + let arith_table_ids = ArithSM::::get_ids_by_name("ArithTable"); + let arith_sm = Self { registered_predecessors: AtomicU32::new(0), threads_controller: Arc::new(ThreadController::new()), - inputs32: Mutex::new(Vec::new()), - inputs64: Mutex::new(Vec::new()), - arith32_sm, - arith64_sm, - arith3264_sm, + inputs: Mutex::new(Vec::new()), + inputs_32: Mutex::new(Vec::new()), + inputs_mul_32: Mutex::new(Vec::new()), + inputs_mul_64: Mutex::new(Vec::new()), + arith_full_sm: ArithFullSM::new(wcm, arith_full_ids.0, &[arith_full_ids.1]), + arith_32_sm: Arith32SM::new(wcm, arith_32_ids.0, &[arith_32_ids.1]), + arith_mul_32_sm: ArithMul32SM::new(wcm, arith_mul_32_ids.0, &[arith_mul_32_ids.1]), + arith_mul_64_sm: ArithMul64SM::new(wcm, arith_mul_64_ids.0, &[arith_mul_64_ids.1]), + arith_range_table_sm: ArithRangeTableSM::new( + wcm, + arith_range_table_ids.0, + &[arith_range_table_ids.1], + ), + arith_table_sm: ArithTableSM::new(wcm, arith_table_ids.0, &[arith_table_ids.1]), }; let arith_sm = Arc::new(arith_sm); wcm.register_component(arith_sm.clone(), None, None); - arith_sm.arith32_sm.register_predecessor(); - arith_sm.arith64_sm.register_predecessor(); - arith_sm.arith3264_sm.register_predecessor(); + arith_sm.arith_32_sm.register_predecessor(); + arith_sm.arith_mul_32_sm.register_predecessor(); + arith_sm.arith_mul_64_sm.register_predecessor(); + arith_sm.arith_full_sm.register_predecessor(); arith_sm } + pub fn get_ids_by_name(name: &str) -> (usize, usize) { + const ARITH_AIRGROUP_ID: usize = 1; + if name == "Arith" { + return (ARITH_AIRGROUP_ID, 10); + } else if name == "Arith32" { + return (ARITH_AIRGROUP_ID, 11); + } else if name == "ArithMul64" { + return (ARITH_AIRGROUP_ID, 12); + } else if name == "ArithMul32" { + return (ARITH_AIRGROUP_ID, 13); + } else if name == "AirthRangeTable" { + return (ARITH_AIRGROUP_ID, 14); + } else if name == "ArithTable" { + return (ARITH_AIRGROUP_ID, 15); + } + return (0, 0); + } + pub fn register_predecessor(&self) { self.registered_predecessors.fetch_add(1, Ordering::SeqCst); } pub fn unregister_predecessor(&self, scope: &Scope) { if self.registered_predecessors.fetch_sub(1, Ordering::SeqCst) == 1 { - >::prove(self, &[], true, scope); + as Provable>::prove( + self, + &[], + true, + scope, + ); self.threads_controller.wait_for_threads(); - self.arith3264_sm.unregister_predecessor(scope); - self.arith64_sm.unregister_predecessor(scope); - self.arith32_sm.unregister_predecessor(scope); + self.arith_32_sm.unregister_predecessor(scope); + self.arith_mul_32_sm.unregister_predecessor(scope); + self.arith_mul_64_sm.unregister_predecessor(scope); + self.arith_full_sm.unregister_predecessor(scope); } } } -impl WitnessComponent for ArithSM { +impl WitnessComponent for ArithSM { fn calculate_witness( &self, _stage: u32, @@ -86,7 +131,9 @@ impl WitnessComponent for ArithSM { } } -impl Provable for ArithSM { +impl Provable + for ArithSM +{ fn calculate( &self, operation: ZiskRequiredOperation, @@ -99,8 +146,8 @@ impl Provable for ArithSM { let mut _inputs32 = Vec::new(); let mut _inputs64 = Vec::new(); - let operations32 = Arith32SM::operations(); - let operations64 = Arith64SM::operations(); + let operations64 = ArithMul64SM::::operations(); + let operations32 = Arith32SM::::operations(); // TODO Split the operations into 32 and 64 bit operations in parallel for operation in operations { @@ -114,8 +161,8 @@ impl Provable for ArithSM { } // TODO When drain is true, drain remaining inputs to the 3264 bits state machine - - let mut inputs32 = self.inputs32.lock().unwrap(); + /* + let mut inputs32 = self.inputs_32.lock().unwrap(); inputs32.extend(_inputs32); while inputs32.len() >= PROVE_CHUNK_SIZE || (drain && !inputs32.is_empty()) { @@ -125,7 +172,7 @@ impl Provable for ArithSM { let num_drained32 = std::cmp::min(PROVE_CHUNK_SIZE, inputs32.len()); let drained_inputs32 = inputs32.drain(..num_drained32).collect::>(); - let arith32_sm_cloned = self.arith32_sm.clone(); + let arith32_sm_cloned = self.arith_32_sm.clone(); self.threads_controller.add_working_thread(); let thread_controller = self.threads_controller.clone(); @@ -138,7 +185,7 @@ impl Provable for ArithSM { } drop(inputs32); - let mut inputs64 = self.inputs64.lock().unwrap(); + let mut inputs64 = self.inputs_mul_64.lock().unwrap(); inputs64.extend(_inputs64); while inputs64.len() >= PROVE_CHUNK_SIZE || (drain && !inputs64.is_empty()) { @@ -148,7 +195,7 @@ impl Provable for ArithSM { let num_drained64 = std::cmp::min(PROVE_CHUNK_SIZE, inputs64.len()); let drained_inputs64 = inputs64.drain(..num_drained64).collect::>(); - let arith64_sm_cloned = self.arith64_sm.clone(); + let arith64_sm_cloned = self.arith_mul_64_sm.clone(); self.threads_controller.add_working_thread(); let thread_controller = self.threads_controller.clone(); @@ -159,7 +206,7 @@ impl Provable for ArithSM { thread_controller.remove_working_thread(); }); } - drop(inputs64); + drop(inputs64);*/ } fn calculate_prove( diff --git a/state-machines/arith/src/arith_32.rs b/state-machines/arith/src/arith_32.rs index 2840923b..50e5b0ff 100644 --- a/state-machines/arith/src/arith_32.rs +++ b/state-machines/arith/src/arith_32.rs @@ -3,6 +3,7 @@ use std::sync::{ Arc, Mutex, }; +use p3_field::AbstractField; use proofman::{WitnessComponent, WitnessManager}; use proofman_common::{ExecutionCtx, ProofCtx, SetupCtx}; use rayon::Scope; @@ -11,23 +12,28 @@ use zisk_core::{opcode_execute, ZiskRequiredOperation}; const PROVE_CHUNK_SIZE: usize = 1 << 12; -pub struct Arith32SM { +pub struct Arith32SM { // Count of registered predecessors registered_predecessors: AtomicU32, // Inputs inputs: Mutex>, + + _phantom: std::marker::PhantomData, } -impl Arith32SM { - pub fn new(wcm: &mut WitnessManager, airgroup_id: usize, air_ids: &[usize]) -> Arc { - let arith32_sm = - Self { registered_predecessors: AtomicU32::new(0), inputs: Mutex::new(Vec::new()) }; - let arith32_sm = Arc::new(arith32_sm); +impl Arith32SM { + pub fn new(wcm: &mut WitnessManager, airgroup_id: usize, air_ids: &[usize]) -> Arc { + let _arith_32_sm = Self { + registered_predecessors: AtomicU32::new(0), + inputs: Mutex::new(Vec::new()), + _phantom: std::marker::PhantomData, + }; + let arith_32_sm = Arc::new(_arith_32_sm); - wcm.register_component(arith32_sm.clone(), Some(airgroup_id), Some(air_ids)); + wcm.register_component(arith_32_sm.clone(), Some(airgroup_id), Some(air_ids)); - arith32_sm + arith_32_sm } pub fn register_predecessor(&self) { @@ -36,7 +42,12 @@ impl Arith32SM { pub fn unregister_predecessor(&self, scope: &Scope) { if self.registered_predecessors.fetch_sub(1, Ordering::SeqCst) == 1 { - >::prove(self, &[], true, scope); + as Provable>::prove( + self, + &[], + true, + scope, + ); } } @@ -45,7 +56,7 @@ impl Arith32SM { } } -impl WitnessComponent for Arith32SM { +impl WitnessComponent for Arith32SM { fn calculate_witness( &self, _stage: u32, @@ -57,7 +68,9 @@ impl WitnessComponent for Arith32SM { } } -impl Provable for Arith32SM { +impl Provable + for Arith32SM +{ fn calculate( &self, operation: ZiskRequiredOperation, diff --git a/state-machines/arith/src/arith_full.rs b/state-machines/arith/src/arith_full.rs new file mode 100644 index 00000000..ae94d029 --- /dev/null +++ b/state-machines/arith/src/arith_full.rs @@ -0,0 +1,122 @@ +use std::sync::{ + atomic::{AtomicU32, Ordering}, + Arc, Mutex, +}; + +use p3_field::AbstractField; +use proofman::{WitnessComponent, WitnessManager}; +use proofman_common::{ExecutionCtx, ProofCtx, SetupCtx}; +use rayon::Scope; +// use sm_common::{OpResult, Provable, ThreadController}; +use sm_common::{OpResult, Provable}; +use zisk_core::{opcode_execute, ZiskRequiredOperation}; +use zisk_pil::Arith0Row; + +const PROVE_CHUNK_SIZE: usize = 1 << 12; + +pub struct ArithFullSM { + // Count of registered predecessors + registered_predecessors: AtomicU32, + + // Thread controller to manage the execution of the state machines + // threads_controller: Arc, + + // Inputs + inputs: Mutex>, + + _phantom: std::marker::PhantomData, +} + +impl ArithFullSM { + pub fn new(wcm: &mut WitnessManager, airgroup_id: usize, air_ids: &[usize]) -> Arc { + let arith_full_sm = Self { + registered_predecessors: AtomicU32::new(0), + inputs: Mutex::new(Vec::new()), + _phantom: std::marker::PhantomData, + //threads_controller: Arc::new(ThreadController::new()), + }; + let arith_full_sm = Arc::new(arith_full_sm); + + wcm.register_component(arith_full_sm.clone(), Some(airgroup_id), Some(air_ids)); + + arith_full_sm + } + + pub fn register_predecessor(&self) { + self.registered_predecessors.fetch_add(1, Ordering::SeqCst); + } + + pub fn unregister_predecessor(&self, scope: &Scope) { + if self.registered_predecessors.fetch_sub(1, Ordering::SeqCst) == 1 { + as Provable>::prove( + self, + &[], + true, + scope, + ); + } + } + pub fn process_slice(input: &Vec) -> Vec> { + let mut _trace: Vec> = Vec::new(); + _trace + } +} + +impl WitnessComponent for ArithFullSM { + fn calculate_witness( + &self, + _stage: u32, + _air_instance: Option, + _pctx: &mut ProofCtx, + _ectx: &ExecutionCtx, + _sctx: &SetupCtx, + ) { + } +} + +impl Provable + for ArithFullSM +{ + fn calculate( + &self, + operation: ZiskRequiredOperation, + ) -> Result> { + let result: OpResult = opcode_execute(operation.opcode, operation.a, operation.b); + Ok(result) + } + + fn prove(&self, operations: &[ZiskRequiredOperation], drain: bool, scope: &Scope) { + if let Ok(mut inputs) = self.inputs.lock() { + inputs.extend_from_slice(operations); + + while inputs.len() >= PROVE_CHUNK_SIZE || (drain && !inputs.is_empty()) { + if drain && !inputs.is_empty() { + println!("Arith3264SM: Draining inputs3264"); + } + + // self.threads_controller.add_working_thread(); + // let thread_controller = self.threads_controller.clone(); + + let num_drained = std::cmp::min(PROVE_CHUNK_SIZE, inputs.len()); + let _drained_inputs = inputs.drain(..num_drained).collect::>(); + + scope.spawn(move |scope| { + let _trace = Self::process_slice(&_drained_inputs); + // thread_controller.remove_working_thread(); + // TODO! Implement prove drained_inputs (a chunk of operations) + }); + } + } + } + + fn calculate_prove( + &self, + operation: ZiskRequiredOperation, + drain: bool, + scope: &Scope, + ) -> Result> { + let result = self.calculate(operation.clone()); + self.prove(&[operation], drain, scope); + result + } +} diff --git a/state-machines/arith/src/arith_mul_32.rs b/state-machines/arith/src/arith_mul_32.rs new file mode 100644 index 00000000..1356b3b1 --- /dev/null +++ b/state-machines/arith/src/arith_mul_32.rs @@ -0,0 +1,106 @@ +use std::sync::{ + atomic::{AtomicU32, Ordering}, + Arc, Mutex, +}; + +use p3_field::AbstractField; +use proofman::{WitnessComponent, WitnessManager}; +use proofman_common::{ExecutionCtx, ProofCtx, SetupCtx}; +use rayon::Scope; +use sm_common::{OpResult, Provable}; +use zisk_core::{opcode_execute, ZiskRequiredOperation}; + +const PROVE_CHUNK_SIZE: usize = 1 << 12; + +pub struct ArithMul32SM { + // Count of registered predecessors + registered_predecessors: AtomicU32, + + // Inputs + inputs: Mutex>, + + _phantom: std::marker::PhantomData, +} + +impl ArithMul32SM { + pub fn new(wcm: &mut WitnessManager, airgroup_id: usize, air_ids: &[usize]) -> Arc { + let arith_mul_32_sm = Self { + registered_predecessors: AtomicU32::new(0), + inputs: Mutex::new(Vec::new()), + _phantom: std::marker::PhantomData, + }; + let arith_mul_32_sm = Arc::new(arith_mul_32_sm); + + wcm.register_component(arith_mul_32_sm.clone(), Some(airgroup_id), Some(air_ids)); + + arith_mul_32_sm + } + + pub fn register_predecessor(&self) { + self.registered_predecessors.fetch_add(1, Ordering::SeqCst); + } + + pub fn unregister_predecessor(&self, scope: &Scope) { + if self.registered_predecessors.fetch_sub(1, Ordering::SeqCst) == 1 { + as Provable>::prove( + self, + &[], + true, + scope, + ); + } + } + + pub fn operations() -> Vec { + // TODO: use constants + vec![0xb6, 0xb7, 0xbe, 0xbf] + } +} + +impl WitnessComponent for ArithMul32SM { + fn calculate_witness( + &self, + _stage: u32, + _air_instance: Option, + _pctx: &mut ProofCtx, + _ectx: &ExecutionCtx, + _sctx: &SetupCtx, + ) { + } +} + +impl Provable for ArithMul32SM { + fn calculate( + &self, + operation: ZiskRequiredOperation, + ) -> Result> { + let result: OpResult = opcode_execute(operation.opcode, operation.a, operation.b); + Ok(result) + } + + fn prove(&self, operations: &[ZiskRequiredOperation], drain: bool, scope: &Scope) { + if let Ok(mut inputs) = self.inputs.lock() { + inputs.extend_from_slice(operations); + + while inputs.len() >= PROVE_CHUNK_SIZE || (drain && !inputs.is_empty()) { + let num_drained = std::cmp::min(PROVE_CHUNK_SIZE, inputs.len()); + let _drained_inputs = inputs.drain(..num_drained).collect::>(); + + scope.spawn(move |_| { + // TODO! Implement prove drained_inputs (a chunk of operations) + }); + } + } + } + + fn calculate_prove( + &self, + operation: ZiskRequiredOperation, + drain: bool, + scope: &Scope, + ) -> Result> { + let result = self.calculate(operation.clone()); + self.prove(&[operation], drain, scope); + result + } +} diff --git a/state-machines/arith/src/arith_64.rs b/state-machines/arith/src/arith_mul_64.rs similarity index 69% rename from state-machines/arith/src/arith_64.rs rename to state-machines/arith/src/arith_mul_64.rs index 699eaafb..0536d959 100644 --- a/state-machines/arith/src/arith_64.rs +++ b/state-machines/arith/src/arith_mul_64.rs @@ -3,6 +3,7 @@ use std::sync::{ Arc, Mutex, }; +use p3_field::AbstractField; use proofman::{WitnessComponent, WitnessManager}; use proofman_common::{ExecutionCtx, ProofCtx, SetupCtx}; use rayon::Scope; @@ -11,23 +12,28 @@ use zisk_core::{opcode_execute, ZiskRequiredOperation}; const PROVE_CHUNK_SIZE: usize = 1 << 12; -pub struct Arith64SM { +pub struct ArithMul64SM { // Count of registered predecessors registered_predecessors: AtomicU32, // Inputs inputs: Mutex>, + + _phantom: std::marker::PhantomData, } -impl Arith64SM { - pub fn new(wcm: &mut WitnessManager, airgroup_id: usize, air_ids: &[usize]) -> Arc { - let arith64_sm = - Self { registered_predecessors: AtomicU32::new(0), inputs: Mutex::new(Vec::new()) }; - let arith64_sm = Arc::new(arith64_sm); +impl ArithMul64SM { + pub fn new(wcm: &mut WitnessManager, airgroup_id: usize, air_ids: &[usize]) -> Arc { + let arith_mul_64_sm = Self { + registered_predecessors: AtomicU32::new(0), + inputs: Mutex::new(Vec::new()), + _phantom: std::marker::PhantomData, + }; + let arith_mul_64_sm = Arc::new(arith_mul_64_sm); - wcm.register_component(arith64_sm.clone(), Some(airgroup_id), Some(air_ids)); + wcm.register_component(arith_mul_64_sm.clone(), Some(airgroup_id), Some(air_ids)); - arith64_sm + arith_mul_64_sm } pub fn register_predecessor(&self) { @@ -36,16 +42,22 @@ impl Arith64SM { pub fn unregister_predecessor(&self, scope: &Scope) { if self.registered_predecessors.fetch_sub(1, Ordering::SeqCst) == 1 { - >::prove(self, &[], true, scope); + as Provable>::prove( + self, + &[], + true, + scope, + ); } } pub fn operations() -> Vec { + // TODO: use constants vec![0xb0, 0xb1, 0xb2, 0xb3, 0xb4, 0xb5, 0xb8, 0xb9, 0xba, 0xbb] } } -impl WitnessComponent for Arith64SM { +impl WitnessComponent for ArithMul64SM { fn calculate_witness( &self, _stage: u32, @@ -57,7 +69,7 @@ impl WitnessComponent for Arith64SM { } } -impl Provable for Arith64SM { +impl Provable for ArithMul64SM { fn calculate( &self, operation: ZiskRequiredOperation, diff --git a/state-machines/arith/src/arith_range_table.rs b/state-machines/arith/src/arith_range_table.rs new file mode 100644 index 00000000..1e438272 --- /dev/null +++ b/state-machines/arith/src/arith_range_table.rs @@ -0,0 +1,106 @@ +use std::sync::{ + atomic::{AtomicU32, Ordering}, + Arc, Mutex, +}; + +use p3_field::AbstractField; +use proofman::{WitnessComponent, WitnessManager}; +use proofman_common::{ExecutionCtx, ProofCtx, SetupCtx}; +use rayon::Scope; +use sm_common::{OpResult, Provable}; +use zisk_core::{opcode_execute, ZiskRequiredOperation}; + +const PROVE_CHUNK_SIZE: usize = 1 << 12; + +pub struct ArithRangeTableSM { + // Count of registered predecessors + registered_predecessors: AtomicU32, + + // Inputs + inputs: Mutex>, + + _phantom: std::marker::PhantomData, +} + +impl ArithRangeTableSM { + pub fn new(wcm: &mut WitnessManager, airgroup_id: usize, air_ids: &[usize]) -> Arc { + let arith_range_table_sm = Self { + registered_predecessors: AtomicU32::new(0), + inputs: Mutex::new(Vec::new()), + _phantom: std::marker::PhantomData, + }; + let arith_range_table_sm = Arc::new(arith_range_table_sm); + + wcm.register_component(arith_range_table_sm.clone(), Some(airgroup_id), Some(air_ids)); + + arith_range_table_sm + } + + pub fn register_predecessor(&self) { + self.registered_predecessors.fetch_add(1, Ordering::SeqCst); + } + + pub fn unregister_predecessor(&self, scope: &Scope) { + if self.registered_predecessors.fetch_sub(1, Ordering::SeqCst) == 1 { + as Provable>::prove( + self, + &[], + true, + scope, + ); + } + } + + pub fn operations() -> Vec { + // TODO: use constants + vec![0xb6, 0xb7, 0xbe, 0xbf] + } +} + +impl WitnessComponent for ArithRangeTableSM { + fn calculate_witness( + &self, + _stage: u32, + _air_instance: Option, + _pctx: &mut ProofCtx, + _ectx: &ExecutionCtx, + _sctx: &SetupCtx, + ) { + } +} + +impl Provable for ArithRangeTableSM { + fn calculate( + &self, + operation: ZiskRequiredOperation, + ) -> Result> { + let result: OpResult = opcode_execute(operation.opcode, operation.a, operation.b); + Ok(result) + } + + fn prove(&self, operations: &[ZiskRequiredOperation], drain: bool, scope: &Scope) { + if let Ok(mut inputs) = self.inputs.lock() { + inputs.extend_from_slice(operations); + + while inputs.len() >= PROVE_CHUNK_SIZE || (drain && !inputs.is_empty()) { + let num_drained = std::cmp::min(PROVE_CHUNK_SIZE, inputs.len()); + let _drained_inputs = inputs.drain(..num_drained).collect::>(); + + scope.spawn(move |_| { + // TODO! Implement prove drained_inputs (a chunk of operations) + }); + } + } + } + + fn calculate_prove( + &self, + operation: ZiskRequiredOperation, + drain: bool, + scope: &Scope, + ) -> Result> { + let result = self.calculate(operation.clone()); + self.prove(&[operation], drain, scope); + result + } +} diff --git a/state-machines/arith/src/arith_3264.rs b/state-machines/arith/src/arith_table.rs similarity index 70% rename from state-machines/arith/src/arith_3264.rs rename to state-machines/arith/src/arith_table.rs index 2323319b..0d2a56e8 100644 --- a/state-machines/arith/src/arith_3264.rs +++ b/state-machines/arith/src/arith_table.rs @@ -3,6 +3,7 @@ use std::sync::{ Arc, Mutex, }; +use p3_field::AbstractField; use proofman::{WitnessComponent, WitnessManager}; use proofman_common::{ExecutionCtx, ProofCtx, SetupCtx}; use rayon::Scope; @@ -11,23 +12,28 @@ use zisk_core::{opcode_execute, ZiskRequiredOperation}; const PROVE_CHUNK_SIZE: usize = 1 << 12; -pub struct Arith3264SM { +pub struct ArithTableSM { // Count of registered predecessors registered_predecessors: AtomicU32, // Inputs inputs: Mutex>, + + _phantom: std::marker::PhantomData, } -impl Arith3264SM { - pub fn new(wcm: &mut WitnessManager, airgroup_id: usize, air_ids: &[usize]) -> Arc { - let arith3264_sm = - Self { registered_predecessors: AtomicU32::new(0), inputs: Mutex::new(Vec::new()) }; - let arith3264_sm = Arc::new(arith3264_sm); +impl ArithTableSM { + pub fn new(wcm: &mut WitnessManager, airgroup_id: usize, air_ids: &[usize]) -> Arc { + let _arith_table_sm = Self { + registered_predecessors: AtomicU32::new(0), + inputs: Mutex::new(Vec::new()), + _phantom: std::marker::PhantomData, + }; + let arith_table_sm = Arc::new(_arith_table_sm); - wcm.register_component(arith3264_sm.clone(), Some(airgroup_id), Some(air_ids)); + wcm.register_component(arith_table_sm.clone(), Some(airgroup_id), Some(air_ids)); - arith3264_sm + arith_table_sm } pub fn register_predecessor(&self) { @@ -36,7 +42,7 @@ impl Arith3264SM { pub fn unregister_predecessor(&self, scope: &Scope) { if self.registered_predecessors.fetch_sub(1, Ordering::SeqCst) == 1 { - >::prove( + as Provable>::prove( self, &[], true, @@ -44,9 +50,14 @@ impl Arith3264SM { ); } } + + pub fn operations() -> Vec { + // TODO: use constants + vec![0xb6, 0xb7, 0xbe, 0xbf] + } } -impl WitnessComponent for Arith3264SM { +impl WitnessComponent for ArithTableSM { fn calculate_witness( &self, _stage: u32, @@ -58,7 +69,7 @@ impl WitnessComponent for Arith3264SM { } } -impl Provable for Arith3264SM { +impl Provable for ArithTableSM { fn calculate( &self, operation: ZiskRequiredOperation, @@ -72,10 +83,6 @@ impl Provable for Arith3264SM { inputs.extend_from_slice(operations); while inputs.len() >= PROVE_CHUNK_SIZE || (drain && !inputs.is_empty()) { - if drain && !inputs.is_empty() { - println!("Arith3264SM: Draining inputs3264"); - } - let num_drained = std::cmp::min(PROVE_CHUNK_SIZE, inputs.len()); let _drained_inputs = inputs.drain(..num_drained).collect::>(); diff --git a/state-machines/arith/src/lib.rs b/state-machines/arith/src/lib.rs index 8297735d..77ca82de 100644 --- a/state-machines/arith/src/lib.rs +++ b/state-machines/arith/src/lib.rs @@ -1,11 +1,17 @@ mod arith; mod arith_32; -mod arith_3264; -mod arith_64; +mod arith_full; +mod arith_mul_32; +mod arith_mul_64; +mod arith_range_table; +mod arith_table; mod arith_traces; pub use arith::*; pub use arith_32::*; -pub use arith_3264::*; -pub use arith_64::*; +pub use arith_full::*; +pub use arith_mul_32::*; +pub use arith_mul_64::*; +pub use arith_range_table::*; +pub use arith_table::*; pub use arith_traces::*; diff --git a/state-machines/main/src/main_sm.rs b/state-machines/main/src/main_sm.rs index 2f03cf49..62f66106 100644 --- a/state-machines/main/src/main_sm.rs +++ b/state-machines/main/src/main_sm.rs @@ -54,7 +54,7 @@ pub struct MainSM { // State machines mem_sm: Arc, binary_sm: Arc, - arith_sm: Arc, + arith_sm: Arc>, } impl<'a, F: AbstractField + Default + Copy + Send + Sync + 'static> MainSM { @@ -83,7 +83,7 @@ impl<'a, F: AbstractField + Default + Copy + Send + Sync + 'static> MainSM { wcm: &mut WitnessManager, mem_sm: Arc, binary_sm: Arc, - arith_sm: Arc, + arith_sm: Arc>, airgroup_id: usize, air_ids: &[usize], ) -> Arc { diff --git a/witness-computation/src/zisk_lib.rs b/witness-computation/src/zisk_lib.rs index 55eb619f..81ece1b1 100644 --- a/witness-computation/src/zisk_lib.rs +++ b/witness-computation/src/zisk_lib.rs @@ -9,7 +9,7 @@ use p3_goldilocks::Goldilocks; use proofman::{WitnessLibrary, WitnessManager}; use proofman_common::{ExecutionCtx, ProofCtx, SetupCtx, WitnessPilout}; use proofman_util::{timer_start, timer_stop_and_log}; -use sm_arith::{Arith3264SM, Arith32SM, Arith64SM, ArithSM}; +use sm_arith::ArithSM; use sm_main::MainSM; use sm_mem::{MemAlignedSM, MemSM, MemUnalignedSM}; @@ -17,10 +17,10 @@ pub struct ZiskWitness { pub public_inputs_path: PathBuf, pub wcm: WitnessManager, // State machines - pub arith_sm: Arc, - pub arith_32_sm: Arc, - pub arith_64_sm: Arc, - pub arith_3264_sm: Arc, + pub arith_sm: Arc>, + /* pub arith_32_sm: Arc, + pub arith_64_sm: Arc, + pub arith_3264_sm: Arc,*/ pub binary_sm: Arc, pub binary_basic_sm: Arc, pub binary_extension_sm: Arc, @@ -53,10 +53,15 @@ impl ZiskWitness { pub const MEM_AIRGROUP_ID: usize = 100; pub const MEM_ALIGN_AIR_IDS: &[usize] = &[1]; pub const MEM_UNALIGNED_AIR_IDS: &[usize] = &[2, 3]; - pub const ARITH_AIRGROUP_ID: usize = 101; - pub const ARITH32_AIR_IDS: &[usize] = &[4, 5]; - pub const ARITH64_AIR_IDS: &[usize] = &[6]; - pub const ARITH3264_AIR_IDS: &[usize] = &[7]; + // pub const ARITH_AIRGROUP_ID: usize = 101; + + // pub const ARITH_32_AIR_IDS: &[usize] = &[4, 5]; + // pub const ARITH_MUL_64_AIR_IDS: &[usize] = &[6]; + // pub const ARITH_MUL_32_AIR_IDS: &[usize] = &[7]; + // pub const ARITH_FULL_AIR_IDS: &[usize] = &[8]; + // pub const ARITH_TABLE_AIR_IDS: &[usize] = &[9]; + // pub const ARITH_RANGE_TABLE_AIR_IDS: &[usize] = &[11]; + pub const QUICKOPS_AIRGROUP_ID: usize = 102; pub const QUICKOPS_AIR_IDS: &[usize] = &[10]; @@ -74,11 +79,7 @@ impl ZiskWitness { let binary_sm = BinarySM::new(&mut wcm, binary_basic_sm.clone(), binary_extension_sm.clone()); - let arith_32_sm = Arith32SM::new(&mut wcm, ARITH_AIRGROUP_ID, ARITH32_AIR_IDS); - let arith_64_sm = Arith64SM::new(&mut wcm, ARITH_AIRGROUP_ID, ARITH64_AIR_IDS); - let arith_3264_sm = Arith3264SM::new(&mut wcm, ARITH_AIRGROUP_ID, ARITH3264_AIR_IDS); - let arith_sm = - ArithSM::new(&mut wcm, arith_32_sm.clone(), arith_64_sm.clone(), arith_3264_sm.clone()); + let arith_sm = ArithSM::new(&mut wcm); let quickops_sm = QuickOpsSM::new(&mut wcm, QUICKOPS_AIRGROUP_ID, QUICKOPS_AIR_IDS); @@ -96,9 +97,9 @@ impl ZiskWitness { public_inputs_path, wcm, arith_sm, - arith_32_sm, - arith_64_sm, - arith_3264_sm, + /* arith_32_sm, + arith_64_sm: arith_mul_64_sm, + arith_3264_sm: arith_full_sm,*/ binary_sm, binary_basic_sm, binary_extension_sm, From 7149ae70b0c7f1af4ec3e18774cec6122b864761 Mon Sep 17 00:00:00 2001 From: zkronos73 Date: Thu, 19 Sep 2024 00:44:18 +0200 Subject: [PATCH 04/17] WIP arith witness comutation --- pil/v0.1/pil/zisk.pilout | Bin 24926182 -> 24926184 bytes state-machines/arith/pil/arith.pil | 23 ++- .../arith/pil/arith_range_table.pil | 13 +- state-machines/arith/pil/arith_table.pil | 29 +++- state-machines/arith/src/arith_full.rs | 27 +++- .../arith/src/arith_range_table_inputs.rs | 48 +++++++ state-machines/arith/src/arith_table.rs | 3 +- .../arith/src/arith_table_inputs.rs | 134 ++++++++++++++++++ state-machines/arith/src/arith_traces.rs | 6 - state-machines/arith/src/lib.rs | 6 +- 10 files changed, 251 insertions(+), 38 deletions(-) create mode 100644 state-machines/arith/src/arith_range_table_inputs.rs create mode 100644 state-machines/arith/src/arith_table_inputs.rs delete mode 100644 state-machines/arith/src/arith_traces.rs diff --git a/pil/v0.1/pil/zisk.pilout b/pil/v0.1/pil/zisk.pilout index 863780f7ed87229a65df96486cdc57f3fc2ea209..9567584fbbd502bf91a95d547eec06cdf4e7aaef 100644 GIT binary patch delta 1200 zcmbWwWq6il7{+nVcjWuhFg6y%sBJJ{)acG_ba#%1(G7zB70^=<12OQKfr(;aBD1@@ zTR=?g1fN&F_}SyW@85BLxUTy=a^OVlSok(0A|z0ja+Ie66{$p&7?r6)RjN^)8q}l~ zwMnE7NhDL36jDheoqE)#0S(EZ5shg=Q<{-U7TGkX1v#{&6|Kpo4Q**hdpeLuM>^4& zE_9_E-RVJ3deNIc^ravD8NfgWkxv1I3}y&J8OCr%Fp^P>W(;E)$9N_%kx5Ku3R9WJ zbY?JAz z2Y%!ye&!c`ZEag`6X| zlv~NI?*6xORj?&2u2)(TIwF?P bEShw>F{mpS7Q2Sd z?(W1!R1~o>kl(A9UjFHI&iR~kp69vFa}Mo08h0>U8xb+YQi;k`p(@p=PJ}q(sXjkX-sDZGnqvxvzfzO<}sfIEMyUjDPsvsS;lf!u##1*W({ju$9gt!GN*7V8##^B zIRiM8vpAbgY-S7RV9w<{wz7@$xq$6l$VFVtCG6l*F5_~pU?*1+a1~dxi)+};9`6{$M}`s_?&f-y1i69SP)?K^$w_ju+*nSLo5)S&W^!{mRc;}-lv~NIxu4u$9v}~t2g!rw zA@Wdpn4Bl)%LVdqd4xPt9wm>K3*|BLSb3a0UM`X+$P?vBakWIk;H$$ TQ9Dc($7V)S{wWH!gpraHu(-Tx diff --git a/state-machines/arith/pil/arith.pil b/state-machines/arith/pil/arith.pil index 12666864..6e291fd2 100644 --- a/state-machines/arith/pil/arith.pil +++ b/state-machines/arith/pil/arith.pil @@ -245,18 +245,17 @@ airtemplate Arith(int N = 2**18, const int operation_bus_id, const int dual_resu range_c3 * (1 - range_c3) * (2 - range_c3) === 0; range_d3 * (1 - range_d3) * (2 - range_d3) === 0; - lookup_assumes(ARITH_TABLE_ID, cols: [ op, m32 + 2 * div + 4 * na + 8 * nb + 16 * nr + 32 * np + 64 * na32 + 128 * nd32 + - 2**8 * range_a1 + 2**10 * range_b1 + 2**12 * range_c1 + 2**14 * range_d1 + - 2**16 * range_a3 + 2**18 * range_b3 + 2**20 * range_c3 + 2**22 * range_d3]); - - lookup_assumes(ARITH_RANGE_TABLE_ID, [range_a1, a[1]]); - lookup_assumes(ARITH_RANGE_TABLE_ID, [range_b1, b[1]]); - lookup_assumes(ARITH_RANGE_TABLE_ID, [range_c1, c[1]]); - lookup_assumes(ARITH_RANGE_TABLE_ID, [range_d1, d[1]]); - lookup_assumes(ARITH_RANGE_TABLE_ID, [range_a3, a[3]]); - lookup_assumes(ARITH_RANGE_TABLE_ID, [range_b3, b[3]]); - lookup_assumes(ARITH_RANGE_TABLE_ID, [range_c3, c[3]]); - lookup_assumes(ARITH_RANGE_TABLE_ID, [range_d3, d[3]]); + + arith_table_assumes(op, m32, div, na, nb, nr, np, na32, nd32, range_a1, range_b1, range_c1, range_d1, range_a3, range_b3, range_c3, range_d3); + + arith_range_table_assumes(range_a1, a[1]); + arith_range_table_assumes(range_b1, b[1]); + arith_range_table_assumes(range_c1, c[1]); + arith_range_table_assumes(range_d1, d[1]); + arith_range_table_assumes(range_a3, a[3]); + arith_range_table_assumes(range_b3, b[3]); + arith_range_table_assumes(range_c3, c[3]); + arith_range_table_assumes(range_d3, d[3]); } function instance_arith(const int rows = 2**21, const int bus_id) { diff --git a/state-machines/arith/pil/arith_range_table.pil b/state-machines/arith/pil/arith_range_table.pil index be88ff3b..3dc21961 100644 --- a/state-machines/arith/pil/arith_range_table.pil +++ b/state-machines/arith/pil/arith_range_table.pil @@ -8,10 +8,15 @@ airtemplate ArithRangeTable(int N = 2**17) { col fixed RANGES = [0:2**16,1:2**15,2:2**15]; col fixed VALUES = [0..2**16-1]...; - col witness multiplicity3; + col witness multiplicity; - lookup_proves(ARITH_TABLE_ID, [RANGES, VALUES], multiplicity3); + lookup_proves(ARITH_TABLE_ID, [RANGES, VALUES], multiplicity); // REMOVE - multiplicity3 * (multiplicity3 - 1) === 0; -} \ No newline at end of file + multiplicity * (multiplicity - 1) === 0; +} + +function arith_range_table_assumes(const expr range_type, const expr value) { + // TODO: define rule for empty rows + lookup_assumes(ARITH_RANGE_TABLE_ID, [range_type, value]); +} diff --git a/state-machines/arith/pil/arith_table.pil b/state-machines/arith/pil/arith_table.pil index 0753feff..01d82006 100644 --- a/state-machines/arith/pil/arith_table.pil +++ b/state-machines/arith/pil/arith_table.pil @@ -35,7 +35,7 @@ airtemplate ArithTable(int N = 2**6) { int sb = 0; switch (opcode & 0xFE) { - case 0xb3: // mulsuh + case 0xb2: // mulsuh sa = 1; case 0xb4: // mul, mulh sa = 1; @@ -43,6 +43,7 @@ airtemplate ArithTable(int N = 2**6) { case 0xb6: // mul_w m32 = 1; sa = 1; + sb = 1; case 0xb8: // divu, remu div = 1; case 0xba: // div, rem @@ -66,6 +67,8 @@ airtemplate ArithTable(int N = 2**6) { int cases = 1 + sa + sb + sa * sb; + println("#ARITH_TABLE", opcode, index, cases); + for (int icase = 0; icase < cases; ++icase) { int na = 0; // a is negative int nb = 0; // b is negative @@ -155,12 +158,26 @@ airtemplate ArithTable(int N = 2**6) { if (size == 0) size = index; } - println("ARITH_TABLE SIZE: ", size); + println("ARITH_TABLE SIZE: ", size); + println("ARITH_FLAGS: ", FLAGS_AND_RANGES); + for (index = 0; index < size; ++index) { + println(FLAGS_AND_RANGES[index]); + } - col witness multiplicity2; + col witness multiplicity; - lookup_proves(ARITH_TABLE_ID, mul: multiplicity2, cols: [OP, FLAGS_AND_RANGES]); + lookup_proves(ARITH_TABLE_ID, mul: multiplicity, cols: [OP, FLAGS_AND_RANGES]); // REMOVE - multiplicity2 * (multiplicity2 - 1) === 0; -} \ No newline at end of file + multiplicity * (multiplicity - 1) === 0; +} + +function arith_table_assumes( const expr op, const expr m32, const expr div, const expr na, const expr nb, + const expr nr, const expr np, const expr na32, const expr nd32, + const expr range_a1, const expr range_b1, const expr range_c1, const expr range_d1, + const expr range_a3, const expr range_b3, const expr range_c3, const expr range_d3) { + // TODO: define rule for empty rows + lookup_assumes(ARITH_TABLE_ID, cols: [ op, m32 + 2 * div + 4 * na + 8 * nb + 16 * nr + 32 * np + 64 * na32 + 128 * nd32 + + 2**8 * range_a1 + 2**10 * range_b1 + 2**12 * range_c1 + 2**14 * range_d1 + + 2**16 * range_a3 + 2**18 * range_b3 + 2**20 * range_c3 + 2**22 * range_d3]); +} diff --git a/state-machines/arith/src/arith_full.rs b/state-machines/arith/src/arith_full.rs index ae94d029..a0c17465 100644 --- a/state-machines/arith/src/arith_full.rs +++ b/state-machines/arith/src/arith_full.rs @@ -3,12 +3,15 @@ use std::sync::{ Arc, Mutex, }; +use crate::{ + arith_table_inputs, ArithRangeTableInputs, ArithRangeTableSM, ArithSM, ArithTableInputs, + ArithTableSM, +}; use p3_field::AbstractField; use proofman::{WitnessComponent, WitnessManager}; use proofman_common::{ExecutionCtx, ProofCtx, SetupCtx}; use rayon::Scope; -// use sm_common::{OpResult, Provable, ThreadController}; -use sm_common::{OpResult, Provable}; +use sm_common::{OpResult, Provable, ThreadController}; use zisk_core::{opcode_execute, ZiskRequiredOperation}; use zisk_pil::Arith0Row; @@ -56,8 +59,14 @@ impl ArithFullSM { ); } } - pub fn process_slice(input: &Vec) -> Vec> { + pub fn process_slice( + input: &Vec, + range_table_inputs: &mut ArithRangeTableInputs, + table_inputs: &mut ArithTableInputs, + ) -> Vec> { let mut _trace: Vec> = Vec::new(); + range_table_inputs.push(0, 0); + table_inputs.fast_push(0, 0, 0); _trace } } @@ -91,7 +100,7 @@ impl Provable= PROVE_CHUNK_SIZE || (drain && !inputs.is_empty()) { if drain && !inputs.is_empty() { - println!("Arith3264SM: Draining inputs3264"); + println!("ArithFullSM: Draining inputs"); } // self.threads_controller.add_working_thread(); @@ -100,8 +109,14 @@ impl Provable>(); - scope.spawn(move |scope| { - let _trace = Self::process_slice(&_drained_inputs); + scope.spawn(move |_| { + let mut arith_range_table_inputs = ArithRangeTableInputs::::new(); + let mut arith_table_inputs = ArithTableInputs::::new(); + let _trace = Self::process_slice( + &_drained_inputs, + &mut arith_range_table_inputs, + &mut arith_table_inputs, + ); // thread_controller.remove_working_thread(); // TODO! Implement prove drained_inputs (a chunk of operations) }); diff --git a/state-machines/arith/src/arith_range_table_inputs.rs b/state-machines/arith/src/arith_range_table_inputs.rs new file mode 100644 index 00000000..85737b0c --- /dev/null +++ b/state-machines/arith/src/arith_range_table_inputs.rs @@ -0,0 +1,48 @@ +use std::ops::Add; + +const ARITH_RANGE_TABLE_SIZE: usize = 2 << 17; + +pub struct ArithRangeTableInputs { + multiplicity: [u32; ARITH_RANGE_TABLE_SIZE], + _phantom: std::marker::PhantomData, +} + +impl Add for ArithRangeTableInputs { + type Output = Self; + fn add(self, other: Self) -> Self { + let mut result = Self::new(); + for i in 0..ARITH_RANGE_TABLE_SIZE { + result.multiplicity[i] = self.multiplicity[i] + other.multiplicity[i]; + } + result + } +} + +impl ArithRangeTableInputs { + pub fn new() -> Self { + Self { multiplicity: [0; ARITH_RANGE_TABLE_SIZE], _phantom: std::marker::PhantomData } + } + pub fn clear(&mut self) { + self.multiplicity = [0; ARITH_RANGE_TABLE_SIZE]; + } + pub fn push(&mut self, range_id: u8, value: u64) { + Self::check_value(range_id, value); + self.fast_push(range_id, value); + } + fn get_row(range_id: u8, value: u64) -> usize { + usize::try_from(value + if range_id > 0 { 2 << 16 } else { 0 }).unwrap() % + ARITH_RANGE_TABLE_SIZE + } + fn check_value(range_id: u8, value: u64) { + match range_id { + 0 => assert!(value <= 0xFFFF), + 1 => assert!(value <= 0x7FFF), + 2 => assert!(value <= 0xFFFF && value >= 0x8000), + _ => assert!(false), + }; + } + + pub fn fast_push(&mut self, op: u8, value: u64) { + self.multiplicity[Self::get_row(op, value)] += 1; + } +} diff --git a/state-machines/arith/src/arith_table.rs b/state-machines/arith/src/arith_table.rs index 0d2a56e8..f8d4e8fc 100644 --- a/state-machines/arith/src/arith_table.rs +++ b/state-machines/arith/src/arith_table.rs @@ -6,7 +6,7 @@ use std::sync::{ use p3_field::AbstractField; use proofman::{WitnessComponent, WitnessManager}; use proofman_common::{ExecutionCtx, ProofCtx, SetupCtx}; -use rayon::Scope; +use rayon::{vec, Scope}; use sm_common::{OpResult, Provable}; use zisk_core::{opcode_execute, ZiskRequiredOperation}; @@ -50,7 +50,6 @@ impl ArithTableSM { ); } } - pub fn operations() -> Vec { // TODO: use constants vec![0xb6, 0xb7, 0xbe, 0xbf] diff --git a/state-machines/arith/src/arith_table_inputs.rs b/state-machines/arith/src/arith_table_inputs.rs new file mode 100644 index 00000000..881aa5e4 --- /dev/null +++ b/state-machines/arith/src/arith_table_inputs.rs @@ -0,0 +1,134 @@ +use std::ops::Add; + +const ARITH_TABLE_SIZE: usize = 36; +pub struct ArithTableInputs { + multiplicity: [u32; ARITH_TABLE_SIZE], + _phantom: std::marker::PhantomData, +} + +impl Add for ArithTableInputs { + type Output = Self; + fn add(self, other: Self) -> Self { + let mut result = Self::new(); + for i in 0..ARITH_TABLE_SIZE { + result.multiplicity[i] = self.multiplicity[i] + other.multiplicity[i]; + } + result + } +} + +impl ArithTableInputs { + const FLAGS_AND_RANGES: [u32; ARITH_TABLE_SIZE] = [ + 0x000000, 0x000000, 0x010000, 0x020024, 0x050000, 0x0A0024, 0x050028, 0x0A002C, 0x050000, + 0x0A0024, 0x050028, 0x0A002C, 0x001501, 0x002665, 0x001929, 0x002A6D, 0x000002, 0x000002, + 0x550002, 0xAA0036, 0xA5002A, 0xAA003E, 0x550002, 0xAA0036, 0xA5002A, 0xAA003E, 0x009003, + 0x009003, 0x009503, 0x0066F7, 0x00692B, 0x006AFF, 0x009503, 0x0066F7, 0x00692B, 0x006AFF, + ]; + pub fn new() -> Self { + Self { multiplicity: [0; ARITH_TABLE_SIZE], _phantom: std::marker::PhantomData } + } + pub fn clear(&mut self) { + self.multiplicity = [0; ARITH_TABLE_SIZE]; + } + pub fn push( + &mut self, + op: u8, + m32: u32, + div: u32, + na: u32, + nb: u32, + nr: u32, + np: u32, + na32: u32, + nd32: u32, + range_a1: u32, + range_b1: u32, + range_c1: u32, + range_d1: u32, + range_a3: u32, + range_b3: u32, + range_c3: u32, + range_d3: u32, + ) { + // TODO: in debug mode + let flags = Self::values_to_flags( + m32, div, na, nb, nr, np, na32, nd32, range_a1, range_b1, range_c1, range_d1, range_a3, + range_b3, range_c3, range_d3, + ); + let variants = Self::get_variants(op); + let row_offset = nb * 2 + nb; + let row: usize = Self::get_row(op, na, nb); + + assert!(row_offset < variants); + assert!(Self::FLAGS_AND_RANGES[row] == flags); + + self.multiplicity[row] += 1; + } + fn get_variants(op: u8) -> u32 { + match op { + 0xb0 | 0xb1 | 0xb8 | 0xb9 | 0xbc | 0xbd => 1, // mulu|muluh|divu|remu|divu_w|remu_w + 0xb3 => 2, // mulsuh + 0xb4 | 0xb5 | 0xb6 | 0xba | 0xbb | 0xbe | 0xbf => 4, /* mul|mulh|mul_w|div|rem|div_w|rem_w */ + _ => panic!("Invalid opcode"), + } + } + fn get_offset(op: u8) -> u32 { + match op { + 0xb0 => 0, // mulu + 0xb1 => 1, // muluh + 0xb3 => 2, // mulsuh + 0xb4 => 4, // mul + 0xb5 => 8, // mulh + 0xb6 => 12, // mul_w + 0xb8 => 16, // divu + 0xb9 => 17, // remu + 0xba => 18, // div + 0xbb => 22, // rem + 0xbc => 26, // divu_w + 0xbd => 27, // remu_w + 0xbe => 28, // div_w + 0xbf => 32, // rem_w + _ => panic!("Invalid opcode"), + } + } + fn get_row(op: u8, na: u32, nb: u32) -> usize { + usize::try_from(Self::get_offset(op) + na + 2 * nb).unwrap() % ARITH_TABLE_SIZE + } + pub fn fast_push(&mut self, op: u8, na: u32, nb: u32) { + self.multiplicity[Self::get_row(op, na, nb)] += 1; + } + fn values_to_flags( + m32: u32, + div: u32, + na: u32, + nb: u32, + nr: u32, + np: u32, + na32: u32, + nd32: u32, + range_a1: u32, + range_b1: u32, + range_c1: u32, + range_d1: u32, + range_a3: u32, + range_b3: u32, + range_c3: u32, + range_d3: u32, + ) -> u32 { + m32 + 0x000002 * div + + 0x000004 * na + + 0x000008 * nb + + 0x000010 * nr + + 0x000020 * np + + 0x000040 * na32 + + 0x000080 * nd32 + + 0x000100 * range_a1 + + 0x000400 * range_b1 + + 0x001000 * range_c1 + + 0x004000 * range_d1 + + 0x010000 * range_a3 + + 0x040000 * range_b3 + + 0x100000 * range_c3 + + 0x400000 * range_d3 + } +} diff --git a/state-machines/arith/src/arith_traces.rs b/state-machines/arith/src/arith_traces.rs deleted file mode 100644 index 20562a6e..00000000 --- a/state-machines/arith/src/arith_traces.rs +++ /dev/null @@ -1,6 +0,0 @@ -use proofman_common as common; -pub use proofman_macros::trace; - -trace!(Arith320Row, Arith320Trace { fake: F }); -trace!(Arith640Row, Arith640Trace { fake: F }); -trace!(Arith32640Row, Arith32640Trace { fake: F }); diff --git a/state-machines/arith/src/lib.rs b/state-machines/arith/src/lib.rs index 77ca82de..2eaeece8 100644 --- a/state-machines/arith/src/lib.rs +++ b/state-machines/arith/src/lib.rs @@ -4,8 +4,9 @@ mod arith_full; mod arith_mul_32; mod arith_mul_64; mod arith_range_table; +mod arith_range_table_inputs; mod arith_table; -mod arith_traces; +mod arith_table_inputs; pub use arith::*; pub use arith_32::*; @@ -13,5 +14,6 @@ pub use arith_full::*; pub use arith_mul_32::*; pub use arith_mul_64::*; pub use arith_range_table::*; +pub use arith_range_table_inputs::*; pub use arith_table::*; -pub use arith_traces::*; +pub use arith_table_inputs::*; From 5eb8afb52a53a857e4593a0373c540380d137c85 Mon Sep 17 00:00:00 2001 From: zkronos73 Date: Mon, 14 Oct 2024 20:44:59 +0200 Subject: [PATCH 05/17] update cargo files --- Cargo.lock | 57 +++++++++++++++++---------------- state-machines/arith/Cargo.toml | 1 - 2 files changed, 29 insertions(+), 29 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 47b3a10f..ae5ccc5a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -198,9 +198,9 @@ checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" [[package]] name = "cc" -version = "1.1.28" +version = "1.1.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2e80e3b6a3ab07840e1cae9b0666a63970dc28e8ed5ffbcdacbfc760c281bfc1" +checksum = "b16803a61b81d9eabb7eae2588776c4c1e584b738ede45fdbb4c972cec1e9945" dependencies = [ "jobserver", "libc", @@ -915,9 +915,9 @@ dependencies = [ [[package]] name = "js-sys" -version = "0.3.70" +version = "0.3.72" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1868808506b929d7b0cfa8f75951347aa71bb21144b7791bae35d9bccfcfe37a" +checksum = "6a88f1bda2bd75b0452a14784937d796722fdebfe50df998aeb3f0b7603019a9" dependencies = [ "wasm-bindgen", ] @@ -1322,7 +1322,7 @@ dependencies = [ [[package]] name = "pil-std-lib" version = "0.1.0" -source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?branch=develop#cc62d85b4ce9a59d9d2aaf1c51bae0f97480f4d3" +source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?branch=develop#5cb4395c280e8b3309c1ac3286c094d3e3fd970a" dependencies = [ "log", "num-bigint", @@ -1340,7 +1340,7 @@ dependencies = [ [[package]] name = "pilout" version = "0.1.0" -source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?branch=develop#cc62d85b4ce9a59d9d2aaf1c51bae0f97480f4d3" +source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?branch=develop#5cb4395c280e8b3309c1ac3286c094d3e3fd970a" dependencies = [ "bytes", "log", @@ -1460,7 +1460,7 @@ dependencies = [ [[package]] name = "proofman" version = "0.1.0" -source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?branch=develop#cc62d85b4ce9a59d9d2aaf1c51bae0f97480f4d3" +source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?branch=develop#5cb4395c280e8b3309c1ac3286c094d3e3fd970a" dependencies = [ "colored", "env_logger", @@ -1470,6 +1470,7 @@ dependencies = [ "p3-goldilocks", "pilout", "proofman-common", + "proofman-hints", "proofman-starks-lib-c", "proofman-util", "stark", @@ -1479,7 +1480,7 @@ dependencies = [ [[package]] name = "proofman-common" version = "0.1.0" -source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?branch=develop#cc62d85b4ce9a59d9d2aaf1c51bae0f97480f4d3" +source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?branch=develop#5cb4395c280e8b3309c1ac3286c094d3e3fd970a" dependencies = [ "env_logger", "log", @@ -1497,7 +1498,7 @@ dependencies = [ [[package]] name = "proofman-hints" version = "0.1.0" -source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?branch=develop#cc62d85b4ce9a59d9d2aaf1c51bae0f97480f4d3" +source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?branch=develop#5cb4395c280e8b3309c1ac3286c094d3e3fd970a" dependencies = [ "p3-field", "proofman-common", @@ -1507,7 +1508,7 @@ dependencies = [ [[package]] name = "proofman-macros" version = "0.1.0" -source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?branch=develop#cc62d85b4ce9a59d9d2aaf1c51bae0f97480f4d3" +source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?branch=develop#5cb4395c280e8b3309c1ac3286c094d3e3fd970a" dependencies = [ "proc-macro2", "quote", @@ -1517,7 +1518,7 @@ dependencies = [ [[package]] name = "proofman-starks-lib-c" version = "0.1.0" -source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?branch=develop#cc62d85b4ce9a59d9d2aaf1c51bae0f97480f4d3" +source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?branch=develop#5cb4395c280e8b3309c1ac3286c094d3e3fd970a" dependencies = [ "log", ] @@ -1525,7 +1526,7 @@ dependencies = [ [[package]] name = "proofman-util" version = "0.1.0" -source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?branch=develop#cc62d85b4ce9a59d9d2aaf1c51bae0f97480f4d3" +source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?branch=develop#5cb4395c280e8b3309c1ac3286c094d3e3fd970a" dependencies = [ "colored", "sysinfo", @@ -2120,7 +2121,7 @@ checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" [[package]] name = "stark" version = "0.1.0" -source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?branch=develop#cc62d85b4ce9a59d9d2aaf1c51bae0f97480f4d3" +source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?branch=develop#5cb4395c280e8b3309c1ac3286c094d3e3fd970a" dependencies = [ "log", "p3-field", @@ -2415,7 +2416,7 @@ dependencies = [ [[package]] name = "transcript" version = "0.1.0" -source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?branch=develop#cc62d85b4ce9a59d9d2aaf1c51bae0f97480f4d3" +source = "git+https://github.com/0xPolygonHermez/pil2-proofman.git?branch=develop#5cb4395c280e8b3309c1ac3286c094d3e3fd970a" dependencies = [ "proofman-starks-lib-c", ] @@ -2544,9 +2545,9 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasm-bindgen" -version = "0.2.93" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a82edfc16a6c469f5f44dc7b571814045d60404b55a0ee849f9bcfa2e63dd9b5" +checksum = "128d1e363af62632b8eb57219c8fd7877144af57558fb2ef0368d0087bddeb2e" dependencies = [ "cfg-if", "once_cell", @@ -2555,9 +2556,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-backend" -version = "0.2.93" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9de396da306523044d3302746f1208fa71d7532227f15e347e2d93e4145dd77b" +checksum = "cb6dd4d3ca0ddffd1dd1c9c04f94b868c37ff5fac97c30b97cff2d74fce3a358" dependencies = [ "bumpalo", "log", @@ -2570,9 +2571,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-futures" -version = "0.4.43" +version = "0.4.45" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61e9300f63a621e96ed275155c108eb6f843b6a26d053f122ab69724559dc8ed" +checksum = "cc7ec4f8827a71586374db3e87abdb5a2bb3a15afed140221307c3ec06b1f63b" dependencies = [ "cfg-if", "js-sys", @@ -2582,9 +2583,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.93" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "585c4c91a46b072c92e908d99cb1dcdf95c5218eeb6f3bf1efa991ee7a68cccf" +checksum = "e79384be7f8f5a9dd5d7167216f022090cf1f9ec128e6e6a482a2cb5c5422c56" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -2592,9 +2593,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.93" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "afc340c74d9005395cf9dd098506f7f44e38f2b4a21c6aaacf9a105ea5e1e836" +checksum = "26c6ab57572f7a24a4985830b120de1594465e5d500f24afe89e16b4e833ef68" dependencies = [ "proc-macro2", "quote", @@ -2605,9 +2606,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.93" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c62a0a307cb4a311d3a07867860911ca130c3494e8c2719593806c08bc5d0484" +checksum = "65fc09f10666a9f147042251e0dda9c18f166ff7de300607007e96bdebc1068d" [[package]] name = "wasm-streams" @@ -2624,9 +2625,9 @@ dependencies = [ [[package]] name = "web-sys" -version = "0.3.70" +version = "0.3.72" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26fdeaafd9bd129f65e7c031593c24d62186301e0c72c8978fa1678be7d532c0" +checksum = "f6488b90108c040df0fe62fa815cbdee25124641df01814dd7282749234c6112" dependencies = [ "js-sys", "wasm-bindgen", diff --git a/state-machines/arith/Cargo.toml b/state-machines/arith/Cargo.toml index f0197da1..23f8d0b2 100644 --- a/state-machines/arith/Cargo.toml +++ b/state-machines/arith/Cargo.toml @@ -5,7 +5,6 @@ edition = "2021" [dependencies] zisk-core = { path = "../../core" } -zisk-pil = { path="../../pil/v0.1" } sm-common = { path = "../common" } zisk-pil = { path = "../../pil" } From 5cbf94747a88a13d6738c2cc685bf55e98a68cd4 Mon Sep 17 00:00:00 2001 From: zkronos73 Date: Mon, 14 Oct 2024 23:11:42 +0000 Subject: [PATCH 06/17] WIP arith, fix errors after merge from develop --- pil/src/pil_helpers/pilout.rs | 42 ++++++- pil/src/pil_helpers/traces.rs | 18 ++- pil/zisk.pil | 15 ++- state-machines/arith/pil/arith.pil | 101 ++++++++-------- .../arith/pil/arith_range_table.pil | 5 +- state-machines/arith/pil/arith_table.pil | 55 +++++---- state-machines/arith/src/arith.rs | 108 ++++++------------ state-machines/arith/src/arith_full.rs | 4 +- state-machines/arith/src/arith_mul_32.rs | 2 - state-machines/arith/src/arith_mul_64.rs | 4 +- state-machines/arith/src/arith_range_table.rs | 5 +- state-machines/arith/src/arith_table.rs | 4 +- 12 files changed, 193 insertions(+), 170 deletions(-) diff --git a/pil/src/pil_helpers/pilout.rs b/pil/src/pil_helpers/pilout.rs index 3032eb27..0dc67191 100644 --- a/pil/src/pil_helpers/pilout.rs +++ b/pil/src/pil_helpers/pilout.rs @@ -8,20 +8,34 @@ pub const PILOUT_HASH: &[u8] = b"Zisk-hash"; pub const MAIN_AIRGROUP_ID: usize = 0; -pub const BINARY_AIRGROUP_ID: usize = 1; +pub const ARITH_AIRGROUP_ID: usize = 1; -pub const BINARY_TABLE_AIRGROUP_ID: usize = 2; +pub const ARITH_TABLE_AIRGROUP_ID: usize = 2; -pub const BINARY_EXTENSION_AIRGROUP_ID: usize = 3; +pub const ARITH_RANGE_TABLE_AIRGROUP_ID: usize = 3; -pub const BINARY_EXTENSION_TABLE_AIRGROUP_ID: usize = 4; +pub const BINARY_AIRGROUP_ID: usize = 4; -pub const SPECIFIED_RANGES_AIRGROUP_ID: usize = 5; +pub const BINARY_TABLE_AIRGROUP_ID: usize = 5; + +pub const BINARY_EXTENSION_AIRGROUP_ID: usize = 6; + +pub const BINARY_EXTENSION_TABLE_AIRGROUP_ID: usize = 7; + +pub const U_16_AIR_AIRGROUP_ID: usize = 8; + +pub const SPECIFIED_RANGES_AIRGROUP_ID: usize = 9; //AIR CONSTANTS pub const MAIN_AIR_IDS: &[usize] = &[0]; +pub const ARITH_AIR_IDS: &[usize] = &[0]; + +pub const ARITH_TABLE_AIR_IDS: &[usize] = &[0]; + +pub const ARITH_RANGE_TABLE_AIR_IDS: &[usize] = &[0]; + pub const BINARY_AIR_IDS: &[usize] = &[0]; pub const BINARY_TABLE_AIR_IDS: &[usize] = &[0]; @@ -30,6 +44,8 @@ pub const BINARY_EXTENSION_AIR_IDS: &[usize] = &[0]; pub const BINARY_EXTENSION_TABLE_AIR_IDS: &[usize] = &[0]; +pub const U_16_AIR_AIR_IDS: &[usize] = &[0]; + pub const SPECIFIED_RANGES_AIR_IDS: &[usize] = &[0]; pub struct Pilout; @@ -42,6 +58,18 @@ impl Pilout { air_group.add_air(Some("Main"), 2097152); + let air_group = pilout.add_air_group(Some("Arith")); + + air_group.add_air(Some("Arith"), 262144); + + let air_group = pilout.add_air_group(Some("ArithTable")); + + air_group.add_air(Some("ArithTable"), 64); + + let air_group = pilout.add_air_group(Some("ArithRangeTable")); + + air_group.add_air(Some("ArithRangeTable"), 131072); + let air_group = pilout.add_air_group(Some("Binary")); air_group.add_air(Some("Binary"), 2097152); @@ -58,6 +86,10 @@ impl Pilout { air_group.add_air(Some("BinaryExtensionTable"), 4194304); + let air_group = pilout.add_air_group(Some("U16Air")); + + air_group.add_air(Some("U16Air"), 65536); + let air_group = pilout.add_air_group(Some("SpecifiedRanges")); air_group.add_air(Some("SpecifiedRanges"), 16777216); diff --git a/pil/src/pil_helpers/traces.rs b/pil/src/pil_helpers/traces.rs index 58f42ff9..5e3fb452 100644 --- a/pil/src/pil_helpers/traces.rs +++ b/pil/src/pil_helpers/traces.rs @@ -7,6 +7,18 @@ trace!(Main0Row, Main0Trace { main_first_segment: F, main_last_segment: F, main_segment: F, a: [F; 2], b: [F; 2], c: [F; 2], last_c: [F; 2], flag: F, pc: F, a_src_imm: F, a_src_mem: F, a_offset_imm0: F, a_imm1: F, a_src_step: F, b_src_imm: F, b_src_mem: F, b_offset_imm0: F, b_imm1: F, b_src_ind: F, ind_width: F, is_external_op: F, op: F, store_ra: F, store_mem: F, store_ind: F, store_offset: F, set_pc: F, jmp_offset1: F, jmp_offset2: F, end: F, m32: F, operation_bus_enabled: F, }); +trace!(Arith0Row, Arith0Trace { + carry: [F; 7], a: [F; 4], b: [F; 4], c: [F; 4], d: [F; 4], na: F, nb: F, nr: F, np: F, na32: F, nd32: F, m32: F, div: F, fab: F, debug_main_step: F, secondary_res: F, op: F, bus_a_low: F, bus_a_high: F, bus_b_high: F, res1_low: F, div64: F, res1_high: F, multiplicity: F, range_a1: F, range_b1: F, range_c1: F, range_d1: F, range_a3: F, range_b3: F, range_c3: F, range_d3: F, +}); + +trace!(ArithTable0Row, ArithTable0Trace { + multiplicity: F, +}); + +trace!(ArithRangeTable0Row, ArithRangeTable0Trace { + multiplicity: F, +}); + trace!(Binary0Row, Binary0Trace { m_op: F, mode32: F, free_in_a: [F; 8], free_in_b: [F; 8], free_in_c: [F; 8], carry: [F; 8], use_last_carry: F, op_is_min_max: F, multiplicity: F, main_step: F, }); @@ -23,6 +35,10 @@ trace!(BinaryExtensionTable0Row, BinaryExtensionTable0Trace { multiplicity: F, }); +trace!(U16Air0Row, U16Air0Trace { + mul: F, +}); + trace!(SpecifiedRanges0Row, SpecifiedRanges0Trace { - mul: [F; 1], + mul: [F; 2], }); diff --git a/pil/zisk.pil b/pil/zisk.pil index 59903524..5cb4f363 100644 --- a/pil/zisk.pil +++ b/pil/zisk.pil @@ -5,6 +5,7 @@ require "binary/pil/binary.pil" require "binary/pil/binary_table.pil" require "binary/pil/binary_extension.pil" require "binary/pil/binary_extension_table.pil" +require "arith/pil/arith.pil" // require "mem/pil/mem.pil" const int OPERATION_BUS_ID = 5000; @@ -16,6 +17,18 @@ airgroup Main { // Mem(N: 2**21, RC: 2); // } +airgroup Arith { + Arith(operation_bus_id: OPERATION_BUS_ID); +} + +airgroup ArithTable { + ArithTable(); +} + +airgroup ArithRangeTable { + ArithRangeTable(); +} + airgroup Binary { Binary(N: 2**21, operation_bus_id: OPERATION_BUS_ID); } @@ -30,4 +43,4 @@ airgroup BinaryExtension { airgroup BinaryExtensionTable { BinaryExtensionTable(disable_fixed: 0); -} +} \ No newline at end of file diff --git a/state-machines/arith/pil/arith.pil b/state-machines/arith/pil/arith.pil index 6e291fd2..11f90d20 100644 --- a/state-machines/arith/pil/arith.pil +++ b/state-machines/arith/pil/arith.pil @@ -5,9 +5,9 @@ require "arith_table.pil" require "arith_range_table.pil" // generic 64 u64 mul_u64 32 *u32 -// witness 45 41 30 26 27 13 -// lookups 3 3 3 2 3 3 -// range_checks 16+7 16+7 16+7 16+7 8+3 7+2 +// witness 45 41 30 26 27 13 +// lookups 3 3 3 2 3 3 +// range_checks 16+7 16+7 16+7 16+7 8+3 7+2 // ---------------------------------------------------------- // TOTAL 123 119 108 101 69 61 // @@ -38,13 +38,15 @@ airtemplate Arith(int N = 2**18, const int operation_bus_id, const int dual_resu col witness m32; // 32 bits operation col witness div; // division operation (div,rem) - col witness fab; // fab, to decrease degree of intermediate products a * b + col witness fab; // fab, to decrease degree of intermediate products a * b // fab = 1 if sign of a,b are the same // fab = -1 if sign of a,b are different + col witness debug_main_step; // only for debug + if (!dual_result) { col witness air.secondary_res; // op_index: 0 => first result, 1 => second result; - secondary_res * (secondary_res - 1) === 0; + secondary_res * (secondary_res - 1) === 0; } else { const expr air.secondary_res = 0; } @@ -52,44 +54,44 @@ airtemplate Arith(int N = 2**18, const int operation_bus_id, const int dual_resu // factor ab € {-1, 1} fab === 1 - 2 * na - 2 * nb + 4 * na * nb; - const expr eq[CHUNKS_OP]; + const expr eq[CHUNKS_OP]; - eq[0] = fab * a[0] * b[0] + eq[0] = fab * a[0] * b[0] - c[0] + 2 * np * c[0] - + div * d[0] + + div * d[0] - 2 * nr * d[0]; - eq[1] = fab * a[1] * b[0] + eq[1] = fab * a[1] * b[0] + fab * a[0] * b[1] - - c[1] + - c[1] + 2 * np * c[1] - + div * d[1] + + div * d[1] - 2 * nr * d[1]; eq[2] = fab * a[2] * b[0] - + fab * a[1] * b[1] + + fab * a[1] * b[1] + fab * a[0] * b[2] - - c[2] + - c[2] + 2 * np * c[2] - + div * d[2] + + div * d[2] - 2 * nr * d[2] - np * div * m32 + nr * m32; - + eq[3] = fab * a[3] * b[0] + fab * a[2] * b[1] + fab * a[1] * b[2] + fab * a[0] * b[3] - - c[3] + - c[3] + 2 * np * c[3] - + div * d[3] + + div * d[3] - 2 * nr * d[3]; - + eq[4] = fab * a[3] * b[1] + fab * a[2] * b[2] + fab * a[1] * b[3] - + na * b[0] * (1 - 2 * nb) + + na * b[0] * (1 - 2 * nb) + nb * a[0] * (1 - 2 * na) - np * div // \ + np * m32 // np * (div ^ m32) @@ -102,26 +104,26 @@ airtemplate Arith(int N = 2**18, const int operation_bus_id, const int dual_resu + fab * a[2] * b[3] + nb * a[1] * (1 - 2 * na) + na * b[1] * (1 - 2 * nb) - - d[1] * (1 - div) + - d[1] * (1 - div) + 2 * np * d[1] * (1 - div); - + eq[6] = fab * a[3] * b[3] + nb * a[2] * (1 - 2 * na) + na * b[2] * (1 - 2 * nb) - - d[2] * (1 - div) + - d[2] * (1 - div) + 2 * np * d[2] * (1 - div); - + eq[7] = CHUNK_SIZE * na * nb + na * b[3] * (1 - 2 * nb) + nb * a[3] * (1 - 2 * na) - CHUNK_SIZE * np * (1 - div) * (1 - m32) - - d[3] * (1 - div) + - d[3] * (1 - div) + 2 * np * d[3] * (1 - div); eq[0] - carry[0] * CHUNK_SIZE === 0; for (int index = 1; index < (CHUNKS_OP - 1); ++index) { eq[index] + carry[index-1] - carry[index] * CHUNK_SIZE === 0; - } + } eq[CHUNKS_OP-1] + carry[CHUNKS_OP-2] === 0; // binary contraint @@ -134,34 +136,34 @@ airtemplate Arith(int N = 2**18, const int operation_bus_id, const int dual_resu na32 * (1 - na32) === 0; nd32 * (1 - nd32) === 0; - col witness op; + col witness op; - // div m32 sa sb comm primary secondary opcodes na nb nr np na32 nd32 + // div m32 sa sb comm primary secondary opcodes na nb nr np na32 nd32 // ---------------------------------------------------------------------------------- - // 0 0 0 0 x mulu muluh (0xb0,0xb1) =0 =0 =0 =0 =0 =0 + // 0 0 0 0 x mulu muluh (0xb0,0xb1) =0 =0 =0 =0 =0 =0 // 0 0 1 0 *n/a* mulsuh (0xb2,0xb3) a3 =0 =0 d3 =0 =0 a3, d3 - // 0 0 1 1 x mul mulh (0xb4,0xb5) a3 b3 =0 d3 =0 =0 a3,b3, d3 + // 0 0 1 1 x mul mulh (0xb4,0xb5) a3 b3 =0 d3 =0 =0 a3,b3, d3 // 0 1 1 1 x mul_w *n/a* (0xb6,0xb7) a1 b1 =0 d3 c1 =0 d3, a1,b1,c1 - // 1 0 0 0 divu remu (0xb8,0xb9) =0 =0 =0 =0 =0 =0 - // 1 0 1 1 div rem (0xba,0xbb) a3 b3 d3 c3 =0 =0 a3,b3,c3,d3 + // 1 0 0 0 divu remu (0xb8,0xb9) =0 =0 =0 =0 =0 =0 + // 1 0 1 1 div rem (0xba,0xbb) a3 b3 d3 c3 =0 =0 a3,b3,c3,d3 // 1 1 0 0 divu_w remu_w (0xbc,0xbd) =0 =0 =0 =0 c1 d1 c1,d1 // 1 1 1 1 div_w rem_w (0xbe,0xbf) a1 b1 d1 c1 c1 d1 a1,b1,c1,d1 - // (*) removed combinations of flags div,m32,sa,sb did allow combinations div, m32, sa, sb + // (*) removed combinations of flags div,m32,sa,sb did allow combinations div, m32, sa, sb // see 5 previous constraints. // =0 means forced to zero by previous constraints // comm = commutative (trivial: commutative operations) - col witness bus_a_low; - bus_a_low === div * (c[0] - a[0]) - + a[0] - + CHUNK_SIZE * div * (c[1] - a[1]) + col witness bus_a_low; + bus_a_low === div * (c[0] - a[0]) + + a[0] + + CHUNK_SIZE * div * (c[1] - a[1]) + CHUNK_SIZE * a[1]; col witness bus_a_high; - bus_a_high === (1 - m32) * (div * (c[2] - a[2]) - + a[2] - + CHUNK_SIZE * div * (c[3] - a[3]) + bus_a_high === (1 - m32) * (div * (c[2] - a[2]) + + a[2] + + CHUNK_SIZE * div * (c[3] - a[3]) + CHUNK_SIZE * a[3]); @@ -172,13 +174,13 @@ airtemplate Arith(int N = 2**18, const int operation_bus_id, const int dual_resu col witness bus_b_high; bus_b_high === (1 - m32) * b[2] + (1 - m32) * CHUNK_SIZE * b[3]; - const expr res2_low = d[0] + CHUNK_SIZE * d[1]; + const expr res2_low = d[0] + CHUNK_SIZE * d[1]; const expr res2_high = d[2] + CHUNK_SIZE * d[3] + nd32 * 0xFFFFFFFF; if (dual_result) { // theorical cost: 4 columns col witness multiplicity_2; - lookup_proves(operation_bus_id, [op+1, bus_a_low, bus_a_high, bus_b_low, bus_b_high, res2_low, res2_high, 0], mul: multiplicity_2); + lookup_proves(operation_bus_id, [debug_main_step, op+1, bus_a_low, bus_a_high, bus_b_low, bus_b_high, res2_low, res2_high, 0], mul: multiplicity_2); } if (dual_result) { @@ -188,7 +190,7 @@ airtemplate Arith(int N = 2**18, const int operation_bus_id, const int dual_resu } else { col witness air.res1_low; res1_low === secondary_res * res2_low - (1 - secondary_res) * (a[0] + c[0] + CHUNK_SIZE * a[1] + CHUNK_SIZE * c[1] - bus_a_low); - + col witness air.div64; div64 === (1 - m32) * div; @@ -196,21 +198,22 @@ airtemplate Arith(int N = 2**18, const int operation_bus_id, const int dual_resu // res1_high === secondary_res * res2_high + (1 - secondary_res) * ((1 - m32) * (div * (a[2] - c[2]) + c[2] + 2**16 * div * (a[3] - c[3]) + 2**16 * c[3]) + div * na32 * 0xFFFFFFFF + (1 - div) * nd32 * 0xFFFFFFFF); res1_high === secondary_res * res2_high + (1 - secondary_res) * (div64 * (a[2] - c[2]) + (1 - m32) * c[2] + CHUNK_SIZE * div64 * (a[3] - c[3]) + (1 - m32) * 2**16 * c[3] + div * na32 * 0xFFFFFFFF + (1 - div) * nd32 * 0xFFFFFFFF); } - + col witness multiplicity; - lookup_proves(operation_bus_id, [op + secondary_res, - bus_a_low, bus_a_high, - bus_b_low, bus_b_high, + lookup_proves(operation_bus_id, [debug_main_step, + op + secondary_res, + bus_a_low, bus_a_high, + bus_b_low, bus_b_high, res1_low, res1_high, -// secondary_res * (res2_low - res1_low) + res1_low, +// secondary_res * (res2_low - res1_low) + res1_low, // secondary_res * (res2_high - res1_high) + res1_high, 0], mul: multiplicity); // TODO: review - lookup_assumes(operation_bus_id, [OP_LT, res2_low, res2_high, bus_b_low, bus_b_high, 0, 1, 1], sel: div); + lookup_assumes(operation_bus_id, [debug_main_step, OP_LT, res2_low, res2_high, bus_b_low, bus_b_high, 0, 1, 1], sel: div); for (int index = 0; index < length(carry); ++index) { range_check(colu: carry[index], min:-2**20, max: 2**20-1); // TODO: review carry range @@ -222,7 +225,7 @@ airtemplate Arith(int N = 2**18, const int operation_bus_id, const int dual_resu range_check(colu: b[2 * index], min:0, max: CHUNK_SIZE - 1); range_check(colu: c[2 * index], min:0, max: CHUNK_SIZE - 1); range_check(colu: d[2 * index], min:0, max: CHUNK_SIZE - 1); - } + } col witness range_a1; col witness range_b1; diff --git a/state-machines/arith/pil/arith_range_table.pil b/state-machines/arith/pil/arith_range_table.pil index 3dc21961..39c15fe6 100644 --- a/state-machines/arith/pil/arith_range_table.pil +++ b/state-machines/arith/pil/arith_range_table.pil @@ -4,16 +4,13 @@ require "operations.pil" const int ARITH_RANGE_TABLE_ID = 330; airtemplate ArithRangeTable(int N = 2**17) { - + col fixed RANGES = [0:2**16,1:2**15,2:2**15]; col fixed VALUES = [0..2**16-1]...; col witness multiplicity; lookup_proves(ARITH_TABLE_ID, [RANGES, VALUES], multiplicity); - - // REMOVE - multiplicity * (multiplicity - 1) === 0; } function arith_range_table_assumes(const expr range_type, const expr value) { diff --git a/state-machines/arith/pil/arith_table.pil b/state-machines/arith/pil/arith_table.pil index 01d82006..0aed5155 100644 --- a/state-machines/arith/pil/arith_table.pil +++ b/state-machines/arith/pil/arith_table.pil @@ -3,22 +3,22 @@ require "std_lookup.pil" const int ARITH_TABLE_ID = 330; airtemplate ArithTable(int N = 2**6) { - + // TABLE - // op + // op // m32|div|na|nb|nr|np|na32|nd32|range_a1(*)|range_b1(*)|range_c1(*)|range_d1(*)|range_a3(*)|range_b3(*)|range_c3(*)|range_d3(*) - // div m32 sa sb comm primary secondary opcodes na nb nr np na32 nd32 + // div m32 sa sb comm primary secondary opcodes na nb nr np na32 nd32 // ---------------------------------------------------------------------------------- - // 0 0 0 0 x mulu muluh (0xb0,0xb1) =0 =0 =0 =0 =0 =0 + // 0 0 0 0 x mulu muluh (0xb0,0xb1) =0 =0 =0 =0 =0 =0 // 0 0 1 0 *n/a* mulsuh (0xb2,0xb3) a3 =0 =0 d3 =0 =0 a3, d3 - // 0 0 1 1 x mul mulh (0xb4,0xb5) a3 b3 =0 d3 =0 =0 a3,b3, d3 + // 0 0 1 1 x mul mulh (0xb4,0xb5) a3 b3 =0 d3 =0 =0 a3,b3, d3 // 0 1 1 1 x mul_w *n/a* (0xb6,0xb7) a1 b1 =0 d3 c1 =0 d3, a1,b1,c1 - // 1 0 0 0 divu remu (0xb8,0xb9) =0 =0 =0 =0 =0 =0 - // 1 0 1 1 div rem (0xba,0xbb) a3 b3 d3 c3 =0 =0 a3,b3,c3,d3 + // 1 0 0 0 divu remu (0xb8,0xb9) =0 =0 =0 =0 =0 =0 + // 1 0 1 1 div rem (0xba,0xbb) a3 b3 d3 c3 =0 =0 a3,b3,c3,d3 // 1 1 0 0 divu_w remu_w (0xbc,0xbd) =0 =0 =0 =0 c1 d1 c1,d1 // 1 1 1 1 div_w rem_w (0xbe,0xbf) a1 b1 d1 c1 c1 d1 a1,b1,c1,d1 - + const int OPS[14] = [0xb0, 0xb1, 0xb3, 0xb4, 0xb5, 0xb6, 0xb8, 0xb9, 0xba, 0xbb, 0xbc, 0xbd, 0xbe, 0xbf]; col fixed OP; @@ -36,7 +36,7 @@ airtemplate ArithTable(int N = 2**6) { switch (opcode & 0xFE) { case 0xb2: // mulsuh - sa = 1; + sa = 1; case 0xb4: // mul, mulh sa = 1; sb = 1; @@ -57,14 +57,14 @@ airtemplate ArithTable(int N = 2**6) { sa = 1; sb = 1; div = 1; - m32 = 1; + m32 = 1; } // CASES: // sa = 0 sb = 0 => [a >= 0, b >= 0] // sa = 1 sb = 0 => [a >= 0, b >= 0], [a < 0, b >= 0] // sa = 1 sb = 1 => [a >= 0, b >= 0], [a < 0, b >= 0], [a >= 0, b < 0], [a < 0, b < 0] - + int cases = 1 + sa + sb + sa * sb; println("#ARITH_TABLE", opcode, index, cases); @@ -115,15 +115,15 @@ airtemplate ArithTable(int N = 2**6) { // div * (sa - sb) === 0; // (1 - div) * m32 * (1 - sa) === 0; // (1 - div) * m32 * (1 - sb) === 0; - - // div m32 sa sb comm primary secondary opcodes na nb nr np na32 nd32 + + // div m32 sa sb comm primary secondary opcodes na nb nr np na32 nd32 // ---------------------------------------------------------------------------------- - // 0 0 0 0 x mulu muluh (0xb0,0xb1) =0 =0 =0 =0 =0 =0 + // 0 0 0 0 x mulu muluh (0xb0,0xb1) =0 =0 =0 =0 =0 =0 // 0 0 1 0 *n/a* mulsuh (0xb2,0xb3) a3 =0 =0 d3 =0 =0 a3, d3 - // 0 0 1 1 x mul mulh (0xb4,0xb5) a3 b3 =0 d3 =0 =0 a3,b3, d3 + // 0 0 1 1 x mul mulh (0xb4,0xb5) a3 b3 =0 d3 =0 =0 a3,b3, d3 // 0 1 1 1 x mul_w *n/a* (0xb6,0xb7) a1 b1 =0 d3 c1 =0 d3, a1,b1,c1 - // 1 0 0 0 divu remu (0xb8,0xb9) =0 =0 =0 =0 =0 =0 - // 1 0 1 1 div rem (0xba,0xbb) a3 b3 d3 c3 =0 =0 a3,b3,c3,d3 + // 1 0 0 0 divu remu (0xb8,0xb9) =0 =0 =0 =0 =0 =0 + // 1 0 1 1 div rem (0xba,0xbb) a3 b3 d3 c3 =0 =0 a3,b3,c3,d3 // 1 1 0 0 divu_w remu_w (0xbc,0xbd) =0 =0 =0 =0 c1 d1 c1,d1 // 1 1 1 1 div_w rem_w (0xbe,0xbf) a1 b1 d1 c1 c1 d1 a1,b1,c1,d1 @@ -147,18 +147,18 @@ airtemplate ArithTable(int N = 2**6) { OP[index] = opcode; FLAGS_AND_RANGES[index] = m32 + 2 * div + 4 * na + 8 * nb + 16 * nr + 32 * np + 64 * na32 + 128 * nd32 + - 2**8 * range_a1 + 2**10 * range_b1 + 2**12 * range_c1 + 2**14 * range_d1 + + 2**8 * range_a1 + 2**10 * range_b1 + 2**12 * range_c1 + 2**14 * range_d1 + 2**16 * range_a3 + 2**18 * range_b3 + 2**20 * range_c3 + 2**22 * range_d3; - + index = index + 1; - if (index == N) break; + if (index == N) break; } - if (index == N) break; + if (index == N) break; } if (size == 0) size = index; } - println("ARITH_TABLE SIZE: ", size); + println("ARITH_TABLE SIZE: ", size); println("ARITH_FLAGS: ", FLAGS_AND_RANGES); for (index = 0; index < size; ++index) { println(FLAGS_AND_RANGES[index]); @@ -167,17 +167,14 @@ airtemplate ArithTable(int N = 2**6) { col witness multiplicity; lookup_proves(ARITH_TABLE_ID, mul: multiplicity, cols: [OP, FLAGS_AND_RANGES]); - - // REMOVE - multiplicity * (multiplicity - 1) === 0; } -function arith_table_assumes( const expr op, const expr m32, const expr div, const expr na, const expr nb, - const expr nr, const expr np, const expr na32, const expr nd32, - const expr range_a1, const expr range_b1, const expr range_c1, const expr range_d1, +function arith_table_assumes( const expr op, const expr m32, const expr div, const expr na, const expr nb, + const expr nr, const expr np, const expr na32, const expr nd32, + const expr range_a1, const expr range_b1, const expr range_c1, const expr range_d1, const expr range_a3, const expr range_b3, const expr range_c3, const expr range_d3) { // TODO: define rule for empty rows lookup_assumes(ARITH_TABLE_ID, cols: [ op, m32 + 2 * div + 4 * na + 8 * nb + 16 * nr + 32 * np + 64 * na32 + 128 * nd32 + - 2**8 * range_a1 + 2**10 * range_b1 + 2**12 * range_c1 + 2**14 * range_d1 + + 2**8 * range_a1 + 2**10 * range_b1 + 2**12 * range_c1 + 2**14 * range_d1 + 2**16 * range_a3 + 2**18 * range_b3 + 2**20 * range_c3 + 2**22 * range_d3]); } diff --git a/state-machines/arith/src/arith.rs b/state-machines/arith/src/arith.rs index cf4b3017..57ec63fe 100644 --- a/state-machines/arith/src/arith.rs +++ b/state-machines/arith/src/arith.rs @@ -9,8 +9,13 @@ use proofman_common::{ExecutionCtx, ProofCtx, SetupCtx}; use rayon::Scope; use sm_common::{OpResult, Provable, ThreadController}; use zisk_core::{zisk_ops::ZiskOp, ZiskRequiredOperation}; +use zisk_pil::{ + ARITH_AIRGROUP_ID, ARITH_AIR_IDS, ARITH_RANGE_TABLE_AIRGROUP_ID, ARITH_RANGE_TABLE_AIR_IDS, + ARITH_TABLE_AIRGROUP_ID, ARITH_TABLE_AIR_IDS, +}; -use crate::{Arith32SM, ArithFullSM, ArithMul32SM, ArithMul64SM, ArithRangeTableSM, ArithTableSM}; +// use crate::{Arith32SM, ArithFullSM, ArithMul32SM, ArithMul64SM, ArithRangeTableSM, ArithTableSM}; +use crate::{ArithFullSM, ArithRangeTableSM, ArithTableSM}; const PROVE_CHUNK_SIZE: usize = 1 << 12; @@ -24,14 +29,14 @@ pub struct ArithSM { // Inputs inputs: Mutex>, - inputs_32: Mutex>, - inputs_mul_32: Mutex>, - inputs_mul_64: Mutex>, + // inputs_32: Mutex>, + // inputs_mul_32: Mutex>, + // inputs_mul_64: Mutex>, // Secondary State machines - arith_32_sm: Arc>, - arith_mul_32_sm: Arc>, - arith_mul_64_sm: Arc>, + // arith_32_sm: Arc>, + // arith_mul_32_sm: Arc>, + // arith_mul_64_sm: Arc>, arith_full_sm: Arc>, arith_range_table_sm: Arc>, arith_table_sm: Arc>, @@ -39,71 +44,34 @@ pub struct ArithSM { impl ArithSM { pub fn new(wcm: Arc>) -> Arc { - // TODO: change this call, for calls to WitnessManager to obtain from airGroupId and airIds - // ON each SM, not need pass to the constructor - let arith_full_ids = ArithSM::::get_ids_by_name("Arith"); - let arith_32_ids = ArithSM::::get_ids_by_name("Arith32"); - let arith_mul_32_ids = ArithSM::::get_ids_by_name("ArithMul32"); - let arith_mul_64_ids = ArithSM::::get_ids_by_name("ArithMul64"); - let arith_range_table_ids = ArithSM::::get_ids_by_name("ArithRangeTable"); - let arith_table_ids = ArithSM::::get_ids_by_name("ArithTable"); - let arith_sm = Self { registered_predecessors: AtomicU32::new(0), threads_controller: Arc::new(ThreadController::new()), inputs: Mutex::new(Vec::new()), - inputs_32: Mutex::new(Vec::new()), - inputs_mul_32: Mutex::new(Vec::new()), - inputs_mul_64: Mutex::new(Vec::new()), - arith_full_sm: ArithFullSM::new(wcm.clone(), arith_full_ids.0, &[arith_full_ids.1]), - arith_32_sm: Arith32SM::new(wcm.clone(), arith_32_ids.0, &[arith_32_ids.1]), - arith_mul_32_sm: ArithMul32SM::new( - wcm.clone(), - arith_mul_32_ids.0, - &[arith_mul_32_ids.1], - ), - arith_mul_64_sm: ArithMul64SM::new( + arith_full_sm: ArithFullSM::new(wcm.clone(), ARITH_AIRGROUP_ID, ARITH_AIR_IDS), + arith_range_table_sm: ArithRangeTableSM::new( wcm.clone(), - arith_mul_64_ids.0, - &[arith_mul_64_ids.1], + ARITH_RANGE_TABLE_AIRGROUP_ID, + ARITH_RANGE_TABLE_AIR_IDS, ), - arith_range_table_sm: ArithRangeTableSM::new( + arith_table_sm: ArithTableSM::new( wcm.clone(), - arith_range_table_ids.0, - &[arith_range_table_ids.1], + ARITH_TABLE_AIRGROUP_ID, + ARITH_TABLE_AIR_IDS, ), - arith_table_sm: ArithTableSM::new(wcm.clone(), arith_table_ids.0, &[arith_table_ids.1]), }; let arith_sm = Arc::new(arith_sm); wcm.register_component(arith_sm.clone(), None, None); - arith_sm.arith_32_sm.register_predecessor(); - arith_sm.arith_mul_32_sm.register_predecessor(); - arith_sm.arith_mul_64_sm.register_predecessor(); + // arith_sm.arith_32_sm.register_predecessor(); + // arith_sm.arith_mul_32_sm.register_predecessor(); + // arith_sm.arith_mul_64_sm.register_predecessor(); arith_sm.arith_full_sm.register_predecessor(); arith_sm } - pub fn get_ids_by_name(name: &str) -> (usize, usize) { - const ARITH_AIRGROUP_ID: usize = 1; - if name == "Arith" { - return (ARITH_AIRGROUP_ID, 10); - } else if name == "Arith32" { - return (ARITH_AIRGROUP_ID, 11); - } else if name == "ArithMul64" { - return (ARITH_AIRGROUP_ID, 12); - } else if name == "ArithMul32" { - return (ARITH_AIRGROUP_ID, 13); - } else if name == "AirthRangeTable" { - return (ARITH_AIRGROUP_ID, 14); - } else if name == "ArithTable" { - return (ARITH_AIRGROUP_ID, 15); - } - return (0, 0); - } - pub fn register_predecessor(&self) { self.registered_predecessors.fetch_add(1, Ordering::SeqCst); } @@ -119,9 +87,9 @@ impl ArithSM { self.threads_controller.wait_for_threads(); - self.arith_32_sm.unregister_predecessor(scope); - self.arith_mul_32_sm.unregister_predecessor(scope); - self.arith_mul_64_sm.unregister_predecessor(scope); + // self.arith_32_sm.unregister_predecessor(scope); + // self.arith_mul_32_sm.unregister_predecessor(scope); + // self.arith_mul_64_sm.unregister_predecessor(scope); self.arith_full_sm.unregister_predecessor(scope); } } @@ -149,22 +117,22 @@ impl Provable for ArithSM { } fn prove(&self, operations: &[ZiskRequiredOperation], drain: bool, scope: &Scope) { - let mut _inputs32 = Vec::new(); - let mut _inputs64 = Vec::new(); + // let mut _inputs32 = Vec::new(); + // let mut _inputs64 = Vec::new(); - let operations64 = ArithMul64SM::::operations(); - let operations32 = Arith32SM::::operations(); + // let operations64 = ArithMul64SM::::operations(); + // let operations32 = Arith32SM::::operations(); // TODO Split the operations into 32 and 64 bit operations in parallel - for operation in operations { - if operations32.contains(&operation.opcode) { - _inputs32.push(operation.clone()); - } else if operations64.contains(&operation.opcode) { - _inputs64.push(operation.clone()); - } else { - panic!("ArithSM: Operator {:x} not found", operation.opcode); - } - } + // for operation in operations { + // if operations32.contains(&operation.opcode) { + // _inputs32.push(operation.clone()); + // } else if operations64.contains(&operation.opcode) { + // _inputs64.push(operation.clone()); + // } else { + // panic!("ArithSM: Operator {:x} not found", operation.opcode); + // } + // } // TODO When drain is true, drain remaining inputs to the 3264 bits state machine /* diff --git a/state-machines/arith/src/arith_full.rs b/state-machines/arith/src/arith_full.rs index 9fa53028..79e83fe2 100644 --- a/state-machines/arith/src/arith_full.rs +++ b/state-machines/arith/src/arith_full.rs @@ -12,7 +12,7 @@ use proofman::{WitnessComponent, WitnessManager}; use proofman_common::{ExecutionCtx, ProofCtx, SetupCtx}; use rayon::Scope; use sm_common::{OpResult, Provable, ThreadController}; -use zisk_core::{opcode_execute, ZiskRequiredOperation}; +use zisk_core::{zisk_ops::ZiskOp, ZiskRequiredOperation}; use zisk_pil::Arith0Row; const PROVE_CHUNK_SIZE: usize = 1 << 12; @@ -88,7 +88,7 @@ impl Provable for ArithFullSM { &self, operation: ZiskRequiredOperation, ) -> Result> { - let result: OpResult = opcode_execute(operation.opcode, operation.a, operation.b); + let result: OpResult = ZiskOp::execute(operation.opcode, operation.a, operation.b); Ok(result) } diff --git a/state-machines/arith/src/arith_mul_32.rs b/state-machines/arith/src/arith_mul_32.rs index 286f375f..ba8ee057 100644 --- a/state-machines/arith/src/arith_mul_32.rs +++ b/state-machines/arith/src/arith_mul_32.rs @@ -9,8 +9,6 @@ use proofman_common::{ExecutionCtx, ProofCtx, SetupCtx}; use rayon::Scope; use sm_common::{OpResult, Provable}; use zisk_core::{zisk_ops::ZiskOp, ZiskRequiredOperation}; - -use p3_field::Field; use zisk_pil::{ARITH3264_AIR_IDS, ARITH_AIRGROUP_ID}; const PROVE_CHUNK_SIZE: usize = 1 << 12; diff --git a/state-machines/arith/src/arith_mul_64.rs b/state-machines/arith/src/arith_mul_64.rs index 03401fdf..a925a6c4 100644 --- a/state-machines/arith/src/arith_mul_64.rs +++ b/state-machines/arith/src/arith_mul_64.rs @@ -8,7 +8,7 @@ use proofman::{WitnessComponent, WitnessManager}; use proofman_common::{ExecutionCtx, ProofCtx, SetupCtx}; use rayon::Scope; use sm_common::{OpResult, Provable}; -use zisk_core::{opcode_execute, ZiskRequiredOperation}; +use zisk_core::{zisk_ops::ZiskOp, ZiskRequiredOperation}; const PROVE_CHUNK_SIZE: usize = 1 << 12; @@ -74,7 +74,7 @@ impl Provable for ArithMul64SM { &self, operation: ZiskRequiredOperation, ) -> Result> { - let result: OpResult = opcode_execute(operation.opcode, operation.a, operation.b); + let result: OpResult = ZiskOp::execute(operation.opcode, operation.a, operation.b); Ok(result) } diff --git a/state-machines/arith/src/arith_range_table.rs b/state-machines/arith/src/arith_range_table.rs index d8570686..13f25fa0 100644 --- a/state-machines/arith/src/arith_range_table.rs +++ b/state-machines/arith/src/arith_range_table.rs @@ -8,8 +8,7 @@ use proofman::{WitnessComponent, WitnessManager}; use proofman_common::{ExecutionCtx, ProofCtx, SetupCtx}; use rayon::Scope; use sm_common::{OpResult, Provable}; -use zisk_core::{opcode_execute, ZiskRequiredOperation}; - +use zisk_core::{zisk_ops::ZiskOp, ZiskRequiredOperation}; const PROVE_CHUNK_SIZE: usize = 1 << 12; pub struct ArithRangeTableSM { @@ -74,7 +73,7 @@ impl Provable for ArithRangeTableSM { &self, operation: ZiskRequiredOperation, ) -> Result> { - let result: OpResult = opcode_execute(operation.opcode, operation.a, operation.b); + let result: OpResult = ZiskOp::execute(operation.opcode, operation.a, operation.b); Ok(result) } diff --git a/state-machines/arith/src/arith_table.rs b/state-machines/arith/src/arith_table.rs index c7807b25..88fda306 100644 --- a/state-machines/arith/src/arith_table.rs +++ b/state-machines/arith/src/arith_table.rs @@ -8,7 +8,7 @@ use proofman::{WitnessComponent, WitnessManager}; use proofman_common::{ExecutionCtx, ProofCtx, SetupCtx}; use rayon::Scope; use sm_common::{OpResult, Provable}; -use zisk_core::{opcode_execute, ZiskRequiredOperation}; +use zisk_core::{zisk_ops::ZiskOp, ZiskRequiredOperation}; const PROVE_CHUNK_SIZE: usize = 1 << 12; @@ -73,7 +73,7 @@ impl Provable for ArithTableSM { &self, operation: ZiskRequiredOperation, ) -> Result> { - let result: OpResult = opcode_execute(operation.opcode, operation.a, operation.b); + let result: OpResult = ZiskOp::execute(operation.opcode, operation.a, operation.b); Ok(result) } From 4f534bd2ba9751c89d9216bae92a2515ec736d13 Mon Sep 17 00:00:00 2001 From: zkronos73 Date: Wed, 16 Oct 2024 21:58:31 +0000 Subject: [PATCH 07/17] WIP arith: adding test of helpers --- state-machines/arith/pil/arith.pil | 2 +- state-machines/arith/pil/arith_table.pil | 8 +- state-machines/arith/src/arith.rs | 26 +- state-machines/arith/src/arith_full.rs | 43 +-- state-machines/arith/src/arith_helpers.rs | 410 ++++++++++++++++++++++ state-machines/arith/src/arith_mul_32.rs | 1 + state-machines/arith/src/lib.rs | 1 + 7 files changed, 451 insertions(+), 40 deletions(-) create mode 100644 state-machines/arith/src/arith_helpers.rs diff --git a/state-machines/arith/pil/arith.pil b/state-machines/arith/pil/arith.pil index 11f90d20..ab5c7891 100644 --- a/state-machines/arith/pil/arith.pil +++ b/state-machines/arith/pil/arith.pil @@ -249,7 +249,7 @@ airtemplate Arith(int N = 2**18, const int operation_bus_id, const int dual_resu range_d3 * (1 - range_d3) * (2 - range_d3) === 0; - arith_table_assumes(op, m32, div, na, nb, nr, np, na32, nd32, range_a1, range_b1, range_c1, range_d1, range_a3, range_b3, range_c3, range_d3); + arith_table_assumes(op, m32, div, na, nb, np, nr, na32, nd32, range_a1, range_b1, range_c1, range_d1, range_a3, range_b3, range_c3, range_d3); arith_range_table_assumes(range_a1, a[1]); arith_range_table_assumes(range_b1, b[1]); diff --git a/state-machines/arith/pil/arith_table.pil b/state-machines/arith/pil/arith_table.pil index 0aed5155..e9a48d26 100644 --- a/state-machines/arith/pil/arith_table.pil +++ b/state-machines/arith/pil/arith_table.pil @@ -72,8 +72,8 @@ airtemplate ArithTable(int N = 2**6) { for (int icase = 0; icase < cases; ++icase) { int na = 0; // a is negative int nb = 0; // b is negative - int nr = 0; // rem is negative int np = 0; // prod is negative + int nr = 0; // rem is negative int na32 = 0; // a is 32-bit negative, 31th bit is 1. int nd32 = 0; // d is 32-bit negative, 31th bit is 1. switch (icase) { @@ -146,7 +146,7 @@ airtemplate ArithTable(int N = 2**6) { int range_d3 = div * (1 - m32) * sa ? 1 + np : 0; OP[index] = opcode; - FLAGS_AND_RANGES[index] = m32 + 2 * div + 4 * na + 8 * nb + 16 * nr + 32 * np + 64 * na32 + 128 * nd32 + + FLAGS_AND_RANGES[index] = m32 + 2 * div + 4 * na + 8 * nb + 16 * np + 32 * nr + 64 * na32 + 128 * nd32 + 2**8 * range_a1 + 2**10 * range_b1 + 2**12 * range_c1 + 2**14 * range_d1 + 2**16 * range_a3 + 2**18 * range_b3 + 2**20 * range_c3 + 2**22 * range_d3; @@ -170,11 +170,11 @@ airtemplate ArithTable(int N = 2**6) { } function arith_table_assumes( const expr op, const expr m32, const expr div, const expr na, const expr nb, - const expr nr, const expr np, const expr na32, const expr nd32, + const expr np, const expr nr, const expr na32, const expr nd32, const expr range_a1, const expr range_b1, const expr range_c1, const expr range_d1, const expr range_a3, const expr range_b3, const expr range_c3, const expr range_d3) { // TODO: define rule for empty rows - lookup_assumes(ARITH_TABLE_ID, cols: [ op, m32 + 2 * div + 4 * na + 8 * nb + 16 * nr + 32 * np + 64 * na32 + 128 * nd32 + + lookup_assumes(ARITH_TABLE_ID, cols: [ op, m32 + 2 * div + 4 * na + 8 * nb + 16 * np + 32 * nr + 64 * na32 + 128 * nd32 + 2**8 * range_a1 + 2**10 * range_b1 + 2**12 * range_c1 + 2**14 * range_d1 + 2**16 * range_a3 + 2**18 * range_b3 + 2**20 * range_c3 + 2**22 * range_d3]); } diff --git a/state-machines/arith/src/arith.rs b/state-machines/arith/src/arith.rs index 57ec63fe..051686d4 100644 --- a/state-machines/arith/src/arith.rs +++ b/state-machines/arith/src/arith.rs @@ -38,27 +38,33 @@ pub struct ArithSM { // arith_mul_32_sm: Arc>, // arith_mul_64_sm: Arc>, arith_full_sm: Arc>, - arith_range_table_sm: Arc>, arith_table_sm: Arc>, + arith_range_table_sm: Arc>, } impl ArithSM { pub fn new(wcm: Arc>) -> Arc { + let arith_table_sm = + ArithTableSM::new(wcm.clone(), ARITH_TABLE_AIRGROUP_ID, ARITH_TABLE_AIR_IDS); + let arith_range_table_sm = ArithRangeTableSM::new( + wcm.clone(), + ARITH_RANGE_TABLE_AIRGROUP_ID, + ARITH_RANGE_TABLE_AIR_IDS, + ); + let arith_sm = Self { registered_predecessors: AtomicU32::new(0), threads_controller: Arc::new(ThreadController::new()), inputs: Mutex::new(Vec::new()), - arith_full_sm: ArithFullSM::new(wcm.clone(), ARITH_AIRGROUP_ID, ARITH_AIR_IDS), - arith_range_table_sm: ArithRangeTableSM::new( - wcm.clone(), - ARITH_RANGE_TABLE_AIRGROUP_ID, - ARITH_RANGE_TABLE_AIR_IDS, - ), - arith_table_sm: ArithTableSM::new( + arith_full_sm: ArithFullSM::new( wcm.clone(), - ARITH_TABLE_AIRGROUP_ID, - ARITH_TABLE_AIR_IDS, + arith_table_sm.clone(), + arith_range_table_sm.clone(), + ARITH_AIRGROUP_ID, + ARITH_AIR_IDS, ), + arith_table_sm, + arith_range_table_sm, }; let arith_sm = Arc::new(arith_sm); diff --git a/state-machines/arith/src/arith_full.rs b/state-machines/arith/src/arith_full.rs index 79e83fe2..a6995893 100644 --- a/state-machines/arith/src/arith_full.rs +++ b/state-machines/arith/src/arith_full.rs @@ -22,21 +22,29 @@ pub struct ArithFullSM { registered_predecessors: AtomicU32, // Thread controller to manage the execution of the state machines - // threads_controller: Arc, + threads_controller: Arc, // Inputs inputs: Mutex>, - - _phantom: std::marker::PhantomData, + arith_table_sm: Arc>, + arith_range_table_sm: Arc>, } impl ArithFullSM { - pub fn new(wcm: Arc>, airgroup_id: usize, air_ids: &[usize]) -> Arc { + const MY_NAME: &'static str = "Arith "; + pub fn new( + wcm: Arc>, + arith_table_sm: Arc>, + arith_range_table_sm: Arc>, + airgroup_id: usize, + air_ids: &[usize], + ) -> Arc { let arith_full_sm = Self { registered_predecessors: AtomicU32::new(0), + threads_controller: Arc::new(ThreadController::new()), inputs: Mutex::new(Vec::new()), - _phantom: std::marker::PhantomData, - //threads_controller: Arc::new(ThreadController::new()), + arith_table_sm, + arith_range_table_sm, }; let arith_full_sm = Arc::new(arith_full_sm); @@ -57,6 +65,10 @@ impl ArithFullSM { true, scope, ); + self.threads_controller.wait_for_threads(); + + self.arith_table_sm.unregister_predecessor(scope); + self.arith_range_table_sm.unregister_predecessor(scope); } } pub fn process_slice( @@ -84,14 +96,6 @@ impl WitnessComponent for ArithFullSM { } impl Provable for ArithFullSM { - fn calculate( - &self, - operation: ZiskRequiredOperation, - ) -> Result> { - let result: OpResult = ZiskOp::execute(operation.opcode, operation.a, operation.b); - Ok(result) - } - fn prove(&self, operations: &[ZiskRequiredOperation], drain: bool, scope: &Scope) { if let Ok(mut inputs) = self.inputs.lock() { inputs.extend_from_slice(operations); @@ -121,15 +125,4 @@ impl Provable for ArithFullSM { } } } - - fn calculate_prove( - &self, - operation: ZiskRequiredOperation, - drain: bool, - scope: &Scope, - ) -> Result> { - let result = self.calculate(operation.clone()); - self.prove(&[operation], drain, scope); - result - } } diff --git a/state-machines/arith/src/arith_helpers.rs b/state-machines/arith/src/arith_helpers.rs new file mode 100644 index 00000000..3f7f9635 --- /dev/null +++ b/state-machines/arith/src/arith_helpers.rs @@ -0,0 +1,410 @@ +const MULU: u8 = 0xb0; +const MULUH: u8 = 0xb1; +const MULSUH: u8 = 0xb3; +const MUL: u8 = 0xb4; +const MULH: u8 = 0xb5; +const MUL_W: u8 = 0xb6; +const DIVU: u8 = 0xb8; +const REMU: u8 = 0xb9; +const DIV: u8 = 0xba; +const REM: u8 = 0xbb; +const DIVU_W: u8 = 0xbc; +const REMU_W: u8 = 0xbd; +const DIV_W: u8 = 0xbe; +const REM_W: u8 = 0xbf; + +pub trait ArithHelpers { + fn calculate_flags_and_ranges( + a: u64, + b: u64, + op: u8, + div: &mut u64, + m32: &mut u64, + na: &mut u64, + nb: &mut u64, + nr: &mut u64, + np: &mut u64, + na32: &mut u64, + nd32: &mut u64, + ) -> [u64; 8] { + let mut range_a1: u64 = 0; + let mut range_b1: u64 = 0; + let mut range_c1: u64 = 0; + let mut range_d1: u64 = 0; + let mut range_a3: u64 = 0; + let mut range_b3: u64 = 0; + let mut range_c3: u64 = 0; + let mut range_d3: u64 = 0; + + // direct table opcode(14), signed 2 or 4 cases (0,na,nb,na+nb) + // 6 * 1 + 7 * 4 + 1 * 2 = 36 entries, + // no compacted => 16 x 4 = 64, key = (op - 0xb0) * 4 + na * 2 + nb + // output: div, m32, sa, sb, nr, np, na, na32, nd32, range x 2 x 4 + + // alternative: switch operation, + + let mut sa: u64 = 0; + let mut sb: u64 = 0; + + match op { + MULU | MULUH => {} + MULSUH => { + sa = 1; + } + MUL | MULH => { + sa = 1; + sb = 1; + } + MUL_W => { + *m32 = 1; + sa = 1; + sb = 1; + } + DIVU | REMU => { + *div = 1; + } + DIV | REM => { + sa = 1; + sb = 1; + *div = 1; + } + DIVU_W | REMU_W => { + // divu_w, remu_w + *div = 1; + *m32 = 1; + } + DIV_W | REM_W => { + // div_w, rem_w + sa = 1; + sb = 1; + *div = 1; + *m32 = 1; + } + _ => { + panic!("Invalid opcode"); + } + } + *na = if sa == 1 && (a as i64) < 0 { 1 } else { 0 }; + *nb = if sb == 1 && (b as i64) < 0 { 1 } else { 0 }; + *np = *na ^ *nb; + *nr = if *div == 1 { *na } else { 0 }; + *na32 = if *m32 == 1 { *na } else { 0 }; + *nd32 = if *m32 == 1 { *nr } else { 0 }; + + if *m32 == 1 { + range_a1 = sa + *na; + range_b1 = sb + *nb; + + if *div == 1 { + range_c1 = if *np == 1 || *na32 == 1 { 2 } else { 1 }; + range_d1 = if (*np == 1 && sa == 1) || *nd32 == 1 { 1 } else { 2 }; + } else { + range_c1 = 1 + *na32; + } + } else { + // m32 = 0 + range_b3 = if sb == 1 { 1 + *na } else { 0 }; + if sa == 1 { + // !m32 && sa + range_a3 = 1 + *na; + if *div == 1 { + // !m32 && sa && div + range_c3 = 1 + *np; + range_d3 = range_c3; + } + } + } + + [range_a1, range_b1, range_c1, range_d1, range_a3, range_b3, range_c3, range_d3] + } + /* + fn calculate_flags( + &self, + op: u8, + a: u64, + b: u64, + na: &mut i64, + nb: &mut i64, + nr: &mut i64, + np: &mut i64, + na32: &mut i64, + nd32: &mut i64, + m32: &mut i64, + div: &mut i64, + fab: &mut i64, + ) -> [u64; 8] { + let MUL_W = 1; + match (op) { + MUL_W=> { + let na = if (a as i32) < 0 { 1 } else { 0 }; + let nb = if (b as i32) < 0 { 1 } else { 0 }; + let c = (a as i32 * b as i32); + let nc = if c < 0 { 1 } else { 0 }; + } + MULSUH => { + let na = if (a as i64) < 0 { 1 } else { 0 }; + let _na = input.a & (2n**63n) ? 1n : 0n; + let _a = _na ? 2n ** 64n - a : a; + let _prod = _a * b; + let _nc = _prod && _na; + + _prod = _nc ? 2n**128n - _prod : _prod; + c = _prod & (2n**64n - 1n); + d = _prod >> 64n; + // console.log(input.c.toString(16), c.toString(16)); + break; + } + case 'divu': + case 'divu_w': { + this.log(opdef.n,a,b); + const div = a / b; + const rem = a % b; + c = a; + a = div; + d = rem; + break; + } + case 'div': { + this.log('div',a,b); + let _na = input.a & (2n**63n) ? 1n : 0n; + let _a = _na ? 2n ** 64n - a : a; + let _nb = input.b & (2n**63n) ? 1n : 0n; + let _b = _nb ? 2n ** 64n - b : b; + const div = _a / _b; + const rem = _a % _b; + c = a; + a = (div && _na ^ _nb) ? 2n**64n - div : div; + d = (rem && _na) ? 2n**64n - rem : rem; + break; + } + case 'div_w': { + this.log('div_w',a,b); + let _na = input.a & (2n**31n) ? 1n : 0n; + let _a = _na ? 2n ** 32n - a : a; + let _nb = input.b & (2n**31n) ? 1n : 0n; + let _b = _nb ? 2n ** 32n - b : b; + this.log([_a,_b].map(x => x.toString(16)).join(' ')); + const div = _a / _b; + const rem = _a % _b; + this.log(div, rem, _na, _nb) + c = a; + a = (div && (_na ^ _nb)) ? 2n**32n - div : div; + d = (rem && _na) ? 2n**32n - rem : rem; + this.log('[a,b,c,d]='+[a,b,c,d].map(x => x.toString(16)).join(' ')); + break; + } + } + if (m32) { + this.log(opdef.a_signed, opdef.b_signed, a.toString(16), (a & 0x80000000n).toString(16)); + a = (opdef.a_signed && a & 0x80000000n) ? a | 0xFFFFFFFF00000000n : a; + b = (opdef.b_signed && b & 0x80000000n) ? b | 0xFFFFFFFF00000000n : b; + } + + return [a,b,c,d]; + [0, 0, 0, 0, 0, 0, 0, 0] + } */ + fn calculate_chunks( + &self, + a: [i64; 4], + b: [i64; 4], + c: [i64; 4], + d: [i64; 4], + div: i64, + fab: i64, + na: i64, + nb: i64, + np: i64, + nr: i64, + m32: i64, + ) -> [i64; 8] { + // TODO: unroll this function in variants (div,m32) and (na,nb,nr,np) + // div, m32, na, nb === f(div,m32,na,nb) => fa, nb, nr + // unroll means 16 variants ==> but more performance + + let mut chunks: [i64; 8] = [0, 0, 0, 0, 0, 0, 0, 0]; + + chunks[0] = fab * a[0] * b[0] // chunk9 + - c[0] + + 2 * np * c[0] + + div * d[0] + - 2 * nr * d[0]; + + chunks[1] = fab * a[1] * b[0] // chunk1 + + fab * a[0] * b[1] + - c[1] + + 2 * np * c[1] + + div * d[1] + - 2 * nr * d[1]; + + chunks[2] = fab * a[2] * b[0] // chunk2 + + fab * a[1] * b[1] + + fab * a[0] * b[2] + - c[2] + + (2 * np) * c[2] + + div * d[2] + - 2 * nr * d[2] + - np * div * m32 + + nr * m32; + + chunks[3] = fab * a[3] * b[0] // chunk3 + + fab * a[2] * b[1] + + fab * a[1] * b[2] + + fab * a[0] * b[3] + - c[3] + + 2 * np * c[3] + + div * d[3] + - 2 * nr * d[3]; + + chunks[4] = fab * a[3] * b[1] // chunk4 + + fab * a[2] * b[2] + + fab * a[1] * b[3] + + b[0] * na * (1 - 2 * nb) + + a[0] * nb * (1 - 2 * na) + - np * div + + m32 + - 2 * div * m32 + + nr * (1 - m32) + - d[0] * (1 - div) + + d[0] * 2 * np * (1 - div); + + chunks[5] = fab * a[3] * b[2] // chunk5 + + fab * a[2] * b[3] + + a[1] * nb * (1 - 2 * na) + + b[1] * na * (1 - 2 * nb) + - d[1] * (1 - div) + + d[1] * 2 * np * (1 - div); + + chunks[6] = fab as i64 * a[3] * b[3] // chunk6 + + a[2] * nb * (1 - 2 * na) + + b[2] * na * (1 - 2 * nb) + - d[2] * (1 - div) + + d[2] * 2 * np * (1 - div); + + chunks[7] = 0x10000 * na * nb // chunk7 + + b[3] * na * (1 - 2 * nb) + + a[3] * nb * (1 - 2 * na) + - 0x10000 * np * (1 - div) * (1 - m32) + - d[3] * (1 - div) + + d[3] * 2 * np * (1 - div); + + chunks + } + fn me() -> i32 { + 13 + } +} + +#[test] +fn test_calculate_range_checks() { + struct TestArithHelpers {} + impl ArithHelpers for TestArithHelpers {} + + const MIN_N_64: u64 = 0x8000_0000_0000_0000; + const MAX_P_64: u64 = 0x7FFF_FFFF_FFFF_FFFF; + const MAX_64: u64 = 0xFFFF_FFFF_FFFF_FFFF; + + const ALL: u64 = 0x0033; + const ALL_P_64: u64 = 0x0034; + const ALL_N_64: u64 = 0x0035; + + const END: u64 = 0x0036; + const ALL_P_64_VALUES: [u64; 5] = [0, 1, MAX_P_64, END, 0]; + const ALL_N_64_VALUES: [u64; 5] = [MIN_N_64, MAX_64, END, 0, 0]; + const ALL_64_VALUES: [u64; 5] = [0, 1, MAX_P_64, MAX_64, MIN_N_64]; + + const F_M32: u64 = 0x0001; + const F_DIV: u64 = 0x0002; + const F_NA: u64 = 0x0004; + const F_NB: u64 = 0x0008; + const F_NP: u64 = 0x0010; + const F_NR: u64 = 0x0020; + const F_NA32: u64 = 0x0040; + const F_ND32: u64 = 0x0080; + + struct TestParams { + op: u8, + a: u64, + b: u64, + flags: u64, + } + + // NOTE: update TEST_COUNT with number of tests, ALL,ALL => 3*3 = 9 + const TEST_COUNT: u32 = 20; + + let tests = [ + // flags: div, m32, sa, sb, na, nr, np, np, na32, nd32 + TestParams { op: MULU, a: ALL, b: ALL, flags: 0 }, + TestParams { op: MULUH, a: ALL, b: ALL, flags: 0 }, + TestParams { op: MULSUH, a: ALL_P_64, b: ALL, flags: 0 }, + TestParams { op: MULSUH, a: ALL_N_64, b: ALL, flags: F_NA + F_NP }, + TestParams { op: MUL_W, a: ALL_P_64, b: ALL_P_64, flags: F_M32 }, + TestParams { op: MUL_W, a: ALL_N_64, b: ALL_P_64, flags: F_M32 + F_NA + F_NP }, + TestParams { op: MUL_W, a: ALL_P_64, b: ALL_N_64, flags: F_M32 + F_NB + F_NP }, + TestParams { op: MUL_W, a: ALL_N_64, b: ALL_N_64, flags: F_M32 + F_NA + F_NB }, + TestParams { op: DIV, a: 0, b: 0, flags: F_DIV }, + TestParams { op: DIV, a: MIN_N_64, b: MAX_P_64, flags: F_DIV + F_NA + F_NP + F_NR }, + ]; + + let mut count = 0; + let mut index: u32 = 0; + for test in tests { + let a_values = if test.a == ALL { + ALL_64_VALUES + } else if test.a == ALL_N_64 { + ALL_N_64_VALUES + } else if test.a == ALL_P_64 { + ALL_P_64_VALUES + } else { + [test.a, END, 0, 0, 0] + }; + for a in a_values { + if a == END { + break; + }; + let b_values = if test.b == ALL { + ALL_64_VALUES + } else if test.b == ALL_N_64 { + ALL_N_64_VALUES + } else if test.b == ALL_P_64 { + ALL_P_64_VALUES + } else { + [test.b, END, 0, 0, 0] + }; + for b in b_values { + if b == END { + break; + }; + let mut div: u64 = 0; + let mut m32: u64 = 0; + let mut na: u64 = 0; + let mut nb: u64 = 0; + let mut nr: u64 = 0; + let mut np: u64 = 0; + let mut na32: u64 = 0; + let mut nd32: u64 = 0; + + TestArithHelpers::calculate_flags_and_ranges( + a, b, test.op, &mut div, &mut m32, &mut na, &mut nb, &mut nr, &mut np, + &mut na32, &mut nd32, + ); + let flags = + m32 + div * 2 + na * 4 + nb * 8 + np * 16 + nr * 32 + na32 * 64 + nd32 * 128; + + assert_eq!( + flags, + test.flags, + "testing #{} op:0x{:x} with a:0x{:X} b:0x{:X} flags:{:b} vs {:b} [div, m32, sa, sb, na, nb, np, nr, na32, nd32]", + index, + test.op, + a, + b, + flags, + test.flags, + ); + count += 1; + } + } + index += 1; + } + assert_eq!(count, TEST_COUNT, "Number of tests not matching"); +} diff --git a/state-machines/arith/src/arith_mul_32.rs b/state-machines/arith/src/arith_mul_32.rs index ba8ee057..4d883440 100644 --- a/state-machines/arith/src/arith_mul_32.rs +++ b/state-machines/arith/src/arith_mul_32.rs @@ -10,6 +10,7 @@ use rayon::Scope; use sm_common::{OpResult, Provable}; use zisk_core::{zisk_ops::ZiskOp, ZiskRequiredOperation}; use zisk_pil::{ARITH3264_AIR_IDS, ARITH_AIRGROUP_ID}; + const PROVE_CHUNK_SIZE: usize = 1 << 12; pub struct ArithMul32SM { diff --git a/state-machines/arith/src/lib.rs b/state-machines/arith/src/lib.rs index 2eaeece8..4e3603bd 100644 --- a/state-machines/arith/src/lib.rs +++ b/state-machines/arith/src/lib.rs @@ -1,6 +1,7 @@ mod arith; mod arith_32; mod arith_full; +mod arith_helpers; mod arith_mul_32; mod arith_mul_64; mod arith_range_table; From 20f6f52d0507dd9e7dba24889cf67198e139efbc Mon Sep 17 00:00:00 2001 From: zkronos73 Date: Mon, 21 Oct 2024 22:03:18 +0000 Subject: [PATCH 08/17] WIP arith --- .gitignore | 3 +- pil/zisk.pil | 4 +- state-machines/arith/pil/arith.pil | 266 ++++---- .../arith/pil/arith_range_table.pil | 8 +- state-machines/arith/pil/arith_table.pil | 14 +- state-machines/arith/src/arith_helpers.rs | 592 +++++++++++++++--- .../binary/pil/binary_extension.pil | 7 +- 7 files changed, 633 insertions(+), 261 deletions(-) diff --git a/.gitignore b/.gitignore index 52498727..63d76060 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,5 @@ /build /proofs *.pilout -/tmp \ No newline at end of file +/tmp +*.log \ No newline at end of file diff --git a/pil/zisk.pil b/pil/zisk.pil index 5cb4f363..e76efaa8 100644 --- a/pil/zisk.pil +++ b/pil/zisk.pil @@ -28,7 +28,7 @@ airgroup ArithTable { airgroup ArithRangeTable { ArithRangeTable(); } - +/* airgroup Binary { Binary(N: 2**21, operation_bus_id: OPERATION_BUS_ID); } @@ -43,4 +43,4 @@ airgroup BinaryExtension { airgroup BinaryExtensionTable { BinaryExtensionTable(disable_fixed: 0); -} \ No newline at end of file +}*/ \ No newline at end of file diff --git a/state-machines/arith/pil/arith.pil b/state-machines/arith/pil/arith.pil index ab5c7891..cdd83cdf 100644 --- a/state-machines/arith/pil/arith.pil +++ b/state-machines/arith/pil/arith.pil @@ -4,16 +4,11 @@ require "operations.pil" require "arith_table.pil" require "arith_range_table.pil" -// generic 64 u64 mul_u64 32 *u32 -// witness 45 41 30 26 27 13 -// lookups 3 3 3 2 3 3 -// range_checks 16+7 16+7 16+7 16+7 8+3 7+2 -// ---------------------------------------------------------- -// TOTAL 123 119 108 101 69 61 -// -// (*) unsigned 32 bit operations only divu_w, remu_w +// full mul_64 full_32 mul_32 +// TOTAL 88 77 57 44 airtemplate Arith(int N = 2**18, const int operation_bus_id, const int dual_result = 0) { + // TODO: const int enable_div = 1, const int enable_32_bits = 1, const int enable_64_bits = 1 // NOTE: // Divisions and remainders by 0 are done by QuickOps @@ -32,8 +27,7 @@ airtemplate Arith(int N = 2**18, const int operation_bus_id, const int dual_resu col witness nb; // b is negative col witness nr; // rem is negative col witness np; // prod is negative - col witness na32; // a is 32-bit negative, 31th bit is 1. - col witness nd32; // d is 32-bit negative, 31th bit is 1. + col witness sext; // sign extend for 32 bits result col witness m32; // 32 bits operation col witness div; // division operation (div,rem) @@ -44,81 +38,81 @@ airtemplate Arith(int N = 2**18, const int operation_bus_id, const int dual_resu col witness debug_main_step; // only for debug - if (!dual_result) { - col witness air.secondary_res; // op_index: 0 => first result, 1 => second result; - secondary_res * (secondary_res - 1) === 0; - } else { - const expr air.secondary_res = 0; - } + col witness secondary_res; // op_index: 0 => first result, 1 => second result; + secondary_res * (secondary_res - 1) === 0; // factor ab € {-1, 1} fab === 1 - 2 * na - 2 * nb + 4 * na * nb; const expr eq[CHUNKS_OP]; - eq[0] = fab * a[0] * b[0] + // NOTE: Equations with m32 for multiplication not exists, because mul m32 it's an unsigned operation. + // In internal equations, it's same than unsigned mul 64 where high part of a and b are zero + + eq[0] = fab * a[0] * b[0] // 3 degree - c[0] + 2 * np * c[0] + div * d[0] - 2 * nr * d[0]; - eq[1] = fab * a[1] * b[0] - + fab * a[0] * b[1] + eq[1] = fab * a[1] * b[0] // 3 degree + + fab * a[0] * b[1] // 3 degree - c[1] + 2 * np * c[1] + div * d[1] - 2 * nr * d[1]; - eq[2] = fab * a[2] * b[0] - + fab * a[1] * b[1] - + fab * a[0] * b[2] + eq[2] = fab * a[2] * b[0] // 3 degree + + fab * a[1] * b[1] // 3 degree + + fab * a[0] * b[2] // 3 degree - c[2] + 2 * np * c[2] + div * d[2] - 2 * nr * d[2] - - np * div * m32 - + nr * m32; + - np * div * m32 // 3 degree + + nr * div * m32; // 3 degree - eq[3] = fab * a[3] * b[0] - + fab * a[2] * b[1] - + fab * a[1] * b[2] - + fab * a[0] * b[3] + eq[3] = fab * a[3] * b[0] // 3 degree + + fab * a[2] * b[1] // 3 degree + + fab * a[1] * b[2] // 3 degree + + fab * a[0] * b[3] // 3 degree - c[3] + 2 * np * c[3] + div * d[3] - 2 * nr * d[3]; - eq[4] = fab * a[3] * b[1] - + fab * a[2] * b[2] - + fab * a[1] * b[3] - + na * b[0] * (1 - 2 * nb) - + nb * a[0] * (1 - 2 * na) - - np * div // \ - + np * m32 // np * (div ^ m32) - - 2 * div * m32 * np // / - + nr * (1 - m32) + eq[4] = fab * a[3] * b[1] // 3 degree + + fab * a[2] * b[2] // 3 degree + + fab * a[1] * b[3] // 3 degree + + na * b[0] * (1 - 2 * nb) // 3 degree + + nb * a[0] * (1 - 2 * na) // 3 degree + - np * div // | + + np * div * m32 // 3 degree | np * (div ^ m32) + - 2 * div * m32 * np // 3 degree | + // + nr * (1 - m32) * div // 3 degree - d[0] * (1 - div) - + 2 * np * d[0] * (1 - div); + + 2 * np * d[0] * (1 - div); // 3 degree - eq[5] = fab * a[3] * b[2] - + fab * a[2] * b[3] + eq[5] = fab * a[3] * b[2] // 3 degree + + fab * a[2] * b[3] // 3 degree + nb * a[1] * (1 - 2 * na) + na * b[1] * (1 - 2 * nb) - d[1] * (1 - div) + 2 * np * d[1] * (1 - div); - eq[6] = fab * a[3] * b[3] - + nb * a[2] * (1 - 2 * na) - + na * b[2] * (1 - 2 * nb) + eq[6] = fab * a[3] * b[3] // 3 degree + + nb * a[2] * (1 - 2 * na) // 3 degree + + na * b[2] * (1 - 2 * nb) // 3 degree - d[2] * (1 - div) - + 2 * np * d[2] * (1 - div); + + 2 * np * d[2] * (1 - div); // 3 degree eq[7] = CHUNK_SIZE * na * nb - + na * b[3] * (1 - 2 * nb) - + nb * a[3] * (1 - 2 * na) - - CHUNK_SIZE * np * (1 - div) * (1 - m32) + + na * b[3] * (1 - 2 * nb) // 3 degree + + nb * a[3] * (1 - 2 * na) // 3 degree + // - CHUNK_SIZE * np * (1 - div) * (1 - m32) // 3 degree + - CHUNK_SIZE * np * (1 - div) - d[3] * (1 - div) - + 2 * np * d[3] * (1 - div); + + 2 * np * d[3] * (1 - div); // 3 degree eq[0] - carry[0] * CHUNK_SIZE === 0; for (int index = 1; index < (CHUNKS_OP - 1); ++index) { @@ -133,72 +127,53 @@ airtemplate Arith(int N = 2**18, const int operation_bus_id, const int dual_resu nb * (1 - nb) === 0; nr * (1 - nr) === 0; np * (1 - np) === 0; - na32 * (1 - na32) === 0; - nd32 * (1 - nd32) === 0; + sext * (1 - sext) === 0; col witness op; - // div m32 sa sb comm primary secondary opcodes na nb nr np na32 nd32 - // ---------------------------------------------------------------------------------- - // 0 0 0 0 x mulu muluh (0xb0,0xb1) =0 =0 =0 =0 =0 =0 - // 0 0 1 0 *n/a* mulsuh (0xb2,0xb3) a3 =0 =0 d3 =0 =0 a3, d3 - // 0 0 1 1 x mul mulh (0xb4,0xb5) a3 b3 =0 d3 =0 =0 a3,b3, d3 - // 0 1 1 1 x mul_w *n/a* (0xb6,0xb7) a1 b1 =0 d3 c1 =0 d3, a1,b1,c1 - // 1 0 0 0 divu remu (0xb8,0xb9) =0 =0 =0 =0 =0 =0 - // 1 0 1 1 div rem (0xba,0xbb) a3 b3 d3 c3 =0 =0 a3,b3,c3,d3 - // 1 1 0 0 divu_w remu_w (0xbc,0xbd) =0 =0 =0 =0 c1 d1 c1,d1 - // 1 1 1 1 div_w rem_w (0xbe,0xbf) a1 b1 d1 c1 c1 d1 a1,b1,c1,d1 + // div m32 sa sb primary secondary opcodes na nb np nr sext + // ----------------------------------------------------------------------------- + // 0 0 0 0 mulu muluh (0xb0,0xb1) =0 =0 =0 =0 =0 =0 + // 0 0 1 0 *n/a* mulsuh (0xb2,0xb3) a3 =0 d3 =0 =0 =0 a3, d3 + // 0 0 1 1 mul mulh (0xb4,0xb5) a3 b3 d3 =0 =0 =0 a3,b3, d3 + // 0 1 0 0 mul_w *n/a* (0xb6,0xb7) =0 =0 =0 =0 c1 =0 a1,b1,c1 + // 1 0 0 0 divu remu (0xb8,0xb9) =0 =0 =0 =0 =0 =0 + // 1 0 1 1 div rem (0xba,0xbb) a3 b3 c3 d3 =0 =0 a3,b3,c3,d3 + // 1 1 0 0 divu_w remu_w (0xbc,0xbd) =0 =0 =0 =0 c1 d1 c1,d1 + // 1 1 1 1 div_w rem_w (0xbe,0xbf) a1 b1 c1 d1 c1 d1 a1,b1,c1,d1 // (*) removed combinations of flags div,m32,sa,sb did allow combinations div, m32, sa, sb // see 5 previous constraints. // =0 means forced to zero by previous constraints - // comm = commutative (trivial: commutative operations) + + // bus result mul div + // -------------------------------- + // primary c a + // secondary d d col witness bus_a_low; - bus_a_low === div * (c[0] - a[0]) - + a[0] - + CHUNK_SIZE * div * (c[1] - a[1]) - + CHUNK_SIZE * a[1]; + bus_a_low === div * (c[0] + c[1] * CHUNK_SIZE) + (1 - div) * (a[0] + a[1] * CHUNK_SIZE); col witness bus_a_high; - bus_a_high === (1 - m32) * (div * (c[2] - a[2]) - + a[2] - + CHUNK_SIZE * div * (c[3] - a[3]) - + CHUNK_SIZE * a[3]); + bus_a_high === div * (c[2] + c[2] * CHUNK_SIZE) + (1 - div) * (a[2] + a[3] * CHUNK_SIZE); + m32 * (1 - bus_a_high) === 0; const expr bus_b_low = b[0] + CHUNK_SIZE * b[1]; + const expr bus_b_high = b[2] + CHUNK_SIZE * b[3]; - // TODO: na32 and nd32 only valid on 32 bit operations - // TODO: m32 === 0 ==> b[2],a[2],b[3],a[3] === 0 avoid two witness - col witness bus_b_high; - bus_b_high === (1 - m32) * b[2] + (1 - m32) * CHUNK_SIZE * b[3]; + m32 * (1 - b[2]) === 0; + m32 * (1 - b[3]) === 0; const expr res2_low = d[0] + CHUNK_SIZE * d[1]; - const expr res2_high = d[2] + CHUNK_SIZE * d[3] + nd32 * 0xFFFFFFFF; - - if (dual_result) { - // theorical cost: 4 columns - col witness multiplicity_2; - lookup_proves(operation_bus_id, [debug_main_step, op+1, bus_a_low, bus_a_high, bus_b_low, bus_b_high, res2_low, res2_high, 0], mul: multiplicity_2); - } + const expr res2_high = d[2] + CHUNK_SIZE * d[3]; - if (dual_result) { - const expr air.res1_low = a[0] + c[0] + CHUNK_SIZE * a[1] + CHUNK_SIZE * c[1] - bus_a_low; - col witness air.res1_high; - res1_high === (1 - m32) * (div * (a[2] - c[2]) + c[2] + CHUNK_SIZE * div * (a[3] - c[3]) + CHUNK_SIZE * c[3]) + div * na32 * 0xFFFFFFFF + (1 - div) * nd32 * 0xFFFFFFFF; - } else { - col witness air.res1_low; - res1_low === secondary_res * res2_low - (1 - secondary_res) * (a[0] + c[0] + CHUNK_SIZE * a[1] + CHUNK_SIZE * c[1] - bus_a_low); - - col witness air.div64; - div64 === (1 - m32) * div; - - col witness air.res1_high; - // res1_high === secondary_res * res2_high + (1 - secondary_res) * ((1 - m32) * (div * (a[2] - c[2]) + c[2] + 2**16 * div * (a[3] - c[3]) + 2**16 * c[3]) + div * na32 * 0xFFFFFFFF + (1 - div) * nd32 * 0xFFFFFFFF); - res1_high === secondary_res * res2_high + (1 - secondary_res) * (div64 * (a[2] - c[2]) + (1 - m32) * c[2] + CHUNK_SIZE * div64 * (a[3] - c[3]) + (1 - m32) * 2**16 * c[3] + div * na32 * 0xFFFFFFFF + (1 - div) * nd32 * 0xFFFFFFFF); - } + col witness res_low; + res_low === secondary_res * res2_low + (1 - secondary_res) * (a[0] + c[0] + CHUNK_SIZE * (a[1] + c[1]) - bus_a_low); + col witness res_high; + res_high === (1 - m32) * (secondary_res * res2_high + (1 - secondary_res) * (a[2] + c[2] + CHUNK_SIZE * (a[3] + c[3]) - bus_a_high)) + + sext * 0xFFFFFFFF; col witness multiplicity; @@ -206,9 +181,7 @@ airtemplate Arith(int N = 2**18, const int operation_bus_id, const int dual_resu op + secondary_res, bus_a_low, bus_a_high, bus_b_low, bus_b_high, - res1_low, res1_high, -// secondary_res * (res2_low - res1_low) + res1_low, -// secondary_res * (res2_high - res1_high) + res1_high, + res_low, res_high, 0], mul: multiplicity); @@ -219,50 +192,63 @@ airtemplate Arith(int N = 2**18, const int operation_bus_id, const int dual_resu range_check(colu: carry[index], min:-2**20, max: 2**20-1); // TODO: review carry range } + // mul a * b = c + d * 2^64 + // div a * b + d = c (a <=> c) + + // range_ab / range_cd + // + // a3 a1 b3 b1 + // rid c3 c1 d3 d1 range 2^16 2^15 notes + // --- -- -- -- -- ----- ---- ---- ------------------------- + // 0 F F F F ab cd 4 0 + // 1 F F + F cd 3 1 b3 sign => a3 sign + // 2 F F - F cd 3 1 b3 sign => a3 sign + // 3 + F F F ab 3 1 c3 sign => d3 sign + // 4 + F + F ab cd 2 2 + // 5 + F - F ab cd 2 2 + // 6 - F F F ab 3 1 c3 sign => d3 sign + // 7 - F + F ab cd 2 2 + // 8 - F - F ab cd 2 2 + // 9 F F F + cd a1 sign <=> b1 sign / d1 sign => c1 sign + // 10 F F F - cd a1 sign <=> b1 sign / d1 sign => c1 sign + // 11 F + F F cd 3 1 a1 sign <=> b1 sign + // 12 F + F + ab cd 2 2 + // 13 F + F - ab cd 2 2 + // 14 F - F F cd 3 1 a1 sign <=> b1 sign + // 15 F - F + ab cd 2 2 + // 16 F - F - ab cd 2 2 + // ---- ---- + // 38 22 = 60 + // + // F: [0..2^16-1] +:[0..2^15-1] -:[2^15..2^16-1] + // + // 22 * 2^15 + 38 * 2^16 = (11+38) * 2^16 = 49 * 2^16 < 2^6 * 2^16 ==> 22 bits + + col witness range_ab; + col witness range_cd; + + arith_table_assumes(op, m32, div, na, nb, np, nr, sext, range_ab, range_cd); + + // 0 - a1/c1 + // 1 - b1/d1 + // 2 - a3/c3 + // 3 - b3/d3 + + arith_range_table_assumes(range_ab, 0, a[1]); + arith_range_table_assumes(range_ab, 1, b[1]); + arith_range_table_assumes(range_cd, 0, c[1]); + arith_range_table_assumes(range_cd, 1, d[1]); + arith_range_table_assumes(range_ab, 2, a[3]); + arith_range_table_assumes(range_ab, 3, b[3]); + arith_range_table_assumes(range_cd, 2, c[3]); + arith_range_table_assumes(range_cd, 3, d[3]); + // loop for range checks index 0, 2 for (int index = 0; index < 2; ++index) { - range_check(colu: a[2 * index], min:0, max: CHUNK_SIZE - 1); - range_check(colu: b[2 * index], min:0, max: CHUNK_SIZE - 1); - range_check(colu: c[2 * index], min:0, max: CHUNK_SIZE - 1); - range_check(colu: d[2 * index], min:0, max: CHUNK_SIZE - 1); + arith_range_table_assumes(0, 0, a[2 * index]); + arith_range_table_assumes(0, 0, b[2 * index]); + arith_range_table_assumes(0, 0, c[2 * index]); + arith_range_table_assumes(0, 0, d[2 * index]); } - col witness range_a1; - col witness range_b1; - col witness range_c1; - col witness range_d1; - - col witness range_a3; - col witness range_b3; - col witness range_c3; - col witness range_d3; - - // verify values of range_xy € {0,1,2} => these constraints not generate - // intermediate columns - range_a1 * (1 - range_a1) * (2 - range_a1) === 0; - range_b1 * (1 - range_b1) * (2 - range_b1) === 0; - range_c1 * (1 - range_c1) * (2 - range_c1) === 0; - range_d1 * (1 - range_d1) * (2 - range_d1) === 0; - range_a3 * (1 - range_a3) * (2 - range_a3) === 0; - range_b3 * (1 - range_b3) * (2 - range_b3) === 0; - range_c3 * (1 - range_c3) * (2 - range_c3) === 0; - range_d3 * (1 - range_d3) * (2 - range_d3) === 0; - - - arith_table_assumes(op, m32, div, na, nb, np, nr, na32, nd32, range_a1, range_b1, range_c1, range_d1, range_a3, range_b3, range_c3, range_d3); - - arith_range_table_assumes(range_a1, a[1]); - arith_range_table_assumes(range_b1, b[1]); - arith_range_table_assumes(range_c1, c[1]); - arith_range_table_assumes(range_d1, d[1]); - arith_range_table_assumes(range_a3, a[3]); - arith_range_table_assumes(range_b3, b[3]); - arith_range_table_assumes(range_c3, c[3]); - arith_range_table_assumes(range_d3, d[3]); } - -function instance_arith(const int rows = 2**21, const int bus_id) { - Arith(rows, operation_bus_id: bus_id); - ArithTable(); - ArithRangeTable(); -} \ No newline at end of file diff --git a/state-machines/arith/pil/arith_range_table.pil b/state-machines/arith/pil/arith_range_table.pil index 39c15fe6..7a03c0f9 100644 --- a/state-machines/arith/pil/arith_range_table.pil +++ b/state-machines/arith/pil/arith_range_table.pil @@ -5,15 +5,17 @@ const int ARITH_RANGE_TABLE_ID = 330; airtemplate ArithRangeTable(int N = 2**17) { + // TODO: update values col fixed RANGES = [0:2**16,1:2**15,2:2**15]; + col fixed INDEX = [0..2]...; col fixed VALUES = [0..2**16-1]...; col witness multiplicity; - lookup_proves(ARITH_TABLE_ID, [RANGES, VALUES], multiplicity); + lookup_proves(ARITH_RANGE_TABLE_ID, [RANGES, INDEX, VALUES], multiplicity); } -function arith_range_table_assumes(const expr range_type, const expr value) { +function arith_range_table_assumes(const expr range_type, const int index, const expr value) { // TODO: define rule for empty rows - lookup_assumes(ARITH_RANGE_TABLE_ID, [range_type, value]); + lookup_assumes(ARITH_RANGE_TABLE_ID, [range_type, index, value]); } diff --git a/state-machines/arith/pil/arith_table.pil b/state-machines/arith/pil/arith_table.pil index e9a48d26..95d3032c 100644 --- a/state-machines/arith/pil/arith_table.pil +++ b/state-machines/arith/pil/arith_table.pil @@ -1,6 +1,6 @@ require "std_lookup.pil" -const int ARITH_TABLE_ID = 330; +const int ARITH_TABLE_ID = 331; airtemplate ArithTable(int N = 2**6) { @@ -166,15 +166,13 @@ airtemplate ArithTable(int N = 2**6) { col witness multiplicity; - lookup_proves(ARITH_TABLE_ID, mul: multiplicity, cols: [OP, FLAGS_AND_RANGES]); + // TODO: + lookup_proves(ARITH_TABLE_ID, mul: multiplicity, cols: [OP, FLAGS_AND_RANGES, 0, 0]); } function arith_table_assumes( const expr op, const expr m32, const expr div, const expr na, const expr nb, - const expr np, const expr nr, const expr na32, const expr nd32, - const expr range_a1, const expr range_b1, const expr range_c1, const expr range_d1, - const expr range_a3, const expr range_b3, const expr range_c3, const expr range_d3) { + const expr np, const expr nr, const expr sext, const expr range_ab, const expr range_cd) { // TODO: define rule for empty rows - lookup_assumes(ARITH_TABLE_ID, cols: [ op, m32 + 2 * div + 4 * na + 8 * nb + 16 * np + 32 * nr + 64 * na32 + 128 * nd32 + - 2**8 * range_a1 + 2**10 * range_b1 + 2**12 * range_c1 + 2**14 * range_d1 + - 2**16 * range_a3 + 2**18 * range_b3 + 2**20 * range_c3 + 2**22 * range_d3]); + lookup_assumes(ARITH_TABLE_ID, cols: [ op, m32 + 2 * div + 4 * na + 8 * nb + 16 * np + 32 * nr + 64 * sext, + range_ab, range_cd]); } diff --git a/state-machines/arith/src/arith_helpers.rs b/state-machines/arith/src/arith_helpers.rs index 3f7f9635..e41459eb 100644 --- a/state-machines/arith/src/arith_helpers.rs +++ b/state-machines/arith/src/arith_helpers.rs @@ -13,20 +13,18 @@ const REMU_W: u8 = 0xbd; const DIV_W: u8 = 0xbe; const REM_W: u8 = 0xbf; +const FLAG_NAMES: [&str; 7] = ["m32", "div", "na", "nb", "np", "nr", "sext"]; + pub trait ArithHelpers { - fn calculate_flags_and_ranges( - a: u64, - b: u64, - op: u8, - div: &mut u64, - m32: &mut u64, - na: &mut u64, - nb: &mut u64, - nr: &mut u64, - np: &mut u64, - na32: &mut u64, - nd32: &mut u64, - ) -> [u64; 8] { + fn calculate_flags_and_ranges(a: u64, b: u64, op: u8) -> [u64; 10] { + let mut m32: u64 = 0; + let mut div: u64 = 0; + let mut na: u64 = 0; + let mut nb: u64 = 0; + let mut np: u64 = 0; + let mut nr: u64 = 0; + let mut sext: u64 = 0; + let mut range_a1: u64 = 0; let mut range_b1: u64 = 0; let mut range_c1: u64 = 0; @@ -43,79 +41,194 @@ pub trait ArithHelpers { // alternative: switch operation, - let mut sa: u64 = 0; - let mut sb: u64 = 0; + let mut sa = false; + let mut sb = false; + let mut rem32 = false; match op { - MULU | MULUH => {} + MULU => {} + MULUH => {} MULSUH => { - sa = 1; + sa = true; } - MUL | MULH => { - sa = 1; - sb = 1; + MUL => { + sa = true; + sb = true; + } + MULH => { + sa = true; + sb = true; } MUL_W => { - *m32 = 1; - sa = 1; - sb = 1; + m32 = 1; + sext = if ((a * b) & 0xFFFF_FFFF) & 0x8000_0000 != 0 { 1 } else { 0 }; + } + DIVU => { + div = 1; + } + REMU => { + div = 1; } - DIVU | REMU => { - *div = 1; + DIV => { + sa = true; + sb = true; + div = 1; } - DIV | REM => { - sa = 1; - sb = 1; - *div = 1; + REM => { + sa = true; + sb = true; + div = 1; } - DIVU_W | REMU_W => { + DIVU_W => { // divu_w, remu_w - *div = 1; - *m32 = 1; + div = 1; + m32 = 1; + sext = if ((a as u32 / b as u32) as i32) < 0 { 1 } else { 0 }; + } + REMU_W => { + // divu_w, remu_w + div = 1; + m32 = 1; + rem32 = true; + sext = if ((a as u32 % b as u32) as i32) < 0 { 1 } else { 0 }; + } + DIV_W => { + // div_w, rem_w + sa = true; + sb = true; + div = 1; + m32 = 1; + sext = if (a as i32 / b as i32) < 0 { 1 } else { 0 }; } - DIV_W | REM_W => { + REM_W => { // div_w, rem_w - sa = 1; - sb = 1; - *div = 1; - *m32 = 1; + sa = true; + sb = true; + div = 1; + m32 = 1; + rem32 = true; + sext = if (a as i32 % b as i32) < 0 { 1 } else { 0 }; } _ => { panic!("Invalid opcode"); } } - *na = if sa == 1 && (a as i64) < 0 { 1 } else { 0 }; - *nb = if sb == 1 && (b as i64) < 0 { 1 } else { 0 }; - *np = *na ^ *nb; - *nr = if *div == 1 { *na } else { 0 }; - *na32 = if *m32 == 1 { *na } else { 0 }; - *nd32 = if *m32 == 1 { *nr } else { 0 }; - - if *m32 == 1 { - range_a1 = sa + *na; - range_b1 = sb + *nb; - - if *div == 1 { - range_c1 = if *np == 1 || *na32 == 1 { 2 } else { 1 }; - range_d1 = if (*np == 1 && sa == 1) || *nd32 == 1 { 1 } else { 2 }; + if sa { + na = if m32 == 1 { + if (a as i32) < 0 { + 1 + } else { + 0 + } + } else { + if (a as i64) < 0 { + 1 + } else { + 0 + } + } + } + if sb { + nb = if m32 == 1 { + if (b as i32) < 0 { + 1 + } else { + 0 + } } else { - range_c1 = 1 + *na32; + if (b as i64) < 0 { + 1 + } else { + 0 + } } + } + + np = na ^ nb; + nr = if div == 1 { na } else { 0 }; + + if m32 == 1 { + // mulw, divu_w, remu_w, div_w, rem_w + range_a1 = if sa { 1 + na } else { 0 }; + range_b1 = if sb { 1 + nb } else { 0 }; + range_c1 = if !rem32 { + sext + 1 + } else if sa { + 1 + np + } else { + 0 + }; + range_d1 = if rem32 { + sext + 1 + } else if sa { + 1 + nr + } else { + 0 + }; } else { - // m32 = 0 - range_b3 = if sb == 1 { 1 + *na } else { 0 }; - if sa == 1 { - // !m32 && sa - range_a3 = 1 + *na; - if *div == 1 { - // !m32 && sa && div - range_c3 = 1 + *np; - range_d3 = range_c3; + // mulu, muluh, mulsuh, mul, mulh, div, rem, divu, remu + if sa { + // mulsuh, mul, mulh, div, rem + range_a3 = 1 + na; + if div == 1 { + // div, rem + range_c3 = 1 + np; + range_d3 = 1 + nr; + } else { + range_d3 = 1 + np; } } + // sb => mul, mulh, div, rem + range_b3 = if sb { 1 + nb } else { 0 }; } - [range_a1, range_b1, range_c1, range_d1, range_a3, range_b3, range_c3, range_d3] + // range_ab / range_cd + // + // a3 a1 b3 b1 + // rid c3 c1 d3 d1 range 2^16 2^15 notes + // --- -- -- -- -- ----- ---- ---- ------------------------- + // 0 F F F F ab cd 4 0 + // 1 F F + F cd 3 1 b3 sign => a3 sign + // 2 F F - F cd 3 1 b3 sign => a3 sign + // 3 + F F F ab 3 1 c3 sign => d3 sign + // 4 + F + F ab cd 2 2 + // 5 + F - F ab cd 2 2 + // 6 - F F F ab 3 1 c3 sign => d3 sign + // 7 - F + F ab cd 2 2 + // 8 - F - F ab cd 2 2 + // 9 F F F + cd a1 sign <=> b1 sign / d1 sign => c1 sign + // 10 F F F - cd a1 sign <=> b1 sign / d1 sign => c1 sign + // 11 F + F F cd 3 1 a1 sign <=> b1 sign + // 12 F + F + ab cd 2 2 + // 13 F + F - ab cd 2 2 + // 14 F - F F cd 3 1 a1 sign <=> b1 sign + // 15 F - F + ab cd 2 2 + // 16 F - F - ab cd 2 2 + + assert!(range_a1 == 0 || range_a3 == 0, "range_a1:{} range_a3:{}", range_a1, range_a3); + assert!(range_b1 == 0 || range_b3 == 0, "range_b1:{} range_b3:{}", range_b1, range_b3); + assert!(range_c1 == 0 || range_c3 == 0, "range_c1:{} range_c3:{}", range_c1, range_c3); + assert!(range_d1 == 0 || range_d3 == 0, "range_d1:{} range_d3:{}", range_d1, range_d3); + + let range_ab = (range_a3 + range_a1) * 3 + + range_b3 + + range_b1 + + if (range_a1 + range_b1) > 0 { 8 } else { 0 }; + + let range_cd = (range_c3 + range_c1) * 3 + + range_d3 + + range_d1 + + if (range_c1 + range_d1) > 0 { 8 } else { 0 }; + + let ranges = range_a3 * 1000_0000 + + range_b3 * 100_0000 + + range_c3 * 10_0000 + + range_d3 * 1000 + + range_a1 * 1000 + + range_b1 * 100 + + range_c1 * 10 + + range_d1; + [m32, div, na, nb, np, nr, sext, range_ab, range_cd, ranges] } /* fn calculate_flags( @@ -209,13 +322,13 @@ pub trait ArithHelpers { b: [i64; 4], c: [i64; 4], d: [i64; 4], + m32: i64, div: i64, - fab: i64, na: i64, nb: i64, np: i64, nr: i64, - m32: i64, + fab: i64, ) -> [i64; 8] { // TODO: unroll this function in variants (div,m32) and (na,nb,nr,np) // div, m32, na, nb === f(div,m32,na,nb) => fa, nb, nr @@ -294,6 +407,23 @@ pub trait ArithHelpers { } } +fn flags_to_strings(mut flags: u64, flag_names: &[&str]) -> String { + let mut res = String::new(); + + for flag_name in flag_names { + if (flags & 1u64) != 0 { + if !res.is_empty() { + res = res + ","; + } + res = res + *flag_name; + } + flags >>= 1; + if flags == 0 { + break; + }; + } + res +} #[test] fn test_calculate_range_checks() { struct TestArithHelpers {} @@ -301,16 +431,17 @@ fn test_calculate_range_checks() { const MIN_N_64: u64 = 0x8000_0000_0000_0000; const MAX_P_64: u64 = 0x7FFF_FFFF_FFFF_FFFF; + const MAX_P_32: u64 = 0x0000_0000_FFFF_FFFF; const MAX_64: u64 = 0xFFFF_FFFF_FFFF_FFFF; - const ALL: u64 = 0x0033; + const ALL_64: u64 = 0x0033; const ALL_P_64: u64 = 0x0034; const ALL_N_64: u64 = 0x0035; const END: u64 = 0x0036; - const ALL_P_64_VALUES: [u64; 5] = [0, 1, MAX_P_64, END, 0]; - const ALL_N_64_VALUES: [u64; 5] = [MIN_N_64, MAX_64, END, 0, 0]; - const ALL_64_VALUES: [u64; 5] = [0, 1, MAX_P_64, MAX_64, MIN_N_64]; + const ALL_P_64_VALUES: [u64; 6] = [0, 1, MAX_P_32, MAX_P_64, 0, END]; + const ALL_N_64_VALUES: [u64; 6] = [MIN_N_64, MAX_64, END, 0, 0, 0]; + const ALL_64_VALUES: [u64; 6] = [0, 1, MAX_P_64, MAX_64, MIN_N_64, MAX_P_32]; const F_M32: u64 = 0x0001; const F_DIV: u64 = 0x0002; @@ -318,88 +449,341 @@ fn test_calculate_range_checks() { const F_NB: u64 = 0x0008; const F_NP: u64 = 0x0010; const F_NR: u64 = 0x0020; - const F_NA32: u64 = 0x0040; - const F_ND32: u64 = 0x0080; + const F_SEXT: u64 = 0x0040; + + // range_ab / range_cd + // + // a3 a1 b3 b1 + // rid c3 c1 d3 d1 range 2^16 2^15 notes + // --- -- -- -- -- ----- ---- ---- ------------------------- + + const R_FF: u64 = 0; // 0 F F F F ab cd 4 0 + const R_3FP: u64 = 1; // 1 F F + F cd 3 1 b3 sign => a3 sign + const R_3FN: u64 = 2; // 2 F F - F cd 3 1 b3 sign => a3 sign + const R_3PF: u64 = 3; // 3 + F F F ab 3 1 c3 sign => d3 sign + const R_3PP: u64 = 4; // 4 + F + F ab cd 2 2 + const R_3PN: u64 = 5; // 5 + F - F ab cd 2 2 + const R_3NF: u64 = 6; // 6 - F F F ab 3 1 c3 sign => d3 sign + const R_3NP: u64 = 7; // 7 - F + F ab cd 2 2 + const R_3NN: u64 = 8; // 8 - F - F ab cd 2 2 + const R_1FP: u64 = 9; // 9 F F F + cd a1 sign <=> b1 sign / d1 sign => c1 sign + const R_1FN: u64 = 10; // 10 F F F - cd a1 sign <=> b1 sign / d1 sign => c1 sign + const R_1PF: u64 = 11; // 11 F + F F cd 3 1 a1 sign <=> b1 sign + const R_1PP: u64 = 12; // 12 F + F + ab cd 2 2 + const R_1PN: u64 = 13; // 13 F + F - ab cd 2 2 + const R_1NF: u64 = 14; // 14 F - F F cd 3 1 a1 sign <=> b1 sign + const R_1NP: u64 = 15; // 15 F - F + ab cd 2 2 + const R_1NN: u64 = 16; // 16 F - F - ab cd 2 2 struct TestParams { op: u8, a: u64, b: u64, flags: u64, + range_ab: u64, + range_cd: u64, } // NOTE: update TEST_COUNT with number of tests, ALL,ALL => 3*3 = 9 - const TEST_COUNT: u32 = 20; + const TEST_COUNT: u32 = 295; let tests = [ - // flags: div, m32, sa, sb, na, nr, np, np, na32, nd32 - TestParams { op: MULU, a: ALL, b: ALL, flags: 0 }, - TestParams { op: MULUH, a: ALL, b: ALL, flags: 0 }, - TestParams { op: MULSUH, a: ALL_P_64, b: ALL, flags: 0 }, - TestParams { op: MULSUH, a: ALL_N_64, b: ALL, flags: F_NA + F_NP }, - TestParams { op: MUL_W, a: ALL_P_64, b: ALL_P_64, flags: F_M32 }, - TestParams { op: MUL_W, a: ALL_N_64, b: ALL_P_64, flags: F_M32 + F_NA + F_NP }, - TestParams { op: MUL_W, a: ALL_P_64, b: ALL_N_64, flags: F_M32 + F_NB + F_NP }, - TestParams { op: MUL_W, a: ALL_N_64, b: ALL_N_64, flags: F_M32 + F_NA + F_NB }, - TestParams { op: DIV, a: 0, b: 0, flags: F_DIV }, - TestParams { op: DIV, a: MIN_N_64, b: MAX_P_64, flags: F_DIV + F_NA + F_NP + F_NR }, + // 0 - MULU + TestParams { + op: MULU, + a: ALL_64, + b: ALL_64, + flags: 0x0000, + range_ab: R_FF, + range_cd: R_FF, + }, + // 1 - MULU + TestParams { + op: MULUH, + a: ALL_64, + b: ALL_64, + flags: 0x0000, + range_ab: R_FF, + range_cd: R_FF, + }, + // 2 - MULSHU + TestParams { + op: MULSUH, + a: ALL_P_64, + b: ALL_64, + flags: 0x0000, + range_ab: R_3PF, + range_cd: R_3FP, + }, + // 3 - MULSHU + TestParams { + op: MULSUH, + a: ALL_N_64, + b: ALL_64, + flags: F_NA + F_NP, + range_ab: R_3NF, + range_cd: R_3FN, + }, + // 4 - MUL + TestParams { + op: MUL, + a: ALL_P_64, + b: ALL_P_64, + flags: 0, + range_ab: R_3PP, + range_cd: R_3FP, + }, + // 5 - MUL + TestParams { + op: MUL, + a: ALL_N_64, + b: ALL_N_64, + flags: F_NA + F_NB, + range_ab: R_3NN, + range_cd: R_3FP, + }, + // 6 - MUL + TestParams { + op: MUL, + a: ALL_N_64, + b: ALL_P_64, + flags: F_NA + F_NP, + range_ab: R_3NP, + range_cd: R_3FN, + }, + // 7 - MUL + TestParams { + op: MUL, + a: ALL_P_64, + b: ALL_N_64, + flags: F_NB + F_NP, + range_ab: R_3PN, + range_cd: R_3FN, + }, + // 8 - MULH + TestParams { + op: MULH, + a: ALL_P_64, + b: ALL_P_64, + flags: 0, + range_ab: R_3PP, + range_cd: R_3FP, + }, + // 9 - MULH + TestParams { + op: MULH, + a: ALL_N_64, + b: ALL_N_64, + flags: F_NA + F_NB, + range_ab: R_3NN, + range_cd: R_3FP, + }, + // 10 - MULH + TestParams { + op: MULH, + a: ALL_N_64, + b: ALL_P_64, + flags: F_NA + F_NP, + range_ab: R_3NP, + range_cd: R_3FN, + }, + // 11 - MULH + TestParams { + op: MULH, + a: ALL_P_64, + b: ALL_N_64, + flags: F_NB + F_NP, + range_ab: R_3PN, + range_cd: R_3FN, + }, + // 12 - MULW + TestParams { + op: MUL_W, + a: 0x0000_0000, + b: 0x0000_0000, + flags: F_M32, + range_ab: R_FF, + range_cd: R_1PF, + }, + // 13 - MUL: 0x00000002 (+/32 bits) * 0x40000000 (+/32 bits) = 0x80000000 (-/32 bits) + TestParams { + op: MUL_W, + a: 0x0000_0002, + b: 0x4000_0000, + flags: F_M32 + F_SEXT, + range_ab: R_FF, + range_cd: R_1NF, + }, + // 14 - MUL + TestParams { + op: MUL_W, + a: 0x0000_0002, + b: 0x8000_0000, + flags: F_M32, + range_ab: R_FF, + range_cd: R_1PF, + }, + // 15 - MUL + TestParams { + op: MUL_W, + a: 0xFFFF_FFFF, + b: 1, + flags: F_M32 + F_SEXT, + range_ab: R_FF, + range_cd: R_1NF, + }, + // 16 - MUL + TestParams { + op: MUL_W, + a: 0xFFFF_FFFF, + b: 0x0000_00000, + flags: F_M32, + range_ab: R_FF, + range_cd: R_1PF, + }, + // 17 - MUL + TestParams { + op: MUL_W, + a: 0x7FFF_FFFF, + b: 2, + flags: F_M32 + F_SEXT, + range_ab: R_FF, + range_cd: R_1NF, + }, + // 18 - MUL + TestParams { + op: MUL_W, + a: 0xBFFF_FFFF, + b: 0x0000_0002, + flags: F_M32, + range_ab: R_FF, + range_cd: R_1PF, + }, + // 19 - MUL: 0xFFFF_FFFF * 0xFFFF_FFFF = 0xFFFF_FFFE_0000_0001 + TestParams { + op: MUL_W, + a: 0xFFFF_FFFF, + b: 0xFFFF_FFFF, + flags: F_M32, + range_ab: R_FF, + range_cd: R_1PF, + }, + // 20 - MUL: 0xFFFF_FFFF * 0x0FFF_FFFF = 0x0FFF_FFFE_F000_0001 + TestParams { + op: MUL_W, + a: 0xFFFF_FFFF, + b: 0x0FFF_FFFF, + flags: F_M32 + F_SEXT, + range_ab: R_FF, + range_cd: R_1NF, + }, + // 21 - MUL: 0x8000_0000 * 0x8000_0000 = 0x4000_0000_0000_0000 + TestParams { + op: MUL_W, + a: 0x8000_0000, + b: 0x8000_0000, + flags: F_M32, + range_ab: R_FF, + range_cd: R_1PF, + }, + // 22 - DIVU + TestParams { op: DIVU, a: ALL_64, b: ALL_64, flags: F_DIV, range_ab: R_FF, range_cd: R_FF }, + // 23 - REMU + TestParams { op: DIVU, a: ALL_64, b: ALL_64, flags: F_DIV, range_ab: R_FF, range_cd: R_FF }, + // 24 - DIV + TestParams { + op: DIV, + a: MAX_P_64, + b: MAX_P_64, + flags: F_DIV, + range_ab: R_3PP, + range_cd: R_3PP, + }, + // 25 - DIV + TestParams { + op: DIV, + a: MIN_N_64, + b: MAX_P_64, + flags: F_DIV + F_NA + F_NP + F_NR, + range_ab: R_3NP, + range_cd: R_3NN, + }, + // 26 - DIV + TestParams { + op: DIV, + a: MAX_P_64, + b: MIN_N_64, + flags: F_DIV + F_NB + F_NP, + range_ab: R_3PN, + range_cd: R_3NP, + }, + // 27 - DIV + TestParams { + op: DIV, + a: MIN_N_64, + b: MIN_N_64, + flags: F_DIV + F_NA + F_NB + F_NR, + range_ab: R_3NN, + range_cd: R_3PN, + }, + // REM + // DIVU_W + // REMU_W + // DIV_W + // REM_W ]; let mut count = 0; let mut index: u32 = 0; for test in tests { - let a_values = if test.a == ALL { + let a_values = if test.a == ALL_64 { ALL_64_VALUES } else if test.a == ALL_N_64 { ALL_N_64_VALUES } else if test.a == ALL_P_64 { ALL_P_64_VALUES } else { - [test.a, END, 0, 0, 0] + [test.a, END, 0, 0, 0, 0] }; for a in a_values { if a == END { break; }; - let b_values = if test.b == ALL { + let b_values = if test.b == ALL_64 { ALL_64_VALUES } else if test.b == ALL_N_64 { ALL_N_64_VALUES } else if test.b == ALL_P_64 { ALL_P_64_VALUES } else { - [test.b, END, 0, 0, 0] + [test.b, END, 0, 0, 0, 0] }; for b in b_values { if b == END { break; }; - let mut div: u64 = 0; - let mut m32: u64 = 0; - let mut na: u64 = 0; - let mut nb: u64 = 0; - let mut nr: u64 = 0; - let mut np: u64 = 0; - let mut na32: u64 = 0; - let mut nd32: u64 = 0; - - TestArithHelpers::calculate_flags_and_ranges( - a, b, test.op, &mut div, &mut m32, &mut na, &mut nb, &mut nr, &mut np, - &mut na32, &mut nd32, - ); - let flags = - m32 + div * 2 + na * 4 + nb * 8 + np * 16 + nr * 32 + na32 * 64 + nd32 * 128; + let [m32, div, na, nb, np, nr, sext, range_ab, range_cd, ranges] = + TestArithHelpers::calculate_flags_and_ranges(a, b, test.op); + + let flags = m32 + div * 2 + na * 4 + nb * 8 + np * 16 + nr * 32 + sext * 64; assert_eq!( - flags, - test.flags, - "testing #{} op:0x{:x} with a:0x{:X} b:0x{:X} flags:{:b} vs {:b} [div, m32, sa, sb, na, nb, np, nr, na32, nd32]", + [flags, range_ab, range_cd], + [test.flags, test.range_ab, test.range_cd], + "testing #{} op:0x{:x} with a:0x{:X} b:0x{:X} flags:{:b}[{}]/{:b}[{}] range_ab:{}/{} range_cd:{}/{} ranges:{}", index, test.op, a, b, flags, + flags_to_strings(flags, &FLAG_NAMES), test.flags, + flags_to_strings(test.flags, &FLAG_NAMES), + range_ab, + test.range_ab, + range_cd, + test.range_cd, + ranges ); count += 1; } diff --git a/state-machines/binary/pil/binary_extension.pil b/state-machines/binary/pil/binary_extension.pil index 2e1587f4..a456a3ac 100644 --- a/state-machines/binary/pil/binary_extension.pil +++ b/state-machines/binary/pil/binary_extension.pil @@ -90,7 +90,7 @@ airtemplate BinaryExtension(const int N = 2**18, const int operation_bus_id = BI expr in1_high = in1[4] + in1[5]*2**8 + in1[6]*2**16 + in1[7]*2**24; col witness main_step; - col witness multiplicity; +// col witness multiplicity; lookup_proves( operation_bus_id, [ @@ -104,8 +104,9 @@ airtemplate BinaryExtension(const int N = 2**18, const int operation_bus_id = BI out[0][1] + out[1][1] + out[2][1] + out[3][1] + out[4][1] + out[5][1] + out[6][1] + out[7][1], 0 ], - multiplicity + 1 +// multiplicity ); range_check(colu: in2[0], min: 0, max: 2**24-1, sel: op_is_shift); -} \ No newline at end of file +} From 80af8efc6667b6371b864e7cdc24c0d7009451ab Mon Sep 17 00:00:00 2001 From: zkronos73 Date: Tue, 22 Oct 2024 04:05:56 +0000 Subject: [PATCH 09/17] WIP arith --- state-machines/arith/pil/arith.pil | 2 +- state-machines/arith/pil/arith_table.pil | 12 +- state-machines/arith/src/arith_helpers.rs | 552 +++++++++++++++++----- 3 files changed, 446 insertions(+), 120 deletions(-) diff --git a/state-machines/arith/pil/arith.pil b/state-machines/arith/pil/arith.pil index cdd83cdf..6d4ad043 100644 --- a/state-machines/arith/pil/arith.pil +++ b/state-machines/arith/pil/arith.pil @@ -227,7 +227,7 @@ airtemplate Arith(int N = 2**18, const int operation_bus_id, const int dual_resu col witness range_ab; col witness range_cd; - arith_table_assumes(op, m32, div, na, nb, np, nr, sext, range_ab, range_cd); + arith_table_assumes(op, m32, div, na, nb, np, nr, sext, secondary_res, range_ab, range_cd); // 0 - a1/c1 // 1 - b1/d1 diff --git a/state-machines/arith/pil/arith_table.pil b/state-machines/arith/pil/arith_table.pil index 95d3032c..d2225d24 100644 --- a/state-machines/arith/pil/arith_table.pil +++ b/state-machines/arith/pil/arith_table.pil @@ -170,9 +170,13 @@ airtemplate ArithTable(int N = 2**6) { lookup_proves(ARITH_TABLE_ID, mul: multiplicity, cols: [OP, FLAGS_AND_RANGES, 0, 0]); } -function arith_table_assumes( const expr op, const expr m32, const expr div, const expr na, const expr nb, - const expr np, const expr nr, const expr sext, const expr range_ab, const expr range_cd) { +function arith_table_assumes( const expr op, const expr flag_m32, const expr flag_div, const expr flag_na, + const expr flag_nb, const expr flag_np, const expr flag_nr, const expr flag_sext, + const expr flag_secondary_res, const expr range_ab, const expr range_cd) { + + // TODO: #pragma binary flag_m32 => check any constraint on compilation time // TODO: define rule for empty rows - lookup_assumes(ARITH_TABLE_ID, cols: [ op, m32 + 2 * div + 4 * na + 8 * nb + 16 * np + 32 * nr + 64 * sext, - range_ab, range_cd]); + lookup_assumes(ARITH_TABLE_ID, cols: [ op, flag_m32 + 2 * flag_div + 4 * flag_na + 8 * flag_nb + + 16 * flag_np + 32 * flag_nr + 64 * flag_sext + + 128 * flag_secondary_res, range_ab, range_cd]); } diff --git a/state-machines/arith/src/arith_helpers.rs b/state-machines/arith/src/arith_helpers.rs index e41459eb..edb90a47 100644 --- a/state-machines/arith/src/arith_helpers.rs +++ b/state-machines/arith/src/arith_helpers.rs @@ -13,10 +13,10 @@ const REMU_W: u8 = 0xbd; const DIV_W: u8 = 0xbe; const REM_W: u8 = 0xbf; -const FLAG_NAMES: [&str; 7] = ["m32", "div", "na", "nb", "np", "nr", "sext"]; +const FLAG_NAMES: [&str; 8] = ["m32", "div", "na", "nb", "np", "nr", "sext", "sec"]; pub trait ArithHelpers { - fn calculate_flags_and_ranges(a: u64, b: u64, op: u8) -> [u64; 10] { + fn calculate_flags_and_ranges(a: u64, b: u64, op: u8) -> [u64; 11] { let mut m32: u64 = 0; let mut div: u64 = 0; let mut na: u64 = 0; @@ -24,6 +24,7 @@ pub trait ArithHelpers { let mut np: u64 = 0; let mut nr: u64 = 0; let mut sext: u64 = 0; + let mut secondary_res: u64 = 0; let mut range_a1: u64 = 0; let mut range_b1: u64 = 0; @@ -47,9 +48,12 @@ pub trait ArithHelpers { match op { MULU => {} - MULUH => {} + MULUH => { + secondary_res = 1; + } MULSUH => { sa = true; + secondary_res = 1; } MUL => { sa = true; @@ -58,6 +62,7 @@ pub trait ArithHelpers { MULH => { sa = true; sb = true; + secondary_res = 1; } MUL_W => { m32 = 1; @@ -68,6 +73,7 @@ pub trait ArithHelpers { } REMU => { div = 1; + secondary_res = 1; } DIV => { sa = true; @@ -78,6 +84,7 @@ pub trait ArithHelpers { sa = true; sb = true; div = 1; + secondary_res = 1; } DIVU_W => { // divu_w, remu_w @@ -91,6 +98,7 @@ pub trait ArithHelpers { m32 = 1; rem32 = true; sext = if ((a as u32 % b as u32) as i32) < 0 { 1 } else { 0 }; + secondary_res = 1; } DIV_W => { // div_w, rem_w @@ -108,6 +116,7 @@ pub trait ArithHelpers { m32 = 1; rem32 = true; sext = if (a as i32 % b as i32) < 0 { 1 } else { 0 }; + secondary_res = 1; } _ => { panic!("Invalid opcode"); @@ -144,7 +153,10 @@ pub trait ArithHelpers { } } - np = na ^ nb; + // a == 0 || b == 0 => np == 0 ==> how was a signed operation + // after that sign of np was verified with range check. + + np = if (a != 0) && (b != 0) { na ^ nb } else { 0 }; nr = if div == 1 { na } else { 0 }; if m32 == 1 { @@ -228,7 +240,7 @@ pub trait ArithHelpers { + range_b1 * 100 + range_c1 * 10 + range_d1; - [m32, div, na, nb, np, nr, sext, range_ab, range_cd, ranges] + [m32, div, na, nb, np, nr, sext, secondary_res, range_ab, range_cd, ranges] } /* fn calculate_flags( @@ -424,57 +436,251 @@ fn flags_to_strings(mut flags: u64, flag_names: &[&str]) -> String { } res } + +const F_M32: u64 = 0x0001; +const F_DIV: u64 = 0x0002; +const F_NA: u64 = 0x0004; +const F_NB: u64 = 0x0008; +const F_NP: u64 = 0x0010; +const F_NR: u64 = 0x0020; +const F_SEXT: u64 = 0x0040; +const F_SEC: u64 = 0x0080; + +// range_ab / range_cd +// +// a3 a1 b3 b1 +// rid c3 c1 d3 d1 range 2^16 2^15 notes +// --- -- -- -- -- ----- ---- ---- ------------------------- + +const R_FF: u64 = 0; // 0 F F F F ab cd 4 0 +const R_3FP: u64 = 1; // 1 F F + F cd 3 1 b3 sign => a3 sign +const R_3FN: u64 = 2; // 2 F F - F cd 3 1 b3 sign => a3 sign +const R_3PF: u64 = 3; // 3 + F F F ab 3 1 c3 sign => d3 sign +const R_3PP: u64 = 4; // 4 + F + F ab cd 2 2 +const R_3PN: u64 = 5; // 5 + F - F ab cd 2 2 +const R_3NF: u64 = 6; // 6 - F F F ab 3 1 c3 sign => d3 sign +const R_3NP: u64 = 7; // 7 - F + F ab cd 2 2 +const R_3NN: u64 = 8; // 8 - F - F ab cd 2 2 +const R_1FP: u64 = 9; // 9 F F F + cd a1 sign <=> b1 sign / d1 sign => c1 sign +const R_1FN: u64 = 10; // 10 F F F - cd a1 sign <=> b1 sign / d1 sign => c1 sign +const R_1PF: u64 = 11; // 11 F + F F cd 3 1 a1 sign <=> b1 sign +const R_1PP: u64 = 12; // 12 F + F + ab cd 2 2 +const R_1PN: u64 = 13; // 13 F + F - ab cd 2 2 +const R_1NF: u64 = 14; // 14 F - F F cd 3 1 a1 sign <=> b1 sign +const R_1NP: u64 = 15; // 15 F - F + ab cd 2 2 +const R_1NN: u64 = 16; // 16 F - F - ab cd 2 2 + +const MIN_N_64: u64 = 0x8000_0000_0000_0000; +const MIN_N_32: u64 = 0x0000_0000_8000_0000; +const MAX_P_64: u64 = 0x7FFF_FFFF_FFFF_FFFF; +const MAX_P_32: u64 = 0x0000_0000_7FFF_FFFF; +const MAX_32: u64 = 0x0000_0000_FFFF_FFFF; +const MAX_64: u64 = 0xFFFF_FFFF_FFFF_FFFF; + +// value cannot used as specific cases +const ALL_64: u64 = 0x0033; +const ALL_NZ_64: u64 = 0x0034; +const ALL_P_64: u64 = 0x0035; +const ALL_NZ_P_64: u64 = 0x0036; +const ALL_N_64: u64 = 0x0037; + +const ALL_32: u64 = 0x0043; +const ALL_NZ_32: u64 = 0x0044; +const ALL_P_32: u64 = 0x0045; +const ALL_N_32: u64 = 0x0046; +const ALL_NZ_P_32: u64 = 0x0047; + +const VALUES_END: u64 = 0x004D; + +fn get_test_values(value: u64) -> [u64; 16] { + match value { + ALL_64 => [ + 0, + 1, + 2, + 3, + MAX_P_32 - 1, + MAX_P_32, + MIN_N_32, + MAX_32 - 1, + MAX_32, + MAX_32 + 1, + MAX_P_64 - 1, + MAX_P_64, + MAX_64 - 1, + MIN_N_64, + MIN_N_64 + 1, + MAX_64, + ], + ALL_NZ_64 => [ + 1, + 2, + 3, + MAX_P_32 - 1, + MAX_P_32, + MIN_N_32, + MAX_32 - 1, + MAX_32, + MAX_32 + 1, + MAX_P_64 - 1, + MAX_P_64, + MAX_64 - 1, + MIN_N_64, + MIN_N_64 + 1, + MAX_64, + VALUES_END, + ], + ALL_P_64 => [ + 0, + 1, + 2, + 3, + MAX_P_32 - 1, + MAX_P_32, + MIN_N_32, + MAX_32 - 1, + MAX_32, + MAX_32 + 1, + MAX_P_64 - 1, + MAX_P_64, + VALUES_END, + 0, + 0, + 0, + ], + ALL_NZ_P_64 => [ + 1, + 2, + 3, + MAX_P_32 - 1, + MAX_P_32, + MIN_N_32, + MAX_32 - 1, + MAX_32, + MAX_32 + 1, + MAX_P_64 - 1, + MAX_P_64, + VALUES_END, + 0, + 0, + 0, + 0, + ], + ALL_N_64 => [ + MIN_N_64, + MIN_N_64 + 1, + MIN_N_64 + 2, + MIN_N_64 + 3, + 0x8000_0000_7FFF_FFFF, + 0x8FFF_FFFF_7FFF_FFFF, + 0xEFFF_FFFF_FFFF_FFFF, + MAX_64 - 3, + MAX_64 - 2, + MAX_64 - 1, + MAX_64, + VALUES_END, + 0, + 0, + 0, + 0, + ], + ALL_32 => [ + 0, + 1, + 2, + 3, + MAX_P_32 - 1, + MAX_P_32, + MIN_N_32, + MAX_32 - 1, + MAX_32, + VALUES_END, + 0, + 0, + 0, + 0, + 0, + 0, + ], + ALL_32 => [ + 1, + 2, + 3, + MAX_P_32 - 1, + MAX_P_32, + MIN_N_32, + MAX_32 - 1, + MAX_32, + VALUES_END, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ], + ALL_P_32 => [ + 0, + 1, + 2, + 3, + MAX_P_32 - 1, + MAX_P_32, + MIN_N_32, + MAX_32 - 1, + MAX_32, + VALUES_END, + 0, + 0, + 0, + 0, + 0, + 0, + ], + ALL_NZ_P_32 => [ + 1, + 2, + 3, + MAX_P_32 - 1, + MAX_P_32, + MIN_N_32, + MAX_32 - 1, + MAX_32, + VALUES_END, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ], + ALL_N_32 => [ + MIN_N_32, + MIN_N_32 + 1, + MIN_N_32 + 2, + MIN_N_32 + 3, + MAX_32 - 3, + MAX_32 - 2, + MAX_32 - 1, + MAX_32, + VALUES_END, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ], + _ => [value, VALUES_END, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + } +} #[test] fn test_calculate_range_checks() { struct TestArithHelpers {} impl ArithHelpers for TestArithHelpers {} - - const MIN_N_64: u64 = 0x8000_0000_0000_0000; - const MAX_P_64: u64 = 0x7FFF_FFFF_FFFF_FFFF; - const MAX_P_32: u64 = 0x0000_0000_FFFF_FFFF; - const MAX_64: u64 = 0xFFFF_FFFF_FFFF_FFFF; - - const ALL_64: u64 = 0x0033; - const ALL_P_64: u64 = 0x0034; - const ALL_N_64: u64 = 0x0035; - - const END: u64 = 0x0036; - const ALL_P_64_VALUES: [u64; 6] = [0, 1, MAX_P_32, MAX_P_64, 0, END]; - const ALL_N_64_VALUES: [u64; 6] = [MIN_N_64, MAX_64, END, 0, 0, 0]; - const ALL_64_VALUES: [u64; 6] = [0, 1, MAX_P_64, MAX_64, MIN_N_64, MAX_P_32]; - - const F_M32: u64 = 0x0001; - const F_DIV: u64 = 0x0002; - const F_NA: u64 = 0x0004; - const F_NB: u64 = 0x0008; - const F_NP: u64 = 0x0010; - const F_NR: u64 = 0x0020; - const F_SEXT: u64 = 0x0040; - - // range_ab / range_cd - // - // a3 a1 b3 b1 - // rid c3 c1 d3 d1 range 2^16 2^15 notes - // --- -- -- -- -- ----- ---- ---- ------------------------- - - const R_FF: u64 = 0; // 0 F F F F ab cd 4 0 - const R_3FP: u64 = 1; // 1 F F + F cd 3 1 b3 sign => a3 sign - const R_3FN: u64 = 2; // 2 F F - F cd 3 1 b3 sign => a3 sign - const R_3PF: u64 = 3; // 3 + F F F ab 3 1 c3 sign => d3 sign - const R_3PP: u64 = 4; // 4 + F + F ab cd 2 2 - const R_3PN: u64 = 5; // 5 + F - F ab cd 2 2 - const R_3NF: u64 = 6; // 6 - F F F ab 3 1 c3 sign => d3 sign - const R_3NP: u64 = 7; // 7 - F + F ab cd 2 2 - const R_3NN: u64 = 8; // 8 - F - F ab cd 2 2 - const R_1FP: u64 = 9; // 9 F F F + cd a1 sign <=> b1 sign / d1 sign => c1 sign - const R_1FN: u64 = 10; // 10 F F F - cd a1 sign <=> b1 sign / d1 sign => c1 sign - const R_1PF: u64 = 11; // 11 F + F F cd 3 1 a1 sign <=> b1 sign - const R_1PP: u64 = 12; // 12 F + F + ab cd 2 2 - const R_1PN: u64 = 13; // 13 F + F - ab cd 2 2 - const R_1NF: u64 = 14; // 14 F - F F cd 3 1 a1 sign <=> b1 sign - const R_1NP: u64 = 15; // 15 F - F + ab cd 2 2 - const R_1NN: u64 = 16; // 16 F - F - ab cd 2 2 - struct TestParams { op: u8, a: u64, @@ -485,8 +691,9 @@ fn test_calculate_range_checks() { } // NOTE: update TEST_COUNT with number of tests, ALL,ALL => 3*3 = 9 - const TEST_COUNT: u32 = 295; + const TEST_COUNT: u32 = 2472; + // NOTE: use 0x0000_0000 instead of 0, to avoid auto-format in one line, 0 is too short. let tests = [ // 0 - MULU TestParams { @@ -497,12 +704,12 @@ fn test_calculate_range_checks() { range_ab: R_FF, range_cd: R_FF, }, - // 1 - MULU + // 1 - MULUH TestParams { op: MULUH, a: ALL_64, b: ALL_64, - flags: 0x0000, + flags: F_SEC, range_ab: R_FF, range_cd: R_FF, }, @@ -511,7 +718,7 @@ fn test_calculate_range_checks() { op: MULSUH, a: ALL_P_64, b: ALL_64, - flags: 0x0000, + flags: F_SEC, range_ab: R_3PF, range_cd: R_3FP, }, @@ -519,12 +726,21 @@ fn test_calculate_range_checks() { TestParams { op: MULSUH, a: ALL_N_64, - b: ALL_64, - flags: F_NA + F_NP, + b: ALL_NZ_64, + flags: F_NA + F_NP + F_SEC, range_ab: R_3NF, range_cd: R_3FN, }, - // 4 - MUL + // 4 - MULSHU + TestParams { + op: MULSUH, + a: ALL_N_64, + b: 0x0000_0000, + flags: F_NA + F_SEC, + range_ab: R_3NF, + range_cd: R_3FP, + }, + // 5 - MUL TestParams { op: MUL, a: ALL_P_64, @@ -533,7 +749,7 @@ fn test_calculate_range_checks() { range_ab: R_3PP, range_cd: R_3FP, }, - // 5 - MUL + // 6 - MUL TestParams { op: MUL, a: ALL_N_64, @@ -542,61 +758,97 @@ fn test_calculate_range_checks() { range_ab: R_3NN, range_cd: R_3FP, }, - // 6 - MUL + // 7 - MUL TestParams { op: MUL, a: ALL_N_64, - b: ALL_P_64, + b: ALL_NZ_P_64, flags: F_NA + F_NP, range_ab: R_3NP, range_cd: R_3FN, }, - // 7 - MUL + // 8 - MUL TestParams { op: MUL, - a: ALL_P_64, + a: ALL_N_64, + b: 0x0000_0000, + flags: F_NA, + range_ab: R_3NP, + range_cd: R_3FP, + }, + // 9 - MUL + TestParams { + op: MUL, + a: ALL_NZ_P_64, b: ALL_N_64, flags: F_NB + F_NP, range_ab: R_3PN, range_cd: R_3FN, }, - // 8 - MULH + // 10 - MUL + TestParams { + op: MUL, + a: 0x0000_0000, + b: ALL_N_64, + flags: F_NB, + range_ab: R_3PN, + range_cd: R_3FP, + }, + // 11 - MULH TestParams { op: MULH, a: ALL_P_64, b: ALL_P_64, - flags: 0, + flags: F_SEC, range_ab: R_3PP, range_cd: R_3FP, }, - // 9 - MULH + // 12 - MULH TestParams { op: MULH, a: ALL_N_64, b: ALL_N_64, - flags: F_NA + F_NB, + flags: F_NA + F_NB + F_SEC, range_ab: R_3NN, range_cd: R_3FP, }, - // 10 - MULH + // 13 - MULH TestParams { op: MULH, a: ALL_N_64, - b: ALL_P_64, - flags: F_NA + F_NP, + b: ALL_NZ_P_64, + flags: F_NA + F_NP + F_SEC, range_ab: R_3NP, range_cd: R_3FN, }, - // 11 - MULH + // 14 - MULH TestParams { op: MULH, - a: ALL_P_64, + a: ALL_N_64, + b: 0x0000_00000, + flags: F_NA + F_SEC, + range_ab: R_3NP, + range_cd: R_3FP, + }, + // 15 - MULH + TestParams { + op: MULH, + a: ALL_NZ_P_64, b: ALL_N_64, - flags: F_NB + F_NP, + flags: F_NB + F_NP + F_SEC, range_ab: R_3PN, range_cd: R_3FN, }, - // 12 - MULW + // 16 - MULH + TestParams { + op: MULH, + a: 0x0000_0000, + b: ALL_N_64, + flags: F_NB + F_SEC, + range_ab: R_3PN, + range_cd: R_3FP, + }, + // 17 - MUL_W TestParams { op: MUL_W, a: 0x0000_0000, @@ -605,7 +857,7 @@ fn test_calculate_range_checks() { range_ab: R_FF, range_cd: R_1PF, }, - // 13 - MUL: 0x00000002 (+/32 bits) * 0x40000000 (+/32 bits) = 0x80000000 (-/32 bits) + // 18 - MUL_W: 0x00000002 (+/32 bits) * 0x40000000 (+/32 bits) = 0x80000000 (-/32 bits) TestParams { op: MUL_W, a: 0x0000_0002, @@ -614,7 +866,7 @@ fn test_calculate_range_checks() { range_ab: R_FF, range_cd: R_1NF, }, - // 14 - MUL + // 19 - MUL_W TestParams { op: MUL_W, a: 0x0000_0002, @@ -623,7 +875,7 @@ fn test_calculate_range_checks() { range_ab: R_FF, range_cd: R_1PF, }, - // 15 - MUL + // 20 - MUL_W TestParams { op: MUL_W, a: 0xFFFF_FFFF, @@ -632,7 +884,7 @@ fn test_calculate_range_checks() { range_ab: R_FF, range_cd: R_1NF, }, - // 16 - MUL + // 21 - MUL_W TestParams { op: MUL_W, a: 0xFFFF_FFFF, @@ -641,7 +893,7 @@ fn test_calculate_range_checks() { range_ab: R_FF, range_cd: R_1PF, }, - // 17 - MUL + // 22 - MUL_W TestParams { op: MUL_W, a: 0x7FFF_FFFF, @@ -650,7 +902,7 @@ fn test_calculate_range_checks() { range_ab: R_FF, range_cd: R_1NF, }, - // 18 - MUL + // 23 - MUL_W TestParams { op: MUL_W, a: 0xBFFF_FFFF, @@ -659,7 +911,7 @@ fn test_calculate_range_checks() { range_ab: R_FF, range_cd: R_1PF, }, - // 19 - MUL: 0xFFFF_FFFF * 0xFFFF_FFFF = 0xFFFF_FFFE_0000_0001 + // 24 - MUL_W: 0xFFFF_FFFF * 0xFFFF_FFFF = 0xFFFF_FFFE_0000_0001 TestParams { op: MUL_W, a: 0xFFFF_FFFF, @@ -668,7 +920,7 @@ fn test_calculate_range_checks() { range_ab: R_FF, range_cd: R_1PF, }, - // 20 - MUL: 0xFFFF_FFFF * 0x0FFF_FFFF = 0x0FFF_FFFE_F000_0001 + // 25 - MUL_W: 0xFFFF_FFFF * 0x0FFF_FFFF = 0x0FFF_FFFE_F000_0001 TestParams { op: MUL_W, a: 0xFFFF_FFFF, @@ -677,7 +929,7 @@ fn test_calculate_range_checks() { range_ab: R_FF, range_cd: R_1NF, }, - // 21 - MUL: 0x8000_0000 * 0x8000_0000 = 0x4000_0000_0000_0000 + // 26 - MUL_W: 0x8000_0000 * 0x8000_0000 = 0x4000_0000_0000_0000 TestParams { op: MUL_W, a: 0x8000_0000, @@ -686,11 +938,25 @@ fn test_calculate_range_checks() { range_ab: R_FF, range_cd: R_1PF, }, - // 22 - DIVU - TestParams { op: DIVU, a: ALL_64, b: ALL_64, flags: F_DIV, range_ab: R_FF, range_cd: R_FF }, - // 23 - REMU - TestParams { op: DIVU, a: ALL_64, b: ALL_64, flags: F_DIV, range_ab: R_FF, range_cd: R_FF }, - // 24 - DIV + // 27 - DIVU + TestParams { + op: DIVU, + a: ALL_64, + b: ALL_64, + flags: F_DIV + 0, + range_ab: R_FF, + range_cd: R_FF, + }, + // 28 - REMU + TestParams { + op: REMU, + a: ALL_64, + b: ALL_64, + flags: F_DIV + F_SEC, + range_ab: R_FF, + range_cd: R_FF, + }, + // 29 - DIV TestParams { op: DIV, a: MAX_P_64, @@ -699,7 +965,7 @@ fn test_calculate_range_checks() { range_ab: R_3PP, range_cd: R_3PP, }, - // 25 - DIV + // 30 - DIV TestParams { op: DIV, a: MIN_N_64, @@ -708,7 +974,7 @@ fn test_calculate_range_checks() { range_ab: R_3NP, range_cd: R_3NN, }, - // 26 - DIV + // 31 - DIV TestParams { op: DIV, a: MAX_P_64, @@ -717,7 +983,7 @@ fn test_calculate_range_checks() { range_ab: R_3PN, range_cd: R_3NP, }, - // 27 - DIV + // 32 - DIV TestParams { op: DIV, a: MIN_N_64, @@ -726,7 +992,78 @@ fn test_calculate_range_checks() { range_ab: R_3NN, range_cd: R_3PN, }, - // REM + // 33 - DIV + TestParams { + op: DIV, + a: 0x0000_0000, + b: MAX_P_64, + flags: F_DIV, + range_ab: R_3PP, + range_cd: R_3PP, + }, + // 34 - DIV + TestParams { + op: DIV, + a: 0x0000_0000, + b: MIN_N_64, + flags: F_DIV + F_NB, + range_ab: R_3PN, + range_cd: R_3PP, + }, + // 35 - REM + TestParams { + op: REM, + a: MAX_P_64, + b: MAX_P_64, + flags: F_DIV + F_SEC, + range_ab: R_3PP, + range_cd: R_3PP, + }, + // 36 - REM + TestParams { + op: REM, + a: MIN_N_64, + b: MAX_P_64, + flags: F_DIV + F_NA + F_NP + F_NR + F_SEC, + range_ab: R_3NP, + range_cd: R_3NN, + }, + // 37 - REM + TestParams { + op: REM, + a: MAX_P_64, + b: MIN_N_64, + flags: F_DIV + F_NB + F_NP + F_SEC, + range_ab: R_3PN, + range_cd: R_3NP, + }, + // 38 - REM + TestParams { + op: REM, + a: MIN_N_64, + b: MIN_N_64, + flags: F_DIV + F_NA + F_NB + F_NR + F_SEC, + range_ab: R_3NN, + range_cd: R_3PN, + }, + // 39 - REM + TestParams { + op: REM, + a: 0x0000_0000, + b: MAX_P_64, + flags: F_DIV + F_SEC, + range_ab: R_3PP, + range_cd: R_3PP, + }, + // 40 - REM + TestParams { + op: REM, + a: 0x0000_0000, + b: MIN_N_64, + flags: F_DIV + F_NB + F_SEC, + range_ab: R_3PN, + range_cd: R_3PP, + }, // DIVU_W // REMU_W // DIV_W @@ -736,36 +1073,21 @@ fn test_calculate_range_checks() { let mut count = 0; let mut index: u32 = 0; for test in tests { - let a_values = if test.a == ALL_64 { - ALL_64_VALUES - } else if test.a == ALL_N_64 { - ALL_N_64_VALUES - } else if test.a == ALL_P_64 { - ALL_P_64_VALUES - } else { - [test.a, END, 0, 0, 0, 0] - }; + let a_values = get_test_values(test.a); for a in a_values { - if a == END { + if a == VALUES_END { break; }; - let b_values = if test.b == ALL_64 { - ALL_64_VALUES - } else if test.b == ALL_N_64 { - ALL_N_64_VALUES - } else if test.b == ALL_P_64 { - ALL_P_64_VALUES - } else { - [test.b, END, 0, 0, 0, 0] - }; + let b_values = get_test_values(test.b); for b in b_values { - if b == END { + if b == VALUES_END { break; }; - let [m32, div, na, nb, np, nr, sext, range_ab, range_cd, ranges] = + let [m32, div, na, nb, np, nr, sext, sec, range_ab, range_cd, ranges] = TestArithHelpers::calculate_flags_and_ranges(a, b, test.op); - let flags = m32 + div * 2 + na * 4 + nb * 8 + np * 16 + nr * 32 + sext * 64; + let flags = + m32 + div * 2 + na * 4 + nb * 8 + np * 16 + nr * 32 + sext * 64 + sec * 128; assert_eq!( [flags, range_ab, range_cd], From f0600c050b0b9dbe8b0ec040fb70654de65e2c1d Mon Sep 17 00:00:00 2001 From: zkronos73 Date: Tue, 22 Oct 2024 22:58:18 +0000 Subject: [PATCH 10/17] WIP arith tests --- state-machines/arith/pil/arith.pil | 28 +- state-machines/arith/src/arith_helpers.rs | 512 ++++++++++++++++++---- 2 files changed, 451 insertions(+), 89 deletions(-) diff --git a/state-machines/arith/pil/arith.pil b/state-machines/arith/pil/arith.pil index 6d4ad043..50811880 100644 --- a/state-machines/arith/pil/arith.pil +++ b/state-machines/arith/pil/arith.pil @@ -142,14 +142,30 @@ airtemplate Arith(int N = 2**18, const int operation_bus_id, const int dual_resu // 1 1 0 0 divu_w remu_w (0xbc,0xbd) =0 =0 =0 =0 c1 d1 c1,d1 // 1 1 1 1 div_w rem_w (0xbe,0xbf) a1 b1 c1 d1 c1 d1 a1,b1,c1,d1 + + // div m32 sa sb primary secondary opcodes na nb np nr sext(c) + // ----------------------------------------------------------------------------- + // 0 0 0 0 mulu muluh (0xb0,0xb1) =0 =0 =0 =0 =0 =0 + // 0 0 1 0 *n/a* mulsuh (0xb2,0xb3) a3 =0 d3 =0 =0 =0 a3, d3 + // 0 0 1 1 mul mulh (0xb4,0xb5) a3 b3 d3 =0 =0 =0 a3,b3, d3 + // 0 1 0 0 mul_w *n/a* (0xb6,0xb7) =0 =0 =0 =0 c1 =0 + + // div m32 sa sb primary secondary opcodes na nb np nr sext(a,d)(*2) + // ------------------------------------------------------------------------------ + // 1 0 0 0 divu remu (0xb8,0xb9) =0 =0 =0 =0 =0 =0 + // 1 0 1 1 div rem (0xba,0xbb) a3 b3 c3 d3 =0 =0 a3,b3,c3,d3 + // 1 1 0 0 divu_w remu_w (0xbc,0xbd) =0 =0 =0 =0 a1 d1 a1 ,d1 + // 1 1 1 1 div_w rem_w (0xbe,0xbf) a1 b1 c1 d1 a1 d1 a1,b1,c1,d1 + // (*) removed combinations of flags div,m32,sa,sb did allow combinations div, m32, sa, sb + // (*2) sext affects to 32 bits result (bus), but in divisions a is used as result // see 5 previous constraints. // =0 means forced to zero by previous constraints - // bus result mul div - // -------------------------------- - // primary c a - // secondary d d + // bus result primary secondary + // ---------------------------------- + // mul (mulh) c d + // div (remu) a d col witness bus_a_low; bus_a_low === div * (c[0] + c[1] * CHUNK_SIZE) + (1 - div) * (a[0] + a[1] * CHUNK_SIZE); @@ -211,10 +227,10 @@ airtemplate Arith(int N = 2**18, const int operation_bus_id, const int dual_resu // 8 - F - F ab cd 2 2 // 9 F F F + cd a1 sign <=> b1 sign / d1 sign => c1 sign // 10 F F F - cd a1 sign <=> b1 sign / d1 sign => c1 sign - // 11 F + F F cd 3 1 a1 sign <=> b1 sign + // 11 F + F F ab cd 3 1 *a1 for sext/divu // 12 F + F + ab cd 2 2 // 13 F + F - ab cd 2 2 - // 14 F - F F cd 3 1 a1 sign <=> b1 sign + // 14 F - F F ab cd 3 1 *a1 for sext/divu // 15 F - F + ab cd 2 2 // 16 F - F - ab cd 2 2 // ---- ---- diff --git a/state-machines/arith/src/arith_helpers.rs b/state-machines/arith/src/arith_helpers.rs index edb90a47..11d99b01 100644 --- a/state-machines/arith/src/arith_helpers.rs +++ b/state-machines/arith/src/arith_helpers.rs @@ -1,3 +1,5 @@ +use zisk_core::zisk_ops::*; + const MULU: u8 = 0xb0; const MULUH: u8 = 0xb1; const MULSUH: u8 = 0xb3; @@ -16,11 +18,151 @@ const REM_W: u8 = 0xbf; const FLAG_NAMES: [&str; 8] = ["m32", "div", "na", "nb", "np", "nr", "sext", "sec"]; pub trait ArithHelpers { - fn calculate_flags_and_ranges(a: u64, b: u64, op: u8) -> [u64; 11] { + fn sign32(abs_value: u64, negative: bool) -> u64 { + assert!(0xFFFF_FFFF >= abs_value, "abs_value:0x{0:X}({0}) is too big", abs_value); + if negative { + (0xFFFF_FFFF - abs_value) + 1 + } else { + abs_value + } + } + fn sign64(abs_value: u64, negative: bool) -> u64 { + if negative { + (0xFFFF_FFFF_FFFF_FFFF - abs_value) + 1 + } else { + abs_value + } + } + fn sign128(abs_value: u128, negative: bool) -> u128 { + if negative { + (0xFFFF_FFFF_FFFF_FFFF_FFFF_FFFF_FFFF_FFFF - abs_value) + 1 + } else { + abs_value + } + } + fn abs32(value: u64) -> [u64; 2] { + let negative = if (value & 0x8000_0000) != 0 { 1 } else { 0 }; + let abs_value = if negative == 1 { (0xFFFF_FFFF - value) + 1 } else { value }; + // println!( + // "value:0x{0:X}({0}) abs_value:0x{1:X}({1}) negative:{2}", + // value, abs_value, negative + // ); + [abs_value, negative] + } + fn abs64(value: u64) -> [u64; 2] { + let negative = if (value & 0x8000_0000_0000_0000) != 0 { 1 } else { 0 }; + let abs_value = if negative == 1 { (0xFFFF_FFFF_FFFF_FFFF - value) + 1 } else { value }; + [abs_value, negative] + } + fn calculate_mul_w(a: u64, b: u64) -> u64 { + let [abs_a, na] = Self::abs32(a); + let [abs_b, nb] = Self::abs32(b); + // println!( + // "a:0x{0:X}({0}) b:0x{1:X}({1}) abs_a:0x{2:X}({2}) na:{3} abs_b:{4:X}({4}) nb:{5}", + // a, b, abs_a, na, abs_b, nb + // ); + let abs_c = abs_a * abs_b; + let nc = if na != nb && abs_c != 0 { 1 } else { 0 }; + Self::sign64(abs_c, nc == 1) + } + + fn calculate_mulsu(a: u64, b: u64) -> [u64; 2] { + let [abs_a, na] = Self::abs64(a); + let abs_c = abs_a as u128 * b as u128; + let nc = if na == 1 && abs_c != 0 { 1 } else { 0 }; + let c = Self::sign128(abs_c, nc == 1); + [c as u64, (c >> 64) as u64] + } + + fn calculate_mul(a: u64, b: u64) -> [u64; 2] { + let [abs_a, na] = Self::abs64(a); + let [abs_b, nb] = Self::abs64(b); + let abs_c = abs_a as u128 * abs_b as u128; + let nc = if na != nb && abs_c != 0 { 1 } else { 0 }; + let c = Self::sign128(abs_c, nc == 1); + [c as u64, (c >> 64) as u64] + } + + fn calculate_div(a: u64, b: u64) -> u64 { + let [abs_a, na] = Self::abs64(a); + let [abs_b, nb] = Self::abs64(b); + let abs_c = abs_a / abs_b; + let nc = if na != nb && abs_c != 0 { 1 } else { 0 }; + Self::sign64(abs_c, nc == 1) + } + + fn calculate_rem(a: u64, b: u64) -> u64 { + let [abs_a, na] = Self::abs64(a); + let [abs_b, _nb] = Self::abs64(b); + let abs_c = abs_a % abs_b; + let nc = if na == 1 && abs_c != 0 { 1 } else { 0 }; + Self::sign64(abs_c, nc == 1) + } + + fn calculate_div_w(a: u64, b: u64) -> u64 { + let [abs_a, na] = Self::abs32(a); + let [abs_b, nb] = Self::abs32(b); + let abs_c = abs_a / abs_b; + let nc = if na != nb && abs_c != 0 { 1 } else { 0 }; + Self::sign32(abs_c, nc == 1) + } + + fn calculate_rem_w(a: u64, b: u64) -> u64 { + let [abs_a, na] = Self::abs32(a); + let [abs_b, _nb] = Self::abs32(b); + let abs_c = abs_a % abs_b; + let nc = if na == 1 && abs_c != 0 { 1 } else { 0 }; + Self::sign32(abs_c, nc == 1) + } + + fn calculate_emulator_res(op: u8, a: u64, b: u64) -> (u64, bool) { + match op { + MULU => return op_mulu(a, b), + MULUH => return op_muluh(a, b), + MULSUH => return op_mulsuh(a, b), + MUL => return op_mul(a, b), + MULH => return op_mulh(a, b), + MUL_W => return op_mul_w(a, b), + DIVU => return op_divu(a, b), + REMU => return op_remu(a, b), + DIVU_W => return op_divu_w(a, b), + REMU_W => return op_remu_w(a, b), + DIV => return op_div(a, b), + REM => return op_rem(a, b), + DIV_W => return op_div_w(a, b), + REM_W => return op_rem_w(a, b), + _ => { + panic!("Invalid opcode"); + } + } + } + + fn calculate_abcd_from_ab(op: u8, a: u64, b: u64) -> [u64; 4] { + match op { + MULU | MULUH => { + let c: u128 = a as u128 * b as u128; + [a, b, c as u64, (c >> 64) as u64] + } + MULSUH => { + let [c, d] = Self::calculate_mulsu(a, b); + [a, b, c, d] + } + MUL | MULH => { + let [c, d] = Self::calculate_mul(a, b); + [a, b, c, d] + } + MUL_W => [a, b, Self::calculate_mul_w(a, b), 0], + DIVU | REMU | DIVU_W | REMU_W => [a / b, b, a, a % b], + DIV | REM => [Self::calculate_div(a, b), b, a, Self::calculate_rem(a, b)], + DIV_W | REM_W => [Self::calculate_div_w(a, b), b, a, Self::calculate_rem_w(a, b)], + _ => { + panic!("Invalid opcode"); + } + } + } + fn calculate_flags_and_ranges(op: u8, a: u64, b: u64, c: u64, d: u64) -> [u64; 11] { let mut m32: u64 = 0; let mut div: u64 = 0; - let mut na: u64 = 0; - let mut nb: u64 = 0; let mut np: u64 = 0; let mut nr: u64 = 0; let mut sext: u64 = 0; @@ -44,7 +186,7 @@ pub trait ArithHelpers { let mut sa = false; let mut sb = false; - let mut rem32 = false; + let mut rem = false; match op { MULU => {} @@ -70,9 +212,11 @@ pub trait ArithHelpers { } DIVU => { div = 1; + assert!(b != 0, "Error on DIVU a:{:x}({}) b:{:x}({})", a, b, a, b); } REMU => { div = 1; + rem = true; secondary_res = 1; } DIV => { @@ -83,6 +227,7 @@ pub trait ArithHelpers { REM => { sa = true; sb = true; + rem = true; div = 1; secondary_res = 1; } @@ -90,14 +235,16 @@ pub trait ArithHelpers { // divu_w, remu_w div = 1; m32 = 1; - sext = if ((a as u32 / b as u32) as i32) < 0 { 1 } else { 0 }; + // use a in bus + sext = if (a & 0x8000_0000) != 0 { 1 } else { 0 }; } REMU_W => { // divu_w, remu_w div = 1; m32 = 1; - rem32 = true; - sext = if ((a as u32 % b as u32) as i32) < 0 { 1 } else { 0 }; + rem = true; + // use d in bus + sext = if (d & 0x8000_0000) != 0 { 1 } else { 0 }; secondary_res = 1; } DIV_W => { @@ -106,7 +253,8 @@ pub trait ArithHelpers { sb = true; div = 1; m32 = 1; - sext = if (a as i32 / b as i32) < 0 { 1 } else { 0 }; + // use a in bus + sext = if (a & 0x8000_0000) != 0 { 1 } else { 0 }; } REM_W => { // div_w, rem_w @@ -114,63 +262,49 @@ pub trait ArithHelpers { sb = true; div = 1; m32 = 1; - rem32 = true; - sext = if (a as i32 % b as i32) < 0 { 1 } else { 0 }; + rem = true; + // use d in bus + sext = if (d & 0x8000_0000) != 0 { 1 } else { 0 }; secondary_res = 1; } _ => { panic!("Invalid opcode"); } } - if sa { - na = if m32 == 1 { - if (a as i32) < 0 { - 1 - } else { - 0 - } - } else { - if (a as i64) < 0 { - 1 - } else { - 0 - } - } - } - if sb { - nb = if m32 == 1 { - if (b as i32) < 0 { - 1 - } else { - 0 - } - } else { - if (b as i64) < 0 { - 1 - } else { - 0 - } - } - } + let sign_mask: u64 = if m32 == 1 { 0x8000_0000 } else { 0x8000_0000_0000_0000 }; + let na = if sa && (a & sign_mask) != 0 { 1 } else { 0 }; + let nb = if sb && (b & sign_mask) != 0 { 1 } else { 0 }; + // a sign => b sign + let nc = if sa && (c & sign_mask) != 0 { 1 } else { 0 }; // a == 0 || b == 0 => np == 0 ==> how was a signed operation // after that sign of np was verified with range check. - - np = if (a != 0) && (b != 0) { na ^ nb } else { 0 }; - nr = if div == 1 { na } else { 0 }; - + if div == 1 { + np = if c != 0 { nc ^ nb } else { 0 }; + nr = if d != 0 { nc } else { 0 }; + } else { + np = if (c != 0) || (d != 0) { na ^ nb } else { 0 }; + nr = 0; + } if m32 == 1 { // mulw, divu_w, remu_w, div_w, rem_w - range_a1 = if sa { 1 + na } else { 0 }; + range_a1 = if sa { + 1 + na + } else if div == 1 && !rem { + 1 + sext + } else { + 0 + }; range_b1 = if sb { 1 + nb } else { 0 }; - range_c1 = if !rem32 { + // m32 && div == 0 => mulw + range_c1 = if div == 0 { sext + 1 } else if sa { 1 + np } else { 0 }; - range_d1 = if rem32 { + range_d1 = if rem { sext + 1 } else if sa { 1 + nr @@ -414,9 +548,6 @@ pub trait ArithHelpers { chunks } - fn me() -> i32 { - 13 - } } fn flags_to_strings(mut flags: u64, flag_names: &[&str]) -> String { @@ -602,7 +733,7 @@ fn get_test_values(value: u64) -> [u64; 16] { 0, 0, ], - ALL_32 => [ + ALL_NZ_32 => [ 1, 2, 3, @@ -625,28 +756,29 @@ fn get_test_values(value: u64) -> [u64; 16] { 1, 2, 3, + 0x0000_7FFF, + 0x0000_FFFF, + MAX_P_32 - 1, + MAX_P_32, MAX_P_32 - 1, MAX_P_32, - MIN_N_32, - MAX_32 - 1, - MAX_32, VALUES_END, 0, 0, 0, 0, 0, - 0, ], ALL_NZ_P_32 => [ 1, 2, 3, + 0x0000_7FFF, + 0x0000_FFFF, + MAX_P_32 - 1, + MAX_P_32, MAX_P_32 - 1, MAX_P_32, - MIN_N_32, - MAX_32 - 1, - MAX_32, VALUES_END, 0, 0, @@ -654,7 +786,6 @@ fn get_test_values(value: u64) -> [u64; 16] { 0, 0, 0, - 0, ], ALL_N_32 => [ MIN_N_32, @@ -691,7 +822,7 @@ fn test_calculate_range_checks() { } // NOTE: update TEST_COUNT with number of tests, ALL,ALL => 3*3 = 9 - const TEST_COUNT: u32 = 2472; + const TEST_COUNT: u32 = 2510; // NOTE: use 0x0000_0000 instead of 0, to avoid auto-format in one line, 0 is too short. let tests = [ @@ -713,7 +844,7 @@ fn test_calculate_range_checks() { range_ab: R_FF, range_cd: R_FF, }, - // 2 - MULSHU + // 2 - MULSUH TestParams { op: MULSUH, a: ALL_P_64, @@ -722,7 +853,7 @@ fn test_calculate_range_checks() { range_ab: R_3PF, range_cd: R_3FP, }, - // 3 - MULSHU + // 3 - MULSUH TestParams { op: MULSUH, a: ALL_N_64, @@ -731,7 +862,7 @@ fn test_calculate_range_checks() { range_ab: R_3NF, range_cd: R_3FN, }, - // 4 - MULSHU + // 4 - MULSUH TestParams { op: MULSUH, a: ALL_N_64, @@ -942,7 +1073,7 @@ fn test_calculate_range_checks() { TestParams { op: DIVU, a: ALL_64, - b: ALL_64, + b: ALL_NZ_64, flags: F_DIV + 0, range_ab: R_FF, range_cd: R_FF, @@ -951,7 +1082,7 @@ fn test_calculate_range_checks() { TestParams { op: REMU, a: ALL_64, - b: ALL_64, + b: ALL_NZ_64, flags: F_DIV + F_SEC, range_ab: R_FF, range_cd: R_FF, @@ -988,9 +1119,9 @@ fn test_calculate_range_checks() { op: DIV, a: MIN_N_64, b: MIN_N_64, - flags: F_DIV + F_NA + F_NB + F_NR, - range_ab: R_3NN, - range_cd: R_3PN, + flags: F_DIV + F_NB, + range_ab: R_3PN, + range_cd: R_3PP, }, // 33 - DIV TestParams { @@ -1042,9 +1173,9 @@ fn test_calculate_range_checks() { op: REM, a: MIN_N_64, b: MIN_N_64, - flags: F_DIV + F_NA + F_NB + F_NR + F_SEC, - range_ab: R_3NN, - range_cd: R_3PN, + flags: F_DIV + F_NB + F_SEC, + range_ab: R_3PN, + range_cd: R_3PP, }, // 39 - REM TestParams { @@ -1064,27 +1195,216 @@ fn test_calculate_range_checks() { range_ab: R_3PN, range_cd: R_3PP, }, - // DIVU_W - // REMU_W - // DIV_W - // REM_W + // 41 - DIVU_W + TestParams { + op: DIVU_W, + a: 0xFFFF_FFFF, + b: 0x0000_0001, + flags: F_DIV + F_M32 + F_SEXT, + range_ab: R_1NF, + range_cd: R_FF, + }, + // 42 - DIVU_W + TestParams { + op: DIVU_W, + a: ALL_NZ_32, + b: 0x0000_00002, + flags: F_DIV + F_M32, + range_ab: R_1PF, + range_cd: R_FF, + }, + // 43 - DIVU_W + TestParams { + op: DIVU_W, + a: ALL_NZ_32, + b: MAX_32, + flags: F_DIV + F_M32, + range_ab: R_1PF, + range_cd: R_FF, + }, + // 44 - DIVU_W + TestParams { + op: DIVU_W, + a: 0, + b: ALL_NZ_32, + flags: F_DIV + F_M32, + range_ab: R_1PF, + range_cd: R_FF, + }, + // 45 - REMU_W + TestParams { + op: REMU_W, + a: 0xFFFF_FFFF, + b: 0x0000_0001, + flags: F_DIV + F_M32 + F_SEC, + range_ab: R_FF, + range_cd: R_1FP, + }, + // 46 - REMU_W + TestParams { + op: REMU_W, + a: ALL_32, + b: 0x0000_00002, + flags: F_DIV + F_M32 + F_SEC, + range_ab: R_FF, + range_cd: R_1FP, + }, + // 47 - REMU_W + TestParams { + op: REMU_W, + a: ALL_NZ_P_32, + b: MAX_32, + flags: F_DIV + F_M32 + F_SEC, + range_ab: R_FF, + range_cd: R_1FP, + }, + // 48 - REMU_W + TestParams { + op: REMU_W, + a: ALL_32, + b: 0x8000_0000, + flags: F_DIV + F_M32 + F_SEC, + range_ab: R_FF, + range_cd: R_1FP, + }, + // 49 - REMU_W + TestParams { + op: REMU_W, + a: 0, + b: ALL_NZ_32, + flags: F_DIV + F_M32 + F_SEC, + range_ab: R_FF, + range_cd: R_1FP, + }, + // 50 - REMU_W + TestParams { + op: REMU_W, + a: 0xFFFF_FFFE, + b: 0xFFFF_FFFF, + flags: F_DIV + F_M32 + F_SEXT + F_SEC, + range_ab: R_FF, + range_cd: R_1FN, + }, + // 51 - REMU_W + TestParams { + op: REMU_W, + a: 0xFFFF_FFFE, + b: 0xFFFF_FFFE, + flags: F_DIV + F_M32 + F_SEC, + range_ab: R_FF, + range_cd: R_1FP, + }, + // 52 - REMU_W + TestParams { + op: REMU_W, + a: 0x8000_0000, + b: 0x8000_0001, + flags: F_DIV + F_M32 + F_SEXT + F_SEC, + range_ab: R_FF, + range_cd: R_1FN, + }, + // 53 - REMU_W + TestParams { + op: REMU_W, + a: 0x8000_0001, + b: 0x8000_0000, + flags: F_DIV + F_M32 + F_SEC, + range_ab: R_FF, + range_cd: R_1FP, + }, + // 54 - REMU_W + TestParams { + op: REMU_W, + a: 0xFFFF_FFFF, + b: 0x0000_0003, + flags: F_DIV + F_M32 + F_SEC, + range_ab: R_FF, + range_cd: R_1FP, + }, + // 55 - DIV_W (-1/1=-1 REM:0) + TestParams { + op: DIV_W, + a: 0xFFFF_FFFF, + b: 0x0000_0001, + flags: F_DIV + F_NA + F_NP + F_M32 + F_SEXT, + range_ab: R_1NP, + range_cd: R_1NP, + }, + // 56 - REM_W !!! + TestParams { + op: REM_W, + a: 0xFFFF_FFFF, + b: 0x0000_0001, + flags: F_DIV + F_NA + F_NP + F_M32 + F_SEC, + range_ab: R_1NP, + range_cd: R_1NP, + }, + // 57 - DIV_W <====== + TestParams { + op: DIV_W, + a: 0xFFFF_FFFF, + b: 0x0000_0002, + flags: F_DIV + F_NP + F_NR + F_M32, + range_ab: R_1PP, + range_cd: R_1NN, + }, + // 58 - REM_W + TestParams { + op: REM_W, + a: 0xFFFF_FFFF, + b: 0x0000_0002, + flags: F_DIV + F_NP + F_NR + F_M32 + F_SEC + F_SEXT, + range_ab: R_1PP, + range_cd: R_1NN, + }, ]; let mut count = 0; let mut index: u32 = 0; + + #[derive(Debug, PartialEq)] + struct TestDone { + op: u8, + a: u64, + b: u64, + index: u32, + offset: u32, + } + + let mut tests_done: Vec = Vec::new(); for test in tests { let a_values = get_test_values(test.a); - for a in a_values { - if a == VALUES_END { + let mut offset = 0; + for _a in a_values { + if _a == VALUES_END { break; - }; + } let b_values = get_test_values(test.b); - for b in b_values { - if b == VALUES_END { + for _b in b_values { + if _b == VALUES_END { break; - }; + } + let test_info = TestDone { op: test.op, a: _a, b: _b, index, offset }; + let previous = tests_done + .iter() + .find(|&x| x.op == test_info.op && x.a == test_info.a && x.b == test_info.b); + match previous { + Some(e) => { + println!( + "\x1B[35mDuplicated TEST #{} op:0x{:x} a:0x{:X} b:0x{:X} offset:{}\x1B[0m", + e.index, e.op, e.a, e.b, e.offset + ); + } + None => { + tests_done.push(test_info); + } + } + println!("testing #{} op:0x{:x} with _a:0x{:X} _b:0x{:X}", index, test.op, _a, _b); + let (emu_c, emu_flag) = TestArithHelpers::calculate_emulator_res(test.op, _a, _b); + let [a, b, c, d] = TestArithHelpers::calculate_abcd_from_ab(test.op, _a, _b); + let [m32, div, na, nb, np, nr, sext, sec, range_ab, range_cd, ranges] = - TestArithHelpers::calculate_flags_and_ranges(a, b, test.op); + TestArithHelpers::calculate_flags_and_ranges(test.op, a, b, c, d); let flags = m32 + div * 2 + na * 4 + nb * 8 + np * 16 + nr * 32 + sext * 64 + sec * 128; @@ -1092,11 +1412,36 @@ fn test_calculate_range_checks() { assert_eq!( [flags, range_ab, range_cd], [test.flags, test.range_ab, test.range_cd], - "testing #{} op:0x{:x} with a:0x{:X} b:0x{:X} flags:{:b}[{}]/{:b}[{}] range_ab:{}/{} range_cd:{}/{} ranges:{}", + "testing #{} op:0x{:x} with _a:0x{:X} _b:0x{:X} a:0x{:X} b:0x{:X} c:0x{:X} d:0x{:X} EMU:0x{:X} flags:{:b}[{}]/{:b}[{}] range_ab:{}/{} range_cd:{}/{} ranges:{}", + index, + test.op, + _a, + _b, + a, + b, + c, + d, + emu_c, + flags, + flags_to_strings(flags, &FLAG_NAMES), + test.flags, + flags_to_strings(test.flags, &FLAG_NAMES), + range_ab, + test.range_ab, + range_cd, + test.range_cd, + ranges + ); + println!("testing #{} op:0x{:x} with _a:0x{:X} _b:0x{:X} a:0x{:X} b:0x{:X} c:0x{:X} d:0x{:X} EMU:0x{:X} flags:{:b}[{}]/{:b}[{}] range_ab:{}/{} range_cd:{}/{} ranges:{}", index, test.op, + _a, + _b, a, b, + c, + d, + emu_c, flags, flags_to_strings(flags, &FLAG_NAMES), test.flags, @@ -1107,6 +1452,7 @@ fn test_calculate_range_checks() { test.range_cd, ranges ); + offset += 1; count += 1; } } From ce017b39884a8af9b7de0b8cdabe63118d44a4c4 Mon Sep 17 00:00:00 2001 From: zkronos73 Date: Thu, 24 Oct 2024 06:13:14 +0000 Subject: [PATCH 11/17] WIP, helpers test ok peding update pil --- state-machines/arith/pil/arith.pil | 103 +++++--- state-machines/arith/src/arith_helpers.rs | 306 +++++++++++++--------- 2 files changed, 244 insertions(+), 165 deletions(-) diff --git a/state-machines/arith/pil/arith.pil b/state-machines/arith/pil/arith.pil index 50811880..9bcdef1f 100644 --- a/state-machines/arith/pil/arith.pil +++ b/state-machines/arith/pil/arith.pil @@ -49,48 +49,71 @@ airtemplate Arith(int N = 2**18, const int operation_bus_id, const int dual_resu // NOTE: Equations with m32 for multiplication not exists, because mul m32 it's an unsigned operation. // In internal equations, it's same than unsigned mul 64 where high part of a and b are zero - eq[0] = fab * a[0] * b[0] // 3 degree - - c[0] - + 2 * np * c[0] - + div * d[0] - - 2 * nr * d[0]; - - eq[1] = fab * a[1] * b[0] // 3 degree - + fab * a[0] * b[1] // 3 degree - - c[1] - + 2 * np * c[1] - + div * d[1] - - 2 * nr * d[1]; - - eq[2] = fab * a[2] * b[0] // 3 degree - + fab * a[1] * b[1] // 3 degree - + fab * a[0] * b[2] // 3 degree - - c[2] - + 2 * np * c[2] - + div * d[2] - - 2 * nr * d[2] - - np * div * m32 // 3 degree - + nr * div * m32; // 3 degree - - eq[3] = fab * a[3] * b[0] // 3 degree - + fab * a[2] * b[1] // 3 degree - + fab * a[1] * b[2] // 3 degree - + fab * a[0] * b[3] // 3 degree - - c[3] - + 2 * np * c[3] - + div * d[3] - - 2 * nr * d[3]; - - eq[4] = fab * a[3] * b[1] // 3 degree - + fab * a[2] * b[2] // 3 degree - + fab * a[1] * b[3] // 3 degree - + na * b[0] * (1 - 2 * nb) // 3 degree + // abs(x) x >= 0 ➜ nx == 0 ➜ x + // x < 0 ➜ nx == 1 ➜ 2^64 - x + // + // abs(x,nx) = nx * (2^64 - 2 * x) + x = 2^64 * nx - 2 * nx * x + x + // + // chunk[0] = x[0] - 2 * nx + x[0] // 2^0 + // chunk[1] = x[1] - 2 * nx + x[1] // 2^16 + // chunk[2] = x[2] - 2 * nx + x[2] // 2^24 + // chunk[3] = x[3] - 2 * nx + x[3] // 2^48 + // chunk[4] = nx // 2^64 + // + // or chunk[3] = x[3] - 2 * nx + x[3] + 2^16 * nx + // chunk[4] = 0 + // + // dual use of d, on multiplication d is high part of result, while in division d + // is the remainder. Selector of these two uses is div or nr (because nr = 0 for div = 0) + // + // div = 0 ➜ a * b = 2^64 * d + c ➜ a * b - 2^64 * d - c === 0 + // div = 1 ➜ a * b + d = c ➜ a * b - c + d === 0 + // + // eq = a * b + c - div * d - (1 - div) * 2^64 * d + + eq[0] = fab * a[0] * b[0] + - c[0] // ⎫ np == 0 ➜ - c + + 2 * np * c[0] // ⎭ np == 1 ➜ - c + 2c = c + + div * d[0] // ⎫ div == 0 ➜ nr = 0 ➜ 0 + - 2 * nr * d[0]; // ⎥ div == 1 and nr == 0 ➜ d + // ⎭ div == 1 and nr == 1 ➜ d - 2d = -d + + eq[1] = fab * a[1] * b[0] + + fab * a[0] * b[1] + - c[1] // ⎫ np == 0 ➜ - c + + 2 * np * c[1] // ⎭ np == 1 ➜ c + + div * d[1] // ⎫ div == 1 ➜ d or -d + - 2 * nr * d[1]; // ⎭ div == 0 ➜ 0 + + eq[2] = fab * a[2] * b[0] + + fab * a[1] * b[1] + + fab * a[0] * b[2] + - c[2] // ⎫ np == 0 ➜ - c + + 2 * np * c[2] // ⎭ np == 1 ➜ c + + div * d[2] // ⎫ div == 1 ➜ d or -d + - 2 * nr * d[2] // ⎭ div == 0 ➜ 0 + - np * div * m32 // m32 == 1 and np == 1 ➜ -2^32 (global) or -1 (in 3rd chunk) + + nr * m32; // m32 == 1 and nr == 1 ➜ div == 1 ➜ 2^32 (global) or 1 (in 3rd chunk) + + eq[3] = fab * a[3] * b[0] + + fab * a[2] * b[1] + + fab * a[1] * b[2] + + fab * a[0] * b[3] // NOTE: m32 => high part is 0 + - c[3] // ⎫ np == 0 ➜ - c + + 2 * np * c[3] // ⎭ np == 1 ➜ c + + div * d[3] // ⎫ div == 1 ➜ d or -d + - 2 * nr * d[3]; // ⎭ div == 0 ➜ 0 + + eq[4] = fab * a[3] * b[1] + + fab * a[2] * b[2] + + fab * a[1] * b[3] + + na * b[0] * (1 - 2 * nb) // nb == 1 ➜ degree + nb * a[0] * (1 - 2 * na) // 3 degree - np * div // | + np * div * m32 // 3 degree | np * (div ^ m32) - 2 * div * m32 * np // 3 degree | - // + nr * (1 - m32) * div // 3 degree - - d[0] * (1 - div) + + nr * (1 - m32) // 3 degree + - d[0] * (1 - div) // 3 degree + 2 * np * d[0] * (1 - div); // 3 degree eq[5] = fab * a[3] * b[2] // 3 degree @@ -109,8 +132,8 @@ airtemplate Arith(int N = 2**18, const int operation_bus_id, const int dual_resu eq[7] = CHUNK_SIZE * na * nb + na * b[3] * (1 - 2 * nb) // 3 degree + nb * a[3] * (1 - 2 * na) // 3 degree - // - CHUNK_SIZE * np * (1 - div) * (1 - m32) // 3 degree - - CHUNK_SIZE * np * (1 - div) + - CHUNK_SIZE * np * (1 - div) * (1 - m32) // 3 degree + // - CHUNK_SIZE * np * (1 - div) - d[3] * (1 - div) + 2 * np * d[3] * (1 - div); // 3 degree diff --git a/state-machines/arith/src/arith_helpers.rs b/state-machines/arith/src/arith_helpers.rs index 11d99b01..f17f73e9 100644 --- a/state-machines/arith/src/arith_helpers.rs +++ b/state-machines/arith/src/arith_helpers.rs @@ -34,11 +34,13 @@ pub trait ArithHelpers { } } fn sign128(abs_value: u128, negative: bool) -> u128 { - if negative { + let res = if negative { (0xFFFF_FFFF_FFFF_FFFF_FFFF_FFFF_FFFF_FFFF - abs_value) + 1 } else { abs_value - } + }; + println!("sign128({:X},{})={:X}", abs_value, negative, res); + res } fn abs32(value: u64) -> [u64; 2] { let negative = if (value & 0x8000_0000) != 0 { 1 } else { 0 }; @@ -55,15 +57,16 @@ pub trait ArithHelpers { [abs_value, negative] } fn calculate_mul_w(a: u64, b: u64) -> u64 { - let [abs_a, na] = Self::abs32(a); - let [abs_b, nb] = Self::abs32(b); + a * b + // let [abs_a, na] = Self::abs32(a); + // let [abs_b, nb] = Self::abs32(b); + // let abs_c = abs_a * abs_b; + // let nc = if na != nb && abs_c != 0 { 1 } else { 0 }; // println!( - // "a:0x{0:X}({0}) b:0x{1:X}({1}) abs_a:0x{2:X}({2}) na:{3} abs_b:{4:X}({4}) nb:{5}", - // a, b, abs_a, na, abs_b, nb + // "a:0x{0:X}({0}) b:0x{1:X}({1}) abs_a:0x{2:X}({2}) na:{3} abs_b:0x{4:X}({4}) nb:{5} abs_b:0x{6:X}({6}) nc:{7}", + // a, b, abs_a, na, abs_b, nb, abs_c, nc // ); - let abs_c = abs_a * abs_b; - let nc = if na != nb && abs_c != 0 { 1 } else { 0 }; - Self::sign64(abs_c, nc == 1) + // Self::sign64(abs_c, nc == 1) } fn calculate_mulsu(a: u64, b: u64) -> [u64; 2] { @@ -77,6 +80,10 @@ pub trait ArithHelpers { fn calculate_mul(a: u64, b: u64) -> [u64; 2] { let [abs_a, na] = Self::abs64(a); let [abs_b, nb] = Self::abs64(b); + println!( + "mul(a:0x{:X}, b:0x{:X} abs_a:0x{:X} na:{} abs_b:0x{:X} nb:{}", + a, b, abs_a, na, abs_b, nb, + ); let abs_c = abs_a as u128 * abs_b as u128; let nc = if na != nb && abs_c != 0 { 1 } else { 0 }; let c = Self::sign128(abs_c, nc == 1); @@ -272,18 +279,22 @@ pub trait ArithHelpers { } } let sign_mask: u64 = if m32 == 1 { 0x8000_0000 } else { 0x8000_0000_0000_0000 }; + let sign_c_mask: u64 = + if m32 == 1 && div == 1 { 0x8000_0000 } else { 0x8000_0000_0000_0000 }; let na = if sa && (a & sign_mask) != 0 { 1 } else { 0 }; let nb = if sb && (b & sign_mask) != 0 { 1 } else { 0 }; // a sign => b sign - let nc = if sa && (c & sign_mask) != 0 { 1 } else { 0 }; + let nc = if sa && (c & sign_c_mask) != 0 { 1 } else { 0 }; + let nd = if sa && (d & sign_mask) != 0 { 1 } else { 0 }; // a == 0 || b == 0 => np == 0 ==> how was a signed operation // after that sign of np was verified with range check. + // TODO: review if secure if div == 1 { - np = if c != 0 { nc ^ nb } else { 0 }; - nr = if d != 0 { nc } else { 0 }; + np = nc; //if c != 0 { na ^ nb } else { 0 }; + nr = nd; } else { - np = if (c != 0) || (d != 0) { na ^ nb } else { 0 }; + np = if m32 == 1 { nc } else { nd }; // if (c != 0) || (d != 0) { na ^ nb } else { 0 } nr = 0; } if m32 == 1 { @@ -376,94 +387,8 @@ pub trait ArithHelpers { + range_d1; [m32, div, na, nb, np, nr, sext, secondary_res, range_ab, range_cd, ranges] } - /* - fn calculate_flags( - &self, - op: u8, - a: u64, - b: u64, - na: &mut i64, - nb: &mut i64, - nr: &mut i64, - np: &mut i64, - na32: &mut i64, - nd32: &mut i64, - m32: &mut i64, - div: &mut i64, - fab: &mut i64, - ) -> [u64; 8] { - let MUL_W = 1; - match (op) { - MUL_W=> { - let na = if (a as i32) < 0 { 1 } else { 0 }; - let nb = if (b as i32) < 0 { 1 } else { 0 }; - let c = (a as i32 * b as i32); - let nc = if c < 0 { 1 } else { 0 }; - } - MULSUH => { - let na = if (a as i64) < 0 { 1 } else { 0 }; - let _na = input.a & (2n**63n) ? 1n : 0n; - let _a = _na ? 2n ** 64n - a : a; - let _prod = _a * b; - let _nc = _prod && _na; - _prod = _nc ? 2n**128n - _prod : _prod; - c = _prod & (2n**64n - 1n); - d = _prod >> 64n; - // console.log(input.c.toString(16), c.toString(16)); - break; - } - case 'divu': - case 'divu_w': { - this.log(opdef.n,a,b); - const div = a / b; - const rem = a % b; - c = a; - a = div; - d = rem; - break; - } - case 'div': { - this.log('div',a,b); - let _na = input.a & (2n**63n) ? 1n : 0n; - let _a = _na ? 2n ** 64n - a : a; - let _nb = input.b & (2n**63n) ? 1n : 0n; - let _b = _nb ? 2n ** 64n - b : b; - const div = _a / _b; - const rem = _a % _b; - c = a; - a = (div && _na ^ _nb) ? 2n**64n - div : div; - d = (rem && _na) ? 2n**64n - rem : rem; - break; - } - case 'div_w': { - this.log('div_w',a,b); - let _na = input.a & (2n**31n) ? 1n : 0n; - let _a = _na ? 2n ** 32n - a : a; - let _nb = input.b & (2n**31n) ? 1n : 0n; - let _b = _nb ? 2n ** 32n - b : b; - this.log([_a,_b].map(x => x.toString(16)).join(' ')); - const div = _a / _b; - const rem = _a % _b; - this.log(div, rem, _na, _nb) - c = a; - a = (div && (_na ^ _nb)) ? 2n**32n - div : div; - d = (rem && _na) ? 2n**32n - rem : rem; - this.log('[a,b,c,d]='+[a,b,c,d].map(x => x.toString(16)).join(' ')); - break; - } - } - if (m32) { - this.log(opdef.a_signed, opdef.b_signed, a.toString(16), (a & 0x80000000n).toString(16)); - a = (opdef.a_signed && a & 0x80000000n) ? a | 0xFFFFFFFF00000000n : a; - b = (opdef.b_signed && b & 0x80000000n) ? b | 0xFFFFFFFF00000000n : b; - } - - return [a,b,c,d]; - [0, 0, 0, 0, 0, 0, 0, 0] - } */ fn calculate_chunks( - &self, a: [i64; 4], b: [i64; 4], c: [i64; 4], @@ -482,7 +407,10 @@ pub trait ArithHelpers { let mut chunks: [i64; 8] = [0, 0, 0, 0, 0, 0, 0, 0]; - chunks[0] = fab * a[0] * b[0] // chunk9 + let na_fb = na * (1 - 2 * nb); + let nb_fa = nb * (1 - 2 * na); + + chunks[0] = fab * a[0] * b[0] // chunk0 - c[0] + 2 * np * c[0] + div * d[0] @@ -498,17 +426,21 @@ pub trait ArithHelpers { chunks[2] = fab * a[2] * b[0] // chunk2 + fab * a[1] * b[1] + fab * a[0] * b[2] + + a[0] * nb_fa * m32 + + b[0] * na_fb * m32 - c[2] - + (2 * np) * c[2] + + 2 * np * c[2] + div * d[2] - 2 * nr * d[2] - np * div * m32 - + nr * m32; + + nr * m32; // div == 0 ==> nr = 0 chunks[3] = fab * a[3] * b[0] // chunk3 + fab * a[2] * b[1] + fab * a[1] * b[2] + fab * a[0] * b[3] + + a[1] * nb_fa * m32 + + b[1] * na_fb * m32 - c[3] + 2 * np * c[3] + div * d[3] @@ -517,37 +449,155 @@ pub trait ArithHelpers { chunks[4] = fab * a[3] * b[1] // chunk4 + fab * a[2] * b[2] + fab * a[1] * b[3] - + b[0] * na * (1 - 2 * nb) - + a[0] * nb * (1 - 2 * na) - - np * div - + m32 - - 2 * div * m32 - + nr * (1 - m32) - - d[0] * (1 - div) - + d[0] * 2 * np * (1 - div); + + na * nb * m32 + // + b[0] * na * (1 - 2 * nb) + // + a[0] * nb * (1 - 2 * na) + + b[0] * na_fb * (1 - m32) + + a[0] * nb_fa * (1 - m32) + // high bits ^^^ + // - np * div + // + np * div * m32 + // - 2 * div * m32 * np + - np * m32 * (1 - div) // + - np * (1 - m32) * div // 2^64 (np) + + nr * (1 - m32) // 2^64 (nr) + // high part d + - d[0] * (1 - div) // m32 == 1 and div == 0 => d = 0 + + 2 * np * d[0] * (1 - div); // chunks[5] = fab * a[3] * b[2] // chunk5 + fab * a[2] * b[3] - + a[1] * nb * (1 - 2 * na) - + b[1] * na * (1 - 2 * nb) + + a[1] * nb_fa * (1 - m32) + + b[1] * na_fb * (1 - m32) - d[1] * (1 - div) + d[1] * 2 * np * (1 - div); chunks[6] = fab as i64 * a[3] * b[3] // chunk6 - + a[2] * nb * (1 - 2 * na) - + b[2] * na * (1 - 2 * nb) + + a[2] * nb_fa * (1 - m32) + + b[2] * na_fb * (1 - m32) - d[2] * (1 - div) + d[2] * 2 * np * (1 - div); - chunks[7] = 0x10000 * na * nb // chunk7 - + b[3] * na * (1 - 2 * nb) - + a[3] * nb * (1 - 2 * na) + // 0x4000_0000_0000_0000__8000_0000_0000_0000 + chunks[7] = 0x10000 * na * nb * (1 - m32) // chunk7 + + a[3] * nb_fa * (1 - m32) + + b[3] * na_fb * (1 - m32) - 0x10000 * np * (1 - div) * (1 - m32) - d[3] * (1 - div) + d[3] * 2 * np * (1 - div); chunks } + fn u64_to_chunks(a: u64) -> [i64; 4] { + [ + (a & 0xFFFF) as i64, + ((a >> 16) & 0xFFFF) as i64, + ((a >> 32) & 0xFFFF) as i64, + ((a >> 48) & 0xFFFF) as i64, + ] + } + fn execute_chunks( + a: u64, + b: u64, + c: u64, + d: u64, + m32: u64, + div: u64, + na: u64, + nb: u64, + np: u64, + nr: u64, + ) -> bool { + let fab: i64 = 1 - 2 * na as i64 - 2 * nb as i64 + 4 * na as i64 * nb as i64; + let a_chunks = Self::u64_to_chunks(a); + let b_chunks = Self::u64_to_chunks(b); + let c_chunks = Self::u64_to_chunks(c); + let d_chunks = Self::u64_to_chunks(d); + println!( + "A: 0x{0:>04X} \x1B[32m{0:>5}\x1B[0m|0x{1:>04X} \x1B[32m{1:>5}\x1B[0m|0x{2::>04X} \x1B[32m{2:>5}\x1B[0m|0x{3:>04X} \x1B[32m{3:>5}\x1B[0m|", + a_chunks[0], a_chunks[1], a_chunks[2], a_chunks[3] + ); + println!( + "B: 0x{0:>04X} \x1B[32m{0:>5}\x1B[0m|0x{1:>04X} \x1B[32m{1:>5}\x1B[0m|0x{2::>04X} \x1B[32m{2:>5}\x1B[0m|0x{3:>04X} \x1B[32m{3:>5}\x1B[0m|", + b_chunks[0], b_chunks[1], b_chunks[2], b_chunks[3] + ); + println!( + "C: 0x{0:>04X} \x1B[32m{0:>5}\x1B[0m|0x{1:>04X} \x1B[32m{1:>5}\x1B[0m|0x{2::>04X} \x1B[32m{2:>5}\x1B[0m|0x{3:>04X} \x1B[32m{3:>5}\x1B[0m|", + c_chunks[0], c_chunks[1], c_chunks[2], c_chunks[3] + ); + println!( + "D: 0x{0:>04X} \x1B[32m{0:>5}\x1B[0m|0x{1:>04X} \x1B[32m{1:>5}\x1B[0m|0x{2::>04X} \x1B[32m{2:>5}\x1B[0m|0x{3:>04X} \x1B[32m{3:>5}\x1B[0m|", + d_chunks[0], d_chunks[1], d_chunks[2], d_chunks[3] + ); + + let mut chunks = Self::calculate_chunks( + a_chunks, b_chunks, c_chunks, d_chunks, m32 as i64, div as i64, na as i64, nb as i64, + np as i64, nr as i64, fab, + ); + let mut carry: i64 = 0; + println!( + "0x{0:X}({0}),0x{1:X}({1}),0x{2:X}({2}),0x{3:X}({3}),0x{4:X}({4}),0x{5:X}({5}),0x{6:X}{6},0x{7:X}({7}) fab:{8:X}", + chunks[0], chunks[1], chunks[2], chunks[3], chunks[4], chunks[5], chunks[6], chunks[7], fab + ); + let mut carrys: [i64; 8] = [0, 0, 0, 0, 0, 0, 0, 0]; + for _index in 0..8 { + println!( + "APPLY CARRY:{0} CHUNK[{1}]:{2:X} ({2}) {3:X}({3})", + carry, + _index, + chunks[_index], + chunks[_index] + carry + ); + let chunk_value = chunks[_index] + carry; + carry = chunk_value / 0x10000; + chunks[_index] = chunk_value - carry * 0x10000; + carrys[_index] = carry; + } + println!( + "CARRY 0x{0:X}({0}),0x{1:X}({1}),0x{2:X}({2}),0x{3:X}({3}),0x{4:X}({4}),0x{5:X}({5}),0x{6:X}{6},0x{7:X}({7}) fab:{8:X}", + carrys[0], carrys[1], carrys[2], carrys[3], carrys[4], carrys[5], carrys[6], carrys[7], fab + ); + println!( + "0x{:X},0x{:X},0x{:X},0x{:X},0x{:X},0x{:X},0x{:X},0x{:X} carry:0x{:X}", + chunks[0], + chunks[1], + chunks[2], + chunks[3], + chunks[4], + chunks[5], + chunks[6], + chunks[7], + carry + ); + println!( + "{} {} {} {} {} {} {} {} {}", + chunks[0], + chunks[1], + chunks[2], + chunks[3], + chunks[4], + chunks[5], + chunks[6], + chunks[7], + carry + ); + if chunks[0] != 0 + || chunks[1] != 0 + || chunks[2] != 0 + || chunks[3] != 0 + || chunks[4] != 0 + || chunks[5] != 0 + || chunks[6] != 0 + || chunks[7] != 0 + || carry != 0 + { + println!("[\x1B[31mFAIL\x1B[0m]"); + false + } else { + println!("[\x1B[32mOK\x1B[0m]"); + true + } + } } fn flags_to_strings(mut flags: u64, flag_names: &[&str]) -> String { @@ -1110,18 +1160,18 @@ fn test_calculate_range_checks() { op: DIV, a: MAX_P_64, b: MIN_N_64, - flags: F_DIV + F_NB + F_NP, + flags: F_DIV + F_NB, // a/b = 0 ➜ np = 0 range_ab: R_3PN, - range_cd: R_3NP, + range_cd: R_3PP, }, // 32 - DIV TestParams { op: DIV, a: MIN_N_64, b: MIN_N_64, - flags: F_DIV + F_NB, + flags: F_DIV + F_NB + F_NP, // a/b = 1 ➜ 1 * b_neg ➜ np = 1 range_ab: R_3PN, - range_cd: R_3PP, + range_cd: R_3NP, }, // 33 - DIV TestParams { @@ -1164,18 +1214,18 @@ fn test_calculate_range_checks() { op: REM, a: MAX_P_64, b: MIN_N_64, - flags: F_DIV + F_NB + F_NP + F_SEC, + flags: F_DIV + F_NB + F_SEC, range_ab: R_3PN, - range_cd: R_3NP, + range_cd: R_3PP, }, // 38 - REM TestParams { op: REM, a: MIN_N_64, b: MIN_N_64, - flags: F_DIV + F_NB + F_SEC, + flags: F_DIV + F_NB + F_NP + F_SEC, range_ab: R_3PN, - range_cd: R_3PP, + range_cd: R_3NP, }, // 39 - REM TestParams { @@ -1372,6 +1422,7 @@ fn test_calculate_range_checks() { } let mut tests_done: Vec = Vec::new(); + let mut errors = 0; for test in tests { let a_values = get_test_values(test.a); let mut offset = 0; @@ -1452,11 +1503,16 @@ fn test_calculate_range_checks() { test.range_cd, ranges ); + if !TestArithHelpers::execute_chunks(a, b, c, d, m32, div, na, nb, np, nr) { + errors += 1; + println!("TOTAL ERRORS: {}", errors); + } offset += 1; count += 1; } } index += 1; } + println!("TOTAL ERRORS: {}", errors); assert_eq!(count, TEST_COUNT, "Number of tests not matching"); } From adf8e5b4161bdd5c9397ad19c7c1873046d00918 Mon Sep 17 00:00:00 2001 From: zkronos73 Date: Thu, 24 Oct 2024 11:00:25 +0000 Subject: [PATCH 12/17] WIP arith helpers pass bus tests --- state-machines/arith/pil/arith.pil | 55 ++++++++------ state-machines/arith/src/arith_helpers.rs | 87 +++++++++++++++++++++-- 2 files changed, 116 insertions(+), 26 deletions(-) diff --git a/state-machines/arith/pil/arith.pil b/state-machines/arith/pil/arith.pil index 9bcdef1f..bbf8a2b8 100644 --- a/state-machines/arith/pil/arith.pil +++ b/state-machines/arith/pil/arith.pil @@ -36,6 +36,9 @@ airtemplate Arith(int N = 2**18, const int operation_bus_id, const int dual_resu // fab = 1 if sign of a,b are the same // fab = -1 if sign of a,b are different + col witness na_fb; + col witness nb_fa; + col witness debug_main_step; // only for debug col witness secondary_res; // op_index: 0 => first result, 1 => second result; @@ -43,6 +46,8 @@ airtemplate Arith(int N = 2**18, const int operation_bus_id, const int dual_resu // factor ab € {-1, 1} fab === 1 - 2 * na - 2 * nb + 4 * na * nb; + na_fb === na * (1 - 2 * nb); + nb_fa === nb * (1 - 2 * na); const expr eq[CHUNKS_OP]; @@ -88,17 +93,21 @@ airtemplate Arith(int N = 2**18, const int operation_bus_id, const int dual_resu eq[2] = fab * a[2] * b[0] + fab * a[1] * b[1] + fab * a[0] * b[2] - - c[2] // ⎫ np == 0 ➜ - c - + 2 * np * c[2] // ⎭ np == 1 ➜ c - + div * d[2] // ⎫ div == 1 ➜ d or -d - - 2 * nr * d[2] // ⎭ div == 0 ➜ 0 - - np * div * m32 // m32 == 1 and np == 1 ➜ -2^32 (global) or -1 (in 3rd chunk) - + nr * m32; // m32 == 1 and nr == 1 ➜ div == 1 ➜ 2^32 (global) or 1 (in 3rd chunk) + + a[0] * nb_fa * m32 // ⎫ sign contribution when m32 + + b[0] * na_fb * m32 // ⎭ + - c[2] // ⎫ np == 0 ➜ - c + + 2 * np * c[2] // ⎭ np == 1 ➜ c + + div * d[2] // ⎫ div == 1 ➜ d or -d + - 2 * nr * d[2] // ⎭ div == 0 ➜ 0 + - np * div * m32 // m32 == 1 and np == 1 ➜ -2^32 (global) or -1 (in 3rd chunk) + + nr * m32; // m32 == 1 and nr == 1 ➜ div == 1 ➜ 2^32 (global) or 1 (in 3rd chunk) eq[3] = fab * a[3] * b[0] + fab * a[2] * b[1] + fab * a[1] * b[2] + fab * a[0] * b[3] // NOTE: m32 => high part is 0 + + a[1] * nb_fa * m32 // ⎫ sign contribution when m32 + + b[1] * na_fb * m32 // ⎭ - c[3] // ⎫ np == 0 ➜ - c + 2 * np * c[3] // ⎭ np == 1 ➜ c + div * d[3] // ⎫ div == 1 ➜ d or -d @@ -107,31 +116,35 @@ airtemplate Arith(int N = 2**18, const int operation_bus_id, const int dual_resu eq[4] = fab * a[3] * b[1] + fab * a[2] * b[2] + fab * a[1] * b[3] - + na * b[0] * (1 - 2 * nb) // nb == 1 ➜ degree - + nb * a[0] * (1 - 2 * na) // 3 degree - - np * div // | - + np * div * m32 // 3 degree | np * (div ^ m32) - - 2 * div * m32 * np // 3 degree | - + nr * (1 - m32) // 3 degree + + na * nb * m32 + // + b[0] * na * (1 - 2 * nb) + // + a[0] * nb * (1 - 2 * na) + + b[0] * na_fb * (1 - m32) + + a[0] * nb_fa * (1 - m32) + + - np * m32 * (1 - div) // + - np * (1 - m32) * div // 2^64 (np) + + nr * (1 - m32) // 2^64 (nr) + - d[0] * (1 - div) // 3 degree + 2 * np * d[0] * (1 - div); // 3 degree eq[5] = fab * a[3] * b[2] // 3 degree + fab * a[2] * b[3] // 3 degree - + nb * a[1] * (1 - 2 * na) - + na * b[1] * (1 - 2 * nb) + + a[1] * nb_fa * (1 - m32) + + b[1] * na_fb * (1 - m32) - d[1] * (1 - div) - + 2 * np * d[1] * (1 - div); + + d[1] * 2 * np * (1 - div); eq[6] = fab * a[3] * b[3] // 3 degree - + nb * a[2] * (1 - 2 * na) // 3 degree - + na * b[2] * (1 - 2 * nb) // 3 degree + + a[2] * nb_fa * (1 - m32) + + b[2] * na_fb * (1 - m32) - d[2] * (1 - div) + 2 * np * d[2] * (1 - div); // 3 degree - eq[7] = CHUNK_SIZE * na * nb - + na * b[3] * (1 - 2 * nb) // 3 degree - + nb * a[3] * (1 - 2 * na) // 3 degree + eq[7] = CHUNK_SIZE * na * nb * (1 - m32) + + a[3] * nb_fa * (1 - m32) + + b[3] * na_fb * (1 - m32) - CHUNK_SIZE * np * (1 - div) * (1 - m32) // 3 degree // - CHUNK_SIZE * np * (1 - div) - d[3] * (1 - div) @@ -194,7 +207,7 @@ airtemplate Arith(int N = 2**18, const int operation_bus_id, const int dual_resu bus_a_low === div * (c[0] + c[1] * CHUNK_SIZE) + (1 - div) * (a[0] + a[1] * CHUNK_SIZE); col witness bus_a_high; - bus_a_high === div * (c[2] + c[2] * CHUNK_SIZE) + (1 - div) * (a[2] + a[3] * CHUNK_SIZE); + bus_a_high === div * (c[2] + c[3] * CHUNK_SIZE) + (1 - div) * (a[2] + a[3] * CHUNK_SIZE); m32 * (1 - bus_a_high) === 0; diff --git a/state-machines/arith/src/arith_helpers.rs b/state-machines/arith/src/arith_helpers.rs index f17f73e9..83d775f1 100644 --- a/state-machines/arith/src/arith_helpers.rs +++ b/state-machines/arith/src/arith_helpers.rs @@ -400,6 +400,8 @@ pub trait ArithHelpers { np: i64, nr: i64, fab: i64, + secondary_res: i64, + sext: i64, ) -> [i64; 8] { // TODO: unroll this function in variants (div,m32) and (na,nb,nr,np) // div, m32, na, nb === f(div,m32,na,nb) => fa, nb, nr @@ -507,6 +509,9 @@ pub trait ArithHelpers { nb: u64, np: u64, nr: u64, + secondary_res: u64, + sext: u64, + bus: [u64; 8], ) -> bool { let fab: i64 = 1 - 2 * na as i64 - 2 * nb as i64 + 4 * na as i64 * nb as i64; let a_chunks = Self::u64_to_chunks(a); @@ -531,8 +536,19 @@ pub trait ArithHelpers { ); let mut chunks = Self::calculate_chunks( - a_chunks, b_chunks, c_chunks, d_chunks, m32 as i64, div as i64, na as i64, nb as i64, - np as i64, nr as i64, fab, + a_chunks, + b_chunks, + c_chunks, + d_chunks, + m32 as i64, + div as i64, + na as i64, + nb as i64, + np as i64, + nr as i64, + fab, + secondary_res as i64, + sext as i64, ); let mut carry: i64 = 0; println!( @@ -581,7 +597,7 @@ pub trait ArithHelpers { chunks[7], carry ); - if chunks[0] != 0 + let mut passed = if chunks[0] != 0 || chunks[1] != 0 || chunks[2] != 0 || chunks[3] != 0 @@ -596,7 +612,56 @@ pub trait ArithHelpers { } else { println!("[\x1B[32mOK\x1B[0m]"); true - } + }; + const CHUNK_SIZE: i64 = 0x10000; + let bus_a_low: i64 = div as i64 * (c_chunks[0] + c_chunks[1] * CHUNK_SIZE) + + (1 - div as i64) * (a_chunks[0] + a_chunks[1] * CHUNK_SIZE); + let bus_a_high: i64 = div as i64 * (c_chunks[2] + c_chunks[3] * CHUNK_SIZE) + + (1 - div as i64) * (a_chunks[2] + a_chunks[3] * CHUNK_SIZE); + + let bus_b_low: i64 = b_chunks[0] + CHUNK_SIZE * b_chunks[1]; + let bus_b_high: i64 = b_chunks[2] + CHUNK_SIZE * b_chunks[3]; + + let res2_low: i64 = d_chunks[0] + CHUNK_SIZE * d_chunks[1]; + let res2_high: i64 = d_chunks[2] + CHUNK_SIZE * d_chunks[3]; + + let res_low: i64 = secondary_res as i64 * res2_low + + (1 - secondary_res as i64) + * (a_chunks[0] + c_chunks[0] + CHUNK_SIZE * (a_chunks[1] + c_chunks[1]) + - bus_a_low); + println!( + "RES_LOW: 0x{0:X}({0}) 0x{1:X}({1}) 0x{2:X}({2})", + res_low, + a_chunks[2] + c_chunks[2] + CHUNK_SIZE * (a_chunks[3] + c_chunks[3]), + bus_a_high + ); + let res_high: i64 = (1 - m32 as i64) + * (secondary_res as i64 * res2_high + + (1 - secondary_res as i64) + * ((a_chunks[2] + c_chunks[2] + CHUNK_SIZE * (a_chunks[3] + c_chunks[3])) + - bus_a_high)) + + sext as i64 * 0xFFFFFFFF; + passed = passed + && if bus[1] != bus_a_low as u64 + || bus[2] != bus_a_high as u64 + || bus[3] != bus_b_low as u64 + || bus[4] != bus_b_high as u64 + || bus[5] != res_low as u64 + || bus[6] != res_high as u64 + { + println!("0x{0:X} ({0}) vs 0x{1:X} ({1})", bus[1], bus_a_low); + println!("0x{0:X} ({0}) vs 0x{1:X} ({1})", bus[2], bus_a_high); + println!("0x{0:X} ({0}) vs 0x{1:X} ({1})", bus[3], bus_b_low); + println!("0x{0:X} ({0}) vs 0x{1:X} ({1})", bus[4], bus_b_high); + println!("0x{0:X} ({0}) vs 0x{1:X} ({1})", bus[5], res_low); + println!("0x{0:X} ({0}) vs 0x{1:X} ({1})", bus[6], res_high); + println!("[\x1B[31mFAIL BUS\x1B[0m]"); + false + } else { + println!("[\x1B[32mOK BUS\x1B[0m]"); + true + }; + passed } } @@ -1503,7 +1568,19 @@ fn test_calculate_range_checks() { test.range_cd, ranges ); - if !TestArithHelpers::execute_chunks(a, b, c, d, m32, div, na, nb, np, nr) { + let bus: [u64; 8] = [ + test.op as u64, + _a & 0xFFFF_FFFF, + _a >> 32, + _b & 0xFFFF_FFFF, + _b >> 32, + emu_c & 0xFFFF_FFFF, + emu_c >> 32, + if emu_flag { 1 } else { 0 }, + ]; + if !TestArithHelpers::execute_chunks( + a, b, c, d, m32, div, na, nb, np, nr, sec, sext, bus, + ) { errors += 1; println!("TOTAL ERRORS: {}", errors); } From 327846b0c979bedfe415fb897ae23bba6881d063 Mon Sep 17 00:00:00 2001 From: zkronos73 Date: Tue, 29 Oct 2024 23:45:36 +0000 Subject: [PATCH 13/17] WIP update pil of arith_table, add method to calculate row of table --- state-machines/arith/pil/arith_table.pil | 328 ++++++++++++++-------- state-machines/arith/src/arith_helpers.rs | 49 +++- 2 files changed, 252 insertions(+), 125 deletions(-) diff --git a/state-machines/arith/pil/arith_table.pil b/state-machines/arith/pil/arith_table.pil index d2225d24..9cdcd894 100644 --- a/state-machines/arith/pil/arith_table.pil +++ b/state-machines/arith/pil/arith_table.pil @@ -2,130 +2,201 @@ require "std_lookup.pil" const int ARITH_TABLE_ID = 331; -airtemplate ArithTable(int N = 2**6) { +airtemplate ArithTable(int N = 2**8, int generate_table = 1) { // TABLE - // op - // m32|div|na|nb|nr|np|na32|nd32|range_a1(*)|range_b1(*)|range_c1(*)|range_d1(*)|range_a3(*)|range_b3(*)|range_c3(*)|range_d3(*) - - // div m32 sa sb comm primary secondary opcodes na nb nr np na32 nd32 - // ---------------------------------------------------------------------------------- - // 0 0 0 0 x mulu muluh (0xb0,0xb1) =0 =0 =0 =0 =0 =0 - // 0 0 1 0 *n/a* mulsuh (0xb2,0xb3) a3 =0 =0 d3 =0 =0 a3, d3 - // 0 0 1 1 x mul mulh (0xb4,0xb5) a3 b3 =0 d3 =0 =0 a3,b3, d3 - // 0 1 1 1 x mul_w *n/a* (0xb6,0xb7) a1 b1 =0 d3 c1 =0 d3, a1,b1,c1 - // 1 0 0 0 divu remu (0xb8,0xb9) =0 =0 =0 =0 =0 =0 - // 1 0 1 1 div rem (0xba,0xbb) a3 b3 d3 c3 =0 =0 a3,b3,c3,d3 - // 1 1 0 0 divu_w remu_w (0xbc,0xbd) =0 =0 =0 =0 c1 d1 c1,d1 - // 1 1 1 1 div_w rem_w (0xbe,0xbf) a1 b1 d1 c1 c1 d1 a1,b1,c1,d1 + // op,m32|div|na|nb|np|nr|sext,range_ab,range_cd + + // div m32 sa sb primary secondary opcodes na nb np nr sext(c) + // ----------------------------------------------------------------------------- + // 0 0 0 0 mulu muluh (0xb0,0xb1) =0 =0 =0 =0 =0 =0 + // 0 0 1 0 *n/a* mulsuh (0xb2,0xb3) a3 =0 d3 =0 =0 =0 a3, d3 + // 0 0 1 1 mul mulh (0xb4,0xb5) a3 b3 d3 =0 =0 =0 a3,b3, d3 + // 0 1 0 0 mul_w *n/a* (0xb6,0xb7) =0 =0 =0 =0 c1 =0 + + // div m32 sa sb primary secondary opcodes na nb np nr sext(a,d)(*2) + // ------------------------------------------------------------------------------ + // 1 0 0 0 divu remu (0xb8,0xb9) =0 =0 =0 =0 =0 =0 + // 1 0 1 1 div rem (0xba,0xbb) a3 b3 c3 d3 =0 =0 a3,b3,c3,d3 + // 1 1 0 0 divu_w remu_w (0xbc,0xbd) =0 =0 =0 =0 a1 d1 a1 ,d1 + // 1 1 1 1 div_w rem_w (0xbe,0xbf) a1 b1 c1 d1 a1 d1 a1,b1,c1,d1 const int OPS[14] = [0xb0, 0xb1, 0xb3, 0xb4, 0xb5, 0xb6, 0xb8, 0xb9, 0xba, 0xbb, 0xbc, 0xbd, 0xbe, 0xbf]; col fixed OP; - col fixed FLAGS_AND_RANGES; + col fixed FLAGS; + col fixed RANGE_AB; + col fixed RANGE_CD; int index = 0; - int size = 0; - while (index < N) { - for (int iop = 0; iop < length(OPS); ++iop) { - int opcode = OPS[iop]; - int m32 = 0; // 32 bits operation - int div = 0; // division operation (div,rem) - int sa = 0; - int sb = 0; - - switch (opcode & 0xFE) { - case 0xb2: // mulsuh - sa = 1; - case 0xb4: // mul, mulh - sa = 1; - sb = 1; - case 0xb6: // mul_w - m32 = 1; - sa = 1; - sb = 1; - case 0xb8: // divu, remu - div = 1; - case 0xba: // div, rem - sa = 1; - sb = 1; - div = 1; - case 0xbc: // divu_w, remu_w - div = 1; - m32 = 1; - case 0xbe: // div_w, rem_w - sa = 1; - sb = 1; - div = 1; - m32 = 1; - } + int aborted = 0; + + if (generate_table) { + int air.op2row[512]; + for (int i = 0; i < 512; ++i) { + op2row[i] = -1; + } + } - // CASES: - // sa = 0 sb = 0 => [a >= 0, b >= 0] - // sa = 1 sb = 0 => [a >= 0, b >= 0], [a < 0, b >= 0] - // sa = 1 sb = 1 => [a >= 0, b >= 0], [a < 0, b >= 0], [a >= 0, b < 0], [a < 0, b < 0] + for (int opcode = 0xb0; opcode <= 0xbf; ++opcode) { + if (opcode == 0xb2 || opcode == 0xb7) { + continue; + } + int m32 = 0; // 32 bits operation + int div = 0; // division operation (div,rem) + int sa = 0; + int sb = 0; + int secondary = 0; + string opname = ""; + switch (opcode) { + case 0xb0: + opname = "mulu"; + case 0xb1: + opname = "mulh"; + secondary = 1; + case 0xb3: + opname = "mulsuh"; + sa = 1; + case 0xb4: + opname = "mul"; + sa = 1; + sb = 1; + case 0xb5: + opname = "mulh"; + sa = 1; + sb = 1; + secondary = 1; + case 0xb6: + opname = "mul_w"; + m32 = 1; + sa = 1; + sb = 1; + case 0xb8: + opname = "divu"; + div = 1; + case 0xb9: + opname = "remu"; + div = 1; + secondary = 1; + case 0xba: + opname = "div"; + sa = 1; + sb = 1; + div = 1; + case 0xbb: + opname = "rem"; + sa = 1; + sb = 1; + div = 1; + secondary = 1; + case 0xbc: + opname = "divu_w"; + div = 1; + m32 = 1; + case 0xbd: + opname = "remu_w"; + div = 1; + m32 = 1; + secondary = 1; + case 0xbe: + opname = "div_w"; + sa = 1; + sb = 1; + div = 1; + m32 = 1; + case 0xbf: + opname = "rem_w"; + sa = 1; + sb = 1; + div = 1; + m32 = 1; + secondary = 1; + } - int cases = 1 + sa + sb + sa * sb; + for (int icase = 0; icase < 32; ++icase) { + int na = 0x01 & icase ? 1 : 0; + int nb = 0x02 & icase ? 1 : 0; + int np = 0x04 & icase ? 1 : 0; + int nr = 0x08 & icase ? 1 : 0; + int sext = 0x10 & icase ? 1 : 0; - println("#ARITH_TABLE", opcode, index, cases); + if (sext && !m32) continue; + if (nr && !div) continue; + if (na && !sa) continue; + if (nb && !sb) continue; + if (np && !sa && !sb) continue; + if (nr && !sa && !sb) continue; + if (np && !na && !nb && !div) continue; + if (np && na && nb) continue; + if (!np & nr) continue; + int range_a1 = m32 * sa ? 1 + na : 0; + int range_b1 = m32 * sb ? 1 + nb : 0; + + int range_c1 = 0; + if (m32) { + if (div) { + // range_c1 = np || na32 ? 2 : 1; + } else { + // range_c1 = 1 + na32; + } + } + int range_d1 = m32 * div ? (((np * sa) || sext) ? 1:2) : 0; + + int range_a3 = (1 - m32) * sa ? 1 + na : 0; + int range_b3 = (1 - m32) * sb ? 1 + na : 0; + int range_c3 = div * (1 - m32) * sa ? 1 + np : 0; + int range_d3 = div * (1 - m32) * sa ? 1 + np : 0; + + if (generate_table) { + op2row[(opcode - 0xb0) * 32 + icase] = index; + } + println(`==> #${index} op:${opname} [${opcode}] na:${na} nb:${nb} np:${np} nr:${nr} sext:${sext} sa:${sa} sb:${sb} secondary:${secondary}`); + ++index; + } + } +/* + int cases = 1 + 2 * sa + 2 * sb + sa * sb; + for (int sext = 0; sext < (1 + m32); ++sext) { for (int icase = 0; icase < cases; ++icase) { - int na = 0; // a is negative - int nb = 0; // b is negative - int np = 0; // prod is negative - int nr = 0; // rem is negative - int na32 = 0; // a is 32-bit negative, 31th bit is 1. - int nd32 = 0; // d is 32-bit negative, 31th bit is 1. + + int na = 0; // a is negative + int nb = 0; // b is negative + int a_is_zero = 0; // a is zero + int b_is_zero = 0; // b is zero + int np = 0; // prod is negative + int nr = 0; // rem is negative + int abort_case = 0; // if abort copy values of row 0 switch (icase) { + // case 0: [a >= 0, b >= 0] case 1: + // [a < 0, b > 0] na = 1; case 2: - nb = 1; + // [a < 0, b = 0] + na = 1; + b_is_zero = 1; case 3: + // [a > 0, b < 0] + nb = 1; + case 4: + // [a = 0, b < 0] + nb = 1; + a_is_zero = 1; + case 5: + // [a < 0, b < 0] na = 1; nb = 1; } + if ((div && b_is_zero) || (sext && (a_is_zero || b_is_zero))) { + abort_case = 1; + } np = na + nb - na * nb; + if (np && (a_is_zero || b_is_zero)) { + // - * 0 = 0 (no negative) + np = 0; + } nr = div ? na : 0; - na32 = m32 ? na : 0; - nd32 = m32 ? nr : 0; - - // negative a,c,d,na32,nd32 must be 0 if no signed_a - // na * (1 - sa) === 0; - // nr * (1 - sa) === 0; - // nr * (1 - div) === 0; - // np * (1 - sa) === 0; - // na32 * (1 - sa) === 0; - // nd32 * (1 - sa) === 0; - - // negative b must be 0 if no signed_b - // nb * (1 - sb) === 0; - - // na32, nd32 only available when 32 bits operation - // na32 * (1 - m32) === 0; - // nd32 * (1 - m32) === 0; - - // nr, nd32 only could be one 1 in divisions - // nr * (1 - div) === 0; - // nd32 * (1 - div) === 0; - - // if sb === 1 then sa must be 1, not allowed sa = 0, sb = 1 - // sb * (1 - sa) === 0; - // m32 * (sa - sb) === 0; - // div * (sa - sb) === 0; - // (1 - div) * m32 * (1 - sa) === 0; - // (1 - div) * m32 * (1 - sb) === 0; - - // div m32 sa sb comm primary secondary opcodes na nb nr np na32 nd32 - // ---------------------------------------------------------------------------------- - // 0 0 0 0 x mulu muluh (0xb0,0xb1) =0 =0 =0 =0 =0 =0 - // 0 0 1 0 *n/a* mulsuh (0xb2,0xb3) a3 =0 =0 d3 =0 =0 a3, d3 - // 0 0 1 1 x mul mulh (0xb4,0xb5) a3 b3 =0 d3 =0 =0 a3,b3, d3 - // 0 1 1 1 x mul_w *n/a* (0xb6,0xb7) a1 b1 =0 d3 c1 =0 d3, a1,b1,c1 - // 1 0 0 0 divu remu (0xb8,0xb9) =0 =0 =0 =0 =0 =0 - // 1 0 1 1 div rem (0xba,0xbb) a3 b3 d3 c3 =0 =0 a3,b3,c3,d3 - // 1 1 0 0 divu_w remu_w (0xbc,0xbd) =0 =0 =0 =0 c1 d1 c1,d1 - // 1 1 1 1 div_w rem_w (0xbe,0xbf) a1 b1 d1 c1 c1 d1 a1,b1,c1,d1 int range_a1 = m32 * sa ? 1 + na : 0; int range_b1 = m32 * sb ? 1 + nb : 0; @@ -133,41 +204,68 @@ airtemplate ArithTable(int N = 2**6) { int range_c1 = 0; if (m32) { if (div) { - range_c1 = np || na32 ? 2 : 1; + // range_c1 = np || na32 ? 2 : 1; } else { - range_c1 = 1 + na32; + // range_c1 = 1 + na32; } } - int range_d1 = m32 * div ? (((np * sa) || nd32) ? 1:2) : 0; + int range_d1 = m32 * div ? (((np * sa) || sext) ? 1:2) : 0; int range_a3 = (1 - m32) * sa ? 1 + na : 0; int range_b3 = (1 - m32) * sb ? 1 + na : 0; int range_c3 = div * (1 - m32) * sa ? 1 + np : 0; int range_d3 = div * (1 - m32) * sa ? 1 + np : 0; - OP[index] = opcode; - FLAGS_AND_RANGES[index] = m32 + 2 * div + 4 * na + 8 * nb + 16 * np + 32 * nr + 64 * na32 + 128 * nd32 + - 2**8 * range_a1 + 2**10 * range_b1 + 2**12 * range_c1 + 2**14 * range_d1 + - 2**16 * range_a3 + 2**18 * range_b3 + 2**20 * range_c3 + 2**22 * range_d3; + + if (abort_case) { + println(`ABORT op:${opcode} sa:${sa} sb:${sb} a_is_zero:${a_is_zero} b_is_zero:${b_is_zero} na:${na} nb:${nb} np:${np} nr:${nr} sext:${sext} secondary:${secondary}`); + OP[index] = OP[0]; + FLAGS[index] = FLAGS[0]; + RANGE_AB[index] = RANGE_AB[0]; + RANGE_CD[index] = RANGE_CD[0]; + ++aborted; + } else { + println(`ADD op:${opcode} sa:${sa} sb:${sb} a_is_zero:${a_is_zero} b_is_zero:${b_is_zero} na:${na} nb:${nb} np:${np} nr:${nr} sext:${sext} secondary:${secondary}`); + OP[index] = opcode; + FLAGS[index] = m32 + 2 * div + 4 * na + 8 * nb + 16 * np + 32 * nr + 64 * sext + 128 * secondary; + RANGE_AB[index] = 0; + RANGE_CD[index] = 0; + } index = index + 1; - if (index == N) break; } - if (index == N) break; } - if (size == 0) size = index; - } + }*/ + const int size = index; println("ARITH_TABLE SIZE: ", size); - println("ARITH_FLAGS: ", FLAGS_AND_RANGES); + println("ARITH_TABLE ABORTED: ", aborted); + println("ARITH_FLAGS: ", FLAGS); + if (generate_table) { + println("let arith_table_rows: [i16; 512] = [", op2row, "];"); + } + return; for (index = 0; index < size; ++index) { - println(FLAGS_AND_RANGES[index]); + println(FLAGS[index]); } + // padding repeat first row + + const int padding_op = OP[0]; + const int padding_flags = FLAGS[0]; + const int padding_range_ab = RANGE_AB[0]; + const int padding_range_cd = RANGE_CD[0]; + + for (index = size; index < N; ++index) { + OP[index] = padding_op; + FLAGS[index] = padding_flags; + RANGE_AB[index] = padding_range_ab; + RANGE_CD[index] = padding_range_cd; + } col witness multiplicity; // TODO: - lookup_proves(ARITH_TABLE_ID, mul: multiplicity, cols: [OP, FLAGS_AND_RANGES, 0, 0]); + lookup_proves(ARITH_TABLE_ID, mul: multiplicity, cols: [OP, FLAGS, RANGE_AB, RANGE_CD]); } function arith_table_assumes( const expr op, const expr flag_m32, const expr flag_div, const expr flag_na, diff --git a/state-machines/arith/src/arith_helpers.rs b/state-machines/arith/src/arith_helpers.rs index 83d775f1..6b5cfdf8 100644 --- a/state-machines/arith/src/arith_helpers.rs +++ b/state-machines/arith/src/arith_helpers.rs @@ -18,6 +18,38 @@ const REM_W: u8 = 0xbf; const FLAG_NAMES: [&str; 8] = ["m32", "div", "na", "nb", "np", "nr", "sext", "sec"]; pub trait ArithHelpers { + fn get_row(op: u8, na: u64, nb: u64, np: u64, nr: u64, sext: u64) -> i16 { + static arith_table_rows: [i16; 512] = [ + 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, 2, 3, -1, -1, -1, 4, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 5, 6, 7, 8, -1, + 9, 10, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, 11, 12, 13, 14, -1, 15, 16, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 17, 18, 19, 20, -1, 21, 22, + -1, -1, -1, -1, -1, -1, -1, -1, -1, 23, 24, 25, 26, -1, 27, 28, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 29, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, 30, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 31, 32, 33, 34, 35, 36, 37, -1, -1, -1, -1, + -1, 38, 39, 40, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 41, + 42, 43, 44, 45, 46, 47, -1, -1, -1, -1, -1, 48, 49, 50, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, 51, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, 52, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 53, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 54, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, 55, 56, 57, 58, 59, 60, 61, -1, -1, -1, -1, -1, 62, 63, 64, + -1, 65, 66, 67, 68, 69, 70, 71, -1, -1, -1, -1, -1, 72, 73, 74, -1, 75, 76, 77, 78, 79, + 80, 81, -1, -1, -1, -1, -1, 82, 83, 84, -1, 85, 86, 87, 88, 89, 90, 91, -1, -1, -1, -1, + -1, 92, 93, 94, -1, + ]; + + let index = (op - 0xb0) as u64 * 32 + na + nb * 2 + np * 4 + nr * 8 + sext * 16; + arith_table_rows[index as usize] + } + // arith_sm fn sign32(abs_value: u64, negative: bool) -> u64 { assert!(0xFFFF_FFFF >= abs_value, "abs_value:0x{0:X}({0}) is too big", abs_value); if negative { @@ -57,16 +89,7 @@ pub trait ArithHelpers { [abs_value, negative] } fn calculate_mul_w(a: u64, b: u64) -> u64 { - a * b - // let [abs_a, na] = Self::abs32(a); - // let [abs_b, nb] = Self::abs32(b); - // let abs_c = abs_a * abs_b; - // let nc = if na != nb && abs_c != 0 { 1 } else { 0 }; - // println!( - // "a:0x{0:X}({0}) b:0x{1:X}({1}) abs_a:0x{2:X}({2}) na:{3} abs_b:0x{4:X}({4}) nb:{5} abs_b:0x{6:X}({6}) nc:{7}", - // a, b, abs_a, na, abs_b, nb, abs_c, nc - // ); - // Self::sign64(abs_c, nc == 1) + (a & 0xFFFF_FFFF) * (b & 0xFFFF_FFFF) } fn calculate_mulsu(a: u64, b: u64) -> [u64; 2] { @@ -1525,6 +1548,11 @@ fn test_calculate_range_checks() { let flags = m32 + div * 2 + na * 4 + nb * 8 + np * 16 + nr * 32 + sext * 64 + sec * 128; + let row = TestArithHelpers::get_row(test.op, na, nb, np, nr, sext); + println!( + "#{} op:0x{:x} na:{} nb:{} np:{} nr:{} sext:{}", + row, test.op, na, nb, np, nr, sext + ); assert_eq!( [flags, range_ab, range_cd], [test.flags, test.range_ab, test.range_cd], @@ -1568,6 +1596,7 @@ fn test_calculate_range_checks() { test.range_cd, ranges ); + assert_ne!(row, -1); let bus: [u64; 8] = [ test.op as u64, _a & 0xFFFF_FFFF, From 334a41aad23260ec021c2da11fa67029c9f37aaf Mon Sep 17 00:00:00 2001 From: zkronos73 Date: Wed, 30 Oct 2024 20:20:18 +0000 Subject: [PATCH 14/17] WIP arith helper tests --- state-machines/arith/src/arith_helpers.rs | 61 ++++++++++++++++++++++- 1 file changed, 60 insertions(+), 1 deletion(-) diff --git a/state-machines/arith/src/arith_helpers.rs b/state-machines/arith/src/arith_helpers.rs index 6b5cfdf8..eac6ffdb 100644 --- a/state-machines/arith/src/arith_helpers.rs +++ b/state-machines/arith/src/arith_helpers.rs @@ -190,6 +190,24 @@ pub trait ArithHelpers { } } } + fn decode_one_range(range_xy: u64) -> [u64; 4] { + if range_xy == 9 { + [0, 0, 0, 0] + } else if range_xy > 9 { + let x = (range_xy - 8) / 3; + let y = (range_xy - 8) % 3; + [0, 0, x, y] + } else { + let x = range_xy / 3; + let y = range_xy % 3; + [x, y, 0, 0] + } + } + fn decode_ranges(range_ab: u64, range_cd: u64) -> [u64; 8] { + let ab = Self::decode_one_range(range_ab); + let cd = Self::decode_one_range(range_cd); + [ab[0], ab[1], cd[0], cd[1], ab[2], ab[3], cd[2], cd[3]] + } fn calculate_flags_and_ranges(op: u8, a: u64, b: u64, c: u64, d: u64) -> [u64; 11] { let mut m32: u64 = 0; let mut div: u64 = 0; @@ -534,6 +552,8 @@ pub trait ArithHelpers { nr: u64, secondary_res: u64, sext: u64, + range_ab: u64, + range_cd: u64, bus: [u64; 8], ) -> bool { let fab: i64 = 1 - 2 * na as i64 - 2 * nb as i64 + 4 * na as i64 * nb as i64; @@ -684,8 +704,47 @@ pub trait ArithHelpers { println!("[\x1B[32mOK BUS\x1B[0m]"); true }; + // check all chunks and carries + let carry_min_value: i64 = -0x0F_FFFF; + let carry_max_value: i64 = 0x0F_FFFF; + for index in 0..8 { + passed = passed + && if carrys[index] > carry_max_value || carrys[index] < carry_min_value { + println!("[\x1B[31mFAIL CARRY RANGE CHECK\x1B[0m]"); + false + } else { + println!("[\x1B[32mOK CARRY RANGE CHECK\x1B[0m]"); + true + }; + } + let ranges = Self::decode_ranges(range_ab, range_cd); + Self::check_range(0, a_chunks[0]); + Self::check_range(0, b_chunks[0]); + Self::check_range(0, c_chunks[0]); + Self::check_range(0, d_chunks[0]); + + Self::check_range(ranges[4], a_chunks[1]); + Self::check_range(ranges[5], b_chunks[1]); + Self::check_range(ranges[6], c_chunks[1]); + Self::check_range(ranges[7], d_chunks[1]); + + Self::check_range(0, a_chunks[2]); + Self::check_range(0, b_chunks[2]); + Self::check_range(0, c_chunks[2]); + Self::check_range(0, d_chunks[2]); + + Self::check_range(ranges[0], a_chunks[3]); + Self::check_range(ranges[1], b_chunks[3]); + Self::check_range(ranges[2], c_chunks[3]); + Self::check_range(ranges[3], d_chunks[3]); + passed } + fn check_range(range_id: u64, value: i64) { + assert!(range_id != 0 || (value >= 0 && value <= 0xFFFF)); + assert!(range_id != 1 || (value >= 0 && value <= 0x7FFF)); + assert!(range_id != 2 || (value >= 0x8000 && value <= 0xFFFF)); + } } fn flags_to_strings(mut flags: u64, flag_names: &[&str]) -> String { @@ -1608,7 +1667,7 @@ fn test_calculate_range_checks() { if emu_flag { 1 } else { 0 }, ]; if !TestArithHelpers::execute_chunks( - a, b, c, d, m32, div, na, nb, np, nr, sec, sext, bus, + a, b, c, d, m32, div, na, nb, np, nr, sec, sext, range_ab, range_cd, bus, ) { errors += 1; println!("TOTAL ERRORS: {}", errors); From 54362c300fdc903bc202e391e92a597088bec659 Mon Sep 17 00:00:00 2001 From: zkronos73 Date: Mon, 4 Nov 2024 11:23:48 +0000 Subject: [PATCH 15/17] WIP arith - test ok --- pil/src/pil_helpers/pilout.rs | 14 +- pil/src/pil_helpers/traces.rs | 8 +- pil/zisk.pil | 5 +- state-machines/arith/pil/arith.pil | 137 +- .../arith/pil/arith_range_table.pil | 76 +- state-machines/arith/pil/arith_table.pil | 166 +- state-machines/arith/src/arith_constants.rs | 14 + state-machines/arith/src/arith_full.rs | 70 +- state-machines/arith/src/arith_helpers.rs | 1683 ----------------- state-machines/arith/src/arith_operation.rs | 609 ++++++ .../arith/src/arith_operation_test.rs | 1115 +++++++++++ .../arith/src/arith_range_table_helpers.rs | 133 ++ .../arith/src/arith_table_helpers.rs | 71 + state-machines/arith/src/lib.rs | 10 +- .../binary/pil/binary_extension.pil | 9 +- 15 files changed, 2228 insertions(+), 1892 deletions(-) create mode 100644 state-machines/arith/src/arith_constants.rs delete mode 100644 state-machines/arith/src/arith_helpers.rs create mode 100644 state-machines/arith/src/arith_operation.rs create mode 100644 state-machines/arith/src/arith_operation_test.rs create mode 100644 state-machines/arith/src/arith_range_table_helpers.rs create mode 100644 state-machines/arith/src/arith_table_helpers.rs diff --git a/pil/src/pil_helpers/pilout.rs b/pil/src/pil_helpers/pilout.rs index 0dc67191..31a83907 100644 --- a/pil/src/pil_helpers/pilout.rs +++ b/pil/src/pil_helpers/pilout.rs @@ -22,9 +22,7 @@ pub const BINARY_EXTENSION_AIRGROUP_ID: usize = 6; pub const BINARY_EXTENSION_TABLE_AIRGROUP_ID: usize = 7; -pub const U_16_AIR_AIRGROUP_ID: usize = 8; - -pub const SPECIFIED_RANGES_AIRGROUP_ID: usize = 9; +pub const SPECIFIED_RANGES_AIRGROUP_ID: usize = 8; //AIR CONSTANTS @@ -44,8 +42,6 @@ pub const BINARY_EXTENSION_AIR_IDS: &[usize] = &[0]; pub const BINARY_EXTENSION_TABLE_AIR_IDS: &[usize] = &[0]; -pub const U_16_AIR_AIR_IDS: &[usize] = &[0]; - pub const SPECIFIED_RANGES_AIR_IDS: &[usize] = &[0]; pub struct Pilout; @@ -64,11 +60,11 @@ impl Pilout { let air_group = pilout.add_air_group(Some("ArithTable")); - air_group.add_air(Some("ArithTable"), 64); + air_group.add_air(Some("ArithTable"), 128); let air_group = pilout.add_air_group(Some("ArithRangeTable")); - air_group.add_air(Some("ArithRangeTable"), 131072); + air_group.add_air(Some("ArithRangeTable"), 4194304); let air_group = pilout.add_air_group(Some("Binary")); @@ -86,10 +82,6 @@ impl Pilout { air_group.add_air(Some("BinaryExtensionTable"), 4194304); - let air_group = pilout.add_air_group(Some("U16Air")); - - air_group.add_air(Some("U16Air"), 65536); - let air_group = pilout.add_air_group(Some("SpecifiedRanges")); air_group.add_air(Some("SpecifiedRanges"), 16777216); diff --git a/pil/src/pil_helpers/traces.rs b/pil/src/pil_helpers/traces.rs index 5e3fb452..60516af5 100644 --- a/pil/src/pil_helpers/traces.rs +++ b/pil/src/pil_helpers/traces.rs @@ -8,7 +8,7 @@ trace!(Main0Row, Main0Trace { }); trace!(Arith0Row, Arith0Trace { - carry: [F; 7], a: [F; 4], b: [F; 4], c: [F; 4], d: [F; 4], na: F, nb: F, nr: F, np: F, na32: F, nd32: F, m32: F, div: F, fab: F, debug_main_step: F, secondary_res: F, op: F, bus_a_low: F, bus_a_high: F, bus_b_high: F, res1_low: F, div64: F, res1_high: F, multiplicity: F, range_a1: F, range_b1: F, range_c1: F, range_d1: F, range_a3: F, range_b3: F, range_c3: F, range_d3: F, + carry: [F; 7], a: [F; 4], b: [F; 4], c: [F; 4], d: [F; 4], na: F, nb: F, nr: F, np: F, sext: F, m32: F, div: F, fab: F, na_fb: F, nb_fa: F, debug_main_step: F, main_div: F, main_mul: F, signed: F, op: F, bus_res1: F, multiplicity: F, range_ab: F, range_cd: F, }); trace!(ArithTable0Row, ArithTable0Trace { @@ -35,10 +35,6 @@ trace!(BinaryExtensionTable0Row, BinaryExtensionTable0Trace { multiplicity: F, }); -trace!(U16Air0Row, U16Air0Trace { - mul: F, -}); - trace!(SpecifiedRanges0Row, SpecifiedRanges0Trace { - mul: [F; 2], + mul: [F; 1], }); diff --git a/pil/zisk.pil b/pil/zisk.pil index e76efaa8..ce861830 100644 --- a/pil/zisk.pil +++ b/pil/zisk.pil @@ -9,6 +9,7 @@ require "arith/pil/arith.pil" // require "mem/pil/mem.pil" const int OPERATION_BUS_ID = 5000; + airgroup Main { Main(N: 2**21, RC: 2, operation_bus_id: OPERATION_BUS_ID); } @@ -28,7 +29,7 @@ airgroup ArithTable { airgroup ArithRangeTable { ArithRangeTable(); } -/* + airgroup Binary { Binary(N: 2**21, operation_bus_id: OPERATION_BUS_ID); } @@ -43,4 +44,4 @@ airgroup BinaryExtension { airgroup BinaryExtensionTable { BinaryExtensionTable(disable_fixed: 0); -}*/ \ No newline at end of file +} \ No newline at end of file diff --git a/state-machines/arith/pil/arith.pil b/state-machines/arith/pil/arith.pil index bbf8a2b8..968c94e8 100644 --- a/state-machines/arith/pil/arith.pil +++ b/state-machines/arith/pil/arith.pil @@ -7,6 +7,8 @@ require "arith_range_table.pil" // full mul_64 full_32 mul_32 // TOTAL 88 77 57 44 +const int OP_LT_ABS = 0x9F; + airtemplate Arith(int N = 2**18, const int operation_bus_id, const int dual_result = 0) { // TODO: const int enable_div = 1, const int enable_32_bits = 1, const int enable_64_bits = 1 @@ -40,9 +42,17 @@ airtemplate Arith(int N = 2**18, const int operation_bus_id, const int dual_resu col witness nb_fa; col witness debug_main_step; // only for debug - - col witness secondary_res; // op_index: 0 => first result, 1 => second result; - secondary_res * (secondary_res - 1) === 0; +/* + col witness secondary; // op_index: 0 => first result, 1 => second result; + secondary * (secondary - 1) === 0; +*/ + col witness main_div; + col witness main_mul; + col witness signed; + main_div * (main_div - 1) === 0; + main_mul * (main_mul - 1) === 0; + main_mul * main_div === 0; + signed * (1 - signed) === 0; // factor ab € {-1, 1} fab === 1 - 2 * na - 2 * nb + 4 * na * nb; @@ -203,6 +213,7 @@ airtemplate Arith(int N = 2**18, const int operation_bus_id, const int dual_resu // mul (mulh) c d // div (remu) a d +/* col witness bus_a_low; bus_a_low === div * (c[0] + c[1] * CHUNK_SIZE) + (1 - div) * (a[0] + a[1] * CHUNK_SIZE); @@ -214,23 +225,22 @@ airtemplate Arith(int N = 2**18, const int operation_bus_id, const int dual_resu const expr bus_b_low = b[0] + CHUNK_SIZE * b[1]; const expr bus_b_high = b[2] + CHUNK_SIZE * b[3]; - m32 * (1 - b[2]) === 0; - m32 * (1 - b[3]) === 0; + m32 * (1 - bus_b_high) === 0; const expr res2_low = d[0] + CHUNK_SIZE * d[1]; const expr res2_high = d[2] + CHUNK_SIZE * d[3]; col witness res_low; - res_low === secondary_res * res2_low + (1 - secondary_res) * (a[0] + c[0] + CHUNK_SIZE * (a[1] + c[1]) - bus_a_low); + res_low === secondary * res2_low + (1 - secondary) * (a[0] + c[0] + CHUNK_SIZE * (a[1] + c[1]) - bus_a_low); col witness res_high; - res_high === (1 - m32) * (secondary_res * res2_high + (1 - secondary_res) * (a[2] + c[2] + CHUNK_SIZE * (a[3] + c[3]) - bus_a_high)) + res_high === (1 - m32) * (secondary * res2_high + (1 - secondary) * (a[2] + c[2] + CHUNK_SIZE * (a[3] + c[3]) - bus_a_high)) + sext * 0xFFFFFFFF; col witness multiplicity; lookup_proves(operation_bus_id, [debug_main_step, - op + secondary_res, + op + secondary, bus_a_low, bus_a_high, bus_b_low, bus_b_high, res_low, res_high, @@ -238,69 +248,76 @@ airtemplate Arith(int N = 2**18, const int operation_bus_id, const int dual_resu // TODO: review - lookup_assumes(operation_bus_id, [debug_main_step, OP_LT, res2_low, res2_high, bus_b_low, bus_b_high, 0, 1, 1], sel: div); + lookup_assumes(operation_bus_id, [debug_main_step, sign * (OP_LT_ABS - OP_LT) + OP_LT, + res2_low, res2_high + m32 * nr * 0xFFFFFFFF, + bus_b_low, bus_b_high + m32 * nb * 0xFFFFFFFF, + 1, 0, 1], sel: div); - for (int index = 0; index < length(carry); ++index) { - range_check(colu: carry[index], min:-2**20, max: 2**20-1); // TODO: review carry range - } +*/ + const expr secondary = 1 - main_mul - main_div; + const expr bus_a0 = div * (c[0] + c[1] * CHUNK_SIZE) + (1 - div) * (a[0] + a[1] * CHUNK_SIZE); + const expr bus_a1 = div * (c[2] + c[3] * CHUNK_SIZE) + (1 - div) * (a[2] + a[3] * CHUNK_SIZE); - // mul a * b = c + d * 2^64 - // div a * b + d = c (a <=> c) + const expr bus_b0 = b[0] + CHUNK_SIZE * b[1]; + const expr bus_b1 = b[2] + CHUNK_SIZE * b[3]; + + const expr bus_res0 = secondary * (d[0] + CHUNK_SIZE * d[1]) + + main_mul * (c[0] + c[1] * CHUNK_SIZE) + + main_div * (a[0] + a[1] * CHUNK_SIZE); + + const expr bus_res1_64 = (secondary * (d[2] + CHUNK_SIZE * d[3]) + + main_mul * (c[2] + c[3] * CHUNK_SIZE) + + main_div * (a[2] + a[3] * CHUNK_SIZE)); + col witness bus_res1; + + bus_res1 === sext * 0xFFFFFFFF + (1 - m32) * bus_res1_64; + + m32 * (1 - bus_a1) === 0; + m32 * (1 - bus_b1) === 0; - // range_ab / range_cd - // - // a3 a1 b3 b1 - // rid c3 c1 d3 d1 range 2^16 2^15 notes - // --- -- -- -- -- ----- ---- ---- ------------------------- - // 0 F F F F ab cd 4 0 - // 1 F F + F cd 3 1 b3 sign => a3 sign - // 2 F F - F cd 3 1 b3 sign => a3 sign - // 3 + F F F ab 3 1 c3 sign => d3 sign - // 4 + F + F ab cd 2 2 - // 5 + F - F ab cd 2 2 - // 6 - F F F ab 3 1 c3 sign => d3 sign - // 7 - F + F ab cd 2 2 - // 8 - F - F ab cd 2 2 - // 9 F F F + cd a1 sign <=> b1 sign / d1 sign => c1 sign - // 10 F F F - cd a1 sign <=> b1 sign / d1 sign => c1 sign - // 11 F + F F ab cd 3 1 *a1 for sext/divu - // 12 F + F + ab cd 2 2 - // 13 F + F - ab cd 2 2 - // 14 F - F F ab cd 3 1 *a1 for sext/divu - // 15 F - F + ab cd 2 2 - // 16 F - F - ab cd 2 2 - // ---- ---- - // 38 22 = 60 - // - // F: [0..2^16-1] +:[0..2^15-1] -:[2^15..2^16-1] - // - // 22 * 2^15 + 38 * 2^16 = (11+38) * 2^16 = 49 * 2^16 < 2^6 * 2^16 ==> 22 bits + + col witness multiplicity; + + lookup_proves(operation_bus_id, [debug_main_step, + op + secondary, + bus_a0, bus_a1, + bus_b0, bus_b1, + bus_res0, bus_res1, + 0], mul: multiplicity); + + lookup_assumes(operation_bus_id, [debug_main_step, signed * (OP_LT_ABS - OP_LT) + OP_LT, + (d[0] + CHUNK_SIZE * d[1]), (d[2] + CHUNK_SIZE * d[3]) + m32 * nr * 0xFFFFFFFF, + (b[0] + CHUNK_SIZE * b[1]), (b[2] + CHUNK_SIZE * b[3]) + m32 * nb * 0xFFFFFFFF, + 1, 0, 1], sel: div); + + for (int index = 0; index < length(carry); ++index) { + arith_range_table_assumes(ARITH_RANGE_CARRY, carry[index]); // TODO: review carry range + } col witness range_ab; col witness range_cd; - arith_table_assumes(op, m32, div, na, nb, np, nr, sext, secondary_res, range_ab, range_cd); + arith_table_assumes(op, m32, div, na, nb, np, nr, sext, main_mul, main_div, signed, range_ab, range_cd); - // 0 - a1/c1 - // 1 - b1/d1 - // 2 - a3/c3 - // 3 - b3/d3 + const expr range_a = range_ab; + const expr range_b = range_ab + 26; + const expr range_c = range_cd + 17; + const expr range_d = range_cd + 9; - arith_range_table_assumes(range_ab, 0, a[1]); - arith_range_table_assumes(range_ab, 1, b[1]); - arith_range_table_assumes(range_cd, 0, c[1]); - arith_range_table_assumes(range_cd, 1, d[1]); - arith_range_table_assumes(range_ab, 2, a[3]); - arith_range_table_assumes(range_ab, 3, b[3]); - arith_range_table_assumes(range_cd, 2, c[3]); - arith_range_table_assumes(range_cd, 3, d[3]); + arith_range_table_assumes(range_a, a[1]); + arith_range_table_assumes(range_b, b[1]); + arith_range_table_assumes(range_c, c[1]); + arith_range_table_assumes(range_d, d[1]); + arith_range_table_assumes(range_a, a[3]); + arith_range_table_assumes(range_b, b[3]); + arith_range_table_assumes(range_c, c[3]); + arith_range_table_assumes(range_d, d[3]); // loop for range checks index 0, 2 for (int index = 0; index < 2; ++index) { - arith_range_table_assumes(0, 0, a[2 * index]); - arith_range_table_assumes(0, 0, b[2 * index]); - arith_range_table_assumes(0, 0, c[2 * index]); - arith_range_table_assumes(0, 0, d[2 * index]); + arith_range_table_assumes(ARITH_RANGE_16_BITS, a[2 * index]); + arith_range_table_assumes(ARITH_RANGE_16_BITS, b[2 * index]); + arith_range_table_assumes(ARITH_RANGE_16_BITS, c[2 * index]); + arith_range_table_assumes(ARITH_RANGE_16_BITS, d[2 * index]); } - } diff --git a/state-machines/arith/pil/arith_range_table.pil b/state-machines/arith/pil/arith_range_table.pil index 7a03c0f9..42369708 100644 --- a/state-machines/arith/pil/arith_range_table.pil +++ b/state-machines/arith/pil/arith_range_table.pil @@ -2,20 +2,80 @@ require "std_lookup.pil" require "operations.pil" const int ARITH_RANGE_TABLE_ID = 330; +const int ARITH_RANGE_CARRY = 100; +const int ARITH_RANGE_16_BITS = 0; -airtemplate ArithRangeTable(int N = 2**17) { +airtemplate ArithRangeTable(int N = 2**22) { + + // a3 a1 b3 b1 + // rid c3 c1 d3 d1 range 2^16 2^15 notes + // --- -- -- -- -- ----- ---- ---- ------------------------- + // 0 F F F F ab cd 4 0 + // 1 F F + F cd 3 1 b3 sign => a3 sign + // 2 F F - F cd 3 1 b3 sign => a3 sign + // 3 + F F F ab 3 1 c3 sign => d3 sign + // 4 + F + F ab cd 2 2 + // 5 + F - F ab cd 2 2 + // 6 - F F F ab 3 1 c3 sign => d3 sign + // 7 - F + F ab cd 2 2 + // 8 - F - F ab cd 2 2 + // 9 F F F + cd a1 sign <=> b1 sign / d1 sign => c1 sign + // 10 F F F - cd a1 sign <=> b1 sign / d1 sign => c1 sign + // 11 F + F F ab cd 3 1 *a1 for sext/divu + // 12 F + F + ab cd 2 2 + // 13 F + F - ab cd 2 2 + // 14 F - F F ab cd 3 1 *a1 for sext/divu + // 15 F - F + ab cd 2 2 + // 16 F - F - ab cd 2 2 + // ---- ---- + + // COL COMPRESSION + // + // + // + // 0: F F F + + + - - - F F F F F F F F offset: 0 + // 1: F F F F F F F F F F F + + + - - - offset: 26 + // 2: F + - F + - F + - F F F F F F F F offset: 17 + // 3: F F F F F F F F F + - F + - F + - offset: 9 + // -------------------------------------------------------------------------------------- + // F F F + + + - - - F F F F F F F F F + - F + - F + - F F F F F F F F F F F + + + - - - + // + + // 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 3 3 3 3 3 3 3 3 3 3 4 4 4 + // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 + // + // F F F + + + - - - F F F F F F F F F + - F + - F + - F F F F F F F F F F F + + + - - - + // + // 25:FULL + 9:POS + 9:NEG = 34 * 2^16 = 2^21 + 2^17 + // + // [range, 0] => [range] + // [range, 1] => [range + 26] + // [range, 2] => [range + 17] + // [range, 3] => [range + 9] + // + // [-(2^19+2^18+2^16-1)...(2^19+2^18+2^16)] range check carry + + const int FULL = 2**16; + const int POS = 2**15; + const int NEG = 2**15; + + col fixed RANGE_ID = [0:FULL..2:FULL, 9:FULL..17:FULL, 20:FULL, 23:FULL, 26:FULL..36:FULL, // 25 FULL + 3:POS..5:POS, 18:POS, 21:POS, 24:POS, 37:POS..39:POS, // 9 POS + 6:NEG..8:NEG, 19:NEG, 22:NEG, 25:NEG, 40:NEG..42:NEG, // 9 NEG + ARITH_RANGE_CARRY...]; + + col fixed RANGE_VALUES = [[0x0000..0xFFFF]:25, + [0x0000..0x7FFF]:9, + [0x8000..0xFFFF]:9, + [-0xEFFFF..0xF0000]]; - // TODO: update values - col fixed RANGES = [0:2**16,1:2**15,2:2**15]; - col fixed INDEX = [0..2]...; - col fixed VALUES = [0..2**16-1]...; col witness multiplicity; - lookup_proves(ARITH_RANGE_TABLE_ID, [RANGES, INDEX, VALUES], multiplicity); + lookup_proves(ARITH_RANGE_TABLE_ID, [RANGE_ID, RANGE_VALUES], multiplicity); } -function arith_range_table_assumes(const expr range_type, const int index, const expr value) { +function arith_range_table_assumes(const expr range_type, const expr value, const expr sel = 1) { // TODO: define rule for empty rows - lookup_assumes(ARITH_RANGE_TABLE_ID, [range_type, index, value]); + lookup_assumes(ARITH_RANGE_TABLE_ID, [range_type, value], sel:sel); } diff --git a/state-machines/arith/pil/arith_table.pil b/state-machines/arith/pil/arith_table.pil index 9cdcd894..c174b1fe 100644 --- a/state-machines/arith/pil/arith_table.pil +++ b/state-machines/arith/pil/arith_table.pil @@ -2,7 +2,7 @@ require "std_lookup.pil" const int ARITH_TABLE_ID = 331; -airtemplate ArithTable(int N = 2**8, int generate_table = 1) { +airtemplate ArithTable(int N = 2**7, int generate_table = 1) { // TABLE // op,m32|div|na|nb|np|nr|sext,range_ab,range_cd @@ -46,14 +46,15 @@ airtemplate ArithTable(int N = 2**8, int generate_table = 1) { int div = 0; // division operation (div,rem) int sa = 0; int sb = 0; - int secondary = 0; + int main_mul = 0; + int main_div = 0; string opname = ""; switch (opcode) { case 0xb0: opname = "mulu"; + main_mul = 1; case 0xb1: opname = "mulh"; - secondary = 1; case 0xb3: opname = "mulsuh"; sa = 1; @@ -61,56 +62,57 @@ airtemplate ArithTable(int N = 2**8, int generate_table = 1) { opname = "mul"; sa = 1; sb = 1; + main_mul = 1; case 0xb5: opname = "mulh"; sa = 1; sb = 1; - secondary = 1; case 0xb6: opname = "mul_w"; m32 = 1; sa = 1; sb = 1; + main_mul = 1; case 0xb8: opname = "divu"; div = 1; + main_div = 1; case 0xb9: opname = "remu"; div = 1; - secondary = 1; case 0xba: opname = "div"; sa = 1; sb = 1; div = 1; + main_div = 1; case 0xbb: opname = "rem"; sa = 1; sb = 1; div = 1; - secondary = 1; case 0xbc: opname = "divu_w"; div = 1; m32 = 1; + main_div = 1; case 0xbd: opname = "remu_w"; div = 1; m32 = 1; - secondary = 1; case 0xbe: opname = "div_w"; sa = 1; sb = 1; div = 1; m32 = 1; + main_div = 1; case 0xbf: opname = "rem_w"; sa = 1; sb = 1; div = 1; m32 = 1; - secondary = 1; } for (int icase = 0; icase < 32; ++icase) { @@ -130,123 +132,66 @@ airtemplate ArithTable(int N = 2**8, int generate_table = 1) { if (np && na && nb) continue; if (!np & nr) continue; - int range_a1 = m32 * sa ? 1 + na : 0; - int range_b1 = m32 * sb ? 1 + nb : 0; - + int range_a1 = 0; + int range_b1 = 0; int range_c1 = 0; + int range_d1 = 0; + int range_a3 = 0; + int range_b3 = 0; + int range_c3 = 0; + int range_d3 = 0; + if (m32) { - if (div) { - // range_c1 = np || na32 ? 2 : 1; - } else { - // range_c1 = 1 + na32; + if (sa) { + range_a1 = 1 + na; + } else if (main_div) { + range_a1 = 1 + sext; } - } - int range_d1 = m32 * div ? (((np * sa) || sext) ? 1:2) : 0; - - int range_a3 = (1 - m32) * sa ? 1 + na : 0; - int range_b3 = (1 - m32) * sb ? 1 + na : 0; - int range_c3 = div * (1 - m32) * sa ? 1 + np : 0; - int range_d3 = div * (1 - m32) * sa ? 1 + np : 0; - - if (generate_table) { - op2row[(opcode - 0xb0) * 32 + icase] = index; - } - println(`==> #${index} op:${opname} [${opcode}] na:${na} nb:${nb} np:${np} nr:${nr} sext:${sext} sa:${sa} sb:${sb} secondary:${secondary}`); - ++index; - } - } -/* - int cases = 1 + 2 * sa + 2 * sb + sa * sb; - for (int sext = 0; sext < (1 + m32); ++sext) { - for (int icase = 0; icase < cases; ++icase) { - - int na = 0; // a is negative - int nb = 0; // b is negative - int a_is_zero = 0; // a is zero - int b_is_zero = 0; // b is zero - int np = 0; // prod is negative - int nr = 0; // rem is negative - int abort_case = 0; // if abort copy values of row 0 - switch (icase) { - // case 0: [a >= 0, b >= 0] - case 1: - // [a < 0, b > 0] - na = 1; - case 2: - // [a < 0, b = 0] - na = 1; - b_is_zero = 1; - case 3: - // [a > 0, b < 0] - nb = 1; - case 4: - // [a = 0, b < 0] - nb = 1; - a_is_zero = 1; - case 5: - // [a < 0, b < 0] - na = 1; - nb = 1; + if (sb) { + range_b1 = 1 + nb; } - if ((div && b_is_zero) || (sext && (a_is_zero || b_is_zero))) { - abort_case = 1; + if (!div) { + range_c1 = sext + 1; + } else if (sa) { + range_c1 = 1 + np; } - np = na + nb - na * nb; - if (np && (a_is_zero || b_is_zero)) { - // - * 0 = 0 (no negative) - np = 0; + if (div && !main_div) { + range_c1 = sext + 1; + } else if (sa) { + range_c1 = 1 + nr; } - nr = div ? na : 0; - - int range_a1 = m32 * sa ? 1 + na : 0; - int range_b1 = m32 * sb ? 1 + nb : 0; - - int range_c1 = 0; - if (m32) { + } else { + if (sa) { + range_a3 = 1 + na; if (div) { - // range_c1 = np || na32 ? 2 : 1; + range_c3 = 1 + np; + range_d3 = 1 + nr; } else { - // range_c1 = 1 + na32; + range_d3 = 1 + np; } } - int range_d1 = m32 * div ? (((np * sa) || sext) ? 1:2) : 0; - - int range_a3 = (1 - m32) * sa ? 1 + na : 0; - int range_b3 = (1 - m32) * sb ? 1 + na : 0; - int range_c3 = div * (1 - m32) * sa ? 1 + np : 0; - int range_d3 = div * (1 - m32) * sa ? 1 + np : 0; - - - if (abort_case) { - println(`ABORT op:${opcode} sa:${sa} sb:${sb} a_is_zero:${a_is_zero} b_is_zero:${b_is_zero} na:${na} nb:${nb} np:${np} nr:${nr} sext:${sext} secondary:${secondary}`); - OP[index] = OP[0]; - FLAGS[index] = FLAGS[0]; - RANGE_AB[index] = RANGE_AB[0]; - RANGE_CD[index] = RANGE_CD[0]; - ++aborted; - } else { - println(`ADD op:${opcode} sa:${sa} sb:${sb} a_is_zero:${a_is_zero} b_is_zero:${b_is_zero} na:${na} nb:${nb} np:${np} nr:${nr} sext:${sext} secondary:${secondary}`); - OP[index] = opcode; - FLAGS[index] = m32 + 2 * div + 4 * na + 8 * nb + 16 * np + 32 * nr + 64 * sext + 128 * secondary; - RANGE_AB[index] = 0; - RANGE_CD[index] = 0; + if (sb) { + range_b3 = 1 + nb; } + } + int signed = sa || sb ? 1 : 0; + OP[index] = opcode; + FLAGS[index] = m32 + 2 * div + 4 * na + 8 * nb + 16 * np + 32 * nr + 64 * sext + + 128 * main_mul + 256 * main_div + 512 * signed; + RANGE_AB[index] = range_a3 * 3 + range_b3 + m32 * 8 + range_a1 * 3 + range_b1; + RANGE_CD[index] = range_c3 * 3 + range_d3 + m32 * 8 + range_c1 * 3 + range_d1; - index = index + 1; + if (generate_table) { + op2row[(opcode - 0xb0) * 32 + icase] = index; } + ++index; } - }*/ + } const int size = index; println("ARITH_TABLE SIZE: ", size); - println("ARITH_TABLE ABORTED: ", aborted); - println("ARITH_FLAGS: ", FLAGS); if (generate_table) { - println("let arith_table_rows: [i16; 512] = [", op2row, "];"); - } - return; - for (index = 0; index < size; ++index) { - println(FLAGS[index]); + println("static ARITH_TABLE_ROWS: [i16; 512] = [", op2row, "];"); } // padding repeat first row @@ -264,17 +209,18 @@ airtemplate ArithTable(int N = 2**8, int generate_table = 1) { } col witness multiplicity; - // TODO: lookup_proves(ARITH_TABLE_ID, mul: multiplicity, cols: [OP, FLAGS, RANGE_AB, RANGE_CD]); } function arith_table_assumes( const expr op, const expr flag_m32, const expr flag_div, const expr flag_na, const expr flag_nb, const expr flag_np, const expr flag_nr, const expr flag_sext, - const expr flag_secondary_res, const expr range_ab, const expr range_cd) { + const expr flag_main_mul, const expr flag_main_div, const expr flag_signed, + const expr range_ab, const expr range_cd) { // TODO: #pragma binary flag_m32 => check any constraint on compilation time // TODO: define rule for empty rows lookup_assumes(ARITH_TABLE_ID, cols: [ op, flag_m32 + 2 * flag_div + 4 * flag_na + 8 * flag_nb + 16 * flag_np + 32 * flag_nr + 64 * flag_sext + - 128 * flag_secondary_res, range_ab, range_cd]); + 128 * flag_main_mul + 256 * flag_main_div + 512 * flag_signed, + range_ab, range_cd]); } diff --git a/state-machines/arith/src/arith_constants.rs b/state-machines/arith/src/arith_constants.rs new file mode 100644 index 00000000..4a7af91a --- /dev/null +++ b/state-machines/arith/src/arith_constants.rs @@ -0,0 +1,14 @@ +pub const MULU: u8 = 0xb0; +pub const MULUH: u8 = 0xb1; +pub const MULSUH: u8 = 0xb3; +pub const MUL: u8 = 0xb4; +pub const MULH: u8 = 0xb5; +pub const MUL_W: u8 = 0xb6; +pub const DIVU: u8 = 0xb8; +pub const REMU: u8 = 0xb9; +pub const DIV: u8 = 0xba; +pub const REM: u8 = 0xbb; +pub const DIVU_W: u8 = 0xbc; +pub const REMU_W: u8 = 0xbd; +pub const DIV_W: u8 = 0xbe; +pub const REM_W: u8 = 0xbf; diff --git a/state-machines/arith/src/arith_full.rs b/state-machines/arith/src/arith_full.rs index a6995893..b0c30880 100644 --- a/state-machines/arith/src/arith_full.rs +++ b/state-machines/arith/src/arith_full.rs @@ -4,8 +4,8 @@ use std::sync::{ }; use crate::{ - arith_table_inputs, ArithRangeTableInputs, ArithRangeTableSM, ArithSM, ArithTableInputs, - ArithTableSM, + arith_table_inputs, ArithOperation, ArithRangeTableInputs, ArithRangeTableSM, ArithSM, + ArithTableInputs, ArithTableSM, }; use p3_field::Field; use proofman::{WitnessComponent, WitnessManager}; @@ -76,10 +76,68 @@ impl ArithFullSM { range_table_inputs: &mut ArithRangeTableInputs, table_inputs: &mut ArithTableInputs, ) -> Vec> { - let mut _trace: Vec> = Vec::new(); - range_table_inputs.push(0, 0); - table_inputs.fast_push(0, 0, 0); - _trace + let mut traces: Vec> = Vec::new(); + let mut aop = ArithOperation::new(); + for input in input.iter() { + aop.calculate(input.opcode, input.a, input.b); + let mut t: Arith0Row = Default::default(); + for i in 0..4 { + t.a[i] = F::from_canonical_u64(aop.a[i]); + t.b[i] = F::from_canonical_u64(aop.b[i]); + t.c[i] = F::from_canonical_u64(aop.c[i]); + t.d[i] = F::from_canonical_u64(aop.d[i]); + // arith_operation.a[i]; + } + // range_table_inputs.push(0, 0); + // table_inputs.fast_push(0, 0, 0); + t.m32 = F::from_bool(aop.m32); + t.div = F::from_bool(aop.div); + t.na = F::from_bool(aop.na); + t.nb = F::from_bool(aop.nb); + t.np = F::from_bool(aop.np); + t.nr = F::from_bool(aop.nr); + t.signed = F::from_bool(aop.signed); + t.main_mul = F::from_bool(aop.main_mul); + t.main_div = F::from_bool(aop.main_div); + t.sext = F::from_bool(aop.sext); + t.multiplicity = F::one(); + + t.fab = if aop.na != aop.nb { F::neg_one() } else { F::one() }; + // na * (1 - 2 * nb); + t.na_fb = if aop.na { + if aop.nb { + F::neg_one() + } else { + F::one() + } + } else { + F::zero() + }; + t.nb_fa = if aop.nb { + if aop.na { + F::neg_one() + } else { + F::one() + } + } else { + F::zero() + }; + t.bus_res1 = F::from_canonical_u64( + if aop.sext { 0xFFFFFFFF } else { 0 } + + if aop.main_mul { + aop.c[2] + aop.c[3] << 16 + } else if aop.main_div { + aop.a[2] + aop.a[3] << 16 + } else { + aop.d[2] + aop.d[3] << 16 + }, + ); + + traces.push(t); + } + // range_table_inputs.push(0, 0); + //table_inputs.fast_push(0, 0, 0); + traces } } diff --git a/state-machines/arith/src/arith_helpers.rs b/state-machines/arith/src/arith_helpers.rs deleted file mode 100644 index eac6ffdb..00000000 --- a/state-machines/arith/src/arith_helpers.rs +++ /dev/null @@ -1,1683 +0,0 @@ -use zisk_core::zisk_ops::*; - -const MULU: u8 = 0xb0; -const MULUH: u8 = 0xb1; -const MULSUH: u8 = 0xb3; -const MUL: u8 = 0xb4; -const MULH: u8 = 0xb5; -const MUL_W: u8 = 0xb6; -const DIVU: u8 = 0xb8; -const REMU: u8 = 0xb9; -const DIV: u8 = 0xba; -const REM: u8 = 0xbb; -const DIVU_W: u8 = 0xbc; -const REMU_W: u8 = 0xbd; -const DIV_W: u8 = 0xbe; -const REM_W: u8 = 0xbf; - -const FLAG_NAMES: [&str; 8] = ["m32", "div", "na", "nb", "np", "nr", "sext", "sec"]; - -pub trait ArithHelpers { - fn get_row(op: u8, na: u64, nb: u64, np: u64, nr: u64, sext: u64) -> i16 { - static arith_table_rows: [i16; 512] = [ - 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, - -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, - -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, - -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, - -1, -1, -1, -1, -1, -1, -1, -1, 2, 3, -1, -1, -1, 4, -1, -1, -1, -1, -1, -1, -1, -1, - -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 5, 6, 7, 8, -1, - 9, 10, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, - -1, -1, -1, -1, -1, 11, 12, 13, 14, -1, 15, 16, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, - -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 17, 18, 19, 20, -1, 21, 22, - -1, -1, -1, -1, -1, -1, -1, -1, -1, 23, 24, 25, 26, -1, 27, 28, -1, -1, -1, -1, -1, -1, - -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, - -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 29, -1, -1, -1, -1, -1, -1, -1, -1, - -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, - -1, 30, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, - -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 31, 32, 33, 34, 35, 36, 37, -1, -1, -1, -1, - -1, 38, 39, 40, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 41, - 42, 43, 44, 45, 46, 47, -1, -1, -1, -1, -1, 48, 49, 50, -1, -1, -1, -1, -1, -1, -1, -1, - -1, -1, -1, -1, -1, -1, -1, -1, -1, 51, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, - -1, -1, -1, 52, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 53, -1, -1, - -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 54, -1, -1, -1, -1, -1, -1, -1, -1, - -1, -1, -1, -1, -1, -1, -1, 55, 56, 57, 58, 59, 60, 61, -1, -1, -1, -1, -1, 62, 63, 64, - -1, 65, 66, 67, 68, 69, 70, 71, -1, -1, -1, -1, -1, 72, 73, 74, -1, 75, 76, 77, 78, 79, - 80, 81, -1, -1, -1, -1, -1, 82, 83, 84, -1, 85, 86, 87, 88, 89, 90, 91, -1, -1, -1, -1, - -1, 92, 93, 94, -1, - ]; - - let index = (op - 0xb0) as u64 * 32 + na + nb * 2 + np * 4 + nr * 8 + sext * 16; - arith_table_rows[index as usize] - } - // arith_sm - fn sign32(abs_value: u64, negative: bool) -> u64 { - assert!(0xFFFF_FFFF >= abs_value, "abs_value:0x{0:X}({0}) is too big", abs_value); - if negative { - (0xFFFF_FFFF - abs_value) + 1 - } else { - abs_value - } - } - fn sign64(abs_value: u64, negative: bool) -> u64 { - if negative { - (0xFFFF_FFFF_FFFF_FFFF - abs_value) + 1 - } else { - abs_value - } - } - fn sign128(abs_value: u128, negative: bool) -> u128 { - let res = if negative { - (0xFFFF_FFFF_FFFF_FFFF_FFFF_FFFF_FFFF_FFFF - abs_value) + 1 - } else { - abs_value - }; - println!("sign128({:X},{})={:X}", abs_value, negative, res); - res - } - fn abs32(value: u64) -> [u64; 2] { - let negative = if (value & 0x8000_0000) != 0 { 1 } else { 0 }; - let abs_value = if negative == 1 { (0xFFFF_FFFF - value) + 1 } else { value }; - // println!( - // "value:0x{0:X}({0}) abs_value:0x{1:X}({1}) negative:{2}", - // value, abs_value, negative - // ); - [abs_value, negative] - } - fn abs64(value: u64) -> [u64; 2] { - let negative = if (value & 0x8000_0000_0000_0000) != 0 { 1 } else { 0 }; - let abs_value = if negative == 1 { (0xFFFF_FFFF_FFFF_FFFF - value) + 1 } else { value }; - [abs_value, negative] - } - fn calculate_mul_w(a: u64, b: u64) -> u64 { - (a & 0xFFFF_FFFF) * (b & 0xFFFF_FFFF) - } - - fn calculate_mulsu(a: u64, b: u64) -> [u64; 2] { - let [abs_a, na] = Self::abs64(a); - let abs_c = abs_a as u128 * b as u128; - let nc = if na == 1 && abs_c != 0 { 1 } else { 0 }; - let c = Self::sign128(abs_c, nc == 1); - [c as u64, (c >> 64) as u64] - } - - fn calculate_mul(a: u64, b: u64) -> [u64; 2] { - let [abs_a, na] = Self::abs64(a); - let [abs_b, nb] = Self::abs64(b); - println!( - "mul(a:0x{:X}, b:0x{:X} abs_a:0x{:X} na:{} abs_b:0x{:X} nb:{}", - a, b, abs_a, na, abs_b, nb, - ); - let abs_c = abs_a as u128 * abs_b as u128; - let nc = if na != nb && abs_c != 0 { 1 } else { 0 }; - let c = Self::sign128(abs_c, nc == 1); - [c as u64, (c >> 64) as u64] - } - - fn calculate_div(a: u64, b: u64) -> u64 { - let [abs_a, na] = Self::abs64(a); - let [abs_b, nb] = Self::abs64(b); - let abs_c = abs_a / abs_b; - let nc = if na != nb && abs_c != 0 { 1 } else { 0 }; - Self::sign64(abs_c, nc == 1) - } - - fn calculate_rem(a: u64, b: u64) -> u64 { - let [abs_a, na] = Self::abs64(a); - let [abs_b, _nb] = Self::abs64(b); - let abs_c = abs_a % abs_b; - let nc = if na == 1 && abs_c != 0 { 1 } else { 0 }; - Self::sign64(abs_c, nc == 1) - } - - fn calculate_div_w(a: u64, b: u64) -> u64 { - let [abs_a, na] = Self::abs32(a); - let [abs_b, nb] = Self::abs32(b); - let abs_c = abs_a / abs_b; - let nc = if na != nb && abs_c != 0 { 1 } else { 0 }; - Self::sign32(abs_c, nc == 1) - } - - fn calculate_rem_w(a: u64, b: u64) -> u64 { - let [abs_a, na] = Self::abs32(a); - let [abs_b, _nb] = Self::abs32(b); - let abs_c = abs_a % abs_b; - let nc = if na == 1 && abs_c != 0 { 1 } else { 0 }; - Self::sign32(abs_c, nc == 1) - } - - fn calculate_emulator_res(op: u8, a: u64, b: u64) -> (u64, bool) { - match op { - MULU => return op_mulu(a, b), - MULUH => return op_muluh(a, b), - MULSUH => return op_mulsuh(a, b), - MUL => return op_mul(a, b), - MULH => return op_mulh(a, b), - MUL_W => return op_mul_w(a, b), - DIVU => return op_divu(a, b), - REMU => return op_remu(a, b), - DIVU_W => return op_divu_w(a, b), - REMU_W => return op_remu_w(a, b), - DIV => return op_div(a, b), - REM => return op_rem(a, b), - DIV_W => return op_div_w(a, b), - REM_W => return op_rem_w(a, b), - _ => { - panic!("Invalid opcode"); - } - } - } - - fn calculate_abcd_from_ab(op: u8, a: u64, b: u64) -> [u64; 4] { - match op { - MULU | MULUH => { - let c: u128 = a as u128 * b as u128; - [a, b, c as u64, (c >> 64) as u64] - } - MULSUH => { - let [c, d] = Self::calculate_mulsu(a, b); - [a, b, c, d] - } - MUL | MULH => { - let [c, d] = Self::calculate_mul(a, b); - [a, b, c, d] - } - MUL_W => [a, b, Self::calculate_mul_w(a, b), 0], - DIVU | REMU | DIVU_W | REMU_W => [a / b, b, a, a % b], - DIV | REM => [Self::calculate_div(a, b), b, a, Self::calculate_rem(a, b)], - DIV_W | REM_W => [Self::calculate_div_w(a, b), b, a, Self::calculate_rem_w(a, b)], - _ => { - panic!("Invalid opcode"); - } - } - } - fn decode_one_range(range_xy: u64) -> [u64; 4] { - if range_xy == 9 { - [0, 0, 0, 0] - } else if range_xy > 9 { - let x = (range_xy - 8) / 3; - let y = (range_xy - 8) % 3; - [0, 0, x, y] - } else { - let x = range_xy / 3; - let y = range_xy % 3; - [x, y, 0, 0] - } - } - fn decode_ranges(range_ab: u64, range_cd: u64) -> [u64; 8] { - let ab = Self::decode_one_range(range_ab); - let cd = Self::decode_one_range(range_cd); - [ab[0], ab[1], cd[0], cd[1], ab[2], ab[3], cd[2], cd[3]] - } - fn calculate_flags_and_ranges(op: u8, a: u64, b: u64, c: u64, d: u64) -> [u64; 11] { - let mut m32: u64 = 0; - let mut div: u64 = 0; - let mut np: u64 = 0; - let mut nr: u64 = 0; - let mut sext: u64 = 0; - let mut secondary_res: u64 = 0; - - let mut range_a1: u64 = 0; - let mut range_b1: u64 = 0; - let mut range_c1: u64 = 0; - let mut range_d1: u64 = 0; - let mut range_a3: u64 = 0; - let mut range_b3: u64 = 0; - let mut range_c3: u64 = 0; - let mut range_d3: u64 = 0; - - // direct table opcode(14), signed 2 or 4 cases (0,na,nb,na+nb) - // 6 * 1 + 7 * 4 + 1 * 2 = 36 entries, - // no compacted => 16 x 4 = 64, key = (op - 0xb0) * 4 + na * 2 + nb - // output: div, m32, sa, sb, nr, np, na, na32, nd32, range x 2 x 4 - - // alternative: switch operation, - - let mut sa = false; - let mut sb = false; - let mut rem = false; - - match op { - MULU => {} - MULUH => { - secondary_res = 1; - } - MULSUH => { - sa = true; - secondary_res = 1; - } - MUL => { - sa = true; - sb = true; - } - MULH => { - sa = true; - sb = true; - secondary_res = 1; - } - MUL_W => { - m32 = 1; - sext = if ((a * b) & 0xFFFF_FFFF) & 0x8000_0000 != 0 { 1 } else { 0 }; - } - DIVU => { - div = 1; - assert!(b != 0, "Error on DIVU a:{:x}({}) b:{:x}({})", a, b, a, b); - } - REMU => { - div = 1; - rem = true; - secondary_res = 1; - } - DIV => { - sa = true; - sb = true; - div = 1; - } - REM => { - sa = true; - sb = true; - rem = true; - div = 1; - secondary_res = 1; - } - DIVU_W => { - // divu_w, remu_w - div = 1; - m32 = 1; - // use a in bus - sext = if (a & 0x8000_0000) != 0 { 1 } else { 0 }; - } - REMU_W => { - // divu_w, remu_w - div = 1; - m32 = 1; - rem = true; - // use d in bus - sext = if (d & 0x8000_0000) != 0 { 1 } else { 0 }; - secondary_res = 1; - } - DIV_W => { - // div_w, rem_w - sa = true; - sb = true; - div = 1; - m32 = 1; - // use a in bus - sext = if (a & 0x8000_0000) != 0 { 1 } else { 0 }; - } - REM_W => { - // div_w, rem_w - sa = true; - sb = true; - div = 1; - m32 = 1; - rem = true; - // use d in bus - sext = if (d & 0x8000_0000) != 0 { 1 } else { 0 }; - secondary_res = 1; - } - _ => { - panic!("Invalid opcode"); - } - } - let sign_mask: u64 = if m32 == 1 { 0x8000_0000 } else { 0x8000_0000_0000_0000 }; - let sign_c_mask: u64 = - if m32 == 1 && div == 1 { 0x8000_0000 } else { 0x8000_0000_0000_0000 }; - let na = if sa && (a & sign_mask) != 0 { 1 } else { 0 }; - let nb = if sb && (b & sign_mask) != 0 { 1 } else { 0 }; - // a sign => b sign - let nc = if sa && (c & sign_c_mask) != 0 { 1 } else { 0 }; - let nd = if sa && (d & sign_mask) != 0 { 1 } else { 0 }; - - // a == 0 || b == 0 => np == 0 ==> how was a signed operation - // after that sign of np was verified with range check. - // TODO: review if secure - if div == 1 { - np = nc; //if c != 0 { na ^ nb } else { 0 }; - nr = nd; - } else { - np = if m32 == 1 { nc } else { nd }; // if (c != 0) || (d != 0) { na ^ nb } else { 0 } - nr = 0; - } - if m32 == 1 { - // mulw, divu_w, remu_w, div_w, rem_w - range_a1 = if sa { - 1 + na - } else if div == 1 && !rem { - 1 + sext - } else { - 0 - }; - range_b1 = if sb { 1 + nb } else { 0 }; - // m32 && div == 0 => mulw - range_c1 = if div == 0 { - sext + 1 - } else if sa { - 1 + np - } else { - 0 - }; - range_d1 = if rem { - sext + 1 - } else if sa { - 1 + nr - } else { - 0 - }; - } else { - // mulu, muluh, mulsuh, mul, mulh, div, rem, divu, remu - if sa { - // mulsuh, mul, mulh, div, rem - range_a3 = 1 + na; - if div == 1 { - // div, rem - range_c3 = 1 + np; - range_d3 = 1 + nr; - } else { - range_d3 = 1 + np; - } - } - // sb => mul, mulh, div, rem - range_b3 = if sb { 1 + nb } else { 0 }; - } - - // range_ab / range_cd - // - // a3 a1 b3 b1 - // rid c3 c1 d3 d1 range 2^16 2^15 notes - // --- -- -- -- -- ----- ---- ---- ------------------------- - // 0 F F F F ab cd 4 0 - // 1 F F + F cd 3 1 b3 sign => a3 sign - // 2 F F - F cd 3 1 b3 sign => a3 sign - // 3 + F F F ab 3 1 c3 sign => d3 sign - // 4 + F + F ab cd 2 2 - // 5 + F - F ab cd 2 2 - // 6 - F F F ab 3 1 c3 sign => d3 sign - // 7 - F + F ab cd 2 2 - // 8 - F - F ab cd 2 2 - // 9 F F F + cd a1 sign <=> b1 sign / d1 sign => c1 sign - // 10 F F F - cd a1 sign <=> b1 sign / d1 sign => c1 sign - // 11 F + F F cd 3 1 a1 sign <=> b1 sign - // 12 F + F + ab cd 2 2 - // 13 F + F - ab cd 2 2 - // 14 F - F F cd 3 1 a1 sign <=> b1 sign - // 15 F - F + ab cd 2 2 - // 16 F - F - ab cd 2 2 - - assert!(range_a1 == 0 || range_a3 == 0, "range_a1:{} range_a3:{}", range_a1, range_a3); - assert!(range_b1 == 0 || range_b3 == 0, "range_b1:{} range_b3:{}", range_b1, range_b3); - assert!(range_c1 == 0 || range_c3 == 0, "range_c1:{} range_c3:{}", range_c1, range_c3); - assert!(range_d1 == 0 || range_d3 == 0, "range_d1:{} range_d3:{}", range_d1, range_d3); - - let range_ab = (range_a3 + range_a1) * 3 - + range_b3 - + range_b1 - + if (range_a1 + range_b1) > 0 { 8 } else { 0 }; - - let range_cd = (range_c3 + range_c1) * 3 - + range_d3 - + range_d1 - + if (range_c1 + range_d1) > 0 { 8 } else { 0 }; - - let ranges = range_a3 * 1000_0000 - + range_b3 * 100_0000 - + range_c3 * 10_0000 - + range_d3 * 1000 - + range_a1 * 1000 - + range_b1 * 100 - + range_c1 * 10 - + range_d1; - [m32, div, na, nb, np, nr, sext, secondary_res, range_ab, range_cd, ranges] - } - - fn calculate_chunks( - a: [i64; 4], - b: [i64; 4], - c: [i64; 4], - d: [i64; 4], - m32: i64, - div: i64, - na: i64, - nb: i64, - np: i64, - nr: i64, - fab: i64, - secondary_res: i64, - sext: i64, - ) -> [i64; 8] { - // TODO: unroll this function in variants (div,m32) and (na,nb,nr,np) - // div, m32, na, nb === f(div,m32,na,nb) => fa, nb, nr - // unroll means 16 variants ==> but more performance - - let mut chunks: [i64; 8] = [0, 0, 0, 0, 0, 0, 0, 0]; - - let na_fb = na * (1 - 2 * nb); - let nb_fa = nb * (1 - 2 * na); - - chunks[0] = fab * a[0] * b[0] // chunk0 - - c[0] - + 2 * np * c[0] - + div * d[0] - - 2 * nr * d[0]; - - chunks[1] = fab * a[1] * b[0] // chunk1 - + fab * a[0] * b[1] - - c[1] - + 2 * np * c[1] - + div * d[1] - - 2 * nr * d[1]; - - chunks[2] = fab * a[2] * b[0] // chunk2 - + fab * a[1] * b[1] - + fab * a[0] * b[2] - + a[0] * nb_fa * m32 - + b[0] * na_fb * m32 - - c[2] - + 2 * np * c[2] - + div * d[2] - - 2 * nr * d[2] - - np * div * m32 - + nr * m32; // div == 0 ==> nr = 0 - - chunks[3] = fab * a[3] * b[0] // chunk3 - + fab * a[2] * b[1] - + fab * a[1] * b[2] - + fab * a[0] * b[3] - + a[1] * nb_fa * m32 - + b[1] * na_fb * m32 - - c[3] - + 2 * np * c[3] - + div * d[3] - - 2 * nr * d[3]; - - chunks[4] = fab * a[3] * b[1] // chunk4 - + fab * a[2] * b[2] - + fab * a[1] * b[3] - + na * nb * m32 - // + b[0] * na * (1 - 2 * nb) - // + a[0] * nb * (1 - 2 * na) - + b[0] * na_fb * (1 - m32) - + a[0] * nb_fa * (1 - m32) - // high bits ^^^ - // - np * div - // + np * div * m32 - // - 2 * div * m32 * np - - np * m32 * (1 - div) // - - np * (1 - m32) * div // 2^64 (np) - + nr * (1 - m32) // 2^64 (nr) - // high part d - - d[0] * (1 - div) // m32 == 1 and div == 0 => d = 0 - + 2 * np * d[0] * (1 - div); // - - chunks[5] = fab * a[3] * b[2] // chunk5 - + fab * a[2] * b[3] - + a[1] * nb_fa * (1 - m32) - + b[1] * na_fb * (1 - m32) - - d[1] * (1 - div) - + d[1] * 2 * np * (1 - div); - - chunks[6] = fab as i64 * a[3] * b[3] // chunk6 - + a[2] * nb_fa * (1 - m32) - + b[2] * na_fb * (1 - m32) - - d[2] * (1 - div) - + d[2] * 2 * np * (1 - div); - - // 0x4000_0000_0000_0000__8000_0000_0000_0000 - chunks[7] = 0x10000 * na * nb * (1 - m32) // chunk7 - + a[3] * nb_fa * (1 - m32) - + b[3] * na_fb * (1 - m32) - - 0x10000 * np * (1 - div) * (1 - m32) - - d[3] * (1 - div) - + d[3] * 2 * np * (1 - div); - - chunks - } - fn u64_to_chunks(a: u64) -> [i64; 4] { - [ - (a & 0xFFFF) as i64, - ((a >> 16) & 0xFFFF) as i64, - ((a >> 32) & 0xFFFF) as i64, - ((a >> 48) & 0xFFFF) as i64, - ] - } - fn execute_chunks( - a: u64, - b: u64, - c: u64, - d: u64, - m32: u64, - div: u64, - na: u64, - nb: u64, - np: u64, - nr: u64, - secondary_res: u64, - sext: u64, - range_ab: u64, - range_cd: u64, - bus: [u64; 8], - ) -> bool { - let fab: i64 = 1 - 2 * na as i64 - 2 * nb as i64 + 4 * na as i64 * nb as i64; - let a_chunks = Self::u64_to_chunks(a); - let b_chunks = Self::u64_to_chunks(b); - let c_chunks = Self::u64_to_chunks(c); - let d_chunks = Self::u64_to_chunks(d); - println!( - "A: 0x{0:>04X} \x1B[32m{0:>5}\x1B[0m|0x{1:>04X} \x1B[32m{1:>5}\x1B[0m|0x{2::>04X} \x1B[32m{2:>5}\x1B[0m|0x{3:>04X} \x1B[32m{3:>5}\x1B[0m|", - a_chunks[0], a_chunks[1], a_chunks[2], a_chunks[3] - ); - println!( - "B: 0x{0:>04X} \x1B[32m{0:>5}\x1B[0m|0x{1:>04X} \x1B[32m{1:>5}\x1B[0m|0x{2::>04X} \x1B[32m{2:>5}\x1B[0m|0x{3:>04X} \x1B[32m{3:>5}\x1B[0m|", - b_chunks[0], b_chunks[1], b_chunks[2], b_chunks[3] - ); - println!( - "C: 0x{0:>04X} \x1B[32m{0:>5}\x1B[0m|0x{1:>04X} \x1B[32m{1:>5}\x1B[0m|0x{2::>04X} \x1B[32m{2:>5}\x1B[0m|0x{3:>04X} \x1B[32m{3:>5}\x1B[0m|", - c_chunks[0], c_chunks[1], c_chunks[2], c_chunks[3] - ); - println!( - "D: 0x{0:>04X} \x1B[32m{0:>5}\x1B[0m|0x{1:>04X} \x1B[32m{1:>5}\x1B[0m|0x{2::>04X} \x1B[32m{2:>5}\x1B[0m|0x{3:>04X} \x1B[32m{3:>5}\x1B[0m|", - d_chunks[0], d_chunks[1], d_chunks[2], d_chunks[3] - ); - - let mut chunks = Self::calculate_chunks( - a_chunks, - b_chunks, - c_chunks, - d_chunks, - m32 as i64, - div as i64, - na as i64, - nb as i64, - np as i64, - nr as i64, - fab, - secondary_res as i64, - sext as i64, - ); - let mut carry: i64 = 0; - println!( - "0x{0:X}({0}),0x{1:X}({1}),0x{2:X}({2}),0x{3:X}({3}),0x{4:X}({4}),0x{5:X}({5}),0x{6:X}{6},0x{7:X}({7}) fab:{8:X}", - chunks[0], chunks[1], chunks[2], chunks[3], chunks[4], chunks[5], chunks[6], chunks[7], fab - ); - let mut carrys: [i64; 8] = [0, 0, 0, 0, 0, 0, 0, 0]; - for _index in 0..8 { - println!( - "APPLY CARRY:{0} CHUNK[{1}]:{2:X} ({2}) {3:X}({3})", - carry, - _index, - chunks[_index], - chunks[_index] + carry - ); - let chunk_value = chunks[_index] + carry; - carry = chunk_value / 0x10000; - chunks[_index] = chunk_value - carry * 0x10000; - carrys[_index] = carry; - } - println!( - "CARRY 0x{0:X}({0}),0x{1:X}({1}),0x{2:X}({2}),0x{3:X}({3}),0x{4:X}({4}),0x{5:X}({5}),0x{6:X}{6},0x{7:X}({7}) fab:{8:X}", - carrys[0], carrys[1], carrys[2], carrys[3], carrys[4], carrys[5], carrys[6], carrys[7], fab - ); - println!( - "0x{:X},0x{:X},0x{:X},0x{:X},0x{:X},0x{:X},0x{:X},0x{:X} carry:0x{:X}", - chunks[0], - chunks[1], - chunks[2], - chunks[3], - chunks[4], - chunks[5], - chunks[6], - chunks[7], - carry - ); - println!( - "{} {} {} {} {} {} {} {} {}", - chunks[0], - chunks[1], - chunks[2], - chunks[3], - chunks[4], - chunks[5], - chunks[6], - chunks[7], - carry - ); - let mut passed = if chunks[0] != 0 - || chunks[1] != 0 - || chunks[2] != 0 - || chunks[3] != 0 - || chunks[4] != 0 - || chunks[5] != 0 - || chunks[6] != 0 - || chunks[7] != 0 - || carry != 0 - { - println!("[\x1B[31mFAIL\x1B[0m]"); - false - } else { - println!("[\x1B[32mOK\x1B[0m]"); - true - }; - const CHUNK_SIZE: i64 = 0x10000; - let bus_a_low: i64 = div as i64 * (c_chunks[0] + c_chunks[1] * CHUNK_SIZE) - + (1 - div as i64) * (a_chunks[0] + a_chunks[1] * CHUNK_SIZE); - let bus_a_high: i64 = div as i64 * (c_chunks[2] + c_chunks[3] * CHUNK_SIZE) - + (1 - div as i64) * (a_chunks[2] + a_chunks[3] * CHUNK_SIZE); - - let bus_b_low: i64 = b_chunks[0] + CHUNK_SIZE * b_chunks[1]; - let bus_b_high: i64 = b_chunks[2] + CHUNK_SIZE * b_chunks[3]; - - let res2_low: i64 = d_chunks[0] + CHUNK_SIZE * d_chunks[1]; - let res2_high: i64 = d_chunks[2] + CHUNK_SIZE * d_chunks[3]; - - let res_low: i64 = secondary_res as i64 * res2_low - + (1 - secondary_res as i64) - * (a_chunks[0] + c_chunks[0] + CHUNK_SIZE * (a_chunks[1] + c_chunks[1]) - - bus_a_low); - println!( - "RES_LOW: 0x{0:X}({0}) 0x{1:X}({1}) 0x{2:X}({2})", - res_low, - a_chunks[2] + c_chunks[2] + CHUNK_SIZE * (a_chunks[3] + c_chunks[3]), - bus_a_high - ); - let res_high: i64 = (1 - m32 as i64) - * (secondary_res as i64 * res2_high - + (1 - secondary_res as i64) - * ((a_chunks[2] + c_chunks[2] + CHUNK_SIZE * (a_chunks[3] + c_chunks[3])) - - bus_a_high)) - + sext as i64 * 0xFFFFFFFF; - passed = passed - && if bus[1] != bus_a_low as u64 - || bus[2] != bus_a_high as u64 - || bus[3] != bus_b_low as u64 - || bus[4] != bus_b_high as u64 - || bus[5] != res_low as u64 - || bus[6] != res_high as u64 - { - println!("0x{0:X} ({0}) vs 0x{1:X} ({1})", bus[1], bus_a_low); - println!("0x{0:X} ({0}) vs 0x{1:X} ({1})", bus[2], bus_a_high); - println!("0x{0:X} ({0}) vs 0x{1:X} ({1})", bus[3], bus_b_low); - println!("0x{0:X} ({0}) vs 0x{1:X} ({1})", bus[4], bus_b_high); - println!("0x{0:X} ({0}) vs 0x{1:X} ({1})", bus[5], res_low); - println!("0x{0:X} ({0}) vs 0x{1:X} ({1})", bus[6], res_high); - println!("[\x1B[31mFAIL BUS\x1B[0m]"); - false - } else { - println!("[\x1B[32mOK BUS\x1B[0m]"); - true - }; - // check all chunks and carries - let carry_min_value: i64 = -0x0F_FFFF; - let carry_max_value: i64 = 0x0F_FFFF; - for index in 0..8 { - passed = passed - && if carrys[index] > carry_max_value || carrys[index] < carry_min_value { - println!("[\x1B[31mFAIL CARRY RANGE CHECK\x1B[0m]"); - false - } else { - println!("[\x1B[32mOK CARRY RANGE CHECK\x1B[0m]"); - true - }; - } - let ranges = Self::decode_ranges(range_ab, range_cd); - Self::check_range(0, a_chunks[0]); - Self::check_range(0, b_chunks[0]); - Self::check_range(0, c_chunks[0]); - Self::check_range(0, d_chunks[0]); - - Self::check_range(ranges[4], a_chunks[1]); - Self::check_range(ranges[5], b_chunks[1]); - Self::check_range(ranges[6], c_chunks[1]); - Self::check_range(ranges[7], d_chunks[1]); - - Self::check_range(0, a_chunks[2]); - Self::check_range(0, b_chunks[2]); - Self::check_range(0, c_chunks[2]); - Self::check_range(0, d_chunks[2]); - - Self::check_range(ranges[0], a_chunks[3]); - Self::check_range(ranges[1], b_chunks[3]); - Self::check_range(ranges[2], c_chunks[3]); - Self::check_range(ranges[3], d_chunks[3]); - - passed - } - fn check_range(range_id: u64, value: i64) { - assert!(range_id != 0 || (value >= 0 && value <= 0xFFFF)); - assert!(range_id != 1 || (value >= 0 && value <= 0x7FFF)); - assert!(range_id != 2 || (value >= 0x8000 && value <= 0xFFFF)); - } -} - -fn flags_to_strings(mut flags: u64, flag_names: &[&str]) -> String { - let mut res = String::new(); - - for flag_name in flag_names { - if (flags & 1u64) != 0 { - if !res.is_empty() { - res = res + ","; - } - res = res + *flag_name; - } - flags >>= 1; - if flags == 0 { - break; - }; - } - res -} - -const F_M32: u64 = 0x0001; -const F_DIV: u64 = 0x0002; -const F_NA: u64 = 0x0004; -const F_NB: u64 = 0x0008; -const F_NP: u64 = 0x0010; -const F_NR: u64 = 0x0020; -const F_SEXT: u64 = 0x0040; -const F_SEC: u64 = 0x0080; - -// range_ab / range_cd -// -// a3 a1 b3 b1 -// rid c3 c1 d3 d1 range 2^16 2^15 notes -// --- -- -- -- -- ----- ---- ---- ------------------------- - -const R_FF: u64 = 0; // 0 F F F F ab cd 4 0 -const R_3FP: u64 = 1; // 1 F F + F cd 3 1 b3 sign => a3 sign -const R_3FN: u64 = 2; // 2 F F - F cd 3 1 b3 sign => a3 sign -const R_3PF: u64 = 3; // 3 + F F F ab 3 1 c3 sign => d3 sign -const R_3PP: u64 = 4; // 4 + F + F ab cd 2 2 -const R_3PN: u64 = 5; // 5 + F - F ab cd 2 2 -const R_3NF: u64 = 6; // 6 - F F F ab 3 1 c3 sign => d3 sign -const R_3NP: u64 = 7; // 7 - F + F ab cd 2 2 -const R_3NN: u64 = 8; // 8 - F - F ab cd 2 2 -const R_1FP: u64 = 9; // 9 F F F + cd a1 sign <=> b1 sign / d1 sign => c1 sign -const R_1FN: u64 = 10; // 10 F F F - cd a1 sign <=> b1 sign / d1 sign => c1 sign -const R_1PF: u64 = 11; // 11 F + F F cd 3 1 a1 sign <=> b1 sign -const R_1PP: u64 = 12; // 12 F + F + ab cd 2 2 -const R_1PN: u64 = 13; // 13 F + F - ab cd 2 2 -const R_1NF: u64 = 14; // 14 F - F F cd 3 1 a1 sign <=> b1 sign -const R_1NP: u64 = 15; // 15 F - F + ab cd 2 2 -const R_1NN: u64 = 16; // 16 F - F - ab cd 2 2 - -const MIN_N_64: u64 = 0x8000_0000_0000_0000; -const MIN_N_32: u64 = 0x0000_0000_8000_0000; -const MAX_P_64: u64 = 0x7FFF_FFFF_FFFF_FFFF; -const MAX_P_32: u64 = 0x0000_0000_7FFF_FFFF; -const MAX_32: u64 = 0x0000_0000_FFFF_FFFF; -const MAX_64: u64 = 0xFFFF_FFFF_FFFF_FFFF; - -// value cannot used as specific cases -const ALL_64: u64 = 0x0033; -const ALL_NZ_64: u64 = 0x0034; -const ALL_P_64: u64 = 0x0035; -const ALL_NZ_P_64: u64 = 0x0036; -const ALL_N_64: u64 = 0x0037; - -const ALL_32: u64 = 0x0043; -const ALL_NZ_32: u64 = 0x0044; -const ALL_P_32: u64 = 0x0045; -const ALL_N_32: u64 = 0x0046; -const ALL_NZ_P_32: u64 = 0x0047; - -const VALUES_END: u64 = 0x004D; - -fn get_test_values(value: u64) -> [u64; 16] { - match value { - ALL_64 => [ - 0, - 1, - 2, - 3, - MAX_P_32 - 1, - MAX_P_32, - MIN_N_32, - MAX_32 - 1, - MAX_32, - MAX_32 + 1, - MAX_P_64 - 1, - MAX_P_64, - MAX_64 - 1, - MIN_N_64, - MIN_N_64 + 1, - MAX_64, - ], - ALL_NZ_64 => [ - 1, - 2, - 3, - MAX_P_32 - 1, - MAX_P_32, - MIN_N_32, - MAX_32 - 1, - MAX_32, - MAX_32 + 1, - MAX_P_64 - 1, - MAX_P_64, - MAX_64 - 1, - MIN_N_64, - MIN_N_64 + 1, - MAX_64, - VALUES_END, - ], - ALL_P_64 => [ - 0, - 1, - 2, - 3, - MAX_P_32 - 1, - MAX_P_32, - MIN_N_32, - MAX_32 - 1, - MAX_32, - MAX_32 + 1, - MAX_P_64 - 1, - MAX_P_64, - VALUES_END, - 0, - 0, - 0, - ], - ALL_NZ_P_64 => [ - 1, - 2, - 3, - MAX_P_32 - 1, - MAX_P_32, - MIN_N_32, - MAX_32 - 1, - MAX_32, - MAX_32 + 1, - MAX_P_64 - 1, - MAX_P_64, - VALUES_END, - 0, - 0, - 0, - 0, - ], - ALL_N_64 => [ - MIN_N_64, - MIN_N_64 + 1, - MIN_N_64 + 2, - MIN_N_64 + 3, - 0x8000_0000_7FFF_FFFF, - 0x8FFF_FFFF_7FFF_FFFF, - 0xEFFF_FFFF_FFFF_FFFF, - MAX_64 - 3, - MAX_64 - 2, - MAX_64 - 1, - MAX_64, - VALUES_END, - 0, - 0, - 0, - 0, - ], - ALL_32 => [ - 0, - 1, - 2, - 3, - MAX_P_32 - 1, - MAX_P_32, - MIN_N_32, - MAX_32 - 1, - MAX_32, - VALUES_END, - 0, - 0, - 0, - 0, - 0, - 0, - ], - ALL_NZ_32 => [ - 1, - 2, - 3, - MAX_P_32 - 1, - MAX_P_32, - MIN_N_32, - MAX_32 - 1, - MAX_32, - VALUES_END, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - ], - ALL_P_32 => [ - 0, - 1, - 2, - 3, - 0x0000_7FFF, - 0x0000_FFFF, - MAX_P_32 - 1, - MAX_P_32, - MAX_P_32 - 1, - MAX_P_32, - VALUES_END, - 0, - 0, - 0, - 0, - 0, - ], - ALL_NZ_P_32 => [ - 1, - 2, - 3, - 0x0000_7FFF, - 0x0000_FFFF, - MAX_P_32 - 1, - MAX_P_32, - MAX_P_32 - 1, - MAX_P_32, - VALUES_END, - 0, - 0, - 0, - 0, - 0, - 0, - ], - ALL_N_32 => [ - MIN_N_32, - MIN_N_32 + 1, - MIN_N_32 + 2, - MIN_N_32 + 3, - MAX_32 - 3, - MAX_32 - 2, - MAX_32 - 1, - MAX_32, - VALUES_END, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - ], - _ => [value, VALUES_END, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - } -} -#[test] -fn test_calculate_range_checks() { - struct TestArithHelpers {} - impl ArithHelpers for TestArithHelpers {} - struct TestParams { - op: u8, - a: u64, - b: u64, - flags: u64, - range_ab: u64, - range_cd: u64, - } - - // NOTE: update TEST_COUNT with number of tests, ALL,ALL => 3*3 = 9 - const TEST_COUNT: u32 = 2510; - - // NOTE: use 0x0000_0000 instead of 0, to avoid auto-format in one line, 0 is too short. - let tests = [ - // 0 - MULU - TestParams { - op: MULU, - a: ALL_64, - b: ALL_64, - flags: 0x0000, - range_ab: R_FF, - range_cd: R_FF, - }, - // 1 - MULUH - TestParams { - op: MULUH, - a: ALL_64, - b: ALL_64, - flags: F_SEC, - range_ab: R_FF, - range_cd: R_FF, - }, - // 2 - MULSUH - TestParams { - op: MULSUH, - a: ALL_P_64, - b: ALL_64, - flags: F_SEC, - range_ab: R_3PF, - range_cd: R_3FP, - }, - // 3 - MULSUH - TestParams { - op: MULSUH, - a: ALL_N_64, - b: ALL_NZ_64, - flags: F_NA + F_NP + F_SEC, - range_ab: R_3NF, - range_cd: R_3FN, - }, - // 4 - MULSUH - TestParams { - op: MULSUH, - a: ALL_N_64, - b: 0x0000_0000, - flags: F_NA + F_SEC, - range_ab: R_3NF, - range_cd: R_3FP, - }, - // 5 - MUL - TestParams { - op: MUL, - a: ALL_P_64, - b: ALL_P_64, - flags: 0, - range_ab: R_3PP, - range_cd: R_3FP, - }, - // 6 - MUL - TestParams { - op: MUL, - a: ALL_N_64, - b: ALL_N_64, - flags: F_NA + F_NB, - range_ab: R_3NN, - range_cd: R_3FP, - }, - // 7 - MUL - TestParams { - op: MUL, - a: ALL_N_64, - b: ALL_NZ_P_64, - flags: F_NA + F_NP, - range_ab: R_3NP, - range_cd: R_3FN, - }, - // 8 - MUL - TestParams { - op: MUL, - a: ALL_N_64, - b: 0x0000_0000, - flags: F_NA, - range_ab: R_3NP, - range_cd: R_3FP, - }, - // 9 - MUL - TestParams { - op: MUL, - a: ALL_NZ_P_64, - b: ALL_N_64, - flags: F_NB + F_NP, - range_ab: R_3PN, - range_cd: R_3FN, - }, - // 10 - MUL - TestParams { - op: MUL, - a: 0x0000_0000, - b: ALL_N_64, - flags: F_NB, - range_ab: R_3PN, - range_cd: R_3FP, - }, - // 11 - MULH - TestParams { - op: MULH, - a: ALL_P_64, - b: ALL_P_64, - flags: F_SEC, - range_ab: R_3PP, - range_cd: R_3FP, - }, - // 12 - MULH - TestParams { - op: MULH, - a: ALL_N_64, - b: ALL_N_64, - flags: F_NA + F_NB + F_SEC, - range_ab: R_3NN, - range_cd: R_3FP, - }, - // 13 - MULH - TestParams { - op: MULH, - a: ALL_N_64, - b: ALL_NZ_P_64, - flags: F_NA + F_NP + F_SEC, - range_ab: R_3NP, - range_cd: R_3FN, - }, - // 14 - MULH - TestParams { - op: MULH, - a: ALL_N_64, - b: 0x0000_00000, - flags: F_NA + F_SEC, - range_ab: R_3NP, - range_cd: R_3FP, - }, - // 15 - MULH - TestParams { - op: MULH, - a: ALL_NZ_P_64, - b: ALL_N_64, - flags: F_NB + F_NP + F_SEC, - range_ab: R_3PN, - range_cd: R_3FN, - }, - // 16 - MULH - TestParams { - op: MULH, - a: 0x0000_0000, - b: ALL_N_64, - flags: F_NB + F_SEC, - range_ab: R_3PN, - range_cd: R_3FP, - }, - // 17 - MUL_W - TestParams { - op: MUL_W, - a: 0x0000_0000, - b: 0x0000_0000, - flags: F_M32, - range_ab: R_FF, - range_cd: R_1PF, - }, - // 18 - MUL_W: 0x00000002 (+/32 bits) * 0x40000000 (+/32 bits) = 0x80000000 (-/32 bits) - TestParams { - op: MUL_W, - a: 0x0000_0002, - b: 0x4000_0000, - flags: F_M32 + F_SEXT, - range_ab: R_FF, - range_cd: R_1NF, - }, - // 19 - MUL_W - TestParams { - op: MUL_W, - a: 0x0000_0002, - b: 0x8000_0000, - flags: F_M32, - range_ab: R_FF, - range_cd: R_1PF, - }, - // 20 - MUL_W - TestParams { - op: MUL_W, - a: 0xFFFF_FFFF, - b: 1, - flags: F_M32 + F_SEXT, - range_ab: R_FF, - range_cd: R_1NF, - }, - // 21 - MUL_W - TestParams { - op: MUL_W, - a: 0xFFFF_FFFF, - b: 0x0000_00000, - flags: F_M32, - range_ab: R_FF, - range_cd: R_1PF, - }, - // 22 - MUL_W - TestParams { - op: MUL_W, - a: 0x7FFF_FFFF, - b: 2, - flags: F_M32 + F_SEXT, - range_ab: R_FF, - range_cd: R_1NF, - }, - // 23 - MUL_W - TestParams { - op: MUL_W, - a: 0xBFFF_FFFF, - b: 0x0000_0002, - flags: F_M32, - range_ab: R_FF, - range_cd: R_1PF, - }, - // 24 - MUL_W: 0xFFFF_FFFF * 0xFFFF_FFFF = 0xFFFF_FFFE_0000_0001 - TestParams { - op: MUL_W, - a: 0xFFFF_FFFF, - b: 0xFFFF_FFFF, - flags: F_M32, - range_ab: R_FF, - range_cd: R_1PF, - }, - // 25 - MUL_W: 0xFFFF_FFFF * 0x0FFF_FFFF = 0x0FFF_FFFE_F000_0001 - TestParams { - op: MUL_W, - a: 0xFFFF_FFFF, - b: 0x0FFF_FFFF, - flags: F_M32 + F_SEXT, - range_ab: R_FF, - range_cd: R_1NF, - }, - // 26 - MUL_W: 0x8000_0000 * 0x8000_0000 = 0x4000_0000_0000_0000 - TestParams { - op: MUL_W, - a: 0x8000_0000, - b: 0x8000_0000, - flags: F_M32, - range_ab: R_FF, - range_cd: R_1PF, - }, - // 27 - DIVU - TestParams { - op: DIVU, - a: ALL_64, - b: ALL_NZ_64, - flags: F_DIV + 0, - range_ab: R_FF, - range_cd: R_FF, - }, - // 28 - REMU - TestParams { - op: REMU, - a: ALL_64, - b: ALL_NZ_64, - flags: F_DIV + F_SEC, - range_ab: R_FF, - range_cd: R_FF, - }, - // 29 - DIV - TestParams { - op: DIV, - a: MAX_P_64, - b: MAX_P_64, - flags: F_DIV, - range_ab: R_3PP, - range_cd: R_3PP, - }, - // 30 - DIV - TestParams { - op: DIV, - a: MIN_N_64, - b: MAX_P_64, - flags: F_DIV + F_NA + F_NP + F_NR, - range_ab: R_3NP, - range_cd: R_3NN, - }, - // 31 - DIV - TestParams { - op: DIV, - a: MAX_P_64, - b: MIN_N_64, - flags: F_DIV + F_NB, // a/b = 0 ➜ np = 0 - range_ab: R_3PN, - range_cd: R_3PP, - }, - // 32 - DIV - TestParams { - op: DIV, - a: MIN_N_64, - b: MIN_N_64, - flags: F_DIV + F_NB + F_NP, // a/b = 1 ➜ 1 * b_neg ➜ np = 1 - range_ab: R_3PN, - range_cd: R_3NP, - }, - // 33 - DIV - TestParams { - op: DIV, - a: 0x0000_0000, - b: MAX_P_64, - flags: F_DIV, - range_ab: R_3PP, - range_cd: R_3PP, - }, - // 34 - DIV - TestParams { - op: DIV, - a: 0x0000_0000, - b: MIN_N_64, - flags: F_DIV + F_NB, - range_ab: R_3PN, - range_cd: R_3PP, - }, - // 35 - REM - TestParams { - op: REM, - a: MAX_P_64, - b: MAX_P_64, - flags: F_DIV + F_SEC, - range_ab: R_3PP, - range_cd: R_3PP, - }, - // 36 - REM - TestParams { - op: REM, - a: MIN_N_64, - b: MAX_P_64, - flags: F_DIV + F_NA + F_NP + F_NR + F_SEC, - range_ab: R_3NP, - range_cd: R_3NN, - }, - // 37 - REM - TestParams { - op: REM, - a: MAX_P_64, - b: MIN_N_64, - flags: F_DIV + F_NB + F_SEC, - range_ab: R_3PN, - range_cd: R_3PP, - }, - // 38 - REM - TestParams { - op: REM, - a: MIN_N_64, - b: MIN_N_64, - flags: F_DIV + F_NB + F_NP + F_SEC, - range_ab: R_3PN, - range_cd: R_3NP, - }, - // 39 - REM - TestParams { - op: REM, - a: 0x0000_0000, - b: MAX_P_64, - flags: F_DIV + F_SEC, - range_ab: R_3PP, - range_cd: R_3PP, - }, - // 40 - REM - TestParams { - op: REM, - a: 0x0000_0000, - b: MIN_N_64, - flags: F_DIV + F_NB + F_SEC, - range_ab: R_3PN, - range_cd: R_3PP, - }, - // 41 - DIVU_W - TestParams { - op: DIVU_W, - a: 0xFFFF_FFFF, - b: 0x0000_0001, - flags: F_DIV + F_M32 + F_SEXT, - range_ab: R_1NF, - range_cd: R_FF, - }, - // 42 - DIVU_W - TestParams { - op: DIVU_W, - a: ALL_NZ_32, - b: 0x0000_00002, - flags: F_DIV + F_M32, - range_ab: R_1PF, - range_cd: R_FF, - }, - // 43 - DIVU_W - TestParams { - op: DIVU_W, - a: ALL_NZ_32, - b: MAX_32, - flags: F_DIV + F_M32, - range_ab: R_1PF, - range_cd: R_FF, - }, - // 44 - DIVU_W - TestParams { - op: DIVU_W, - a: 0, - b: ALL_NZ_32, - flags: F_DIV + F_M32, - range_ab: R_1PF, - range_cd: R_FF, - }, - // 45 - REMU_W - TestParams { - op: REMU_W, - a: 0xFFFF_FFFF, - b: 0x0000_0001, - flags: F_DIV + F_M32 + F_SEC, - range_ab: R_FF, - range_cd: R_1FP, - }, - // 46 - REMU_W - TestParams { - op: REMU_W, - a: ALL_32, - b: 0x0000_00002, - flags: F_DIV + F_M32 + F_SEC, - range_ab: R_FF, - range_cd: R_1FP, - }, - // 47 - REMU_W - TestParams { - op: REMU_W, - a: ALL_NZ_P_32, - b: MAX_32, - flags: F_DIV + F_M32 + F_SEC, - range_ab: R_FF, - range_cd: R_1FP, - }, - // 48 - REMU_W - TestParams { - op: REMU_W, - a: ALL_32, - b: 0x8000_0000, - flags: F_DIV + F_M32 + F_SEC, - range_ab: R_FF, - range_cd: R_1FP, - }, - // 49 - REMU_W - TestParams { - op: REMU_W, - a: 0, - b: ALL_NZ_32, - flags: F_DIV + F_M32 + F_SEC, - range_ab: R_FF, - range_cd: R_1FP, - }, - // 50 - REMU_W - TestParams { - op: REMU_W, - a: 0xFFFF_FFFE, - b: 0xFFFF_FFFF, - flags: F_DIV + F_M32 + F_SEXT + F_SEC, - range_ab: R_FF, - range_cd: R_1FN, - }, - // 51 - REMU_W - TestParams { - op: REMU_W, - a: 0xFFFF_FFFE, - b: 0xFFFF_FFFE, - flags: F_DIV + F_M32 + F_SEC, - range_ab: R_FF, - range_cd: R_1FP, - }, - // 52 - REMU_W - TestParams { - op: REMU_W, - a: 0x8000_0000, - b: 0x8000_0001, - flags: F_DIV + F_M32 + F_SEXT + F_SEC, - range_ab: R_FF, - range_cd: R_1FN, - }, - // 53 - REMU_W - TestParams { - op: REMU_W, - a: 0x8000_0001, - b: 0x8000_0000, - flags: F_DIV + F_M32 + F_SEC, - range_ab: R_FF, - range_cd: R_1FP, - }, - // 54 - REMU_W - TestParams { - op: REMU_W, - a: 0xFFFF_FFFF, - b: 0x0000_0003, - flags: F_DIV + F_M32 + F_SEC, - range_ab: R_FF, - range_cd: R_1FP, - }, - // 55 - DIV_W (-1/1=-1 REM:0) - TestParams { - op: DIV_W, - a: 0xFFFF_FFFF, - b: 0x0000_0001, - flags: F_DIV + F_NA + F_NP + F_M32 + F_SEXT, - range_ab: R_1NP, - range_cd: R_1NP, - }, - // 56 - REM_W !!! - TestParams { - op: REM_W, - a: 0xFFFF_FFFF, - b: 0x0000_0001, - flags: F_DIV + F_NA + F_NP + F_M32 + F_SEC, - range_ab: R_1NP, - range_cd: R_1NP, - }, - // 57 - DIV_W <====== - TestParams { - op: DIV_W, - a: 0xFFFF_FFFF, - b: 0x0000_0002, - flags: F_DIV + F_NP + F_NR + F_M32, - range_ab: R_1PP, - range_cd: R_1NN, - }, - // 58 - REM_W - TestParams { - op: REM_W, - a: 0xFFFF_FFFF, - b: 0x0000_0002, - flags: F_DIV + F_NP + F_NR + F_M32 + F_SEC + F_SEXT, - range_ab: R_1PP, - range_cd: R_1NN, - }, - ]; - - let mut count = 0; - let mut index: u32 = 0; - - #[derive(Debug, PartialEq)] - struct TestDone { - op: u8, - a: u64, - b: u64, - index: u32, - offset: u32, - } - - let mut tests_done: Vec = Vec::new(); - let mut errors = 0; - for test in tests { - let a_values = get_test_values(test.a); - let mut offset = 0; - for _a in a_values { - if _a == VALUES_END { - break; - } - let b_values = get_test_values(test.b); - for _b in b_values { - if _b == VALUES_END { - break; - } - let test_info = TestDone { op: test.op, a: _a, b: _b, index, offset }; - let previous = tests_done - .iter() - .find(|&x| x.op == test_info.op && x.a == test_info.a && x.b == test_info.b); - match previous { - Some(e) => { - println!( - "\x1B[35mDuplicated TEST #{} op:0x{:x} a:0x{:X} b:0x{:X} offset:{}\x1B[0m", - e.index, e.op, e.a, e.b, e.offset - ); - } - None => { - tests_done.push(test_info); - } - } - println!("testing #{} op:0x{:x} with _a:0x{:X} _b:0x{:X}", index, test.op, _a, _b); - let (emu_c, emu_flag) = TestArithHelpers::calculate_emulator_res(test.op, _a, _b); - let [a, b, c, d] = TestArithHelpers::calculate_abcd_from_ab(test.op, _a, _b); - - let [m32, div, na, nb, np, nr, sext, sec, range_ab, range_cd, ranges] = - TestArithHelpers::calculate_flags_and_ranges(test.op, a, b, c, d); - - let flags = - m32 + div * 2 + na * 4 + nb * 8 + np * 16 + nr * 32 + sext * 64 + sec * 128; - - let row = TestArithHelpers::get_row(test.op, na, nb, np, nr, sext); - println!( - "#{} op:0x{:x} na:{} nb:{} np:{} nr:{} sext:{}", - row, test.op, na, nb, np, nr, sext - ); - assert_eq!( - [flags, range_ab, range_cd], - [test.flags, test.range_ab, test.range_cd], - "testing #{} op:0x{:x} with _a:0x{:X} _b:0x{:X} a:0x{:X} b:0x{:X} c:0x{:X} d:0x{:X} EMU:0x{:X} flags:{:b}[{}]/{:b}[{}] range_ab:{}/{} range_cd:{}/{} ranges:{}", - index, - test.op, - _a, - _b, - a, - b, - c, - d, - emu_c, - flags, - flags_to_strings(flags, &FLAG_NAMES), - test.flags, - flags_to_strings(test.flags, &FLAG_NAMES), - range_ab, - test.range_ab, - range_cd, - test.range_cd, - ranges - ); - println!("testing #{} op:0x{:x} with _a:0x{:X} _b:0x{:X} a:0x{:X} b:0x{:X} c:0x{:X} d:0x{:X} EMU:0x{:X} flags:{:b}[{}]/{:b}[{}] range_ab:{}/{} range_cd:{}/{} ranges:{}", - index, - test.op, - _a, - _b, - a, - b, - c, - d, - emu_c, - flags, - flags_to_strings(flags, &FLAG_NAMES), - test.flags, - flags_to_strings(test.flags, &FLAG_NAMES), - range_ab, - test.range_ab, - range_cd, - test.range_cd, - ranges - ); - assert_ne!(row, -1); - let bus: [u64; 8] = [ - test.op as u64, - _a & 0xFFFF_FFFF, - _a >> 32, - _b & 0xFFFF_FFFF, - _b >> 32, - emu_c & 0xFFFF_FFFF, - emu_c >> 32, - if emu_flag { 1 } else { 0 }, - ]; - if !TestArithHelpers::execute_chunks( - a, b, c, d, m32, div, na, nb, np, nr, sec, sext, range_ab, range_cd, bus, - ) { - errors += 1; - println!("TOTAL ERRORS: {}", errors); - } - offset += 1; - count += 1; - } - } - index += 1; - } - println!("TOTAL ERRORS: {}", errors); - assert_eq!(count, TEST_COUNT, "Number of tests not matching"); -} diff --git a/state-machines/arith/src/arith_operation.rs b/state-machines/arith/src/arith_operation.rs new file mode 100644 index 00000000..20ffde86 --- /dev/null +++ b/state-machines/arith/src/arith_operation.rs @@ -0,0 +1,609 @@ +use crate::{arith_constants::*, arith_range_table_helpers::*}; +use std::fmt; +use zisk_core::zisk_ops::*; + +pub struct ArithOperation { + pub op: u8, + pub input_a: u64, + pub input_b: u64, + pub a: [u64; 4], + pub b: [u64; 4], + pub c: [u64; 4], + pub d: [u64; 4], + pub carry: [i64; 7], + pub m32: bool, + pub div: bool, + pub na: bool, + pub nb: bool, + pub np: bool, + pub nr: bool, + pub sext: bool, + pub main_mul: bool, + pub main_div: bool, + pub signed: bool, + pub range_ab: u8, + pub range_cd: u8, +} + +impl fmt::Debug for ArithOperation { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let mut flags = String::new(); + if self.m32 { + flags += "m32 " + }; + if self.div { + flags += "div " + }; + if self.na { + flags += "na " + }; + if self.nb { + flags += "nb " + }; + if self.np { + flags += "np " + }; + if self.nr { + flags += "nr " + }; + if self.sext { + flags += "sext " + }; + if self.main_mul { + flags += "main_mul " + }; + if self.main_div { + flags += "main_div " + }; + if self.signed { + flags += "signed " + }; + write!(f, "operation 0x{:x} flags={}\n", self.op, flags)?; + write!(f, "input_a: 0x{0:x}({0})\n", self.input_a)?; + write!(f, "input_b: 0x{0:x}({0})\n", self.input_b)?; + self.dump_chunks(f, "a", &self.a)?; + self.dump_chunks(f, "b", &self.b)?; + self.dump_chunks(f, "c", &self.c)?; + self.dump_chunks(f, "d", &self.d)?; + write!( + f, + "carry: [0x{0:X}({0}), 0x{1:X}({1}), 0x{2:X}({2}), 0x{3:X}({3}), 0x{4:X}({4}), 0x{5:X}({5}), 0x{6:X}({6})]\n", + self.carry[0], self.carry[1], self.carry[2], self.carry[3], self.carry[4], self.carry[5], self.carry[6] + )?; + write!( + f, + "range_ab: 0x{0:X} {1}, range_cd:0x{2:X} {3}\n", + self.range_ab, + AirthRangeTableHelpers::get_range_name(self.range_ab), + self.range_cd, + AirthRangeTableHelpers::get_range_name(self.range_cd) + ) + } +} + +impl ArithOperation { + fn dump_chunks(&self, f: &mut fmt::Formatter, name: &str, value: &[u64; 4]) -> fmt::Result { + write!( + f, + "{0}: [0x{1:X}({1}), 0x{2:X}({2}), 0x{3:X}({3}), 0x{4:X}({4})]\n", + name, value[0], value[1], value[2], value[3] + ) + } + pub fn new() -> Self { + Self { + op: 0, + input_a: 0, + input_b: 0, + a: [0, 0, 0, 0], + b: [0, 0, 0, 0], + c: [0, 0, 0, 0], + d: [0, 0, 0, 0], + carry: [0, 0, 0, 0, 0, 0, 0], + m32: false, + div: false, + na: false, + nb: false, + np: false, + nr: false, + sext: false, + main_mul: false, + main_div: false, + signed: false, + range_ab: 0, + range_cd: 0, + } + } + pub fn calculate(&mut self, op: u8, input_a: u64, input_b: u64) { + self.op = op; + self.input_a = input_a; + self.input_b = input_b; + let [a, b, c, d] = Self::calculate_abcd_from_ab(op, input_a, input_b); + self.a = Self::u64_to_chunks(a); + self.b = Self::u64_to_chunks(b); + self.c = Self::u64_to_chunks(c); + self.d = Self::u64_to_chunks(d); + self.update_flags_and_ranges(op, a, b, c, d); + let chunks = self.calculate_chunks(); + self.update_carries(&chunks); + } + fn update_carries(&mut self, chunks: &[i64; 8]) { + for i in 0..8 { + let chunk_value = chunks[i] + if i > 0 { self.carry[i - 1] } else { 0 }; + if i >= 7 { + continue; + } + self.carry[i] = chunk_value / 0x10000; + } + } + fn sign32(abs_value: u64, negative: bool) -> u64 { + assert!(0xFFFF_FFFF >= abs_value, "abs_value:0x{0:X}({0}) is too big", abs_value); + if negative { + (0xFFFF_FFFF - abs_value) + 1 + } else { + abs_value + } + } + + fn sign64(abs_value: u64, negative: bool) -> u64 { + if negative { + (0xFFFF_FFFF_FFFF_FFFF - abs_value) + 1 + } else { + abs_value + } + } + fn sign128(abs_value: u128, negative: bool) -> u128 { + let res = if negative { + (0xFFFF_FFFF_FFFF_FFFF_FFFF_FFFF_FFFF_FFFF - abs_value) + 1 + } else { + abs_value + }; + // println!("sign128({:X},{})={:X}", abs_value, negative, res); + res + } + fn abs32(value: u64) -> [u64; 2] { + let negative = if (value & 0x8000_0000) != 0 { 1 } else { 0 }; + let abs_value = if negative == 1 { (0xFFFF_FFFF - value) + 1 } else { value }; + [abs_value, negative] + } + fn abs64(value: u64) -> [u64; 2] { + let negative = if (value & 0x8000_0000_0000_0000) != 0 { 1 } else { 0 }; + let abs_value = if negative == 1 { (0xFFFF_FFFF_FFFF_FFFF - value) + 1 } else { value }; + [abs_value, negative] + } + fn calculate_mul_w(a: u64, b: u64) -> u64 { + (a & 0xFFFF_FFFF) * (b & 0xFFFF_FFFF) + } + + fn calculate_mulsu(a: u64, b: u64) -> [u64; 2] { + let [abs_a, na] = Self::abs64(a); + let abs_c = abs_a as u128 * b as u128; + let nc = if na == 1 && abs_c != 0 { 1 } else { 0 }; + let c = Self::sign128(abs_c, nc == 1); + [c as u64, (c >> 64) as u64] + } + + fn calculate_mul(a: u64, b: u64) -> [u64; 2] { + let [abs_a, na] = Self::abs64(a); + let [abs_b, nb] = Self::abs64(b); + let abs_c = abs_a as u128 * abs_b as u128; + let nc = if na != nb && abs_c != 0 { 1 } else { 0 }; + let c = Self::sign128(abs_c, nc == 1); + [c as u64, (c >> 64) as u64] + } + + fn calculate_div(a: u64, b: u64) -> u64 { + let [abs_a, na] = Self::abs64(a); + let [abs_b, nb] = Self::abs64(b); + let abs_c = abs_a / abs_b; + let nc = if na != nb && abs_c != 0 { 1 } else { 0 }; + Self::sign64(abs_c, nc == 1) + } + + fn calculate_rem(a: u64, b: u64) -> u64 { + let [abs_a, na] = Self::abs64(a); + let [abs_b, _nb] = Self::abs64(b); + let abs_c = abs_a % abs_b; + let nc = if na == 1 && abs_c != 0 { 1 } else { 0 }; + Self::sign64(abs_c, nc == 1) + } + + fn calculate_div_w(a: u64, b: u64) -> u64 { + let [abs_a, na] = Self::abs32(a); + let [abs_b, nb] = Self::abs32(b); + let abs_c = abs_a / abs_b; + let nc = if na != nb && abs_c != 0 { 1 } else { 0 }; + Self::sign32(abs_c, nc == 1) + } + + fn calculate_rem_w(a: u64, b: u64) -> u64 { + let [abs_a, na] = Self::abs32(a); + let [abs_b, _nb] = Self::abs32(b); + let abs_c = abs_a % abs_b; + let nc = if na == 1 && abs_c != 0 { 1 } else { 0 }; + Self::sign32(abs_c, nc == 1) + } + + fn calculate_abcd_from_ab(op: u8, a: u64, b: u64) -> [u64; 4] { + match op { + MULU | MULUH => { + let c: u128 = a as u128 * b as u128; + [a, b, c as u64, (c >> 64) as u64] + } + MULSUH => { + let [c, d] = Self::calculate_mulsu(a, b); + [a, b, c, d] + } + MUL | MULH => { + let [c, d] = Self::calculate_mul(a, b); + [a, b, c, d] + } + MUL_W => [a, b, Self::calculate_mul_w(a, b), 0], + DIVU | REMU | DIVU_W | REMU_W => [a / b, b, a, a % b], + DIV | REM => [Self::calculate_div(a, b), b, a, Self::calculate_rem(a, b)], + DIV_W | REM_W => [Self::calculate_div_w(a, b), b, a, Self::calculate_rem_w(a, b)], + _ => { + panic!("Invalid opcode"); + } + } + } + fn update_flags_and_ranges(&mut self, op: u8, a: u64, b: u64, c: u64, d: u64) { + self.m32 = false; + self.div = false; + self.np = false; + self.nr = false; + self.sext = false; + self.main_mul = false; + self.main_div = false; + self.signed = false; + + let mut range_a1: u8 = 0; + let mut range_b1: u8 = 0; + let mut range_c1: u8 = 0; + let mut range_d1: u8 = 0; + let mut range_a3: u8 = 0; + let mut range_b3: u8 = 0; + let mut range_c3: u8 = 0; + let mut range_d3: u8 = 0; + + // direct table opcode(14), signed 2 or 4 cases (0,na,nb,na+nb) + // 6 * 1 + 7 * 4 + 1 * 2 = 36 entries, + // no compacted => 16 x 4 = 64, key = (op - 0xb0) * 4 + na * 2 + nb + // output: div, m32, sa, sb, nr, np, na, na32, nd32, range x 2 x 4 + + // alternative: switch operation, + + let mut sa = false; + let mut sb = false; + let mut rem = false; + + match op { + MULU => { + self.main_mul = true; + } + MULUH => {} + MULSUH => { + sa = true; + } + MUL => { + sa = true; + sb = true; + self.main_mul = true; + } + MULH => { + sa = true; + sb = true; + } + MUL_W => { + self.m32 = true; + self.sext = ((a * b) & 0xFFFF_FFFF) & 0x8000_0000 != 0; + self.main_mul = true; + } + DIVU => { + self.div = true; + self.main_div = true; + assert!(b != 0, "Error on DIVU a:{:x}({}) b:{:x}({})", a, b, a, b); + } + REMU => { + self.div = true; + rem = true; + } + DIV => { + sa = true; + sb = true; + self.div = true; + self.main_div = true; + } + REM => { + sa = true; + sb = true; + rem = true; + self.div = true; + } + DIVU_W => { + // divu_w, remu_w + self.div = true; + self.m32 = true; + // use a in bus + self.sext = (a & 0x8000_0000) != 0; + self.main_div = true; + } + REMU_W => { + // divu_w, remu_w + self.div = true; + self.m32 = true; + rem = true; + // use d in bus + self.sext = (d & 0x8000_0000) != 0; + } + DIV_W => { + // div_w, rem_w + sa = true; + sb = true; + self.div = true; + self.m32 = true; + // use a in bus + self.sext = (a & 0x8000_0000) != 0; + self.main_div = true; + } + REM_W => { + // div_w, rem_w + sa = true; + sb = true; + self.div = true; + self.m32 = true; + rem = true; + // use d in bus + self.sext = (d & 0x8000_0000) != 0; + } + _ => { + panic!("Invalid opcode"); + } + } + self.signed = sa || sb; + + let sign_mask: u64 = if self.m32 { 0x8000_0000 } else { 0x8000_0000_0000_0000 }; + let sign_c_mask: u64 = + if self.m32 && self.div { 0x8000_0000 } else { 0x8000_0000_0000_0000 }; + self.na = sa && (a & sign_mask) != 0; + self.nb = sb && (b & sign_mask) != 0; + // a sign => b sign + let nc = sa && (c & sign_c_mask) != 0; + let nd = sa && (d & sign_mask) != 0; + + // a == 0 || b == 0 => np == 0 ==> how was a signed operation + // after that sign of np was verified with range check. + // TODO: review if secure + if self.div { + self.np = nc; //if c != 0 { na ^ nb } else { 0 }; + self.nr = nd; + } else { + self.np = if self.m32 { nc } else { nd }; // if (c != 0) || (d != 0) { na ^ nb } else { 0 } + self.nr = false; + } + if self.m32 { + // mulw, divu_w, remu_w, div_w, rem_w + range_a1 = if sa { + if self.na { + 2 + } else { + 1 + } + } else if self.div && !rem { + if self.sext { + 2 + } else { + 1 + } + } else { + 0 + }; + range_b1 = if sb { + if self.nb { + 2 + } else { + 1 + } + } else { + 0 + }; + // m32 && div == 0 => mulw + range_c1 = if !self.div { + if self.sext { + 2 + } else { + 1 + } + } else if sa { + if self.np { + 2 + } else { + 1 + } + } else { + 0 + }; + range_d1 = if rem { + if self.sext { + 2 + } else { + 1 + } + } else if sa { + if self.nr { + 2 + } else { + 1 + } + } else { + 0 + }; + } else { + // mulu, muluh, mulsuh, mul, mulh, div, rem, divu, remu + if sa { + // mulsuh, mul, mulh, div, rem + range_a3 = if self.na { 2 } else { 1 }; + if self.div { + // div, rem + range_c3 = if self.np { 2 } else { 1 }; + range_d3 = if self.nr { 2 } else { 1 } + } else { + range_d3 = if self.np { 2 } else { 1 } + } + } + // sb => mul, mulh, div, rem + range_b3 = if sb { + if self.nb { + 2 + } else { + 1 + } + } else { + 0 + }; + } + + // range_ab / range_cd + // + // a3 a1 b3 b1 + // rid c3 c1 d3 d1 range 2^16 2^15 notes + // --- -- -- -- -- ----- ---- ---- ------------------------- + // 0 F F F F ab cd 4 0 + // 1 F F + F cd 3 1 b3 sign => a3 sign + // 2 F F - F cd 3 1 b3 sign => a3 sign + // 3 + F F F ab 3 1 c3 sign => d3 sign + // 4 + F + F ab cd 2 2 + // 5 + F - F ab cd 2 2 + // 6 - F F F ab 3 1 c3 sign => d3 sign + // 7 - F + F ab cd 2 2 + // 8 - F - F ab cd 2 2 + // 9 F F F + cd a1 sign <=> b1 sign / d1 sign => c1 sign + // 10 F F F - cd a1 sign <=> b1 sign / d1 sign => c1 sign + // 11 F + F F cd 3 1 a1 sign <=> b1 sign + // 12 F + F + ab cd 2 2 + // 13 F + F - ab cd 2 2 + // 14 F - F F cd 3 1 a1 sign <=> b1 sign + // 15 F - F + ab cd 2 2 + // 16 F - F - ab cd 2 2 + + assert!(range_a1 == 0 || range_a3 == 0, "range_a1:{} range_a3:{}", range_a1, range_a3); + assert!(range_b1 == 0 || range_b3 == 0, "range_b1:{} range_b3:{}", range_b1, range_b3); + assert!(range_c1 == 0 || range_c3 == 0, "range_c1:{} range_c3:{}", range_c1, range_c3); + assert!(range_d1 == 0 || range_d3 == 0, "range_d1:{} range_d3:{}", range_d1, range_d3); + + self.range_ab = (range_a3 + range_a1) * 3 + + range_b3 + + range_b1 + + if (range_a1 + range_b1) > 0 { 8 } else { 0 }; + + self.range_cd = (range_c3 + range_c1) * 3 + + range_d3 + + range_d1 + + if (range_c1 + range_d1) > 0 { 8 } else { 0 }; + } + + pub fn calculate_chunks(&self) -> [i64; 8] { + // TODO: unroll this function in variants (div,m32) and (na,nb,nr,np) + // div, m32, na, nb === f(div,m32,na,nb) => fa, nb, nr + // unroll means 16 variants ==> but more performance + + let mut chunks: [i64; 8] = [0, 0, 0, 0, 0, 0, 0, 0]; + + let fab = if self.na != self.nb { -1 } else { 1 }; + + let a = [self.a[0] as i64, self.a[1] as i64, self.a[2] as i64, self.a[3] as i64]; + let b = [self.b[0] as i64, self.b[1] as i64, self.b[2] as i64, self.b[3] as i64]; + let c = [self.c[0] as i64, self.c[1] as i64, self.c[2] as i64, self.c[3] as i64]; + let d = [self.d[0] as i64, self.d[1] as i64, self.d[2] as i64, self.d[3] as i64]; + + let na = self.na as i64; + let nb = self.nb as i64; + let np = self.np as i64; + let nr = self.nr as i64; + let m32 = self.m32 as i64; + let div = self.div as i64; + + let na_fb = na * (1 - 2 * nb); + let nb_fa = nb * (1 - 2 * na); + + chunks[0] = fab * a[0] * b[0] // chunk0 + - c[0] + + 2 * np * c[0] + + div * d[0] + - 2 * nr * d[0]; + + chunks[1] = fab * a[1] * b[0] // chunk1 + + fab * a[0] * b[1] + - c[1] + + 2 * np * c[1] + + div * d[1] + - 2 * nr * d[1]; + + chunks[2] = fab * a[2] * b[0] // chunk2 + + fab * a[1] * b[1] + + fab * a[0] * b[2] + + a[0] * nb_fa * m32 + + b[0] * na_fb * m32 + - c[2] + + 2 * np * c[2] + + div * d[2] + - 2 * nr * d[2] + - np * div * m32 + + nr * m32; // div == 0 ==> nr = 0 + + chunks[3] = fab * a[3] * b[0] // chunk3 + + fab * a[2] * b[1] + + fab * a[1] * b[2] + + fab * a[0] * b[3] + + a[1] * nb_fa * m32 + + b[1] * na_fb * m32 + - c[3] + + 2 * np * c[3] + + div * d[3] + - 2 * nr * d[3]; + + chunks[4] = fab * a[3] * b[1] // chunk4 + + fab * a[2] * b[2] + + fab * a[1] * b[3] + + na * nb * m32 + // + b[0] * na * (1 - 2 * nb) + // + a[0] * nb * (1 - 2 * na) + + b[0] * na_fb * (1 - m32) + + a[0] * nb_fa * (1 - m32) + // high bits ^^^ + // - np * div + // + np * div * m32 + // - 2 * div * m32 * np + - np * m32 * (1 - div) // + - np * (1 - m32) * div // 2^64 (np) + + nr * (1 - m32) // 2^64 (nr) + // high part d + - d[0] * (1 - div) // m32 == 1 and div == 0 => d = 0 + + 2 * np * d[0] * (1 - div); // + + chunks[5] = fab * a[3] * b[2] // chunk5 + + fab * a[2] * b[3] + + a[1] * nb_fa * (1 - m32) + + b[1] * na_fb * (1 - m32) + - d[1] * (1 - div) + + d[1] * 2 * np * (1 - div); + + chunks[6] = fab as i64 * a[3] * b[3] // chunk6 + + a[2] * nb_fa * (1 - m32) + + b[2] * na_fb * (1 - m32) + - d[2] * (1 - div) + + d[2] * 2 * np * (1 - div); + + // 0x4000_0000_0000_0000__8000_0000_0000_0000 + chunks[7] = 0x10000 * na * nb * (1 - m32) // chunk7 + + a[3] * nb_fa * (1 - m32) + + b[3] * na_fb * (1 - m32) + - 0x10000 * np * (1 - div) * (1 - m32) + - d[3] * (1 - div) + + d[3] * 2 * np * (1 - div); + + chunks + } + fn u64_to_chunks(a: u64) -> [u64; 4] { + [a & 0xFFFF, (a >> 16) & 0xFFFF, (a >> 32) & 0xFFFF, (a >> 48) & 0xFFFF] + } +} diff --git a/state-machines/arith/src/arith_operation_test.rs b/state-machines/arith/src/arith_operation_test.rs new file mode 100644 index 00000000..c76d56f1 --- /dev/null +++ b/state-machines/arith/src/arith_operation_test.rs @@ -0,0 +1,1115 @@ +use zisk_core::zisk_ops::*; + +use crate::{arith_constants::*, ArithOperation}; + +const FLAG_NAMES: [&str; 8] = ["m32", "div", "na", "nb", "np", "nr", "sext", "sec"]; + +struct TestParams { + op: u8, + a: u64, + b: u64, + flags: u64, + range_ab: u64, + range_cd: u64, +} +// NOTE: update TEST_COUNT with number of tests, ALL,ALL => 3*3 = 9 +const TEST_COUNT: u32 = 2510; + +const F_M32: u64 = 0x0001; +const F_DIV: u64 = 0x0002; +const F_NA: u64 = 0x0004; +const F_NB: u64 = 0x0008; +const F_NP: u64 = 0x0010; +const F_NR: u64 = 0x0020; +const F_SEXT: u64 = 0x0040; +const F_SEC: u64 = 0x0080; + +// range_ab / range_cd +// +// a3 a1 b3 b1 +// rid c3 c1 d3 d1 range 2^16 2^15 notes +// --- -- -- -- -- ----- ---- ---- ------------------------- + +const R_FF: u64 = 0; // 0 F F F F ab cd 4 0 +const R_3FP: u64 = 1; // 1 F F + F cd 3 1 b3 sign => a3 sign +const R_3FN: u64 = 2; // 2 F F - F cd 3 1 b3 sign => a3 sign +const R_3PF: u64 = 3; // 3 + F F F ab 3 1 c3 sign => d3 sign +const R_3PP: u64 = 4; // 4 + F + F ab cd 2 2 +const R_3PN: u64 = 5; // 5 + F - F ab cd 2 2 +const R_3NF: u64 = 6; // 6 - F F F ab 3 1 c3 sign => d3 sign +const R_3NP: u64 = 7; // 7 - F + F ab cd 2 2 +const R_3NN: u64 = 8; // 8 - F - F ab cd 2 2 +const R_1FP: u64 = 9; // 9 F F F + cd a1 sign <=> b1 sign / d1 sign => c1 sign +const R_1FN: u64 = 10; // 10 F F F - cd a1 sign <=> b1 sign / d1 sign => c1 sign +const R_1PF: u64 = 11; // 11 F + F F cd 3 1 a1 sign <=> b1 sign +const R_1PP: u64 = 12; // 12 F + F + ab cd 2 2 +const R_1PN: u64 = 13; // 13 F + F - ab cd 2 2 +const R_1NF: u64 = 14; // 14 F - F F cd 3 1 a1 sign <=> b1 sign +const R_1NP: u64 = 15; // 15 F - F + ab cd 2 2 +const R_1NN: u64 = 16; // 16 F - F - ab cd 2 2 + +const MIN_N_64: u64 = 0x8000_0000_0000_0000; +const MIN_N_32: u64 = 0x0000_0000_8000_0000; +const MAX_P_64: u64 = 0x7FFF_FFFF_FFFF_FFFF; +const MAX_P_32: u64 = 0x0000_0000_7FFF_FFFF; +const MAX_32: u64 = 0x0000_0000_FFFF_FFFF; +const MAX_64: u64 = 0xFFFF_FFFF_FFFF_FFFF; + +// value cannot used as specific cases +const ALL_64: u64 = 0x0033; +const ALL_NZ_64: u64 = 0x0034; +const ALL_P_64: u64 = 0x0035; +const ALL_NZ_P_64: u64 = 0x0036; +const ALL_N_64: u64 = 0x0037; + +const ALL_32: u64 = 0x0043; +const ALL_NZ_32: u64 = 0x0044; +const ALL_P_32: u64 = 0x0045; +const ALL_N_32: u64 = 0x0046; +const ALL_NZ_P_32: u64 = 0x0047; + +const VALUES_END: u64 = 0x004D; + +struct ArithOperationTest { + count: u32, + ok: u32, + fail: u32, + fail_range_check: u32, + fail_table: u32, + fail_bus: u32, + fail_by_op: [u32; 16], +} + +impl ArithOperationTest { + // NOTE: use 0x0000_0000 instead of 0, to avoid auto-format in one line, 0 is too short. + pub fn new() -> Self { + ArithOperationTest { + count: 0, + ok: 0, + fail: 0, + fail_range_check: 0, + fail_table: 0, + fail_bus: 0, + fail_by_op: [0; 16], + } + } + fn test(&mut self) { + let mut count = 0; + let mut index: u32 = 0; + + #[derive(Debug, PartialEq)] + struct TestDone { + op: u8, + a: u64, + b: u64, + index: u32, + offset: u32, + } + + let tests = [ + // 0 - MULU + TestParams { + op: MULU, + a: ALL_64, + b: ALL_64, + flags: 0x0000, + range_ab: R_FF, + range_cd: R_FF, + }, + // 1 - MULUH + TestParams { + op: MULUH, + a: ALL_64, + b: ALL_64, + flags: F_SEC, + range_ab: R_FF, + range_cd: R_FF, + }, + // 2 - MULSUH + TestParams { + op: MULSUH, + a: ALL_P_64, + b: ALL_64, + flags: F_SEC, + range_ab: R_3PF, + range_cd: R_3FP, + }, + // 3 - MULSUH + TestParams { + op: MULSUH, + a: ALL_N_64, + b: ALL_NZ_64, + flags: F_NA + F_NP + F_SEC, + range_ab: R_3NF, + range_cd: R_3FN, + }, + // 4 - MULSUH + TestParams { + op: MULSUH, + a: ALL_N_64, + b: 0x0000_0000, + flags: F_NA + F_SEC, + range_ab: R_3NF, + range_cd: R_3FP, + }, + // 5 - MUL + TestParams { + op: MUL, + a: ALL_P_64, + b: ALL_P_64, + flags: 0, + range_ab: R_3PP, + range_cd: R_3FP, + }, + // 6 - MUL + TestParams { + op: MUL, + a: ALL_N_64, + b: ALL_N_64, + flags: F_NA + F_NB, + range_ab: R_3NN, + range_cd: R_3FP, + }, + // 7 - MUL + TestParams { + op: MUL, + a: ALL_N_64, + b: ALL_NZ_P_64, + flags: F_NA + F_NP, + range_ab: R_3NP, + range_cd: R_3FN, + }, + // 8 - MUL + TestParams { + op: MUL, + a: ALL_N_64, + b: 0x0000_0000, + flags: F_NA, + range_ab: R_3NP, + range_cd: R_3FP, + }, + // 9 - MUL + TestParams { + op: MUL, + a: ALL_NZ_P_64, + b: ALL_N_64, + flags: F_NB + F_NP, + range_ab: R_3PN, + range_cd: R_3FN, + }, + // 10 - MUL + TestParams { + op: MUL, + a: 0x0000_0000, + b: ALL_N_64, + flags: F_NB, + range_ab: R_3PN, + range_cd: R_3FP, + }, + // 11 - MULH + TestParams { + op: MULH, + a: ALL_P_64, + b: ALL_P_64, + flags: F_SEC, + range_ab: R_3PP, + range_cd: R_3FP, + }, + // 12 - MULH + TestParams { + op: MULH, + a: ALL_N_64, + b: ALL_N_64, + flags: F_NA + F_NB + F_SEC, + range_ab: R_3NN, + range_cd: R_3FP, + }, + // 13 - MULH + TestParams { + op: MULH, + a: ALL_N_64, + b: ALL_NZ_P_64, + flags: F_NA + F_NP + F_SEC, + range_ab: R_3NP, + range_cd: R_3FN, + }, + // 14 - MULH + TestParams { + op: MULH, + a: ALL_N_64, + b: 0x0000_00000, + flags: F_NA + F_SEC, + range_ab: R_3NP, + range_cd: R_3FP, + }, + // 15 - MULH + TestParams { + op: MULH, + a: ALL_NZ_P_64, + b: ALL_N_64, + flags: F_NB + F_NP + F_SEC, + range_ab: R_3PN, + range_cd: R_3FN, + }, + // 16 - MULH + TestParams { + op: MULH, + a: 0x0000_0000, + b: ALL_N_64, + flags: F_NB + F_SEC, + range_ab: R_3PN, + range_cd: R_3FP, + }, + // 17 - MUL_W + TestParams { + op: MUL_W, + a: 0x0000_0000, + b: 0x0000_0000, + flags: F_M32, + range_ab: R_FF, + range_cd: R_1PF, + }, + // 18 - MUL_W: 0x00000002 (+/32 bits) * 0x40000000 (+/32 bits) = 0x80000000 (-/32 bits) + TestParams { + op: MUL_W, + a: 0x0000_0002, + b: 0x4000_0000, + flags: F_M32 + F_SEXT, + range_ab: R_FF, + range_cd: R_1NF, + }, + // 19 - MUL_W + TestParams { + op: MUL_W, + a: 0x0000_0002, + b: 0x8000_0000, + flags: F_M32, + range_ab: R_FF, + range_cd: R_1PF, + }, + // 20 - MUL_W + TestParams { + op: MUL_W, + a: 0xFFFF_FFFF, + b: 1, + flags: F_M32 + F_SEXT, + range_ab: R_FF, + range_cd: R_1NF, + }, + // 21 - MUL_W + TestParams { + op: MUL_W, + a: 0xFFFF_FFFF, + b: 0x0000_00000, + flags: F_M32, + range_ab: R_FF, + range_cd: R_1PF, + }, + // 22 - MUL_W + TestParams { + op: MUL_W, + a: 0x7FFF_FFFF, + b: 2, + flags: F_M32 + F_SEXT, + range_ab: R_FF, + range_cd: R_1NF, + }, + // 23 - MUL_W + TestParams { + op: MUL_W, + a: 0xBFFF_FFFF, + b: 0x0000_0002, + flags: F_M32, + range_ab: R_FF, + range_cd: R_1PF, + }, + // 24 - MUL_W: 0xFFFF_FFFF * 0xFFFF_FFFF = 0xFFFF_FFFE_0000_0001 + TestParams { + op: MUL_W, + a: 0xFFFF_FFFF, + b: 0xFFFF_FFFF, + flags: F_M32, + range_ab: R_FF, + range_cd: R_1PF, + }, + // 25 - MUL_W: 0xFFFF_FFFF * 0x0FFF_FFFF = 0x0FFF_FFFE_F000_0001 + TestParams { + op: MUL_W, + a: 0xFFFF_FFFF, + b: 0x0FFF_FFFF, + flags: F_M32 + F_SEXT, + range_ab: R_FF, + range_cd: R_1NF, + }, + // 26 - MUL_W: 0x8000_0000 * 0x8000_0000 = 0x4000_0000_0000_0000 + TestParams { + op: MUL_W, + a: 0x8000_0000, + b: 0x8000_0000, + flags: F_M32, + range_ab: R_FF, + range_cd: R_1PF, + }, + // 27 - DIVU + TestParams { + op: DIVU, + a: ALL_64, + b: ALL_NZ_64, + flags: F_DIV + 0, + range_ab: R_FF, + range_cd: R_FF, + }, + // 28 - REMU + TestParams { + op: REMU, + a: ALL_64, + b: ALL_NZ_64, + flags: F_DIV + F_SEC, + range_ab: R_FF, + range_cd: R_FF, + }, + // 29 - DIV + TestParams { + op: DIV, + a: MAX_P_64, + b: MAX_P_64, + flags: F_DIV, + range_ab: R_3PP, + range_cd: R_3PP, + }, + // 30 - DIV + TestParams { + op: DIV, + a: MIN_N_64, + b: MAX_P_64, + flags: F_DIV + F_NA + F_NP + F_NR, + range_ab: R_3NP, + range_cd: R_3NN, + }, + // 31 - DIV + TestParams { + op: DIV, + a: MAX_P_64, + b: MIN_N_64, + flags: F_DIV + F_NB, // a/b = 0 ➜ np = 0 + range_ab: R_3PN, + range_cd: R_3PP, + }, + // 32 - DIV + TestParams { + op: DIV, + a: MIN_N_64, + b: MIN_N_64, + flags: F_DIV + F_NB + F_NP, // a/b = 1 ➜ 1 * b_neg ➜ np = 1 + range_ab: R_3PN, + range_cd: R_3NP, + }, + // 33 - DIV + TestParams { + op: DIV, + a: 0x0000_0000, + b: MAX_P_64, + flags: F_DIV, + range_ab: R_3PP, + range_cd: R_3PP, + }, + // 34 - DIV + TestParams { + op: DIV, + a: 0x0000_0000, + b: MIN_N_64, + flags: F_DIV + F_NB, + range_ab: R_3PN, + range_cd: R_3PP, + }, + // 35 - REM + TestParams { + op: REM, + a: MAX_P_64, + b: MAX_P_64, + flags: F_DIV + F_SEC, + range_ab: R_3PP, + range_cd: R_3PP, + }, + // 36 - REM + TestParams { + op: REM, + a: MIN_N_64, + b: MAX_P_64, + flags: F_DIV + F_NA + F_NP + F_NR + F_SEC, + range_ab: R_3NP, + range_cd: R_3NN, + }, + // 37 - REM + TestParams { + op: REM, + a: MAX_P_64, + b: MIN_N_64, + flags: F_DIV + F_NB + F_SEC, + range_ab: R_3PN, + range_cd: R_3PP, + }, + // 38 - REM + TestParams { + op: REM, + a: MIN_N_64, + b: MIN_N_64, + flags: F_DIV + F_NB + F_NP + F_SEC, + range_ab: R_3PN, + range_cd: R_3NP, + }, + // 39 - REM + TestParams { + op: REM, + a: 0x0000_0000, + b: MAX_P_64, + flags: F_DIV + F_SEC, + range_ab: R_3PP, + range_cd: R_3PP, + }, + // 40 - REM + TestParams { + op: REM, + a: 0x0000_0000, + b: MIN_N_64, + flags: F_DIV + F_NB + F_SEC, + range_ab: R_3PN, + range_cd: R_3PP, + }, + // 41 - DIVU_W + TestParams { + op: DIVU_W, + a: 0xFFFF_FFFF, + b: 0x0000_0001, + flags: F_DIV + F_M32 + F_SEXT, + range_ab: R_1NF, + range_cd: R_FF, + }, + // 42 - DIVU_W + TestParams { + op: DIVU_W, + a: ALL_NZ_32, + b: 0x0000_00002, + flags: F_DIV + F_M32, + range_ab: R_1PF, + range_cd: R_FF, + }, + // 43 - DIVU_W + TestParams { + op: DIVU_W, + a: ALL_NZ_32, + b: MAX_32, + flags: F_DIV + F_M32, + range_ab: R_1PF, + range_cd: R_FF, + }, + // 44 - DIVU_W + TestParams { + op: DIVU_W, + a: 0, + b: ALL_NZ_32, + flags: F_DIV + F_M32, + range_ab: R_1PF, + range_cd: R_FF, + }, + // 45 - REMU_W + TestParams { + op: REMU_W, + a: 0xFFFF_FFFF, + b: 0x0000_0001, + flags: F_DIV + F_M32 + F_SEC, + range_ab: R_FF, + range_cd: R_1FP, + }, + // 46 - REMU_W + TestParams { + op: REMU_W, + a: ALL_32, + b: 0x0000_00002, + flags: F_DIV + F_M32 + F_SEC, + range_ab: R_FF, + range_cd: R_1FP, + }, + // 47 - REMU_W + TestParams { + op: REMU_W, + a: ALL_NZ_P_32, + b: MAX_32, + flags: F_DIV + F_M32 + F_SEC, + range_ab: R_FF, + range_cd: R_1FP, + }, + // 48 - REMU_W + TestParams { + op: REMU_W, + a: ALL_32, + b: 0x8000_0000, + flags: F_DIV + F_M32 + F_SEC, + range_ab: R_FF, + range_cd: R_1FP, + }, + // 49 - REMU_W + TestParams { + op: REMU_W, + a: 0, + b: ALL_NZ_32, + flags: F_DIV + F_M32 + F_SEC, + range_ab: R_FF, + range_cd: R_1FP, + }, + // 50 - REMU_W + TestParams { + op: REMU_W, + a: 0xFFFF_FFFE, + b: 0xFFFF_FFFF, + flags: F_DIV + F_M32 + F_SEXT + F_SEC, + range_ab: R_FF, + range_cd: R_1FN, + }, + // 51 - REMU_W + TestParams { + op: REMU_W, + a: 0xFFFF_FFFE, + b: 0xFFFF_FFFE, + flags: F_DIV + F_M32 + F_SEC, + range_ab: R_FF, + range_cd: R_1FP, + }, + // 52 - REMU_W + TestParams { + op: REMU_W, + a: 0x8000_0000, + b: 0x8000_0001, + flags: F_DIV + F_M32 + F_SEXT + F_SEC, + range_ab: R_FF, + range_cd: R_1FN, + }, + // 53 - REMU_W + TestParams { + op: REMU_W, + a: 0x8000_0001, + b: 0x8000_0000, + flags: F_DIV + F_M32 + F_SEC, + range_ab: R_FF, + range_cd: R_1FP, + }, + // 54 - REMU_W + TestParams { + op: REMU_W, + a: 0xFFFF_FFFF, + b: 0x0000_0003, + flags: F_DIV + F_M32 + F_SEC, + range_ab: R_FF, + range_cd: R_1FP, + }, + // 55 - DIV_W (-1/1=-1 REM:0) + TestParams { + op: DIV_W, + a: 0xFFFF_FFFF, + b: 0x0000_0001, + flags: F_DIV + F_NA + F_NP + F_M32 + F_SEXT, + range_ab: R_1NP, + range_cd: R_1NP, + }, + // 56 - REM_W !!! + TestParams { + op: REM_W, + a: 0xFFFF_FFFF, + b: 0x0000_0001, + flags: F_DIV + F_NA + F_NP + F_M32 + F_SEC, + range_ab: R_1NP, + range_cd: R_1NP, + }, + // 57 - DIV_W <====== + TestParams { + op: DIV_W, + a: 0xFFFF_FFFF, + b: 0x0000_0002, + flags: F_DIV + F_NP + F_NR + F_M32, + range_ab: R_1PP, + range_cd: R_1NN, + }, + // 58 - REM_W + TestParams { + op: REM_W, + a: 0xFFFF_FFFF, + b: 0x0000_0002, + flags: F_DIV + F_NP + F_NR + F_M32 + F_SEC + F_SEXT, + range_ab: R_1PP, + range_cd: R_1NN, + }, + ]; + + let mut tests_done: Vec = Vec::new(); + let mut errors = 0; + for test in tests { + let a_values = Self::get_test_values(test.a); + let mut offset = 0; + for _a in a_values { + if _a == VALUES_END { + break; + } + let b_values = Self::get_test_values(test.b); + for _b in b_values { + if _b == VALUES_END { + break; + } + let test_info = TestDone { op: test.op, a: _a, b: _b, index, offset }; + let previous = tests_done.iter().find(|&x| { + x.op == test_info.op && x.a == test_info.a && x.b == test_info.b + }); + match previous { + Some(e) => { + println!( + "\x1B[35mDuplicated TEST #{} op:0x{:x} a:0x{:X} b:0x{:X} offset:{}\x1B[0m", + e.index, e.op, e.a, e.b, e.offset + ); + } + None => { + tests_done.push(test_info); + } + } + println!( + "testing #{} op:0x{:x} with _a:0x{:X} _b:0x{:X}", + index, test.op, _a, _b + ); + let (emu_c, emu_flag) = Self::calculate_emulator_res(test.op, _a, _b); + self.test_operation( + test.op, + _a, + _b, + emu_c, + emu_flag, + test.range_ab, + test.range_cd, + test.flags, + ); + offset += 1; + count += 1; + } + } + index += 1; + } + println!("TOTAL ERRORS: {}", self.fail); + assert_eq!(count, TEST_COUNT, "Number of tests not matching"); + } + + fn calculate_emulator_res(op: u8, a: u64, b: u64) -> (u64, bool) { + match op { + MULU => return op_mulu(a, b), + MULUH => return op_muluh(a, b), + MULSUH => return op_mulsuh(a, b), + MUL => return op_mul(a, b), + MULH => return op_mulh(a, b), + MUL_W => return op_mul_w(a, b), + DIVU => return op_divu(a, b), + REMU => return op_remu(a, b), + DIVU_W => return op_divu_w(a, b), + REMU_W => return op_remu_w(a, b), + DIV => return op_div(a, b), + REM => return op_rem(a, b), + DIV_W => return op_div_w(a, b), + REM_W => return op_rem_w(a, b), + _ => { + panic!("Invalid opcode"); + } + } + } + + fn decode_one_range(range_xy: u64) -> [u64; 4] { + if range_xy == 9 { + [0, 0, 0, 0] + } else if range_xy > 9 { + let x = (range_xy - 8) / 3; + let y = (range_xy - 8) % 3; + [0, 0, x, y] + } else { + let x = range_xy / 3; + let y = range_xy % 3; + [x, y, 0, 0] + } + } + fn decode_ranges(range_ab: u64, range_cd: u64) -> [u64; 8] { + let ab = Self::decode_one_range(range_ab); + let cd = Self::decode_one_range(range_cd); + [ab[0], ab[1], cd[0], cd[1], ab[2], ab[3], cd[2], cd[3]] + } + fn dump_test( + &mut self, + index: u32, + op: u8, + a: u64, + b: u64, + c: u64, + flag: bool, + range_ab: u64, + range_cd: u64, + flags: u64, + aop: &ArithOperation, + ) { + println!("{:#?}", aop); + } + fn test_operation( + &mut self, + op: u8, + a: u64, + b: u64, + c: u64, + flag: bool, + range_ab: u64, + range_cd: u64, + flags: u64, + ) { + let mut aop = ArithOperation::new(); + aop.calculate(op, a, b); + let chunks = aop.calculate_chunks(); + for i in 0..8 { + let carry_in = if i > 0 { aop.carry[i - 1] } else { 0 }; + let carry_out = if i < 7 { aop.carry[i] } else { 0 }; + let res = chunks[i] + carry_in - 0x10000 * carry_out; + if res != 0 { + println!("{:#?}", aop); + + self.fail += 1; + self.fail_by_op[(op - 0xb0) as usize] += 1; + println!("\x1B[31mFAIL: 0x{4:X}({4})!= 0 chunks[{0}]=0x{1:X}({1}) carry_in: 0x{2:x},{2} carry_out: 0x{3:x},{3} failed\x1B[0m", + i, + chunks[i], + carry_in, + carry_out, + res); + } + } + // println!( + // "CARRY 0x{0:X}({0}),0x{1:X}({1}),0x{2:X}({2}),0x{3:X}({3}),0x{4:X}({4}),0x{5:X}({5}),0x{6:X}{6},0x{7:X}({7}) fab:{8:X}", + // carrys[0], carrys[1], carrys[2], carrys[3], carrys[4], carrys[5], carrys[6], carrys[7], fab + // ); + + const CHUNK_SIZE: u64 = 0x10000; + let bus_a_low: u64 = aop.div as u64 * (aop.c[0] + aop.c[1] * CHUNK_SIZE) + + (1 - aop.div as u64) * (aop.a[0] + aop.a[1] * CHUNK_SIZE); + let bus_a_high: u64 = aop.div as u64 * (aop.c[2] + aop.c[3] * CHUNK_SIZE) + + (1 - aop.div as u64) * (aop.a[2] + aop.a[3] * CHUNK_SIZE); + + let bus_b_low: u64 = aop.b[0] + CHUNK_SIZE * aop.b[1]; + let bus_b_high: u64 = aop.b[2] + CHUNK_SIZE * aop.b[3]; + + let res2_low: u64 = aop.d[0] + CHUNK_SIZE * aop.d[1]; + let res2_high: u64 = aop.d[2] + CHUNK_SIZE * aop.d[3]; + + let secondary_res: u64 = if aop.main_mul || aop.main_div { 0 } else { 1 }; + /* let bus_res_low: u64 = secondary_res * res2_low + + (1 - secondary_res) + * (aop.a[0] + aop.c[0] + CHUNK_SIZE * (aop.a[1] + aop.c[1]) - bus_a_low); + + let bus_res_high: u64 = (1 - aop.m32 as u64) + * (secondary_res * res2_high + + (1 - secondary_res) + * ((aop.a[2] + aop.c[2] + CHUNK_SIZE * (aop.a[3] + aop.c[3])) - bus_a_high)) + + aop.sext as u64 * 0xFFFFFFFF;*/ + + let bus_res_low = secondary_res * (aop.d[0] + aop.d[1] * CHUNK_SIZE) + + aop.main_mul as u64 * (aop.c[0] + aop.c[1] * CHUNK_SIZE) + + aop.main_div as u64 * (aop.a[0] + aop.a[1] * CHUNK_SIZE); + + let bus_res_high_64 = secondary_res * (aop.d[2] + aop.d[3] * CHUNK_SIZE) + + aop.main_mul as u64 * (aop.c[2] + aop.c[3] * CHUNK_SIZE) + + aop.main_div as u64 * (aop.a[2] + aop.a[3] * CHUNK_SIZE); + + let bus_res_high = aop.sext as u64 * 0xFFFF_FFFF + (1 - aop.m32 as u64) * bus_res_high_64; + + let expected_a_low = a & 0xFFFF_FFFF; + let expected_a_high = (a >> 32) & 0xFFFF_FFFF; + let expected_b_low = b & 0xFFFF_FFFF; + let expected_b_high = (b >> 32) & 0xFFFF_FFFF; + let expected_res_low = c & 0xFFFF_FFFF; + let expected_res_high = (c >> 32) & 0xFFFF_FFFF; + + assert_eq!( + bus_a_low, expected_a_low, + "bus_a_low: 0x{0:X}({0}) vs 0x{1:X}({1}) (expected)", + bus_a_low, expected_a_low + ); + assert_eq!( + bus_a_high, expected_a_high, + "bus_a_high: 0x{0:X}({0}) vs 0x{1:X}({1}) (expected)", + bus_a_high, expected_a_high + ); + assert_eq!( + bus_b_low, expected_b_low, + "bus_b_low: 0x{0:X}({0}) vs 0x{1:X}({1}) (expected)", + bus_b_low, expected_b_low + ); + assert_eq!( + bus_b_high, expected_b_high, + "bus_b_high: 0x{0:X}({0}) vs 0x{1:X}({1}) (expected)", + bus_b_high, expected_b_high + ); + assert_eq!( + bus_res_low, expected_res_low, + "bus_c_low: 0x{0:X}({0}) vs 0x{1:X}({1}) (expected)", + bus_res_low, expected_res_low + ); + assert_eq!( + bus_res_high, expected_res_high, + "bus_c_high: 0x{0:X}({0}) vs 0x{1:X}({1}) (expected)", + bus_res_high, expected_res_high + ); + // check all chunks and carries + let carry_min_value: i64 = -0xEFFFF; + let carry_max_value: i64 = 0xF0000; + for i in 0..7 { + assert!(aop.carry[i] >= carry_min_value); + assert!(aop.carry[i] <= carry_max_value); + } + + let ranges = Self::decode_ranges(range_ab, range_cd); + Self::check_range(0, aop.a[0]); + Self::check_range(0, aop.b[0]); + Self::check_range(0, aop.c[0]); + Self::check_range(0, aop.d[0]); + + Self::check_range(ranges[4], aop.a[1]); + Self::check_range(ranges[5], aop.b[1]); + Self::check_range(ranges[6], aop.c[1]); + Self::check_range(ranges[7], aop.d[1]); + + Self::check_range(0, aop.a[2]); + Self::check_range(0, aop.b[2]); + Self::check_range(0, aop.c[2]); + Self::check_range(0, aop.d[2]); + + Self::check_range(ranges[0], aop.a[3]); + Self::check_range(ranges[1], aop.b[3]); + Self::check_range(ranges[2], aop.c[3]); + Self::check_range(ranges[3], aop.d[3]); + } + fn print_chunks(label: &str, chunks: [u64; 4]) { + println!( + "{0}: 0x{1:>04X} \x1B[32m{1:>5}\x1B[0m|0x{2:>04X} \x1B[32m{2:>5}\x1B[0m|0x{3::>04X} \x1B[32m{3:>5}\x1B[0m|0x{4:>04X} \x1B[32m{4:>5}\x1B[0m|", + label, chunks[0], chunks[1], chunks[2], chunks[3] + ); + } + fn check_range(range_id: u64, value: u64) { + assert!(range_id != 0 || (value >= 0 && value <= 0xFFFF)); + assert!(range_id != 1 || (value >= 0 && value <= 0x7FFF)); + assert!(range_id != 2 || (value >= 0x8000 && value <= 0xFFFF)); + } + + fn flags_to_strings(mut flags: u64, flag_names: &[&str]) -> String { + let mut res = String::new(); + + for flag_name in flag_names { + if (flags & 1u64) != 0 { + if !res.is_empty() { + res = res + ","; + } + res = res + *flag_name; + } + flags >>= 1; + if flags == 0 { + break; + }; + } + res + } + + fn get_test_values(value: u64) -> [u64; 16] { + match value { + ALL_64 => [ + 0, + 1, + 2, + 3, + MAX_P_32 - 1, + MAX_P_32, + MIN_N_32, + MAX_32 - 1, + MAX_32, + MAX_32 + 1, + MAX_P_64 - 1, + MAX_P_64, + MAX_64 - 1, + MIN_N_64, + MIN_N_64 + 1, + MAX_64, + ], + ALL_NZ_64 => [ + 1, + 2, + 3, + MAX_P_32 - 1, + MAX_P_32, + MIN_N_32, + MAX_32 - 1, + MAX_32, + MAX_32 + 1, + MAX_P_64 - 1, + MAX_P_64, + MAX_64 - 1, + MIN_N_64, + MIN_N_64 + 1, + MAX_64, + VALUES_END, + ], + ALL_P_64 => [ + 0, + 1, + 2, + 3, + MAX_P_32 - 1, + MAX_P_32, + MIN_N_32, + MAX_32 - 1, + MAX_32, + MAX_32 + 1, + MAX_P_64 - 1, + MAX_P_64, + VALUES_END, + 0, + 0, + 0, + ], + ALL_NZ_P_64 => [ + 1, + 2, + 3, + MAX_P_32 - 1, + MAX_P_32, + MIN_N_32, + MAX_32 - 1, + MAX_32, + MAX_32 + 1, + MAX_P_64 - 1, + MAX_P_64, + VALUES_END, + 0, + 0, + 0, + 0, + ], + ALL_N_64 => [ + MIN_N_64, + MIN_N_64 + 1, + MIN_N_64 + 2, + MIN_N_64 + 3, + 0x8000_0000_7FFF_FFFF, + 0x8FFF_FFFF_7FFF_FFFF, + 0xEFFF_FFFF_FFFF_FFFF, + MAX_64 - 3, + MAX_64 - 2, + MAX_64 - 1, + MAX_64, + VALUES_END, + 0, + 0, + 0, + 0, + ], + ALL_32 => [ + 0, + 1, + 2, + 3, + MAX_P_32 - 1, + MAX_P_32, + MIN_N_32, + MAX_32 - 1, + MAX_32, + VALUES_END, + 0, + 0, + 0, + 0, + 0, + 0, + ], + ALL_NZ_32 => [ + 1, + 2, + 3, + MAX_P_32 - 1, + MAX_P_32, + MIN_N_32, + MAX_32 - 1, + MAX_32, + VALUES_END, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ], + ALL_P_32 => [ + 0, + 1, + 2, + 3, + 0x0000_7FFF, + 0x0000_FFFF, + MAX_P_32 - 1, + MAX_P_32, + MAX_P_32 - 1, + MAX_P_32, + VALUES_END, + 0, + 0, + 0, + 0, + 0, + ], + ALL_NZ_P_32 => [ + 1, + 2, + 3, + 0x0000_7FFF, + 0x0000_FFFF, + MAX_P_32 - 1, + MAX_P_32, + MAX_P_32 - 1, + MAX_P_32, + VALUES_END, + 0, + 0, + 0, + 0, + 0, + 0, + ], + ALL_N_32 => [ + MIN_N_32, + MIN_N_32 + 1, + MIN_N_32 + 2, + MIN_N_32 + 3, + MAX_32 - 3, + MAX_32 - 2, + MAX_32 - 1, + MAX_32, + VALUES_END, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ], + _ => [value, VALUES_END, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + } + } +} + +#[test] +fn test() { + let mut test = ArithOperationTest::new(); + test.test(); + for i in 0..16 { + if test.fail_by_op[i] == 0 { + continue; + } + println!("fail_by_op[0x{:X}]: {}", i + 0xb0, test.fail_by_op[i]); + } + assert_eq!(test.fail, 0); +} diff --git a/state-machines/arith/src/arith_range_table_helpers.rs b/state-machines/arith/src/arith_range_table_helpers.rs new file mode 100644 index 00000000..be68d777 --- /dev/null +++ b/state-machines/arith/src/arith_range_table_helpers.rs @@ -0,0 +1,133 @@ +use std::ops::Add; + +const ROWS: usize = 1 << 22; +const FULL: u8 = 0x00; +const POS: u8 = 0x01; +const NEG: u8 = 0x02; +pub struct AirthRangeTableHelpers; + +const RANGES: [u8; 43] = [ + FULL, FULL, FULL, POS, POS, POS, NEG, NEG, NEG, FULL, FULL, FULL, FULL, FULL, FULL, FULL, FULL, + FULL, POS, NEG, FULL, POS, NEG, FULL, POS, NEG, FULL, FULL, FULL, FULL, FULL, FULL, FULL, FULL, + FULL, FULL, FULL, POS, POS, POS, NEG, NEG, NEG, +]; +const OFFSETS: [usize; 43] = [ + 0, 2, 4, 50, 51, 52, 59, 60, 61, 6, 8, 10, 12, 14, 16, 18, 20, 22, 53, 62, 24, 54, 63, 26, 55, + 64, 28, 30, 32, 34, 36, 38, 40, 42, 44, 46, 48, 56, 57, 58, 65, 66, 67, +]; + +impl AirthRangeTableHelpers { + pub fn get_range_name(range_index: u8) -> &'static str { + match range_index { + 0 => "F F F F", + 1 => "F F + F", + 2 => "F F - F", + 3 => "+ F F F", + 4 => "+ F + F", + 5 => "+ F - F", + 6 => "- F F F", + 7 => "- F + F", + 8 => "- F - F", + 9 => "F F F +", + 10 => "F F F -", + 11 => "F + F F", + 12 => "F + F +", + 13 => "F + F -", + 14 => "F - F F", + 15 => "F - F +", + 16 => "F - F -", + _ => panic!("Invalid range index"), + } + } + pub fn get_row_chunk_range_check(range_index: u8, value: i64) -> usize { + // F F F + + + - - - F F F F F F F F F + - F + - F + - F F F F F F F F F F F + + + - - - + let range_type = RANGES[range_index as usize]; + assert!(range_index < 43); + assert!(value >= if range_type == NEG { -0xFFFF } else { 0 }); + assert!( + value + <= match range_type { + FULL => 0xFFFF, + POS => -1, + NEG => 0x7FFF, + _ => panic!("Invalid range type"), + } + ); + OFFSETS[range_index as usize] * 0x8000 + + if range_type == NEG { 0x8000 + value } else { value } as usize + } + pub fn get_row_carry_range_check(value: i64) -> usize { + assert!(value >= -0xEFFFF); + assert!(value <= 0xF0000); + (0x220000 + 0xEFFFF + value) as usize + } +} +struct AirthRangeTableMultiplicity { + multiplicity: [u64; ROWS], +} +impl AirthRangeTableMultiplicity { + fn new() -> Self { + AirthRangeTableMultiplicity { multiplicity: [0; ROWS] } + } + fn use_chunk_range_check(&self, range_id: u8, value: i64) { + let row = AirthRangeTableHelpers::get_row_chunk_range_check(range_id, value); + self.multiplicity[row as usize]; + } + fn use_carry_range_check(&self, value: i64) { + let row = AirthRangeTableHelpers::get_row_carry_range_check(value); + self.multiplicity[row as usize]; + } +} + +impl Add for AirthRangeTableMultiplicity { + type Output = Self; + + fn add(self, other: Self) -> Self { + let mut result = AirthRangeTableMultiplicity::new(); + for i in 0..ROWS { + result.multiplicity[i] = self.multiplicity[i] + other.multiplicity[i]; + } + result + } +} + +#[cfg(generate_code_arith_range_table)] +fn generate_table() { + let pattern = "FFF+++---FFFFFFFFF+-F+-F+-FFFFFFFFFFF+++---"; + // let mut ranges = [0u8; 43]; + let mut ranges = String::new(); + let mut offsets = [0usize; 43]; + let mut offset = 0; + for range_loop in [FULL, POS, NEG] { + let mut index = 0; + for c in pattern.chars() { + if c == ' ' || c == '_' { + continue; + } + let range_id = match c { + 'F' => FULL, + '+' => POS, + '-' => NEG, + _ => panic!("Invalid character in pattern"), + }; + if range_loop == FULL { + if index > 0 { + ranges.push_str(", ") + } + ranges.push_str(match range_id { + FULL => "FULL", + POS => "POS", + _ => "NEG", + }); + // ranges[index] = range_id + } + if range_loop == range_id { + offsets[index] = offset; + offset = offset + if range_loop == FULL { 2 } else { 1 }; + } + index += 1; + } + } + println!("const RANGES: [u8; 43] = [{}];", ranges); + println!("const OFFSETS: [usize; 43] = {:?};", offsets); +} diff --git a/state-machines/arith/src/arith_table_helpers.rs b/state-machines/arith/src/arith_table_helpers.rs new file mode 100644 index 00000000..04a89931 --- /dev/null +++ b/state-machines/arith/src/arith_table_helpers.rs @@ -0,0 +1,71 @@ +use std::ops::Add; + +const ROWS: usize = 95; +const FIRST_OP: u8 = 0xb0; + +struct AirthTableHelpers; + +impl AirthTableHelpers { + fn get_row(op: u8, na: u64, nb: u64, np: u64, nr: u64, sext: u64) -> usize { + static ARITH_TABLE_ROWS: [i16; 512] = [ + 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, 2, 3, -1, -1, -1, 4, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 5, 6, 7, 8, -1, + 9, 10, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, 11, 12, 13, 14, -1, 15, 16, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 17, 18, 19, 20, -1, 21, 22, + -1, -1, -1, -1, -1, -1, -1, -1, -1, 23, 24, 25, 26, -1, 27, 28, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 29, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, 30, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 31, 32, 33, 34, 35, 36, 37, -1, -1, -1, -1, + -1, 38, 39, 40, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 41, + 42, 43, 44, 45, 46, 47, -1, -1, -1, -1, -1, 48, 49, 50, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, 51, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, 52, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 53, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 54, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, 55, 56, 57, 58, 59, 60, 61, -1, -1, -1, -1, -1, 62, 63, 64, + -1, 65, 66, 67, 68, 69, 70, 71, -1, -1, -1, -1, -1, 72, 73, 74, -1, 75, 76, 77, 78, 79, + 80, 81, -1, -1, -1, -1, -1, 82, 83, 84, -1, 85, 86, 87, 88, 89, 90, 91, -1, -1, -1, -1, + -1, 92, 93, 94, -1, + ]; + + let index = (op - FIRST_OP) as u64 * 32 + na + nb * 2 + np * 4 + nr * 8 + sext * 16; + let row = ARITH_TABLE_ROWS[index as usize]; + assert!(row >= 0); + row as usize + } + fn get_max_row() -> usize { + ROWS - 1 + } +} + +struct AirthTableMultiplicity { + multiplicity: [u64; ROWS], +} + +impl AirthTableMultiplicity { + fn new() -> Self { + AirthTableMultiplicity { multiplicity: [0; ROWS] } + } + fn add_use(&self, op: u8, na: u64, nb: u64, np: u64, nr: u64, sext: u64) { + let row = AirthTableHelpers::get_row(op, na, nb, np, nr, sext); + self.multiplicity[row as usize]; + } +} + +impl Add for AirthTableMultiplicity { + type Output = Self; + + fn add(self, other: Self) -> Self { + let mut result = AirthTableMultiplicity::new(); + for i in 0..ROWS { + result.multiplicity[i] = self.multiplicity[i] + other.multiplicity[i]; + } + result + } +} diff --git a/state-machines/arith/src/lib.rs b/state-machines/arith/src/lib.rs index 4e3603bd..fab4cf43 100644 --- a/state-machines/arith/src/lib.rs +++ b/state-machines/arith/src/lib.rs @@ -1,20 +1,28 @@ mod arith; mod arith_32; +mod arith_constants; mod arith_full; -mod arith_helpers; mod arith_mul_32; mod arith_mul_64; +mod arith_operation; +mod arith_operation_test; mod arith_range_table; +mod arith_range_table_helpers; mod arith_range_table_inputs; mod arith_table; +mod arith_table_helpers; mod arith_table_inputs; pub use arith::*; pub use arith_32::*; +pub use arith_constants::*; pub use arith_full::*; pub use arith_mul_32::*; pub use arith_mul_64::*; +pub use arith_operation::*; pub use arith_range_table::*; +pub use arith_range_table_helpers::*; pub use arith_range_table_inputs::*; pub use arith_table::*; +pub use arith_table_helpers::*; pub use arith_table_inputs::*; diff --git a/state-machines/binary/pil/binary_extension.pil b/state-machines/binary/pil/binary_extension.pil index a456a3ac..c26594d8 100644 --- a/state-machines/binary/pil/binary_extension.pil +++ b/state-machines/binary/pil/binary_extension.pil @@ -44,7 +44,7 @@ x in1[x] out[x][0] out[x][1] 1 0x22 0x00220000 0x00000000 2 0x33 0x33000000 0x00000000 3 0x44 0x00000000 0x00000044 -4 0x55 0x00000000 0x00000000 (since 0x44 & 0x80 = 0, we stop here and set the remaining bytes to 0x00) +4 0x55 0x00000000 0x00000000 (since 0x44 & 0x80 = 0, we stop here and set the remaining bytes to 0x00) 5 0x66 0x00000000 0x00000000 (bytes of in1 are ignored from here) 6 0x77 0x00000000 0x00000000 7 0x88 0x00000000 0x00000000 @@ -72,7 +72,7 @@ airtemplate BinaryExtension(const int N = 2**18, const int operation_bus_id = BI const int bits = 64; const int bytes = bits / 8; - col witness op; + col witness op; col witness in1[bytes]; col witness in2_low; // Note: if in2_low∊[0,2^5-1], else in2_low∊[0,2^6-1] (checked by the table) col witness out[bytes][2]; @@ -90,7 +90,7 @@ airtemplate BinaryExtension(const int N = 2**18, const int operation_bus_id = BI expr in1_high = in1[4] + in1[5]*2**8 + in1[6]*2**16 + in1[7]*2**24; col witness main_step; -// col witness multiplicity; + col witness multiplicity; lookup_proves( operation_bus_id, [ @@ -104,8 +104,7 @@ airtemplate BinaryExtension(const int N = 2**18, const int operation_bus_id = BI out[0][1] + out[1][1] + out[2][1] + out[3][1] + out[4][1] + out[5][1] + out[6][1] + out[7][1], 0 ], - 1 -// multiplicity + multiplicity ); range_check(colu: in2[0], min: 0, max: 2**24-1, sel: op_is_shift); From 8766d06786cba32089a8ead9cf5da7b3e0374387 Mon Sep 17 00:00:00 2001 From: zkronos73 Date: Mon, 4 Nov 2024 11:55:31 +0000 Subject: [PATCH 16/17] arith, clean obsolete code --- state-machines/arith/pil/arith_32.pil | 160 ---------------- state-machines/arith/pil/arith_mul_32.pil | 90 --------- state-machines/arith/pil/arith_mul_64.pil | 177 ------------------ state-machines/arith/src/arith.rs | 15 -- state-machines/arith/src/arith_32.rs | 115 ------------ state-machines/arith/src/arith_full.rs | 16 +- state-machines/arith/src/arith_mul_32.rs | 117 ------------ state-machines/arith/src/arith_mul_64.rs | 108 ----------- .../arith/src/arith_range_table_helpers.rs | 16 +- .../arith/src/arith_range_table_inputs.rs | 48 ----- .../arith/src/arith_table_helpers.rs | 20 +- .../arith/src/arith_table_inputs.rs | 134 ------------- state-machines/arith/src/lib.rs | 10 - 13 files changed, 26 insertions(+), 1000 deletions(-) delete mode 100644 state-machines/arith/pil/arith_32.pil delete mode 100644 state-machines/arith/pil/arith_mul_32.pil delete mode 100644 state-machines/arith/pil/arith_mul_64.pil delete mode 100644 state-machines/arith/src/arith_32.rs delete mode 100644 state-machines/arith/src/arith_mul_32.rs delete mode 100644 state-machines/arith/src/arith_mul_64.rs delete mode 100644 state-machines/arith/src/arith_range_table_inputs.rs delete mode 100644 state-machines/arith/src/arith_table_inputs.rs diff --git a/state-machines/arith/pil/arith_32.pil b/state-machines/arith/pil/arith_32.pil deleted file mode 100644 index d012e38d..00000000 --- a/state-machines/arith/pil/arith_32.pil +++ /dev/null @@ -1,160 +0,0 @@ -require "std_lookup.pil" -require "std_range_check.pil" -require "operations.pil" -require "arith_table.pil" - -airtemplate Arith32(int N = 2**10, const int dual_result = 0) { - - // NOTE: - // Divisions and remainders by 0 are done by QuickOps - - col witness carry[3]; - col witness a[2]; - col witness b[2]; - col witness c[2]; - col witness d[2]; - - col witness na; // a is negative - col witness nb; // b is negative - col witness nr; // rem is negative - col witness np; // prod is negative - col witness na32; // a is 32-bit negative, 31th bit is 1. - col witness nd32; // d is 32-bit negative, 31th bit is 1. - - col witness div; // division operation (div,rem) - - col witness fab; // fab, to decrease degree of intermediate products a * b - // fab = 1 if sign of a,b are the same - // fab = -1 if sign of a,b are different - - if (!dual_result) { - col witness air.secondary_res; // op_index: 0 => first result, 1 => second result; - secondary_res * (secondary_res - 1) === 0; - } else { - const expr air.secondary_res = 0; - } - - fab === 1 - 2 * na - 2 * nb + 4 * na * nb; - - const expr eq[8]; - - eq[0] = fab * a[0] * b[0] - - c[0] - + 2 * np * c[0] - + div * d[0] - - 2 * nr * d[0]; - - eq[1] = fab * a[1] * b[0] - + fab * a[0] * b[1] - - c[1] - + 2 * np * c[1] - + div * d[1] - - 2 * nr * d[1]; - - eq[2] = fab * a[1] * b[1] - - np * div - + nr; - - // TODO: review !!!!! - eq[3] = 2**16 * na * nb; - - eq[0] - carry[0] * 2**16 === 0; - eq[1] + carry[0] - carry[1] * 2**16 === 0; - eq[2] + carry[1] - carry[2] * 2**16 === 0; - eq[3] + carry[2] === 0; - - // binary contraint - div * (1 - div) === 0; - na * (1 - na) === 0; - nb * (1 - nb) === 0; - nr * (1 - nr) === 0; - np * (1 - np) === 0; - na32 * (1 - na32) === 0; - nd32 * (1 - nd32) === 0; - - col witness op; - - // div sa sb comm primary secondary opcodes na nb nr np na32 nd32 - // ------------------------------------------------------------------------------ - // 0 1 1 x mul_w *n/a* (0xb6,0xb7) a1 b1 0 d3 c1 0 d3, a1,b1,c1 - // 1 1 1 div_w rem_w (0xbe,0xbf) a1 b1 d1 c1 c1 d1 a1,b1,c1,d1 - - // (*) removed combinations of flags div,sa,sb did allow combinations div, sa, sb - // comm = commutative (trivial: commutative operations) - - col witness bus_a_low; - bus_a_low === div * (c[0] - a[0]) - + a[0] - + 2**16 * div * (c[1] - a[1]) - + 2**16 * a[1]; - - const expr bus_a_high = 0; - - - const expr bus_b_low = b[0] + 2**16 * b[1]; - - // TODO: na32 and nd32 only valid on 32 bit operations - // TODO: m32 === 0 ==> b[2],a[2],b[3],a[3] === 0 avoid two witness - const expr bus_b_high = 0; - - const expr res2_low = d[0] + 2**16 * d[1]; - const expr res2_high = nd32 * 0xFFFFFFFF; - - if (dual_result) { - // theorical cost: 4 columns - col witness multiplicity_2; - lookup_proves(OPERATION_BUS_ID, [op+1, bus_a_low, bus_a_high, bus_b_low, bus_b_high, res2_low, res2_high, 0], mul: multiplicity_2); - } - - if (dual_result) { - const expr air.res1_low = a[0] + c[0] + 2**16 * a[1] + 2**16 * c[1] - bus_a_low; - col witness air.res1_high; - res1_high === div * na32 * 0xFFFFFFFF + (1 - div) * nd32 * 0xFFFFFFFF; - } else { - col witness air.res1_low; - res1_low === secondary_res * res2_low - (1 - secondary_res) * (a[0] + c[0] + 2**16 * a[1] + 2**16 * c[1] - bus_a_low); - - col witness air.res1_high; - res1_high === secondary_res * res2_high + (1 - secondary_res) * (div * na32 * 0xFFFFFFFF + (1 - div) * nd32 * 0xFFFFFFFF); - } - - - col witness multiplicity; - - lookup_proves(OPERATION_BUS_ID, [op + secondary_res, - bus_a_low, bus_a_high, - bus_b_low, bus_b_high, - res1_low, res1_high, - 0], mul: multiplicity); - - - // TODO: review - lookup_assumes(OPERATION_BUS_ID, [OP_LT, res2_low, res2_high, bus_b_low, bus_b_high, 0, 1, 1], sel: div); - - for (int index = 0; index < length(carry); ++index) { - range_check(colu: carry[index], min:-2**20, max: 2**20-1); // TODO: review range - } - - range_check(colu: a[0], min:0, max: 2**16-1); - range_check(colu: b[0], min:0, max: 2**16-1); - range_check(colu: c[0], min:0, max: 2**16-1); - range_check(colu: d[0], min:0, max: 2**16-1); - - col witness range_a1; - col witness range_b1; - col witness range_c1; - col witness range_d1; - - lookup_assumes(ARITH_TABLE_ID, cols: [ op, 1 + 2 * div + 4 * na + 8 * nb + 16 * nr + 32 * np + 64 * na32 + 128 * nd32 + - 2**8 * range_a1 + 2**10 * range_b1 + 2**12 * range_c1 + 2**14 * range_d1]); - - range_a1 * (1 - range_a1) * (2 - range_a1) === 0; - range_b1 * (1 - range_b1) * (2 - range_b1) === 0; - range_c1 * (1 - range_c1) * (2 - range_c1) === 0; - range_d1 * (1 - range_d1) * (2 - range_d1) === 0; - - lookup_assumes(QUICK_RANGE_TABLE_ID, [range_a1, a[1]]); - lookup_assumes(QUICK_RANGE_TABLE_ID, [range_b1, b[1]]); - lookup_assumes(QUICK_RANGE_TABLE_ID, [range_c1, c[1]]); - lookup_assumes(QUICK_RANGE_TABLE_ID, [range_d1, d[1]]); -} \ No newline at end of file diff --git a/state-machines/arith/pil/arith_mul_32.pil b/state-machines/arith/pil/arith_mul_32.pil deleted file mode 100644 index 91c2ec0c..00000000 --- a/state-machines/arith/pil/arith_mul_32.pil +++ /dev/null @@ -1,90 +0,0 @@ -require "std_lookup.pil" -require "std_range_check.pil" -require "operations.pil" -require "arith_table.pil" - -airtemplate ArithMul32(int N = 2**10, const int operation_bus_id) { - - const int CHUNK_SIZE = 2**16; - const int CHUNKS_INPUT = 2; - const int CHUNKS_OP = CHUNKS_INPUT * 2; - - col witness carry[CHUNKS_OP - 1]; - col witness a[CHUNKS_INPUT]; - col witness b[CHUNKS_INPUT]; - col witness c[CHUNKS_INPUT]; - col witness d[CHUNKS_INPUT]; - - col witness na; // a is negative - col witness nb; // b is negative - col witness np; // prod is negative - col witness nd32; // d is 32-bit negative, 31th bit is 1. - - col witness fab; // fab, to decrease degree of intermediate products a * b - // fab = 1 if sign of a,b are the same - // fab = -1 if sign of a,b are different - // factor ab € {-1, 1} - fab === 1 - 2 * na - 2 * nb + 4 * na * nb; - - const expr eq[CHUNKS_OP]; - - eq[0] = fab * a[0] * b[0] - - c[0] - + 2 * np * c[0]; - - eq[1] = fab * a[1] * b[0] - + fab * a[0] * b[1] - - c[1] - + 2 * np * c[1]; - - eq[2] = fab * a[1] * b[1]; - - // TODO: review !!!!! - eq[3] = 2**16 * na * nb; - - eq[0] - carry[0] * CHUNK_SIZE === 0; - for (int index = 1; index < (CHUNKS_OP - 1); ++index) { - eq[index] + carry[index-1] - carry[index] * CHUNK_SIZE === 0; - } - eq[CHUNKS_OP-1] + carry[CHUNKS_OP-2] === 0; - - // binary contraint - na * (1 - na) === 0; - nb * (1 - nb) === 0; - np * (1 - np) === 0; - nd32 * (1 - nd32) === 0; - - np === na + nb - 2 * na * nb; - - const expr bus_a_low = a[0] + 2**16 * a[1]; - const expr bus_a_high = 0; - - const expr bus_b_low = b[0] + CHUNK_SIZE * b[1]; - const expr bus_b_high = 0; - - const expr res1_low = c[0] + CHUNK_SIZE * + CHUNK_SIZE * c[1]; - const expr res1_high = nd32 * 0xFFFFFFFF; - - col witness multiplicity; - - lookup_proves(operation_bus_id, [OP_MUL_W, - bus_a_low, bus_a_high, - bus_b_low, bus_b_high, - res1_low, res1_high, - 0], mul: multiplicity); - - - for (int index = 0; index < length(carry); ++index) { - range_check(colu: carry[index], min:-2**20, max: 2**20-1); // TODO: review range - } - - range_check(colu: a[0], min:0, max: CHUNK_SIZE-1); - range_check(colu: b[0], min:0, max: CHUNK_SIZE-1); - range_check(colu: c[0], min:0, max: CHUNK_SIZE-1); - range_check(colu: d[0], min:0, max: CHUNK_SIZE-1); - range_check(colu: c[1], min:0, max: CHUNK_SIZE-1); - - lookup_assumes(QUICK_RANGE_TABLE_ID, [1 + na, a[1]]); - lookup_assumes(QUICK_RANGE_TABLE_ID, [1 + nb, b[1]]); - lookup_assumes(QUICK_RANGE_TABLE_ID, [1 + np, d[1]]); -} \ No newline at end of file diff --git a/state-machines/arith/pil/arith_mul_64.pil b/state-machines/arith/pil/arith_mul_64.pil deleted file mode 100644 index 03ca6ec4..00000000 --- a/state-machines/arith/pil/arith_mul_64.pil +++ /dev/null @@ -1,177 +0,0 @@ -require "std_lookup.pil" -require "std_range_check.pil" -require "operations.pil" -require "arith_table.pil" - -airtemplate ArithMul64(int N = 2**18, const int operation_bus_id, const int dual_result = 0) { - - // NOTE: - // Divisions and remainders by 0 are done by QuickOps - - const int CHUNK_SIZE = 2**16; - const int CHUNKS = 8; - - col witness carry[CHUNKS - 1]; - col witness a[4]; - col witness b[4]; - col witness c[4]; - col witness d[4]; - - col witness na; // a is negative - col witness nb; // b is negative - col witness np; // prod is negative - - col witness fab; // fab, to decrease degree of intermediate products a * b - // fab = 1 if sign of a,b are the same - // fab = -1 if sign of a,b are different - - if (!dual_result) { - col witness air.secondary_res; // op_index: 0 => first result, 1 => second result; - secondary_res * (secondary_res - 1) === 0; - } else { - const expr air.secondary_res = 0; - } - - // factor ab € {-1, 1} - fab === 1 - 2 * na - 2 * nb + 4 * na * nb; - - const expr eq[CHUNKS]; - - eq[0] = fab * a[0] * b[0] - - c[0] - + 2 * np * c[0]; - - eq[1] = fab * a[1] * b[0] - + fab * a[0] * b[1] - - c[1] - + 2 * np * c[1]; - - eq[2] = fab * a[2] * b[0] - + fab * a[1] * b[1] - + fab * a[0] * b[2] - - c[2] - + 2 * np * c[2]; - - eq[3] = fab * a[3] * b[0] - + fab * a[2] * b[1] - + fab * a[1] * b[2] - + fab * a[0] * b[3] - - c[3] - + 2 * np * c[3]; - - eq[4] = fab * a[3] * b[1] - + fab * a[2] * b[2] - + fab * a[1] * b[3] - + na * b[0] * (1 - 2 * nb) - + nb * a[0] * (1 - 2 * na) - - d[0] - + 2 * np * d[0]; - - eq[5] = fab * a[3] * b[2] - + fab * a[2] * b[3] - + nb * a[1] * (1 - 2 * na) - + na * b[1] * (1 - 2 * nb) - - d[1] - + 2 * np * d[1]; - - eq[6] = fab * a[3] * b[3] - + nb * a[2] * (1 - 2 * na) - + na * b[2] * (1 - 2 * nb) - - d[2] - + 2 * np * d[2]; - - eq[7] = CHUNK_SIZE * na * nb - + na * b[3] * (1 - 2 * nb) - + nb * a[3] * (1 - 2 * na) - - CHUNK_SIZE * np - - d[3] - + 2 * np * d[3]; - - eq[0] - carry[0] * CHUNK_SIZE === 0; - for (int index = 1; index < (CHUNKS - 1); ++index) { - eq[index] + carry[index-1] - carry[index] * CHUNK_SIZE === 0; - } - - // binary contraint - na * (1 - na) === 0; - nb * (1 - nb) === 0; - np * (1 - np) === 0; - - col witness op; - - // div m32 sa sb comm primary secondary opcodes na nb nr np na32 nd32 - // ---------------------------------------------------------------------------------- - // 0 0 0 0 x mulu muluh (0xb0,0xb1) =0 =0 =0 =0 =0 =0 - // 0 0 1 0 *n/a* mulsuh (0xb2,0xb3) a3 =0 =0 d3 =0 =0 a3, d3 - // 0 0 1 1 x mul mulh (0xb4,0xb5) a3 b3 =0 d3 =0 =0 a3,b3, d3 - - // (*) removed combinations of flags div,m32,sa,sb did allow combinations div, m32, sa, sb - // see 5 previous constraints. - // =0 means forced to zero by previous constraints - // comm = commutative (trivial: commutative operations) - - const expr bus_a_low = a[0] + CHUNK_SIZE * a[1]; - const expr bus_a_high = a[2] + CHUNK_SIZE * a[3]; - - - const expr bus_b_low = b[0] + CHUNK_SIZE * b[1]; - const expr bus_b_high = b[2] + CHUNK_SIZE * b[3]; - - const expr res2_low = d[0] + CHUNK_SIZE * d[1]; - const expr res2_high = d[2] + CHUNK_SIZE * d[3]; - - if (dual_result) { - // theorical cost: 4 columns - col witness multiplicity_2; - lookup_proves(operation_bus_id, [op+1, bus_a_low, bus_a_high, bus_b_low, bus_b_high, res2_low, res2_high, 0], mul: multiplicity_2); - - const expr air.res1_low = a[0] + c[0] + CHUNK_SIZE * a[1] + CHUNK_SIZE * c[1] - bus_a_low; - const expr air.res1_high = c[2] + CHUNK_SIZE * c[3]; - } else { - col witness air.res1_low; - res1_low === secondary_res * res2_low - (1 - secondary_res) * (a[0] + c[0] + CHUNK_SIZE * a[1] + CHUNK_SIZE * c[1] - bus_a_low); - - col witness air.res1_high; - // res1_high === secondary_res * res2_high + (1 - secondary_res) * ((1 - m32) * (div * (a[2] - c[2]) + c[2] + 2**16 * div * (a[3] - c[3]) + 2**16 * c[3]) + div * na32 * 0xFFFFFFFF + (1 - div) * nd32 * 0xFFFFFFFF); - res1_high === secondary_res * res2_high + (1 - secondary_res) * (c[2] + CHUNK_SIZE * c[3]); - } - - - col witness multiplicity; - - lookup_proves(operation_bus_id, [op + secondary_res, - bus_a_low, bus_a_high, - bus_b_low, bus_b_high, - res1_low, res1_high, -// secondary_res * (res2_low - res1_low) + res1_low, -// secondary_res * (res2_high - res1_high) + res1_high, - 0], mul: multiplicity); - - for (int index = 0; index < length(carry); ++index) { - range_check(colu: carry[index], min:-2**20, max: 2**20-1); // TODO: review range - } - - // loop for range checks index 0, 2 - for (int index = 0; index < 3; ++index) { - range_check(colu: a[index], min:0, max: CHUNK_SIZE-1); - range_check(colu: b[index], min:0, max: CHUNK_SIZE-1); - range_check(colu: c[index], min:0, max: CHUNK_SIZE-1); - range_check(colu: d[index], min:0, max: CHUNK_SIZE-1); - } - - range_check(colu: c[3], min:0, max: 2**16-1); - - col witness range_a3; - col witness range_b3; - col witness range_d3; - - lookup_assumes(ARITH_TABLE_ID, cols: [ op, 4 * na + 8 * nb + 32 * np + 2**16 * range_a3 + 2**18 * range_b3 + 2**22 * range_d3]); - - range_a3 * (1 - range_a3) * (2 - range_a3) === 0; - range_b3 * (1 - range_b3) * (2 - range_b3) === 0; - range_d3 * (1 - range_d3) * (2 - range_d3) === 0; - - lookup_assumes(QUICK_RANGE_TABLE_ID, [range_a3, a[3]]); - lookup_assumes(QUICK_RANGE_TABLE_ID, [range_b3, b[3]]); - lookup_assumes(QUICK_RANGE_TABLE_ID, [range_d3, d[3]]); -} \ No newline at end of file diff --git a/state-machines/arith/src/arith.rs b/state-machines/arith/src/arith.rs index 051686d4..1bde17f2 100644 --- a/state-machines/arith/src/arith.rs +++ b/state-machines/arith/src/arith.rs @@ -14,7 +14,6 @@ use zisk_pil::{ ARITH_TABLE_AIRGROUP_ID, ARITH_TABLE_AIR_IDS, }; -// use crate::{Arith32SM, ArithFullSM, ArithMul32SM, ArithMul64SM, ArithRangeTableSM, ArithTableSM}; use crate::{ArithFullSM, ArithRangeTableSM, ArithTableSM}; const PROVE_CHUNK_SIZE: usize = 1 << 12; @@ -29,14 +28,6 @@ pub struct ArithSM { // Inputs inputs: Mutex>, - // inputs_32: Mutex>, - // inputs_mul_32: Mutex>, - // inputs_mul_64: Mutex>, - - // Secondary State machines - // arith_32_sm: Arc>, - // arith_mul_32_sm: Arc>, - // arith_mul_64_sm: Arc>, arith_full_sm: Arc>, arith_table_sm: Arc>, arith_range_table_sm: Arc>, @@ -70,9 +61,6 @@ impl ArithSM { wcm.register_component(arith_sm.clone(), None, None); - // arith_sm.arith_32_sm.register_predecessor(); - // arith_sm.arith_mul_32_sm.register_predecessor(); - // arith_sm.arith_mul_64_sm.register_predecessor(); arith_sm.arith_full_sm.register_predecessor(); arith_sm @@ -93,9 +81,6 @@ impl ArithSM { self.threads_controller.wait_for_threads(); - // self.arith_32_sm.unregister_predecessor(scope); - // self.arith_mul_32_sm.unregister_predecessor(scope); - // self.arith_mul_64_sm.unregister_predecessor(scope); self.arith_full_sm.unregister_predecessor(scope); } } diff --git a/state-machines/arith/src/arith_32.rs b/state-machines/arith/src/arith_32.rs deleted file mode 100644 index b8d6df1a..00000000 --- a/state-machines/arith/src/arith_32.rs +++ /dev/null @@ -1,115 +0,0 @@ -use std::{ - fmt::Error, - sync::{ - atomic::{AtomicU32, Ordering}, - Arc, Mutex, - }, -}; - -use p3_field::Field; -use proofman::{WitnessComponent, WitnessManager}; -use proofman_common::{ExecutionCtx, ProofCtx, SetupCtx}; -use rayon::Scope; -use sm_common::{OpResult, Provable}; -use zisk_core::{zisk_ops::ZiskOp, ZiskRequiredOperation}; -use zisk_pil::{ARITH32_AIR_IDS, ARITH_AIRGROUP_ID}; - -const PROVE_CHUNK_SIZE: usize = 1 << 12; - -pub struct Arith32SM { - // Count of registered predecessors - registered_predecessors: AtomicU32, - - // Inputs - inputs: Mutex>, - - _phantom: std::marker::PhantomData, -} - -impl Arith32SM { - pub fn new(wcm: Arc>) -> Arc { - let _arith_32_sm = Self { - registered_predecessors: AtomicU32::new(0), - inputs: Mutex::new(Vec::new()), - _phantom: std::marker::PhantomData, - }; - let arith_32_sm = Arc::new(_arith_32_sm); - - wcm.register_component(arith_32_sm.clone(), Some(ARITH_AIRGROUP_ID), Some(ARITH32_AIR_IDS)); - - arith_32_sm - } - - pub fn register_predecessor(&self) { - self.registered_predecessors.fetch_add(1, Ordering::SeqCst); - } - - pub fn unregister_predecessor(&self, scope: &Scope) { - if self.registered_predecessors.fetch_sub(1, Ordering::SeqCst) == 1 { - as Provable>::prove( - self, - &[], - true, - scope, - ); - } - } - - pub fn operations() -> Vec { - vec![0xb6, 0xb7, 0xbe, 0xbf] - } -} - -impl WitnessComponent for Arith32SM { - fn calculate_witness( - &self, - _stage: u32, - _air_instance: Option, - _pctx: Arc>, - _ectx: Arc, - _sctx: Arc, - ) { - } -} - -impl Provable for Arith32SM { - fn calculate( - &self, - operation: ZiskRequiredOperation, - ) -> Result> { - let result: OpResult = ZiskOp::execute( - ZiskOp::try_from_code(operation.opcode).map_err(|_| Error)?.code(), - operation.a, - operation.b, - ); - Ok(result) - } - - fn prove(&self, operations: &[ZiskRequiredOperation], drain: bool, scope: &Scope) { - if let Ok(mut inputs) = self.inputs.lock() { - inputs.extend_from_slice(operations); - - while inputs.len() >= PROVE_CHUNK_SIZE || (drain && !inputs.is_empty()) { - let num_drained = std::cmp::min(PROVE_CHUNK_SIZE, inputs.len()); - let _drained_inputs = inputs.drain(..num_drained).collect::>(); - - scope.spawn(move |_| { - // TODO! Implement prove drained_inputs (a chunk of operations) - }); - } - } - } - - fn calculate_prove( - &self, - operation: ZiskRequiredOperation, - drain: bool, - scope: &Scope, - ) -> Result> { - let result = self.calculate(operation.clone()); - - self.prove(&[operation], drain, scope); - - result - } -} diff --git a/state-machines/arith/src/arith_full.rs b/state-machines/arith/src/arith_full.rs index b0c30880..4172854b 100644 --- a/state-machines/arith/src/arith_full.rs +++ b/state-machines/arith/src/arith_full.rs @@ -4,8 +4,8 @@ use std::sync::{ }; use crate::{ - arith_table_inputs, ArithOperation, ArithRangeTableInputs, ArithRangeTableSM, ArithSM, - ArithTableInputs, ArithTableSM, + ArithOperation, ArithRangeTableInputs, ArithRangeTableSM, ArithSM, ArithTableInputs, + ArithTableSM, }; use p3_field::Field; use proofman::{WitnessComponent, WitnessManager}; @@ -73,8 +73,8 @@ impl ArithFullSM { } pub fn process_slice( input: &Vec, - range_table_inputs: &mut ArithRangeTableInputs, - table_inputs: &mut ArithTableInputs, + range_table_inputs: &mut ArithRangeTableInputs, + table_inputs: &mut ArithTableInputs, ) -> Vec> { let mut traces: Vec> = Vec::new(); let mut aop = ArithOperation::new(); @@ -170,12 +170,12 @@ impl Provable for ArithFullSM { let _drained_inputs = inputs.drain(..num_drained).collect::>(); scope.spawn(move |_| { - let mut arith_range_table_inputs = ArithRangeTableInputs::::new(); - let mut arith_table_inputs = ArithTableInputs::::new(); + let mut arith_range_table = ArithRangeTableInputs::new(); + let mut arith_table = ArithTableInputs::new(); let _trace = Self::process_slice( &_drained_inputs, - &mut arith_range_table_inputs, - &mut arith_table_inputs, + &mut arith_range_table, + &mut arith_table, ); // thread_controller.remove_working_thread(); // TODO! Implement prove drained_inputs (a chunk of operations) diff --git a/state-machines/arith/src/arith_mul_32.rs b/state-machines/arith/src/arith_mul_32.rs deleted file mode 100644 index 4d883440..00000000 --- a/state-machines/arith/src/arith_mul_32.rs +++ /dev/null @@ -1,117 +0,0 @@ -use std::sync::{ - atomic::{AtomicU32, Ordering}, - Arc, Mutex, -}; - -use p3_field::Field; -use proofman::{WitnessComponent, WitnessManager}; -use proofman_common::{ExecutionCtx, ProofCtx, SetupCtx}; -use rayon::Scope; -use sm_common::{OpResult, Provable}; -use zisk_core::{zisk_ops::ZiskOp, ZiskRequiredOperation}; -use zisk_pil::{ARITH3264_AIR_IDS, ARITH_AIRGROUP_ID}; - -const PROVE_CHUNK_SIZE: usize = 1 << 12; - -pub struct ArithMul32SM { - // Count of registered predecessors - registered_predecessors: AtomicU32, - - // Inputs - inputs: Mutex>, - - _phantom: std::marker::PhantomData, -} - -impl ArithMul32SM { - pub fn new(wcm: Arc>) -> Arc { - let arith_mul_32_sm = Self { - registered_predecessors: AtomicU32::new(0), - inputs: Mutex::new(Vec::new()), - _phantom: std::marker::PhantomData, - }; - let arith_mul_32_sm = Arc::new(arith_mul_32_sm); - - wcm.register_component( - arith_mul_32_sm.clone(), - Some(ARITH_AIRGROUP_ID), - Some(ARITH3264_AIR_IDS), - ); - - arith_mul_32_sm - } - - pub fn register_predecessor(&self) { - self.registered_predecessors.fetch_add(1, Ordering::SeqCst); - } - - pub fn unregister_predecessor(&self, scope: &Scope) { - if self.registered_predecessors.fetch_sub(1, Ordering::SeqCst) == 1 { - as Provable>::prove( - self, - &[], - true, - scope, - ); - } - } - - pub fn operations() -> Vec { - // TODO: use constants - vec![0xb6, 0xb7, 0xbe, 0xbf] - } -} - -impl WitnessComponent for ArithMul32SM { - fn calculate_witness( - &self, - _stage: u32, - _air_instance: Option, - _pctx: Arc>, - _ectx: Arc, - _sctx: Arc, - ) { - } -} - -impl Provable for ArithMul32SM { - fn calculate( - &self, - operation: ZiskRequiredOperation, - ) -> Result> { - let result: OpResult = ZiskOp::execute(operation.opcode, operation.a, operation.b); - Ok(result) - } - - fn prove(&self, operations: &[ZiskRequiredOperation], drain: bool, scope: &Scope) { - if let Ok(mut inputs) = self.inputs.lock() { - inputs.extend_from_slice(operations); - - while inputs.len() >= PROVE_CHUNK_SIZE || (drain && !inputs.is_empty()) { - if drain && !inputs.is_empty() { - // println!("Arith3264SM: Draining inputs3264"); - } - - let num_drained = std::cmp::min(PROVE_CHUNK_SIZE, inputs.len()); - let _drained_inputs = inputs.drain(..num_drained).collect::>(); - - scope.spawn(move |_| { - // TODO! Implement prove drained_inputs (a chunk of operations) - }); - } - } - } - - fn calculate_prove( - &self, - operation: ZiskRequiredOperation, - drain: bool, - scope: &Scope, - ) -> Result> { - let result = self.calculate(operation.clone()); - - self.prove(&[operation], drain, scope); - - result - } -} diff --git a/state-machines/arith/src/arith_mul_64.rs b/state-machines/arith/src/arith_mul_64.rs deleted file mode 100644 index a925a6c4..00000000 --- a/state-machines/arith/src/arith_mul_64.rs +++ /dev/null @@ -1,108 +0,0 @@ -use std::sync::{ - atomic::{AtomicU32, Ordering}, - Arc, Mutex, -}; - -use p3_field::Field; -use proofman::{WitnessComponent, WitnessManager}; -use proofman_common::{ExecutionCtx, ProofCtx, SetupCtx}; -use rayon::Scope; -use sm_common::{OpResult, Provable}; -use zisk_core::{zisk_ops::ZiskOp, ZiskRequiredOperation}; - -const PROVE_CHUNK_SIZE: usize = 1 << 12; - -pub struct ArithMul64SM { - // Count of registered predecessors - registered_predecessors: AtomicU32, - - // Inputs - inputs: Mutex>, - - _phantom: std::marker::PhantomData, -} - -impl ArithMul64SM { - pub fn new(wcm: Arc>, airgroup_id: usize, air_ids: &[usize]) -> Arc { - let arith_mul_64_sm = Self { - registered_predecessors: AtomicU32::new(0), - inputs: Mutex::new(Vec::new()), - _phantom: std::marker::PhantomData, - }; - let arith_mul_64_sm = Arc::new(arith_mul_64_sm); - - wcm.register_component(arith_mul_64_sm.clone(), Some(airgroup_id), Some(air_ids)); - - arith_mul_64_sm - } - - pub fn register_predecessor(&self) { - self.registered_predecessors.fetch_add(1, Ordering::SeqCst); - } - - pub fn unregister_predecessor(&self, scope: &Scope) { - if self.registered_predecessors.fetch_sub(1, Ordering::SeqCst) == 1 { - as Provable>::prove( - self, - &[], - true, - scope, - ); - } - } - - pub fn operations() -> Vec { - // TODO: use constants - vec![0xb0, 0xb1, 0xb2, 0xb3, 0xb4, 0xb5, 0xb8, 0xb9, 0xba, 0xbb] - } -} - -impl WitnessComponent for ArithMul64SM { - fn calculate_witness( - &self, - _stage: u32, - _air_instance: Option, - _pctx: Arc>, - _ectx: Arc, - _sctx: Arc, - ) { - } -} - -impl Provable for ArithMul64SM { - fn calculate( - &self, - operation: ZiskRequiredOperation, - ) -> Result> { - let result: OpResult = ZiskOp::execute(operation.opcode, operation.a, operation.b); - Ok(result) - } - - fn prove(&self, operations: &[ZiskRequiredOperation], drain: bool, scope: &Scope) { - if let Ok(mut inputs) = self.inputs.lock() { - inputs.extend_from_slice(operations); - - while inputs.len() >= PROVE_CHUNK_SIZE || (drain && !inputs.is_empty()) { - let num_drained = std::cmp::min(PROVE_CHUNK_SIZE, inputs.len()); - let _drained_inputs = inputs.drain(..num_drained).collect::>(); - - scope.spawn(move |_| { - // TODO! Implement prove drained_inputs (a chunk of operations) - }); - } - } - } - - fn calculate_prove( - &self, - operation: ZiskRequiredOperation, - drain: bool, - scope: &Scope, - ) -> Result> { - let result = self.calculate(operation.clone()); - - self.prove(&[operation], drain, scope); - - result - } -} diff --git a/state-machines/arith/src/arith_range_table_helpers.rs b/state-machines/arith/src/arith_range_table_helpers.rs index be68d777..f02e4ae5 100644 --- a/state-machines/arith/src/arith_range_table_helpers.rs +++ b/state-machines/arith/src/arith_range_table_helpers.rs @@ -62,28 +62,28 @@ impl AirthRangeTableHelpers { (0x220000 + 0xEFFFF + value) as usize } } -struct AirthRangeTableMultiplicity { +pub struct ArithRangeTableInputs { multiplicity: [u64; ROWS], } -impl AirthRangeTableMultiplicity { - fn new() -> Self { - AirthRangeTableMultiplicity { multiplicity: [0; ROWS] } +impl ArithRangeTableInputs { + pub fn new() -> Self { + ArithRangeTableInputs { multiplicity: [0; ROWS] } } - fn use_chunk_range_check(&self, range_id: u8, value: i64) { + pub fn use_chunk_range_check(&self, range_id: u8, value: i64) { let row = AirthRangeTableHelpers::get_row_chunk_range_check(range_id, value); self.multiplicity[row as usize]; } - fn use_carry_range_check(&self, value: i64) { + pub fn use_carry_range_check(&self, value: i64) { let row = AirthRangeTableHelpers::get_row_carry_range_check(value); self.multiplicity[row as usize]; } } -impl Add for AirthRangeTableMultiplicity { +impl Add for ArithRangeTableInputs { type Output = Self; fn add(self, other: Self) -> Self { - let mut result = AirthRangeTableMultiplicity::new(); + let mut result = ArithRangeTableInputs::new(); for i in 0..ROWS { result.multiplicity[i] = self.multiplicity[i] + other.multiplicity[i]; } diff --git a/state-machines/arith/src/arith_range_table_inputs.rs b/state-machines/arith/src/arith_range_table_inputs.rs deleted file mode 100644 index 8985026b..00000000 --- a/state-machines/arith/src/arith_range_table_inputs.rs +++ /dev/null @@ -1,48 +0,0 @@ -use std::ops::Add; - -const ARITH_RANGE_TABLE_SIZE: usize = 2 << 17; - -pub struct ArithRangeTableInputs { - multiplicity: [u32; ARITH_RANGE_TABLE_SIZE], - _phantom: std::marker::PhantomData, -} - -impl Add for ArithRangeTableInputs { - type Output = Self; - fn add(self, other: Self) -> Self { - let mut result = Self::new(); - for i in 0..ARITH_RANGE_TABLE_SIZE { - result.multiplicity[i] = self.multiplicity[i] + other.multiplicity[i]; - } - result - } -} - -impl ArithRangeTableInputs { - pub fn new() -> Self { - Self { multiplicity: [0; ARITH_RANGE_TABLE_SIZE], _phantom: std::marker::PhantomData } - } - pub fn clear(&mut self) { - self.multiplicity = [0; ARITH_RANGE_TABLE_SIZE]; - } - pub fn push(&mut self, range_id: u8, value: u64) { - Self::check_value(range_id, value); - self.fast_push(range_id, value); - } - fn get_row(range_id: u8, value: u64) -> usize { - usize::try_from(value + if range_id > 0 { 2 << 16 } else { 0 }).unwrap() - % ARITH_RANGE_TABLE_SIZE - } - fn check_value(range_id: u8, value: u64) { - match range_id { - 0 => assert!(value <= 0xFFFF), - 1 => assert!(value <= 0x7FFF), - 2 => assert!(value <= 0xFFFF && value >= 0x8000), - _ => assert!(false), - }; - } - - pub fn fast_push(&mut self, op: u8, value: u64) { - self.multiplicity[Self::get_row(op, value)] += 1; - } -} diff --git a/state-machines/arith/src/arith_table_helpers.rs b/state-machines/arith/src/arith_table_helpers.rs index 04a89931..ee1683dc 100644 --- a/state-machines/arith/src/arith_table_helpers.rs +++ b/state-machines/arith/src/arith_table_helpers.rs @@ -3,10 +3,10 @@ use std::ops::Add; const ROWS: usize = 95; const FIRST_OP: u8 = 0xb0; -struct AirthTableHelpers; +pub struct AirthTableHelpers; impl AirthTableHelpers { - fn get_row(op: u8, na: u64, nb: u64, np: u64, nr: u64, sext: u64) -> usize { + pub fn get_row(op: u8, na: u64, nb: u64, np: u64, nr: u64, sext: u64) -> usize { static ARITH_TABLE_ROWS: [i16; 512] = [ 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, @@ -39,30 +39,30 @@ impl AirthTableHelpers { assert!(row >= 0); row as usize } - fn get_max_row() -> usize { + pub fn get_max_row() -> usize { ROWS - 1 } } -struct AirthTableMultiplicity { +pub struct ArithTableInputs { multiplicity: [u64; ROWS], } -impl AirthTableMultiplicity { - fn new() -> Self { - AirthTableMultiplicity { multiplicity: [0; ROWS] } +impl ArithTableInputs { + pub fn new() -> Self { + ArithTableInputs { multiplicity: [0; ROWS] } } - fn add_use(&self, op: u8, na: u64, nb: u64, np: u64, nr: u64, sext: u64) { + pub fn add_use(&self, op: u8, na: u64, nb: u64, np: u64, nr: u64, sext: u64) { let row = AirthTableHelpers::get_row(op, na, nb, np, nr, sext); self.multiplicity[row as usize]; } } -impl Add for AirthTableMultiplicity { +impl Add for ArithTableInputs { type Output = Self; fn add(self, other: Self) -> Self { - let mut result = AirthTableMultiplicity::new(); + let mut result = ArithTableInputs::new(); for i in 0..ROWS { result.multiplicity[i] = self.multiplicity[i] + other.multiplicity[i]; } diff --git a/state-machines/arith/src/arith_table_inputs.rs b/state-machines/arith/src/arith_table_inputs.rs deleted file mode 100644 index e1a1e89e..00000000 --- a/state-machines/arith/src/arith_table_inputs.rs +++ /dev/null @@ -1,134 +0,0 @@ -use std::ops::Add; - -const ARITH_TABLE_SIZE: usize = 36; -pub struct ArithTableInputs { - multiplicity: [u32; ARITH_TABLE_SIZE], - _phantom: std::marker::PhantomData, -} - -impl Add for ArithTableInputs { - type Output = Self; - fn add(self, other: Self) -> Self { - let mut result = Self::new(); - for i in 0..ARITH_TABLE_SIZE { - result.multiplicity[i] = self.multiplicity[i] + other.multiplicity[i]; - } - result - } -} - -impl ArithTableInputs { - const FLAGS_AND_RANGES: [u32; ARITH_TABLE_SIZE] = [ - 0x000000, 0x000000, 0x010000, 0x020024, 0x050000, 0x0A0024, 0x050028, 0x0A002C, 0x050000, - 0x0A0024, 0x050028, 0x0A002C, 0x001501, 0x002665, 0x001929, 0x002A6D, 0x000002, 0x000002, - 0x550002, 0xAA0036, 0xA5002A, 0xAA003E, 0x550002, 0xAA0036, 0xA5002A, 0xAA003E, 0x009003, - 0x009003, 0x009503, 0x0066F7, 0x00692B, 0x006AFF, 0x009503, 0x0066F7, 0x00692B, 0x006AFF, - ]; - pub fn new() -> Self { - Self { multiplicity: [0; ARITH_TABLE_SIZE], _phantom: std::marker::PhantomData } - } - pub fn clear(&mut self) { - self.multiplicity = [0; ARITH_TABLE_SIZE]; - } - pub fn push( - &mut self, - op: u8, - m32: u32, - div: u32, - na: u32, - nb: u32, - nr: u32, - np: u32, - na32: u32, - nd32: u32, - range_a1: u32, - range_b1: u32, - range_c1: u32, - range_d1: u32, - range_a3: u32, - range_b3: u32, - range_c3: u32, - range_d3: u32, - ) { - // TODO: in debug mode - let flags = Self::values_to_flags( - m32, div, na, nb, nr, np, na32, nd32, range_a1, range_b1, range_c1, range_d1, range_a3, - range_b3, range_c3, range_d3, - ); - let variants = Self::get_variants(op); - let row_offset = nb * 2 + nb; - let row: usize = Self::get_row(op, na, nb); - - assert!(row_offset < variants); - assert!(Self::FLAGS_AND_RANGES[row] == flags); - - self.multiplicity[row] += 1; - } - fn get_variants(op: u8) -> u32 { - match op { - 0xb0 | 0xb1 | 0xb8 | 0xb9 | 0xbc | 0xbd => 1, // mulu|muluh|divu|remu|divu_w|remu_w - 0xb3 => 2, // mulsuh - 0xb4 | 0xb5 | 0xb6 | 0xba | 0xbb | 0xbe | 0xbf => 4, /* mul|mulh|mul_w|div|rem|div_w|rem_w */ - _ => panic!("Invalid opcode"), - } - } - fn get_offset(op: u8) -> u32 { - match op { - 0xb0 => 0, // mulu - 0xb1 => 1, // muluh - 0xb3 => 2, // mulsuh - 0xb4 => 4, // mul - 0xb5 => 8, // mulh - 0xb6 => 12, // mul_w - 0xb8 => 16, // divu - 0xb9 => 17, // remu - 0xba => 18, // div - 0xbb => 22, // rem - 0xbc => 26, // divu_w - 0xbd => 27, // remu_w - 0xbe => 28, // div_w - 0xbf => 32, // rem_w - _ => panic!("Invalid opcode"), - } - } - fn get_row(op: u8, na: u32, nb: u32) -> usize { - usize::try_from(Self::get_offset(op) + na + 2 * nb).unwrap() % ARITH_TABLE_SIZE - } - pub fn fast_push(&mut self, op: u8, na: u32, nb: u32) { - self.multiplicity[Self::get_row(op, na, nb)] += 1; - } - fn values_to_flags( - m32: u32, - div: u32, - na: u32, - nb: u32, - nr: u32, - np: u32, - na32: u32, - nd32: u32, - range_a1: u32, - range_b1: u32, - range_c1: u32, - range_d1: u32, - range_a3: u32, - range_b3: u32, - range_c3: u32, - range_d3: u32, - ) -> u32 { - m32 + 0x000002 * div - + 0x000004 * na - + 0x000008 * nb - + 0x000010 * nr - + 0x000020 * np - + 0x000040 * na32 - + 0x000080 * nd32 - + 0x000100 * range_a1 - + 0x000400 * range_b1 - + 0x001000 * range_c1 - + 0x004000 * range_d1 - + 0x010000 * range_a3 - + 0x040000 * range_b3 - + 0x100000 * range_c3 - + 0x400000 * range_d3 - } -} diff --git a/state-machines/arith/src/lib.rs b/state-machines/arith/src/lib.rs index fab4cf43..348c6ea5 100644 --- a/state-machines/arith/src/lib.rs +++ b/state-machines/arith/src/lib.rs @@ -1,28 +1,18 @@ mod arith; -mod arith_32; mod arith_constants; mod arith_full; -mod arith_mul_32; -mod arith_mul_64; mod arith_operation; mod arith_operation_test; mod arith_range_table; mod arith_range_table_helpers; -mod arith_range_table_inputs; mod arith_table; mod arith_table_helpers; -mod arith_table_inputs; pub use arith::*; -pub use arith_32::*; pub use arith_constants::*; pub use arith_full::*; -pub use arith_mul_32::*; -pub use arith_mul_64::*; pub use arith_operation::*; pub use arith_range_table::*; pub use arith_range_table_helpers::*; -pub use arith_range_table_inputs::*; pub use arith_table::*; pub use arith_table_helpers::*; -pub use arith_table_inputs::*; From a50c88604036dc4ff4001e1a6f3d6614d813e69b Mon Sep 17 00:00:00 2001 From: zkronos73 Date: Mon, 4 Nov 2024 14:22:55 +0000 Subject: [PATCH 17/17] add multiplicity on arith tables --- state-machines/arith/Cargo.toml | 1 + state-machines/arith/pil/arith.pil | 39 ++++++------ .../arith/pil/arith_range_table.pil | 8 +-- state-machines/arith/src/arith.rs | 60 +++---------------- state-machines/arith/src/arith_full.rs | 48 +++++++++++++-- state-machines/arith/src/arith_operation.rs | 4 +- .../arith/src/arith_operation_test.rs | 3 - .../arith/src/arith_range_table_helpers.rs | 27 +++++---- state-machines/arith/src/arith_table.rs | 7 ++- .../arith/src/arith_table_helpers.rs | 22 +++++-- state-machines/arith/src/lib.rs | 4 +- 11 files changed, 120 insertions(+), 103 deletions(-) diff --git a/state-machines/arith/Cargo.toml b/state-machines/arith/Cargo.toml index 23f8d0b2..36b43c80 100644 --- a/state-machines/arith/Cargo.toml +++ b/state-machines/arith/Cargo.toml @@ -18,4 +18,5 @@ rayon = { workspace = true } [features] default = [] +generate_code_arith_range_table = [] no_lib_link = ["proofman-common/no_lib_link", "proofman/no_lib_link"] \ No newline at end of file diff --git a/state-machines/arith/pil/arith.pil b/state-machines/arith/pil/arith.pil index 968c94e8..18f45d92 100644 --- a/state-machines/arith/pil/arith.pil +++ b/state-machines/arith/pil/arith.pil @@ -258,14 +258,14 @@ airtemplate Arith(int N = 2**18, const int operation_bus_id, const int dual_resu const expr bus_a0 = div * (c[0] + c[1] * CHUNK_SIZE) + (1 - div) * (a[0] + a[1] * CHUNK_SIZE); const expr bus_a1 = div * (c[2] + c[3] * CHUNK_SIZE) + (1 - div) * (a[2] + a[3] * CHUNK_SIZE); - const expr bus_b0 = b[0] + CHUNK_SIZE * b[1]; - const expr bus_b1 = b[2] + CHUNK_SIZE * b[3]; + const expr bus_b0 = b[0] + b[1] * CHUNK_SIZE; + const expr bus_b1 = b[2] + b[3] * CHUNK_SIZE; - const expr bus_res0 = secondary * (d[0] + CHUNK_SIZE * d[1]) + + const expr bus_res0 = secondary * (d[0] + d[1] * CHUNK_SIZE) + main_mul * (c[0] + c[1] * CHUNK_SIZE) + main_div * (a[0] + a[1] * CHUNK_SIZE); - const expr bus_res1_64 = (secondary * (d[2] + CHUNK_SIZE * d[3]) + + const expr bus_res1_64 = (secondary * (d[2] + d[3] * CHUNK_SIZE) + main_mul * (c[2] + c[3] * CHUNK_SIZE) + main_div * (a[2] + a[3] * CHUNK_SIZE)); col witness bus_res1; @@ -299,19 +299,24 @@ airtemplate Arith(int N = 2**18, const int operation_bus_id, const int dual_resu arith_table_assumes(op, m32, div, na, nb, np, nr, sext, main_mul, main_div, signed, range_ab, range_cd); - const expr range_a = range_ab; - const expr range_b = range_ab + 26; - const expr range_c = range_cd + 17; - const expr range_d = range_cd + 9; - - arith_range_table_assumes(range_a, a[1]); - arith_range_table_assumes(range_b, b[1]); - arith_range_table_assumes(range_c, c[1]); - arith_range_table_assumes(range_d, d[1]); - arith_range_table_assumes(range_a, a[3]); - arith_range_table_assumes(range_b, b[3]); - arith_range_table_assumes(range_c, c[3]); - arith_range_table_assumes(range_d, d[3]); + const expr range_a3 = range_ab; + const expr range_a1 = range_ab + 26; + const expr range_b3 = range_ab + 17; + const expr range_b1 = range_ab + 9; + + const expr range_c3 = range_cd; + const expr range_c1 = range_cd + 26; + const expr range_d3 = range_cd + 17; + const expr range_d1 = range_cd + 9; + + arith_range_table_assumes(range_a1, a[1]); + arith_range_table_assumes(range_b1, b[1]); + arith_range_table_assumes(range_c1, c[1]); + arith_range_table_assumes(range_d1, d[1]); + arith_range_table_assumes(range_a3, a[3]); + arith_range_table_assumes(range_b3, b[3]); + arith_range_table_assumes(range_c3, c[3]); + arith_range_table_assumes(range_d3, d[3]); // loop for range checks index 0, 2 for (int index = 0; index < 2; ++index) { diff --git a/state-machines/arith/pil/arith_range_table.pil b/state-machines/arith/pil/arith_range_table.pil index 42369708..89caba0e 100644 --- a/state-machines/arith/pil/arith_range_table.pil +++ b/state-machines/arith/pil/arith_range_table.pil @@ -48,10 +48,10 @@ airtemplate ArithRangeTable(int N = 2**22) { // // 25:FULL + 9:POS + 9:NEG = 34 * 2^16 = 2^21 + 2^17 // - // [range, 0] => [range] - // [range, 1] => [range + 26] - // [range, 2] => [range + 17] - // [range, 3] => [range + 9] + // a3 c3 [range, 0] => [range] + // a1 c1 [range, 1] => [range + 26] + // b3 d3 [range, 2] => [range + 17] + // b1 d1 [range, 3] => [range + 9] // // [-(2^19+2^18+2^16-1)...(2^19+2^18+2^16)] range check carry diff --git a/state-machines/arith/src/arith.rs b/state-machines/arith/src/arith.rs index 1bde17f2..3f06df2a 100644 --- a/state-machines/arith/src/arith.rs +++ b/state-machines/arith/src/arith.rs @@ -108,70 +108,24 @@ impl Provable for ArithSM { } fn prove(&self, operations: &[ZiskRequiredOperation], drain: bool, scope: &Scope) { - // let mut _inputs32 = Vec::new(); - // let mut _inputs64 = Vec::new(); - - // let operations64 = ArithMul64SM::::operations(); - // let operations32 = Arith32SM::::operations(); - - // TODO Split the operations into 32 and 64 bit operations in parallel - // for operation in operations { - // if operations32.contains(&operation.opcode) { - // _inputs32.push(operation.clone()); - // } else if operations64.contains(&operation.opcode) { - // _inputs64.push(operation.clone()); - // } else { - // panic!("ArithSM: Operator {:x} not found", operation.opcode); - // } - // } - - // TODO When drain is true, drain remaining inputs to the 3264 bits state machine - /* - let mut inputs32 = self.inputs_32.lock().unwrap(); - inputs32.extend(_inputs32); - - while inputs32.len() >= PROVE_CHUNK_SIZE || (drain && !inputs32.is_empty()) { - if drain && !inputs32.is_empty() { - // println!("ArithSM: Draining inputs32"); + while operations.len() >= PROVE_CHUNK_SIZE || (drain && !operations.is_empty()) { + if drain && !operations.is_empty() { + // println!("ArithSM: Draining inputs"); } - let num_drained32 = std::cmp::min(PROVE_CHUNK_SIZE, inputs32.len()); - let drained_inputs32 = inputs32.drain(..num_drained32).collect::>(); - let arith32_sm_cloned = self.arith_32_sm.clone(); + let num_drained = std::cmp::min(PROVE_CHUNK_SIZE, operations.len()); + let drained_inputs = operations[..num_drained].to_vec(); + let arith_full_sm_cloned = self.arith_full_sm.clone(); self.threads_controller.add_working_thread(); let thread_controller = self.threads_controller.clone(); scope.spawn(move |scope| { - arith32_sm_cloned.prove(&drained_inputs32, drain, scope); + arith_full_sm_cloned.prove(&drained_inputs, drain, scope); thread_controller.remove_working_thread(); }); } - drop(inputs32); - - let mut inputs64 = self.inputs_mul_64.lock().unwrap(); - inputs64.extend(_inputs64); - - while inputs64.len() >= PROVE_CHUNK_SIZE || (drain && !inputs64.is_empty()) { - if drain && !inputs64.is_empty() { - // println!("ArithSM: Draining inputs64"); - } - - let num_drained64 = std::cmp::min(PROVE_CHUNK_SIZE, inputs64.len()); - let drained_inputs64 = inputs64.drain(..num_drained64).collect::>(); - let arith64_sm_cloned = self.arith_mul_64_sm.clone(); - - self.threads_controller.add_working_thread(); - let thread_controller = self.threads_controller.clone(); - - scope.spawn(move |scope| { - arith64_sm_cloned.prove(&drained_inputs64, drain, scope); - - thread_controller.remove_working_thread(); - }); - } - drop(inputs64);*/ } fn calculate_prove( diff --git a/state-machines/arith/src/arith_full.rs b/state-machines/arith/src/arith_full.rs index 4172854b..e02cfc4f 100644 --- a/state-machines/arith/src/arith_full.rs +++ b/state-machines/arith/src/arith_full.rs @@ -4,8 +4,7 @@ use std::sync::{ }; use crate::{ - ArithOperation, ArithRangeTableInputs, ArithRangeTableSM, ArithSM, ArithTableInputs, - ArithTableSM, + ArithOperation, ArithRangeTableInputs, ArithRangeTableSM, ArithTableInputs, ArithTableSM, }; use p3_field::Field; use proofman::{WitnessComponent, WitnessManager}; @@ -15,8 +14,17 @@ use sm_common::{OpResult, Provable, ThreadController}; use zisk_core::{zisk_ops::ZiskOp, ZiskRequiredOperation}; use zisk_pil::Arith0Row; -const PROVE_CHUNK_SIZE: usize = 1 << 12; +fn i64_to_u64_field(value: i64) -> u64 { + const PRIME_MINUS_ONE: u64 = 0xFFFF_FFFF_0000_0000; + if value >= 0 { + value as u64 + } else { + PRIME_MINUS_ONE - (0xFFFF_FFFF_FFFF_FFFF - value as u64) + } +} +const PROVE_CHUNK_SIZE: usize = 1 << 12; +const PRIME: u64 = 0xFFFF_FFFF_0000_0001; pub struct ArithFullSM { // Count of registered predecessors registered_predecessors: AtomicU32, @@ -81,13 +89,38 @@ impl ArithFullSM { for input in input.iter() { aop.calculate(input.opcode, input.a, input.b); let mut t: Arith0Row = Default::default(); - for i in 0..4 { + for i in [0, 2] { t.a[i] = F::from_canonical_u64(aop.a[i]); t.b[i] = F::from_canonical_u64(aop.b[i]); t.c[i] = F::from_canonical_u64(aop.c[i]); t.d[i] = F::from_canonical_u64(aop.d[i]); + range_table_inputs.use_chunk_range_check(0, aop.a[i]); + range_table_inputs.use_chunk_range_check(0, aop.b[i]); + range_table_inputs.use_chunk_range_check(0, aop.c[i]); + range_table_inputs.use_chunk_range_check(0, aop.d[i]); // arith_operation.a[i]; } + for i in [1, 3] { + t.a[i] = F::from_canonical_u64(aop.a[i]); + t.b[i] = F::from_canonical_u64(aop.b[i]); + t.c[i] = F::from_canonical_u64(aop.c[i]); + t.d[i] = F::from_canonical_u64(aop.d[i]); + // arith_operation.a[i]; + } + range_table_inputs.use_chunk_range_check(aop.range_ab, aop.a[3]); + range_table_inputs.use_chunk_range_check(aop.range_ab + 26, aop.a[1]); + range_table_inputs.use_chunk_range_check(aop.range_ab + 17, aop.b[3]); + range_table_inputs.use_chunk_range_check(aop.range_ab + 9, aop.b[1]); + + range_table_inputs.use_chunk_range_check(aop.range_cd, aop.c[3]); + range_table_inputs.use_chunk_range_check(aop.range_cd + 26, aop.c[1]); + range_table_inputs.use_chunk_range_check(aop.range_cd + 17, aop.d[3]); + range_table_inputs.use_chunk_range_check(aop.range_cd + 9, aop.d[1]); + + for i in 0..7 { + t.carry[i] = F::from_canonical_u64(i64_to_u64_field(aop.carry[i])); + range_table_inputs.use_carry_range_check(aop.carry[i]); + } // range_table_inputs.push(0, 0); // table_inputs.fast_push(0, 0, 0); t.m32 = F::from_bool(aop.m32); @@ -102,6 +135,8 @@ impl ArithFullSM { t.sext = F::from_bool(aop.sext); t.multiplicity = F::one(); + table_inputs.add_use(aop.op, aop.na, aop.nb, aop.np, aop.nr, aop.sext); + t.fab = if aop.na != aop.nb { F::neg_one() } else { F::one() }; // na * (1 - 2 * nb); t.na_fb = if aop.na { @@ -169,6 +204,9 @@ impl Provable for ArithFullSM { let num_drained = std::cmp::min(PROVE_CHUNK_SIZE, inputs.len()); let _drained_inputs = inputs.drain(..num_drained).collect::>(); + let mut all_arith_range_table = Mutex::new(ArithRangeTableInputs::new()); + let mut all_arith_table = Mutex::new(ArithTableInputs::new()); + scope.spawn(move |_| { let mut arith_range_table = ArithRangeTableInputs::new(); let mut arith_table = ArithTableInputs::new(); @@ -177,6 +215,8 @@ impl Provable for ArithFullSM { &mut arith_range_table, &mut arith_table, ); + all_arith_range_table.lock().unwrap().update_with(&arith_range_table); + all_arith_table.lock().unwrap().update_with(&arith_table); // thread_controller.remove_working_thread(); // TODO! Implement prove drained_inputs (a chunk of operations) }); diff --git a/state-machines/arith/src/arith_operation.rs b/state-machines/arith/src/arith_operation.rs index 20ffde86..5333b6c1 100644 --- a/state-machines/arith/src/arith_operation.rs +++ b/state-machines/arith/src/arith_operation.rs @@ -74,9 +74,9 @@ impl fmt::Debug for ArithOperation { f, "range_ab: 0x{0:X} {1}, range_cd:0x{2:X} {3}\n", self.range_ab, - AirthRangeTableHelpers::get_range_name(self.range_ab), + ArithRangeTableHelpers::get_range_name(self.range_ab), self.range_cd, - AirthRangeTableHelpers::get_range_name(self.range_cd) + ArithRangeTableHelpers::get_range_name(self.range_cd) ) } } diff --git a/state-machines/arith/src/arith_operation_test.rs b/state-machines/arith/src/arith_operation_test.rs index c76d56f1..ebb8fb22 100644 --- a/state-machines/arith/src/arith_operation_test.rs +++ b/state-machines/arith/src/arith_operation_test.rs @@ -794,9 +794,6 @@ impl ArithOperationTest { let bus_b_low: u64 = aop.b[0] + CHUNK_SIZE * aop.b[1]; let bus_b_high: u64 = aop.b[2] + CHUNK_SIZE * aop.b[3]; - let res2_low: u64 = aop.d[0] + CHUNK_SIZE * aop.d[1]; - let res2_high: u64 = aop.d[2] + CHUNK_SIZE * aop.d[3]; - let secondary_res: u64 = if aop.main_mul || aop.main_div { 0 } else { 1 }; /* let bus_res_low: u64 = secondary_res * res2_low + (1 - secondary_res) diff --git a/state-machines/arith/src/arith_range_table_helpers.rs b/state-machines/arith/src/arith_range_table_helpers.rs index f02e4ae5..52016d7a 100644 --- a/state-machines/arith/src/arith_range_table_helpers.rs +++ b/state-machines/arith/src/arith_range_table_helpers.rs @@ -4,7 +4,7 @@ const ROWS: usize = 1 << 22; const FULL: u8 = 0x00; const POS: u8 = 0x01; const NEG: u8 = 0x02; -pub struct AirthRangeTableHelpers; +pub struct ArithRangeTableHelpers; const RANGES: [u8; 43] = [ FULL, FULL, FULL, POS, POS, POS, NEG, NEG, NEG, FULL, FULL, FULL, FULL, FULL, FULL, FULL, FULL, @@ -16,7 +16,7 @@ const OFFSETS: [usize; 43] = [ 64, 28, 30, 32, 34, 36, 38, 40, 42, 44, 46, 48, 56, 57, 58, 65, 66, 67, ]; -impl AirthRangeTableHelpers { +impl ArithRangeTableHelpers { pub fn get_range_name(range_index: u8) -> &'static str { match range_index { 0 => "F F F F", @@ -39,22 +39,22 @@ impl AirthRangeTableHelpers { _ => panic!("Invalid range index"), } } - pub fn get_row_chunk_range_check(range_index: u8, value: i64) -> usize { + pub fn get_row_chunk_range_check(range_index: u8, value: u64) -> usize { // F F F + + + - - - F F F F F F F F F + - F + - F + - F F F F F F F F F F F + + + - - - let range_type = RANGES[range_index as usize]; assert!(range_index < 43); - assert!(value >= if range_type == NEG { -0xFFFF } else { 0 }); + assert!(value >= if range_type == NEG { 0x8000 } else { 0 }); assert!( value <= match range_type { FULL => 0xFFFF, - POS => -1, - NEG => 0x7FFF, + POS => 0x7FFF, + NEG => 0xFFFF, _ => panic!("Invalid range type"), } ); OFFSETS[range_index as usize] * 0x8000 - + if range_type == NEG { 0x8000 + value } else { value } as usize + + if range_type == NEG { value - 0x8000 } else { value } as usize } pub fn get_row_carry_range_check(value: i64) -> usize { assert!(value >= -0xEFFFF); @@ -69,14 +69,19 @@ impl ArithRangeTableInputs { pub fn new() -> Self { ArithRangeTableInputs { multiplicity: [0; ROWS] } } - pub fn use_chunk_range_check(&self, range_id: u8, value: i64) { - let row = AirthRangeTableHelpers::get_row_chunk_range_check(range_id, value); + pub fn use_chunk_range_check(&self, range_id: u8, value: u64) { + let row = ArithRangeTableHelpers::get_row_chunk_range_check(range_id, value); self.multiplicity[row as usize]; } pub fn use_carry_range_check(&self, value: i64) { - let row = AirthRangeTableHelpers::get_row_carry_range_check(value); + let row = ArithRangeTableHelpers::get_row_carry_range_check(value); self.multiplicity[row as usize]; } + pub fn update_with(&mut self, other: &Self) { + for i in 0..ROWS { + self.multiplicity[i] += other.multiplicity[i]; + } + } } impl Add for ArithRangeTableInputs { @@ -91,7 +96,7 @@ impl Add for ArithRangeTableInputs { } } -#[cfg(generate_code_arith_range_table)] +#[cfg(feature = "generate_code_arith_range_table")] fn generate_table() { let pattern = "FFF+++---FFFFFFFFF+-F+-F+-FFFFFFFFFFF+++---"; // let mut ranges = [0u8; 43]; diff --git a/state-machines/arith/src/arith_table.rs b/state-machines/arith/src/arith_table.rs index 88fda306..f9d4d4b0 100644 --- a/state-machines/arith/src/arith_table.rs +++ b/state-machines/arith/src/arith_table.rs @@ -3,6 +3,7 @@ use std::sync::{ Arc, Mutex, }; +use crate::arith_constants::*; use p3_field::Field; use proofman::{WitnessComponent, WitnessManager}; use proofman_common::{ExecutionCtx, ProofCtx, SetupCtx}; @@ -51,8 +52,10 @@ impl ArithTableSM { } } pub fn operations() -> Vec { - // TODO: use constants - vec![0xb6, 0xb7, 0xbe, 0xbf] + vec![ + MULU, MULUH, MULSUH, MUL, MULH, MUL_W, DIVU, REMU, DIV, REM, DIVU_W, REMU_W, DIV_W, + REM_W, + ] } } diff --git a/state-machines/arith/src/arith_table_helpers.rs b/state-machines/arith/src/arith_table_helpers.rs index ee1683dc..76b06c31 100644 --- a/state-machines/arith/src/arith_table_helpers.rs +++ b/state-machines/arith/src/arith_table_helpers.rs @@ -3,10 +3,10 @@ use std::ops::Add; const ROWS: usize = 95; const FIRST_OP: u8 = 0xb0; -pub struct AirthTableHelpers; +pub struct ArithTableHelpers; -impl AirthTableHelpers { - pub fn get_row(op: u8, na: u64, nb: u64, np: u64, nr: u64, sext: u64) -> usize { +impl ArithTableHelpers { + pub fn get_row(op: u8, na: bool, nb: bool, np: bool, nr: bool, sext: bool) -> usize { static ARITH_TABLE_ROWS: [i16; 512] = [ 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, @@ -34,7 +34,12 @@ impl AirthTableHelpers { -1, 92, 93, 94, -1, ]; - let index = (op - FIRST_OP) as u64 * 32 + na + nb * 2 + np * 4 + nr * 8 + sext * 16; + let index = (op - FIRST_OP) as u64 * 32 + + na as u64 + + nb as u64 * 2 + + np as u64 * 4 + + nr as u64 * 8 + + sext as u64 * 16; let row = ARITH_TABLE_ROWS[index as usize]; assert!(row >= 0); row as usize @@ -52,10 +57,15 @@ impl ArithTableInputs { pub fn new() -> Self { ArithTableInputs { multiplicity: [0; ROWS] } } - pub fn add_use(&self, op: u8, na: u64, nb: u64, np: u64, nr: u64, sext: u64) { - let row = AirthTableHelpers::get_row(op, na, nb, np, nr, sext); + pub fn add_use(&self, op: u8, na: bool, nb: bool, np: bool, nr: bool, sext: bool) { + let row = ArithTableHelpers::get_row(op, na, nb, np, nr, sext); self.multiplicity[row as usize]; } + pub fn update_with(&mut self, other: &Self) { + for i in 0..ROWS { + self.multiplicity[i] += other.multiplicity[i]; + } + } } impl Add for ArithTableInputs { diff --git a/state-machines/arith/src/lib.rs b/state-machines/arith/src/lib.rs index 348c6ea5..018e78cb 100644 --- a/state-machines/arith/src/lib.rs +++ b/state-machines/arith/src/lib.rs @@ -2,12 +2,14 @@ mod arith; mod arith_constants; mod arith_full; mod arith_operation; -mod arith_operation_test; mod arith_range_table; mod arith_range_table_helpers; mod arith_table; mod arith_table_helpers; +#[cfg(test)] +mod arith_operation_test; + pub use arith::*; pub use arith_constants::*; pub use arith_full::*;