From 1cdb01750ee2fea1443bc43a6172b22c0ee55945 Mon Sep 17 00:00:00 2001 From: fmoletta <99273364+fmoletta@users.noreply.github.com> Date: Wed, 20 Sep 2023 23:27:02 +0300 Subject: [PATCH] Implement `UNSAFE_KECCAK_FINALIZE` hint (#271) * Implement UnsafeKeccak * Update dependencies * Add unit test * Fix hash * Add unit tests * Add integration test * Add missing file * Add constant + GetStructFieldRelocatable + start hint * Add test file * Add MemorySegmentManager.GetFeltRange * Progress * Finish hint * Add unit test * Add integration test --- cairo_programs/unsafe_keccak_finalize.cairo | 28 +++++++++++++ pkg/hints/hint_codes/keccak_hint_codes.go | 2 + pkg/hints/hint_processor.go | 2 + pkg/hints/hint_utils/ids_manager.go | 29 +++++++++++++ pkg/hints/keccak_hints.go | 46 +++++++++++++++++++++ pkg/hints/keccak_hints_test.go | 43 +++++++++++++++++++ pkg/vm/cairo_run/cairo_run_test.go | 10 ++--- 7 files changed, 155 insertions(+), 5 deletions(-) create mode 100644 cairo_programs/unsafe_keccak_finalize.cairo diff --git a/cairo_programs/unsafe_keccak_finalize.cairo b/cairo_programs/unsafe_keccak_finalize.cairo new file mode 100644 index 00000000..6cf89f88 --- /dev/null +++ b/cairo_programs/unsafe_keccak_finalize.cairo @@ -0,0 +1,28 @@ +%builtins output + +from starkware.cairo.common.alloc import alloc +from starkware.cairo.common.serialize import serialize_word +from starkware.cairo.common.keccak import unsafe_keccak_finalize, KeccakState +from starkware.cairo.common.uint256 import Uint256 + +func main{output_ptr: felt*}() { + alloc_locals; + + let (data: felt*) = alloc(); + + assert data[0] = 0; + assert data[1] = 1; + assert data[2] = 2; + + let keccak_state = KeccakState(start_ptr=data, end_ptr=data + 2); + + let res: Uint256 = unsafe_keccak_finalize(keccak_state); + + assert res.low = 17219183504112405672555532996650339574; + assert res.high = 235346966651632113557018504892503714354; + + serialize_word(res.low); + serialize_word(res.high); + + return (); +} diff --git a/pkg/hints/hint_codes/keccak_hint_codes.go b/pkg/hints/hint_codes/keccak_hint_codes.go index 2db053ca..c9cc1ea1 100644 --- a/pkg/hints/hint_codes/keccak_hint_codes.go +++ b/pkg/hints/hint_codes/keccak_hint_codes.go @@ -1,3 +1,5 @@ package hint_codes const UNSAFE_KECCAK = "from eth_hash.auto import keccak\n\ndata, length = ids.data, ids.length\n\nif '__keccak_max_size' in globals():\n assert length <= __keccak_max_size, \\\n f'unsafe_keccak() can only be used with length<={__keccak_max_size}. ' \\\n f'Got: length={length}.'\n\nkeccak_input = bytearray()\nfor word_i, byte_i in enumerate(range(0, length, 16)):\n word = memory[data + word_i]\n n_bytes = min(16, length - byte_i)\n assert 0 <= word < 2 ** (8 * n_bytes)\n keccak_input += word.to_bytes(n_bytes, 'big')\n\nhashed = keccak(keccak_input)\nids.high = int.from_bytes(hashed[:16], 'big')\nids.low = int.from_bytes(hashed[16:32], 'big')" + +const UNSAFE_KECCAK_FINALIZE = "from eth_hash.auto import keccak\nkeccak_input = bytearray()\nn_elms = ids.keccak_state.end_ptr - ids.keccak_state.start_ptr\nfor word in memory.get_range(ids.keccak_state.start_ptr, n_elms):\n keccak_input += word.to_bytes(16, 'big')\nhashed = keccak(keccak_input)\nids.high = int.from_bytes(hashed[:16], 'big')\nids.low = int.from_bytes(hashed[16:32], 'big')" diff --git a/pkg/hints/hint_processor.go b/pkg/hints/hint_processor.go index b571bc83..75cbc952 100644 --- a/pkg/hints/hint_processor.go +++ b/pkg/hints/hint_processor.go @@ -108,6 +108,8 @@ func (p *CairoVmHintProcessor) ExecuteHint(vm *vm.VirtualMachine, hintData *any, return vm_enter_scope(execScopes) case UNSAFE_KECCAK: return unsafeKeccak(data.Ids, vm, *execScopes) + case UNSAFE_KECCAK_FINALIZE: + return unsafeKeccakFinalize(data.Ids, vm) case UNSIGNED_DIV_REM: return unsignedDivRem(data.Ids, vm) case SIGNED_DIV_REM: diff --git a/pkg/hints/hint_utils/ids_manager.go b/pkg/hints/hint_utils/ids_manager.go index adfb163e..ef2550b5 100644 --- a/pkg/hints/hint_utils/ids_manager.go +++ b/pkg/hints/hint_utils/ids_manager.go @@ -166,6 +166,35 @@ func (ids *IdsManager) GetStructFieldFelt(name string, field_off uint, vm *Virtu return lambdaworks.Felt{}, ErrUnknownIdentifier(name) } +/* + Returns the value of an ids' field (given that the identifier is a sruct) as a Relocatable + For example: + + struct shelter { + cats cat* + dogs dog* + } + + to access each struct field, cats will be field 0 and dogs will be field 1, so to access them we can use: + ids_cats := ids.GetStructFieldFelt("shelter", 0, vm) or ids_cats := ids.Get("shelter", vm) + ids_dogs := ids.GetStructFieldFelt("shelter", 1, vm) +*/ +func (ids *IdsManager) GetStructFieldRelocatable(name string, field_off uint, vm *VirtualMachine) (Relocatable, error) { + reference, ok := ids.References[name] + if ok { + val, ok := getStructFieldFromReference(&reference, field_off, ids.HintApTracking, vm) + if ok { + rel, is_rel := val.GetRelocatable() + if !is_rel { + return Relocatable{}, errors.Errorf("Identifier %s is not a Relocatable", name) + } + return rel, nil + } + } + + return Relocatable{}, ErrUnknownIdentifier(name) +} + /* Inserts value into an ids' field (given that the identifier is a sruct) For example: diff --git a/pkg/hints/keccak_hints.go b/pkg/hints/keccak_hints.go index 615e5d06..969ccdba 100644 --- a/pkg/hints/keccak_hints.go +++ b/pkg/hints/keccak_hints.go @@ -71,3 +71,49 @@ func unsafeKeccak(ids IdsManager, vm *VirtualMachine, scopes ExecutionScopes) er } return ids.Insert("low", NewMaybeRelocatableFelt(low), vm) } + +func unsafeKeccakFinalize(ids IdsManager, vm *VirtualMachine) error { + // Fetch ids variables + startPtr, err := ids.GetStructFieldRelocatable("keccak_state", 0, vm) + if err != nil { + return err + } + endPtr, err := ids.GetStructFieldRelocatable("keccak_state", 1, vm) + if err != nil { + return err + } + + // Hint Logic + nElemsFelt, err := endPtr.Sub(startPtr) + if err != nil { + return err + } + nElems, err := nElemsFelt.ToU64() + if err != nil { + return err + } + inputFelts, err := vm.Segments.GetFeltRange(startPtr, uint(nElems)) + if err != nil { + return err + } + inputBytes := make([]byte, 0, 16*nElems) + for i := 0; i < int(nElems); i++ { + inputBytes = append(inputBytes, inputFelts[i].ToBeBytes()[16:]...) + } + + hasher := keccak.New256() + hasher.Write(inputBytes) + resBytes := hasher.Sum(nil) + + highBytes := append([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, resBytes[:16]...) + lowBytes := append([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, resBytes[16:32]...) + + high := FeltFromBeBytes((*[32]byte)(highBytes)) + low := FeltFromBeBytes((*[32]byte)(lowBytes)) + + err = ids.Insert("high", NewMaybeRelocatableFelt(high), vm) + if err != nil { + return err + } + return ids.Insert("low", NewMaybeRelocatableFelt(low), vm) +} diff --git a/pkg/hints/keccak_hints_test.go b/pkg/hints/keccak_hints_test.go index 020412b9..6f20d8e4 100644 --- a/pkg/hints/keccak_hints_test.go +++ b/pkg/hints/keccak_hints_test.go @@ -123,3 +123,46 @@ func TestUnsafeKeccakInvalidWordSize(t *testing.T) { t.Errorf("UNSAFE_KECCAK hint test should have failed") } } + +func TestUnsafeKeccakFinalizeOk(t *testing.T) { + vm := NewVirtualMachine() + vm.Segments.AddSegment() + inputStart := vm.Segments.AddSegment() + idsManager := SetupIdsForTest( + map[string][]*MaybeRelocatable{ + "keccak_state": { + NewMaybeRelocatableRelocatable(inputStart), + NewMaybeRelocatableRelocatable(inputStart.AddUint(2)), + }, + "high": {nil}, + "low": {nil}, + }, + vm, + ) + // Insert keccak input into memory + input := []MaybeRelocatable{ + *NewMaybeRelocatableFelt(FeltZero()), + *NewMaybeRelocatableFelt(FeltOne()), + } + vm.Segments.LoadData(inputStart, &input) + hintProcessor := CairoVmHintProcessor{} + hintData := any(HintData{ + Ids: idsManager, + Code: UNSAFE_KECCAK_FINALIZE, + }) + err := hintProcessor.ExecuteHint(vm, &hintData, nil, nil) + if err != nil { + t.Errorf("UNSAFE_KECCAK_FINALIZE hint test failed with error %s", err) + } + // Check ids values + high, err := idsManager.GetFelt("high", vm) + expectedHigh := FeltFromDecString("235346966651632113557018504892503714354") + if err != nil || high != expectedHigh { + t.Errorf("Wrong/No ids.high.\n Expected %s, got %s.", expectedHigh.ToHexString(), high.ToHexString()) + } + low, err := idsManager.GetFelt("low", vm) + expectedLow := FeltFromDecString("17219183504112405672555532996650339574") + if err != nil || low != expectedLow { + t.Errorf("Wrong/No ids.low\n Expected %s, got %s.", expectedLow.ToHexString(), low.ToHexString()) + } +} diff --git a/pkg/vm/cairo_run/cairo_run_test.go b/pkg/vm/cairo_run/cairo_run_test.go index 0de36ae1..aefaa7d0 100644 --- a/pkg/vm/cairo_run/cairo_run_test.go +++ b/pkg/vm/cairo_run/cairo_run_test.go @@ -132,11 +132,11 @@ func TestSqrtHint(t *testing.T) { } func TestUnsafeKeccak(t *testing.T) { - cairoRunConfig := cairo_run.CairoRunConfig{DisableTracePadding: false, Layout: "all_cairo", ProofMode: false} - _, err := cairo_run.CairoRun("../../../cairo_programs/unsafe_keccak.json", cairoRunConfig) - if err != nil { - t.Errorf("Program execution failed with error: %s", err) - } + testProgram("unsafe_keccak", t) +} + +func TestUnsafeKeccakFinalize(t *testing.T) { + testProgram("unsafe_keccak_finalize", t) } func TestUnsignedDivRemHint(t *testing.T) {