Skip to content

Commit

Permalink
idx_abs_max: add customization point kokkos#96
Browse files Browse the repository at this point in the history
Note the different behavior with kokkos-kernels,
tracked in issue kokkos#114.
  • Loading branch information
fnrizzi committed Oct 21, 2021
1 parent 21fbd5e commit bb58ad4
Show file tree
Hide file tree
Showing 6 changed files with 116 additions and 4 deletions.
1 change: 1 addition & 0 deletions examples/kokkos-based/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
linalg_add_example(add_kokkos)
linalg_add_example(dot_kokkos)
linalg_add_example(dotc_kokkos)
linalg_add_example(idx_abs_max_kokkos)
linalg_add_example(simple_scale_kokkos)
linalg_add_example(matrix_vector_product_kokkos)
42 changes: 42 additions & 0 deletions examples/kokkos-based/idx_abs_max_kokkos.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#include <experimental/linalg>
#include <iostream>

int main(int argc, char* argv[])
{
std::cout << "idx_abs_max example: calling kokkos-kernels" << std::endl;

std::size_t N = 10;
Kokkos::initialize(argc,argv);
{
using value_type = double;

Kokkos::View<value_type*> a_view("A",N);
value_type* a_ptr = a_view.data();

// Requires CTAD working, GCC 11.1 works but some others are buggy
// std::experimental::mdspan a(a_ptr,N);
using extents_type = std::experimental::extents<std::experimental::dynamic_extent>;
std::experimental::mdspan<value_type, extents_type> a(a_ptr,N);
a(0) = 0.5;
a(1) = 0.2;
a(2) = 0.1;
a(3) = 0.4;
a(4) = -0.8;
a(5) = -1.7;
a(6) = -0.3;
a(7) = 0.5;
a(8) = -1.7;
a(9) = -0.9;

namespace stdla = std::experimental::linalg;

// This goes to the base implementation
const auto idx = stdla::idx_abs_max(std::execution::seq, a);
printf("Seq result = %i\n", idx);

// This forwards to KokkosKernels (https://github.com/kokkos/kokkos-kernels
const auto idx_kk = stdla::idx_abs_max(KokkosKernelsSTD::kokkos_exec<>(), a);
printf("Kokkos result = %i\n", idx_kk);
}
Kokkos::finalize();
}
50 changes: 47 additions & 3 deletions include/experimental/__p1673_bits/blas1_vector_idx_abs_max.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,33 @@ namespace experimental {
inline namespace __p1673_version_0 {
namespace linalg {

// begin anonymous namespace
namespace {

template <class Exec, class v_t, class = void>
struct is_custom_idx_abs_max_avail : std::false_type {};

template <class Exec, class v_t>
struct is_custom_idx_abs_max_avail<
Exec, v_t,
std::enable_if_t<
std::is_integral<
decltype(idx_abs_max(std::declval<Exec>(),
std::declval<v_t>()
)
)
>::value
>
>
{
static constexpr bool value = !std::is_same<Exec,std::experimental::linalg::impl::inline_exec_t>::value;
};

template<class ElementType,
extents<>::size_type ext0,
class Layout,
class Accessor>
extents<>::size_type idx_abs_max(
extents<>::size_type idx_abs_max_default_impl(
std::experimental::mdspan<ElementType, std::experimental::extents<ext0>, Layout, Accessor> v)
{
using std::abs;
Expand All @@ -73,16 +95,38 @@ extents<>::size_type idx_abs_max(
return maxInd; // FIXME check for NaN "never less than" stuff
}

} // end anonymous namespace

template<class ExecutionPolicy,
class ElementType,
extents<>::size_type ext0,
class Layout,
class Accessor>
extents<>::size_type idx_abs_max(
ExecutionPolicy&& /* exec */,
ExecutionPolicy&& exec,
std::experimental::mdspan<ElementType, std::experimental::extents<ext0>, Layout, Accessor> v)
{
constexpr bool use_custom = is_custom_idx_abs_max_avail<
decltype(execpolicy_mapper(exec)), decltype(v)
>::value;

if constexpr(use_custom){
using return_type = extents<>::size_type;
return return_type(idx_abs_max(execpolicy_mapper(exec), v));
}
else{
return idx_abs_max_default_impl(v);
}
}

template<class ElementType,
extents<>::size_type ext0,
class Layout,
class Accessor>
extents<>::size_type idx_abs_max(
std::experimental::mdspan<ElementType, std::experimental::extents<ext0>, Layout, Accessor> v)
{
return idx_abs_max(v);
return idx_abs_max(std::experimental::linalg::impl::default_exec_t(), v);
}

} // end namespace linalg
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ struct default_exec_t {};
}


#if defined(LINALG_ENABLE_KOKKOS) || defined(LINALG_ENABLE_KOKKOS_DEFAULT)
#if defined(LINALG_ENABLE_KOKKOS) && defined(LINALG_ENABLE_KOKKOS_DEFAULT)
#include <experimental/__p1673_bits/kokkos-kernels/exec_policy_wrapper_kk.hpp>
#endif

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@

#ifndef LINALG_TPLIMPLEMENTATIONS_INCLUDE_EXPERIMENTAL_P1673_BITS_KOKKOSKERNELS_IDX_ABS_MAX_HPP_
#define LINALG_TPLIMPLEMENTATIONS_INCLUDE_EXPERIMENTAL_P1673_BITS_KOKKOSKERNELS_IDX_ABS_MAX_HPP_

#include <KokkosBlas1_iamax.hpp>

namespace KokkosKernelsSTD {

template<class ExecSpace,
class ElementType,
std::experimental::extents<>::size_type ext0,
class Layout,
class Accessor>
auto idx_abs_max(kokkos_exec<ExecSpace>,
std::experimental::mdspan<ElementType, std::experimental::extents<ext0>, Layout, Accessor> v)
{
// note that -1 here, this is related to:
// https://github.com/kokkos/stdBLAS/issues/114

return KokkosBlas::iamax(Impl::mdspan_to_view(v))-1;
}

}
#endif
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@
#include "__p1673_bits/kokkos-kernels/blas1_dot_kk.hpp"
#include "__p1673_bits/kokkos-kernels/blas1_add_kk.hpp"
#include "__p1673_bits/kokkos-kernels/blas1_scale_kk.hpp"
#include "__p1673_bits/kokkos-kernels/blas1_idx_abs_max_kk.hpp"
#include "__p1673_bits/kokkos-kernels/blas2_matrix_vector_product_kk.hpp"

0 comments on commit bb58ad4

Please sign in to comment.