From f5d0483c172e9fb39b6adaa3fb1228963b573b16 Mon Sep 17 00:00:00 2001 From: Thomas Coratger <60488569+tcoratger@users.noreply.github.com> Date: Tue, 11 Jun 2024 22:16:50 +0200 Subject: [PATCH] Refactored sint128 to use the BlockExt trait (#666) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Refactored enum sint128 to use the BlockExt trait * fix --------- Co-authored-by: IƱaki Garay --- src/libfuncs/sint128.rs | 284 +++++++++++++--------------------------- 1 file changed, 89 insertions(+), 195 deletions(-) diff --git a/src/libfuncs/sint128.rs b/src/libfuncs/sint128.rs index 209730273..deaa6e893 100644 --- a/src/libfuncs/sint128.rs +++ b/src/libfuncs/sint128.rs @@ -25,12 +25,7 @@ use melior::{ arith::{self, CmpiPredicate}, cf, llvm, }, - ir::{ - attribute::{DenseI64ArrayAttribute, IntegerAttribute}, - operation::OperationBuilder, - r#type::IntegerType, - Attribute, Block, Location, Value, ValueLike, - }, + ir::{operation::OperationBuilder, r#type::IntegerType, Block, Location, Value, ValueLike}, Context, }; use starknet_types_core::felt::Felt; @@ -122,59 +117,34 @@ pub fn build_operation<'ctx, 'this>( false, ); - let result = entry - .append_operation( - OperationBuilder::new(op_name, location) - .add_operands(&[lhs, rhs]) - .add_results(&[result_type]) - .build()?, - ) - .result(0)? - .into(); - - let op_result = entry - .append_operation(llvm::extract_value( - context, - result, - DenseI64ArrayAttribute::new(context, &[0]), - values_type, - location, - )) - .result(0)? - .into(); + let result = entry.append_op_result( + OperationBuilder::new(op_name, location) + .add_operands(&[lhs, rhs]) + .add_results(&[result_type]) + .build()?, + )?; + + let op_result = entry.extract_value(context, location, result, values_type, 0)?; // Create a const operation to get the 0 value to compare against - let zero_const = entry - .append_operation(arith::constant( - context, - IntegerAttribute::new(values_type, 0.into()).into(), - location, - )) - .result(0)? - .into(); + let zero_const = entry.const_int_from_type(context, location, 0, values_type)?; // Check if the result is positive - let is_positive = entry - .append_operation(arith::cmpi( - context, - CmpiPredicate::Sge, - op_result, - zero_const, - location, - )) - .result(0)? - .into(); + let is_positive = entry.append_op_result(arith::cmpi( + context, + CmpiPredicate::Sge, + op_result, + zero_const, + location, + ))?; // Check overflow flag - let op_overflow = entry - .append_operation(llvm::extract_value( - context, - result, - DenseI64ArrayAttribute::new(context, &[1]), - IntegerType::new(context, 1).into(), - location, - )) - .result(0)? - .into(); + let op_overflow = entry.extract_value( + context, + location, + result, + IntegerType::new(context, 1).into(), + 1, + )?; let block_not_overflow = helper.append_block(Block::new(&[])); let block_overflow = helper.append_block(Block::new(&[])); @@ -245,23 +215,15 @@ pub fn build_is_zero<'ctx, 'this>( ) -> Result<()> { let arg0: Value = entry.argument(0)?.into(); - let op = entry.append_operation(arith::constant( + let const_0 = entry.const_int_from_type(context, location, 0, arg0.r#type())?; + + let condition = entry.append_op_result(arith::cmpi( context, - IntegerAttribute::new(arg0.r#type(), 0).into(), + CmpiPredicate::Eq, + arg0, + const_0, location, - )); - let const_0 = op.result(0)?.into(); - - let condition = entry - .append_operation(arith::cmpi( - context, - CmpiPredicate::Eq, - arg0, - const_0, - location, - )) - .result(0)? - .into(); + ))?; entry.append_operation(helper.cond_br(context, condition, [0, 1], [&[], &[arg0]], location)); @@ -287,10 +249,7 @@ pub fn build_to_felt252<'ctx, 'this>( )?; let value: Value = entry.argument(0)?.into(); - let result = entry - .append_operation(arith::extui(value, felt252_ty, location)) - .result(0)? - .into(); + let result = entry.append_op_result(arith::extui(value, felt252_ty, location))?; entry.append_operation(helper.br(0, &[result], location)); @@ -327,59 +286,32 @@ pub fn build_from_felt252<'ctx, 'this>( &info.branch_signatures()[0].vars[1].ty, )?; - let const_max = entry - .append_operation(arith::constant( - context, - Attribute::parse(context, &format!("{} : {}", i128::MAX, felt252_ty)) - .ok_or(Error::ParseAttributeError)?, - location, - )) - .result(0)? - .into(); - - let const_min = entry - .append_operation(arith::constant( - context, - Attribute::parse(context, &format!("{} : {}", i128::MIN, felt252_ty)) - .ok_or(Error::ParseAttributeError)?, - location, - )) - .result(0)? - .into(); + let const_max = entry.const_int_from_type(context, location, i128::MAX, felt252_ty)?; + let const_min = entry.const_int_from_type(context, location, i128::MIN, felt252_ty)?; let mut block = entry; // make unsigned felt into signed felt // felt > half prime = negative let value = { - let attr_halfprime_i252 = Attribute::parse( + let half_prime: melior::ir::Value = block.const_int_from_type( context, - &format!( - "{} : {}", - metadata - .get::>() - .ok_or(Error::MissingMetadata)? - .prime() - .shr(1), - felt252_ty - ), - ) - .ok_or(Error::ParseAttributeError)?; - let half_prime: melior::ir::Value = block - .append_operation(arith::constant(context, attr_halfprime_i252, location)) - .result(0)? - .into(); - - let is_felt_neg = block - .append_operation(arith::cmpi( - context, - CmpiPredicate::Ugt, - value, - half_prime, - location, - )) - .result(0)? - .into(); + location, + metadata + .get::>() + .ok_or(Error::MissingMetadata)? + .prime() + .shr(1), + felt252_ty, + )?; + + let is_felt_neg = block.append_op_result(arith::cmpi( + context, + CmpiPredicate::Ugt, + value, + half_prime, + location, + ))?; let is_neg_block = helper.append_block(Block::new(&[])); let is_not_neg_block = helper.append_block(Block::new(&[])); @@ -396,45 +328,24 @@ pub fn build_from_felt252<'ctx, 'this>( )); { - let prime = is_neg_block - .append_operation(arith::constant( - context, - Attribute::parse( - context, - &format!( - "{} : {}", - metadata - .get::>() - .ok_or(Error::MissingMetadata)? - .prime(), - felt252_ty - ), - ) - .ok_or(Error::ParseAttributeError)?, - location, - )) - .result(0)? - .into(); - - let mut src_value_is_neg: melior::ir::Value = is_neg_block - .append_operation(arith::subi(prime, value, location)) - .result(0)? - .into(); - - let kneg1 = is_neg_block - .append_operation(arith::constant( - context, - Attribute::parse(context, &format!("-1 : {}", felt252_ty)) - .ok_or(Error::ParseAttributeError)?, - location, - )) - .result(0)? - .into(); - - src_value_is_neg = is_neg_block - .append_operation(arith::muli(src_value_is_neg, kneg1, location)) - .result(0)? - .into(); + let prime = is_neg_block.const_int_from_type( + context, + location, + metadata + .get::>() + .ok_or(Error::MissingMetadata)? + .prime() + .clone(), + felt252_ty, + )?; + + let mut src_value_is_neg: melior::ir::Value = + is_neg_block.append_op_result(arith::subi(prime, value, location))?; + + let kneg1 = is_neg_block.const_int_from_type(context, location, -1, felt252_ty)?; + + src_value_is_neg = + is_neg_block.append_op_result(arith::muli(src_value_is_neg, kneg1, location))?; is_neg_block.append_operation(cf::br(final_block, &[src_value_is_neg], location)); } @@ -446,32 +357,23 @@ pub fn build_from_felt252<'ctx, 'this>( block.argument(0)?.into() }; - let is_smaller_eq = block - .append_operation(arith::cmpi( - context, - CmpiPredicate::Sle, - value, - const_max, - location, - )) - .result(0)? - .into(); + let is_smaller_eq = block.append_op_result(arith::cmpi( + context, + CmpiPredicate::Sle, + value, + const_max, + location, + ))?; - let is_bigger_eq = block - .append_operation(arith::cmpi( - context, - CmpiPredicate::Sge, - value, - const_min, - location, - )) - .result(0)? - .into(); + let is_bigger_eq = block.append_op_result(arith::cmpi( + context, + CmpiPredicate::Sge, + value, + const_min, + location, + ))?; - let is_ok = block - .append_operation(arith::andi(is_smaller_eq, is_bigger_eq, location)) - .result(0)? - .into(); + let is_ok = block.append_op_result(arith::andi(is_smaller_eq, is_bigger_eq, location))?; let block_success = helper.append_block(Block::new(&[])); let block_failure = helper.append_block(Block::new(&[])); @@ -486,10 +388,7 @@ pub fn build_from_felt252<'ctx, 'this>( location, )); - let value = block_success - .append_operation(arith::trunci(value, result_ty, location)) - .result(0)? - .into(); + let value = block_success.append_op_result(arith::trunci(value, result_ty, location))?; block_success.append_operation(helper.br(0, &[range_check, value], location)); block_failure.append_operation(helper.br(1, &[range_check], location)); @@ -513,15 +412,10 @@ pub fn build_diff<'ctx, 'this>( let rhs: Value = entry.argument(2)?.into(); // Check if lhs >= rhs - let is_ge = entry - .append_operation(arith::cmpi(context, CmpiPredicate::Sge, lhs, rhs, location)) - .result(0)? - .into(); - - let result = entry - .append_operation(arith::subi(lhs, rhs, location)) - .result(0)? - .into(); + let is_ge = + entry.append_op_result(arith::cmpi(context, CmpiPredicate::Sge, lhs, rhs, location))?; + + let result = entry.append_op_result(arith::subi(lhs, rhs, location))?; entry.append_operation(helper.cond_br( context,