diff --git a/cpp/src/neighbors/detail/cagra/device_common.hpp b/cpp/src/neighbors/detail/cagra/device_common.hpp index 20545a225..73787ec90 100644 --- a/cpp/src/neighbors/detail/cagra/device_common.hpp +++ b/cpp/src/neighbors/detail/cagra/device_common.hpp @@ -210,12 +210,12 @@ template RAFT_DEVICE_INLINE_FUNCTION auto team_sum(T x, uint32_t team_size) -> T { switch (team_size) { - case 1: return team_sum<1>(x); - case 2: return team_sum<2>(x); - case 4: return team_sum<4>(x); - case 8: return team_sum<8>(x); - case 16: return team_sum<16>(x); - default: return team_sum<32>(x); + case 32: x += raft::shfl_xor(x, 16); + case 16: x += raft::shfl_xor(x, 8); + case 8: x += raft::shfl_xor(x, 4); + case 4: x += raft::shfl_xor(x, 2); + case 2: x += raft::shfl_xor(x, 1); + default: return x; } }