Skip to content

Commit

Permalink
Implement UNSAFE_KECCAK_FINALIZE hint (#271)
Browse files Browse the repository at this point in the history
* Implement UnsafeKeccak

* Update dependencies

* Add unit test

* Fix hash

* Add unit tests

* Add integration test

* Add missing file

* Add constant + GetStructFieldRelocatable + start hint

* Add test file

* Add MemorySegmentManager.GetFeltRange

* Progress

* Finish hint

* Add unit test

* Add integration test
  • Loading branch information
fmoletta authored Sep 20, 2023
1 parent 5cfe5fe commit 1cdb017
Show file tree
Hide file tree
Showing 7 changed files with 155 additions and 5 deletions.
28 changes: 28 additions & 0 deletions cairo_programs/unsafe_keccak_finalize.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
%builtins output

from starkware.cairo.common.alloc import alloc
from starkware.cairo.common.serialize import serialize_word
from starkware.cairo.common.keccak import unsafe_keccak_finalize, KeccakState
from starkware.cairo.common.uint256 import Uint256

func main{output_ptr: felt*}() {
alloc_locals;

let (data: felt*) = alloc();

assert data[0] = 0;
assert data[1] = 1;
assert data[2] = 2;

let keccak_state = KeccakState(start_ptr=data, end_ptr=data + 2);

let res: Uint256 = unsafe_keccak_finalize(keccak_state);

assert res.low = 17219183504112405672555532996650339574;
assert res.high = 235346966651632113557018504892503714354;

serialize_word(res.low);
serialize_word(res.high);

return ();
}
2 changes: 2 additions & 0 deletions pkg/hints/hint_codes/keccak_hint_codes.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
package hint_codes

const UNSAFE_KECCAK = "from eth_hash.auto import keccak\n\ndata, length = ids.data, ids.length\n\nif '__keccak_max_size' in globals():\n assert length <= __keccak_max_size, \\\n f'unsafe_keccak() can only be used with length<={__keccak_max_size}. ' \\\n f'Got: length={length}.'\n\nkeccak_input = bytearray()\nfor word_i, byte_i in enumerate(range(0, length, 16)):\n word = memory[data + word_i]\n n_bytes = min(16, length - byte_i)\n assert 0 <= word < 2 ** (8 * n_bytes)\n keccak_input += word.to_bytes(n_bytes, 'big')\n\nhashed = keccak(keccak_input)\nids.high = int.from_bytes(hashed[:16], 'big')\nids.low = int.from_bytes(hashed[16:32], 'big')"

const UNSAFE_KECCAK_FINALIZE = "from eth_hash.auto import keccak\nkeccak_input = bytearray()\nn_elms = ids.keccak_state.end_ptr - ids.keccak_state.start_ptr\nfor word in memory.get_range(ids.keccak_state.start_ptr, n_elms):\n keccak_input += word.to_bytes(16, 'big')\nhashed = keccak(keccak_input)\nids.high = int.from_bytes(hashed[:16], 'big')\nids.low = int.from_bytes(hashed[16:32], 'big')"
2 changes: 2 additions & 0 deletions pkg/hints/hint_processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ func (p *CairoVmHintProcessor) ExecuteHint(vm *vm.VirtualMachine, hintData *any,
return vm_enter_scope(execScopes)
case UNSAFE_KECCAK:
return unsafeKeccak(data.Ids, vm, *execScopes)
case UNSAFE_KECCAK_FINALIZE:
return unsafeKeccakFinalize(data.Ids, vm)
case UNSIGNED_DIV_REM:
return unsignedDivRem(data.Ids, vm)
case SIGNED_DIV_REM:
Expand Down
29 changes: 29 additions & 0 deletions pkg/hints/hint_utils/ids_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,35 @@ func (ids *IdsManager) GetStructFieldFelt(name string, field_off uint, vm *Virtu
return lambdaworks.Felt{}, ErrUnknownIdentifier(name)
}

/*
Returns the value of an ids' field (given that the identifier is a sruct) as a Relocatable
For example:
struct shelter {
cats cat*
dogs dog*
}
to access each struct field, cats will be field 0 and dogs will be field 1, so to access them we can use:
ids_cats := ids.GetStructFieldFelt("shelter", 0, vm) or ids_cats := ids.Get("shelter", vm)
ids_dogs := ids.GetStructFieldFelt("shelter", 1, vm)
*/
func (ids *IdsManager) GetStructFieldRelocatable(name string, field_off uint, vm *VirtualMachine) (Relocatable, error) {
reference, ok := ids.References[name]
if ok {
val, ok := getStructFieldFromReference(&reference, field_off, ids.HintApTracking, vm)
if ok {
rel, is_rel := val.GetRelocatable()
if !is_rel {
return Relocatable{}, errors.Errorf("Identifier %s is not a Relocatable", name)
}
return rel, nil
}
}

return Relocatable{}, ErrUnknownIdentifier(name)
}

/*
Inserts value into an ids' field (given that the identifier is a sruct)
For example:
Expand Down
46 changes: 46 additions & 0 deletions pkg/hints/keccak_hints.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,49 @@ func unsafeKeccak(ids IdsManager, vm *VirtualMachine, scopes ExecutionScopes) er
}
return ids.Insert("low", NewMaybeRelocatableFelt(low), vm)
}

func unsafeKeccakFinalize(ids IdsManager, vm *VirtualMachine) error {
// Fetch ids variables
startPtr, err := ids.GetStructFieldRelocatable("keccak_state", 0, vm)
if err != nil {
return err
}
endPtr, err := ids.GetStructFieldRelocatable("keccak_state", 1, vm)
if err != nil {
return err
}

// Hint Logic
nElemsFelt, err := endPtr.Sub(startPtr)
if err != nil {
return err
}
nElems, err := nElemsFelt.ToU64()
if err != nil {
return err
}
inputFelts, err := vm.Segments.GetFeltRange(startPtr, uint(nElems))
if err != nil {
return err
}
inputBytes := make([]byte, 0, 16*nElems)
for i := 0; i < int(nElems); i++ {
inputBytes = append(inputBytes, inputFelts[i].ToBeBytes()[16:]...)
}

hasher := keccak.New256()
hasher.Write(inputBytes)
resBytes := hasher.Sum(nil)

highBytes := append([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, resBytes[:16]...)
lowBytes := append([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, resBytes[16:32]...)

high := FeltFromBeBytes((*[32]byte)(highBytes))
low := FeltFromBeBytes((*[32]byte)(lowBytes))

err = ids.Insert("high", NewMaybeRelocatableFelt(high), vm)
if err != nil {
return err
}
return ids.Insert("low", NewMaybeRelocatableFelt(low), vm)
}
43 changes: 43 additions & 0 deletions pkg/hints/keccak_hints_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,3 +123,46 @@ func TestUnsafeKeccakInvalidWordSize(t *testing.T) {
t.Errorf("UNSAFE_KECCAK hint test should have failed")
}
}

func TestUnsafeKeccakFinalizeOk(t *testing.T) {
vm := NewVirtualMachine()
vm.Segments.AddSegment()
inputStart := vm.Segments.AddSegment()
idsManager := SetupIdsForTest(
map[string][]*MaybeRelocatable{
"keccak_state": {
NewMaybeRelocatableRelocatable(inputStart),
NewMaybeRelocatableRelocatable(inputStart.AddUint(2)),
},
"high": {nil},
"low": {nil},
},
vm,
)
// Insert keccak input into memory
input := []MaybeRelocatable{
*NewMaybeRelocatableFelt(FeltZero()),
*NewMaybeRelocatableFelt(FeltOne()),
}
vm.Segments.LoadData(inputStart, &input)
hintProcessor := CairoVmHintProcessor{}
hintData := any(HintData{
Ids: idsManager,
Code: UNSAFE_KECCAK_FINALIZE,
})
err := hintProcessor.ExecuteHint(vm, &hintData, nil, nil)
if err != nil {
t.Errorf("UNSAFE_KECCAK_FINALIZE hint test failed with error %s", err)
}
// Check ids values
high, err := idsManager.GetFelt("high", vm)
expectedHigh := FeltFromDecString("235346966651632113557018504892503714354")
if err != nil || high != expectedHigh {
t.Errorf("Wrong/No ids.high.\n Expected %s, got %s.", expectedHigh.ToHexString(), high.ToHexString())
}
low, err := idsManager.GetFelt("low", vm)
expectedLow := FeltFromDecString("17219183504112405672555532996650339574")
if err != nil || low != expectedLow {
t.Errorf("Wrong/No ids.low\n Expected %s, got %s.", expectedLow.ToHexString(), low.ToHexString())
}
}
10 changes: 5 additions & 5 deletions pkg/vm/cairo_run/cairo_run_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,11 +132,11 @@ func TestSqrtHint(t *testing.T) {
}

func TestUnsafeKeccak(t *testing.T) {
cairoRunConfig := cairo_run.CairoRunConfig{DisableTracePadding: false, Layout: "all_cairo", ProofMode: false}
_, err := cairo_run.CairoRun("../../../cairo_programs/unsafe_keccak.json", cairoRunConfig)
if err != nil {
t.Errorf("Program execution failed with error: %s", err)
}
testProgram("unsafe_keccak", t)
}

func TestUnsafeKeccakFinalize(t *testing.T) {
testProgram("unsafe_keccak_finalize", t)
}

func TestUnsignedDivRemHint(t *testing.T) {
Expand Down

0 comments on commit 1cdb017

Please sign in to comment.