diff --git a/cairo_programs/is_quad_residue.cairo b/cairo_programs/is_quad_residue.cairo new file mode 100644 index 00000000..7b78ec6e --- /dev/null +++ b/cairo_programs/is_quad_residue.cairo @@ -0,0 +1,43 @@ +%builtins output +from starkware.cairo.common.serialize import serialize_word +from starkware.cairo.common.math import is_quad_residue +from starkware.cairo.common.alloc import alloc + +func fill_array(array_start: felt*, iter: felt) -> () { + if (iter == 8) { + return (); + } + assert array_start[iter] = iter; + return fill_array(array_start, iter + 1); +} + +func check_quad_res{output_ptr: felt*}(inputs: felt*, expected: felt*, iter: felt) { + if (iter == 8) { + return (); + } + serialize_word(inputs[iter]); + serialize_word(expected[iter]); + + assert is_quad_residue(inputs[iter]) = expected[iter]; + return check_quad_res(inputs, expected, iter + 1); +} + +func main{output_ptr: felt*}() { + alloc_locals; + let (inputs: felt*) = alloc(); + fill_array(inputs, 0); + + let (expected: felt*) = alloc(); + assert expected[0] = 1; + assert expected[1] = 1; + assert expected[2] = 1; + assert expected[3] = 0; + assert expected[4] = 1; + assert expected[5] = 1; + assert expected[6] = 0; + assert expected[7] = 1; + + check_quad_res(inputs, expected, 0); + + return (); +} diff --git a/cairo_programs/proof_programs/is_quad_residue.cairo b/cairo_programs/proof_programs/is_quad_residue.cairo new file mode 100644 index 00000000..7b78ec6e --- /dev/null +++ b/cairo_programs/proof_programs/is_quad_residue.cairo @@ -0,0 +1,43 @@ +%builtins output +from starkware.cairo.common.serialize import serialize_word +from starkware.cairo.common.math import is_quad_residue +from starkware.cairo.common.alloc import alloc + +func fill_array(array_start: felt*, iter: felt) -> () { + if (iter == 8) { + return (); + } + assert array_start[iter] = iter; + return fill_array(array_start, iter + 1); +} + +func check_quad_res{output_ptr: felt*}(inputs: felt*, expected: felt*, iter: felt) { + if (iter == 8) { + return (); + } + serialize_word(inputs[iter]); + serialize_word(expected[iter]); + + assert is_quad_residue(inputs[iter]) = expected[iter]; + return check_quad_res(inputs, expected, iter + 1); +} + +func main{output_ptr: felt*}() { + alloc_locals; + let (inputs: felt*) = alloc(); + fill_array(inputs, 0); + + let (expected: felt*) = alloc(); + assert expected[0] = 1; + assert expected[1] = 1; + assert expected[2] = 1; + assert expected[3] = 0; + assert expected[4] = 1; + assert expected[5] = 1; + assert expected[6] = 0; + assert expected[7] = 1; + + check_quad_res(inputs, expected, 0); + + return (); +} diff --git a/pkg/hints/hint_processor.go b/pkg/hints/hint_processor.go index 3b932e25..63c78a8a 100644 --- a/pkg/hints/hint_processor.go +++ b/pkg/hints/hint_processor.go @@ -29,7 +29,7 @@ func (p *CairoVmHintProcessor) CompileHint(hintParams *parser.HintParams, refere name = split[len(split)-1] references[name] = ParseHintReference(referenceManager.References[n]) } - ids := NewIdsManager(references, hintParams.FlowTrackingData.APTracking) + ids := NewIdsManager(references, hintParams.FlowTrackingData.APTracking, hintParams.AccessibleScopes) return HintData{Ids: ids, Code: hintParams.Code}, nil } @@ -47,6 +47,8 @@ func (p *CairoVmHintProcessor) ExecuteHint(vm *vm.VirtualMachine, hintData *any, return is_positive(data.Ids, vm) case ASSERT_NOT_ZERO: return assert_not_zero(data.Ids, vm) + case IS_QUAD_RESIDUE: + return is_quad_residue(data.Ids, vm) case DEFAULT_DICT_NEW: return defaultDictNew(data.Ids, execScopes, vm) case DICT_READ: diff --git a/pkg/hints/hint_utils/ids_manager.go b/pkg/hints/hint_utils/ids_manager.go index ac7e8bd9..48affa6a 100644 --- a/pkg/hints/hint_utils/ids_manager.go +++ b/pkg/hints/hint_utils/ids_manager.go @@ -12,8 +12,9 @@ import ( // Identifier Manager // Provides methods that allow hints to interact with cairo variables given their identifier name type IdsManager struct { - References map[string]HintReference - HintApTracking parser.ApTrackingData + References map[string]HintReference + HintApTracking parser.ApTrackingData + AccessibleScopes []string } func ErrIdsManager(err error) error { @@ -28,13 +29,30 @@ func ErrIdentifierNotFelt(name string) error { return ErrIdsManager(errors.Errorf("Identifier %s is not a Felt", name)) } -func NewIdsManager(references map[string]HintReference, hintApTracking parser.ApTrackingData) IdsManager { +func NewIdsManager(references map[string]HintReference, hintApTracking parser.ApTrackingData, accessibleScopes []string) IdsManager { return IdsManager{ - References: references, - HintApTracking: hintApTracking, + References: references, + HintApTracking: hintApTracking, + AccessibleScopes: accessibleScopes, } } +// Fetches a constant used by the hint +// Searches inner modules first for name-matching constants +func (ids *IdsManager) GetConst(name string, constants *map[string]lambdaworks.Felt) (lambdaworks.Felt, error) { + // Hints should always have accessible scopes + if len(ids.AccessibleScopes) != 0 { + // Accessible scopes are listed from outer to inner + for i := len(ids.AccessibleScopes) - 1; i >= 0; i-- { + constant, ok := (*constants)[ids.AccessibleScopes[i]+"."+name] + if ok { + return constant, nil + } + } + } + return lambdaworks.FeltZero(), errors.Errorf("Missing constant %s", name) +} + // Inserts value into memory given its identifier name func (ids *IdsManager) Insert(name string, value *MaybeRelocatable, vm *VirtualMachine) error { diff --git a/pkg/hints/hint_utils/ids_manager_test.go b/pkg/hints/hint_utils/ids_manager_test.go index 6fd3f234..5dc69b12 100644 --- a/pkg/hints/hint_utils/ids_manager_test.go +++ b/pkg/hints/hint_utils/ids_manager_test.go @@ -285,3 +285,55 @@ func TestIdsManagerGetStructFieldTest(t *testing.T) { t.Errorf("IdsManager.GetStructFieldFelt returned wrong values") } } + +func TestIdsManagerGetConst(t *testing.T) { + ids := IdsManager{ + AccessibleScopes: []string{ + "starkware.cairo.common.math", + "starkware.cairo.common.math.assert_250_bit", + }, + } + upperBound := lambdaworks.FeltFromUint64(250) + constants := map[string]lambdaworks.Felt{ + "starkware.cairo.common.math.assert_250_bit.UPPER_BOUND": upperBound, + } + constant, err := ids.GetConst("UPPER_BOUND", &constants) + if err != nil || constant != upperBound { + t.Errorf("IdsManager.GetConst returned wrong/no constant") + } +} + +func TestIdsManagerGetConstPrioritizeInnerModule(t *testing.T) { + ids := IdsManager{ + AccessibleScopes: []string{ + "starkware.cairo.common.math", + "starkware.cairo.common.math.assert_250_bit", + }, + } + upperBound := lambdaworks.FeltFromUint64(250) + constants := map[string]lambdaworks.Felt{ + "starkware.cairo.common.math.assert_250_bit.UPPER_BOUND": upperBound, + "starkware.cairo.common.math.UPPER_BOUND": lambdaworks.FeltZero(), + } + constant, err := ids.GetConst("UPPER_BOUND", &constants) + if err != nil || constant != upperBound { + t.Errorf("IdsManager.GetConst returned wrong/no constant") + } +} + +func TestIdsManagerGetConstNoMConst(t *testing.T) { + ids := IdsManager{ + AccessibleScopes: []string{ + "starkware.cairo.common.math", + "starkware.cairo.common.math.assert_250_bit", + }, + } + lowerBound := lambdaworks.FeltFromUint64(250) + constants := map[string]lambdaworks.Felt{ + "starkware.cairo.common.math.assert_250_bit.LOWER_BOUND": lowerBound, + } + _, err := ids.GetConst("UPPER_BOUND", &constants) + if err == nil { + t.Errorf("IdsManager.GetConst should have failed") + } +} diff --git a/pkg/hints/hint_utils/testing_utils.go b/pkg/hints/hint_utils/testing_utils.go index ab2470cd..8c7f1f5c 100644 --- a/pkg/hints/hint_utils/testing_utils.go +++ b/pkg/hints/hint_utils/testing_utils.go @@ -1,6 +1,7 @@ package hint_utils import ( + "github.com/lambdaclass/cairo-vm.go/pkg/lambdaworks" "github.com/lambdaclass/cairo-vm.go/pkg/parser" . "github.com/lambdaclass/cairo-vm.go/pkg/vm" "github.com/lambdaclass/cairo-vm.go/pkg/vm/memory" @@ -17,7 +18,7 @@ import ( // All references will be FP-based, so please don't update the value of FP after calling this function, // and make sure that the memory at fp's segment is clear from its current offset onwards func SetupIdsForTest(ids map[string][]*memory.MaybeRelocatable, vm *VirtualMachine) IdsManager { - manager := NewIdsManager(make(map[string]HintReference), parser.ApTrackingData{}) + manager := NewIdsManager(make(map[string]HintReference), parser.ApTrackingData{}, []string{}) base_addr := vm.RunContext.Fp current_offset := 0 for name, elems := range ids { @@ -43,3 +44,14 @@ func SetupIdsForTest(ids map[string][]*memory.MaybeRelocatable, vm *VirtualMachi } return manager } + +// Returns a constants map accoring to the new_constants map received +// Adds a path to each constant and a matching path to the hint's accessible scopes +func SetupConstantsForTest(new_constants map[string]lambdaworks.Felt, ids *IdsManager) map[string]lambdaworks.Felt { + constants := make(map[string]lambdaworks.Felt) + ids.AccessibleScopes = append(ids.AccessibleScopes, "path") + for name, constant := range new_constants { + constants["path."+name] = constant + } + return constants +} diff --git a/pkg/hints/hint_utils/testing_utils_test.go b/pkg/hints/hint_utils/testing_utils_test.go index a9c19945..4875a1a1 100644 --- a/pkg/hints/hint_utils/testing_utils_test.go +++ b/pkg/hints/hint_utils/testing_utils_test.go @@ -108,3 +108,24 @@ func TestSetupIdsForTestStructWithGap(t *testing.T) { t.Error("Failed to insert ids") } } + +func TestSetupConstantsForTest(t *testing.T) { + constA := FeltOne() + constB := FeltZero() + IdsManager := IdsManager{} + constants := SetupConstantsForTest(map[string]Felt{ + "A": constA, + "B": constB, + }, + &IdsManager, + ) + // Check that we can fetch the constants + a, err := IdsManager.GetConst("A", &constants) + if err != nil || a != constA { + t.Error("SetupConstantsForTest wrong/no A") + } + b, err := IdsManager.GetConst("B", &constants) + if err != nil || b != constB { + t.Error("SetupConstantsForTest wrong/no B") + } +} diff --git a/pkg/hints/math_hint_codes.go b/pkg/hints/math_hint_codes.go index fcec4449..553bcde1 100644 --- a/pkg/hints/math_hint_codes.go +++ b/pkg/hints/math_hint_codes.go @@ -6,6 +6,15 @@ const IS_POSITIVE = "from starkware.cairo.common.math_utils import is_positive\n const ASSERT_NOT_ZERO = "from starkware.cairo.common.math_utils import assert_integer\nassert_integer(ids.value)\nassert ids.value % PRIME != 0, f'assert_not_zero failed: {ids.value} = 0.'" +const IS_QUAD_RESIDUE = `from starkware.crypto.signature.signature import FIELD_PRIME +from starkware.python.math_utils import div_mod, is_quad_residue, sqrt + +x = ids.x +if is_quad_residue(x, FIELD_PRIME): + ids.y = sqrt(x, FIELD_PRIME) +else: + ids.y = sqrt(div_mod(x, 3, FIELD_PRIME), FIELD_PRIME)` + const ASSERT_NOT_EQUAL = "from starkware.cairo.lang.vm.relocatable import RelocatableValue\nboth_ints = isinstance(ids.a, int) and isinstance(ids.b, int)\nboth_relocatable = (\n isinstance(ids.a, RelocatableValue) and isinstance(ids.b, RelocatableValue) and\n ids.a.segment_index == ids.b.segment_index)\nassert both_ints or both_relocatable, \\\n f'assert_not_equal failed: non-comparable values: {ids.a}, {ids.b}.'\nassert (ids.a - ids.b) % PRIME != 0, f'assert_not_equal failed: {ids.a} = {ids.b}.'" const SQRT = "from starkware.python.math_utils import isqrt\nvalue = ids.value % PRIME\nassert value < 2 ** 250, f\"value={value} is outside of the range [0, 2**250).\"\nassert 2 ** 250 < PRIME\nids.root = isqrt(value)" diff --git a/pkg/hints/math_hints.go b/pkg/hints/math_hints.go index b303d95c..c56c0a18 100644 --- a/pkg/hints/math_hints.go +++ b/pkg/hints/math_hints.go @@ -3,6 +3,7 @@ package hints import ( "github.com/lambdaclass/cairo-vm.go/pkg/builtins" . "github.com/lambdaclass/cairo-vm.go/pkg/hints/hint_utils" + "github.com/lambdaclass/cairo-vm.go/pkg/lambdaworks" . "github.com/lambdaclass/cairo-vm.go/pkg/lambdaworks" . "github.com/lambdaclass/cairo-vm.go/pkg/math_utils" . "github.com/lambdaclass/cairo-vm.go/pkg/vm" @@ -65,6 +66,38 @@ func assert_not_zero(ids IdsManager, vm *VirtualMachine) error { return nil } +// Implements hint:from starkware.cairo.common.math.cairo +// +// %{ +// from starkware.crypto.signature.signature import FIELD_PRIME +// from starkware.python.math_utils import div_mod, is_quad_residue, sqrt +// +// x = ids.x +// if is_quad_residue(x, FIELD_PRIME): +// ids.y = sqrt(x, FIELD_PRIME) +// else: +// ids.y = sqrt(div_mod(x, 3, FIELD_PRIME), FIELD_PRIME) +// +// %} +func is_quad_residue(ids IdsManager, vm *VirtualMachine) error { + x, err := ids.GetFelt("x", vm) + if err != nil { + return err + } + if x.IsZero() || x.IsOne() { + ids.Insert("y", NewMaybeRelocatableFelt(x), vm) + + } else if x.Pow(SignedFeltMaxValue()) == FeltOne() { + num := x.Sqrt() + ids.Insert("y", NewMaybeRelocatableFelt(num), vm) + + } else { + num := (x.Div(lambdaworks.FeltFromUint64(3))).Sqrt() + ids.Insert("y", NewMaybeRelocatableFelt(num), vm) + } + return nil +} + func assert_not_equal(ids IdsManager, vm *VirtualMachine) error { // Extract Ids Variables a, err := ids.Get("a", vm) diff --git a/pkg/hints/memcpy_hints_test.go b/pkg/hints/memcpy_hints_test.go index 4005b51a..941216f3 100644 --- a/pkg/hints/memcpy_hints_test.go +++ b/pkg/hints/memcpy_hints_test.go @@ -7,6 +7,7 @@ import ( . "github.com/lambdaclass/cairo-vm.go/pkg/hints/hint_utils" . "github.com/lambdaclass/cairo-vm.go/pkg/lambdaworks" . "github.com/lambdaclass/cairo-vm.go/pkg/types" + . "github.com/lambdaclass/cairo-vm.go/pkg/utils" . "github.com/lambdaclass/cairo-vm.go/pkg/vm" . "github.com/lambdaclass/cairo-vm.go/pkg/vm/memory" ) diff --git a/pkg/lambdaworks/lambdaworks.go b/pkg/lambdaworks/lambdaworks.go index 66b40b81..475f2717 100644 --- a/pkg/lambdaworks/lambdaworks.go +++ b/pkg/lambdaworks/lambdaworks.go @@ -144,6 +144,13 @@ func FeltOne() Felt { return fromC(result) } +// Gets the Signed Felt max value: 0x400000000000008800000000000000000000000000000000000000000000000 +func SignedFeltMaxValue() Felt { + var result C.felt_t + C.signed_felt_max_value(&result[0]) + return fromC(result) +} + func (f Felt) IsZero() bool { return f == FeltZero() } @@ -152,6 +159,10 @@ func (f Felt) IsPositive() bool { return !f.IsZero() } +func (f Felt) IsOne() bool { + return f == FeltOne() +} + // Writes the result variable with the sum of a and b felts. func (a Felt) Add(b Felt) Felt { var result C.felt_t @@ -246,6 +257,23 @@ func (a Felt) PowUint(p uint32) Felt { return fromC(result) } +func (a Felt) Pow(p Felt) Felt { + var result C.felt_t + var a_c C.felt_t = a.toC() + var p_c C.felt_t = p.toC() + + C.felt_pow(&a_c[0], &p_c[0], &result[0]) + return fromC(result) +} + +func (a Felt) Sqrt() Felt { + var result C.felt_t + var a_c C.felt_t = a.toC() + + C.felt_sqrt(&a_c[0], &result[0]) + return fromC(result) +} + func (a Felt) Shr(b uint) Felt { var result C.felt_t var a_c C.felt_t = a.toC() diff --git a/pkg/lambdaworks/lambdaworks_test.go b/pkg/lambdaworks/lambdaworks_test.go index 1c049d37..8116378d 100644 --- a/pkg/lambdaworks/lambdaworks_test.go +++ b/pkg/lambdaworks/lambdaworks_test.go @@ -466,6 +466,128 @@ func TestPow3(t *testing.T) { } } +func TestPowFelt(t *testing.T) { + felt_base := lambdaworks.FeltFromUint64(1233) + felt_exp := lambdaworks.FeltFromUint64(1233) + + expected := lambdaworks.FeltFromDecString("3418065535446855313238995939000463244303872344528900201124636596003468607918") + result := felt_base.Pow(felt_exp) + + if expected != result { + t.Errorf("TestPowFelt Failed, expecte: %v, got %v", expected, result) + } + + felt_base = lambdaworks.FeltFromDecString("12383109480418712378780123") + felt_exp = lambdaworks.FeltFromDecString("91872587643897123781098123") + + expected = lambdaworks.FeltFromDecString("2088955439096022421017346644949649198425019274657075865926754962561596407882") + result = felt_base.Pow(felt_exp) + + if expected != result { + t.Errorf("TestPowFelt Failed, expecte: %v, got %v", expected, result) + } + + felt_base = lambdaworks.FeltFromDecString("1480418712378780123123543345665445665445") + felt_exp = lambdaworks.FeltFromDecString("91872587643897345876123781098124353") + + expected = lambdaworks.FeltFromDecString("3250055959035395902088721634924698439245455440785258481507488871970708539723") + result = felt_base.Pow(felt_exp) + + if expected != result { + t.Errorf("TestPowFelt Failed, expecte: %v, got %v", expected, result) + } + + felt_base = lambdaworks.FeltFromDecString("3250055959035395902088721634924698439245455440785258481507488871970708539723") + felt_exp = lambdaworks.FeltFromDecString("2088955439096022421017346644949649198425019274657075865926754962561596407882") + + expected = lambdaworks.FeltFromDecString("2222900320242877003674481253396117682567674359625426155657415083745164507492") + result = felt_base.Pow(felt_exp) + + if expected != result { + t.Errorf("TestPowFelt Failed, expecte: %v, got %v", expected, result) + } + + felt_base = lambdaworks.FeltFromDecString("3") + felt_exp = lambdaworks.FeltFromDecString("1809251394333065606848661391547535052811553607665798349986546028067936010240") + + expected = lambdaworks.FeltFromDecString("3618502788666131213697322783095070105623107215331596699973092056135872020480") + result = felt_base.Pow(felt_exp) + + if expected != result { + t.Errorf("TestPowFelt Failed, expecte: %v, got %v", expected, result) + } + + felt_base = lambdaworks.FeltFromDecString("6") + felt_exp = lambdaworks.FeltFromDecString("1809251394333065606848661391547535052811553607665798349986546028067936010240") + + expected = lambdaworks.FeltFromDecString("3618502788666131213697322783095070105623107215331596699973092056135872020480") + result = felt_base.Pow(felt_exp) + + if expected != result { + t.Errorf("TestPowFelt Failed, expecte: %v, got %v", expected, result) + } +} + +func TestSqrt(t *testing.T) { + + sqrt := lambdaworks.FeltFromDecString("1").Sqrt() + expect_res := lambdaworks.FeltFromDecString("1") + + if sqrt != expect_res { + t.Errorf("TestSqrt Failed, expecte: %v, got %v", expect_res, sqrt) + } + + sqrt = lambdaworks.FeltFromDecString("2").Sqrt() + expect_res = lambdaworks.FeltFromDecString("1120755473020101814179135767224264702961552391386192943129361948990833801454") + + if sqrt != expect_res { + t.Errorf("TestSqrt Failed, expecte: %v, got %v", expect_res, sqrt) + } + + sqrt = lambdaworks.FeltFromDecString("231354855").Sqrt() + expect_res = lambdaworks.FeltFromDecString("1025311277904211196612478135732240927612998008429122495456758581279557012570") + + if sqrt != expect_res { + t.Errorf("TestSqrt Failed, expecte: %v, got %v", expect_res, sqrt) + } + + sqrt = lambdaworks.FeltFromDecString("2837690996375263304037947136281").Sqrt() + expect_res = lambdaworks.FeltFromDecString("1518810120662201067233534392916286105989903317885218522014371199182069054395") + + if sqrt != expect_res { + t.Errorf("TestSqrt Failed, expecte: %v, got %v", expect_res, sqrt) + } + + sqrt = lambdaworks.FeltFromDecString("2412335192444087475798215188730046737082071476887756022673614366244186268173").Sqrt() + expect_res = lambdaworks.FeltFromDecString("1600105265616524426130944162206101590464382512931039575828824593677922684056") + + if sqrt != expect_res { + t.Errorf("TestSqrt Failed, expecte: %v, got %v", expect_res, sqrt) + } + + sqrt = lambdaworks.FeltFromDecString("836397911567565091").Sqrt() + expect_res = lambdaworks.FeltFromDecString("1471326547166706568879530640427725594549306523774764149866072915947254299525") + + if sqrt != expect_res { + t.Errorf("TestSqrt Failed, expecte: %v, got %v", expect_res, sqrt) + } + + sqrt = lambdaworks.FeltFromDecString("1206167596222043737899107594365023368541035738443865566657697352047277454118").Sqrt() + expect_res = lambdaworks.FeltFromDecString("1052329372911162474471895538435386694104976874815914718986386439764768300074") + + if sqrt != expect_res { + t.Errorf("TestSqrt Failed, expecte: %v, got %v", expect_res, sqrt) + } + + sqrt = lambdaworks.FeltFromDecString("1206167596222043737899107594365023368541035738443865566948239979431619043114").Sqrt() + expect_res = lambdaworks.FeltFromDecString("139198744922466627270517589217125805480206233967015957629136270350373167196") + + if sqrt != expect_res { + t.Errorf("TestSqrt Failed, expecte: %v, got %v", expect_res, sqrt) + } + +} + func TestFeltNeg1ToString(t *testing.T) { f_neg_1 := lambdaworks.FeltFromDecString("-1") expected := "-1" @@ -520,3 +642,15 @@ func TestRelocatableToString(t *testing.T) { } } + +func TestSignedMaxValue(t *testing.T) { + + signed_max_value := lambdaworks.SignedFeltMaxValue() + str := signed_max_value.ToHexString() + expect_str := "0x400000000000008800000000000000000000000000000000000000000000000" + + if signed_max_value.ToHexString() != expect_str { + t.Errorf("TestSignedMaxValue Failed, expecte: %s, got %s", expect_str, str) + } + +} diff --git a/pkg/lambdaworks/lib/lambdaworks.h b/pkg/lambdaworks/lib/lambdaworks.h index 52982dd6..573a8825 100644 --- a/pkg/lambdaworks/lib/lambdaworks.h +++ b/pkg/lambdaworks/lib/lambdaworks.h @@ -43,6 +43,8 @@ void zero(felt_t result); /* Gets a felt_t representing 1 */ void one(felt_t result); +void signed_felt_max_value(felt_t result); + /* Writes the result variable with the sum of a and b felts. */ void add(felt_t a, felt_t b, felt_t result); @@ -73,6 +75,11 @@ void felt_shl(felt_t a, uint64_t num, felt_t result); /* writes the result variable with a.pow(num) */ void felt_pow_uint(felt_t a, uint32_t num, felt_t result); +/* writes the result variable with a.pow(exponent) */ +void felt_pow(felt_t a, felt_t p, felt_t result); + +void felt_sqrt(felt_t a, felt_t result); + /* returns the representation of a felt to string */ char *to_signed_felt(felt_t value); diff --git a/pkg/lambdaworks/lib/lambdaworks/src/lib.rs b/pkg/lambdaworks/lib/lambdaworks/src/lib.rs index f7f14612..b772d13d 100644 --- a/pkg/lambdaworks/lib/lambdaworks/src/lib.rs +++ b/pkg/lambdaworks/lib/lambdaworks/src/lib.rs @@ -132,6 +132,11 @@ pub extern "C" fn one(result: Limbs) { felt_to_limbs(Felt::one(), result) } +#[no_mangle] +pub extern "C" fn signed_felt_max_value(result: Limbs) { + felt_to_limbs(Felt::from_bytes_be(&*SIGNED_FELT_MAX.to_bytes_be()).unwrap(), result) +} + #[no_mangle] pub extern "C" fn add(a: Limbs, b: Limbs, result: Limbs) { felt_to_limbs(limbs_to_felt(a) + limbs_to_felt(b), result); @@ -204,6 +209,23 @@ pub extern "C" fn felt_pow_uint(a: Limbs, num: u32, result: Limbs) { felt_to_limbs(res, result) } +#[no_mangle] +pub extern "C" fn felt_pow(a: Limbs, exponent: Limbs, result: Limbs) { + let felt_a = limbs_to_felt(a); + let felt_exponent = limbs_to_felt(exponent).representative(); + let res = felt_a.pow(felt_exponent); + felt_to_limbs(res, result) +} + +#[no_mangle] +pub extern "C" fn felt_sqrt(a: Limbs,result: Limbs) { + let felt_a = limbs_to_felt(a); + + let (root_1, root_2) = felt_a.sqrt().unwrap(); + let res = root_1.representative().min(root_2.representative()); + felt_to_limbs(Felt::from(&res), result) +} + #[no_mangle] pub extern "C" fn to_signed_felt(value: Limbs) -> *mut c_char { let felt = limbs_to_felt(value).representative().to_bytes_le(); @@ -235,7 +257,7 @@ pub unsafe extern "C" fn free_string(ptr: *mut c_char) { pub extern "C" fn felt_shr(a: Limbs, b: usize, result: Limbs) { let felt_a = limbs_to_felt(a).representative(); - let res = felt_a << b; + let res = felt_a >> b; felt_to_limbs(Felt::from(&res), result) } diff --git a/pkg/vm/cairo_run/cairo_run_test.go b/pkg/vm/cairo_run/cairo_run_test.go index 5cf14aff..3182e252 100644 --- a/pkg/vm/cairo_run/cairo_run_test.go +++ b/pkg/vm/cairo_run/cairo_run_test.go @@ -147,6 +147,22 @@ func TestBitwiseRecursionProofMode(t *testing.T) { } } +func TestIsQuadResidueoHint(t *testing.T) { + cairoRunConfig := cairo_run.CairoRunConfig{DisableTracePadding: false, Layout: "small", ProofMode: false} + _, err := cairo_run.CairoRun("../../../cairo_programs/is_quad_residue.json", cairoRunConfig) + if err != nil { + t.Errorf("Program execution failed with error: %s", err) + } +} + +func TestIsQuadResidueoHintProofMode(t *testing.T) { + cairoRunConfig := cairo_run.CairoRunConfig{DisableTracePadding: false, Layout: "small", ProofMode: true} + _, err := cairo_run.CairoRun("../../../cairo_programs/proof_programs/is_quad_residue.json", cairoRunConfig) + if err != nil { + t.Errorf("Program execution failed with error: %s", err) + } +} + func TestDict(t *testing.T) { cairoRunConfig := cairo_run.CairoRunConfig{DisableTracePadding: false, Layout: "small", ProofMode: false} _, err := cairo_run.CairoRun("../../../cairo_programs/dict.json", cairoRunConfig)