From 6f4596196dda78fda3aa91eb7fe3c0e05562b485 Mon Sep 17 00:00:00 2001 From: fmoletta <99273364+fmoletta@users.noreply.github.com> Date: Wed, 20 Sep 2023 00:21:03 +0300 Subject: [PATCH] Implement `ASSERT_LE_FELT+` hints (#240) * Add IdsManager.GetConst * Integrate into logic * Add utils to help with fetching constants * Add SetupConstantsForTest * Add comments * Guard error case * Fix typo * Fix util * Add hint code * Implement FeltFromBigInt * Implement ASSERT_LE_FELT hint * Add comment * Add test * Add the 3 assert_le_felt_exclued hints * Add tests * Fix identifier * Add integration test * Fix type const in tests * Fix typo * Fix test * Fix bug * Remove debug print * Fix test name * Format files --------- Co-authored-by: Mariano Nicolini --- cairo_programs/assert_le_felt.cairo | 10 ++ cmd/cli/main.go | 4 +- pkg/hints/hint_codes/math_hint_codes.go | 8 ++ pkg/hints/hint_processor.go | 8 ++ pkg/hints/math_hints.go | 123 +++++++++++++++++++++ pkg/hints/math_hints_test.go | 140 ++++++++++++++++++++++++ pkg/lambdaworks/lambdaworks.go | 12 ++ pkg/lambdaworks/lambdaworks_test.go | 16 +++ pkg/vm/cairo_run/cairo_run_test.go | 8 ++ 9 files changed, 327 insertions(+), 2 deletions(-) create mode 100644 cairo_programs/assert_le_felt.cairo diff --git a/cairo_programs/assert_le_felt.cairo b/cairo_programs/assert_le_felt.cairo new file mode 100644 index 00000000..df623734 --- /dev/null +++ b/cairo_programs/assert_le_felt.cairo @@ -0,0 +1,10 @@ + +%builtins range_check +from starkware.cairo.common.math import assert_le_felt + +func main{range_check_ptr: felt}() { + assert_le_felt(1, 2); + assert_le_felt(-2, -1); + assert_le_felt(2, -1); + return (); +} diff --git a/cmd/cli/main.go b/cmd/cli/main.go index f671fb79..db77623b 100644 --- a/cmd/cli/main.go +++ b/cmd/cli/main.go @@ -11,13 +11,13 @@ import ( func handleCommands(ctx *cli.Context) error { programPath := ctx.Args().First() - + layout := ctx.String("layout") if layout == "" { layout = "plain" } - cairoRunConfig := cairo_run.CairoRunConfig{DisableTracePadding: false, ProofMode: ctx.Bool("proof_mode"), Layout: layout} + cairoRunConfig := cairo_run.CairoRunConfig{DisableTracePadding: false, ProofMode: ctx.Bool("proof_mode"), Layout: layout} cairoRunner, err := cairo_run.CairoRun(programPath, cairoRunConfig) if err != nil { diff --git a/pkg/hints/hint_codes/math_hint_codes.go b/pkg/hints/hint_codes/math_hint_codes.go index 803e5626..6bfdc9e3 100644 --- a/pkg/hints/hint_codes/math_hint_codes.go +++ b/pkg/hints/hint_codes/math_hint_codes.go @@ -21,6 +21,14 @@ const ASSERT_NOT_EQUAL = "from starkware.cairo.lang.vm.relocatable import Reloca const SQRT = "from starkware.python.math_utils import isqrt\nvalue = ids.value % PRIME\nassert value < 2 ** 250, f\"value={value} is outside of the range [0, 2**250).\"\nassert 2 ** 250 < PRIME\nids.root = isqrt(value)" +const ASSERT_LE_FELT = "import itertools\n\nfrom starkware.cairo.common.math_utils import assert_integer\nassert_integer(ids.a)\nassert_integer(ids.b)\na = ids.a % PRIME\nb = ids.b % PRIME\nassert a <= b, f'a = {a} is not less than or equal to b = {b}.'\n\n# Find an arc less than PRIME / 3, and another less than PRIME / 2.\nlengths_and_indices = [(a, 0), (b - a, 1), (PRIME - 1 - b, 2)]\nlengths_and_indices.sort()\nassert lengths_and_indices[0][0] <= PRIME // 3 and lengths_and_indices[1][0] <= PRIME // 2\nexcluded = lengths_and_indices[2][1]\n\nmemory[ids.range_check_ptr + 1], memory[ids.range_check_ptr + 0] = (\n divmod(lengths_and_indices[0][0], ids.PRIME_OVER_3_HIGH))\nmemory[ids.range_check_ptr + 3], memory[ids.range_check_ptr + 2] = (\n divmod(lengths_and_indices[1][0], ids.PRIME_OVER_2_HIGH))" + +const ASSERT_LE_FELT_EXCLUDED_0 = "memory[ap] = 1 if excluded != 0 else 0" + +const ASSERT_LE_FELT_EXCLUDED_1 = "memory[ap] = 1 if excluded != 1 else 0" + +const ASSERT_LE_FELT_EXCLUDED_2 = "assert excluded == 2" + const ASSERT_250_BITS = "from starkware.cairo.common.math_utils import as_int\n\n# Correctness check.\nvalue = as_int(ids.value, PRIME) % PRIME\nassert value < ids.UPPER_BOUND, f'{value} is outside of the range [0, 2**250).'\n\n# Calculation for the assertion.\nids.high, ids.low = divmod(ids.value, ids.SHIFT)" const SPLIT_FELT = "from starkware.cairo.common.math_utils import assert_integer\nassert ids.MAX_HIGH < 2**128 and ids.MAX_LOW < 2**128\nassert PRIME - 1 == ids.MAX_HIGH * 2**128 + ids.MAX_LOW\nassert_integer(ids.value)\nids.low = ids.value & ((1 << 128) - 1)\nids.high = ids.value >> 128" diff --git a/pkg/hints/hint_processor.go b/pkg/hints/hint_processor.go index feda7fb3..b877bd1f 100644 --- a/pkg/hints/hint_processor.go +++ b/pkg/hints/hint_processor.go @@ -100,6 +100,14 @@ func (p *CairoVmHintProcessor) ExecuteHint(vm *vm.VirtualMachine, hintData *any, return memcpy_enter_scope(data.Ids, vm, execScopes) case VM_ENTER_SCOPE: return vm_enter_scope(execScopes) + case ASSERT_LE_FELT: + return assertLeFelt(data.Ids, vm, execScopes, constants) + case ASSERT_LE_FELT_EXCLUDED_0: + return assertLeFeltExcluded0(vm, execScopes) + case ASSERT_LE_FELT_EXCLUDED_1: + return assertLeFeltExcluded1(vm, execScopes) + case ASSERT_LE_FELT_EXCLUDED_2: + return assertLeFeltExcluded2(vm, execScopes) case IS_NN: return isNN(data.Ids, vm) case IS_NN_OUT_OF_RANGE: diff --git a/pkg/hints/math_hints.go b/pkg/hints/math_hints.go index 3c963ee0..17ed3643 100644 --- a/pkg/hints/math_hints.go +++ b/pkg/hints/math_hints.go @@ -1,11 +1,14 @@ package hints import ( + "math/big" + "github.com/lambdaclass/cairo-vm.go/pkg/builtins" . "github.com/lambdaclass/cairo-vm.go/pkg/hints/hint_utils" "github.com/lambdaclass/cairo-vm.go/pkg/lambdaworks" . "github.com/lambdaclass/cairo-vm.go/pkg/lambdaworks" . "github.com/lambdaclass/cairo-vm.go/pkg/math_utils" + . "github.com/lambdaclass/cairo-vm.go/pkg/types" . "github.com/lambdaclass/cairo-vm.go/pkg/vm" . "github.com/lambdaclass/cairo-vm.go/pkg/vm/memory" "github.com/pkg/errors" @@ -184,6 +187,126 @@ func sqrt(ids IdsManager, vm *VirtualMachine) error { return nil } +func assertLeFelt(ids IdsManager, vm *VirtualMachine, scopes *ExecutionScopes, constants *map[string]Felt) error { + // Fetch constants + primeOver3HighFelt, err := ids.GetConst("PRIME_OVER_3_HIGH", constants) + if err != nil { + return err + } + primeOver3High := primeOver3HighFelt.ToBigInt() + primeOver2HighFelt, err := ids.GetConst("PRIME_OVER_2_HIGH", constants) + if err != nil { + return err + } + primeOver2High := primeOver2HighFelt.ToBigInt() + // Fetch ids variables + aFelt, err := ids.GetFelt("a", vm) + if err != nil { + return err + } + a := aFelt.ToBigInt() + bFelt, err := ids.GetFelt("b", vm) + if err != nil { + return err + } + b := bFelt.ToBigInt() + rangeCheckPtr, err := ids.GetRelocatable("range_check_ptr", vm) + if err != nil { + return err + } + // Hint Logic + cairoPrime, _ := new(big.Int).SetString(CAIRO_PRIME_HEX, 0) + halfPrime := new(big.Int).Div(cairoPrime, new(big.Int).SetUint64(2)) + thirdOfPrime := new(big.Int).Div(cairoPrime, new(big.Int).SetUint64(3)) + if a.Cmp(b) == 1 { + return errors.Errorf("Assertion failed, %v, is not less or equal to %v", a, b) + } + arc1 := new(big.Int).Sub(b, a) + arc2 := new(big.Int).Sub(new(big.Int).Sub(cairoPrime, (big.NewInt(1))), b) + + // Split lengthsAndIndices array into lenght & idxs array and mantain the same order between them + lengths := []*big.Int{a, arc1, arc2} + idxs := []int{0, 1, 2} + // Sort lengths & idxs by lengths + for i := 0; i < 3; i++ { + for j := i; j > 0 && lengths[j-1].Cmp(lengths[j]) == 1; j-- { + lengths[j], lengths[j-1] = lengths[j-1], lengths[j] + idxs[j], idxs[j-1] = idxs[j-1], idxs[j] + } + } + + if lengths[0].Cmp(thirdOfPrime) == 1 || lengths[1].Cmp(halfPrime) == 1 { + return errors.Errorf("Arc too big, %v must be <= %v and %v <= %v", lengths[0], thirdOfPrime, lengths[1], halfPrime) + } + excluded := idxs[2] + scopes.AssignOrUpdateVariable("excluded", excluded) + q_0, r_0 := new(big.Int).DivMod(lengths[0], primeOver3High, primeOver3High) + q_1, r_1 := new(big.Int).DivMod(lengths[1], primeOver2High, primeOver2High) + + // Insert values into range_check_ptr + data := []MaybeRelocatable{ + *NewMaybeRelocatableFelt(FeltFromBigInt(r_0)), + *NewMaybeRelocatableFelt(FeltFromBigInt(q_0)), + *NewMaybeRelocatableFelt(FeltFromBigInt(r_1)), + *NewMaybeRelocatableFelt(FeltFromBigInt(q_1)), + } + _, err = vm.Segments.LoadData(rangeCheckPtr, &data) + + return err +} + +// "memory[ap] = 1 if excluded != 0 else 0" +func assertLeFeltExcluded0(vm *VirtualMachine, scopes *ExecutionScopes) error { + // Fetch scope var + excludedAny, err := scopes.Get("excluded") + if err != nil { + return err + } + excluded, ok := excludedAny.(int) + if !ok { + return errors.New("excluded not in scope") + } + if excluded == 0 { + return vm.Segments.Memory.Insert(vm.RunContext.Ap, NewMaybeRelocatableFelt(FeltZero())) + } + return vm.Segments.Memory.Insert(vm.RunContext.Ap, NewMaybeRelocatableFelt(FeltOne())) +} + +// "memory[ap] = 1 if excluded != 1 else 0" +func assertLeFeltExcluded1(vm *VirtualMachine, scopes *ExecutionScopes) error { + // Fetch scope var + excludedAny, err := scopes.Get("excluded") + if err != nil { + return err + } + excluded, ok := excludedAny.(int) + if !ok { + return errors.New("excluded not in scope") + } + if excluded == 1 { + return vm.Segments.Memory.Insert(vm.RunContext.Ap, NewMaybeRelocatableFelt(FeltZero())) + } + return vm.Segments.Memory.Insert(vm.RunContext.Ap, NewMaybeRelocatableFelt(FeltOne())) +} + +// "assert excluded == 2" +func assertLeFeltExcluded2(vm *VirtualMachine, scopes *ExecutionScopes) error { + // Fetch scope var + excludedAny, err := scopes.Get("excluded") + if err != nil { + return err + } + excluded, ok := excludedAny.(int) + if !ok { + return errors.New("excluded not in scope") + } + if excluded != 2 { + return errors.New("Assertion Failed: excluded == 2") + } + + return nil +} + // Implements hint: // // from starkware.cairo.common.math_utils import as_int diff --git a/pkg/hints/math_hints_test.go b/pkg/hints/math_hints_test.go index 641a64f2..888e695e 100644 --- a/pkg/hints/math_hints_test.go +++ b/pkg/hints/math_hints_test.go @@ -9,6 +9,7 @@ import ( . "github.com/lambdaclass/cairo-vm.go/pkg/hints/hint_utils" "github.com/lambdaclass/cairo-vm.go/pkg/lambdaworks" . "github.com/lambdaclass/cairo-vm.go/pkg/lambdaworks" + . "github.com/lambdaclass/cairo-vm.go/pkg/types" . "github.com/lambdaclass/cairo-vm.go/pkg/vm" "github.com/lambdaclass/cairo-vm.go/pkg/vm/memory" . "github.com/lambdaclass/cairo-vm.go/pkg/vm/memory" @@ -395,6 +396,145 @@ func TestSqrtOk(t *testing.T) { } } +func TestAssertLeFeltOk(t *testing.T) { + vm := NewVirtualMachine() + vm.Segments.AddSegment() + vm.Segments.AddSegment() + scopes := NewExecutionScopes() + idsManager := SetupIdsForTest( + map[string][]*MaybeRelocatable{ + "a": {NewMaybeRelocatableFelt(FeltOne())}, + "b": {NewMaybeRelocatableFelt(FeltFromUint64(2))}, + "range_check_ptr": {NewMaybeRelocatableRelocatable(NewRelocatable(1, 0))}, + }, + vm, + ) + constants := SetupConstantsForTest(map[string]Felt{ + "PRIME_OVER_3_HIGH": FeltFromHex("4000000000000088000000000000001"), + "PRIME_OVER_2_HIGH": FeltFromHex("2AAAAAAAAAAAAB05555555555555556"), + }, + &idsManager, + ) + hintProcessor := CairoVmHintProcessor{} + hintData := any(HintData{ + Ids: idsManager, + Code: ASSERT_LE_FELT, + }) + err := hintProcessor.ExecuteHint(vm, &hintData, &constants, scopes) + if err != nil { + t.Errorf("ASSERT_LE_FELT hint failed with error: %s", err) + } +} + +func TestAssertLeFeltExcluded0Zero(t *testing.T) { + vm := NewVirtualMachine() + vm.Segments.AddSegment() + scopes := NewExecutionScopes() + scopes.AssignOrUpdateVariable("excluded", int(0)) + hintProcessor := CairoVmHintProcessor{} + hintData := any(HintData{ + Code: ASSERT_LE_FELT_EXCLUDED_0, + }) + err := hintProcessor.ExecuteHint(vm, &hintData, nil, scopes) + if err != nil { + t.Errorf("ASSERT_LE_FELT_EXCLUDED_0 hint test failed with error %s", err) + } + // Check the value of memory[ap] + val, err := vm.Segments.Memory.GetFelt(vm.RunContext.Ap) + if err != nil || !val.IsZero() { + t.Error("Wrong/No value inserted into ap") + } +} + +func TestAssertLeFeltExcluded0One(t *testing.T) { + vm := NewVirtualMachine() + vm.Segments.AddSegment() + scopes := NewExecutionScopes() + scopes.AssignOrUpdateVariable("excluded", int(1)) + hintProcessor := CairoVmHintProcessor{} + hintData := any(HintData{ + Code: ASSERT_LE_FELT_EXCLUDED_0, + }) + err := hintProcessor.ExecuteHint(vm, &hintData, nil, scopes) + if err != nil { + t.Errorf("ASSERT_LE_FELT_EXCLUDED_0 hint test failed with error %s", err) + } + // Check the value of memory[ap] + val, err := vm.Segments.Memory.GetFelt(vm.RunContext.Ap) + if err != nil || val != FeltOne() { + t.Error("Wrong/No value inserted into ap") + } +} + +func TestAssertLeFeltExcluded1Zero(t *testing.T) { + vm := NewVirtualMachine() + vm.Segments.AddSegment() + scopes := NewExecutionScopes() + scopes.AssignOrUpdateVariable("excluded", int(1)) + hintProcessor := CairoVmHintProcessor{} + hintData := any(HintData{ + Code: ASSERT_LE_FELT_EXCLUDED_1, + }) + err := hintProcessor.ExecuteHint(vm, &hintData, nil, scopes) + if err != nil { + t.Errorf("ASSERT_LE_FELT_EXCLUDED_1 hint test failed with error %s", err) + } + // Check the value of memory[ap] + val, err := vm.Segments.Memory.GetFelt(vm.RunContext.Ap) + if err != nil || !val.IsZero() { + t.Error("Wrong/No value inserted into ap") + } +} + +func TestAssertLeFeltExcluded1One(t *testing.T) { + vm := NewVirtualMachine() + vm.Segments.AddSegment() + scopes := NewExecutionScopes() + scopes.AssignOrUpdateVariable("excluded", int(0)) + hintProcessor := CairoVmHintProcessor{} + hintData := any(HintData{ + Code: ASSERT_LE_FELT_EXCLUDED_1, + }) + err := hintProcessor.ExecuteHint(vm, &hintData, nil, scopes) + if err != nil { + t.Errorf("ASSERT_LE_FELT_EXCLUDED_1 hint test failed with error %s", err) + } + // Check the value of memory[ap] + val, err := vm.Segments.Memory.GetFelt(vm.RunContext.Ap) + if err != nil || val != FeltOne() { + t.Error("Wrong/No value inserted into ap") + } +} + +func TestAssertLeFeltExcluded2Ok(t *testing.T) { + vm := NewVirtualMachine() + vm.Segments.AddSegment() + scopes := NewExecutionScopes() + scopes.AssignOrUpdateVariable("excluded", int(2)) + hintProcessor := CairoVmHintProcessor{} + hintData := any(HintData{ + Code: ASSERT_LE_FELT_EXCLUDED_2, + }) + err := hintProcessor.ExecuteHint(vm, &hintData, nil, scopes) + if err != nil { + t.Errorf("ASSERT_LE_FELT_EXCLUDED_2 hint test failed with error %s", err) + } +} + +func TestAssertLeFeltExcluded2Err(t *testing.T) { + vm := NewVirtualMachine() + vm.Segments.AddSegment() + scopes := NewExecutionScopes() + scopes.AssignOrUpdateVariable("excluded", int(0)) + hintProcessor := CairoVmHintProcessor{} + hintData := any(HintData{ + Code: ASSERT_LE_FELT_EXCLUDED_2, + }) + err := hintProcessor.ExecuteHint(vm, &hintData, nil, scopes) + if err == nil { + t.Errorf("ASSERT_LE_FELT_EXCLUDED_2 hint test should have failed") + } +} func TestAssert250BitHintSuccess(t *testing.T) { vm := NewVirtualMachine() vm.Segments.AddSegment() diff --git a/pkg/lambdaworks/lambdaworks.go b/pkg/lambdaworks/lambdaworks.go index e4ffd4b9..d66e7b24 100644 --- a/pkg/lambdaworks/lambdaworks.go +++ b/pkg/lambdaworks/lambdaworks.go @@ -283,6 +283,18 @@ func (f Felt) ToBigInt() *big.Int { return new(big.Int).SetBytes(f.ToBeBytes()[:32]) } +func FeltFromBigInt(n *big.Int) Felt { + // Perform modulo prime + prime, _ := new(big.Int).SetString(CAIRO_PRIME_HEX, 0) + if n.Cmp(prime) != -1 { + n = new(big.Int).Mod(n, prime) + } + bytes := n.Bytes() + var bytes32 [32]byte + copy(bytes32[:], bytes) + return FeltFromLeBytes(&bytes32) +} + const CAIRO_PRIME_HEX = "0x800000000000011000000000000000000000000000000000000000000000001" const SIGNED_FELT_MAX_HEX = "0x400000000000008800000000000000000000000000000000000000000000000" diff --git a/pkg/lambdaworks/lambdaworks_test.go b/pkg/lambdaworks/lambdaworks_test.go index 8116378d..bffb7e7a 100644 --- a/pkg/lambdaworks/lambdaworks_test.go +++ b/pkg/lambdaworks/lambdaworks_test.go @@ -70,6 +70,22 @@ func TestToBigInt(t *testing.T) { } } +func TestFromBigInt(t *testing.T) { + expectedFelt := lambdaworks.FeltFromUint64(26) + bigInt := new(big.Int).SetUint64(26) + if !reflect.DeepEqual(lambdaworks.FeltFromBigInt(bigInt), expectedFelt) { + t.Errorf("TestToBigInt failed. Expected: %v, Got: %v", 26, lambdaworks.FeltFromBigInt(bigInt)) + } +} + +func TestFromBigIntPrime(t *testing.T) { + expectedFelt := lambdaworks.FeltFromDecString("0") + bigInt, _ := new(big.Int).SetString(lambdaworks.CAIRO_PRIME_HEX, 0) + if !reflect.DeepEqual(lambdaworks.FeltFromBigInt(bigInt), expectedFelt) { + t.Errorf("TestToBigInt failed. Expected: PRIME, Got: %v", lambdaworks.FeltFromBigInt(bigInt)) + } +} + func TestToSignedNegative(t *testing.T) { felt := lambdaworks.FeltFromDecString("-1") bigInt := felt.ToSigned() diff --git a/pkg/vm/cairo_run/cairo_run_test.go b/pkg/vm/cairo_run/cairo_run_test.go index 325d75db..12fabd3e 100644 --- a/pkg/vm/cairo_run/cairo_run_test.go +++ b/pkg/vm/cairo_run/cairo_run_test.go @@ -210,6 +210,14 @@ func TestSqrtHint(t *testing.T) { } } +func TestAssertLeFelt(t *testing.T) { + cairoRunConfig := cairo_run.CairoRunConfig{DisableTracePadding: false, Layout: "all_cairo", ProofMode: false} + _, err := cairo_run.CairoRun("../../../cairo_programs/assert_le_felt.json", cairoRunConfig) + if err != nil { + t.Errorf("Program execution failed with error: %s", err) + } +} + func TestMathCmp(t *testing.T) { cairoRunConfig := cairo_run.CairoRunConfig{DisableTracePadding: false, Layout: "all_cairo", ProofMode: false} _, err := cairo_run.CairoRun("../../../cairo_programs/math_cmp.json", cairoRunConfig)