Skip to content

Commit

Permalink
Add pow & sqrt hints (#224)
Browse files Browse the repository at this point in the history
* Add needed args in ExecuteHint

* Add pow hint

* Finish unit test

* Add sqrt hint

* Add cairo config to tests

---------

Co-authored-by: juan.mv <[email protected]>
Co-authored-by: Mariano A. Nicolini <[email protected]>
  • Loading branch information
3 people authored Sep 18, 2023
1 parent 9887892 commit 4135dd1
Show file tree
Hide file tree
Showing 10 changed files with 181 additions and 0 deletions.
18 changes: 18 additions & 0 deletions cairo_programs/pow.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
%builtins range_check

from starkware.cairo.common.pow import pow

func main{range_check_ptr: felt}() {
let (x) = pow(2, 3);
assert x = 8;
let (y) = pow(10, 6);
assert y = 1000000;
let (z) = pow(152, 25);
assert z = 3516330588649452857943715400722794159857838650852114432;
let (u) = pow(-2, 3);
assert (u) = -8;
let (v) = pow(-25, 31);
assert (v) = -21684043449710088680149056017398834228515625;

return ();
}
16 changes: 16 additions & 0 deletions cairo_programs/sqrt.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
%builtins range_check

from starkware.cairo.common.math import sqrt

func main{range_check_ptr: felt}() {
let result_a = sqrt(0);
assert result_a = 0;

let result_b = sqrt(2402);
assert result_b = 49;

let result_c = sqrt(361850278866613121369732278309507010562);
assert result_c = 19022362599493605525;

return ();
}
4 changes: 4 additions & 0 deletions pkg/hints/hint_processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ func (p *CairoVmHintProcessor) ExecuteHint(vm *vm.VirtualMachine, hintData *any,
return vm_exit_scope(execScopes)
case ASSERT_NOT_EQUAL:
return assert_not_equal(data.Ids, vm)
case POW:
return pow(data.Ids, vm)
case SQRT:
return sqrt(data.Ids, vm)
case MEMCPY_ENTER_SCOPE:
return memcpy_enter_scope(data.Ids, vm, execScopes)
case VM_ENTER_SCOPE:
Expand Down
2 changes: 2 additions & 0 deletions pkg/hints/math_hint_codes.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,5 @@ const IS_POSITIVE = "from starkware.cairo.common.math_utils import is_positive\n
const ASSERT_NOT_ZERO = "from starkware.cairo.common.math_utils import assert_integer\nassert_integer(ids.value)\nassert ids.value % PRIME != 0, f'assert_not_zero failed: {ids.value} = 0.'"

const ASSERT_NOT_EQUAL = "from starkware.cairo.lang.vm.relocatable import RelocatableValue\nboth_ints = isinstance(ids.a, int) and isinstance(ids.b, int)\nboth_relocatable = (\n isinstance(ids.a, RelocatableValue) and isinstance(ids.b, RelocatableValue) and\n ids.a.segment_index == ids.b.segment_index)\nassert both_ints or both_relocatable, \\\n f'assert_not_equal failed: non-comparable values: {ids.a}, {ids.b}.'\nassert (ids.a - ids.b) % PRIME != 0, f'assert_not_equal failed: {ids.a} = {ids.b}.'"

const SQRT = "from starkware.python.math_utils import isqrt\nvalue = ids.value % PRIME\nassert value < 2 ** 250, f\"value={value} is outside of the range [0, 2**250).\"\nassert 2 ** 250 < PRIME\nids.root = isqrt(value)"
29 changes: 29 additions & 0 deletions pkg/hints/math_hints.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"github.com/lambdaclass/cairo-vm.go/pkg/builtins"
. "github.com/lambdaclass/cairo-vm.go/pkg/hints/hint_utils"
. "github.com/lambdaclass/cairo-vm.go/pkg/lambdaworks"
. "github.com/lambdaclass/cairo-vm.go/pkg/math_utils"
. "github.com/lambdaclass/cairo-vm.go/pkg/vm"
. "github.com/lambdaclass/cairo-vm.go/pkg/vm/memory"
"github.com/pkg/errors"
Expand Down Expand Up @@ -89,3 +90,31 @@ func assert_not_equal(ids IdsManager, vm *VirtualMachine) error {
}
return nil
}

/*
Implements the hint:
from starkware.python.math_utils import isqrt
value = ids.value % PRIME
assert value < 2 ** 250, f"value={value} is outside of the range [0, 2**250)."
assert 2 ** 250 < PRIME
ids.root = isqrt(value)
*/
func sqrt(ids IdsManager, vm *VirtualMachine) error {
value, err := ids.GetFelt("value", vm)
if err != nil {
return err
}

if value.Bits() >= 250 {
return errors.Errorf("Value: %v is outside of the range [0, 2**250)", value)
}

root_big, err := ISqrt(value.ToBigInt())
if err != nil {
return err
}
root_felt := FeltFromDecString(root_big.String())
ids.Insert("root", NewMaybeRelocatableFelt(root_felt), vm)
return nil
}
29 changes: 29 additions & 0 deletions pkg/hints/math_hints_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -287,3 +287,32 @@ func TestAssertNotEqualHintOkRelocatables(t *testing.T) {
t.Errorf("ASSERT_NOT_EQUAL hint failed with error: %s", err)
}
}

func TestSqrtOk(t *testing.T) {
vm := NewVirtualMachine()
vm.Segments.AddSegment()
idsManager := SetupIdsForTest(
map[string][]*MaybeRelocatable{
"value": {NewMaybeRelocatableFelt(FeltFromDecString("9"))},
"root": {nil},
},
vm,
)
hintProcessor := CairoVmHintProcessor{}
hintData := any(HintData{
Ids: idsManager,
Code: SQRT,
})
err := hintProcessor.ExecuteHint(vm, &hintData, nil, nil)
if err != nil {
t.Errorf("SQRT hint failed with error: %s", err)
}

root, err := idsManager.GetFelt("root", vm)
if err != nil {
t.Errorf("failed to get root: %s", err)
}
if root != FeltFromUint64(3) {
t.Errorf("Expected sqrt(9) == 3. Got: %v", root)
}
}
3 changes: 3 additions & 0 deletions pkg/hints/pow_hint_codes.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
package hints

const POW = "ids.locs.bit = (ids.prev_locs.exp % PRIME) & 1"
22 changes: 22 additions & 0 deletions pkg/hints/pow_hints.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package hints

import (
. "github.com/lambdaclass/cairo-vm.go/pkg/hints/hint_utils"
. "github.com/lambdaclass/cairo-vm.go/pkg/lambdaworks"
. "github.com/lambdaclass/cairo-vm.go/pkg/vm"
. "github.com/lambdaclass/cairo-vm.go/pkg/vm/memory"
)

// Implements hint:
// %{ ids.locs.bit = (ids.prev_locs.exp % PRIME) & 1 %}
func pow(ids IdsManager, vm *VirtualMachine) error {
prev_locs_exp_addr, err := ids.GetAddr("prev_locs", vm)
prev_locs_exp, _ := vm.Segments.Memory.GetFelt(prev_locs_exp_addr.AddUint(4))

if err != nil {
return err
}

ids.Insert("locs", NewMaybeRelocatableFelt(prev_locs_exp.And(FeltOne())), vm)
return nil
}
42 changes: 42 additions & 0 deletions pkg/hints/pow_hints_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package hints_test

import (
. "github.com/lambdaclass/cairo-vm.go/pkg/hints"
. "github.com/lambdaclass/cairo-vm.go/pkg/hints/hint_utils"
. "github.com/lambdaclass/cairo-vm.go/pkg/lambdaworks"
. "github.com/lambdaclass/cairo-vm.go/pkg/vm"
. "github.com/lambdaclass/cairo-vm.go/pkg/vm/memory"
"testing"
)

func TestPowHintOk(t *testing.T) {
vm := NewVirtualMachine()
vm.Segments.AddSegment()
vm.Segments.Memory.Insert(NewRelocatable(0, 4), NewMaybeRelocatableFelt(FeltFromUint64(5)))
idsManager := SetupIdsForTest(
map[string][]*MaybeRelocatable{
"prev_locs": {NewMaybeRelocatableRelocatable(NewRelocatable(0, 0))},
"locs": {nil},
},
vm,
)
hintProcessor := CairoVmHintProcessor{}
hintData := any(HintData{
Ids: idsManager,
Code: POW,
})

err := hintProcessor.ExecuteHint(vm, &hintData, nil, nil)
if err != nil {
t.Errorf("POW hint test failed with error %s", err)
}

locs, err := idsManager.GetFelt("locs", vm)
if err != nil {
t.Errorf("Failed to get locs.bit with error: %s", err)
}

if locs != FeltFromUint64(1) {
t.Errorf("locs.bit: %d != 1", locs)
}
}
16 changes: 16 additions & 0 deletions pkg/vm/cairo_run/cairo_run_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,19 @@ func TestAssertNotEqualHint(t *testing.T) {
t.Errorf("Program execution failed with error: %s", err)
}
}

func TestPowHint(t *testing.T) {
cairoRunConfig := cairo_run.CairoRunConfig{DisableTracePadding: false, Layout: "all_cairo", ProofMode: false}
_, err := cairo_run.CairoRun("../../../cairo_programs/pow.json", cairoRunConfig)
if err != nil {
t.Errorf("Program execution failed with error: %s", err)
}
}

func TestSqrtHint(t *testing.T) {
cairoRunConfig := cairo_run.CairoRunConfig{DisableTracePadding: false, Layout: "all_cairo", ProofMode: false}
_, err := cairo_run.CairoRun("../../../cairo_programs/sqrt.json", cairoRunConfig)
if err != nil {
t.Errorf("Program execution failed with error: %s", err)
}
}

0 comments on commit 4135dd1

Please sign in to comment.