diff --git a/cairo_programs/div_mod_n.cairo b/cairo_programs/div_mod_n.cairo new file mode 100644 index 00000000..4dbe3c82 --- /dev/null +++ b/cairo_programs/div_mod_n.cairo @@ -0,0 +1,129 @@ +%builtins range_check + +from starkware.cairo.common.cairo_secp.bigint import BigInt3, nondet_bigint3, BASE, bigint_mul +from starkware.cairo.common.cairo_secp.constants import BETA, N0, N1, N2 + +// Source: https://github.com/myBraavos/efficient-secp256r1/blob/73cca4d53730cb8b2dcf34e36c7b8f34b96b3230/src/secp256r1/signature.cairo + +// Computes a * b^(-1) modulo the size of the elliptic curve (N). +// +// Prover assumptions: +// * All the limbs of a are in the range (-2 ** 210.99, 2 ** 210.99). +// * All the limbs of b are in the range (-2 ** 124.99, 2 ** 124.99). +// * b is in the range [0, 2 ** 256). +// +// Soundness assumptions: +// * The limbs of a are in the range (-2 ** 249, 2 ** 249). +// * The limbs of b are in the range (-2 ** 159.83, 2 ** 159.83). +func div_mod_n{range_check_ptr}(a: BigInt3, b: BigInt3) -> (res: BigInt3) { + %{ + from starkware.cairo.common.cairo_secp.secp_utils import N, pack + from starkware.python.math_utils import div_mod, safe_div + + a = pack(ids.a, PRIME) + b = pack(ids.b, PRIME) + value = res = div_mod(a, b, N) + %} + let (res) = nondet_bigint3(); + + %{ value = k_plus_one = safe_div(res * b - a, N) + 1 %} + let (k_plus_one) = nondet_bigint3(); + let k = BigInt3(d0=k_plus_one.d0 - 1, d1=k_plus_one.d1, d2=k_plus_one.d2); + + let (res_b) = bigint_mul(res, b); + let n = BigInt3(N0, N1, N2); + let (k_n) = bigint_mul(k, n); + + // We should now have res_b = k_n + a. Since the numbers are in unreduced form, + // we should handle the carry. + + tempvar carry1 = (res_b.d0 - k_n.d0 - a.d0) / BASE; + assert [range_check_ptr + 0] = carry1 + 2 ** 127; + + tempvar carry2 = (res_b.d1 - k_n.d1 - a.d1 + carry1) / BASE; + assert [range_check_ptr + 1] = carry2 + 2 ** 127; + + tempvar carry3 = (res_b.d2 - k_n.d2 - a.d2 + carry2) / BASE; + assert [range_check_ptr + 2] = carry3 + 2 ** 127; + + tempvar carry4 = (res_b.d3 - k_n.d3 + carry3) / BASE; + assert [range_check_ptr + 3] = carry4 + 2 ** 127; + + assert res_b.d4 - k_n.d4 + carry4 = 0; + + let range_check_ptr = range_check_ptr + 4; + + return (res=res); +} + +func div_mod_n_alt{range_check_ptr}(a: BigInt3, b: BigInt3) -> (res: BigInt3) { + // just used to import N + %{ + from starkware.cairo.common.cairo_secp.secp_utils import N, pack + from starkware.python.math_utils import div_mod, safe_div + + a = pack(ids.a, PRIME) + b = pack(ids.b, PRIME) + value = res = div_mod(a, b, N) + %} + + %{ + from starkware.cairo.common.cairo_secp.secp_utils import pack + from starkware.python.math_utils import div_mod, safe_div + + a = pack(ids.a, PRIME) + b = pack(ids.b, PRIME) + value = res = div_mod(a, b, N) + %} + let (res) = nondet_bigint3(); + + %{ value = k_plus_one = safe_div(res * b - a, N) + 1 %} + let (k_plus_one) = nondet_bigint3(); + let k = BigInt3(d0=k_plus_one.d0 - 1, d1=k_plus_one.d1, d2=k_plus_one.d2); + + let (res_b) = bigint_mul(res, b); + let n = BigInt3(N0, N1, N2); + let (k_n) = bigint_mul(k, n); + + tempvar carry1 = (res_b.d0 - k_n.d0 - a.d0) / BASE; + assert [range_check_ptr + 0] = carry1 + 2 ** 127; + + tempvar carry2 = (res_b.d1 - k_n.d1 - a.d1 + carry1) / BASE; + assert [range_check_ptr + 1] = carry2 + 2 ** 127; + + tempvar carry3 = (res_b.d2 - k_n.d2 - a.d2 + carry2) / BASE; + assert [range_check_ptr + 2] = carry3 + 2 ** 127; + + tempvar carry4 = (res_b.d3 - k_n.d3 + carry3) / BASE; + assert [range_check_ptr + 3] = carry4 + 2 ** 127; + + assert res_b.d4 - k_n.d4 + carry4 = 0; + + let range_check_ptr = range_check_ptr + 4; + + return (res=res); +} + +func test_div_mod_n{range_check_ptr: felt}() { + let a: BigInt3 = BigInt3(100, 99, 98); + let b: BigInt3 = BigInt3(10, 9, 8); + + let (res) = div_mod_n(a, b); + + assert res = BigInt3( + 3413472211745629263979533, 17305268010345238170172332, 11991751872105858217578135 + ); + + // test alternative hint + let (res_alt) = div_mod_n_alt(a, b); + + assert res_alt = res; + + return (); +} + +func main{range_check_ptr: felt}() { + test_div_mod_n(); + + return (); +} diff --git a/pkg/builtins/ec_op.go b/pkg/builtins/ec_op.go index 0e973a0a..9053c25c 100644 --- a/pkg/builtins/ec_op.go +++ b/pkg/builtins/ec_op.go @@ -5,7 +5,6 @@ import ( "math/big" "github.com/lambdaclass/cairo-vm.go/pkg/lambdaworks" - "github.com/lambdaclass/cairo-vm.go/pkg/math_utils" "github.com/lambdaclass/cairo-vm.go/pkg/utils" "github.com/lambdaclass/cairo-vm.go/pkg/vm/memory" "github.com/pkg/errors" @@ -263,7 +262,7 @@ func LineSlope(point_a PartialSumB, point_b DoublePointB, prime big.Int) (big.In n := new(big.Int).Sub(&point_a.Y, &point_b.Y) m := new(big.Int).Sub(&point_a.X, &point_b.X) - z, err := math_utils.DivMod(n, m, &prime) + z, err := utils.DivMod(n, m, &prime) if err != nil { return big.Int{}, err } @@ -299,7 +298,7 @@ func EcDoubleSlope(point DoublePointB, alpha big.Int, prime big.Int) (big.Int, e n.Add(n, &alpha) m := new(big.Int).Mul(&point.Y, big.NewInt(2)) - z, err := math_utils.DivMod(n, m, &prime) + z, err := utils.DivMod(n, m, &prime) if err != nil { return big.Int{}, err diff --git a/pkg/hints/hint_codes/signature_hint_codes.go b/pkg/hints/hint_codes/signature_hint_codes.go new file mode 100644 index 00000000..37b37051 --- /dev/null +++ b/pkg/hints/hint_codes/signature_hint_codes.go @@ -0,0 +1,21 @@ +package hint_codes + +const DIV_MOD_N_PACKED_DIVMOD_V1 = `from starkware.cairo.common.cairo_secp.secp_utils import N, pack +from starkware.python.math_utils import div_mod, safe_div + +a = pack(ids.a, PRIME) +b = pack(ids.b, PRIME) +value = res = div_mod(a, b, N)` + +const DIV_MOD_N_PACKED_DIVMOD_EXTERNAL_N = `from starkware.cairo.common.cairo_secp.secp_utils import pack +from starkware.python.math_utils import div_mod, safe_div + +a = pack(ids.a, PRIME) +b = pack(ids.b, PRIME) +value = res = div_mod(a, b, N)` + +const DIV_MOD_N_SAFE_DIV = "value = k = safe_div(res * b - a, N)" + +const DIV_MOD_N_SAFE_DIV_PLUS_ONE = "value = k_plus_one = safe_div(res * b - a, N) + 1" + +const XS_SAFE_DIV = "value = k = safe_div(res * s - x, N)" diff --git a/pkg/hints/hint_processor.go b/pkg/hints/hint_processor.go index ccac7bd2..9685de6e 100644 --- a/pkg/hints/hint_processor.go +++ b/pkg/hints/hint_processor.go @@ -188,6 +188,16 @@ func (p *CairoVmHintProcessor) ExecuteHint(vm *vm.VirtualMachine, hintData *any, return splitInt(data.Ids, vm) case SPLIT_INT_ASSERT_RANGE: return splitIntAssertRange(data.Ids, vm) + case DIV_MOD_N_PACKED_DIVMOD_V1: + return divModNPackedDivMod(data.Ids, vm, execScopes) + case DIV_MOD_N_PACKED_DIVMOD_EXTERNAL_N: + return divModNPackedDivModExternalN(data.Ids, vm, execScopes) + case XS_SAFE_DIV: + return divModNSafeDiv(data.Ids, execScopes, "x", "s", false) + case DIV_MOD_N_SAFE_DIV: + return divModNSafeDiv(data.Ids, execScopes, "a", "b", false) + case DIV_MOD_N_SAFE_DIV_PLUS_ONE: + return divModNSafeDiv(data.Ids, execScopes, "a", "b", true) case VERIFY_ZERO_EXTERNAL_SECP: return verifyZeroWithExternalConst(*vm, *execScopes, data.Ids) case FAST_EC_ADD_ASSIGN_NEW_X: diff --git a/pkg/hints/hint_utils/bigint_utils.go b/pkg/hints/hint_utils/bigint_utils.go index 9a900443..2fac8054 100644 --- a/pkg/hints/hint_utils/bigint_utils.go +++ b/pkg/hints/hint_utils/bigint_utils.go @@ -96,15 +96,14 @@ func BigInt3FromBaseAddr(addr Relocatable, name string, vm *VirtualMachine) (Big } func BigInt3FromVarName(name string, ids IdsManager, vm *VirtualMachine) (BigInt3, error) { - bigIntAddr, err := ids.GetAddr(name, vm) - if err != nil { - return BigInt3{}, err - } + limbs, err := limbsFromVarName(3, name, ids, vm) + return BigInt3{Limbs: limbs}, err +} - bigInt, err := BigInt3FromBaseAddr(bigIntAddr, name, vm) - if err != nil { - return BigInt3{}, err - } +// Uint384 + +type Uint384 = BigInt3 - return bigInt, err +func Uint384FromVarName(name string, ids IdsManager, vm *VirtualMachine) (Uint384, error) { + return BigInt3FromVarName(name, ids, vm) } diff --git a/pkg/hints/hint_utils/secp_utils.go b/pkg/hints/hint_utils/secp_utils.go index a125823a..bec3b650 100644 --- a/pkg/hints/hint_utils/secp_utils.go +++ b/pkg/hints/hint_utils/secp_utils.go @@ -46,7 +46,7 @@ func Bigint3Split(integer big.Int) ([]big.Int, error) { for i := 0; i < 3; i++ { canonicalRepr[i] = *new(big.Int).And(&num, BASE_MINUS_ONE()) - num.Rsh(&num, 86) + num = *new(big.Int).Rsh(&num, 86) } if num.Cmp(big.NewInt(0)) != 0 { return nil, errors.New("HintError SecpSplitOutOfRange") diff --git a/pkg/hints/math_hints.go b/pkg/hints/math_hints.go index 03b239c3..a577131d 100644 --- a/pkg/hints/math_hints.go +++ b/pkg/hints/math_hints.go @@ -7,8 +7,8 @@ import ( . "github.com/lambdaclass/cairo-vm.go/pkg/hints/hint_utils" "github.com/lambdaclass/cairo-vm.go/pkg/lambdaworks" . "github.com/lambdaclass/cairo-vm.go/pkg/lambdaworks" - . "github.com/lambdaclass/cairo-vm.go/pkg/math_utils" . "github.com/lambdaclass/cairo-vm.go/pkg/types" + . "github.com/lambdaclass/cairo-vm.go/pkg/utils" . "github.com/lambdaclass/cairo-vm.go/pkg/vm" . "github.com/lambdaclass/cairo-vm.go/pkg/vm/memory" "github.com/pkg/errors" diff --git a/pkg/hints/signature_hints.go b/pkg/hints/signature_hints.go new file mode 100644 index 00000000..37b57f7f --- /dev/null +++ b/pkg/hints/signature_hints.go @@ -0,0 +1,84 @@ +package hints + +import ( + "math/big" + + . "github.com/lambdaclass/cairo-vm.go/pkg/hints/hint_utils" + . "github.com/lambdaclass/cairo-vm.go/pkg/types" + "github.com/lambdaclass/cairo-vm.go/pkg/utils" + . "github.com/lambdaclass/cairo-vm.go/pkg/vm" +) + +func divModNPacked(ids IdsManager, vm *VirtualMachine, scopes *ExecutionScopes, n *big.Int) error { + a, err := Uint384FromVarName("a", ids, vm) + if err != nil { + return err + } + b, err := Uint384FromVarName("b", ids, vm) + if err != nil { + return err + } + packedA := a.Pack86() + packedB := b.Pack86() + + val, err := utils.DivMod(&packedA, &packedB, n) + if err != nil { + return err + } + + scopes.AssignOrUpdateVariable("a", packedA) + scopes.AssignOrUpdateVariable("b", packedB) + scopes.AssignOrUpdateVariable("value", *val) + scopes.AssignOrUpdateVariable("res", *val) + + return nil +} + +func divModNPackedDivMod(ids IdsManager, vm *VirtualMachine, scopes *ExecutionScopes) error { + n, _ := new(big.Int).SetString("115792089237316195423570985008687907852837564279074904382605163141518161494337", 10) + scopes.AssignOrUpdateVariable("N", *n) + return divModNPacked(ids, vm, scopes, n) +} + +func divModNPackedDivModExternalN(ids IdsManager, vm *VirtualMachine, scopes *ExecutionScopes) error { + n, err := FetchScopeVar[big.Int]("N", scopes) + if err != nil { + return err + } + return divModNPacked(ids, vm, scopes, &n) +} + +func divModNSafeDiv(ids IdsManager, scopes *ExecutionScopes, aAlias string, bAlias string, addOne bool) error { + // Fetch scope variables + a, err := FetchScopeVar[big.Int](aAlias, scopes) + if err != nil { + return err + } + + b, err := FetchScopeVar[big.Int](bAlias, scopes) + if err != nil { + return err + } + + res, err := FetchScopeVar[big.Int]("res", scopes) + if err != nil { + return err + } + + n, err := FetchScopeVar[big.Int]("N", scopes) + if err != nil { + return err + } + + // Hint logic + value, err := utils.SafeDivBig(new(big.Int).Sub(new(big.Int).Mul(&res, &b), &a), &n) + if err != nil { + return err + } + if addOne { + value = new(big.Int).Add(value, big.NewInt(1)) + } + // Update scope + scopes.AssignOrUpdateVariable("value", *value) + return nil +} diff --git a/pkg/hints/signature_hints_test.go b/pkg/hints/signature_hints_test.go new file mode 100644 index 00000000..4c3de522 --- /dev/null +++ b/pkg/hints/signature_hints_test.go @@ -0,0 +1,159 @@ +package hints_test + +import ( + "math/big" + "testing" + + . "github.com/lambdaclass/cairo-vm.go/pkg/hints" + . "github.com/lambdaclass/cairo-vm.go/pkg/hints/hint_codes" + . "github.com/lambdaclass/cairo-vm.go/pkg/hints/hint_utils" + . "github.com/lambdaclass/cairo-vm.go/pkg/lambdaworks" + . "github.com/lambdaclass/cairo-vm.go/pkg/types" + . "github.com/lambdaclass/cairo-vm.go/pkg/vm" + . "github.com/lambdaclass/cairo-vm.go/pkg/vm/memory" +) + +func TestDivModNPackedDivMod(t *testing.T) { + vm := NewVirtualMachine() + vm.Segments.AddSegment() + idsManager := SetupIdsForTest( + map[string][]*MaybeRelocatable{ + "a": { + NewMaybeRelocatableFelt(FeltFromUint64(10)), + NewMaybeRelocatableFelt(FeltFromUint64(0)), + NewMaybeRelocatableFelt(FeltFromUint64(0)), + }, + "b": { + NewMaybeRelocatableFelt(FeltFromUint64(2)), + NewMaybeRelocatableFelt(FeltFromUint64(0)), + NewMaybeRelocatableFelt(FeltFromUint64(0)), + }, + }, + vm, + ) + hintProcessor := CairoVmHintProcessor{} + hintData := any(HintData{ + Ids: idsManager, + Code: DIV_MOD_N_PACKED_DIVMOD_V1, + }) + scopes := NewExecutionScopes() + err := hintProcessor.ExecuteHint(vm, &hintData, nil, scopes) + if err != nil { + t.Errorf("DIV_MOD_N_PACKED_DIVMOD_V1 hint test failed with error %s", err) + } + // Check result in scope + expectedRes := big.NewInt(5) + + res, err := FetchScopeVar[big.Int]("res", scopes) + if err != nil || res.Cmp(expectedRes) != 0 { + t.Error("Wrong/No scope value res") + } + + val, err := FetchScopeVar[big.Int]("value", scopes) + if err != nil || val.Cmp(expectedRes) != 0 { + t.Error("Wrong/No scope var value") + } +} + +func TestDivModNPackedDivModExternalN(t *testing.T) { + vm := NewVirtualMachine() + vm.Segments.AddSegment() + idsManager := SetupIdsForTest( + map[string][]*MaybeRelocatable{ + "a": { + NewMaybeRelocatableFelt(FeltFromUint64(20)), + NewMaybeRelocatableFelt(FeltFromUint64(0)), + NewMaybeRelocatableFelt(FeltFromUint64(0)), + }, + "b": { + NewMaybeRelocatableFelt(FeltFromUint64(2)), + NewMaybeRelocatableFelt(FeltFromUint64(0)), + NewMaybeRelocatableFelt(FeltFromUint64(0)), + }, + }, + vm, + ) + hintProcessor := CairoVmHintProcessor{} + hintData := any(HintData{ + Ids: idsManager, + Code: DIV_MOD_N_PACKED_DIVMOD_EXTERNAL_N, + }) + scopes := NewExecutionScopes() + scopes.AssignOrUpdateVariable("N", *big.NewInt(7)) + err := hintProcessor.ExecuteHint(vm, &hintData, nil, scopes) + if err != nil { + t.Errorf("DIV_MOD_N_PACKED_DIVMOD_EXTERNAL_N hint test failed with error %s", err) + } + // Check result in scope + expectedRes := big.NewInt(3) + + res, err := FetchScopeVar[big.Int]("res", scopes) + if err != nil || res.Cmp(expectedRes) != 0 { + t.Error("Wrong/No scope value res") + } + + val, err := FetchScopeVar[big.Int]("value", scopes) + if err != nil || val.Cmp(expectedRes) != 0 { + t.Error("Wrong/No scope var value") + } +} + +func TestDivModSafeDivOk(t *testing.T) { + vm := NewVirtualMachine() + vm.Segments.AddSegment() + idsManager := SetupIdsForTest( + map[string][]*MaybeRelocatable{}, + vm, + ) + hintProcessor := CairoVmHintProcessor{} + hintData := any(HintData{ + Ids: idsManager, + Code: DIV_MOD_N_SAFE_DIV, + }) + scopes := NewExecutionScopes() + scopes.AssignOrUpdateVariable("N", *big.NewInt(5)) + scopes.AssignOrUpdateVariable("a", *big.NewInt(10)) + scopes.AssignOrUpdateVariable("b", *big.NewInt(30)) + scopes.AssignOrUpdateVariable("res", *big.NewInt(2)) + err := hintProcessor.ExecuteHint(vm, &hintData, nil, scopes) + if err != nil { + t.Errorf("DIV_MOD_N_SAFE_DIV hint test failed with error %s", err) + } + // Check result in scope + expectedValue := big.NewInt(10) // (2 * 30 - 10) / 5 = 10 + + val, err := FetchScopeVar[big.Int]("value", scopes) + if err != nil || val.Cmp(expectedValue) != 0 { + t.Error("Wrong/No scope value val") + } +} + +func TestDivModSafeDivPlusOneOk(t *testing.T) { + vm := NewVirtualMachine() + vm.Segments.AddSegment() + idsManager := SetupIdsForTest( + map[string][]*MaybeRelocatable{}, + vm, + ) + hintProcessor := CairoVmHintProcessor{} + hintData := any(HintData{ + Ids: idsManager, + Code: DIV_MOD_N_SAFE_DIV_PLUS_ONE, + }) + scopes := NewExecutionScopes() + scopes.AssignOrUpdateVariable("N", *big.NewInt(5)) + scopes.AssignOrUpdateVariable("a", *big.NewInt(10)) + scopes.AssignOrUpdateVariable("b", *big.NewInt(30)) + scopes.AssignOrUpdateVariable("res", *big.NewInt(2)) + err := hintProcessor.ExecuteHint(vm, &hintData, nil, scopes) + if err != nil { + t.Errorf("DIV_MOD_N_SAFE_DIV_PLUS_ONE hint test failed with error %s", err) + } + // Check result in scope + expectedValue := big.NewInt(11) // (2 * 30 - 10) / 5 + 1 = 11 + + val, err := FetchScopeVar[big.Int]("value", scopes) + if err != nil || val.Cmp(expectedValue) != 0 { + t.Error("Wrong/No scope value val") + } +} diff --git a/pkg/math_utils/utils.go b/pkg/math_utils/utils.go deleted file mode 100644 index 282ba4a3..00000000 --- a/pkg/math_utils/utils.go +++ /dev/null @@ -1,27 +0,0 @@ -package math_utils - -import ( - "github.com/pkg/errors" - "math/big" -) - -// Finds a nonnegative integer x < p such that (m * x) % p == n. -func DivMod(n *big.Int, m *big.Int, p *big.Int) (*big.Int, error) { - a := new(big.Int) - gcd := new(big.Int) - gcd.GCD(a, nil, m, p) - - if gcd.Cmp(big.NewInt(1)) != 0 { - return nil, errors.Errorf("gcd(%s, %s) != 1", m, p) - } - - return n.Mul(n, a).Mod(n, p), nil -} - -func ISqrt(x *big.Int) (*big.Int, error) { - if x.Sign() == -1 { - return nil, errors.Errorf("Expected x: %s to be non-negative", x) - } - res := new(big.Int) - return res.Sqrt(x), nil -} diff --git a/pkg/math_utils/utils_test.go b/pkg/math_utils/utils_test.go deleted file mode 100644 index c4eee853..00000000 --- a/pkg/math_utils/utils_test.go +++ /dev/null @@ -1,129 +0,0 @@ -package math_utils_test - -import ( - "math/big" - "testing" - - . "github.com/lambdaclass/cairo-vm.go/pkg/math_utils" -) - -func TestDivModOk(t *testing.T) { - a := new(big.Int) - b := new(big.Int) - prime := new(big.Int) - expected := new(big.Int) - - a.SetString("11260647941622813594563746375280766662237311019551239924981511729608487775604310196863705127454617186486639011517352066501847110680463498585797912894788", 10) - b.SetString("4020711254448367604954374443741161860304516084891705811279711044808359405970", 10) - prime.SetString("800000000000011000000000000000000000000000000000000000000000001", 16) - expected.SetString("2904750555256547440469454488220756360634457312540595732507835416669695939476", 10) - - num, err := DivMod(a, b, prime) - if err != nil { - t.Errorf("DivMod failed with error: %s", err) - } - if num.Cmp(expected) != 0 { - t.Errorf("Expected result: %s to be equal to %s", num, expected) - } -} - -func TestDivModMZeroFail(t *testing.T) { - a := new(big.Int) - b := new(big.Int) - prime := new(big.Int) - - a.SetString("11260647941622813594563746375280766662237311019551239924981511729608487775604310196863705127454617186486639011517352066501847110680463498585797912894788", 10) - prime.SetString("800000000000011000000000000000000000000000000000000000000000001", 16) - - _, err := DivMod(a, b, prime) - if err == nil { - t.Errorf("DivMod expected to failed with gcd != 1") - } -} - -func TestDivModMEqPFail(t *testing.T) { - a := new(big.Int) - b := new(big.Int) - prime := new(big.Int) - - a.SetString("11260647941622813594563746375280766662237311019551239924981511729608487775604310196863705127454617186486639011517352066501847110680463498585797912894788", 10) - b.SetString("800000000000011000000000000000000000000000000000000000000000001", 16) - prime.SetString("800000000000011000000000000000000000000000000000000000000000001", 16) - - _, err := DivMod(a, b, prime) - if err == nil { - t.Errorf("DivMod expected to failed with gcd != 1") - } -} - -func TestIsSqrtOk(t *testing.T) { - x := new(big.Int) - y := new(big.Int) - x.SetString("4573659632505831259480", 10) - y.Mul(x, x) - - sqr_y, err := ISqrt(y) - if err != nil { - t.Errorf("ISqrt failed with error: %s", err) - } - if x.Cmp(sqr_y) != 0 { - t.Errorf("Failed to get square root of x^2, x: %s", x) - } -} - -func TestCalculateIsqrtA(t *testing.T) { - x := new(big.Int) - x.SetString("81", 10) - sqrt, err := ISqrt(x) - if err != nil { - t.Error("ISqrt failed") - } - - expected := new(big.Int) - expected.SetString("9", 10) - - if sqrt.Cmp(expected) != 0 { - t.Errorf("ISqrt failed, expected %d, got %d", expected, sqrt) - } -} - -func TestCalculateIsqrtB(t *testing.T) { - x := new(big.Int) - x.SetString("4573659632505831259480", 10) - square := new(big.Int) - square = square.Mul(x, x) - - sqrt, err := ISqrt(square) - if err != nil { - t.Error("ISqrt failed") - } - - if sqrt.Cmp(x) != 0 { - t.Errorf("ISqrt failed, expected %d, got %d", x, sqrt) - } -} - -func TestCalculateIsqrtC(t *testing.T) { - x := new(big.Int) - x.SetString("3618502788666131213697322783095070105623107215331596699973092056135872020481", 10) - square := new(big.Int) - square = square.Mul(x, x) - - sqrt, err := ISqrt(square) - if err != nil { - t.Error("ISqrt failed") - } - - if sqrt.Cmp(x) != 0 { - t.Errorf("ISqrt failed, expected %d, got %d", x, sqrt) - } -} - -func TestIsSqrtFail(t *testing.T) { - x := big.NewInt(-1) - - _, err := ISqrt(x) - if err == nil { - t.Errorf("expected ISqrt to fail") - } -} diff --git a/pkg/utils/math_utils.go b/pkg/utils/math_utils.go index b1054f70..7f011a6c 100644 --- a/pkg/utils/math_utils.go +++ b/pkg/utils/math_utils.go @@ -63,3 +63,49 @@ func SafeDivBig(x *big.Int, y *big.Int) (*big.Int, error) { } return q, nil } + +// Finds a nonnegative integer x < p such that (m * x) % p == n. +func DivMod(n *big.Int, m *big.Int, p *big.Int) (*big.Int, error) { + a, _, c := Igcdex(m, p) + if c.Cmp(big.NewInt(1)) != 0 { + return nil, errors.Errorf("Operation failed: divmod(%s, %s, %s), igcdex(%s, %s) != 1 ", n.Text(10), m.Text(10), p.Text(10), m.Text(10), p.Text(10)) + } + return new(big.Int).Mod(new(big.Int).Mul(n, a), p), nil +} + +func Igcdex(a *big.Int, b *big.Int) (*big.Int, *big.Int, *big.Int) { + zero := big.NewInt(0) + one := big.NewInt(1) + switch true { + case a.Cmp(zero) == 0 && b.Cmp(zero) == 0: + return zero, one, zero + case a.Cmp(zero) == 0: + return zero, big.NewInt(int64(a.Sign())), new(big.Int).Abs(b) + case b.Cmp(zero) == 0: + return big.NewInt(int64(a.Sign())), zero, new(big.Int).Abs(a) + default: + xSign := big.NewInt(int64(a.Sign())) + ySign := big.NewInt(int64(b.Sign())) + a = new(big.Int).Abs(a) + b = new(big.Int).Abs(b) + x, y, r, s := big.NewInt(1), big.NewInt(0), big.NewInt(0), big.NewInt(1) + for b.Cmp(zero) != 0 { + q, c := new(big.Int).DivMod(a, b, new(big.Int)) + x = new(big.Int).Sub(x, new(big.Int).Mul(q, r)) + y = new(big.Int).Sub(y, new(big.Int).Mul(q, s)) + + a, b, r, s, x, y = b, c, x, y, r, s + } + + return new(big.Int).Mul(x, xSign), new(big.Int).Mul(y, ySign), a + + } +} + +func ISqrt(x *big.Int) (*big.Int, error) { + if x.Sign() == -1 { + return nil, errors.Errorf("Expected x: %s to be non-negative", x) + } + res := new(big.Int) + return res.Sqrt(x), nil +} diff --git a/pkg/utils/math_utils_test.go b/pkg/utils/math_utils_test.go index e3b2a152..3b308322 100644 --- a/pkg/utils/math_utils_test.go +++ b/pkg/utils/math_utils_test.go @@ -44,3 +44,164 @@ func TestSafeDivBigErrZeroDivison(t *testing.T) { t.Error("SafeDivBig should have failed") } } + +func TestIgcdex11(t *testing.T) { + a := big.NewInt(1) + b := big.NewInt(1) + expectedX, expectedY, expectedZ := big.NewInt(0), big.NewInt(1), big.NewInt(1) + x, y, z := Igcdex(a, b) + if x.Cmp(expectedX) != 0 || y.Cmp(expectedY) != 0 || z.Cmp(expectedZ) != 0 { + t.Error("Wrong values returned by Igcdex") + } +} + +func TestIgcdex00(t *testing.T) { + a := big.NewInt(0) + b := big.NewInt(0) + expectedX, expectedY, expectedZ := big.NewInt(0), big.NewInt(1), big.NewInt(0) + x, y, z := Igcdex(a, b) + if x.Cmp(expectedX) != 0 || y.Cmp(expectedY) != 0 || z.Cmp(expectedZ) != 0 { + t.Error("Wrong values returned by Igcdex") + } +} + +func TestIgcdex10(t *testing.T) { + a := big.NewInt(1) + b := big.NewInt(0) + expectedX, expectedY, expectedZ := big.NewInt(1), big.NewInt(0), big.NewInt(1) + x, y, z := Igcdex(a, b) + if x.Cmp(expectedX) != 0 || y.Cmp(expectedY) != 0 || z.Cmp(expectedZ) != 0 { + t.Error("Wrong values returned by Igcdex") + } +} + +func TestIgcdex46(t *testing.T) { + a := big.NewInt(4) + b := big.NewInt(6) + expectedX, expectedY, expectedZ := big.NewInt(-1), big.NewInt(1), big.NewInt(2) + x, y, z := Igcdex(a, b) + if x.Cmp(expectedX) != 0 || y.Cmp(expectedY) != 0 || z.Cmp(expectedZ) != 0 { + t.Error("Wrong values returned by Igcdex") + } +} + +func TestDivModOk(t *testing.T) { + a := new(big.Int) + b := new(big.Int) + prime := new(big.Int) + expected := new(big.Int) + + a.SetString("11260647941622813594563746375280766662237311019551239924981511729608487775604310196863705127454617186486639011517352066501847110680463498585797912894788", 10) + b.SetString("4020711254448367604954374443741161860304516084891705811279711044808359405970", 10) + prime.SetString("800000000000011000000000000000000000000000000000000000000000001", 16) + expected.SetString("2904750555256547440469454488220756360634457312540595732507835416669695939476", 10) + + num, err := DivMod(a, b, prime) + if err != nil { + t.Errorf("DivMod failed with error: %s", err) + } + if num.Cmp(expected) != 0 { + t.Errorf("Expected result: %s to be equal to %s", num, expected) + } +} + +func TestDivModMZeroFail(t *testing.T) { + a := new(big.Int) + b := new(big.Int) + prime := new(big.Int) + + a.SetString("11260647941622813594563746375280766662237311019551239924981511729608487775604310196863705127454617186486639011517352066501847110680463498585797912894788", 10) + prime.SetString("800000000000011000000000000000000000000000000000000000000000001", 16) + + _, err := DivMod(a, b, prime) + if err == nil { + t.Errorf("DivMod expected to failed with gcd != 1") + } +} + +func TestDivModMEqPFail(t *testing.T) { + a := new(big.Int) + b := new(big.Int) + prime := new(big.Int) + + a.SetString("11260647941622813594563746375280766662237311019551239924981511729608487775604310196863705127454617186486639011517352066501847110680463498585797912894788", 10) + b.SetString("800000000000011000000000000000000000000000000000000000000000001", 16) + prime.SetString("800000000000011000000000000000000000000000000000000000000000001", 16) + + _, err := DivMod(a, b, prime) + if err == nil { + t.Errorf("DivMod expected to failed with gcd != 1") + } +} + +func TestIsSqrtOk(t *testing.T) { + x := new(big.Int) + y := new(big.Int) + x.SetString("4573659632505831259480", 10) + y.Mul(x, x) + + sqr_y, err := ISqrt(y) + if err != nil { + t.Errorf("ISqrt failed with error: %s", err) + } + if x.Cmp(sqr_y) != 0 { + t.Errorf("Failed to get square root of x^2, x: %s", x) + } +} + +func TestCalculateIsqrtA(t *testing.T) { + x := new(big.Int) + x.SetString("81", 10) + sqrt, err := ISqrt(x) + if err != nil { + t.Error("ISqrt failed") + } + + expected := new(big.Int) + expected.SetString("9", 10) + + if sqrt.Cmp(expected) != 0 { + t.Errorf("ISqrt failed, expected %d, got %d", expected, sqrt) + } +} + +func TestCalculateIsqrtB(t *testing.T) { + x := new(big.Int) + x.SetString("4573659632505831259480", 10) + square := new(big.Int) + square = square.Mul(x, x) + + sqrt, err := ISqrt(square) + if err != nil { + t.Error("ISqrt failed") + } + + if sqrt.Cmp(x) != 0 { + t.Errorf("ISqrt failed, expected %d, got %d", x, sqrt) + } +} + +func TestCalculateIsqrtC(t *testing.T) { + x := new(big.Int) + x.SetString("3618502788666131213697322783095070105623107215331596699973092056135872020481", 10) + square := new(big.Int) + square = square.Mul(x, x) + + sqrt, err := ISqrt(square) + if err != nil { + t.Error("ISqrt failed") + } + + if sqrt.Cmp(x) != 0 { + t.Errorf("ISqrt failed, expected %d, got %d", x, sqrt) + } +} + +func TestIsSqrtFail(t *testing.T) { + x := big.NewInt(-1) + + _, err := ISqrt(x) + if err == nil { + t.Errorf("expected ISqrt to fail") + } +} diff --git a/pkg/vm/cairo_run/cairo_run_test.go b/pkg/vm/cairo_run/cairo_run_test.go index 845a3076..f0685afb 100644 --- a/pkg/vm/cairo_run/cairo_run_test.go +++ b/pkg/vm/cairo_run/cairo_run_test.go @@ -329,6 +329,10 @@ func TestSplitIntHintProofMode(t *testing.T) { testProgramProof("split_int", t) } +func TestDivModN(t *testing.T) { + testProgram("div_mod_n", t) +} + func TestEcDoubleAssign(t *testing.T) { testProgram("ec_double_assign", t) }