Skip to content

Commit

Permalink
Implement ASSERT_LE_FELT+ hints (#240)
Browse files Browse the repository at this point in the history
* Add IdsManager.GetConst

* Integrate into logic

* Add utils to help with fetching constants

* Add SetupConstantsForTest

* Add comments

* Guard error case

* Fix typo

* Fix util

* Add hint code

* Implement FeltFromBigInt

* Implement ASSERT_LE_FELT hint

* Add comment

* Add test

* Add the 3 assert_le_felt_exclued hints

* Add tests

* Fix identifier

* Add integration test

* Fix type const in tests

* Fix typo

* Fix test

* Fix bug

* Remove debug print

* Fix test name

* Format files

---------

Co-authored-by: Mariano Nicolini <[email protected]>
  • Loading branch information
fmoletta and entropidelic authored Sep 19, 2023
1 parent ecf7a77 commit 6f45961
Show file tree
Hide file tree
Showing 9 changed files with 327 additions and 2 deletions.
10 changes: 10 additions & 0 deletions cairo_programs/assert_le_felt.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@

%builtins range_check
from starkware.cairo.common.math import assert_le_felt

func main{range_check_ptr: felt}() {
assert_le_felt(1, 2);
assert_le_felt(-2, -1);
assert_le_felt(2, -1);
return ();
}
4 changes: 2 additions & 2 deletions cmd/cli/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@ import (

func handleCommands(ctx *cli.Context) error {
programPath := ctx.Args().First()

layout := ctx.String("layout")
if layout == "" {
layout = "plain"
}

cairoRunConfig := cairo_run.CairoRunConfig{DisableTracePadding: false, ProofMode: ctx.Bool("proof_mode"), Layout: layout}
cairoRunConfig := cairo_run.CairoRunConfig{DisableTracePadding: false, ProofMode: ctx.Bool("proof_mode"), Layout: layout}

cairoRunner, err := cairo_run.CairoRun(programPath, cairoRunConfig)
if err != nil {
Expand Down
8 changes: 8 additions & 0 deletions pkg/hints/hint_codes/math_hint_codes.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,14 @@ const ASSERT_NOT_EQUAL = "from starkware.cairo.lang.vm.relocatable import Reloca

const SQRT = "from starkware.python.math_utils import isqrt\nvalue = ids.value % PRIME\nassert value < 2 ** 250, f\"value={value} is outside of the range [0, 2**250).\"\nassert 2 ** 250 < PRIME\nids.root = isqrt(value)"

const ASSERT_LE_FELT = "import itertools\n\nfrom starkware.cairo.common.math_utils import assert_integer\nassert_integer(ids.a)\nassert_integer(ids.b)\na = ids.a % PRIME\nb = ids.b % PRIME\nassert a <= b, f'a = {a} is not less than or equal to b = {b}.'\n\n# Find an arc less than PRIME / 3, and another less than PRIME / 2.\nlengths_and_indices = [(a, 0), (b - a, 1), (PRIME - 1 - b, 2)]\nlengths_and_indices.sort()\nassert lengths_and_indices[0][0] <= PRIME // 3 and lengths_and_indices[1][0] <= PRIME // 2\nexcluded = lengths_and_indices[2][1]\n\nmemory[ids.range_check_ptr + 1], memory[ids.range_check_ptr + 0] = (\n divmod(lengths_and_indices[0][0], ids.PRIME_OVER_3_HIGH))\nmemory[ids.range_check_ptr + 3], memory[ids.range_check_ptr + 2] = (\n divmod(lengths_and_indices[1][0], ids.PRIME_OVER_2_HIGH))"

const ASSERT_LE_FELT_EXCLUDED_0 = "memory[ap] = 1 if excluded != 0 else 0"

const ASSERT_LE_FELT_EXCLUDED_1 = "memory[ap] = 1 if excluded != 1 else 0"

const ASSERT_LE_FELT_EXCLUDED_2 = "assert excluded == 2"

const ASSERT_250_BITS = "from starkware.cairo.common.math_utils import as_int\n\n# Correctness check.\nvalue = as_int(ids.value, PRIME) % PRIME\nassert value < ids.UPPER_BOUND, f'{value} is outside of the range [0, 2**250).'\n\n# Calculation for the assertion.\nids.high, ids.low = divmod(ids.value, ids.SHIFT)"

const SPLIT_FELT = "from starkware.cairo.common.math_utils import assert_integer\nassert ids.MAX_HIGH < 2**128 and ids.MAX_LOW < 2**128\nassert PRIME - 1 == ids.MAX_HIGH * 2**128 + ids.MAX_LOW\nassert_integer(ids.value)\nids.low = ids.value & ((1 << 128) - 1)\nids.high = ids.value >> 128"
8 changes: 8 additions & 0 deletions pkg/hints/hint_processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,14 @@ func (p *CairoVmHintProcessor) ExecuteHint(vm *vm.VirtualMachine, hintData *any,
return memcpy_enter_scope(data.Ids, vm, execScopes)
case VM_ENTER_SCOPE:
return vm_enter_scope(execScopes)
case ASSERT_LE_FELT:
return assertLeFelt(data.Ids, vm, execScopes, constants)
case ASSERT_LE_FELT_EXCLUDED_0:
return assertLeFeltExcluded0(vm, execScopes)
case ASSERT_LE_FELT_EXCLUDED_1:
return assertLeFeltExcluded1(vm, execScopes)
case ASSERT_LE_FELT_EXCLUDED_2:
return assertLeFeltExcluded2(vm, execScopes)
case IS_NN:
return isNN(data.Ids, vm)
case IS_NN_OUT_OF_RANGE:
Expand Down
123 changes: 123 additions & 0 deletions pkg/hints/math_hints.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
package hints

import (
"math/big"

"github.com/lambdaclass/cairo-vm.go/pkg/builtins"
. "github.com/lambdaclass/cairo-vm.go/pkg/hints/hint_utils"
"github.com/lambdaclass/cairo-vm.go/pkg/lambdaworks"
. "github.com/lambdaclass/cairo-vm.go/pkg/lambdaworks"
. "github.com/lambdaclass/cairo-vm.go/pkg/math_utils"
. "github.com/lambdaclass/cairo-vm.go/pkg/types"
. "github.com/lambdaclass/cairo-vm.go/pkg/vm"
. "github.com/lambdaclass/cairo-vm.go/pkg/vm/memory"
"github.com/pkg/errors"
Expand Down Expand Up @@ -184,6 +187,126 @@ func sqrt(ids IdsManager, vm *VirtualMachine) error {
return nil
}

func assertLeFelt(ids IdsManager, vm *VirtualMachine, scopes *ExecutionScopes, constants *map[string]Felt) error {
// Fetch constants
primeOver3HighFelt, err := ids.GetConst("PRIME_OVER_3_HIGH", constants)
if err != nil {
return err
}
primeOver3High := primeOver3HighFelt.ToBigInt()
primeOver2HighFelt, err := ids.GetConst("PRIME_OVER_2_HIGH", constants)
if err != nil {
return err
}
primeOver2High := primeOver2HighFelt.ToBigInt()
// Fetch ids variables
aFelt, err := ids.GetFelt("a", vm)
if err != nil {
return err
}
a := aFelt.ToBigInt()
bFelt, err := ids.GetFelt("b", vm)
if err != nil {
return err
}
b := bFelt.ToBigInt()
rangeCheckPtr, err := ids.GetRelocatable("range_check_ptr", vm)
if err != nil {
return err
}
// Hint Logic
cairoPrime, _ := new(big.Int).SetString(CAIRO_PRIME_HEX, 0)
halfPrime := new(big.Int).Div(cairoPrime, new(big.Int).SetUint64(2))
thirdOfPrime := new(big.Int).Div(cairoPrime, new(big.Int).SetUint64(3))
if a.Cmp(b) == 1 {
return errors.Errorf("Assertion failed, %v, is not less or equal to %v", a, b)
}
arc1 := new(big.Int).Sub(b, a)
arc2 := new(big.Int).Sub(new(big.Int).Sub(cairoPrime, (big.NewInt(1))), b)

// Split lengthsAndIndices array into lenght & idxs array and mantain the same order between them
lengths := []*big.Int{a, arc1, arc2}
idxs := []int{0, 1, 2}
// Sort lengths & idxs by lengths
for i := 0; i < 3; i++ {
for j := i; j > 0 && lengths[j-1].Cmp(lengths[j]) == 1; j-- {
lengths[j], lengths[j-1] = lengths[j-1], lengths[j]
idxs[j], idxs[j-1] = idxs[j-1], idxs[j]
}
}

if lengths[0].Cmp(thirdOfPrime) == 1 || lengths[1].Cmp(halfPrime) == 1 {
return errors.Errorf("Arc too big, %v must be <= %v and %v <= %v", lengths[0], thirdOfPrime, lengths[1], halfPrime)
}
excluded := idxs[2]
scopes.AssignOrUpdateVariable("excluded", excluded)
q_0, r_0 := new(big.Int).DivMod(lengths[0], primeOver3High, primeOver3High)
q_1, r_1 := new(big.Int).DivMod(lengths[1], primeOver2High, primeOver2High)

// Insert values into range_check_ptr
data := []MaybeRelocatable{
*NewMaybeRelocatableFelt(FeltFromBigInt(r_0)),
*NewMaybeRelocatableFelt(FeltFromBigInt(q_0)),
*NewMaybeRelocatableFelt(FeltFromBigInt(r_1)),
*NewMaybeRelocatableFelt(FeltFromBigInt(q_1)),
}
_, err = vm.Segments.LoadData(rangeCheckPtr, &data)

return err
}

// "memory[ap] = 1 if excluded != 0 else 0"
func assertLeFeltExcluded0(vm *VirtualMachine, scopes *ExecutionScopes) error {
// Fetch scope var
excludedAny, err := scopes.Get("excluded")
if err != nil {
return err
}
excluded, ok := excludedAny.(int)
if !ok {
return errors.New("excluded not in scope")
}
if excluded == 0 {
return vm.Segments.Memory.Insert(vm.RunContext.Ap, NewMaybeRelocatableFelt(FeltZero()))
}
return vm.Segments.Memory.Insert(vm.RunContext.Ap, NewMaybeRelocatableFelt(FeltOne()))
}

// "memory[ap] = 1 if excluded != 1 else 0"
func assertLeFeltExcluded1(vm *VirtualMachine, scopes *ExecutionScopes) error {
// Fetch scope var
excludedAny, err := scopes.Get("excluded")
if err != nil {
return err
}
excluded, ok := excludedAny.(int)
if !ok {
return errors.New("excluded not in scope")
}
if excluded == 1 {
return vm.Segments.Memory.Insert(vm.RunContext.Ap, NewMaybeRelocatableFelt(FeltZero()))
}
return vm.Segments.Memory.Insert(vm.RunContext.Ap, NewMaybeRelocatableFelt(FeltOne()))
}

// "assert excluded == 2"
func assertLeFeltExcluded2(vm *VirtualMachine, scopes *ExecutionScopes) error {
// Fetch scope var
excludedAny, err := scopes.Get("excluded")
if err != nil {
return err
}
excluded, ok := excludedAny.(int)
if !ok {
return errors.New("excluded not in scope")
}
if excluded != 2 {
return errors.New("Assertion Failed: excluded == 2")
}

return nil
}

// Implements hint:
//
// from starkware.cairo.common.math_utils import as_int
Expand Down
140 changes: 140 additions & 0 deletions pkg/hints/math_hints_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
. "github.com/lambdaclass/cairo-vm.go/pkg/hints/hint_utils"
"github.com/lambdaclass/cairo-vm.go/pkg/lambdaworks"
. "github.com/lambdaclass/cairo-vm.go/pkg/lambdaworks"
. "github.com/lambdaclass/cairo-vm.go/pkg/types"
. "github.com/lambdaclass/cairo-vm.go/pkg/vm"
"github.com/lambdaclass/cairo-vm.go/pkg/vm/memory"
. "github.com/lambdaclass/cairo-vm.go/pkg/vm/memory"
Expand Down Expand Up @@ -395,6 +396,145 @@ func TestSqrtOk(t *testing.T) {
}
}

func TestAssertLeFeltOk(t *testing.T) {
vm := NewVirtualMachine()
vm.Segments.AddSegment()
vm.Segments.AddSegment()
scopes := NewExecutionScopes()
idsManager := SetupIdsForTest(
map[string][]*MaybeRelocatable{
"a": {NewMaybeRelocatableFelt(FeltOne())},
"b": {NewMaybeRelocatableFelt(FeltFromUint64(2))},
"range_check_ptr": {NewMaybeRelocatableRelocatable(NewRelocatable(1, 0))},
},
vm,
)
constants := SetupConstantsForTest(map[string]Felt{
"PRIME_OVER_3_HIGH": FeltFromHex("4000000000000088000000000000001"),
"PRIME_OVER_2_HIGH": FeltFromHex("2AAAAAAAAAAAAB05555555555555556"),
},
&idsManager,
)
hintProcessor := CairoVmHintProcessor{}
hintData := any(HintData{
Ids: idsManager,
Code: ASSERT_LE_FELT,
})
err := hintProcessor.ExecuteHint(vm, &hintData, &constants, scopes)
if err != nil {
t.Errorf("ASSERT_LE_FELT hint failed with error: %s", err)
}
}

func TestAssertLeFeltExcluded0Zero(t *testing.T) {
vm := NewVirtualMachine()
vm.Segments.AddSegment()
scopes := NewExecutionScopes()
scopes.AssignOrUpdateVariable("excluded", int(0))
hintProcessor := CairoVmHintProcessor{}
hintData := any(HintData{
Code: ASSERT_LE_FELT_EXCLUDED_0,
})
err := hintProcessor.ExecuteHint(vm, &hintData, nil, scopes)
if err != nil {
t.Errorf("ASSERT_LE_FELT_EXCLUDED_0 hint test failed with error %s", err)
}
// Check the value of memory[ap]
val, err := vm.Segments.Memory.GetFelt(vm.RunContext.Ap)
if err != nil || !val.IsZero() {
t.Error("Wrong/No value inserted into ap")
}
}

func TestAssertLeFeltExcluded0One(t *testing.T) {
vm := NewVirtualMachine()
vm.Segments.AddSegment()
scopes := NewExecutionScopes()
scopes.AssignOrUpdateVariable("excluded", int(1))
hintProcessor := CairoVmHintProcessor{}
hintData := any(HintData{
Code: ASSERT_LE_FELT_EXCLUDED_0,
})
err := hintProcessor.ExecuteHint(vm, &hintData, nil, scopes)
if err != nil {
t.Errorf("ASSERT_LE_FELT_EXCLUDED_0 hint test failed with error %s", err)
}
// Check the value of memory[ap]
val, err := vm.Segments.Memory.GetFelt(vm.RunContext.Ap)
if err != nil || val != FeltOne() {
t.Error("Wrong/No value inserted into ap")
}
}

func TestAssertLeFeltExcluded1Zero(t *testing.T) {
vm := NewVirtualMachine()
vm.Segments.AddSegment()
scopes := NewExecutionScopes()
scopes.AssignOrUpdateVariable("excluded", int(1))
hintProcessor := CairoVmHintProcessor{}
hintData := any(HintData{
Code: ASSERT_LE_FELT_EXCLUDED_1,
})
err := hintProcessor.ExecuteHint(vm, &hintData, nil, scopes)
if err != nil {
t.Errorf("ASSERT_LE_FELT_EXCLUDED_1 hint test failed with error %s", err)
}
// Check the value of memory[ap]
val, err := vm.Segments.Memory.GetFelt(vm.RunContext.Ap)
if err != nil || !val.IsZero() {
t.Error("Wrong/No value inserted into ap")
}
}

func TestAssertLeFeltExcluded1One(t *testing.T) {
vm := NewVirtualMachine()
vm.Segments.AddSegment()
scopes := NewExecutionScopes()
scopes.AssignOrUpdateVariable("excluded", int(0))
hintProcessor := CairoVmHintProcessor{}
hintData := any(HintData{
Code: ASSERT_LE_FELT_EXCLUDED_1,
})
err := hintProcessor.ExecuteHint(vm, &hintData, nil, scopes)
if err != nil {
t.Errorf("ASSERT_LE_FELT_EXCLUDED_1 hint test failed with error %s", err)
}
// Check the value of memory[ap]
val, err := vm.Segments.Memory.GetFelt(vm.RunContext.Ap)
if err != nil || val != FeltOne() {
t.Error("Wrong/No value inserted into ap")
}
}

func TestAssertLeFeltExcluded2Ok(t *testing.T) {
vm := NewVirtualMachine()
vm.Segments.AddSegment()
scopes := NewExecutionScopes()
scopes.AssignOrUpdateVariable("excluded", int(2))
hintProcessor := CairoVmHintProcessor{}
hintData := any(HintData{
Code: ASSERT_LE_FELT_EXCLUDED_2,
})
err := hintProcessor.ExecuteHint(vm, &hintData, nil, scopes)
if err != nil {
t.Errorf("ASSERT_LE_FELT_EXCLUDED_2 hint test failed with error %s", err)
}
}

func TestAssertLeFeltExcluded2Err(t *testing.T) {
vm := NewVirtualMachine()
vm.Segments.AddSegment()
scopes := NewExecutionScopes()
scopes.AssignOrUpdateVariable("excluded", int(0))
hintProcessor := CairoVmHintProcessor{}
hintData := any(HintData{
Code: ASSERT_LE_FELT_EXCLUDED_2,
})
err := hintProcessor.ExecuteHint(vm, &hintData, nil, scopes)
if err == nil {
t.Errorf("ASSERT_LE_FELT_EXCLUDED_2 hint test should have failed")
}
}
func TestAssert250BitHintSuccess(t *testing.T) {
vm := NewVirtualMachine()
vm.Segments.AddSegment()
Expand Down
12 changes: 12 additions & 0 deletions pkg/lambdaworks/lambdaworks.go
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,18 @@ func (f Felt) ToBigInt() *big.Int {
return new(big.Int).SetBytes(f.ToBeBytes()[:32])
}

func FeltFromBigInt(n *big.Int) Felt {
// Perform modulo prime
prime, _ := new(big.Int).SetString(CAIRO_PRIME_HEX, 0)
if n.Cmp(prime) != -1 {
n = new(big.Int).Mod(n, prime)
}
bytes := n.Bytes()
var bytes32 [32]byte
copy(bytes32[:], bytes)
return FeltFromLeBytes(&bytes32)
}

const CAIRO_PRIME_HEX = "0x800000000000011000000000000000000000000000000000000000000000001"
const SIGNED_FELT_MAX_HEX = "0x400000000000008800000000000000000000000000000000000000000000000"

Expand Down
16 changes: 16 additions & 0 deletions pkg/lambdaworks/lambdaworks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,22 @@ func TestToBigInt(t *testing.T) {
}
}

func TestFromBigInt(t *testing.T) {
expectedFelt := lambdaworks.FeltFromUint64(26)
bigInt := new(big.Int).SetUint64(26)
if !reflect.DeepEqual(lambdaworks.FeltFromBigInt(bigInt), expectedFelt) {
t.Errorf("TestToBigInt failed. Expected: %v, Got: %v", 26, lambdaworks.FeltFromBigInt(bigInt))
}
}

func TestFromBigIntPrime(t *testing.T) {
expectedFelt := lambdaworks.FeltFromDecString("0")
bigInt, _ := new(big.Int).SetString(lambdaworks.CAIRO_PRIME_HEX, 0)
if !reflect.DeepEqual(lambdaworks.FeltFromBigInt(bigInt), expectedFelt) {
t.Errorf("TestToBigInt failed. Expected: PRIME, Got: %v", lambdaworks.FeltFromBigInt(bigInt))
}
}

func TestToSignedNegative(t *testing.T) {
felt := lambdaworks.FeltFromDecString("-1")
bigInt := felt.ToSigned()
Expand Down
Loading

0 comments on commit 6f45961

Please sign in to comment.