Skip to content

Commit

Permalink
imax() for neon64
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikhail Katliar committed Oct 25, 2024
1 parent b8f16d3 commit c0dda32
Showing 1 changed file with 31 additions and 3 deletions.
34 changes: 31 additions & 3 deletions include/blast/math/simd/arch/Neon64.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,16 +99,44 @@ namespace blast

template <typename Arch>
requires std::is_base_of_v<xsimd::neon64, Arch>
inline std::tuple<xsimd::batch<float, Arch>, xsimd::batch<std::int32_t, Arch>> imax(xsimd::batch<float, Arch> const& v1, xsimd::batch<std::int32_t, Arch> const& idx) noexcept
inline std::tuple<xsimd::batch<float, Arch>, xsimd::batch<std::int32_t, Arch>> imax(xsimd::batch<float, Arch> const& x, xsimd::batch<std::int32_t, Arch> const& idx) noexcept
{
throw std::logic_error {"Not implemented"};
// Step 1: Initial pairwise comparisons
float32x4_t const y1 = vextq_f32(x, x, 1); // Shift elements by 1: [x[1], x[2], x[3], x[0]]
int32x4_t const iy1 = vextq_s32(idx, idx, 1); // Shift idx by 1: [idx[1], idx[2], idx[3], idx[0]]

uint32x4_t const mask1 = vcgtq_f32(x, y1); // Mask for x > y1
float32x4_t const max1 = vbslq_f32(mask1, x, y1); // [max(x[0], x[1]), max(x[1], x[2]), max(x[2], x[3]), max(x[3], x[0])]
int32x4_t const idx1 = vbslq_s32(mask1, idx, iy1); // Blend idx and iy1 based on mask

// Step 2: Second pairwise comparison on the result from Step 1
float32x4_t const y2 = vextq_f32(max1, max1, 2); // Shift elements by 2: [max1[2], max1[3], max1[0], max1[1]]
int32x4_t const iy2 = vextq_s32(idx1, idx1, 2); // Shift idx1 by 2: [idx1[2], idx1[3], idx1[0], idx1[1]]

uint32x4_t const mask2 = vcgtq_f32(max1, y2); // Mask for max1 > y2
float32x4_t const max2 = vbslq_f32(mask2, max1, y2); // Blend max1 and y2 based on mask
int32x4_t const idx2 = vbslq_s32(mask2, idx1, iy2); // Blend idx1 and iy2 based on mask

// Return the max value and corresponding index, which are now the same across all lanes
return {max2, idx2};
}


template <typename Arch>
requires std::is_base_of_v<xsimd::neon64, Arch>
inline std::tuple<xsimd::batch<double, Arch>, xsimd::batch<std::int64_t, Arch>> imax(xsimd::batch<double, Arch> const& x, xsimd::batch<std::int64_t, Arch> const& idx) noexcept
{
throw std::logic_error {"Not implemented"};
// Swap elements of x
float64x2_t const y = vextq_f64(x, x, 1);
int64x2_t const iy = vextq_s64(idx, idx, 1);

// Compare
uint64x2_t const mask = vcgtq_f64(x, y);

// Blend
float64x2_t const m = vbslq_f64(mask, x, y);
int64x2_t const im = vbslq_s64(mask, idx, iy);

return {m, im};
}
}

0 comments on commit c0dda32

Please sign in to comment.