Skip to content

Commit

Permalink
Implement felt_div libfunc (#389)
Browse files Browse the repository at this point in the history
* Initial progress

* Fix bug

* Kill two blocks with one arith::select

* Apply modulo to result

* Remove comment

* Add proptests
  • Loading branch information
fmoletta authored Dec 18, 2023
1 parent 1b4d1d5 commit 57c04b3
Show file tree
Hide file tree
Showing 3 changed files with 214 additions and 6 deletions.
190 changes: 187 additions & 3 deletions src/libfuncs/felt252.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@ use cairo_lang_sierra::{
program_registry::ProgramRegistry,
};
use melior::{
dialect::arith::{self, CmpiPredicate},
dialect::{
arith::{self, CmpiPredicate},
cf,
},
ir::{
attribute::IntegerAttribute, r#type::IntegerType, Attribute, Block, Location, Value,
ValueLike,
Expand Down Expand Up @@ -197,8 +200,189 @@ where
result
}
Felt252BinaryOperator::Div => {
// TODO: Implement `felt252_div` and `felt252_div_const`.
todo!("Implement `felt252_div` and `felt252_div_const`")
// The extended euclidean algorithm calculates the greatest common divisor of two integers,
// as well as the bezout coefficients x and y such that for inputs a and b, ax+by=gcd(a,b)
// We use this in felt division to find the modular inverse of a given number
// If a is the number we're trying to find the inverse of, we can do
// ax+y*PRIME=gcd(a,PRIME)=1 => ax = 1 (mod PRIME)
// Hence for input a, we return x
// The input MUST be non-zero
// See https://en.wikipedia.org/wiki/Extended_Euclidean_algorithm
let start_block = helper.append_block(Block::new(&[(i512, location)]));
let loop_block = helper.append_block(Block::new(&[
(i512, location),
(i512, location),
(i512, location),
(i512, location),
]));
let negative_check_block = helper.append_block(Block::new(&[]));
// Block containing final result
let inverse_result_block = helper.append_block(Block::new(&[(i512, location)]));
// Egcd works by calculating a series of remainders, each the remainder of dividing the previous two
// For the initial setup, r0 = PRIME, r1 = a
// This order is chosen because if we reverse them, then the first iteration will just swap them
let prev_remainder = start_block
.append_operation(arith::constant(context, attr_prime_i512, location))
.result(0)?
.into();
let remainder = start_block.argument(0)?.into();
// Similarly we'll calculate another series which starts 0,1,... and from which we will retrieve the modular inverse of a
let prev_inverse = start_block
.append_operation(arith::constant(
context,
IntegerAttribute::new(0, i512).into(),
location,
))
.result(0)?
.into();
let inverse = start_block
.append_operation(arith::constant(
context,
IntegerAttribute::new(1, i512).into(),
location,
))
.result(0)?
.into();
start_block.append_operation(cf::br(
loop_block,
&[prev_remainder, remainder, prev_inverse, inverse],
location,
));

//---Loop body---
// Arguments are rem_(i-1), rem, inv_(i-1), inv
let prev_remainder = loop_block.argument(0)?.into();
let remainder = loop_block.argument(1)?.into();
let prev_inverse = loop_block.argument(2)?.into();
let inverse = loop_block.argument(3)?.into();

// First calculate q = rem_(i-1)/rem_i, rounded down
let quotient = loop_block
.append_operation(arith::divui(prev_remainder, remainder, location))
.result(0)?
.into();
// Then r_(i+1) = r_(i-1) - q * r_i, and inv_(i+1) = inv_(i-1) - q * inv_i
let rem_times_quo = loop_block
.append_operation(arith::muli(remainder, quotient, location))
.result(0)?
.into();
let inv_times_quo = loop_block
.append_operation(arith::muli(inverse, quotient, location))
.result(0)?
.into();
let next_remainder = loop_block
.append_operation(arith::subi(prev_remainder, rem_times_quo, location))
.result(0)?
.into();
let next_inverse = loop_block
.append_operation(arith::subi(prev_inverse, inv_times_quo, location))
.result(0)?
.into();

// If r_(i+1) is 0, then inv_i is the inverse
let zero = loop_block
.append_operation(arith::constant(
context,
IntegerAttribute::new(0, i512).into(),
location,
))
.result(0)?
.into();
let next_remainder_eq_zero = loop_block
.append_operation(arith::cmpi(
context,
CmpiPredicate::Eq,
next_remainder,
zero,
location,
))
.result(0)?
.into();
loop_block.append_operation(cf::cond_br(
context,
next_remainder_eq_zero,
negative_check_block,
loop_block,
&[],
&[remainder, next_remainder, inverse, next_inverse],
location,
));

// egcd sometimes returns a negative number for the inverse,
// in such cases we must simply wrap it around back into [0, PRIME)
// this suffices because |inv_i| <= divfloor(PRIME,2)
let zero = negative_check_block
.append_operation(arith::constant(
context,
IntegerAttribute::new(0, i512).into(),
location,
))
.result(0)?
.into();

let is_negative = negative_check_block
.append_operation(arith::cmpi(
context,
CmpiPredicate::Slt,
inverse,
zero,
location,
))
.result(0)?
.into();
// if the inverse is < 0, add PRIME
let prime = negative_check_block
.append_operation(arith::constant(context, attr_prime_i512, location))
.result(0)?
.into();
let wrapped_inverse = negative_check_block
.append_operation(arith::addi(inverse, prime, location))
.result(0)?
.into();
let inverse = negative_check_block
.append_operation(arith::select(
is_negative,
wrapped_inverse,
inverse,
location,
))
.result(0)?
.into();
negative_check_block.append_operation(cf::br(
inverse_result_block,
&[inverse],
location,
));

// Div Logic Start
// Fetch operands
let lhs = entry
.append_operation(arith::extui(lhs, i512, location))
.result(0)?
.into();
let rhs = entry
.append_operation(arith::extui(rhs, i512, location))
.result(0)?
.into();
// Calculate inverse of rhs, callling the inverse implementation's starting block
entry.append_operation(cf::br(start_block, &[rhs], location));
// Fetch the inverse result from the result block
let inverse = inverse_result_block.argument(0)?.into();
// Peform lhs * (1/ rhs)
let result = inverse_result_block
.append_operation(arith::muli(lhs, inverse, location))
.result(0)?
.into();
// Apply modulo and convert result to felt252
mlir_asm! { context, inverse_result_block, location =>
; result_mod = "arith.remui"(result, prime) : (i512, i512) -> i512
; is_out_of_range = "arith.cmpi"(result, prime) { "predicate" = attr_cmp_uge } : (i512, i512) -> bool_ty

; result = "arith.select"(is_out_of_range, result_mod, result) : (bool_ty, i512, i512) -> i512
; result = "arith.trunci"(result) : (i512) -> felt252_ty
}
inverse_result_block.append_operation(helper.br(0, &[result], location));
return Ok(());
}
};

Expand Down
2 changes: 1 addition & 1 deletion tests/cases.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ mod common;
#[test_case("tests/cases/felt_ops/felt_is_zero.cairo")]
#[test_case("tests/cases/felt_ops/mul.cairo")]
#[test_case("tests/cases/felt_ops/negation.cairo")]
#[test_case("tests/cases/felt_ops/div.cairo" => ignore["not implemented yet"])]
#[test_case("tests/cases/felt_ops/div.cairo")]
// generic tests
#[test_case("tests/cases/fib_counter.cairo")]
#[test_case("tests/cases/fib_local.cairo")]
Expand Down
28 changes: 26 additions & 2 deletions tests/felt252.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::common::{any_felt, load_cairo, run_native_program, run_vm_program};
use crate::common::{any_felt, load_cairo, nonzero_felt, run_native_program, run_vm_program};
use cairo_felt::Felt252 as DeprecatedFelt;
use cairo_lang_runner::{Arg, SierraCasmRunner};
use cairo_lang_sierra::program::Program;
Expand Down Expand Up @@ -31,7 +31,11 @@ lazy_static! {
}
};

// TODO: Add test program for `felt252_div`.
static ref FELT252_DIV: (String, Program, SierraCasmRunner) = load_cairo! {
fn run_test(lhs: felt252, rhs: felt252) -> felt252 {
felt252_div(lhs, rhs.try_into().unwrap())
}
};

// TODO: Add test program for `felt252_add_const`.
// TODO: Add test program for `felt252_sub_const`.
Expand Down Expand Up @@ -114,4 +118,24 @@ proptest! {
&result_native,
)?;
}

#[test]
fn felt_div_proptest(a in any_felt(), b in nonzero_felt()) {
let program = &FELT252_DIV;
let result_vm = run_vm_program(
program,
"run_test",
&[Arg::Value(DeprecatedFelt::from_bytes_be(&a.clone().to_bytes_be())), Arg::Value(DeprecatedFelt::from_bytes_be(&b.clone().to_bytes_be()))],
Some(GAS),
)
.unwrap();
let result_native = run_native_program(program, "run_test", &[JITValue::Felt252(a), JITValue::Felt252(b)]);

compare_outputs(
&program.1,
&program.2.find_function("run_test").unwrap().id,
&result_vm,
&result_native,
)?;
}
}

0 comments on commit 57c04b3

Please sign in to comment.