diff --git a/cairo_programs/proof_programs/usort.cairo b/cairo_programs/proof_programs/usort.cairo new file mode 120000 index 00000000..4a7e19e6 --- /dev/null +++ b/cairo_programs/proof_programs/usort.cairo @@ -0,0 +1 @@ +../usort.cairo \ No newline at end of file diff --git a/cairo_programs/usort.cairo b/cairo_programs/usort.cairo new file mode 100644 index 00000000..e5859b29 --- /dev/null +++ b/cairo_programs/usort.cairo @@ -0,0 +1,22 @@ +%builtins range_check +from starkware.cairo.common.usort import usort +from starkware.cairo.common.alloc import alloc + +func main{range_check_ptr}() -> () { + alloc_locals; + let (input_array: felt*) = alloc(); + assert input_array[0] = 2; + assert input_array[1] = 1; + assert input_array[2] = 0; + + let (output_len, output, multiplicities) = usort(input_len=3, input=input_array); + + assert output_len = 3; + assert output[0] = 0; + assert output[1] = 1; + assert output[2] = 2; + assert multiplicities[0] = 1; + assert multiplicities[1] = 1; + assert multiplicities[2] = 1; + return (); +} diff --git a/pkg/hints/hint_codes/usort_hint_codes.go b/pkg/hints/hint_codes/usort_hint_codes.go new file mode 100644 index 00000000..7dd3e944 --- /dev/null +++ b/pkg/hints/hint_codes/usort_hint_codes.go @@ -0,0 +1,32 @@ +package hint_codes + +const USORT_ENTER_SCOPE = "vm_enter_scope(dict(__usort_max_size = globals().get('__usort_max_size')))" + +const USORT_BODY = `from collections import defaultdict + +input_ptr = ids.input +input_len = int(ids.input_len) +if __usort_max_size is not None: + assert input_len <= __usort_max_size, ( + f"usort() can only be used with input_len<={__usort_max_size}. " + f"Got: input_len={input_len}." + ) + +positions_dict = defaultdict(list) +for i in range(input_len): + val = memory[input_ptr + i] + positions_dict[val].append(i) + +output = sorted(positions_dict.keys()) +ids.output_len = len(output) +ids.output = segments.gen_arg(output) +ids.multiplicities = segments.gen_arg([len(positions_dict[k]) for k in output])` + +const USORT_VERIFY = `last_pos = 0 +positions = positions_dict[ids.value][::-1]` + +const USORT_VERIFY_MULTIPLICITY_ASSERT = "assert len(positions) == 0" + +const USORT_VERIFY_MULTIPLICITY_BODY = `current_pos = positions.pop() +ids.next_item_index = current_pos - last_pos +last_pos = current_pos + 1` diff --git a/pkg/hints/hint_processor.go b/pkg/hints/hint_processor.go index 6ed7562f..a43c76ef 100644 --- a/pkg/hints/hint_processor.go +++ b/pkg/hints/hint_processor.go @@ -106,6 +106,16 @@ func (p *CairoVmHintProcessor) ExecuteHint(vm *vm.VirtualMachine, hintData *any, return memset_step_loop(data.Ids, vm, execScopes, "continue_loop") case VM_ENTER_SCOPE: return vm_enter_scope(execScopes) + case USORT_ENTER_SCOPE: + return usortEnterScope(execScopes) + case USORT_BODY: + return usortBody(data.Ids, execScopes, vm) + case USORT_VERIFY: + return usortVerify(data.Ids, execScopes, vm) + case USORT_VERIFY_MULTIPLICITY_ASSERT: + return usortVerifyMultiplicityAssert(execScopes) + case USORT_VERIFY_MULTIPLICITY_BODY: + return usortVerifyMultiplicityBody(data.Ids, execScopes, vm) case SET_ADD: return setAdd(data.Ids, vm) case FIND_ELEMENT: diff --git a/pkg/hints/usort_hints.go b/pkg/hints/usort_hints.go new file mode 100644 index 00000000..1dc0f364 --- /dev/null +++ b/pkg/hints/usort_hints.go @@ -0,0 +1,266 @@ +package hints + +import ( + "fmt" + "sort" + + "github.com/lambdaclass/cairo-vm.go/pkg/lambdaworks" + "github.com/lambdaclass/cairo-vm.go/pkg/types" + . "github.com/lambdaclass/cairo-vm.go/pkg/vm" + "github.com/lambdaclass/cairo-vm.go/pkg/vm/memory" + + . "github.com/lambdaclass/cairo-vm.go/pkg/hints/hint_utils" + "github.com/pkg/errors" +) + +// SortFelt implements sort.Interface for []lambdaworks.Felt +type SortFelt []lambdaworks.Felt + +func (s SortFelt) Len() int { return len(s) } +func (s SortFelt) Swap(i, j int) { s[i], s[j] = s[j], s[i] } +func (s SortFelt) Less(i, j int) bool { + a, b := s[i], s[j] + + return a.Cmp(b) == -1 +} + +// Implements hint: +// %{ vm_enter_scope(dict(__usort_max_size = globals().get('__usort_max_size'))) %} +func usortEnterScope(executionScopes *types.ExecutionScopes) error { + usort_max_size_interface, err := executionScopes.Get("usort_max_size") + + if err != nil { + executionScopes.EnterScope(make(map[string]interface{})) + return nil + } + + usort_max_size, cast_ok := usort_max_size_interface.(uint64) + + if !cast_ok { + return errors.New("Error casting usort_max_size into a uint64") + } + + scope := make(map[string]interface{}) + scope["usort_max_size"] = usort_max_size + executionScopes.EnterScope(scope) + + return nil +} + +func usortBody(ids IdsManager, executionScopes *types.ExecutionScopes, vm *VirtualMachine) error { + + input_ptr, err := ids.GetRelocatable("input", vm) + if err != nil { + return err + } + + input_len, err := ids.GetFelt("input_len", vm) + + if err != nil { + return err + } + input_len_u64, err := input_len.ToU64() + + if err != nil { + return err + } + + usort_max_size, err := executionScopes.Get("usort_max_size") + + if err == nil { + usort_max_size_u64, cast_ok := usort_max_size.(uint64) + + if !cast_ok { + return errors.New("Error casting usort_max_size into a uint64") + } + + if input_len_u64 > usort_max_size_u64 { + return errors.New(fmt.Sprintf("usort() can only be used with input_len<= %v. Got: input_len=%v.", usort_max_size_u64, input_len_u64)) + } + } + + positions_dict := make(map[lambdaworks.Felt][]uint64) + + for i := uint64(0); i < input_len_u64; i++ { + + val, err := vm.Segments.Memory.GetFelt(input_ptr.AddUint(uint(i))) + + if err != nil { + return err + } + + positions_dict[val] = append(positions_dict[val], i) + } + executionScopes.AssignOrUpdateVariable("positions_dict", positions_dict) + + output := make([]lambdaworks.Felt, 0, len(positions_dict)) + + for key := range positions_dict { + output = append(output, key) + } + + sort.Sort(SortFelt(output)) + + output_len := memory.NewMaybeRelocatableFelt(lambdaworks.FeltFromUint64(uint64((len(output))))) + err = ids.Insert("output_len", output_len, vm) + + if err != nil { + return err + } + + output_base := vm.Segments.AddSegment() + + for i := range output { + err = vm.Segments.Memory.Insert(output_base.AddUint(uint(i)), memory.NewMaybeRelocatableFelt(output[i])) + + if err != nil { + return err + } + } + + multiplicities_base := vm.Segments.AddSegment() + + multiplicities := make([]uint64, 0, len(output)) + + for key := range output { + multiplicities = append(multiplicities, uint64(len(positions_dict[output[key]]))) + } + + for i := range multiplicities { + err = vm.Segments.Memory.Insert(multiplicities_base.AddUint(uint(i)), memory.NewMaybeRelocatableFelt(lambdaworks.FeltFromUint64(multiplicities[i]))) + + if err != nil { + return err + } + } + + err = ids.Insert("output", memory.NewMaybeRelocatableRelocatable(output_base), vm) + + if err != nil { + return err + } + + err = ids.Insert("multiplicities", memory.NewMaybeRelocatableRelocatable(multiplicities_base), vm) + + if err != nil { + return err + } + + return nil +} + +// Implements hint: +// +// %{ +// last_pos = 0 +// positions = positions_dict[ids.value][::-1] +// %} +func usortVerify(ids IdsManager, executionScopes *types.ExecutionScopes, vm *VirtualMachine) error { + + executionScopes.AssignOrUpdateVariable("last_pos", uint64(0)) + + positions_dict_interface, err := executionScopes.Get("positions_dict") + + if err != nil { + return err + } + + positions_dict, cast_ok := positions_dict_interface.(map[lambdaworks.Felt][]uint64) + + if !cast_ok { + return errors.New("Error casting positions_dict") + } + + value, err := ids.GetFelt("value", vm) + if err != nil { + return err + } + + if err != nil { + return err + } + + positions := positions_dict[value] + + for i, j := 0, len(positions)-1; i < j; i, j = i+1, j-1 { + positions[i], positions[j] = positions[j], positions[i] + } + + executionScopes.AssignOrUpdateVariable("positions", positions) + + return nil +} + +// Implements hint: +// %{ assert len(positions) == 0 %} +func usortVerifyMultiplicityAssert(executionScopes *types.ExecutionScopes) error { + + positions_interface, err := executionScopes.Get("positions") + + if err != nil { + return err + } + + positions, cast_ok := positions_interface.([]uint64) + + if !cast_ok { + return errors.New("Error casting positions to []uint64") + } + + if len(positions) != 0 { + return errors.New("Assertion failed: len(positions) == 0") + } + + return nil + +} + +// Implements hint: +// +// %{ +// current_pos = positions.pop() +// ids.next_item_index = current_pos - last_pos +// last_pos = current_pos + 1 +// %} +func usortVerifyMultiplicityBody(ids IdsManager, executionScopes *types.ExecutionScopes, vm *VirtualMachine) error { + + positions_interface, err := executionScopes.Get("positions") + + if err != nil { + return err + } + + positions, cast_ok := positions_interface.([]uint64) + + if !cast_ok { + return errors.New("Error casting positions to []uint64") + } + + last_pos_interface, err := executionScopes.Get("last_pos") + + if err != nil { + return err + } + + last_pos, cast_ok := last_pos_interface.(uint64) + + if !cast_ok { + return errors.New("Error casting last_pos to uint64") + } + + current_pos := positions[len(positions)-1] + + executionScopes.AssignOrUpdateVariable("positions", positions[:len(positions)-1]) + + next_item_index := current_pos - last_pos + + err = ids.Insert("next_item_index", memory.NewMaybeRelocatableFelt(lambdaworks.FeltFromUint64(next_item_index)), vm) + + if err != nil { + return err + } + + executionScopes.AssignOrUpdateVariable("last_pos", current_pos+1) + + return nil +} diff --git a/pkg/hints/usort_hints_test.go b/pkg/hints/usort_hints_test.go new file mode 100644 index 00000000..c390869e --- /dev/null +++ b/pkg/hints/usort_hints_test.go @@ -0,0 +1,240 @@ +package hints_test + +import ( + "reflect" + "sort" + "testing" + + . "github.com/lambdaclass/cairo-vm.go/pkg/hints" + . "github.com/lambdaclass/cairo-vm.go/pkg/hints/hint_codes" + . "github.com/lambdaclass/cairo-vm.go/pkg/hints/hint_utils" + "github.com/lambdaclass/cairo-vm.go/pkg/lambdaworks" + "github.com/lambdaclass/cairo-vm.go/pkg/types" + . "github.com/lambdaclass/cairo-vm.go/pkg/vm" + . "github.com/lambdaclass/cairo-vm.go/pkg/vm/memory" +) + +func TestSortFeltArray(t *testing.T) { + array := []lambdaworks.Felt{lambdaworks.FeltFromUint(6), lambdaworks.FeltFromUint(0), lambdaworks.FeltFromUint(100), lambdaworks.FeltFromUint(1), lambdaworks.FeltFromUint(50)} + + sort.Sort(SortFelt(array)) + + sortedarray := []lambdaworks.Felt{lambdaworks.FeltFromUint(0), lambdaworks.FeltFromUint(1), lambdaworks.FeltFromUint(6), lambdaworks.FeltFromUint(50), lambdaworks.FeltFromUint(100)} + + if !reflect.DeepEqual(array, sortedarray) { + t.Errorf("Error sorting felt array") + } + +} + +func TestUsortWithMaxSize(t *testing.T) { + vm := NewVirtualMachine() + scopes := types.NewExecutionScopes() + scopes.AssignOrUpdateVariable("usort_max_size", uint64(1)) + idsManager := SetupIdsForTest( + map[string][]*MaybeRelocatable{}, + vm, + ) + hintProcessor := CairoVmHintProcessor{} + hintData := any(HintData{ + Ids: idsManager, + Code: USORT_ENTER_SCOPE, + }) + err := hintProcessor.ExecuteHint(vm, &hintData, nil, scopes) + if err != nil { + t.Errorf("USORT_ENTER_SCOPE hint execution failed") + } + + usort_max_size_interface, err := scopes.Get("usort_max_size") + + if err != nil { + t.Errorf("Error assigning usort_max_size") + } + + usort_max_size := usort_max_size_interface.(uint64) + + if usort_max_size != uint64(1) { + t.Errorf("Error assigning usort_max_size") + } + +} +func TestUsortOutOfRange(t *testing.T) { + vm := NewVirtualMachine() + vm.Segments.AddSegment() + vm.Segments.AddSegment() + vm.Segments.AddSegment() + scopes := types.NewExecutionScopes() + scopes.AssignOrUpdateVariable("usort_max_size", uint64(1)) + idsManager := SetupIdsForTest( + map[string][]*MaybeRelocatable{ + "input": {NewMaybeRelocatableRelocatable(NewRelocatable(2, 1))}, + "input_len": {NewMaybeRelocatableFelt(lambdaworks.FeltFromUint64(5))}, + }, + vm, + ) + hintProcessor := CairoVmHintProcessor{} + hintData := any(HintData{ + Ids: idsManager, + Code: USORT_BODY, + }) + err := hintProcessor.ExecuteHint(vm, &hintData, nil, scopes) + if err == nil { + t.Errorf("USORT_BODY hint should have failed") + } + +} + +func TestUsortVerify(t *testing.T) { + vm := NewVirtualMachine() + vm.Segments.AddSegment() + vm.Segments.AddSegment() + vm.Segments.AddSegment() + scopes := types.NewExecutionScopes() + positions_dict := make(map[lambdaworks.Felt][]uint64) + positions_dict[lambdaworks.FeltFromUint64(0)] = []uint64{2} + positions_dict[lambdaworks.FeltFromUint64(1)] = []uint64{1} + positions_dict[lambdaworks.FeltFromUint64(2)] = []uint64{0} + + scopes.AssignOrUpdateVariable("positions_dict", positions_dict) + + idsManager := SetupIdsForTest( + map[string][]*MaybeRelocatable{ + "value": {NewMaybeRelocatableFelt(lambdaworks.FeltFromUint64(0))}, + }, + vm, + ) + hintProcessor := CairoVmHintProcessor{} + hintData := any(HintData{ + Ids: idsManager, + Code: USORT_VERIFY, + }) + err := hintProcessor.ExecuteHint(vm, &hintData, nil, scopes) + if err != nil { + t.Errorf("USORT_VERIFY failed") + } + + positions_interface, err := scopes.Get("positions") + + if err != nil { + t.Errorf("Error assigning positions_interface") + } + + positions := positions_interface.([]uint64) + + if !reflect.DeepEqual(positions, []uint64{2}) { + t.Errorf("Error assigning positions") + } + + last_pos_interface, err := scopes.Get("last_pos") + + if err != nil { + t.Errorf("Error assigning last_pos") + } + + last_pos := last_pos_interface.(uint64) + + if last_pos != uint64(0) { + t.Errorf("Error assigning last_pos") + } + +} + +func TestUsortVerifyMultiplicityAssert(t *testing.T) { + vm := NewVirtualMachine() + scopes := types.NewExecutionScopes() + + idsManager := SetupIdsForTest( + map[string][]*MaybeRelocatable{}, + vm, + ) + hintProcessor := CairoVmHintProcessor{} + hintData := any(HintData{ + Ids: idsManager, + Code: USORT_VERIFY_MULTIPLICITY_ASSERT, + }) + err := hintProcessor.ExecuteHint(vm, &hintData, nil, scopes) + if err == nil { + t.Errorf("USORT_VERIFY_MULTIPLICITY_ASSERT should have failed") + } + + positions := []uint64{0} + + scopes.AssignOrUpdateVariable("positions", positions) + + err = hintProcessor.ExecuteHint(vm, &hintData, nil, scopes) + if err == nil { + t.Errorf("USORT_VERIFY_MULTIPLICITY_ASSERT should have failed") + } + + positions = []uint64{} + + scopes.AssignOrUpdateVariable("positions", positions) + + err = hintProcessor.ExecuteHint(vm, &hintData, nil, scopes) + if err != nil { + t.Errorf("USORT_VERIFY_MULTIPLICITY_ASSERT failed") + } + +} + +func TestUsortVerifyMultiplicityBody(t *testing.T) { + vm := NewVirtualMachine() + vm.Segments.AddSegment() + scopes := types.NewExecutionScopes() + + scopes.AssignOrUpdateVariable("positions", []uint64{1, 0, 4, 7, 10}) + scopes.AssignOrUpdateVariable("last_pos", uint64(3)) + + idsManager := SetupIdsForTest( + map[string][]*MaybeRelocatable{ + "next_item_index": {nil}, + }, + vm, + ) + hintProcessor := CairoVmHintProcessor{} + hintData := any(HintData{ + Ids: idsManager, + Code: USORT_VERIFY_MULTIPLICITY_BODY, + }) + err := hintProcessor.ExecuteHint(vm, &hintData, nil, scopes) + if err != nil { + t.Errorf("USORT_VERIFY_MULTIPLICITY_BODY failed") + } + + // Check scopes variables + positions_interface, err := scopes.Get("positions") + + if err != nil { + t.Errorf("Error assigning positions_interface") + } + + positions := positions_interface.([]uint64) + + if !reflect.DeepEqual(positions, []uint64{1, 0, 4, 7}) { + t.Errorf("Error assigning positions") + } + + last_pos_interface, err := scopes.Get("last_pos") + + if err != nil { + t.Errorf("Error assigning last_pos") + } + + last_pos := last_pos_interface.(uint64) + + if last_pos != uint64(11) { + t.Errorf("Error assigning last_pos") + } + + // Check VM inserts + next_item_index, err := idsManager.GetFelt("next_item_index", vm) + + if err != nil { + t.Errorf("Error assigning next_item_index") + } + + if next_item_index != lambdaworks.FeltFromUint(7) { + t.Errorf("Error assigning next_item_index") + } + +} diff --git a/pkg/vm/cairo_run/cairo_run_test.go b/pkg/vm/cairo_run/cairo_run_test.go index 3f54eacd..f2921821 100644 --- a/pkg/vm/cairo_run/cairo_run_test.go +++ b/pkg/vm/cairo_run/cairo_run_test.go @@ -309,6 +309,14 @@ func TestSplitFeltHint(t *testing.T) { testProgram("split_felt", t) } +func TestUsort(t *testing.T) { + testProgram("usort", t) +} + +func TestUsortProofMode(t *testing.T) { + testProgramProof("usort", t) +} + func TestSplitFeltHintProofMode(t *testing.T) { testProgramProof("split_felt", t) }