Skip to content

Commit

Permalink
Improve dft performance on arm64
Browse files Browse the repository at this point in the history
  • Loading branch information
dancazarin committed Jan 13, 2025
1 parent 6aea976 commit 8306017
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 11 deletions.
15 changes: 15 additions & 0 deletions src/dft/bitrev.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,22 @@ constexpr inline static size_t bitrev_table_log2N = ilog2(arraysize(data::bitrev
template <size_t Bits>
CMT_GNU_CONSTEXPR inline u32 bitrev_using_table(u32 x)
{
#ifdef CMT_ARCH_NEON
return __builtin_bitreverse32(x) >> (32 - Bits);
#else
if constexpr (Bits > bitrev_table_log2N)
return bitreverse<Bits>(x);

return data::bitrev_table[x] >> (bitrev_table_log2N - Bits);
#endif
}

template <bool use_table>
CMT_GNU_CONSTEXPR inline u32 bitrev_using_table(u32 x, size_t bits, cbool_t<use_table>)
{
#ifdef CMT_ARCH_NEON
return __builtin_bitreverse32(x) >> (32 - bits);
#else
if constexpr (use_table)
{
return data::bitrev_table[x] >> (bitrev_table_log2N - bits);
Expand All @@ -66,10 +73,17 @@ CMT_GNU_CONSTEXPR inline u32 bitrev_using_table(u32 x, size_t bits, cbool_t<use_
{
return bitreverse<32>(x) >> (32 - bits);
}
#endif
}

CMT_GNU_CONSTEXPR inline u32 dig4rev_using_table(u32 x, size_t bits)
{
#ifdef CMT_ARCH_NEON
x = __builtin_bitreverse32(x);
x = (((x & 0xaaaaaaaa) >> 1) | ((x & 0x55555555) << 1));
x = x >> (32 - bits);
return x;
#else
if (bits > bitrev_table_log2N)
{
if (bits <= 16)
Expand All @@ -82,6 +96,7 @@ CMT_GNU_CONSTEXPR inline u32 dig4rev_using_table(u32 x, size_t bits)
x = (((x & 0xaaaaaaaa) >> 1) | ((x & 0x55555555) << 1));
x = x >> (bitrev_table_log2N - bits);
return x;
#endif
}

template <size_t log2n, size_t bitrev, typename T>
Expand Down
32 changes: 24 additions & 8 deletions src/dft/fft-impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,22 +52,30 @@ template <typename T>
inline std::bitset<DFT_MAX_STAGES> fft_algorithm_selection;

template <>
inline std::bitset<DFT_MAX_STAGES> fft_algorithm_selection<float>{ (1ull << 15) - 1 };
inline std::bitset<DFT_MAX_STAGES> fft_algorithm_selection<float>{
#ifdef CMT_ARCH_NEON
0
#else
(1ull << 15) - 1
#endif
};

template <>
inline std::bitset<DFT_MAX_STAGES> fft_algorithm_selection<double>{ 0 };

template <typename T>
constexpr bool inline use_autosort(size_t log2n)
inline bool use_autosort(size_t log2n)
{
return fft_algorithm_selection<T>[log2n];
}

#ifndef CMT_ARCH_NEON
#define KFR_AUTOSORT_FOR_2048
#define KFR_AUTOSORT_FOR_128D
#define KFR_AUTOSORT_FOR_256D
#define KFR_AUTOSORT_FOR_512
#define KFR_AUTOSORT_FOR_1024
#define KFR_AUTOSORT_FOR_2048
#endif

#ifdef CMT_ARCH_AVX
template <>
Expand Down Expand Up @@ -855,7 +863,11 @@ template <typename T>
struct fft_config
{
constexpr static inline const bool recursion = true;
constexpr static inline const bool prefetch = true;
#ifdef CMT_ARCH_NEON
constexpr static inline const bool prefetch = false;
#else
constexpr static inline const bool prefetch = true;
#endif
constexpr static inline const size_t process_width =
const_max(static_cast<size_t>(1), vector_capacity<T> / 16);
};
Expand Down Expand Up @@ -1606,7 +1618,7 @@ struct fft_specialization<T, 10> : fft_final_stage_impl<T, false, 1024>
{
fft_final_stage_impl<T, false, 1024>::template do_execute<inverse>(out, in, nullptr);
if (this->need_reorder)
fft_reorder(out, 10, cfalse);
fft_reorder(out, csize_t<10>{}, cbool_t<always_br2>{});
}
};
#endif
Expand Down Expand Up @@ -1649,8 +1661,6 @@ struct fft_specialization<T, 11> : dft_stage<T>
radix8_autosort_pass_last(256, csize<width>, no, no, no, cbool<inverse>, out, out, tw);
}
};

#else
#endif

template <bool is_even, bool first, typename T, bool autosort>
Expand Down Expand Up @@ -1768,7 +1778,13 @@ KFR_INTRINSIC void init_fft(dft_plan<T>* self, size_t size, dft_order)
{
const size_t log2n = ilog2(size);
cswitch(
csizes_t<0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11>(), log2n,
csizes_t<0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10
#ifdef KFR_AUTOSORT_FOR_2048
,
11
#endif
>(),
log2n,
[&](auto log2n)
{
(void)log2n;
Expand Down
8 changes: 5 additions & 3 deletions tests/dft_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,13 @@ constexpr ctypes_t<float, double> dft_float_types{};
constexpr ctypes_t<float> dft_float_types{};
#endif

#if defined(CMT_ARCH_X86) && !defined(KFR_NO_PERF_TESTS)
#if !defined(KFR_NO_PERF_TESTS)

static void full_barrier()
{
#ifdef CMT_COMPILER_GNU
#if defined(CMT_ARCH_NEON)
asm volatile("dmb ish" ::: "memory");
#elif defined(CMT_COMPILER_GNU)
asm volatile("mfence" ::: "memory");
#else
_ReadWriteBarrier();
Expand Down Expand Up @@ -235,7 +237,7 @@ TEST(fft_accuracy)

if (is_even(size))
{
index_t csize = dft_plan_real<float_type>::complex_size_for(size, dft_pack_format::CCs);
index_t csize = dft_plan_real<float_type>::complex_size_for(size, dft_pack_format::CCs);
univector<float_type> in = truncate(gen_random_range<float_type>(gen, -1.0, +1.0), size);

univector<complex<float_type>> out = truncate(dimensions<1>(scalar(qnan)), csize);
Expand Down

0 comments on commit 8306017

Please sign in to comment.