From 3ab71d9d3a8dd02ea764d7bbd7e4f4e6ab22143d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Javier=20Rodr=C3=ADguez=20Chatruc?= <49622509+jrchatruc@users.noreply.github.com> Date: Tue, 19 Sep 2023 15:29:24 -0300 Subject: [PATCH] Assert 250 Bits and SPLIT_FELT (#233) * [WIP] Assert 250 Bits and SPLIT_FELT * Add tests for ASSERT_250_BIT * Implement SPLIT_FELT * Add tests for SPLIT_FELT * Incorporate the GetConst function * Skip checking for the common library path first * Remove redundant nil check --- .../assert_250_bit_element_array.cairo | 31 +++ cairo_programs/split_felt.cairo | 42 ++++ pkg/hints/hint_codes/math_hint_codes.go | 4 + pkg/hints/hint_processor.go | 4 + pkg/hints/math_hints.go | 96 ++++++++ pkg/hints/math_hints_test.go | 229 ++++++++++++++++++ pkg/vm/cairo_run/cairo_run_test.go | 16 ++ pkg/vm/program.go | 2 +- pkg/vm/program_test.go | 4 +- 9 files changed, 425 insertions(+), 3 deletions(-) create mode 100644 cairo_programs/assert_250_bit_element_array.cairo create mode 100644 cairo_programs/split_felt.cairo diff --git a/cairo_programs/assert_250_bit_element_array.cairo b/cairo_programs/assert_250_bit_element_array.cairo new file mode 100644 index 00000000..b4307e26 --- /dev/null +++ b/cairo_programs/assert_250_bit_element_array.cairo @@ -0,0 +1,31 @@ +%builtins range_check + +from starkware.cairo.common.math import assert_250_bit +from starkware.cairo.common.alloc import alloc + +func assert_250_bit_element_array{range_check_ptr: felt}( + array: felt*, array_length: felt, iterator: felt +) { + if (iterator == array_length) { + return (); + } + assert_250_bit(array[iterator]); + return assert_250_bit_element_array(array, array_length, iterator + 1); +} + +func fill_array(array: felt*, base: felt, step: felt, array_length: felt, iterator: felt) { + if (iterator == array_length) { + return (); + } + assert array[iterator] = base + step * iterator; + return fill_array(array, base, step, array_length, iterator + 1); +} + +func main{range_check_ptr: felt}() { + alloc_locals; + tempvar array_length = 10; + let (array: felt*) = alloc(); + fill_array(array, 70000000000000000000, 300000000000000000, array_length, 0); + assert_250_bit_element_array(array, array_length, 0); + return (); +} diff --git a/cairo_programs/split_felt.cairo b/cairo_programs/split_felt.cairo new file mode 100644 index 00000000..3790297d --- /dev/null +++ b/cairo_programs/split_felt.cairo @@ -0,0 +1,42 @@ +%builtins range_check + +from starkware.cairo.common.math import assert_le +from starkware.cairo.common.math import split_felt + +func split_felt_manual_implemetation{range_check_ptr}(value) -> (high: felt, low: felt) { + // Note: the following code works because PRIME - 1 is divisible by 2**128. + const MAX_HIGH = (-1) / 2 ** 128; + const MAX_LOW = 0; + + // Guess the low and high parts of the integer. + let low = [range_check_ptr]; + let high = [range_check_ptr + 1]; + let range_check_ptr = range_check_ptr + 2; + + %{ + from starkware.cairo.common.math_utils import assert_integer + assert ids.MAX_HIGH < 2**128 and ids.MAX_LOW < 2**128 + assert PRIME - 1 == ids.MAX_HIGH * 2**128 + ids.MAX_LOW + assert_integer(ids.value) + ids.low = ids.value & ((1 << 128) - 1) + ids.high = ids.value >> 128 + %} + + assert value = high * (2 ** 128) + low; + if (high == MAX_HIGH) { + assert_le(low, MAX_LOW); + } else { + assert_le(high, MAX_HIGH - 1); + } + return (high=high, low=low); +} + +func main{range_check_ptr: felt}() { + let (m, n) = split_felt_manual_implemetation(5784800237655953878877368326340059594760); + assert m = 17; + assert n = 8; + let (x, y) = split_felt(5784800237655953878877368326340059594760); + assert x = 17; + assert y = 8; + return (); +} diff --git a/pkg/hints/hint_codes/math_hint_codes.go b/pkg/hints/hint_codes/math_hint_codes.go index 0919825e..e6fb345e 100644 --- a/pkg/hints/hint_codes/math_hint_codes.go +++ b/pkg/hints/hint_codes/math_hint_codes.go @@ -18,3 +18,7 @@ else: const ASSERT_NOT_EQUAL = "from starkware.cairo.lang.vm.relocatable import RelocatableValue\nboth_ints = isinstance(ids.a, int) and isinstance(ids.b, int)\nboth_relocatable = (\n isinstance(ids.a, RelocatableValue) and isinstance(ids.b, RelocatableValue) and\n ids.a.segment_index == ids.b.segment_index)\nassert both_ints or both_relocatable, \\\n f'assert_not_equal failed: non-comparable values: {ids.a}, {ids.b}.'\nassert (ids.a - ids.b) % PRIME != 0, f'assert_not_equal failed: {ids.a} = {ids.b}.'" const SQRT = "from starkware.python.math_utils import isqrt\nvalue = ids.value % PRIME\nassert value < 2 ** 250, f\"value={value} is outside of the range [0, 2**250).\"\nassert 2 ** 250 < PRIME\nids.root = isqrt(value)" + +const ASSERT_250_BITS = "from starkware.cairo.common.math_utils import as_int\n\n# Correctness check.\nvalue = as_int(ids.value, PRIME) % PRIME\nassert value < ids.UPPER_BOUND, f'{value} is outside of the range [0, 2**250).'\n\n# Calculation for the assertion.\nids.high, ids.low = divmod(ids.value, ids.SHIFT)" + +const SPLIT_FELT = "from starkware.cairo.common.math_utils import assert_integer\nassert ids.MAX_HIGH < 2**128 and ids.MAX_LOW < 2**128\nassert PRIME - 1 == ids.MAX_HIGH * 2**128 + ids.MAX_LOW\nassert_integer(ids.value)\nids.low = ids.value & ((1 << 128) - 1)\nids.high = ids.value >> 128" diff --git a/pkg/hints/hint_processor.go b/pkg/hints/hint_processor.go index ccc26810..fb025a07 100644 --- a/pkg/hints/hint_processor.go +++ b/pkg/hints/hint_processor.go @@ -74,6 +74,10 @@ func (p *CairoVmHintProcessor) ExecuteHint(vm *vm.VirtualMachine, hintData *any, return memcpy_enter_scope(data.Ids, vm, execScopes) case VM_ENTER_SCOPE: return vm_enter_scope(execScopes) + case ASSERT_250_BITS: + return Assert250Bit(data.Ids, vm, constants) + case SPLIT_FELT: + return SplitFelt(data.Ids, vm, constants) default: return errors.Errorf("Unknown Hint: %s", data.Code) } diff --git a/pkg/hints/math_hints.go b/pkg/hints/math_hints.go index c56c0a18..b039b0a2 100644 --- a/pkg/hints/math_hints.go +++ b/pkg/hints/math_hints.go @@ -151,3 +151,99 @@ func sqrt(ids IdsManager, vm *VirtualMachine) error { ids.Insert("root", NewMaybeRelocatableFelt(root_felt), vm) return nil } + +// Implements hint: +// +// from starkware.cairo.common.math_utils import as_int +// # Correctness check. +// value = as_int(ids.value, PRIME) % PRIME +// assert value < ids.UPPER_BOUND, f'{value} is outside of the range [0, 2**250).' +// # Calculation for the assertion. +// ids.high, ids.low = divmod(ids.value, ids.SHIFT) +func Assert250Bit(ids IdsManager, vm *VirtualMachine, constants *map[string]Felt) error { + upperBound, err := ids.GetConst("UPPER_BOUND", constants) + if err != nil { + return err + } + + shift, err := ids.GetConst("SHIFT", constants) + if err != nil { + return err + } + + value, err := ids.GetFelt("value", vm) + + if err != nil { + return err + } + + if Felt.Cmp(value, upperBound) == 1 { + return errors.New("Value outside of 250 bit Range") + } + + high, low := value.DivRem(shift) + + err = ids.Insert("high", NewMaybeRelocatableFelt(high), vm) + if err != nil { + return err + } + + err = ids.Insert("low", NewMaybeRelocatableFelt(low), vm) + if err != nil { + return err + } + + return nil +} + +// Implements hint: +// +// %{ +// from starkware.cairo.common.math_utils import assert_integer +// assert ids.MAX_HIGH < 2**128 and ids.MAX_LOW < 2**128 +// assert PRIME - 1 == ids.MAX_HIGH * 2**128 + ids.MAX_LOW +// assert_integer(ids.value) +// ids.low = ids.value & ((1 << 128) - 1) +// ids.high = ids.value >> 128 +// +// %} +func SplitFelt(ids IdsManager, vm *VirtualMachine, constants *map[string]Felt) error { + maxHigh, err := ids.GetConst("MAX_HIGH", constants) + if err != nil { + return err + } + + maxLow, err := ids.GetConst("MAX_LOW", constants) + if err != nil { + return err + } + + if maxHigh.Bits() > 128 || maxLow.Bits() > 128 { + return errors.New("Assertion Failed: assert ids.MAX_HIGH < 2**128 and ids.MAX_LOW < 2**128") + } + + twoToTheOneTwentyEight := lambdaworks.FeltOne().Shl(128) + if lambdaworks.FeltFromDecString("-1") != maxHigh.Mul(twoToTheOneTwentyEight).Add(maxLow) { + return errors.New("Assertion Failed: assert PRIME - 1 == ids.MAX_HIGH * 2**128 + ids.MAX_LOW") + } + + value, err := ids.GetFelt("value", vm) + if err != nil { + return err + } + + low := value.And(twoToTheOneTwentyEight.Sub(lambdaworks.FeltOne())) + high := value.Shr(128) + + err = ids.Insert("high", NewMaybeRelocatableFelt(high), vm) + if err != nil { + return err + } + + err = ids.Insert("low", NewMaybeRelocatableFelt(low), vm) + if err != nil { + return err + } + + return nil +} diff --git a/pkg/hints/math_hints_test.go b/pkg/hints/math_hints_test.go index 54672499..5f3184ba 100644 --- a/pkg/hints/math_hints_test.go +++ b/pkg/hints/math_hints_test.go @@ -6,6 +6,7 @@ import ( . "github.com/lambdaclass/cairo-vm.go/pkg/hints" . "github.com/lambdaclass/cairo-vm.go/pkg/hints/hint_codes" . "github.com/lambdaclass/cairo-vm.go/pkg/hints/hint_utils" + "github.com/lambdaclass/cairo-vm.go/pkg/lambdaworks" . "github.com/lambdaclass/cairo-vm.go/pkg/lambdaworks" . "github.com/lambdaclass/cairo-vm.go/pkg/vm" . "github.com/lambdaclass/cairo-vm.go/pkg/vm/memory" @@ -317,3 +318,231 @@ func TestSqrtOk(t *testing.T) { t.Errorf("Expected sqrt(9) == 3. Got: %v", root) } } + +func TestAssert250BitHintSuccess(t *testing.T) { + vm := NewVirtualMachine() + vm.Segments.AddSegment() + idsManager := SetupIdsForTest( + map[string][]*MaybeRelocatable{ + "value": {NewMaybeRelocatableFelt(FeltFromUint64(3))}, + "high": {nil}, + "low": {nil}, + }, + vm, + ) + + hintProcessor := CairoVmHintProcessor{} + constants := SetupConstantsForTest(map[string]Felt{ + "UPPER_BOUND": lambdaworks.FeltFromUint64(10), + "SHIFT": lambdaworks.FeltFromUint64(1), + }, + &idsManager, + ) + + hintData := any(HintData{ + Ids: idsManager, + Code: ASSERT_250_BITS, + }) + + err := hintProcessor.ExecuteHint(vm, &hintData, &constants, nil) + if err != nil { + t.Errorf("ASSERT_250_BIT hint failed with error %s", err) + } + + high, err := idsManager.GetFelt("high", vm) + if err != nil { + t.Errorf("failed to get high: %s", err) + } + + low, err := idsManager.GetFelt("low", vm) + if err != nil { + t.Errorf("failed to get low: %s", err) + } + + if high != FeltFromUint64(3) { + t.Errorf("Expected high == 3. Got: %v", high) + } + + if low != FeltFromUint64(0) { + t.Errorf("Expected low == 0. Got: %v", low) + } +} + +func TestAssert250BitHintFail(t *testing.T) { + vm := NewVirtualMachine() + vm.Segments.AddSegment() + idsManager := SetupIdsForTest( + map[string][]*MaybeRelocatable{ + "value": {NewMaybeRelocatableFelt(FeltFromUint64(20))}, + "high": {nil}, + "low": {nil}, + }, + vm, + ) + + hintProcessor := CairoVmHintProcessor{} + constants := SetupConstantsForTest(map[string]Felt{ + "UPPER_BOUND": lambdaworks.FeltFromUint64(10), + "SHIFT": lambdaworks.FeltFromUint64(1), + }, + &idsManager, + ) + + hintData := any(HintData{ + Ids: idsManager, + Code: ASSERT_250_BITS, + }) + + err := hintProcessor.ExecuteHint(vm, &hintData, &constants, nil) + if err == nil { + t.Errorf("ASSERT_250_BIT hint should have failed with Value outside of 250 bit error") + } +} + +func TestSplitFeltAssertPrimeFailure(t *testing.T) { + vm := NewVirtualMachine() + vm.Segments.AddSegment() + idsManager := SetupIdsForTest( + map[string][]*MaybeRelocatable{ + "value": {NewMaybeRelocatableFelt(FeltFromUint64(1))}, + "high": {nil}, + "low": {nil}, + }, + vm, + ) + + hintProcessor := CairoVmHintProcessor{} + constants := SetupConstantsForTest(map[string]Felt{ + "MAX_HIGH": lambdaworks.FeltFromHex("0xffffffffffffffffffffffffffffffff"), + "MAX_LOW": lambdaworks.FeltFromHex("0xffffffffffffffffffffffffffffffff"), + }, + &idsManager, + ) + + hintData := any(HintData{ + Ids: idsManager, + Code: SPLIT_FELT, + }) + + err := hintProcessor.ExecuteHint(vm, &hintData, &constants, nil) + if err == nil { + t.Errorf("SPLIT_FELT hint should have failed with assert PRIME - 1 == ids.MAX_HIGH * 2**128 + ids.MAX_LOW error") + } +} + +func TestSplitFeltAssertMaxHighFailedAssertion(t *testing.T) { + vm := NewVirtualMachine() + vm.Segments.AddSegment() + idsManager := SetupIdsForTest( + map[string][]*MaybeRelocatable{ + "value": {NewMaybeRelocatableFelt(FeltFromUint64(1))}, + "high": {nil}, + "low": {nil}, + }, + vm, + ) + + hintProcessor := CairoVmHintProcessor{} + constants := SetupConstantsForTest(map[string]Felt{ + "MAX_HIGH": lambdaworks.FeltFromHex("0xffffffffffffffffffffffffffffffffffff"), + "MAX_LOW": lambdaworks.FeltFromHex("0xffffffffffffffffffffffffffffffff"), + }, + &idsManager, + ) + + hintData := any(HintData{ + Ids: idsManager, + Code: SPLIT_FELT, + }) + + err := hintProcessor.ExecuteHint(vm, &hintData, &constants, nil) + if err == nil { + t.Errorf("SPLIT_FELT hint should have failed with assert ids.MAX_HIGH < 2**128 and ids.MAX_LOW < 2**128") + } +} + +func TestSplitFeltAssertMaxLowFailedAssertion(t *testing.T) { + vm := NewVirtualMachine() + vm.Segments.AddSegment() + idsManager := SetupIdsForTest( + map[string][]*MaybeRelocatable{ + "value": {NewMaybeRelocatableFelt(FeltFromUint64(1))}, + "high": {nil}, + "low": {nil}, + }, + vm, + ) + + hintProcessor := CairoVmHintProcessor{} + constants := SetupConstantsForTest(map[string]Felt{ + "MAX_HIGH": lambdaworks.FeltFromHex("0xffffffffffffffffffffffffffffffff"), + "MAX_LOW": lambdaworks.FeltFromHex("0xffffffffffffffffffffffffffffffffffff"), + }, + &idsManager, + ) + + hintData := any(HintData{ + Ids: idsManager, + Code: SPLIT_FELT, + }) + + err := hintProcessor.ExecuteHint(vm, &hintData, &constants, nil) + if err == nil { + t.Errorf("SPLIT_FELT hint should have failed with assert ids.MAX_HIGH < 2**128 and ids.MAX_LOW < 2**128") + } +} + +func TestSplitFeltSuccess(t *testing.T) { + vm := NewVirtualMachine() + vm.Segments.AddSegment() + + firstLimb := lambdaworks.FeltFromUint64(1) + secondLimb := lambdaworks.FeltFromUint64(2) + thirdLimb := lambdaworks.FeltFromUint64(3) + fourthLimb := lambdaworks.FeltFromUint64(4) + value := fourthLimb.Or(thirdLimb.Shl(64).Or(secondLimb.Shl(128).Or(firstLimb.Shl(192)))) + idsManager := SetupIdsForTest( + map[string][]*MaybeRelocatable{ + "value": {NewMaybeRelocatableFelt(value)}, + "high": {nil}, + "low": {nil}, + }, + vm, + ) + + hintProcessor := CairoVmHintProcessor{} + constants := SetupConstantsForTest(map[string]Felt{ + "MAX_HIGH": lambdaworks.FeltFromDecString("10633823966279327296825105735305134080"), + "MAX_LOW": lambdaworks.FeltFromUint64(0), + }, + &idsManager, + ) + + hintData := any(HintData{ + Ids: idsManager, + Code: SPLIT_FELT, + }) + + err := hintProcessor.ExecuteHint(vm, &hintData, &constants, nil) + if err != nil { + t.Errorf("SPLIT_FELT hint failed with error %s", err) + } + + high, err := idsManager.GetFelt("high", vm) + if err != nil { + t.Errorf("failed to get high: %s", err) + } + + low, err := idsManager.GetFelt("low", vm) + if err != nil { + t.Errorf("failed to get low: %s", err) + } + + if high != firstLimb.Shl(64).Or(secondLimb) { + t.Errorf("Expected high == 335438970432432812899076431678123043273. Got: %v", high) + } + + if low != thirdLimb.Shl(64).Or(fourthLimb) { + t.Errorf("Expected low == 0. Got: %v", low) + } +} diff --git a/pkg/vm/cairo_run/cairo_run_test.go b/pkg/vm/cairo_run/cairo_run_test.go index 3182e252..c5985bc3 100644 --- a/pkg/vm/cairo_run/cairo_run_test.go +++ b/pkg/vm/cairo_run/cairo_run_test.go @@ -202,3 +202,19 @@ func TestSqrtHint(t *testing.T) { t.Errorf("Program execution failed with error: %s", err) } } + +func TestAssert250BitHint(t *testing.T) { + cairoRunConfig := cairo_run.CairoRunConfig{DisableTracePadding: false, Layout: "all_cairo", ProofMode: false} + _, err := cairo_run.CairoRun("../../../cairo_programs/assert_250_bit_element_array.json", cairoRunConfig) + if err != nil { + t.Errorf("Program execution failed with error: %s", err) + } +} + +func TestSplitFeltHint(t *testing.T) { + cairoRunConfig := cairo_run.CairoRunConfig{DisableTracePadding: false, Layout: "all_cairo", ProofMode: false} + _, err := cairo_run.CairoRun("../../../cairo_programs/split_felt.json", cairoRunConfig) + if err != nil { + t.Errorf("Program execution failed with error: %s", err) + } +} diff --git a/pkg/vm/program.go b/pkg/vm/program.go index 125860fb..b72707ee 100644 --- a/pkg/vm/program.go +++ b/pkg/vm/program.go @@ -64,7 +64,7 @@ func DeserializeProgramJson(compiledProgram parser.CompiledJson) Program { func (p *Program) ExtractConstants() map[string]lambdaworks.Felt { constants := make(map[string]lambdaworks.Felt) for name, identifier := range p.Identifiers { - if identifier.Type == "constant" { + if identifier.Type == "const" { constants[name] = identifier.Value } } diff --git a/pkg/vm/program_test.go b/pkg/vm/program_test.go index 95371ce6..bf8d4fff 100644 --- a/pkg/vm/program_test.go +++ b/pkg/vm/program_test.go @@ -34,11 +34,11 @@ func TestExtractConstants(t *testing.T) { }, "A": { Value: lambdaworks.FeltFromUint64(7), - Type: "constant", + Type: "const", }, "B": { Value: lambdaworks.FeltFromUint64(17), - Type: "constant", + Type: "const", }, }, }