Skip to content

Commit

Permalink
uint256 hints (#279)
Browse files Browse the repository at this point in the history
* draft uint256 add

* implement hint uint256_add

* add tests

* add hint uint256_add_low

* implement split_64 hint

* change location of uint256 struct

* implement auxiliar methods to ids manager to insert u256 structs

* implement uint256sqrt hint

* add unit test sqrt

* fix unit test

* add unit tests uint256_sqrt

* implement hint uint256_signed_nn

* add tests

* implement UINT256_UNSIGNED_DIV_REM hint

* add tests

* implement hint and test

* implement hint

* add test and improve commit

* fix test

* add integration tests

* fix unit and integration tests

* improve code·

* add uint256 utils

* add test

* fix comments

* improve ToString method on Uint256

* fix test

* fix hint

* add extra test used to debug error

* add hint uint256_sub

* add newline
  • Loading branch information
toni-calvin authored Oct 5, 2023
1 parent 76d65f6 commit e5fcc25
Show file tree
Hide file tree
Showing 13 changed files with 1,488 additions and 3 deletions.
115 changes: 115 additions & 0 deletions cairo_programs/uint256.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
%builtins range_check

from starkware.cairo.common.uint256 import (
Uint256,
uint256_add,
split_64,
uint256_sqrt,
uint256_signed_nn,
uint256_unsigned_div_rem,
uint256_mul,
uint256_mul_div_mod
)
from starkware.cairo.common.alloc import alloc

func fill_array{range_check_ptr: felt}(
array: Uint256*, base: Uint256, step: Uint256, array_length: felt, iterator: felt
) {
if (iterator == array_length) {
return ();
}
let (res, carry_high) = uint256_add(step, base);
let (sqrt) = uint256_sqrt(res);

assert array[iterator] = sqrt;
return fill_array(array, base, array[iterator], array_length, iterator + 1);
}

func main{range_check_ptr: felt}() {
let x: Uint256 = Uint256(5, 2);
let y = Uint256(3, 7);
let (res, carry_high) = uint256_add(x, y);
assert res.low = 8;
assert res.high = 9;
assert carry_high = 0;

let (low, high) = split_64(850981239023189021389081239089023);
assert low = 7249717543555297151;
assert high = 46131785404667;

let (root) = uint256_sqrt(Uint256(17, 7));
assert root = Uint256(48805497317890012913, 0);

let (signed_nn) = uint256_signed_nn(Uint256(5, 2));
assert signed_nn = 1;
let (p) = uint256_signed_nn(Uint256(1, 170141183460469231731687303715884105728));
assert p = 0;
let (q) = uint256_signed_nn(Uint256(1, 170141183460469231731687303715884105727));
assert q = 1;

let (a_quotient, a_remainder) = uint256_unsigned_div_rem(Uint256(89, 72), Uint256(3, 7));
assert a_quotient = Uint256(10, 0);
assert a_remainder = Uint256(59, 2);

let (b_quotient, b_remainder) = uint256_unsigned_div_rem(
Uint256(-3618502788666131213697322783095070105282824848410658236509717448704103809099, 2),
Uint256(5, 2),
);
assert b_quotient = Uint256(1, 0);
assert b_remainder = Uint256(340282366920938463463374607431768211377, 0);

let (c_quotient, c_remainder) = uint256_unsigned_div_rem(
Uint256(340282366920938463463374607431768211455, 340282366920938463463374607431768211455),
Uint256(1, 0),
);

assert c_quotient = Uint256(340282366920938463463374607431768211455, 340282366920938463463374607431768211455);
assert c_remainder = Uint256(0, 0);

let (a_quotient_low, a_quotient_high, a_remainder) = uint256_mul_div_mod(
Uint256(89, 72),
Uint256(3, 7),
Uint256(107, 114),
);
assert a_quotient_low = Uint256(143276786071974089879315624181797141668, 4);
assert a_quotient_high = Uint256(0, 0);
assert a_remainder = Uint256(322372768661941702228460154409043568767, 101);

let (b_quotient_low, b_quotient_high, b_remainder) = uint256_mul_div_mod(
Uint256(-3618502788666131213697322783095070105282824848410658236509717448704103809099, 2),
Uint256(1, 1),
Uint256(5, 2),
);
assert b_quotient_low = Uint256(170141183460469231731687303715884105688, 1);
assert b_quotient_high = Uint256(0, 0);
assert b_remainder = Uint256(170141183460469231731687303715884105854, 1);

let (c_quotient_low, c_quotient_high, c_remainder) = uint256_mul_div_mod(
Uint256(340281070833283907490476236129005105807, 340282366920938463463374607431768211455),
Uint256(2447157533618445569039502, 0),
Uint256(0, 1),
);

assert c_quotient_low = Uint256(340282366920938463454053728725133866491, 2447157533618445569039501);
assert c_quotient_high = Uint256(0, 0);
assert c_remainder = Uint256(326588112914912836985603897252688572242, 0);

let (mult_low_a, mult_high_a) = uint256_mul(Uint256(59, 2), Uint256(10, 0));
assert mult_low_a = Uint256(590, 20);
assert mult_high_a = Uint256(0, 0);

let (mult_low_b: Uint256, mult_high_b: Uint256) = uint256_mul(
Uint256(271442546951262198976322048597925888860, 0),
Uint256(271442546951262198976322048597925888860, 0),
);
assert mult_low_b = Uint256(
42047520920204780886066537579778623760, 216529163594619381764978757921136443390
);
assert mult_high_b = Uint256(0, 0);

let array_length = 100;
let (sum_array: Uint256*) = alloc();
fill_array(sum_array, Uint256(57, 8), Uint256(17, 7), array_length, 0);

return ();
}
151 changes: 151 additions & 0 deletions cairo_programs/uint256_integration_tests.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
%builtins range_check bitwise

from starkware.cairo.common.uint256 import (
Uint256,
uint256_add,
split_64,
uint256_sqrt,
uint256_signed_nn,
uint256_unsigned_div_rem,
uint256_mul,
uint256_or,
uint256_reverse_endian,
)
from starkware.cairo.common.alloc import alloc
from starkware.cairo.common.cairo_builtins import BitwiseBuiltin

func fill_array(array_start: felt*, base: felt, step: felt, iter: felt, last: felt) -> () {
if (iter == last) {
return ();
}
assert array_start[iter] = base + step;
return fill_array(array_start, base + step, step, iter + 1, last);
}

func fill_uint256_array{range_check_ptr: felt}(
array: Uint256*, base: Uint256, step: Uint256, array_len: felt, iterator: felt
) {
if (iterator == array_len) {
return ();
}
let (res: Uint256, carry_high: felt) = uint256_add(step, base);

assert array[iterator] = res;
return fill_uint256_array(array, base, array[iterator], array_len, iterator + 1);
}

func test_sqrt{range_check_ptr}(
base_array: Uint256*, new_array: Uint256*, iter: felt, last: felt
) -> () {
alloc_locals;

if (iter == last) {
return ();
}

let res: Uint256 = uint256_sqrt(base_array[iter]);
assert new_array[iter] = res;

return test_sqrt(base_array, new_array, iter + 1, last);
}

func test_signed_nn{range_check_ptr}(
base_array: Uint256*, new_array: felt*, iter: felt, last: felt
) -> () {
alloc_locals;

if (iter == last) {
return ();
}

let res: felt = uint256_signed_nn(base_array[iter]);
assert res = 1;
assert new_array[iter] = res;

return test_signed_nn(base_array, new_array, iter + 1, last);
}

func test_unsigned_div_rem{range_check_ptr}(
base_array: Uint256*, new_array: Uint256*, iter: felt, last: felt
) -> () {
alloc_locals;

if (iter == last) {
return ();
}

let (quotient: Uint256, remainder: Uint256) = uint256_unsigned_div_rem(
base_array[iter], Uint256(7, 8)
);
assert new_array[(iter * 2)] = quotient;
assert new_array[(iter * 2) + 1] = remainder;

return test_unsigned_div_rem(base_array, new_array, iter + 1, last);
}

func test_split_64{range_check_ptr}(
base_array: felt*, new_array: felt*, iter: felt, last: felt
) -> () {
alloc_locals;

if (iter == last) {
return ();
}

let (low: felt, high: felt) = split_64(base_array[iter]);
assert new_array[(iter * 2)] = low;
assert new_array[(iter * 2) + 1] = high;
return test_split_64(base_array, new_array, iter + 1, last);
}

func test_integration{range_check_ptr, bitwise_ptr: BitwiseBuiltin*}(
base_array: Uint256*, new_array: Uint256*, iter: felt, last: felt
) -> () {
alloc_locals;

if (iter == last) {
return ();
}

let (add: Uint256, carry_high: felt) = uint256_add(base_array[iter], base_array[iter + 1]);
let (quotient: Uint256, remainder: Uint256) = uint256_unsigned_div_rem(add, Uint256(5, 3));
let (low: Uint256, high: Uint256) = uint256_mul(quotient, remainder);
let (bitwise_or: Uint256) = uint256_or(low, high);
let (reverse_endian: Uint256) = uint256_reverse_endian(bitwise_or);
let (result: Uint256) = uint256_sqrt(reverse_endian);

assert new_array[iter] = result;
return test_integration(base_array, new_array, iter + 1, last);
}

func run_tests{range_check_ptr, bitwise_ptr: BitwiseBuiltin*}(array_len: felt) -> () {
alloc_locals;
let (uint256_array: Uint256*) = alloc();
fill_uint256_array(uint256_array, Uint256(57, 8), Uint256(57, 101), array_len, 0);

let (array_sqrt: Uint256*) = alloc();
test_sqrt(uint256_array, array_sqrt, 0, array_len);

let (array_signed_nn: felt*) = alloc();
test_signed_nn(uint256_array, array_signed_nn, 0, array_len);

let (array_unsigned_div_rem: Uint256*) = alloc();
test_unsigned_div_rem(uint256_array, array_unsigned_div_rem, 0, array_len);

let (felt_array: felt*) = alloc();
fill_array(felt_array, 0, 3, 0, array_len);

let (array_split_64: felt*) = alloc();
test_split_64(felt_array, array_split_64, 0, array_len);

let (array_test_integration: Uint256*) = alloc();
test_integration(uint256_array, array_test_integration, 0, array_len - 1);

return ();
}

func main{range_check_ptr: felt, bitwise_ptr: BitwiseBuiltin*}() {
run_tests(10);

return ();
}
12 changes: 12 additions & 0 deletions cairo_programs/uint256_root.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
%builtins range_check bitwise
from starkware.cairo.common.uint256 import (
Uint256,
uint256_sqrt,
)
from starkware.cairo.common.cairo_builtins import BitwiseBuiltin

func main{range_check_ptr: felt, bitwise_ptr: BitwiseBuiltin*}() {
let n = Uint256(0, 157560248172239344387757911110183813120);
let res = uint256_sqrt(n);
return ();
}
28 changes: 28 additions & 0 deletions pkg/hints/hint_codes/uint256_hint_codes.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package hint_codes

const UINT256_ADD = "sum_low = ids.a.low + ids.b.low\nids.carry_low = 1 if sum_low >= ids.SHIFT else 0\nsum_high = ids.a.high + ids.b.high + ids.carry_low\nids.carry_high = 1 if sum_high >= ids.SHIFT else 0"
const UINT256_ADD_LOW = "sum_low = ids.a.low + ids.b.low\nids.carry_low = 1 if sum_low >= ids.SHIFT else 0"
const SPLIT_64 = "ids.low = ids.a & ((1<<64) - 1)\nids.high = ids.a >> 64"
const UINT256_SQRT = "from starkware.python.math_utils import isqrt\nn = (ids.n.high << 128) + ids.n.low\nroot = isqrt(n)\nassert 0 <= root < 2 ** 128\nids.root.low = root\nids.root.high = 0"
const UINT256_SQRT_FELT = "from starkware.python.math_utils import isqrt\nn = (ids.n.high << 128) + ids.n.low\nroot = isqrt(n)\nassert 0 <= root < 2 ** 128\nids.root = root;"
const UINT256_SIGNED_NN = "memory[ap] = 1 if 0 <= (ids.a.high % PRIME) < 2 ** 127 else 0"
const UINT256_UNSIGNED_DIV_REM = "a = (ids.a.high << 128) + ids.a.low\ndiv = (ids.div.high << 128) + ids.div.low\nquotient, remainder = divmod(a, div)\n\nids.quotient.low = quotient & ((1 << 128) - 1)\nids.quotient.high = quotient >> 128\nids.remainder.low = remainder & ((1 << 128) - 1)\nids.remainder.high = remainder >> 128"
const UINT256_EXPANDED_UNSIGNED_DIV_REM = "a = (ids.a.high << 128) + ids.a.low\ndiv = (ids.div.b23 << 128) + ids.div.b01\nquotient, remainder = divmod(a, div)\n\nids.quotient.low = quotient & ((1 << 128) - 1)\nids.quotient.high = quotient >> 128\nids.remainder.low = remainder & ((1 << 128) - 1)\nids.remainder.high = remainder >> 128"
const UINT256_MUL_DIV_MOD = "a = (ids.a.high << 128) + ids.a.low\nb = (ids.b.high << 128) + ids.b.low\ndiv = (ids.div.high << 128) + ids.div.low\nquotient, remainder = divmod(a * b, div)\n\nids.quotient_low.low = quotient & ((1 << 128) - 1)\nids.quotient_low.high = (quotient >> 128) & ((1 << 128) - 1)\nids.quotient_high.low = (quotient >> 256) & ((1 << 128) - 1)\nids.quotient_high.high = quotient >> 384\nids.remainder.low = remainder & ((1 << 128) - 1)\nids.remainder.high = remainder >> 128"
const UINT256_SUB = `def split(num: int, num_bits_shift: int = 128, length: int = 2):
a = []
for _ in range(length):
a.append( num & ((1 << num_bits_shift) - 1) )
num = num >> num_bits_shift
return tuple(a)
def pack(z, num_bits_shift: int = 128) -> int:
limbs = (z.low, z.high)
return sum(limb << (num_bits_shift * i) for i, limb in enumerate(limbs))
a = pack(ids.a)
b = pack(ids.b)
res = (a - b)%2**256
res_split = split(res)
ids.res.low = res_split[0]
ids.res.high = res_split[1]`
20 changes: 20 additions & 0 deletions pkg/hints/hint_processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,26 @@ func (p *CairoVmHintProcessor) ExecuteHint(vm *vm.VirtualMachine, hintData *any,
return splitInt(data.Ids, vm)
case SPLIT_INT_ASSERT_RANGE:
return splitIntAssertRange(data.Ids, vm)
case UINT256_ADD:
return uint256Add(data.Ids, vm, false)
case UINT256_ADD_LOW:
return uint256Add(data.Ids, vm, true)
case UINT256_SUB:
return uint256Sub(data.Ids, vm)
case SPLIT_64:
return split64(data.Ids, vm)
case UINT256_SQRT:
return uint256Sqrt(data.Ids, vm, false)
case UINT256_SQRT_FELT:
return uint256Sqrt(data.Ids, vm, true)
case UINT256_SIGNED_NN:
return uint256SignedNN(data.Ids, vm)
case UINT256_UNSIGNED_DIV_REM:
return uint256UnsignedDivRem(data.Ids, vm)
case UINT256_EXPANDED_UNSIGNED_DIV_REM:
return uint256ExpandedUnsignedDivRem(data.Ids, vm)
case UINT256_MUL_DIV_MOD:
return uint256MulDivMod(data.Ids, vm)
case DIV_MOD_N_PACKED_DIVMOD_V1:
return divModNPackedDivMod(data.Ids, vm, execScopes)
case DIV_MOD_N_PACKED_DIVMOD_EXTERNAL_N:
Expand Down
29 changes: 29 additions & 0 deletions pkg/hints/hint_utils/ids_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,22 @@ func (ids *IdsManager) GetFelt(name string, vm *VirtualMachine) (lambdaworks.Fel
return felt, nil
}

func (ids *IdsManager) GetUint256(name string, vm *VirtualMachine) (Uint256, error) {
lowAddr, err := ids.GetAddr(name, vm)
if err != nil {
return Uint256{}, err
}
low, err := vm.Segments.Memory.GetFelt(lowAddr)
if err != nil {
return Uint256{}, err
}
high, err := vm.Segments.Memory.GetFelt(lowAddr.AddUint(1))
if err != nil {
return Uint256{}, err
}
return Uint256{Low: low, High: high}, nil
}

// Returns the value of an identifier as a Relocatable
func (ids *IdsManager) GetRelocatable(name string, vm *VirtualMachine) (Relocatable, error) {
val, err := ids.Get(name, vm)
Expand Down Expand Up @@ -217,6 +233,19 @@ func (ids *IdsManager) InsertStructField(name string, field_off uint, value *May
return vm.Segments.Memory.Insert(addr.AddUint(field_off), value)
}

// Inserts Uint256 value into an ids field (given the identifier is a Uint256)
func (ids *IdsManager) InsertUint256(name string, val Uint256, vm *VirtualMachine) error {
baseAddr, err := ids.GetAddr(name, vm)
if err != nil {
return err
}
err = vm.Segments.Memory.Insert(baseAddr, NewMaybeRelocatableFelt(val.Low))
if err != nil {
return err
}
return vm.Segments.Memory.Insert(baseAddr.AddUint(1), NewMaybeRelocatableFelt(val.High))
}

// Inserts value into the address of the given identifier
func insertIdsFromReference(value *MaybeRelocatable, reference *HintReference, apTracking parser.ApTrackingData, vm *VirtualMachine) error {
addr, ok := getAddressFromReference(reference, apTracking, vm)
Expand Down
Loading

0 comments on commit e5fcc25

Please sign in to comment.