From d06887ac53a214b431270ec4d1eb8c02bbc6b203 Mon Sep 17 00:00:00 2001 From: jatin Date: Wed, 22 Nov 2023 03:07:29 -0800 Subject: [PATCH] Adding sin and cosine approximations --- include/math_approx/math_approx.hpp | 1 + include/math_approx/src/basic_math.hpp | 12 ++++++ include/math_approx/src/tanh_approx.hpp | 2 +- test/CMakeLists.txt | 2 + test/src/cos_approx_test.cpp | 49 +++++++++++++++++++++++-- test/src/sin_approx_test.cpp | 49 +++++++++++++++++++++++-- tools/plotter/plotter.cpp | 14 +++---- 7 files changed, 114 insertions(+), 15 deletions(-) diff --git a/include/math_approx/math_approx.hpp b/include/math_approx/math_approx.hpp index 734138b..6f5a076 100644 --- a/include/math_approx/math_approx.hpp +++ b/include/math_approx/math_approx.hpp @@ -8,3 +8,4 @@ namespace math_approx #include "src/tanh_approx.hpp" #include "src/sigmoid_approx.hpp" +#include "src/sin_approx.hpp" diff --git a/include/math_approx/src/basic_math.hpp b/include/math_approx/src/basic_math.hpp index e4b6be7..9a1f2a5 100644 --- a/include/math_approx/src/basic_math.hpp +++ b/include/math_approx/src/basic_math.hpp @@ -36,6 +36,12 @@ T rsqrt (T x) // return x * r; } +template +T select (bool q, T t, T f) +{ + return q ? t : f; +} + #if defined(XSIMD_HPP) template struct scalar_of> @@ -54,5 +60,11 @@ xsimd::batch rsqrt (xsimd::batch x) r *= (S) -0.5; return x * r; } + +template +xsimd::batch select (xsimd::batch_bool q, xsimd::batch t, xsimd::batch f) +{ + return xsimd::select (q, t, f); +} #endif } // namespace math_approx diff --git a/include/math_approx/src/tanh_approx.hpp b/include/math_approx/src/tanh_approx.hpp index 1bf932e..e4c8283 100644 --- a/include/math_approx/src/tanh_approx.hpp +++ b/include/math_approx/src/tanh_approx.hpp @@ -67,7 +67,7 @@ namespace tanh_detail template T tanh (T x) { - static_assert (order % 2 == 1 && order <= 11 && order >= 3, "Order must e an odd number within [3, 9]"); + static_assert (order % 2 == 1 && order <= 11 && order >= 3, "Order must e an odd number within [3, 11]"); T x_poly {}; if constexpr (order == 11) diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 0be4d0b..d588dfa 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -30,3 +30,5 @@ endfunction(setup_catch_test) setup_catch_test(tanh_approx_test) setup_catch_test(sigmoid_approx_test) +setup_catch_test(sin_approx_test) +setup_catch_test(cos_approx_test) diff --git a/test/src/cos_approx_test.cpp b/test/src/cos_approx_test.cpp index 028cc36..c821b13 100644 --- a/test/src/cos_approx_test.cpp +++ b/test/src/cos_approx_test.cpp @@ -1,3 +1,46 @@ -// -// Created by jatin on 11/22/23. -// +#include "test_helpers.hpp" +#include +#include + +#include + +TEST_CASE ("Cosine Approx Test") +{ +#if ! defined(WIN32) + const auto all_floats = test_helpers::all_32_bit_floats (-10.0f, 10.0f, 1.0e-3f); +#else + const auto all_floats = test_helpers::all_32_bit_floats (-10.0f, 10.0f, 1.0e-1f); +#endif + const auto y_exact = test_helpers::compute_all (all_floats, [] (auto x) + { return std::cos (x); }); + + const auto test_approx = [&all_floats, &y_exact] (auto&& f_approx, float err_bound) + { + const auto y_approx = test_helpers::compute_all (all_floats, f_approx); + + const auto error = test_helpers::compute_error (y_exact, y_approx); + const auto max_error = test_helpers::abs_max (error); + + std::cout << max_error << std::endl; + REQUIRE (std::abs (max_error) < err_bound); + }; + + SECTION ("9th-Order") + { + test_approx ([] (auto x) + { return math_approx::cos<9> (x); }, + 7.0e-7f); + } + SECTION ("7th-Order") + { + test_approx ([] (auto x) + { return math_approx::cos<7> (x); }, + 1.8e-5f); + } + SECTION ("5th-Order") + { + test_approx ([] (auto x) + { return math_approx::cos<5> (x); }, + 7.5e-4f); + } +} diff --git a/test/src/sin_approx_test.cpp b/test/src/sin_approx_test.cpp index 028cc36..11cdaca 100644 --- a/test/src/sin_approx_test.cpp +++ b/test/src/sin_approx_test.cpp @@ -1,3 +1,46 @@ -// -// Created by jatin on 11/22/23. -// +#include "test_helpers.hpp" +#include +#include + +#include + +TEST_CASE ("Sine Approx Test") +{ +#if ! defined(WIN32) + const auto all_floats = test_helpers::all_32_bit_floats (-10.0f, 10.0f, 1.0e-3f); +#else + const auto all_floats = test_helpers::all_32_bit_floats (-10.0f, 10.0f, 1.0e-1f); +#endif + const auto y_exact = test_helpers::compute_all (all_floats, [] (auto x) + { return std::sin (x); }); + + const auto test_approx = [&all_floats, &y_exact] (auto&& f_approx, float err_bound) + { + const auto y_approx = test_helpers::compute_all (all_floats, f_approx); + + const auto error = test_helpers::compute_error (y_exact, y_approx); + const auto max_error = test_helpers::abs_max (error); + + // std::cout << max_error << std::endl; + REQUIRE (std::abs (max_error) < err_bound); + }; + + SECTION ("9th-Order") + { + test_approx ([] (auto x) + { return math_approx::sin<9> (x); }, + 8.5e-7f); + } + SECTION ("7th-Order") + { + test_approx ([] (auto x) + { return math_approx::sin<7> (x); }, + 1.8e-5f); + } + SECTION ("5th-Order") + { + test_approx ([] (auto x) + { return math_approx::sin<5> (x); }, + 7.5e-4f); + } +} diff --git a/tools/plotter/plotter.cpp b/tools/plotter/plotter.cpp index 3fab040..8ec6b9d 100644 --- a/tools/plotter/plotter.cpp +++ b/tools/plotter/plotter.cpp @@ -57,19 +57,17 @@ void plot_function (std::span all_floats, int main() { plt::figure(); - const auto range = std::make_pair (-10.0f, 10.0f); + const auto range = std::make_pair (-3.141f, 3.141f); static constexpr auto tol = 1.0e-2f; const auto all_floats = test_helpers::all_32_bit_floats (range.first, range.second, tol); const auto y_exact = test_helpers::compute_all (all_floats, [] (float x) - { return 1.0f / (1.0f + std::exp (-x)); }); + { return std::cos (x); }); - plot_error ( - all_floats, - y_exact, - [] (float x) - { return math_approx::sigmoid_exp<3> (x); }, - "Sigmoid-Exp-5"); + // // plot_error (all_floats, y_exact, [] (float x) { return math_approx::sin<5> (x); }, "Sin-5"); + // // plot_error (all_floats, y_exact, [] (float x) { return math_approx::sin<7> (x); }, "Sin-7"); + plot_ulp_error (all_floats, y_exact, [] (float x) { return math_approx::cos_mpi_pi<9> (x); }, "Cos-9"); + // plot_function (all_floats, [] (float x) { return math_approx::cos_mpi_pi<9> (x); }, "Cos-9"); plt::legend ({ { "loc", "upper right" } }); plt::xlim (range.first, range.second);