From 64aeeeb9d3fcf6198f2e13d3774de1efe3d66676 Mon Sep 17 00:00:00 2001 From: Wojciech Zmuda Date: Wed, 27 Mar 2024 22:44:03 +0100 Subject: [PATCH] prover: use Keccak256 implementation from Gnark In `keccak` package add a wrapper converting []frontend.Variable to []uints.U8 expected by NewLegacykeccak256. Use the wrapper in insertion and deletion circuits instead of our Keccak256 implementation. Remove the existing Keccak implementation. Add tests of the new implementation. Signed-off-by: Wojciech Zmuda --- prover/deletion_circuit.go | 5 +- prover/insertion_circuit.go | 5 +- prover/keccak/constants.go | 48 ----- prover/keccak/keccak.go | 345 +++-------------------------------- prover/keccak/keccak_test.go | 183 ++++++++----------- 5 files changed, 109 insertions(+), 477 deletions(-) delete mode 100644 prover/keccak/constants.go diff --git a/prover/deletion_circuit.go b/prover/deletion_circuit.go index b81209c..acee799 100644 --- a/prover/deletion_circuit.go +++ b/prover/deletion_circuit.go @@ -49,7 +49,10 @@ func (circuit *DeletionMbuCircuit) Define(api frontend.API) error { bits_post := abstractor.Call1(api, ToReducedBigEndian{Variable: circuit.PostRoot, Size: 256}) bits = append(bits, bits_post...) - hash := keccak.NewKeccak256(api, circuit.BatchSize*32+2*256, bits...) + hash, err := keccak.Keccak256(api, bits) + if err != nil { + return err + } sum := abstractor.Call(api, FromBinaryBigEndian{Variable: hash}) // The same endianness conversion has been performed in the hash generation diff --git a/prover/insertion_circuit.go b/prover/insertion_circuit.go index aaba8e4..eab909a 100644 --- a/prover/insertion_circuit.go +++ b/prover/insertion_circuit.go @@ -50,7 +50,10 @@ func (circuit *InsertionMbuCircuit) Define(api frontend.API) error { bits = append(bits, bits_id...) } - hash := keccak.NewKeccak256(api, (circuit.BatchSize+2)*256+32, bits...) + hash, err := keccak.Keccak256(api, bits) + if err != nil { + return err + } sum := abstractor.Call(api, FromBinaryBigEndian{Variable: hash}) // The same endianness conversion has been performed in the hash generation diff --git a/prover/keccak/constants.go b/prover/keccak/constants.go deleted file mode 100644 index b80559c..0000000 --- a/prover/keccak/constants.go +++ /dev/null @@ -1,48 +0,0 @@ -package keccak - -import ( - "github.com/consensys/gnark/frontend" -) - -var RC = [24][64]frontend.Variable{ - toBits(0x0000000000000001), - toBits(0x0000000000008082), - toBits(0x800000000000808A), - toBits(0x8000000080008000), - toBits(0x000000000000808B), - toBits(0x0000000080000001), - toBits(0x8000000080008081), - toBits(0x8000000000008009), - toBits(0x000000000000008A), - toBits(0x0000000000000088), - toBits(0x0000000080008009), - toBits(0x000000008000000A), - toBits(0x000000008000808B), - toBits(0x800000000000008B), - toBits(0x8000000000008089), - toBits(0x8000000000008003), - toBits(0x8000000000008002), - toBits(0x8000000000000080), - toBits(0x000000000000800A), - toBits(0x800000008000000A), - toBits(0x8000000080008081), - toBits(0x8000000000008080), - toBits(0x0000000080000001), - toBits(0x8000000080008008), -} - -var R = [5][5]int{ - {0, 36, 3, 41, 18}, - {1, 44, 10, 45, 2}, - {62, 6, 43, 15, 61}, - {28, 55, 25, 21, 56}, - {27, 20, 39, 8, 14}, -} - -func toBits(a uint64) [64]frontend.Variable { - var b [64]frontend.Variable - for i := 0; i < 64; i += 1 { - b[i] = (a >> i) & 1 - } - return b -} diff --git a/prover/keccak/keccak.go b/prover/keccak/keccak.go index 2434b96..3ad885a 100644 --- a/prover/keccak/keccak.go +++ b/prover/keccak/keccak.go @@ -1,337 +1,38 @@ package keccak import ( - "math" - "github.com/consensys/gnark/frontend" - "github.com/reilabs/gnark-lean-extractor/v3/abstractor" + "github.com/consensys/gnark/std/hash/sha3" + "github.com/consensys/gnark/std/math/uints" ) -// Implemention of the Keccak in gnark following the specification of the Keccak team -// https://keccak.team/keccak_specs_summary.html - -const laneSize = 64 -const stateSize = 5 -const blockSize = 1088 -const domainSeparatorSize = 8 - -func NewKeccak256(api frontend.API, inputSize int, data ...frontend.Variable) []frontend.Variable { - hash := abstractor.Call1(api, KeccakGadget{ - InputSize: inputSize, - InputData: data, - OutputSize: 256, - Rounds: 24, - BlockSize: blockSize, - RotationOffsets: R, - RoundConstants: RC, - Domain: 0x01, - }) - return hash -} - -func NewSHA3_256(api frontend.API, inputSize int, data ...frontend.Variable) []frontend.Variable { - hash := abstractor.Call1(api, KeccakGadget{ - InputSize: inputSize, - InputData: data, - OutputSize: 256, - Rounds: 24, - BlockSize: blockSize, - RotationOffsets: R, - RoundConstants: RC, - Domain: 0x06, - }) - return hash -} - -func allZeroes(v []frontend.Variable) bool { - for _, v := range v { - if v != 0 { - return false - } - } - return true -} - -type KeccakRound struct { - A [][][]frontend.Variable - RC [laneSize]frontend.Variable - RotationOffsets [5][5]int -} - -func (g KeccakRound) DefineGadget(api frontend.API) interface{} { - // C[x] = A[x,0] xor A[x,1] xor A[x,2] xor A[x,3] xor A[x,4], for x in 0…4 - C := make([][]frontend.Variable, stateSize) - for i := 0; i < int(stateSize); i++ { - C[i] = make([]frontend.Variable, laneSize) - } - - for x := 0; x < stateSize; x += 1 { - C[x] = abstractor.Call1(api, Xor5{g.A[x][0], g.A[x][1], g.A[x][2], g.A[x][3], g.A[x][4]}) - } - - // D[x] = C[x-1] xor rot(C[x+1],1), for x in 0…4 - D := make([][]frontend.Variable, stateSize) - for i := 0; i < int(stateSize); i++ { - D[i] = make([]frontend.Variable, laneSize) - } - for x := 0; x < stateSize; x += 1 { - tmp := abstractor.Call1(api, Rot{C[(x+1)%stateSize], 1}) - D[x] = abstractor.Call1(api, Xor{C[(x+4)%stateSize], tmp}) - } - - // A[x,y] = A[x,y] xor D[x], for x in 0…4 and y in 0…4 - for x := 0; x < stateSize; x += 1 { - for y := 0; y < stateSize; y += 1 { - g.A[x][y] = abstractor.Call1(api, Xor{g.A[x][y], D[x]}) - } - } - - // B[y,2*x+3*y] = rot(A[x,y], r[x,y]), for (x,y) in (0…4,0…4) - B := make([][][]frontend.Variable, stateSize) - for x := 0; x < int(stateSize); x++ { - B[x] = make([][]frontend.Variable, stateSize) - for y := 0; y < int(stateSize); y++ { - B[x][y] = make([]frontend.Variable, laneSize) - } - } - for x := 0; x < stateSize; x += 1 { - for y := 0; y < stateSize; y += 1 { - B[y][(2*x+3*y)%stateSize] = abstractor.Call1(api, Rot{g.A[x][y], g.RotationOffsets[x][y]}) - } - } - - // A[x,y] = B[x,y] xor ((not B[x+1,y]) and B[x+2,y]), for x in 0…4 and y in 0…4 - for x := 0; x < stateSize; x += 1 { - for y := 0; y < stateSize; y += 1 { - left := abstractor.Call1(api, Not{B[(x+1)%stateSize][y]}) - right := B[(x+2)%stateSize][y] - tmp := abstractor.Call1(api, And{left, right}) - g.A[x][y] = abstractor.Call1(api, Xor{B[x][y], tmp}) - } - } - - // A[0,0] = A[0,0] xor RC - g.A[0][0] = abstractor.Call1(api, Xor{g.A[0][0], g.RC[:]}) - - return g.A -} - -type KeccakF struct { - A [][][]frontend.Variable - Rounds int - RotationOffsets [5][5]int - RoundConstants [24][64]frontend.Variable -} - -func (g KeccakF) DefineGadget(api frontend.API) interface{} { - for i := 0; i < g.Rounds; i += 1 { - g.A = abstractor.Call3(api, KeccakRound{ - A: g.A, - RC: g.RoundConstants[i], - RotationOffsets: g.RotationOffsets, - }) - } - return g.A -} - -type KeccakGadget struct { - InputSize int - InputData []frontend.Variable - OutputSize int - Rounds int - BlockSize int - RotationOffsets [5][5]int - RoundConstants [24][64]frontend.Variable - Domain int -} - -func (g KeccakGadget) DefineGadget(api frontend.API) interface{} { - // Padding - paddedSize := int(math.Ceil(float64(g.InputSize+domainSeparatorSize)/float64(g.BlockSize))) * g.BlockSize - if len(g.InputData) == 0 { - paddedSize = g.BlockSize - } - - P := make([]frontend.Variable, paddedSize) - for i := 0; i < len(g.InputData); i += 1 { - P[i] = g.InputData[i] - } - - // write domain separator - for i := 0; i < domainSeparatorSize; i += 1 { - P[i+len(g.InputData)] = (g.Domain >> i) & 1 - } - - // fill with zero bytes - for i := len(g.InputData) + domainSeparatorSize; i < len(P); i += 1 { - P[i] = 0 - } - - tmp := make([]frontend.Variable, len(P)) - for i := 0; i < len(P)-1; i += 1 { - tmp[i] = 0 +func Keccak256(api frontend.API, data []frontend.Variable) (hash []frontend.Variable, err error) { + // Pad bits with frontend.Variable(0) until len(data) is a multiple of 8 + if len(data) % 8 != 0 { + padSize := 8 - (len(data) % 8) + data = append(data, make([]frontend.Variable, padSize)...) } - // set last byte to 0x80 - tmp[len(P)-1] = 1 - for i := 0; i < len(P); i += 1 { - if tmp[i] != 0 { - P[i] = api.Xor(P[i], tmp[i]) - } + // Convert bits to slice of uint8 + var input []uints.U8 + for i := 0; i < len(data); i += 8 { + byteSlice := data[i : i+8] + byteFv := api.FromBinary(byteSlice...) + byteU8 := uints.U8{Val: byteFv} + input = append(input, byteU8) } - // Initialization - S := make([][][]frontend.Variable, stateSize) - for x := 0; x < int(stateSize); x++ { - S[x] = make([][]frontend.Variable, stateSize) - for y := 0; y < int(stateSize); y++ { - S[x][y] = make([]frontend.Variable, laneSize) - } + h, err := sha3.NewLegacyKeccak256(api) + if err != nil { + return nil, err } + h.Write(input) - for i := 0; i < stateSize; i += 1 { - for j := 0; j < stateSize; j += 1 { - for k := 0; k < laneSize; k += 1 { - S[i][j][k] = 0 - } - } + // Convert slice of uint8 to one variable + for _, sumByte := range h.Sum() { + sumBits := api.ToBinary(sumByte.Val, 8) + hash = append(hash, sumBits...) } - // Absorbing phase - for i := 0; i < len(P); i += g.BlockSize { - for x := 0; x < stateSize; x += 1 { - for y := 0; y < stateSize; y += 1 { - if x+5*y < g.BlockSize/laneSize { - //var Pi [laneSize]frontend.Variable - Pi := make([]frontend.Variable, laneSize) - copy(Pi[:], P[i+(x+5*y)*laneSize:i+(x+5*y+1)*laneSize]) - if allZeroes(S[x][y]) { - S[x][y] = Pi - continue - } - if allZeroes(Pi) { - continue - } - S[x][y] = abstractor.Call1(api, Xor{S[x][y], Pi}) - } - } - } - S = abstractor.Call3(api, KeccakF{ - A: S, - Rounds: g.Rounds, - RotationOffsets: g.RotationOffsets, - RoundConstants: g.RoundConstants, - }) - } - - // Squeezing phase - var Z []frontend.Variable - i := 0 - for i < g.OutputSize { - for x := 0; x < stateSize; x += 1 { - for y := 0; y < stateSize; y += 1 { - if i < g.OutputSize && x+5*y < g.BlockSize/laneSize { - Z = append(Z, S[y][x][:]...) - i += laneSize - } - } - } - if i < g.OutputSize-laneSize { - S = abstractor.Call3(api, KeccakF{ - A: S, - Rounds: g.Rounds, - RotationOffsets: g.RotationOffsets, - RoundConstants: g.RoundConstants, - }) - } - } - - return Z -} - -/////////////////////////////////////////////////////////////////////////////////////////// -/// Helpers for various binary operations -/////////////////////////////////////////////////////////////////////////////////////////// - -type Xor5Round struct { - A frontend.Variable - B frontend.Variable - C frontend.Variable - D frontend.Variable - E frontend.Variable -} - -func (g Xor5Round) DefineGadget(api frontend.API) interface{} { - tmp_ab := api.Xor(g.A, g.B) - tmp_abc := api.Xor(g.C, tmp_ab) - tmp_abcd := api.Xor(g.D, tmp_abc) - xor := api.Xor(g.E, tmp_abcd) - return xor -} - -type Xor5 struct { - A []frontend.Variable - B []frontend.Variable - C []frontend.Variable - D []frontend.Variable - E []frontend.Variable -} - -func (g Xor5) DefineGadget(api frontend.API) interface{} { - var c [laneSize]frontend.Variable - for i := 0; i < len(g.A); i += 1 { - c[i] = abstractor.Call(api, Xor5Round{g.A[i], g.B[i], g.C[i], g.D[i], g.E[i]}) - } - return c[:] -} - -type Xor struct { - A []frontend.Variable - B []frontend.Variable -} - -func (g Xor) DefineGadget(api frontend.API) interface{} { - var c [laneSize]frontend.Variable - for i := 0; i < len(g.A); i += 1 { - c[i] = api.Xor(g.A[i], g.B[i]) - } - return c[:] -} - -type Rot struct { - A []frontend.Variable - R int -} - -func (g Rot) DefineGadget(api frontend.API) interface{} { - var c [laneSize]frontend.Variable - for i := 0; i < len(g.A); i += 1 { - c[i] = g.A[(i+(laneSize-g.R))%len(g.A)] - } - return c[:] -} - -type And struct { - A []frontend.Variable - B []frontend.Variable -} - -func (g And) DefineGadget(api frontend.API) interface{} { - var c [laneSize]frontend.Variable - for i := 0; i < len(g.A); i += 1 { - c[i] = api.And(g.A[i], g.B[i]) - } - return c[:] -} - -type Not struct { - A []frontend.Variable -} - -func (g Not) DefineGadget(api frontend.API) interface{} { - var c [laneSize]frontend.Variable - for i := 0; i < len(g.A); i += 1 { - c[i] = api.Sub(1, g.A[i]) - } - return c[:] + return hash, nil } diff --git a/prover/keccak/keccak_test.go b/prover/keccak/keccak_test.go index c326e73..1c9fe11 100644 --- a/prover/keccak/keccak_test.go +++ b/prover/keccak/keccak_test.go @@ -1,127 +1,100 @@ package keccak import ( - "math/big" + "encoding/hex" + "fmt" + "strings" "testing" "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/backend" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/test" ) -type TestKeccakCircuit1 struct { - Input [8]frontend.Variable `gnark:"input"` - Hash frontend.Variable `gnark:",public"` +type testCircuit struct { + In []frontend.Variable + Expected []frontend.Variable } -func (circuit *TestKeccakCircuit1) Define(api frontend.API) error { - hash := NewKeccak256(api, len(circuit.Input), circuit.Input[:]...) - sum := api.FromBinary(hash...) - api.AssertIsEqual(circuit.Hash, sum) - return nil -} - -type TestKeccakCircuit2 struct { - Input [0]frontend.Variable `gnark:"input"` - Hash frontend.Variable `gnark:",public"` -} - -func (circuit *TestKeccakCircuit2) Define(api frontend.API) error { - hash := NewKeccak256(api, 0) - sum := api.FromBinary(hash...) - api.AssertIsEqual(circuit.Hash, sum) - return nil -} - -type TestKeccakCircuitBlockSize struct { - Input [blockSize]frontend.Variable `gnark:"input"` - Hash frontend.Variable `gnark:",public"` -} - -func (circuit *TestKeccakCircuitBlockSize) Define(api frontend.API) error { - hash := NewKeccak256(api, len(circuit.Input), circuit.Input[:]...) - sum := api.FromBinary(hash...) - api.AssertIsEqual(circuit.Hash, sum) - return nil -} - -type TestSHACircuit struct { - Input [0]frontend.Variable `gnark:"input"` - Hash frontend.Variable `gnark:",public"` -} +func (c *testCircuit) Define(api frontend.API) error { + sum, err := Keccak256(api, c.In) + if err != nil { + return err + } -func (circuit *TestSHACircuit) Define(api frontend.API) error { - hash := NewSHA3_256(api, 0) - sum := api.FromBinary(hash...) - api.AssertIsEqual(circuit.Hash, sum) + for i := range c.Expected { + api.AssertIsEqual(c.Expected[i], sum[i]) + } return nil } -func TestKeccak(t *testing.T) { - assert := test.NewAssert(t) - - // Keccak: hash zero byte - var circuit1 TestKeccakCircuit1 - assert.ProverSucceeded(&circuit1, &TestKeccakCircuit1{ - Input: [8]frontend.Variable{0, 0, 0, 0, 0, 0, 0, 0}, - Hash: bigIntLE("0xbc36789e7a1e281436464229828f817d6612f7b477d66591ff96a9e064bcc98a"), - }, test.WithBackends(backend.GROTH16), test.WithCurves(ecc.BN254)) - - // Keccak: hash empty input - var circuit2 TestKeccakCircuit2 - assert.ProverSucceeded(&circuit2, &TestKeccakCircuit2{ - Input: [0]frontend.Variable{}, - Hash: bigIntLE("0xc5d2460186f7233c927e7db2dcc703c0e500b653ca82273b7bfad8045d85a470"), - }, test.WithBackends(backend.GROTH16), test.WithCurves(ecc.BN254)) - - // Keccak: input equal block size - var circuit3 TestKeccakCircuitBlockSize - var inputArray [1088]frontend.Variable - fillArray(1, inputArray[:]) - assert.ProverSucceeded( - &circuit3, &TestKeccakCircuitBlockSize{ - Input: inputArray, - Hash: bigIntLE("0x2d417340362cd4144efbf52adc1bfb7a4b40254f55f3b0f09efa6a1ef299b51a"), - }, test.WithBackends(backend.GROTH16), test.WithCurves(ecc.BN254), - ) - - // SHA3: hash empty input - var circuit4 TestSHACircuit - assert.ProverSucceeded(&circuit4, &TestSHACircuit{ - Input: [0]frontend.Variable{}, - Hash: bigIntLE("0xa7ffc6f8bf1ed76651c14756a061d662f580ff4de43b49fa82d80a4b80f8434a"), - }, test.WithBackends(backend.GROTH16), test.WithCurves(ecc.BN254)) - -} - -// we need to feed in the hash in little endian -func bigIntLE(s string) big.Int { - var bi big.Int - bi.SetString(s, 0) - - b := bi.Bytes() - for i := 0; i < len(b)/2; i++ { - b[i], b[len(b)-i-1] = b[len(b)-i-1], b[i] +func TestKeccak256(t *testing.T) { + // Helper function to convert a hex string to []frontend.Variable, 1 bit per element + hexToBits := func(hexStr string) ([]frontend.Variable, error) { + bytes, err := hex.DecodeString(hexStr) + if err != nil { + return nil, err + } + vars := make([]frontend.Variable, len(bytes)*8) + for i, b := range bytes { + for j := 0; j < 8; j++ { + vars[i*8+j] = frontend.Variable((b >> j) & 1) + } + } + return vars, nil } - bi.SetBytes(b) - - // Reduce the number by BN254 group order, because the circuit does the same - modulus, ok := new(big.Int).SetString( - "21888242871839275222246405745257275088548364400416034343698204186575808495617", 10, - ) - if !ok { - panic("can't set big int to BN254 group order") + // Helper function to generate hex string in the format of b repeated count times + var repeatHex = func(b byte, count int) string { + hexStr := fmt.Sprintf("%02x", b) + return strings.Repeat(hexStr, count) } - bi.Mod(&bi, modulus) - return bi -} + // Table driven test cases + testCases := []struct { + input string + expected string + }{ + // Test vectors from https://bob.nem.ninja/test-vectors.html + {"", "C5D2460186F7233C927E7DB2DCC703C0E500B653CA82273B7BFAD8045D85A470"}, + {"CC", "EEAD6DBFC7340A56CAEDC044696A168870549A6A7F6F56961E84A54BD9970B8A"}, + {"41FB", "A8EACEDA4D47B3281A795AD9E1EA2122B407BAF9AABCB9E18B5717B7873537D2"}, + {"1F877C", "627D7BC1491B2AB127282827B8DE2D276B13D7D70FB4C5957FDF20655BC7AC30"}, + {"C1ECFDFC", "B149E766D7612EAF7D55F74E1A4FDD63709A8115B14F61FCD22AA4ABC8B8E122"}, + {"9F2FCC7C90DE090D6B87CD7E9718C1EA6CB21118FC2D5DE9F97E5DB6AC1E9C10", + "24DD2EE02482144F539F810D2CAA8A7B75D0FA33657E47932122D273C3F6F6D1"}, + + // Other test vectors verified against https://emn178.github.io/online-tools/keccak_256.html + {"00", "bc36789e7a1e281436464229828f817d6612f7b477d66591ff96a9e064bcc98a"}, + {repeatHex(0x00, 8), "011b4d03dd8c01f1049143cf9c4c817e4b167f1d1b83e5c6f0f10d89ba1e7bce"}, + {repeatHex(0xaa, 50), "04b992b0fda7cc35cb6c2ae5423b463e8f519efd70d8bab8394c1cd42839c2e2"}, + {repeatHex(0xfd, 640), "e1eb3e4b14a80a72d3decb952300b2efe0616341e0e55d98f60669873eb43d4d"}, + {repeatHex(0xcb, 1088), "5d89305ddc9240e623acebd3050d80102a35e7be3023314aff13bb0be19c0653"}, + } -// Fill an array with a specific value. -func fillArray(value int, inputArray []frontend.Variable) { - for i := range inputArray { - inputArray[i] = value + for _, tc := range testCases { + in, err := hexToBits(tc.input) + if err != nil { + t.Fatal(err) + } + + expected, err := hexToBits(tc.expected) + if err != nil { + t.Fatal(err) + } + + circuit := &testCircuit{ + In: in, + Expected: expected, + } + + witness := &testCircuit{ + In: in, + Expected: expected, + } + + if err := test.IsSolved(circuit, witness, ecc.BN254.ScalarField()); err != nil { + t.Fatalf("Test case with input '%s' failed: %s", tc.input, err) + } } -} +} \ No newline at end of file