Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement DICT_SQUASH_COPY_DICT, DICT_SQUASH_UPDATE_PTR & DICT_NEW #231

Merged
merged 10 commits into from
Sep 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions cairo_programs/dict_squash.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
%builtins range_check

from starkware.cairo.common.dict_access import DictAccess
from starkware.cairo.common.dict import dict_write, dict_update, dict_squash
from starkware.cairo.common.default_dict import default_dict_new

func main{range_check_ptr}() -> () {
let (dict_start) = default_dict_new(17);
let dict_end = dict_start;
dict_write{dict_ptr=dict_end}(0, 1);
dict_write{dict_ptr=dict_end}(1, 10);
dict_write{dict_ptr=dict_end}(2, -2);
dict_update{dict_ptr=dict_end}(0, 1, 2);
dict_update{dict_ptr=dict_end}(0, 2, 3);
dict_update{dict_ptr=dict_end}(0, 3, 4);
dict_update{dict_ptr=dict_end}(1, 10, 15);
dict_update{dict_ptr=dict_end}(1, 15, 20);
dict_update{dict_ptr=dict_end}(1, 20, 25);
dict_update{dict_ptr=dict_end}(2, -2, -4);
dict_update{dict_ptr=dict_end}(2, -4, -8);
dict_update{dict_ptr=dict_end}(2, -8, -16);
let (squashed_dict_start, squashed_dict_end) = dict_squash{range_check_ptr=range_check_ptr}(
dict_start, dict_end
);
assert squashed_dict_end[0] = DictAccess(key=0, prev_value=1, new_value=4);
assert squashed_dict_end[1] = DictAccess(key=1, prev_value=10, new_value=25);
assert squashed_dict_end[2] = DictAccess(key=2, prev_value=-2, new_value=-16);
return ();
}
6 changes: 6 additions & 0 deletions pkg/hints/dict_hint_codes.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,9 @@ const SQUASH_DICT_INNER_LEN_ASSERT = "assert len(current_access_indices) == 0"
const SQUASH_DICT_INNER_USED_ACCESSES_ASSERT = "assert ids.n_used_accesses == len(access_indices[key])"

const SQUASH_DICT_INNER_NEXT_KEY = "assert len(keys) > 0, 'No keys left but remaining_accesses > 0.'\nids.next_key = key = keys.pop()"

const DICT_SQUASH_COPY_DICT = "# Prepare arguments for dict_new. In particular, the same dictionary values should be copied\n# to the new (squashed) dictionary.\nvm_enter_scope({\n # Make __dict_manager accessible.\n '__dict_manager': __dict_manager,\n # Create a copy of the dict, in case it changes in the future.\n 'initial_dict': dict(__dict_manager.get_dict(ids.dict_accesses_end)),\n})"

const DICT_SQUASH_UPDATE_PTR = "# Update the DictTracker's current_ptr to point to the end of the squashed dict.\n__dict_manager.get_tracker(ids.squashed_dict_start).current_ptr = \\\n ids.squashed_dict_end.address_"

const DICT_NEW = "if '__dict_manager' not in globals():\n from starkware.cairo.common.dict import DictManager\n __dict_manager = DictManager()\n\nmemory[ap] = __dict_manager.new_dict(segments, initial_dict)\ndel initial_dict"
67 changes: 67 additions & 0 deletions pkg/hints/dict_hints.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,3 +141,70 @@ func dictUpdate(ids IdsManager, scopes *ExecutionScopes, vm *VirtualMachine) err
tracker.CurrentPtr.Offset += DICT_ACCESS_SIZE
return nil
}

func dictSquashCopyDict(ids IdsManager, scopes *ExecutionScopes, vm *VirtualMachine) error {
// Extract Variables
dictManager, ok := FetchDictManager(scopes)
if !ok {
return errors.New("Variable __dict_manager not present in current execution scope")
}
dictAccessEnd, err := ids.GetRelocatable("dict_accesses_end", vm)
if err != nil {
return err
}
// Hint logic
tracker, err := dictManager.GetTracker(dictAccessEnd)
if err != nil {
return err
}
initialDict := tracker.CopyDictionary()
scopes.EnterScope(map[string]interface{}{
"__dict_manager": dictManager,
"initial_dict": initialDict,
})
return nil
}

func dictSquashUpdatePtr(ids IdsManager, scopes *ExecutionScopes, vm *VirtualMachine) error {
// Extract Variables
dictManager, ok := FetchDictManager(scopes)
if !ok {
return errors.New("Variable __dict_manager not present in current execution scope")
}
squashedDictStart, err := ids.GetRelocatable("squashed_dict_start", vm)
if err != nil {
return err
}
squashedDictEnd, err := ids.GetRelocatable("squashed_dict_end", vm)
if err != nil {
return err
}
// Hint logic
tracker, err := dictManager.GetTracker(squashedDictStart)
if err != nil {
return err
}
tracker.CurrentPtr = squashedDictEnd
return nil
}

func dictNew(ids IdsManager, scopes *ExecutionScopes, vm *VirtualMachine) error {
// Fetch scope variables
initialDictAny, err := scopes.Get("initial_dict")
if err != nil {
return err
}
initialDict, ok := initialDictAny.(map[memory.MaybeRelocatable]memory.MaybeRelocatable)
if !ok {
return errors.New("initial_dict not in scope")
}
// Hint Logic
dictManager, ok := FetchDictManager(scopes)
if !ok {
newDictManager := NewDictManager()
dictManager = &newDictManager
scopes.AssignOrUpdateVariable("__dict_manager", dictManager)
}
dict_ptr := dictManager.NewDictionary(&initialDict, vm)
return vm.Segments.Memory.Insert(vm.RunContext.Ap, memory.NewMaybeRelocatableRelocatable(dict_ptr))
}
237 changes: 236 additions & 1 deletion pkg/hints/dict_hints_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package hints_test

import (
"reflect"
"testing"

. "github.com/lambdaclass/cairo-vm.go/pkg/hints"
Expand Down Expand Up @@ -70,7 +71,7 @@ func TestDefaultDictNewHasManager(t *testing.T) {
if err != nil {
t.Errorf("DEFAULT_DICT_NEW hint test failed with error %s", err)
}
// Check that the manager wan't replaced by a new one
// Check that the manager wasn't replaced by a new one
dictManagerPtr, ok := FetchDictManager(scopes)
if !ok || dictManagerPtr != dictManagerRef {
t.Error("DEFAULT_DICT_NEW DictManager replaced")
Expand Down Expand Up @@ -413,3 +414,237 @@ func TestDictUpdateErr(t *testing.T) {
t.Error("DICT_UPDATE hint test should have failed")
}
}

func TestDictSqushCopyDictOkEmptyDict(t *testing.T) {
vm := NewVirtualMachine()
vm.Segments.AddSegment()
scopes := types.NewExecutionScopes()

// Create dictManager & add it to scope
dictManager := dict_manager.NewDictManager()
dictManagerRef := &dictManager
initialDict := map[MaybeRelocatable]MaybeRelocatable{}
dict_ptr := dictManager.NewDictionary(&initialDict, vm)
scopes.AssignOrUpdateVariable("__dict_manager", dictManagerRef)

idsManager := SetupIdsForTest(
map[string][]*MaybeRelocatable{
"dict_accesses_end": {NewMaybeRelocatableRelocatable(dict_ptr)},
},
vm,
)
hintProcessor := CairoVmHintProcessor{}
hintData := any(HintData{
Ids: idsManager,
Code: DICT_SQUASH_COPY_DICT,
})
err := hintProcessor.ExecuteHint(vm, &hintData, nil, scopes)
if err != nil {
t.Errorf("DICT_SQUASH_COPY_DICT hint test failed with error %s", err)
}
// Check new scope
new_scope, _ := scopes.GetLocalVariables()
if !reflect.DeepEqual(new_scope, map[string]interface{}{
"__dict_manager": dictManagerRef,
"initial_dict": initialDict,
}) {
t.Errorf("DICT_SQUASH_COPY_DICT hint test wrong new sope created")
}
}

func TestDictSqushCopyDictOkNonEmptyDict(t *testing.T) {
vm := NewVirtualMachine()
vm.Segments.AddSegment()
scopes := types.NewExecutionScopes()

// Create dictManager & add it to scope
dictManager := dict_manager.NewDictManager()
dictManagerRef := &dictManager
initialDict := map[MaybeRelocatable]MaybeRelocatable{
*NewMaybeRelocatableFelt(FeltZero()): *NewMaybeRelocatableFelt(FeltOne()),
*NewMaybeRelocatableFelt(FeltOne()): *NewMaybeRelocatableFelt(FeltOne()),
}
dict_ptr := dictManager.NewDictionary(&initialDict, vm)
scopes.AssignOrUpdateVariable("__dict_manager", dictManagerRef)

idsManager := SetupIdsForTest(
map[string][]*MaybeRelocatable{
"dict_accesses_end": {NewMaybeRelocatableRelocatable(dict_ptr)},
},
vm,
)
hintProcessor := CairoVmHintProcessor{}
hintData := any(HintData{
Ids: idsManager,
Code: DICT_SQUASH_COPY_DICT,
})
err := hintProcessor.ExecuteHint(vm, &hintData, nil, scopes)
if err != nil {
t.Errorf("DICT_SQUASH_COPY_DICT hint test failed with error %s", err)
}
// Check new scope
new_scope, _ := scopes.GetLocalVariables()
if !reflect.DeepEqual(new_scope, map[string]interface{}{
"__dict_manager": dictManagerRef,
"initial_dict": initialDict,
}) {
t.Errorf("DICT_SQUASH_COPY_DICT hint test wrong new sope created")
}
}

func TestDictSquashUpdatePtrOk(t *testing.T) {
vm := NewVirtualMachine()
vm.Segments.AddSegment()
scopes := types.NewExecutionScopes()
initialDict := make(map[MaybeRelocatable]MaybeRelocatable)
// Create dictManager & add it to scope
dictManager := dict_manager.NewDictManager()
dict_ptr := dictManager.NewDictionary(&initialDict, vm)
// Keep a reference to the tracker to check that it was updated after the hint
tracker, _ := dictManager.GetTracker(dict_ptr)
scopes.AssignOrUpdateVariable("__dict_manager", &dictManager)

idsManager := SetupIdsForTest(
map[string][]*MaybeRelocatable{
"squashed_dict_start": {NewMaybeRelocatableRelocatable(dict_ptr)},
"squashed_dict_end": {NewMaybeRelocatableRelocatable(dict_ptr.AddUint(5))},
},
vm,
)
hintProcessor := CairoVmHintProcessor{}
hintData := any(HintData{
Ids: idsManager,
Code: DICT_SQUASH_UPDATE_PTR,
})
err := hintProcessor.ExecuteHint(vm, &hintData, nil, scopes)
if err != nil {
t.Errorf("DICT_SQUASH_UPDATE_PTR hint test failed with error %s", err)
}
// Check updated ptr
if tracker.CurrentPtr != dict_ptr.AddUint(5) {
t.Error("DICT_SQUASH_UPDATE_PTR hint test failed: Wrong updated tracker.CurrentPtr")
}
}

func TestDictSquashUpdatePtrMismatchedPtr(t *testing.T) {
vm := NewVirtualMachine()
vm.Segments.AddSegment()
scopes := types.NewExecutionScopes()
initialDict := make(map[MaybeRelocatable]MaybeRelocatable)
// Create dictManager & add it to scope
dictManager := dict_manager.NewDictManager()
dict_ptr := dictManager.NewDictionary(&initialDict, vm)
scopes.AssignOrUpdateVariable("__dict_manager", &dictManager)

idsManager := SetupIdsForTest(
map[string][]*MaybeRelocatable{
"squashed_dict_start": {NewMaybeRelocatableRelocatable(dict_ptr.AddUint(3))},
"squashed_dict_end": {NewMaybeRelocatableRelocatable(dict_ptr.AddUint(5))},
},
vm,
)
hintProcessor := CairoVmHintProcessor{}
hintData := any(HintData{
Ids: idsManager,
Code: DICT_SQUASH_UPDATE_PTR,
})
err := hintProcessor.ExecuteHint(vm, &hintData, nil, scopes)
if err == nil {
t.Errorf("DICT_SQUASH_UPDATE_PTR hint test should have failed")
}
}

func TestDictNewCreateManager(t *testing.T) {
vm := NewVirtualMachine()
scopes := types.NewExecutionScopes()
initialDict := make(map[MaybeRelocatable]MaybeRelocatable)
scopes.AssignOrUpdateVariable("initial_dict", initialDict)
vm.Segments.AddSegment()
idsManager := SetupIdsForTest(
map[string][]*MaybeRelocatable{},
vm,
)
hintProcessor := CairoVmHintProcessor{}
hintData := any(HintData{
Ids: idsManager,
Code: DICT_NEW,
})
// Advance AP so that values don't clash with FP-based ids
vm.RunContext.Ap = NewRelocatable(0, 5)
err := hintProcessor.ExecuteHint(vm, &hintData, nil, scopes)
if err != nil {
t.Errorf("DICT_NEW hint test failed with error %s", err)
}
// Check that a manager was created in the scope
_, ok := FetchDictManager(scopes)
if !ok {
t.Error("DICT_NEW No DictManager created")
}
// Check that the correct base was inserted into ap
val, _ := vm.Segments.Memory.Get(vm.RunContext.Ap)
if val == nil || *val != *NewMaybeRelocatableRelocatable(NewRelocatable(1, 0)) {
t.Error("DICT_NEW Wrong/No base inserted into ap")
}
}

func TestDictNewHasManager(t *testing.T) {
vm := NewVirtualMachine()
scopes := types.NewExecutionScopes()
// Create initialDict & dictManager & add them to scope
initialDict := make(map[MaybeRelocatable]MaybeRelocatable)
scopes.AssignOrUpdateVariable("initial_dict", initialDict)
dictManager := dict_manager.NewDictManager()
dictManagerRef := &dictManager
scopes.AssignOrUpdateVariable("__dict_manager", dictManagerRef)
vm.Segments.AddSegment()
idsManager := SetupIdsForTest(
map[string][]*MaybeRelocatable{},
vm,
)
hintProcessor := CairoVmHintProcessor{}
hintData := any(HintData{
Ids: idsManager,
Code: DICT_NEW,
})
// Advance AP so that values don't clash with FP-based ids
vm.RunContext.Ap = NewRelocatable(0, 5)
err := hintProcessor.ExecuteHint(vm, &hintData, nil, scopes)
if err != nil {
t.Errorf("DICT_NEW hint test failed with error %s", err)
}
// Check that the manager wasn't replaced by a new one
dictManagerPtr, ok := FetchDictManager(scopes)
if !ok || dictManagerPtr != dictManagerRef {
t.Error("DICT_NEW DictManager replaced")
}
// Check that the correct base was inserted into ap
val, _ := vm.Segments.Memory.Get(vm.RunContext.Ap)
if val == nil || *val != *NewMaybeRelocatableRelocatable(NewRelocatable(1, 0)) {
t.Error("DICT_NEW Wrong/No base inserted into ap")
}
}

func TestDictNewHasManagerNoInitialDict(t *testing.T) {
vm := NewVirtualMachine()
scopes := types.NewExecutionScopes()
// Create dictManager & add it to scope
dictManager := dict_manager.NewDictManager()
dictManagerRef := &dictManager
scopes.AssignOrUpdateVariable("__dict_manager", dictManagerRef)
vm.Segments.AddSegment()
idsManager := SetupIdsForTest(
map[string][]*MaybeRelocatable{},
vm,
)
hintProcessor := CairoVmHintProcessor{}
hintData := any(HintData{
Ids: idsManager,
Code: DICT_NEW,
})
// Advance AP so that values don't clash with FP-based ids
vm.RunContext.Ap = NewRelocatable(0, 5)
err := hintProcessor.ExecuteHint(vm, &hintData, nil, scopes)
if err == nil {
t.Errorf("DICT_NEW hint test should have failed")
}
}
6 changes: 6 additions & 0 deletions pkg/hints/hint_processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,12 @@ func (p *CairoVmHintProcessor) ExecuteHint(vm *vm.VirtualMachine, hintData *any,
return squashDictInnerUsedAccessesAssert(data.Ids, execScopes, vm)
case SQUASH_DICT_INNER_NEXT_KEY:
return squashDictInnerNextKey(data.Ids, execScopes, vm)
case DICT_SQUASH_COPY_DICT:
return dictSquashCopyDict(data.Ids, execScopes, vm)
case DICT_SQUASH_UPDATE_PTR:
return dictSquashUpdatePtr(data.Ids, execScopes, vm)
case DICT_NEW:
return dictNew(data.Ids, execScopes, vm)
case VM_EXIT_SCOPE:
return vm_exit_scope(execScopes)
case ASSERT_NOT_EQUAL:
Expand Down
8 changes: 8 additions & 0 deletions pkg/vm/cairo_run/cairo_run_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,3 +194,11 @@ func TestSquashDict(t *testing.T) {
t.Errorf("Program execution failed with error: %s", err)
}
}

func TestDictSquash(t *testing.T) {
cairoRunConfig := cairo_run.CairoRunConfig{DisableTracePadding: false, Layout: "small", ProofMode: false}
_, err := cairo_run.CairoRun("../../../cairo_programs/dict_squash.json", cairoRunConfig)
if err != nil {
t.Errorf("Program execution failed with error: %s", err)
}
}