From e2eba4bd0883954c5eba7ace94b0358b80c00aa3 Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Mon, 16 Oct 2023 12:49:43 +0300 Subject: [PATCH 01/42] add xgboost regressor --- src/operators/ml.cairo | 3 +- src/operators/ml/xgboost_regressor.cairo | 1 + src/operators/ml/xgboost_regressor/core.cairo | 36 +++++++++++++++++++ 3 files changed, 39 insertions(+), 1 deletion(-) create mode 100644 src/operators/ml/xgboost_regressor.cairo create mode 100644 src/operators/ml/xgboost_regressor/core.cairo diff --git a/src/operators/ml.cairo b/src/operators/ml.cairo index b3eaef500..143cba273 100644 --- a/src/operators/ml.cairo +++ b/src/operators/ml.cairo @@ -1,6 +1,7 @@ mod tree_regressor; +mod xgboost_regressor; -use orion::operators::ml::tree_regressor::core::TreeRegressorTrait; +use orion::operators::ml::tree_regressor::core::{TreeRegressorTrait, TreeNode}; use orion::operators::ml::tree_regressor::implementations::tree_regressor_fp16x16::FP16x16TreeRegressor; use orion::operators::ml::tree_regressor::implementations::tree_regressor_fp8x23::FP8x23TreeRegressor; use orion::operators::ml::tree_regressor::implementations::tree_regressor_fp32x32::FP32x32TreeRegressor; diff --git a/src/operators/ml/xgboost_regressor.cairo b/src/operators/ml/xgboost_regressor.cairo new file mode 100644 index 000000000..ef33ab296 --- /dev/null +++ b/src/operators/ml/xgboost_regressor.cairo @@ -0,0 +1 @@ +mod core; \ No newline at end of file diff --git a/src/operators/ml/xgboost_regressor/core.cairo b/src/operators/ml/xgboost_regressor/core.cairo new file mode 100644 index 000000000..d1bc87e04 --- /dev/null +++ b/src/operators/ml/xgboost_regressor/core.cairo @@ -0,0 +1,36 @@ +use orion::operators::ml::{TreeNode, TreeRegressorTrait}; +use orion::numbers::FixedTrait; + + +trait XGBoostPredictorTrait { + fn predict(trees: Span>, features: Span, weights: Span) -> T; +} + +fn predict< + T, + MAG, + impl TFixed: FixedTrait, + impl TTreeRegressor: TreeRegressorTrait, + impl TMul: Mul, + impl TAddEq: AddEq, + impl TCopy: Copy, + impl TDrop: Drop, +>( + ref trees: Span>, ref features: Span, ref weights: Span +) -> T { + let mut sum_prediction: T = FixedTrait::ZERO(); + + loop { + match trees.pop_front() { + Option::Some(tree) => { + let mut tree = *tree; + sum_prediction += tree.predict(features) * *weights.pop_front().unwrap() + }, + Option::None(_) => { + break; + } + }; + }; + + sum_prediction +} From 62a4844936be048080db174e808c471edb1be5d6 Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Mon, 16 Oct 2023 15:04:43 +0300 Subject: [PATCH 02/42] implement tree classifier --- src/operators/ml.cairo | 1 + src/operators/ml/tree_classifier.cairo | 2 + src/operators/ml/tree_classifier/core.cairo | 94 +++++++++++++++++++ .../ml/tree_classifier/implementations.cairo | 4 + .../tree_classifier_fp16x16.cairo | 13 +++ .../tree_classifier_fp32x32.cairo | 13 +++ .../tree_classifier_fp64x64.cairo | 13 +++ .../tree_classifier_fp8x23.cairo | 13 +++ 8 files changed, 153 insertions(+) create mode 100644 src/operators/ml/tree_classifier.cairo create mode 100644 src/operators/ml/tree_classifier/core.cairo create mode 100644 src/operators/ml/tree_classifier/implementations.cairo create mode 100644 src/operators/ml/tree_classifier/implementations/tree_classifier_fp16x16.cairo create mode 100644 src/operators/ml/tree_classifier/implementations/tree_classifier_fp32x32.cairo create mode 100644 src/operators/ml/tree_classifier/implementations/tree_classifier_fp64x64.cairo create mode 100644 src/operators/ml/tree_classifier/implementations/tree_classifier_fp8x23.cairo diff --git a/src/operators/ml.cairo b/src/operators/ml.cairo index 143cba273..1f4e0dad4 100644 --- a/src/operators/ml.cairo +++ b/src/operators/ml.cairo @@ -1,4 +1,5 @@ mod tree_regressor; +mod tree_classifier; mod xgboost_regressor; use orion::operators::ml::tree_regressor::core::{TreeRegressorTrait, TreeNode}; diff --git a/src/operators/ml/tree_classifier.cairo b/src/operators/ml/tree_classifier.cairo new file mode 100644 index 000000000..2ab1a62ac --- /dev/null +++ b/src/operators/ml/tree_classifier.cairo @@ -0,0 +1,2 @@ +mod core; +mod implementations; \ No newline at end of file diff --git a/src/operators/ml/tree_classifier/core.cairo b/src/operators/ml/tree_classifier/core.cairo new file mode 100644 index 000000000..c8c4a2c7f --- /dev/null +++ b/src/operators/ml/tree_classifier/core.cairo @@ -0,0 +1,94 @@ +use orion::numbers::{FixedTrait}; + +#[derive(Copy, Drop)] +struct TreeNode { + left: Option>>, + right: Option>>, + split_feature: usize, + split_value: T, + prediction: T, + class_distribution: Span, // assuming class labels of type usize (span index), and probability as T. +} + +/// Trait +/// +/// predict - Given a set of features, predicts the target value using the constructed decision tree. +/// predict_proba - Given a set of features, predicts the probability of each X example being of a given class.. +trait TreeClassifierTrait { + fn predict(ref self: TreeNode, features: Span) -> T; + fn predict_proba(ref self: TreeNode, features: Span) -> Span; +} + +fn predict< + T, + MAG, + impl FFixedTrait: FixedTrait, + impl TPartialOrd: PartialOrd, + impl FCopy: Copy, + impl FDrop: Drop, +>( + ref self: TreeNode, features: Span +) -> T { + let mut current_node: TreeNode = self; + + loop { + match current_node.left { + Option::Some(left) => { + match current_node.right { + Option::Some(right) => { + if *features.at(current_node.split_feature) < current_node.split_value { + current_node = left.unbox(); + } else { + current_node = right.unbox(); + } + }, + Option::None(_) => { + break; + } + } + }, + Option::None(_) => { + break; + } + }; + }; + + current_node.prediction +} + +fn predict_proba< + T, + MAG, + impl FFixedTrait: FixedTrait, + impl TPartialOrd: PartialOrd, + impl FCopy: Copy, + impl FDrop: Drop, +>( + ref self: TreeNode, features: Span +) -> Span { + let mut current_node: TreeNode = self; + + loop { + match current_node.left { + Option::Some(left) => { + match current_node.right { + Option::Some(right) => { + if *features.at(current_node.split_feature) < current_node.split_value { + current_node = left.unbox(); + } else { + current_node = right.unbox(); + } + }, + Option::None(_) => { + break; + } + } + }, + Option::None(_) => { + break; + } + }; + }; + + current_node.class_distribution +} diff --git a/src/operators/ml/tree_classifier/implementations.cairo b/src/operators/ml/tree_classifier/implementations.cairo new file mode 100644 index 000000000..2421c7809 --- /dev/null +++ b/src/operators/ml/tree_classifier/implementations.cairo @@ -0,0 +1,4 @@ +mod tree_classifier_fp8x23; +mod tree_classifier_fp16x16; +mod tree_classifier_fp32x32; +mod tree_classifier_fp64x64; diff --git a/src/operators/ml/tree_classifier/implementations/tree_classifier_fp16x16.cairo b/src/operators/ml/tree_classifier/implementations/tree_classifier_fp16x16.cairo new file mode 100644 index 000000000..579c18928 --- /dev/null +++ b/src/operators/ml/tree_classifier/implementations/tree_classifier_fp16x16.cairo @@ -0,0 +1,13 @@ +use orion::operators::ml::tree_classifier::core::{TreeNode, TreeClassifierTrait}; +use orion::operators::ml::tree_classifier::core; +use orion::numbers::FP16x16; + +impl FP16x16TreeClassifier of TreeClassifierTrait { + fn predict(ref self: TreeNode, features: Span) -> FP16x16 { + core::predict(ref self, features) + } + + fn predict_proba(ref self: TreeNode, features: Span) -> Span { + core::predict_proba(ref self, features) + } +} diff --git a/src/operators/ml/tree_classifier/implementations/tree_classifier_fp32x32.cairo b/src/operators/ml/tree_classifier/implementations/tree_classifier_fp32x32.cairo new file mode 100644 index 000000000..c10a0c82f --- /dev/null +++ b/src/operators/ml/tree_classifier/implementations/tree_classifier_fp32x32.cairo @@ -0,0 +1,13 @@ +use orion::operators::ml::tree_classifier::core::{TreeNode, TreeClassifierTrait}; +use orion::operators::ml::tree_classifier::core; +use orion::numbers::{FP32x32, FP32x32Impl}; + +impl FP32x32TreeClassifier of TreeClassifierTrait { + fn predict(ref self: TreeNode, features: Span) -> FP32x32 { + core::predict(ref self, features) + } + + fn predict_proba(ref self: TreeNode, features: Span) -> Span { + core::predict_proba(ref self, features) + } +} diff --git a/src/operators/ml/tree_classifier/implementations/tree_classifier_fp64x64.cairo b/src/operators/ml/tree_classifier/implementations/tree_classifier_fp64x64.cairo new file mode 100644 index 000000000..ce3e6541a --- /dev/null +++ b/src/operators/ml/tree_classifier/implementations/tree_classifier_fp64x64.cairo @@ -0,0 +1,13 @@ +use orion::operators::ml::tree_classifier::core::{TreeNode, TreeClassifierTrait}; +use orion::operators::ml::tree_classifier::core; +use orion::numbers::{FP64x64, FP64x64Impl}; + +impl FP64x64TreeClassifier of TreeClassifierTrait { + fn predict(ref self: TreeNode, features: Span) -> FP64x64 { + core::predict(ref self, features) + } + + fn predict_proba(ref self: TreeNode, features: Span) -> Span { + core::predict_proba(ref self, features) + } +} diff --git a/src/operators/ml/tree_classifier/implementations/tree_classifier_fp8x23.cairo b/src/operators/ml/tree_classifier/implementations/tree_classifier_fp8x23.cairo new file mode 100644 index 000000000..88eaf0fc4 --- /dev/null +++ b/src/operators/ml/tree_classifier/implementations/tree_classifier_fp8x23.cairo @@ -0,0 +1,13 @@ +use orion::operators::ml::tree_classifier::core::{TreeNode, TreeClassifierTrait}; +use orion::operators::ml::tree_classifier::core; +use orion::numbers::FP8x23; + +impl FP8x23TreeClassifier of TreeClassifierTrait { + fn predict(ref self: TreeNode, features: Span) -> FP8x23 { + core::predict(ref self, features) + } + + fn predict_proba(ref self: TreeNode, features: Span) -> Span { + core::predict_proba(ref self, features) + } +} From 4e8d686d7eb483d9823913f13bdb4b1cafd65505 Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Mon, 16 Oct 2023 15:14:21 +0300 Subject: [PATCH 03/42] rename trees struct --- src/operators/ml.cairo | 8 +++++- src/operators/ml/tree_classifier/core.cairo | 18 ++++++------- .../tree_classifier_fp16x16.cairo | 6 ++--- .../tree_classifier_fp32x32.cairo | 6 ++--- .../tree_classifier_fp64x64.cairo | 6 ++--- .../tree_classifier_fp8x23.cairo | 6 ++--- src/operators/ml/tree_regressor/core.cairo | 26 +++++++++---------- .../tree_regressor_fp16x16.cairo | 6 ++--- .../tree_regressor_fp32x32.cairo | 6 ++--- .../tree_regressor_fp64x64.cairo | 6 ++--- .../tree_regressor_fp8x23.cairo | 6 ++--- src/operators/ml/xgboost_regressor/core.cairo | 6 ++--- 12 files changed, 56 insertions(+), 50 deletions(-) diff --git a/src/operators/ml.cairo b/src/operators/ml.cairo index 1f4e0dad4..0a29e00c2 100644 --- a/src/operators/ml.cairo +++ b/src/operators/ml.cairo @@ -2,8 +2,14 @@ mod tree_regressor; mod tree_classifier; mod xgboost_regressor; -use orion::operators::ml::tree_regressor::core::{TreeRegressorTrait, TreeNode}; +use orion::operators::ml::tree_regressor::core::{TreeRegressorTrait, TreeRegressor}; use orion::operators::ml::tree_regressor::implementations::tree_regressor_fp16x16::FP16x16TreeRegressor; use orion::operators::ml::tree_regressor::implementations::tree_regressor_fp8x23::FP8x23TreeRegressor; use orion::operators::ml::tree_regressor::implementations::tree_regressor_fp32x32::FP32x32TreeRegressor; use orion::operators::ml::tree_regressor::implementations::tree_regressor_fp64x64::FP64x64TreeRegressor; + +use orion::operators::ml::tree_classifier::core::{TreeClassifierTrait, TreeClassifier}; +use orion::operators::ml::tree_classifier::implementations::tree_classifier_fp16x16::FP16x16TreeClassifier; +use orion::operators::ml::tree_classifier::implementations::tree_classifier_fp8x23::FP8x23TreeClassifier; +use orion::operators::ml::tree_classifier::implementations::tree_classifier_fp32x32::FP32x32TreeClassifier; +use orion::operators::ml::tree_classifier::implementations::tree_classifier_fp64x64::FP64x64TreeClassifier; diff --git a/src/operators/ml/tree_classifier/core.cairo b/src/operators/ml/tree_classifier/core.cairo index c8c4a2c7f..669c23abd 100644 --- a/src/operators/ml/tree_classifier/core.cairo +++ b/src/operators/ml/tree_classifier/core.cairo @@ -1,9 +1,9 @@ use orion::numbers::{FixedTrait}; #[derive(Copy, Drop)] -struct TreeNode { - left: Option>>, - right: Option>>, +struct TreeClassifier { + left: Option>>, + right: Option>>, split_feature: usize, split_value: T, prediction: T, @@ -15,8 +15,8 @@ struct TreeNode { /// predict - Given a set of features, predicts the target value using the constructed decision tree. /// predict_proba - Given a set of features, predicts the probability of each X example being of a given class.. trait TreeClassifierTrait { - fn predict(ref self: TreeNode, features: Span) -> T; - fn predict_proba(ref self: TreeNode, features: Span) -> Span; + fn predict(ref self: TreeClassifier, features: Span) -> T; + fn predict_proba(ref self: TreeClassifier, features: Span) -> Span; } fn predict< @@ -27,9 +27,9 @@ fn predict< impl FCopy: Copy, impl FDrop: Drop, >( - ref self: TreeNode, features: Span + ref self: TreeClassifier, features: Span ) -> T { - let mut current_node: TreeNode = self; + let mut current_node: TreeClassifier = self; loop { match current_node.left { @@ -64,9 +64,9 @@ fn predict_proba< impl FCopy: Copy, impl FDrop: Drop, >( - ref self: TreeNode, features: Span + ref self: TreeClassifier, features: Span ) -> Span { - let mut current_node: TreeNode = self; + let mut current_node: TreeClassifier = self; loop { match current_node.left { diff --git a/src/operators/ml/tree_classifier/implementations/tree_classifier_fp16x16.cairo b/src/operators/ml/tree_classifier/implementations/tree_classifier_fp16x16.cairo index 579c18928..1789c8a64 100644 --- a/src/operators/ml/tree_classifier/implementations/tree_classifier_fp16x16.cairo +++ b/src/operators/ml/tree_classifier/implementations/tree_classifier_fp16x16.cairo @@ -1,13 +1,13 @@ -use orion::operators::ml::tree_classifier::core::{TreeNode, TreeClassifierTrait}; +use orion::operators::ml::tree_classifier::core::{TreeClassifier, TreeClassifierTrait}; use orion::operators::ml::tree_classifier::core; use orion::numbers::FP16x16; impl FP16x16TreeClassifier of TreeClassifierTrait { - fn predict(ref self: TreeNode, features: Span) -> FP16x16 { + fn predict(ref self: TreeClassifier, features: Span) -> FP16x16 { core::predict(ref self, features) } - fn predict_proba(ref self: TreeNode, features: Span) -> Span { + fn predict_proba(ref self: TreeClassifier, features: Span) -> Span { core::predict_proba(ref self, features) } } diff --git a/src/operators/ml/tree_classifier/implementations/tree_classifier_fp32x32.cairo b/src/operators/ml/tree_classifier/implementations/tree_classifier_fp32x32.cairo index c10a0c82f..442fb100a 100644 --- a/src/operators/ml/tree_classifier/implementations/tree_classifier_fp32x32.cairo +++ b/src/operators/ml/tree_classifier/implementations/tree_classifier_fp32x32.cairo @@ -1,13 +1,13 @@ -use orion::operators::ml::tree_classifier::core::{TreeNode, TreeClassifierTrait}; +use orion::operators::ml::tree_classifier::core::{TreeClassifier, TreeClassifierTrait}; use orion::operators::ml::tree_classifier::core; use orion::numbers::{FP32x32, FP32x32Impl}; impl FP32x32TreeClassifier of TreeClassifierTrait { - fn predict(ref self: TreeNode, features: Span) -> FP32x32 { + fn predict(ref self: TreeClassifier, features: Span) -> FP32x32 { core::predict(ref self, features) } - fn predict_proba(ref self: TreeNode, features: Span) -> Span { + fn predict_proba(ref self: TreeClassifier, features: Span) -> Span { core::predict_proba(ref self, features) } } diff --git a/src/operators/ml/tree_classifier/implementations/tree_classifier_fp64x64.cairo b/src/operators/ml/tree_classifier/implementations/tree_classifier_fp64x64.cairo index ce3e6541a..61c9415ec 100644 --- a/src/operators/ml/tree_classifier/implementations/tree_classifier_fp64x64.cairo +++ b/src/operators/ml/tree_classifier/implementations/tree_classifier_fp64x64.cairo @@ -1,13 +1,13 @@ -use orion::operators::ml::tree_classifier::core::{TreeNode, TreeClassifierTrait}; +use orion::operators::ml::tree_classifier::core::{TreeClassifier, TreeClassifierTrait}; use orion::operators::ml::tree_classifier::core; use orion::numbers::{FP64x64, FP64x64Impl}; impl FP64x64TreeClassifier of TreeClassifierTrait { - fn predict(ref self: TreeNode, features: Span) -> FP64x64 { + fn predict(ref self: TreeClassifier, features: Span) -> FP64x64 { core::predict(ref self, features) } - fn predict_proba(ref self: TreeNode, features: Span) -> Span { + fn predict_proba(ref self: TreeClassifier, features: Span) -> Span { core::predict_proba(ref self, features) } } diff --git a/src/operators/ml/tree_classifier/implementations/tree_classifier_fp8x23.cairo b/src/operators/ml/tree_classifier/implementations/tree_classifier_fp8x23.cairo index 88eaf0fc4..01f548efe 100644 --- a/src/operators/ml/tree_classifier/implementations/tree_classifier_fp8x23.cairo +++ b/src/operators/ml/tree_classifier/implementations/tree_classifier_fp8x23.cairo @@ -1,13 +1,13 @@ -use orion::operators::ml::tree_classifier::core::{TreeNode, TreeClassifierTrait}; +use orion::operators::ml::tree_classifier::core::{TreeClassifier, TreeClassifierTrait}; use orion::operators::ml::tree_classifier::core; use orion::numbers::FP8x23; impl FP8x23TreeClassifier of TreeClassifierTrait { - fn predict(ref self: TreeNode, features: Span) -> FP8x23 { + fn predict(ref self: TreeClassifier, features: Span) -> FP8x23 { core::predict(ref self, features) } - fn predict_proba(ref self: TreeNode, features: Span) -> Span { + fn predict_proba(ref self: TreeClassifier, features: Span) -> Span { core::predict_proba(ref self, features) } } diff --git a/src/operators/ml/tree_regressor/core.cairo b/src/operators/ml/tree_regressor/core.cairo index d81557dc2..ea4ac0271 100644 --- a/src/operators/ml/tree_regressor/core.cairo +++ b/src/operators/ml/tree_regressor/core.cairo @@ -3,9 +3,9 @@ use cubit::f64::procgen::rand::u64_between; use orion::numbers::{FixedTrait}; #[derive(Copy, Drop)] -struct TreeNode { - left: Option>>, - right: Option>>, +struct TreeRegressor { + left: Option>>, + right: Option>>, split_feature: usize, split_value: T, prediction: T, @@ -19,7 +19,7 @@ trait TreeRegressorTrait { /// # TreeRegressorTrait::fit /// /// ```rust - /// fn fit(data: Span>, target: Span, max_depth: usize, random_state: usize) -> TreeNode; + /// fn fit(data: Span>, target: Span, max_depth: usize, random_state: usize) -> TreeRegressor; /// ``` /// /// Builds a decision tree based on the provided data and target values up to a specified maximum depth. @@ -33,7 +33,7 @@ trait TreeRegressorTrait { /// /// ## Returns /// - /// A `TreeNode` representing the root of the constructed decision tree. + /// A `TreeRegressor` representing the root of the constructed decision tree. /// /// ## Type Constraints /// @@ -69,11 +69,11 @@ trait TreeRegressorTrait { /// fn fit( data: Span>, target: Span, max_depth: usize, random_state: usize - ) -> TreeNode; + ) -> TreeRegressor; /// # TreeRegressorTrait::predict /// /// ```rust - /// fn predict(ref self: TreeNode, features: Span) -> T; + /// fn predict(ref self: TreeRegressor, features: Span) -> T; /// ``` /// /// Predicts the target value for a set of features using the provided decision tree. @@ -124,7 +124,7 @@ trait TreeRegressorTrait { /// } /// ``` /// - fn predict(ref self: TreeNode, features: Span) -> T; + fn predict(ref self: TreeRegressor, features: Span) -> T; } fn predict< @@ -135,9 +135,9 @@ fn predict< impl FCopy: Copy, impl FDrop: Drop, >( - ref self: TreeNode, features: Span + ref self: TreeRegressor, features: Span ) -> T { - let mut current_node: TreeNode = self; + let mut current_node: TreeRegressor = self; loop { match current_node.left { @@ -363,7 +363,7 @@ fn fit< impl TDrop: Drop, >( data: Span>, target: Span, depth: usize, max_depth: usize, random_state: usize -) -> TreeNode { +) -> TreeRegressor { if depth == max_depth || data.len() < 2 { let mut total = FixedTrait::ZERO(); let mut target_copy = target; @@ -377,7 +377,7 @@ fn fit< } }; }; - return TreeNode { + return TreeRegressor { left: Option::None(()), right: Option::None(()), split_feature: 0, @@ -413,7 +413,7 @@ fn fit< }; }; - TreeNode { + TreeRegressor { left: Option::Some( BoxTrait::new( fit(left_data.span(), left_target.span(), depth + 1, max_depth, random_state) diff --git a/src/operators/ml/tree_regressor/implementations/tree_regressor_fp16x16.cairo b/src/operators/ml/tree_regressor/implementations/tree_regressor_fp16x16.cairo index 3cb35ab13..7aeb6eb69 100644 --- a/src/operators/ml/tree_regressor/implementations/tree_regressor_fp16x16.cairo +++ b/src/operators/ml/tree_regressor/implementations/tree_regressor_fp16x16.cairo @@ -1,15 +1,15 @@ -use orion::operators::ml::tree_regressor::core::{TreeNode, TreeRegressorTrait}; +use orion::operators::ml::tree_regressor::core::{TreeRegressor, TreeRegressorTrait}; use orion::operators::ml::tree_regressor::core; use orion::numbers::FP16x16; impl FP16x16TreeRegressor of TreeRegressorTrait { fn fit( data: Span>, target: Span, max_depth: usize, random_state: usize - ) -> TreeNode { + ) -> TreeRegressor { core::fit(data, target, 0, max_depth, random_state) } - fn predict(ref self: TreeNode, features: Span) -> FP16x16 { + fn predict(ref self: TreeRegressor, features: Span) -> FP16x16 { core::predict(ref self, features) } } diff --git a/src/operators/ml/tree_regressor/implementations/tree_regressor_fp32x32.cairo b/src/operators/ml/tree_regressor/implementations/tree_regressor_fp32x32.cairo index d1791a9c9..288d7e15d 100644 --- a/src/operators/ml/tree_regressor/implementations/tree_regressor_fp32x32.cairo +++ b/src/operators/ml/tree_regressor/implementations/tree_regressor_fp32x32.cairo @@ -1,15 +1,15 @@ -use orion::operators::ml::tree_regressor::core::{TreeNode, TreeRegressorTrait}; +use orion::operators::ml::tree_regressor::core::{TreeRegressor, TreeRegressorTrait}; use orion::operators::ml::tree_regressor::core; use orion::numbers::{FP32x32, FP32x32Impl}; impl FP32x32TreeRegressor of TreeRegressorTrait { fn fit( data: Span>, target: Span, max_depth: usize, random_state: usize - ) -> TreeNode { + ) -> TreeRegressor { core::fit(data, target, 0, max_depth, random_state) } - fn predict(ref self: TreeNode, features: Span) -> FP32x32 { + fn predict(ref self: TreeRegressor, features: Span) -> FP32x32 { core::predict(ref self, features) } } diff --git a/src/operators/ml/tree_regressor/implementations/tree_regressor_fp64x64.cairo b/src/operators/ml/tree_regressor/implementations/tree_regressor_fp64x64.cairo index 54adb6ce4..9102428fc 100644 --- a/src/operators/ml/tree_regressor/implementations/tree_regressor_fp64x64.cairo +++ b/src/operators/ml/tree_regressor/implementations/tree_regressor_fp64x64.cairo @@ -1,15 +1,15 @@ -use orion::operators::ml::tree_regressor::core::{TreeNode, TreeRegressorTrait}; +use orion::operators::ml::tree_regressor::core::{TreeRegressor, TreeRegressorTrait}; use orion::operators::ml::tree_regressor::core; use orion::numbers::{FP64x64, FP64x64Impl}; impl FP64x64TreeRegressor of TreeRegressorTrait { fn fit( data: Span>, target: Span, max_depth: usize, random_state: usize - ) -> TreeNode { + ) -> TreeRegressor { core::fit(data, target, 0, max_depth, random_state) } - fn predict(ref self: TreeNode, features: Span) -> FP64x64 { + fn predict(ref self: TreeRegressor, features: Span) -> FP64x64 { core::predict(ref self, features) } } diff --git a/src/operators/ml/tree_regressor/implementations/tree_regressor_fp8x23.cairo b/src/operators/ml/tree_regressor/implementations/tree_regressor_fp8x23.cairo index baf61096c..54c195704 100644 --- a/src/operators/ml/tree_regressor/implementations/tree_regressor_fp8x23.cairo +++ b/src/operators/ml/tree_regressor/implementations/tree_regressor_fp8x23.cairo @@ -1,15 +1,15 @@ -use orion::operators::ml::tree_regressor::core::{TreeNode, TreeRegressorTrait}; +use orion::operators::ml::tree_regressor::core::{TreeRegressor, TreeRegressorTrait}; use orion::operators::ml::tree_regressor::core; use orion::numbers::FP8x23; impl FP8x23TreeRegressor of TreeRegressorTrait { fn fit( data: Span>, target: Span, max_depth: usize, random_state: usize - ) -> TreeNode { + ) -> TreeRegressor { core::fit(data, target, 0, max_depth, random_state) } - fn predict(ref self: TreeNode, features: Span) -> FP8x23 { + fn predict(ref self: TreeRegressor, features: Span) -> FP8x23 { core::predict(ref self, features) } } diff --git a/src/operators/ml/xgboost_regressor/core.cairo b/src/operators/ml/xgboost_regressor/core.cairo index d1bc87e04..1fb84c8dc 100644 --- a/src/operators/ml/xgboost_regressor/core.cairo +++ b/src/operators/ml/xgboost_regressor/core.cairo @@ -1,9 +1,9 @@ -use orion::operators::ml::{TreeNode, TreeRegressorTrait}; +use orion::operators::ml::{TreeRegressor, TreeRegressorTrait}; use orion::numbers::FixedTrait; trait XGBoostPredictorTrait { - fn predict(trees: Span>, features: Span, weights: Span) -> T; + fn predict(trees: Span>, features: Span, weights: Span) -> T; } fn predict< @@ -16,7 +16,7 @@ fn predict< impl TCopy: Copy, impl TDrop: Drop, >( - ref trees: Span>, ref features: Span, ref weights: Span + ref trees: Span>, ref features: Span, ref weights: Span ) -> T { let mut sum_prediction: T = FixedTrait::ZERO(); From 3de69074da5a97f7b0f524c10e11098ca29db189 Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Mon, 16 Oct 2023 15:29:33 +0300 Subject: [PATCH 04/42] generate doc --- docgen/src/main.rs | 8 ++ docs/SUMMARY.md | 5 ++ .../tree-classifier/README.md | 23 ++++++ .../tree-classifier/tree.predict.md | 35 +++++++++ .../tree-classifier/tree.predict_proba.md | 38 +++++++++ .../tree-regressor/tree.fit.md | 4 +- .../tree-regressor/tree.predict.md | 4 +- src/operators/ml/tree_classifier/core.cairo | 77 ++++++++++++++++++- src/operators/ml/tree_regressor/core.cairo | 2 +- 9 files changed, 190 insertions(+), 6 deletions(-) create mode 100644 docs/framework/operators/machine-learning/tree-classifier/README.md create mode 100644 docs/framework/operators/machine-learning/tree-classifier/tree.predict.md create mode 100644 docs/framework/operators/machine-learning/tree-classifier/tree.predict_proba.md diff --git a/docgen/src/main.rs b/docgen/src/main.rs index e628cf980..b29e49d77 100644 --- a/docgen/src/main.rs +++ b/docgen/src/main.rs @@ -42,6 +42,14 @@ fn main() { let trait_name: &str = "TreeRegressorTrait"; doc_trait(trait_path, doc_path, label); doc_functions(trait_path, doc_path, trait_name, label); + + // TREE ClASSIFIER DOC + let trait_path = "src/operators/ml/tree_classifier/core.cairo"; + let doc_path = "docs/framework/operators/machine-learning/tree-classifier"; + let label = "tree"; + let trait_name: &str = "TreeClassifierTrait"; + doc_trait(trait_path, doc_path, label); + doc_functions(trait_path, doc_path, trait_name, label); } fn doc_trait(trait_path: &str, doc_path: &str, label: &str) { diff --git a/docs/SUMMARY.md b/docs/SUMMARY.md index 8cdf6ddcb..520e39f9b 100644 --- a/docs/SUMMARY.md +++ b/docs/SUMMARY.md @@ -99,6 +99,11 @@ * [Tree Regressor](framework/operators/machine-learning/tree-regressor/README.md) * [tree.fit](framework/operators/machine-learning/tree-regressor/tree.fit.md) * [tree.predict](framework/operators/machine-learning/tree-regressor/tree.predict.md) + * [Tree Classifier](framework/operators/machine-learning/tree-classifier/README.md) + * [tree.predict](framework/operators/machine-learning/tree-classifier/tree.predict.md) + * [tree.predict_proba](framework/operators/machine-learning/tree-classifier/tree.predict_proba.md) + + ## 🏛 Hub diff --git a/docs/framework/operators/machine-learning/tree-classifier/README.md b/docs/framework/operators/machine-learning/tree-classifier/README.md new file mode 100644 index 000000000..8c371c996 --- /dev/null +++ b/docs/framework/operators/machine-learning/tree-classifier/README.md @@ -0,0 +1,23 @@ +# Tree Classifier + +`TreeClassifierTrait` provides a trait definition for decision tree classifier. This trait offers functionalities to build a decision tree and predict target values based on input features. + +```rust +use orion::operators::ml::TreeClassifierTrait; +``` + +### Data types + +Orion supports currently only fixed point data types for `TreeClassifierTrait`. + +| Data type | dtype | +| -------------------- | ------------------------------------------------------------- | +| Fixed point (signed) | `TreeClassifierTrait` | + +*** + +| function | description | +| --- | --- | +| [`tree.predict`](tree.predict.md) | Given a set of features, predicts the target value using the constructed decision tree. | +| [`tree.predict_proba`](tree.predict\_proba.md) | Predicts class probabilities based on feature data. | + diff --git a/docs/framework/operators/machine-learning/tree-classifier/tree.predict.md b/docs/framework/operators/machine-learning/tree-classifier/tree.predict.md new file mode 100644 index 000000000..efd46cddb --- /dev/null +++ b/docs/framework/operators/machine-learning/tree-classifier/tree.predict.md @@ -0,0 +1,35 @@ +# TreeClassifierTrait::predict + +```rust + fn predict(ref self: TreeClassifier, features: Span) -> T; +``` + +Predicts the target value for a set of features using the provided decision tree. + +## Args + +* `self`: A reference to the decision tree used for making the prediction. +* `features`: A span representing the features for which the prediction is to be made. + +## Returns + +The predicted target value. + +## Type Constraints + +Constrain input and output types to fixed point. + +## Examples + +```rust +use orion::operators::ml::{FP16x16TreeClassifier, TreeClassifierTrait, TreeClassifier}; +use orion::numbers::{FP16x16, FixedTrait}; + +fn tree_classifier_example(tree: TreeClassifier) { + + tree.predict( + array![FixedTrait::new_unscaled(1, false), FixedTrait::new_unscaled(2, false),].span() + ); + +} +``` diff --git a/docs/framework/operators/machine-learning/tree-classifier/tree.predict_proba.md b/docs/framework/operators/machine-learning/tree-classifier/tree.predict_proba.md new file mode 100644 index 000000000..56afcd4e0 --- /dev/null +++ b/docs/framework/operators/machine-learning/tree-classifier/tree.predict_proba.md @@ -0,0 +1,38 @@ +# TreeClassifierTrait::predict_proba + +```rust + fn predict_proba(ref self: TreeClassifier, features: Span) -> Span; +``` + +Given a set of features, this method traverses the decision tree +represented by `self` and returns the class distribution (probabilities) +found in the leaf node that matches the provided features. The traversal +stops once a leaf node is reached in the decision tree. + +## Args + +* `self`: A reference to the decision tree used for making the prediction. +* `features`: A span representing the features for which the prediction is to be made. + +## Returns + +Returns a `Span` representing the class distribution at the leaf node. + +## Type Constraints + +Constrain input and output types to fixed points. + +## Examples + +```rust +use orion::operators::ml::{FP16x16TreeClassifier, TreeClassifierTrait, TreeClassifier}; +use orion::numbers::{FP16x16, FixedTrait}; + +fn tree_classifier_example(tree: TreeClassifier) { + + tree.predict_proba( + array![FixedTrait::new_unscaled(1, false), FixedTrait::new_unscaled(2, false),].span() + ); + +} +``` diff --git a/docs/framework/operators/machine-learning/tree-regressor/tree.fit.md b/docs/framework/operators/machine-learning/tree-regressor/tree.fit.md index e74929547..0ba61814d 100644 --- a/docs/framework/operators/machine-learning/tree-regressor/tree.fit.md +++ b/docs/framework/operators/machine-learning/tree-regressor/tree.fit.md @@ -1,7 +1,7 @@ # TreeRegressorTrait::fit ```rust - fn fit(data: Span>, target: Span, max_depth: usize, random_state: usize) -> TreeNode; + fn fit(data: Span>, target: Span, max_depth: usize, random_state: usize) -> TreeRegressor; ``` Builds a decision tree based on the provided data and target values up to a specified maximum depth. @@ -15,7 +15,7 @@ Builds a decision tree based on the provided data and target values up to a spec ## Returns -A `TreeNode` representing the root of the constructed decision tree. +A `TreeRegressor` representing the root of the constructed decision tree. ## Type Constraints diff --git a/docs/framework/operators/machine-learning/tree-regressor/tree.predict.md b/docs/framework/operators/machine-learning/tree-regressor/tree.predict.md index 28d4a027c..c76714d58 100644 --- a/docs/framework/operators/machine-learning/tree-regressor/tree.predict.md +++ b/docs/framework/operators/machine-learning/tree-regressor/tree.predict.md @@ -1,7 +1,7 @@ # TreeRegressorTrait::predict ```rust - fn predict(ref self: TreeNode, features: Span) -> T; + fn predict(ref self: TreeRegressor, features: Span) -> T; ``` Predicts the target value for a set of features using the provided decision tree. @@ -17,7 +17,7 @@ The predicted target value. ## Type Constraints -Constrain input and output types to fixed point tensors. +Constrain input and output types to fixed point. ## Examples diff --git a/src/operators/ml/tree_classifier/core.cairo b/src/operators/ml/tree_classifier/core.cairo index 669c23abd..f4c948762 100644 --- a/src/operators/ml/tree_classifier/core.cairo +++ b/src/operators/ml/tree_classifier/core.cairo @@ -13,9 +13,84 @@ struct TreeClassifier { /// Trait /// /// predict - Given a set of features, predicts the target value using the constructed decision tree. -/// predict_proba - Given a set of features, predicts the probability of each X example being of a given class.. +/// predict_proba - Predicts class probabilities based on feature data. trait TreeClassifierTrait { + /// # TreeClassifierTrait::predict + /// + /// ```rust + /// fn predict(ref self: TreeClassifier, features: Span) -> T; + /// ``` + /// + /// Predicts the target value for a set of features using the provided decision tree. + /// + /// ## Args + /// + /// * `self`: A reference to the decision tree used for making the prediction. + /// * `features`: A span representing the features for which the prediction is to be made. + /// + /// ## Returns + /// + /// The predicted target value. + /// + /// ## Type Constraints + /// + /// Constrain input and output types to fixed point. + /// + /// ## Examples + /// + /// ```rust + /// use orion::operators::ml::{FP16x16TreeClassifier, TreeClassifierTrait, TreeClassifier}; + /// use orion::numbers::{FP16x16, FixedTrait}; + /// + /// fn tree_classifier_example(tree: TreeClassifier) { + /// + /// tree.predict( + /// array![FixedTrait::new_unscaled(1, false), FixedTrait::new_unscaled(2, false),].span() + /// ); + /// + /// } + /// ``` + /// fn predict(ref self: TreeClassifier, features: Span) -> T; + /// # TreeClassifierTrait::predict_proba + /// + /// ```rust + /// fn predict_proba(ref self: TreeClassifier, features: Span) -> Span; + /// ``` + /// + /// Given a set of features, this method traverses the decision tree + /// represented by `self` and returns the class distribution (probabilities) + /// found in the leaf node that matches the provided features. The traversal + /// stops once a leaf node is reached in the decision tree. + /// + /// ## Args + /// + /// * `self`: A reference to the decision tree used for making the prediction. + /// * `features`: A span representing the features for which the prediction is to be made. + /// + /// ## Returns + /// + /// Returns a `Span` representing the class distribution at the leaf node. + /// + /// ## Type Constraints + /// + /// Constrain input and output types to fixed points. + /// + /// ## Examples + /// + /// ```rust + /// use orion::operators::ml::{FP16x16TreeClassifier, TreeClassifierTrait, TreeClassifier}; + /// use orion::numbers::{FP16x16, FixedTrait}; + /// + /// fn tree_classifier_example(tree: TreeClassifier) { + /// + /// tree.predict_proba( + /// array![FixedTrait::new_unscaled(1, false), FixedTrait::new_unscaled(2, false),].span() + /// ); + /// + /// } + /// ``` + /// fn predict_proba(ref self: TreeClassifier, features: Span) -> Span; } diff --git a/src/operators/ml/tree_regressor/core.cairo b/src/operators/ml/tree_regressor/core.cairo index ea4ac0271..cb88de5b4 100644 --- a/src/operators/ml/tree_regressor/core.cairo +++ b/src/operators/ml/tree_regressor/core.cairo @@ -89,7 +89,7 @@ trait TreeRegressorTrait { /// /// ## Type Constraints /// - /// Constrain input and output types to fixed point tensors. + /// Constrain input and output types to fixed point. /// /// ## Examples /// From bcc412dda8acf30b6af26a1d939a606be19cf3ea Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Mon, 16 Oct 2023 16:15:33 +0300 Subject: [PATCH 05/42] add xgboost implementations + docstrings --- src/operators/ml.cairo | 6 +++ src/operators/ml/xgboost_regressor.cairo | 3 +- src/operators/ml/xgboost_regressor/core.cairo | 53 +++++++++++++++++-- .../xgboost_regressor/implementations.cairo | 4 ++ .../xgboost_regressor_fp16x16.cairo | 12 +++++ .../xgboost_regressor_fp32x32.cairo | 12 +++++ .../xgboost_regressor_fp64x64.cairo | 12 +++++ .../xgboost_regressor_fp8x23.cairo | 12 +++++ 8 files changed, 110 insertions(+), 4 deletions(-) create mode 100644 src/operators/ml/xgboost_regressor/implementations.cairo create mode 100644 src/operators/ml/xgboost_regressor/implementations/xgboost_regressor_fp16x16.cairo create mode 100644 src/operators/ml/xgboost_regressor/implementations/xgboost_regressor_fp32x32.cairo create mode 100644 src/operators/ml/xgboost_regressor/implementations/xgboost_regressor_fp64x64.cairo create mode 100644 src/operators/ml/xgboost_regressor/implementations/xgboost_regressor_fp8x23.cairo diff --git a/src/operators/ml.cairo b/src/operators/ml.cairo index 143cba273..47c6cfa32 100644 --- a/src/operators/ml.cairo +++ b/src/operators/ml.cairo @@ -6,3 +6,9 @@ use orion::operators::ml::tree_regressor::implementations::tree_regressor_fp16x1 use orion::operators::ml::tree_regressor::implementations::tree_regressor_fp8x23::FP8x23TreeRegressor; use orion::operators::ml::tree_regressor::implementations::tree_regressor_fp32x32::FP32x32TreeRegressor; use orion::operators::ml::tree_regressor::implementations::tree_regressor_fp64x64::FP64x64TreeRegressor; + +use orion::operators::ml::xgboost_regressor::core::{XGBoostRegressorTrait}; +use orion::operators::ml::xgboost_regressor::implementations::xgboost_regressor_fp16x16::FP16x16XGBoostRegressor; +use orion::operators::ml::xgboost_regressor::implementations::xgboost_regressor_fp8x23::FP8x23XGBoostRegressor; +use orion::operators::ml::xgboost_regressor::implementations::xgboost_regressor_fp32x32::FP32x32XGBoostRegressor; +use orion::operators::ml::xgboost_regressor::implementations::xgboost_regressor_fp64x64::FP64x64XGBoostRegressor; diff --git a/src/operators/ml/xgboost_regressor.cairo b/src/operators/ml/xgboost_regressor.cairo index ef33ab296..2ab1a62ac 100644 --- a/src/operators/ml/xgboost_regressor.cairo +++ b/src/operators/ml/xgboost_regressor.cairo @@ -1 +1,2 @@ -mod core; \ No newline at end of file +mod core; +mod implementations; \ No newline at end of file diff --git a/src/operators/ml/xgboost_regressor/core.cairo b/src/operators/ml/xgboost_regressor/core.cairo index d1bc87e04..e0534d10b 100644 --- a/src/operators/ml/xgboost_regressor/core.cairo +++ b/src/operators/ml/xgboost_regressor/core.cairo @@ -1,9 +1,56 @@ use orion::operators::ml::{TreeNode, TreeRegressorTrait}; use orion::numbers::FixedTrait; - -trait XGBoostPredictorTrait { - fn predict(trees: Span>, features: Span, weights: Span) -> T; +/// Trait +/// +/// predict - Predicts the target value for a set of features using the provided ensemble of decision trees. +trait XGBoostRegressorTrait { + /// # XGBoostRegressorTrait::predict + /// + /// ```rust + /// fn predict(ref self: Span>, ref features: Span, ref weights: Span) -> T; + /// ``` + /// + /// Predicts the target value for a set of features using the provided ensemble of decision trees + /// and combining their results using given weights. + /// + /// ## Args + /// + /// * `self`: A reference to a span representing a ensemble of decision trees. + /// * `features`: A reference to a span representing the features for which the prediction is to be made. + /// * `weights`: A reference to a span representing the weights applied to the predictions from each tree. + /// + /// ## Returns + /// + /// The predicted target value. + /// + /// ## Type Constraints + /// + /// Constrain input and output types to fixed point. + /// + /// ## Examples + /// + /// ```rust + /// use orion::operators::ml::{FP16x16XGBoostRegressor, TreeRegressorTrait, TreeRegressor}; + /// use orion::numbers::{FP16x16, FixedTrait}; + /// + /// fn xgboost_regressor_example(trees: Span>) { + /// + /// let mut features = array![ + /// FixedTrait::new_unscaled(1, false), + /// FixedTrait::new_unscaled(2, false), + /// ].span(); + /// + /// let mut weights = array![ + /// FixedTrait::new_unscaled(0.5, false), + /// FixedTrait::new_unscaled(0.5, false) + /// ].span(); + /// + /// FP16x16XGBoostRegressor::predict(ref trees, ref features, ref weights); + /// } + /// ``` + /// + fn predict(ref self: Span>, ref features: Span, ref weights: Span) -> T; } fn predict< diff --git a/src/operators/ml/xgboost_regressor/implementations.cairo b/src/operators/ml/xgboost_regressor/implementations.cairo new file mode 100644 index 000000000..cd493cf91 --- /dev/null +++ b/src/operators/ml/xgboost_regressor/implementations.cairo @@ -0,0 +1,4 @@ +mod xgboost_regressor_fp8x23; +mod xgboost_regressor_fp16x16; +mod xgboost_regressor_fp32x32; +mod xgboost_regressor_fp64x64; diff --git a/src/operators/ml/xgboost_regressor/implementations/xgboost_regressor_fp16x16.cairo b/src/operators/ml/xgboost_regressor/implementations/xgboost_regressor_fp16x16.cairo new file mode 100644 index 000000000..41661711b --- /dev/null +++ b/src/operators/ml/xgboost_regressor/implementations/xgboost_regressor_fp16x16.cairo @@ -0,0 +1,12 @@ +use orion::operators::ml::xgboost_regressor::core::{TreeNode, XGBoostRegressorTrait}; +use orion::operators::ml::xgboost_regressor::core; +use orion::operators::ml::FP16x16TreeRegressor; +use orion::numbers::FP16x16; + +impl FP16x16XGBoostRegressor of XGBoostRegressorTrait { + fn predict( + ref self: Span>, ref features: Span, ref weights: Span + ) -> FP16x16 { + core::predict(ref self, ref features, ref weights) + } +} diff --git a/src/operators/ml/xgboost_regressor/implementations/xgboost_regressor_fp32x32.cairo b/src/operators/ml/xgboost_regressor/implementations/xgboost_regressor_fp32x32.cairo new file mode 100644 index 000000000..83eca88ca --- /dev/null +++ b/src/operators/ml/xgboost_regressor/implementations/xgboost_regressor_fp32x32.cairo @@ -0,0 +1,12 @@ +use orion::operators::ml::xgboost_regressor::core::{TreeNode, XGBoostRegressorTrait}; +use orion::operators::ml::xgboost_regressor::core; +use orion::operators::ml::FP32x32TreeRegressor; +use orion::numbers::{FP32x32, FP32x32Impl}; + +impl FP32x32XGBoostRegressor of XGBoostRegressorTrait { + fn predict( + ref self: Span>, ref features: Span, ref weights: Span + ) -> FP32x32 { + core::predict(ref self, ref features, ref weights) + } +} diff --git a/src/operators/ml/xgboost_regressor/implementations/xgboost_regressor_fp64x64.cairo b/src/operators/ml/xgboost_regressor/implementations/xgboost_regressor_fp64x64.cairo new file mode 100644 index 000000000..21c967976 --- /dev/null +++ b/src/operators/ml/xgboost_regressor/implementations/xgboost_regressor_fp64x64.cairo @@ -0,0 +1,12 @@ +use orion::operators::ml::xgboost_regressor::core::{TreeNode, XGBoostRegressorTrait}; +use orion::operators::ml::xgboost_regressor::core; +use orion::operators::ml::FP64x64TreeRegressor; +use orion::numbers::{FP64x64, FP64x64Impl}; + +impl FP64x64XGBoostRegressor of XGBoostRegressorTrait { + fn predict( + ref self: Span>, ref features: Span, ref weights: Span + ) -> FP64x64 { + core::predict(ref self, ref features, ref weights) + } +} diff --git a/src/operators/ml/xgboost_regressor/implementations/xgboost_regressor_fp8x23.cairo b/src/operators/ml/xgboost_regressor/implementations/xgboost_regressor_fp8x23.cairo new file mode 100644 index 000000000..a011233bd --- /dev/null +++ b/src/operators/ml/xgboost_regressor/implementations/xgboost_regressor_fp8x23.cairo @@ -0,0 +1,12 @@ +use orion::operators::ml::xgboost_regressor::core::{TreeNode, XGBoostRegressorTrait}; +use orion::operators::ml::xgboost_regressor::core; +use orion::operators::ml::FP8x23TreeRegressor; +use orion::numbers::FP8x23; + +impl FP8x23XGBoostRegressor of XGBoostRegressorTrait { + fn predict( + ref self: Span>, ref features: Span, ref weights: Span + ) -> FP8x23 { + core::predict(ref self, ref features, ref weights) + } +} From 55402b2eac3b4230623e83a26aab165d22e79627 Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Mon, 16 Oct 2023 16:37:46 +0300 Subject: [PATCH 06/42] generate doc --- docgen/src/main.rs | 8 ++++ docs/SUMMARY.md | 2 + .../xgboost-regressor/README.md | 22 ++++++++++ .../xgboost-regressor/xgboost.predict.md | 44 +++++++++++++++++++ 4 files changed, 76 insertions(+) create mode 100644 docs/framework/operators/machine-learning/xgboost-regressor/README.md create mode 100644 docs/framework/operators/machine-learning/xgboost-regressor/xgboost.predict.md diff --git a/docgen/src/main.rs b/docgen/src/main.rs index e628cf980..15b962082 100644 --- a/docgen/src/main.rs +++ b/docgen/src/main.rs @@ -42,6 +42,14 @@ fn main() { let trait_name: &str = "TreeRegressorTrait"; doc_trait(trait_path, doc_path, label); doc_functions(trait_path, doc_path, trait_name, label); + + // XGBOOST REGRESSOR DOC + let trait_path = "src/operators/ml/xgboost_regressor/core.cairo"; + let doc_path = "docs/framework/operators/machine-learning/xgboost-regressor"; + let label = "xgboost"; + let trait_name: &str = "XGBoostRegressorTrait"; + doc_trait(trait_path, doc_path, label); + doc_functions(trait_path, doc_path, trait_name, label); } fn doc_trait(trait_path: &str, doc_path: &str, label: &str) { diff --git a/docs/SUMMARY.md b/docs/SUMMARY.md index 8cdf6ddcb..473871fe3 100644 --- a/docs/SUMMARY.md +++ b/docs/SUMMARY.md @@ -99,6 +99,8 @@ * [Tree Regressor](framework/operators/machine-learning/tree-regressor/README.md) * [tree.fit](framework/operators/machine-learning/tree-regressor/tree.fit.md) * [tree.predict](framework/operators/machine-learning/tree-regressor/tree.predict.md) + * [XGBoost Regressor](framework/operators/machine-learning/xgboost-regressor/README.md) + * [xgboost.predict](framework/operators/machine-learning/xgboost-regressor/xgboost.predict.md) ## 🏛 Hub diff --git a/docs/framework/operators/machine-learning/xgboost-regressor/README.md b/docs/framework/operators/machine-learning/xgboost-regressor/README.md new file mode 100644 index 000000000..1187916e8 --- /dev/null +++ b/docs/framework/operators/machine-learning/xgboost-regressor/README.md @@ -0,0 +1,22 @@ +# Tree Regressor + +`XGBoostRegressorTrait` provides a trait definition for xgboost regression. This trait offers functionalities to predict target values based on input features. + +```rust +use orion::operators::ml::XGBoostRegressorTrait; +``` + +### Data types + +Orion supports currently only fixed point data types for `XGBoostRegressorTrait`. + +| Data type | dtype | +| -------------------- | ------------------------------------------------------------- | +| Fixed point (signed) | `TreeRegressorTrait` | + +*** + +| function | description | +| --- | --- | +| [`xgboost.predict`](xgboost.predict.md) | Predicts the target value for a set of features using the provided ensemble of decision trees. | + diff --git a/docs/framework/operators/machine-learning/xgboost-regressor/xgboost.predict.md b/docs/framework/operators/machine-learning/xgboost-regressor/xgboost.predict.md new file mode 100644 index 000000000..ed7d7a31d --- /dev/null +++ b/docs/framework/operators/machine-learning/xgboost-regressor/xgboost.predict.md @@ -0,0 +1,44 @@ +# XGBoostRegressorTrait::predict + +```rust + fn predict(ref self: Span>, ref features: Span, ref weights: Span) -> T; +``` + +Predicts the target value for a set of features using the provided ensemble of decision trees +and combining their results using given weights. + +## Args + +* `self`: A reference to a span representing a ensemble of decision trees. +* `features`: A reference to a span representing the features for which the prediction is to be made. +* `weights`: A reference to a span representing the weights applied to the predictions from each tree. + +## Returns + +The predicted target value. + +## Type Constraints + +Constrain input and output types to fixed point. + +## Examples + +```rust +use orion::operators::ml::{FP16x16XGBoostRegressor, TreeRegressorTrait, TreeRegressor}; +use orion::numbers::{FP16x16, FixedTrait}; + +fn xgboost_regressor_example(trees: Span>) { + + let mut features = array![ + FixedTrait::new_unscaled(1, false), + FixedTrait::new_unscaled(2, false), + ].span(); + + let mut weights = array![ + FixedTrait::new_unscaled(0.5, false), + FixedTrait::new_unscaled(0.5, false) + ].span(); + + FP16x16XGBoostRegressor::predict(ref trees, ref features, ref weights); +} +``` From 5dfd5c573448d513f9e8585d2ed15ac4234a06ee Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Mon, 16 Oct 2023 17:02:08 +0300 Subject: [PATCH 07/42] rename trees --- .../implementations/xgboost_regressor_fp16x16.cairo | 4 ++-- .../implementations/xgboost_regressor_fp32x32.cairo | 4 ++-- .../implementations/xgboost_regressor_fp64x64.cairo | 4 ++-- .../implementations/xgboost_regressor_fp8x23.cairo | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/operators/ml/xgboost_regressor/implementations/xgboost_regressor_fp16x16.cairo b/src/operators/ml/xgboost_regressor/implementations/xgboost_regressor_fp16x16.cairo index 41661711b..e8202a8d1 100644 --- a/src/operators/ml/xgboost_regressor/implementations/xgboost_regressor_fp16x16.cairo +++ b/src/operators/ml/xgboost_regressor/implementations/xgboost_regressor_fp16x16.cairo @@ -1,11 +1,11 @@ -use orion::operators::ml::xgboost_regressor::core::{TreeNode, XGBoostRegressorTrait}; +use orion::operators::ml::xgboost_regressor::core::{TreeRegressor, XGBoostRegressorTrait}; use orion::operators::ml::xgboost_regressor::core; use orion::operators::ml::FP16x16TreeRegressor; use orion::numbers::FP16x16; impl FP16x16XGBoostRegressor of XGBoostRegressorTrait { fn predict( - ref self: Span>, ref features: Span, ref weights: Span + ref self: Span>, ref features: Span, ref weights: Span ) -> FP16x16 { core::predict(ref self, ref features, ref weights) } diff --git a/src/operators/ml/xgboost_regressor/implementations/xgboost_regressor_fp32x32.cairo b/src/operators/ml/xgboost_regressor/implementations/xgboost_regressor_fp32x32.cairo index 83eca88ca..6d266fce4 100644 --- a/src/operators/ml/xgboost_regressor/implementations/xgboost_regressor_fp32x32.cairo +++ b/src/operators/ml/xgboost_regressor/implementations/xgboost_regressor_fp32x32.cairo @@ -1,11 +1,11 @@ -use orion::operators::ml::xgboost_regressor::core::{TreeNode, XGBoostRegressorTrait}; +use orion::operators::ml::xgboost_regressor::core::{TreeRegressor, XGBoostRegressorTrait}; use orion::operators::ml::xgboost_regressor::core; use orion::operators::ml::FP32x32TreeRegressor; use orion::numbers::{FP32x32, FP32x32Impl}; impl FP32x32XGBoostRegressor of XGBoostRegressorTrait { fn predict( - ref self: Span>, ref features: Span, ref weights: Span + ref self: Span>, ref features: Span, ref weights: Span ) -> FP32x32 { core::predict(ref self, ref features, ref weights) } diff --git a/src/operators/ml/xgboost_regressor/implementations/xgboost_regressor_fp64x64.cairo b/src/operators/ml/xgboost_regressor/implementations/xgboost_regressor_fp64x64.cairo index 21c967976..ff21c9860 100644 --- a/src/operators/ml/xgboost_regressor/implementations/xgboost_regressor_fp64x64.cairo +++ b/src/operators/ml/xgboost_regressor/implementations/xgboost_regressor_fp64x64.cairo @@ -1,11 +1,11 @@ -use orion::operators::ml::xgboost_regressor::core::{TreeNode, XGBoostRegressorTrait}; +use orion::operators::ml::xgboost_regressor::core::{TreeRegressor, XGBoostRegressorTrait}; use orion::operators::ml::xgboost_regressor::core; use orion::operators::ml::FP64x64TreeRegressor; use orion::numbers::{FP64x64, FP64x64Impl}; impl FP64x64XGBoostRegressor of XGBoostRegressorTrait { fn predict( - ref self: Span>, ref features: Span, ref weights: Span + ref self: Span>, ref features: Span, ref weights: Span ) -> FP64x64 { core::predict(ref self, ref features, ref weights) } diff --git a/src/operators/ml/xgboost_regressor/implementations/xgboost_regressor_fp8x23.cairo b/src/operators/ml/xgboost_regressor/implementations/xgboost_regressor_fp8x23.cairo index a011233bd..ac2f1d3b5 100644 --- a/src/operators/ml/xgboost_regressor/implementations/xgboost_regressor_fp8x23.cairo +++ b/src/operators/ml/xgboost_regressor/implementations/xgboost_regressor_fp8x23.cairo @@ -1,11 +1,11 @@ -use orion::operators::ml::xgboost_regressor::core::{TreeNode, XGBoostRegressorTrait}; +use orion::operators::ml::xgboost_regressor::core::{TreeRegressor, XGBoostRegressorTrait}; use orion::operators::ml::xgboost_regressor::core; use orion::operators::ml::FP8x23TreeRegressor; use orion::numbers::FP8x23; impl FP8x23XGBoostRegressor of XGBoostRegressorTrait { fn predict( - ref self: Span>, ref features: Span, ref weights: Span + ref self: Span>, ref features: Span, ref weights: Span ) -> FP8x23 { core::predict(ref self, ref features, ref weights) } From cb32695371cd5f6f7880907502f3c8da29adff1f Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Mon, 16 Oct 2023 17:15:15 +0300 Subject: [PATCH 08/42] remove fit from TreeRegressor --- .../machine-learning/tree-regressor/README.md | 1 - .../tree-regressor/tree.fit.md | 50 --- .../tree-regressor/tree.predict.md | 26 +- src/operators/ml/tree_regressor/core.cairo | 347 +----------------- .../tree_regressor_fp16x16.cairo | 6 - .../tree_regressor_fp32x32.cairo | 6 - .../tree_regressor_fp64x64.cairo | 6 - .../tree_regressor_fp8x23.cairo | 6 - tests/src/ml/tree_regressor.cairo | 72 +--- 9 files changed, 9 insertions(+), 511 deletions(-) delete mode 100644 docs/framework/operators/machine-learning/tree-regressor/tree.fit.md diff --git a/docs/framework/operators/machine-learning/tree-regressor/README.md b/docs/framework/operators/machine-learning/tree-regressor/README.md index 7df2112c4..286587884 100644 --- a/docs/framework/operators/machine-learning/tree-regressor/README.md +++ b/docs/framework/operators/machine-learning/tree-regressor/README.md @@ -18,6 +18,5 @@ Orion supports currently only fixed point data types for `TreeRegressorTrait`. | function | description | | --- | --- | -| [`tree.fit`](tree.fit.md) | Constructs a decision tree regressor based on the provided data and target values. | | [`tree.predict`](tree.predict.md) | Given a set of features, predicts the target value using the constructed decision tree. | diff --git a/docs/framework/operators/machine-learning/tree-regressor/tree.fit.md b/docs/framework/operators/machine-learning/tree-regressor/tree.fit.md deleted file mode 100644 index 0ba61814d..000000000 --- a/docs/framework/operators/machine-learning/tree-regressor/tree.fit.md +++ /dev/null @@ -1,50 +0,0 @@ -# TreeRegressorTrait::fit - -```rust - fn fit(data: Span>, target: Span, max_depth: usize, random_state: usize) -> TreeRegressor; -``` - -Builds a decision tree based on the provided data and target values up to a specified maximum depth. - -## Args - -* `data`: A span of spans representing rows of features in the dataset. -* `target`: A span representing the target values corresponding to each row in the dataset. -* `max_depth`: The maximum depth of the decision tree. The tree stops growing once this depth is reached. -* `random_state`: It ensures that the tie-breaking is consistent across multiple runs, leading to reproducible results. - -## Returns - -A `TreeRegressor` representing the root of the constructed decision tree. - -## Type Constraints - -Constrain input and output types to fixed point tensors. - -## Examples - -```rust -use orion::operators::ml::{FP16x16TreeRegressor, TreeRegressorTrait}; -use orion::numbers::{FP16x16, FixedTrait}; - -fn tree_regressor_example() { - - let data = array![ - array![FixedTrait::new_unscaled(1, false), FixedTrait::new_unscaled(2, false)].span(), - array![FixedTrait::new_unscaled(3, false), FixedTrait::new_unscaled(4, false)].span(), - array![FixedTrait::new_unscaled(5, false), FixedTrait::new_unscaled(6, false)].span(), - array![FixedTrait::new_unscaled(7, false), FixedTrait::new_unscaled(8, false)].span(), - ] - .span(); - - let target = array![ - FixedTrait::new_unscaled(2, false), - FixedTrait::new_unscaled(4, false), - FixedTrait::new_unscaled(6, false), - FixedTrait::new_unscaled(8, false) - ] - .span(); - - TreeRegressorTrait::fit(data, target, 3, 42); -} -``` diff --git a/docs/framework/operators/machine-learning/tree-regressor/tree.predict.md b/docs/framework/operators/machine-learning/tree-regressor/tree.predict.md index c76714d58..6281af625 100644 --- a/docs/framework/operators/machine-learning/tree-regressor/tree.predict.md +++ b/docs/framework/operators/machine-learning/tree-regressor/tree.predict.md @@ -22,32 +22,14 @@ Constrain input and output types to fixed point. ## Examples ```rust -use orion::operators::ml::{FP16x16TreeRegressor, TreeRegressorTrait}; +use orion::operators::ml::{FP16x16TreeRegressor, TreeRegressorTrait, TreeRegressor}; use orion::numbers::{FP16x16, FixedTrait}; -fn tree_regressor_example() { +fn tree_regressor_example(tree: TreeRegressor) { - let data = array![ - array![FixedTrait::new_unscaled(1, false), FixedTrait::new_unscaled(2, false)].span(), - array![FixedTrait::new_unscaled(3, false), FixedTrait::new_unscaled(4, false)].span(), - array![FixedTrait::new_unscaled(5, false), FixedTrait::new_unscaled(6, false)].span(), - array![FixedTrait::new_unscaled(7, false), FixedTrait::new_unscaled(8, false)].span(), - ] - .span(); - - let target = array![ - FixedTrait::new_unscaled(2, false), - FixedTrait::new_unscaled(4, false), - FixedTrait::new_unscaled(6, false), - FixedTrait::new_unscaled(8, false) - ] - .span(); - - let mut tree = TreeRegressorTrait::fit(data, target, 3); - - let prediction_1 = tree - .predict( + tree.predict( array![FixedTrait::new_unscaled(1, false), FixedTrait::new_unscaled(2, false),].span() ); + } ``` diff --git a/src/operators/ml/tree_regressor/core.cairo b/src/operators/ml/tree_regressor/core.cairo index cb88de5b4..1206d9094 100644 --- a/src/operators/ml/tree_regressor/core.cairo +++ b/src/operators/ml/tree_regressor/core.cairo @@ -13,63 +13,8 @@ struct TreeRegressor { /// Trait /// -/// fit - Constructs a decision tree regressor based on the provided data and target values. /// predict - Given a set of features, predicts the target value using the constructed decision tree. trait TreeRegressorTrait { - /// # TreeRegressorTrait::fit - /// - /// ```rust - /// fn fit(data: Span>, target: Span, max_depth: usize, random_state: usize) -> TreeRegressor; - /// ``` - /// - /// Builds a decision tree based on the provided data and target values up to a specified maximum depth. - /// - /// ## Args - /// - /// * `data`: A span of spans representing rows of features in the dataset. - /// * `target`: A span representing the target values corresponding to each row in the dataset. - /// * `max_depth`: The maximum depth of the decision tree. The tree stops growing once this depth is reached. - /// * `random_state`: It ensures that the tie-breaking is consistent across multiple runs, leading to reproducible results. - /// - /// ## Returns - /// - /// A `TreeRegressor` representing the root of the constructed decision tree. - /// - /// ## Type Constraints - /// - /// Constrain input and output types to fixed point tensors. - /// - /// ## Examples - /// - /// ```rust - /// use orion::operators::ml::{FP16x16TreeRegressor, TreeRegressorTrait}; - /// use orion::numbers::{FP16x16, FixedTrait}; - /// - /// fn tree_regressor_example() { - /// - /// let data = array![ - /// array![FixedTrait::new_unscaled(1, false), FixedTrait::new_unscaled(2, false)].span(), - /// array![FixedTrait::new_unscaled(3, false), FixedTrait::new_unscaled(4, false)].span(), - /// array![FixedTrait::new_unscaled(5, false), FixedTrait::new_unscaled(6, false)].span(), - /// array![FixedTrait::new_unscaled(7, false), FixedTrait::new_unscaled(8, false)].span(), - /// ] - /// .span(); - /// - /// let target = array![ - /// FixedTrait::new_unscaled(2, false), - /// FixedTrait::new_unscaled(4, false), - /// FixedTrait::new_unscaled(6, false), - /// FixedTrait::new_unscaled(8, false) - /// ] - /// .span(); - /// - /// TreeRegressorTrait::fit(data, target, 3, 42); - /// } - /// ``` - /// - fn fit( - data: Span>, target: Span, max_depth: usize, random_state: usize - ) -> TreeRegressor; /// # TreeRegressorTrait::predict /// /// ```rust @@ -94,33 +39,15 @@ trait TreeRegressorTrait { /// ## Examples /// /// ```rust - /// use orion::operators::ml::{FP16x16TreeRegressor, TreeRegressorTrait}; + /// use orion::operators::ml::{FP16x16TreeRegressor, TreeRegressorTrait, TreeRegressor}; /// use orion::numbers::{FP16x16, FixedTrait}; /// - /// fn tree_regressor_example() { + /// fn tree_regressor_example(tree: TreeRegressor) { /// - /// let data = array![ - /// array![FixedTrait::new_unscaled(1, false), FixedTrait::new_unscaled(2, false)].span(), - /// array![FixedTrait::new_unscaled(3, false), FixedTrait::new_unscaled(4, false)].span(), - /// array![FixedTrait::new_unscaled(5, false), FixedTrait::new_unscaled(6, false)].span(), - /// array![FixedTrait::new_unscaled(7, false), FixedTrait::new_unscaled(8, false)].span(), - /// ] - /// .span(); - /// - /// let target = array![ - /// FixedTrait::new_unscaled(2, false), - /// FixedTrait::new_unscaled(4, false), - /// FixedTrait::new_unscaled(6, false), - /// FixedTrait::new_unscaled(8, false) - /// ] - /// .span(); - /// - /// let mut tree = TreeRegressorTrait::fit(data, target, 3); - /// - /// let prediction_1 = tree - /// .predict( + /// tree.predict( /// array![FixedTrait::new_unscaled(1, false), FixedTrait::new_unscaled(2, false),].span() /// ); + /// /// } /// ``` /// @@ -163,269 +90,3 @@ fn predict< current_node.prediction } - -fn mse< - T, - MAG, - impl FFixedTrait: FixedTrait, - impl TSub: Sub, - impl TAddEq: AddEq, - impl TDiv: Div, - impl U32IntoMAG: Into, - impl FeltTryIntoMAG: TryInto, - impl TCopy: Copy, - impl TDrop: Drop, ->( - y: Span, prediction: T -) -> T { - let mut sum_squared_error: T = FixedTrait::ZERO(); - - let mut y_copy = y; - loop { - match y_copy.pop_front() { - Option::Some(yi) => { - let error = *yi - prediction; - sum_squared_error += error - .pow(FixedTrait::new_unscaled(2.try_into().unwrap(), false)); - }, - Option::None(_) => { - break; - } - }; - }; - - sum_squared_error / FixedTrait::new_unscaled(y.len().into(), false) -} - -fn best_split< - T, - MAG, - impl FFixedTrait: FixedTrait, - impl TPartialOrd: PartialOrd, - impl TPartialEq: PartialEq, - impl TAddEq: AddEq, - impl TAdd: Add, - impl TSub: Sub, - impl TDiv: Div, - impl TMul: Mul, - impl U32IntoMAG: Into, - impl FeltTryIntoMAG: TryInto, - impl TCopy: Copy, - impl TDrop: Drop, ->( - data: Span>, target: Span, random_state: usize -) -> (usize, T, T) { - let mut best_mse = FixedTrait::MAX(); - let mut best_split_feature = 0; - let mut best_splits: Array<(usize, T, T)> = ArrayTrait::new(); - - let n_features: u32 = (*data[0]).len(); - - let mut feature = 0; - loop { - if feature == n_features { - break; - }; - - let mut unique_values = ArrayTrait::new(); - let mut data_copy = data; - loop { - match data_copy.pop_front() { - Option::Some(row) => { - unique_values.append(*row[feature]) - }, - Option::None(_) => { - break; - } - }; - }; - - let mut unique_values = unique_values.span(); - loop { - match unique_values.pop_front() { - Option::Some(value) => { - let mut left_target = ArrayTrait::new(); - let mut right_target = ArrayTrait::new(); - - let mut i = 0; - let mut target_copy = target; - loop { - match target_copy.pop_front() { - Option::Some(t) => { - if *(*data.at(i))[feature] < *value { - left_target.append(*t); - } else { - right_target.append(*t); - } - i += 1; - }, - Option::None(_) => { - break; - } - }; - }; - - if !left_target.is_empty() && !right_target.is_empty() { - let mut left_sum = FixedTrait::ZERO(); - let mut left_target_copy = left_target.span(); - loop { - match left_target_copy.pop_front() { - Option::Some(val) => { - left_sum += *val; - }, - Option::None(_) => { - break; - } - }; - }; - let left_target_as_fp: T = FixedTrait::new_unscaled( - left_target.len().into(), false - ); - let left_pred = left_sum / left_target_as_fp; - - let mut right_sum = FixedTrait::ZERO(); - let mut right_target_copy = right_target.span(); - loop { - match right_target_copy.pop_front() { - Option::Some(val) => { - right_sum += *val; - }, - Option::None(_) => { - break; - } - }; - }; - let right_target_as_fp: T = FixedTrait::new_unscaled( - right_target.len().into(), false - ); - let right_pred = right_sum / right_target_as_fp; - - let current_mse = (left_target_as_fp * mse(left_target.span(), left_pred)) - + (right_target_as_fp * mse(right_target.span(), right_pred)); - - if !(current_mse > best_mse) { - if current_mse < best_mse { - best_mse = current_mse; - best_splits = array![]; - } - - let mut total_sum = FixedTrait::ZERO(); - let mut target_copy = target; - loop { - match target_copy.pop_front() { - Option::Some(t) => { - total_sum += *t; - }, - Option::None(_) => { - break; - } - }; - }; - - let prediction = total_sum - / FixedTrait::new_unscaled(target.len().into(), false); - - best_splits.append((feature, *value, prediction)); - } - } - }, - Option::None(_) => { - break; - } - }; - }; - - feature += 1; - }; - - let random_idx: usize = u64_between(random_state.into(), 0, best_splits.len().into()) - .try_into() - .unwrap(); - let (best_split_feature, best_split_value, best_prediction) = *best_splits.at(random_idx); - - (best_split_feature, best_split_value, best_prediction) -} - -fn fit< - T, - MAG, - impl FFixedTrait: FixedTrait, - impl TPartialOrd: PartialOrd, - impl TPartialEq: PartialEq, - impl TAddEq: AddEq, - impl TAdd: Add, - impl TSub: Sub, - impl TDiv: Div, - impl TMul: Mul, - impl U32IntoMAG: Into, - impl FeltTryIntoMAG: TryInto, - impl TCopy: Copy, - impl TDrop: Drop, ->( - data: Span>, target: Span, depth: usize, max_depth: usize, random_state: usize -) -> TreeRegressor { - if depth == max_depth || data.len() < 2 { - let mut total = FixedTrait::ZERO(); - let mut target_copy = target; - loop { - match target_copy.pop_front() { - Option::Some(val) => { - total += *val; - }, - Option::None(_) => { - break; - } - }; - }; - return TreeRegressor { - left: Option::None(()), - right: Option::None(()), - split_feature: 0, - split_value: FixedTrait::ZERO(), - prediction: total / FixedTrait::new_unscaled(target.len().into(), false), - }; - } - - let (split_feature, split_value, prediction) = best_split(data, target, random_state); - let mut left_data = ArrayTrait::new(); - let mut left_target = ArrayTrait::new(); - - let mut right_data = ArrayTrait::new(); - let mut right_target = ArrayTrait::new(); - - let mut data_copy = data; - let mut i: usize = 0; - loop { - match data_copy.pop_front() { - Option::Some(row) => { - if *(*row).at(split_feature) < split_value { - left_data.append(row.clone()); - left_target.append(*target[i]) - } else { - right_data.append(row.clone()); - right_target.append(*target[i]) - } - i += 1 - }, - Option::None(_) => { - break; - } - }; - }; - - TreeRegressor { - left: Option::Some( - BoxTrait::new( - fit(left_data.span(), left_target.span(), depth + 1, max_depth, random_state) - ) - ), - right: Option::Some( - BoxTrait::new( - fit(right_data.span(), right_target.span(), depth + 1, max_depth, random_state) - ) - ), - split_feature, - split_value, - prediction, - } -} diff --git a/src/operators/ml/tree_regressor/implementations/tree_regressor_fp16x16.cairo b/src/operators/ml/tree_regressor/implementations/tree_regressor_fp16x16.cairo index 7aeb6eb69..10fa1aa53 100644 --- a/src/operators/ml/tree_regressor/implementations/tree_regressor_fp16x16.cairo +++ b/src/operators/ml/tree_regressor/implementations/tree_regressor_fp16x16.cairo @@ -3,12 +3,6 @@ use orion::operators::ml::tree_regressor::core; use orion::numbers::FP16x16; impl FP16x16TreeRegressor of TreeRegressorTrait { - fn fit( - data: Span>, target: Span, max_depth: usize, random_state: usize - ) -> TreeRegressor { - core::fit(data, target, 0, max_depth, random_state) - } - fn predict(ref self: TreeRegressor, features: Span) -> FP16x16 { core::predict(ref self, features) } diff --git a/src/operators/ml/tree_regressor/implementations/tree_regressor_fp32x32.cairo b/src/operators/ml/tree_regressor/implementations/tree_regressor_fp32x32.cairo index 288d7e15d..72c5033c2 100644 --- a/src/operators/ml/tree_regressor/implementations/tree_regressor_fp32x32.cairo +++ b/src/operators/ml/tree_regressor/implementations/tree_regressor_fp32x32.cairo @@ -3,12 +3,6 @@ use orion::operators::ml::tree_regressor::core; use orion::numbers::{FP32x32, FP32x32Impl}; impl FP32x32TreeRegressor of TreeRegressorTrait { - fn fit( - data: Span>, target: Span, max_depth: usize, random_state: usize - ) -> TreeRegressor { - core::fit(data, target, 0, max_depth, random_state) - } - fn predict(ref self: TreeRegressor, features: Span) -> FP32x32 { core::predict(ref self, features) } diff --git a/src/operators/ml/tree_regressor/implementations/tree_regressor_fp64x64.cairo b/src/operators/ml/tree_regressor/implementations/tree_regressor_fp64x64.cairo index 9102428fc..4450f630c 100644 --- a/src/operators/ml/tree_regressor/implementations/tree_regressor_fp64x64.cairo +++ b/src/operators/ml/tree_regressor/implementations/tree_regressor_fp64x64.cairo @@ -3,12 +3,6 @@ use orion::operators::ml::tree_regressor::core; use orion::numbers::{FP64x64, FP64x64Impl}; impl FP64x64TreeRegressor of TreeRegressorTrait { - fn fit( - data: Span>, target: Span, max_depth: usize, random_state: usize - ) -> TreeRegressor { - core::fit(data, target, 0, max_depth, random_state) - } - fn predict(ref self: TreeRegressor, features: Span) -> FP64x64 { core::predict(ref self, features) } diff --git a/src/operators/ml/tree_regressor/implementations/tree_regressor_fp8x23.cairo b/src/operators/ml/tree_regressor/implementations/tree_regressor_fp8x23.cairo index 54c195704..f6b1361be 100644 --- a/src/operators/ml/tree_regressor/implementations/tree_regressor_fp8x23.cairo +++ b/src/operators/ml/tree_regressor/implementations/tree_regressor_fp8x23.cairo @@ -3,12 +3,6 @@ use orion::operators::ml::tree_regressor::core; use orion::numbers::FP8x23; impl FP8x23TreeRegressor of TreeRegressorTrait { - fn fit( - data: Span>, target: Span, max_depth: usize, random_state: usize - ) -> TreeRegressor { - core::fit(data, target, 0, max_depth, random_state) - } - fn predict(ref self: TreeRegressor, features: Span) -> FP8x23 { core::predict(ref self, features) } diff --git a/tests/src/ml/tree_regressor.cairo b/tests/src/ml/tree_regressor.cairo index 057fa58b0..98a9a3265 100644 --- a/tests/src/ml/tree_regressor.cairo +++ b/tests/src/ml/tree_regressor.cairo @@ -1,71 +1 @@ -use orion::operators::ml::{FP16x16TreeRegressor, TreeRegressorTrait}; -use orion::operators::ml::tree_regressor::core::mse; -use orion::numbers::{FP16x16, FixedTrait}; - -#[test] -#[available_gas(2000000000000)] -fn test_mse() { - let mut y = array![ - FixedTrait::new_unscaled(2, false), - FixedTrait::new_unscaled(4, false), - FixedTrait::new_unscaled(6, false), - FixedTrait::new_unscaled(8, false) - ] - .span(); - - let prediction = FixedTrait::::new_unscaled(5, false); - let expected_mse = FixedTrait::::new_unscaled( - 5, false - ); // MSE = [(2-5)^2 + (4-5)^2 + (6-5)^2 + (8-5)^2] / 4 = 5 - - let computed_mse = mse(y, prediction); - assert(computed_mse == expected_mse, 'Failed mse'); -} - - -#[test] -#[available_gas(2000000000000)] -fn test_tree() { - let data = array![ - array![FixedTrait::new_unscaled(1, false), FixedTrait::new_unscaled(2, false)].span(), - array![FixedTrait::new_unscaled(3, false), FixedTrait::new_unscaled(4, false)].span(), - array![FixedTrait::new_unscaled(5, false), FixedTrait::new_unscaled(6, false)].span(), - array![FixedTrait::new_unscaled(7, false), FixedTrait::new_unscaled(8, false)].span(), - ] - .span(); - - let target = array![ - FixedTrait::new_unscaled(2, false), - FixedTrait::new_unscaled(4, false), - FixedTrait::new_unscaled(6, false), - FixedTrait::new_unscaled(8, false) - ] - .span(); - - let mut tree = TreeRegressorTrait::fit(data, target, 3, 42); - - let prediction_1 = tree - .predict( - array![FixedTrait::new_unscaled(1, false), FixedTrait::new_unscaled(2, false),].span() - ); - - let prediction_2 = tree - .predict( - array![FixedTrait::new_unscaled(3, false), FixedTrait::new_unscaled(4, false)].span() - ); - - let prediction_3 = tree - .predict( - array![FixedTrait::new_unscaled(5, false), FixedTrait::new_unscaled(6, false)].span() - ); - - let prediction_4 = tree - .predict( - array![FixedTrait::new_unscaled(7, false), FixedTrait::new_unscaled(8, false)].span() - ); - - assert(prediction_1 == FixedTrait::::new_unscaled(2, false), 'should predict 2'); - assert(prediction_2 == FixedTrait::::new_unscaled(4, false), 'should predict 4'); - assert(prediction_3 == FixedTrait::::new_unscaled(6, false), 'should predict 6'); - assert(prediction_4 == FixedTrait::::new_unscaled(8, false), 'should predict 8'); -} +// TODO: make test once Tree transpilation implemented \ No newline at end of file From 7ab41bcf630ca0695c6d96f35cec2b8e8aeb3d8c Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Mon, 16 Oct 2023 18:29:03 +0300 Subject: [PATCH 09/42] remove fit from doc --- docs/SUMMARY.md | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/SUMMARY.md b/docs/SUMMARY.md index 66647a7a3..07de3f249 100644 --- a/docs/SUMMARY.md +++ b/docs/SUMMARY.md @@ -97,7 +97,6 @@ * [nn.linear](framework/operators/neural-network/nn.linear.md) * [Machine Learning](framework/operators/machine-learning/README.md) * [Tree Regressor](framework/operators/machine-learning/tree-regressor/README.md) - * [tree.fit](framework/operators/machine-learning/tree-regressor/tree.fit.md) * [tree.predict](framework/operators/machine-learning/tree-regressor/tree.predict.md) * [Tree Classifier](framework/operators/machine-learning/tree-classifier/README.md) * [tree.predict](framework/operators/machine-learning/tree-classifier/tree.predict.md) From 5c98b64f45fabef22a167f562f9004b6f9032752 Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Sun, 22 Oct 2023 10:38:09 +0300 Subject: [PATCH 10/42] implement fp16x16wide --- src/numbers/fixed_point/implementations.cairo | 1 + .../implementations/fp16x16wide.cairo | 3 + .../implementations/fp16x16wide/core.cairo | 390 ++++++ .../implementations/fp16x16wide/helpers.cairo | 41 + .../implementations/fp16x16wide/math.cairo | 5 + .../fp16x16wide/math/comp.cairo | 76 + .../fp16x16wide/math/core.cairo | 659 +++++++++ .../fp16x16wide/math/hyp.cairo | 159 +++ .../fp16x16wide/math/lut.cairo | 1235 +++++++++++++++++ .../fp16x16wide/math/trig.cairo | 450 ++++++ 10 files changed, 3019 insertions(+) create mode 100644 src/numbers/fixed_point/implementations/fp16x16wide.cairo create mode 100644 src/numbers/fixed_point/implementations/fp16x16wide/core.cairo create mode 100644 src/numbers/fixed_point/implementations/fp16x16wide/helpers.cairo create mode 100644 src/numbers/fixed_point/implementations/fp16x16wide/math.cairo create mode 100644 src/numbers/fixed_point/implementations/fp16x16wide/math/comp.cairo create mode 100644 src/numbers/fixed_point/implementations/fp16x16wide/math/core.cairo create mode 100644 src/numbers/fixed_point/implementations/fp16x16wide/math/hyp.cairo create mode 100644 src/numbers/fixed_point/implementations/fp16x16wide/math/lut.cairo create mode 100644 src/numbers/fixed_point/implementations/fp16x16wide/math/trig.cairo diff --git a/src/numbers/fixed_point/implementations.cairo b/src/numbers/fixed_point/implementations.cairo index 8b010e349..e6152e25a 100644 --- a/src/numbers/fixed_point/implementations.cairo +++ b/src/numbers/fixed_point/implementations.cairo @@ -2,3 +2,4 @@ mod fp8x23; mod fp16x16; mod fp64x64; mod fp32x32; +mod fp16x16wide; \ No newline at end of file diff --git a/src/numbers/fixed_point/implementations/fp16x16wide.cairo b/src/numbers/fixed_point/implementations/fp16x16wide.cairo new file mode 100644 index 000000000..e9acee340 --- /dev/null +++ b/src/numbers/fixed_point/implementations/fp16x16wide.cairo @@ -0,0 +1,3 @@ +mod core; +mod math; +mod helpers; diff --git a/src/numbers/fixed_point/implementations/fp16x16wide/core.cairo b/src/numbers/fixed_point/implementations/fp16x16wide/core.cairo new file mode 100644 index 000000000..01a1d8b8d --- /dev/null +++ b/src/numbers/fixed_point/implementations/fp16x16wide/core.cairo @@ -0,0 +1,390 @@ +use debug::PrintTrait; + +use option::OptionTrait; +use result::{ResultTrait, ResultTraitImpl}; +use traits::{TryInto, Into}; + +use orion::numbers::signed_integer::{i32::i32, i8::i8}; +use orion::numbers::fixed_point::core::FixedTrait; +use orion::numbers::fixed_point::implementations::fp16x16wide::math::{core, trig, hyp}; +use orion::numbers::fixed_point::utils; + +/// A struct representing a fixed point number. +#[derive(Serde, Copy, Drop)] +struct FP16x16W { + mag: u64, + sign: bool +} + +// CONSTANTS + +const TWO: u64 = 131072; // 2 ** 17 +const ONE: u64 = 65536; // 2 ** 16 +const HALF: u64 = 32768; // 2 ** 15 +const MAX: u64 = 2147483648; // 2 ** 31 + + +impl FP16x16WImpl of FixedTrait { + fn ZERO() -> FP16x16W { + return FP16x16W { mag: 0, sign: false }; + } + + fn ONE() -> FP16x16W { + return FP16x16W { mag: ONE, sign: false }; + } + + fn MAX() -> FP16x16W { + return FP16x16W { mag: MAX, sign: false }; + } + + fn new(mag: u64, sign: bool) -> FP16x16W { + return FP16x16W { mag: mag, sign: sign }; + } + + fn new_unscaled(mag: u64, sign: bool) -> FP16x16W { + return FP16x16W { mag: mag * ONE, sign: sign }; + } + + fn from_felt(val: felt252) -> FP16x16W { + let mag = integer::u64_try_from_felt252(utils::felt_abs(val)).unwrap(); + return FixedTrait::new(mag, utils::felt_sign(val)); + } + + fn abs(self: FP16x16W) -> FP16x16W { + return core::abs(self); + } + + fn acos(self: FP16x16W) -> FP16x16W { + return trig::acos_fast(self); + } + + fn acos_fast(self: FP16x16W) -> FP16x16W { + return trig::acos_fast(self); + } + + fn acosh(self: FP16x16W) -> FP16x16W { + return hyp::acosh(self); + } + + fn asin(self: FP16x16W) -> FP16x16W { + return trig::asin_fast(self); + } + + fn asin_fast(self: FP16x16W) -> FP16x16W { + return trig::asin_fast(self); + } + + fn asinh(self: FP16x16W) -> FP16x16W { + return hyp::asinh(self); + } + + fn atan(self: FP16x16W) -> FP16x16W { + return trig::atan_fast(self); + } + + fn atan_fast(self: FP16x16W) -> FP16x16W { + return trig::atan_fast(self); + } + + fn atanh(self: FP16x16W) -> FP16x16W { + return hyp::atanh(self); + } + + fn ceil(self: FP16x16W) -> FP16x16W { + return core::ceil(self); + } + + fn cos(self: FP16x16W) -> FP16x16W { + return trig::cos_fast(self); + } + + fn cos_fast(self: FP16x16W) -> FP16x16W { + return trig::cos_fast(self); + } + + fn cosh(self: FP16x16W) -> FP16x16W { + return hyp::cosh(self); + } + + fn floor(self: FP16x16W) -> FP16x16W { + return core::floor(self); + } + + // Calculates the natural exponent of x: e^x + fn exp(self: FP16x16W) -> FP16x16W { + return core::exp(self); + } + + // Calculates the binary exponent of x: 2^x + fn exp2(self: FP16x16W) -> FP16x16W { + return core::exp2(self); + } + + // Calculates the natural logarithm of x: ln(x) + // self must be greater than zero + fn ln(self: FP16x16W) -> FP16x16W { + return core::ln(self); + } + + // Calculates the binary logarithm of x: log2(x) + // self must be greather than zero + fn log2(self: FP16x16W) -> FP16x16W { + return core::log2(self); + } + + // Calculates the base 10 log of x: log10(x) + // self must be greater than zero + fn log10(self: FP16x16W) -> FP16x16W { + return core::log10(self); + } + + // Calclates the value of x^y and checks for overflow before returning + // self is a fixed point value + // b is a fixed point value + fn pow(self: FP16x16W, b: FP16x16W) -> FP16x16W { + return core::pow(self, b); + } + + fn round(self: FP16x16W) -> FP16x16W { + return core::round(self); + } + + fn sin(self: FP16x16W) -> FP16x16W { + return trig::sin_fast(self); + } + + fn sin_fast(self: FP16x16W) -> FP16x16W { + return trig::sin_fast(self); + } + + fn sinh(self: FP16x16W) -> FP16x16W { + return hyp::sinh(self); + } + + // Calculates the square root of a fixed point value + // x must be positive + fn sqrt(self: FP16x16W) -> FP16x16W { + return core::sqrt(self); + } + + fn tan(self: FP16x16W) -> FP16x16W { + return trig::tan_fast(self); + } + + fn tan_fast(self: FP16x16W) -> FP16x16W { + return trig::tan_fast(self); + } + + fn tanh(self: FP16x16W) -> FP16x16W { + return hyp::tanh(self); + } + + fn sign(self: FP16x16W) -> FP16x16W { + return core::sign(self); + } +} + + +impl FP16x16WPrint of PrintTrait { + fn print(self: FP16x16W) { + self.sign.print(); + self.mag.print(); + } +} + +// Into a raw felt without unscaling +impl FP16x16WIntoFelt252 of Into { + fn into(self: FP16x16W) -> felt252 { + let mag_felt = self.mag.into(); + + if self.sign { + return mag_felt * -1; + } else { + return mag_felt * 1; + } + } +} + +impl FP16x16WIntoI32 of Into { + fn into(self: FP16x16W) -> i32 { + _i32_into_fp(self) + } +} + +impl FP16x16WTryIntoI8 of TryInto { + fn try_into(self: FP16x16W) -> Option { + _i8_try_from_fp(self) + } +} + + +impl FP16x16WTryIntoU128 of TryInto { + fn try_into(self: FP16x16W) -> Option { + if self.sign { + return Option::None(()); + } else { + // Unscale the magnitude and round down + return Option::Some((self.mag / ONE).into()); + } + } +} + +impl FP16x16WTryIntoU64 of TryInto { + fn try_into(self: FP16x16W) -> Option { + if self.sign { + return Option::None(()); + } else { + // Unscale the magnitude and round down + return Option::Some((self.mag / ONE).into()); + } + } +} + +impl FP16x16WTryIntoU32 of TryInto { + fn try_into(self: FP16x16W) -> Option { + if self.sign { + return Option::None(()); + } else { + // Unscale the magnitude and round down + return Option::Some(self.mag / ONE); + } + } +} + +impl FP16x16WTryIntoU16 of TryInto { + fn try_into(self: FP16x16W) -> Option { + if self.sign { + Option::None(()) + } else { + // Unscale the magnitude and round down + return (self.mag / ONE).try_into(); + } + } +} + +impl FP16x16WTryIntoU8 of TryInto { + fn try_into(self: FP16x16W) -> Option { + if self.sign { + Option::None(()) + } else { + // Unscale the magnitude and round down + return (self.mag / ONE).try_into(); + } + } +} + +impl FP16x16WPartialEq of PartialEq { + #[inline(always)] + fn eq(lhs: @FP16x16W, rhs: @FP16x16W) -> bool { + return core::eq(lhs, rhs); + } + + #[inline(always)] + fn ne(lhs: @FP16x16W, rhs: @FP16x16W) -> bool { + return core::ne(lhs, rhs); + } +} + +impl FP16x16WAdd of Add { + fn add(lhs: FP16x16W, rhs: FP16x16W) -> FP16x16W { + return core::add(lhs, rhs); + } +} + +impl FP16x16WAddEq of AddEq { + #[inline(always)] + fn add_eq(ref self: FP16x16W, other: FP16x16W) { + self = Add::add(self, other); + } +} + +impl FP16x16WSub of Sub { + fn sub(lhs: FP16x16W, rhs: FP16x16W) -> FP16x16W { + return core::sub(lhs, rhs); + } +} + +impl FP16x16WSubEq of SubEq { + #[inline(always)] + fn sub_eq(ref self: FP16x16W, other: FP16x16W) { + self = Sub::sub(self, other); + } +} + +impl FP16x16WMul of Mul { + fn mul(lhs: FP16x16W, rhs: FP16x16W) -> FP16x16W { + return core::mul(lhs, rhs); + } +} + +impl FP16x16WMulEq of MulEq { + #[inline(always)] + fn mul_eq(ref self: FP16x16W, other: FP16x16W) { + self = Mul::mul(self, other); + } +} + +impl FP16x16WDiv of Div { + fn div(lhs: FP16x16W, rhs: FP16x16W) -> FP16x16W { + return core::div(lhs, rhs); + } +} + +impl FP16x16WDivEq of DivEq { + #[inline(always)] + fn div_eq(ref self: FP16x16W, other: FP16x16W) { + self = Div::div(self, other); + } +} + +impl FP16x16WPartialOrd of PartialOrd { + #[inline(always)] + fn ge(lhs: FP16x16W, rhs: FP16x16W) -> bool { + return core::ge(lhs, rhs); + } + + #[inline(always)] + fn gt(lhs: FP16x16W, rhs: FP16x16W) -> bool { + return core::gt(lhs, rhs); + } + + #[inline(always)] + fn le(lhs: FP16x16W, rhs: FP16x16W) -> bool { + return core::le(lhs, rhs); + } + + #[inline(always)] + fn lt(lhs: FP16x16W, rhs: FP16x16W) -> bool { + return core::lt(lhs, rhs); + } +} + +impl FP16x16WNeg of Neg { + #[inline(always)] + fn neg(a: FP16x16W) -> FP16x16W { + return core::neg(a); + } +} + +impl FP16x16WRem of Rem { + #[inline(always)] + fn rem(lhs: FP16x16W, rhs: FP16x16W) -> FP16x16W { + return core::rem(lhs, rhs); + } +} + + +/// INTERNAL + +fn _i32_into_fp(x: FP16x16W) -> i32 { + i32 { mag: (x.mag / ONE).try_into().unwrap(), sign: x.sign } +} + +fn _i8_try_from_fp(x: FP16x16W) -> Option { + let unscaled_mag: Option = (x.mag / ONE).try_into(); + + match unscaled_mag { + Option::Some(val) => Option::Some(i8 { mag: unscaled_mag.unwrap(), sign: x.sign }), + Option::None(_) => Option::None(()) + } +} diff --git a/src/numbers/fixed_point/implementations/fp16x16wide/helpers.cairo b/src/numbers/fixed_point/implementations/fp16x16wide/helpers.cairo new file mode 100644 index 000000000..c2a65e156 --- /dev/null +++ b/src/numbers/fixed_point/implementations/fp16x16wide/helpers.cairo @@ -0,0 +1,41 @@ +use debug::PrintTrait; +use traits::Into; + +use orion::numbers::fixed_point::implementations::fp16x16wide::core::{ + HALF, ONE, TWO, FP16x16W, FP16x16WImpl, FP16x16WSub, FP16x16WDiv, FixedTrait, FP16x16WPrint +}; + +const DEFAULT_PRECISION: u64 = 7; // 1e-4 + +// To use `DEFAULT_PRECISION`, final arg is: `Option::None(())`. +// To use `custom_precision` of 430_u32: `Option::Some(430_u32)`. +fn assert_precise(result: FP16x16W, expected: felt252, msg: felt252, custom_precision: Option) { + let precision = match custom_precision { + Option::Some(val) => val, + Option::None(_) => DEFAULT_PRECISION, + }; + + let diff = (result - FixedTrait::from_felt(expected)).mag; + + if (diff > precision) { + result.print(); + assert(diff <= precision, msg); + } +} + +fn assert_relative( + result: FP16x16W, expected: felt252, msg: felt252, custom_precision: Option +) { + let precision = match custom_precision { + Option::Some(val) => val, + Option::None(_) => DEFAULT_PRECISION, + }; + + let diff = result - FixedTrait::from_felt(expected); + let rel_diff = (diff / result).mag; + + if (rel_diff > precision) { + result.print(); + assert(rel_diff <= precision, msg); + } +} diff --git a/src/numbers/fixed_point/implementations/fp16x16wide/math.cairo b/src/numbers/fixed_point/implementations/fp16x16wide/math.cairo new file mode 100644 index 000000000..970c65f30 --- /dev/null +++ b/src/numbers/fixed_point/implementations/fp16x16wide/math.cairo @@ -0,0 +1,5 @@ +mod core; +mod comp; +mod lut; +mod trig; +mod hyp; diff --git a/src/numbers/fixed_point/implementations/fp16x16wide/math/comp.cairo b/src/numbers/fixed_point/implementations/fp16x16wide/math/comp.cairo new file mode 100644 index 000000000..63a3e4855 --- /dev/null +++ b/src/numbers/fixed_point/implementations/fp16x16wide/math/comp.cairo @@ -0,0 +1,76 @@ +use orion::numbers::fixed_point::implementations::fp16x16wide::core::{ + FP16x16W, FixedTrait, FP16x16WImpl, FP16x16WPartialOrd, FP16x16WPartialEq +}; + +fn max(a: FP16x16W, b: FP16x16W) -> FP16x16W { + if (a >= b) { + return a; + } else { + return b; + } +} + +fn min(a: FP16x16W, b: FP16x16W) -> FP16x16W { + if (a <= b) { + return a; + } else { + return b; + } +} + +fn xor(a: FP16x16W, b: FP16x16W) -> bool { + if (a == FixedTrait::new(0, false) || b == FixedTrait::new(0, false)) && (a != b) { + return true; + } else { + return false; + } +} + +fn or(a: FP16x16W, b: FP16x16W) -> bool { + let zero = FixedTrait::new(0, false); + if a == zero && b == zero { + return false; + } else { + return true; + } +} + +// Tests -------------------------------------------------------------------------------------------------------------- + +#[test] +fn test_max() { + let a = FixedTrait::new_unscaled(1, false); + let b = FixedTrait::new_unscaled(0, false); + let c = FixedTrait::new_unscaled(1, true); + + assert(max(a, a) == a, 'max(a, a)'); + assert(max(a, b) == a, 'max(a, b)'); + assert(max(a, c) == a, 'max(a, c)'); + + assert(max(b, a) == a, 'max(b, a)'); + assert(max(b, b) == b, 'max(b, b)'); + assert(max(b, c) == b, 'max(b, c)'); + + assert(max(c, a) == a, 'max(c, a)'); + assert(max(c, b) == b, 'max(c, b)'); + assert(max(c, c) == c, 'max(c, c)'); +} + +#[test] +fn test_min() { + let a = FixedTrait::new_unscaled(1, false); + let b = FixedTrait::new_unscaled(0, false); + let c = FixedTrait::new_unscaled(1, true); + + assert(min(a, a) == a, 'min(a, a)'); + assert(min(a, b) == b, 'min(a, b)'); + assert(min(a, c) == c, 'min(a, c)'); + + assert(min(b, a) == b, 'min(b, a)'); + assert(min(b, b) == b, 'min(b, b)'); + assert(min(b, c) == c, 'min(b, c)'); + + assert(min(c, a) == c, 'min(c, a)'); + assert(min(c, b) == c, 'min(c, b)'); + assert(min(c, c) == c, 'min(c, c)'); +} diff --git a/src/numbers/fixed_point/implementations/fp16x16wide/math/core.cairo b/src/numbers/fixed_point/implementations/fp16x16wide/math/core.cairo new file mode 100644 index 000000000..33c1c6d85 --- /dev/null +++ b/src/numbers/fixed_point/implementations/fp16x16wide/math/core.cairo @@ -0,0 +1,659 @@ +use core::debug::PrintTrait; +use option::OptionTrait; +use result::{ResultTrait, ResultTraitImpl}; +use traits::{Into, TryInto}; +use integer::{u64_safe_divmod, u64_as_non_zero, u64_wide_mul}; + +use orion::numbers::fixed_point::implementations::fp16x16wide::core::{ + HALF, ONE, MAX, FP16x16W, FP16x16WImpl, FP16x16WAdd, FP16x16WAddEq, FP16x16WSub, FP16x16WMul, + FP16x16WMulEq, FP16x16WTryIntoU128, FP16x16WPartialEq, FP16x16WPartialOrd, FP16x16WSubEq, FP16x16WNeg, + FP16x16WDiv, FP16x16WIntoFelt252, FixedTrait +}; +use orion::numbers::fixed_point::implementations::fp16x16wide::math::lut; + +// PUBLIC + +fn abs(a: FP16x16W) -> FP16x16W { + return FixedTrait::new(a.mag, false); +} + +fn add(a: FP16x16W, b: FP16x16W) -> FP16x16W { + if a.sign == b.sign { + return FixedTrait::new(a.mag + b.mag, a.sign); + } + + if a.mag == b.mag { + return FixedTrait::ZERO(); + } + + if (a.mag > b.mag) { + return FixedTrait::new(a.mag - b.mag, a.sign); + } else { + return FixedTrait::new(b.mag - a.mag, b.sign); + } +} + +fn ceil(a: FP16x16W) -> FP16x16W { + let (div, rem) = u64_safe_divmod(a.mag, u64_as_non_zero(ONE)); + + if rem == 0 { + return a; + } else if !a.sign { + return FixedTrait::new_unscaled(div + 1, false); + } else if div == 0 { + return FixedTrait::new_unscaled(0, false); + } else { + return FixedTrait::new_unscaled(div, true); + } +} + +fn div(a: FP16x16W, b: FP16x16W) -> FP16x16W { + let a_u64 = integer::u64_wide_mul(a.mag, ONE); + let res_u64 = a_u64 / b.mag.into(); + + // Re-apply sign + return FixedTrait::new(res_u64.try_into().unwrap(), a.sign ^ b.sign); +} + +fn eq(a: @FP16x16W, b: @FP16x16W) -> bool { + return (*a.mag == *b.mag) && (*a.sign == *b.sign); +} + +// Calculates the natural exponent of x: e^x +fn exp(a: FP16x16W) -> FP16x16W { + return exp2(FixedTrait::new(94548, false) * a); // log2(e) * 2^23 ≈ 12102203 +} + +// Calculates the binary exponent of x: 2^x +fn exp2(a: FP16x16W) -> FP16x16W { + if (a.mag == 0) { + return FixedTrait::ONE(); + } + + let (int_part, frac_part) = integer::u64_safe_divmod(a.mag, u64_as_non_zero(ONE)); + let int_res = FixedTrait::new_unscaled(lut::exp2(int_part), false); + let mut res_u = int_res; + + if frac_part != 0 { + let frac = FixedTrait::new(frac_part, false); + let r7 = FixedTrait::new(1, false) * frac; + let r6 = (r7 + FixedTrait::new(10, false)) * frac; + let r5 = (r6 + FixedTrait::new(87, false)) * frac; + let r4 = (r5 + FixedTrait::new(630, false)) * frac; + let r3 = (r4 + FixedTrait::new(3638, false)) * frac; + let r2 = (r3 + FixedTrait::new(15743, false)) * frac; + let r1 = (r2 + FixedTrait::new(45426, false)) * frac; + res_u = res_u * (r1 + FixedTrait::ONE()); + } + + if (a.sign == true) { + return FixedTrait::ONE() / res_u; + } else { + return res_u; + } +} + +fn exp2_int(exp: u64) -> FP16x16W { + return FixedTrait::new_unscaled(lut::exp2(exp), false); +} + +fn floor(a: FP16x16W) -> FP16x16W { + let (div, rem) = integer::u64_safe_divmod(a.mag, u64_as_non_zero(ONE)); + + if rem == 0 { + return a; + } else if !a.sign { + return FixedTrait::new_unscaled(div, false); + } else { + return FixedTrait::new_unscaled(div + 1, true); + } +} + +fn ge(a: FP16x16W, b: FP16x16W) -> bool { + if a.sign != b.sign { + return !a.sign; + } else { + return (a.mag == b.mag) || ((a.mag > b.mag) ^ a.sign); + } +} + +fn gt(a: FP16x16W, b: FP16x16W) -> bool { + if a.sign != b.sign { + return !a.sign; + } else { + return (a.mag != b.mag) && ((a.mag > b.mag) ^ a.sign); + } +} + +fn le(a: FP16x16W, b: FP16x16W) -> bool { + if a.sign != b.sign { + return a.sign; + } else { + return (a.mag == b.mag) || ((a.mag < b.mag) ^ a.sign); + } +} + +// Calculates the natural logarithm of x: ln(x) +// self must be greater than zero +fn ln(a: FP16x16W) -> FP16x16W { + return FixedTrait::new(45426, false) * log2(a); // ln(2) = 0.693... +} + +// Calculates the binary logarithm of x: log2(x) +// self must be greather than zero +fn log2(a: FP16x16W) -> FP16x16W { + assert(a.sign == false, 'must be positive'); + + if (a.mag == ONE) { + return FixedTrait::ZERO(); + } else if (a.mag < ONE) { + // Compute true inverse binary log if 0 < x < 1 + let div = FixedTrait::ONE() / a; + return -log2(div); + } + + let whole = a.mag / ONE; + let (msb, div) = lut::msb(whole); + + if a.mag == div * ONE { + return FixedTrait::new_unscaled(msb, false); + } else { + let norm = a / FixedTrait::new_unscaled(div, false); + let r8 = FixedTrait::new(596, true) * norm; + let r7 = (r8 + FixedTrait::new(8116, false)) * norm; + let r6 = (r7 + FixedTrait::new(49044, true)) * norm; + let r5 = (r6 + FixedTrait::new(172935, false)) * norm; + let r4 = (r5 + FixedTrait::new(394096, true)) * norm; + let r3 = (r4 + FixedTrait::new(608566, false)) * norm; + let r2 = (r3 + FixedTrait::new(655828, true)) * norm; + let r1 = (r2 + FixedTrait::new(534433, false)) * norm; + return r1 + FixedTrait::new(224487, true) + FixedTrait::new_unscaled(msb, false); + } +} + +// Calculates the base 10 log of x: log10(x) +// self must be greater than zero +fn log10(a: FP16x16W) -> FP16x16W { + return FixedTrait::new(19728, false) * log2(a); // log10(2) = 0.301... +} + +fn lt(a: FP16x16W, b: FP16x16W) -> bool { + if a.sign != b.sign { + return a.sign; + } else { + return (a.mag != b.mag) && ((a.mag < b.mag) ^ a.sign); + } +} + +fn mul(a: FP16x16W, b: FP16x16W) -> FP16x16W { + let prod_u128 = integer::u64_wide_mul(a.mag, b.mag); + + // Re-apply sign + return FixedTrait::new((prod_u128 / ONE.into()).try_into().unwrap(), a.sign ^ b.sign); +} + +fn ne(a: @FP16x16W, b: @FP16x16W) -> bool { + return (*a.mag != *b.mag) || (*a.sign != *b.sign); +} + +fn neg(a: FP16x16W) -> FP16x16W { + if a.mag == 0 { + return a; + } else if !a.sign { + return FixedTrait::new(a.mag, !a.sign); + } else { + return FixedTrait::new(a.mag, false); + } +} + +// Calclates the value of x^y and checks for overflow before returning +// self is a FP16x16W point value +// b is a FP16x16W point value +fn pow(a: FP16x16W, b: FP16x16W) -> FP16x16W { + let (div, rem) = integer::u64_safe_divmod(b.mag, u64_as_non_zero(ONE)); + + // use the more performant integer pow when y is an int + if (rem == 0) { + return pow_int(a, b.mag / ONE, b.sign); + } + + // x^y = exp(y*ln(x)) for x > 0 will error for x < 0 + return exp(b * ln(a)); +} + +// Calclates the value of a^b and checks for overflow before returning +fn pow_int(a: FP16x16W, b: u64, sign: bool) -> FP16x16W { + let mut x = a; + let mut n = b; + + if sign == true { + x = FixedTrait::ONE() / x; + } + + if n == 0 { + return FixedTrait::ONE(); + } + + let mut y = FixedTrait::ONE(); + let two = integer::u64_as_non_zero(2); + + loop { + if n <= 1 { + break; + } + + let (div, rem) = integer::u64_safe_divmod(n, two); + + if rem == 1 { + y = x * y; + } + + x = x * x; + n = div; + }; + + return x * y; +} + +fn rem(a: FP16x16W, b: FP16x16W) -> FP16x16W { + return a - floor(a / b) * b; +} + +fn round(a: FP16x16W) -> FP16x16W { + let (div, rem) = integer::u64_safe_divmod(a.mag, u64_as_non_zero(ONE)); + + if (HALF <= rem) { + return FixedTrait::new_unscaled(div + 1, a.sign); + } else { + return FixedTrait::new_unscaled(div, a.sign); + } +} + +// Calculates the square root of a FP16x16W point value +// x must be positive +fn sqrt(a: FP16x16W) -> FP16x16W { + assert(a.sign == false, 'must be positive'); + + let root = integer::u64_sqrt(a.mag.into() * ONE.into()); + return FixedTrait::new(root.into(), false); +} + +fn sub(a: FP16x16W, b: FP16x16W) -> FP16x16W { + return add(a, -b); +} + +fn sign(a: FP16x16W) -> FP16x16W { + if a.mag == 0 { + FixedTrait::new(0, false) + } else { + FixedTrait::new(ONE, a.sign) + } +} + +// Tests -------------------------------------------------------------------------------------------------------------- + +use orion::numbers::fixed_point::implementations::fp16x16wide::helpers::{ + assert_precise, assert_relative +}; +use orion::numbers::fixed_point::implementations::fp16x16wide::math::trig::{PI, HALF_PI}; + +#[test] +fn test_into() { + let a = FixedTrait::::new_unscaled(5, false); + assert(a.mag == 5 * ONE, 'invalid result'); +} + +#[test] +fn test_try_into_u128() { + // Positive unscaled + let a = FixedTrait::::new_unscaled(5, false); + assert(a.try_into().unwrap() == 5_u128, 'invalid result'); + + // Positive scaled + let b = FixedTrait::::new(5 * ONE, false); + assert(b.try_into().unwrap() == 5_u128, 'invalid result'); + + // Zero + let d = FixedTrait::::new_unscaled(0, false); + assert(d.try_into().unwrap() == 0_u128, 'invalid result'); +} + +#[test] +#[should_panic] +fn test_negative_try_into_u128() { + let a = FixedTrait::::new_unscaled(1, true); + let a: u128 = a.try_into().unwrap(); +} + +#[test] +#[available_gas(1000000)] +fn test_acos() { + let a = FixedTrait::::ONE(); + assert(a.acos().into() == 0, 'invalid one'); +} + +#[test] +#[available_gas(1000000)] +fn test_asin() { + let a = FixedTrait::ONE(); + assert_precise(a.asin(), HALF_PI.into(), 'invalid one', Option::None(())); // PI / 2 +} + +#[test] +#[available_gas(2000000)] +fn test_atan() { + let a = FixedTrait::new(2 * ONE, false); + assert_relative(a.atan(), 72558, 'invalid two', Option::None(())); +} + +#[test] +fn test_ceil() { + let a = FixedTrait::new(190054, false); // 2.9 + assert(ceil(a).mag == 3 * ONE, 'invalid pos decimal'); +} + +#[test] +fn test_floor() { + let a = FixedTrait::new(190054, false); // 2.9 + assert(floor(a).mag == 2 * ONE, 'invalid pos decimal'); +} + +#[test] +fn test_round() { + let a = FixedTrait::new(190054, false); // 2.9 + assert(round(a).mag == 3 * ONE, 'invalid pos decimal'); +} + +#[test] +#[should_panic] +fn test_sqrt_fail() { + let a = FixedTrait::new_unscaled(25, true); + sqrt(a); +} + +#[test] +fn test_sqrt() { + let mut a = FixedTrait::new_unscaled(0, false); + assert(sqrt(a).mag == 0, 'invalid zero root'); + a = FixedTrait::new_unscaled(25, false); + assert(sqrt(a).mag == 5 * ONE, 'invalid pos root'); +} + + +#[test] +#[available_gas(100000)] +fn test_msb() { + let a = FixedTrait::::new_unscaled(100, false); + let (msb, div) = lut::msb(a.mag / ONE); + assert(msb == 6, 'invalid msb'); + assert(div == 64, 'invalid msb ceil'); +} + +#[test] +#[available_gas(600000)] +fn test_pow() { + let a = FixedTrait::new_unscaled(3, false); + let b = FixedTrait::new_unscaled(4, false); + assert(pow(a, b).mag == 81 * ONE, 'invalid pos base power'); +} + +#[test] +#[available_gas(900000)] +fn test_pow_frac() { + let a = FixedTrait::new_unscaled(3, false); + let b = FixedTrait::new(32768, false); // 0.5 + assert_relative( + pow(a, b), 113512, 'invalid pos base power', Option::None(()) + ); // 1.7320508075688772 +} + +#[test] +#[available_gas(1000000)] +fn test_exp() { + let a = FixedTrait::new_unscaled(2, false); + assert_relative(exp(a), 484249, 'invalid exp of 2', Option::None(())); // 7.389056098793725 +} + +#[test] +#[available_gas(400000)] +fn test_exp2() { + let a = FixedTrait::new_unscaled(5, false); + assert(exp2(a).mag == 2097152, 'invalid exp2 of 2'); +} + +#[test] +#[available_gas(20000)] +fn test_exp2_int() { + assert(exp2_int(5).into() == 2097152, 'invalid exp2 of 2'); +} + +#[test] +#[available_gas(1000000)] +fn test_ln() { + let mut a = FixedTrait::new_unscaled(1, false); + assert(ln(a).mag == 0, 'invalid ln of 1'); + + a = FixedTrait::new(178145, false); + assert_relative(ln(a), ONE.into(), 'invalid ln of 2.7...', Option::None(())); +} + +#[test] +#[available_gas(1000000)] +fn test_log2() { + let mut a = FixedTrait::new_unscaled(32, false); + assert(log2(a) == FixedTrait::new_unscaled(5, false), 'invalid log2 32'); + + a = FixedTrait::new_unscaled(10, false); + assert_relative(log2(a), 217706, 'invalid log2 10', Option::None(())); // 3.321928094887362 +} + +#[test] +#[available_gas(1000000)] +fn test_log10() { + let a = FixedTrait::new_unscaled(100, false); + assert_relative(log10(a), 2 * ONE.into(), 'invalid log10', Option::None(())); +} + +#[test] +fn test_eq() { + let a = FixedTrait::new_unscaled(42, false); + let b = FixedTrait::new_unscaled(42, false); + let c = eq(@a, @b); + assert(c == true, 'invalid result'); +} + +#[test] +fn test_ne() { + let a = FixedTrait::new_unscaled(42, false); + let b = FixedTrait::new_unscaled(42, false); + let c = ne(@a, @b); + assert(c == false, 'invalid result'); +} + +#[test] +fn test_add() { + let a = FixedTrait::new_unscaled(1, false); + let b = FixedTrait::new_unscaled(2, false); + assert(add(a, b) == FixedTrait::new_unscaled(3, false), 'invalid result'); +} + +#[test] +fn test_add_eq() { + let mut a = FixedTrait::new_unscaled(1, false); + let b = FixedTrait::new_unscaled(2, false); + a += b; + assert(a == FixedTrait::::new_unscaled(3, false), 'invalid result'); +} + +#[test] +fn test_sub() { + let a = FixedTrait::new_unscaled(5, false); + let b = FixedTrait::new_unscaled(2, false); + let c = a - b; + assert(c == FixedTrait::::new_unscaled(3, false), 'false result invalid'); +} + +#[test] +fn test_sub_eq() { + let mut a = FixedTrait::new_unscaled(5, false); + let b = FixedTrait::new_unscaled(2, false); + a -= b; + assert(a == FixedTrait::::new_unscaled(3, false), 'invalid result'); +} + +#[test] +#[available_gas(100000)] +fn test_mul_pos() { + let a = FP16x16W { mag: 190054, sign: false }; + let b = FP16x16W { mag: 190054, sign: false }; + let c = a * b; + assert(c.mag == 551155, 'invalid result'); +} + +#[test] +fn test_mul_neg() { + let a = FixedTrait::new_unscaled(5, false); + let b = FixedTrait::new_unscaled(2, true); + let c = a * b; + assert(c == FixedTrait::::new_unscaled(10, true), 'invalid result'); +} + +#[test] +fn test_mul_eq() { + let mut a = FixedTrait::new_unscaled(5, false); + let b = FixedTrait::new_unscaled(2, true); + a *= b; + assert(a == FixedTrait::::new_unscaled(10, true), 'invalid result'); +} + +#[test] +fn test_div() { + let a = FixedTrait::new_unscaled(10, false); + let b = FixedTrait::::new(190054, false); // 2.9 + let c = a / b; + assert(c.mag == 225986, 'invalid pos decimal'); // 3.4482758620689653 +} + +#[test] +fn test_le() { + let a = FixedTrait::new_unscaled(1, false); + let b = FixedTrait::new_unscaled(0, false); + let c = FixedTrait::::new_unscaled(1, true); + + assert(a <= a, 'a <= a'); + assert(a <= b == false, 'a <= b'); + assert(a <= c == false, 'a <= c'); + + assert(b <= a, 'b <= a'); + assert(b <= b, 'b <= b'); + assert(b <= c == false, 'b <= c'); + + assert(c <= a, 'c <= a'); + assert(c <= b, 'c <= b'); + assert(c <= c, 'c <= c'); +} + +#[test] +fn test_lt() { + let a = FixedTrait::new_unscaled(1, false); + let b = FixedTrait::new_unscaled(0, false); + let c = FixedTrait::::new_unscaled(1, true); + + assert(a < a == false, 'a < a'); + assert(a < b == false, 'a < b'); + assert(a < c == false, 'a < c'); + + assert(b < a, 'b < a'); + assert(b < b == false, 'b < b'); + assert(b < c == false, 'b < c'); + + assert(c < a, 'c < a'); + assert(c < b, 'c < b'); + assert(c < c == false, 'c < c'); +} + +#[test] +fn test_ge() { + let a = FixedTrait::new_unscaled(1, false); + let b = FixedTrait::new_unscaled(0, false); + let c = FixedTrait::::new_unscaled(1, true); + + assert(a >= a, 'a >= a'); + assert(a >= b, 'a >= b'); + assert(a >= c, 'a >= c'); + + assert(b >= a == false, 'b >= a'); + assert(b >= b, 'b >= b'); + assert(b >= c, 'b >= c'); + + assert(c >= a == false, 'c >= a'); + assert(c >= b == false, 'c >= b'); + assert(c >= c, 'c >= c'); +} + +#[test] +fn test_gt() { + let a = FixedTrait::new_unscaled(1, false); + let b = FixedTrait::new_unscaled(0, false); + let c = FixedTrait::::new_unscaled(1, true); + + assert(a > a == false, 'a > a'); + assert(a > b, 'a > b'); + assert(a > c, 'a > c'); + + assert(b > a == false, 'b > a'); + assert(b > b == false, 'b > b'); + assert(b > c, 'b > c'); + + assert(c > a == false, 'c > a'); + assert(c > b == false, 'c > b'); + assert(c > c == false, 'c > c'); +} + +#[test] +#[available_gas(1000000)] +fn test_cos() { + let a = FixedTrait::::new(HALF_PI, false); + assert(a.cos().into() == 0, 'invalid half pi'); +} + +#[test] +#[available_gas(1000000)] +fn test_sin() { + let a = FixedTrait::new(HALF_PI, false); + assert_precise(a.sin(), ONE.into(), 'invalid half pi', Option::None(())); +} + +#[test] +#[available_gas(2000000)] +fn test_tan() { + let a = FixedTrait::::new(HALF_PI / 2, false); + assert(a.tan().mag == 65536, 'invalid quarter pi'); +} + +#[test] +#[available_gas(2000000)] +fn test_sign() { + let a = FixedTrait::::new(0, false); + assert(a.sign().mag == 0 && !a.sign().sign, 'invalid sign (0, true)'); + + let a = FixedTrait::::new(HALF, true); + assert(a.sign().mag == ONE && a.sign().sign, 'invalid sign (HALF, true)'); + + let a = FixedTrait::::new(HALF, false); + assert(a.sign().mag == ONE && !a.sign().sign, 'invalid sign (HALF, false)'); + + let a = FixedTrait::::new(ONE, true); + assert(a.sign().mag == ONE && a.sign().sign, 'invalid sign (ONE, true)'); + + let a = FixedTrait::::new(ONE, false); + assert(a.sign().mag == ONE && !a.sign().sign, 'invalid sign (ONE, false)'); +} + +#[test] +#[should_panic] +#[available_gas(2000000)] +fn test_sign_fail() { + let a = FixedTrait::::new(HALF, true); + assert(a.sign().mag != ONE && !a.sign().sign, 'invalid sign (HALF, true)'); +} diff --git a/src/numbers/fixed_point/implementations/fp16x16wide/math/hyp.cairo b/src/numbers/fixed_point/implementations/fp16x16wide/math/hyp.cairo new file mode 100644 index 000000000..3286b6345 --- /dev/null +++ b/src/numbers/fixed_point/implementations/fp16x16wide/math/hyp.cairo @@ -0,0 +1,159 @@ +use core::debug::PrintTrait; +use orion::numbers::fixed_point::implementations::fp16x16wide::core::{ + HALF, ONE, TWO, FP16x16W, FP16x16WImpl, FP16x16WAdd, FP16x16WAddEq, FP16x16WSub, FP16x16WMul, + FP16x16WMulEq, FP16x16WTryIntoU128, FP16x16WPartialEq, FP16x16WPartialOrd, FP16x16WSubEq, FP16x16WNeg, + FP16x16WDiv, FP16x16WIntoFelt252, FixedTrait +}; + +// Calculates hyperbolic cosine of a (fixed point) +fn cosh(a: FP16x16W) -> FP16x16W { + let ea = a.exp(); + return (ea + (FixedTrait::ONE() / ea)) / FixedTrait::new(TWO, false); +} + +// Calculates hyperbolic sine of a (fixed point) +fn sinh(a: FP16x16W) -> FP16x16W { + let ea = a.exp(); + return (ea - (FixedTrait::ONE() / ea)) / FixedTrait::new(TWO, false); +} + +// Calculates hyperbolic tangent of a (fixed point) +fn tanh(a: FP16x16W) -> FP16x16W { + let ea = a.exp(); + let ea_i = FixedTrait::ONE() / ea; + return (ea - ea_i) / (ea + ea_i); +} + +// Calculates inverse hyperbolic cosine of a (fixed point) +fn acosh(a: FP16x16W) -> FP16x16W { + let root = (a * a - FixedTrait::ONE()).sqrt(); + return (a + root).ln(); +} + +// Calculates inverse hyperbolic sine of a (fixed point) +fn asinh(a: FP16x16W) -> FP16x16W { + let root = (a * a + FixedTrait::ONE()).sqrt(); + return (a + root).ln(); +} + +// Calculates inverse hyperbolic tangent of a (fixed point) +fn atanh(a: FP16x16W) -> FP16x16W { + let one = FixedTrait::ONE(); + let ln_arg = (one + a) / (one - a); + return ln_arg.ln() / FixedTrait::new(TWO, false); +} + +// Tests -------------------------------------------------------------------------------------------------------------- + +use option::OptionTrait; +use traits::Into; + +use orion::numbers::fixed_point::implementations::fp16x16wide::helpers::assert_precise; + +#[test] +#[available_gas(10000000)] +fn test_cosh() { + let a = FixedTrait::new(TWO, false); + assert_precise(cosh(a), 246550, 'invalid two', Option::None(())); // 3.5954653836066 + + let a = FixedTrait::ONE(); + assert_precise(cosh(a), 101127, 'invalid one', Option::None(())); // 1.42428174592510 + + let a = FixedTrait::ZERO(); + assert_precise(cosh(a), ONE.into(), 'invalid zero', Option::None(())); + + let a = FixedTrait::ONE(); + assert_precise(cosh(a), 101127, 'invalid neg one', Option::None(())); // 1.42428174592510 + + let a = FixedTrait::new(TWO, true); + assert_precise(cosh(a), 246568, 'invalid neg two', Option::None(())); // 3.5954653836066 +} + +#[test] +#[available_gas(10000000)] +fn test_sinh() { + let a = FixedTrait::new(TWO, false); + assert_precise(sinh(a), 237681, 'invalid two', Option::None(())); // 3.48973469357602 + + let a = FixedTrait::ONE(); + assert_precise(sinh(a), 77018, 'invalid one', Option::None(())); // 1.13687593250230 + + let a = FixedTrait::ZERO(); + assert(sinh(a).into() == 0, 'invalid zero'); + + let a = FixedTrait::new(ONE, true); + assert_precise(sinh(a), -77018, 'invalid neg one', Option::None(())); // -1.13687593250230 + + let a = FixedTrait::new(TWO, true); + assert_precise(sinh(a), -237699, 'invalid neg two', Option::None(())); // -3.48973469357602 +} + +#[test] +#[available_gas(10000000)] +fn test_tanh() { + let a = FixedTrait::new(TWO, false); + assert_precise(tanh(a), 63179, 'invalid two', Option::None(())); // 0.75314654693321 + + let a = FixedTrait::ONE(); + assert_precise(tanh(a), 49912, 'invalid one', Option::None(())); // 0.59499543433175 + + let a = FixedTrait::ZERO(); + assert(tanh(a).into() == 0, 'invalid zero'); + + let a = FixedTrait::new(ONE, true); + assert_precise(tanh(a), -49912, 'invalid neg one', Option::None(())); // -0.59499543433175 + + let a = FixedTrait::new(TWO, true); + assert_precise(tanh(a), -63179, 'invalid neg two', Option::None(())); // 0.75314654693321 +} + +#[test] +#[available_gas(10000000)] +fn test_acosh() { + let a = FixedTrait::new(246559, false); // 3.5954653836066 + assert_precise(acosh(a), 131072, 'invalid two', Option::None(())); + + let a = FixedTrait::new(101127, false); // 1.42428174592510 + assert_precise(acosh(a), ONE.into(), 'invalid one', Option::None(())); + + let a = FixedTrait::ONE(); // 1 + assert(acosh(a).into() == 0, 'invalid zero'); +} + +#[test] +#[available_gas(10000000)] +fn test_asinh() { + let a = FixedTrait::new(237690, false); // 3.48973469357602 + assert_precise(asinh(a), 131072, 'invalid two', Option::None(())); + + let a = FixedTrait::new(77018, false); // 1.13687593250230 + assert_precise(asinh(a), ONE.into(), 'invalid one', Option::None(())); + + let a = FixedTrait::ZERO(); + assert(asinh(a).into() == 0, 'invalid zero'); + + let a = FixedTrait::new(77018, true); // -1.13687593250230 + assert_precise(asinh(a), -ONE.into(), 'invalid neg one', Option::None(())); + + let a = FixedTrait::new(237690, true); // -3.48973469357602 + assert_precise(asinh(a), -131017, 'invalid neg two', Option::None(())); +} + +#[test] +#[available_gas(10000000)] +fn test_atanh() { + let a = FixedTrait::new(58982, false); // 0.9 + assert_precise(atanh(a), 96483, 'invalid 0.9', Option::None(())); // 1.36892147623689 + + let a = FixedTrait::new(HALF, false); // 0.5 + assert_precise(atanh(a), 35999, 'invalid half', Option::None(())); // 0.42914542526098 + + let a = FixedTrait::ZERO(); + assert(atanh(a).into() == 0, 'invalid zero'); + + let a = FixedTrait::new(HALF, true); // 0.5 + assert_precise(atanh(a), -35999, 'invalid neg half', Option::None(())); // 0.42914542526098 + + let a = FixedTrait::new(58982, true); // 0.9 + assert_precise(atanh(a), -96483, 'invalid -0.9', Option::None(())); // 1.36892147623689 +} diff --git a/src/numbers/fixed_point/implementations/fp16x16wide/math/lut.cairo b/src/numbers/fixed_point/implementations/fp16x16wide/math/lut.cairo new file mode 100644 index 000000000..e96b0d389 --- /dev/null +++ b/src/numbers/fixed_point/implementations/fp16x16wide/math/lut.cairo @@ -0,0 +1,1235 @@ +// Calculates the most significant bit +fn msb(whole: u64) -> (u64, u64) { + if whole < 256 { + if whole < 2 { + return (0, 1); + } + if whole < 4 { + return (1, 2); + } + if whole < 8 { + return (2, 4); + } + if whole < 16 { + return (3, 8); + } + if whole < 32 { + return (4, 16); + } + if whole < 64 { + return (5, 32); + } + if whole < 128 { + return (6, 64); + } + if whole < 256 { + return (7, 128); + } + } else if whole < 65536 { + if whole < 512 { + return (8, 256); + } + if whole < 1024 { + return (9, 512); + } + if whole < 2048 { + return (10, 1024); + } + if whole < 4096 { + return (11, 2048); + } + if whole < 8192 { + return (12, 4096); + } + if whole < 16384 { + return (13, 8192); + } + if whole < 32768 { + return (14, 16384); + } + if whole < 65536 { + return (15, 32768); + } + } + + return (16, 65536); +} + +fn exp2(exp: u64) -> u64 { + if exp <= 16 { + if exp == 0 { + return 1; + } + if exp == 1 { + return 2; + } + if exp == 2 { + return 4; + } + if exp == 3 { + return 8; + } + if exp == 4 { + return 16; + } + if exp == 5 { + return 32; + } + if exp == 6 { + return 64; + } + if exp == 7 { + return 128; + } + if exp == 8 { + return 256; + } + if exp == 9 { + return 512; + } + if exp == 10 { + return 1024; + } + if exp == 11 { + return 2048; + } + if exp == 12 { + return 4096; + } + if exp == 13 { + return 8192; + } + if exp == 14 { + return 16384; + } + if exp == 15 { + return 32768; + } + if exp == 16 { + return 65536; + } + } + + return 65536; +} + +fn sin(a: u64) -> (u64, u64, u64) { + let slot = a / 402; + + if slot < 128 { + if slot < 64 { + if slot < 32 { + if slot < 16 { + if slot == 0 { + return (0, 0, 402); + } + if slot == 1 { + return (402, 402, 804); + } + if slot == 2 { + return (804, 804, 1206); + } + if slot == 3 { + return (1206, 1206, 1608); + } + if slot == 4 { + return (1608, 1608, 2010); + } + if slot == 5 { + return (2011, 2010, 2412); + } + if slot == 6 { + return (2413, 2412, 2814); + } + if slot == 7 { + return (2815, 2814, 3216); + } + if slot == 8 { + return (3217, 3216, 3617); + } + if slot == 9 { + return (3619, 3617, 4019); + } + if slot == 10 { + return (4023, 4019, 4420); + } + if slot == 11 { + return (4423, 4420, 4821); + } + if slot == 12 { + return (4825, 4821, 5222); + } + if slot == 13 { + return (5228, 5222, 5623); + } + if slot == 14 { + return (5630, 5623, 6023); + } + if slot == 15 { + return (6032, 6023, 6424); + } + } else { + if slot == 16 { + return (6434, 6424, 6824); + } + if slot == 17 { + return (6836, 6824, 7224); + } + if slot == 18 { + return (7238, 7224, 7623); + } + if slot == 19 { + return (7640, 7623, 8022); + } + if slot == 20 { + return (8042, 8022, 8421); + } + if slot == 21 { + return (8445, 8421, 8820); + } + if slot == 22 { + return (8847, 8820, 9218); + } + if slot == 23 { + return (9249, 9218, 9616); + } + if slot == 24 { + return (9651, 9616, 10014); + } + if slot == 25 { + return (10053, 10014, 10411); + } + if slot == 26 { + return (10455, 10411, 10808); + } + if slot == 27 { + return (10857, 10808, 11204); + } + if slot == 28 { + return (11259, 11204, 11600); + } + if slot == 29 { + return (11662, 11600, 11996); + } + if slot == 30 { + return (12064, 11996, 12391); + } + if slot == 31 { + return (12466, 12391, 12785); + } + } + } else { + if slot < 48 { + if slot == 32 { + return (12868, 12785, 13180); + } + if slot == 33 { + return (13270, 13180, 13573); + } + if slot == 34 { + return (13672, 13573, 13966); + } + if slot == 35 { + return (14074, 13966, 14359); + } + if slot == 36 { + return (14476, 14359, 14751); + } + if slot == 37 { + return (14879, 14751, 15143); + } + if slot == 38 { + return (15281, 15143, 15534); + } + if slot == 39 { + return (15683, 15534, 15924); + } + if slot == 40 { + return (16081, 15924, 16314); + } + if slot == 41 { + return (16487, 16314, 16703); + } + if slot == 42 { + return (16889, 16703, 17091); + } + if slot == 43 { + return (17291, 17091, 17479); + } + if slot == 44 { + return (17693, 17479, 17867); + } + if slot == 45 { + return (18096, 17867, 18253); + } + if slot == 46 { + return (18498, 18253, 18639); + } + if slot == 47 { + return (18900, 18639, 19024); + } + } else { + if slot == 48 { + return (19302, 19024, 19409); + } + if slot == 49 { + return (19704, 19409, 19792); + } + if slot == 50 { + return (20113, 19792, 20175); + } + if slot == 51 { + return (20508, 20175, 20557); + } + if slot == 52 { + return (20910, 20557, 20939); + } + if slot == 53 { + return (21313, 20939, 21320); + } + if slot == 54 { + return (21715, 21320, 21699); + } + if slot == 55 { + return (22117, 21699, 22078); + } + if slot == 56 { + return (22519, 22078, 22457); + } + if slot == 57 { + return (22921, 22457, 22834); + } + if slot == 58 { + return (23323, 22834, 23210); + } + if slot == 59 { + return (23725, 23210, 23586); + } + if slot == 60 { + return (24127, 23586, 23961); + } + if slot == 61 { + return (24530, 23961, 24335); + } + if slot == 62 { + return (24932, 24335, 24708); + } + if slot == 63 { + return (25334, 24708, 25080); + } + } + } + } else { + if slot < 96 { + if slot < 80 { + if slot == 64 { + return (25736, 25080, 25451); + } + if slot == 65 { + return (26138, 25451, 25821); + } + if slot == 66 { + return (26540, 25821, 26190); + } + if slot == 67 { + return (26942, 26190, 26558); + } + if slot == 68 { + return (27344, 26558, 26925); + } + if slot == 69 { + return (27747, 26925, 27291); + } + if slot == 70 { + return (28149, 27291, 27656); + } + if slot == 71 { + return (28551, 27656, 28020); + } + if slot == 72 { + return (28953, 28020, 28383); + } + if slot == 73 { + return (29355, 28383, 28745); + } + if slot == 74 { + return (29757, 28745, 29106); + } + if slot == 75 { + return (30159, 29106, 29466); + } + if slot == 76 { + return (30561, 29466, 29824); + } + if slot == 77 { + return (30964, 29824, 30182); + } + if slot == 78 { + return (31366, 30182, 30538); + } + if slot == 79 { + return (31768, 30538, 30893); + } + } else { + if slot == 80 { + return (32171, 30893, 31248); + } + if slot == 81 { + return (32572, 31248, 31600); + } + if slot == 82 { + return (32974, 31600, 31952); + } + if slot == 83 { + return (33376, 31952, 32303); + } + if slot == 84 { + return (33778, 32303, 32652); + } + if slot == 85 { + return (34181, 32652, 33000); + } + if slot == 86 { + return (34583, 33000, 33347); + } + if slot == 87 { + return (34985, 33347, 33692); + } + if slot == 88 { + return (35387, 33692, 34037); + } + if slot == 89 { + return (35789, 34037, 34380); + } + if slot == 90 { + return (36194, 34380, 34721); + } + if slot == 91 { + return (36593, 34721, 35062); + } + if slot == 92 { + return (36995, 35062, 35401); + } + if slot == 93 { + return (37398, 35401, 35738); + } + if slot == 94 { + return (37800, 35738, 36075); + } + if slot == 95 { + return (38202, 36075, 36410); + } + } + } else { + if slot < 112 { + if slot == 96 { + return (38604, 36410, 36744); + } + if slot == 97 { + return (39006, 36744, 37076); + } + if slot == 98 { + return (39408, 37076, 37407); + } + if slot == 99 { + return (39810, 37407, 37736); + } + if slot == 100 { + return (40227, 37736, 38064); + } + if slot == 101 { + return (40615, 38064, 38391); + } + if slot == 102 { + return (41017, 38391, 38716); + } + if slot == 103 { + return (41419, 38716, 39040); + } + if slot == 104 { + return (41821, 39040, 39362); + } + if slot == 105 { + return (42223, 39362, 39683); + } + if slot == 106 { + return (42625, 39683, 40002); + } + if slot == 107 { + return (43027, 40002, 40320); + } + if slot == 108 { + return (43429, 40320, 40636); + } + if slot == 109 { + return (43832, 40636, 40951); + } + if slot == 110 { + return (44234, 40951, 41264); + } + if slot == 111 { + return (44636, 41264, 41576); + } + } else { + if slot == 112 { + return (45038, 41576, 41886); + } + if slot == 113 { + return (45440, 41886, 42194); + } + if slot == 114 { + return (45842, 42194, 42501); + } + if slot == 115 { + return (46244, 42501, 42806); + } + if slot == 116 { + return (46646, 42806, 43110); + } + if slot == 117 { + return (47048, 43110, 43412); + } + if slot == 118 { + return (47451, 43412, 43713); + } + if slot == 119 { + return (47853, 43713, 44011); + } + if slot == 120 { + return (48252, 44011, 44308); + } + if slot == 121 { + return (48657, 44308, 44604); + } + if slot == 122 { + return (49059, 44604, 44898); + } + if slot == 123 { + return (49461, 44898, 45190); + } + if slot == 124 { + return (49863, 45190, 45480); + } + if slot == 125 { + return (50265, 45480, 45769); + } + if slot == 126 { + return (50668, 45769, 46056); + } + if slot == 127 { + return (51070, 46056, 46341); + } + } + } + } + } else { + if slot < 192 { + if slot < 160 { + if slot < 144 { + if slot == 128 { + return (51472, 46341, 46624); + } + if slot == 129 { + return (51874, 46624, 46906); + } + if slot == 130 { + return (52285, 46906, 47186); + } + if slot == 131 { + return (52678, 47186, 47464); + } + if slot == 132 { + return (53080, 47464, 47741); + } + if slot == 133 { + return (53482, 47741, 48015); + } + if slot == 134 { + return (53885, 48015, 48288); + } + if slot == 135 { + return (54287, 48288, 48559); + } + if slot == 136 { + return (54689, 48559, 48828); + } + if slot == 137 { + return (55091, 48828, 49095); + } + if slot == 138 { + return (55493, 49095, 49361); + } + if slot == 139 { + return (55895, 49361, 49624); + } + if slot == 140 { + return (56297, 49624, 49886); + } + if slot == 141 { + return (56699, 49886, 50146); + } + if slot == 142 { + return (57102, 50146, 50404); + } + if slot == 143 { + return (57504, 50404, 50660); + } + } else { + if slot == 144 { + return (57906, 50660, 50914); + } + if slot == 145 { + return (58308, 50914, 51166); + } + if slot == 146 { + return (58710, 51166, 51417); + } + if slot == 147 { + return (59112, 51417, 51665); + } + if slot == 148 { + return (59514, 51665, 51911); + } + if slot == 149 { + return (59916, 51911, 52156); + } + if slot == 150 { + return (60320, 52156, 52398); + } + if slot == 151 { + return (60721, 52398, 52639); + } + if slot == 152 { + return (61123, 52639, 52878); + } + if slot == 153 { + return (61525, 52878, 53114); + } + if slot == 154 { + return (61927, 53114, 53349); + } + if slot == 155 { + return (62329, 53349, 53581); + } + if slot == 156 { + return (62731, 53581, 53812); + } + if slot == 157 { + return (63133, 53812, 54040); + } + if slot == 158 { + return (63536, 54040, 54267); + } + if slot == 159 { + return (63938, 54267, 54491); + } + if slot == 160 { + return (64343, 54491, 54714); + } + } + } else { + if slot < 176 { + if slot == 161 { + return (64742, 54714, 54934); + } + if slot == 162 { + return (65144, 54934, 55152); + } + if slot == 163 { + return (65546, 55152, 55368); + } + if slot == 164 { + return (65948, 55368, 55582); + } + if slot == 165 { + return (66350, 55582, 55794); + } + if slot == 166 { + return (66753, 55794, 56004); + } + if slot == 167 { + return (67155, 56004, 56212); + } + if slot == 168 { + return (67557, 56212, 56418); + } + if slot == 169 { + return (67959, 56418, 56621); + } + if slot == 170 { + return (68361, 56621, 56823); + } + if slot == 171 { + return (68763, 56823, 57022); + } + if slot == 172 { + return (69165, 57022, 57219); + } + if slot == 173 { + return (69567, 57219, 57414); + } + if slot == 174 { + return (69970, 57414, 57607); + } + if slot == 175 { + return (70372, 57607, 57798); + } + } else { + if slot == 176 { + return (70774, 57798, 57986); + } + if slot == 177 { + return (71176, 57986, 58172); + } + if slot == 178 { + return (71578, 58172, 58356); + } + if slot == 179 { + return (71980, 58356, 58538); + } + if slot == 180 { + return (72382, 58538, 58718); + } + if slot == 181 { + return (72784, 58718, 58896); + } + if slot == 182 { + return (73187, 58896, 59071); + } + if slot == 183 { + return (73589, 59071, 59244); + } + if slot == 184 { + return (73991, 59244, 59415); + } + if slot == 185 { + return (74393, 59415, 59583); + } + if slot == 186 { + return (74795, 59583, 59750); + } + if slot == 187 { + return (75197, 59750, 59914); + } + if slot == 188 { + return (75599, 59914, 60075); + } + if slot == 189 { + return (76001, 60075, 60235); + } + if slot == 190 { + return (76401, 60235, 60392); + } + if slot == 191 { + return (76806, 60392, 60547); + } + } + } + } else { + if slot < 224 { + if slot < 208 { + if slot == 192 { + return (77208, 60547, 60700); + } + if slot == 193 { + return (77610, 60700, 60851); + } + if slot == 194 { + return (78012, 60851, 60999); + } + if slot == 195 { + return (78414, 60999, 61145); + } + if slot == 196 { + return (78816, 61145, 61288); + } + if slot == 197 { + return (79218, 61288, 61429); + } + if slot == 198 { + return (79621, 61429, 61568); + } + if slot == 199 { + return (80023, 61568, 61705); + } + if slot == 200 { + return (80423, 61705, 61839); + } + if slot == 201 { + return (80827, 61839, 61971); + } + if slot == 202 { + return (81229, 61971, 62101); + } + if slot == 203 { + return (81631, 62101, 62228); + } + if slot == 204 { + return (82033, 62228, 62353); + } + if slot == 205 { + return (82435, 62353, 62476); + } + if slot == 206 { + return (82838, 62476, 62596); + } + if slot == 207 { + return (83240, 62596, 62714); + } + } else { + if slot == 208 { + return (83642, 62714, 62830); + } + if slot == 209 { + return (84044, 62830, 62943); + } + if slot == 210 { + return (84446, 62943, 63054); + } + if slot == 211 { + return (84848, 63054, 63162); + } + if slot == 212 { + return (85250, 63162, 63268); + } + if slot == 213 { + return (85652, 63268, 63372); + } + if slot == 214 { + return (86055, 63372, 63473); + } + if slot == 215 { + return (86457, 63473, 63572); + } + if slot == 216 { + return (86859, 63572, 63668); + } + if slot == 217 { + return (87261, 63668, 63763); + } + if slot == 218 { + return (87663, 63763, 63854); + } + if slot == 219 { + return (88065, 63854, 63944); + } + if slot == 220 { + return (88467, 63944, 64031); + } + if slot == 221 { + return (88869, 64031, 64115); + } + if slot == 222 { + return (89271, 64115, 64197); + } + if slot == 223 { + return (89674, 64197, 64277); + } + } + } else { + if slot < 240 { + if slot == 224 { + return (90076, 64277, 64354); + } + if slot == 225 { + return (90478, 64354, 64429); + } + if slot == 226 { + return (90880, 64429, 64501); + } + if slot == 227 { + return (91282, 64501, 64571); + } + if slot == 228 { + return (91684, 64571, 64639); + } + if slot == 229 { + return (92086, 64639, 64704); + } + if slot == 230 { + return (92491, 64704, 64766); + } + if slot == 231 { + return (92891, 64766, 64827); + } + if slot == 232 { + return (93293, 64827, 64884); + } + if slot == 233 { + return (93695, 64884, 64940); + } + if slot == 234 { + return (94097, 64940, 64993); + } + if slot == 235 { + return (94499, 64993, 65043); + } + if slot == 236 { + return (94901, 65043, 65091); + } + if slot == 237 { + return (95303, 65091, 65137); + } + if slot == 238 { + return (95705, 65137, 65180); + } + if slot == 239 { + return (96108, 65180, 65220); + } + } else { + if slot == 240 { + return (96514, 65220, 65259); + } + if slot == 241 { + return (96912, 65259, 65294); + } + if slot == 242 { + return (97314, 65294, 65328); + } + if slot == 243 { + return (97716, 65328, 65358); + } + if slot == 244 { + return (98118, 65358, 65387); + } + if slot == 245 { + return (98520, 65387, 65413); + } + if slot == 246 { + return (98922, 65413, 65436); + } + if slot == 247 { + return (99325, 65436, 65457); + } + if slot == 248 { + return (99727, 65457, 65476); + } + if slot == 249 { + return (100129, 65476, 65492); + } + if slot == 250 { + return (100531, 65492, 65505); + } + if slot == 251 { + return (100933, 65505, 65516); + } + if slot == 252 { + return (101335, 65516, 65525); + } + if slot == 253 { + return (101737, 65525, 65531); + } + if slot == 254 { + return (102139, 65531, 65535); + } + } + } + } + } + + return (102542, 65535, 65536); +} + +fn atan(a: u64) -> (u64, u64, u64) { + let slot = a / 459; + + if slot == 0 { + return (0, 0, 459); + } + if slot == 1 { + return (459, 459, 917); + } + if slot == 2 { + return (918, 917, 1376); + } + if slot == 3 { + return (1376, 1376, 1835); + } + if slot == 4 { + return (1835, 1835, 2293); + } + if slot == 5 { + return (2294, 2293, 2751); + } + if slot == 6 { + return (2753, 2751, 3209); + } + if slot == 7 { + return (3211, 3209, 3666); + } + if slot == 8 { + return (3670, 3666, 4123); + } + if slot == 9 { + return (4129, 4123, 4580); + } + if slot == 10 { + return (4591, 4580, 5036); + } + if slot == 11 { + return (5046, 5036, 5492); + } + if slot == 12 { + return (5505, 5492, 5947); + } + if slot == 13 { + return (5964, 5947, 6402); + } + if slot == 14 { + return (6423, 6402, 6856); + } + if slot == 15 { + return (6881, 6856, 7310); + } + if slot == 16 { + return (7340, 7310, 7762); + } + if slot == 17 { + return (7799, 7762, 8214); + } + if slot == 18 { + return (8258, 8214, 8665); + } + if slot == 19 { + return (8716, 8665, 9116); + } + if slot == 20 { + return (9181, 9116, 9565); + } + if slot == 21 { + return (9634, 9565, 10014); + } + if slot == 22 { + return (10093, 10014, 10462); + } + if slot == 23 { + return (10551, 10462, 10908); + } + if slot == 24 { + return (11010, 10908, 11354); + } + if slot == 25 { + return (11469, 11354, 11798); + } + if slot == 26 { + return (11928, 11798, 12242); + } + if slot == 27 { + return (12386, 12242, 12684); + } + if slot == 28 { + return (12845, 12684, 13125); + } + if slot == 29 { + return (13304, 13125, 13565); + } + if slot == 30 { + return (13762, 13565, 14004); + } + if slot == 31 { + return (14221, 14004, 14442); + } + if slot == 32 { + return (14680, 14442, 14878); + } + if slot == 33 { + return (15139, 14878, 15313); + } + if slot == 34 { + return (15598, 15313, 15746); + } + if slot == 35 { + return (16056, 15746, 16178); + } + if slot == 36 { + return (16515, 16178, 16609); + } + if slot == 37 { + return (16974, 16609, 17038); + } + if slot == 38 { + return (17433, 17038, 17466); + } + if slot == 39 { + return (17891, 17466, 17892); + } + if slot == 40 { + return (18353, 17892, 18317); + } + if slot == 41 { + return (18809, 18317, 18740); + } + if slot == 42 { + return (19268, 18740, 19161); + } + if slot == 43 { + return (19726, 19161, 19581); + } + if slot == 44 { + return (20185, 19581, 19999); + } + if slot == 45 { + return (20644, 19999, 20416); + } + if slot == 46 { + return (21103, 20416, 20830); + } + if slot == 47 { + return (21561, 20830, 21243); + } + if slot == 48 { + return (22020, 21243, 21655); + } + if slot == 49 { + return (22479, 21655, 22064); + } + if slot == 50 { + return (22944, 22064, 22472); + } + if slot == 51 { + return (23396, 22472, 22878); + } + if slot == 52 { + return (23855, 22878, 23282); + } + if slot == 53 { + return (24314, 23282, 23685); + } + if slot == 54 { + return (24773, 23685, 24085); + } + if slot == 55 { + return (25231, 24085, 24484); + } + if slot == 56 { + return (25690, 24484, 24880); + } + if slot == 57 { + return (26149, 24880, 25275); + } + if slot == 58 { + return (26608, 25275, 25668); + } + if slot == 59 { + return (27066, 25668, 26059); + } + if slot == 60 { + return (27534, 26059, 26448); + } + if slot == 61 { + return (27984, 26448, 26835); + } + if slot == 62 { + return (28443, 26835, 27220); + } + if slot == 63 { + return (28901, 27220, 27603); + } + if slot == 64 { + return (29360, 27603, 27984); + } + if slot == 65 { + return (29819, 27984, 28363); + } + if slot == 66 { + return (30278, 28363, 28740); + } + if slot == 67 { + return (30736, 28740, 29115); + } + if slot == 68 { + return (31195, 29115, 29488); + } + if slot == 69 { + return (31654, 29488, 29859); + } + if slot == 70 { + return (32113, 29859, 30228); + } + if slot == 71 { + return (32571, 30228, 30595); + } + if slot == 72 { + return (33030, 30595, 30960); + } + if slot == 73 { + return (33489, 30960, 31323); + } + if slot == 74 { + return (33948, 31323, 31683); + } + if slot == 75 { + return (34406, 31683, 32042); + } + if slot == 76 { + return (34865, 32042, 32398); + } + if slot == 77 { + return (35324, 32398, 32753); + } + if slot == 78 { + return (35783, 32753, 33105); + } + if slot == 79 { + return (36241, 33105, 33455); + } + if slot == 80 { + return (36700, 33455, 33804); + } + if slot == 81 { + return (37159, 33804, 34150); + } + if slot == 82 { + return (37618, 34150, 34494); + } + if slot == 83 { + return (38076, 34494, 34836); + } + if slot == 84 { + return (38535, 34836, 35175); + } + if slot == 85 { + return (38994, 35175, 35513); + } + if slot == 86 { + return (39453, 35513, 35849); + } + if slot == 87 { + return (39911, 35849, 36183); + } + if slot == 88 { + return (40370, 36183, 36514); + } + if slot == 89 { + return (40829, 36514, 36843); + } + if slot == 90 { + return (41288, 36843, 37171); + } + if slot == 91 { + return (41746, 37171, 37496); + } + if slot == 92 { + return (42205, 37496, 37819); + } + if slot == 93 { + return (42664, 37819, 38141); + } + if slot == 94 { + return (43123, 38141, 38460); + } + if slot == 95 { + return (43581, 38460, 38777); + } + if slot == 96 { + return (44040, 38777, 39092); + } + if slot == 97 { + return (44499, 39092, 39405); + } + if slot == 98 { + return (44958, 39405, 39716); + } + + return (45416, 39716, 40025); +} diff --git a/src/numbers/fixed_point/implementations/fp16x16wide/math/trig.cairo b/src/numbers/fixed_point/implementations/fp16x16wide/math/trig.cairo new file mode 100644 index 000000000..4c47eca5e --- /dev/null +++ b/src/numbers/fixed_point/implementations/fp16x16wide/math/trig.cairo @@ -0,0 +1,450 @@ +use debug::PrintTrait; +use integer::{u64_safe_divmod, u64_as_non_zero}; +use option::OptionTrait; + +use orion::numbers::fixed_point::implementations::fp16x16wide::math::lut; +use orion::numbers::fixed_point::implementations::fp16x16wide::core::{ + HALF, ONE, TWO, FP16x16W, FP16x16WImpl, FP16x16WAdd, FP16x16WSub, FP16x16WMul, FP16x16WDiv, + FP16x16WIntoFelt252, FixedTrait +}; + +// CONSTANTS + +const TWO_PI: u64 = 411775; +const PI: u64 = 205887; +const HALF_PI: u64 = 102944; + +// PUBLIC + +// Calculates arccos(a) for -1 <= a <= 1 (fixed point) +// arccos(a) = arcsin(sqrt(1 - a^2)) - arctan identity has discontinuity at zero +fn acos(a: FP16x16W) -> FP16x16W { + let asin_arg = (FixedTrait::ONE() - a * a).sqrt(); // will fail if a > 1 + let asin_res = asin(asin_arg); + + if (a.sign) { + return FixedTrait::new(PI, false) - asin_res; + } else { + return asin_res; + } +} + +fn acos_fast(a: FP16x16W) -> FP16x16W { + let asin_arg = (FixedTrait::ONE() - a * a).sqrt(); // will fail if a > 1 + let asin_res = asin_fast(asin_arg); + + if (a.sign) { + return FixedTrait::new(PI, false) - asin_res; + } else { + return asin_res; + } +} + +// Calculates arcsin(a) for -1 <= a <= 1 (fixed point) +// arcsin(a) = arctan(a / sqrt(1 - a^2)) +fn asin(a: FP16x16W) -> FP16x16W { + if (a.mag == ONE) { + return FixedTrait::new(HALF_PI, a.sign); + } + + let div = (FixedTrait::ONE() - a * a).sqrt(); // will fail if a > 1 + return atan(a / div); +} + +fn asin_fast(a: FP16x16W) -> FP16x16W { + if (a.mag == ONE) { + return FixedTrait::new(HALF_PI, a.sign); + } + + let div = (FixedTrait::ONE() - a * a).sqrt(); // will fail if a > 1 + return atan_fast(a / div); +} + +// Calculates arctan(a) (fixed point) +// See https://stackoverflow.com/a/50894477 for range adjustments +fn atan(a: FP16x16W) -> FP16x16W { + let mut at = a.abs(); + let mut shift = false; + let mut invert = false; + + // Invert value when a > 1 + if (at.mag > ONE) { + at = FixedTrait::ONE() / at; + invert = true; + } + + // Account for lack of precision in polynomaial when a > 0.7 + if (at.mag > 45875) { + let sqrt3_3 = FixedTrait::new(37837, false); // sqrt(3) / 3 + at = (at - sqrt3_3) / (FixedTrait::ONE() + at * sqrt3_3); + shift = true; + } + + let r10 = FixedTrait::new(120, true) * at; + let r9 = (r10 + FixedTrait::new(3066, true)) * at; + let r8 = (r9 + FixedTrait::new(12727, false)) * at; + let r7 = (r8 + FixedTrait::new(17170, true)) * at; + let r6 = (r7 + FixedTrait::new(2865, false)) * at; + let r5 = (r6 + FixedTrait::new(12456, false)) * at; + let r4 = (r5 + FixedTrait::new(90, false)) * at; + let r3 = (r4 + FixedTrait::new(21852, true)) * at; + let r2 = r3 * at; + let mut res = (r2 + FixedTrait::new(65536, false)) * at; + + // Adjust for sign change, inversion, and shift + if (shift) { + res = res + FixedTrait::new(34315, false); // pi / 6 + } + + if (invert) { + res = res - FixedTrait::new(HALF_PI, false); + } + + return FixedTrait::new(res.mag, a.sign); +} + + +fn atan_fast(a: FP16x16W) -> FP16x16W { + let mut at = a.abs(); + let mut shift = false; + let mut invert = false; + + // Invert value when a > 1 + if (at.mag > ONE) { + at = FixedTrait::ONE() / at; + invert = true; + } + + // Account for lack of precision in polynomaial when a > 0.7 + if (at.mag > 45875) { + let sqrt3_3 = FixedTrait::new(37837, false); // sqrt(3) / 3 + at = (at - sqrt3_3) / (FixedTrait::ONE() + at * sqrt3_3); + shift = true; + } + + let (start, low, high) = lut::atan(at.mag); + let partial_step = FixedTrait::new(at.mag - start, false) / FixedTrait::new(459, false); + let mut res = partial_step * FixedTrait::new(high - low, false) + FixedTrait::new(low, false); + + // Adjust for sign change, inversion, and shift + if (shift) { + res = res + FixedTrait::new(34315, false); // pi / 6 + } + + if (invert) { + res = res - FixedTrait::::new(HALF_PI, false); + } + + return FixedTrait::new(res.mag, a.sign); +} + +// Calculates cos(a) with a in radians (fixed point) +fn cos(a: FP16x16W) -> FP16x16W { + return sin(FixedTrait::new(HALF_PI, false) - a); +} + +fn cos_fast(a: FP16x16W) -> FP16x16W { + return sin_fast(FixedTrait::new(HALF_PI, false) - a); +} + +fn sin(a: FP16x16W) -> FP16x16W { + let a1 = a.mag % TWO_PI; + let (whole_rem, partial_rem) = u64_safe_divmod(a1, u64_as_non_zero(PI)); + let a2 = FixedTrait::new(partial_rem, false); + let partial_sign = whole_rem == 1; + + let loop_res = a2 * _sin_loop(a2, 7, FixedTrait::ONE()); + return FixedTrait::new(loop_res.mag, a.sign ^ partial_sign && loop_res.mag != 0); +} + +fn sin_fast(a: FP16x16W) -> FP16x16W { + let a1 = a.mag % TWO_PI; + let (whole_rem, mut partial_rem) = u64_safe_divmod(a1, u64_as_non_zero(PI)); + let partial_sign = whole_rem == 1; + + if partial_rem >= HALF_PI { + partial_rem = PI - partial_rem; + } + + let (start, low, high) = lut::sin(partial_rem); + let partial_step = FixedTrait::new(partial_rem - start, false) / FixedTrait::new(402, false); + let res = partial_step * (FixedTrait::new(high, false) - FixedTrait::new(low, false)) + + FixedTrait::::new(low, false); + + return FixedTrait::new(res.mag, a.sign ^ partial_sign && res.mag != 0); +} + +// Calculates tan(a) with a in radians (fixed point) +fn tan(a: FP16x16W) -> FP16x16W { + let sinx = sin(a); + let cosx = cos(a); + assert(cosx.mag != 0, 'tan undefined'); + return sinx / cosx; +} + +fn tan_fast(a: FP16x16W) -> FP16x16W { + let sinx = sin_fast(a); + let cosx = cos_fast(a); + assert(cosx.mag != 0, 'tan undefined'); + return sinx / cosx; +} + +// Helper function to calculate Taylor series for sin +fn _sin_loop(a: FP16x16W, i: u64, acc: FP16x16W) -> FP16x16W { + let div = (2 * i + 2) * (2 * i + 3); + let term = a * a * acc / FixedTrait::new_unscaled(div, false); + let new_acc = FixedTrait::ONE() - term; + + if (i == 0) { + return new_acc; + } + + return _sin_loop(a, i - 1, new_acc); +} + +// Tests -------------------------------------------------------------------------------------------------------------- + +use traits::Into; + +use orion::numbers::fixed_point::implementations::fp16x16wide::helpers::{ + assert_precise, assert_relative +}; +use orion::numbers::fixed_point::implementations::fp16x16wide::core::{FP16x16WPartialEq, FP16x16WPrint}; + +#[test] +#[available_gas(8000000)] +fn test_acos() { + let error = Option::Some(84); // 1e-5 + + let a = FixedTrait::ONE(); + assert(acos(a).into() == 0, 'invalid one'); + + let a = FixedTrait::new(ONE / 2, false); + assert_relative(acos(a), 68629, 'invalid half', error); // 1.3687308642680 + + let a = FixedTrait::ZERO(); + assert_relative(acos(a), HALF_PI.into(), 'invalid zero', Option::None(())); // PI / 2 + + let a = FixedTrait::new(ONE / 2, true); + assert_relative(acos(a), 137258, 'invalid neg half', error); // 2.737461741902 + + let a = FixedTrait::new(ONE, true); + assert_relative(acos(a), PI.into(), 'invalid neg one', Option::None(())); // PI +} + +#[test] +#[available_gas(8000000)] +fn test_acos_fast() { + let error = Option::Some(84); // 1e-5 + + let a = FixedTrait::ONE(); + assert(acos_fast(a).into() == 0, 'invalid one'); + + let a = FixedTrait::new(ONE / 2, false); + assert_relative(acos_fast(a), 68629, 'invalid half', error); // 1.3687308642680 + + let a = FixedTrait::ZERO(); + assert_relative(acos_fast(a), HALF_PI.into(), 'invalid zero', Option::None(())); // PI / 2 + + let a = FixedTrait::new(ONE / 2, true); + assert_relative(acos_fast(a), 137258, 'invalid neg half', error); // 2.737461741902 + + let a = FixedTrait::new(ONE, true); + assert_relative(acos_fast(a), PI.into(), 'invalid neg one', Option::None(())); // PI +} + +#[test] +#[should_panic] +#[available_gas(8000000)] +fn test_acos_fail() { + let a = FixedTrait::new(2 * ONE, true); + acos(a); +} + +#[test] +#[available_gas(8000000)] +fn test_atan_fast() { + let error = Option::Some(84); // 1e-5 + + let a = FixedTrait::new(2 * ONE, false); + assert_relative(atan_fast(a), 72558, 'invalid two', error); + + let a = FixedTrait::ONE(); + assert_relative(atan_fast(a), 51472, 'invalid one', error); + + let a = FixedTrait::new(ONE / 2, false); + assert_relative(atan_fast(a), 30386, 'invalid half', error); + + let a = FixedTrait::ZERO(); + assert(atan_fast(a).into() == 0, 'invalid zero'); + + let a = FixedTrait::new(ONE / 2, true); + assert_relative(atan_fast(a), -30386, 'invalid neg half', error); + + let a = FixedTrait::new(ONE, true); + assert_relative(atan_fast(a), -51472, 'invalid neg one', error); + + let a = FixedTrait::new(2 * ONE, true); + assert_relative(atan_fast(a), -72558, 'invalid neg two', error); +} + +#[test] +#[available_gas(8000000)] +fn test_atan() { + let a = FixedTrait::new(2 * ONE, false); + assert_relative(atan(a), 72558, 'invalid two', Option::None(())); + + let a = FixedTrait::ONE(); + assert_relative(atan(a), 51472, 'invalid one', Option::None(())); + + let a = FixedTrait::new(ONE / 2, false); + assert_relative(atan(a), 30386, 'invalid half', Option::None(())); + + let a = FixedTrait::ZERO(); + assert(atan(a).into() == 0, 'invalid zero'); + + let a = FixedTrait::new(ONE / 2, true); + assert_relative(atan(a), -30386, 'invalid neg half', Option::None(())); + + let a = FixedTrait::new(ONE, true); + assert_relative(atan(a), -51472, 'invalid neg one', Option::None(())); + + let a = FixedTrait::new(2 * ONE, true); + assert_relative(atan(a), -72558, 'invalid neg two', Option::None(())); +} + +#[test] +#[available_gas(8000000)] +fn test_asin() { + let error = Option::Some(84); // 1e-5 + + let a = FixedTrait::ONE(); + assert_relative(asin(a), HALF_PI.into(), 'invalid one', Option::None(())); // PI / 2 + + let a = FixedTrait::new(ONE / 2, false); + assert_relative(asin(a), 34315, 'invalid half', error); + + let a = FixedTrait::ZERO(); + assert_precise(asin(a), 0, 'invalid zero', Option::None(())); + + let a = FixedTrait::new(ONE / 2, true); + assert_relative(asin(a), -34315, 'invalid neg half', error); + + let a = FixedTrait::new(ONE, true); + assert_relative(asin(a), -HALF_PI.into(), 'invalid neg one', Option::None(())); // -PI / 2 +} + +#[test] +#[should_panic] +#[available_gas(8000000)] +fn test_asin_fail() { + let a = FixedTrait::new(2 * ONE, false); + asin(a); +} + +#[test] +#[available_gas(8000000)] +fn test_cos() { + let a = FixedTrait::new(HALF_PI, false); + assert(cos(a).into() == 0, 'invalid half pi'); + + let a = FixedTrait::new(HALF_PI / 2, false); + assert_relative(cos(a), 46341, 'invalid quarter pi', Option::None(())); // 0.55242717280199 + + let a = FixedTrait::new(PI, false); + assert_relative(cos(a), -1 * ONE.into(), 'invalid pi', Option::None(())); + + let a = FixedTrait::new(HALF_PI, true); + assert_precise(cos(a), 0, 'invalid neg half pi', Option::None(())); + + let a = FixedTrait::new_unscaled(17, false); + assert_relative(cos(a), -18033, 'invalid 17', Option::None(())); // -0.21497123284870 + + let a = FixedTrait::new_unscaled(17, true); + assert_relative(cos(a), -18033, 'invalid -17', Option::None(())); // -0.21497123284870 +} + +#[test] +#[available_gas(8000000)] +fn test_cos_fast() { + let error = Option::Some(84); // 1e-5 + + let a = FixedTrait::new(HALF_PI, false); + assert(cos_fast(a).into() == 0, 'invalid half pi'); + + let a = FixedTrait::new(HALF_PI / 2, false); + assert_precise(cos_fast(a), 46341, 'invalid quarter pi', error); // 0.55242717280199 + + let a = FixedTrait::new(PI, false); + assert_precise(cos_fast(a), -1 * ONE.into(), 'invalid pi', error); + + let a = FixedTrait::new(HALF_PI, true); + assert_precise(cos(a), 0, 'invalid neg half pi', Option::None(())); + + let a = FixedTrait::new_unscaled(17, false); + assert_precise(cos_fast(a), -18033, 'invalid 17', error); // -0.21497123284870 +} + +#[test] +#[available_gas(8000000)] +fn test_sin() { + let a = FixedTrait::new(HALF_PI, false); + assert_precise(sin(a), ONE.into(), 'invalid half pi', Option::None(())); + + let a = FixedTrait::new(HALF_PI / 2, false); + assert_precise(sin(a), 46341, 'invalid quarter pi', Option::None(())); // 0.55242717280199 + + let a = FixedTrait::new(PI, false); + assert(sin(a).into() == 0, 'invalid pi'); + + let a = FixedTrait::new(HALF_PI, true); + assert_precise( + sin(a), -ONE.into(), 'invalid neg half pi', Option::None(()) + ); // 0.78124999999529 + + let a = FixedTrait::new_unscaled(17, false); + assert_precise(sin(a), -63006, 'invalid 17', Option::None(())); // -0.75109179053073 + + let a = FixedTrait::new_unscaled(17, true); + assert_precise(sin(a), 63006, 'invalid -17', Option::None(())); // 0.75109179053073 +} + +#[test] +#[available_gas(8000000)] +fn test_sin_fast() { + let error = Option::Some(84); // 1e-5 + + let a = FixedTrait::new(HALF_PI, false); + assert_precise(sin_fast(a), ONE.into(), 'invalid half pi', error); + + let a = FixedTrait::new(HALF_PI / 2, false); + assert_precise(sin_fast(a), 46341, 'invalid quarter pi', error); // 0.55242717280199 + + let a = FixedTrait::new(PI, false); + assert(sin_fast(a).into() == 0, 'invalid pi'); + + let a = FixedTrait::new(HALF_PI, true); + assert_precise(sin_fast(a), -ONE.into(), 'invalid neg half pi', error); // 0.78124999999529 + + let a = FixedTrait::new_unscaled(17, false); + assert_precise(sin_fast(a), -63006, 'invalid 17', error); // -0.75109179053073 + + let a = FixedTrait::new_unscaled(17, true); + assert_precise(sin_fast(a), 63006, 'invalid -17', error); // 0.75109179053073 +} + +#[test] +#[available_gas(8000000)] +fn test_tan() { + let a = FixedTrait::new(HALF_PI / 2, false); + assert_precise(tan(a), ONE.into(), 'invalid quarter pi', Option::None(())); + + let a = FixedTrait::new(PI, false); + assert_precise(tan(a), 0, 'invalid pi', Option::None(())); + + let a = FixedTrait::new_unscaled(17, false); + assert_precise(tan(a), 228990, 'invalid 17', Option::None(())); // 3.3858731852805 + + let a = FixedTrait::new_unscaled(17, true); + assert_precise(tan(a), -228952, 'invalid -17', Option::None(())); // -3.3858731852805 +} From 27f3b6b5f3ab60225329a88411310900450d8535 Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Sun, 22 Oct 2023 10:51:37 +0300 Subject: [PATCH 11/42] add convertors --- .../implementations/fp16x16wide/core.cairo | 21 ++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/src/numbers/fixed_point/implementations/fp16x16wide/core.cairo b/src/numbers/fixed_point/implementations/fp16x16wide/core.cairo index 01a1d8b8d..f12b96d9b 100644 --- a/src/numbers/fixed_point/implementations/fp16x16wide/core.cairo +++ b/src/numbers/fixed_point/implementations/fp16x16wide/core.cairo @@ -5,7 +5,7 @@ use result::{ResultTrait, ResultTraitImpl}; use traits::{TryInto, Into}; use orion::numbers::signed_integer::{i32::i32, i8::i8}; -use orion::numbers::fixed_point::core::FixedTrait; +use orion::numbers::{fixed_point::core::FixedTrait, FP16x16}; use orion::numbers::fixed_point::implementations::fp16x16wide::math::{core, trig, hyp}; use orion::numbers::fixed_point::utils; @@ -211,6 +211,25 @@ impl FP16x16WIntoI32 of Into { } } +impl FP16x16IntoFP16x16W of Into { + fn into(self: FP16x16) -> FP16x16W { + FP16x16W { mag: self.mag.into(), sign: self.sign } + } +} + +impl FP16x16WTryIntoFP16x16 of TryInto { + fn try_into(self: FP16x16W) -> Option { + match self.mag.try_into() { + Option::Some(val) => { + Option::Some(FP16x16 { mag: val, sign: self.sign }) + }, + Option::None(_) => { + Option::None(()) + } + } + } +} + impl FP16x16WTryIntoI8 of TryInto { fn try_into(self: FP16x16W) -> Option { _i8_try_from_fp(self) From 53181cfad4a0250e8e97582c91a55972a96489fd Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Sun, 22 Oct 2023 11:03:50 +0300 Subject: [PATCH 12/42] implement FP16x16WTensor --- src/numbers.cairo | 165 ++++++++ src/operators/tensor/implementations.cairo | 1 + .../implementations/tensor_fp16x16wide.cairo | 361 ++++++++++++++++++ 3 files changed, 527 insertions(+) create mode 100644 src/operators/tensor/implementations/tensor_fp16x16wide.cairo diff --git a/src/numbers.cairo b/src/numbers.cairo index 04ad6efa5..02dd5b344 100644 --- a/src/numbers.cairo +++ b/src/numbers.cairo @@ -378,6 +378,171 @@ impl FP16x16Number of NumberTrait { } } +use orion::numbers::fixed_point::implementations::fp16x16wide::core::{FP16x16WImpl, FP16x16W}; +use orion::numbers::fixed_point::implementations::fp16x16wide::math::core as core_fp16x16wide; +use orion::numbers::fixed_point::implementations::fp16x16wide::math::comp as comp_fp16x16wide; + +impl FP16x16WNumber of NumberTrait { + fn new(mag: u64, sign: bool) -> FP16x16W { + FP16x16WImpl::new(mag, sign) + } + + fn new_unscaled(mag: u64, sign: bool) -> FP16x16W { + FP16x16WImpl::new_unscaled(mag, sign) + } + + fn from_felt(val: felt252) -> FP16x16W { + FP16x16WImpl::from_felt(val) + } + + fn ceil(self: FP16x16W) -> FP16x16W { + FP16x16WImpl::ceil(self) + } + + fn exp(self: FP16x16W) -> FP16x16W { + FP16x16WImpl::exp(self) + } + + fn exp2(self: FP16x16W) -> FP16x16W { + FP16x16WImpl::exp2(self) + } + + fn floor(self: FP16x16W) -> FP16x16W { + FP16x16WImpl::floor(self) + } + + fn ln(self: FP16x16W) -> FP16x16W { + FP16x16WImpl::ln(self) + } + + fn log2(self: FP16x16W) -> FP16x16W { + FP16x16WImpl::log2(self) + } + + fn log10(self: FP16x16W) -> FP16x16W { + FP16x16WImpl::log10(self) + } + + fn pow(self: FP16x16W, b: FP16x16W) -> FP16x16W { + FP16x16WImpl::pow(self, b) + } + + fn round(self: FP16x16W) -> FP16x16W { + FP16x16WImpl::round(self) + } + + fn sqrt(self: FP16x16W) -> FP16x16W { + FP16x16WImpl::sqrt(self) + } + + fn acos(self: FP16x16W) -> FP16x16W { + FP16x16WImpl::acos(self) + } + + fn asin(self: FP16x16W) -> FP16x16W { + FP16x16WImpl::asin(self) + } + + fn atan(self: FP16x16W) -> FP16x16W { + FP16x16WImpl::atan(self) + } + + fn cos(self: FP16x16W) -> FP16x16W { + FP16x16WImpl::cos(self) + } + + fn sin(self: FP16x16W) -> FP16x16W { + FP16x16WImpl::sin(self) + } + + fn tan(self: FP16x16W) -> FP16x16W { + FP16x16WImpl::tan(self) + } + + fn acosh(self: FP16x16W) -> FP16x16W { + FP16x16WImpl::acosh(self) + } + + fn asinh(self: FP16x16W) -> FP16x16W { + FP16x16WImpl::asinh(self) + } + + fn atanh(self: FP16x16W) -> FP16x16W { + FP16x16WImpl::atanh(self) + } + + fn cosh(self: FP16x16W) -> FP16x16W { + FP16x16WImpl::cosh(self) + } + + fn sinh(self: FP16x16W) -> FP16x16W { + FP16x16WImpl::sinh(self) + } + + fn tanh(self: FP16x16W) -> FP16x16W { + FP16x16WImpl::tanh(self) + } + + fn zero() -> FP16x16W { + FP16x16WImpl::ZERO() + } + fn is_zero(self: FP16x16W) -> bool { + core_fp16x16wide::eq(@self, @FP16x16WImpl::ZERO()) + } + + fn one() -> FP16x16W { + FP16x16WImpl::ONE() + } + + fn neg_one() -> FP16x16W { + FP16x16W { mag: core_fp16x16wide::ONE, sign: true } + } + + fn is_one(self: FP16x16W) -> bool { + core_fp16x16wide::eq(@self, @FP16x16WImpl::ONE()) + } + + fn abs(self: FP16x16W) -> FP16x16W { + core_fp16x16wide::abs(self) + } + + fn min_value() -> FP16x16W { + FP16x16W { mag: core_fp16x16wide::MAX, sign: true } + } + + fn max_value() -> FP16x16W { + FP16x16W { mag: core_fp16x16wide::MAX, sign: false } + } + + fn min(self: FP16x16W, other: FP16x16W) -> FP16x16W { + comp_fp16x16wide::min(self, other) + } + + fn max(self: FP16x16W, other: FP16x16W) -> FP16x16W { + comp_fp16x16wide::max(self, other) + } + + fn mag(self: FP16x16W) -> u64 { + self.mag + } + + fn is_neg(self: FP16x16W) -> bool { + self.sign + } + + fn xor(lhs: FP16x16W, rhs: FP16x16W) -> bool { + comp_fp16x16wide::xor(lhs, rhs) + } + + fn or(lhs: FP16x16W, rhs: FP16x16W) -> bool { + comp_fp16x16wide::or(lhs, rhs) + } + + fn sign(self: FP16x16W) -> FP16x16W { + core_fp16x16wide::sign(self) + } +} + use orion::numbers::fixed_point::implementations::fp64x64::core::{FP64x64Impl, FP64x64}; use orion::numbers::fixed_point::implementations::fp64x64::core as core_fp64x64; use orion::numbers::fixed_point::implementations::fp64x64::comp as comp_fp64x64; diff --git a/src/operators/tensor/implementations.cairo b/src/operators/tensor/implementations.cairo index a585b88a7..0df3dcdec 100644 --- a/src/operators/tensor/implementations.cairo +++ b/src/operators/tensor/implementations.cairo @@ -5,3 +5,4 @@ mod tensor_fp8x23; mod tensor_fp16x16; mod tensor_fp64x64; mod tensor_fp32x32; +mod tensor_fp16x16wide; \ No newline at end of file diff --git a/src/operators/tensor/implementations/tensor_fp16x16wide.cairo b/src/operators/tensor/implementations/tensor_fp16x16wide.cairo new file mode 100644 index 000000000..0a89fe72d --- /dev/null +++ b/src/operators/tensor/implementations/tensor_fp16x16wide.cairo @@ -0,0 +1,361 @@ +use array::ArrayTrait; +use array::SpanTrait; +use option::OptionTrait; +use traits::{TryInto, Into}; + +use orion::numbers::fixed_point::core::FixedTrait; +use orion::operators::tensor::core::{ + new_tensor, stride, Tensor, TensorTrait, ravel_index, unravel_index, reshape, at_tensor, +}; +use orion::operators::tensor::{math, linalg, quantization, core}; +use orion::numbers::{i8, i32, NumberTrait, FP16x16W}; +use orion::operators::tensor::implementations::{tensor_i8::I8Tensor, tensor_u32::U32Tensor}; + +impl FP16x16WTensor of TensorTrait { + fn new(shape: Span, data: Span) -> Tensor { + new_tensor(shape, data) + } + + fn at(self: @Tensor, indices: Span) -> FP16x16W { + *at_tensor(self, indices) + } + + fn min(self: @Tensor) -> FP16x16W { + math::min::min_in_tensor::(*self.data) + } + + fn max(self: @Tensor) -> FP16x16W { + math::max::max_in_tensor(*self.data) + } + + fn stride(self: @Tensor) -> Span { + stride(*self.shape) + } + + fn ravel_index(self: @Tensor, indices: Span) -> usize { + ravel_index(*self.shape, indices) + } + + fn unravel_index(self: @Tensor, index: usize) -> Span { + unravel_index(index, *self.shape) + } + + fn reshape(self: @Tensor, target_shape: Span) -> Tensor { + reshape(self, target_shape) + } + + fn reduce_sum(self: @Tensor, axis: usize, keepdims: bool) -> Tensor { + math::reduce_sum::reduce_sum(self, axis, keepdims) + } + + fn argmax( + self: @Tensor, + axis: usize, + keepdims: Option, + select_last_index: Option + ) -> Tensor { + math::argmax::argmax(self, axis, keepdims, select_last_index) + } + + fn argmin( + self: @Tensor, + axis: usize, + keepdims: Option, + select_last_index: Option + ) -> Tensor { + math::argmin::argmin(self, axis, keepdims, select_last_index) + } + + fn transpose(self: @Tensor, axes: Span) -> Tensor { + linalg::transpose::transpose(self, axes) + } + + fn matmul(self: @Tensor, other: @Tensor) -> Tensor { + linalg::matmul::matmul(self, other) + } + + fn exp(self: @Tensor) -> Tensor { + math::exp::exp(*self) + } + + fn log(self: @Tensor) -> Tensor { + math::log::log(*self) + } + + fn equal(self: @Tensor, other: @Tensor) -> Tensor { + math::equal::equal(self, other) + } + + fn greater(self: @Tensor, other: @Tensor) -> Tensor { + math::greater::greater(self, other) + } + + fn greater_equal(self: @Tensor, other: @Tensor) -> Tensor { + math::greater_equal::greater_equal(self, other) + } + + fn less(self: @Tensor, other: @Tensor) -> Tensor { + math::less::less(self, other) + } + + fn less_equal(self: @Tensor, other: @Tensor) -> Tensor { + math::less_equal::less_equal(self, other) + } + + fn abs(self: @Tensor) -> Tensor { + math::abs::abs(*self) + } + + fn ceil(self: @Tensor) -> Tensor { + math::ceil::ceil(*self) + } + + fn sin(self: @Tensor) -> Tensor { + math::sin::sin(*self) + } + + fn cos(self: @Tensor) -> Tensor { + math::cos::cos(*self) + } + + fn asin(self: @Tensor) -> Tensor { + math::asin::asin(*self) + } + + fn cumsum( + self: @Tensor, axis: usize, exclusive: Option, reverse: Option + ) -> Tensor { + math::cumsum::cumsum(self, axis, exclusive, reverse) + } + + fn flatten(self: @Tensor, axis: usize) -> Tensor { + math::flatten::flatten(self, axis) + } + + fn sinh(self: @Tensor) -> Tensor { + math::sinh::sinh(*self) + } + + fn tanh(self: @Tensor) -> Tensor { + math::tanh::tanh(*self) + } + + fn cosh(self: @Tensor) -> Tensor { + math::cosh::cosh(*self) + } + + fn acosh(self: @Tensor) -> Tensor { + math::acosh::acosh(*self) + } + + fn asinh(self: @Tensor) -> Tensor { + math::asinh::asinh(*self) + } + + fn atan(self: @Tensor) -> Tensor { + math::atan::atan(*self) + } + + fn xor(self: @Tensor, other: @Tensor) -> Tensor { + math::xor::xor(self, other) + } + + fn or(self: @Tensor, other: @Tensor) -> Tensor { + math::or::or(self, other) + } + + fn acos(self: @Tensor) -> Tensor { + math::acos::acos(*self) + } + + fn onehot( + self: @Tensor, depth: usize, axis: Option, values: Span + ) -> Tensor { + panic(array!['not supported!']) + } + + fn sqrt(self: @Tensor) -> Tensor { + math::sqrt::sqrt(*self) + } + + fn concat(tensors: Span>, axis: usize,) -> Tensor { + math::concat::concat(tensors, axis) + } + + fn quantize_linear( + self: @Tensor, y_scale: @Tensor, y_zero_point: @Tensor + ) -> Tensor:: { + quantization::quantize_linear::quantize_linear( + self, + y_scale, + y_zero_point, + NumberTrait::new_unscaled(128, true), + NumberTrait::new_unscaled(127, false) + ) + } + + fn dequantize_linear( + self: @Tensor, x_scale: @Tensor, x_zero_point: @Tensor + ) -> Tensor:: { + panic(array!['not supported!']) + } + + fn slice( + self: @Tensor, + starts: Span, + ends: Span, + axes: Option>, + steps: Option> + ) -> Tensor { + core::slice::(self, starts, ends, axes, steps) + } + + fn gather( + self: @Tensor, indices: Tensor, axis: Option + ) -> Tensor { + math::gather::gather(self, indices, axis) + } + + fn nonzero(self: @Tensor) -> Tensor { + core::nonzero(self) + } + + fn squeeze(self: @Tensor, axes: Option>) -> Tensor { + core::squeeze(self, axes) + } + + fn unsqueeze(self: @Tensor, axes: Span) -> Tensor { + core::unsqueeze(self, axes) + } + + fn sign(self: @Tensor) -> Tensor { + math::sign::sign(*self) + } + + fn clip( + self: @Tensor, min: Option, max: Option + ) -> Tensor { + core::clip(self, min, max) + } +} + +/// Implements addition for `Tensor` using the `Add` trait. +impl FP16x16WTensorAdd of Add> { + /// Adds two `Tensor` instances element-wise. + /// + /// # Arguments + /// * `lhs` - The first tensor. + /// * `rhs` - The second tensor. + /// + /// # Returns + /// * A `Tensor` instance representing the result of the element-wise addition. + fn add(lhs: Tensor, rhs: Tensor) -> Tensor { + math::arithmetic::add(@lhs, @rhs) + } +} + +/// Implements subtraction for `Tensor` using the `Sub` trait. +impl FP16x16WTensorSub of Sub> { + /// Subtracts two `Tensor` instances element-wise. + /// + /// # Arguments + /// * `lhs` - The first tensor. + /// * `rhs` - The second tensor. + /// + /// # Returns + /// * A `Tensor` instance representing the result of the element-wise subtraction. + fn sub(lhs: Tensor, rhs: Tensor) -> Tensor { + math::arithmetic::sub(@lhs, @rhs) + } +} + +/// Implements multiplication for `Tensor` using the `Mul` trait. +impl FP16x16WTensorMul of Mul> { + /// Multiplies two `Tensor` instances element-wise. + /// + /// # Arguments + /// * `lhs` - The first tensor. + /// * `rhs` - The second tensor. + /// + /// # Returns + /// * A `Tensor` instance representing the result of the element-wise multiplication. + fn mul(lhs: Tensor, rhs: Tensor) -> Tensor { + math::arithmetic::mul(@lhs, @rhs) + } +} + +/// Implements division for `Tensor` using the `Div` trait. +impl FP16x16WTensorDiv of Div> { + /// Divides two `Tensor` instances element-wise. + /// + /// # Arguments + /// * `lhs` - The first tensor. + /// * `rhs` - The second tensor. + /// + /// # Returns + /// * A `Tensor` instance representing the result of the element-wise division. + fn div(lhs: Tensor, rhs: Tensor) -> Tensor { + math::arithmetic::div(@lhs, @rhs) + } +} + +/// Implements partial equal for two `Tensor` using the `PartialEq` trait. +impl FP16x16WTensorPartialEq of PartialEq> { + fn eq(lhs: @Tensor, rhs: @Tensor) -> bool { + tensor_eq(*lhs, *rhs) + } + + fn ne(lhs: @Tensor, rhs: @Tensor) -> bool { + !tensor_eq(*lhs, *rhs) + } +} + +impl U32TryIntoU32 of TryInto { + fn try_into(self: u32) -> Option { + Option::Some(self) + } +} + + +// Internals +const PRECISION: u64 = 589; // 0.009 + +fn relative_eq(lhs: @FP16x16W, rhs: @FP16x16W) -> bool { + let diff = *lhs - *rhs; + + let rel_diff = if *lhs.mag != 0 { + (diff / *lhs).mag + } else { + diff.mag + }; + + rel_diff <= PRECISION +} + + +fn tensor_eq(mut lhs: Tensor, mut rhs: Tensor,) -> bool { + let mut is_eq = true; + + loop { + if lhs.shape.len() == 0 || !is_eq { + break; + } + + is_eq = lhs.shape.pop_front().unwrap() == rhs.shape.pop_front().unwrap(); + }; + + if !is_eq { + return false; + } + + loop { + if lhs.data.len() == 0 || !is_eq { + break; + } + + is_eq = relative_eq(lhs.data.pop_front().unwrap(), rhs.data.pop_front().unwrap()); + }; + + return is_eq; +} + From fc98e829f7f3c31e1e9e19b7b95d97a74eea68af Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Sun, 22 Oct 2023 11:36:03 +0300 Subject: [PATCH 13/42] implement softmaxWide --- src/operators/nn/functional/softmax.cairo | 30 +++++++++++++- .../nn/implementations/nn_fp16x16.cairo | 8 +++- src/operators/tensor/math/arithmetic.cairo | 41 +++++++++++++++++++ src/operators/tensor/math/exp.cairo | 34 +++++++++++++++ 4 files changed, 111 insertions(+), 2 deletions(-) diff --git a/src/operators/nn/functional/softmax.cairo b/src/operators/nn/functional/softmax.cairo index 528856265..fdbb7054f 100644 --- a/src/operators/nn/functional/softmax.cairo +++ b/src/operators/nn/functional/softmax.cairo @@ -1,5 +1,6 @@ use orion::operators::tensor::core::{Tensor, TensorTrait}; - +use orion::operators::tensor::math::{exp::exp_upcast, arithmetic::div_downcast}; +use orion::numbers::fixed_point::core::FixedTrait; /// Cf: NNTrait::softmax docstring fn softmax< @@ -19,3 +20,30 @@ fn softmax< return softmax; } +/// Cf: NNTrait::softmax docstring +fn softmaxWide< + T, + TMAG, + W, + WMAG, + impl TTensor: TensorTrait, + impl WTensor: TensorTrait, + impl TDiv: Div, + impl TIntoW: Into, + impl WTryIntoT: TryInto, + impl TCopy: Copy, + impl TDrop: Drop, + impl WCopy: Copy, + impl WDrop: Drop, + impl TFixed: FixedTrait, + impl WFixed: FixedTrait, +>( + z: @Tensor, axis: usize +) -> Tensor { + let exp_tensor: Tensor = exp_upcast(*z); + let sum = exp_tensor.reduce_sum(axis, true); + let softmax: Tensor = div_downcast(@exp_tensor, @sum); + + return softmax; +} + diff --git a/src/operators/nn/implementations/nn_fp16x16.cairo b/src/operators/nn/implementations/nn_fp16x16.cairo index a0094de29..b940d8742 100644 --- a/src/operators/nn/implementations/nn_fp16x16.cairo +++ b/src/operators/nn/implementations/nn_fp16x16.cairo @@ -7,6 +7,12 @@ use orion::numbers::fixed_point::implementations::fp16x16::core::FP16x16; use orion::operators::tensor::implementations::tensor_fp16x16::{ FP16x16Tensor, FP16x16TensorDiv, FP16x16TensorAdd }; +use orion::numbers::fixed_point::implementations::fp16x16wide::core::{ + FP16x16WImpl, FP16x16WTryIntoFP16x16, FP16x16W, FP16x16IntoFP16x16W +}; +use orion::operators::tensor::implementations::tensor_fp16x16wide::{ + FP16x16WTensor, FP16x16WTensorDiv, FP16x16WTensorAdd +}; impl FP16x16NN of NNTrait { fn relu(tensor: @Tensor) -> Tensor { @@ -18,7 +24,7 @@ impl FP16x16NN of NNTrait { } fn softmax(tensor: @Tensor, axis: usize) -> Tensor { - functional::softmax::softmax(tensor, axis) + functional::softmax::softmaxWide::(tensor, axis) } fn logsoftmax(tensor: @Tensor, axis: usize) -> Tensor { diff --git a/src/operators/tensor/math/arithmetic.cairo b/src/operators/tensor/math/arithmetic.cairo index 075f565b8..06879a4af 100644 --- a/src/operators/tensor/math/arithmetic.cairo +++ b/src/operators/tensor/math/arithmetic.cairo @@ -304,3 +304,44 @@ fn saturated_div< return TensorTrait::::new(broadcasted_shape, result.span()); } + +fn div_downcast< + T, + D, + impl TTensor: TensorTrait, + impl DTensor: TensorTrait, + impl DDiv: Div, + impl TTryIntoD: TryInto, + impl TCopy: Copy, + impl TDrop: Drop, + impl DCopy: Copy, + impl DDrop: Drop +>( + self: @Tensor, other: @Tensor +) -> Tensor { + let broadcasted_shape = broadcast_shape(*self.shape, *other.shape); + let mut result = ArrayTrait::new(); + + let num_elements = len_from_shape(broadcasted_shape); + + let mut n: usize = 0; + loop { + let indices_broadcasted = unravel_index(n, broadcasted_shape); + + let indices_self = broadcast_index_mapping(*self.shape, indices_broadcasted); + let indices_other = broadcast_index_mapping(*other.shape, indices_broadcasted); + + result + .append( + (*(*self.data)[indices_self]).try_into().unwrap() + / (*(*other.data)[indices_other]).try_into().unwrap() + ); + + n += 1; + if n == num_elements { + break (); + }; + }; + + return TensorTrait::::new(broadcasted_shape, result.span()); +} diff --git a/src/operators/tensor/math/exp.cairo b/src/operators/tensor/math/exp.cairo index 3ba1e97d7..5ba161030 100644 --- a/src/operators/tensor/math/exp.cairo +++ b/src/operators/tensor/math/exp.cairo @@ -34,3 +34,37 @@ fn exp< return TensorTrait::new(self.shape, result.span()); } + +/// Cf: TensorTrait::exp docstring +fn exp_upcast< + T, + MAG, + W, + WMAG, + impl TFixedTrait: FixedTrait, + impl TTensor: TensorTrait, + impl TCopy: Copy, + impl TDrop: Drop, + impl WFixedTrait: FixedTrait, + impl WTensor: TensorTrait, + impl WCopy: Copy, + impl WDrop: Drop, + impl TIntoW: Into, +>( + mut self: Tensor +) -> Tensor { + let mut result = ArrayTrait::new(); + + loop { + match self.data.pop_front() { + Option::Some(item) => { + result.append((TIntoW::into(*item)).exp()); + }, + Option::None(_) => { + break; + } + }; + }; + + return TensorTrait::new(self.shape, result.span()); +} From eb09f55a1ed5c06eac5a2df224f028f7be6c7b3b Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Sun, 22 Oct 2023 11:46:59 +0300 Subject: [PATCH 14/42] implement softmaxWide2 --- src/operators/nn/functional/softmax.cairo | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/operators/nn/functional/softmax.cairo b/src/operators/nn/functional/softmax.cairo index fdbb7054f..1d3c59090 100644 --- a/src/operators/nn/functional/softmax.cairo +++ b/src/operators/nn/functional/softmax.cairo @@ -47,3 +47,15 @@ fn softmaxWide< return softmax; } +use orion::numbers::{FP16x16, FP16x16W}; +use orion::operators::tensor::{ + implementations::tensor_fp16x16wide::{FP16x16WTensor, FP16x16WTensorDiv}, FP16x16Tensor +}; + +/// Cf: NNTrait::softmax docstring +fn softmaxWide2(z: @Tensor, axis: usize) -> Tensor { + let exp_tensor: Tensor = exp_upcast(*z); + let sum = exp_tensor.reduce_sum(axis, true); + let softmax = exp_tensor / sum; + return softmax; +} From 73e49ab080a350d6f041cc5c9aaa6ff5df9cf0e1 Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Sun, 22 Oct 2023 11:48:44 +0300 Subject: [PATCH 15/42] Update softmax.cairo --- src/operators/nn/functional/softmax.cairo | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/operators/nn/functional/softmax.cairo b/src/operators/nn/functional/softmax.cairo index 1d3c59090..81cfd6c4e 100644 --- a/src/operators/nn/functional/softmax.cairo +++ b/src/operators/nn/functional/softmax.cairo @@ -56,6 +56,6 @@ use orion::operators::tensor::{ fn softmaxWide2(z: @Tensor, axis: usize) -> Tensor { let exp_tensor: Tensor = exp_upcast(*z); let sum = exp_tensor.reduce_sum(axis, true); - let softmax = exp_tensor / sum; - return softmax; + // let softmax = exp_tensor / sum; + return sum; } From a835f7375071d5a011b49918f19ff617d3c20734 Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Sun, 22 Oct 2023 11:51:50 +0300 Subject: [PATCH 16/42] Update softmax.cairo --- src/operators/nn/functional/softmax.cairo | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/operators/nn/functional/softmax.cairo b/src/operators/nn/functional/softmax.cairo index 81cfd6c4e..870db1611 100644 --- a/src/operators/nn/functional/softmax.cairo +++ b/src/operators/nn/functional/softmax.cairo @@ -55,7 +55,7 @@ use orion::operators::tensor::{ /// Cf: NNTrait::softmax docstring fn softmaxWide2(z: @Tensor, axis: usize) -> Tensor { let exp_tensor: Tensor = exp_upcast(*z); - let sum = exp_tensor.reduce_sum(axis, true); + // let sum = exp_tensor.reduce_sum(axis, true); // let softmax = exp_tensor / sum; - return sum; + return exp_tensor; } From 541f9c450e213e36296adac28e94071f052809b5 Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Sun, 22 Oct 2023 11:58:23 +0300 Subject: [PATCH 17/42] Update core.cairo --- src/numbers/fixed_point/implementations/fp16x16/math/core.cairo | 1 + 1 file changed, 1 insertion(+) diff --git a/src/numbers/fixed_point/implementations/fp16x16/math/core.cairo b/src/numbers/fixed_point/implementations/fp16x16/math/core.cairo index e113b97c7..fc05cd941 100644 --- a/src/numbers/fixed_point/implementations/fp16x16/math/core.cairo +++ b/src/numbers/fixed_point/implementations/fp16x16/math/core.cairo @@ -61,6 +61,7 @@ fn eq(a: @FP16x16, b: @FP16x16) -> bool { // Calculates the natural exponent of x: e^x fn exp(a: FP16x16) -> FP16x16 { + a.print(); return exp2(FixedTrait::new(94548, false) * a); // log2(e) * 2^23 ≈ 12102203 } From 5f7e22e3f11e733e533d0ac506339f0b72f989d6 Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Sun, 22 Oct 2023 11:59:59 +0300 Subject: [PATCH 18/42] Update softmax.cairo --- src/operators/nn/functional/softmax.cairo | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/operators/nn/functional/softmax.cairo b/src/operators/nn/functional/softmax.cairo index 870db1611..a26288fc9 100644 --- a/src/operators/nn/functional/softmax.cairo +++ b/src/operators/nn/functional/softmax.cairo @@ -53,9 +53,10 @@ use orion::operators::tensor::{ }; /// Cf: NNTrait::softmax docstring -fn softmaxWide2(z: @Tensor, axis: usize) -> Tensor { - let exp_tensor: Tensor = exp_upcast(*z); +fn softmaxWide2(z: @Tensor, axis: usize) -> Tensor { + // let exp_tensor: Tensor = exp_upcast(*z); // let sum = exp_tensor.reduce_sum(axis, true); // let softmax = exp_tensor / sum; - return exp_tensor; + // return exp_tensor; + *z } From 1b1064e406979a2ebfdfba7d19d57833df69ef14 Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Sun, 22 Oct 2023 12:06:20 +0300 Subject: [PATCH 19/42] Update softmax.cairo --- src/operators/nn/functional/softmax.cairo | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/operators/nn/functional/softmax.cairo b/src/operators/nn/functional/softmax.cairo index a26288fc9..f8d4f234c 100644 --- a/src/operators/nn/functional/softmax.cairo +++ b/src/operators/nn/functional/softmax.cairo @@ -54,7 +54,7 @@ use orion::operators::tensor::{ /// Cf: NNTrait::softmax docstring fn softmaxWide2(z: @Tensor, axis: usize) -> Tensor { - // let exp_tensor: Tensor = exp_upcast(*z); + let exp_tensor: Tensor = exp_upcast(*z); // let sum = exp_tensor.reduce_sum(axis, true); // let softmax = exp_tensor / sum; // return exp_tensor; From 397da5f139db5f08e043965f8a529f53782a79af Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Sun, 22 Oct 2023 12:07:20 +0300 Subject: [PATCH 20/42] Update softmax.cairo --- src/operators/nn/functional/softmax.cairo | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/operators/nn/functional/softmax.cairo b/src/operators/nn/functional/softmax.cairo index f8d4f234c..7088c5f40 100644 --- a/src/operators/nn/functional/softmax.cairo +++ b/src/operators/nn/functional/softmax.cairo @@ -55,7 +55,7 @@ use orion::operators::tensor::{ /// Cf: NNTrait::softmax docstring fn softmaxWide2(z: @Tensor, axis: usize) -> Tensor { let exp_tensor: Tensor = exp_upcast(*z); - // let sum = exp_tensor.reduce_sum(axis, true); + let sum = exp_tensor.reduce_sum(axis, true); // let softmax = exp_tensor / sum; // return exp_tensor; *z From 09a39488abc489cbda17b382abe8372b3c0683f9 Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Sun, 22 Oct 2023 12:10:08 +0300 Subject: [PATCH 21/42] Update softmax.cairo --- src/operators/nn/functional/softmax.cairo | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/operators/nn/functional/softmax.cairo b/src/operators/nn/functional/softmax.cairo index 7088c5f40..bc5dd4025 100644 --- a/src/operators/nn/functional/softmax.cairo +++ b/src/operators/nn/functional/softmax.cairo @@ -56,7 +56,7 @@ use orion::operators::tensor::{ fn softmaxWide2(z: @Tensor, axis: usize) -> Tensor { let exp_tensor: Tensor = exp_upcast(*z); let sum = exp_tensor.reduce_sum(axis, true); - // let softmax = exp_tensor / sum; + let softmax = exp_tensor / sum; // return exp_tensor; *z } From 3bb5a5d30bf2bf10fdcc8228939731ca1e345b51 Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Sun, 22 Oct 2023 12:13:14 +0300 Subject: [PATCH 22/42] add print --- .../fixed_point/implementations/fp16x16/math/core.cairo | 1 - src/operators/nn/functional/softmax.cairo | 4 +++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/numbers/fixed_point/implementations/fp16x16/math/core.cairo b/src/numbers/fixed_point/implementations/fp16x16/math/core.cairo index fc05cd941..e113b97c7 100644 --- a/src/numbers/fixed_point/implementations/fp16x16/math/core.cairo +++ b/src/numbers/fixed_point/implementations/fp16x16/math/core.cairo @@ -61,7 +61,6 @@ fn eq(a: @FP16x16, b: @FP16x16) -> bool { // Calculates the natural exponent of x: e^x fn exp(a: FP16x16) -> FP16x16 { - a.print(); return exp2(FixedTrait::new(94548, false) * a); // log2(e) * 2^23 ≈ 12102203 } diff --git a/src/operators/nn/functional/softmax.cairo b/src/operators/nn/functional/softmax.cairo index bc5dd4025..247f8af5f 100644 --- a/src/operators/nn/functional/softmax.cairo +++ b/src/operators/nn/functional/softmax.cairo @@ -51,12 +51,14 @@ use orion::numbers::{FP16x16, FP16x16W}; use orion::operators::tensor::{ implementations::tensor_fp16x16wide::{FP16x16WTensor, FP16x16WTensorDiv}, FP16x16Tensor }; +use debug::PrintTrait; /// Cf: NNTrait::softmax docstring fn softmaxWide2(z: @Tensor, axis: usize) -> Tensor { let exp_tensor: Tensor = exp_upcast(*z); let sum = exp_tensor.reduce_sum(axis, true); - let softmax = exp_tensor / sum; + (sum.data.len()).print(); + // let softmax = exp_tensor / sum; // return exp_tensor; *z } From e881ab2899cab807b9c727ad964a0eeebd8501f1 Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Sun, 22 Oct 2023 12:14:46 +0300 Subject: [PATCH 23/42] Update softmax.cairo --- src/operators/nn/functional/softmax.cairo | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/operators/nn/functional/softmax.cairo b/src/operators/nn/functional/softmax.cairo index 247f8af5f..e9d968c8a 100644 --- a/src/operators/nn/functional/softmax.cairo +++ b/src/operators/nn/functional/softmax.cairo @@ -57,7 +57,7 @@ use debug::PrintTrait; fn softmaxWide2(z: @Tensor, axis: usize) -> Tensor { let exp_tensor: Tensor = exp_upcast(*z); let sum = exp_tensor.reduce_sum(axis, true); - (sum.data.len()).print(); + (*sum.data.at(0)).print(); // let softmax = exp_tensor / sum; // return exp_tensor; *z From 803fbcecdbfcc3477ea5ada53c7b5dea1ba00d3f Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Sun, 22 Oct 2023 12:17:11 +0300 Subject: [PATCH 24/42] Update softmax.cairo --- src/operators/nn/functional/softmax.cairo | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/operators/nn/functional/softmax.cairo b/src/operators/nn/functional/softmax.cairo index e9d968c8a..6fdfd747b 100644 --- a/src/operators/nn/functional/softmax.cairo +++ b/src/operators/nn/functional/softmax.cairo @@ -56,8 +56,10 @@ use debug::PrintTrait; /// Cf: NNTrait::softmax docstring fn softmaxWide2(z: @Tensor, axis: usize) -> Tensor { let exp_tensor: Tensor = exp_upcast(*z); - let sum = exp_tensor.reduce_sum(axis, true); - (*sum.data.at(0)).print(); + (*exp_tensor.data.at(0)).print(); + + // let sum = exp_tensor.reduce_sum(axis, true); + // (*sum.data.at(0)).print(); // let softmax = exp_tensor / sum; // return exp_tensor; *z From 162f5288b7b8bbcbc16fa8039ab82d78126bc0e9 Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Sun, 22 Oct 2023 12:17:26 +0300 Subject: [PATCH 25/42] Update softmax.cairo --- src/operators/nn/functional/softmax.cairo | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/operators/nn/functional/softmax.cairo b/src/operators/nn/functional/softmax.cairo index 6fdfd747b..bbe9335ec 100644 --- a/src/operators/nn/functional/softmax.cairo +++ b/src/operators/nn/functional/softmax.cairo @@ -56,7 +56,7 @@ use debug::PrintTrait; /// Cf: NNTrait::softmax docstring fn softmaxWide2(z: @Tensor, axis: usize) -> Tensor { let exp_tensor: Tensor = exp_upcast(*z); - (*exp_tensor.data.at(0)).print(); + (exp_tensor.data.len()).print(); // let sum = exp_tensor.reduce_sum(axis, true); // (*sum.data.at(0)).print(); From febcae739c83581bff427efb7c6ffd5b30f1d6fb Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Sun, 22 Oct 2023 12:20:34 +0300 Subject: [PATCH 26/42] Update softmax.cairo --- src/operators/nn/functional/softmax.cairo | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/operators/nn/functional/softmax.cairo b/src/operators/nn/functional/softmax.cairo index bbe9335ec..dc7e0636e 100644 --- a/src/operators/nn/functional/softmax.cairo +++ b/src/operators/nn/functional/softmax.cairo @@ -56,7 +56,9 @@ use debug::PrintTrait; /// Cf: NNTrait::softmax docstring fn softmaxWide2(z: @Tensor, axis: usize) -> Tensor { let exp_tensor: Tensor = exp_upcast(*z); - (exp_tensor.data.len()).print(); + (*exp_tensor.data.at(0)).print(); + (*exp_tensor.data.at(1)).print(); + (*exp_tensor.data.at(2)).print(); // let sum = exp_tensor.reduce_sum(axis, true); // (*sum.data.at(0)).print(); From 05dbea58a527cebe2447ca8eaf4b010579637cb1 Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Sun, 22 Oct 2023 12:23:08 +0300 Subject: [PATCH 27/42] fix exp --- src/operators/nn/functional/softmax.cairo | 7 ++++--- src/operators/tensor/math/exp.cairo | 5 +++++ 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/src/operators/nn/functional/softmax.cairo b/src/operators/nn/functional/softmax.cairo index dc7e0636e..2f251f1b0 100644 --- a/src/operators/nn/functional/softmax.cairo +++ b/src/operators/nn/functional/softmax.cairo @@ -37,6 +37,7 @@ fn softmaxWide< impl WDrop: Drop, impl TFixed: FixedTrait, impl WFixed: FixedTrait, + impl TPrint: PrintTrait >( z: @Tensor, axis: usize ) -> Tensor { @@ -56,9 +57,9 @@ use debug::PrintTrait; /// Cf: NNTrait::softmax docstring fn softmaxWide2(z: @Tensor, axis: usize) -> Tensor { let exp_tensor: Tensor = exp_upcast(*z); - (*exp_tensor.data.at(0)).print(); - (*exp_tensor.data.at(1)).print(); - (*exp_tensor.data.at(2)).print(); + // (*exp_tensor.data.at(0)).print(); + // (*exp_tensor.data.at(1)).print(); + // (*exp_tensor.data.at(2)).print(); // let sum = exp_tensor.reduce_sum(axis, true); // (*sum.data.at(0)).print(); diff --git a/src/operators/tensor/math/exp.cairo b/src/operators/tensor/math/exp.cairo index 5ba161030..aeb620208 100644 --- a/src/operators/tensor/math/exp.cairo +++ b/src/operators/tensor/math/exp.cairo @@ -35,6 +35,8 @@ fn exp< return TensorTrait::new(self.shape, result.span()); } +use debug::PrintTrait; + /// Cf: TensorTrait::exp docstring fn exp_upcast< T, @@ -50,6 +52,7 @@ fn exp_upcast< impl WCopy: Copy, impl WDrop: Drop, impl TIntoW: Into, + impl TPrint: PrintTrait >( mut self: Tensor ) -> Tensor { @@ -58,6 +61,8 @@ fn exp_upcast< loop { match self.data.pop_front() { Option::Some(item) => { + (*item).print(); + result.append((TIntoW::into(*item)).exp()); }, Option::None(_) => { From 62edd56d83118f5edbd552ddd057d072ada335a3 Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Sun, 22 Oct 2023 12:26:05 +0300 Subject: [PATCH 28/42] fix exp --- src/operators/nn/functional/softmax.cairo | 3 ++- src/operators/tensor/math/exp.cairo | 5 +++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/operators/nn/functional/softmax.cairo b/src/operators/nn/functional/softmax.cairo index 2f251f1b0..168fab483 100644 --- a/src/operators/nn/functional/softmax.cairo +++ b/src/operators/nn/functional/softmax.cairo @@ -37,7 +37,8 @@ fn softmaxWide< impl WDrop: Drop, impl TFixed: FixedTrait, impl WFixed: FixedTrait, - impl TPrint: PrintTrait + impl TPrint: PrintTrait, + impl WPrint: PrintTrait >( z: @Tensor, axis: usize ) -> Tensor { diff --git a/src/operators/tensor/math/exp.cairo b/src/operators/tensor/math/exp.cairo index aeb620208..0b3889511 100644 --- a/src/operators/tensor/math/exp.cairo +++ b/src/operators/tensor/math/exp.cairo @@ -52,7 +52,8 @@ fn exp_upcast< impl WCopy: Copy, impl WDrop: Drop, impl TIntoW: Into, - impl TPrint: PrintTrait + impl TPrint: PrintTrait, + impl WPrint: PrintTrait >( mut self: Tensor ) -> Tensor { @@ -61,7 +62,7 @@ fn exp_upcast< loop { match self.data.pop_front() { Option::Some(item) => { - (*item).print(); + (TIntoW::into(*item)).print(); result.append((TIntoW::into(*item)).exp()); }, From d6c86316c4d72deecc6a718c285a40a052308576 Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Sun, 22 Oct 2023 12:27:14 +0300 Subject: [PATCH 29/42] Update exp.cairo --- src/operators/tensor/math/exp.cairo | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/operators/tensor/math/exp.cairo b/src/operators/tensor/math/exp.cairo index 0b3889511..18ee0ccde 100644 --- a/src/operators/tensor/math/exp.cairo +++ b/src/operators/tensor/math/exp.cairo @@ -62,7 +62,7 @@ fn exp_upcast< loop { match self.data.pop_front() { Option::Some(item) => { - (TIntoW::into(*item)).print(); + (TIntoW::into(*item)).exp().print(); result.append((TIntoW::into(*item)).exp()); }, From dc2019f5f22210ae404b36004ab75f0cce58eccd Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Sun, 22 Oct 2023 12:28:57 +0300 Subject: [PATCH 30/42] debugin' --- .../fixed_point/implementations/fp16x16wide/math/core.cairo | 1 + src/operators/tensor/math/exp.cairo | 2 -- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/src/numbers/fixed_point/implementations/fp16x16wide/math/core.cairo b/src/numbers/fixed_point/implementations/fp16x16wide/math/core.cairo index 33c1c6d85..a3ed48b4c 100644 --- a/src/numbers/fixed_point/implementations/fp16x16wide/math/core.cairo +++ b/src/numbers/fixed_point/implementations/fp16x16wide/math/core.cairo @@ -61,6 +61,7 @@ fn eq(a: @FP16x16W, b: @FP16x16W) -> bool { // Calculates the natural exponent of x: e^x fn exp(a: FP16x16W) -> FP16x16W { + a.print(); return exp2(FixedTrait::new(94548, false) * a); // log2(e) * 2^23 ≈ 12102203 } diff --git a/src/operators/tensor/math/exp.cairo b/src/operators/tensor/math/exp.cairo index 18ee0ccde..92b60e2ac 100644 --- a/src/operators/tensor/math/exp.cairo +++ b/src/operators/tensor/math/exp.cairo @@ -62,8 +62,6 @@ fn exp_upcast< loop { match self.data.pop_front() { Option::Some(item) => { - (TIntoW::into(*item)).exp().print(); - result.append((TIntoW::into(*item)).exp()); }, Option::None(_) => { From 12ac782e299fc52c4769918f52fc306d93fb45c4 Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Sun, 22 Oct 2023 12:30:36 +0300 Subject: [PATCH 31/42] Update core.cairo --- .../fixed_point/implementations/fp16x16wide/math/core.cairo | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/numbers/fixed_point/implementations/fp16x16wide/math/core.cairo b/src/numbers/fixed_point/implementations/fp16x16wide/math/core.cairo index a3ed48b4c..6b6a74e07 100644 --- a/src/numbers/fixed_point/implementations/fp16x16wide/math/core.cairo +++ b/src/numbers/fixed_point/implementations/fp16x16wide/math/core.cairo @@ -61,7 +61,7 @@ fn eq(a: @FP16x16W, b: @FP16x16W) -> bool { // Calculates the natural exponent of x: e^x fn exp(a: FP16x16W) -> FP16x16W { - a.print(); + (FixedTrait::new(94548, false) * a).print(); return exp2(FixedTrait::new(94548, false) * a); // log2(e) * 2^23 ≈ 12102203 } From ce0c01e4722b4fc37c2c85d243012763d88dfa82 Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Sun, 22 Oct 2023 12:31:54 +0300 Subject: [PATCH 32/42] Update core.cairo --- .../fixed_point/implementations/fp16x16wide/math/core.cairo | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/numbers/fixed_point/implementations/fp16x16wide/math/core.cairo b/src/numbers/fixed_point/implementations/fp16x16wide/math/core.cairo index 6b6a74e07..87a7c8706 100644 --- a/src/numbers/fixed_point/implementations/fp16x16wide/math/core.cairo +++ b/src/numbers/fixed_point/implementations/fp16x16wide/math/core.cairo @@ -61,7 +61,6 @@ fn eq(a: @FP16x16W, b: @FP16x16W) -> bool { // Calculates the natural exponent of x: e^x fn exp(a: FP16x16W) -> FP16x16W { - (FixedTrait::new(94548, false) * a).print(); return exp2(FixedTrait::new(94548, false) * a); // log2(e) * 2^23 ≈ 12102203 } @@ -87,6 +86,8 @@ fn exp2(a: FP16x16W) -> FP16x16W { res_u = res_u * (r1 + FixedTrait::ONE()); } + res_u.print(); + if (a.sign == true) { return FixedTrait::ONE() / res_u; } else { From acd104362cd362a3a84624abaea00d94bf0656a7 Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Sun, 22 Oct 2023 12:34:46 +0300 Subject: [PATCH 33/42] Update core.cairo --- .../fixed_point/implementations/fp16x16wide/math/core.cairo | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/numbers/fixed_point/implementations/fp16x16wide/math/core.cairo b/src/numbers/fixed_point/implementations/fp16x16wide/math/core.cairo index 87a7c8706..bcd7fc8af 100644 --- a/src/numbers/fixed_point/implementations/fp16x16wide/math/core.cairo +++ b/src/numbers/fixed_point/implementations/fp16x16wide/math/core.cairo @@ -61,6 +61,7 @@ fn eq(a: @FP16x16W, b: @FP16x16W) -> bool { // Calculates the natural exponent of x: e^x fn exp(a: FP16x16W) -> FP16x16W { + (exp2(FixedTrait::new(94548, false) * a)).print(); return exp2(FixedTrait::new(94548, false) * a); // log2(e) * 2^23 ≈ 12102203 } @@ -86,8 +87,6 @@ fn exp2(a: FP16x16W) -> FP16x16W { res_u = res_u * (r1 + FixedTrait::ONE()); } - res_u.print(); - if (a.sign == true) { return FixedTrait::ONE() / res_u; } else { From 715a497d7825e6163c449b39b153a18705faa23a Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Sun, 22 Oct 2023 12:36:14 +0300 Subject: [PATCH 34/42] Update core.cairo --- .../fixed_point/implementations/fp16x16wide/math/core.cairo | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/numbers/fixed_point/implementations/fp16x16wide/math/core.cairo b/src/numbers/fixed_point/implementations/fp16x16wide/math/core.cairo index bcd7fc8af..2395e4c4a 100644 --- a/src/numbers/fixed_point/implementations/fp16x16wide/math/core.cairo +++ b/src/numbers/fixed_point/implementations/fp16x16wide/math/core.cairo @@ -6,8 +6,8 @@ use integer::{u64_safe_divmod, u64_as_non_zero, u64_wide_mul}; use orion::numbers::fixed_point::implementations::fp16x16wide::core::{ HALF, ONE, MAX, FP16x16W, FP16x16WImpl, FP16x16WAdd, FP16x16WAddEq, FP16x16WSub, FP16x16WMul, - FP16x16WMulEq, FP16x16WTryIntoU128, FP16x16WPartialEq, FP16x16WPartialOrd, FP16x16WSubEq, FP16x16WNeg, - FP16x16WDiv, FP16x16WIntoFelt252, FixedTrait + FP16x16WMulEq, FP16x16WTryIntoU128, FP16x16WPartialEq, FP16x16WPartialOrd, FP16x16WSubEq, + FP16x16WNeg, FP16x16WDiv, FP16x16WIntoFelt252, FixedTrait }; use orion::numbers::fixed_point::implementations::fp16x16wide::math::lut; @@ -61,7 +61,6 @@ fn eq(a: @FP16x16W, b: @FP16x16W) -> bool { // Calculates the natural exponent of x: e^x fn exp(a: FP16x16W) -> FP16x16W { - (exp2(FixedTrait::new(94548, false) * a)).print(); return exp2(FixedTrait::new(94548, false) * a); // log2(e) * 2^23 ≈ 12102203 } @@ -88,6 +87,7 @@ fn exp2(a: FP16x16W) -> FP16x16W { } if (a.sign == true) { + (FixedTrait::ONE() / res_u).print(); return FixedTrait::ONE() / res_u; } else { return res_u; From 214a95f4800fdce9fcc5d224616cff4f5e77dcf7 Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Sun, 22 Oct 2023 12:37:16 +0300 Subject: [PATCH 35/42] Update core.cairo --- .../fixed_point/implementations/fp16x16wide/math/core.cairo | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/numbers/fixed_point/implementations/fp16x16wide/math/core.cairo b/src/numbers/fixed_point/implementations/fp16x16wide/math/core.cairo index 2395e4c4a..38db5b84d 100644 --- a/src/numbers/fixed_point/implementations/fp16x16wide/math/core.cairo +++ b/src/numbers/fixed_point/implementations/fp16x16wide/math/core.cairo @@ -61,6 +61,7 @@ fn eq(a: @FP16x16W, b: @FP16x16W) -> bool { // Calculates the natural exponent of x: e^x fn exp(a: FP16x16W) -> FP16x16W { + a.sign.print(); return exp2(FixedTrait::new(94548, false) * a); // log2(e) * 2^23 ≈ 12102203 } @@ -87,7 +88,6 @@ fn exp2(a: FP16x16W) -> FP16x16W { } if (a.sign == true) { - (FixedTrait::ONE() / res_u).print(); return FixedTrait::ONE() / res_u; } else { return res_u; From 6a82ce9338f31a62d6ab59b96eec5867ba69d647 Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Sun, 22 Oct 2023 12:39:15 +0300 Subject: [PATCH 36/42] debbug --- .../fixed_point/implementations/fp16x16wide/math/core.cairo | 2 +- src/operators/tensor/math/exp.cairo | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/numbers/fixed_point/implementations/fp16x16wide/math/core.cairo b/src/numbers/fixed_point/implementations/fp16x16wide/math/core.cairo index 38db5b84d..9ab7fcc4c 100644 --- a/src/numbers/fixed_point/implementations/fp16x16wide/math/core.cairo +++ b/src/numbers/fixed_point/implementations/fp16x16wide/math/core.cairo @@ -61,7 +61,7 @@ fn eq(a: @FP16x16W, b: @FP16x16W) -> bool { // Calculates the natural exponent of x: e^x fn exp(a: FP16x16W) -> FP16x16W { - a.sign.print(); + // a.sign.print(); return exp2(FixedTrait::new(94548, false) * a); // log2(e) * 2^23 ≈ 12102203 } diff --git a/src/operators/tensor/math/exp.cairo b/src/operators/tensor/math/exp.cairo index 92b60e2ac..0b3889511 100644 --- a/src/operators/tensor/math/exp.cairo +++ b/src/operators/tensor/math/exp.cairo @@ -62,6 +62,8 @@ fn exp_upcast< loop { match self.data.pop_front() { Option::Some(item) => { + (TIntoW::into(*item)).print(); + result.append((TIntoW::into(*item)).exp()); }, Option::None(_) => { From 3b56c695818b5c23e042b8199dc0eee12f7d03be Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Sun, 22 Oct 2023 12:39:51 +0300 Subject: [PATCH 37/42] Update exp.cairo --- src/operators/tensor/math/exp.cairo | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/operators/tensor/math/exp.cairo b/src/operators/tensor/math/exp.cairo index 0b3889511..c4a9903c0 100644 --- a/src/operators/tensor/math/exp.cairo +++ b/src/operators/tensor/math/exp.cairo @@ -62,7 +62,7 @@ fn exp_upcast< loop { match self.data.pop_front() { Option::Some(item) => { - (TIntoW::into(*item)).print(); + ((*item)).print(); result.append((TIntoW::into(*item)).exp()); }, From 5267e358e77911e1c0bedae7887a5e7eab0e1387 Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Sun, 22 Oct 2023 12:41:41 +0300 Subject: [PATCH 38/42] debug --- src/operators/nn/functional/softmax.cairo | 7 ++++--- src/operators/tensor/math/exp.cairo | 2 -- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/operators/nn/functional/softmax.cairo b/src/operators/nn/functional/softmax.cairo index 168fab483..59c66c3d2 100644 --- a/src/operators/nn/functional/softmax.cairo +++ b/src/operators/nn/functional/softmax.cairo @@ -58,9 +58,10 @@ use debug::PrintTrait; /// Cf: NNTrait::softmax docstring fn softmaxWide2(z: @Tensor, axis: usize) -> Tensor { let exp_tensor: Tensor = exp_upcast(*z); - // (*exp_tensor.data.at(0)).print(); - // (*exp_tensor.data.at(1)).print(); - // (*exp_tensor.data.at(2)).print(); + + (*(*z).data.at(0)).print(); + (*(*z).data.at(1)).print(); + (*(*z).data.at(2)).print(); // let sum = exp_tensor.reduce_sum(axis, true); // (*sum.data.at(0)).print(); diff --git a/src/operators/tensor/math/exp.cairo b/src/operators/tensor/math/exp.cairo index c4a9903c0..92b60e2ac 100644 --- a/src/operators/tensor/math/exp.cairo +++ b/src/operators/tensor/math/exp.cairo @@ -62,8 +62,6 @@ fn exp_upcast< loop { match self.data.pop_front() { Option::Some(item) => { - ((*item)).print(); - result.append((TIntoW::into(*item)).exp()); }, Option::None(_) => { From 8b2276bab3f82b257959a3fededab5e46714a4a6 Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Mon, 23 Oct 2023 10:25:23 +0300 Subject: [PATCH 39/42] clean --- .../fp16x16wide/math/core.cairo | 1 - src/operators/nn/functional/softmax.cairo | 30 ++----------------- src/operators/tensor/math/exp.cairo | 2 -- 3 files changed, 2 insertions(+), 31 deletions(-) diff --git a/src/numbers/fixed_point/implementations/fp16x16wide/math/core.cairo b/src/numbers/fixed_point/implementations/fp16x16wide/math/core.cairo index 9ab7fcc4c..4654cd6ba 100644 --- a/src/numbers/fixed_point/implementations/fp16x16wide/math/core.cairo +++ b/src/numbers/fixed_point/implementations/fp16x16wide/math/core.cairo @@ -61,7 +61,6 @@ fn eq(a: @FP16x16W, b: @FP16x16W) -> bool { // Calculates the natural exponent of x: e^x fn exp(a: FP16x16W) -> FP16x16W { - // a.sign.print(); return exp2(FixedTrait::new(94548, false) * a); // log2(e) * 2^23 ≈ 12102203 } diff --git a/src/operators/nn/functional/softmax.cairo b/src/operators/nn/functional/softmax.cairo index 59c66c3d2..81696ef22 100644 --- a/src/operators/nn/functional/softmax.cairo +++ b/src/operators/nn/functional/softmax.cairo @@ -15,9 +15,7 @@ fn softmax< ) -> Tensor { let exp_tensor = z.exp(); let sum = exp_tensor.reduce_sum(axis, true); - let softmax = exp_tensor / sum; - - return softmax; + exp_tensor / sum } /// Cf: NNTrait::softmax docstring @@ -37,35 +35,11 @@ fn softmaxWide< impl WDrop: Drop, impl TFixed: FixedTrait, impl WFixed: FixedTrait, - impl TPrint: PrintTrait, - impl WPrint: PrintTrait >( z: @Tensor, axis: usize ) -> Tensor { let exp_tensor: Tensor = exp_upcast(*z); let sum = exp_tensor.reduce_sum(axis, true); - let softmax: Tensor = div_downcast(@exp_tensor, @sum); - - return softmax; + div_downcast(@exp_tensor, @sum) } -use orion::numbers::{FP16x16, FP16x16W}; -use orion::operators::tensor::{ - implementations::tensor_fp16x16wide::{FP16x16WTensor, FP16x16WTensorDiv}, FP16x16Tensor -}; -use debug::PrintTrait; - -/// Cf: NNTrait::softmax docstring -fn softmaxWide2(z: @Tensor, axis: usize) -> Tensor { - let exp_tensor: Tensor = exp_upcast(*z); - - (*(*z).data.at(0)).print(); - (*(*z).data.at(1)).print(); - (*(*z).data.at(2)).print(); - - // let sum = exp_tensor.reduce_sum(axis, true); - // (*sum.data.at(0)).print(); - // let softmax = exp_tensor / sum; - // return exp_tensor; - *z -} diff --git a/src/operators/tensor/math/exp.cairo b/src/operators/tensor/math/exp.cairo index 92b60e2ac..83f79eac7 100644 --- a/src/operators/tensor/math/exp.cairo +++ b/src/operators/tensor/math/exp.cairo @@ -52,8 +52,6 @@ fn exp_upcast< impl WCopy: Copy, impl WDrop: Drop, impl TIntoW: Into, - impl TPrint: PrintTrait, - impl WPrint: PrintTrait >( mut self: Tensor ) -> Tensor { From f79e0f5efff681eca6e23d5313f23c3a5d92ffdb Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Mon, 23 Oct 2023 10:35:23 +0300 Subject: [PATCH 40/42] implement fp8x23wide --- src/numbers/fixed_point/implementations.cairo | 3 +- .../implementations/fp8x23wide.cairo | 4 + .../implementations/fp8x23wide/core.cairo | 378 +++++ .../implementations/fp8x23wide/helpers.cairo | 39 + .../implementations/fp8x23wide/math.cairo | 5 + .../fp8x23wide/math/comp.cairo | 76 + .../fp8x23wide/math/core.cairo | 660 +++++++++ .../implementations/fp8x23wide/math/hyp.cairo | 159 +++ .../implementations/fp8x23wide/math/lut.cairo | 1229 +++++++++++++++++ .../fp8x23wide/math/trig.cairo | 448 ++++++ 10 files changed, 3000 insertions(+), 1 deletion(-) create mode 100644 src/numbers/fixed_point/implementations/fp8x23wide.cairo create mode 100644 src/numbers/fixed_point/implementations/fp8x23wide/core.cairo create mode 100644 src/numbers/fixed_point/implementations/fp8x23wide/helpers.cairo create mode 100644 src/numbers/fixed_point/implementations/fp8x23wide/math.cairo create mode 100644 src/numbers/fixed_point/implementations/fp8x23wide/math/comp.cairo create mode 100644 src/numbers/fixed_point/implementations/fp8x23wide/math/core.cairo create mode 100644 src/numbers/fixed_point/implementations/fp8x23wide/math/hyp.cairo create mode 100644 src/numbers/fixed_point/implementations/fp8x23wide/math/lut.cairo create mode 100644 src/numbers/fixed_point/implementations/fp8x23wide/math/trig.cairo diff --git a/src/numbers/fixed_point/implementations.cairo b/src/numbers/fixed_point/implementations.cairo index e6152e25a..d7617f9c8 100644 --- a/src/numbers/fixed_point/implementations.cairo +++ b/src/numbers/fixed_point/implementations.cairo @@ -2,4 +2,5 @@ mod fp8x23; mod fp16x16; mod fp64x64; mod fp32x32; -mod fp16x16wide; \ No newline at end of file +mod fp16x16wide; +mod fp8x23wide; \ No newline at end of file diff --git a/src/numbers/fixed_point/implementations/fp8x23wide.cairo b/src/numbers/fixed_point/implementations/fp8x23wide.cairo new file mode 100644 index 000000000..2cc1d5085 --- /dev/null +++ b/src/numbers/fixed_point/implementations/fp8x23wide.cairo @@ -0,0 +1,4 @@ +mod core; +mod math; +mod helpers; + diff --git a/src/numbers/fixed_point/implementations/fp8x23wide/core.cairo b/src/numbers/fixed_point/implementations/fp8x23wide/core.cairo new file mode 100644 index 000000000..36b64ce5e --- /dev/null +++ b/src/numbers/fixed_point/implementations/fp8x23wide/core.cairo @@ -0,0 +1,378 @@ +use debug::PrintTrait; + +use option::OptionTrait; +use result::{ResultTrait, ResultTraitImpl}; +use traits::{TryInto, Into}; + +use orion::numbers::signed_integer::{i32::i32, i8::i8}; +use orion::numbers::fixed_point::core::{FixedTrait}; +use orion::numbers::fixed_point::implementations::fp8x23wide::math::{core, trig, hyp}; +use orion::numbers::fixed_point::utils; + +/// A struct representing a fixed point number. +#[derive(Serde, Copy, Drop)] +struct FP8x23W { + mag: u64, + sign: bool +} + +// CONSTANTS + +const TWO: u64 = 16777216; // 2 ** 24 +const ONE: u64 = 8388608; // 2 ** 23 +const HALF: u64 = 4194304; // 2 ** 22 +const MAX: u64 = 2147483648; // 2 ** 31 + + +impl FP8x23WImpl of FixedTrait { + fn ZERO() -> FP8x23W { + return FP8x23W { mag: 0, sign: false }; + } + + fn ONE() -> FP8x23W { + return FP8x23W { mag: ONE, sign: false }; + } + + fn MAX() -> FP8x23W { + return FP8x23W { mag: MAX, sign: false }; + } + + fn new(mag: u64, sign: bool) -> FP8x23W { + return FP8x23W { mag: mag, sign: sign }; + } + + fn new_unscaled(mag: u64, sign: bool) -> FP8x23W { + return FP8x23W { mag: mag * ONE, sign: sign }; + } + + fn from_felt(val: felt252) -> FP8x23W { + let mag = integer::u64_try_from_felt252(utils::felt_abs(val)).unwrap(); + return FixedTrait::new(mag, utils::felt_sign(val)); + } + + fn abs(self: FP8x23W) -> FP8x23W { + return core::abs(self); + } + + fn acos(self: FP8x23W) -> FP8x23W { + return trig::acos_fast(self); + } + + fn acos_fast(self: FP8x23W) -> FP8x23W { + return trig::acos_fast(self); + } + + fn acosh(self: FP8x23W) -> FP8x23W { + return hyp::acosh(self); + } + + fn asin(self: FP8x23W) -> FP8x23W { + return trig::asin_fast(self); + } + + fn asin_fast(self: FP8x23W) -> FP8x23W { + return trig::asin_fast(self); + } + + fn asinh(self: FP8x23W) -> FP8x23W { + return hyp::asinh(self); + } + + fn atan(self: FP8x23W) -> FP8x23W { + return trig::atan_fast(self); + } + + fn atan_fast(self: FP8x23W) -> FP8x23W { + return trig::atan_fast(self); + } + + fn atanh(self: FP8x23W) -> FP8x23W { + return hyp::atanh(self); + } + + fn ceil(self: FP8x23W) -> FP8x23W { + return core::ceil(self); + } + + fn cos(self: FP8x23W) -> FP8x23W { + return trig::cos_fast(self); + } + + fn cos_fast(self: FP8x23W) -> FP8x23W { + return trig::cos_fast(self); + } + + fn cosh(self: FP8x23W) -> FP8x23W { + return hyp::cosh(self); + } + + fn floor(self: FP8x23W) -> FP8x23W { + return core::floor(self); + } + + // Calculates the natural exponent of x: e^x + fn exp(self: FP8x23W) -> FP8x23W { + return core::exp(self); + } + + // Calculates the binary exponent of x: 2^x + fn exp2(self: FP8x23W) -> FP8x23W { + return core::exp2(self); + } + + // Calculates the natural logarithm of x: ln(x) + // self must be greater than zero + fn ln(self: FP8x23W) -> FP8x23W { + return core::ln(self); + } + + // Calculates the binary logarithm of x: log2(x) + // self must be greather than zero + fn log2(self: FP8x23W) -> FP8x23W { + return core::log2(self); + } + + // Calculates the base 10 log of x: log10(x) + // self must be greater than zero + fn log10(self: FP8x23W) -> FP8x23W { + return core::log10(self); + } + + // Calclates the value of x^y and checks for overflow before returning + // self is a fixed point value + // b is a fixed point value + fn pow(self: FP8x23W, b: FP8x23W) -> FP8x23W { + return core::pow(self, b); + } + + fn round(self: FP8x23W) -> FP8x23W { + return core::round(self); + } + + fn sin(self: FP8x23W) -> FP8x23W { + return trig::sin_fast(self); + } + + fn sin_fast(self: FP8x23W) -> FP8x23W { + return trig::sin_fast(self); + } + + fn sinh(self: FP8x23W) -> FP8x23W { + return hyp::sinh(self); + } + + // Calculates the square root of a fixed point value + // x must be positive + fn sqrt(self: FP8x23W) -> FP8x23W { + return core::sqrt(self); + } + + fn tan(self: FP8x23W) -> FP8x23W { + return trig::tan_fast(self); + } + + fn tan_fast(self: FP8x23W) -> FP8x23W { + return trig::tan_fast(self); + } + + fn tanh(self: FP8x23W) -> FP8x23W { + return hyp::tanh(self); + } + + fn sign(self: FP8x23W) -> FP8x23W { + return core::sign(self); + } +} + + +impl FP8x23WPrint of PrintTrait { + fn print(self: FP8x23W) { + self.sign.print(); + self.mag.print(); + } +} + +// Into a raw felt without unscaling +impl FP8x23WIntoFelt252 of Into { + fn into(self: FP8x23W) -> felt252 { + let mag_felt = self.mag.into(); + + if self.sign { + return mag_felt * -1; + } else { + return mag_felt * 1; + } + } +} + +impl FP8x23WTryIntoU128 of TryInto { + fn try_into(self: FP8x23W) -> Option { + if self.sign { + return Option::None(()); + } else { + // Unscale the magnitude and round down + return Option::Some((self.mag / ONE).into()); + } + } +} + +impl FP8x23WTryIntoU64 of TryInto { + fn try_into(self: FP8x23W) -> Option { + if self.sign { + return Option::None(()); + } else { + // Unscale the magnitude and round down + return Option::Some((self.mag / ONE).into()); + } + } +} + + +impl FP8x23WTryIntoU16 of TryInto { + fn try_into(self: FP8x23W) -> Option { + if self.sign { + Option::None(()) + } else { + // Unscale the magnitude and round down + return (self.mag / ONE).try_into(); + } + } +} + +impl FP8x23WTryIntoU8 of TryInto { + fn try_into(self: FP8x23W) -> Option { + if self.sign { + Option::None(()) + } else { + // Unscale the magnitude and round down + return (self.mag / ONE).try_into(); + } + } +} + +impl FP8x23WIntoI32 of Into { + fn into(self: FP8x23W) -> i32 { + _i32_into_fp(self) + } +} + +impl FP8x23WTryIntoI8 of TryInto { + fn try_into(self: FP8x23W) -> Option { + _i8_try_from_fp(self) + } +} + +impl FP8x23WPartialEq of PartialEq { + #[inline(always)] + fn eq(lhs: @FP8x23W, rhs: @FP8x23W) -> bool { + return core::eq(lhs, rhs); + } + + #[inline(always)] + fn ne(lhs: @FP8x23W, rhs: @FP8x23W) -> bool { + return core::ne(lhs, rhs); + } +} + +impl FP8x23WAdd of Add { + fn add(lhs: FP8x23W, rhs: FP8x23W) -> FP8x23W { + return core::add(lhs, rhs); + } +} + +impl FP8x23WAddEq of AddEq { + #[inline(always)] + fn add_eq(ref self: FP8x23W, other: FP8x23W) { + self = Add::add(self, other); + } +} + +impl FP8x23WSub of Sub { + fn sub(lhs: FP8x23W, rhs: FP8x23W) -> FP8x23W { + return core::sub(lhs, rhs); + } +} + +impl FP8x23WSubEq of SubEq { + #[inline(always)] + fn sub_eq(ref self: FP8x23W, other: FP8x23W) { + self = Sub::sub(self, other); + } +} + +impl FP8x23WMul of Mul { + fn mul(lhs: FP8x23W, rhs: FP8x23W) -> FP8x23W { + return core::mul(lhs, rhs); + } +} + +impl FP8x23WMulEq of MulEq { + #[inline(always)] + fn mul_eq(ref self: FP8x23W, other: FP8x23W) { + self = Mul::mul(self, other); + } +} + +impl FP8x23WDiv of Div { + fn div(lhs: FP8x23W, rhs: FP8x23W) -> FP8x23W { + return core::div(lhs, rhs); + } +} + +impl FP8x23WDivEq of DivEq { + #[inline(always)] + fn div_eq(ref self: FP8x23W, other: FP8x23W) { + self = Div::div(self, other); + } +} + +impl FP8x23WPartialOrd of PartialOrd { + #[inline(always)] + fn ge(lhs: FP8x23W, rhs: FP8x23W) -> bool { + return core::ge(lhs, rhs); + } + + #[inline(always)] + fn gt(lhs: FP8x23W, rhs: FP8x23W) -> bool { + return core::gt(lhs, rhs); + } + + #[inline(always)] + fn le(lhs: FP8x23W, rhs: FP8x23W) -> bool { + return core::le(lhs, rhs); + } + + #[inline(always)] + fn lt(lhs: FP8x23W, rhs: FP8x23W) -> bool { + return core::lt(lhs, rhs); + } +} + +impl FP8x23WNeg of Neg { + #[inline(always)] + fn neg(a: FP8x23W) -> FP8x23W { + return core::neg(a); + } +} + +impl FP8x23WRem of Rem { + #[inline(always)] + fn rem(lhs: FP8x23W, rhs: FP8x23W) -> FP8x23W { + return core::rem(lhs, rhs); + } +} + +/// INTERNAL + +fn _i32_into_fp(x: FP8x23W) -> i32 { + i32 { mag: (x.mag / ONE).try_into().unwrap(), sign: x.sign } +} + +fn _i8_try_from_fp(x: FP8x23W) -> Option { + let unscaled_mag: Option = (x.mag / ONE).try_into(); + + match unscaled_mag { + Option::Some(val) => Option::Some(i8 { mag: unscaled_mag.unwrap(), sign: x.sign }), + Option::None(_) => Option::None(()) + } +} diff --git a/src/numbers/fixed_point/implementations/fp8x23wide/helpers.cairo b/src/numbers/fixed_point/implementations/fp8x23wide/helpers.cairo new file mode 100644 index 000000000..a627803be --- /dev/null +++ b/src/numbers/fixed_point/implementations/fp8x23wide/helpers.cairo @@ -0,0 +1,39 @@ +use debug::PrintTrait; +use traits::Into; + +use orion::numbers::fixed_point::implementations::fp8x23wide::core::{ + HALF, ONE, TWO, FP8x23W, FP8x23WSub, FP8x23WDiv, FixedTrait, FP8x23WPrint +}; + +const DEFAULT_PRECISION: u64 = 8; // 1e-6 + +// To use `DEFAULT_PRECISION`, final arg is: `Option::None(())`. +// To use `custom_precision` of 430_u64: `Option::Some(430_u64)`. +fn assert_precise(result: FP8x23W, expected: felt252, msg: felt252, custom_precision: Option) { + let precision = match custom_precision { + Option::Some(val) => val, + Option::None(_) => DEFAULT_PRECISION, + }; + + let diff = (result - FixedTrait::from_felt(expected)).mag; + + if (diff > precision) { + result.print(); + assert(diff <= precision, msg); + } +} + +fn assert_relative(result: FP8x23W, expected: felt252, msg: felt252, custom_precision: Option) { + let precision = match custom_precision { + Option::Some(val) => val, + Option::None(_) => DEFAULT_PRECISION, + }; + + let diff = result - FixedTrait::from_felt(expected); + let rel_diff = (diff / result).mag; + + if (rel_diff > precision) { + result.print(); + assert(rel_diff <= precision, msg); + } +} diff --git a/src/numbers/fixed_point/implementations/fp8x23wide/math.cairo b/src/numbers/fixed_point/implementations/fp8x23wide/math.cairo new file mode 100644 index 000000000..970c65f30 --- /dev/null +++ b/src/numbers/fixed_point/implementations/fp8x23wide/math.cairo @@ -0,0 +1,5 @@ +mod core; +mod comp; +mod lut; +mod trig; +mod hyp; diff --git a/src/numbers/fixed_point/implementations/fp8x23wide/math/comp.cairo b/src/numbers/fixed_point/implementations/fp8x23wide/math/comp.cairo new file mode 100644 index 000000000..95b329109 --- /dev/null +++ b/src/numbers/fixed_point/implementations/fp8x23wide/math/comp.cairo @@ -0,0 +1,76 @@ +use orion::numbers::fixed_point::implementations::fp8x23wide::core::{ + FP8x23W, FixedTrait, FP8x23WPartialOrd, FP8x23WPartialEq +}; + +fn max(a: FP8x23W, b: FP8x23W) -> FP8x23W { + if (a >= b) { + return a; + } else { + return b; + } +} + +fn min(a: FP8x23W, b: FP8x23W) -> FP8x23W { + if (a <= b) { + return a; + } else { + return b; + } +} + +fn xor(a: FP8x23W, b: FP8x23W) -> bool { + if (a == FixedTrait::new(0, false) || b == FixedTrait::new(0, false)) && (a != b) { + return true; + } else { + return false; + } +} + +fn or(a: FP8x23W, b: FP8x23W) -> bool { + let zero = FixedTrait::new(0, false); + if a == zero && b == zero { + return false; + } else { + return true; + } +} + +// Tests -------------------------------------------------------------------------------------------------------------- + +#[test] +fn test_max() { + let a = FixedTrait::new_unscaled(1, false); + let b = FixedTrait::new_unscaled(0, false); + let c = FixedTrait::new_unscaled(1, true); + + assert(max(a, a) == a, 'max(a, a)'); + assert(max(a, b) == a, 'max(a, b)'); + assert(max(a, c) == a, 'max(a, c)'); + + assert(max(b, a) == a, 'max(b, a)'); + assert(max(b, b) == b, 'max(b, b)'); + assert(max(b, c) == b, 'max(b, c)'); + + assert(max(c, a) == a, 'max(c, a)'); + assert(max(c, b) == b, 'max(c, b)'); + assert(max(c, c) == c, 'max(c, c)'); +} + +#[test] +fn test_min() { + let a = FixedTrait::new_unscaled(1, false); + let b = FixedTrait::new_unscaled(0, false); + let c = FixedTrait::new_unscaled(1, true); + + assert(min(a, a) == a, 'min(a, a)'); + assert(min(a, b) == b, 'min(a, b)'); + assert(min(a, c) == c, 'min(a, c)'); + + assert(min(b, a) == b, 'min(b, a)'); + assert(min(b, b) == b, 'min(b, b)'); + assert(min(b, c) == c, 'min(b, c)'); + + assert(min(c, a) == c, 'min(c, a)'); + assert(min(c, b) == c, 'min(c, b)'); + assert(min(c, c) == c, 'min(c, c)'); +} diff --git a/src/numbers/fixed_point/implementations/fp8x23wide/math/core.cairo b/src/numbers/fixed_point/implementations/fp8x23wide/math/core.cairo new file mode 100644 index 000000000..129ff02c8 --- /dev/null +++ b/src/numbers/fixed_point/implementations/fp8x23wide/math/core.cairo @@ -0,0 +1,660 @@ +use core::debug::PrintTrait; +use option::OptionTrait; +use result::{ResultTrait, ResultTraitImpl}; +use traits::{Into, TryInto}; +use integer::{u64_safe_divmod, u64_as_non_zero, u64_wide_mul}; + +use orion::numbers::fixed_point::implementations::fp8x23wide::core::{ + HALF, ONE, MAX, FP8x23W, FP8x23WAdd, FP8x23WImpl, FP8x23WAddEq, FP8x23WSub, FP8x23WMul, + FP8x23WMulEq, FP8x23WTryIntoU128, FP8x23WPartialEq, FP8x23WPartialOrd, FP8x23WSubEq, FP8x23WNeg, + FP8x23WDiv, FP8x23WIntoFelt252, FixedTrait +}; +use orion::numbers::fixed_point::implementations::fp8x23wide::math::lut; + +// PUBLIC + +fn abs(a: FP8x23W) -> FP8x23W { + return FixedTrait::new(a.mag, false); +} + +fn add(a: FP8x23W, b: FP8x23W) -> FP8x23W { + if a.sign == b.sign { + return FixedTrait::new(a.mag + b.mag, a.sign); + } + + if a.mag == b.mag { + return FixedTrait::ZERO(); + } + + if (a.mag > b.mag) { + return FixedTrait::new(a.mag - b.mag, a.sign); + } else { + return FixedTrait::new(b.mag - a.mag, b.sign); + } +} + +fn ceil(a: FP8x23W) -> FP8x23W { + let (div, rem) = u64_safe_divmod(a.mag, u64_as_non_zero(ONE)); + + if rem == 0 { + return a; + } else if !a.sign { + return FixedTrait::new_unscaled(div + 1, false); + } else if div == 0 { + return FixedTrait::new_unscaled(0, false); + } else { + return FixedTrait::new_unscaled(div, true); + } +} + +fn div(a: FP8x23W, b: FP8x23W) -> FP8x23W { + let a_u64 = integer::u64_wide_mul(a.mag, ONE); + let res_u64 = a_u64 / b.mag.into(); + + // Re-apply sign + return FixedTrait::new(res_u64.try_into().unwrap(), a.sign ^ b.sign); +} + +fn eq(a: @FP8x23W, b: @FP8x23W) -> bool { + return (*a.mag == *b.mag) && (*a.sign == *b.sign); +} + +// Calculates the natural exponent of x: e^x +fn exp(a: FP8x23W) -> FP8x23W { + return exp2(FixedTrait::new(12102203, false) * a); // log2(e) * 2^23 ≈ 12102203 +} + +// Calculates the binary exponent of x: 2^x +fn exp2(a: FP8x23W) -> FP8x23W { + if (a.mag == 0) { + return FixedTrait::ONE(); + } + + let (int_part, frac_part) = integer::u64_safe_divmod(a.mag, u64_as_non_zero(ONE)); + let int_res = FixedTrait::new_unscaled(lut::exp2(int_part), false); + let mut res_u = int_res; + + if frac_part != 0 { + let frac = FixedTrait::new(frac_part, false); + let r8 = FixedTrait::new(19, false) * frac; + let r7 = (r8 + FixedTrait::new(105, false)) * frac; + let r6 = (r7 + FixedTrait::new(1324, false)) * frac; + let r5 = (r6 + FixedTrait::new(11159, false)) * frac; + let r4 = (r5 + FixedTrait::new(80695, false)) * frac; + let r3 = (r4 + FixedTrait::new(465599, false)) * frac; + let r2 = (r3 + FixedTrait::new(2015166, false)) * frac; + let r1 = (r2 + FixedTrait::new(5814540, false)) * frac; + res_u = res_u * (r1 + FixedTrait::ONE()); + } + + if (a.sign == true) { + return FixedTrait::ONE() / res_u; + } else { + return res_u; + } +} + +fn exp2_int(exp: u64) -> FP8x23W { + return FixedTrait::new_unscaled(lut::exp2(exp), false); +} + +fn floor(a: FP8x23W) -> FP8x23W { + let (div, rem) = integer::u64_safe_divmod(a.mag, u64_as_non_zero(ONE)); + + if rem == 0 { + return a; + } else if !a.sign { + return FixedTrait::new_unscaled(div, false); + } else { + return FixedTrait::new_unscaled(div + 1, true); + } +} + +fn ge(a: FP8x23W, b: FP8x23W) -> bool { + if a.sign != b.sign { + return !a.sign; + } else { + return (a.mag == b.mag) || ((a.mag > b.mag) ^ a.sign); + } +} + +fn gt(a: FP8x23W, b: FP8x23W) -> bool { + if a.sign != b.sign { + return !a.sign; + } else { + return (a.mag != b.mag) && ((a.mag > b.mag) ^ a.sign); + } +} + +fn le(a: FP8x23W, b: FP8x23W) -> bool { + if a.sign != b.sign { + return a.sign; + } else { + return (a.mag == b.mag) || ((a.mag < b.mag) ^ a.sign); + } +} + +// Calculates the natural logarithm of x: ln(x) +// self must be greater than zero +fn ln(a: FP8x23W) -> FP8x23W { + return FixedTrait::new(5814540, false) * log2(a); // ln(2) = 0.693... +} + +// Calculates the binary logarithm of x: log2(x) +// self must be greather than zero +fn log2(a: FP8x23W) -> FP8x23W { + assert(a.sign == false, 'must be positive'); + + if (a.mag == ONE) { + return FixedTrait::ZERO(); + } else if (a.mag < ONE) { + // Compute true inverse binary log if 0 < x < 1 + let div = FixedTrait::ONE() / a; + return -log2(div); + } + + let whole = a.mag / ONE; + let (msb, div) = lut::msb(whole); + + if a.mag == div * ONE { + return FixedTrait::new_unscaled(msb, false); + } else { + let norm = a / FixedTrait::new_unscaled(div, false); + let r8 = FixedTrait::new(76243, true) * norm; + let r7 = (r8 + FixedTrait::new(1038893, false)) * norm; + let r6 = (r7 + FixedTrait::new(6277679, true)) * norm; + let r5 = (r6 + FixedTrait::new(22135645, false)) * norm; + let r4 = (r5 + FixedTrait::new(50444339, true)) * norm; + let r3 = (r4 + FixedTrait::new(77896489, false)) * norm; + let r2 = (r3 + FixedTrait::new(83945943, true)) * norm; + let r1 = (r2 + FixedTrait::new(68407458, false)) * norm; + return r1 + FixedTrait::new(28734280, true) + FixedTrait::new_unscaled(msb, false); + } +} + +// Calculates the base 10 log of x: log10(x) +// self must be greater than zero +fn log10(a: FP8x23W) -> FP8x23W { + return FixedTrait::new(2525223, false) * log2(a); // log10(2) = 0.301... +} + +fn lt(a: FP8x23W, b: FP8x23W) -> bool { + if a.sign != b.sign { + return a.sign; + } else { + return (a.mag != b.mag) && ((a.mag < b.mag) ^ a.sign); + } +} + +fn mul(a: FP8x23W, b: FP8x23W) -> FP8x23W { + let prod_u128 = integer::u64_wide_mul(a.mag, b.mag); + + // Re-apply sign + return FixedTrait::new((prod_u128 / ONE.into()).try_into().unwrap(), a.sign ^ b.sign); +} + +fn ne(a: @FP8x23W, b: @FP8x23W) -> bool { + return (*a.mag != *b.mag) || (*a.sign != *b.sign); +} + +fn neg(a: FP8x23W) -> FP8x23W { + if a.mag == 0 { + return a; + } else if !a.sign { + return FixedTrait::new(a.mag, !a.sign); + } else { + return FixedTrait::new(a.mag, false); + } +} + +// Calclates the value of x^y and checks for overflow before returning +// self is a FP8x23W point value +// b is a FP8x23W point value +fn pow(a: FP8x23W, b: FP8x23W) -> FP8x23W { + let (div, rem) = integer::u64_safe_divmod(b.mag, u64_as_non_zero(ONE)); + + // use the more performant integer pow when y is an int + if (rem == 0) { + return pow_int(a, b.mag / ONE, b.sign); + } + + // x^y = exp(y*ln(x)) for x > 0 will error for x < 0 + return exp(b * ln(a)); +} + +// Calclates the value of a^b and checks for overflow before returning +fn pow_int(a: FP8x23W, b: u64, sign: bool) -> FP8x23W { + let mut x = a; + let mut n = b; + + if sign == true { + x = FixedTrait::ONE() / x; + } + + if n == 0 { + return FixedTrait::ONE(); + } + + let mut y = FixedTrait::ONE(); + let two = integer::u64_as_non_zero(2); + + loop { + if n <= 1 { + break; + } + + let (div, rem) = integer::u64_safe_divmod(n, two); + + if rem == 1 { + y = x * y; + } + + x = x * x; + n = div; + }; + + return x * y; +} + +fn rem(a: FP8x23W, b: FP8x23W) -> FP8x23W { + return a - floor(a / b) * b; +} + +fn round(a: FP8x23W) -> FP8x23W { + let (div, rem) = integer::u64_safe_divmod(a.mag, u64_as_non_zero(ONE)); + + if (HALF <= rem) { + return FixedTrait::new_unscaled(div + 1, a.sign); + } else { + return FixedTrait::new_unscaled(div, a.sign); + } +} + +// Calculates the square root of a FP8x23W point value +// x must be positive +fn sqrt(a: FP8x23W) -> FP8x23W { + assert(a.sign == false, 'must be positive'); + + let root = integer::u64_sqrt(a.mag.into() * ONE.into()); + return FixedTrait::new(root.into(), false); +} + +fn sub(a: FP8x23W, b: FP8x23W) -> FP8x23W { + return add(a, -b); +} + +fn sign(a: FP8x23W) -> FP8x23W { + if a.mag == 0 { + FixedTrait::new(0, false) + } else { + FixedTrait::new(ONE, a.sign) + } +} + +// Tests -------------------------------------------------------------------------------------------------------------- + +use orion::numbers::fixed_point::implementations::fp8x23wide::helpers::{ + assert_precise, assert_relative +}; +use orion::numbers::fixed_point::implementations::fp8x23wide::math::trig::{PI, HALF_PI}; + +#[test] +fn test_into() { + let a = FixedTrait::::new_unscaled(5, false); + assert(a.mag == 5 * ONE, 'invalid result'); +} + +#[test] +fn test_try_into_u128() { + // Positive unscaled + let a = FixedTrait::::new_unscaled(5, false); + assert(a.try_into().unwrap() == 5_u128, 'invalid result'); + + // Positive scaled + let b = FixedTrait::::new(5 * ONE, false); + assert(b.try_into().unwrap() == 5_u128, 'invalid result'); + + // Zero + let d = FixedTrait::::new_unscaled(0, false); + assert(d.try_into().unwrap() == 0_u128, 'invalid result'); +} + +#[test] +#[should_panic] +fn test_negative_try_into_u128() { + let a = FixedTrait::::new_unscaled(1, true); + let a: u128 = a.try_into().unwrap(); +} + +#[test] +#[available_gas(1000000)] +fn test_acos() { + let a = FixedTrait::::ONE(); + assert(a.acos().into() == 0, 'invalid one'); +} + +#[test] +#[available_gas(1000000)] +fn test_asin() { + let a = FixedTrait::ONE(); + assert_precise(a.asin(), HALF_PI.into(), 'invalid one', Option::None(())); // PI / 2 +} + +#[test] +#[available_gas(2000000)] +fn test_atan() { + let a = FixedTrait::new(2 * ONE, false); + assert_relative(a.atan(), 9287469, 'invalid two', Option::None(())); +} + +#[test] +fn test_ceil() { + let a = FixedTrait::new(24326963, false); // 2.9 + assert(ceil(a).mag == 3 * ONE, 'invalid pos decimal'); +} + +#[test] +fn test_floor() { + let a = FixedTrait::new(24326963, false); // 2.9 + assert(floor(a).mag == 2 * ONE, 'invalid pos decimal'); +} + +#[test] +fn test_round() { + let a = FixedTrait::new(24326963, false); // 2.9 + assert(round(a).mag == 3 * ONE, 'invalid pos decimal'); +} + +#[test] +#[should_panic] +fn test_sqrt_fail() { + let a = FixedTrait::new_unscaled(25, true); + sqrt(a); +} + +#[test] +fn test_sqrt() { + let mut a = FixedTrait::new_unscaled(0, false); + assert(sqrt(a).mag == 0, 'invalid zero root'); + a = FixedTrait::new_unscaled(25, false); + assert(sqrt(a).mag == 5 * ONE, 'invalid pos root'); +} + + +#[test] +#[available_gas(100000)] +fn test_msb() { + let a = FixedTrait::::new_unscaled(100, false); + let (msb, div) = lut::msb(a.mag / ONE); + assert(msb == 6, 'invalid msb'); + assert(div == 64, 'invalid msb ceil'); +} + +#[test] +#[available_gas(600000)] +fn test_pow() { + let a = FixedTrait::new_unscaled(3, false); + let b = FixedTrait::new_unscaled(4, false); + assert(pow(a, b).mag == 81 * ONE, 'invalid pos base power'); +} + +#[test] +#[available_gas(900000)] +fn test_pow_frac() { + let a = FixedTrait::new_unscaled(3, false); + let b = FixedTrait::new(4194304, false); // 0.5 + assert_relative( + pow(a, b), 14529495, 'invalid pos base power', Option::None(()) + ); // 1.7320508075688772 +} + +#[test] +#[available_gas(1000000)] +fn test_exp() { + let a = FixedTrait::new_unscaled(2, false); + assert_relative(exp(a), 61983895, 'invalid exp of 2', Option::None(())); // 7.389056098793725 +} + +#[test] +#[available_gas(400000)] +fn test_exp2() { + let a = FixedTrait::new_unscaled(5, false); + assert(exp2(a).mag == 268435456, 'invalid exp2 of 2'); +} + +#[test] +#[available_gas(20000)] +fn test_exp2_int() { + assert(exp2_int(5).into() == 268435456, 'invalid exp2 of 2'); +} + +#[test] +#[available_gas(1000000)] +fn test_ln() { + let mut a = FixedTrait::new_unscaled(1, false); + assert(ln(a).mag == 0, 'invalid ln of 1'); + + a = FixedTrait::new(22802601, false); + assert_relative(ln(a), ONE.into(), 'invalid ln of 2.7...', Option::None(())); +} + +#[test] +#[available_gas(1000000)] +fn test_log2() { + let mut a = FixedTrait::new_unscaled(32, false); + assert(log2(a) == FixedTrait::new_unscaled(5, false), 'invalid log2 32'); + + a = FixedTrait::new_unscaled(10, false); + assert_relative(log2(a), 27866353, 'invalid log2 10', Option::None(())); // 3.321928094887362 +} + +#[test] +#[available_gas(1000000)] +fn test_log10() { + let a = FixedTrait::new_unscaled(100, false); + assert_relative(log10(a), 2 * ONE.into(), 'invalid log10', Option::None(())); +} + +#[test] +fn test_eq() { + let a = FixedTrait::new_unscaled(42, false); + let b = FixedTrait::new_unscaled(42, false); + let c = eq(@a, @b); + assert(c == true, 'invalid result'); +} + +#[test] +fn test_ne() { + let a = FixedTrait::new_unscaled(42, false); + let b = FixedTrait::new_unscaled(42, false); + let c = ne(@a, @b); + assert(c == false, 'invalid result'); +} + +#[test] +fn test_add() { + let a = FixedTrait::new_unscaled(1, false); + let b = FixedTrait::new_unscaled(2, false); + assert(add(a, b) == FixedTrait::new_unscaled(3, false), 'invalid result'); +} + +#[test] +fn test_add_eq() { + let mut a = FixedTrait::new_unscaled(1, false); + let b = FixedTrait::new_unscaled(2, false); + a += b; + assert(a == FixedTrait::::new_unscaled(3, false), 'invalid result'); +} + +#[test] +fn test_sub() { + let a = FixedTrait::new_unscaled(5, false); + let b = FixedTrait::new_unscaled(2, false); + let c = a - b; + assert(c == FixedTrait::::new_unscaled(3, false), 'false result invalid'); +} + +#[test] +fn test_sub_eq() { + let mut a = FixedTrait::new_unscaled(5, false); + let b = FixedTrait::new_unscaled(2, false); + a -= b; + assert(a == FixedTrait::::new_unscaled(3, false), 'invalid result'); +} + +#[test] +#[available_gas(100000)] +fn test_mul_pos() { + let a = FP8x23W { mag: 24326963, sign: false }; + let b = FP8x23W { mag: 24326963, sign: false }; + let c = a * b; + assert(c.mag == 70548192, 'invalid result'); +} + +#[test] +fn test_mul_neg() { + let a = FixedTrait::new_unscaled(5, false); + let b = FixedTrait::new_unscaled(2, true); + let c = a * b; + assert(c == FixedTrait::::new_unscaled(10, true), 'invalid result'); +} + +#[test] +fn test_mul_eq() { + let mut a = FixedTrait::new_unscaled(5, false); + let b = FixedTrait::new_unscaled(2, true); + a *= b; + assert(a == FixedTrait::::new_unscaled(10, true), 'invalid result'); +} + +#[test] +fn test_div() { + let a = FixedTrait::new_unscaled(10, false); + let b = FixedTrait::::new(24326963, false); // 2.9 + let c = a / b; + assert(c.mag == 28926234, 'invalid pos decimal'); // 3.4482758620689653 +} + +#[test] +fn test_le() { + let a = FixedTrait::new_unscaled(1, false); + let b = FixedTrait::new_unscaled(0, false); + let c = FixedTrait::::new_unscaled(1, true); + + assert(a <= a, 'a <= a'); + assert(a <= b == false, 'a <= b'); + assert(a <= c == false, 'a <= c'); + + assert(b <= a, 'b <= a'); + assert(b <= b, 'b <= b'); + assert(b <= c == false, 'b <= c'); + + assert(c <= a, 'c <= a'); + assert(c <= b, 'c <= b'); + assert(c <= c, 'c <= c'); +} + +#[test] +fn test_lt() { + let a = FixedTrait::new_unscaled(1, false); + let b = FixedTrait::new_unscaled(0, false); + let c = FixedTrait::::new_unscaled(1, true); + + assert(a < a == false, 'a < a'); + assert(a < b == false, 'a < b'); + assert(a < c == false, 'a < c'); + + assert(b < a, 'b < a'); + assert(b < b == false, 'b < b'); + assert(b < c == false, 'b < c'); + + assert(c < a, 'c < a'); + assert(c < b, 'c < b'); + assert(c < c == false, 'c < c'); +} + +#[test] +fn test_ge() { + let a = FixedTrait::new_unscaled(1, false); + let b = FixedTrait::new_unscaled(0, false); + let c = FixedTrait::::new_unscaled(1, true); + + assert(a >= a, 'a >= a'); + assert(a >= b, 'a >= b'); + assert(a >= c, 'a >= c'); + + assert(b >= a == false, 'b >= a'); + assert(b >= b, 'b >= b'); + assert(b >= c, 'b >= c'); + + assert(c >= a == false, 'c >= a'); + assert(c >= b == false, 'c >= b'); + assert(c >= c, 'c >= c'); +} + +#[test] +fn test_gt() { + let a = FixedTrait::new_unscaled(1, false); + let b = FixedTrait::new_unscaled(0, false); + let c = FixedTrait::::new_unscaled(1, true); + + assert(a > a == false, 'a > a'); + assert(a > b, 'a > b'); + assert(a > c, 'a > c'); + + assert(b > a == false, 'b > a'); + assert(b > b == false, 'b > b'); + assert(b > c, 'b > c'); + + assert(c > a == false, 'c > a'); + assert(c > b == false, 'c > b'); + assert(c > c == false, 'c > c'); +} + +#[test] +#[available_gas(1000000)] +fn test_cos() { + let a = FixedTrait::::new(HALF_PI, false); + assert(a.cos().into() == 0, 'invalid half pi'); +} + +#[test] +#[available_gas(1000000)] +fn test_sin() { + let a = FixedTrait::new(HALF_PI, false); + assert_precise(a.sin(), ONE.into(), 'invalid half pi', Option::None(())); +} + +#[test] +#[available_gas(2000000)] +fn test_tan() { + let a = FixedTrait::::new(HALF_PI / 2, false); + assert(a.tan().mag == 8388608, 'invalid quarter pi'); +} + +#[test] +#[available_gas(2000000)] +fn test_sign() { + let a = FixedTrait::::new(0, false); + assert(a.sign().mag == 0 && !a.sign().sign, 'invalid sign (0, true)'); + + let a = FixedTrait::::new(HALF, true); + assert(a.sign().mag == ONE && a.sign().sign, 'invalid sign (HALF, true)'); + + let a = FixedTrait::::new(HALF, false); + assert(a.sign().mag == ONE && !a.sign().sign, 'invalid sign (HALF, false)'); + + let a = FixedTrait::::new(ONE, true); + assert(a.sign().mag == ONE && a.sign().sign, 'invalid sign (ONE, true)'); + + let a = FixedTrait::::new(ONE, false); + assert(a.sign().mag == ONE && !a.sign().sign, 'invalid sign (ONE, false)'); +} + +#[test] +#[should_panic] +#[available_gas(2000000)] +fn test_sign_fail() { + let a = FixedTrait::::new(HALF, true); + assert(a.sign().mag != ONE && !a.sign().sign, 'invalid sign (HALF, true)'); +} diff --git a/src/numbers/fixed_point/implementations/fp8x23wide/math/hyp.cairo b/src/numbers/fixed_point/implementations/fp8x23wide/math/hyp.cairo new file mode 100644 index 000000000..ed9b66391 --- /dev/null +++ b/src/numbers/fixed_point/implementations/fp8x23wide/math/hyp.cairo @@ -0,0 +1,159 @@ +use core::debug::PrintTrait; +use orion::numbers::fixed_point::implementations::fp8x23wide::core::{ + HALF, ONE, TWO, FP8x23W, FP8x23WImpl, FP8x23WAdd, FP8x23WAddEq, FP8x23WSub, FP8x23WMul, FP8x23WMulEq, + FP8x23WTryIntoU128, FP8x23WPartialEq, FP8x23WPartialOrd, FP8x23WSubEq, FP8x23WNeg, FP8x23WDiv, + FP8x23WIntoFelt252, FixedTrait +}; + +// Calculates hyperbolic cosine of a (fixed point) +fn cosh(a: FP8x23W) -> FP8x23W { + let ea = a.exp(); + return (ea + (FixedTrait::ONE() / ea)) / FixedTrait::new(TWO, false); +} + +// Calculates hyperbolic sine of a (fixed point) +fn sinh(a: FP8x23W) -> FP8x23W { + let ea = a.exp(); + return (ea - (FixedTrait::ONE() / ea)) / FixedTrait::new(TWO, false); +} + +// Calculates hyperbolic tangent of a (fixed point) +fn tanh(a: FP8x23W) -> FP8x23W { + let ea = a.exp(); + let ea_i = FixedTrait::ONE() / ea; + return (ea - ea_i) / (ea + ea_i); +} + +// Calculates inverse hyperbolic cosine of a (fixed point) +fn acosh(a: FP8x23W) -> FP8x23W { + let root = (a * a - FixedTrait::ONE()).sqrt(); + return (a + root).ln(); +} + +// Calculates inverse hyperbolic sine of a (fixed point) +fn asinh(a: FP8x23W) -> FP8x23W { + let root = (a * a + FixedTrait::ONE()).sqrt(); + return (a + root).ln(); +} + +// Calculates inverse hyperbolic tangent of a (fixed point) +fn atanh(a: FP8x23W) -> FP8x23W { + let one = FixedTrait::ONE(); + let ln_arg = (one + a) / (one - a); + return ln_arg.ln() / FixedTrait::new(TWO, false); +} + +// Tests -------------------------------------------------------------------------------------------------------------- + +use option::OptionTrait; +use traits::Into; + +use orion::numbers::fixed_point::implementations::fp8x23wide::helpers::assert_precise; + +#[test] +#[available_gas(10000000)] +fn test_cosh() { + let a = FixedTrait::new(TWO, false); + assert_precise(cosh(a), 31559585, 'invalid two', Option::None(())); // 3.762195691016423 + + let a = FixedTrait::ONE(); + assert_precise(cosh(a), 12944299, 'invalid one', Option::None(())); // 1.5430806347841253 + + let a = FixedTrait::ZERO(); + assert_precise(cosh(a), ONE.into(), 'invalid zero', Option::None(())); + + let a = FixedTrait::ONE(); + assert_precise(cosh(a), 12944299, 'invalid neg one', Option::None(())); // 1.5430806347841253 + + let a = FixedTrait::new(TWO, true); + assert_precise(cosh(a), 31559602, 'invalid neg two', Option::None(())); // 3.762195691016423 +} + +#[test] +#[available_gas(10000000)] +fn test_sinh() { + let a = FixedTrait::new(TWO, false); + assert_precise(sinh(a), 30424310, 'invalid two', Option::None(())); // 3.6268604077773023 + + let a = FixedTrait::ONE(); + assert_precise(sinh(a), 9858302, 'invalid one', Option::None(())); // 1.1752011936029418 + + let a = FixedTrait::ZERO(); + assert(sinh(a).into() == 0, 'invalid zero'); + + let a = FixedTrait::new(ONE, true); + assert_precise(sinh(a), -9858302, 'invalid neg one', Option::None(())); // -1.1752011936029418 + + let a = FixedTrait::new(TWO, true); + assert_precise(sinh(a), -30424328, 'invalid neg two', Option::None(())); // -3.6268604077773023 +} + +#[test] +#[available_gas(10000000)] +fn test_tanh() { + let a = FixedTrait::new(TWO, false); + assert_precise(tanh(a), 8086849, 'invalid two', Option::None(())); // 0.9640275800745076 + + let a = FixedTrait::ONE(); + assert_precise(tanh(a), 6388715, 'invalid one', Option::None(())); // 0.7615941559446443 + + let a = FixedTrait::ZERO(); + assert(tanh(a).into() == 0, 'invalid zero'); + + let a = FixedTrait::new(ONE, true); + assert_precise(tanh(a), -6388715, 'invalid neg one', Option::None(())); // -0.7615941559446443 + + let a = FixedTrait::new(TWO, true); + assert_precise(tanh(a), -8086849, 'invalid neg two', Option::None(())); // 0.9640275800745076 +} + +#[test] +#[available_gas(10000000)] +fn test_acosh() { + let a = FixedTrait::new(31559585, false); // 3.762195691016423 + assert_precise(acosh(a), 16777257, 'invalid two', Option::None(())); + + let a = FixedTrait::new(12944299, false); // 1.5430806347841253 + assert_precise(acosh(a), ONE.into(), 'invalid one', Option::None(())); + + let a = FixedTrait::ONE(); // 1 + assert(acosh(a).into() == 0, 'invalid zero'); +} + +#[test] +#[available_gas(10000000)] +fn test_asinh() { + let a = FixedTrait::new(30424310, false); // 3.6268604077773023 + assert_precise(asinh(a), 16777257, 'invalid two', Option::None(())); + + let a = FixedTrait::new(9858302, false); // 1.1752011936029418 + assert_precise(asinh(a), ONE.into(), 'invalid one', Option::None(())); + + let a = FixedTrait::ZERO(); + assert(asinh(a).into() == 0, 'invalid zero'); + + let a = FixedTrait::new(9858302, true); // -1.1752011936029418 + assert_precise(asinh(a), -ONE.into(), 'invalid neg one', Option::None(())); + + let a = FixedTrait::new(30424310, true); // -3.6268604077773023 + assert_precise(asinh(a), -16777238, 'invalid neg two', Option::None(())); +} + +#[test] +#[available_gas(10000000)] +fn test_atanh() { + let a = FixedTrait::new(7549747, false); // 0.9 + assert_precise(atanh(a), 12349872, 'invalid 0.9', Option::None(())); // 1.4722194895832204 + + let a = FixedTrait::new(HALF, false); // 0.5 + assert_precise(atanh(a), 4607914, 'invalid half', Option::None(())); // 0.5493061443340548 + + let a = FixedTrait::ZERO(); + assert(atanh(a).into() == 0, 'invalid zero'); + + let a = FixedTrait::new(HALF, true); // 0.5 + assert_precise(atanh(a), -4607914, 'invalid neg half', Option::None(())); // 0.5493061443340548 + + let a = FixedTrait::new(7549747, true); // 0.9 + assert_precise(atanh(a), -12349872, 'invalid -0.9', Option::None(())); // 1.4722194895832204 +} diff --git a/src/numbers/fixed_point/implementations/fp8x23wide/math/lut.cairo b/src/numbers/fixed_point/implementations/fp8x23wide/math/lut.cairo new file mode 100644 index 000000000..157499b5b --- /dev/null +++ b/src/numbers/fixed_point/implementations/fp8x23wide/math/lut.cairo @@ -0,0 +1,1229 @@ +// Calculates the most significant bit +fn msb(whole: u64) -> (u64, u64) { + if whole < 256 { + if whole < 2 { + return (0, 1); + } + if whole < 4 { + return (1, 2); + } + if whole < 8 { + return (2, 4); + } + if whole < 16 { + return (3, 8); + } + if whole < 32 { + return (4, 16); + } + if whole < 64 { + return (5, 32); + } + if whole < 128 { + return (6, 64); + } + if whole < 256 { + return (7, 128); + } + } + + return (8, 256); +} + +fn exp2(exp: u64) -> u64 { + if exp <= 16 { + if exp == 0 { + return 1; + } + if exp == 1 { + return 2; + } + if exp == 2 { + return 4; + } + if exp == 3 { + return 8; + } + if exp == 4 { + return 16; + } + if exp == 5 { + return 32; + } + if exp == 6 { + return 64; + } + if exp == 7 { + return 128; + } + if exp == 8 { + return 256; + } + if exp == 9 { + return 512; + } + if exp == 10 { + return 1024; + } + if exp == 11 { + return 2048; + } + if exp == 12 { + return 4096; + } + if exp == 13 { + return 8192; + } + if exp == 14 { + return 16384; + } + if exp == 15 { + return 32768; + } + if exp == 16 { + return 65536; + } + } else if exp <= 32 { + if exp == 17 { + return 131072; + } + if exp == 18 { + return 262144; + } + if exp == 19 { + return 524288; + } + if exp == 20 { + return 1048576; + } + if exp == 21 { + return 2097152; + } + if exp == 22 { + return 4194304; + } + } + + return 8388608; +} + +fn sin(a: u64) -> (u64, u64, u64) { + let slot = a / 51472; + + if slot < 128 { + if slot < 64 { + if slot < 32 { + if slot < 16 { + if slot == 0 { + return (0, 0, 51472); + } + if slot == 1 { + return (51472, 51472, 102941); + } + if slot == 2 { + return (102944, 102941, 154407); + } + if slot == 3 { + return (154416, 154407, 205867); + } + if slot == 4 { + return (205887, 205867, 257319); + } + if slot == 5 { + return (257359, 257319, 308761); + } + if slot == 6 { + return (308831, 308761, 360192); + } + if slot == 7 { + return (360303, 360192, 411609); + } + if slot == 8 { + return (411775, 411609, 463011); + } + if slot == 9 { + return (463247, 463011, 514396); + } + if slot == 10 { + return (514723, 514396, 565761); + } + if slot == 11 { + return (566190, 565761, 617104); + } + if slot == 12 { + return (617662, 617104, 668425); + } + if slot == 13 { + return (669134, 668425, 719720); + } + if slot == 14 { + return (720606, 719720, 770988); + } + if slot == 15 { + return (772078, 770988, 822227); + } + } else { + if slot == 16 { + return (823550, 822227, 873436); + } + if slot == 17 { + return (875022, 873436, 924611); + } + if slot == 18 { + return (926493, 924611, 975751); + } + if slot == 19 { + return (977965, 975751, 1026855); + } + if slot == 20 { + return (1029437, 1026855, 1077920); + } + if slot == 21 { + return (1080909, 1077920, 1128945); + } + if slot == 22 { + return (1132381, 1128945, 1179927); + } + if slot == 23 { + return (1183853, 1179927, 1230864); + } + if slot == 24 { + return (1235324, 1230864, 1281756); + } + if slot == 25 { + return (1286796, 1281756, 1332599); + } + if slot == 26 { + return (1338268, 1332599, 1383392); + } + if slot == 27 { + return (1389740, 1383392, 1434132); + } + if slot == 28 { + return (1441212, 1434132, 1484819); + } + if slot == 29 { + return (1492684, 1484819, 1535450); + } + if slot == 30 { + return (1544156, 1535450, 1586023); + } + if slot == 31 { + return (1595627, 1586023, 1636536); + } + } + } else { + if slot < 48 { + if slot == 32 { + return (1647099, 1636536, 1686988); + } + if slot == 33 { + return (1698571, 1686988, 1737376); + } + if slot == 34 { + return (1750043, 1737376, 1787699); + } + if slot == 35 { + return (1801515, 1787699, 1837954); + } + if slot == 36 { + return (1852987, 1837954, 1888141); + } + if slot == 37 { + return (1904459, 1888141, 1938256); + } + if slot == 38 { + return (1955930, 1938256, 1988298); + } + if slot == 39 { + return (2007402, 1988298, 2038265); + } + if slot == 40 { + return (2058871, 2038265, 2088156); + } + if slot == 41 { + return (2110346, 2088156, 2137968); + } + if slot == 42 { + return (2161818, 2137968, 2187700); + } + if slot == 43 { + return (2213290, 2187700, 2237349); + } + if slot == 44 { + return (2264762, 2237349, 2286914); + } + if slot == 45 { + return (2316233, 2286914, 2336392); + } + if slot == 46 { + return (2367705, 2336392, 2385783); + } + if slot == 47 { + return (2419177, 2385783, 2435084); + } + } else { + if slot == 48 { + return (2470649, 2435084, 2484294); + } + if slot == 49 { + return (2522121, 2484294, 2533410); + } + if slot == 50 { + return (2573593, 2533410, 2582430); + } + if slot == 51 { + return (2625065, 2582430, 2631353); + } + if slot == 52 { + return (2676536, 2631353, 2680177); + } + if slot == 53 { + return (2728008, 2680177, 2728901); + } + if slot == 54 { + return (2779480, 2728901, 2777521); + } + if slot == 55 { + return (2830952, 2777521, 2826037); + } + if slot == 56 { + return (2882424, 2826037, 2874446); + } + if slot == 57 { + return (2933896, 2874446, 2922748); + } + if slot == 58 { + return (2985368, 2922748, 2970939); + } + if slot == 59 { + return (3036839, 2970939, 3019018); + } + if slot == 60 { + return (3088311, 3019018, 3066984); + } + if slot == 61 { + return (3139783, 3066984, 3114834); + } + if slot == 62 { + return (3191255, 3114834, 3162567); + } + if slot == 63 { + return (3242727, 3162567, 3210181); + } + } + } + } else { + if slot < 96 { + if slot < 80 { + if slot == 64 { + return (3294199, 3210181, 3257674); + } + if slot == 65 { + return (3345671, 3257674, 3305045); + } + if slot == 66 { + return (3397142, 3305045, 3352291); + } + if slot == 67 { + return (3448614, 3352291, 3399411); + } + if slot == 68 { + return (3500086, 3399411, 3446402); + } + if slot == 69 { + return (3551558, 3446402, 3493264); + } + if slot == 70 { + return (3603030, 3493264, 3539995); + } + if slot == 71 { + return (3654502, 3539995, 3586592); + } + if slot == 72 { + return (3705973, 3586592, 3633054); + } + if slot == 73 { + return (3757445, 3633054, 3679380); + } + if slot == 74 { + return (3808917, 3679380, 3725567); + } + if slot == 75 { + return (3860389, 3725567, 3771613); + } + if slot == 76 { + return (3911861, 3771613, 3817518); + } + if slot == 77 { + return (3963333, 3817518, 3863279); + } + if slot == 78 { + return (4014805, 3863279, 3908894); + } + if slot == 79 { + return (4066276, 3908894, 3954362); + } + } else { + if slot == 80 { + return (4117751, 3954362, 3999682); + } + if slot == 81 { + return (4169220, 3999682, 4044851); + } + if slot == 82 { + return (4220692, 4044851, 4089867); + } + if slot == 83 { + return (4272164, 4089867, 4134730); + } + if slot == 84 { + return (4323636, 4134730, 4179437); + } + if slot == 85 { + return (4375108, 4179437, 4223986); + } + if slot == 86 { + return (4426579, 4223986, 4268377); + } + if slot == 87 { + return (4478051, 4268377, 4312606); + } + if slot == 88 { + return (4529523, 4312606, 4356674); + } + if slot == 89 { + return (4580995, 4356674, 4400577); + } + if slot == 90 { + return (4632474, 4400577, 4444315); + } + if slot == 91 { + return (4683939, 4444315, 4487885); + } + if slot == 92 { + return (4735411, 4487885, 4531287); + } + if slot == 93 { + return (4786882, 4531287, 4574518); + } + if slot == 94 { + return (4838354, 4574518, 4617576); + } + if slot == 95 { + return (4889826, 4617576, 4660461); + } + } + } else { + if slot < 112 { + if slot == 96 { + return (4941298, 4660461, 4703170); + } + if slot == 97 { + return (4992770, 4703170, 4745702); + } + if slot == 98 { + return (5044242, 4745702, 4788056); + } + if slot == 99 { + return (5095714, 4788056, 4830229); + } + if slot == 100 { + return (5147227, 4830229, 4872221); + } + if slot == 101 { + return (5198657, 4872221, 4914029); + } + if slot == 102 { + return (5250129, 4914029, 4955652); + } + if slot == 103 { + return (5301601, 4955652, 4997088); + } + if slot == 104 { + return (5353073, 4997088, 5038336); + } + if slot == 105 { + return (5404545, 5038336, 5079395); + } + if slot == 106 { + return (5456017, 5079395, 5120262); + } + if slot == 107 { + return (5507488, 5120262, 5160937); + } + if slot == 108 { + return (5558960, 5160937, 5201417); + } + if slot == 109 { + return (5610432, 5201417, 5241701); + } + if slot == 110 { + return (5661904, 5241701, 5281788); + } + if slot == 111 { + return (5713376, 5281788, 5321677); + } + } else { + if slot == 112 { + return (5764848, 5321677, 5361364); + } + if slot == 113 { + return (5816320, 5361364, 5400850); + } + if slot == 114 { + return (5867791, 5400850, 5440133); + } + if slot == 115 { + return (5919263, 5440133, 5479211); + } + if slot == 116 { + return (5970735, 5479211, 5518082); + } + if slot == 117 { + return (6022207, 5518082, 5556746); + } + if slot == 118 { + return (6073679, 5556746, 5595201); + } + if slot == 119 { + return (6125151, 5595201, 5633445); + } + if slot == 120 { + return (6176622, 5633445, 5671477); + } + if slot == 121 { + return (6228094, 5671477, 5709295); + } + if slot == 122 { + return (6279566, 5709295, 5746898); + } + if slot == 123 { + return (6331038, 5746898, 5784285); + } + if slot == 124 { + return (6382510, 5784285, 5821455); + } + if slot == 125 { + return (6433982, 5821455, 5858405); + } + if slot == 126 { + return (6485454, 5858405, 5895134); + } + if slot == 127 { + return (6536925, 5895134, 5931642); + } + } + } + } + } else { + if slot < 192 { + if slot < 160 { + if slot < 144 { + if slot == 128 { + return (6588397, 5931642, 5967926); + } + if slot == 129 { + return (6639869, 5967926, 6003985); + } + if slot == 130 { + return (6691345, 6003985, 6039819); + } + if slot == 131 { + return (6742813, 6039819, 6075425); + } + if slot == 132 { + return (6794285, 6075425, 6110802); + } + if slot == 133 { + return (6845757, 6110802, 6145949); + } + if slot == 134 { + return (6897228, 6145949, 6180865); + } + if slot == 135 { + return (6948700, 6180865, 6215549); + } + if slot == 136 { + return (7000172, 6215549, 6249998); + } + if slot == 137 { + return (7051644, 6249998, 6284212); + } + if slot == 138 { + return (7103116, 6284212, 6318189); + } + if slot == 139 { + return (7154588, 6318189, 6351928); + } + if slot == 140 { + return (7206060, 6351928, 6385428); + } + if slot == 141 { + return (7257531, 6385428, 6418688); + } + if slot == 142 { + return (7309003, 6418688, 6451706); + } + if slot == 143 { + return (7360475, 6451706, 6484482); + } + } else { + if slot == 144 { + return (7411947, 6484482, 6517013); + } + if slot == 145 { + return (7463419, 6517013, 6549299); + } + if slot == 146 { + return (7514891, 6549299, 6581338); + } + if slot == 147 { + return (7566363, 6581338, 6613129); + } + if slot == 148 { + return (7617834, 6613129, 6644672); + } + if slot == 149 { + return (7669306, 6644672, 6675964); + } + if slot == 150 { + return (7720780, 6675964, 6707005); + } + if slot == 151 { + return (7772250, 6707005, 6737793); + } + if slot == 152 { + return (7823722, 6737793, 6768328); + } + if slot == 153 { + return (7875194, 6768328, 6798608); + } + if slot == 154 { + return (7926666, 6798608, 6828632); + } + if slot == 155 { + return (7978137, 6828632, 6858399); + } + if slot == 156 { + return (8029609, 6858399, 6887907); + } + if slot == 157 { + return (8081081, 6887907, 6917156); + } + if slot == 158 { + return (8132553, 6917156, 6946145); + } + if slot == 159 { + return (8184025, 6946145, 6974873); + } + if slot == 160 { + return (8235503, 6974873, 7003337); + } + } + } else { + if slot < 176 { + if slot == 161 { + return (8286968, 7003337, 7031538); + } + if slot == 162 { + return (8338440, 7031538, 7059475); + } + if slot == 163 { + return (8389912, 7059475, 7087145); + } + if slot == 164 { + return (8441384, 7087145, 7114549); + } + if slot == 165 { + return (8492856, 7114549, 7141685); + } + if slot == 166 { + return (8544328, 7141685, 7168552); + } + if slot == 167 { + return (8595800, 7168552, 7195149); + } + if slot == 168 { + return (8647271, 7195149, 7221475); + } + if slot == 169 { + return (8698743, 7221475, 7247530); + } + if slot == 170 { + return (8750215, 7247530, 7273311); + } + if slot == 171 { + return (8801687, 7273311, 7298819); + } + if slot == 172 { + return (8853159, 7298819, 7324052); + } + if slot == 173 { + return (8904631, 7324052, 7349009); + } + if slot == 174 { + return (8956103, 7349009, 7373689); + } + if slot == 175 { + return (9007574, 7373689, 7398092); + } + } else { + if slot == 176 { + return (9059046, 7398092, 7422216); + } + if slot == 177 { + return (9110518, 7422216, 7446061); + } + if slot == 178 { + return (9161990, 7446061, 7469625); + } + if slot == 179 { + return (9213462, 7469625, 7492909); + } + if slot == 180 { + return (9264934, 7492909, 7515910); + } + if slot == 181 { + return (9316406, 7515910, 7538628); + } + if slot == 182 { + return (9367877, 7538628, 7561062); + } + if slot == 183 { + return (9419349, 7561062, 7583212); + } + if slot == 184 { + return (9470821, 7583212, 7605076); + } + if slot == 185 { + return (9522293, 7605076, 7626654); + } + if slot == 186 { + return (9573765, 7626654, 7647945); + } + if slot == 187 { + return (9625237, 7647945, 7668947); + } + if slot == 188 { + return (9676709, 7668947, 7689661); + } + if slot == 189 { + return (9728180, 7689661, 7710086); + } + if slot == 190 { + return (9779651, 7710086, 7730220); + } + if slot == 191 { + return (9831124, 7730220, 7750063); + } + } + } + } else { + if slot < 224 { + if slot < 208 { + if slot == 192 { + return (9882596, 7750063, 7769615); + } + if slot == 193 { + return (9934068, 7769615, 7788874); + } + if slot == 194 { + return (9985540, 7788874, 7807839); + } + if slot == 195 { + return (10037012, 7807839, 7826511); + } + if slot == 196 { + return (10088483, 7826511, 7844888); + } + if slot == 197 { + return (10139955, 7844888, 7862970); + } + if slot == 198 { + return (10191427, 7862970, 7880755); + } + if slot == 199 { + return (10242899, 7880755, 7898244); + } + if slot == 200 { + return (10294373, 7898244, 7915436); + } + if slot == 201 { + return (10345843, 7915436, 7932329); + } + if slot == 202 { + return (10397315, 7932329, 7948924); + } + if slot == 203 { + return (10448786, 7948924, 7965220); + } + if slot == 204 { + return (10500258, 7965220, 7981215); + } + if slot == 205 { + return (10551730, 7981215, 7996911); + } + if slot == 206 { + return (10603202, 7996911, 8012305); + } + if slot == 207 { + return (10654674, 8012305, 8027397); + } + } else { + if slot == 208 { + return (10706146, 8027397, 8042188); + } + if slot == 209 { + return (10757617, 8042188, 8056675); + } + if slot == 210 { + return (10809089, 8056675, 8070859); + } + if slot == 211 { + return (10860561, 8070859, 8084740); + } + if slot == 212 { + return (10912033, 8084740, 8098316); + } + if slot == 213 { + return (10963505, 8098316, 8111587); + } + if slot == 214 { + return (11014977, 8111587, 8124552); + } + if slot == 215 { + return (11066449, 8124552, 8137212); + } + if slot == 216 { + return (11117920, 8137212, 8149565); + } + if slot == 217 { + return (11169392, 8149565, 8161612); + } + if slot == 218 { + return (11220864, 8161612, 8173351); + } + if slot == 219 { + return (11272336, 8173351, 8184783); + } + if slot == 220 { + return (11323808, 8184783, 8195906); + } + if slot == 221 { + return (11375280, 8195906, 8206721); + } + if slot == 222 { + return (11426752, 8206721, 8217227); + } + if slot == 223 { + return (11478223, 8217227, 8227423); + } + } + } else { + if slot < 240 { + if slot == 224 { + return (11529695, 8227423, 8237310); + } + if slot == 225 { + return (11581167, 8237310, 8246887); + } + if slot == 226 { + return (11632639, 8246887, 8256153); + } + if slot == 227 { + return (11684111, 8256153, 8265108); + } + if slot == 228 { + return (11735583, 8265108, 8273752); + } + if slot == 229 { + return (11787055, 8273752, 8282085); + } + if slot == 230 { + return (11838531, 8282085, 8290105); + } + if slot == 231 { + return (11889998, 8290105, 8297814); + } + if slot == 232 { + return (11941470, 8297814, 8305210); + } + if slot == 233 { + return (11992942, 8305210, 8312294); + } + if slot == 234 { + return (12044414, 8312294, 8319064); + } + if slot == 235 { + return (12095886, 8319064, 8325522); + } + if slot == 236 { + return (12147358, 8325522, 8331666); + } + if slot == 237 { + return (12198829, 8331666, 8337496); + } + if slot == 238 { + return (12250301, 8337496, 8343012); + } + if slot == 239 { + return (12301773, 8343012, 8348215); + } + } else { + if slot == 240 { + return (12353244, 8348215, 8353102); + } + if slot == 241 { + return (12404717, 8353102, 8357676); + } + if slot == 242 { + return (12456189, 8357676, 8361935); + } + if slot == 243 { + return (12507661, 8361935, 8365879); + } + if slot == 244 { + return (12559132, 8365879, 8369508); + } + if slot == 245 { + return (12610604, 8369508, 8372822); + } + if slot == 246 { + return (12662076, 8372822, 8375820); + } + if slot == 247 { + return (12713548, 8375820, 8378504); + } + if slot == 248 { + return (12765020, 8378504, 8380871); + } + if slot == 249 { + return (12816492, 8380871, 8382924); + } + if slot == 250 { + return (12867964, 8382924, 8384660); + } + if slot == 251 { + return (12919435, 8384660, 8386082); + } + if slot == 252 { + return (12970907, 8386082, 8387187); + } + if slot == 253 { + return (13022379, 8387187, 8387976); + } + if slot == 254 { + return (13073851, 8387976, 8388450); + } + } + } + } + } + + return (13125323, 8388450, 8388608); +} + +fn atan(a: u64) -> (u64, u64, u64) { + let slot = a / 58720; + + if slot == 0 { + return (0, 0, 58719); + } + if slot == 1 { + return (58720, 58719, 117433); + } + if slot == 2 { + return (117441, 117433, 176135); + } + if slot == 3 { + return (176161, 176135, 234820); + } + if slot == 4 { + return (234881, 234820, 293481); + } + if slot == 5 { + return (293601, 293481, 352115); + } + if slot == 6 { + return (352322, 352115, 410713); + } + if slot == 7 { + return (411042, 410713, 469272); + } + if slot == 8 { + return (469762, 469272, 527785); + } + if slot == 9 { + return (528482, 527785, 586246); + } + if slot == 10 { + return (587201, 586246, 644651); + } + if slot == 11 { + return (645923, 644651, 702993); + } + if slot == 12 { + return (704643, 702993, 761267); + } + if slot == 13 { + return (763363, 761267, 819467); + } + if slot == 14 { + return (822084, 819467, 877588); + } + if slot == 15 { + return (880804, 877588, 935625); + } + if slot == 16 { + return (939524, 935625, 993572); + } + if slot == 17 { + return (998244, 993572, 1051424); + } + if slot == 18 { + return (1056965, 1051424, 1109175); + } + if slot == 19 { + return (1115685, 1109175, 1166821); + } + if slot == 20 { + return (1174411, 1166821, 1224357); + } + if slot == 21 { + return (1233125, 1224357, 1281776); + } + if slot == 22 { + return (1291846, 1281776, 1339075); + } + if slot == 23 { + return (1350566, 1339075, 1396248); + } + if slot == 24 { + return (1409286, 1396248, 1453290); + } + if slot == 25 { + return (1468006, 1453290, 1510197); + } + if slot == 26 { + return (1526727, 1510197, 1566964); + } + if slot == 27 { + return (1585447, 1566964, 1623585); + } + if slot == 28 { + return (1644167, 1623585, 1680058); + } + if slot == 29 { + return (1702887, 1680058, 1736376); + } + if slot == 30 { + return (1761612, 1736376, 1792537); + } + if slot == 31 { + return (1820328, 1792537, 1848534); + } + if slot == 32 { + return (1879048, 1848534, 1904364); + } + if slot == 33 { + return (1937768, 1904364, 1960024); + } + if slot == 34 { + return (1996489, 1960024, 2015508); + } + if slot == 35 { + return (2055209, 2015508, 2070813); + } + if slot == 36 { + return (2113929, 2070813, 2125935); + } + if slot == 37 { + return (2172649, 2125935, 2180869); + } + if slot == 38 { + return (2231370, 2180869, 2235613); + } + if slot == 39 { + return (2290090, 2235613, 2290163); + } + if slot == 40 { + return (2348813, 2290163, 2344515); + } + if slot == 41 { + return (2407530, 2344515, 2398665); + } + if slot == 42 { + return (2466251, 2398665, 2452611); + } + if slot == 43 { + return (2524971, 2452611, 2506348); + } + if slot == 44 { + return (2583691, 2506348, 2559875); + } + if slot == 45 { + return (2642412, 2559875, 2613187); + } + if slot == 46 { + return (2701132, 2613187, 2666281); + } + if slot == 47 { + return (2759852, 2666281, 2719156); + } + if slot == 48 { + return (2818572, 2719156, 2771807); + } + if slot == 49 { + return (2877293, 2771807, 2824233); + } + if slot == 50 { + return (2936014, 2824233, 2876431); + } + if slot == 51 { + return (2994733, 2876431, 2928397); + } + if slot == 52 { + return (3053453, 2928397, 2980130); + } + if slot == 53 { + return (3112174, 2980130, 3031628); + } + if slot == 54 { + return (3170894, 3031628, 3082888); + } + if slot == 55 { + return (3229614, 3082888, 3133907); + } + if slot == 56 { + return (3288334, 3133907, 3184685); + } + if slot == 57 { + return (3347055, 3184685, 3235218); + } + if slot == 58 { + return (3405775, 3235218, 3285506); + } + if slot == 59 { + return (3464495, 3285506, 3335545); + } + if slot == 60 { + return (3523224, 3335545, 3385336); + } + if slot == 61 { + return (3581936, 3385336, 3434875); + } + if slot == 62 { + return (3640656, 3434875, 3484161); + } + if slot == 63 { + return (3699376, 3484161, 3533193); + } + if slot == 64 { + return (3758096, 3533193, 3581970); + } + if slot == 65 { + return (3816817, 3581970, 3630491); + } + if slot == 66 { + return (3875537, 3630491, 3678753); + } + if slot == 67 { + return (3934257, 3678753, 3726756); + } + if slot == 68 { + return (3992977, 3726756, 3774499); + } + if slot == 69 { + return (4051698, 3774499, 3821981); + } + if slot == 70 { + return (4110418, 3821981, 3869201); + } + if slot == 71 { + return (4169138, 3869201, 3916159); + } + if slot == 72 { + return (4227858, 3916159, 3962853); + } + if slot == 73 { + return (4286579, 3962853, 4009282); + } + if slot == 74 { + return (4345299, 4009282, 4055447); + } + if slot == 75 { + return (4404019, 4055447, 4101347); + } + if slot == 76 { + return (4462739, 4101347, 4146981); + } + if slot == 77 { + return (4521460, 4146981, 4192350); + } + if slot == 78 { + return (4580180, 4192350, 4237451); + } + if slot == 79 { + return (4638900, 4237451, 4282286); + } + if slot == 80 { + return (4697620, 4282286, 4326855); + } + if slot == 81 { + return (4756341, 4326855, 4371156); + } + if slot == 82 { + return (4815061, 4371156, 4415191); + } + if slot == 83 { + return (4873781, 4415191, 4458958); + } + if slot == 84 { + return (4932502, 4458958, 4502459); + } + if slot == 85 { + return (4991222, 4502459, 4545693); + } + if slot == 86 { + return (5049942, 4545693, 4588660); + } + if slot == 87 { + return (5108662, 4588660, 4631361); + } + if slot == 88 { + return (5167383, 4631361, 4673795); + } + if slot == 89 { + return (5226103, 4673795, 4715964); + } + if slot == 90 { + return (5284823, 4715964, 4757868); + } + if slot == 91 { + return (5343543, 4757868, 4799506); + } + if slot == 92 { + return (5402264, 4799506, 4840880); + } + if slot == 93 { + return (5460984, 4840880, 4881990); + } + if slot == 94 { + return (5519704, 4881990, 4922837); + } + if slot == 95 { + return (5578424, 4922837, 4963420); + } + if slot == 96 { + return (5637145, 4963420, 5003742); + } + if slot == 97 { + return (5695865, 5003742, 5043802); + } + if slot == 98 { + return (5754585, 5043802, 5083601); + } + + return (5813305, 5083601, 5123141); +} diff --git a/src/numbers/fixed_point/implementations/fp8x23wide/math/trig.cairo b/src/numbers/fixed_point/implementations/fp8x23wide/math/trig.cairo new file mode 100644 index 000000000..025b79bb2 --- /dev/null +++ b/src/numbers/fixed_point/implementations/fp8x23wide/math/trig.cairo @@ -0,0 +1,448 @@ +use debug::PrintTrait; +use integer::{u64_safe_divmod, u64_as_non_zero}; +use option::OptionTrait; + +use orion::numbers::fixed_point::implementations::fp8x23wide::math::lut; +use orion::numbers::fixed_point::implementations::fp8x23wide::core::{ + HALF, ONE, TWO, FP8x23W, FP8x23WImpl, FP8x23WAdd, FP8x23WSub, FP8x23WMul, FP8x23WDiv, + FP8x23WIntoFelt252, FixedTrait +}; + +// CONSTANTS + +const TWO_PI: u64 = 52707178; +const PI: u64 = 26353589; +const HALF_PI: u64 = 13176795; + +// PUBLIC + +// Calculates arccos(a) for -1 <= a <= 1 (fixed point) +// arccos(a) = arcsin(sqrt(1 - a^2)) - arctan identity has discontinuity at zero +fn acos(a: FP8x23W) -> FP8x23W { + let asin_arg = (FixedTrait::ONE() - a * a).sqrt(); // will fail if a > 1 + let asin_res = asin(asin_arg); + + if (a.sign) { + return FixedTrait::new(PI, false) - asin_res; + } else { + return asin_res; + } +} + +fn acos_fast(a: FP8x23W) -> FP8x23W { + let asin_arg = (FixedTrait::ONE() - a * a).sqrt(); // will fail if a > 1 + let asin_res = asin_fast(asin_arg); + + if (a.sign) { + return FixedTrait::new(PI, false) - asin_res; + } else { + return asin_res; + } +} + +// Calculates arcsin(a) for -1 <= a <= 1 (fixed point) +// arcsin(a) = arctan(a / sqrt(1 - a^2)) +fn asin(a: FP8x23W) -> FP8x23W { + if (a.mag == ONE) { + return FixedTrait::new(HALF_PI, a.sign); + } + + let div = (FixedTrait::ONE() - a * a).sqrt(); // will fail if a > 1 + return atan(a / div); +} + +fn asin_fast(a: FP8x23W) -> FP8x23W { + if (a.mag == ONE) { + return FixedTrait::new(HALF_PI, a.sign); + } + + let div = (FixedTrait::ONE() - a * a).sqrt(); // will fail if a > 1 + return atan_fast(a / div); +} + +// Calculates arctan(a) (fixed point) +// See https://stackoverflow.com/a/50894477 for range adjustments +fn atan(a: FP8x23W) -> FP8x23W { + let mut at = a.abs(); + let mut shift = false; + let mut invert = false; + + // Invert value when a > 1 + if (at.mag > ONE) { + at = FixedTrait::ONE() / at; + invert = true; + } + + // Account for lack of precision in polynomaial when a > 0.7 + if (at.mag > 5872026) { + let sqrt3_3 = FixedTrait::new(4843165, false); // sqrt(3) / 3 + at = (at - sqrt3_3) / (FixedTrait::ONE() + at * sqrt3_3); + shift = true; + } + + let r10 = FixedTrait::new(15363, true) * at; + let r9 = (r10 + FixedTrait::new(392482, true)) * at; + let r8 = (r9 + FixedTrait::new(1629064, false)) * at; + let r7 = (r8 + FixedTrait::new(2197820, true)) * at; + let r6 = (r7 + FixedTrait::new(366693, false)) * at; + let r5 = (r6 + FixedTrait::new(1594324, false)) * at; + let r4 = (r5 + FixedTrait::new(11519, false)) * at; + let r3 = (r4 + FixedTrait::new(2797104, true)) * at; + let r2 = (r3 + FixedTrait::new(34, false)) * at; + let mut res = (r2 + FixedTrait::new(8388608, false)) * at; + + // Adjust for sign change, inversion, and shift + if (shift) { + res = res + FixedTrait::new(4392265, false); // pi / 6 + } + + if (invert) { + res = res - FixedTrait::new(HALF_PI, false); + } + + return FixedTrait::new(res.mag, a.sign); +} + +fn atan_fast(a: FP8x23W) -> FP8x23W { + let mut at = a.abs(); + let mut shift = false; + let mut invert = false; + + // Invert value when a > 1 + if (at.mag > ONE) { + at = FixedTrait::ONE() / at; + invert = true; + } + + // Account for lack of precision in polynomaial when a > 0.7 + if (at.mag > 5872026) { + let sqrt3_3 = FixedTrait::new(4843165, false); // sqrt(3) / 3 + at = (at - sqrt3_3) / (FixedTrait::ONE() + at * sqrt3_3); + shift = true; + } + + let (start, low, high) = lut::atan(at.mag); + let partial_step = FixedTrait::new(at.mag - start, false) / FixedTrait::new(58720, false); + let mut res = partial_step * FixedTrait::new(high - low, false) + FixedTrait::new(low, false); + + // Adjust for sign change, inversion, and shift + if (shift) { + res = res + FixedTrait::new(4392265, false); // pi / 6 + } + + if (invert) { + res = res - FixedTrait::::new(HALF_PI, false); + } + + return FixedTrait::new(res.mag, a.sign); +} + +// Calculates cos(a) with a in radians (fixed point) +fn cos(a: FP8x23W) -> FP8x23W { + return sin(FixedTrait::new(HALF_PI, false) - a); +} + +fn cos_fast(a: FP8x23W) -> FP8x23W { + return sin_fast(FixedTrait::new(HALF_PI, false) - a); +} + +fn sin(a: FP8x23W) -> FP8x23W { + let a1 = a.mag % TWO_PI; + let (whole_rem, partial_rem) = u64_safe_divmod(a1, u64_as_non_zero(PI)); + let a2 = FixedTrait::new(partial_rem, false); + let partial_sign = whole_rem == 1; + + let loop_res = a2 * _sin_loop(a2, 7, FixedTrait::ONE()); + return FixedTrait::new(loop_res.mag, a.sign ^ partial_sign && loop_res.mag != 0); +} + +fn sin_fast(a: FP8x23W) -> FP8x23W { + let a1 = a.mag % TWO_PI; + let (whole_rem, mut partial_rem) = u64_safe_divmod(a1, u64_as_non_zero(PI)); + let partial_sign = whole_rem == 1; + + if partial_rem >= HALF_PI { + partial_rem = PI - partial_rem; + } + + let (start, low, high) = lut::sin(partial_rem); + let partial_step = FixedTrait::new(partial_rem - start, false) / FixedTrait::new(51472, false); + let res = partial_step * (FixedTrait::new(high, false) - FixedTrait::new(low, false)) + + FixedTrait::::new(low, false); + + return FixedTrait::new(res.mag, a.sign ^ partial_sign && res.mag != 0); +} + +// Calculates tan(a) with a in radians (fixed point) +fn tan(a: FP8x23W) -> FP8x23W { + let sinx = sin(a); + let cosx = cos(a); + assert(cosx.mag != 0, 'tan undefined'); + return sinx / cosx; +} + +fn tan_fast(a: FP8x23W) -> FP8x23W { + let sinx = sin_fast(a); + let cosx = cos_fast(a); + assert(cosx.mag != 0, 'tan undefined'); + return sinx / cosx; +} + +// Helper function to calculate Taylor series for sin +fn _sin_loop(a: FP8x23W, i: u64, acc: FP8x23W) -> FP8x23W { + let div = (2 * i + 2) * (2 * i + 3); + let term = a * a * acc / FixedTrait::new_unscaled(div, false); + let new_acc = FixedTrait::ONE() - term; + + if (i == 0) { + return new_acc; + } + + return _sin_loop(a, i - 1, new_acc); +} + +// Tests -------------------------------------------------------------------------------------------------------------- + +use traits::Into; + +use orion::numbers::fixed_point::implementations::fp8x23wide::helpers::{ + assert_precise, assert_relative +}; + +#[test] +#[available_gas(3000000)] +fn test_acos() { + let error = Option::Some(84); // 1e-5 + + let a = FixedTrait::ONE(); + assert(acos(a).into() == 0, 'invalid one'); + + let a = FixedTrait::new(ONE / 2, false); + assert_relative(acos(a), 8784530, 'invalid half', error); // 1.0471975506263043 + + let a = FixedTrait::ZERO(); + assert_relative(acos(a), HALF_PI.into(), 'invalid zero', Option::None(())); // PI / 2 + + let a = FixedTrait::new(ONE / 2, true); + assert_relative(acos(a), 17569060, 'invalid neg half', error); // 2.094395102963489 + + let a = FixedTrait::new(ONE, true); + assert_relative(acos(a), PI.into(), 'invalid neg one', Option::None(())); // PI +} + +#[test] +#[available_gas(3000000)] +fn test_acos_fast() { + let error = Option::Some(84); // 1e-5 + + let a = FixedTrait::ONE(); + assert(acos_fast(a).into() == 0, 'invalid one'); + + let a = FixedTrait::new(ONE / 2, false); + assert_relative(acos_fast(a), 8784530, 'invalid half', error); // 1.0471975506263043 + + let a = FixedTrait::ZERO(); + assert_relative(acos_fast(a), HALF_PI.into(), 'invalid zero', Option::None(())); // PI / 2 + + let a = FixedTrait::new(ONE / 2, true); + assert_relative(acos_fast(a), 17569060, 'invalid neg half', error); // 2.094395102963489 + + let a = FixedTrait::new(ONE, true); + assert_relative(acos_fast(a), PI.into(), 'invalid neg one', Option::None(())); // PI +} + +#[test] +#[should_panic] +#[available_gas(1000000)] +fn test_acos_fail() { + let a = FixedTrait::new(2 * ONE, true); + acos(a); +} + +#[test] +#[available_gas(1400000)] +fn test_atan_fast() { + let error = Option::Some(84); // 1e-5 + + let a = FixedTrait::new(2 * ONE, false); + assert_relative(atan_fast(a), 9287437, 'invalid two', error); + + let a = FixedTrait::ONE(); + assert_relative(atan_fast(a), 6588397, 'invalid one', error); + + let a = FixedTrait::new(ONE / 2, false); + assert_relative(atan_fast(a), 3889358, 'invalid half', error); + + let a = FixedTrait::ZERO(); + assert(atan_fast(a).into() == 0, 'invalid zero'); + + let a = FixedTrait::new(ONE / 2, true); + assert_relative(atan_fast(a), -3889358, 'invalid neg half', error); + + let a = FixedTrait::new(ONE, true); + assert_relative(atan_fast(a), -6588397, 'invalid neg one', error); + + let a = FixedTrait::new(2 * ONE, true); + assert_relative(atan_fast(a), -9287437, 'invalid neg two', error); +} + +#[test] +#[available_gas(2600000)] +fn test_atan() { + let a = FixedTrait::new(2 * ONE, false); + assert_relative(atan(a), 9287437, 'invalid two', Option::None(())); + + let a = FixedTrait::ONE(); + assert_relative(atan(a), 6588397, 'invalid one', Option::None(())); + + let a = FixedTrait::new(ONE / 2, false); + assert_relative(atan(a), 3889358, 'invalid half', Option::None(())); + + let a = FixedTrait::ZERO(); + assert(atan(a).into() == 0, 'invalid zero'); + + let a = FixedTrait::new(ONE / 2, true); + assert_relative(atan(a), -3889358, 'invalid neg half', Option::None(())); + + let a = FixedTrait::new(ONE, true); + assert_relative(atan(a), -6588397, 'invalid neg one', Option::None(())); + + let a = FixedTrait::new(2 * ONE, true); + assert_relative(atan(a), -9287437, 'invalid neg two', Option::None(())); +} + +#[test] +#[available_gas(3000000)] +fn test_asin() { + let error = Option::Some(84); // 1e-5 + + let a = FixedTrait::ONE(); + assert_relative(asin(a), HALF_PI.into(), 'invalid one', Option::None(())); // PI / 2 + + let a = FixedTrait::new(ONE / 2, false); + assert_relative(asin(a), 4392265, 'invalid half', error); + + let a = FixedTrait::ZERO(); + assert_precise(asin(a), 0, 'invalid zero', Option::None(())); + + let a = FixedTrait::new(ONE / 2, true); + assert_relative(asin(a), -4392265, 'invalid neg half', error); + + let a = FixedTrait::new(ONE, true); + assert_relative(asin(a), -HALF_PI.into(), 'invalid neg one', Option::None(())); // -PI / 2 +} + +#[test] +#[should_panic] +#[available_gas(1000000)] +fn test_asin_fail() { + let a = FixedTrait::new(2 * ONE, false); + asin(a); +} + +#[test] +#[available_gas(6000000)] +fn test_cos() { + let a = FixedTrait::new(HALF_PI, false); + assert(cos(a).into() == 0, 'invalid half pi'); + + let a = FixedTrait::new(HALF_PI / 2, false); + assert_relative(cos(a), 5931642, 'invalid quarter pi', Option::None(())); // 0.7071067811865475 + + let a = FixedTrait::new(PI, false); + assert_relative(cos(a), -1 * ONE.into(), 'invalid pi', Option::None(())); + + let a = FixedTrait::new(HALF_PI, true); + assert_precise(cos(a), 0, 'invalid neg half pi', Option::None(())); + + let a = FixedTrait::new_unscaled(17, false); + assert_relative(cos(a), -2308239, 'invalid 17', Option::None(())); // -0.2751631780463348 + + let a = FixedTrait::new_unscaled(17, true); + assert_relative(cos(a), -2308236, 'invalid -17', Option::None(())); // -0.2751631780463348 +} + +#[test] +#[available_gas(6000000)] +fn test_cos_fast() { + let error = Option::Some(84); // 1e-5 + + let a = FixedTrait::new(HALF_PI, false); + assert(cos_fast(a).into() == 0, 'invalid half pi'); + + let a = FixedTrait::new(HALF_PI / 2, false); + assert_precise(cos_fast(a), 5931642, 'invalid quarter pi', error); // 0.7071067811865475 + + let a = FixedTrait::new(PI, false); + assert_precise(cos_fast(a), -1 * ONE.into(), 'invalid pi', error); + + let a = FixedTrait::new(HALF_PI, true); + assert_precise(cos(a), 0, 'invalid neg half pi', Option::None(())); + + let a = FixedTrait::new_unscaled(17, false); + assert_precise(cos_fast(a), -2308239, 'invalid 17', error); // -0.2751631780463348 +} + +#[test] +#[available_gas(6000000)] +fn test_sin() { + let a = FixedTrait::new(HALF_PI, false); + assert_precise(sin(a), ONE.into(), 'invalid half pi', Option::None(())); + + let a = FixedTrait::new(HALF_PI / 2, false); + assert_precise(sin(a), 5931642, 'invalid quarter pi', Option::None(())); // 0.7071067811865475 + + let a = FixedTrait::new(PI, false); + assert(sin(a).into() == 0, 'invalid pi'); + + let a = FixedTrait::new(HALF_PI, true); + assert_precise( + sin(a), -ONE.into(), 'invalid neg half pi', Option::None(()) + ); // 0.9999999999939766 + + let a = FixedTrait::new_unscaled(17, false); + assert_precise(sin(a), -8064787, 'invalid 17', Option::None(())); // -0.9613974918793389 + + let a = FixedTrait::new_unscaled(17, true); + assert_precise(sin(a), 8064787, 'invalid -17', Option::None(())); // 0.9613974918793389 +} + +#[test] +#[available_gas(1000000)] +fn test_sin_fast() { + let error = Option::Some(84); // 1e-5 + + let a = FixedTrait::new(HALF_PI, false); + assert_precise(sin_fast(a), ONE.into(), 'invalid half pi', error); + + let a = FixedTrait::new(HALF_PI / 2, false); + assert_precise(sin_fast(a), 5931642, 'invalid quarter pi', error); // 0.7071067811865475 + + let a = FixedTrait::new(PI, false); + assert(sin_fast(a).into() == 0, 'invalid pi'); + + let a = FixedTrait::new(HALF_PI, true); + assert_precise(sin_fast(a), -ONE.into(), 'invalid neg half pi', error); // 0.9999999999939766 + + let a = FixedTrait::new_unscaled(17, false); + assert_precise(sin_fast(a), -8064787, 'invalid 17', error); // -0.9613974918793389 + + let a = FixedTrait::new_unscaled(17, true); + assert_precise(sin_fast(a), 8064787, 'invalid -17', error); // 0.9613974918793389 +} + +#[test] +#[available_gas(8000000)] +fn test_tan() { + let a = FixedTrait::new(HALF_PI / 2, false); + assert_precise(tan(a), ONE.into(), 'invalid quarter pi', Option::None(())); + + let a = FixedTrait::new(PI, false); + assert_precise(tan(a), 0, 'invalid pi', Option::None(())); + + let a = FixedTrait::new_unscaled(17, false); + assert_precise(tan(a), 29309069, 'invalid 17', Option::None(())); // 3.493917677159002 + + let a = FixedTrait::new_unscaled(17, true); + assert_precise(tan(a), -29309106, 'invalid -17', Option::None(())); // -3.493917677159002 +} From 8a55e5c730d6b2400165c517c251d6e3b603192b Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Mon, 23 Oct 2023 11:00:19 +0300 Subject: [PATCH 41/42] implement tensor_fp8x23wide --- src/numbers.cairo | 165 ++++++++ .../implementations/fp8x23wide/core.cairo | 21 +- .../nn/implementations/nn_fp8x23.cairo | 6 +- src/operators/tensor/implementations.cairo | 3 +- .../implementations/tensor_fp8x23wide.cairo | 376 ++++++++++++++++++ 5 files changed, 568 insertions(+), 3 deletions(-) create mode 100644 src/operators/tensor/implementations/tensor_fp8x23wide.cairo diff --git a/src/numbers.cairo b/src/numbers.cairo index 02dd5b344..0a533e937 100644 --- a/src/numbers.cairo +++ b/src/numbers.cairo @@ -213,6 +213,171 @@ impl FP8x23Number of NumberTrait { } } +use orion::numbers::fixed_point::implementations::fp8x23wide::core::{FP8x23WImpl, FP8x23W}; +use orion::numbers::fixed_point::implementations::fp8x23wide::math::core as core_fp8x23wide; +use orion::numbers::fixed_point::implementations::fp8x23wide::math::comp as comp_fp8x23wide; + +impl FP8x23WNumber of NumberTrait { + fn new(mag: u64, sign: bool) -> FP8x23W { + FP8x23WImpl::new(mag, sign) + } + + fn new_unscaled(mag: u64, sign: bool) -> FP8x23W { + FP8x23WImpl::new_unscaled(mag, sign) + } + + fn from_felt(val: felt252) -> FP8x23W { + FP8x23WImpl::from_felt(val) + } + + fn ceil(self: FP8x23W) -> FP8x23W { + FP8x23WImpl::ceil(self) + } + + fn exp(self: FP8x23W) -> FP8x23W { + FP8x23WImpl::exp(self) + } + + fn exp2(self: FP8x23W) -> FP8x23W { + FP8x23WImpl::exp2(self) + } + + fn floor(self: FP8x23W) -> FP8x23W { + FP8x23WImpl::floor(self) + } + + fn ln(self: FP8x23W) -> FP8x23W { + FP8x23WImpl::ln(self) + } + + fn log2(self: FP8x23W) -> FP8x23W { + FP8x23WImpl::log2(self) + } + + fn log10(self: FP8x23W) -> FP8x23W { + FP8x23WImpl::log10(self) + } + + fn pow(self: FP8x23W, b: FP8x23W) -> FP8x23W { + FP8x23WImpl::pow(self, b) + } + + fn round(self: FP8x23W) -> FP8x23W { + FP8x23WImpl::round(self) + } + + fn sqrt(self: FP8x23W) -> FP8x23W { + FP8x23WImpl::sqrt(self) + } + + fn acos(self: FP8x23W) -> FP8x23W { + FP8x23WImpl::acos(self) + } + + fn asin(self: FP8x23W) -> FP8x23W { + FP8x23WImpl::asin(self) + } + + fn atan(self: FP8x23W) -> FP8x23W { + FP8x23WImpl::atan(self) + } + + fn cos(self: FP8x23W) -> FP8x23W { + FP8x23WImpl::cos(self) + } + + fn sin(self: FP8x23W) -> FP8x23W { + FP8x23WImpl::sin(self) + } + + fn tan(self: FP8x23W) -> FP8x23W { + FP8x23WImpl::tan(self) + } + + fn acosh(self: FP8x23W) -> FP8x23W { + FP8x23WImpl::acosh(self) + } + + fn asinh(self: FP8x23W) -> FP8x23W { + FP8x23WImpl::asinh(self) + } + + fn atanh(self: FP8x23W) -> FP8x23W { + FP8x23WImpl::atanh(self) + } + + fn cosh(self: FP8x23W) -> FP8x23W { + FP8x23WImpl::cosh(self) + } + + fn sinh(self: FP8x23W) -> FP8x23W { + FP8x23WImpl::sinh(self) + } + + fn tanh(self: FP8x23W) -> FP8x23W { + FP8x23WImpl::tanh(self) + } + + fn zero() -> FP8x23W { + FP8x23WImpl::ZERO() + } + fn is_zero(self: FP8x23W) -> bool { + core_fp8x23wide::eq(@self, @FP8x23WImpl::ZERO()) + } + + fn one() -> FP8x23W { + FP8x23WImpl::ONE() + } + + fn neg_one() -> FP8x23W { + FP8x23W { mag: core_fp8x23wide::ONE, sign: true } + } + + fn is_one(self: FP8x23W) -> bool { + core_fp8x23wide::eq(@self, @FP8x23WImpl::ONE()) + } + + fn abs(self: FP8x23W) -> FP8x23W { + core_fp8x23wide::abs(self) + } + + fn min_value() -> FP8x23W { + FP8x23W { mag: core_fp8x23wide::MAX, sign: true } + } + + fn max_value() -> FP8x23W { + FP8x23W { mag: core_fp8x23wide::MAX, sign: false } + } + + fn min(self: FP8x23W, other: FP8x23W) -> FP8x23W { + comp_fp8x23wide::min(self, other) + } + + fn max(self: FP8x23W, other: FP8x23W) -> FP8x23W { + comp_fp8x23wide::max(self, other) + } + + fn mag(self: FP8x23W) -> u64 { + self.mag + } + + fn is_neg(self: FP8x23W) -> bool { + self.sign + } + + fn xor(lhs: FP8x23W, rhs: FP8x23W) -> bool { + comp_fp8x23wide::xor(lhs, rhs) + } + + fn or(lhs: FP8x23W, rhs: FP8x23W) -> bool { + comp_fp8x23wide::or(lhs, rhs) + } + + fn sign(self: FP8x23W) -> FP8x23W { + core_fp8x23wide::sign(self) + } +} + use orion::numbers::fixed_point::implementations::fp16x16::core::{FP16x16Impl, FP16x16}; use orion::numbers::fixed_point::implementations::fp16x16::math::core as core_fp16x16; use orion::numbers::fixed_point::implementations::fp16x16::math::comp as comp_fp16x16; diff --git a/src/numbers/fixed_point/implementations/fp8x23wide/core.cairo b/src/numbers/fixed_point/implementations/fp8x23wide/core.cairo index 36b64ce5e..3fe3cd3cb 100644 --- a/src/numbers/fixed_point/implementations/fp8x23wide/core.cairo +++ b/src/numbers/fixed_point/implementations/fp8x23wide/core.cairo @@ -5,7 +5,7 @@ use result::{ResultTrait, ResultTraitImpl}; use traits::{TryInto, Into}; use orion::numbers::signed_integer::{i32::i32, i8::i8}; -use orion::numbers::fixed_point::core::{FixedTrait}; +use orion::numbers::{fixed_point::core::{FixedTrait}, FP8x23}; use orion::numbers::fixed_point::implementations::fp8x23wide::math::{core, trig, hyp}; use orion::numbers::fixed_point::utils; @@ -205,6 +205,25 @@ impl FP8x23WIntoFelt252 of Into { } } +impl FP8x23IntoFP8x23W of Into { + fn into(self: FP8x23) -> FP8x23W { + FP8x23W { mag: self.mag.into(), sign: self.sign } + } +} + +impl FP8x23WTryIntoFP8x23 of TryInto { + fn try_into(self: FP8x23W) -> Option { + match self.mag.try_into() { + Option::Some(val) => { + Option::Some(FP8x23 { mag: val, sign: self.sign }) + }, + Option::None(_) => { + Option::None(()) + } + } + } +} + impl FP8x23WTryIntoU128 of TryInto { fn try_into(self: FP8x23W) -> Option { if self.sign { diff --git a/src/operators/nn/implementations/nn_fp8x23.cairo b/src/operators/nn/implementations/nn_fp8x23.cairo index 305eeaba2..510f8cebd 100644 --- a/src/operators/nn/implementations/nn_fp8x23.cairo +++ b/src/operators/nn/implementations/nn_fp8x23.cairo @@ -7,6 +7,10 @@ use orion::numbers::fixed_point::implementations::fp8x23::core::FP8x23; use orion::operators::tensor::implementations::tensor_fp8x23::{ FP8x23Tensor, FP8x23TensorDiv, FP8x23TensorAdd }; +use orion::numbers::fixed_point::implementations::fp8x23wide::core::{ + FP8x23WImpl, FP8x23WTryIntoFP8x23, FP8x23W, FP8x23IntoFP8x23W +}; +use orion::operators::tensor::implementations::tensor_fp8x23wide::{FP8x23WTensor}; impl FP8x23NN of NNTrait { fn relu(tensor: @Tensor) -> Tensor { @@ -18,7 +22,7 @@ impl FP8x23NN of NNTrait { } fn softmax(tensor: @Tensor, axis: usize) -> Tensor { - functional::softmax::softmax(tensor, axis) + functional::softmax::softmaxWide::(tensor, axis) } fn logsoftmax(tensor: @Tensor, axis: usize) -> Tensor { diff --git a/src/operators/tensor/implementations.cairo b/src/operators/tensor/implementations.cairo index 0df3dcdec..a96030744 100644 --- a/src/operators/tensor/implementations.cairo +++ b/src/operators/tensor/implementations.cairo @@ -5,4 +5,5 @@ mod tensor_fp8x23; mod tensor_fp16x16; mod tensor_fp64x64; mod tensor_fp32x32; -mod tensor_fp16x16wide; \ No newline at end of file +mod tensor_fp16x16wide; +mod tensor_fp8x23wide; \ No newline at end of file diff --git a/src/operators/tensor/implementations/tensor_fp8x23wide.cairo b/src/operators/tensor/implementations/tensor_fp8x23wide.cairo new file mode 100644 index 000000000..91ca9338d --- /dev/null +++ b/src/operators/tensor/implementations/tensor_fp8x23wide.cairo @@ -0,0 +1,376 @@ +use array::ArrayTrait; +use array::SpanTrait; +use option::OptionTrait; +use traits::{TryInto, Into}; + +use orion::numbers::fixed_point::core::FixedTrait; +use orion::operators::tensor::core::{ + new_tensor, stride, Tensor, TensorTrait, ravel_index, unravel_index, reshape, at_tensor, +}; +use orion::operators::tensor::{math, linalg, quantization, core}; +use orion::numbers::{i8, i32, NumberTrait, FP8x23W}; +use orion::operators::tensor::implementations::{tensor_i8::I8Tensor, tensor_u32::U32Tensor}; + +impl FP8x23WTensor of TensorTrait { + fn new(shape: Span, data: Span) -> Tensor { + new_tensor(shape, data) + } + + fn at(self: @Tensor, indices: Span) -> FP8x23W { + *at_tensor(self, indices) + } + + fn min(self: @Tensor) -> FP8x23W { + math::min::min_in_tensor::(*self.data) + } + + fn max(self: @Tensor) -> FP8x23W { + math::max::max_in_tensor(*self.data) + } + + fn stride(self: @Tensor) -> Span { + stride(*self.shape) + } + + fn ravel_index(self: @Tensor, indices: Span) -> usize { + ravel_index(*self.shape, indices) + } + + fn unravel_index(self: @Tensor, index: usize) -> Span { + unravel_index(index, *self.shape) + } + + fn reshape(self: @Tensor, target_shape: Span) -> Tensor { + reshape(self, target_shape) + } + + fn reduce_sum(self: @Tensor, axis: usize, keepdims: bool) -> Tensor { + math::reduce_sum::reduce_sum(self, axis, keepdims) + } + + fn argmax( + self: @Tensor, axis: usize, keepdims: Option, select_last_index: Option + ) -> Tensor { + math::argmax::argmax(self, axis, keepdims, select_last_index) + } + + fn argmin( + self: @Tensor, axis: usize, keepdims: Option, select_last_index: Option + ) -> Tensor { + math::argmin::argmin(self, axis, keepdims, select_last_index) + } + + fn transpose(self: @Tensor, axes: Span) -> Tensor { + linalg::transpose::transpose(self, axes) + } + + fn matmul(self: @Tensor, other: @Tensor) -> Tensor { + linalg::matmul::matmul(self, other) + } + + fn exp(self: @Tensor) -> Tensor { + math::exp::exp(*self) + } + + fn log(self: @Tensor) -> Tensor { + math::log::log(*self) + } + + fn equal(self: @Tensor, other: @Tensor) -> Tensor { + math::equal::equal(self, other) + } + + fn greater(self: @Tensor, other: @Tensor) -> Tensor { + math::greater::greater(self, other) + } + + fn greater_equal(self: @Tensor, other: @Tensor) -> Tensor { + math::greater_equal::greater_equal(self, other) + } + + fn less(self: @Tensor, other: @Tensor) -> Tensor { + math::less::less(self, other) + } + + fn less_equal(self: @Tensor, other: @Tensor) -> Tensor { + math::less_equal::less_equal(self, other) + } + + fn abs(self: @Tensor) -> Tensor { + math::abs::abs(*self) + } + + fn ceil(self: @Tensor) -> Tensor { + math::ceil::ceil(*self) + } + + fn sin(self: @Tensor) -> Tensor { + math::sin::sin(*self) + } + + fn cos(self: @Tensor) -> Tensor { + math::cos::cos(*self) + } + + fn asin(self: @Tensor) -> Tensor { + math::asin::asin(*self) + } + + fn cumsum( + self: @Tensor, axis: usize, exclusive: Option, reverse: Option + ) -> Tensor { + math::cumsum::cumsum(self, axis, exclusive, reverse) + } + + fn flatten(self: @Tensor, axis: usize) -> Tensor { + math::flatten::flatten(self, axis) + } + + fn sinh(self: @Tensor) -> Tensor { + math::sinh::sinh(*self) + } + + fn tanh(self: @Tensor) -> Tensor { + math::tanh::tanh(*self) + } + + fn cosh(self: @Tensor) -> Tensor { + math::cosh::cosh(*self) + } + + fn acosh(self: @Tensor) -> Tensor { + math::acosh::acosh(*self) + } + + fn asinh(self: @Tensor) -> Tensor { + math::asinh::asinh(*self) + } + + fn atan(self: @Tensor) -> Tensor { + math::atan::atan(*self) + } + + fn xor(self: @Tensor, other: @Tensor) -> Tensor { + math::xor::xor(self, other) + } + + fn or(self: @Tensor, other: @Tensor) -> Tensor { + math::or::or(self, other) + } + + fn acos(self: @Tensor) -> Tensor { + math::acos::acos(*self) + } + + fn onehot( + self: @Tensor, depth: usize, axis: Option, values: Span + ) -> Tensor { + panic(array!['not supported!']) + } + + fn sqrt(self: @Tensor) -> Tensor { + math::sqrt::sqrt(*self) + } + + fn concat(tensors: Span>, axis: usize,) -> Tensor { + math::concat::concat(tensors, axis) + } + + fn quantize_linear( + self: @Tensor, y_scale: @Tensor, y_zero_point: @Tensor + ) -> Tensor:: { + quantization::quantize_linear::quantize_linear( + self, + y_scale, + y_zero_point, + NumberTrait::new_unscaled(128, true), + NumberTrait::new_unscaled(127, false) + ) + } + + fn dequantize_linear( + self: @Tensor, x_scale: @Tensor, x_zero_point: @Tensor + ) -> Tensor:: { + panic(array!['not supported!']) + } + + fn slice( + self: @Tensor, + starts: Span, + ends: Span, + axes: Option>, + steps: Option> + ) -> Tensor { + core::slice::(self, starts, ends, axes, steps) + } + + fn gather( + self: @Tensor, indices: Tensor, axis: Option + ) -> Tensor { + math::gather::gather(self, indices, axis) + } + + fn nonzero(self: @Tensor) -> Tensor { + core::nonzero(self) + } + + fn squeeze(self: @Tensor, axes: Option>) -> Tensor { + core::squeeze(self, axes) + } + + fn unsqueeze(self: @Tensor, axes: Span) -> Tensor { + core::unsqueeze(self, axes) + } + + fn sign(self: @Tensor) -> Tensor { + math::sign::sign(*self) + } + + fn clip(self: @Tensor, min: Option, max: Option) -> Tensor { + core::clip(self, min, max) + } +} + +/// Implements addition for `Tensor` using the `Add` trait. +impl FP8x23WTensorAdd< + FP8x23W, + impl FP8x23WTensor: TensorTrait, + impl TAdd: Add, + impl TCopy: Copy, + impl TDrop: Drop +> of Add> { + /// Adds two `Tensor` instances element-wise. + /// + /// # Arguments + /// * `lhs` - The first tensor. + /// * `rhs` - The second tensor. + /// + /// # Returns + /// * A `Tensor` instance representing the result of the element-wise addition. + fn add(lhs: Tensor, rhs: Tensor) -> Tensor { + math::arithmetic::add(@lhs, @rhs) + } +} + +/// Implements subtraction for `Tensor` using the `Sub` trait. +impl FP8x23WTensorSub< + FP8x23W, + impl FP8x23WTensor: TensorTrait, + impl TSub: Sub, + impl TCopy: Copy, + impl TDrop: Drop +> of Sub> { + /// Subtracts two `Tensor` instances element-wise. + /// + /// # Arguments + /// * `lhs` - The first tensor. + /// * `rhs` - The second tensor. + /// + /// # Returns + /// * A `Tensor` instance representing the result of the element-wise subtraction. + fn sub(lhs: Tensor, rhs: Tensor) -> Tensor { + math::arithmetic::sub(@lhs, @rhs) + } +} + +/// Implements multiplication for `Tensor` using the `Mul` trait. +impl FP8x23WTensorMul< + FP8x23W, + impl FP8x23WTensor: TensorTrait, + impl TMul: Mul, + impl TCopy: Copy, + impl TDrop: Drop +> of Mul> { + /// Multiplies two `Tensor` instances element-wise. + /// + /// # Arguments + /// * `lhs` - The first tensor. + /// * `rhs` - The second tensor. + /// + /// # Returns + /// * A `Tensor` instance representing the result of the element-wise multiplication. + fn mul(lhs: Tensor, rhs: Tensor) -> Tensor { + math::arithmetic::mul(@lhs, @rhs) + } +} + +/// Implements division for `Tensor` using the `Div` trait. +impl FP8x23WTensorDiv< + FP8x23W, + impl FP8x23WTensor: TensorTrait, + impl TDiv: Div, + impl TCopy: Copy, + impl TDrop: Drop +> of Div> { + /// Divides two `Tensor` instances element-wise. + /// + /// # Arguments + /// * `lhs` - The first tensor. + /// * `rhs` - The second tensor. + /// + /// # Returns + /// * A `Tensor` instance representing the result of the element-wise division. + fn div(lhs: Tensor, rhs: Tensor) -> Tensor { + math::arithmetic::div(@lhs, @rhs) + } +} + +/// Implements partial equal for two `Tensor` using the `PartialEq` trait. +impl FP8x23WTensorPartialEq of PartialEq> { + fn eq(lhs: @Tensor, rhs: @Tensor) -> bool { + tensor_eq(*lhs, *rhs) + } + + fn ne(lhs: @Tensor, rhs: @Tensor) -> bool { + !tensor_eq(*lhs, *rhs) + } +} + +impl U32TryIntoU32 of TryInto { + fn try_into(self: u64) -> Option { + Option::Some(self) + } +} + +// Internals + +const PRECISION: u64 = 75497; // 0.009 + +fn relative_eq(lhs: @FP8x23W, rhs: @FP8x23W) -> bool { + let diff = *lhs - *rhs; + + let rel_diff = if *lhs.mag != 0 { + (diff / *lhs).mag + } else { + diff.mag + }; + + rel_diff <= PRECISION +} + +fn tensor_eq(mut lhs: Tensor, mut rhs: Tensor,) -> bool { + let mut is_eq = true; + + loop { + if lhs.shape.len() == 0 || !is_eq { + break; + } + + is_eq = lhs.shape.pop_front().unwrap() == rhs.shape.pop_front().unwrap(); + }; + + if !is_eq { + return false; + } + + loop { + if lhs.data.len() == 0 || !is_eq { + break; + } + + is_eq = relative_eq(lhs.data.pop_front().unwrap(), rhs.data.pop_front().unwrap()); + }; + + return is_eq; +} + From 19d5e99d5b7e487776d6c42dc5b7b4823514c0eb Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Mon, 23 Oct 2023 11:10:27 +0300 Subject: [PATCH 42/42] add logsoftmaxwide --- src/operators/nn/functional/logsoftmax.cairo | 28 +++++++++++++++++++ .../nn/implementations/nn_fp16x16.cairo | 2 +- .../nn/implementations/nn_fp8x23.cairo | 2 +- 3 files changed, 30 insertions(+), 2 deletions(-) diff --git a/src/operators/nn/functional/logsoftmax.cairo b/src/operators/nn/functional/logsoftmax.cairo index 6d19cbb62..bd38d138c 100644 --- a/src/operators/nn/functional/logsoftmax.cairo +++ b/src/operators/nn/functional/logsoftmax.cairo @@ -2,6 +2,8 @@ use array::SpanTrait; use orion::numbers::NumberTrait; use orion::operators::tensor::core::{Tensor, TensorTrait}; +use orion::numbers::fixed_point::core::FixedTrait; +use orion::operators::tensor::math::{exp::exp_upcast, arithmetic::div_downcast}; /// Cf: NNTrait::logsoftmax docstring fn logsoftmax< @@ -16,3 +18,29 @@ fn logsoftmax< return logsoftmax; } + +/// Cf: NNTrait::logsoftmax docstring +fn logsoftmaxWide< + T, + TMAG, + W, + WMAG, + impl TTensor: TensorTrait, + impl WTensor: TensorTrait, + impl TDiv: Div, + impl TIntoW: Into, + impl WTryIntoT: TryInto, + impl TCopy: Copy, + impl TDrop: Drop, + impl WCopy: Copy, + impl WDrop: Drop, + impl TFixed: FixedTrait, + impl WFixed: FixedTrait, +>( + z: @Tensor, axis: usize +) -> Tensor { + let exp_tensor: Tensor = exp_upcast(*z); + let sum = exp_tensor.reduce_sum(axis, true); + let softmax = div_downcast(@exp_tensor, @sum); + softmax.log() +} \ No newline at end of file diff --git a/src/operators/nn/implementations/nn_fp16x16.cairo b/src/operators/nn/implementations/nn_fp16x16.cairo index b940d8742..de81cde6d 100644 --- a/src/operators/nn/implementations/nn_fp16x16.cairo +++ b/src/operators/nn/implementations/nn_fp16x16.cairo @@ -28,7 +28,7 @@ impl FP16x16NN of NNTrait { } fn logsoftmax(tensor: @Tensor, axis: usize) -> Tensor { - functional::logsoftmax::logsoftmax(tensor, axis) + functional::logsoftmax::logsoftmaxWide::(tensor, axis) } fn softsign(tensor: @Tensor) -> Tensor { diff --git a/src/operators/nn/implementations/nn_fp8x23.cairo b/src/operators/nn/implementations/nn_fp8x23.cairo index 510f8cebd..d837b8fef 100644 --- a/src/operators/nn/implementations/nn_fp8x23.cairo +++ b/src/operators/nn/implementations/nn_fp8x23.cairo @@ -26,7 +26,7 @@ impl FP8x23NN of NNTrait { } fn logsoftmax(tensor: @Tensor, axis: usize) -> Tensor { - functional::logsoftmax::logsoftmax(tensor, axis) + functional::logsoftmax::logsoftmaxWide::(tensor, axis) } fn softsign(tensor: @Tensor) -> Tensor {