Skip to content

Commit

Permalink
Refactored sint64 to use the BlockExt trait (#665)
Browse files Browse the repository at this point in the history
* Refactored enum sint64 to use the BlockExt trait

* fix

---------

Co-authored-by: Iñaki Garay <[email protected]>
  • Loading branch information
tcoratger and igaray authored Jun 11, 2024
1 parent f5d0483 commit 8a3502b
Showing 1 changed file with 92 additions and 207 deletions.
299 changes: 92 additions & 207 deletions src/libfuncs/sint64.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,7 @@ use melior::{
arith::{self, CmpiPredicate},
cf, llvm,
},
ir::{
attribute::{DenseI64ArrayAttribute, IntegerAttribute},
operation::OperationBuilder,
r#type::IntegerType,
Attribute, Block, Location, Value, ValueLike,
},
ir::{operation::OperationBuilder, r#type::IntegerType, Block, Location, Value, ValueLike},
Context,
};
use starknet_types_core::felt::Felt;
Expand Down Expand Up @@ -123,59 +118,34 @@ pub fn build_operation<'ctx, 'this>(
false,
);

let result = entry
.append_operation(
OperationBuilder::new(op_name, location)
.add_operands(&[lhs, rhs])
.add_results(&[result_type])
.build()?,
)
.result(0)?
.into();

let op_result = entry
.append_operation(llvm::extract_value(
context,
result,
DenseI64ArrayAttribute::new(context, &[0]),
values_type,
location,
))
.result(0)?
.into();
let result = entry.append_op_result(
OperationBuilder::new(op_name, location)
.add_operands(&[lhs, rhs])
.add_results(&[result_type])
.build()?,
)?;

let op_result = entry.extract_value(context, location, result, values_type, 0)?;

// Create a const operation to get the 0 value to compare against
let zero_const = entry
.append_operation(arith::constant(
context,
IntegerAttribute::new(values_type, 0.into()).into(),
location,
))
.result(0)?
.into();
let zero_const = entry.const_int_from_type(context, location, 0, values_type)?;
// Check if the result is positive
let is_positive = entry
.append_operation(arith::cmpi(
context,
CmpiPredicate::Sge,
op_result,
zero_const,
location,
))
.result(0)?
.into();
let is_positive = entry.append_op_result(arith::cmpi(
context,
CmpiPredicate::Sge,
op_result,
zero_const,
location,
))?;

// Check overflow flag
let op_overflow = entry
.append_operation(llvm::extract_value(
context,
result,
DenseI64ArrayAttribute::new(context, &[1]),
IntegerType::new(context, 1).into(),
location,
))
.result(0)?
.into();
let op_overflow = entry.extract_value(
context,
location,
result,
IntegerType::new(context, 1).into(),
1,
)?;

let block_not_overflow = helper.append_block(Block::new(&[]));
let block_overflow = helper.append_block(Block::new(&[]));
Expand Down Expand Up @@ -246,23 +216,15 @@ pub fn build_is_zero<'ctx, 'this>(
) -> Result<()> {
let arg0: Value = entry.argument(0)?.into();

let op = entry.append_operation(arith::constant(
let const_0 = entry.const_int_from_type(context, location, 0, arg0.r#type())?;

let condition = entry.append_op_result(arith::cmpi(
context,
IntegerAttribute::new(arg0.r#type(), 0).into(),
CmpiPredicate::Eq,
arg0,
const_0,
location,
));
let const_0 = op.result(0)?.into();

let condition = entry
.append_operation(arith::cmpi(
context,
CmpiPredicate::Eq,
arg0,
const_0,
location,
))
.result(0)?
.into();
))?;

entry.append_operation(helper.cond_br(context, condition, [0, 1], [&[], &[arg0]], location));

Expand All @@ -289,19 +251,10 @@ pub fn build_widemul<'ctx, 'this>(
let lhs: Value = entry.argument(0)?.into();
let rhs: Value = entry.argument(1)?.into();

let lhs = entry
.append_operation(arith::extsi(lhs, target_type, location))
.result(0)?
.into();
let rhs = entry
.append_operation(arith::extsi(rhs, target_type, location))
.result(0)?
.into();
let lhs = entry.append_op_result(arith::extsi(lhs, target_type, location))?;
let rhs = entry.append_op_result(arith::extsi(rhs, target_type, location))?;

let result = entry
.append_operation(arith::muli(lhs, rhs, location))
.result(0)?
.into();
let result = entry.append_op_result(arith::muli(lhs, rhs, location))?;

entry.append_operation(helper.br(0, &[result], location));
Ok(())
Expand All @@ -326,10 +279,7 @@ pub fn build_to_felt252<'ctx, 'this>(
)?;
let value: Value = entry.argument(0)?.into();

let result = entry
.append_operation(arith::extui(value, felt252_ty, location))
.result(0)?
.into();
let result = entry.append_op_result(arith::extui(value, felt252_ty, location))?;

entry.append_operation(helper.br(0, &[result], location));

Expand Down Expand Up @@ -366,59 +316,32 @@ pub fn build_from_felt252<'ctx, 'this>(
&info.branch_signatures()[0].vars[1].ty,
)?;

let const_max = entry
.append_operation(arith::constant(
context,
Attribute::parse(context, &format!("{} : {}", i64::MAX, felt252_ty))
.ok_or(Error::ParseAttributeError)?,
location,
))
.result(0)?
.into();

let const_min = entry
.append_operation(arith::constant(
context,
Attribute::parse(context, &format!("{} : {}", i64::MIN, felt252_ty))
.ok_or(Error::ParseAttributeError)?,
location,
))
.result(0)?
.into();
let const_max = entry.const_int_from_type(context, location, i64::MAX, felt252_ty)?;
let const_min = entry.const_int_from_type(context, location, i64::MIN, felt252_ty)?;

let mut block = entry;

// make unsigned felt into signed felt
// felt > half prime = negative
let value = {
let attr_halfprime_i252 = Attribute::parse(
let half_prime: melior::ir::Value = block.const_int_from_type(
context,
&format!(
"{} : {}",
metadata
.get::<PrimeModuloMeta<Felt>>()
.ok_or(Error::MissingMetadata)?
.prime()
.shr(1),
felt252_ty
),
)
.ok_or(Error::ParseAttributeError)?;
let half_prime: melior::ir::Value = block
.append_operation(arith::constant(context, attr_halfprime_i252, location))
.result(0)?
.into();

let is_felt_neg = block
.append_operation(arith::cmpi(
context,
CmpiPredicate::Ugt,
value,
half_prime,
location,
))
.result(0)?
.into();
location,
metadata
.get::<PrimeModuloMeta<Felt>>()
.ok_or(Error::MissingMetadata)?
.prime()
.shr(1),
felt252_ty,
)?;

let is_felt_neg = block.append_op_result(arith::cmpi(
context,
CmpiPredicate::Ugt,
value,
half_prime,
location,
))?;

let is_neg_block = helper.append_block(Block::new(&[]));
let is_not_neg_block = helper.append_block(Block::new(&[]));
Expand All @@ -435,45 +358,24 @@ pub fn build_from_felt252<'ctx, 'this>(
));

{
let prime = is_neg_block
.append_operation(arith::constant(
context,
Attribute::parse(
context,
&format!(
"{} : {}",
metadata
.get::<PrimeModuloMeta<Felt>>()
.ok_or(Error::MissingMetadata)?
.prime(),
felt252_ty
),
)
.ok_or(Error::ParseAttributeError)?,
location,
))
.result(0)?
.into();

let mut src_value_is_neg: melior::ir::Value = is_neg_block
.append_operation(arith::subi(prime, value, location))
.result(0)?
.into();

let kneg1 = is_neg_block
.append_operation(arith::constant(
context,
Attribute::parse(context, &format!("-1 : {}", felt252_ty))
.ok_or(Error::ParseAttributeError)?,
location,
))
.result(0)?
.into();

src_value_is_neg = is_neg_block
.append_operation(arith::muli(src_value_is_neg, kneg1, location))
.result(0)?
.into();
let prime = is_neg_block.const_int_from_type(
context,
location,
metadata
.get::<PrimeModuloMeta<Felt>>()
.ok_or(Error::MissingMetadata)?
.prime()
.clone(),
felt252_ty,
)?;

let mut src_value_is_neg: melior::ir::Value =
is_neg_block.append_op_result(arith::subi(prime, value, location))?;

let kneg1 = is_neg_block.const_int_from_type(context, location, -1, felt252_ty)?;

src_value_is_neg =
is_neg_block.append_op_result(arith::muli(src_value_is_neg, kneg1, location))?;

is_neg_block.append_operation(cf::br(final_block, &[src_value_is_neg], location));
}
Expand All @@ -485,32 +387,23 @@ pub fn build_from_felt252<'ctx, 'this>(
block.argument(0)?.into()
};

let is_smaller_eq = block
.append_operation(arith::cmpi(
context,
CmpiPredicate::Sle,
value,
const_max,
location,
))
.result(0)?
.into();
let is_smaller_eq = block.append_op_result(arith::cmpi(
context,
CmpiPredicate::Sle,
value,
const_max,
location,
))?;

let is_bigger_eq = block
.append_operation(arith::cmpi(
context,
CmpiPredicate::Sge,
value,
const_min,
location,
))
.result(0)?
.into();
let is_bigger_eq = block.append_op_result(arith::cmpi(
context,
CmpiPredicate::Sge,
value,
const_min,
location,
))?;

let is_ok = block
.append_operation(arith::andi(is_smaller_eq, is_bigger_eq, location))
.result(0)?
.into();
let is_ok = block.append_op_result(arith::andi(is_smaller_eq, is_bigger_eq, location))?;

let block_success = helper.append_block(Block::new(&[]));
let block_failure = helper.append_block(Block::new(&[]));
Expand All @@ -525,10 +418,7 @@ pub fn build_from_felt252<'ctx, 'this>(
location,
));

let value = block_success
.append_operation(arith::trunci(value, result_ty, location))
.result(0)?
.into();
let value = block_success.append_op_result(arith::trunci(value, result_ty, location))?;

block_success.append_operation(helper.br(0, &[range_check, value], location));
block_failure.append_operation(helper.br(1, &[range_check], location));
Expand All @@ -552,15 +442,10 @@ pub fn build_diff<'ctx, 'this>(
let rhs: Value = entry.argument(2)?.into();

// Check if lhs >= rhs
let is_ge = entry
.append_operation(arith::cmpi(context, CmpiPredicate::Sge, lhs, rhs, location))
.result(0)?
.into();

let result = entry
.append_operation(arith::subi(lhs, rhs, location))
.result(0)?
.into();
let is_ge =
entry.append_op_result(arith::cmpi(context, CmpiPredicate::Sge, lhs, rhs, location))?;

let result = entry.append_op_result(arith::subi(lhs, rhs, location))?;

entry.append_operation(helper.cond_br(
context,
Expand Down

0 comments on commit 8a3502b

Please sign in to comment.