From 830601705db50b1f174ff407661b4326b70f9655 Mon Sep 17 00:00:00 2001 From: "d.levin256@gmail.com" Date: Mon, 13 Jan 2025 11:46:05 +0100 Subject: [PATCH] Improve dft performance on arm64 --- src/dft/bitrev.hpp | 15 +++++++++++++++ src/dft/fft-impl.hpp | 32 ++++++++++++++++++++++++-------- tests/dft_test.cpp | 8 +++++--- 3 files changed, 44 insertions(+), 11 deletions(-) diff --git a/src/dft/bitrev.hpp b/src/dft/bitrev.hpp index f22dfe54..f5faf7d1 100644 --- a/src/dft/bitrev.hpp +++ b/src/dft/bitrev.hpp @@ -49,15 +49,22 @@ constexpr inline static size_t bitrev_table_log2N = ilog2(arraysize(data::bitrev template CMT_GNU_CONSTEXPR inline u32 bitrev_using_table(u32 x) { +#ifdef CMT_ARCH_NEON + return __builtin_bitreverse32(x) >> (32 - Bits); +#else if constexpr (Bits > bitrev_table_log2N) return bitreverse(x); return data::bitrev_table[x] >> (bitrev_table_log2N - Bits); +#endif } template CMT_GNU_CONSTEXPR inline u32 bitrev_using_table(u32 x, size_t bits, cbool_t) { +#ifdef CMT_ARCH_NEON + return __builtin_bitreverse32(x) >> (32 - bits); +#else if constexpr (use_table) { return data::bitrev_table[x] >> (bitrev_table_log2N - bits); @@ -66,10 +73,17 @@ CMT_GNU_CONSTEXPR inline u32 bitrev_using_table(u32 x, size_t bits, cbool_t(x) >> (32 - bits); } +#endif } CMT_GNU_CONSTEXPR inline u32 dig4rev_using_table(u32 x, size_t bits) { +#ifdef CMT_ARCH_NEON + x = __builtin_bitreverse32(x); + x = (((x & 0xaaaaaaaa) >> 1) | ((x & 0x55555555) << 1)); + x = x >> (32 - bits); + return x; +#else if (bits > bitrev_table_log2N) { if (bits <= 16) @@ -82,6 +96,7 @@ CMT_GNU_CONSTEXPR inline u32 dig4rev_using_table(u32 x, size_t bits) x = (((x & 0xaaaaaaaa) >> 1) | ((x & 0x55555555) << 1)); x = x >> (bitrev_table_log2N - bits); return x; +#endif } template diff --git a/src/dft/fft-impl.hpp b/src/dft/fft-impl.hpp index 234ed70c..d176b283 100644 --- a/src/dft/fft-impl.hpp +++ b/src/dft/fft-impl.hpp @@ -52,22 +52,30 @@ template inline std::bitset fft_algorithm_selection; template <> -inline std::bitset fft_algorithm_selection{ (1ull << 15) - 1 }; +inline std::bitset fft_algorithm_selection{ +#ifdef CMT_ARCH_NEON + 0 +#else + (1ull << 15) - 1 +#endif +}; template <> inline std::bitset fft_algorithm_selection{ 0 }; template -constexpr bool inline use_autosort(size_t log2n) +inline bool use_autosort(size_t log2n) { return fft_algorithm_selection[log2n]; } +#ifndef CMT_ARCH_NEON +#define KFR_AUTOSORT_FOR_2048 #define KFR_AUTOSORT_FOR_128D #define KFR_AUTOSORT_FOR_256D #define KFR_AUTOSORT_FOR_512 #define KFR_AUTOSORT_FOR_1024 -#define KFR_AUTOSORT_FOR_2048 +#endif #ifdef CMT_ARCH_AVX template <> @@ -855,7 +863,11 @@ template struct fft_config { constexpr static inline const bool recursion = true; - constexpr static inline const bool prefetch = true; +#ifdef CMT_ARCH_NEON + constexpr static inline const bool prefetch = false; +#else + constexpr static inline const bool prefetch = true; +#endif constexpr static inline const size_t process_width = const_max(static_cast(1), vector_capacity / 16); }; @@ -1606,7 +1618,7 @@ struct fft_specialization : fft_final_stage_impl { fft_final_stage_impl::template do_execute(out, in, nullptr); if (this->need_reorder) - fft_reorder(out, 10, cfalse); + fft_reorder(out, csize_t<10>{}, cbool_t{}); } }; #endif @@ -1649,8 +1661,6 @@ struct fft_specialization : dft_stage radix8_autosort_pass_last(256, csize, no, no, no, cbool, out, out, tw); } }; - -#else #endif template @@ -1768,7 +1778,13 @@ KFR_INTRINSIC void init_fft(dft_plan* self, size_t size, dft_order) { const size_t log2n = ilog2(size); cswitch( - csizes_t<0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11>(), log2n, + csizes_t<0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 +#ifdef KFR_AUTOSORT_FOR_2048 + , + 11 +#endif + >(), + log2n, [&](auto log2n) { (void)log2n; diff --git a/tests/dft_test.cpp b/tests/dft_test.cpp index ccd382e2..93df83c3 100644 --- a/tests/dft_test.cpp +++ b/tests/dft_test.cpp @@ -33,11 +33,13 @@ constexpr ctypes_t dft_float_types{}; constexpr ctypes_t dft_float_types{}; #endif -#if defined(CMT_ARCH_X86) && !defined(KFR_NO_PERF_TESTS) +#if !defined(KFR_NO_PERF_TESTS) static void full_barrier() { -#ifdef CMT_COMPILER_GNU +#if defined(CMT_ARCH_NEON) + asm volatile("dmb ish" ::: "memory"); +#elif defined(CMT_COMPILER_GNU) asm volatile("mfence" ::: "memory"); #else _ReadWriteBarrier(); @@ -235,7 +237,7 @@ TEST(fft_accuracy) if (is_even(size)) { - index_t csize = dft_plan_real::complex_size_for(size, dft_pack_format::CCs); + index_t csize = dft_plan_real::complex_size_for(size, dft_pack_format::CCs); univector in = truncate(gen_random_range(gen, -1.0, +1.0), size); univector> out = truncate(dimensions<1>(scalar(qnan)), csize);