diff --git a/include/blast/math/simd/arch/Neon64.hpp b/include/blast/math/simd/arch/Neon64.hpp index cd58be0..83c23cb 100644 --- a/include/blast/math/simd/arch/Neon64.hpp +++ b/include/blast/math/simd/arch/Neon64.hpp @@ -99,9 +99,26 @@ namespace blast template requires std::is_base_of_v - inline std::tuple, xsimd::batch> imax(xsimd::batch const& v1, xsimd::batch const& idx) noexcept + inline std::tuple, xsimd::batch> imax(xsimd::batch const& x, xsimd::batch const& idx) noexcept { - throw std::logic_error {"Not implemented"}; + // Step 1: Initial pairwise comparisons + float32x4_t const y1 = vextq_f32(x, x, 1); // Shift elements by 1: [x[1], x[2], x[3], x[0]] + int32x4_t const iy1 = vextq_s32(idx, idx, 1); // Shift idx by 1: [idx[1], idx[2], idx[3], idx[0]] + + uint32x4_t const mask1 = vcgtq_f32(x, y1); // Mask for x > y1 + float32x4_t const max1 = vbslq_f32(mask1, x, y1); // [max(x[0], x[1]), max(x[1], x[2]), max(x[2], x[3]), max(x[3], x[0])] + int32x4_t const idx1 = vbslq_s32(mask1, idx, iy1); // Blend idx and iy1 based on mask + + // Step 2: Second pairwise comparison on the result from Step 1 + float32x4_t const y2 = vextq_f32(max1, max1, 2); // Shift elements by 2: [max1[2], max1[3], max1[0], max1[1]] + int32x4_t const iy2 = vextq_s32(idx1, idx1, 2); // Shift idx1 by 2: [idx1[2], idx1[3], idx1[0], idx1[1]] + + uint32x4_t const mask2 = vcgtq_f32(max1, y2); // Mask for max1 > y2 + float32x4_t const max2 = vbslq_f32(mask2, max1, y2); // Blend max1 and y2 based on mask + int32x4_t const idx2 = vbslq_s32(mask2, idx1, iy2); // Blend idx1 and iy2 based on mask + + // Return the max value and corresponding index, which are now the same across all lanes + return {max2, idx2}; } @@ -109,6 +126,17 @@ namespace blast requires std::is_base_of_v inline std::tuple, xsimd::batch> imax(xsimd::batch const& x, xsimd::batch const& idx) noexcept { - throw std::logic_error {"Not implemented"}; + // Swap elements of x + float64x2_t const y = vextq_f64(x, x, 1); + int64x2_t const iy = vextq_s64(idx, idx, 1); + + // Compare + uint64x2_t const mask = vcgtq_f64(x, y); + + // Blend + float64x2_t const m = vbslq_f64(mask, x, y); + int64x2_t const im = vbslq_s64(mask, idx, iy); + + return {m, im}; } }