From cb7e483802b91a1f1331b1686ea69847ce495e50 Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Sat, 21 Oct 2023 19:04:09 +0300 Subject: [PATCH] implement FP16x16Wide --- src/numbers.cairo | 170 ++- src/numbers/fixed_point/implementations.cairo | 1 + .../implementations/fp16x16/core.cairo | 26 +- .../implementations/fp16x16wide.cairo | 3 + .../implementations/fp16x16wide/core.cairo | 421 ++++++ .../implementations/fp16x16wide/helpers.cairo | 41 + .../implementations/fp16x16wide/math.cairo | 5 + .../fp16x16wide/math/comp.cairo | 76 + .../fp16x16wide/math/core.cairo | 659 +++++++++ .../fp16x16wide/math/hyp.cairo | 159 +++ .../fp16x16wide/math/lut.cairo | 1235 +++++++++++++++++ .../fp16x16wide/math/trig.cairo | 450 ++++++ src/operators/nn/functional/softmax.cairo | 18 +- .../nn/implementations/nn_fp16x16.cairo | 6 +- src/operators/tensor.cairo | 6 + src/operators/tensor/implementations.cairo | 1 + .../implementations/tensor_fp16x16wide.cairo | 374 +++++ 17 files changed, 3619 insertions(+), 32 deletions(-) create mode 100644 src/numbers/fixed_point/implementations/fp16x16wide.cairo create mode 100644 src/numbers/fixed_point/implementations/fp16x16wide/core.cairo create mode 100644 src/numbers/fixed_point/implementations/fp16x16wide/helpers.cairo create mode 100644 src/numbers/fixed_point/implementations/fp16x16wide/math.cairo create mode 100644 src/numbers/fixed_point/implementations/fp16x16wide/math/comp.cairo create mode 100644 src/numbers/fixed_point/implementations/fp16x16wide/math/core.cairo create mode 100644 src/numbers/fixed_point/implementations/fp16x16wide/math/hyp.cairo create mode 100644 src/numbers/fixed_point/implementations/fp16x16wide/math/lut.cairo create mode 100644 src/numbers/fixed_point/implementations/fp16x16wide/math/trig.cairo create mode 100644 src/operators/tensor/implementations/tensor_fp16x16wide.cairo diff --git a/src/numbers.cairo b/src/numbers.cairo index d40589d38..a83365754 100644 --- a/src/numbers.cairo +++ b/src/numbers.cairo @@ -216,7 +216,7 @@ impl FP8x23Number of NumberTrait { } use orion::numbers::fixed_point::implementations::fp16x16::core::{ - FP16x16Impl, FP16x16, FP16x16IntoFP64x64 + FP16x16Impl, FP16x16, FP16x16IntoFP16x16Wide }; use orion::numbers::fixed_point::implementations::fp16x16::math::core as core_fp16x16; use orion::numbers::fixed_point::implementations::fp16x16::math::comp as comp_fp16x16; @@ -382,6 +382,174 @@ impl FP16x16Number of NumberTrait { } } +use orion::numbers::fixed_point::implementations::fp16x16wide::core::{ + FP16x16Impl as FP16x16WideImpl, FP16x16Wide, FP16x16IntoFP64x64 as FP16x16WideIntoFP64x64 +}; +use orion::numbers::fixed_point::implementations::fp16x16wide::math::core as core_fp16x16wide; +use orion::numbers::fixed_point::implementations::fp16x16wide::math::comp as comp_fp16x16wide; + +impl FP16x16WideNumber of NumberTrait { + fn new(mag: u64, sign: bool) -> FP16x16Wide { + FP16x16WideImpl::new(mag, sign) + } + + fn new_unscaled(mag: u64, sign: bool) -> FP16x16Wide { + FP16x16WideImpl::new_unscaled(mag, sign) + } + + fn from_felt(val: felt252) -> FP16x16Wide { + FP16x16WideImpl::from_felt(val) + } + + fn ceil(self: FP16x16Wide) -> FP16x16Wide { + FP16x16WideImpl::ceil(self) + } + + fn exp(self: FP16x16Wide) -> FP16x16Wide { + FP16x16WideImpl::exp(self) + } + + fn exp2(self: FP16x16Wide) -> FP16x16Wide { + FP16x16WideImpl::exp2(self) + } + + fn floor(self: FP16x16Wide) -> FP16x16Wide { + FP16x16WideImpl::floor(self) + } + + fn ln(self: FP16x16Wide) -> FP16x16Wide { + FP16x16WideImpl::ln(self) + } + + fn log2(self: FP16x16Wide) -> FP16x16Wide { + FP16x16WideImpl::log2(self) + } + + fn log10(self: FP16x16Wide) -> FP16x16Wide { + FP16x16WideImpl::log10(self) + } + + fn pow(self: FP16x16Wide, b: FP16x16Wide) -> FP16x16Wide { + FP16x16WideImpl::pow(self, b) + } + + fn round(self: FP16x16Wide) -> FP16x16Wide { + FP16x16WideImpl::round(self) + } + + fn sqrt(self: FP16x16Wide) -> FP16x16Wide { + FP16x16WideImpl::sqrt(self) + } + + fn acos(self: FP16x16Wide) -> FP16x16Wide { + FP16x16WideImpl::acos(self) + } + + fn asin(self: FP16x16Wide) -> FP16x16Wide { + FP16x16WideImpl::asin(self) + } + + fn atan(self: FP16x16Wide) -> FP16x16Wide { + FP16x16WideImpl::atan(self) + } + + fn cos(self: FP16x16Wide) -> FP16x16Wide { + FP16x16WideImpl::cos(self) + } + + fn sin(self: FP16x16Wide) -> FP16x16Wide { + FP16x16WideImpl::sin(self) + } + + fn tan(self: FP16x16Wide) -> FP16x16Wide { + FP16x16WideImpl::tan(self) + } + + fn acosh(self: FP16x16Wide) -> FP16x16Wide { + FP16x16WideImpl::acosh(self) + } + + fn asinh(self: FP16x16Wide) -> FP16x16Wide { + FP16x16WideImpl::asinh(self) + } + + fn atanh(self: FP16x16Wide) -> FP16x16Wide { + FP16x16WideImpl::atanh(self) + } + + fn cosh(self: FP16x16Wide) -> FP16x16Wide { + FP16x16WideImpl::cosh(self) + } + + fn sinh(self: FP16x16Wide) -> FP16x16Wide { + FP16x16WideImpl::sinh(self) + } + + fn tanh(self: FP16x16Wide) -> FP16x16Wide { + FP16x16WideImpl::tanh(self) + } + + fn zero() -> FP16x16Wide { + FP16x16WideImpl::ZERO() + } + fn is_zero(self: FP16x16Wide) -> bool { + core_fp16x16wide::eq(@self, @FP16x16WideImpl::ZERO()) + } + + fn one() -> FP16x16Wide { + FP16x16WideImpl::ONE() + } + + fn neg_one() -> FP16x16Wide { + FP16x16Wide { mag: core_fp16x16wide::ONE, sign: true } + } + + fn is_one(self: FP16x16Wide) -> bool { + core_fp16x16wide::eq(@self, @FP16x16WideImpl::ONE()) + } + + fn abs(self: FP16x16Wide) -> FP16x16Wide { + core_fp16x16wide::abs(self) + } + + fn min_value() -> FP16x16Wide { + FP16x16Wide { mag: core_fp16x16wide::MAX, sign: true } + } + + fn max_value() -> FP16x16Wide { + FP16x16Wide { mag: core_fp16x16wide::MAX, sign: false } + } + + fn min(self: FP16x16Wide, other: FP16x16Wide) -> FP16x16Wide { + comp_fp16x16wide::min(self, other) + } + + fn max(self: FP16x16Wide, other: FP16x16Wide) -> FP16x16Wide { + comp_fp16x16wide::max(self, other) + } + + fn mag(self: FP16x16Wide) -> u64 { + self.mag + } + + fn is_neg(self: FP16x16Wide) -> bool { + self.sign + } + + fn xor(lhs: FP16x16Wide, rhs: FP16x16Wide) -> bool { + comp_fp16x16wide::xor(lhs, rhs) + } + + fn or(lhs: FP16x16Wide, rhs: FP16x16Wide) -> bool { + comp_fp16x16wide::or(lhs, rhs) + } + + fn sign(self: FP16x16Wide) -> FP16x16Wide { + core_fp16x16wide::sign(self) + } +} + + use orion::numbers::fixed_point::implementations::fp64x64::core::{FP64x64Impl, FP64x64}; use orion::numbers::fixed_point::implementations::fp64x64::core as core_fp64x64; use orion::numbers::fixed_point::implementations::fp64x64::comp as comp_fp64x64; diff --git a/src/numbers/fixed_point/implementations.cairo b/src/numbers/fixed_point/implementations.cairo index 8b010e349..f7c3d4c4c 100644 --- a/src/numbers/fixed_point/implementations.cairo +++ b/src/numbers/fixed_point/implementations.cairo @@ -2,3 +2,4 @@ mod fp8x23; mod fp16x16; mod fp64x64; mod fp32x32; +mod fp16x16wide; diff --git a/src/numbers/fixed_point/implementations/fp16x16/core.cairo b/src/numbers/fixed_point/implementations/fp16x16/core.cairo index 2c03f5bc4..7c2a90bea 100644 --- a/src/numbers/fixed_point/implementations/fp16x16/core.cairo +++ b/src/numbers/fixed_point/implementations/fp16x16/core.cairo @@ -5,7 +5,7 @@ use result::{ResultTrait, ResultTraitImpl}; use traits::{TryInto, Into}; use orion::numbers::signed_integer::{i32::i32, i8::i8}; -use orion::numbers::{FP64x64, FP64x64Impl}; +use orion::numbers::{FP16x16Wide, FP16x16WideImpl}; use orion::numbers::fixed_point::core::FixedTrait; use orion::numbers::fixed_point::implementations::fp16x16::math::{core, trig, hyp}; use orion::numbers::fixed_point::utils; @@ -193,31 +193,19 @@ impl FP16x16Print of PrintTrait { } } -impl FP16x16IntoFP64x64 of Into { - fn into(self: FP16x16) -> FP64x64 { - return FP64x64 { mag: self.mag.into() * 281474976710656_u128, sign: self.sign }; +impl FP16x16IntoFP16x16Wide of Into { + fn into(self: FP16x16) -> FP16x16Wide { + return FP16x16Wide { mag: self.mag.into(), sign: self.sign }; } } -#[test] -fn test_fp16x16_into_fp32x32() { - let a = FP16x16Impl::new_unscaled(42, true); - let b: FP64x64 = a.into(); - assert(b == FP64x64Impl::new_unscaled(42, true), 'invalid conversion'); -} -impl FP64x64TryIntoFP16x16 of TryInto { - fn try_into(self: FP64x64) -> Option { - Option::Some(FP16x16 { mag: (self.mag / 281474976710656_u128).try_into().unwrap(), sign: self.sign }) +impl FP16x16WideTryIntoFP16x16 of TryInto { + fn try_into(self: FP16x16Wide) -> Option { + Option::Some(FP16x16 { mag: (self.mag).try_into().unwrap(), sign: self.sign }) } } -#[test] -fn test_fp32x32_try_into_fp16x16() { - let a = FP64x64Impl::new_unscaled(42, true); - let b: FP16x16 = a.try_into().unwrap(); - assert(b == FP16x16Impl::new_unscaled(42, true), 'invalid conversion'); -} // Into a raw felt without unscaling impl FP16x16IntoFelt252 of Into { diff --git a/src/numbers/fixed_point/implementations/fp16x16wide.cairo b/src/numbers/fixed_point/implementations/fp16x16wide.cairo new file mode 100644 index 000000000..e9acee340 --- /dev/null +++ b/src/numbers/fixed_point/implementations/fp16x16wide.cairo @@ -0,0 +1,3 @@ +mod core; +mod math; +mod helpers; diff --git a/src/numbers/fixed_point/implementations/fp16x16wide/core.cairo b/src/numbers/fixed_point/implementations/fp16x16wide/core.cairo new file mode 100644 index 000000000..1c7816632 --- /dev/null +++ b/src/numbers/fixed_point/implementations/fp16x16wide/core.cairo @@ -0,0 +1,421 @@ +use debug::PrintTrait; + +use option::OptionTrait; +use result::{ResultTrait, ResultTraitImpl}; +use traits::{TryInto, Into}; + +use orion::numbers::signed_integer::{i32::i32, i8::i8}; +use orion::numbers::{FP64x64, FP64x64Impl}; +use orion::numbers::fixed_point::core::FixedTrait; +use orion::numbers::fixed_point::implementations::fp16x16wide::math::{core, trig, hyp}; +use orion::numbers::fixed_point::utils; + +/// A struct representing a fixed point number. +#[derive(Serde, Copy, Drop)] +struct FP16x16Wide { + mag: u64, + sign: bool +} + +// CONSTANTS + +const TWO: u64 = 131072; // 2 ** 17 +const ONE: u64 = 65536; // 2 ** 16 +const HALF: u64 = 32768; // 2 ** 15 +const MAX: u64 = 2147483648; // 2 ** 31 + + +impl FP16x16Impl of FixedTrait { + fn ZERO() -> FP16x16Wide { + return FP16x16Wide { mag: 0, sign: false }; + } + + fn ONE() -> FP16x16Wide { + return FP16x16Wide { mag: ONE, sign: false }; + } + + fn MAX() -> FP16x16Wide { + return FP16x16Wide { mag: MAX, sign: false }; + } + + fn new(mag: u64, sign: bool) -> FP16x16Wide { + return FP16x16Wide { mag: mag, sign: sign }; + } + + fn new_unscaled(mag: u64, sign: bool) -> FP16x16Wide { + return FP16x16Wide { mag: mag * ONE, sign: sign }; + } + + fn from_felt(val: felt252) -> FP16x16Wide { + let mag = integer::u64_try_from_felt252(utils::felt_abs(val)).unwrap(); + return FixedTrait::new(mag, utils::felt_sign(val)); + } + + fn abs(self: FP16x16Wide) -> FP16x16Wide { + return core::abs(self); + } + + fn acos(self: FP16x16Wide) -> FP16x16Wide { + return trig::acos_fast(self); + } + + fn acos_fast(self: FP16x16Wide) -> FP16x16Wide { + return trig::acos_fast(self); + } + + fn acosh(self: FP16x16Wide) -> FP16x16Wide { + return hyp::acosh(self); + } + + fn asin(self: FP16x16Wide) -> FP16x16Wide { + return trig::asin_fast(self); + } + + fn asin_fast(self: FP16x16Wide) -> FP16x16Wide { + return trig::asin_fast(self); + } + + fn asinh(self: FP16x16Wide) -> FP16x16Wide { + return hyp::asinh(self); + } + + fn atan(self: FP16x16Wide) -> FP16x16Wide { + return trig::atan_fast(self); + } + + fn atan_fast(self: FP16x16Wide) -> FP16x16Wide { + return trig::atan_fast(self); + } + + fn atanh(self: FP16x16Wide) -> FP16x16Wide { + return hyp::atanh(self); + } + + fn ceil(self: FP16x16Wide) -> FP16x16Wide { + return core::ceil(self); + } + + fn cos(self: FP16x16Wide) -> FP16x16Wide { + return trig::cos_fast(self); + } + + fn cos_fast(self: FP16x16Wide) -> FP16x16Wide { + return trig::cos_fast(self); + } + + fn cosh(self: FP16x16Wide) -> FP16x16Wide { + return hyp::cosh(self); + } + + fn floor(self: FP16x16Wide) -> FP16x16Wide { + return core::floor(self); + } + + // Calculates the natural exponent of x: e^x + fn exp(self: FP16x16Wide) -> FP16x16Wide { + return core::exp(self); + } + + // Calculates the binary exponent of x: 2^x + fn exp2(self: FP16x16Wide) -> FP16x16Wide { + return core::exp2(self); + } + + // Calculates the natural logarithm of x: ln(x) + // self must be greater than zero + fn ln(self: FP16x16Wide) -> FP16x16Wide { + return core::ln(self); + } + + // Calculates the binary logarithm of x: log2(x) + // self must be greather than zero + fn log2(self: FP16x16Wide) -> FP16x16Wide { + return core::log2(self); + } + + // Calculates the base 10 log of x: log10(x) + // self must be greater than zero + fn log10(self: FP16x16Wide) -> FP16x16Wide { + return core::log10(self); + } + + // Calclates the value of x^y and checks for overflow before returning + // self is a fixed point value + // b is a fixed point value + fn pow(self: FP16x16Wide, b: FP16x16Wide) -> FP16x16Wide { + return core::pow(self, b); + } + + fn round(self: FP16x16Wide) -> FP16x16Wide { + return core::round(self); + } + + fn sin(self: FP16x16Wide) -> FP16x16Wide { + return trig::sin_fast(self); + } + + fn sin_fast(self: FP16x16Wide) -> FP16x16Wide { + return trig::sin_fast(self); + } + + fn sinh(self: FP16x16Wide) -> FP16x16Wide { + return hyp::sinh(self); + } + + // Calculates the square root of a fixed point value + // x must be positive + fn sqrt(self: FP16x16Wide) -> FP16x16Wide { + return core::sqrt(self); + } + + fn tan(self: FP16x16Wide) -> FP16x16Wide { + return trig::tan_fast(self); + } + + fn tan_fast(self: FP16x16Wide) -> FP16x16Wide { + return trig::tan_fast(self); + } + + fn tanh(self: FP16x16Wide) -> FP16x16Wide { + return hyp::tanh(self); + } + + fn sign(self: FP16x16Wide) -> FP16x16Wide { + return core::sign(self); + } +} + + +impl FP16x16Print of PrintTrait { + fn print(self: FP16x16Wide) { + self.sign.print(); + self.mag.print(); + } +} + +impl FP16x16IntoFP64x64 of Into { + fn into(self: FP16x16Wide) -> FP64x64 { + return FP64x64 { mag: self.mag.into() * 281474976710656_u128, sign: self.sign }; + } +} + +#[test] +fn test_fp16x16_into_fp32x32() { + let a = FP16x16Impl::new_unscaled(42, true); + let b: FP64x64 = a.into(); + assert(b == FP64x64Impl::new_unscaled(42, true), 'invalid conversion'); +} + +impl FP64x64TryIntoFP16x16 of TryInto { + fn try_into(self: FP64x64) -> Option { + Option::Some( + FP16x16Wide { + mag: (self.mag / 281474976710656_u128).try_into().unwrap(), sign: self.sign + } + ) + } +} + +#[test] +fn test_fp32x32_try_into_fp16x16() { + let a = FP64x64Impl::new_unscaled(42, true); + let b: FP16x16Wide = a.try_into().unwrap(); + assert(b == FP16x16Impl::new_unscaled(42, true), 'invalid conversion'); +} + +// Into a raw felt without unscaling +impl FP16x16IntoFelt252 of Into { + fn into(self: FP16x16Wide) -> felt252 { + let mag_felt = self.mag.into(); + + if self.sign { + return mag_felt * -1; + } else { + return mag_felt * 1; + } + } +} + +impl FP16x16IntoI32 of Into { + fn into(self: FP16x16Wide) -> i32 { + _i32_into_fp(self) + } +} + +impl FP16x16TryIntoI8 of TryInto { + fn try_into(self: FP16x16Wide) -> Option { + _i8_try_from_fp(self) + } +} + + +impl FP16x16TryIntoU128 of TryInto { + fn try_into(self: FP16x16Wide) -> Option { + if self.sign { + return Option::None(()); + } else { + // Unscale the magnitude and round down + return Option::Some((self.mag / ONE).into()); + } + } +} + +impl FP16x16TryIntoU64 of TryInto { + fn try_into(self: FP16x16Wide) -> Option { + if self.sign { + return Option::None(()); + } else { + // Unscale the magnitude and round down + return Option::Some((self.mag / ONE).into()); + } + } +} + +impl FP16x16TryIntoU32 of TryInto { + fn try_into(self: FP16x16Wide) -> Option { + if self.sign { + return Option::None(()); + } else { + // Unscale the magnitude and round down + return (self.mag / ONE).try_into(); + } + } +} + +impl FP16x16TryIntoU16 of TryInto { + fn try_into(self: FP16x16Wide) -> Option { + if self.sign { + Option::None(()) + } else { + // Unscale the magnitude and round down + return (self.mag / ONE).try_into(); + } + } +} + +impl FP16x16TryIntoU8 of TryInto { + fn try_into(self: FP16x16Wide) -> Option { + if self.sign { + Option::None(()) + } else { + // Unscale the magnitude and round down + return (self.mag / ONE).try_into(); + } + } +} + +impl FP16x16PartialEq of PartialEq { + #[inline(always)] + fn eq(lhs: @FP16x16Wide, rhs: @FP16x16Wide) -> bool { + return core::eq(lhs, rhs); + } + + #[inline(always)] + fn ne(lhs: @FP16x16Wide, rhs: @FP16x16Wide) -> bool { + return core::ne(lhs, rhs); + } +} + +impl FP16x16Add of Add { + fn add(lhs: FP16x16Wide, rhs: FP16x16Wide) -> FP16x16Wide { + return core::add(lhs, rhs); + } +} + +impl FP16x16AddEq of AddEq { + #[inline(always)] + fn add_eq(ref self: FP16x16Wide, other: FP16x16Wide) { + self = Add::add(self, other); + } +} + +impl FP16x16Sub of Sub { + fn sub(lhs: FP16x16Wide, rhs: FP16x16Wide) -> FP16x16Wide { + return core::sub(lhs, rhs); + } +} + +impl FP16x16SubEq of SubEq { + #[inline(always)] + fn sub_eq(ref self: FP16x16Wide, other: FP16x16Wide) { + self = Sub::sub(self, other); + } +} + +impl FP16x16Mul of Mul { + fn mul(lhs: FP16x16Wide, rhs: FP16x16Wide) -> FP16x16Wide { + return core::mul(lhs, rhs); + } +} + +impl FP16x16MulEq of MulEq { + #[inline(always)] + fn mul_eq(ref self: FP16x16Wide, other: FP16x16Wide) { + self = Mul::mul(self, other); + } +} + +impl FP16x16Div of Div { + fn div(lhs: FP16x16Wide, rhs: FP16x16Wide) -> FP16x16Wide { + return core::div(lhs, rhs); + } +} + +impl FP16x16DivEq of DivEq { + #[inline(always)] + fn div_eq(ref self: FP16x16Wide, other: FP16x16Wide) { + self = Div::div(self, other); + } +} + +impl FP16x16PartialOrd of PartialOrd { + #[inline(always)] + fn ge(lhs: FP16x16Wide, rhs: FP16x16Wide) -> bool { + return core::ge(lhs, rhs); + } + + #[inline(always)] + fn gt(lhs: FP16x16Wide, rhs: FP16x16Wide) -> bool { + return core::gt(lhs, rhs); + } + + #[inline(always)] + fn le(lhs: FP16x16Wide, rhs: FP16x16Wide) -> bool { + return core::le(lhs, rhs); + } + + #[inline(always)] + fn lt(lhs: FP16x16Wide, rhs: FP16x16Wide) -> bool { + return core::lt(lhs, rhs); + } +} + +impl FP16x16Neg of Neg { + #[inline(always)] + fn neg(a: FP16x16Wide) -> FP16x16Wide { + return core::neg(a); + } +} + +impl FP16x16Rem of Rem { + #[inline(always)] + fn rem(lhs: FP16x16Wide, rhs: FP16x16Wide) -> FP16x16Wide { + return core::rem(lhs, rhs); + } +} + + +/// INTERNAL + +fn _i32_into_fp(x: FP16x16Wide) -> i32 { + i32 { mag: (x.mag / ONE).try_into().unwrap(), sign: x.sign } +} + +fn _i8_try_from_fp(x: FP16x16Wide) -> Option { + let unscaled_mag: Option = (x.mag / ONE).try_into(); + + match unscaled_mag { + Option::Some(val) => Option::Some(i8 { mag: unscaled_mag.unwrap(), sign: x.sign }), + Option::None(_) => Option::None(()) + } +} diff --git a/src/numbers/fixed_point/implementations/fp16x16wide/helpers.cairo b/src/numbers/fixed_point/implementations/fp16x16wide/helpers.cairo new file mode 100644 index 000000000..959f00d43 --- /dev/null +++ b/src/numbers/fixed_point/implementations/fp16x16wide/helpers.cairo @@ -0,0 +1,41 @@ +use debug::PrintTrait; +use traits::Into; + +use orion::numbers::fixed_point::implementations::fp16x16wide::core::{ + HALF, ONE, TWO, FP16x16Wide, FP16x16Impl, FP16x16Sub, FP16x16Div, FixedTrait, FP16x16Print +}; + +const DEFAULT_PRECISION: u64 = 7; // 1e-4 + +// To use `DEFAULT_PRECISION`, final arg is: `Option::None(())`. +// To use `custom_precision` of 430_u64: `Option::Some(430_u64)`. +fn assert_precise(result: FP16x16Wide, expected: felt252, msg: felt252, custom_precision: Option) { + let precision = match custom_precision { + Option::Some(val) => val, + Option::None(_) => DEFAULT_PRECISION, + }; + + let diff = (result - FixedTrait::from_felt(expected)).mag; + + if (diff > precision) { + result.print(); + assert(diff <= precision, msg); + } +} + +fn assert_relative( + result: FP16x16Wide, expected: felt252, msg: felt252, custom_precision: Option +) { + let precision = match custom_precision { + Option::Some(val) => val, + Option::None(_) => DEFAULT_PRECISION, + }; + + let diff = result - FixedTrait::from_felt(expected); + let rel_diff = (diff / result).mag; + + if (rel_diff > precision) { + result.print(); + assert(rel_diff <= precision, msg); + } +} diff --git a/src/numbers/fixed_point/implementations/fp16x16wide/math.cairo b/src/numbers/fixed_point/implementations/fp16x16wide/math.cairo new file mode 100644 index 000000000..970c65f30 --- /dev/null +++ b/src/numbers/fixed_point/implementations/fp16x16wide/math.cairo @@ -0,0 +1,5 @@ +mod core; +mod comp; +mod lut; +mod trig; +mod hyp; diff --git a/src/numbers/fixed_point/implementations/fp16x16wide/math/comp.cairo b/src/numbers/fixed_point/implementations/fp16x16wide/math/comp.cairo new file mode 100644 index 000000000..26701f9d7 --- /dev/null +++ b/src/numbers/fixed_point/implementations/fp16x16wide/math/comp.cairo @@ -0,0 +1,76 @@ +use orion::numbers::fixed_point::implementations::fp16x16wide::core::{ + FP16x16Wide, FixedTrait, FP16x16Impl, FP16x16PartialOrd, FP16x16PartialEq +}; + +fn max(a: FP16x16Wide, b: FP16x16Wide) -> FP16x16Wide { + if (a >= b) { + return a; + } else { + return b; + } +} + +fn min(a: FP16x16Wide, b: FP16x16Wide) -> FP16x16Wide { + if (a <= b) { + return a; + } else { + return b; + } +} + +fn xor(a: FP16x16Wide, b: FP16x16Wide) -> bool { + if (a == FixedTrait::new(0, false) || b == FixedTrait::new(0, false)) && (a != b) { + return true; + } else { + return false; + } +} + +fn or(a: FP16x16Wide, b: FP16x16Wide) -> bool { + let zero = FixedTrait::new(0, false); + if a == zero && b == zero { + return false; + } else { + return true; + } +} + +// Tests -------------------------------------------------------------------------------------------------------------- + +#[test] +fn test_max() { + let a = FixedTrait::new_unscaled(1, false); + let b = FixedTrait::new_unscaled(0, false); + let c = FixedTrait::new_unscaled(1, true); + + assert(max(a, a) == a, 'max(a, a)'); + assert(max(a, b) == a, 'max(a, b)'); + assert(max(a, c) == a, 'max(a, c)'); + + assert(max(b, a) == a, 'max(b, a)'); + assert(max(b, b) == b, 'max(b, b)'); + assert(max(b, c) == b, 'max(b, c)'); + + assert(max(c, a) == a, 'max(c, a)'); + assert(max(c, b) == b, 'max(c, b)'); + assert(max(c, c) == c, 'max(c, c)'); +} + +#[test] +fn test_min() { + let a = FixedTrait::new_unscaled(1, false); + let b = FixedTrait::new_unscaled(0, false); + let c = FixedTrait::new_unscaled(1, true); + + assert(min(a, a) == a, 'min(a, a)'); + assert(min(a, b) == b, 'min(a, b)'); + assert(min(a, c) == c, 'min(a, c)'); + + assert(min(b, a) == b, 'min(b, a)'); + assert(min(b, b) == b, 'min(b, b)'); + assert(min(b, c) == c, 'min(b, c)'); + + assert(min(c, a) == c, 'min(c, a)'); + assert(min(c, b) == c, 'min(c, b)'); + assert(min(c, c) == c, 'min(c, c)'); +} diff --git a/src/numbers/fixed_point/implementations/fp16x16wide/math/core.cairo b/src/numbers/fixed_point/implementations/fp16x16wide/math/core.cairo new file mode 100644 index 000000000..c80251639 --- /dev/null +++ b/src/numbers/fixed_point/implementations/fp16x16wide/math/core.cairo @@ -0,0 +1,659 @@ +use core::debug::PrintTrait; +use option::OptionTrait; +use result::{ResultTrait, ResultTraitImpl}; +use traits::{Into, TryInto}; +use integer::{u64_safe_divmod, u64_as_non_zero, u64_wide_mul}; + +use orion::numbers::fixed_point::implementations::fp16x16wide::core::{ + HALF, ONE, MAX, FP16x16Wide, FP16x16Impl, FP16x16Add, FP16x16AddEq, FP16x16Sub, FP16x16Mul, + FP16x16MulEq, FP16x16TryIntoU128, FP16x16PartialEq, FP16x16PartialOrd, FP16x16SubEq, FP16x16Neg, + FP16x16Div, FP16x16IntoFelt252, FixedTrait +}; +use orion::numbers::fixed_point::implementations::fp16x16wide::math::lut; + +// PUBLIC + +fn abs(a: FP16x16Wide) -> FP16x16Wide { + return FixedTrait::new(a.mag, false); +} + +fn add(a: FP16x16Wide, b: FP16x16Wide) -> FP16x16Wide { + if a.sign == b.sign { + return FixedTrait::new(a.mag + b.mag, a.sign); + } + + if a.mag == b.mag { + return FixedTrait::ZERO(); + } + + if (a.mag > b.mag) { + return FixedTrait::new(a.mag - b.mag, a.sign); + } else { + return FixedTrait::new(b.mag - a.mag, b.sign); + } +} + +fn ceil(a: FP16x16Wide) -> FP16x16Wide { + let (div, rem) = u64_safe_divmod(a.mag, u64_as_non_zero(ONE)); + + if rem == 0 { + return a; + } else if !a.sign { + return FixedTrait::new_unscaled(div + 1, false); + } else if div == 0 { + return FixedTrait::new_unscaled(0, false); + } else { + return FixedTrait::new_unscaled(div, true); + } +} + +fn div(a: FP16x16Wide, b: FP16x16Wide) -> FP16x16Wide { + let a_u64 = integer::u64_wide_mul(a.mag, ONE); + let res_u64 = a_u64 / b.mag.into(); + + // Re-apply sign + return FixedTrait::new(res_u64.try_into().unwrap(), a.sign ^ b.sign); +} + +fn eq(a: @FP16x16Wide, b: @FP16x16Wide) -> bool { + return (*a.mag == *b.mag) && (*a.sign == *b.sign); +} + +// Calculates the natural exponent of x: e^x +fn exp(a: FP16x16Wide) -> FP16x16Wide { + return exp2(FixedTrait::new(94548, false) * a); // log2(e) * 2^23 ≈ 12102203 +} + +// Calculates the binary exponent of x: 2^x +fn exp2(a: FP16x16Wide) -> FP16x16Wide { + if (a.mag == 0) { + return FixedTrait::ONE(); + } + + let (int_part, frac_part) = integer::u64_safe_divmod(a.mag, u64_as_non_zero(ONE)); + let int_res = FixedTrait::new_unscaled(lut::exp2(int_part), false); + let mut res_u = int_res; + + if frac_part != 0 { + let frac = FixedTrait::new(frac_part, false); + let r7 = FixedTrait::new(1, false) * frac; + let r6 = (r7 + FixedTrait::new(10, false)) * frac; + let r5 = (r6 + FixedTrait::new(87, false)) * frac; + let r4 = (r5 + FixedTrait::new(630, false)) * frac; + let r3 = (r4 + FixedTrait::new(3638, false)) * frac; + let r2 = (r3 + FixedTrait::new(15743, false)) * frac; + let r1 = (r2 + FixedTrait::new(45426, false)) * frac; + res_u = res_u * (r1 + FixedTrait::ONE()); + } + + if (a.sign == true) { + return FixedTrait::ONE() / res_u; + } else { + return res_u; + } +} + +fn exp2_int(exp: u64) -> FP16x16Wide { + return FixedTrait::new_unscaled(lut::exp2(exp), false); +} + +fn floor(a: FP16x16Wide) -> FP16x16Wide { + let (div, rem) = integer::u64_safe_divmod(a.mag, u64_as_non_zero(ONE)); + + if rem == 0 { + return a; + } else if !a.sign { + return FixedTrait::new_unscaled(div, false); + } else { + return FixedTrait::new_unscaled(div + 1, true); + } +} + +fn ge(a: FP16x16Wide, b: FP16x16Wide) -> bool { + if a.sign != b.sign { + return !a.sign; + } else { + return (a.mag == b.mag) || ((a.mag > b.mag) ^ a.sign); + } +} + +fn gt(a: FP16x16Wide, b: FP16x16Wide) -> bool { + if a.sign != b.sign { + return !a.sign; + } else { + return (a.mag != b.mag) && ((a.mag > b.mag) ^ a.sign); + } +} + +fn le(a: FP16x16Wide, b: FP16x16Wide) -> bool { + if a.sign != b.sign { + return a.sign; + } else { + return (a.mag == b.mag) || ((a.mag < b.mag) ^ a.sign); + } +} + +// Calculates the natural logarithm of x: ln(x) +// self must be greater than zero +fn ln(a: FP16x16Wide) -> FP16x16Wide { + return FixedTrait::new(45426, false) * log2(a); // ln(2) = 0.693... +} + +// Calculates the binary logarithm of x: log2(x) +// self must be greather than zero +fn log2(a: FP16x16Wide) -> FP16x16Wide { + assert(a.sign == false, 'must be positive'); + + if (a.mag == ONE) { + return FixedTrait::ZERO(); + } else if (a.mag < ONE) { + // Compute true inverse binary log if 0 < x < 1 + let div = FixedTrait::ONE() / a; + return -log2(div); + } + + let whole = a.mag / ONE; + let (msb, div) = lut::msb(whole); + + if a.mag == div * ONE { + return FixedTrait::new_unscaled(msb, false); + } else { + let norm = a / FixedTrait::new_unscaled(div, false); + let r8 = FixedTrait::new(596, true) * norm; + let r7 = (r8 + FixedTrait::new(8116, false)) * norm; + let r6 = (r7 + FixedTrait::new(49044, true)) * norm; + let r5 = (r6 + FixedTrait::new(172935, false)) * norm; + let r4 = (r5 + FixedTrait::new(394096, true)) * norm; + let r3 = (r4 + FixedTrait::new(608566, false)) * norm; + let r2 = (r3 + FixedTrait::new(655828, true)) * norm; + let r1 = (r2 + FixedTrait::new(534433, false)) * norm; + return r1 + FixedTrait::new(224487, true) + FixedTrait::new_unscaled(msb, false); + } +} + +// Calculates the base 10 log of x: log10(x) +// self must be greater than zero +fn log10(a: FP16x16Wide) -> FP16x16Wide { + return FixedTrait::new(19728, false) * log2(a); // log10(2) = 0.301... +} + +fn lt(a: FP16x16Wide, b: FP16x16Wide) -> bool { + if a.sign != b.sign { + return a.sign; + } else { + return (a.mag != b.mag) && ((a.mag < b.mag) ^ a.sign); + } +} + +fn mul(a: FP16x16Wide, b: FP16x16Wide) -> FP16x16Wide { + let prod_u128 = integer::u64_wide_mul(a.mag, b.mag); + + // Re-apply sign + return FixedTrait::new((prod_u128 / ONE.into()).try_into().unwrap(), a.sign ^ b.sign); +} + +fn ne(a: @FP16x16Wide, b: @FP16x16Wide) -> bool { + return (*a.mag != *b.mag) || (*a.sign != *b.sign); +} + +fn neg(a: FP16x16Wide) -> FP16x16Wide { + if a.mag == 0 { + return a; + } else if !a.sign { + return FixedTrait::new(a.mag, !a.sign); + } else { + return FixedTrait::new(a.mag, false); + } +} + +// Calclates the value of x^y and checks for overflow before returning +// self is a FP16x16Wide point value +// b is a FP16x16Wide point value +fn pow(a: FP16x16Wide, b: FP16x16Wide) -> FP16x16Wide { + let (div, rem) = integer::u64_safe_divmod(b.mag, u64_as_non_zero(ONE)); + + // use the more performant integer pow when y is an int + if (rem == 0) { + return pow_int(a, b.mag / ONE, b.sign); + } + + // x^y = exp(y*ln(x)) for x > 0 will error for x < 0 + return exp(b * ln(a)); +} + +// Calclates the value of a^b and checks for overflow before returning +fn pow_int(a: FP16x16Wide, b: u64, sign: bool) -> FP16x16Wide { + let mut x = a; + let mut n = b; + + if sign == true { + x = FixedTrait::ONE() / x; + } + + if n == 0 { + return FixedTrait::ONE(); + } + + let mut y = FixedTrait::ONE(); + let two = integer::u64_as_non_zero(2); + + loop { + if n <= 1 { + break; + } + + let (div, rem) = integer::u64_safe_divmod(n, two); + + if rem == 1 { + y = x * y; + } + + x = x * x; + n = div; + }; + + return x * y; +} + +fn rem(a: FP16x16Wide, b: FP16x16Wide) -> FP16x16Wide { + return a - floor(a / b) * b; +} + +fn round(a: FP16x16Wide) -> FP16x16Wide { + let (div, rem) = integer::u64_safe_divmod(a.mag, u64_as_non_zero(ONE)); + + if (HALF <= rem) { + return FixedTrait::new_unscaled(div + 1, a.sign); + } else { + return FixedTrait::new_unscaled(div, a.sign); + } +} + +// Calculates the square root of a FP16x16Wide point value +// x must be positive +fn sqrt(a: FP16x16Wide) -> FP16x16Wide { + assert(a.sign == false, 'must be positive'); + + let root = integer::u64_sqrt(a.mag.into() * ONE.into()); + return FixedTrait::new(root.into(), false); +} + +fn sub(a: FP16x16Wide, b: FP16x16Wide) -> FP16x16Wide { + return add(a, -b); +} + +fn sign(a: FP16x16Wide) -> FP16x16Wide { + if a.mag == 0 { + FixedTrait::new(0, false) + } else { + FixedTrait::new(ONE, a.sign) + } +} + +// Tests -------------------------------------------------------------------------------------------------------------- + +use orion::numbers::fixed_point::implementations::fp16x16wide::helpers::{ + assert_precise, assert_relative +}; +use orion::numbers::fixed_point::implementations::fp16x16wide::math::trig::{PI, HALF_PI}; + +#[test] +fn test_into() { + let a = FixedTrait::::new_unscaled(5, false); + assert(a.mag == 5 * ONE, 'invalid result'); +} + +#[test] +fn test_try_into_u128() { + // Positive unscaled + let a = FixedTrait::::new_unscaled(5, false); + assert(a.try_into().unwrap() == 5_u128, 'invalid result'); + + // Positive scaled + let b = FixedTrait::::new(5 * ONE, false); + assert(b.try_into().unwrap() == 5_u128, 'invalid result'); + + // Zero + let d = FixedTrait::::new_unscaled(0, false); + assert(d.try_into().unwrap() == 0_u128, 'invalid result'); +} + +#[test] +#[should_panic] +fn test_negative_try_into_u128() { + let a = FixedTrait::::new_unscaled(1, true); + let a: u128 = a.try_into().unwrap(); +} + +#[test] +#[available_gas(1000000)] +fn test_acos() { + let a = FixedTrait::::ONE(); + assert(a.acos().into() == 0, 'invalid one'); +} + +#[test] +#[available_gas(1000000)] +fn test_asin() { + let a = FixedTrait::ONE(); + assert_precise(a.asin(), HALF_PI.into(), 'invalid one', Option::None(())); // PI / 2 +} + +#[test] +#[available_gas(2000000)] +fn test_atan() { + let a = FixedTrait::new(2 * ONE, false); + assert_relative(a.atan(), 72558, 'invalid two', Option::None(())); +} + +#[test] +fn test_ceil() { + let a = FixedTrait::new(190054, false); // 2.9 + assert(ceil(a).mag == 3 * ONE, 'invalid pos decimal'); +} + +#[test] +fn test_floor() { + let a = FixedTrait::new(190054, false); // 2.9 + assert(floor(a).mag == 2 * ONE, 'invalid pos decimal'); +} + +#[test] +fn test_round() { + let a = FixedTrait::new(190054, false); // 2.9 + assert(round(a).mag == 3 * ONE, 'invalid pos decimal'); +} + +#[test] +#[should_panic] +fn test_sqrt_fail() { + let a = FixedTrait::new_unscaled(25, true); + sqrt(a); +} + +#[test] +fn test_sqrt() { + let mut a = FixedTrait::new_unscaled(0, false); + assert(sqrt(a).mag == 0, 'invalid zero root'); + a = FixedTrait::new_unscaled(25, false); + assert(sqrt(a).mag == 5 * ONE, 'invalid pos root'); +} + + +#[test] +#[available_gas(100000)] +fn test_msb() { + let a = FixedTrait::::new_unscaled(100, false); + let (msb, div) = lut::msb(a.mag / ONE); + assert(msb == 6, 'invalid msb'); + assert(div == 64, 'invalid msb ceil'); +} + +#[test] +#[available_gas(600000)] +fn test_pow() { + let a = FixedTrait::new_unscaled(3, false); + let b = FixedTrait::new_unscaled(4, false); + assert(pow(a, b).mag == 81 * ONE, 'invalid pos base power'); +} + +#[test] +#[available_gas(900000)] +fn test_pow_frac() { + let a = FixedTrait::new_unscaled(3, false); + let b = FixedTrait::new(32768, false); // 0.5 + assert_relative( + pow(a, b), 113512, 'invalid pos base power', Option::None(()) + ); // 1.7320508075688772 +} + +#[test] +#[available_gas(1000000)] +fn test_exp() { + let a = FixedTrait::new_unscaled(2, false); + assert_relative(exp(a), 484249, 'invalid exp of 2', Option::None(())); // 7.389056098793725 +} + +#[test] +#[available_gas(400000)] +fn test_exp2() { + let a = FixedTrait::new_unscaled(5, false); + assert(exp2(a).mag == 2097152, 'invalid exp2 of 2'); +} + +#[test] +#[available_gas(20000)] +fn test_exp2_int() { + assert(exp2_int(5).into() == 2097152, 'invalid exp2 of 2'); +} + +#[test] +#[available_gas(1000000)] +fn test_ln() { + let mut a = FixedTrait::new_unscaled(1, false); + assert(ln(a).mag == 0, 'invalid ln of 1'); + + a = FixedTrait::new(178145, false); + assert_relative(ln(a), ONE.into(), 'invalid ln of 2.7...', Option::None(())); +} + +#[test] +#[available_gas(1000000)] +fn test_log2() { + let mut a = FixedTrait::new_unscaled(32, false); + assert(log2(a) == FixedTrait::new_unscaled(5, false), 'invalid log2 32'); + + a = FixedTrait::new_unscaled(10, false); + assert_relative(log2(a), 217706, 'invalid log2 10', Option::None(())); // 3.321928094887362 +} + +#[test] +#[available_gas(1000000)] +fn test_log10() { + let a = FixedTrait::new_unscaled(100, false); + assert_relative(log10(a), 2 * ONE.into(), 'invalid log10', Option::None(())); +} + +#[test] +fn test_eq() { + let a = FixedTrait::new_unscaled(42, false); + let b = FixedTrait::new_unscaled(42, false); + let c = eq(@a, @b); + assert(c == true, 'invalid result'); +} + +#[test] +fn test_ne() { + let a = FixedTrait::new_unscaled(42, false); + let b = FixedTrait::new_unscaled(42, false); + let c = ne(@a, @b); + assert(c == false, 'invalid result'); +} + +#[test] +fn test_add() { + let a = FixedTrait::new_unscaled(1, false); + let b = FixedTrait::new_unscaled(2, false); + assert(add(a, b) == FixedTrait::new_unscaled(3, false), 'invalid result'); +} + +#[test] +fn test_add_eq() { + let mut a = FixedTrait::new_unscaled(1, false); + let b = FixedTrait::new_unscaled(2, false); + a += b; + assert(a == FixedTrait::::new_unscaled(3, false), 'invalid result'); +} + +#[test] +fn test_sub() { + let a = FixedTrait::new_unscaled(5, false); + let b = FixedTrait::new_unscaled(2, false); + let c = a - b; + assert(c == FixedTrait::::new_unscaled(3, false), 'false result invalid'); +} + +#[test] +fn test_sub_eq() { + let mut a = FixedTrait::new_unscaled(5, false); + let b = FixedTrait::new_unscaled(2, false); + a -= b; + assert(a == FixedTrait::::new_unscaled(3, false), 'invalid result'); +} + +#[test] +#[available_gas(100000)] +fn test_mul_pos() { + let a = FP16x16Wide { mag: 190054, sign: false }; + let b = FP16x16Wide { mag: 190054, sign: false }; + let c = a * b; + assert(c.mag == 551155, 'invalid result'); +} + +#[test] +fn test_mul_neg() { + let a = FixedTrait::new_unscaled(5, false); + let b = FixedTrait::new_unscaled(2, true); + let c = a * b; + assert(c == FixedTrait::::new_unscaled(10, true), 'invalid result'); +} + +#[test] +fn test_mul_eq() { + let mut a = FixedTrait::new_unscaled(5, false); + let b = FixedTrait::new_unscaled(2, true); + a *= b; + assert(a == FixedTrait::::new_unscaled(10, true), 'invalid result'); +} + +#[test] +fn test_div() { + let a = FixedTrait::new_unscaled(10, false); + let b = FixedTrait::::new(190054, false); // 2.9 + let c = a / b; + assert(c.mag == 225986, 'invalid pos decimal'); // 3.4482758620689653 +} + +#[test] +fn test_le() { + let a = FixedTrait::new_unscaled(1, false); + let b = FixedTrait::new_unscaled(0, false); + let c = FixedTrait::::new_unscaled(1, true); + + assert(a <= a, 'a <= a'); + assert(a <= b == false, 'a <= b'); + assert(a <= c == false, 'a <= c'); + + assert(b <= a, 'b <= a'); + assert(b <= b, 'b <= b'); + assert(b <= c == false, 'b <= c'); + + assert(c <= a, 'c <= a'); + assert(c <= b, 'c <= b'); + assert(c <= c, 'c <= c'); +} + +#[test] +fn test_lt() { + let a = FixedTrait::new_unscaled(1, false); + let b = FixedTrait::new_unscaled(0, false); + let c = FixedTrait::::new_unscaled(1, true); + + assert(a < a == false, 'a < a'); + assert(a < b == false, 'a < b'); + assert(a < c == false, 'a < c'); + + assert(b < a, 'b < a'); + assert(b < b == false, 'b < b'); + assert(b < c == false, 'b < c'); + + assert(c < a, 'c < a'); + assert(c < b, 'c < b'); + assert(c < c == false, 'c < c'); +} + +#[test] +fn test_ge() { + let a = FixedTrait::new_unscaled(1, false); + let b = FixedTrait::new_unscaled(0, false); + let c = FixedTrait::::new_unscaled(1, true); + + assert(a >= a, 'a >= a'); + assert(a >= b, 'a >= b'); + assert(a >= c, 'a >= c'); + + assert(b >= a == false, 'b >= a'); + assert(b >= b, 'b >= b'); + assert(b >= c, 'b >= c'); + + assert(c >= a == false, 'c >= a'); + assert(c >= b == false, 'c >= b'); + assert(c >= c, 'c >= c'); +} + +#[test] +fn test_gt() { + let a = FixedTrait::new_unscaled(1, false); + let b = FixedTrait::new_unscaled(0, false); + let c = FixedTrait::::new_unscaled(1, true); + + assert(a > a == false, 'a > a'); + assert(a > b, 'a > b'); + assert(a > c, 'a > c'); + + assert(b > a == false, 'b > a'); + assert(b > b == false, 'b > b'); + assert(b > c, 'b > c'); + + assert(c > a == false, 'c > a'); + assert(c > b == false, 'c > b'); + assert(c > c == false, 'c > c'); +} + +#[test] +#[available_gas(1000000)] +fn test_cos() { + let a = FixedTrait::::new(HALF_PI, false); + assert(a.cos().into() == 0, 'invalid half pi'); +} + +#[test] +#[available_gas(1000000)] +fn test_sin() { + let a = FixedTrait::new(HALF_PI, false); + assert_precise(a.sin(), ONE.into(), 'invalid half pi', Option::None(())); +} + +#[test] +#[available_gas(2000000)] +fn test_tan() { + let a = FixedTrait::::new(HALF_PI / 2, false); + assert(a.tan().mag == 65536, 'invalid quarter pi'); +} + +#[test] +#[available_gas(2000000)] +fn test_sign() { + let a = FixedTrait::::new(0, false); + assert(a.sign().mag == 0 && !a.sign().sign, 'invalid sign (0, true)'); + + let a = FixedTrait::::new(HALF, true); + assert(a.sign().mag == ONE && a.sign().sign, 'invalid sign (HALF, true)'); + + let a = FixedTrait::::new(HALF, false); + assert(a.sign().mag == ONE && !a.sign().sign, 'invalid sign (HALF, false)'); + + let a = FixedTrait::::new(ONE, true); + assert(a.sign().mag == ONE && a.sign().sign, 'invalid sign (ONE, true)'); + + let a = FixedTrait::::new(ONE, false); + assert(a.sign().mag == ONE && !a.sign().sign, 'invalid sign (ONE, false)'); +} + +#[test] +#[should_panic] +#[available_gas(2000000)] +fn test_sign_fail() { + let a = FixedTrait::::new(HALF, true); + assert(a.sign().mag != ONE && !a.sign().sign, 'invalid sign (HALF, true)'); +} diff --git a/src/numbers/fixed_point/implementations/fp16x16wide/math/hyp.cairo b/src/numbers/fixed_point/implementations/fp16x16wide/math/hyp.cairo new file mode 100644 index 000000000..5f462faab --- /dev/null +++ b/src/numbers/fixed_point/implementations/fp16x16wide/math/hyp.cairo @@ -0,0 +1,159 @@ +use core::debug::PrintTrait; +use orion::numbers::fixed_point::implementations::fp16x16wide::core::{ + HALF, ONE, TWO, FP16x16Wide, FP16x16Impl, FP16x16Add, FP16x16AddEq, FP16x16Sub, FP16x16Mul, + FP16x16MulEq, FP16x16TryIntoU128, FP16x16PartialEq, FP16x16PartialOrd, FP16x16SubEq, FP16x16Neg, + FP16x16Div, FP16x16IntoFelt252, FixedTrait +}; + +// Calculates hyperbolic cosine of a (fixed point) +fn cosh(a: FP16x16Wide) -> FP16x16Wide { + let ea = a.exp(); + return (ea + (FixedTrait::ONE() / ea)) / FixedTrait::new(TWO, false); +} + +// Calculates hyperbolic sine of a (fixed point) +fn sinh(a: FP16x16Wide) -> FP16x16Wide { + let ea = a.exp(); + return (ea - (FixedTrait::ONE() / ea)) / FixedTrait::new(TWO, false); +} + +// Calculates hyperbolic tangent of a (fixed point) +fn tanh(a: FP16x16Wide) -> FP16x16Wide { + let ea = a.exp(); + let ea_i = FixedTrait::ONE() / ea; + return (ea - ea_i) / (ea + ea_i); +} + +// Calculates inverse hyperbolic cosine of a (fixed point) +fn acosh(a: FP16x16Wide) -> FP16x16Wide { + let root = (a * a - FixedTrait::ONE()).sqrt(); + return (a + root).ln(); +} + +// Calculates inverse hyperbolic sine of a (fixed point) +fn asinh(a: FP16x16Wide) -> FP16x16Wide { + let root = (a * a + FixedTrait::ONE()).sqrt(); + return (a + root).ln(); +} + +// Calculates inverse hyperbolic tangent of a (fixed point) +fn atanh(a: FP16x16Wide) -> FP16x16Wide { + let one = FixedTrait::ONE(); + let ln_arg = (one + a) / (one - a); + return ln_arg.ln() / FixedTrait::new(TWO, false); +} + +// Tests -------------------------------------------------------------------------------------------------------------- + +use option::OptionTrait; +use traits::Into; + +use orion::numbers::fixed_point::implementations::fp16x16wide::helpers::assert_precise; + +#[test] +#[available_gas(10000000)] +fn test_cosh() { + let a = FixedTrait::new(TWO, false); + assert_precise(cosh(a), 246550, 'invalid two', Option::None(())); // 3.5954653836066 + + let a = FixedTrait::ONE(); + assert_precise(cosh(a), 101127, 'invalid one', Option::None(())); // 1.42428174592510 + + let a = FixedTrait::ZERO(); + assert_precise(cosh(a), ONE.into(), 'invalid zero', Option::None(())); + + let a = FixedTrait::ONE(); + assert_precise(cosh(a), 101127, 'invalid neg one', Option::None(())); // 1.42428174592510 + + let a = FixedTrait::new(TWO, true); + assert_precise(cosh(a), 246568, 'invalid neg two', Option::None(())); // 3.5954653836066 +} + +#[test] +#[available_gas(10000000)] +fn test_sinh() { + let a = FixedTrait::new(TWO, false); + assert_precise(sinh(a), 237681, 'invalid two', Option::None(())); // 3.48973469357602 + + let a = FixedTrait::ONE(); + assert_precise(sinh(a), 77018, 'invalid one', Option::None(())); // 1.13687593250230 + + let a = FixedTrait::ZERO(); + assert(sinh(a).into() == 0, 'invalid zero'); + + let a = FixedTrait::new(ONE, true); + assert_precise(sinh(a), -77018, 'invalid neg one', Option::None(())); // -1.13687593250230 + + let a = FixedTrait::new(TWO, true); + assert_precise(sinh(a), -237699, 'invalid neg two', Option::None(())); // -3.48973469357602 +} + +#[test] +#[available_gas(10000000)] +fn test_tanh() { + let a = FixedTrait::new(TWO, false); + assert_precise(tanh(a), 63179, 'invalid two', Option::None(())); // 0.75314654693321 + + let a = FixedTrait::ONE(); + assert_precise(tanh(a), 49912, 'invalid one', Option::None(())); // 0.59499543433175 + + let a = FixedTrait::ZERO(); + assert(tanh(a).into() == 0, 'invalid zero'); + + let a = FixedTrait::new(ONE, true); + assert_precise(tanh(a), -49912, 'invalid neg one', Option::None(())); // -0.59499543433175 + + let a = FixedTrait::new(TWO, true); + assert_precise(tanh(a), -63179, 'invalid neg two', Option::None(())); // 0.75314654693321 +} + +#[test] +#[available_gas(10000000)] +fn test_acosh() { + let a = FixedTrait::new(246559, false); // 3.5954653836066 + assert_precise(acosh(a), 131072, 'invalid two', Option::None(())); + + let a = FixedTrait::new(101127, false); // 1.42428174592510 + assert_precise(acosh(a), ONE.into(), 'invalid one', Option::None(())); + + let a = FixedTrait::ONE(); // 1 + assert(acosh(a).into() == 0, 'invalid zero'); +} + +#[test] +#[available_gas(10000000)] +fn test_asinh() { + let a = FixedTrait::new(237690, false); // 3.48973469357602 + assert_precise(asinh(a), 131072, 'invalid two', Option::None(())); + + let a = FixedTrait::new(77018, false); // 1.13687593250230 + assert_precise(asinh(a), ONE.into(), 'invalid one', Option::None(())); + + let a = FixedTrait::ZERO(); + assert(asinh(a).into() == 0, 'invalid zero'); + + let a = FixedTrait::new(77018, true); // -1.13687593250230 + assert_precise(asinh(a), -ONE.into(), 'invalid neg one', Option::None(())); + + let a = FixedTrait::new(237690, true); // -3.48973469357602 + assert_precise(asinh(a), -131017, 'invalid neg two', Option::None(())); +} + +#[test] +#[available_gas(10000000)] +fn test_atanh() { + let a = FixedTrait::new(58982, false); // 0.9 + assert_precise(atanh(a), 96483, 'invalid 0.9', Option::None(())); // 1.36892147623689 + + let a = FixedTrait::new(HALF, false); // 0.5 + assert_precise(atanh(a), 35999, 'invalid half', Option::None(())); // 0.42914542526098 + + let a = FixedTrait::ZERO(); + assert(atanh(a).into() == 0, 'invalid zero'); + + let a = FixedTrait::new(HALF, true); // 0.5 + assert_precise(atanh(a), -35999, 'invalid neg half', Option::None(())); // 0.42914542526098 + + let a = FixedTrait::new(58982, true); // 0.9 + assert_precise(atanh(a), -96483, 'invalid -0.9', Option::None(())); // 1.36892147623689 +} diff --git a/src/numbers/fixed_point/implementations/fp16x16wide/math/lut.cairo b/src/numbers/fixed_point/implementations/fp16x16wide/math/lut.cairo new file mode 100644 index 000000000..e96b0d389 --- /dev/null +++ b/src/numbers/fixed_point/implementations/fp16x16wide/math/lut.cairo @@ -0,0 +1,1235 @@ +// Calculates the most significant bit +fn msb(whole: u64) -> (u64, u64) { + if whole < 256 { + if whole < 2 { + return (0, 1); + } + if whole < 4 { + return (1, 2); + } + if whole < 8 { + return (2, 4); + } + if whole < 16 { + return (3, 8); + } + if whole < 32 { + return (4, 16); + } + if whole < 64 { + return (5, 32); + } + if whole < 128 { + return (6, 64); + } + if whole < 256 { + return (7, 128); + } + } else if whole < 65536 { + if whole < 512 { + return (8, 256); + } + if whole < 1024 { + return (9, 512); + } + if whole < 2048 { + return (10, 1024); + } + if whole < 4096 { + return (11, 2048); + } + if whole < 8192 { + return (12, 4096); + } + if whole < 16384 { + return (13, 8192); + } + if whole < 32768 { + return (14, 16384); + } + if whole < 65536 { + return (15, 32768); + } + } + + return (16, 65536); +} + +fn exp2(exp: u64) -> u64 { + if exp <= 16 { + if exp == 0 { + return 1; + } + if exp == 1 { + return 2; + } + if exp == 2 { + return 4; + } + if exp == 3 { + return 8; + } + if exp == 4 { + return 16; + } + if exp == 5 { + return 32; + } + if exp == 6 { + return 64; + } + if exp == 7 { + return 128; + } + if exp == 8 { + return 256; + } + if exp == 9 { + return 512; + } + if exp == 10 { + return 1024; + } + if exp == 11 { + return 2048; + } + if exp == 12 { + return 4096; + } + if exp == 13 { + return 8192; + } + if exp == 14 { + return 16384; + } + if exp == 15 { + return 32768; + } + if exp == 16 { + return 65536; + } + } + + return 65536; +} + +fn sin(a: u64) -> (u64, u64, u64) { + let slot = a / 402; + + if slot < 128 { + if slot < 64 { + if slot < 32 { + if slot < 16 { + if slot == 0 { + return (0, 0, 402); + } + if slot == 1 { + return (402, 402, 804); + } + if slot == 2 { + return (804, 804, 1206); + } + if slot == 3 { + return (1206, 1206, 1608); + } + if slot == 4 { + return (1608, 1608, 2010); + } + if slot == 5 { + return (2011, 2010, 2412); + } + if slot == 6 { + return (2413, 2412, 2814); + } + if slot == 7 { + return (2815, 2814, 3216); + } + if slot == 8 { + return (3217, 3216, 3617); + } + if slot == 9 { + return (3619, 3617, 4019); + } + if slot == 10 { + return (4023, 4019, 4420); + } + if slot == 11 { + return (4423, 4420, 4821); + } + if slot == 12 { + return (4825, 4821, 5222); + } + if slot == 13 { + return (5228, 5222, 5623); + } + if slot == 14 { + return (5630, 5623, 6023); + } + if slot == 15 { + return (6032, 6023, 6424); + } + } else { + if slot == 16 { + return (6434, 6424, 6824); + } + if slot == 17 { + return (6836, 6824, 7224); + } + if slot == 18 { + return (7238, 7224, 7623); + } + if slot == 19 { + return (7640, 7623, 8022); + } + if slot == 20 { + return (8042, 8022, 8421); + } + if slot == 21 { + return (8445, 8421, 8820); + } + if slot == 22 { + return (8847, 8820, 9218); + } + if slot == 23 { + return (9249, 9218, 9616); + } + if slot == 24 { + return (9651, 9616, 10014); + } + if slot == 25 { + return (10053, 10014, 10411); + } + if slot == 26 { + return (10455, 10411, 10808); + } + if slot == 27 { + return (10857, 10808, 11204); + } + if slot == 28 { + return (11259, 11204, 11600); + } + if slot == 29 { + return (11662, 11600, 11996); + } + if slot == 30 { + return (12064, 11996, 12391); + } + if slot == 31 { + return (12466, 12391, 12785); + } + } + } else { + if slot < 48 { + if slot == 32 { + return (12868, 12785, 13180); + } + if slot == 33 { + return (13270, 13180, 13573); + } + if slot == 34 { + return (13672, 13573, 13966); + } + if slot == 35 { + return (14074, 13966, 14359); + } + if slot == 36 { + return (14476, 14359, 14751); + } + if slot == 37 { + return (14879, 14751, 15143); + } + if slot == 38 { + return (15281, 15143, 15534); + } + if slot == 39 { + return (15683, 15534, 15924); + } + if slot == 40 { + return (16081, 15924, 16314); + } + if slot == 41 { + return (16487, 16314, 16703); + } + if slot == 42 { + return (16889, 16703, 17091); + } + if slot == 43 { + return (17291, 17091, 17479); + } + if slot == 44 { + return (17693, 17479, 17867); + } + if slot == 45 { + return (18096, 17867, 18253); + } + if slot == 46 { + return (18498, 18253, 18639); + } + if slot == 47 { + return (18900, 18639, 19024); + } + } else { + if slot == 48 { + return (19302, 19024, 19409); + } + if slot == 49 { + return (19704, 19409, 19792); + } + if slot == 50 { + return (20113, 19792, 20175); + } + if slot == 51 { + return (20508, 20175, 20557); + } + if slot == 52 { + return (20910, 20557, 20939); + } + if slot == 53 { + return (21313, 20939, 21320); + } + if slot == 54 { + return (21715, 21320, 21699); + } + if slot == 55 { + return (22117, 21699, 22078); + } + if slot == 56 { + return (22519, 22078, 22457); + } + if slot == 57 { + return (22921, 22457, 22834); + } + if slot == 58 { + return (23323, 22834, 23210); + } + if slot == 59 { + return (23725, 23210, 23586); + } + if slot == 60 { + return (24127, 23586, 23961); + } + if slot == 61 { + return (24530, 23961, 24335); + } + if slot == 62 { + return (24932, 24335, 24708); + } + if slot == 63 { + return (25334, 24708, 25080); + } + } + } + } else { + if slot < 96 { + if slot < 80 { + if slot == 64 { + return (25736, 25080, 25451); + } + if slot == 65 { + return (26138, 25451, 25821); + } + if slot == 66 { + return (26540, 25821, 26190); + } + if slot == 67 { + return (26942, 26190, 26558); + } + if slot == 68 { + return (27344, 26558, 26925); + } + if slot == 69 { + return (27747, 26925, 27291); + } + if slot == 70 { + return (28149, 27291, 27656); + } + if slot == 71 { + return (28551, 27656, 28020); + } + if slot == 72 { + return (28953, 28020, 28383); + } + if slot == 73 { + return (29355, 28383, 28745); + } + if slot == 74 { + return (29757, 28745, 29106); + } + if slot == 75 { + return (30159, 29106, 29466); + } + if slot == 76 { + return (30561, 29466, 29824); + } + if slot == 77 { + return (30964, 29824, 30182); + } + if slot == 78 { + return (31366, 30182, 30538); + } + if slot == 79 { + return (31768, 30538, 30893); + } + } else { + if slot == 80 { + return (32171, 30893, 31248); + } + if slot == 81 { + return (32572, 31248, 31600); + } + if slot == 82 { + return (32974, 31600, 31952); + } + if slot == 83 { + return (33376, 31952, 32303); + } + if slot == 84 { + return (33778, 32303, 32652); + } + if slot == 85 { + return (34181, 32652, 33000); + } + if slot == 86 { + return (34583, 33000, 33347); + } + if slot == 87 { + return (34985, 33347, 33692); + } + if slot == 88 { + return (35387, 33692, 34037); + } + if slot == 89 { + return (35789, 34037, 34380); + } + if slot == 90 { + return (36194, 34380, 34721); + } + if slot == 91 { + return (36593, 34721, 35062); + } + if slot == 92 { + return (36995, 35062, 35401); + } + if slot == 93 { + return (37398, 35401, 35738); + } + if slot == 94 { + return (37800, 35738, 36075); + } + if slot == 95 { + return (38202, 36075, 36410); + } + } + } else { + if slot < 112 { + if slot == 96 { + return (38604, 36410, 36744); + } + if slot == 97 { + return (39006, 36744, 37076); + } + if slot == 98 { + return (39408, 37076, 37407); + } + if slot == 99 { + return (39810, 37407, 37736); + } + if slot == 100 { + return (40227, 37736, 38064); + } + if slot == 101 { + return (40615, 38064, 38391); + } + if slot == 102 { + return (41017, 38391, 38716); + } + if slot == 103 { + return (41419, 38716, 39040); + } + if slot == 104 { + return (41821, 39040, 39362); + } + if slot == 105 { + return (42223, 39362, 39683); + } + if slot == 106 { + return (42625, 39683, 40002); + } + if slot == 107 { + return (43027, 40002, 40320); + } + if slot == 108 { + return (43429, 40320, 40636); + } + if slot == 109 { + return (43832, 40636, 40951); + } + if slot == 110 { + return (44234, 40951, 41264); + } + if slot == 111 { + return (44636, 41264, 41576); + } + } else { + if slot == 112 { + return (45038, 41576, 41886); + } + if slot == 113 { + return (45440, 41886, 42194); + } + if slot == 114 { + return (45842, 42194, 42501); + } + if slot == 115 { + return (46244, 42501, 42806); + } + if slot == 116 { + return (46646, 42806, 43110); + } + if slot == 117 { + return (47048, 43110, 43412); + } + if slot == 118 { + return (47451, 43412, 43713); + } + if slot == 119 { + return (47853, 43713, 44011); + } + if slot == 120 { + return (48252, 44011, 44308); + } + if slot == 121 { + return (48657, 44308, 44604); + } + if slot == 122 { + return (49059, 44604, 44898); + } + if slot == 123 { + return (49461, 44898, 45190); + } + if slot == 124 { + return (49863, 45190, 45480); + } + if slot == 125 { + return (50265, 45480, 45769); + } + if slot == 126 { + return (50668, 45769, 46056); + } + if slot == 127 { + return (51070, 46056, 46341); + } + } + } + } + } else { + if slot < 192 { + if slot < 160 { + if slot < 144 { + if slot == 128 { + return (51472, 46341, 46624); + } + if slot == 129 { + return (51874, 46624, 46906); + } + if slot == 130 { + return (52285, 46906, 47186); + } + if slot == 131 { + return (52678, 47186, 47464); + } + if slot == 132 { + return (53080, 47464, 47741); + } + if slot == 133 { + return (53482, 47741, 48015); + } + if slot == 134 { + return (53885, 48015, 48288); + } + if slot == 135 { + return (54287, 48288, 48559); + } + if slot == 136 { + return (54689, 48559, 48828); + } + if slot == 137 { + return (55091, 48828, 49095); + } + if slot == 138 { + return (55493, 49095, 49361); + } + if slot == 139 { + return (55895, 49361, 49624); + } + if slot == 140 { + return (56297, 49624, 49886); + } + if slot == 141 { + return (56699, 49886, 50146); + } + if slot == 142 { + return (57102, 50146, 50404); + } + if slot == 143 { + return (57504, 50404, 50660); + } + } else { + if slot == 144 { + return (57906, 50660, 50914); + } + if slot == 145 { + return (58308, 50914, 51166); + } + if slot == 146 { + return (58710, 51166, 51417); + } + if slot == 147 { + return (59112, 51417, 51665); + } + if slot == 148 { + return (59514, 51665, 51911); + } + if slot == 149 { + return (59916, 51911, 52156); + } + if slot == 150 { + return (60320, 52156, 52398); + } + if slot == 151 { + return (60721, 52398, 52639); + } + if slot == 152 { + return (61123, 52639, 52878); + } + if slot == 153 { + return (61525, 52878, 53114); + } + if slot == 154 { + return (61927, 53114, 53349); + } + if slot == 155 { + return (62329, 53349, 53581); + } + if slot == 156 { + return (62731, 53581, 53812); + } + if slot == 157 { + return (63133, 53812, 54040); + } + if slot == 158 { + return (63536, 54040, 54267); + } + if slot == 159 { + return (63938, 54267, 54491); + } + if slot == 160 { + return (64343, 54491, 54714); + } + } + } else { + if slot < 176 { + if slot == 161 { + return (64742, 54714, 54934); + } + if slot == 162 { + return (65144, 54934, 55152); + } + if slot == 163 { + return (65546, 55152, 55368); + } + if slot == 164 { + return (65948, 55368, 55582); + } + if slot == 165 { + return (66350, 55582, 55794); + } + if slot == 166 { + return (66753, 55794, 56004); + } + if slot == 167 { + return (67155, 56004, 56212); + } + if slot == 168 { + return (67557, 56212, 56418); + } + if slot == 169 { + return (67959, 56418, 56621); + } + if slot == 170 { + return (68361, 56621, 56823); + } + if slot == 171 { + return (68763, 56823, 57022); + } + if slot == 172 { + return (69165, 57022, 57219); + } + if slot == 173 { + return (69567, 57219, 57414); + } + if slot == 174 { + return (69970, 57414, 57607); + } + if slot == 175 { + return (70372, 57607, 57798); + } + } else { + if slot == 176 { + return (70774, 57798, 57986); + } + if slot == 177 { + return (71176, 57986, 58172); + } + if slot == 178 { + return (71578, 58172, 58356); + } + if slot == 179 { + return (71980, 58356, 58538); + } + if slot == 180 { + return (72382, 58538, 58718); + } + if slot == 181 { + return (72784, 58718, 58896); + } + if slot == 182 { + return (73187, 58896, 59071); + } + if slot == 183 { + return (73589, 59071, 59244); + } + if slot == 184 { + return (73991, 59244, 59415); + } + if slot == 185 { + return (74393, 59415, 59583); + } + if slot == 186 { + return (74795, 59583, 59750); + } + if slot == 187 { + return (75197, 59750, 59914); + } + if slot == 188 { + return (75599, 59914, 60075); + } + if slot == 189 { + return (76001, 60075, 60235); + } + if slot == 190 { + return (76401, 60235, 60392); + } + if slot == 191 { + return (76806, 60392, 60547); + } + } + } + } else { + if slot < 224 { + if slot < 208 { + if slot == 192 { + return (77208, 60547, 60700); + } + if slot == 193 { + return (77610, 60700, 60851); + } + if slot == 194 { + return (78012, 60851, 60999); + } + if slot == 195 { + return (78414, 60999, 61145); + } + if slot == 196 { + return (78816, 61145, 61288); + } + if slot == 197 { + return (79218, 61288, 61429); + } + if slot == 198 { + return (79621, 61429, 61568); + } + if slot == 199 { + return (80023, 61568, 61705); + } + if slot == 200 { + return (80423, 61705, 61839); + } + if slot == 201 { + return (80827, 61839, 61971); + } + if slot == 202 { + return (81229, 61971, 62101); + } + if slot == 203 { + return (81631, 62101, 62228); + } + if slot == 204 { + return (82033, 62228, 62353); + } + if slot == 205 { + return (82435, 62353, 62476); + } + if slot == 206 { + return (82838, 62476, 62596); + } + if slot == 207 { + return (83240, 62596, 62714); + } + } else { + if slot == 208 { + return (83642, 62714, 62830); + } + if slot == 209 { + return (84044, 62830, 62943); + } + if slot == 210 { + return (84446, 62943, 63054); + } + if slot == 211 { + return (84848, 63054, 63162); + } + if slot == 212 { + return (85250, 63162, 63268); + } + if slot == 213 { + return (85652, 63268, 63372); + } + if slot == 214 { + return (86055, 63372, 63473); + } + if slot == 215 { + return (86457, 63473, 63572); + } + if slot == 216 { + return (86859, 63572, 63668); + } + if slot == 217 { + return (87261, 63668, 63763); + } + if slot == 218 { + return (87663, 63763, 63854); + } + if slot == 219 { + return (88065, 63854, 63944); + } + if slot == 220 { + return (88467, 63944, 64031); + } + if slot == 221 { + return (88869, 64031, 64115); + } + if slot == 222 { + return (89271, 64115, 64197); + } + if slot == 223 { + return (89674, 64197, 64277); + } + } + } else { + if slot < 240 { + if slot == 224 { + return (90076, 64277, 64354); + } + if slot == 225 { + return (90478, 64354, 64429); + } + if slot == 226 { + return (90880, 64429, 64501); + } + if slot == 227 { + return (91282, 64501, 64571); + } + if slot == 228 { + return (91684, 64571, 64639); + } + if slot == 229 { + return (92086, 64639, 64704); + } + if slot == 230 { + return (92491, 64704, 64766); + } + if slot == 231 { + return (92891, 64766, 64827); + } + if slot == 232 { + return (93293, 64827, 64884); + } + if slot == 233 { + return (93695, 64884, 64940); + } + if slot == 234 { + return (94097, 64940, 64993); + } + if slot == 235 { + return (94499, 64993, 65043); + } + if slot == 236 { + return (94901, 65043, 65091); + } + if slot == 237 { + return (95303, 65091, 65137); + } + if slot == 238 { + return (95705, 65137, 65180); + } + if slot == 239 { + return (96108, 65180, 65220); + } + } else { + if slot == 240 { + return (96514, 65220, 65259); + } + if slot == 241 { + return (96912, 65259, 65294); + } + if slot == 242 { + return (97314, 65294, 65328); + } + if slot == 243 { + return (97716, 65328, 65358); + } + if slot == 244 { + return (98118, 65358, 65387); + } + if slot == 245 { + return (98520, 65387, 65413); + } + if slot == 246 { + return (98922, 65413, 65436); + } + if slot == 247 { + return (99325, 65436, 65457); + } + if slot == 248 { + return (99727, 65457, 65476); + } + if slot == 249 { + return (100129, 65476, 65492); + } + if slot == 250 { + return (100531, 65492, 65505); + } + if slot == 251 { + return (100933, 65505, 65516); + } + if slot == 252 { + return (101335, 65516, 65525); + } + if slot == 253 { + return (101737, 65525, 65531); + } + if slot == 254 { + return (102139, 65531, 65535); + } + } + } + } + } + + return (102542, 65535, 65536); +} + +fn atan(a: u64) -> (u64, u64, u64) { + let slot = a / 459; + + if slot == 0 { + return (0, 0, 459); + } + if slot == 1 { + return (459, 459, 917); + } + if slot == 2 { + return (918, 917, 1376); + } + if slot == 3 { + return (1376, 1376, 1835); + } + if slot == 4 { + return (1835, 1835, 2293); + } + if slot == 5 { + return (2294, 2293, 2751); + } + if slot == 6 { + return (2753, 2751, 3209); + } + if slot == 7 { + return (3211, 3209, 3666); + } + if slot == 8 { + return (3670, 3666, 4123); + } + if slot == 9 { + return (4129, 4123, 4580); + } + if slot == 10 { + return (4591, 4580, 5036); + } + if slot == 11 { + return (5046, 5036, 5492); + } + if slot == 12 { + return (5505, 5492, 5947); + } + if slot == 13 { + return (5964, 5947, 6402); + } + if slot == 14 { + return (6423, 6402, 6856); + } + if slot == 15 { + return (6881, 6856, 7310); + } + if slot == 16 { + return (7340, 7310, 7762); + } + if slot == 17 { + return (7799, 7762, 8214); + } + if slot == 18 { + return (8258, 8214, 8665); + } + if slot == 19 { + return (8716, 8665, 9116); + } + if slot == 20 { + return (9181, 9116, 9565); + } + if slot == 21 { + return (9634, 9565, 10014); + } + if slot == 22 { + return (10093, 10014, 10462); + } + if slot == 23 { + return (10551, 10462, 10908); + } + if slot == 24 { + return (11010, 10908, 11354); + } + if slot == 25 { + return (11469, 11354, 11798); + } + if slot == 26 { + return (11928, 11798, 12242); + } + if slot == 27 { + return (12386, 12242, 12684); + } + if slot == 28 { + return (12845, 12684, 13125); + } + if slot == 29 { + return (13304, 13125, 13565); + } + if slot == 30 { + return (13762, 13565, 14004); + } + if slot == 31 { + return (14221, 14004, 14442); + } + if slot == 32 { + return (14680, 14442, 14878); + } + if slot == 33 { + return (15139, 14878, 15313); + } + if slot == 34 { + return (15598, 15313, 15746); + } + if slot == 35 { + return (16056, 15746, 16178); + } + if slot == 36 { + return (16515, 16178, 16609); + } + if slot == 37 { + return (16974, 16609, 17038); + } + if slot == 38 { + return (17433, 17038, 17466); + } + if slot == 39 { + return (17891, 17466, 17892); + } + if slot == 40 { + return (18353, 17892, 18317); + } + if slot == 41 { + return (18809, 18317, 18740); + } + if slot == 42 { + return (19268, 18740, 19161); + } + if slot == 43 { + return (19726, 19161, 19581); + } + if slot == 44 { + return (20185, 19581, 19999); + } + if slot == 45 { + return (20644, 19999, 20416); + } + if slot == 46 { + return (21103, 20416, 20830); + } + if slot == 47 { + return (21561, 20830, 21243); + } + if slot == 48 { + return (22020, 21243, 21655); + } + if slot == 49 { + return (22479, 21655, 22064); + } + if slot == 50 { + return (22944, 22064, 22472); + } + if slot == 51 { + return (23396, 22472, 22878); + } + if slot == 52 { + return (23855, 22878, 23282); + } + if slot == 53 { + return (24314, 23282, 23685); + } + if slot == 54 { + return (24773, 23685, 24085); + } + if slot == 55 { + return (25231, 24085, 24484); + } + if slot == 56 { + return (25690, 24484, 24880); + } + if slot == 57 { + return (26149, 24880, 25275); + } + if slot == 58 { + return (26608, 25275, 25668); + } + if slot == 59 { + return (27066, 25668, 26059); + } + if slot == 60 { + return (27534, 26059, 26448); + } + if slot == 61 { + return (27984, 26448, 26835); + } + if slot == 62 { + return (28443, 26835, 27220); + } + if slot == 63 { + return (28901, 27220, 27603); + } + if slot == 64 { + return (29360, 27603, 27984); + } + if slot == 65 { + return (29819, 27984, 28363); + } + if slot == 66 { + return (30278, 28363, 28740); + } + if slot == 67 { + return (30736, 28740, 29115); + } + if slot == 68 { + return (31195, 29115, 29488); + } + if slot == 69 { + return (31654, 29488, 29859); + } + if slot == 70 { + return (32113, 29859, 30228); + } + if slot == 71 { + return (32571, 30228, 30595); + } + if slot == 72 { + return (33030, 30595, 30960); + } + if slot == 73 { + return (33489, 30960, 31323); + } + if slot == 74 { + return (33948, 31323, 31683); + } + if slot == 75 { + return (34406, 31683, 32042); + } + if slot == 76 { + return (34865, 32042, 32398); + } + if slot == 77 { + return (35324, 32398, 32753); + } + if slot == 78 { + return (35783, 32753, 33105); + } + if slot == 79 { + return (36241, 33105, 33455); + } + if slot == 80 { + return (36700, 33455, 33804); + } + if slot == 81 { + return (37159, 33804, 34150); + } + if slot == 82 { + return (37618, 34150, 34494); + } + if slot == 83 { + return (38076, 34494, 34836); + } + if slot == 84 { + return (38535, 34836, 35175); + } + if slot == 85 { + return (38994, 35175, 35513); + } + if slot == 86 { + return (39453, 35513, 35849); + } + if slot == 87 { + return (39911, 35849, 36183); + } + if slot == 88 { + return (40370, 36183, 36514); + } + if slot == 89 { + return (40829, 36514, 36843); + } + if slot == 90 { + return (41288, 36843, 37171); + } + if slot == 91 { + return (41746, 37171, 37496); + } + if slot == 92 { + return (42205, 37496, 37819); + } + if slot == 93 { + return (42664, 37819, 38141); + } + if slot == 94 { + return (43123, 38141, 38460); + } + if slot == 95 { + return (43581, 38460, 38777); + } + if slot == 96 { + return (44040, 38777, 39092); + } + if slot == 97 { + return (44499, 39092, 39405); + } + if slot == 98 { + return (44958, 39405, 39716); + } + + return (45416, 39716, 40025); +} diff --git a/src/numbers/fixed_point/implementations/fp16x16wide/math/trig.cairo b/src/numbers/fixed_point/implementations/fp16x16wide/math/trig.cairo new file mode 100644 index 000000000..8ccc236b7 --- /dev/null +++ b/src/numbers/fixed_point/implementations/fp16x16wide/math/trig.cairo @@ -0,0 +1,450 @@ +use debug::PrintTrait; +use integer::{u64_safe_divmod, u64_as_non_zero}; +use option::OptionTrait; + +use orion::numbers::fixed_point::implementations::fp16x16wide::math::lut; +use orion::numbers::fixed_point::implementations::fp16x16wide::core::{ + HALF, ONE, TWO, FP16x16Wide, FP16x16Impl, FP16x16Add, FP16x16Sub, FP16x16Mul, FP16x16Div, + FP16x16IntoFelt252, FixedTrait +}; + +// CONSTANTS + +const TWO_PI: u64 = 411775; +const PI: u64 = 205887; +const HALF_PI: u64 = 102944; + +// PUBLIC + +// Calculates arccos(a) for -1 <= a <= 1 (fixed point) +// arccos(a) = arcsin(sqrt(1 - a^2)) - arctan identity has discontinuity at zero +fn acos(a: FP16x16Wide) -> FP16x16Wide { + let asin_arg = (FixedTrait::ONE() - a * a).sqrt(); // will fail if a > 1 + let asin_res = asin(asin_arg); + + if (a.sign) { + return FixedTrait::new(PI, false) - asin_res; + } else { + return asin_res; + } +} + +fn acos_fast(a: FP16x16Wide) -> FP16x16Wide { + let asin_arg = (FixedTrait::ONE() - a * a).sqrt(); // will fail if a > 1 + let asin_res = asin_fast(asin_arg); + + if (a.sign) { + return FixedTrait::new(PI, false) - asin_res; + } else { + return asin_res; + } +} + +// Calculates arcsin(a) for -1 <= a <= 1 (fixed point) +// arcsin(a) = arctan(a / sqrt(1 - a^2)) +fn asin(a: FP16x16Wide) -> FP16x16Wide { + if (a.mag == ONE) { + return FixedTrait::new(HALF_PI, a.sign); + } + + let div = (FixedTrait::ONE() - a * a).sqrt(); // will fail if a > 1 + return atan(a / div); +} + +fn asin_fast(a: FP16x16Wide) -> FP16x16Wide { + if (a.mag == ONE) { + return FixedTrait::new(HALF_PI, a.sign); + } + + let div = (FixedTrait::ONE() - a * a).sqrt(); // will fail if a > 1 + return atan_fast(a / div); +} + +// Calculates arctan(a) (fixed point) +// See https://stackoverflow.com/a/50894477 for range adjustments +fn atan(a: FP16x16Wide) -> FP16x16Wide { + let mut at = a.abs(); + let mut shift = false; + let mut invert = false; + + // Invert value when a > 1 + if (at.mag > ONE) { + at = FixedTrait::ONE() / at; + invert = true; + } + + // Account for lack of precision in polynomaial when a > 0.7 + if (at.mag > 45875) { + let sqrt3_3 = FixedTrait::new(37837, false); // sqrt(3) / 3 + at = (at - sqrt3_3) / (FixedTrait::ONE() + at * sqrt3_3); + shift = true; + } + + let r10 = FixedTrait::new(120, true) * at; + let r9 = (r10 + FixedTrait::new(3066, true)) * at; + let r8 = (r9 + FixedTrait::new(12727, false)) * at; + let r7 = (r8 + FixedTrait::new(17170, true)) * at; + let r6 = (r7 + FixedTrait::new(2865, false)) * at; + let r5 = (r6 + FixedTrait::new(12456, false)) * at; + let r4 = (r5 + FixedTrait::new(90, false)) * at; + let r3 = (r4 + FixedTrait::new(21852, true)) * at; + let r2 = r3 * at; + let mut res = (r2 + FixedTrait::new(65536, false)) * at; + + // Adjust for sign change, inversion, and shift + if (shift) { + res = res + FixedTrait::new(34315, false); // pi / 6 + } + + if (invert) { + res = res - FixedTrait::new(HALF_PI, false); + } + + return FixedTrait::new(res.mag, a.sign); +} + + +fn atan_fast(a: FP16x16Wide) -> FP16x16Wide { + let mut at = a.abs(); + let mut shift = false; + let mut invert = false; + + // Invert value when a > 1 + if (at.mag > ONE) { + at = FixedTrait::ONE() / at; + invert = true; + } + + // Account for lack of precision in polynomaial when a > 0.7 + if (at.mag > 45875) { + let sqrt3_3 = FixedTrait::new(37837, false); // sqrt(3) / 3 + at = (at - sqrt3_3) / (FixedTrait::ONE() + at * sqrt3_3); + shift = true; + } + + let (start, low, high) = lut::atan(at.mag); + let partial_step = FixedTrait::new(at.mag - start, false) / FixedTrait::new(459, false); + let mut res = partial_step * FixedTrait::new(high - low, false) + FixedTrait::new(low, false); + + // Adjust for sign change, inversion, and shift + if (shift) { + res = res + FixedTrait::new(34315, false); // pi / 6 + } + + if (invert) { + res = res - FixedTrait::::new(HALF_PI, false); + } + + return FixedTrait::new(res.mag, a.sign); +} + +// Calculates cos(a) with a in radians (fixed point) +fn cos(a: FP16x16Wide) -> FP16x16Wide { + return sin(FixedTrait::new(HALF_PI, false) - a); +} + +fn cos_fast(a: FP16x16Wide) -> FP16x16Wide { + return sin_fast(FixedTrait::new(HALF_PI, false) - a); +} + +fn sin(a: FP16x16Wide) -> FP16x16Wide { + let a1 = a.mag % TWO_PI; + let (whole_rem, partial_rem) = u64_safe_divmod(a1, u64_as_non_zero(PI)); + let a2 = FixedTrait::new(partial_rem, false); + let partial_sign = whole_rem == 1; + + let loop_res = a2 * _sin_loop(a2, 7, FixedTrait::ONE()); + return FixedTrait::new(loop_res.mag, a.sign ^ partial_sign && loop_res.mag != 0); +} + +fn sin_fast(a: FP16x16Wide) -> FP16x16Wide { + let a1 = a.mag % TWO_PI; + let (whole_rem, mut partial_rem) = u64_safe_divmod(a1, u64_as_non_zero(PI)); + let partial_sign = whole_rem == 1; + + if partial_rem >= HALF_PI { + partial_rem = PI - partial_rem; + } + + let (start, low, high) = lut::sin(partial_rem); + let partial_step = FixedTrait::new(partial_rem - start, false) / FixedTrait::new(402, false); + let res = partial_step * (FixedTrait::new(high, false) - FixedTrait::new(low, false)) + + FixedTrait::::new(low, false); + + return FixedTrait::new(res.mag, a.sign ^ partial_sign && res.mag != 0); +} + +// Calculates tan(a) with a in radians (fixed point) +fn tan(a: FP16x16Wide) -> FP16x16Wide { + let sinx = sin(a); + let cosx = cos(a); + assert(cosx.mag != 0, 'tan undefined'); + return sinx / cosx; +} + +fn tan_fast(a: FP16x16Wide) -> FP16x16Wide { + let sinx = sin_fast(a); + let cosx = cos_fast(a); + assert(cosx.mag != 0, 'tan undefined'); + return sinx / cosx; +} + +// Helper function to calculate Taylor series for sin +fn _sin_loop(a: FP16x16Wide, i: u64, acc: FP16x16Wide) -> FP16x16Wide { + let div = (2 * i + 2) * (2 * i + 3); + let term = a * a * acc / FixedTrait::new_unscaled(div, false); + let new_acc = FixedTrait::ONE() - term; + + if (i == 0) { + return new_acc; + } + + return _sin_loop(a, i - 1, new_acc); +} + +// Tests -------------------------------------------------------------------------------------------------------------- + +use traits::Into; + +use orion::numbers::fixed_point::implementations::fp16x16wide::helpers::{ + assert_precise, assert_relative +}; +use orion::numbers::fixed_point::implementations::fp16x16wide::core::{FP16x16PartialEq, FP16x16Print}; + +#[test] +#[available_gas(8000000)] +fn test_acos() { + let error = Option::Some(84); // 1e-5 + + let a = FixedTrait::ONE(); + assert(acos(a).into() == 0, 'invalid one'); + + let a = FixedTrait::new(ONE / 2, false); + assert_relative(acos(a), 68629, 'invalid half', error); // 1.3687308642680 + + let a = FixedTrait::ZERO(); + assert_relative(acos(a), HALF_PI.into(), 'invalid zero', Option::None(())); // PI / 2 + + let a = FixedTrait::new(ONE / 2, true); + assert_relative(acos(a), 137258, 'invalid neg half', error); // 2.737461741902 + + let a = FixedTrait::new(ONE, true); + assert_relative(acos(a), PI.into(), 'invalid neg one', Option::None(())); // PI +} + +#[test] +#[available_gas(8000000)] +fn test_acos_fast() { + let error = Option::Some(84); // 1e-5 + + let a = FixedTrait::ONE(); + assert(acos_fast(a).into() == 0, 'invalid one'); + + let a = FixedTrait::new(ONE / 2, false); + assert_relative(acos_fast(a), 68629, 'invalid half', error); // 1.3687308642680 + + let a = FixedTrait::ZERO(); + assert_relative(acos_fast(a), HALF_PI.into(), 'invalid zero', Option::None(())); // PI / 2 + + let a = FixedTrait::new(ONE / 2, true); + assert_relative(acos_fast(a), 137258, 'invalid neg half', error); // 2.737461741902 + + let a = FixedTrait::new(ONE, true); + assert_relative(acos_fast(a), PI.into(), 'invalid neg one', Option::None(())); // PI +} + +#[test] +#[should_panic] +#[available_gas(8000000)] +fn test_acos_fail() { + let a = FixedTrait::new(2 * ONE, true); + acos(a); +} + +#[test] +#[available_gas(8000000)] +fn test_atan_fast() { + let error = Option::Some(84); // 1e-5 + + let a = FixedTrait::new(2 * ONE, false); + assert_relative(atan_fast(a), 72558, 'invalid two', error); + + let a = FixedTrait::ONE(); + assert_relative(atan_fast(a), 51472, 'invalid one', error); + + let a = FixedTrait::new(ONE / 2, false); + assert_relative(atan_fast(a), 30386, 'invalid half', error); + + let a = FixedTrait::ZERO(); + assert(atan_fast(a).into() == 0, 'invalid zero'); + + let a = FixedTrait::new(ONE / 2, true); + assert_relative(atan_fast(a), -30386, 'invalid neg half', error); + + let a = FixedTrait::new(ONE, true); + assert_relative(atan_fast(a), -51472, 'invalid neg one', error); + + let a = FixedTrait::new(2 * ONE, true); + assert_relative(atan_fast(a), -72558, 'invalid neg two', error); +} + +#[test] +#[available_gas(8000000)] +fn test_atan() { + let a = FixedTrait::new(2 * ONE, false); + assert_relative(atan(a), 72558, 'invalid two', Option::None(())); + + let a = FixedTrait::ONE(); + assert_relative(atan(a), 51472, 'invalid one', Option::None(())); + + let a = FixedTrait::new(ONE / 2, false); + assert_relative(atan(a), 30386, 'invalid half', Option::None(())); + + let a = FixedTrait::ZERO(); + assert(atan(a).into() == 0, 'invalid zero'); + + let a = FixedTrait::new(ONE / 2, true); + assert_relative(atan(a), -30386, 'invalid neg half', Option::None(())); + + let a = FixedTrait::new(ONE, true); + assert_relative(atan(a), -51472, 'invalid neg one', Option::None(())); + + let a = FixedTrait::new(2 * ONE, true); + assert_relative(atan(a), -72558, 'invalid neg two', Option::None(())); +} + +#[test] +#[available_gas(8000000)] +fn test_asin() { + let error = Option::Some(84); // 1e-5 + + let a = FixedTrait::ONE(); + assert_relative(asin(a), HALF_PI.into(), 'invalid one', Option::None(())); // PI / 2 + + let a = FixedTrait::new(ONE / 2, false); + assert_relative(asin(a), 34315, 'invalid half', error); + + let a = FixedTrait::ZERO(); + assert_precise(asin(a), 0, 'invalid zero', Option::None(())); + + let a = FixedTrait::new(ONE / 2, true); + assert_relative(asin(a), -34315, 'invalid neg half', error); + + let a = FixedTrait::new(ONE, true); + assert_relative(asin(a), -HALF_PI.into(), 'invalid neg one', Option::None(())); // -PI / 2 +} + +#[test] +#[should_panic] +#[available_gas(8000000)] +fn test_asin_fail() { + let a = FixedTrait::new(2 * ONE, false); + asin(a); +} + +#[test] +#[available_gas(8000000)] +fn test_cos() { + let a = FixedTrait::new(HALF_PI, false); + assert(cos(a).into() == 0, 'invalid half pi'); + + let a = FixedTrait::new(HALF_PI / 2, false); + assert_relative(cos(a), 46341, 'invalid quarter pi', Option::None(())); // 0.55242717280199 + + let a = FixedTrait::new(PI, false); + assert_relative(cos(a), -1 * ONE.into(), 'invalid pi', Option::None(())); + + let a = FixedTrait::new(HALF_PI, true); + assert_precise(cos(a), 0, 'invalid neg half pi', Option::None(())); + + let a = FixedTrait::new_unscaled(17, false); + assert_relative(cos(a), -18033, 'invalid 17', Option::None(())); // -0.21497123284870 + + let a = FixedTrait::new_unscaled(17, true); + assert_relative(cos(a), -18033, 'invalid -17', Option::None(())); // -0.21497123284870 +} + +#[test] +#[available_gas(8000000)] +fn test_cos_fast() { + let error = Option::Some(84); // 1e-5 + + let a = FixedTrait::new(HALF_PI, false); + assert(cos_fast(a).into() == 0, 'invalid half pi'); + + let a = FixedTrait::new(HALF_PI / 2, false); + assert_precise(cos_fast(a), 46341, 'invalid quarter pi', error); // 0.55242717280199 + + let a = FixedTrait::new(PI, false); + assert_precise(cos_fast(a), -1 * ONE.into(), 'invalid pi', error); + + let a = FixedTrait::new(HALF_PI, true); + assert_precise(cos(a), 0, 'invalid neg half pi', Option::None(())); + + let a = FixedTrait::new_unscaled(17, false); + assert_precise(cos_fast(a), -18033, 'invalid 17', error); // -0.21497123284870 +} + +#[test] +#[available_gas(8000000)] +fn test_sin() { + let a = FixedTrait::new(HALF_PI, false); + assert_precise(sin(a), ONE.into(), 'invalid half pi', Option::None(())); + + let a = FixedTrait::new(HALF_PI / 2, false); + assert_precise(sin(a), 46341, 'invalid quarter pi', Option::None(())); // 0.55242717280199 + + let a = FixedTrait::new(PI, false); + assert(sin(a).into() == 0, 'invalid pi'); + + let a = FixedTrait::new(HALF_PI, true); + assert_precise( + sin(a), -ONE.into(), 'invalid neg half pi', Option::None(()) + ); // 0.78124999999529 + + let a = FixedTrait::new_unscaled(17, false); + assert_precise(sin(a), -63006, 'invalid 17', Option::None(())); // -0.75109179053073 + + let a = FixedTrait::new_unscaled(17, true); + assert_precise(sin(a), 63006, 'invalid -17', Option::None(())); // 0.75109179053073 +} + +#[test] +#[available_gas(8000000)] +fn test_sin_fast() { + let error = Option::Some(84); // 1e-5 + + let a = FixedTrait::new(HALF_PI, false); + assert_precise(sin_fast(a), ONE.into(), 'invalid half pi', error); + + let a = FixedTrait::new(HALF_PI / 2, false); + assert_precise(sin_fast(a), 46341, 'invalid quarter pi', error); // 0.55242717280199 + + let a = FixedTrait::new(PI, false); + assert(sin_fast(a).into() == 0, 'invalid pi'); + + let a = FixedTrait::new(HALF_PI, true); + assert_precise(sin_fast(a), -ONE.into(), 'invalid neg half pi', error); // 0.78124999999529 + + let a = FixedTrait::new_unscaled(17, false); + assert_precise(sin_fast(a), -63006, 'invalid 17', error); // -0.75109179053073 + + let a = FixedTrait::new_unscaled(17, true); + assert_precise(sin_fast(a), 63006, 'invalid -17', error); // 0.75109179053073 +} + +#[test] +#[available_gas(8000000)] +fn test_tan() { + let a = FixedTrait::new(HALF_PI / 2, false); + assert_precise(tan(a), ONE.into(), 'invalid quarter pi', Option::None(())); + + let a = FixedTrait::new(PI, false); + assert_precise(tan(a), 0, 'invalid pi', Option::None(())); + + let a = FixedTrait::new_unscaled(17, false); + assert_precise(tan(a), 228990, 'invalid 17', Option::None(())); // 3.3858731852805 + + let a = FixedTrait::new_unscaled(17, true); + assert_precise(tan(a), -228952, 'invalid -17', Option::None(())); // -3.3858731852805 +} diff --git a/src/operators/nn/functional/softmax.cairo b/src/operators/nn/functional/softmax.cairo index e0538a92a..70bdfe2cf 100644 --- a/src/operators/nn/functional/softmax.cairo +++ b/src/operators/nn/functional/softmax.cairo @@ -6,24 +6,24 @@ use orion::numbers::fixed_point::core::FixedTrait; fn softmax< T, TMAG, - U, - UMAG, - impl UFixedTrait: FixedTrait, + W, + WMAG, + impl WFixedTrait: FixedTrait, impl TTensor: TensorTrait, - impl UTensor: TensorTrait, + impl WTensor: TensorTrait, impl TTensorDiv: Div>, impl TDiv: Div, - impl TIntoU: Into, - impl UTtryIntoT: TryInto, + impl TIntoW: Into, + impl WTtryIntoT: TryInto, impl TCopy: Copy, impl TDrop: Drop, - impl UCopy: Copy, - impl UDrop: Drop, + impl WCopy: Copy, + impl WDrop: Drop, >( z: @Tensor, axis: usize, wide: bool ) -> Tensor { if wide { - let exp_tensor: Tensor = exp_upcast::(*z); + let exp_tensor: Tensor = exp_upcast::(*z); let sum = exp_tensor.reduce_sum(axis, true); let softmax: Tensor = div_downcast(@exp_tensor, @sum); diff --git a/src/operators/nn/implementations/nn_fp16x16.cairo b/src/operators/nn/implementations/nn_fp16x16.cairo index 46f8b7f6d..f5c27d6cc 100644 --- a/src/operators/nn/implementations/nn_fp16x16.cairo +++ b/src/operators/nn/implementations/nn_fp16x16.cairo @@ -3,8 +3,8 @@ use core::option::OptionTrait; use orion::operators::tensor::core::Tensor; use orion::operators::nn::core::NNTrait; use orion::operators::nn::functional; -use orion::numbers::{FP16x16, FP16x16IntoFP64x64, FP64x64, FP64x64Impl}; -use orion::operators::tensor::{FP16x16Tensor, FP16x16TensorDiv, FP16x16TensorAdd, FP64x64Tensor}; +use orion::numbers::{FP16x16, FP16x16Wide, FP16x16IntoFP16x16Wide, FP64x64Impl}; +use orion::operators::tensor::{FP16x16Tensor, FP16x16WideTensor, FP16x16TensorDiv, FP16x16TensorAdd, FP64x64Tensor}; impl FP16x16NN of NNTrait { fn relu(tensor: @Tensor) -> Tensor { @@ -16,7 +16,7 @@ impl FP16x16NN of NNTrait { } fn softmax(tensor: @Tensor, axis: usize, wide: bool) -> Tensor { - functional::softmax::softmax::(tensor, axis, wide) + functional::softmax::softmax::(tensor, axis, wide) } fn logsoftmax(tensor: @Tensor, axis: usize) -> Tensor { diff --git a/src/operators/tensor.cairo b/src/operators/tensor.cairo index d0feb8a32..fed10f3aa 100644 --- a/src/operators/tensor.cairo +++ b/src/operators/tensor.cairo @@ -17,6 +17,12 @@ use orion::operators::tensor::implementations::tensor_fp16x16::{ FP16x16TensorPartialEq, }; +use orion::operators::tensor::implementations::tensor_fp16x16wide::{ + FP16x16WideTensor, FP16x16TensorAdd as FP16x16WideTensorAdd, + FP16x16TensorSub as FP16x16WideTensorSub, FP16x16TensorMul as FP16x16WideTensorMul, + FP16x16TensorDiv as FP16x16WideTensorDiv, FP16x16TensorPartialEq as FP16x16WideTensorPartialEq, +}; + use orion::operators::tensor::implementations::tensor_fp32x32::{ FP32x32Tensor, FP32x32TensorAdd, FP32x32TensorSub, FP32x32TensorMul, FP32x32TensorDiv, FP32x32TensorPartialEq, diff --git a/src/operators/tensor/implementations.cairo b/src/operators/tensor/implementations.cairo index a585b88a7..4a5d2cab3 100644 --- a/src/operators/tensor/implementations.cairo +++ b/src/operators/tensor/implementations.cairo @@ -5,3 +5,4 @@ mod tensor_fp8x23; mod tensor_fp16x16; mod tensor_fp64x64; mod tensor_fp32x32; +mod tensor_fp16x16wide; diff --git a/src/operators/tensor/implementations/tensor_fp16x16wide.cairo b/src/operators/tensor/implementations/tensor_fp16x16wide.cairo new file mode 100644 index 000000000..d7a2ee6e6 --- /dev/null +++ b/src/operators/tensor/implementations/tensor_fp16x16wide.cairo @@ -0,0 +1,374 @@ +use array::ArrayTrait; +use array::SpanTrait; +use option::OptionTrait; +use traits::{TryInto, Into}; + +use orion::numbers::fixed_point::core::FixedTrait; +use orion::operators::tensor::core::{ + new_tensor, stride, Tensor, TensorTrait, ravel_index, unravel_index, reshape, at_tensor, +}; +use orion::operators::tensor::{math, linalg, quantization, core}; +use orion::numbers::{i8, i32, NumberTrait}; +use orion::numbers::fixed_point::implementations::fp16x16wide::core::{FP16x16Wide}; +use orion::operators::tensor::implementations::{tensor_i8::I8Tensor, tensor_u32::U32Tensor}; + +impl FP16x16WideTensor of TensorTrait { + fn new(shape: Span, data: Span) -> Tensor { + new_tensor(shape, data) + } + + fn at(self: @Tensor, indices: Span) -> FP16x16Wide { + *at_tensor(self, indices) + } + + fn min(self: @Tensor) -> FP16x16Wide { + math::min::min_in_tensor::(*self.data) + } + + fn max(self: @Tensor) -> FP16x16Wide { + math::max::max_in_tensor(*self.data) + } + + fn stride(self: @Tensor) -> Span { + stride(*self.shape) + } + + fn ravel_index(self: @Tensor, indices: Span) -> usize { + ravel_index(*self.shape, indices) + } + + fn unravel_index(self: @Tensor, index: usize) -> Span { + unravel_index(index, *self.shape) + } + + fn reshape(self: @Tensor, target_shape: Span) -> Tensor { + reshape(self, target_shape) + } + + fn reduce_sum(self: @Tensor, axis: usize, keepdims: bool) -> Tensor { + math::reduce_sum::reduce_sum(self, axis, keepdims) + } + + fn argmax( + self: @Tensor, axis: usize, keepdims: Option, select_last_index: Option + ) -> Tensor { + math::argmax::argmax(self, axis, keepdims, select_last_index) + } + + fn argmin( + self: @Tensor, axis: usize, keepdims: Option, select_last_index: Option + ) -> Tensor { + math::argmin::argmin(self, axis, keepdims, select_last_index) + } + + fn transpose(self: @Tensor, axes: Span) -> Tensor { + linalg::transpose::transpose(self, axes) + } + + fn matmul(self: @Tensor, other: @Tensor) -> Tensor { + linalg::matmul::matmul(self, other) + } + + fn exp(self: @Tensor) -> Tensor { + math::exp::exp(*self) + } + + fn log(self: @Tensor) -> Tensor { + math::log::log(*self) + } + + fn equal(self: @Tensor, other: @Tensor) -> Tensor { + math::equal::equal(self, other) + } + + fn greater(self: @Tensor, other: @Tensor) -> Tensor { + math::greater::greater(self, other) + } + + fn greater_equal(self: @Tensor, other: @Tensor) -> Tensor { + math::greater_equal::greater_equal(self, other) + } + + fn less(self: @Tensor, other: @Tensor) -> Tensor { + math::less::less(self, other) + } + + fn less_equal(self: @Tensor, other: @Tensor) -> Tensor { + math::less_equal::less_equal(self, other) + } + + fn abs(self: @Tensor) -> Tensor { + math::abs::abs(*self) + } + + fn ceil(self: @Tensor) -> Tensor { + math::ceil::ceil(*self) + } + + fn sin(self: @Tensor) -> Tensor { + math::sin::sin(*self) + } + + fn cos(self: @Tensor) -> Tensor { + math::cos::cos(*self) + } + + fn asin(self: @Tensor) -> Tensor { + math::asin::asin(*self) + } + + fn cumsum( + self: @Tensor, axis: usize, exclusive: Option, reverse: Option + ) -> Tensor { + math::cumsum::cumsum(self, axis, exclusive, reverse) + } + + fn flatten(self: @Tensor, axis: usize) -> Tensor { + math::flatten::flatten(self, axis) + } + + fn sinh(self: @Tensor) -> Tensor { + math::sinh::sinh(*self) + } + + fn tanh(self: @Tensor) -> Tensor { + math::tanh::tanh(*self) + } + + fn cosh(self: @Tensor) -> Tensor { + math::cosh::cosh(*self) + } + + fn acosh(self: @Tensor) -> Tensor { + math::acosh::acosh(*self) + } + + fn asinh(self: @Tensor) -> Tensor { + math::asinh::asinh(*self) + } + + fn atan(self: @Tensor) -> Tensor { + math::atan::atan(*self) + } + + fn xor(self: @Tensor, other: @Tensor) -> Tensor { + math::xor::xor(self, other) + } + + fn or(self: @Tensor, other: @Tensor) -> Tensor { + math::or::or(self, other) + } + + fn acos(self: @Tensor) -> Tensor { + math::acos::acos(*self) + } + + fn onehot( + self: @Tensor, depth: usize, axis: Option, values: Span + ) -> Tensor { + panic(array!['not supported!']) + } + + fn sqrt(self: @Tensor) -> Tensor { + math::sqrt::sqrt(*self) + } + + fn concat(tensors: Span>, axis: usize,) -> Tensor { + math::concat::concat(tensors, axis) + } + + fn quantize_linear( + self: @Tensor, y_scale: @Tensor, y_zero_point: @Tensor + ) -> Tensor:: { + quantization::quantize_linear::quantize_linear( + self, + y_scale, + y_zero_point, + NumberTrait::new_unscaled(128, true), + NumberTrait::new_unscaled(127, false) + ) + } + + fn dequantize_linear( + self: @Tensor, x_scale: @Tensor, x_zero_point: @Tensor + ) -> Tensor:: { + panic(array!['not supported!']) + } + + fn slice( + self: @Tensor, + starts: Span, + ends: Span, + axes: Option>, + steps: Option> + ) -> Tensor { + core::slice::(self, starts, ends, axes, steps) + } + + fn gather( + self: @Tensor, indices: Tensor, axis: Option + ) -> Tensor { + math::gather::gather(self, indices, axis) + } + + fn nonzero(self: @Tensor) -> Tensor { + core::nonzero(self) + } + + fn squeeze(self: @Tensor, axes: Option>) -> Tensor { + core::squeeze(self, axes) + } + + fn unsqueeze(self: @Tensor, axes: Span) -> Tensor { + core::unsqueeze(self, axes) + } + + fn sign(self: @Tensor) -> Tensor { + math::sign::sign(*self) + } + + fn clip(self: @Tensor, min: Option, max: Option) -> Tensor { + core::clip(self, min, max) + } +} + +/// Implements addition for `Tensor` using the `Add` trait. +impl FP16x16TensorAdd of Add> { + /// Adds two `Tensor` instances element-wise. + /// + /// # Arguments + /// * `lhs` - The first tensor. + /// * `rhs` - The second tensor. + /// + /// # Returns + /// * A `Tensor` instance representing the result of the element-wise addition. + fn add(lhs: Tensor, rhs: Tensor) -> Tensor { + math::arithmetic::add(@lhs, @rhs) + } +} + +/// Implements subtraction for `Tensor` using the `Sub` trait. +impl FP16x16TensorSub of Sub> { + /// Subtracts two `Tensor` instances element-wise. + /// + /// # Arguments + /// * `lhs` - The first tensor. + /// * `rhs` - The second tensor. + /// + /// # Returns + /// * A `Tensor` instance representing the result of the element-wise subtraction. + fn sub(lhs: Tensor, rhs: Tensor) -> Tensor { + math::arithmetic::sub(@lhs, @rhs) + } +} + +/// Implements multiplication for `Tensor` using the `Mul` trait. +impl FP16x16TensorMul of Mul> { + /// Multiplies two `Tensor` instances element-wise. + /// + /// # Arguments + /// * `lhs` - The first tensor. + /// * `rhs` - The second tensor. + /// + /// # Returns + /// * A `Tensor` instance representing the result of the element-wise multiplication. + fn mul(lhs: Tensor, rhs: Tensor) -> Tensor { + math::arithmetic::mul(@lhs, @rhs) + } +} + +/// Implements division for `Tensor` using the `Div` trait. +impl FP16x16TensorDiv of Div> { + /// Divides two `Tensor` instances element-wise. + /// + /// # Arguments + /// * `lhs` - The first tensor. + /// * `rhs` - The second tensor. + /// + /// # Returns + /// * A `Tensor` instance representing the result of the element-wise division. + fn div(lhs: Tensor, rhs: Tensor) -> Tensor { + math::arithmetic::div(@lhs, @rhs) + } +} + +/// Implements partial equal for two `Tensor` using the `PartialEq` trait. +impl FP16x16TensorPartialEq of PartialEq> { + fn eq(lhs: @Tensor, rhs: @Tensor) -> bool { + tensor_eq(*lhs, *rhs) + } + + fn ne(lhs: @Tensor, rhs: @Tensor) -> bool { + !tensor_eq(*lhs, *rhs) + } +} + +impl U32TryIntoU32 of TryInto { + fn try_into(self: u64) -> Option { + Option::Some(self) + } +} + +// impl TensorI8IntoTensorFP16x16 of Into, Tensor> { +// fn into(self: Tensor) -> Tensor { +// tensor_i8_to_tensor_fp16x16(@self) +// } +// } + + +// Internals +const PRECISION: u64 = 589; // 0.009 + +fn relative_eq(lhs: @FP16x16Wide, rhs: @FP16x16Wide) -> bool { + let diff = *lhs - *rhs; + + let rel_diff = if *lhs.mag != 0 { + (diff / *lhs).mag + } else { + diff.mag + }; + + rel_diff <= PRECISION +} + + +fn tensor_eq(mut lhs: Tensor, mut rhs: Tensor,) -> bool { + let mut is_eq = true; + + loop { + if lhs.shape.len() == 0 || !is_eq { + break; + } + + is_eq = lhs.shape.pop_front().unwrap() == rhs.shape.pop_front().unwrap(); + }; + + if !is_eq { + return false; + } + + loop { + if lhs.data.len() == 0 || !is_eq { + break; + } + + is_eq = relative_eq(lhs.data.pop_front().unwrap(), rhs.data.pop_front().unwrap()); + }; + + return is_eq; +} + +// fn tensor_i8_to_tensor_fp16x16(x: @Tensor) -> Tensor { +// let mut result_data = ArrayTrait::::new(); +// let mut data = *x.data; + +// loop { +// result_data.append((*data.pop_front().unwrap()).into()); + +// if data.len() == 0 { +// break (); +// }; +// }; + +// return TensorTrait::new(*x.shape, result_data.span()); +// }