diff --git a/src/libfuncs/felt252.rs b/src/libfuncs/felt252.rs index 8aa049035..577a35d4e 100644 --- a/src/libfuncs/felt252.rs +++ b/src/libfuncs/felt252.rs @@ -22,7 +22,10 @@ use cairo_lang_sierra::{ program_registry::ProgramRegistry, }; use melior::{ - dialect::arith::{self, CmpiPredicate}, + dialect::{ + arith::{self, CmpiPredicate}, + cf, + }, ir::{ attribute::IntegerAttribute, r#type::IntegerType, Attribute, Block, Location, Value, ValueLike, @@ -197,8 +200,189 @@ where result } Felt252BinaryOperator::Div => { - // TODO: Implement `felt252_div` and `felt252_div_const`. - todo!("Implement `felt252_div` and `felt252_div_const`") + // The extended euclidean algorithm calculates the greatest common divisor of two integers, + // as well as the bezout coefficients x and y such that for inputs a and b, ax+by=gcd(a,b) + // We use this in felt division to find the modular inverse of a given number + // If a is the number we're trying to find the inverse of, we can do + // ax+y*PRIME=gcd(a,PRIME)=1 => ax = 1 (mod PRIME) + // Hence for input a, we return x + // The input MUST be non-zero + // See https://en.wikipedia.org/wiki/Extended_Euclidean_algorithm + let start_block = helper.append_block(Block::new(&[(i512, location)])); + let loop_block = helper.append_block(Block::new(&[ + (i512, location), + (i512, location), + (i512, location), + (i512, location), + ])); + let negative_check_block = helper.append_block(Block::new(&[])); + // Block containing final result + let inverse_result_block = helper.append_block(Block::new(&[(i512, location)])); + // Egcd works by calculating a series of remainders, each the remainder of dividing the previous two + // For the initial setup, r0 = PRIME, r1 = a + // This order is chosen because if we reverse them, then the first iteration will just swap them + let prev_remainder = start_block + .append_operation(arith::constant(context, attr_prime_i512, location)) + .result(0)? + .into(); + let remainder = start_block.argument(0)?.into(); + // Similarly we'll calculate another series which starts 0,1,... and from which we will retrieve the modular inverse of a + let prev_inverse = start_block + .append_operation(arith::constant( + context, + IntegerAttribute::new(0, i512).into(), + location, + )) + .result(0)? + .into(); + let inverse = start_block + .append_operation(arith::constant( + context, + IntegerAttribute::new(1, i512).into(), + location, + )) + .result(0)? + .into(); + start_block.append_operation(cf::br( + loop_block, + &[prev_remainder, remainder, prev_inverse, inverse], + location, + )); + + //---Loop body--- + // Arguments are rem_(i-1), rem, inv_(i-1), inv + let prev_remainder = loop_block.argument(0)?.into(); + let remainder = loop_block.argument(1)?.into(); + let prev_inverse = loop_block.argument(2)?.into(); + let inverse = loop_block.argument(3)?.into(); + + // First calculate q = rem_(i-1)/rem_i, rounded down + let quotient = loop_block + .append_operation(arith::divui(prev_remainder, remainder, location)) + .result(0)? + .into(); + // Then r_(i+1) = r_(i-1) - q * r_i, and inv_(i+1) = inv_(i-1) - q * inv_i + let rem_times_quo = loop_block + .append_operation(arith::muli(remainder, quotient, location)) + .result(0)? + .into(); + let inv_times_quo = loop_block + .append_operation(arith::muli(inverse, quotient, location)) + .result(0)? + .into(); + let next_remainder = loop_block + .append_operation(arith::subi(prev_remainder, rem_times_quo, location)) + .result(0)? + .into(); + let next_inverse = loop_block + .append_operation(arith::subi(prev_inverse, inv_times_quo, location)) + .result(0)? + .into(); + + // If r_(i+1) is 0, then inv_i is the inverse + let zero = loop_block + .append_operation(arith::constant( + context, + IntegerAttribute::new(0, i512).into(), + location, + )) + .result(0)? + .into(); + let next_remainder_eq_zero = loop_block + .append_operation(arith::cmpi( + context, + CmpiPredicate::Eq, + next_remainder, + zero, + location, + )) + .result(0)? + .into(); + loop_block.append_operation(cf::cond_br( + context, + next_remainder_eq_zero, + negative_check_block, + loop_block, + &[], + &[remainder, next_remainder, inverse, next_inverse], + location, + )); + + // egcd sometimes returns a negative number for the inverse, + // in such cases we must simply wrap it around back into [0, PRIME) + // this suffices because |inv_i| <= divfloor(PRIME,2) + let zero = negative_check_block + .append_operation(arith::constant( + context, + IntegerAttribute::new(0, i512).into(), + location, + )) + .result(0)? + .into(); + + let is_negative = negative_check_block + .append_operation(arith::cmpi( + context, + CmpiPredicate::Slt, + inverse, + zero, + location, + )) + .result(0)? + .into(); + // if the inverse is < 0, add PRIME + let prime = negative_check_block + .append_operation(arith::constant(context, attr_prime_i512, location)) + .result(0)? + .into(); + let wrapped_inverse = negative_check_block + .append_operation(arith::addi(inverse, prime, location)) + .result(0)? + .into(); + let inverse = negative_check_block + .append_operation(arith::select( + is_negative, + wrapped_inverse, + inverse, + location, + )) + .result(0)? + .into(); + negative_check_block.append_operation(cf::br( + inverse_result_block, + &[inverse], + location, + )); + + // Div Logic Start + // Fetch operands + let lhs = entry + .append_operation(arith::extui(lhs, i512, location)) + .result(0)? + .into(); + let rhs = entry + .append_operation(arith::extui(rhs, i512, location)) + .result(0)? + .into(); + // Calculate inverse of rhs, callling the inverse implementation's starting block + entry.append_operation(cf::br(start_block, &[rhs], location)); + // Fetch the inverse result from the result block + let inverse = inverse_result_block.argument(0)?.into(); + // Peform lhs * (1/ rhs) + let result = inverse_result_block + .append_operation(arith::muli(lhs, inverse, location)) + .result(0)? + .into(); + // Apply modulo and convert result to felt252 + mlir_asm! { context, inverse_result_block, location => + ; result_mod = "arith.remui"(result, prime) : (i512, i512) -> i512 + ; is_out_of_range = "arith.cmpi"(result, prime) { "predicate" = attr_cmp_uge } : (i512, i512) -> bool_ty + + ; result = "arith.select"(is_out_of_range, result_mod, result) : (bool_ty, i512, i512) -> i512 + ; result = "arith.trunci"(result) : (i512) -> felt252_ty + } + inverse_result_block.append_operation(helper.br(0, &[result], location)); + return Ok(()); } }; diff --git a/tests/cases.rs b/tests/cases.rs index da1f0ef86..863f4ca4b 100644 --- a/tests/cases.rs +++ b/tests/cases.rs @@ -11,7 +11,7 @@ mod common; #[test_case("tests/cases/felt_ops/felt_is_zero.cairo")] #[test_case("tests/cases/felt_ops/mul.cairo")] #[test_case("tests/cases/felt_ops/negation.cairo")] -#[test_case("tests/cases/felt_ops/div.cairo" => ignore["not implemented yet"])] +#[test_case("tests/cases/felt_ops/div.cairo")] // generic tests #[test_case("tests/cases/fib_counter.cairo")] #[test_case("tests/cases/fib_local.cairo")] diff --git a/tests/felt252.rs b/tests/felt252.rs index c11890a03..a3ba22c2e 100644 --- a/tests/felt252.rs +++ b/tests/felt252.rs @@ -1,4 +1,4 @@ -use crate::common::{any_felt, load_cairo, run_native_program, run_vm_program}; +use crate::common::{any_felt, load_cairo, nonzero_felt, run_native_program, run_vm_program}; use cairo_felt::Felt252 as DeprecatedFelt; use cairo_lang_runner::{Arg, SierraCasmRunner}; use cairo_lang_sierra::program::Program; @@ -31,7 +31,11 @@ lazy_static! { } }; - // TODO: Add test program for `felt252_div`. + static ref FELT252_DIV: (String, Program, SierraCasmRunner) = load_cairo! { + fn run_test(lhs: felt252, rhs: felt252) -> felt252 { + felt252_div(lhs, rhs.try_into().unwrap()) + } + }; // TODO: Add test program for `felt252_add_const`. // TODO: Add test program for `felt252_sub_const`. @@ -114,4 +118,24 @@ proptest! { &result_native, )?; } + + #[test] + fn felt_div_proptest(a in any_felt(), b in nonzero_felt()) { + let program = &FELT252_DIV; + let result_vm = run_vm_program( + program, + "run_test", + &[Arg::Value(DeprecatedFelt::from_bytes_be(&a.clone().to_bytes_be())), Arg::Value(DeprecatedFelt::from_bytes_be(&b.clone().to_bytes_be()))], + Some(GAS), + ) + .unwrap(); + let result_native = run_native_program(program, "run_test", &[JITValue::Felt252(a), JITValue::Felt252(b)]); + + compare_outputs( + &program.1, + &program.2.find_function("run_test").unwrap().id, + &result_vm, + &result_native, + )?; + } }